├── README.md ├── confPath.m ├── data └── afew │ └── spddb_afew_train_spd400_int_histeq.mat ├── manopt ├── CLA.txt ├── COPYING.txt ├── CREDITS.txt ├── LICENSE.txt ├── README.txt ├── checkinstall │ └── basicexample.m ├── examples │ ├── dominant_invariant_subspace.m │ ├── generalized_procrustes.m │ ├── low_rank_matrix_completion.m │ ├── maxcut.m │ ├── maxcut_octave.m │ ├── packing_on_the_sphere.m │ ├── positive_definite_karcher_mean.m │ ├── sparse_pca.m │ └── truncated_svd.m ├── importmanopt.m ├── manopt │ ├── manifolds │ │ ├── complexcircle │ │ │ └── complexcirclefactory.m │ │ ├── euclidean │ │ │ ├── euclideanfactory.m │ │ │ └── symmetricfactory.m │ │ ├── fixedrank │ │ │ ├── fixedrankMNquotientfactory.m │ │ │ ├── fixedrankembeddedfactory.m │ │ │ ├── fixedrankfactory_2factors.m │ │ │ ├── fixedrankfactory_2factors_preconditioned.m │ │ │ ├── fixedrankfactory_2factors_subspace_projection.m │ │ │ ├── fixedrankfactory_3factors.m │ │ │ └── fixedrankfactory_3factors_preconditioned.m │ │ ├── grassmann │ │ │ └── grassmannfactory.m │ │ ├── oblique │ │ │ └── obliquefactory.m │ │ ├── rotations │ │ │ ├── randrot.m │ │ │ ├── randskew.m │ │ │ └── rotationsfactory.m │ │ ├── sphere │ │ │ ├── spherecomplexfactory.m │ │ │ └── spherefactory.m │ │ ├── stiefel │ │ │ └── stiefelfactory.m │ │ └── symfixedrank │ │ │ ├── elliptopefactory.m │ │ │ ├── spectrahedronfactory.m │ │ │ ├── symfixedrankYYfactory.m │ │ │ └── sympositivedefinitefactory.m │ ├── privatetools │ │ ├── applyStatsfun.m │ │ ├── canGetCost.m │ │ ├── canGetDirectionalDerivative.m │ │ ├── canGetEuclideanGradient.m │ │ ├── canGetGradient.m │ │ ├── canGetHessian.m │ │ ├── canGetLinesearch.m │ │ ├── canGetPrecon.m │ │ ├── getApproxHessian.m │ │ ├── getCost.m │ │ ├── getCostGrad.m │ │ ├── getDirectionalDerivative.m │ │ ├── getEuclideanGradient.m │ │ ├── getGlobalDefaults.m │ │ ├── getGradient.m │ │ ├── getHessian.m │ │ ├── getHessianFD.m │ │ ├── getLinesearch.m │ │ ├── getPrecon.m │ │ ├── getStore.m │ │ ├── hashmd5.m │ │ ├── mergeOptions.m │ │ ├── purgeStoredb.m │ │ ├── setStore.m │ │ └── stoppingcriterion.m │ ├── solvers │ │ ├── conjugategradient │ │ │ └── conjugategradient.m │ │ ├── linesearch │ │ │ ├── linesearch.m │ │ │ ├── linesearch_adaptive.m │ │ │ └── linesearch_hint.m │ │ ├── neldermead │ │ │ ├── centroid.m │ │ │ └── neldermead.m │ │ ├── pso │ │ │ └── pso.m │ │ ├── steepestdescent │ │ │ └── steepestdescent.m │ │ └── trustregions │ │ │ ├── license for original GenRTR code.txt │ │ │ ├── tCG.m │ │ │ └── trustregions.m │ └── tools │ │ ├── checkdiff.m │ │ ├── checkdiffSingle.m │ │ ├── checkdiffSinglePrecision.m │ │ ├── checkgradient.m │ │ ├── checkgradientSinglePrecision.m │ │ ├── checkhessian.m │ │ ├── diagsum.m │ │ ├── hessianspectrum.m │ │ ├── identify_linear_piece.m │ │ ├── identify_linear_piece_SinglePrecision.m │ │ ├── multiprod.m │ │ ├── multiprodmultitransp_license.txt │ │ ├── multiscale.m │ │ ├── multiskew.m │ │ ├── multisym.m │ │ ├── multitrace.m │ │ ├── multitransp.m │ │ ├── plotprofile.m │ │ ├── powermanifold.m │ │ └── productmanifold.m └── manopt_version.m ├── spdnet ├── vl_mybfc.m ├── vl_myfc.m ├── vl_myforbackward.m ├── vl_mylog.m ├── vl_myrec.m └── vl_mysoftmaxloss.m ├── spdnet_afew.m ├── spdnet_init_afew.m ├── spdnet_train_afew.m └── utils ├── dDiag.m ├── diagInv.m ├── diagLog.m ├── max_eig.m └── symmetric.m /README.md: -------------------------------------------------------------------------------- 1 | # SPDNet-master 2 | Zhiwu Huang and Luc Van Gool. A Riemannian Network for SPD Matrix Learning, In Proc. AAAI 2017. 3 | 4 | Version 1.0, Copyright(c) November, 2017. 5 | 6 | Note that the copyright of the manopt toolbox is reserved by https://www.manopt.org/ 7 | 8 | ## Usage 9 | 10 | Step1: Place the used AFEW SPD data under the folder "./data/afew/". Note that the used HDM05 and PaSC SPD data are also publicly available. 11 | 12 | Step2: Launch spdnet_afew.m for a simple example. 13 | 14 | ## Related Work/Implementation 15 | 16 | 1. Thanks to Oleg Smirnov who is Sr. Applied Scientist at Amazon, a TensorFlow ManOpt library is released to reproduce our SPDNet. 17 | 18 | 2. A NeurIPS 2019 paper "Riemannian batch normalization for SPD neural networks" develops batch normalization layer upon our SPDNet, with the official PyTorch code being publicly available at the 'Supplemental' tab. 19 | 20 | 3. A report "Second-order networks in PyTorch" is released. 21 | 22 | 4. Thanks to Alireza Davoudi, there is another Python implementation for SPDNet. 23 | 24 | 5. A direct extension of our SPDNet for facial emotion recognition is published by CVPR workshop 2018, with the code being available. 25 | 26 | 27 | ## How to Cite 28 | If you find this project helpful, please consider citing us as follows: 29 | ```bash 30 | @inproceedings{huang2017spdnet, 31 | title = {A Riemannian Network for SPD Matrix Learning}, 32 | author = {Huang, Zhiwu and 33 | Van Gool, Luc}, 34 | year = {2017}, 35 | booktitle = {Association for the Advancement of Artificial Intelligence (AAAI)} 36 | } 37 | 38 | 39 | -------------------------------------------------------------------------------- /confPath.m: -------------------------------------------------------------------------------- 1 | % Add folders to path. 2 | addpath(pwd); 3 | 4 | cd manopt; 5 | addpath(genpath(pwd)); 6 | cd ..; 7 | 8 | cd utils; 9 | addpath(genpath(pwd)); 10 | cd ..; 11 | 12 | cd spdnet; 13 | addpath(genpath(pwd)); 14 | cd ..; 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /data/afew/spddb_afew_train_spd400_int_histeq.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiwu-huang/SPDNet/527110a2209e918785075738b81c4ab091664e61/data/afew/spddb_afew_train_spd400_int_histeq.mat -------------------------------------------------------------------------------- /manopt/CLA.txt: -------------------------------------------------------------------------------- 1 | Thank you for your interest in Manopt. The purpose of this Contributor License Agreement is to 2 | clarify the intellectual property license granted with contributions of software from any person or 3 | entity (the "Contributor") to the owners of Manopt. This license is for your protection as a 4 | Contributor of software to Manopt and does not change your right to use your own contributions for 5 | any other purpose. 6 | 7 | The owners of Manopt are the copyright holders of Manopt indicated in the license files distributed 8 | with the software. 9 | 10 | You and the owners of Manopt hereby accept and agree to the following terms and conditions: 11 | 12 | Your "Contributions" means all of your past, present and future contributions of object code, source 13 | code and documentation to Manopt, however submitted to Manopt, excluding any submissions that are 14 | conspicuously marked or otherwise designated in writing by You as "Not a Contribution." 15 | 16 | You hereby grant to the owners of Manopt a non-exclusive, irrevocable, worldwide, no-charge, 17 | transferable copyright license to use, execute, prepare derivative works of, and distribute 18 | (internally and externally, in object code and, if included in your Contributions, source code form) 19 | your Contributions. Except for the rights granted to the owners of Manopt in this paragraph, You 20 | reserve all right, title and interest in and to your Contributions. 21 | 22 | You represent that you are legally entitled to grant the above license. If your employer(s) have 23 | rights to intellectual property that you create, you represent that you have received permission to 24 | make the Contibutions on behalf of that employer, or that your employer has waived such rights for 25 | your Contributions to Manopt. 26 | 27 | You represent that, except as disclosed in your Contribution submission(s), each of your 28 | Contributions is your original creation. You represent that your Contribution submissions(s) 29 | included complete details of any license or other restriction (including, but not limited to, 30 | related patents and trademarks) associated with any part of your Contribution(s) (including a copy 31 | of any applicable license agreement). You agree to notify the owners of Manopt of any facts or 32 | circumstances of which you become aware that would make Your representations in the Agreement 33 | inaccurate in any respect. 34 | 35 | You are not expected to provide support for your Contributions, except to the extent you desire to 36 | provide support. Your may provide support for free, for a fee, or not at all. Your Contributions are 37 | provided as-is, with all faults, defects and errors, and without any warranty of any kind (either 38 | express or implied) including, without limitation, any implied warranty of merchantability and 39 | fitness for a particular purpose and any warranty of non-infringement. 40 | 41 | 42 | This CLA is a modification of the CLA used by the UW Calendar project of the Univeristy of 43 | Washington: 44 | -------------------------------------------------------------------------------- /manopt/CREDITS.txt: -------------------------------------------------------------------------------- 1 | The Manopt project is led by these people: 2 | 3 | * Nicolas Boumal 4 | * Bamdev Mishra 5 | * Pierre-Antoine Absil 6 | * Rodolphe Sepulchre 7 | 8 | 9 | The following people have written code specifically for Manopt: 10 | 11 | * Nicolas Boumal 12 | * Bamdev Mishra 13 | * Pierre Borckmans 14 | 15 | 16 | Furthermore, code written by the following people can be found in Manopt: 17 | 18 | * Chris Baker 19 | * Pierre-Antoine Absil 20 | * Kyle Gallivan 21 | * Paolo de Leva 22 | * Wynton Moore 23 | * Michael Kleder 24 | -------------------------------------------------------------------------------- /manopt/LICENSE.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiwu-huang/SPDNet/527110a2209e918785075738b81c4ab091664e61/manopt/LICENSE.txt -------------------------------------------------------------------------------- /manopt/README.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiwu-huang/SPDNet/527110a2209e918785075738b81c4ab091664e61/manopt/README.txt -------------------------------------------------------------------------------- /manopt/checkinstall/basicexample.m: -------------------------------------------------------------------------------- 1 | function basicexample 2 | 3 | % Verify that Manopt was indeed added to the Matlab path. 4 | if isempty(which('spherefactory')) 5 | error(['You should first add Manopt to the Matlab path.\n' ... 6 | 'Please run importmanopt first.']); 7 | end 8 | 9 | % Generate the problem data. 10 | n = 1000; 11 | A = randn(n); 12 | A = .5*(A+A'); 13 | 14 | % Create the problem structure. 15 | manifold = spherefactory(n); 16 | problem.M = manifold; 17 | 18 | % Define the problem cost function and its gradient. 19 | problem.cost = @(x) -x'*(A*x); 20 | problem.grad = @(x) manifold.egrad2rgrad(x, -2*A*x); 21 | 22 | % Numerically check gradient consistency. 23 | checkgradient(problem); 24 | 25 | % Solve. 26 | % The trust-regions algorithm requires the Hessian. Since we do not 27 | % provide it, it will go for a standard approximation of it. The first 28 | % instruction tells Manopt not to issue a warning when this happens. 29 | warning('off', 'manopt:getHessian:approx'); 30 | [x xcost info] = trustregions(problem); %#ok 31 | 32 | % Display some statistics. 33 | figure; 34 | semilogy([info.iter], [info.gradnorm], '.-'); 35 | xlabel('Iteration #'); 36 | ylabel('Gradient norm'); 37 | title('Convergence of the trust-regions algorithm on the sphere'); 38 | 39 | end 40 | -------------------------------------------------------------------------------- /manopt/examples/dominant_invariant_subspace.m: -------------------------------------------------------------------------------- 1 | function [X, info] = dominant_invariant_subspace(A, p) 2 | % Returns an orthonormal basis of the dominant invariant p-subspace of A. 3 | % 4 | % function X = dominant_invariant_subspace(A, p) 5 | % 6 | % Input: A real, symmetric matrix A of size nxn and an integer p < n. 7 | % Output: A real, orthonormal matrix X of size nxp such that trace(X'*A*X) 8 | % is maximized. That is, the columns of X form an orthonormal basis 9 | % of a dominant subspace of dimension p of A. These are thus 10 | % eigenvectors associated with the largest eigenvalues of A (in no 11 | % particular order). Sign is important: 2 is deemed a larger 12 | % eigenvalue than -5. 13 | % 14 | % The optimization is performed on the Grassmann manifold, since only the 15 | % space spanned by the columns of X matters. The implementation is short to 16 | % show how Manopt can be used to quickly obtain a prototype. To make the 17 | % implementation more efficient, one might first try to use the caching 18 | % system, that is, use the optional 'store' arguments in the cost, grad and 19 | % hess functions. Furthermore, using egrad2rgrad and ehess2rhess is quick 20 | % and easy, but not always efficient. Having a look at the formulas 21 | % implemented in these functions can help rewrite the code without them, 22 | % possibly more efficiently. 23 | 24 | % This file is part of Manopt and is copyrighted. See the license file. 25 | % 26 | % Main author: Nicolas Boumal, July 5, 2013 27 | % Contributors: 28 | % 29 | % Change log: 30 | % 31 | % NB Dec. 6, 2013: 32 | % We specify a max and initial trust region radius in the options. 33 | 34 | % Generate some random data to test the function 35 | if ~exist('A', 'var') || isempty(A) 36 | A = randn(128); 37 | A = (A+A')/2; 38 | end 39 | if ~exist('p', 'var') || isempty(p) 40 | p = 3; 41 | end 42 | 43 | % Make sure the input matrix is square and symmetric 44 | n = size(A, 1); 45 | assert(isreal(A), 'A must be real.') 46 | assert(size(A, 2) == n, 'A must be square.'); 47 | assert(norm(A-A', 'fro') < n*eps, 'A must be symmetric.'); 48 | assert(p<=n, 'p must be smaller than n.'); 49 | 50 | % Define the cost and its derivatives on the Grassmann manifold 51 | Gr = grassmannfactory(n, p); 52 | problem.M = Gr; 53 | problem.cost = @(X) -trace(X'*A*X); 54 | problem.grad = @(X) -2*Gr.egrad2rgrad(X, A*X); 55 | problem.hess = @(X, H) -2*Gr.ehess2rhess(X, A*X, A*H, H); 56 | 57 | % Execute some checks on the derivatives for early debugging. 58 | % These things can be commented out of course. 59 | % checkgradient(problem); 60 | % pause; 61 | % checkhessian(problem); 62 | % pause; 63 | 64 | % Issue a call to a solver. A random initial guess will be chosen and 65 | % default options are selected except for the ones we specify here. 66 | options.Delta_bar = 8*sqrt(p); 67 | [X, costX, info, options] = trustregions(problem, [], options); %#ok 68 | 69 | fprintf('Options used:\n'); 70 | disp(options); 71 | 72 | % For our information, Manopt can also compute the spectrum of the 73 | % Riemannian Hessian on the tangent space at (any) X. Computing the 74 | % spectrum at the solution gives us some idea of the conditioning of 75 | % the problem. If we were to implement a preconditioner for the 76 | % Hessian, this would also inform us on its performance. 77 | % 78 | % Notice that (typically) all eigenvalues of the Hessian at the 79 | % solution are positive, i.e., we find an isolated minimizer. If we 80 | % replace the Grassmann manifold by the Stiefel manifold, hence still 81 | % optimizing over orthonormal matrices but ignoring the invariance 82 | % cost(XQ) = cost(X) for all Q orthogonal, then we see 83 | % dim O(p) = p(p-1)/2 zero eigenvalues in the Hessian spectrum, making 84 | % the optimizer not isolated anymore. 85 | if Gr.dim() < 512 86 | evs = hessianspectrum(problem, X); 87 | stairs(sort(evs)); 88 | title(['Eigenvalues of the Hessian of the cost function ' ... 89 | 'at the solution']); 90 | xlabel('Eigenvalue number (sorted)'); 91 | ylabel('Value of the eigenvalue'); 92 | end 93 | 94 | end 95 | -------------------------------------------------------------------------------- /manopt/examples/positive_definite_karcher_mean.m: -------------------------------------------------------------------------------- 1 | function X = positive_definite_karcher_mean(A) 2 | % Computes a Karcher mean of a collection of positive definite matrices. 3 | % 4 | % function X = positive_definite_karcher_mean(A) 5 | % 6 | % Input: A 3D matrix A of size nxnxm such that each slice A(:,:,k) is a 7 | % positive definite matrix of size nxn. 8 | % 9 | % Output: A positive definite matrix X of size nxn which is a Karcher mean 10 | % of the m matrices in A, that is, X minimizes the sum of squared 11 | % Riemannian distances to the matrices in A: 12 | % f(X) = sum_k=1^m .5*dist^2(X, A(:, :, k)) 13 | % The distance is defined by the natural metric on the set of 14 | % positive definite matrices: dist(X,Y) = norm(logm(X\Y), 'fro'). 15 | % 16 | % This simple example is not the best way to compute Karcher means. Its 17 | % purpose it to serve as base code to explore other algorithms. In 18 | % particular, in the presence of large noise, this algorithm seems to not 19 | % be able to reach points with a very small gradient norm. This may be 20 | % caused by insufficient accuracy in the gradient computation. 21 | 22 | % This file is part of Manopt and is copyrighted. See the license file. 23 | % 24 | % Main author: Nicolas Boumal, Sept. 3, 2013 25 | % Contributors: 26 | % 27 | % Change log: 28 | % 29 | 30 | % Generate some random data to test the function if none is given. 31 | if ~exist('A', 'var') || isempty(A) 32 | n = 5; 33 | m = 10; 34 | A = zeros(n, n, m); 35 | ref = diag(max(.1, 1+.1*randn(n, 1))); 36 | for i = 1 : m 37 | noise = 0.01*randn(n); 38 | noise = (noise + noise')/2; 39 | [V D] = eig(ref + noise); 40 | A(:, :, i) = V*diag(max(.01, diag(D)))*V'; 41 | end 42 | end 43 | 44 | % Retrieve the size of the problem: 45 | % There are m matrices of size nxn to average. 46 | n = size(A, 1); 47 | m = size(A, 3); 48 | assert(n == size(A, 2), ... 49 | ['The slices of A must be square, i.e., the ' ... 50 | 'first and second dimensions of A must be equal.']); 51 | 52 | % Our search space is the set of positive definite matrices of size n. 53 | % Notice that this is the only place we specify on which manifold we 54 | % wish to compute Karcher means. Replacing this factory for another 55 | % geometry will yield code to compute Karcher means on that other 56 | % manifold, provided that manifold is equipped with a dist function and 57 | % a logarithmic map log. 58 | M = sympositivedefinitefactory(n); 59 | 60 | % Define a problem structure, specifying the manifold M, the cost 61 | % function and its gradient. 62 | problem.M = M; 63 | problem.cost = @cost; 64 | problem.grad = @grad; 65 | 66 | % The functions below make many redundant computations. This 67 | % performance hit can be alleviated by using the caching system. We go 68 | % for a simple implementation here, as a tutorial example. 69 | 70 | % Cost function 71 | function f = cost(X) 72 | f = 0; 73 | for k = 1 : m 74 | f = f + M.dist(X, A(:, :, k))^2; 75 | end 76 | f = f/(2*m); 77 | end 78 | 79 | % Riemannian gradient of the cost function 80 | function g = grad(X) 81 | g = M.zerovec(X); 82 | for k = 1 : m 83 | % Update g in a linear combination of the form 84 | % g = g - [something]/m. 85 | g = M.lincomb(X, 1, g, -1/m, M.log(X, A(:, :, k))); 86 | end 87 | end 88 | 89 | % Execute some checks on the derivatives for early debugging. 90 | % These things can be commented out of course. 91 | % The slopes should agree on part of the plot at least. In this case, 92 | % it is sometimes necessary to inspect the plot visually to make the 93 | % call, but it is indeed correct. 94 | % checkgradient(problem); 95 | % pause; 96 | 97 | % Execute this if you want to force using a proper parallel vector 98 | % transport. This is not necessary. If you omit this, the default 99 | % vector transport is the identity map, which is (of course) cheaper 100 | % and seems to perform well in practice. 101 | % M.transp = M.paralleltransp; 102 | 103 | % Issue a call to a solver. Default options are selected. 104 | % Our initial guess is the first data point. 105 | X = trustregions(problem, A(:, :, 1)); 106 | 107 | end 108 | -------------------------------------------------------------------------------- /manopt/importmanopt.m: -------------------------------------------------------------------------------- 1 | % Add Manopt to the path and make all manopt components available. 2 | 3 | % This file is part of Manopt: www.manopt.org. 4 | % Original author: Nicolas Boumal, Jan. 3, 2013. 5 | % Contributors: 6 | % Change log: 7 | % Aug. 7, 2013 (NB): Changed to work without the import command 8 | % (new structure of the toolbox). 9 | % Aug. 8, 2013 (NB): Changed to use addpath_recursive, home brewed. 10 | % Aug. 22, 2013 (NB): Using genpath instead of homecooked 11 | % addpath_recursive. 12 | 13 | addpath(pwd); 14 | 15 | % Recursively add Manopt directories to the Matlab path. 16 | cd manopt; 17 | addpath(genpath(pwd)); 18 | cd ..; 19 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/complexcircle/complexcirclefactory.m: -------------------------------------------------------------------------------- 1 | function M = complexcirclefactory(n) 2 | % Returns a manifold struct to optimize over unit-modulus complex numbers. 3 | % 4 | % function M = complexcirclefactory() 5 | % function M = complexcirclefactory(n) 6 | % 7 | % Description of vectors z in C^n (complex) such that each component z(i) 8 | % has unit modulus. The manifold structure is the Riemannian submanifold 9 | % structure from the embedding space R^2 x ... x R^2, i.e., the complex 10 | % circle is identified with the unit circle in the real plane. 11 | % 12 | % By default, n = 1. 13 | % 14 | % See also spherecomplexfactory 15 | 16 | % This file is part of Manopt: www.manopt.org. 17 | % Original author: Nicolas Boumal, Dec. 30, 2012. 18 | % Contributors: 19 | % Change log: 20 | % 21 | % July 7, 2014 (NB): Added ehess2rhess function. 22 | % 23 | 24 | if ~exist('n', 'var') 25 | n = 1; 26 | end 27 | 28 | M.name = @() sprintf('Complex circle (S^1)^%d', n); 29 | 30 | M.dim = @() n; 31 | 32 | M.inner = @(z, v, w) real(v'*w); 33 | 34 | M.norm = @(x, v) norm(v); 35 | 36 | M.dist = @(x, y) norm(acos(conj(x) .* y)); 37 | 38 | M.typicaldist = @() pi*sqrt(n); 39 | 40 | M.proj = @(z, u) u - real( conj(u) .* z ) .* z; 41 | 42 | M.tangent = M.proj; 43 | 44 | % For Riemannian submanifolds, converting a Euclidean gradient into a 45 | % Riemannian gradient amounts to an orthogonal projection. 46 | M.egrad2rgrad = M.proj; 47 | 48 | M.ehess2rhess = @ehess2rhess; 49 | function rhess = ehess2rhess(z, egrad, ehess, zdot) 50 | rhess = M.proj(z, ehess - real(z.*conj(egrad)).*zdot); 51 | end 52 | 53 | M.exp = @exponential; 54 | function y = exponential(z, v, t) 55 | if nargin <= 2 56 | t = 1.0; 57 | end 58 | 59 | y = zeros(n, 1); 60 | tv = t*v; 61 | 62 | nrm_tv = abs(tv); 63 | 64 | % We need to distinguish between very small steps and the others. 65 | % For very small steps, we use a a limit version of the exponential 66 | % (which actually coincides with the retraction), so as to not 67 | % divide by very small numbers. 68 | mask = nrm_tv > 1e-6; 69 | y(mask) = z(mask).*cos(nrm_tv(mask)) + ... 70 | tv(mask).*(sin(nrm_tv(mask))./nrm_tv(mask)); 71 | y(~mask) = z(~mask) + tv(~mask); 72 | y(~mask) = y(~mask) ./ abs(y(~mask)); 73 | 74 | end 75 | 76 | M.retr = @retraction; 77 | function y = retraction(z, v, t) 78 | if nargin <= 2 79 | t = 1.0; 80 | end 81 | y = z+t*v; 82 | y = y ./ abs(y); 83 | end 84 | 85 | M.log = @logarithm; 86 | function v = logarithm(x1, x2) 87 | v = M.proj(x1, x2 - x1); 88 | di = M.dist(x1, x2); 89 | nv = norm(v); 90 | v = v * (di / nv); 91 | end 92 | 93 | M.hash = @(z) ['z' hashmd5( [real(z(:)) ; imag(z(:))] ) ]; 94 | 95 | M.rand = @random; 96 | function z = random() 97 | z = randn(n, 1) + 1i*randn(n, 1); 98 | z = z ./ abs(z); 99 | end 100 | 101 | M.randvec = @randomvec; 102 | function v = randomvec(z) 103 | % i*z(k) is a basis vector of the tangent vector to the k-th circle 104 | v = randn(n, 1) .* (1i*z); 105 | v = v / norm(v); 106 | end 107 | 108 | M.lincomb = @lincomb; 109 | 110 | M.zerovec = @(x) zeros(n, 1); 111 | 112 | M.transp = @(x1, x2, d) M.proj(x2, d); 113 | 114 | M.pairmean = @pairmean; 115 | function z = pairmean(z1, z2) 116 | z = z1+z2; 117 | z = z ./ abs(z); 118 | end 119 | 120 | M.vec = @(x, u_mat) [real(u_mat) ; imag(u_mat)]; 121 | M.mat = @(x, u_vec) u_vec(1:n) + 1i*u_vec((n+1):end); 122 | M.vecmatareisometries = @() true; 123 | 124 | end 125 | 126 | 127 | % Linear combination of tangent vectors 128 | function d = lincomb(x, a1, d1, a2, d2) %#ok 129 | 130 | if nargin == 3 131 | d = a1*d1; 132 | elseif nargin == 5 133 | d = a1*d1 + a2*d2; 134 | else 135 | error('Bad use of sphere.lincomb.'); 136 | end 137 | 138 | end 139 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/euclidean/euclideanfactory.m: -------------------------------------------------------------------------------- 1 | function M = euclideanfactory(m, n) 2 | % Returns a manifold struct to optimize over m-by-n matrices. 3 | % 4 | % function M = euclideanfactory(m, n) 5 | % 6 | % Returns M, a structure describing the Euclidean space of m-by-n matrices 7 | % equipped with the standard Frobenius distance and associated trace inner 8 | % product as a manifold for Manopt. 9 | 10 | % This file is part of Manopt: www.manopt.org. 11 | % Original author: Nicolas Boumal, Dec. 30, 2012. 12 | % Contributors: 13 | % Change log: 14 | % July 5, 2013 (NB): added egred2rgrad, ehess2rhess, mat, vec, tangent. 15 | 16 | 17 | if ~exist('n', 'var') || isempty(n) 18 | n = 1; 19 | end 20 | 21 | M.name = @() sprintf('Euclidean space R^(%dx%d)', m, n); 22 | 23 | M.dim = @() m*n; 24 | 25 | M.inner = @(x, d1, d2) d1(:).'*d2(:); 26 | 27 | M.norm = @(x, d) norm(d, 'fro'); 28 | 29 | M.dist = @(x, y) norm(x-y, 'fro'); 30 | 31 | M.typicaldist = @() sqrt(m*n); 32 | 33 | M.proj = @(x, d) d; 34 | 35 | M.egrad2rgrad = @(x, g) g; 36 | 37 | M.ehess2rhess = @(x, eg, eh, d) eh; 38 | 39 | M.tangent = M.proj; 40 | 41 | M.exp = @exp; 42 | function y = exp(x, d, t) 43 | if nargin == 3 44 | y = x + t*d; 45 | else 46 | y = x + d; 47 | end 48 | end 49 | 50 | M.retr = M.exp; 51 | 52 | M.log = @(x, y) y-x; 53 | 54 | M.hash = @(x) ['z' hashmd5(x(:))]; 55 | 56 | M.rand = @() randn(m, n); 57 | 58 | M.randvec = @randvec; 59 | function u = randvec(x) %#ok 60 | u = randn(m, n); 61 | u = u / norm(u, 'fro'); 62 | end 63 | 64 | M.lincomb = @lincomb; 65 | function v = lincomb(x, a1, d1, a2, d2) %#ok 66 | if nargin == 3 67 | v = a1*d1; 68 | elseif nargin == 5 69 | v = a1*d1 + a2*d2; 70 | else 71 | error('Bad usage of euclidean.lincomb'); 72 | end 73 | end 74 | 75 | M.zerovec = @(x) zeros(m, n); 76 | 77 | M.transp = @(x1, x2, d) d; 78 | 79 | M.pairmean = @(x1, x2) .5*(x1+x2); 80 | 81 | M.vec = @(x, u_mat) u_mat(:); 82 | M.mat = @(x, u_vec) reshape(u_vec, [m, n]); 83 | M.vecmatareisometries = @() true; 84 | 85 | end 86 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/euclidean/symmetricfactory.m: -------------------------------------------------------------------------------- 1 | function M = symmetricfactory(n, k) 2 | % Returns a manifold struct to optimize over k symmetric matrices of size n 3 | % 4 | % function M = symmetricfactory(n) 5 | % function M = symmetricfactory(n, k) 6 | % 7 | % Returns M, a structure describing the Euclidean space of n-by-n symmetric 8 | % matrices equipped with the standard Frobenius distance and associated 9 | % trace inner product, as a manifold for Manopt. 10 | % By default, k = 1. If k > 1, points and vectors are stored in 3D matrices 11 | % X of size nxnxk such that each slice X(:, :, i), for i = 1:k, is 12 | % symmetric. 13 | 14 | % This file is part of Manopt: www.manopt.org. 15 | % Original author: Nicolas Boumal, Jan. 22, 2014. 16 | % Contributors: 17 | % Change log: 18 | 19 | if ~exist('k', 'var') || isempty(k) 20 | k = 1; 21 | end 22 | 23 | M.name = @() sprintf('(Symmetric matrices of size %d)^%d', n, k); 24 | 25 | M.dim = @() k*n*(n+1)/2; 26 | 27 | M.inner = @(x, d1, d2) d1(:).'*d2(:); 28 | 29 | M.norm = @(x, d) norm(d(:), 'fro'); 30 | 31 | M.dist = @(x, y) norm(x(:)-y(:), 'fro'); 32 | 33 | M.typicaldist = @() sqrt(k)*n; 34 | 35 | M.proj = @(x, d) multisym(d); 36 | 37 | M.egrad2rgrad = M.proj; 38 | 39 | %M.ehess2rhess = @(x, eg, eh, d) eh; 40 | 41 | M.tangent = @(x, d) d; 42 | 43 | M.exp = @exp; 44 | function y = exp(x, d, t) 45 | if nargin == 3 46 | y = x + t*d; 47 | else 48 | y = x + d; 49 | end 50 | end 51 | 52 | M.retr = M.exp; 53 | 54 | M.log = @(x, y) y-x; 55 | 56 | M.hash = @(x) ['z' hashmd5(x(:))]; 57 | 58 | M.rand = @() multisym(randn(n, n, k)); 59 | 60 | M.randvec = @randvec; 61 | function u = randvec(x) %#ok 62 | u = multisym(randn(n, n, k)); 63 | u = u / norm(u(:), 'fro'); 64 | end 65 | 66 | M.lincomb = @lincomb; 67 | function v = lincomb(x, a1, d1, a2, d2) %#ok 68 | if nargin == 3 69 | v = a1*d1; 70 | elseif nargin == 5 71 | v = a1*d1 + a2*d2; 72 | else 73 | error('Bad usage of euclidean.lincomb'); 74 | end 75 | end 76 | 77 | M.zerovec = @(x) zeros(n, n, k); 78 | 79 | M.transp = @(x1, x2, d) d; 80 | 81 | M.pairmean = @(x1, x2) .5*(x1+x2); 82 | 83 | M.vec = @(x, u_mat) u_mat(:); 84 | M.mat = @(x, u_vec) reshape(u_vec, [m, n]); 85 | M.vecmatareisometries = @() true; 86 | 87 | end 88 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/fixedrank/fixedrankMNquotientfactory.m: -------------------------------------------------------------------------------- 1 | function M = fixedrankMNquotientfactory(m, n, k) 2 | % Manifold of m-by-n matrices of rank k with quotient geometry. 3 | % 4 | % function M = fixedrankMNquotientfactory(m, n, k) 5 | % 6 | % This follows the quotient geometry described in the following paper: 7 | % P.-A. Absil, L. Amodei and G. Meyer, 8 | % "Two Newton methods on the manifold of fixed-rank matrices endowed 9 | % with Riemannian quotient geometries", arXiv, 2012. 10 | % 11 | % Paper link: http://arxiv.org/abs/1209.0068 12 | % 13 | % A point X on the manifold is represented as a structure with two 14 | % fields: M and N. The matrix M (mxk) is orthonormal, while the matrix N 15 | % (nxk) is full-rank. 16 | % 17 | % Tangent vectors are represented as a structure with two fields (M, N). 18 | 19 | % This file is part of Manopt: www.manopt.org. 20 | % Original author: Bamdev Mishra, Dec. 30, 2012. 21 | % Contributors: 22 | % Change log: 23 | 24 | 25 | M.name = @() sprintf('MN'' quotient manifold of %dx%d matrices of rank %d', m, n, k); 26 | 27 | M.dim = @() (m+n-k)*k; 28 | 29 | % Choice of the metric is motivated by the symmetry present in the 30 | % space 31 | M.inner = @(X, eta, zeta) eta.M(:).'*zeta.M(:) + eta.N(:).'*zeta.N(:); 32 | 33 | M.norm = @(X, eta) sqrt(M.inner(X, eta, eta)); 34 | 35 | M.dist = @(x, y) error('fixedrankMNquotientfactory.dist not implemented yet.'); 36 | 37 | M.typicaldist = @() 10*k; 38 | 39 | symm = @(X) .5*(X+X'); 40 | stiefel_proj = @(M, H) H - M*symm(M'*H); 41 | 42 | M.egrad2rgrad = @egrad2rgrad; 43 | function eta = egrad2rgrad(X, eta) 44 | eta.M = stiefel_proj(X.M, eta.M); 45 | end 46 | 47 | M.ehess2rhess = @ehess2rhess; 48 | function Hess = ehess2rhess(X, egrad, ehess, eta) 49 | 50 | % Directional derivative of the Riemannian gradient 51 | Hess.M = ehess.M - eta.M*symm(X.M'*egrad.M); 52 | Hess.M = stiefel_proj(X.M, Hess.M); 53 | 54 | Hess.N = ehess.N; 55 | 56 | % Projection onto the horizontal space 57 | Hess = M.proj(X, Hess); 58 | end 59 | 60 | 61 | M.proj = @projection; 62 | function etaproj = projection(X, eta) 63 | 64 | % Start by projecting the vector from Rmp x Rnp to the tangent 65 | % space to the total space, that is, eta.M should be in the 66 | % tangent space to Stiefel at X.M and eta.N is arbitrary. 67 | eta.M = stiefel_proj(X.M, eta.M); 68 | 69 | % Now project from the tangent space to the horizontal space, that 70 | % is, take care of the quotient. 71 | 72 | % First solve a Sylvester equation (A symm., B skew-symm.) 73 | A = X.N'*X.N + eye(k); 74 | B = eta.M'*X.M + eta.N'*X.N; 75 | B = B-B'; 76 | omega = lyap(A, -B); 77 | 78 | % And project along the vertical space to the horizontal space. 79 | etaproj.M = eta.M + X.M*omega; 80 | etaproj.N = eta.N + X.N*omega; 81 | 82 | end 83 | 84 | M.exp = @exponential; 85 | function Y = exponential(X, eta, t) 86 | if nargin < 3 87 | t = 1.0; 88 | end 89 | 90 | A = t*X.M'*eta.M; 91 | S = t^2*eta.M'*eta.M; 92 | Y.M = [X.M t*eta.M]*expm([A -S ; eye(k) A])*eye(2*k, k)*expm(-A); 93 | 94 | % re-orthonormalize (seems necessary from time to time) 95 | [Q R] = qr(Y.M, 0); 96 | Y.M = Q * diag(sign(diag(R))); 97 | 98 | Y.N = X.N + t*eta.N; 99 | 100 | end 101 | 102 | % Factor M lives on the Stiefel manifold, hence we will reuse its 103 | % random generator. 104 | stiefelm = stiefelfactory(m, k); 105 | 106 | M.retr = @retraction; 107 | function Y = retraction(X, eta, t) 108 | if nargin < 3 109 | t = 1.0; 110 | end 111 | 112 | Y.M = uf(X.M + t*eta.M); % This is a valid retraction 113 | Y.N = X.N + t*eta.N; 114 | end 115 | 116 | M.hash = @(X) ['z' hashmd5([X.M(:) ; X.N(:)])]; 117 | 118 | M.rand = @random; 119 | function X = random() 120 | X.M = stiefelm.rand(); 121 | X.N = randn(n, k); 122 | end 123 | 124 | M.randvec = @randomvec; 125 | function eta = randomvec(X) 126 | eta.M = randn(m, k); 127 | eta.N = randn(n, k); 128 | eta = projection(X, eta); 129 | nrm = M.norm(X, eta); 130 | eta.M = eta.M / nrm; 131 | eta.N = eta.N / nrm; 132 | end 133 | 134 | M.lincomb = @lincomb; 135 | 136 | M.zerovec = @(X) struct('M', zeros(m, k), 'N', zeros(n, k)); 137 | 138 | M.transp = @(x1, x2, d) projection(x2, d); 139 | 140 | end 141 | 142 | 143 | % Linear combination of tangent vectors 144 | function d = lincomb(x, a1, d1, a2, d2) %#ok 145 | 146 | if nargin == 3 147 | d.M = a1*d1.M; 148 | d.N = a1*d1.N; 149 | elseif nargin == 5 150 | d.M = a1*d1.M + a2*d2.M; 151 | d.N = a1*d1.N + a2*d2.N; 152 | else 153 | error('Bad use of fixedrankMNquotientfactory.lincomb.'); 154 | end 155 | 156 | end 157 | 158 | 159 | function A = uf(A) 160 | [L, unused, R] = svd(A, 0); 161 | A = L*R'; 162 | end -------------------------------------------------------------------------------- /manopt/manopt/manifolds/fixedrank/fixedrankfactory_2factors.m: -------------------------------------------------------------------------------- 1 | function M = fixedrankfactory_2factors(m, n, k) 2 | % Manifold of m-by-n matrices of rank k with balanced quotient geometry. 3 | % 4 | % function M = fixedrankfactory_2factors(m, n, k) 5 | % 6 | % This follows the balanced quotient geometry described in the following paper: 7 | % G. Meyer, S. Bonnabel and R. Sepulchre, 8 | % "Linear regression under fixed-rank constraints: a Riemannian approach", 9 | % ICML 2011. 10 | % 11 | % Paper link: http://www.icml-2011.org/papers/350_icmlpaper.pdf 12 | % 13 | % A point X on the manifold is represented as a structure with two 14 | % fields: L and R. The matrices L (mxk) and R (nxk) are full column-rank 15 | % matrices such that X = L*R'. 16 | % 17 | % Tangent vectors are represented as a structure with two fields: L, R 18 | 19 | % This file is part of Manopt: www.manopt.org. 20 | % Original author: Bamdev Mishra, Dec. 30, 2012. 21 | % Contributors: 22 | % Change log: 23 | % July 10, 2013 (NB) : added vec, mat, tangent, tangent2ambient 24 | 25 | 26 | M.name = @() sprintf('LR'' quotient manifold of %dx%d matrices of rank %d', m, n, k); 27 | 28 | M.dim = @() (m+n-k)*k; 29 | 30 | % Some precomputations at the point X to be used in the inner product (and 31 | % pretty much everywhere else). 32 | function X = prepare(X) 33 | if ~all(isfield(X,{'LtL','RtR','invRtR','invLtL'})) 34 | L = X.L; 35 | R = X.R; 36 | X.LtL = L'*L; 37 | X.RtR = R'*R; 38 | X.invLtL = inv(X.LtL); 39 | X.invRtR = inv(X.RtR); 40 | end 41 | end 42 | 43 | % Choice of the metric is motivated by the symmetry present in the space 44 | M.inner = @iproduct; 45 | function ip = iproduct(X, eta, zeta) 46 | X = prepare(X); 47 | ip = trace(X.invLtL*(eta.L'*zeta.L)) + trace( X.invRtR*(eta.R'*zeta.R)); 48 | end 49 | 50 | M.norm = @(X, eta) sqrt(M.inner(X, eta, eta)); 51 | 52 | M.dist = @(x, y) error('fixedrankfactory_2factors.dist not implemented yet.'); 53 | 54 | M.typicaldist = @() 10*k; 55 | 56 | symm = @(M) .5*(M+M'); 57 | 58 | M.egrad2rgrad = @egrad2rgrad; 59 | function eta = egrad2rgrad(X, eta) 60 | X = prepare(X); 61 | eta.L = eta.L*X.LtL; 62 | eta.R = eta.R*X.RtR; 63 | end 64 | 65 | M.ehess2rhess = @ehess2rhess; 66 | function Hess = ehess2rhess(X, egrad, ehess, eta) 67 | X = prepare(X); 68 | 69 | % Riemannian gradient 70 | rgrad = egrad2rgrad(X, egrad); 71 | 72 | % Directional derivative of the Riemannian gradient 73 | Hess.L = ehess.L*X.LtL + 2*egrad.L*symm(eta.L'*X.L); 74 | Hess.R = ehess.R*X.RtR + 2*egrad.R*symm(eta.R'*X.R); 75 | 76 | % We need a correction term for the non-constant metric 77 | Hess.L = Hess.L - rgrad.L*((X.invLtL)*symm(X.L'*eta.L)) - eta.L*(X.invLtL*symm(X.L'*rgrad.L)) + X.L*(X.invLtL*symm(eta.L'*rgrad.L)); 78 | Hess.R = Hess.R - rgrad.R*((X.invRtR)*symm(X.R'*eta.R)) - eta.R*(X.invRtR*symm(X.R'*rgrad.R)) + X.R*(X.invRtR*symm(eta.R'*rgrad.R)); 79 | 80 | % Projection onto the horizontal space 81 | Hess = M.proj(X, Hess); 82 | end 83 | 84 | M.proj = @projection; 85 | % Projection of the vector eta onto the horizontal space 86 | function etaproj = projection(X, eta) 87 | X = prepare(X); 88 | 89 | SS = (X.LtL)*(X.RtR); 90 | AS = (X.LtL)*(X.R'*eta.R) - (eta.L'*X.L)*(X.RtR); 91 | Omega = lyap(SS, SS,-AS); 92 | etaproj.L = eta.L + X.L*Omega'; 93 | etaproj.R = eta.R - X.R*Omega; 94 | end 95 | 96 | M.tangent = M.proj; 97 | M.tangent2ambient = @(X, eta) eta; 98 | 99 | M.retr = @retraction; 100 | function Y = retraction(X, eta, t) 101 | if nargin < 3 102 | t = 1.0; 103 | end 104 | 105 | Y.L = X.L + t*eta.L; 106 | Y.R = X.R + t*eta.R; 107 | 108 | % Numerical conditioning step: A simpler version. 109 | % We need to ensure that L and R do not have very relative 110 | % skewed norms. 111 | 112 | scaling = norm(X.L, 'fro')/norm(X.R, 'fro'); 113 | scaling = sqrt(scaling); 114 | Y.L = Y.L / scaling; 115 | Y.R = Y.R * scaling; 116 | 117 | % These are reused in the computation of the gradient and Hessian 118 | Y = prepare(Y); 119 | end 120 | 121 | M.exp = @exponential; 122 | function Y = exponential(X, eta, t) 123 | if nargin < 3 124 | t = 1.0; 125 | end 126 | 127 | Y = retraction(X, eta, t); 128 | warning('manopt:fixedrankfactory_2factors:exp', ... 129 | ['Exponential for fixed rank ' ... 130 | 'manifold not implemented yet. Used retraction instead.']); 131 | end 132 | 133 | M.hash = @(X) ['z' hashmd5([X.L(:) ; X.R(:)])]; 134 | 135 | M.rand = @random; 136 | function X = random() 137 | % A random point on the total space 138 | X.L = randn(m, k); 139 | X.R = randn(n, k); 140 | X = prepare(X); 141 | end 142 | 143 | M.randvec = @randomvec; 144 | function eta = randomvec(X) 145 | % A random vector in the horizontal space 146 | eta.L = randn(m, k); 147 | eta.R = randn(n, k); 148 | eta = projection(X, eta); 149 | nrm = M.norm(X, eta); 150 | eta.L = eta.L / nrm; 151 | eta.R = eta.R / nrm; 152 | end 153 | 154 | M.lincomb = @lincomb; 155 | 156 | M.zerovec = @(X) struct('L', zeros(m, k),'R', zeros(n, k)); 157 | 158 | M.transp = @(x1, x2, d) projection(x2, d); 159 | 160 | % vec and mat are not isometries, because of the unusual inner metric. 161 | M.vec = @(X, U) [U.L(:) ; U.R(:)]; 162 | M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ... 163 | 'R', reshape(u((m*k+1):end), n, k)); 164 | M.vecmatareisometries = @() false; 165 | 166 | end 167 | 168 | % Linear combination of tangent vectors 169 | function d = lincomb(x, a1, d1, a2, d2) %#ok 170 | 171 | if nargin == 3 172 | d.L = a1*d1.L; 173 | d.R = a1*d1.R; 174 | elseif nargin == 5 175 | d.L = a1*d1.L + a2*d2.L; 176 | d.R = a1*d1.R + a2*d2.R; 177 | else 178 | error('Bad use of fixedrankfactory_2factors.lincomb.'); 179 | end 180 | 181 | end 182 | 183 | 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/fixedrank/fixedrankfactory_2factors_preconditioned.m: -------------------------------------------------------------------------------- 1 | function M = fixedrankfactory_2factors_preconditioned(m, n, k) 2 | % Manifold of m-by-n matrices of rank k with new balanced quotient geometry 3 | % 4 | % function M = fixedrankfactory_2factors_preconditioned(m, n, k) 5 | % 6 | % This follows the quotient geometry described in the following paper: 7 | % B. Mishra, K. Adithya Apuroop and R. Sepulchre, 8 | % "A Riemannian geometry for low-rank matrix completion", 9 | % arXiv, 2012. 10 | % 11 | % Paper link: http://arxiv.org/abs/1211.1550 12 | % 13 | % This geoemtry is tuned to least square problems such as low-rank matrix 14 | % completion. 15 | % 16 | % A point X on the manifold is represented as a structure with two 17 | % fields: L and R. The matrices L (mxk) and R (nxk) are full column-rank 18 | % matrices. 19 | % 20 | % Tangent vectors are represented as a structure with two fields: L, R 21 | 22 | % This file is part of Manopt: www.manopt.org. 23 | % Original author: Bamdev Mishra, Dec. 30, 2012. 24 | % Contributors: 25 | % Change log: 26 | 27 | 28 | 29 | M.name = @() sprintf('LR''(tuned for least square problems) quotient manifold of %dx%d matrices of rank %d', m, n, k); 30 | 31 | M.dim = @() (m+n-k)*k; 32 | 33 | 34 | % Some precomputations at the point X to be used in the inner product (and 35 | % pretty much everywhere else). 36 | function X = prepare(X) 37 | if ~all(isfield(X,{'LtL','RtR','invRtR','invLtL'})) 38 | L = X.L; 39 | R = X.R; 40 | X.LtL = L'*L; 41 | X.RtR = R'*R; 42 | X.invLtL = inv(X.LtL); 43 | X.invRtR = inv(X.RtR); 44 | end 45 | end 46 | 47 | 48 | % The choice of metric is motivated by symmetry and tuned to least square 49 | % objective function 50 | M.inner = @iproduct; 51 | function ip = iproduct(X, eta, zeta) 52 | X = prepare(X); 53 | 54 | ip = trace(X.RtR*(eta.L'*zeta.L)) + trace(X.LtL*(eta.R'*zeta.R)); 55 | end 56 | 57 | M.norm = @(X, eta) sqrt(M.inner(X, eta, eta)); 58 | 59 | M.dist = @(x, y) error('fixedrankfactory_2factors_preconditioned.dist not implemented yet.'); 60 | 61 | M.typicaldist = @() 10*k; 62 | 63 | symm = @(M) .5*(M+M'); 64 | 65 | M.egrad2rgrad = @egrad2rgrad; 66 | function eta = egrad2rgrad(X, eta) 67 | X = prepare(X); 68 | 69 | eta.L = eta.L*X.invRtR; 70 | eta.R = eta.R*X.invLtL; 71 | end 72 | 73 | M.ehess2rhess = @ehess2rhess; 74 | function Hess = ehess2rhess(X, egrad, ehess, eta) 75 | X = prepare(X); 76 | 77 | % Riemannian gradient 78 | rgrad = egrad2rgrad(X, egrad); 79 | 80 | % Directional derivative of the Riemannian gradient 81 | Hess.L = ehess.L*X.invRtR - 2*egrad.L*(X.invRtR * symm(eta.R'*X.R) * X.invRtR); 82 | Hess.R = ehess.R*X.invLtL - 2*egrad.R*(X.invLtL * symm(eta.L'*X.L) * X.invLtL); 83 | 84 | % We still need a correction factor for the non-constant metric 85 | Hess.L = Hess.L + rgrad.L*(symm(eta.R'*X.R)*X.invRtR) + eta.L*(symm(rgrad.R'*X.R)*X.invRtR) - X.L*(symm(eta.R'*rgrad.R)*X.invRtR); 86 | Hess.R = Hess.R + rgrad.R*(symm(eta.L'*X.L)*X.invLtL) + eta.R*(symm(rgrad.L'*X.L)*X.invLtL) - X.R*(symm(eta.L'*rgrad.L)*X.invLtL); 87 | 88 | % Project on the horizontal space 89 | Hess = M.proj(X, Hess); 90 | 91 | end 92 | 93 | M.proj = @projection; 94 | function etaproj = projection(X, eta) 95 | X = prepare(X); 96 | 97 | Lambda = (eta.R'*X.R)*X.invRtR - X.invLtL*(X.L'*eta.L); 98 | Lambda = Lambda/2; 99 | etaproj.L = eta.L + X.L*Lambda; 100 | etaproj.R = eta.R - X.R*Lambda'; 101 | end 102 | 103 | M.tangent = M.proj; 104 | M.tangent2ambient = @(X, eta) eta; 105 | 106 | 107 | 108 | M.retr = @retraction; 109 | function Y = retraction(X, eta, t) 110 | if nargin < 3 111 | t = 1.0; 112 | end 113 | Y.L = X.L + t*eta.L; 114 | Y.R = X.R + t*eta.R; 115 | 116 | % Numerical conditioning step: A simpler version. 117 | % We need to ensure that L and R are do not have very relative 118 | % skewed norms. 119 | 120 | scaling = norm(X.L, 'fro')/norm(X.R, 'fro'); 121 | scaling = sqrt(scaling); 122 | Y.L = Y.L / scaling; 123 | Y.R = Y.R * scaling; 124 | 125 | % These are reused in the computation of the gradient and Hessian 126 | Y = prepare(Y); 127 | end 128 | 129 | 130 | M.exp = @exponential; 131 | function Y = exponential(X, eta, t) 132 | if nargin < 3 133 | t = 1.0; 134 | end 135 | 136 | Y = retraction(X, eta, t); 137 | warning('manopt:fixedrankfactory_2factors_preconditioned:exp', ... 138 | ['Exponential for fixed rank ' ... 139 | 'manifold not implemented yet. Used retraction instead.']); 140 | end 141 | 142 | M.hash = @(X) ['z' hashmd5([X.L(:) ; X.R(:)])]; 143 | 144 | M.rand = @random; 145 | 146 | function X = random() 147 | X.L = randn(m, k); 148 | X.R = randn(n, k); 149 | end 150 | 151 | M.randvec = @randomvec; 152 | function eta = randomvec(X) 153 | eta.L = randn(m, k); 154 | eta.R = randn(n, k); 155 | eta = projection(X, eta); 156 | nrm = M.norm(X, eta); 157 | eta.L = eta.L / nrm; 158 | eta.R = eta.R / nrm; 159 | end 160 | 161 | M.lincomb = @lincomb; 162 | 163 | M.zerovec = @(X) struct('L', zeros(m, k),'R', zeros(n, k)); 164 | 165 | M.transp = @(x1, x2, d) projection(x2, d); 166 | 167 | % vec and mat are not isometries, because of the unusual inner metric. 168 | M.vec = @(X, U) [U.L(:) ; U.R(:)]; 169 | M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ... 170 | 'R', reshape(u((m*k+1):end), n, k)); 171 | M.vecmatareisometries = @() false; 172 | 173 | end 174 | 175 | % Linear combination of tangent vectors 176 | function d = lincomb(x, a1, d1, a2, d2) %#ok 177 | 178 | if nargin == 3 179 | d.L = a1*d1.L; 180 | d.R = a1*d1.R; 181 | elseif nargin == 5 182 | d.L = a1*d1.L + a2*d2.L; 183 | d.R = a1*d1.R + a2*d2.R; 184 | else 185 | error('Bad use of fixedrankfactory_2factors_preconditioned.lincomb.'); 186 | end 187 | 188 | end 189 | 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/fixedrank/fixedrankfactory_3factors.m: -------------------------------------------------------------------------------- 1 | function M = fixedrankfactory_3factors(m, n, k) 2 | % Manifold of m-by-n matrices of rank k with polar quotient geometry. 3 | % 4 | % function M = fixedrankfactory_3factors(m, n, k) 5 | % 6 | % Follows the polar quotient geometry described in the following paper: 7 | % G. Meyer, S. Bonnabel and R. Sepulchre, 8 | % "Linear regression under fixed-rank constraints: a Riemannian approach", 9 | % ICML 2011. 10 | % 11 | % Paper link: http://www.icml-2011.org/papers/350_icmlpaper.pdf 12 | % 13 | % Additional reference is 14 | % 15 | % B. Mishra, R. Meyer, S. Bonnabel and R. Sepulchre 16 | % "Fixed-rank matrix factorizations and Riemannian low-rank optimization", 17 | % arXiv, 2012. 18 | % 19 | % Paper link: http://arxiv.org/abs/1209.0430 20 | % 21 | % A point X on the manifold is represented as a structure with three 22 | % fields: L, S and R. The matrices L (mxk) and R (nxk) are orthonormal, 23 | % while the matrix S (kxk) is a symmetric positive definite full rank 24 | % matrix. 25 | % 26 | % Tangent vectors are represented as a structure with three fields: L, S 27 | % and R. 28 | 29 | % This file is part of Manopt: www.manopt.org. 30 | % Original author: Bamdev Mishra, Dec. 30, 2012. 31 | % Contributors: 32 | % Change log: 33 | 34 | M.name = @() sprintf('LSR'' quotient manifold of %dx%d matrices of rank %d', m, n, k); 35 | 36 | M.dim = @() (m+n-k)*k; 37 | 38 | % Choice of the metric on the orthnormal space is motivated by the symmetry present in the 39 | % space. The metric on the positive definite space is its natural metric. 40 | M.inner = @(X, eta, zeta) eta.L(:).'*zeta.L(:) + eta.R(:).'*zeta.R(:) ... 41 | + trace( (X.S\eta.S) * (X.S\zeta.S) ); 42 | 43 | M.norm = @(X, eta) sqrt(M.inner(X, eta, eta)); 44 | 45 | M.dist = @(x, y) error('fixedrankfactory_3factors.dist not implemented yet.'); 46 | 47 | M.typicaldist = @() 10*k; 48 | 49 | skew = @(X) .5*(X-X'); 50 | symm = @(X) .5*(X+X'); 51 | stiefel_proj = @(L, H) H - L*symm(L'*H); 52 | 53 | M.egrad2rgrad = @egrad2rgrad; 54 | function eta = egrad2rgrad(X, eta) 55 | eta.L = stiefel_proj(X.L, eta.L); 56 | eta.S = X.S*symm(eta.S)*X.S; 57 | eta.R = stiefel_proj(X.R, eta.R); 58 | end 59 | 60 | 61 | M.ehess2rhess = @ehess2rhess; 62 | function Hess = ehess2rhess(X, egrad, ehess, eta) 63 | 64 | % Riemannian gradient for the factor S 65 | rgrad.S = X.S*symm(egrad.S)*X.S; 66 | 67 | % Directional derivatives of the Riemannian gradient 68 | Hess.L = ehess.L - eta.L*symm(X.L'*egrad.L); 69 | Hess.L = stiefel_proj(X.L, Hess.L); 70 | 71 | Hess.R = ehess.R - eta.R*symm(X.R'*egrad.R); 72 | Hess.R = stiefel_proj(X.R, Hess.R); 73 | 74 | Hess.S = X.S*symm(ehess.S)*X.S + 2*symm(eta.S*symm(egrad.S)*X.S); 75 | 76 | % Correction factor for the non-constant metric on the factor S 77 | Hess.S = Hess.S - symm(eta.S*(X.S\rgrad.S)); 78 | 79 | % Projection onto the horizontal space 80 | Hess = M.proj(X, Hess); 81 | end 82 | 83 | 84 | M.proj = @projection; 85 | function etaproj = projection(X, eta) 86 | % First, projection onto the tangent space of the total sapce 87 | eta.L = stiefel_proj(X.L, eta.L); 88 | eta.R = stiefel_proj(X.R, eta.R); 89 | eta.S = symm(eta.S); 90 | 91 | % Then, projection onto the horizontal space 92 | SS = X.S*X.S; 93 | AS = X.S*(skew(X.L'*eta.L) + skew(X.R'*eta.R) - 2*skew(X.S\eta.S))*X.S; 94 | omega = lyap(SS, -AS); 95 | 96 | etaproj.L = eta.L - X.L*omega; 97 | etaproj.S = eta.S - (X.S*omega - omega*X.S); 98 | etaproj.R = eta.R - X.R*omega; 99 | end 100 | 101 | M.tangent = M.proj; 102 | M.tangent2ambient = @(X, eta) eta; 103 | 104 | M.retr = @retraction; 105 | function Y = retraction(X, eta, t) 106 | if nargin < 3 107 | t = 1.0; 108 | end 109 | 110 | L = chol(X.S); 111 | Y.S = L'*expm(L'\(t*eta.S)/L)*L; 112 | Y.L = uf(X.L + t*eta.L); 113 | Y.R = uf(X.R + t*eta.R); 114 | end 115 | 116 | M.exp = @exponential; 117 | function Y = exponential(X, eta, t) 118 | if nargin < 3 119 | t = 1.0; 120 | end 121 | Y = retraction(X, eta, t); 122 | warning('manopt:fixedrankfactory_3factors:exp', ... 123 | ['Exponential for fixed rank ' ... 124 | 'manifold not implemented yet. Lsed retraction instead.']); 125 | end 126 | 127 | M.hash = @(X) ['z' hashmd5([X.L(:) ; X.S(:) ; X.R(:)])]; 128 | 129 | M.rand = @random; 130 | % Factors L and R live on Stiefel manifolds, hence we will reuse 131 | % their random generator. 132 | stiefelm = stiefelfactory(m, k); 133 | stiefeln = stiefelfactory(n, k); 134 | function X = random() 135 | X.L = stiefelm.rand(); 136 | X.R = stiefeln.rand(); 137 | X.S = diag(1+rand(k, 1)); 138 | end 139 | 140 | M.randvec = @randomvec; 141 | function eta = randomvec(X) 142 | % A random vector on the horizontal space 143 | eta.L = randn(m, k); 144 | eta.R = randn(n, k); 145 | eta.S = randn(k, k); 146 | eta = projection(X, eta); 147 | nrm = M.norm(X, eta); 148 | eta.L = eta.L / nrm; 149 | eta.R = eta.R / nrm; 150 | eta.S = eta.S / nrm; 151 | end 152 | 153 | M.lincomb = @lincomb; 154 | 155 | M.zerovec = @(X) struct('L', zeros(m, k), 'S', zeros(k, k), ... 156 | 'R', zeros(n, k)); 157 | 158 | M.transp = @(x1, x2, d) projection(x2, d); 159 | 160 | % vec and mat are not isometries, because of the unusual inner metric. 161 | M.vec = @(X, U) [U.L(:) ; U.S(:); U.R(:)]; 162 | M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ... 163 | 'S', reshape(u((m*k+1): m*k + k*k), k, k), ... 164 | 'R', reshape(u((m*k+ k*k + 1):end), n, k)); 165 | M.vecmatareisometries = @() false; 166 | 167 | end 168 | 169 | % Linear combination of tangent vectors 170 | function d = lincomb(x, a1, d1, a2, d2) %#ok 171 | 172 | if nargin == 3 173 | d.L = a1*d1.L; 174 | d.R = a1*d1.R; 175 | d.S = a1*d1.S; 176 | elseif nargin == 5 177 | d.L = a1*d1.L + a2*d2.L; 178 | d.R = a1*d1.R + a2*d2.R; 179 | d.S = a1*d1.S + a2*d2.S; 180 | else 181 | error('Bad use of fixedrankfactory_3factors.lincomb.'); 182 | end 183 | 184 | end 185 | 186 | function A = uf(A) 187 | [L, unused, R] = svd(A, 0); %#ok 188 | A = L*R'; 189 | end 190 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/rotations/randrot.m: -------------------------------------------------------------------------------- 1 | function R = randrot(n, N) 2 | % Generates uniformly random rotation matrices. 3 | % 4 | % function R = randrot(n, N) 5 | % 6 | % R is a n-by-n-by-N matrix such that each slice R(:, :, i) is an 7 | % orthogonal matrix of size n of determinant +1 (i.e., a matrix in SO(n)). 8 | % By default, N = 1. 9 | % Complexity: N times O(n^3). 10 | % Theory in Diaconis and Shahshahani 1987 for the uniformity on O(n); 11 | % With details in Mezzadri 2007, 12 | % "How to generate random matrices from the classical compact groups." 13 | % To ensure matrices in SO(n), we permute the two first columns when 14 | % the determinant is -1. 15 | % 16 | % See also: randskew 17 | 18 | % This file is part of Manopt: www.manopt.org. 19 | % Original author: Nicolas Boumal, Sept. 25, 2012. 20 | % Contributors: 21 | % Change log: 22 | 23 | if nargin < 2 24 | N = 1; 25 | end 26 | 27 | if n == 1 28 | R = ones(1, 1, N); 29 | return; 30 | end 31 | 32 | R = zeros(n, n, N); 33 | 34 | for i = 1 : N 35 | 36 | % Generated as such, Q is uniformly distributed over O(n), the set 37 | % of orthogonal matrices. 38 | A = randn(n); 39 | [Q, RR] = qr(A); 40 | Q = Q * diag(sign(diag(RR))); %% Mezzadri 2007 41 | 42 | % If Q is in O(n) but not in SO(n), we permute the two first 43 | % columns of Q such that det(new Q) = -det(Q), hence the new Q will 44 | % be in SO(n), uniformly distributed. 45 | if det(Q) < 0 46 | Q(:, [1 2]) = Q(:, [2 1]); 47 | end 48 | 49 | R(:, :, i) = Q; 50 | 51 | end 52 | 53 | end 54 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/rotations/randskew.m: -------------------------------------------------------------------------------- 1 | function S = randskew(n, N) 2 | % Generates random skew symmetric matrices with normal entries. 3 | % 4 | % function S = randskew(n, N) 5 | % 6 | % S is an n-by-n-by-N matrix where each slice S(:, :, i) for i = 1..N is a 7 | % random skew-symmetric matrix with upper triangular entries distributed 8 | % independently following a normal distribution (Gaussian, zero mean, unit 9 | % variance). 10 | % 11 | % See also: randrot 12 | 13 | % This file is part of Manopt: www.manopt.org. 14 | % Original author: Nicolas Boumal, Sept. 25, 2012. 15 | % Contributors: 16 | % Change log: 17 | 18 | 19 | if nargin < 2 20 | N = 1; 21 | end 22 | 23 | % Subindices of the (strictly) upper triangular entries of an n-by-n 24 | % matrix 25 | [I J] = find(triu(ones(n), 1)); 26 | 27 | K = repmat(1:N, n*(n-1)/2, 1); 28 | 29 | % Indices of the strictly upper triangular entries of all N slices of 30 | % an n-by-n-by-N matrix 31 | L = sub2ind([n n N], repmat(I, N, 1), repmat(J, N, 1), K(:)); 32 | 33 | % Allocate memory for N random skew matrices of size n-by-n and 34 | % populate each upper triangular entry with a random number following a 35 | % normal distribution and copy them with opposite sign on the 36 | % corresponding lower triangular side. 37 | S = zeros(n, n, N); 38 | S(L) = randn(size(L)); 39 | S = S-multitransp(S); 40 | 41 | end 42 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/rotations/rotationsfactory.m: -------------------------------------------------------------------------------- 1 | function M = rotationsfactory(n, k) 2 | % Returns a manifold structure to optimize over rotation matrices. 3 | % 4 | % function M = rotationsfactory(n) 5 | % function M = rotationsfactory(n, k) 6 | % 7 | % Special orthogonal group (the manifold of rotations): deals with matrices 8 | % R of size n x n x k (or n x n if k = 1, which is the default) such that 9 | % each n x n matrix is orthogonal, with determinant 1, i.e., X'*X = eye(n) 10 | % if k = 1, or X(:, :, i)' * X(:, :, i) = eye(n) for i = 1 : k if k > 1. 11 | % 12 | % This is a description of SO(n)^k with the induced metric from the 13 | % embedding space (R^nxn)^k, i.e., this manifold is a Riemannian 14 | % submanifold of (R^nxn)^k endowed with the usual trace inner product. 15 | % 16 | % Tangent vectors are represented in the Lie algebra, i.e., as skew 17 | % symmetric matrices. Use the function M.tangent2ambient(X, H) to switch 18 | % from the Lie algebra representation to the embedding space 19 | % representation. 20 | % 21 | % By default, k = 1. 22 | % 23 | % See also: stiefelfactory 24 | 25 | % This file is part of Manopt: www.manopt.org. 26 | % Original author: Nicolas Boumal, Dec. 30, 2012. 27 | % Contributors: 28 | % Change log: 29 | % Jan. 31, 2013, NB : added egrad2rgrad and ehess2rhess 30 | 31 | 32 | if ~exist('k', 'var') || isempty(k) 33 | k = 1; 34 | end 35 | 36 | if k == 1 37 | M.name = @() sprintf('Rotations manifold SO(%d)', n); 38 | elseif k > 1 39 | M.name = @() sprintf('Product rotations manifold SO(%d)^%d', n, k); 40 | else 41 | error('k must be an integer no less than 1.'); 42 | end 43 | 44 | M.dim = @() k*nchoosek(n, 2); 45 | 46 | M.inner = @(x, d1, d2) d1(:).'*d2(:); 47 | 48 | M.norm = @(x, d) norm(d(:)); 49 | 50 | M.typicaldist = @() pi*sqrt(n*k); 51 | 52 | M.proj = @(X, H) multiskew(multiprod(multitransp(X), H)); 53 | 54 | M.tangent = @(X, H) multiskew(H); 55 | 56 | M.tangent2ambient = @(X, U) multiprod(X, U); 57 | 58 | M.egrad2rgrad = M.proj; 59 | 60 | M.ehess2rhess = @ehess2rhess; 61 | function Rhess = ehess2rhess(X, Egrad, Ehess, H) 62 | % Reminder : H contains skew-symmeric matrices. The actual 63 | % direction that the point X is moved along is X*H. 64 | Xt = multitransp(X); 65 | XtEgrad = multiprod(Xt, Egrad); 66 | symXtEgrad = multisym(XtEgrad); 67 | XtEhess = multiprod(Xt, Ehess); 68 | Rhess = multiskew( XtEhess - multiprod(H, symXtEgrad) ); 69 | end 70 | 71 | M.retr = @retraction; 72 | function Y = retraction(X, U, t) 73 | if nargin == 3 74 | tU = t*U; 75 | else 76 | tU = U; 77 | end 78 | Y = X + multiprod(X, tU); 79 | for i = 1 : k 80 | [Q R] = qr(Y(:, :, i)); 81 | % The instruction with R ensures we are not flipping signs 82 | % of some columns, which should never happen in modern Matlab 83 | % versions but may be an issue with older versions. 84 | Y(:, :, i) = Q * diag(sign(sign(diag(R))+.5)); 85 | % This is guaranteed to always yield orthogonal matrices with 86 | % determinant +1. Simply look at the eigenvalues of a skew 87 | % symmetric matrix, than at those of identity plus that matrix, 88 | % and compute their product for the determinant: it's stricly 89 | % positive in all cases. 90 | end 91 | end 92 | 93 | M.exp = @exponential; 94 | function Y = exponential(X, U, t) 95 | if nargin == 3 96 | exptU = t*U; 97 | else 98 | exptU = U; 99 | end 100 | for i = 1 : k 101 | exptU(:, :, i) = expm(exptU(:, :, i)); 102 | end 103 | Y = multiprod(X, exptU); 104 | end 105 | 106 | M.log = @logarithm; 107 | function U = logarithm(X, Y) 108 | U = multiprod(multitransp(X), Y); 109 | for i = 1 : k 110 | % The result of logm should be real in theory, but it is 111 | % numerically useful to force it. 112 | U(:, :, i) = real(logm(U(:, :, i))); 113 | end 114 | % Ensure the tangent vector is in the Lie algebra. 115 | U = multiskew(U); 116 | end 117 | 118 | M.hash = @(X) ['z' hashmd5(X(:))]; 119 | 120 | M.rand = @() randrot(n, k); 121 | 122 | M.randvec = @randomvec; 123 | function U = randomvec(X) %#ok 124 | U = randskew(n, k); 125 | nrmU = sqrt(U(:).'*U(:)); 126 | U = U / nrmU; 127 | end 128 | 129 | M.lincomb = @lincomb; 130 | 131 | M.zerovec = @(x) zeros(n, n, k); 132 | 133 | M.transp = @(x1, x2, d) d; 134 | 135 | M.pairmean = @pairmean; 136 | function Y = pairmean(X1, X2) 137 | V = M.log(X1, X2); 138 | Y = M.exp(X1, .5*V); 139 | end 140 | 141 | M.dist = @(x, y) M.norm(x, M.log(x, y)); 142 | 143 | M.vec = @(x, u_mat) u_mat(:); 144 | M.mat = @(x, u_vec) reshape(u_vec, [n, n, k]); 145 | M.vecmatareisometries = @() true; 146 | 147 | end 148 | 149 | % Linear combination of tangent vectors 150 | function d = lincomb(x, a1, d1, a2, d2) %#ok 151 | 152 | if nargin == 3 153 | d = a1*d1; 154 | elseif nargin == 5 155 | d = a1*d1 + a2*d2; 156 | else 157 | error('Bad use of rotations.lincomb.'); 158 | end 159 | 160 | end 161 | 162 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/sphere/spherecomplexfactory.m: -------------------------------------------------------------------------------- 1 | function M = spherecomplexfactory(n, m) 2 | % Returns a manifold struct to optimize over unit-norm complex matrices. 3 | % 4 | % function M = spherecomplexfactory(n) 5 | % function M = spherecomplexfactory(n, m) 6 | % 7 | % Manifold of n-by-m complex matrices of unit Frobenius norm. 8 | % By default, m = 1, which corresponds to the unit sphere in C^n. The 9 | % metric is such that the sphere is a Riemannian submanifold of the space 10 | % of 2nx2m real matrices with the usual trace inner product, i.e., the 11 | % usual metric. 12 | % 13 | % See also: spherefactory 14 | 15 | % This file is part of Manopt: www.manopt.org. 16 | % Original author: Nicolas Boumal, Dec. 30, 2012. 17 | % Contributors: 18 | % Change log: 19 | 20 | 21 | if ~exist('m', 'var') 22 | m = 1; 23 | end 24 | 25 | if m == 1 26 | M.name = @() sprintf('Complex sphere S^%d', n-1); 27 | else 28 | M.name = @() sprintf('Unit F-norm %dx%d complex matrices', n, m); 29 | end 30 | 31 | M.dim = @() 2*(n*m)-1; 32 | 33 | M.inner = @(x, d1, d2) real(d1(:)'*d2(:)); 34 | 35 | M.norm = @(x, d) norm(d, 'fro'); 36 | 37 | M.dist = @(x, y) acos(real(x(:)'*y(:))); 38 | 39 | M.typicaldist = @() pi; 40 | 41 | M.proj = @(x, d) reshape(d(:) - x(:)*(real(x(:)'*d(:))), n, m); 42 | 43 | % For Riemannian submanifolds, converting a Euclidean gradient into a 44 | % Riemannian gradient amounts to an orthogonal projection. 45 | M.egrad2rgrad = M.proj; 46 | 47 | M.tangent = M.proj; 48 | 49 | M.exp = @exponential; 50 | 51 | M.retr = @retraction; 52 | 53 | M.log = @logarithm; 54 | function v = logarithm(x1, x2) 55 | error('The logarithmic map is not yet implemented for this manifold.'); 56 | end 57 | 58 | M.hash = @(x) ['z' hashmd5([real(x(:)) ; imag(x(:))])]; 59 | 60 | M.rand = @() random(n, m); 61 | 62 | M.randvec = @(x) randomvec(n, m, x); 63 | 64 | M.lincomb = @lincomb; 65 | 66 | M.zerovec = @(x) zeros(n, m); 67 | 68 | M.transp = @(x1, x2, d) M.proj(x2, d); 69 | 70 | M.pairmean = @pairmean; 71 | function y = pairmean(x1, x2) 72 | y = x1+x2; 73 | y = y / norm(y, 'fro'); 74 | end 75 | 76 | end 77 | 78 | % Exponential on the sphere 79 | function y = exponential(x, d, t) 80 | 81 | if nargin == 2 82 | t = 1; 83 | end 84 | 85 | td = t*d; 86 | 87 | nrm_td = norm(td, 'fro'); 88 | 89 | if nrm_td > 1e-6 90 | y = x*cos(nrm_td) + td*(sin(nrm_td)/nrm_td); 91 | else 92 | % If the step is too small, to avoid dividing by nrm_td, we choose 93 | % to approximate with this retraction-like step. 94 | y = x + td; 95 | y = y / norm(y, 'fro'); 96 | end 97 | 98 | end 99 | 100 | % Retraction on the sphere 101 | function y = retraction(x, d, t) 102 | 103 | if nargin == 2 104 | t = 1; 105 | end 106 | 107 | y = x+t*d; 108 | y = y/norm(y, 'fro'); 109 | 110 | end 111 | 112 | % Uniform random sampling on the sphere. 113 | function x = random(n, m) 114 | 115 | x = randn(n, m) + 1i*randn(n, m); 116 | x = x/norm(x, 'fro'); 117 | 118 | end 119 | 120 | % Random normalized tangent vector at x. 121 | function d = randomvec(n, m, x) 122 | 123 | d = randn(n, m) + 1i*randn(n, m); 124 | d = reshape(d(:) - x(:)*(real(x(:)'*d(:))), n, m); 125 | d = d / norm(d, 'fro'); 126 | 127 | end 128 | 129 | % Linear combination of tangent vectors 130 | function d = lincomb(x, a1, d1, a2, d2) %#ok 131 | 132 | if nargin == 3 133 | d = a1*d1; 134 | elseif nargin == 5 135 | d = a1*d1 + a2*d2; 136 | else 137 | error('Bad use of spherecomplex.lincomb.'); 138 | end 139 | 140 | end 141 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/sphere/spherefactory.m: -------------------------------------------------------------------------------- 1 | function M = spherefactory(n, m) 2 | % Returns a manifold struct to optimize over unit-norm vectors or matrices. 3 | % 4 | % function M = spherefactory(n) 5 | % function M = spherefactory(n, m) 6 | % 7 | % Manifold of n-by-m real matrices of unit Frobenius norm. 8 | % By default, m = 1, which corresponds to the unit sphere in R^n. The 9 | % metric is such that the sphere is a Riemannian submanifold of the space 10 | % of nxm matrices with the usual trace inner product, i.e., the usual 11 | % metric. 12 | % 13 | % See also: obliquefactory spherecomplexfactory 14 | 15 | % This file is part of Manopt: www.manopt.org. 16 | % Original author: Nicolas Boumal, Dec. 30, 2012. 17 | % Contributors: 18 | % Change log: 19 | 20 | 21 | if ~exist('m', 'var') 22 | m = 1; 23 | end 24 | 25 | if m == 1 26 | M.name = @() sprintf('Sphere S^%d', n-1); 27 | else 28 | M.name = @() sprintf('Unit F-norm %dx%d matrices', n, m); 29 | end 30 | 31 | M.dim = @() n*m-1; 32 | 33 | M.inner = @(x, d1, d2) d1(:).'*d2(:); 34 | 35 | M.norm = @(x, d) norm(d, 'fro'); 36 | 37 | M.dist = @(x, y) real(acos(x(:).'*y(:))); 38 | 39 | M.typicaldist = @() pi; 40 | 41 | M.proj = @(x, d) d - x*(x(:).'*d(:)); 42 | 43 | M.tangent = M.proj; 44 | 45 | % For Riemannian submanifolds, converting a Euclidean gradient into a 46 | % Riemannian gradient amounts to an orthogonal projection. 47 | M.egrad2rgrad = M.proj; 48 | 49 | M.ehess2rhess = @ehess2rhess; 50 | function rhess = ehess2rhess(x, egrad, ehess, u) 51 | rhess = M.proj(x, ehess) - (x(:)'*egrad(:))*u; 52 | end 53 | 54 | M.exp = @exponential; 55 | 56 | M.retr = @retraction; 57 | 58 | M.log = @logarithm; 59 | function v = logarithm(x1, x2) 60 | v = M.proj(x1, x2 - x1); 61 | di = M.dist(x1, x2); 62 | nv = norm(v, 'fro'); 63 | v = v * (di / nv); 64 | end 65 | 66 | M.hash = @(x) ['z' hashmd5(x(:))]; 67 | 68 | M.rand = @() random(n, m); 69 | 70 | M.randvec = @(x) randomvec(n, m, x); 71 | 72 | M.lincomb = @lincomb; 73 | 74 | M.zerovec = @(x) zeros(n, m); 75 | 76 | M.transp = @(x1, x2, d) M.proj(x2, d); 77 | 78 | M.pairmean = @pairmean; 79 | function y = pairmean(x1, x2) 80 | y = x1+x2; 81 | y = y / norm(y, 'fro'); 82 | end 83 | 84 | M.vec = @(x, u_mat) u_mat(:); 85 | M.mat = @(x, u_vec) reshape(u_vec, [n, m]); 86 | M.vecmatareisometries = @() true; 87 | 88 | end 89 | 90 | % Exponential on the sphere 91 | function y = exponential(x, d, t) 92 | 93 | if nargin == 2 94 | t = 1; 95 | end 96 | 97 | td = t*d; 98 | 99 | nrm_td = norm(td, 'fro'); 100 | 101 | if nrm_td > 1e-6 102 | y = x*cos(nrm_td) + td*(sin(nrm_td)/nrm_td); 103 | else 104 | % if the step is too small, to avoid dividing by nrm_td, we choose 105 | % to approximate with this retraction-like step. 106 | y = x + td; 107 | y = y / norm(y, 'fro'); 108 | end 109 | 110 | end 111 | 112 | % Retraction on the sphere 113 | function y = retraction(x, d, t) 114 | 115 | if nargin == 2 116 | t = 1; 117 | end 118 | 119 | y = x + t*d; 120 | y = y / norm(y, 'fro'); 121 | 122 | end 123 | 124 | % Uniform random sampling on the sphere. 125 | function x = random(n, m) 126 | 127 | x = randn(n, m); 128 | x = x/norm(x, 'fro'); 129 | 130 | end 131 | 132 | % Random normalized tangent vector at x. 133 | function d = randomvec(n, m, x) 134 | 135 | d = randn(n, m); 136 | d = d - x*(x(:).'*d(:)); 137 | d = d / norm(d, 'fro'); 138 | 139 | end 140 | 141 | % Linear combination of tangent vectors 142 | function d = lincomb(x, a1, d1, a2, d2) %#ok 143 | 144 | if nargin == 3 145 | d = a1*d1; 146 | elseif nargin == 5 147 | d = a1*d1 + a2*d2; 148 | else 149 | error('Bad use of sphere.lincomb.'); 150 | end 151 | 152 | end 153 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/stiefel/stiefelfactory.m: -------------------------------------------------------------------------------- 1 | function M = stiefelfactory(n, p, k) 2 | % Returns a manifold structure to optimize over orthonormal matrices. 3 | % 4 | % function M = stiefelfactory(n, p) 5 | % function M = stiefelfactory(n, p, k) 6 | % 7 | % The Stiefel manifold is the set of orthonormal nxp matrices. If k 8 | % is larger than 1, this is the Cartesian product of the Stiefel manifold 9 | % taken k times. The metric is such that the manifold is a Riemannian 10 | % submanifold of R^nxp equipped with the usual trace inner product, that 11 | % is, it is the usual metric. 12 | % 13 | % Points are represented as matrices X of size n x p x k (or n x p if k=1, 14 | % which is the default) such that each n x p matrix is orthonormal, 15 | % i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for 16 | % i = 1 : k if k > 1. Tangent vectors are represented as matrices the same 17 | % size as points. 18 | % 19 | % By default, k = 1. 20 | % 21 | % See also: grassmannfactory rotationsfactory 22 | 23 | % This file is part of Manopt: www.manopt.org. 24 | % Original author: Nicolas Boumal, Dec. 30, 2012. 25 | % Contributors: 26 | % Change log: 27 | % July 5, 2013 (NB) : Added ehess2rhess. 28 | % Jan. 27, 2014 (BM) : Bug in ehess2rhess corrected. 29 | % June 24, 2014 (NB) : Added true exponential map and changed the randvec 30 | % function so that it now returns a globally 31 | % normalized vector, not a vector where each 32 | % component is normalized (this only matters if k>1). 33 | 34 | 35 | if ~exist('k', 'var') || isempty(k) 36 | k = 1; 37 | end 38 | 39 | if k == 1 40 | M.name = @() sprintf('Stiefel manifold St(%d, %d)', n, p); 41 | elseif k > 1 42 | M.name = @() sprintf('Product Stiefel manifold St(%d, %d)^%d', n, p, k); 43 | else 44 | error('k must be an integer no less than 1.'); 45 | end 46 | 47 | M.dim = @() k*(n*p - .5*p*(p+1)); 48 | 49 | M.inner = @(x, d1, d2) d1(:).'*d2(:); 50 | 51 | M.norm = @(x, d) norm(d(:)); 52 | 53 | M.dist = @(x, y) error('stiefel.dist not implemented yet.'); 54 | 55 | M.typicaldist = @() sqrt(p*k); 56 | 57 | M.proj = @projection; 58 | function Up = projection(X, U) 59 | 60 | XtU = multiprod(multitransp(X), U); 61 | symXtU = multisym(XtU); 62 | Up = U - multiprod(X, symXtU); 63 | 64 | % The code above is equivalent to, but much faster than, the code below. 65 | % 66 | % Up = zeros(size(U)); 67 | % function A = sym(A), A = .5*(A+A'); end 68 | % for i = 1 : k 69 | % Xi = X(:, :, i); 70 | % Ui = U(:, :, i); 71 | % Up(:, :, i) = Ui - Xi*sym(Xi'*Ui); 72 | % end 73 | 74 | end 75 | 76 | M.tangent = M.proj; 77 | 78 | % For Riemannian submanifolds, converting a Euclidean gradient into a 79 | % Riemannian gradient amounts to an orthogonal projection. 80 | M.egrad2rgrad = M.proj; 81 | 82 | M.ehess2rhess = @ehess2rhess; 83 | function rhess = ehess2rhess(X, egrad, ehess, H) 84 | XtG = multiprod(multitransp(X), egrad); 85 | symXtG = multisym(XtG); 86 | HsymXtG = multiprod(H, symXtG); 87 | rhess = projection(X, ehess - HsymXtG); 88 | end 89 | 90 | M.retr = @retraction; 91 | function Y = retraction(X, U, t) 92 | if nargin < 3 93 | t = 1.0; 94 | end 95 | Y = X + t*U; 96 | for i = 1 : k 97 | [Q, R] = qr(Y(:, :, i), 0); 98 | % The instruction with R assures we are not flipping signs 99 | % of some columns, which should never happen in modern Matlab 100 | % versions but may be an issue with older versions. 101 | Y(:, :, i) = Q * diag(sign(sign(diag(R))+.5)); 102 | end 103 | end 104 | 105 | M.exp = @exponential; 106 | function Y = exponential(X, U, t) 107 | if nargin == 2 108 | t = 1; 109 | end 110 | tU = t*U; 111 | Y = zeros(size(X)); 112 | for i = 1 : k 113 | % From a formula by Ross Lippert, Example 5.4.2 in AMS08. 114 | Xi = X(:, :, i); 115 | Ui = tU(:, :, i); 116 | Y(:, :, i) = [Xi Ui] * ... 117 | expm([Xi'*Ui , -Ui'*Ui ; eye(p) , Xi'*Ui]) * ... 118 | [ expm(-Xi'*Ui) ; zeros(p) ]; 119 | end 120 | 121 | end 122 | 123 | M.hash = @(X) ['z' hashmd5(X(:))]; 124 | 125 | M.rand = @random; 126 | function X = random() 127 | X = zeros(n, p, k); 128 | for i = 1 : k 129 | [Q, unused] = qr(randn(n, p), 0); %#ok 130 | X(:, :, i) = Q; 131 | end 132 | end 133 | 134 | M.randvec = @randomvec; 135 | function U = randomvec(X) 136 | U = projection(X, randn(n, p, k)); 137 | U = U / norm(U(:)); 138 | end 139 | 140 | M.lincomb = @lincomb; 141 | 142 | M.zerovec = @(x) zeros(n, p, k); 143 | 144 | M.transp = @(x1, x2, d) projection(x2, d); 145 | 146 | M.vec = @(x, u_mat) u_mat(:); 147 | M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]); 148 | M.vecmatareisometries = @() true; 149 | 150 | end 151 | 152 | % Linear combination of tangent vectors 153 | function d = lincomb(x, a1, d1, a2, d2) %#ok 154 | 155 | if nargin == 3 156 | d = a1*d1; 157 | elseif nargin == 5 158 | d = a1*d1 + a2*d2; 159 | else 160 | error('Bad use of stiefel.lincomb.'); 161 | end 162 | 163 | end 164 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/symfixedrank/spectrahedronfactory.m: -------------------------------------------------------------------------------- 1 | function M = spectrahedronfactory(n, k) 2 | % Manifold of n-by-n symmetric positive semidefinite natrices of rank k 3 | % with trace (sum of diagonal elements) being 1. 4 | % 5 | % function M = spectrahedronfactory(n, k) 6 | % 7 | % The goemetry is based on the paper, 8 | % M. Journee, P.-A. Absil, F. Bach and R. Sepulchre, 9 | % "Low-Rank Optinization on the Cone of Positive Semidefinite Matrices", 10 | % SIOPT, 2010. 11 | % 12 | % Paper link: http://www.di.ens.fr/~fbach/journee2010_sdp.pdf 13 | % 14 | % A point X on the manifold is parameterized as YY^T where Y is a matrix of 15 | % size nxk. The matrix Y (nxk) is a full colunn-rank natrix. Hence, we deal 16 | % directly with Y. The trace constraint on X translates to the Frobenius 17 | % norm constrain on Y, i.e., trace(X) = || Y ||^2. 18 | 19 | % This file is part of Manopt: www.nanopt.org. 20 | % Original author: Bamdev Mishra, July 11, 2013. 21 | % Contributors: 22 | % Change log: 23 | 24 | 25 | 26 | M.name = @() sprintf('YY'' quotient manifold of %dx%d PSD matrices of rank %d with trace 1 ', n, k); 27 | 28 | M.dim = @() (k*n - 1) - k*(k-1)/2; % Extra -1 is because of the trace constraint that 29 | 30 | % Euclidean metric on the total space 31 | M.inner = @(Y, eta, zeta) trace(eta'*zeta); 32 | 33 | M.norm = @(Y, eta) sqrt(M.inner(Y, eta, eta)); 34 | 35 | M.dist = @(Y, Z) error('spectrahedronfactory.dist not implemented yet.'); 36 | 37 | M.typicaldist = @() 10*k; 38 | 39 | M.proj = @projection; 40 | function etaproj = projection(Y, eta) 41 | % Projection onto the tangent space, i.e., on the tangent space of 42 | % ||Y|| = 1 43 | 44 | eta = eta - trace(eta'*Y)*Y; 45 | 46 | % Projection onto the horizontal space 47 | YtY = Y'*Y; 48 | SS = YtY; 49 | AS = Y'*eta - eta'*Y; 50 | Omega = lyap(SS, -AS); 51 | etaproj = eta - Y*Omega; 52 | end 53 | 54 | M.tangent = M.proj; 55 | M.tangent2ambient = @(Y, eta) eta; 56 | 57 | M.retr = @retraction; 58 | function Ynew = retraction(Y, eta, t) 59 | if nargin < 3 60 | t = 1.0; 61 | end 62 | Ynew = Y + t*eta; 63 | Ynew = Ynew/norm(Ynew,'fro'); 64 | end 65 | 66 | 67 | M.egrad2rgrad = @(Y, eta) eta - trace(eta'*Y)*Y; 68 | 69 | M.ehess2rhess = @ehess2rhess; 70 | function Hess = ehess2rhess(Y, egrad, ehess, eta) 71 | 72 | % Directional derivative of the Riemannian gradient 73 | Hess = ehess - trace(egrad'*Y)*eta - (trace(ehess'*Y) + trace(egrad'*eta))*Y; 74 | Hess = Hess - trace(Hess'*Y)*Y; 75 | 76 | % Project on the horizontal space 77 | Hess = M.proj(Y, Hess); 78 | 79 | end 80 | 81 | M.exp = @exponential; 82 | function Ynew = exponential(Y, eta, t) 83 | if nargin < 3 84 | t = 1.0; 85 | end 86 | 87 | Ynew = retraction(Y, eta, t); 88 | warning('manopt:spectrahedronfactory:exp', ... 89 | ['Exponential for fixed rank spectrahedron ' ... 90 | 'manifold not implenented yet. Used retraction instead.']); 91 | end 92 | 93 | % Notice that the hash of two equivalent points will be different... 94 | M.hash = @(Y) ['z' hashmd5(Y(:))]; 95 | 96 | M.rand = @random; 97 | 98 | function Y = random() 99 | Y = randn(n, k); 100 | Y = Y/norm(Y,'fro'); 101 | end 102 | 103 | M.randvec = @randomvec; 104 | function eta = randomvec(Y) 105 | eta = randn(n, k); 106 | eta = projection(Y, eta); 107 | nrm = M.norm(Y, eta); 108 | eta = eta / nrm; 109 | end 110 | 111 | M.lincomb = @lincomb; 112 | 113 | M.zerovec = @(Y) zeros(n, k); 114 | 115 | M.transp = @(Y1, Y2, d) projection(Y2, d); 116 | 117 | M.vec = @(Y, u_mat) u_mat(:); 118 | M.mat = @(Y, u_vec) reshape(u_vec, [n, k]); 119 | M.vecmatareisometries = @() true; 120 | 121 | end 122 | 123 | 124 | % Linear conbination of tangent vectors 125 | function d = lincomb(Y, a1, d1, a2, d2) %#ok 126 | 127 | if nargin == 3 128 | d = a1*d1; 129 | elseif nargin == 5 130 | d = a1*d1 + a2*d2; 131 | else 132 | error('Bad use of spectrahedronfactory.lincomb.'); 133 | end 134 | 135 | end 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/symfixedrank/symfixedrankYYfactory.m: -------------------------------------------------------------------------------- 1 | function M = symfixedrankYYfactory(n, k) 2 | % Manifold of n-by-n symmetric positive semidefinite matrices of rank k. 3 | % 4 | % function M = symfixedrankYYfactory(n, k) 5 | % 6 | % The geometry is based on the paper, 7 | % M. Journee, P.-A. Absil, F. Bach and R. Sepulchre, 8 | % "Low-Rank Optimization on the Cone of Positive Semidefinite Matrices", 9 | % SIAM Journal on Optimization, 2010. 10 | % 11 | % Paper link: http://www.di.ens.fr/~fbach/journee2010_sdp.pdf 12 | % 13 | % A point X on the manifold is parameterized as YY^T where Y is a matrix of 14 | % size nxk. The matrix Y (nxk) is a full column-rank matrix. Hence, we deal 15 | % directly with Y. 16 | % 17 | % Notice that this manifold is not complete: if optimization leads Y to be 18 | % rank-deficient, the geometry will break down. Hence, this geometry should 19 | % only be used if it is expected that the points of interest will have rank 20 | % exactly k. Reduce k if that is not the case. 21 | % 22 | % An alternative, complete, geometry for positive semidefinite matrices of 23 | % rank k is described in Bonnabel and Sepulchre 2009, "Riemannian Metric 24 | % and Geometric Mean for Positive Semidefinite Matrices of Fixed Rank", 25 | % SIAM Journal on Matrix Analysis and Applications. 26 | 27 | % This file is part of Manopt: www.manopt.org. 28 | % Original author: Bamdev Mishra, Dec. 30, 2012. 29 | % Contributors: 30 | % Change log: 31 | % July 10, 2013 (NB) 32 | % Added vec, mat, tangent, tangent2ambient ; 33 | % Correction for the dimension of the manifold. 34 | 35 | 36 | M.name = @() sprintf('YY'' quotient manifold of %dx%d PSD matrices of rank %d', n, k); 37 | 38 | M.dim = @() k*n - k*(k-1)/2; 39 | 40 | % Euclidean metric on the total space 41 | M.inner = @(Y, eta, zeta) trace(eta'*zeta); 42 | 43 | M.norm = @(Y, eta) sqrt(M.inner(Y, eta, eta)); 44 | 45 | M.dist = @(Y, Z) error('symfixedrankYYfactory.dist not implemented yet.'); 46 | 47 | M.typicaldist = @() 10*k; 48 | 49 | M.proj = @projection; 50 | function etaproj = projection(Y, eta) 51 | % Projection onto the horizontal space 52 | YtY = Y'*Y; 53 | SS = YtY; 54 | AS = Y'*eta - eta'*Y; 55 | Omega = lyap(SS, -AS); 56 | etaproj = eta - Y*Omega; 57 | end 58 | 59 | M.tangent = M.proj; 60 | M.tangent2ambient = @(Y, eta) eta; 61 | 62 | M.retr = @retraction; 63 | function Ynew = retraction(Y, eta, t) 64 | if nargin < 3 65 | t = 1.0; 66 | end 67 | Ynew = Y + t*eta; 68 | end 69 | 70 | 71 | M.egrad2rgrad = @(Y, eta) eta; 72 | M.ehess2rhess = @(Y, egrad, ehess, U) M.proj(Y, ehess); 73 | 74 | M.exp = @exponential; 75 | function Ynew = exponential(Y, eta, t) 76 | if nargin < 3 77 | t = 1.0; 78 | end 79 | 80 | Ynew = retraction(Y, eta, t); 81 | warning('manopt:symfixedrankYYfactory:exp', ... 82 | ['Exponential for symmetric, fixed-rank ' ... 83 | 'manifold not implemented yet. Used retraction instead.']); 84 | end 85 | 86 | % Notice that the hash of two equivalent points will be different... 87 | M.hash = @(Y) ['z' hashmd5(Y(:))]; 88 | 89 | M.rand = @random; 90 | 91 | function Y = random() 92 | Y = randn(n, k); 93 | end 94 | 95 | M.randvec = @randomvec; 96 | function eta = randomvec(Y) 97 | eta = randn(n, k); 98 | eta = projection(Y, eta); 99 | nrm = M.norm(Y, eta); 100 | eta = eta / nrm; 101 | end 102 | 103 | M.lincomb = @lincomb; 104 | 105 | M.zerovec = @(Y) zeros(n, k); 106 | 107 | M.transp = @(Y1, Y2, d) projection(Y2, d); 108 | 109 | M.vec = @(Y, u_mat) u_mat(:); 110 | M.mat = @(Y, u_vec) reshape(u_vec, [n, k]); 111 | M.vecmatareisometries = @() true; 112 | 113 | end 114 | 115 | 116 | % Linear conbination of tangent vectors 117 | function d = lincomb(Y, a1, d1, a2, d2) %#ok 118 | 119 | if nargin == 3 120 | d = a1*d1; 121 | elseif nargin == 5 122 | d = a1*d1 + a2*d2; 123 | else 124 | error('Bad use of symfixedrankYYfactory.lincomb.'); 125 | end 126 | 127 | end 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /manopt/manopt/manifolds/symfixedrank/sympositivedefinitefactory.m: -------------------------------------------------------------------------------- 1 | function M = sympositivedefinitefactory(n) 2 | % Manifold of n-by-n symmetric positive definite matrices with 3 | % the bi-invariant geometry. 4 | % 5 | % function M = sympositivedefinitefactory(n) 6 | % 7 | % A point X on the manifold is represented as a symmetric positive definite 8 | % matrix X (nxn). 9 | % 10 | % The following material is referenced from Chapter 6 of the book: 11 | % Rajendra Bhatia, "Positive definite matrices", 12 | % Princeton University Press, 2007. 13 | 14 | % This file is part of Manopt: www.manopt.org. 15 | % Original author: Bamdev Mishra, August 29, 2013. 16 | % Contributors: Nicolas Boumal 17 | % Change log: 18 | % 19 | % March 5, 2014 (NB) 20 | % There were a number of mistakes in the code owing to the tacit 21 | % assumption that if X and eta are symmetric, then X\eta is 22 | % symmetric too, which is not the case. See discussion on the Manopt 23 | % forum started on Jan. 19, 2014. Functions norm, dist, exp and log 24 | % were modified accordingly. Furthermore, they only require matrix 25 | % inversion (as well as matrix log or matrix exp), not matrix square 26 | % roots or their inverse. 27 | % 28 | % July 28, 2014 (NB) 29 | % The dim() function returned n*(n-1)/2 instead of n*(n+1)/2. 30 | % Implemented proper parallel transport from Sra and Hosseini (not 31 | % used by default). 32 | % Also added symmetrization in exp and log (to be sure). 33 | 34 | symm = @(X) .5*(X+X'); 35 | 36 | M.name = @() sprintf('Symmetric positive definite geometry of %dx%d matrices', n, n); 37 | 38 | M.dim = @() n*(n+1)/2; 39 | 40 | % Choice of the metric on the orthnormal space is motivated by the 41 | % symmetry present in the space. The metric on the positive definite 42 | % cone is its natural bi-invariant metric. 43 | M.inner = @(X, eta, zeta) trace( (X\eta) * (X\zeta) ); 44 | 45 | % Notice that X\eta is *not* symmetric in general. 46 | M.norm = @(X, eta) sqrt(trace((X\eta)^2)); 47 | 48 | % Same here: X\Y is not symmetric in general. There should be no need 49 | % to take the real part, but rounding errors may cause a small 50 | % imaginary part to appear, so we discard it. 51 | M.dist = @(X, Y) sqrt(real(trace((logm(X\Y))^2))); 52 | 53 | 54 | M.typicaldist = @() sqrt(n*(n+1)/2); 55 | 56 | 57 | M.egrad2rgrad = @egrad2rgrad; 58 | function eta = egrad2rgrad(X, eta) 59 | eta = X*symm(eta)*X; 60 | end 61 | 62 | 63 | M.ehess2rhess = @ehess2rhess; 64 | function Hess = ehess2rhess(X, egrad, ehess, eta) 65 | % Directional derivatives of the Riemannian gradient 66 | Hess = X*symm(ehess)*X + 2*symm(eta*symm(egrad)*X); 67 | 68 | % Correction factor for the non-constant metric 69 | Hess = Hess - symm(eta*symm(egrad)*X); 70 | end 71 | 72 | 73 | M.proj = @(X, eta) symm(eta); 74 | 75 | M.tangent = M.proj; 76 | M.tangent2ambient = @(X, eta) eta; 77 | 78 | M.retr = @exponential; 79 | 80 | M.exp = @exponential; 81 | function Y = exponential(X, eta, t) 82 | if nargin < 3 83 | t = 1.0; 84 | end 85 | % The symm() and real() calls are mathematically not necessary but 86 | % are numerically necessary. 87 | Y = symm(X*real(expm(X\(t*eta)))); 88 | end 89 | 90 | M.log = @logarithm; 91 | function H = logarithm(X, Y) 92 | % Same remark regarding the calls to symm() and real(). 93 | H = symm(X*real(logm(X\Y))); 94 | end 95 | 96 | M.hash = @(X) ['z' hashmd5(X(:))]; 97 | 98 | % Generate a random symmetric positive definite matrix following a 99 | % certain distribution. The particular choice of a distribution is of 100 | % course arbitrary, and specific applications might require different 101 | % ones. 102 | M.rand = @random; 103 | function X = random() 104 | D = diag(1+rand(n, 1)); 105 | [Q, R] = qr(randn(n)); %#ok 106 | X = Q*D*Q'; 107 | end 108 | 109 | % Generate a uniformly random unit-norm tangent vector at X. 110 | M.randvec = @randomvec; 111 | function eta = randomvec(X) 112 | eta = symm(randn(n)); 113 | nrm = M.norm(X, eta); 114 | eta = eta / nrm; 115 | end 116 | 117 | M.lincomb = @lincomb; 118 | 119 | M.zerovec = @(X) zeros(n); 120 | 121 | % Poor man's vector transport: exploit the fact that all tangent spaces 122 | % are the set of symmetric matrices, so that the identity is a sort of 123 | % vector transport. It may perform poorly if the origin and target (X1 124 | % and X2) are far apart though. This should not be the case for typical 125 | % optimization algorithms, which perform small steps. 126 | M.transp = @(X1, X2, eta) eta; 127 | 128 | % For reference, a proper vector transport is given here, following 129 | % work by Sra and Hosseini (2014), "Conic geometric optimisation on the 130 | % manifold of positive definite matrices", 131 | % http://arxiv.org/abs/1312.1039 132 | % This will not be used by default. To force the use of this transport, 133 | % call "M.transp = M.paralleltransp;" on your M returned by the present 134 | % factory. 135 | M.paralleltransp = @parallel_transport; 136 | function zeta = parallel_transport(X, Y, eta) 137 | E = sqrtm((Y/X)); 138 | zeta = E*eta*E'; 139 | end 140 | 141 | % vec and mat are not isometries, because of the unusual inner metric. 142 | M.vec = @(X, U) U(:); 143 | M.mat = @(X, u) reshape(u, n, n); 144 | M.vecmatareisometries = @() false; 145 | 146 | end 147 | 148 | % Linear combination of tangent vectors 149 | function d = lincomb(X, a1, d1, a2, d2) %#ok 150 | if nargin == 3 151 | d = a1*d1; 152 | elseif nargin == 5 153 | d = a1*d1 + a2*d2; 154 | else 155 | error('Bad use of sympositivedefinitefactory.lincomb.'); 156 | end 157 | end 158 | 159 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/applyStatsfun.m: -------------------------------------------------------------------------------- 1 | function stats = applyStatsfun(problem, x, storedb, options, stats) 2 | % Apply the statsfun function to a stats structure (for solvers). 3 | % 4 | % function stats = applyStatsfun(problem, x, storedb, options, stats) 5 | % 6 | % Applies the options.statsfun user supplied function to the stats 7 | % structure, if it was provided, with the appropriate inputs, and returns 8 | % the (possibly) modified stats structure. 9 | % 10 | % See also: 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, April 3, 2013. 14 | % Contributors: 15 | % Change log: 16 | 17 | if isfield(options, 'statsfun') 18 | 19 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 20 | if ~is_octave 21 | narg = nargin(options.statsfun); 22 | else 23 | narg = 4; 24 | end 25 | 26 | switch narg 27 | case 3 28 | stats = options.statsfun(problem, x, stats); 29 | case 4 30 | store = getStore(problem, x, storedb); 31 | stats = options.statsfun(problem, x, stats, store); 32 | otherwise 33 | warning('manopt:statsfun', ... 34 | 'statsfun unused: wrong number of inputs'); 35 | end 36 | end 37 | 38 | end 39 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/canGetCost.m: -------------------------------------------------------------------------------- 1 | function candoit = canGetCost(problem) 2 | % Checks whether the cost function can be computed for a problem structure. 3 | % 4 | % function candoit = canGetCost(problem) 5 | % 6 | % Returns true if the cost function can be computed given the problem 7 | % description, false otherwise. 8 | % 9 | % See also: getCost canGetDirectionalDerivative canGetGradient canGetHessian 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, Dec. 30, 2012. 13 | % Contributors: 14 | % Change log: 15 | 16 | 17 | candoit = isfield(problem, 'cost') || isfield(problem, 'costgrad'); 18 | 19 | end 20 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/canGetDirectionalDerivative.m: -------------------------------------------------------------------------------- 1 | function candoit = canGetDirectionalDerivative(problem) 2 | % Checks whether dir. derivatives can be computed for a problem structure. 3 | % 4 | % function candoit = canGetDirectionalDerivative(problem) 5 | % 6 | % Returns true if the directional derivatives of the cost function can be 7 | % computed given the problem description, false otherwise. 8 | % 9 | % See also: canGetCost canGetGradient canGetHessian 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, Dec. 30, 2012. 13 | % Contributors: 14 | % Change log: 15 | 16 | candoit = isfield(problem, 'diff') || canGetGradient(problem); 17 | 18 | end 19 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/canGetEuclideanGradient.m: -------------------------------------------------------------------------------- 1 | function candoit = canGetEuclideanGradient(problem) 2 | % Checks whether the Euclidean gradient can be computed for a problem. 3 | % 4 | % function candoit = canGetEuclideanGradient(problem) 5 | % 6 | % Returns true if the Euclidean gradient can be computed given the problem 7 | % description, false otherwise. 8 | % 9 | % See also: canGetGradient getEuclideanGradient 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, Dec. 30, 2012. 13 | % Contributors: 14 | % Change log: 15 | 16 | 17 | candoit = isfield(problem, 'egrad'); 18 | 19 | end 20 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/canGetGradient.m: -------------------------------------------------------------------------------- 1 | function candoit = canGetGradient(problem) 2 | % Checks whether the gradient can be computed for a problem structure. 3 | % 4 | % function candoit = canGetGradient(problem) 5 | % 6 | % Returns true if the gradient of the cost function can be computed given 7 | % the problem description, false otherwise. 8 | % 9 | % See also: canGetCost canGetDirectionalDerivative canGetHessian 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, Dec. 30, 2012. 13 | % Contributors: 14 | % Change log: 15 | 16 | candoit = isfield(problem, 'grad') || isfield(problem, 'costgrad') || ... 17 | canGetEuclideanGradient(problem); 18 | 19 | end 20 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/canGetHessian.m: -------------------------------------------------------------------------------- 1 | function candoit = canGetHessian(problem) 2 | % Checks whether the Hessian can be computed for a problem structure. 3 | % 4 | % function candoit = canGetHessian(problem) 5 | % 6 | % Returns true if the Hessian of the cost function can be computed given 7 | % the problem description, false otherwise. 8 | % 9 | % See also: canGetCost canGetDirectionalDerivative canGetGradient 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, Dec. 30, 2012. 13 | % Contributors: 14 | % Change log: 15 | 16 | candoit = isfield(problem, 'hess') || ... 17 | (isfield(problem, 'ehess') && canGetEuclideanGradient(problem)); 18 | 19 | if ~candoit && ... 20 | (isfield(problem, 'ehess') && ~canGetEuclideanGradient(problem)) 21 | warning('manopt:canGetHessian', ... 22 | ['If the Hessian is supplied as a Euclidean Hessian (ehess), ' ... 23 | 'then the Euclidean gradient must also be supplied (egrad).']); 24 | end 25 | 26 | end 27 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/canGetLinesearch.m: -------------------------------------------------------------------------------- 1 | function candoit = canGetLinesearch(problem) 2 | % Checks whether the problem structure can give a line-search a hint. 3 | % 4 | % function candoit = canGetLinesearch(problem) 5 | % 6 | % Returns true if the the problem description includes a mechanism to give 7 | % line-search algorithms a hint as to "how far to look", false otherwise. 8 | % 9 | % See also: getLinesearch 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, July 17, 2014. 13 | % Contributors: 14 | % Change log: 15 | 16 | 17 | candoit = isfield(problem, 'linesearch'); 18 | 19 | end 20 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/canGetPrecon.m: -------------------------------------------------------------------------------- 1 | function candoit = canGetPrecon(problem) 2 | % Checks whether a preconditioner was specified in the problem description. 3 | % 4 | % function candoit = canGetPrecon(problem) 5 | % 6 | % Returns true if a preconditioner was specified, false otherwise. Notice 7 | % that even if this function returns false, it is still possible to call 8 | % getPrecon, as the default preconditioner is simply the identity operator. 9 | % This check function is mostly useful to tell whether that default 10 | % preconditioner will be in use or not. 11 | % 12 | % See also: canGetDirectionalDerivative canGetGradient canGetHessian 13 | 14 | % This file is part of Manopt: www.manopt.org. 15 | % Original author: Nicolas Boumal, July 3, 2013. 16 | % Contributors: 17 | % Change log: 18 | 19 | candoit = isfield(problem, 'precon'); 20 | 21 | end 22 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getApproxHessian.m: -------------------------------------------------------------------------------- 1 | function [approxhess, storedb] = getApproxHessian(problem, x, d, storedb) 2 | % Computes an approximation of the Hessian of the cost fun. at x along d. 3 | % 4 | % function [approxhess, storedb] = getApproxHessian(problem, x, d, storedb) 5 | % 6 | % Returns an approximation of the Hessian at x along d of the cost function 7 | % described in the problem structure. The cache database storedb is passed 8 | % along, possibly modified and returned in the process. 9 | % 10 | % If no approximate Hessian was furnished, this call is redirected to 11 | % getHessianFD. 12 | % 13 | % See also: getHessianFD 14 | 15 | % This file is part of Manopt: www.manopt.org. 16 | % Original author: Nicolas Boumal, Dec. 30, 2012. 17 | % Contributors: 18 | % Change log: 19 | 20 | 21 | if isfield(problem, 'approxhess') 22 | %% Compute the approximate Hessian using approxhess. 23 | 24 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 25 | if ~is_octave 26 | narg = nargin(problem.approxhess); 27 | else 28 | narg = 3; 29 | end 30 | 31 | % Check whether the approximate Hessian function wants to deal with 32 | % the store structure or not. 33 | switch narg 34 | case 2 35 | approxhess = problem.approxhess(x, d); 36 | case 3 37 | % Obtain, pass along, and save the store structure 38 | % associated to this point. 39 | store = getStore(problem, x, storedb); 40 | [approxhess store] = problem.approxhess(x, d, store); 41 | storedb = setStore(problem, x, storedb, store); 42 | otherwise 43 | up = MException('manopt:getApproxHessian:badapproxhess', ... 44 | 'approxhess should accept 2 or 3 inputs.'); 45 | throw(up); 46 | end 47 | 48 | else 49 | %% Try to fall back to a standard FD approximation. 50 | 51 | [approxhess, storedb] = getHessianFD(problem, x, d, storedb); 52 | 53 | end 54 | 55 | end 56 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getCost.m: -------------------------------------------------------------------------------- 1 | function [cost, storedb] = getCost(problem, x, storedb) 2 | % Computes the cost function at x. 3 | % 4 | % function [cost, storedb] = getCost(problem, x, storedb) 5 | % 6 | % Returns the value at x of the cost function described in the problem 7 | % structure. The cache database storedb is passed along, possibly modified 8 | % and returned in the process. 9 | % 10 | % See also: canGetCost 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, Dec. 30, 2012. 14 | % Contributors: 15 | % Change log: 16 | 17 | 18 | if isfield(problem, 'cost') 19 | %% Compute the cost function using cost. 20 | 21 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 22 | if ~is_octave 23 | narg = nargin(problem.cost); 24 | else 25 | narg = 2; 26 | end 27 | 28 | % Check whether the cost function wants to deal with the store 29 | % structure or not. 30 | switch narg 31 | case 1 32 | cost = problem.cost(x); 33 | case 2 34 | % Obtain, pass along, and save the store structure 35 | % associated to this point. 36 | store = getStore(problem, x, storedb); 37 | [cost, store] = problem.cost(x, store); 38 | storedb = setStore(problem, x, storedb, store); 39 | otherwise 40 | up = MException('manopt:getCost:badcost', ... 41 | 'cost should accept 1 or 2 inputs.'); 42 | throw(up); 43 | end 44 | 45 | elseif isfield(problem, 'costgrad') 46 | %% Compute the cost function using costgrad. 47 | 48 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 49 | if ~is_octave 50 | narg = nargin(problem.costgrad); 51 | else 52 | narg = 2; 53 | end 54 | 55 | % Check whether the costgrad function wants to deal with the store 56 | % structure or not. 57 | switch narg 58 | case 1 59 | cost = problem.costgrad(x); 60 | case 2 61 | % Obtain, pass along, and save the store structure 62 | % associated to this point. 63 | store = getStore(problem, x, storedb); 64 | [cost, grad, store] = problem.costgrad(x, store); %#ok 65 | storedb = setStore(problem, x, storedb, store); 66 | otherwise 67 | up = MException('manopt:getCost:badcostgrad', ... 68 | 'costgrad should accept 1 or 2 inputs.'); 69 | throw(up); 70 | end 71 | 72 | else 73 | %% Abandon computing the cost function. 74 | 75 | up = MException('manopt:getCost:fail', ... 76 | ['The problem description is not explicit enough to ' ... 77 | 'compute the cost.']); 78 | throw(up); 79 | 80 | end 81 | 82 | end 83 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getCostGrad.m: -------------------------------------------------------------------------------- 1 | function [cost, grad, storedb] = getCostGrad(problem, x, storedb) 2 | % Computes the cost function and the gradient at x in one call if possible. 3 | % 4 | % function [cost, storedb] = getCostGrad(problem, x, storedb) 5 | % 6 | % Returns the value at x of the cost function described in the problem 7 | % structure, as well as the gradient at x. The cache database storedb is 8 | % passed along, possibly modified and returned in the process. 9 | % 10 | % See also: canGetCost canGetGradient getCost getGradient 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, Dec. 30, 2012. 14 | % Contributors: 15 | % Change log: 16 | 17 | 18 | if isfield(problem, 'costgrad') 19 | %% Compute the cost/grad pair using costgrad. 20 | 21 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 22 | if ~is_octave 23 | narg = nargin(problem.costgrad); 24 | else 25 | narg = 2; 26 | end 27 | 28 | % Check whether the costgrad function wants to deal with the store 29 | % structure or not. 30 | switch narg 31 | case 1 32 | [cost, grad] = problem.costgrad(x); 33 | case 2 34 | % Obtain, pass along, and save the store structure 35 | % associated to this point. 36 | store = getStore(problem, x, storedb); 37 | [cost, grad, store] = problem.costgrad(x, store); 38 | storedb = setStore(problem, x, storedb, store); 39 | otherwise 40 | up = MException('manopt:getCostGrad:badcostgrad', ... 41 | 'costgrad should accept 1 or 2 inputs.'); 42 | throw(up); 43 | end 44 | 45 | else 46 | %% Revert to calling getCost and getGradient separately 47 | 48 | [cost, storedb] = getCost(problem, x, storedb); 49 | [grad, storedb] = getGradient(problem, x, storedb); 50 | 51 | end 52 | 53 | end 54 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getDirectionalDerivative.m: -------------------------------------------------------------------------------- 1 | function [diff, storedb] = getDirectionalDerivative(problem, x, d, storedb) 2 | % Computes the directional derivative of the cost function at x along d. 3 | % 4 | % function [diff, storedb] = getDirectionalDerivative(problem, x, d, storedb) 5 | % 6 | % Returns the derivative at x along d of the cost function described in the 7 | % problem structure. The cache database storedb is passed along, possibly 8 | % modified and returned in the process. 9 | % 10 | % See also: getGradient canGetDirectionalDerivative 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, Dec. 30, 2012. 14 | % Contributors: 15 | % Change log: 16 | 17 | 18 | if isfield(problem, 'diff') 19 | %% Compute the directional derivative using diff. 20 | 21 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 22 | if ~is_octave 23 | narg = nargin(problem.diff); 24 | else 25 | narg = 3; 26 | end 27 | 28 | % Check whether the diff function wants to deal with the store 29 | % structure or not. 30 | switch narg 31 | case 2 32 | diff = problem.diff(x, d); 33 | case 3 34 | % Obtain, pass along, and save the store structure 35 | % associated to this point. 36 | store = getStore(problem, x, storedb); 37 | [diff store] = problem.diff(x, d, store); 38 | storedb = setStore(problem, x, storedb, store); 39 | otherwise 40 | up = MException('manopt:getDirectionalDerivative:baddiff', ... 41 | 'diff should accept 2 or 3 inputs.'); 42 | throw(up); 43 | end 44 | 45 | elseif canGetGradient(problem) 46 | %% Compute the directional derivative using the gradient. 47 | 48 | % Compute the gradient at x, then compute its inner product with d. 49 | [grad, storedb] = getGradient(problem, x, storedb); 50 | diff = problem.M.inner(x, grad, d); 51 | 52 | else 53 | %% Abandon computing the directional derivative. 54 | 55 | up = MException('manopt:getDirectionalDerivative:fail', ... 56 | ['The problem description is not explicit enough to ' ... 57 | 'compute the directional derivatives of f.']); 58 | throw(up); 59 | 60 | end 61 | 62 | end 63 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getEuclideanGradient.m: -------------------------------------------------------------------------------- 1 | function [egrad, storedb] = getEuclideanGradient(problem, x, storedb) 2 | % Computes the Euclidean gradient of the cost function at x. 3 | % 4 | % function [egrad, storedb] = getEuclideanGradient(problem, x, storedb) 5 | % 6 | % Returns the Euclidean gradient at x of the cost function described in the 7 | % problem structure. The cache database storedb is passed along, possibly 8 | % modified and returned in the process. 9 | % 10 | % Because computing the Hessian based on the Euclidean Hessian will require 11 | % the Euclidean gradient every time, to avoid overly redundant 12 | % computations, if the egrad function does not use the store caching 13 | % capabilites, we implement an automatic caching functionality. This means 14 | % that even if the user does not use the store parameter, the hashing 15 | % function will be used, and this can translate in a performance hit for 16 | % small problems. For problems with expensive cost functions, this should 17 | % be a bonus though. Writing egrad to accept the optional store parameter 18 | % (as input and output) will disable automatic caching, but activate user 19 | % controlled caching, which means hashing will be computed in all cases. 20 | % 21 | % If you absolutely do not want hashing to be used (and hence do not want 22 | % caching to be used), you can define grad instead of egrad, without store 23 | % support, and call problem.M.egrad2rgrad manually. 24 | % 25 | % See also: getGradient canGetGradient canGetEuclideanGradient 26 | 27 | % This file is part of Manopt: www.manopt.org. 28 | % Original author: Nicolas Boumal, July 9, 2013. 29 | % Contributors: 30 | % Change log: 31 | 32 | 33 | if isfield(problem, 'egrad') 34 | %% Compute the Euclidean gradient using egrad. 35 | 36 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 37 | if ~is_octave 38 | narg = nargin(problem.egrad); 39 | else 40 | narg = 2; 41 | end 42 | 43 | % Check whether the egrad function wants to deal with the store 44 | % structure or not. 45 | switch narg 46 | case 1 47 | % If it does not want to deal with the store structure, 48 | % then we do some caching of our own. There is a small 49 | % performance hit for this is some cases, but we expect 50 | % that this is most often the preferred choice. 51 | store = getStore(problem, x, storedb); 52 | if ~isfield(store, 'egrad__') 53 | store.egrad__ = problem.egrad(x); 54 | storedb = setStore(problem, x, storedb, store); 55 | end 56 | egrad = store.egrad__; 57 | case 2 58 | % Obtain, pass along, and save the store structure 59 | % associated to this point. If the user deals with the 60 | % store structure, then we don't do any automatic caching: 61 | % the user is in control. 62 | store = getStore(problem, x, storedb); 63 | [egrad, store] = problem.egrad(x, store); 64 | storedb = setStore(problem, x, storedb, store); 65 | otherwise 66 | up = MException('manopt:getEuclideanGradient:badegrad', ... 67 | 'egrad should accept 1 or 2 inputs.'); 68 | throw(up); 69 | end 70 | 71 | else 72 | %% Abandon computing the Euclidean gradient 73 | 74 | up = MException('manopt:getEuclideanGradient:fail', ... 75 | ['The problem description is not explicit enough to ' ... 76 | 'compute the Euclidean gradient of the cost.']); 77 | throw(up); 78 | 79 | end 80 | 81 | end 82 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getGlobalDefaults.m: -------------------------------------------------------------------------------- 1 | function opts = getGlobalDefaults() 2 | % Returns a structure with default option values for Manopt. 3 | % 4 | % function opts = getGlobalDefaults() 5 | % 6 | % Returns a structure opts containing the global default options such as 7 | % verbosity level etc. Typically, global defaults are overwritten by solver 8 | % defaults, which are in turn overwritten by user-specified options. 9 | % See the online Manopt documentation for details on options. 10 | % 11 | % See also: mergeOptions 12 | 13 | % This file is part of Manopt: www.manopt.org. 14 | % Original author: Nicolas Boumal, Dec. 30, 2012. 15 | % Contributors: 16 | % Change log: 17 | 18 | 19 | % Verbosity level: 0 is no output at all. The higher the verbosity, the 20 | % more info is printed / displayed during solver execution. 21 | opts.verbosity = 3; 22 | 23 | % If debug is set to true, additional computations may be performed and 24 | % debugging information is outputed during solver execution. 25 | opts.debug = false; 26 | 27 | % Maximum number of store structures to store. If set to 0, caching 28 | % capabilities are not disabled, but the cache will be emptied at each 29 | % iteration of iterative solvers (more specifically: every time the 30 | % solver calls the purgeStoredb tool). 31 | opts.storedepth = 20; 32 | 33 | % Maximum amount of time a solver may execute, in seconds. 34 | opts.maxtime = inf; 35 | 36 | end 37 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getGradient.m: -------------------------------------------------------------------------------- 1 | function [grad, storedb] = getGradient(problem, x, storedb) 2 | % Computes the gradient of the cost function at x. 3 | % 4 | % function [grad, storedb] = getGradient(problem, x, storedb) 5 | % 6 | % Returns the gradient at x of the cost function described in the problem 7 | % structure. The cache database storedb is passed along, possibly modified 8 | % and returned in the process. 9 | % 10 | % See also: getDirectionalDerivative canGetGradient 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, Dec. 30, 2012. 14 | % Contributors: 15 | % Change log: 16 | 17 | 18 | if isfield(problem, 'grad') 19 | %% Compute the gradient using grad. 20 | 21 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 22 | if ~is_octave 23 | narg = nargin(problem.cost); 24 | else 25 | narg = 2; 26 | end 27 | 28 | % Check whether the gradient function wants to deal with the store 29 | % structure or not. 30 | switch narg 31 | case 1 32 | grad = problem.grad(x); 33 | case 2 34 | % Obtain, pass along, and save the store structure 35 | % associated to this point. 36 | store = getStore(problem, x, storedb); 37 | [grad store] = problem.grad(x, store); 38 | storedb = setStore(problem, x, storedb, store); 39 | otherwise 40 | up = MException('manopt:getGradient:badgrad', ... 41 | 'grad should accept 1 or 2 inputs.'); 42 | throw(up); 43 | end 44 | 45 | elseif isfield(problem, 'costgrad') 46 | %% Compute the gradient using costgrad. 47 | 48 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 49 | if ~is_octave 50 | narg = nargin(problem.costgrad); 51 | else 52 | narg = 2; 53 | end 54 | 55 | % Check whether the costgrad function wants to deal with the store 56 | % structure or not. 57 | switch narg 58 | case 1 59 | [unused, grad] = problem.costgrad(x); %#ok 60 | case 2 61 | % Obtain, pass along, and save the store structure 62 | % associated to this point. 63 | store = getStore(problem, x, storedb); 64 | [unused, grad, store] = problem.costgrad(x, store); %#ok 65 | storedb = setStore(problem, x, storedb, store); 66 | otherwise 67 | up = MException('manopt:getGradient:badcostgrad', ... 68 | 'costgrad should accept 1 or 2 inputs.'); 69 | throw(up); 70 | end 71 | 72 | elseif canGetEuclideanGradient(problem) 73 | %% Compute the gradient using the Euclidean gradient. 74 | 75 | [egrad, storedb] = getEuclideanGradient(problem, x, storedb); 76 | grad = problem.M.egrad2rgrad(x, egrad); 77 | 78 | else 79 | %% Abandon computing the gradient. 80 | 81 | up = MException('manopt:getGradient:fail', ... 82 | ['The problem description is not explicit enough to ' ... 83 | 'compute the gradient of the cost.']); 84 | throw(up); 85 | 86 | end 87 | 88 | end 89 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getHessian.m: -------------------------------------------------------------------------------- 1 | function [hess, storedb] = getHessian(problem, x, d, storedb) 2 | % Computes the Hessian of the cost function at x along d. 3 | % 4 | % function [hess, storedb] = getHessian(problem, x, d, storedb) 5 | % 6 | % Returns the Hessian at x along d of the cost function described in the 7 | % problem structure. The cache database storedb is passed along, possibly 8 | % modified and returned in the process. 9 | % 10 | % If an exact Hessian is not provided, an approximate Hessian is returned 11 | % if possible, without warning. If not possible, an exception will be 12 | % thrown. To check whether an exact Hessian is available or not (typically 13 | % to issue a warning if not), use canGetHessian. 14 | % 15 | % See also: getPrecon getApproxHessian canGetHessian 16 | 17 | % This file is part of Manopt: www.manopt.org. 18 | % Original author: Nicolas Boumal, Dec. 30, 2012. 19 | % Contributors: 20 | % Change log: 21 | 22 | if isfield(problem, 'hess') 23 | %% Compute the Hessian using hess. 24 | 25 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 26 | if ~is_octave 27 | narg = nargin(problem.hess); 28 | else 29 | narg = 3; 30 | end 31 | 32 | % Check whether the hess function wants to deal with the store 33 | % structure or not. 34 | switch narg 35 | case 2 36 | hess = problem.hess(x, d); 37 | case 3 38 | % Obtain, pass along, and save the store structure 39 | % associated to this point. 40 | store = getStore(problem, x, storedb); 41 | [hess, store] = problem.hess(x, d, store); 42 | storedb = setStore(problem, x, storedb, store); 43 | otherwise 44 | up = MException('manopt:getHessian:badhess', ... 45 | 'hess should accept 2 or 3 inputs.'); 46 | throw(up); 47 | end 48 | 49 | elseif isfield(problem, 'ehess') && canGetEuclideanGradient(problem) 50 | %% Compute the Hessian using ehess. 51 | 52 | % We will need the Euclidean gradient for the conversion from the 53 | % Euclidean Hessian to the Riemannian Hessian. 54 | [egrad, storedb] = getEuclideanGradient(problem, x, storedb); 55 | 56 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 57 | if ~is_octave 58 | narg = nargin(problem.ehess); 59 | else 60 | narg = 3; 61 | end 62 | 63 | % Check whether the ehess function wants to deal with the store 64 | % structure or not. 65 | switch narg 66 | case 2 67 | ehess = problem.ehess(x, d); 68 | case 3 69 | % Obtain, pass along, and save the store structure 70 | % associated to this point. 71 | store = getStore(problem, x, storedb); 72 | [ehess, store] = problem.ehess(x, d, store); 73 | storedb = setStore(problem, x, storedb, store); 74 | otherwise 75 | up = MException('manopt:getHessian:badehess', ... 76 | 'ehess should accept 2 or 3 inputs.'); 77 | throw(up); 78 | end 79 | 80 | % Convert to the Riemannian Hessian 81 | hess = problem.M.ehess2rhess(x, egrad, ehess, d); 82 | 83 | else 84 | %% Attempt the computation of an approximation of the Hessian. 85 | 86 | [hess, storedb] = getApproxHessian(problem, x, d, storedb); 87 | 88 | end 89 | 90 | end 91 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getHessianFD.m: -------------------------------------------------------------------------------- 1 | function [hessfd, storedb] = getHessianFD(problem, x, d, storedb) 2 | % Computes an approx. of the Hessian w/ finite differences of the gradient. 3 | % 4 | % function [hessfd, storedb] = getHessianFD(problem, x, d, storedb) 5 | % 6 | % Return a finite difference approximation of the Hessian at x along d of 7 | % the cost function described in the problem structure. The cache database 8 | % storedb is passed along, possibly modified and returned in the process. 9 | % The finite difference is based on computations of the gradient. 10 | % 11 | % If the gradient cannot be computed, an exception is thrown. 12 | 13 | % This file is part of Manopt: www.manopt.org. 14 | % Original author: Nicolas Boumal, Dec. 30, 2012. 15 | % Contributors: 16 | % Change log: 17 | 18 | 19 | if ~canGetGradient(problem) 20 | up = MException('manopt:getHessianFD:nogradient', ... 21 | 'getHessianFD requires the gradient to be computable.'); 22 | throw(up); 23 | end 24 | 25 | % First, check whether the step d is not too small 26 | if problem.M.norm(x, d) < eps 27 | hessfd = problem.M.zerovec(x); 28 | return; 29 | end 30 | 31 | % Parameter: how far do we look? 32 | % TODO: this parameter should be tunable by the users. 33 | epsilon = 1e-4; 34 | 35 | % TODO: to make this approximation of the Hessian radially linear, that 36 | % is, to have that HessianFD(x, a*d) = a*HessianFD(x, d) for all 37 | % points x, tangent vectors d and real scalars a, we need to pay 38 | % special attention to the case of a < 0. This requires a notion of 39 | % "sign" for vectors d. 40 | % If vectors are uniquely represented by a matrix (which is the case 41 | % for Riemannian submanifolds of the space of matrices), than this 42 | % function is appropriate: 43 | % sg = sign(d(find(d(:), 1, 'first'))); 44 | % But it is more difficult to build such a function in general. For 45 | % now, we ignore this difficulty and let the sign always be +1. This 46 | % hardly impacts the actual performances, but may be an obstacle for 47 | % theoretical analysis. 48 | sg = 1; 49 | norm_d = problem.M.norm(x, d); 50 | c = epsilon*sg/norm_d; 51 | 52 | % Compute the gradient at the current point. 53 | [grad0 storedb] = getGradient(problem, x, storedb); 54 | 55 | % Compute a point a little further along d and the gradient there. 56 | x1 = problem.M.retr(x, d, c); 57 | [grad1 storedb] = getGradient(problem, x1, storedb); 58 | 59 | % Transport grad1 back from x1 to x. 60 | grad1 = problem.M.transp(x1, x, grad1); 61 | 62 | % Return the finite difference of them. 63 | hessfd = problem.M.lincomb(x, 1/c, grad1, -1/c, grad0); 64 | 65 | end 66 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getLinesearch.m: -------------------------------------------------------------------------------- 1 | function [t, storedb] = getLinesearch(problem, x, d, storedb) 2 | % Returns a hint for line-search algorithms. 3 | % 4 | % function [t, storedb] = getLinesearch(problem, x, d, storedb) 5 | % 6 | % For a line-search problem at x along the tangent direction d, computes 7 | % and returns t such that retracting t*d at x yields a good point around 8 | % where to look for a line-search solution. That is: t is a hint as to "how 9 | % far to look" along the line. 10 | % 11 | % The cache database storedb is passed along, possibly modified and 12 | % returned in the process. 13 | % 14 | % See also: canGetLinesearch 15 | 16 | % This file is part of Manopt: www.manopt.org. 17 | % Original author: Nicolas Boumal, July 17, 2014. 18 | % Contributors: 19 | % Change log: 20 | 21 | 22 | if isfield(problem, 'linesearch') 23 | %% Compute the line-search hint function using linesearch. 24 | 25 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 26 | if ~is_octave 27 | narg = nargin(problem.linesearch); 28 | else 29 | narg = 3; 30 | end 31 | 32 | % Check whether the linesearch function wants to deal with the 33 | % store structure or not. 34 | switch narg 35 | case 2 36 | t = problem.linesearch(x, d); 37 | case 3 38 | % Obtain, pass along, and save the store structure 39 | % associated to this point. 40 | store = getStore(problem, x, storedb); 41 | [t, store] = problem.linesearch(x, d, store); 42 | storedb = setStore(problem, x, storedb, store); 43 | otherwise 44 | up = MException('manopt:getLinesearch:badfun', ... 45 | 'linesearch should accept 2 or 3 inputs.'); 46 | throw(up); 47 | end 48 | 49 | else 50 | %% Abandon computing the line-search function. 51 | 52 | up = MException('manopt:getLinesearch:fail', ... 53 | ['The problem description is not explicit enough to ' ... 54 | 'compute a line-search hint.']); 55 | throw(up); 56 | 57 | end 58 | 59 | end 60 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getPrecon.m: -------------------------------------------------------------------------------- 1 | function [Pd, storedb] = getPrecon(problem, x, d, storedb) 2 | % Applies the preconditioner for the Hessian of the cost at x along d. 3 | % 4 | % function [Pd, storedb] = getPrecon(problem, x, storedb) 5 | % 6 | % Returns as Pd the result of applying the Hessian preconditioner to the 7 | % tangent vector d at point x. If no preconditioner is specified, Pd = d 8 | % (identity). 9 | % 10 | % See also: getHessian 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, Dec. 30, 2012. 14 | % Contributors: 15 | % Change log: 16 | 17 | 18 | if isfield(problem, 'precon') 19 | %% Compute the preconditioning using precon. 20 | 21 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 22 | if ~is_octave 23 | narg = nargin(problem.precon); 24 | else 25 | narg = 3; 26 | end 27 | 28 | % Check whether the precon function wants to deal with the store 29 | % structure or not. 30 | switch narg 31 | case 2 32 | Pd = problem.precon(x, d); 33 | case 3 34 | % Obtain, pass along, and save the store structure 35 | % associated to this point. 36 | store = getStore(problem, x, storedb); 37 | [Pd store] = problem.precon(x, d, store); 38 | storedb = setStore(problem, x, storedb, store); 39 | otherwise 40 | up = MException('manopt:getPrecon:badprecon', ... 41 | 'precon should accept 2 or 3 inputs.'); 42 | throw(up); 43 | end 44 | 45 | else 46 | %% No preconditioner provided, so just use the identity. 47 | 48 | Pd = d; 49 | 50 | end 51 | 52 | end 53 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/getStore.m: -------------------------------------------------------------------------------- 1 | function store = getStore(problem, x, storedb) 2 | % Extracts a store struct. pertaining to a point from the storedb database. 3 | % 4 | % function store = getStore(problem, x, storedb) 5 | % 6 | % Queries the storedb database of structures (itself a structure) and 7 | % returns the store structure corresponding to the point x. If there is no 8 | % record for the point x, returns an empty structure. 9 | % 10 | % See also: setStore purgeStoredb 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, Dec. 30, 2012. 14 | % Contributors: 15 | % Change log: 16 | 17 | % Construct the fieldname (key) associated to the queried point x. 18 | key = problem.M.hash(x); 19 | 20 | % If there is a value stored for this key, return it. 21 | % Otherwise, return an empty structure. 22 | if isfield(storedb, key) 23 | store = storedb.(key); 24 | else 25 | store = struct(); 26 | end 27 | 28 | end 29 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/hashmd5.m: -------------------------------------------------------------------------------- 1 | function h = hashmd5(inp) 2 | % Computes the MD5 hash of input data. 3 | % 4 | % function h = hashmd5(inp) 5 | % 6 | % Returns a string containing the MD5 hash of the input variable. The input 7 | % variable may be of any class that can be typecast to uint8 format, which 8 | % is fairly non-restrictive. 9 | 10 | % This file is part of Manopt: www.manopt.org. 11 | % This code is a stripped version of more general hashing code by 12 | % Michael Kleder, Nov 2005. 13 | % Change log: 14 | % Aug. 8, 2013 (NB) : Made x a static (persistent) variable, in the hope 15 | % it will speed it up. Furthermore, the function is 16 | % now Octave compatible. 17 | 18 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 19 | 20 | persistent x; 21 | if isempty(x) && ~is_octave 22 | x = java.security.MessageDigest.getInstance('MD5'); 23 | end 24 | 25 | inp=inp(:); 26 | % Convert strings and logicals into uint8 format 27 | if ischar(inp) || islogical(inp) 28 | inp=uint8(inp); 29 | else % Convert everything else into uint8 format without loss of data 30 | inp=typecast(inp,'uint8'); 31 | end 32 | 33 | % Create hash 34 | if ~is_octave 35 | x.update(inp); 36 | h = typecast(x.digest, 'uint8'); 37 | h = dec2hex(h)'; 38 | % Remote possibility: all hash bytes < 128, so pad: 39 | if(size(h,1))==1 40 | h = [repmat('0',[1 size(h,2)]);h]; 41 | end 42 | h = lower(h(:)'); 43 | else 44 | h = md5sum(char(inp'), true); 45 | end 46 | 47 | end 48 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/mergeOptions.m: -------------------------------------------------------------------------------- 1 | function opts = mergeOptions(opts1, opts2) 2 | % Merges two options structures with one having precedence over the other. 3 | % 4 | % function opts = mergeOptions(opts1, opts2) 5 | % 6 | % input: opts1 and opts2 are two structures. 7 | % output: opts is a structure containing all fields of opts1 and opts2. 8 | % Whenever a field is present in both opts1 and opts2, it is the value in 9 | % opts2 that is kept. 10 | % 11 | % The typical usage is to have opts1 contain default options and opts2 12 | % contain user-specified options that overwrite the defaults. 13 | % 14 | % See also: getGlobalDefaults 15 | 16 | % This file is part of Manopt: www.manopt.org. 17 | % Original author: Nicolas Boumal, Dec. 30, 2012. 18 | % Contributors: 19 | % Change log: 20 | 21 | 22 | if isempty(opts1) 23 | opts1 = struct(); 24 | end 25 | if isempty(opts2) 26 | opts2 = struct(); 27 | end 28 | 29 | opts = opts1; 30 | fields = fieldnames(opts2); 31 | for i = 1 : length(fields) 32 | opts.(fields{i}) = opts2.(fields{i}); 33 | end 34 | 35 | end 36 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/purgeStoredb.m: -------------------------------------------------------------------------------- 1 | function storedb = purgeStoredb(storedb, storedepth) 2 | % Makes sure the storedb database does not exceed some maximum size. 3 | % 4 | % function storedb = purgeStoredb(storedb, storedepth) 5 | % 6 | % Trim the store database storedb such that it contains at most storedepth 7 | % elements (store structures). The 'lastset__' field of the store 8 | % structures is used to delete the oldest elements first. 9 | 10 | % This file is part of Manopt: www.manopt.org. 11 | % Original author: Nicolas Boumal, Dec. 30, 2012. 12 | % Contributors: 13 | % Change log: 14 | % Dec. 6, 2013, NB: 15 | % Now using a persistent uint32 counter instead of cputime to track 16 | % the most recently modified stores. 17 | 18 | 19 | if storedepth <= 0 20 | storedb = struct(); 21 | return; 22 | end 23 | 24 | % Get list of field names (keys). 25 | keys = fieldnames(storedb); 26 | nkeys = length(keys); 27 | 28 | % If we need to remove some of the elements in the database. 29 | if nkeys > storedepth 30 | 31 | % Get the last-set counter of each element: a higher number means 32 | % it was modified more recently. 33 | lastset = zeros(nkeys, 1, 'uint32'); 34 | for i = 1 : nkeys 35 | store = storedb.(keys{i}); 36 | lastset(i) = store.lastset__; 37 | end 38 | 39 | % Sort the counters and determine the threshold above which the 40 | % field needs to be removed. 41 | sortlastset = sort(lastset, 1, 'descend'); 42 | minlastset = sortlastset(storedepth); 43 | 44 | % Remove all fields that are too old. 45 | storedb = rmfield(storedb, keys(lastset < minlastset)); 46 | 47 | end 48 | 49 | end 50 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/setStore.m: -------------------------------------------------------------------------------- 1 | function storedb = setStore(problem, x, storedb, store) 2 | % Updates the store struct. pertaining to a point in the storedb database. 3 | % 4 | % function storedb = setStore(problem, x, storedb, store) 5 | % 6 | % Updates the storedb database of structures such that the structure 7 | % corresponding to the point x will be replaced by store. If there was no 8 | % record for the point x, it is created and set to store. The updated 9 | % storedb database is returned. The lastset__ field of the store structure 10 | % keeps track of which stores were updated latest. 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, Dec. 30, 2012. 14 | % Contributors: 15 | % Change log: 16 | % Dec. 6, 2013, NB: 17 | % Now using a persistent uint32 counter instead of cputime to track 18 | % the most recently modified stores. 19 | 20 | % This persistent counter is used to keep track of the order in which 21 | % store structures are updated. This is used by purgeStoredb to erase 22 | % the least recently useful store structures first when garbage 23 | % collecting. 24 | persistent counter; 25 | if isempty(counter) 26 | counter = uint32(0); 27 | end 28 | 29 | assert(nargout == 1, ... 30 | 'The output of setStore should replace your storedb.'); 31 | 32 | % Construct the fieldname (key) associated to the current point x. 33 | key = problem.M.hash(x); 34 | 35 | % Set the value associated to that key to store. 36 | storedb.(key) = store; 37 | 38 | % Add / update a last-set flag. 39 | storedb.(key).lastset__ = counter; 40 | counter = counter + 1; 41 | 42 | end 43 | -------------------------------------------------------------------------------- /manopt/manopt/privatetools/stoppingcriterion.m: -------------------------------------------------------------------------------- 1 | function [stop reason] = stoppingcriterion(problem, x, options, info, last) 2 | % Checks for standard stopping criteria, as a helper to solvers. 3 | % 4 | % function [stop reason] = stoppingcriterion(problem, x, options, info, last) 5 | % 6 | % Executes standard stopping criterion checks, based on what is defined in 7 | % the info(last) stats structure and in the options structure. 8 | % 9 | % The returned number 'stop' is 0 if none of the stopping criteria 10 | % triggered, and a (strictly) positive integer otherwise. The integer 11 | % identifies which criterion triggered: 12 | % 0 : Nothing triggered; 13 | % 1 : Cost tolerance reached; 14 | % 2 : Gradient norm tolerance reached; 15 | % 3 : Max time exceeded; 16 | % 4 : Max iteration count reached; 17 | % 5 : Maximum number of cost evaluations reached; 18 | % 6 : User defined stopfun criterion triggered. 19 | % 20 | % The output 'reason' is a string describing the triggered event. 21 | 22 | % This file is part of Manopt: www.manopt.org. 23 | % Original author: Nicolas Boumal, Dec. 30, 2012. 24 | % Contributors: 25 | % Change log: 26 | 27 | 28 | stop = 0; 29 | reason = ''; 30 | 31 | stats = info(last); 32 | 33 | % Target cost attained 34 | if isfield(stats, 'cost') && isfield(options, 'tolcost') && ... 35 | stats.cost <= options.tolcost 36 | reason = 'Cost tolerance reached. See options.tolcost.'; 37 | stop = 1; 38 | return; 39 | end 40 | 41 | % Target gradient norm attained 42 | if isfield(stats, 'gradnorm') && isfield(options, 'tolgradnorm') && ... 43 | stats.gradnorm < options.tolgradnorm 44 | reason = 'Gradient norm tolerance reached. See options.tolgradnorm.'; 45 | stop = 2; 46 | return; 47 | end 48 | 49 | % Alloted time exceeded 50 | if isfield(stats, 'time') && isfield(options, 'maxtime') && ... 51 | stats.time >= options.maxtime 52 | reason = 'Max time exceeded. See options.maxtime.'; 53 | stop = 3; 54 | return; 55 | end 56 | 57 | % Alloted iteration count exceeded 58 | if isfield(stats, 'iter') && isfield(options, 'maxiter') && ... 59 | stats.iter >= options.maxiter 60 | reason = 'Max iteration count reached. See options.maxiter.'; 61 | stop = 4; 62 | return; 63 | end 64 | 65 | % Alloted function evaluation count exceeded 66 | if isfield(stats, 'costevals') && isfield(options, 'maxcostevals') && ... 67 | stats.costevals >= options.maxcostevals 68 | reason = 'Maximum number of cost evaluations reached. See options.maxcostevals.'; 69 | stop = 5; 70 | end 71 | 72 | % Check whether the possibly user defined stopping criterion 73 | % triggers or not. 74 | if isfield(options, 'stopfun') 75 | userstop = options.stopfun(problem, x, info, last); 76 | if userstop 77 | reason = 'User defined stopfun criterion triggered. See options.stopfun.'; 78 | stop = 6; 79 | return; 80 | end 81 | end 82 | 83 | end 84 | -------------------------------------------------------------------------------- /manopt/manopt/solvers/linesearch/linesearch_adaptive.m: -------------------------------------------------------------------------------- 1 | function [stepsize newx storedb lsmem lsstats] = ... 2 | linesearch_adaptive(problem, x, d, f0, df0, options, storedb, lsmem) 3 | % Adaptive line search algorithm (step size selection) for descent methods. 4 | % 5 | % function [stepsize newx storedb lsmem lsstats] = 6 | % linesearch_adaptive(problem, x, d, f0, df0, options, storedb, lsmem) 7 | % 8 | % Adaptive linesearch algorithm for descent methods, based on a simple 9 | % backtracking method. On average, this line search intends to do only one 10 | % or two cost evaluations. 11 | % 12 | % Contrary to linesearch.m, this function is not invariant under rescaling 13 | % of the search direction d. Nevertheless, it sometimes performs better. 14 | % 15 | % Inputs/Outputs : see help for linesearch 16 | % 17 | % See also: steepestdescent conjugategradients linesearch 18 | 19 | % This file is part of Manopt: www.manopt.org. 20 | % Original author: Bamdev Mishra, Dec. 30, 2012. 21 | % Contributors: Nicolas Boumal 22 | % Change log: 23 | % 24 | % Sept. 13, 2013 (NB) : 25 | % The automatic direction reversal feature was removed (it triggered 26 | % when df0 > 0). Direction reversal is a decision that needs to be 27 | % made by the solver, so it can know about it. 28 | % 29 | % Nov. 7, 2013 (NB) : 30 | % The whole function has been recoded to mimick more closely the new 31 | % version of linesearch.m. The parameters are available through the 32 | % options structure passed to the solver and have the same names and 33 | % same meaning as for the base linesearch. The information is logged 34 | % more reliably. 35 | 36 | 37 | % Backtracking default parameters. These can be overwritten in the 38 | % options structure which is passed to the solver. 39 | default_options.ls_contraction_factor = .5; 40 | default_options.ls_suff_decr = .5; 41 | default_options.ls_max_steps = 10; 42 | default_options.ls_initial_stepsize = 1; 43 | options = mergeOptions(default_options, options); 44 | 45 | contraction_factor = options.ls_contraction_factor; 46 | suff_decr = options.ls_suff_decr; 47 | max_ls_steps = options.ls_max_steps; 48 | initial_stepsize = options.ls_initial_stepsize; 49 | 50 | % Compute the norm of the search direction. 51 | norm_d = problem.M.norm(x, d); 52 | 53 | % If this is not the first iteration, then lsmem should have been 54 | % filled with a suggestion for the initial step. 55 | if isstruct(lsmem) && isfield(lsmem, 'init_alpha') 56 | % Pick initial step size based on where we were last time, 57 | alpha = lsmem.init_alpha; 58 | 59 | % Otherwise, fall back to a user supplied suggestion. 60 | else 61 | alpha = initial_stepsize / norm_d; 62 | end 63 | 64 | % Make the chosen step and compute the cost there. 65 | newx = problem.M.retr(x, d, alpha); 66 | [newf storedb] = getCost(problem, newx, storedb); 67 | cost_evaluations = 1; 68 | 69 | % Backtrack while the Armijo criterion is not satisfied 70 | while newf > f0 + suff_decr*alpha*df0 71 | 72 | % Reduce the step size, 73 | alpha = contraction_factor * alpha; 74 | 75 | % and look closer down the line 76 | newx = problem.M.retr(x, d, alpha); 77 | [newf storedb] = getCost(problem, newx, storedb); 78 | cost_evaluations = cost_evaluations + 1; 79 | 80 | % Make sure we don't run out of budget 81 | if cost_evaluations >= max_ls_steps 82 | break; 83 | end 84 | 85 | end 86 | 87 | % If we got here without obtaining a decrease, we reject the step. 88 | if newf > f0 89 | alpha = 0; 90 | newx = x; 91 | newf = f0; %#ok 92 | end 93 | 94 | % As seen outside this function, stepsize is the size of the vector we 95 | % retract to make the step from x to newx. Since the step is alpha*d: 96 | stepsize = alpha * norm_d; 97 | 98 | % Fill lsmem with a suggestion for what the next initial step size 99 | % trial should be. On average we intend to do only one extra cost 100 | % evaluation. Notice how the suggestion is not about stepsize but about 101 | % alpha. This is the reason why this line search is not invariant under 102 | % rescaling of the search direction d. 103 | switch cost_evaluations 104 | case 1 105 | % If things go well, push your luck. 106 | lsmem.init_alpha = 2 * alpha; 107 | case 2 108 | % If things go smoothly, try to keep pace. 109 | lsmem.init_alpha = alpha; 110 | otherwise 111 | % If you backtracked a lot, the new stepsize is probably quite 112 | % small: try to recover. 113 | lsmem.init_alpha = 2 * alpha; 114 | end 115 | 116 | % Save some statistics also, for possible analysis. 117 | lsstats.costevals = cost_evaluations; 118 | lsstats.stepsize = stepsize; 119 | lsstats.alpha = alpha; 120 | 121 | end 122 | -------------------------------------------------------------------------------- /manopt/manopt/solvers/linesearch/linesearch_hint.m: -------------------------------------------------------------------------------- 1 | function [stepsize, newx, storedb, lsmem, lsstats] = ... 2 | linesearch_hint(problem, x, d, f0, df0, options, storedb, lsmem) 3 | % Armijo line-search based on the line-search hint in the problem structure. 4 | % 5 | % function [stepsize, newx, storedb, lsmem, lsstats] = 6 | % linesearch_hint(problem, x, d, f0, df0, options, storedb, lsmem) 7 | % 8 | % Base line-search algorithm for descent methods, based on a simple 9 | % backtracking method. The search direction provided has to be a descent 10 | % direction, as indicated by a negative df0 = directional derivative of f 11 | % at x along d. 12 | % 13 | % The algorithm obtains an initial step size candidate from the problem 14 | % structure, typically through the problem.linesearch function. If that 15 | % step does not fulfill the Armijo sufficient decrease criterion, that step 16 | % size is reduced geometrically until a satisfactory step size is obtained 17 | % or until a failure criterion triggers. 18 | % 19 | % Below, the step will be constructed as alpha*d, and the step size is the 20 | % norm of that vector, thus: stepsize = alpha*norm_d. The step is executed 21 | % by retracting the vector alpha*d from the current point x, giving newx. 22 | % 23 | % Inputs 24 | % 25 | % problem : structure holding the description of the optimization problem 26 | % x : current point on the manifold problem.M 27 | % d : tangent vector at x (descent direction) 28 | % f0 : cost value at x 29 | % df0 : directional derivative at x along d 30 | % options : options structure (see in code for usage) 31 | % storedb : store database structure for caching purposes 32 | % lsmem : not used 33 | % 34 | % Outputs 35 | % 36 | % stepsize : norm of the vector retracted to reach newx from x. 37 | % newx : next iterate suggested by the line-search algorithm, such that 38 | % the retraction at x of the vector alpha*d reaches newx. 39 | % storedb : the (possibly updated) store database structure. 40 | % lsmem : not used. 41 | % lsstats : statistics about the line-search procedure (stepsize, number 42 | % of trials etc). 43 | % 44 | % See also: steepestdescent conjugategradients linesearch 45 | 46 | % This file is part of Manopt: www.manopt.org. 47 | % Original author: Nicolas Boumal, July 17, 2014. 48 | % Contributors: 49 | % Change log: 50 | % 51 | 52 | 53 | % Backtracking default parameters. These can be overwritten in the 54 | % options structure which is passed to the solver. 55 | default_options.ls_contraction_factor = .5; 56 | default_options.ls_suff_decr = 1e-4; 57 | default_options.ls_max_steps = 25; 58 | 59 | options = mergeOptions(default_options, options); 60 | 61 | contraction_factor = options.ls_contraction_factor; 62 | suff_decr = options.ls_suff_decr; 63 | max_ls_steps = options.ls_max_steps; 64 | 65 | % Obtain an initial guess at alpha from the problem structure. It is 66 | % assumed that the present line-search is only called when the problem 67 | % structure provides enough information for the call here to work. 68 | [alpha, storedb] = getLinesearch(problem, x, d, storedb); 69 | 70 | % Make the chosen step and compute the cost there. 71 | newx = problem.M.retr(x, d, alpha); 72 | [newf, storedb] = getCost(problem, newx, storedb); 73 | cost_evaluations = 1; 74 | 75 | % Backtrack while the Armijo criterion is not satisfied 76 | while newf > f0 + suff_decr*alpha*df0 77 | 78 | % Reduce the step size, 79 | alpha = contraction_factor * alpha; 80 | 81 | % and look closer down the line 82 | newx = problem.M.retr(x, d, alpha); 83 | [newf, storedb] = getCost(problem, newx, storedb); 84 | cost_evaluations = cost_evaluations + 1; 85 | 86 | % Make sure we don't run out of budget 87 | if cost_evaluations >= max_ls_steps 88 | break; 89 | end 90 | 91 | end 92 | 93 | % If we got here without obtaining a decrease, we reject the step. 94 | if newf > f0 95 | alpha = 0; 96 | newx = x; 97 | newf = f0; %#ok 98 | end 99 | 100 | % As seen outside this function, stepsize is the size of the vector we 101 | % retract to make the step from x to newx. Since the step is alpha*d: 102 | norm_d = problem.M.norm(x, d); 103 | stepsize = alpha * norm_d; 104 | 105 | % Save some statistics also, for possible analysis. 106 | lsstats.costevals = cost_evaluations; 107 | lsstats.stepsize = stepsize; 108 | lsstats.alpha = alpha; 109 | 110 | end 111 | -------------------------------------------------------------------------------- /manopt/manopt/solvers/neldermead/centroid.m: -------------------------------------------------------------------------------- 1 | function y = centroid(M, x) 2 | % Attempts the computation of a centroid of a set of points on amanifold. 3 | % 4 | % function y = centroid(M, x) 5 | % 6 | % M is a structure representing a manifold. x is a cell of points on that 7 | % manifold. 8 | 9 | % This file is part of Manopt: www.manopt.org. 10 | % Original author: Nicolas Boumal, Dec. 30, 2012. 11 | % Contributors: 12 | % Change log: 13 | 14 | 15 | % For now, just apply a few steps of gradient descent for Karcher means 16 | 17 | n = numel(x); 18 | 19 | problem.M = M; 20 | 21 | problem.cost = @cost; 22 | function val = cost(y) 23 | val = 0; 24 | for i = 1 : n 25 | val = val + M.dist(y, x{i})^2; 26 | end 27 | val = val/2; 28 | end 29 | 30 | problem.grad = @grad; 31 | function g = grad(y) 32 | g = M.zerovec(y); 33 | for i = 1 : n 34 | g = M.lincomb(y, 1, g, -1, M.log(y, x{i})); 35 | end 36 | end 37 | 38 | % This line can be uncommented to check that the gradient is indeed 39 | % correct. This should always be the case if the dist and the log 40 | % functions in the manifold are correct. 41 | % checkgradient(problem); 42 | 43 | query = warning('query', 'manopt:getHessian:approx'); 44 | warning('off', 'manopt:getHessian:approx') 45 | options.verbosity = 0; 46 | options.maxiter = 15; 47 | y = trustregions(problem, x{randi(n)}, options); 48 | warning(query.state, 'manopt:getHessian:approx') 49 | 50 | end 51 | -------------------------------------------------------------------------------- /manopt/manopt/solvers/trustregions/license for original GenRTR code.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2007,2012 Christopher G. Baker, Pierre-Antoine Absil, Kyle A. Gallivan 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the names of the contributors nor of their affiliated 12 | institutions may be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | 26 | For questions, please contact Chris Baker (chris@cgbaker.net) 27 | -------------------------------------------------------------------------------- /manopt/manopt/tools/checkdiff.m: -------------------------------------------------------------------------------- 1 | function checkdiff(problem, x, d) 2 | % Checks the consistency of the cost function and directional derivatives. 3 | % 4 | % function checkdiff(problem) 5 | % function checkdiff(problem, x) 6 | % function checkdiff(problem, x, d) 7 | % 8 | % checkdiff performs a numerical test to check that the directional 9 | % derivatives defined in the problem structure agree up to first order with 10 | % the cost function at some point x, along some direction d. The test is 11 | % based on a truncated Taylor series (see online Manopt documentation). 12 | % 13 | % Both x and d are optional and will be sampled at random if omitted. 14 | % 15 | % See also: checkgradient checkhessian 16 | 17 | % This file is part of Manopt: www.manopt.org. 18 | % Original author: Nicolas Boumal, Dec. 30, 2012. 19 | % Contributors: 20 | % Change log: 21 | 22 | 23 | % Verify that the problem description is sufficient. 24 | if ~canGetCost(problem) 25 | error('It seems no cost was provided.'); 26 | end 27 | if ~canGetDirectionalDerivative(problem) 28 | error('It seems no directional derivatives were provided.'); 29 | end 30 | 31 | dbstore = struct(); 32 | 33 | x_isprovided = exist('x', 'var') && ~isempty(x); 34 | d_isprovided = exist('d', 'var') && ~isempty(d); 35 | 36 | if ~x_isprovided && d_isprovided 37 | error('If d is provided, x must be too, since d is tangent at x.'); 38 | end 39 | 40 | % If x and / or d are not specified, pick them at random. 41 | if ~x_isprovided 42 | x = problem.M.rand(); 43 | end 44 | if ~d_isprovided 45 | d = problem.M.randvec(x); 46 | end 47 | 48 | % Compute the value f0 at f and directional derivative at x along d. 49 | f0 = getCost(problem, x, dbstore); 50 | df0 = getDirectionalDerivative(problem, x, d, dbstore); 51 | 52 | % Compute the value of f at points on the geodesic (or approximation of 53 | % it) originating from x, along direction d, for stepsizes in a large 54 | % range given by h. 55 | h = logspace(-8, 0, 51); 56 | value = zeros(size(h)); 57 | for i = 1 : length(h) 58 | y = problem.M.exp(x, d, h(i)); 59 | value(i) = getCost(problem, y, dbstore); 60 | end 61 | 62 | % Compute the linear approximation of the cost function using f0 and 63 | % df0 at the same points. 64 | model = polyval([df0 f0], h); 65 | 66 | % Compute the approximation error 67 | err = abs(model - value); 68 | 69 | % And plot it. 70 | loglog(h, err); 71 | title(sprintf(['Directional derivative check.\nThe slope of the '... 72 | 'continuous line should match that of the dashed '... 73 | '(reference) line\nover at least a few orders of '... 74 | 'magnitude for h.'])); 75 | xlabel('h'); 76 | ylabel('Approximation error'); 77 | 78 | line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ... 79 | 'color', 'k', 'LineStyle', '--', ... 80 | 'YLimInclude', 'off', 'XLimInclude', 'off'); 81 | 82 | 83 | % In a numerically reasonable neighborhood, the error should decrease 84 | % as the square of the stepsize, i.e., in loglog scale, the error 85 | % should have a slope of 2. 86 | window_len = 10; 87 | [range poly] = identify_linear_piece(log10(h), log10(err), window_len); 88 | hold on; 89 | loglog(h(range), 10.^polyval(poly, log10(h(range))), ... 90 | 'r-', 'LineWidth', 3); 91 | hold off; 92 | 93 | fprintf('The slope should be 2. It appears to be: %g.\n', poly(1)); 94 | fprintf(['If it is far from 2, then directional derivatives ' ... 95 | 'might be erroneous.\n']); 96 | 97 | end 98 | -------------------------------------------------------------------------------- /manopt/manopt/tools/checkdiffSingle.m: -------------------------------------------------------------------------------- 1 | function checkdiff(problem, x, d) 2 | % Checks the consistency of the cost function and directional derivatives. 3 | % 4 | % function checkdiff(problem) 5 | % function checkdiff(problem, x) 6 | % function checkdiff(problem, x, d) 7 | % 8 | % checkdiff performs a numerical test to check that the directional 9 | % derivatives defined in the problem structure agree up to first order with 10 | % the cost function at some point x, along some direction d. The test is 11 | % based on a truncated Taylor series (see online Manopt documentation). 12 | % 13 | % Both x and d are optional and will be sampled at random if omitted. 14 | % 15 | % See also: checkgradient checkhessian 16 | 17 | % This file is part of Manopt: www.manopt.org. 18 | % Original author: Nicolas Boumal, Dec. 30, 2012. 19 | % Contributors: 20 | % Change log: 21 | 22 | 23 | % Verify that the problem description is sufficient. 24 | if ~canGetCost(problem) 25 | error('It seems no cost was provided.'); 26 | end 27 | if ~canGetDirectionalDerivative(problem) 28 | error('It seems no directional derivatives were provided.'); 29 | end 30 | 31 | dbstore = struct(); 32 | 33 | x_isprovided = exist('x', 'var') && ~isempty(x); 34 | d_isprovided = exist('d', 'var') && ~isempty(d); 35 | 36 | if ~x_isprovided && d_isprovided 37 | error('If d is provided, x must be too, since d is tangent at x.'); 38 | end 39 | 40 | % If x and / or d are not specified, pick them at random. 41 | if ~x_isprovided 42 | x = problem.M.rand(); 43 | end 44 | if ~d_isprovided 45 | d = problem.M.randvec(x); 46 | end 47 | 48 | % Compute the value f0 at f and directional derivative at x along d. 49 | f0 = getCost(problem, x, dbstore); 50 | df0 = getDirectionalDerivative(problem, x, d, dbstore); 51 | 52 | % Compute the value of f at points on the geodesic (or approximation of 53 | % it) originating from x, along direction d, for stepsizes in a large 54 | % range given by h. 55 | h = logspace(-8, 0, 51); 56 | value = zeros(size(h)); 57 | for i = 1 : length(h) 58 | y = problem.M.exp(x, d, h(i)); 59 | value(i) = getCost(problem, y, dbstore); 60 | end 61 | 62 | % Compute the linear approximation of the cost function using f0 and 63 | % df0 at the same points. 64 | model = polyval([df0 f0], h); 65 | 66 | % Compute the approximation error 67 | err = abs(model - value); 68 | 69 | % And plot it. 70 | loglog(h, err); 71 | title(sprintf(['Directional derivative check.\nThe slope of the '... 72 | 'continuous line should match that of the dashed '... 73 | '(reference) line\nover at least a few orders of '... 74 | 'magnitude for h.'])); 75 | xlabel('h'); 76 | ylabel('Approximation error'); 77 | 78 | line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ... 79 | 'color', 'k', 'LineStyle', '--', ... 80 | 'YLimInclude', 'off', 'XLimInclude', 'off'); 81 | 82 | 83 | % In a numerically reasonable neighborhood, the error should decrease 84 | % as the square of the stepsize, i.e., in loglog scale, the error 85 | % should have a slope of 2. 86 | window_len = 10; 87 | [range poly] = identify_linear_piece(log10(h), log10(err), window_len); 88 | hold on; 89 | loglog(h(range), 10.^polyval(poly, log10(h(range))), ... 90 | 'r-', 'LineWidth', 3); 91 | hold off; 92 | 93 | fprintf('The slope should be 2. It appears to be: %g.\n', poly(1)); 94 | fprintf(['If it is far from 2, then directional derivatives ' ... 95 | 'might be erroneous.\n']); 96 | 97 | end 98 | -------------------------------------------------------------------------------- /manopt/manopt/tools/checkdiffSinglePrecision.m: -------------------------------------------------------------------------------- 1 | function checkdiffSinglePrecision(problem, x, d) 2 | % Checks the consistency of the cost function and directional derivatives. 3 | % 4 | % function checkdiff(problem) 5 | % function checkdiff(problem, x) 6 | % function checkdiff(problem, x, d) 7 | % 8 | % checkdiff performs a numerical test to check that the directional 9 | % derivatives defined in the problem structure agree up to first order with 10 | % the cost function at some point x, along some direction d. The test is 11 | % based on a truncated Taylor series (see online Manopt documentation). 12 | % 13 | % Both x and d are optional and will be sampled at random if omitted. 14 | % 15 | % See also: checkgradient checkhessian 16 | 17 | % This file is part of Manopt: www.manopt.org. 18 | % Original author: Nicolas Boumal, Dec. 30, 2012. 19 | % Contributors: 20 | % Change log: 21 | 22 | 23 | % Verify that the problem description is sufficient. 24 | if ~canGetCost(problem) 25 | error('It seems no cost was provided.'); 26 | end 27 | if ~canGetDirectionalDerivative(problem) 28 | error('It seems no directional derivatives were provided.'); 29 | end 30 | 31 | dbstore = struct(); 32 | 33 | x_isprovided = exist('x', 'var') && ~isempty(x); 34 | d_isprovided = exist('d', 'var') && ~isempty(d); 35 | 36 | if ~x_isprovided && d_isprovided 37 | error('If d is provided, x must be too, since d is tangent at x.'); 38 | end 39 | 40 | % If x and / or d are not specified, pick them at random. 41 | if ~x_isprovided 42 | x = problem.M.rand(); 43 | end 44 | if ~d_isprovided 45 | d = problem.M.randvec(x); 46 | end 47 | 48 | % Compute the value f0 at f and directional derivative at x along d. 49 | f0 = getCost(problem, x, dbstore); 50 | df0 = getDirectionalDerivative(problem, x, d, dbstore); 51 | 52 | % Compute the value of f at points on the geodesic (or approximation of 53 | % it) originating from x, along direction d, for stepsizes in a large 54 | % range given by h. 55 | h = single(logspace(-8, 0, 51)); 56 | value = single(zeros(size(h))); 57 | for i = 1 : length(h) 58 | y = problem.M.exp(x, d, h(i)); 59 | y.encoder = single(y.encoder); 60 | y.decoder = single(y.decoder); 61 | value(i) = getCost(problem, y, dbstore); 62 | end 63 | 64 | % Compute the linear approximation of the cost function using f0 and 65 | % df0 at the same points. 66 | model = polyval([df0 f0], h); 67 | 68 | % Compute the approximation error 69 | err = abs(model - value); 70 | 71 | % And plot it. 72 | loglog(h, err); 73 | title(sprintf(['Directional derivative check.\nThe slope of the '... 74 | 'continuous line should match that of the dashed '... 75 | '(reference) line\nover at least a few orders of '... 76 | 'magnitude for h.'])); 77 | xlabel('h'); 78 | ylabel('Approximation error'); 79 | 80 | line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ... 81 | 'color', 'k', 'LineStyle', '--', ... 82 | 'YLimInclude', 'off', 'XLimInclude', 'off'); 83 | 84 | 85 | % In a numerically reasonable neighborhood, the error should decrease 86 | % as the square of the stepsize, i.e., in loglog scale, the error 87 | % should have a slope of 2. 88 | window_len = 10; 89 | [range poly] = identify_linear_piece_SinglePrecision(log10(h), log10(err), window_len); 90 | hold on; 91 | loglog(h(range), 10.^polyval(poly, log10(h(range))), ... 92 | 'r-', 'LineWidth', 3); 93 | hold off; 94 | 95 | fprintf('The slope should be 2. It appears to be: %g.\n', poly(1)); 96 | fprintf(['If it is far from 2, then directional derivatives ' ... 97 | 'might be erroneous.\n']); 98 | 99 | end 100 | -------------------------------------------------------------------------------- /manopt/manopt/tools/checkgradient.m: -------------------------------------------------------------------------------- 1 | function checkgradient(problem, x, d) 2 | % Checks the consistency of the cost function and the gradient. 3 | % 4 | % function checkgradient(problem) 5 | % function checkgradient(problem, x) 6 | % function checkgradient(problem, x, d) 7 | % 8 | % checkgradient performs a numerical test to check that the gradient 9 | % defined in the problem structure agrees up to first order with the cost 10 | % function at some point x, along some direction d. The test is based on a 11 | % truncated Taylor series (see online Manopt documentation). 12 | % 13 | % It is also tested that the gradient is indeed a tangent vector. 14 | % 15 | % Both x and d are optional and will be sampled at random if omitted. 16 | % 17 | % See also: checkdiff checkhessian 18 | 19 | % This file is part of Manopt: www.manopt.org. 20 | % Original author: Nicolas Boumal, Dec. 30, 2012. 21 | % Contributors: 22 | % Change log: 23 | 24 | 25 | % Verify that the problem description is sufficient. 26 | if ~canGetCost(problem) 27 | error('It seems no cost was provided.'); 28 | end 29 | if ~canGetGradient(problem) 30 | error('It seems no gradient provided.'); 31 | end 32 | 33 | dbstore = struct(); 34 | 35 | x_isprovided = exist('x', 'var') && ~isempty(x); 36 | d_isprovided = exist('d', 'var') && ~isempty(d); 37 | 38 | if ~x_isprovided && d_isprovided 39 | error('If d is provided, x must be too, since d is tangent at x.'); 40 | end 41 | 42 | % If x and / or d are not specified, pick them at random. 43 | if ~x_isprovided 44 | x = problem.M.rand(); 45 | end 46 | if ~d_isprovided 47 | d = problem.M.randvec(x); 48 | end 49 | 50 | %% Check that the gradient yields a first order model of the cost. 51 | 52 | % By removing the 'directional derivative' function, it should be so 53 | % (?) that the checkdiff function will use the gradient to compute 54 | % directional derivatives. 55 | if isfield(problem, 'diff') 56 | problem = rmfield(problem, 'diff'); 57 | end 58 | checkdiff(problem, x, d); 59 | title(sprintf(['Gradient check.\nThe slope of the continuous line ' ... 60 | 'should match that of the dashed (reference) line\n' ... 61 | 'over at least a few orders of magnitude for h.'])); 62 | xlabel('h'); 63 | ylabel('Approximation error'); 64 | 65 | %% Try to check that the gradient is a tangent vector. 66 | if isfield(problem.M, 'tangent') 67 | grad = getGradient(problem, x, dbstore); 68 | pgrad = problem.M.tangent(x, grad); 69 | residual = problem.M.lincomb(x, 1, grad, -1, pgrad); 70 | err = problem.M.norm(x, residual); 71 | fprintf('The residual should be 0, or very close. Residual: %g.\n', err); 72 | fprintf('If it is far from 0, then the gradient is not in the tangent space.\n'); 73 | else 74 | fprintf(['Unfortunately, Manopt was unable to verify that the '... 75 | 'gradient is indeed a tangent vector.\nPlease verify ' ... 76 | 'this manually or implement the ''tangent'' function ' ... 77 | 'in your manifold structure.']); 78 | end 79 | 80 | end 81 | -------------------------------------------------------------------------------- /manopt/manopt/tools/checkgradientSinglePrecision.m: -------------------------------------------------------------------------------- 1 | function checkgradientSinglePrecision(problem, x, d) 2 | % Checks the consistency of the cost function and the gradient. 3 | % 4 | % function checkgradient(problem) 5 | % function checkgradient(problem, x) 6 | % function checkgradient(problem, x, d) 7 | % 8 | % checkgradient performs a numerical test to check that the gradient 9 | % defined in the problem structure agrees up to first order with the cost 10 | % function at some point x, along some direction d. The test is based on a 11 | % truncated Taylor series (see online Manopt documentation). 12 | % 13 | % It is also tested that the gradient is indeed a tangent vector. 14 | % 15 | % Both x and d are optional and will be sampled at random if omitted. 16 | % 17 | % See also: checkdiff checkhessian 18 | 19 | % This file is part of Manopt: www.manopt.org. 20 | % Original author: Nicolas Boumal, Dec. 30, 2012. 21 | % Contributors: 22 | % Change log: 23 | 24 | 25 | % Verify that the problem description is sufficient. 26 | if ~canGetCost(problem) 27 | error('It seems no cost was provided.'); 28 | end 29 | if ~canGetGradient(problem) 30 | error('It seems no gradient provided.'); 31 | end 32 | 33 | dbstore = struct(); 34 | 35 | x_isprovided = exist('x', 'var') && ~isempty(x); 36 | d_isprovided = exist('d', 'var') && ~isempty(d); 37 | 38 | if ~x_isprovided && d_isprovided 39 | error('If d is provided, x must be too, since d is tangent at x.'); 40 | end 41 | 42 | % If x and / or d are not specified, pick them at random. 43 | if ~x_isprovided 44 | x = problem.M.rand(); 45 | end 46 | if ~d_isprovided 47 | d = problem.M.randvec(x); 48 | end 49 | 50 | x.encoder = single(x.encoder); 51 | x.decoder = single(x.decoder); 52 | d.encoder = single(d.encoder); 53 | d.decoder = single(d.decoder); 54 | %% Check that the gradient yields a first order model of the cost. 55 | 56 | % By removing the 'directional derivative' function, it should be so 57 | % (?) that the checkdiff function will use the gradient to compute 58 | % directional derivatives. 59 | if isfield(problem, 'diff') 60 | problem = rmfield(problem, 'diff'); 61 | end 62 | checkdiffSinglePrecision(problem, x, d); 63 | title(sprintf(['Gradient check.\nThe slope of the continuous line ' ... 64 | 'should match that of the dashed (reference) line\n' ... 65 | 'over at least a few orders of magnitude for h.'])); 66 | xlabel('h'); 67 | ylabel('Approximation error'); 68 | 69 | %% Try to check that the gradient is a tangent vector. 70 | if isfield(problem.M, 'tangent') 71 | grad = getGradient(problem, x, dbstore); 72 | pgrad = problem.M.tangent(x, grad); 73 | residual = problem.M.lincomb(x, 1, grad, -1, pgrad); 74 | err = problem.M.norm(x, residual); 75 | fprintf('The residual should be 0, or very close. Residual: %g.\n', err); 76 | fprintf('If it is far from 0, then the gradient is not in the tangent space.\n'); 77 | else 78 | fprintf(['Unfortunately, Manopt was unable to verify that the '... 79 | 'gradient is indeed a tangent vector.\nPlease verify ' ... 80 | 'this manually or implement the ''tangent'' function ' ... 81 | 'in your manifold structure.']); 82 | end 83 | 84 | end 85 | -------------------------------------------------------------------------------- /manopt/manopt/tools/checkhessian.m: -------------------------------------------------------------------------------- 1 | function checkhessian(problem, x, d) 2 | % Checks the consistency of the cost function and the Hessian. 3 | % 4 | % function checkhessian(problem) 5 | % function checkhessian(problem, x) 6 | % function checkhessian(problem, x, d) 7 | % 8 | % checkhessian performs a numerical test to check that the directional 9 | % derivatives and Hessian defined in the problem structure agree up to 10 | % second order with the cost function at some point x, along some direction 11 | % d. The test is based on a truncated Taylor series (see online Manopt 12 | % documentation). 13 | % 14 | % It is also tested that the Hessian along some direction is indeed a 15 | % tangent vector and that the Hessian operator is symmetric w.r.t. the 16 | % Riemannian metric. 17 | % 18 | % Both x and d are optional and will be sampled at random if omitted. 19 | % 20 | % See also: checkdiff checkgradient 21 | 22 | % This file is part of Manopt: www.manopt.org. 23 | % Original author: Nicolas Boumal, Dec. 30, 2012. 24 | % Contributors: 25 | % Change log: 26 | 27 | 28 | % Verify that the problem description is sufficient. 29 | if ~canGetCost(problem) 30 | error('It seems no cost was provided.'); 31 | end 32 | if ~canGetGradient(problem) 33 | error('It seems no gradient provided.'); 34 | end 35 | if ~canGetHessian(problem) 36 | error('It seems no Hessian was provided.'); 37 | end 38 | 39 | dbstore = struct(); 40 | 41 | x_isprovided = exist('x', 'var') && ~isempty(x); 42 | d_isprovided = exist('d', 'var') && ~isempty(d); 43 | 44 | if ~x_isprovided && d_isprovided 45 | error('If d is provided, x must be too, since d is tangent at x.'); 46 | end 47 | 48 | % If x and / or d are not specified, pick them at random. 49 | if ~x_isprovided 50 | x = problem.M.rand(); 51 | end 52 | if ~d_isprovided 53 | d = problem.M.randvec(x); 54 | end 55 | 56 | %% Check that the directional derivative and the Hessian at x along d 57 | %% yield a second order model of the cost function. 58 | 59 | % Compute the value f0 at f, directional derivative df0 at x along d, 60 | % and Hessian along [d, d]. 61 | f0 = getCost(problem, x, dbstore); 62 | df0 = getDirectionalDerivative(problem, x, d, dbstore); 63 | d2f0 = problem.M.inner(x, d, getHessian(problem, x, d, dbstore)); 64 | 65 | % Compute the value of f at points on the geodesic (or approximation of 66 | % it) originating from x, along direction d, for stepsizes in a large 67 | % range given by h. 68 | h = logspace(-8, 0, 51); 69 | value = zeros(size(h)); 70 | for i = 1 : length(h) 71 | y = problem.M.exp(x, d, h(i)); 72 | value(i) = getCost(problem, y, dbstore); 73 | end 74 | 75 | % Compute the quadratic approximation of the cost function using f0, 76 | % df0 and d2f0 at the same points. 77 | model = polyval([.5*d2f0 df0 f0], h); 78 | 79 | % Compute the approximation error 80 | err = abs(model - value); 81 | 82 | % And plot it. 83 | loglog(h, err); 84 | title(sprintf(['Hessian check.\nThe slope of the continuous line ' ... 85 | 'should match that of the dashed (reference) line\n' ... 86 | 'over at least a few orders of magnitude for h.'])); 87 | xlabel('h'); 88 | ylabel('Approximation error'); 89 | 90 | line('xdata', [1e-8 1e0], 'ydata', [1e-16 1e8], ... 91 | 'color', 'k', 'LineStyle', '--', ... 92 | 'YLimInclude', 'off', 'XLimInclude', 'off'); 93 | 94 | % In a numerically reasonable neighborhood, the error should decrease 95 | % as the cube of the stepsize, i.e., in loglog scale, the error 96 | % should have a slope of 3. 97 | window_len = 10; 98 | [range poly] = identify_linear_piece(log10(h), log10(err), window_len); 99 | hold on; 100 | loglog(h(range), 10.^polyval(poly, log10(h(range))), ... 101 | 'r-', 'LineWidth', 3); 102 | hold off; 103 | 104 | fprintf('The slope should be 3. It appears to be: %g.\n', poly(1)); 105 | fprintf(['If it is far from 3, then directional derivatives or ' ... 106 | 'the Hessian might be erroneous.\n']); 107 | 108 | 109 | %% Check that the Hessian at x along direction d is a tangent vector. 110 | if isfield(problem.M, 'tangent') 111 | hess = getHessian(problem, x, d, dbstore); 112 | phess = problem.M.tangent(x, hess); 113 | residual = problem.M.lincomb(x, 1, hess, -1, phess); 114 | err = problem.M.norm(x, residual); 115 | fprintf('The residual should be zero, or very close. '); 116 | fprintf('Residual: %g.\n', err); 117 | fprintf(['If it is far from 0, then the Hessian is not in the ' ... 118 | 'tangent plane.\n']); 119 | else 120 | fprintf(['Unfortunately, Manopt was un able to verify that the '... 121 | 'Hessian is indeed a tangent vector.\nPlease verify ' ... 122 | 'this manually.']); 123 | end 124 | 125 | %% Check that the Hessian at x is symmetric. 126 | d1 = problem.M.randvec(x); 127 | d2 = problem.M.randvec(x); 128 | h1 = getHessian(problem, x, d1, dbstore); 129 | h2 = getHessian(problem, x, d2, dbstore); 130 | v1 = problem.M.inner(x, d1, h2); 131 | v2 = problem.M.inner(x, h1, d2); 132 | value = v1-v2; 133 | fprintf([' - should be zero, or very close.' ... 134 | '\n\tValue: %g - %g = %g.\n'], v1, v2, value); 135 | fprintf('If it is far from 0, then the Hessian is not symmetric.\n'); 136 | 137 | end 138 | -------------------------------------------------------------------------------- /manopt/manopt/tools/diagsum.m: -------------------------------------------------------------------------------- 1 | function [tracedtensor] = diagsum(tensor1, d1, d2) 2 | % C = DIAGSUM(A, d1, d2) Performs the trace 3 | % C(i[1],...,i[d1-1],i[d1+1],...,i[d2-1],i[d2+1],...i[n]) = 4 | % A(i[1],...,i[d1-1],k,i[d1+1],...,i[d2-1],k,i[d2+1],...,i[n]) 5 | % (Sum on k). 6 | % 7 | % C = DIAGSUM(A, d1, d2) traces A along the diagonal formed by dimensions d1 8 | % and d2. If the lengths of these dimensions are not equal, DIAGSUM traces 9 | % until the end of the shortest of dimensions d1 and d2 is reached. This is 10 | % an analogue of the built in TRACE function. 11 | % 12 | % Wynton Moore, January 2006 13 | 14 | 15 | dim1=size(tensor1); 16 | numdims=length(dim1); 17 | 18 | 19 | %check inputs 20 | if d1==d2 21 | tracedtensor=squeeze(sum(tensor1,d1)); 22 | elseif numdims==2 23 | tracedtensor=trace(tensor1); 24 | elseif dim1(d1)==1 && dim1(d2)==1 25 | tracedtensor=squeeze(tensor1); 26 | else 27 | 28 | 29 | %determine correct permutation 30 | swapd1=d1;swapd2=d2; 31 | 32 | if d1~=numdims-1 && d1~=numdims && d2~=numdims-1 33 | swapd1=numdims-1; 34 | elseif d1~=numdims-1 && d1~=numdims && d2~=numdims 35 | swapd1=numdims; 36 | end 37 | if d2~=numdims-1 && d2~=numdims && swapd1~=numdims-1 38 | swapd2=numdims-1; 39 | elseif d2~=numdims-1 && d2~=numdims && swapd1~=numdims 40 | swapd2=numdims; 41 | end 42 | 43 | 44 | %prepare for construction of selector tensor 45 | temp1=eye(numdims); 46 | permmatrix=temp1; 47 | permmatrix(:,d1)=temp1(:,swapd1); 48 | permmatrix(:,swapd1)=temp1(:,d1); 49 | permmatrix(:,d2)=temp1(:,swapd2); 50 | permmatrix(:,swapd2)=temp1(:,d2); 51 | 52 | selectordim=dim1*permmatrix; 53 | permvector=(1:numdims)*permmatrix; 54 | 55 | 56 | %construct selector tensor 57 | if numdims>3 58 | selector = ipermute(outer(ones(selectordim(1:numdims-2)), ... 59 | eye(selectordim(numdims-1), ... 60 | selectordim(numdims)), ... 61 | 0), ... 62 | permvector); 63 | else 64 | %when numdims=3, the above line gives ndims(selector)=4. This 65 | %routine avoids that error. When used with GMDMP, numdims will be 66 | %at least 4, so this routine will be unnecessary. 67 | selector2=eye(selectordim(numdims-1), selectordim(numdims)); 68 | selector=zeros(selectordim); 69 | for j=1:selectordim(1) 70 | selector(j, :, :)=selector2; 71 | end 72 | selector=ipermute(selector, permvector); 73 | end 74 | 75 | 76 | %perform trace, discard resulting singleton dimensions 77 | tracedtensor=sum(sum(tensor1.*selector, d1), d2); 78 | tracedtensor=squeeze(tracedtensor); 79 | 80 | 81 | end 82 | 83 | 84 | %correction for abberation in squeeze function: 85 | %size(squeeze(rand(1,1,2)))=[2 1] 86 | nontracedimensions=dim1; 87 | nontracedimensions(d1)=[]; 88 | if d2>d1 89 | nontracedimensions(d2-1)=[]; 90 | else 91 | nontracedimensions(d2)=[]; 92 | end 93 | tracedsize=size(tracedtensor); 94 | % Next line modified, Nicolas Boumal, April 30, 2012, such that diagsum(A, 95 | % 1, 2) would compute the trace of A, a 2D matrix. 96 | if length(tracedsize)==2 && tracedsize(2)==1 && ... 97 | (isempty(nontracedimensions) || tracedsize(1)~=nontracedimensions(1)) 98 | 99 | tracedtensor=tracedtensor.'; 100 | 101 | end 102 | -------------------------------------------------------------------------------- /manopt/manopt/tools/hessianspectrum.m: -------------------------------------------------------------------------------- 1 | function lambdas = hessianspectrum(problem, x, sqrtprec) 2 | % Returns the eigenvalues of the (preconditioned) Hessian at x. 3 | % 4 | % function lambdas = hessianspectrum(problem, x) 5 | % function lambdas = hessianspectrum(problem, x, sqrtprecon) 6 | % 7 | % If the Hessian is defined in the problem structure and if no 8 | % preconditioner is defined, returns the eigenvalues of the Hessian 9 | % operator (which needs to be symmetric but not necessarily definite) on 10 | % the tangent space at x. There are problem.M.dim() eigenvalues. 11 | % 12 | % If a preconditioner is defined, the eigenvalues of the composition is 13 | % computed: precon o Hessian. Remember that the preconditioner has to be 14 | % symmetric, positive definite, and is supposed to approximate the inverse 15 | % of the Hessian. 16 | % 17 | % Even though the Hessian and the preconditioner are both symmetric, their 18 | % composition is not symmetric, which can slow down the call to 'eigs' 19 | % substantially. If possible, you may specify the square root of the 20 | % preconditioner as an optional input sqrtprecon. This operator on the 21 | % tangent space at x must also be symmetric, positive definite, and such 22 | % that sqrtprecon o sqrtprecon = precon. Then the spectrum of the symmetric 23 | % operator sqrtprecon o hess o sqrtprecon is computed: it is the same as 24 | % the spectrum of precon o hess, but is generally faster to compute. 25 | % The operator sqrtprecon(x, u[, store]) accepts as input: a point x, 26 | % a tangent vector u and (optional) a store structure. 27 | % 28 | % The input and the output of the Hessian and of the preconditioner are 29 | % projected on the tangent space to avoid undesired contributions of the 30 | % ambient space. 31 | % 32 | % Requires the manifold description in problem.M to have these functions: 33 | % 34 | % u_vec = vec(x, u_mat) : 35 | % Returns a column vector representation of the normal (usually 36 | % matrix) representation of the tangent vector u_mat. vec must be an 37 | % isometry between the tangent space (with its Riemannian metric) and 38 | % a subspace of R^n where n = length(u_vec), with the 2-norm on R^n. 39 | % In other words: it is an orthogonal projector. 40 | % 41 | % u_mat = mat(x, u_vec) : 42 | % The inverse of vec (its adjoint). 43 | % 44 | % u_mat_clean = tangent(x, u_mat) : 45 | % Subtracts from the tangent vector u_mat any component that would 46 | % make it "not really tangent", by projection. 47 | % 48 | % answer = vecmatareisometries() : 49 | % Returns true if the linear maps encoded by vec and mat are 50 | % isometries, false otherwise. It is better if the answer is yes. 51 | % 52 | 53 | % This file is part of Manopt: www.manopt.org. 54 | % Original author: Nicolas Boumal, July 3, 2013. 55 | % Contributors: 56 | % Change log: 57 | 58 | 59 | if ~canGetHessian(problem) 60 | warning('manopt:hessianspectrum:nohessian', ... 61 | ['The Hessian appears to be unavailable.\n' ... 62 | 'Will try to use an approximate Hessian instead.\n'... 63 | 'Since this approximation may not be linear or '... 64 | 'symmetric,\nthe computation might fail and the '... 65 | 'results (if any)\nmight make no sense.']); 66 | end 67 | 68 | vec = @(u_mat) problem.M.vec(x, u_mat); 69 | mat = @(u_vec) problem.M.mat(x, u_vec); 70 | tgt = @(u_mat) problem.M.tangent(x, u_mat); 71 | 72 | % n: size of a vectorized tangent vector 73 | % dim: dimension of the tangent space 74 | % necessarily, n >= dim. 75 | % The vectorized operators we build below will have at least n - dim 76 | % zero eigenvalues. 77 | n = length(vec(problem.M.zerovec(x))); 78 | dim = problem.M.dim(); 79 | 80 | % The store structure is not updated by the getHessian call because the 81 | % eigs function will not take care of it. This might be worked around, 82 | % but for now we simply obtain the store structure built from calling 83 | % the cost and gradient at x and pass that one for every Hessian call. 84 | % This will typically be enough, seen as the Hessian is not supposed to 85 | % store anything new. 86 | storedb = struct(); 87 | if canGetGradient(problem) 88 | [unused1, unused2, storedb] = getCostGrad(problem, x, struct()); %#ok 89 | end 90 | 91 | hess = @(u_mat) tgt(getHessian(problem, x, tgt(u_mat), storedb)); 92 | hess_vec = @(u_vec) vec(hess(mat(u_vec))); 93 | 94 | % Regardless of preconditioning, we can only have a symmetric 95 | % eigenvalue problem if the vec/mat pair of the manifold is an 96 | % isometry: 97 | vec_mat_are_isometries = problem.M.vecmatareisometries(); 98 | 99 | if ~exist('sqrtprec', 'var') || isempty(sqrtprec) 100 | 101 | if ~canGetPrecon(problem) 102 | 103 | % There is no preconditinoer : just deal with the (symmetric) 104 | % Hessian. 105 | 106 | eigs_opts.issym = vec_mat_are_isometries; 107 | eigs_opts.isreal = true; 108 | lambdas = eigs(hess_vec, n, dim, 'LM', eigs_opts); 109 | 110 | else 111 | 112 | % There is a preconditioner, but we don't have its square root: 113 | % deal with the non-symmetric composition prec o hess. 114 | 115 | prec = @(u_mat) tgt(getPrecon(problem, x, tgt(u_mat), storedb)); 116 | prec_vec = @(u_vec) vec(prec(mat(u_vec))); 117 | % prec_inv_vec = @(u_vec) pcg(prec_vec, u_vec); 118 | 119 | eigs_opts.issym = false; 120 | eigs_opts.isreal = true; 121 | lambdas = eigs(@(u_vec) prec_vec(hess_vec(u_vec)), ... 122 | n, dim, 'LM', eigs_opts); 123 | 124 | end 125 | 126 | else 127 | 128 | % There is a preconditioner, and we have its square root: deal with 129 | % the symmetric composition sqrtprec o hess o sqrtprec. 130 | % Need to check also whether sqrtprec uses the store cache or not. 131 | 132 | is_octave = exist('OCTAVE_VERSION', 'builtin'); 133 | if ~is_octave 134 | narg = nargin(sqrtprec); 135 | else 136 | narg = 3; 137 | end 138 | 139 | switch narg 140 | case 2 141 | sqrtprec_vec = @(u_vec) vec(tgt(sqrtprec(x, tgt(mat(u_vec))))); 142 | case 3 143 | store = getStore(problem, x, storedb); 144 | sqrtprec_vec = @(u_vec) vec(tgt(sqrtprec(x, tgt(mat(u_vec)), store))); 145 | otherwise 146 | error('sqrtprec must accept 2 or 3 inputs: x, u, (optional: store)'); 147 | end 148 | 149 | eigs_opts.issym = vec_mat_are_isometries; 150 | eigs_opts.isreal = true; 151 | lambdas = eigs(@(u_vec) ... 152 | sqrtprec_vec(hess_vec(sqrtprec_vec(u_vec))), ... 153 | n, dim, 'LM', eigs_opts); 154 | 155 | end 156 | 157 | end 158 | -------------------------------------------------------------------------------- /manopt/manopt/tools/identify_linear_piece.m: -------------------------------------------------------------------------------- 1 | function [range, poly] = identify_linear_piece(x, y, window_length) 2 | % Identify a segment of the curve (x, y) that appears to be linear. 3 | % 4 | % function [range poly] = identify_linear_piece(x, y, window_length) 5 | % 6 | % This function attempts to identify a contiguous segment of the curve 7 | % defined by the vectors x and y that appears to be linear. A line is fit 8 | % through the data over all windows of length window_length and the best 9 | % fit is retained. The output specifies the range of indices such that 10 | % x(range) is the portion over which (x, y) is the most linear and the 11 | % output poly specifies a first order polynomial that best fits (x, y) over 12 | % that range, following the usual matlab convention for polynomials 13 | % (highest degree coefficients first). 14 | % 15 | % See also: checkdiff checkgradient checkhessian 16 | 17 | % This file is part of Manopt: www.manopt.org. 18 | % Original author: Nicolas Boumal, July 8, 2013. 19 | % Contributors: 20 | % Change log: 21 | 22 | residues = zeros(length(x)-window_length, 1); 23 | polys = zeros(2, length(residues)); 24 | for i = 1 : length(residues) 25 | range = i:i+window_length; 26 | [poly, meta] = polyfit(x(range), y(range), 1); 27 | residues(i) = meta.normr; 28 | polys(:, i) = poly'; 29 | end 30 | [unused, best] = min(residues); %#ok 31 | range = best:best+window_length; 32 | poly = polys(:, best)'; 33 | 34 | end 35 | -------------------------------------------------------------------------------- /manopt/manopt/tools/identify_linear_piece_SinglePrecision.m: -------------------------------------------------------------------------------- 1 | function [range, poly] = identify_linear_piece_SinglePrecision(x, y, window_length) 2 | % Identify a segment of the curve (x, y) that appears to be linear. 3 | % 4 | % function [range poly] = identify_linear_piece(x, y, window_length) 5 | % 6 | % This function attempts to identify a contiguous segment of the curve 7 | % defined by the vectors x and y that appears to be linear. A line is fit 8 | % through the data over all windows of length window_length and the best 9 | % fit is retained. The output specifies the range of indices such that 10 | % x(range) is the portion over which (x, y) is the most linear and the 11 | % output poly specifies a first order polynomial that best fits (x, y) over 12 | % that range, following the usual matlab convention for polynomials 13 | % (highest degree coefficients first). 14 | % 15 | % See also: checkdiff checkgradient checkhessian 16 | 17 | % This file is part of Manopt: www.manopt.org. 18 | % Original author: Nicolas Boumal, July 8, 2013. 19 | % Contributors: 20 | % Change log: 21 | 22 | residues = single(zeros(length(x)-window_length, 1)); 23 | polys = single(zeros(2, length(residues))); 24 | for i = 1 : length(residues) 25 | range = i:i+window_length; 26 | [poly, meta] = polyfit(x(range), y(range), 1); 27 | residues(i) = meta.normr; 28 | polys(:, i) = poly'; 29 | end 30 | [unused, best] = min(residues); %#ok 31 | range = best:best+window_length; 32 | poly = polys(:, best)'; 33 | 34 | end 35 | -------------------------------------------------------------------------------- /manopt/manopt/tools/multiprod.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiwu-huang/SPDNet/527110a2209e918785075738b81c4ab091664e61/manopt/manopt/tools/multiprod.m -------------------------------------------------------------------------------- /manopt/manopt/tools/multiprodmultitransp_license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2009, Paolo de Leva 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in 12 | the documentation and/or other materials provided with the distribution 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 18 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /manopt/manopt/tools/multiscale.m: -------------------------------------------------------------------------------- 1 | function A = multiscale(scale, A) 2 | % Multiplies the 2D slices in a 3D matrix by individual scalars. 3 | % 4 | % function A = multiscale(scale, A) 5 | % 6 | % Given a vector scale of length N and a 3-dimensional matrix A of size 7 | % n-by-m-by-N, returns a matrix A of same size such that 8 | % A(:, :, k) := scale(k) * A(:, :, k); 9 | % 10 | % See also: multiprod multitransp multitrace 11 | 12 | % This file is part of Manopt: www.manopt.org. 13 | % Original author: Nicolas Boumal, Dec. 30, 2012. 14 | % Contributors: 15 | % Change log: 16 | 17 | 18 | assert(ndims(A) <= 3, ... 19 | ['multiscale is only well defined for matrix arrays of 3 ' ... 20 | 'or less dimensions.']); 21 | [n m N] = size(A); 22 | assert(numel(scale) == N, ... 23 | ['scale must be a vector whose length equals the third ' ... 24 | 'dimension of A, that is, the number of 2D matrix slices ' ... 25 | 'in the 3D matrix A.']); 26 | 27 | scale = scale(:); 28 | A = reshape(bsxfun(@times, reshape(A, n*m, N), scale'), n, m, N); 29 | 30 | end 31 | -------------------------------------------------------------------------------- /manopt/manopt/tools/multiskew.m: -------------------------------------------------------------------------------- 1 | function Y = multiskew(X) 2 | % Returns the skew-symmetric parts of the matrices in the 3D matrix X. 3 | % 4 | % function Y = multiskew(X) 5 | % 6 | % Y is a 3D matrix the same size as X. Each slice Y(:, :, i) is the 7 | % skew-symmetric part of the slice X(:, :, i). 8 | % 9 | % See also: multiprod multitransp multiscale multisym 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, Jan. 31, 2013. 13 | % Contributors: 14 | % Change log: 15 | 16 | Y = .5*(X - multitransp(X)); 17 | 18 | end 19 | -------------------------------------------------------------------------------- /manopt/manopt/tools/multisym.m: -------------------------------------------------------------------------------- 1 | function Y = multisym(X) 2 | % Returns the symmetric parts of the matrices in the 3D matrix X 3 | % 4 | % function Y = multisym(X) 5 | % 6 | % Y is a 3D matrix the same size as X. Each slice Y(:, :, i) is the 7 | % symmetric part of the slice X(:, :, i). 8 | % 9 | % See also: multiprod multitransp multiscale multiskew 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, Jan. 31, 2013. 13 | % Contributors: 14 | % Change log: 15 | 16 | Y = .5*(X + multitransp(X)); 17 | 18 | end 19 | -------------------------------------------------------------------------------- /manopt/manopt/tools/multitrace.m: -------------------------------------------------------------------------------- 1 | function tr = multitrace(A) 2 | % Computes the traces of the 2D slices in a 3D matrix. 3 | % 4 | % function tr = multitrace(A) 5 | % 6 | % For a 3-dimensional matrix A of size n-by-n-by-N, returns a column vector 7 | % tr of length N such that tr(k) = trace(A(:, :, k)); 8 | % 9 | % See also: multiprod multitransp multiscale 10 | 11 | % This file is part of Manopt: www.manopt.org. 12 | % Original author: Nicolas Boumal, Dec. 30, 2012. 13 | % Contributors: 14 | % Change log: 15 | 16 | 17 | assert(ndims(A) <= 3, ... 18 | ['multitrace is only well defined for matrix arrays of 3 ' ... 19 | 'or less dimensions.']); 20 | 21 | tr = diagsum(A, 1, 2); 22 | 23 | end 24 | -------------------------------------------------------------------------------- /manopt/manopt/tools/multitransp.m: -------------------------------------------------------------------------------- 1 | function b = multitransp(a, dim) 2 | %MULTITRANSP Transposing arrays of matrices. 3 | % B = MULTITRANSP(A) is equivalent to B = MULTITRANSP(A, DIM), where 4 | % DIM = 1. 5 | % 6 | % B = MULTITRANSP(A, DIM) is equivalent to 7 | % B = PERMUTE(A, [1:DIM-1, DIM+1, DIM, DIM+2:NDIMS(A)]), where A is an 8 | % array containing N P-by-Q matrices along its dimensions DIM and DIM+1, 9 | % and B is an array containing the Q-by-P transpose (.') of those N 10 | % matrices along the same dimensions. N = NUMEL(A) / (P*Q), i.e. N is 11 | % equal to the number of elements in A divided by the number of elements 12 | % in each matrix. 13 | % 14 | % MULTITRANSP, PERMUTE and IPERMUTE are a generalization of TRANSPOSE 15 | % (.') for N-D arrays. 16 | % 17 | % Example: 18 | % A 5-by-9-by-3-by-2 array may be considered to be a block array 19 | % containing ten 9-by-3 matrices along dimensions 2 and 3. In this 20 | % case, its size is so indicated: 5-by-(9-by-3)-by-2 or 5x(9x3)x2. 21 | % If A is ................ a 5x(9x3)x2 array of 9x3 matrices, 22 | % C = MULTITRANSP(A, 2) is a 5x(3x9)x2 array of 3x9 matrices. 23 | % 24 | % See also PERMUTE, IPERMUTE, MULTIPROD, MULTITRACE, MULTISCALE. 25 | 26 | % $ Version: 1.0 $ 27 | % CODE by: Paolo de Leva (IUSM, Rome, IT) 2005 Sep 9 28 | % COMMENTS by: Code author 2006 Nov 21 29 | % OUTPUT tested by: Code author 2005 Sep 13 30 | % ------------------------------------------------------------------------- 31 | 32 | % Setting DIM if not supplied. 33 | if nargin == 1, dim = 1; end 34 | 35 | % Transposing 36 | order = [1:dim-1, dim+1, dim, dim+2:ndims(a)]; 37 | b = permute(a, order); 38 | -------------------------------------------------------------------------------- /manopt/manopt/tools/plotprofile.m: -------------------------------------------------------------------------------- 1 | function cost = plotprofile(problem, x, d, t) 2 | % Plot the cost function along a geodesic or a retraction path. 3 | % 4 | % function plotprofile(problem, x, d, t) 5 | % function costs = plotprofile(problem, x, d, t) 6 | % 7 | % Plot profile evaluates the cost function along a geodesic gamma(t) such 8 | % that gamma(0) = x and the derivative of gamma at 0 is the direction d. 9 | % The input t is a vector specifying for which values of t we must evaluate 10 | % f(gamma(t)) (it may include negative values). 11 | % 12 | % If the function is called with an output, the plot is not drawn and the 13 | % values of the cost are returned for the instants t. 14 | 15 | % This file is part of Manopt: www.manopt.org. 16 | % Original author: Nicolas Boumal, Jan. 9, 2013. 17 | % Contributors: 18 | % Change log: 19 | 20 | % Verify that the problem description is sufficient. 21 | if ~canGetCost(problem) 22 | error('It seems no cost was provided.'); 23 | end 24 | 25 | linesearch_fun = @(t) getCost(problem, problem.M.exp(x, d, t), struct()); 26 | 27 | cost = zeros(size(t)); 28 | for i = 1 : numel(t) 29 | cost(i) = linesearch_fun(t(i)); 30 | end 31 | 32 | if nargout == 0 33 | plot(t, cost); 34 | xlabel('t'); 35 | ylabel('f(Exp_x(t*d))'); 36 | end 37 | 38 | end 39 | -------------------------------------------------------------------------------- /manopt/manopt/tools/powermanifold.m: -------------------------------------------------------------------------------- 1 | function Mn = powermanifold(M, n) 2 | % Returns a structure describing a power manifold M^n = M x M x ... x M. 3 | % 4 | % function Mn = powermanifold(M, n) 5 | % 6 | % Input: a manifold structure M and an integer n >= 1. 7 | % 8 | % Output: a manifold structure Mn representing M x ... x M (n copies of M) 9 | % with the metric of M extended element-wise. Points and vectors are stored 10 | % as cells of size nx1. 11 | % 12 | % This code is for prototyping uses. The structures returned are often 13 | % inefficient representations of power manifolds owing to their use of 14 | % for-loops, but they should allow to rapidly try out an idea. 15 | % 16 | % Example (an inefficient representation of the oblique manifold (3, 10)): 17 | % Mn = powermanifold(spherefactory(3), 10) 18 | % disp(Mn.name()); 19 | % x = Mn.rand() 20 | % 21 | % See also: productmanifold 22 | 23 | % This file is part of Manopt: www.manopt.org. 24 | % Original author: Nicolas Boumal, Dec. 30, 2012. 25 | % Contributors: 26 | % Change log: 27 | % NB, July 4, 2013: Added support for vec, mat, tangent. 28 | % Added support for egrad2rgrad and ehess2rhess. 29 | 30 | 31 | assert(n >= 1, 'n must be an integer larger than or equal to 1.'); 32 | 33 | Mn.name = @() sprintf('[%s]^%d', M.name(), n); 34 | 35 | Mn.dim = @() n*M.dim(); 36 | 37 | Mn.inner = @inner; 38 | function val = inner(x, u, v) 39 | val = 0; 40 | for i = 1 : n 41 | val = val + M.inner(x{i}, u{i}, v{i}); 42 | end 43 | end 44 | 45 | Mn.norm = @(x, d) sqrt(Mn.inner(x, d, d)); 46 | 47 | Mn.dist = @dist; 48 | function d = dist(x, y) 49 | sqd = 0; 50 | for i = 1 : n 51 | sqd = sqd + M.dist(x{i}, y{i})^2; 52 | end 53 | d = sqrt(sqd); 54 | end 55 | 56 | Mn.typicaldist = @typicaldist; 57 | function d = typicaldist() 58 | sqd = 0; 59 | for i = 1 : n 60 | sqd = sqd + M.typicaldist()^2; 61 | end 62 | d = sqrt(sqd); 63 | end 64 | 65 | Mn.proj = @proj; 66 | function u = proj(x, u) 67 | for i = 1 : n 68 | u{i} = M.proj(x{i}, u{i}); 69 | end 70 | end 71 | 72 | Mn.tangent = @tangent; 73 | function u = tangent(x, u) 74 | for i = 1 : n 75 | u{i} = M.tangent(x{i}, u{i}); 76 | end 77 | end 78 | 79 | if isfield(M, 'tangent2ambient') 80 | Mn.tangent2ambient = @tangent2ambient; 81 | else 82 | Mn.tangent2ambient = @(x, u) u; 83 | end 84 | function u = tangent2ambient(x, u) 85 | for i = 1 : n 86 | u{i} = M.tangent2ambient(x{i}, u{i}); 87 | end 88 | end 89 | 90 | Mn.egrad2rgrad = @egrad2rgrad; 91 | function g = egrad2rgrad(x, g) 92 | for i = 1 : n 93 | g{i} = M.egrad2rgrad(x{i}, g{i}); 94 | end 95 | end 96 | 97 | Mn.ehess2rhess = @ehess2rhess; 98 | function h = ehess2rhess(x, eg, eh, h) 99 | for i = 1 : n 100 | h{i} = M.ehess2rhess(x{i}, eg{i}, eh{i}, h{i}); 101 | end 102 | end 103 | 104 | Mn.exp = @expo; 105 | function x = expo(x, u, t) 106 | if nargin < 3 107 | t = 1.0; 108 | end 109 | for i = 1 : n 110 | x{i} = M.exp(x{i}, u{i}, t); 111 | end 112 | end 113 | 114 | Mn.retr = @retr; 115 | function x = retr(x, u, t) 116 | if nargin < 3 117 | t = 1.0; 118 | end 119 | for i = 1 : n 120 | x{i} = M.retr(x{i}, u{i}, t); 121 | end 122 | end 123 | 124 | if isfield(M, 'log') 125 | Mn.log = @loga; 126 | end 127 | function u = loga(x, y) 128 | u = cell(n, 1); 129 | for i = 1 : n 130 | u{i} = M.log(x{i}, y{i}); 131 | end 132 | end 133 | 134 | Mn.hash = @hash; 135 | function str = hash(x) 136 | str = ''; 137 | for i = 1 : n 138 | str = [str M.hash(x{i})]; %#ok 139 | end 140 | str = ['z' hashmd5(str)]; 141 | end 142 | 143 | Mn.lincomb = @lincomb; 144 | function x = lincomb(x, a1, u1, a2, u2) 145 | if nargin == 3 146 | for i = 1 : n 147 | x{i} = M.lincomb(x{i}, a1, u1{i}); 148 | end 149 | elseif nargin == 5 150 | for i = 1 : n 151 | x{i} = M.lincomb(x{i}, a1, u1{i}, a2, u2{i}); 152 | end 153 | else 154 | error('Bad usage of powermanifold.lincomb'); 155 | end 156 | end 157 | 158 | Mn.rand = @rand; 159 | function x = rand() 160 | x = cell(n, 1); 161 | for i = 1 : n 162 | x{i} = M.rand(); 163 | end 164 | end 165 | 166 | Mn.randvec = @randvec; 167 | function u = randvec(x) 168 | u = cell(n, 1); 169 | for i = 1 : n 170 | u{i} = M.randvec(x{i}); 171 | end 172 | u = Mn.lincomb(x, 1/sqrt(n), u); 173 | end 174 | 175 | Mn.zerovec = @zerovec; 176 | function u = zerovec(x) 177 | u = cell(n, 1); 178 | for i = 1 : n 179 | u{i} = M.zerovec(x{i}); 180 | end 181 | end 182 | 183 | if isfield(M, 'transp') 184 | Mn.transp = @transp; 185 | end 186 | function u = transp(x1, x2, u) 187 | for i = 1 : n 188 | u{i} = M.transp(x1{i}, x2{i}, u{i}); 189 | end 190 | end 191 | 192 | if isfield(M, 'pairmean') 193 | Mn.pairmean = @pairmean; 194 | end 195 | function y = pairmean(x1, x2) 196 | y = cell(n, 1); 197 | for i = 1 : n 198 | y{i} = M.pairmean(x1{i}, x2{i}); 199 | end 200 | end 201 | 202 | % Compute the length of a vectorized tangent vector of M at x, assuming 203 | % this length is independent of the point x (that should be fine). 204 | if isfield(M, 'vec') 205 | rand_x = M.rand(); 206 | zero_u = M.zerovec(rand_x); 207 | len_vec = length(M.vec(rand_x, zero_u)); 208 | 209 | Mn.vec = @vec; 210 | 211 | if isfield(M, 'mat') 212 | Mn.mat = @mat; 213 | end 214 | 215 | end 216 | 217 | function u_vec = vec(x, u_mat) 218 | u_vec = zeros(len_vec, n); 219 | for i = 1 : n 220 | u_vec(:, i) = M.vec(x{i}, u_mat{i}); 221 | end 222 | u_vec = u_vec(:); 223 | end 224 | 225 | function u_mat = mat(x, u_vec) 226 | u_mat = cell(n, 1); 227 | u_vec = reshape(u_vec, len_vec, n); 228 | for i = 1 : n 229 | u_mat{i} = M.mat(x{i}, u_vec(:, i)); 230 | end 231 | end 232 | 233 | if isfield(M, 'vecmatareisometries') 234 | Mn.vecmatareisometries = M.vecmatareisometries; 235 | else 236 | Mn.vecmatareisometries = @() false; 237 | end 238 | 239 | end 240 | -------------------------------------------------------------------------------- /manopt/manopt_version.m: -------------------------------------------------------------------------------- 1 | function [version released] = manopt_version() 2 | % Returns the version of the Manopt package you are running, as a vector. 3 | % 4 | % function [version released] = manopt_version() 5 | % 6 | % version(1) is the primary version number. 7 | % released is the date this version was released, in the same format as the 8 | % date() function in Matlab. 9 | 10 | version = [1, 0, 7]; 11 | released = '12-Aug-2014'; 12 | -------------------------------------------------------------------------------- /spdnet/vl_mybfc.m: -------------------------------------------------------------------------------- 1 | function [Y, Y_w] = vl_mybfc(X, W, dzdy) 2 | %[DZDX, DZDF, DZDB] = VL_MYBFC(X, F, B, DZDY) 3 | %BiMap layer 4 | 5 | Y = cell(length(X),1); 6 | for ix = 1 : length(X) 7 | Y{ix} = W'*X{ix}*W; 8 | end 9 | if nargin == 3 10 | [dim_ori, dim_tar] = size(W); 11 | Y_w = zeros(dim_ori,dim_tar); 12 | for ix = 1 : length(X) 13 | if iscell(dzdy)==1 14 | d_t = dzdy{ix}; 15 | else 16 | d_t = dzdy(:,ix); 17 | d_t = reshape(d_t,[dim_tar dim_tar]); 18 | end 19 | Y{ix} = W*d_t*W'; 20 | Y_w = Y_w+2*X{ix}*W*d_t; 21 | end 22 | end 23 | -------------------------------------------------------------------------------- /spdnet/vl_myfc.m: -------------------------------------------------------------------------------- 1 | function [Y, Y_w] = vl_myfc(X, W, dzdy) 2 | %[DZDX, DZDF, DZDB] = vl_myconv(X, F, B, DZDY) 3 | %regular fully connected layer 4 | 5 | X_t = zeros(size(X{1},1)^2,length(X)); 6 | for ix = 1 : length(X) 7 | x_t = X{ix}; 8 | X_t(:,ix) = x_t(:); 9 | end 10 | if nargin < 3 11 | Y = W'*X_t; 12 | else 13 | Y = W * dzdy; 14 | Y_w = X_t * dzdy'; 15 | end -------------------------------------------------------------------------------- /spdnet/vl_myforbackward.m: -------------------------------------------------------------------------------- 1 | function res = vl_myforbackward(net, x, dzdy, res, varargin) 2 | % vl_myforbackward evaluates a simple SPDNet 3 | 4 | opts.res = [] ; 5 | opts.conserveMemory = false ; 6 | opts.sync = false ; 7 | opts.disableDropout = false ; 8 | opts.freezeDropout = false ; 9 | opts.accumulate = false ; 10 | opts.cudnn = true ; 11 | opts.skipForward = false; 12 | opts.backPropDepth = +inf ; 13 | opts.epsilon = 1e-4; 14 | 15 | % opts = vl_argparse(opts, varargin); 16 | 17 | n = numel(net.layers) ; 18 | 19 | if (nargin <= 2) || isempty(dzdy) 20 | doder = false ; 21 | else 22 | doder = true ; 23 | end 24 | 25 | if opts.cudnn 26 | cudnn = {'CuDNN'} ; 27 | else 28 | cudnn = {'NoCuDNN'} ; 29 | end 30 | 31 | gpuMode = isa(x, 'gpuArray') ; 32 | 33 | if nargin <= 3 || isempty(res) 34 | res = struct(... 35 | 'x', cell(1,n+1), ... 36 | 'dzdx', cell(1,n+1), ... 37 | 'dzdw', cell(1,n+1), ... 38 | 'aux', cell(1,n+1), ... 39 | 'time', num2cell(zeros(1,n+1)), ... 40 | 'backwardTime', num2cell(zeros(1,n+1))) ; 41 | end 42 | if ~opts.skipForward 43 | res(1).x = x ; 44 | end 45 | 46 | 47 | % ------------------------------------------------------------------------- 48 | % Forward pass 49 | % ------------------------------------------------------------------------- 50 | 51 | for i=1:n 52 | if opts.skipForward, break; end; 53 | l = net.layers{i} ; 54 | res(i).time = tic ; 55 | switch l.type 56 | case 'bfc' 57 | res(i+1).x = vl_mybfc(res(i).x, l.weight) ; 58 | case 'fc' 59 | res(i+1).x = vl_myfc(res(i).x, l.weight) ; 60 | case 'rec' 61 | res(i+1).x = vl_myrec(res(i).x, opts.epsilon) ; 62 | case 'log' 63 | res(i+1).x = vl_mylog(res(i).x) ; 64 | case 'softmaxloss' 65 | res(i+1).x = vl_mysoftmaxloss(res(i).x, l.class) ; 66 | 67 | case 'custom' 68 | res(i+1) = l.forward(l, res(i), res(i+1)) ; 69 | otherwise 70 | error('Unknown layer type %s', l.type) ; 71 | end 72 | % optionally forget intermediate results 73 | forget = opts.conserveMemory ; 74 | forget = forget & (~doder || strcmp(l.type, 'relu')) ; 75 | forget = forget & ~(strcmp(l.type, 'loss') || strcmp(l.type, 'softmaxloss')) ; 76 | forget = forget & (~isfield(l, 'rememberOutput') || ~l.rememberOutput) ; 77 | if forget 78 | res(i).x = [] ; 79 | end 80 | if gpuMode & opts.sync 81 | % This should make things slower, but on MATLAB 2014a it is necessary 82 | % for any decent performance. 83 | wait(gpuDevice) ; 84 | end 85 | res(i).time = toc(res(i).time) ; 86 | end 87 | 88 | % ------------------------------------------------------------------------- 89 | % Backward pass 90 | % ------------------------------------------------------------------------- 91 | 92 | if doder 93 | res(n+1).dzdx = dzdy ; 94 | for i=n:-1:max(1, n-opts.backPropDepth+1) 95 | l = net.layers{i} ; 96 | res(i).backwardTime = tic ; 97 | switch l.type 98 | case 'bfc' 99 | [res(i).dzdx, res(i).dzdw] = ... 100 | vl_mybfc(res(i).x, l.weight, res(i+1).dzdx) ; 101 | case 'fc' 102 | [res(i).dzdx, res(i).dzdw] = ... 103 | vl_myfc(res(i).x, l.weight, res(i+1).dzdx) ; 104 | case 'rec' 105 | res(i).dzdx = vl_myrec(res(i).x, opts.epsilon, res(i+1).dzdx) ; 106 | case 'log' 107 | res(i).dzdx = vl_mylog(res(i).x, res(i+1).dzdx) ; 108 | case 'softmaxloss' 109 | res(i).dzdx = vl_mysoftmaxloss(res(i).x, l.class, res(i+1).dzdx) ; 110 | case 'custom' 111 | res(i) = l.backward(l, res(i), res(i+1)) ; 112 | end 113 | if opts.conserveMemory 114 | res(i+1).dzdx = [] ; 115 | end 116 | if gpuMode & opts.sync 117 | wait(gpuDevice) ; 118 | end 119 | res(i).backwardTime = toc(res(i).backwardTime) ; 120 | end 121 | end 122 | 123 | -------------------------------------------------------------------------------- /spdnet/vl_mylog.m: -------------------------------------------------------------------------------- 1 | function Y = vl_mylog(X, dzdy) 2 | %Y = VL_MYLOG(X, DZDY) 3 | %LogEig layer 4 | 5 | Us = cell(length(X),1); 6 | Ss = cell(length(X),1); 7 | Vs = cell(length(X),1); 8 | 9 | for ix = 1 : length(X) 10 | [Us{ix},Ss{ix},Vs{ix}] = svd(X{ix}); 11 | end 12 | 13 | 14 | D = size(Ss{1},2); 15 | Y = cell(length(X),1); 16 | 17 | if nargin < 2 18 | for ix = 1:length(X) 19 | Y{ix} = Us{ix}*diag(log(diag(Ss{ix})))*Us{ix}'; 20 | end 21 | else 22 | for ix = 1:length(X) 23 | U = Us{ix}; S = Ss{ix}; V = Vs{ix}; 24 | diagS = diag(S); 25 | ind =diagS >(D*eps(max(diagS))); 26 | Dmin = (min(find(ind,1,'last'),D)); 27 | 28 | S = S(:,ind); U = U(:,ind); 29 | dLdC = double(reshape(dzdy(:,ix),[D D])); dLdC = symmetric(dLdC); 30 | 31 | 32 | dLdV = 2*dLdC*U*diagLog(S,0); 33 | dLdS = diagInv(S)*(U'*dLdC*U); 34 | if sum(ind) == 1 % diag behaves badly when there is only 1d 35 | K = 1./(S(1)*ones(1,Dmin)-(S(1)*ones(1,Dmin))'); 36 | K(eye(size(K,1))>0)=0; 37 | else 38 | K = 1./(diag(S)*ones(1,Dmin)-(diag(S)*ones(1,Dmin))'); 39 | K(eye(size(K,1))>0)=0; 40 | K(find(isinf(K)==1))=0; 41 | end 42 | if all(diagS==1) 43 | dzdx = zeros(D,D); 44 | else 45 | dzdx = U*(symmetric(K'.*(U'*dLdV))+dDiag(dLdS))*U'; 46 | 47 | end 48 | Y{ix} = dzdx; %warning('no normalization'); 49 | end 50 | end 51 | -------------------------------------------------------------------------------- /spdnet/vl_myrec.m: -------------------------------------------------------------------------------- 1 | function Y = vl_myrec(X, epsilon, dzdy) 2 | % Y = VL_MYREC (X, EPSILON, DZDY) 3 | % ReEig layer 4 | 5 | Us = cell(length(X),1); 6 | Ss = cell(length(X),1); 7 | Vs = cell(length(X),1); 8 | 9 | for ix = 1 : length(X) 10 | [Us{ix},Ss{ix},Vs{ix}] = svd(X{ix}); 11 | end 12 | 13 | D = size(Ss{1},2); 14 | Y = cell(length(X),1); 15 | 16 | if nargin < 3 17 | for ix = 1:length(X) 18 | [max_S, ~]=max_eig(Ss{ix},epsilon); 19 | Y{ix} = Us{ix}*max_S*Us{ix}'; 20 | end 21 | else 22 | for ix = 1:length(X) 23 | U = Us{ix}; S = Ss{ix}; V = Vs{ix}; 24 | 25 | Dmin = D; 26 | 27 | dLdC = double(dzdy{ix}); dLdC = symmetric(dLdC); 28 | 29 | [max_S, max_I]=max_eig(Ss{ix},epsilon); 30 | dLdV = 2*dLdC*U*max_S; 31 | dLdS = (diag(not(max_I)))*U'*dLdC*U; 32 | 33 | 34 | K = 1./(diag(S)*ones(1,Dmin)-(diag(S)*ones(1,Dmin))'); 35 | K(eye(size(K,1))>0)=0; 36 | K(find(isinf(K)==1))=0; 37 | 38 | dzdx = U*(symmetric(K'.*(U'*dLdV))+dDiag(dLdS))*U'; 39 | 40 | Y{ix} = dzdx; %warning('no normalization'); 41 | end 42 | end 43 | -------------------------------------------------------------------------------- /spdnet/vl_mysoftmaxloss.m: -------------------------------------------------------------------------------- 1 | function Y = vl_mysoftmaxloss(X,c,dzdy) 2 | % Softmax layer 3 | 4 | % class c = 0 skips a spatial location 5 | mass = single(c > 0) ; 6 | mass = mass'; 7 | 8 | % convert to indexes 9 | c_ = c - 1 ; 10 | for ic = 1 : length(c) 11 | c_(ic) = c(ic)+(ic-1)*size(X,1); 12 | end 13 | 14 | % compute softmaxloss 15 | Xmax = max(X,[],1) ; 16 | ex = exp(bsxfun(@minus, X, Xmax)) ; 17 | 18 | % s = bsxfun(@minus, X, Xmax); 19 | % ex = exp(s) ; 20 | % y = ex./repmat(sum(ex,1),[size(X,1) 1]); 21 | 22 | %n = sz(1)*sz(2) ; 23 | if nargin < 3 24 | t = Xmax + log(sum(ex,1)) - reshape(X(c_), [1 size(X,2)]) ; 25 | Y = sum(sum(mass .* t,1)) ; 26 | else 27 | Y = bsxfun(@rdivide, ex, sum(ex,1)) ; 28 | Y(c_) = Y(c_) - 1; 29 | Y = bsxfun(@times, Y, bsxfun(@times, mass, dzdy)) ; 30 | end -------------------------------------------------------------------------------- /spdnet_afew.m: -------------------------------------------------------------------------------- 1 | function [net, info] = spdnet_afew(varargin) 2 | %set up the path 3 | confPath; 4 | %parameter setting 5 | opts.dataDir = fullfile('./data/afew') ; 6 | opts.imdbPathtrain = fullfile(opts.dataDir, 'spddb_afew_train_spd400_int_histeq.mat'); 7 | opts.batchSize = 30 ; 8 | opts.test.batchSize = 1; 9 | opts.numEpochs = 500 ; 10 | opts.gpus = [] ; 11 | opts.learningRate = 0.01*ones(1,opts.numEpochs); 12 | opts.weightDecay = 0.0005 ; 13 | opts.continue = 1; 14 | %spdnet initialization 15 | net = spdnet_init_afew() ; 16 | %loading metadata 17 | load(opts.imdbPathtrain) ; 18 | %spdnet training 19 | [net, info] = spdnet_train_afew(net, spd_train, opts); 20 | 21 | 22 | -------------------------------------------------------------------------------- /spdnet_init_afew.m: -------------------------------------------------------------------------------- 1 | function net = spdnet_init_afew(varargin) 2 | % spdnet_init initializes a spdnet 3 | 4 | rng('default'); 5 | rng(0) ; 6 | 7 | opts.layernum = 3; 8 | 9 | Winit = cell(opts.layernum,1); 10 | opts.datadim = [400, 200, 100, 50]; 11 | 12 | 13 | for iw = 1 : opts.layernum 14 | A = rand(opts.datadim(iw)); 15 | [U1, S1, V1] = svd(A * A'); 16 | Winit{iw} = U1(:,1:opts.datadim(iw+1)); 17 | end 18 | 19 | f=1/100 ; 20 | classNum = 7; 21 | fdim = size(Winit{iw},2)*size(Winit{iw},2); 22 | theta = f*randn(fdim, classNum, 'single'); 23 | Winit{end+1} = theta; 24 | 25 | net.layers = {} ; 26 | net.layers{end+1} = struct('type', 'bfc',... 27 | 'weight', Winit{1}) ; 28 | net.layers{end+1} = struct('type', 'rec') ; 29 | net.layers{end+1} = struct('type', 'bfc',... 30 | 'weight', Winit{2}) ; 31 | net.layers{end+1} = struct('type', 'rec') ; 32 | net.layers{end+1} = struct('type', 'bfc',... 33 | 'weight', Winit{3}) ; 34 | net.layers{end+1} = struct('type', 'log') ; 35 | net.layers{end+1} = struct('type', 'fc', ... 36 | 'weight', Winit{end}) ; 37 | net.layers{end+1} = struct('type', 'softmaxloss') ; 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /spdnet_train_afew.m: -------------------------------------------------------------------------------- 1 | function [net, info] = spdnet_train_afew(net, spd_train, opts) 2 | 3 | opts.errorLabels = {'top1e'}; 4 | opts.train = find(spd_train.spd.set==1) ; 5 | opts.val = find(spd_train.spd.set==2) ; 6 | 7 | for epoch=1:opts.numEpochs 8 | learningRate = opts.learningRate(epoch); 9 | 10 | % fast-forward to last checkpoint 11 | modelPath = @(ep) fullfile(opts.dataDir, sprintf('net-epoch-%d.mat', ep)); 12 | modelFigPath = fullfile(opts.dataDir, 'net-train.pdf') ; 13 | if opts.continue 14 | if exist(modelPath(epoch),'file') 15 | if epoch == opts.numEpochs 16 | load(modelPath(epoch), 'net', 'info') ; 17 | end 18 | continue ; 19 | end 20 | if epoch > 1 21 | fprintf('resuming by loading epoch %d\n', epoch-1) ; 22 | load(modelPath(epoch-1), 'net', 'info') ; 23 | end 24 | end 25 | 26 | train = opts.train(randperm(length(opts.train))) ; % shuffle 27 | val = opts.val; 28 | 29 | [net,stats.train] = process_epoch(opts, epoch, spd_train, train, learningRate, net) ; 30 | [net,stats.val] = process_epoch(opts, epoch, spd_train, val, 0, net) ; 31 | 32 | 33 | % save 34 | evaluateMode = 0; 35 | 36 | if evaluateMode, sets = {'train'} ; else sets = {'train', 'val'} ; end 37 | 38 | for f = sets 39 | f = char(f) ; 40 | n = numel(eval(f)) ; % 41 | info.(f).objective(epoch) = stats.(f)(2) / n ; 42 | info.(f).error(:,epoch) = stats.(f)(3:end) / n ; 43 | end 44 | if ~evaluateMode, save(modelPath(epoch), 'net', 'info') ; end 45 | 46 | figure(1) ; clf ; 47 | hasError = 1 ; 48 | subplot(1,1+hasError,1) ; 49 | if ~evaluateMode 50 | semilogy(1:epoch, info.train.objective, '.-', 'linewidth', 2) ; 51 | hold on ; 52 | end 53 | semilogy(1:epoch, info.val.objective, '.--') ; 54 | xlabel('training epoch') ; ylabel('energy') ; 55 | grid on ; 56 | h=legend(sets) ; 57 | set(h,'color','none'); 58 | title('objective') ; 59 | if hasError 60 | subplot(1,2,2) ; leg = {} ; 61 | if ~evaluateMode 62 | plot(1:epoch, info.train.error', '.-', 'linewidth', 2) ; 63 | hold on ; 64 | leg = horzcat(leg, strcat('train ', opts.errorLabels)) ; 65 | end 66 | plot(1:epoch, info.val.error', '.--') ; 67 | leg = horzcat(leg, strcat('val ', opts.errorLabels)) ; 68 | set(legend(leg{:}),'color','none') ; 69 | grid on ; 70 | xlabel('training epoch') ; ylabel('error') ; 71 | title('error') ; 72 | end 73 | drawnow ; 74 | print(1, modelFigPath, '-dpdf') ; 75 | end 76 | 77 | 78 | 79 | function [net,stats] = process_epoch(opts, epoch, spd_train, trainInd, learningRate, net) 80 | 81 | training = learningRate > 0 ; 82 | if training, mode = 'training' ; else mode = 'validation' ; end 83 | 84 | stats = [0 ; 0 ; 0] ; 85 | numGpus = numel(opts.gpus) ; 86 | if numGpus >= 1 87 | one = gpuArray(single(1)) ; 88 | else 89 | one = single(1) ; 90 | end 91 | 92 | batchSize = opts.batchSize; 93 | errors = 0; 94 | numDone = 0 ; 95 | 96 | 97 | for ib = 1 : batchSize : length(trainInd) 98 | fprintf('%s: epoch %02d: batch %3d/%3d:', mode, epoch, ib,length(trainInd)) ; 99 | batchTime = tic ; 100 | res = []; 101 | if (ib+batchSize> length(trainInd)) 102 | batchSize_r = length(trainInd)-ib+1; 103 | else 104 | batchSize_r = batchSize; 105 | end 106 | spd_data = cell(batchSize_r,1); 107 | spd_label = zeros(batchSize_r,1); 108 | for ib_r = 1 : batchSize_r 109 | spdPath = [spd_train.spdDir '\' spd_train.spd.name{trainInd(ib+ib_r-1)}]; 110 | load(spdPath); spd_data{ib_r} = Y1; 111 | spd_label(ib_r) = spd_train.spd.label(trainInd(ib+ib_r-1)); 112 | 113 | end 114 | net.layers{end}.class = spd_label ; 115 | 116 | %forward/backward spdnet 117 | if training, dzdy = one; else dzdy = [] ; end 118 | res = vl_myforbackward(net, spd_data, dzdy, res) ; 119 | 120 | %accumulating graidents 121 | if numGpus <= 1 122 | [net,res] = accumulate_gradients(opts, learningRate, batchSize_r, net, res) ; 123 | else 124 | if isempty(mmap) 125 | mmap = map_gradients(opts.memoryMapFile, net, res, numGpus) ; 126 | end 127 | write_gradients(mmap, net, res) ; 128 | labBarrier() ; 129 | [net,res] = accumulate_gradients(opts, learningRate, batchSize_r, net, res, mmap) ; 130 | end 131 | 132 | % accumulate training errors 133 | predictions = gather(res(end-1).x) ; 134 | [~,pre_label] = sort(predictions, 'descend') ; 135 | error = sum(~bsxfun(@eq, pre_label(1,:)', spd_label)) ; 136 | 137 | numDone = numDone + batchSize_r ; 138 | errors = errors+error; 139 | batchTime = toc(batchTime) ; 140 | speed = batchSize/batchTime ; 141 | stats = stats+[batchTime ; res(end).x ; error]; % works even when stats=[] 142 | 143 | fprintf(' %.2f s (%.1f data/s)', batchTime, speed) ; 144 | 145 | fprintf(' error: %.5f', stats(3)/numDone) ; 146 | fprintf(' obj: %.5f', stats(2)/numDone) ; 147 | 148 | fprintf(' [%d/%d]', numDone, batchSize_r); 149 | fprintf('\n') ; 150 | 151 | end 152 | 153 | 154 | % ------------------------------------------------------------------------- 155 | function [net,res] = accumulate_gradients(opts, lr, batchSize, net, res, mmap) 156 | % ------------------------------------------------------------------------- 157 | for l=numel(net.layers):-1:1 158 | if isempty(res(l).dzdw)==0 159 | if ~isfield(net.layers{l}, 'learningRate') 160 | net.layers{l}.learningRate = 1 ; 161 | end 162 | if ~isfield(net.layers{l}, 'weightDecay') 163 | net.layers{l}.weightDecay = 1; 164 | end 165 | thisLR = lr * net.layers{l}.learningRate ; 166 | 167 | if isfield(net.layers{l}, 'weight') 168 | if strcmp(net.layers{l}.type,'bfc')==1 169 | W1=net.layers{l}.weight; 170 | W1grad = (1/batchSize)*res(l).dzdw; 171 | %gradient update on Stiefel manifolds 172 | problemW1.M = stiefelfactory(size(W1,1), size(W1,2)); 173 | W1Rgrad = (problemW1.M.egrad2rgrad(W1, W1grad)); 174 | net.layers{l}.weight = (problemW1.M.retr(W1, -thisLR*W1Rgrad)); %%!!!NOTE 175 | else 176 | net.layers{l}.weight = net.layers{l}.weight - thisLR * (1/batchSize)* res(l).dzdw ; 177 | end 178 | 179 | end 180 | end 181 | end 182 | 183 | -------------------------------------------------------------------------------- /utils/dDiag.m: -------------------------------------------------------------------------------- 1 | function M = dDiag(M) 2 | % double diag function i.e. return a matrix with all the elements except 3 | % for the diagonal 4 | 5 | if isa(M,'gpuArray') 6 | I = eye(size(M),'single','gpuArray'); 7 | else 8 | I = eye(size(M)); 9 | end 10 | 11 | M = I.*M; -------------------------------------------------------------------------------- /utils/diagInv.m: -------------------------------------------------------------------------------- 1 | function invX = diagInv(X) 2 | % compute the inverse of a diagonal matrix 3 | % no verification for speed 4 | % FIXME add some checks for zeros on the diagonal 5 | 6 | diagX = diag(X); 7 | invX = diag(1./diagX); 8 | end -------------------------------------------------------------------------------- /utils/diagLog.m: -------------------------------------------------------------------------------- 1 | function L = diagLog(D,c) 2 | % compute log of a diagonal matrix add constant displacement c if necessary 3 | if ~exist('c','var'), c = 0; end 4 | [M,N] = size(D); 5 | if isa(D,'gpuArray') 6 | L = zeros(size(D),'single','gpuArray'); 7 | else 8 | L = zeros(size(D)); 9 | end 10 | 11 | m = min(M,N); 12 | L(1:m,1:m) = diag(log(diag(D)+c)); 13 | end -------------------------------------------------------------------------------- /utils/max_eig.m: -------------------------------------------------------------------------------- 1 | function [L, L_I]=max_eig(D,c) 2 | if ~exist('c','var'), c = 0; end 3 | [M,N] = size(D); 4 | if isa(D,'gpuArray') 5 | L = zeros(size(D),'single','gpuArray'); 6 | else 7 | L = zeros(size(D)); 8 | end 9 | 10 | m = min(M,N); 11 | h1=diag(D); 12 | % Ir1 = h1 < c; 13 | % %reLU 14 | % r1 = h1; 15 | % r1(Ir1) = 0; 16 | 17 | L_I = h1 < c; 18 | %reLU 19 | h1(L_I) = c; 20 | L(1:m,1:m) = diag(h1); -------------------------------------------------------------------------------- /utils/symmetric.m: -------------------------------------------------------------------------------- 1 | function ssX = symmetric(X) 2 | % symmetrized tensor 3 | % this is a version of symmetric that does symmetric of matrices on the 4 | % first 2 dimensions for each of the matrices given by the 3rd dimension 5 | % 6 | % (c) 2015 catalin ionescu -- catalin.ionescu@ins.uni-bonn.de 7 | 8 | ssX = .5*(X+permute(X,[2 1 3])); 9 | end --------------------------------------------------------------------------------