├── README.md
├── confPath.m
├── data
└── afew
│ └── spddb_afew_train_spd400_int_histeq.mat
├── manopt
├── CLA.txt
├── COPYING.txt
├── CREDITS.txt
├── LICENSE.txt
├── README.txt
├── checkinstall
│ └── basicexample.m
├── examples
│ ├── dominant_invariant_subspace.m
│ ├── generalized_procrustes.m
│ ├── low_rank_matrix_completion.m
│ ├── maxcut.m
│ ├── maxcut_octave.m
│ ├── packing_on_the_sphere.m
│ ├── positive_definite_karcher_mean.m
│ ├── sparse_pca.m
│ └── truncated_svd.m
├── importmanopt.m
├── manopt
│ ├── manifolds
│ │ ├── complexcircle
│ │ │ └── complexcirclefactory.m
│ │ ├── euclidean
│ │ │ ├── euclideanfactory.m
│ │ │ └── symmetricfactory.m
│ │ ├── fixedrank
│ │ │ ├── fixedrankMNquotientfactory.m
│ │ │ ├── fixedrankembeddedfactory.m
│ │ │ ├── fixedrankfactory_2factors.m
│ │ │ ├── fixedrankfactory_2factors_preconditioned.m
│ │ │ ├── fixedrankfactory_2factors_subspace_projection.m
│ │ │ ├── fixedrankfactory_3factors.m
│ │ │ └── fixedrankfactory_3factors_preconditioned.m
│ │ ├── grassmann
│ │ │ └── grassmannfactory.m
│ │ ├── oblique
│ │ │ └── obliquefactory.m
│ │ ├── rotations
│ │ │ ├── randrot.m
│ │ │ ├── randskew.m
│ │ │ └── rotationsfactory.m
│ │ ├── sphere
│ │ │ ├── spherecomplexfactory.m
│ │ │ └── spherefactory.m
│ │ ├── stiefel
│ │ │ └── stiefelfactory.m
│ │ └── symfixedrank
│ │ │ ├── elliptopefactory.m
│ │ │ ├── spectrahedronfactory.m
│ │ │ ├── symfixedrankYYfactory.m
│ │ │ └── sympositivedefinitefactory.m
│ ├── privatetools
│ │ ├── applyStatsfun.m
│ │ ├── canGetCost.m
│ │ ├── canGetDirectionalDerivative.m
│ │ ├── canGetEuclideanGradient.m
│ │ ├── canGetGradient.m
│ │ ├── canGetHessian.m
│ │ ├── canGetLinesearch.m
│ │ ├── canGetPrecon.m
│ │ ├── getApproxHessian.m
│ │ ├── getCost.m
│ │ ├── getCostGrad.m
│ │ ├── getDirectionalDerivative.m
│ │ ├── getEuclideanGradient.m
│ │ ├── getGlobalDefaults.m
│ │ ├── getGradient.m
│ │ ├── getHessian.m
│ │ ├── getHessianFD.m
│ │ ├── getLinesearch.m
│ │ ├── getPrecon.m
│ │ ├── getStore.m
│ │ ├── hashmd5.m
│ │ ├── mergeOptions.m
│ │ ├── purgeStoredb.m
│ │ ├── setStore.m
│ │ └── stoppingcriterion.m
│ ├── solvers
│ │ ├── conjugategradient
│ │ │ └── conjugategradient.m
│ │ ├── linesearch
│ │ │ ├── linesearch.m
│ │ │ ├── linesearch_adaptive.m
│ │ │ └── linesearch_hint.m
│ │ ├── neldermead
│ │ │ ├── centroid.m
│ │ │ └── neldermead.m
│ │ ├── pso
│ │ │ └── pso.m
│ │ ├── steepestdescent
│ │ │ └── steepestdescent.m
│ │ └── trustregions
│ │ │ ├── license for original GenRTR code.txt
│ │ │ ├── tCG.m
│ │ │ └── trustregions.m
│ └── tools
│ │ ├── checkdiff.m
│ │ ├── checkdiffSingle.m
│ │ ├── checkdiffSinglePrecision.m
│ │ ├── checkgradient.m
│ │ ├── checkgradientSinglePrecision.m
│ │ ├── checkhessian.m
│ │ ├── diagsum.m
│ │ ├── hessianspectrum.m
│ │ ├── identify_linear_piece.m
│ │ ├── identify_linear_piece_SinglePrecision.m
│ │ ├── multiprod.m
│ │ ├── multiprodmultitransp_license.txt
│ │ ├── multiscale.m
│ │ ├── multiskew.m
│ │ ├── multisym.m
│ │ ├── multitrace.m
│ │ ├── multitransp.m
│ │ ├── plotprofile.m
│ │ ├── powermanifold.m
│ │ └── productmanifold.m
└── manopt_version.m
├── spdnet
├── vl_mybfc.m
├── vl_myfc.m
├── vl_myforbackward.m
├── vl_mylog.m
├── vl_myrec.m
└── vl_mysoftmaxloss.m
├── spdnet_afew.m
├── spdnet_init_afew.m
├── spdnet_train_afew.m
└── utils
├── dDiag.m
├── diagInv.m
├── diagLog.m
├── max_eig.m
└── symmetric.m
/README.md:
--------------------------------------------------------------------------------
1 | # SPDNet-master
2 | Zhiwu Huang and Luc Van Gool. A Riemannian Network for SPD Matrix Learning, In Proc. AAAI 2017.
3 |
4 | Version 1.0, Copyright(c) November, 2017.
5 |
6 | Note that the copyright of the manopt toolbox is reserved by https://www.manopt.org/
7 |
8 | ## Usage
9 |
10 | Step1: Place the used AFEW SPD data under the folder "./data/afew/". Note that the used HDM05 and PaSC SPD data are also publicly available.
11 |
12 | Step2: Launch spdnet_afew.m for a simple example.
13 |
14 | ## Related Work/Implementation
15 |
16 | 1. Thanks to Oleg Smirnov who is Sr. Applied Scientist at Amazon, a TensorFlow ManOpt library is released to reproduce our SPDNet.
17 |
18 | 2. A NeurIPS 2019 paper "Riemannian batch normalization for SPD neural networks" develops batch normalization layer upon our SPDNet, with the official PyTorch code being publicly available at the 'Supplemental' tab.
19 |
20 | 3. A report "Second-order networks in PyTorch" is released.
21 |
22 | 4. Thanks to Alireza Davoudi, there is another Python implementation for SPDNet.
23 |
24 | 5. A direct extension of our SPDNet for facial emotion recognition is published by CVPR workshop 2018, with the code being available.
25 |
26 |
27 | ## How to Cite
28 | If you find this project helpful, please consider citing us as follows:
29 | ```bash
30 | @inproceedings{huang2017spdnet,
31 | title = {A Riemannian Network for SPD Matrix Learning},
32 | author = {Huang, Zhiwu and
33 | Van Gool, Luc},
34 | year = {2017},
35 | booktitle = {Association for the Advancement of Artificial Intelligence (AAAI)}
36 | }
37 |
38 |
39 |
--------------------------------------------------------------------------------
/confPath.m:
--------------------------------------------------------------------------------
1 | % Add folders to path.
2 | addpath(pwd);
3 |
4 | cd manopt;
5 | addpath(genpath(pwd));
6 | cd ..;
7 |
8 | cd utils;
9 | addpath(genpath(pwd));
10 | cd ..;
11 |
12 | cd spdnet;
13 | addpath(genpath(pwd));
14 | cd ..;
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/data/afew/spddb_afew_train_spd400_int_histeq.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiwu-huang/SPDNet/527110a2209e918785075738b81c4ab091664e61/data/afew/spddb_afew_train_spd400_int_histeq.mat
--------------------------------------------------------------------------------
/manopt/CLA.txt:
--------------------------------------------------------------------------------
1 | Thank you for your interest in Manopt. The purpose of this Contributor License Agreement is to
2 | clarify the intellectual property license granted with contributions of software from any person or
3 | entity (the "Contributor") to the owners of Manopt. This license is for your protection as a
4 | Contributor of software to Manopt and does not change your right to use your own contributions for
5 | any other purpose.
6 |
7 | The owners of Manopt are the copyright holders of Manopt indicated in the license files distributed
8 | with the software.
9 |
10 | You and the owners of Manopt hereby accept and agree to the following terms and conditions:
11 |
12 | Your "Contributions" means all of your past, present and future contributions of object code, source
13 | code and documentation to Manopt, however submitted to Manopt, excluding any submissions that are
14 | conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
15 |
16 | You hereby grant to the owners of Manopt a non-exclusive, irrevocable, worldwide, no-charge,
17 | transferable copyright license to use, execute, prepare derivative works of, and distribute
18 | (internally and externally, in object code and, if included in your Contributions, source code form)
19 | your Contributions. Except for the rights granted to the owners of Manopt in this paragraph, You
20 | reserve all right, title and interest in and to your Contributions.
21 |
22 | You represent that you are legally entitled to grant the above license. If your employer(s) have
23 | rights to intellectual property that you create, you represent that you have received permission to
24 | make the Contibutions on behalf of that employer, or that your employer has waived such rights for
25 | your Contributions to Manopt.
26 |
27 | You represent that, except as disclosed in your Contribution submission(s), each of your
28 | Contributions is your original creation. You represent that your Contribution submissions(s)
29 | included complete details of any license or other restriction (including, but not limited to,
30 | related patents and trademarks) associated with any part of your Contribution(s) (including a copy
31 | of any applicable license agreement). You agree to notify the owners of Manopt of any facts or
32 | circumstances of which you become aware that would make Your representations in the Agreement
33 | inaccurate in any respect.
34 |
35 | You are not expected to provide support for your Contributions, except to the extent you desire to
36 | provide support. Your may provide support for free, for a fee, or not at all. Your Contributions are
37 | provided as-is, with all faults, defects and errors, and without any warranty of any kind (either
38 | express or implied) including, without limitation, any implied warranty of merchantability and
39 | fitness for a particular purpose and any warranty of non-infringement.
40 |
41 |
42 | This CLA is a modification of the CLA used by the UW Calendar project of the Univeristy of
43 | Washington:
44 |
--------------------------------------------------------------------------------
/manopt/CREDITS.txt:
--------------------------------------------------------------------------------
1 | The Manopt project is led by these people:
2 |
3 | * Nicolas Boumal
4 | * Bamdev Mishra
5 | * Pierre-Antoine Absil
6 | * Rodolphe Sepulchre
7 |
8 |
9 | The following people have written code specifically for Manopt:
10 |
11 | * Nicolas Boumal
12 | * Bamdev Mishra
13 | * Pierre Borckmans
14 |
15 |
16 | Furthermore, code written by the following people can be found in Manopt:
17 |
18 | * Chris Baker
19 | * Pierre-Antoine Absil
20 | * Kyle Gallivan
21 | * Paolo de Leva
22 | * Wynton Moore
23 | * Michael Kleder
24 |
--------------------------------------------------------------------------------
/manopt/LICENSE.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiwu-huang/SPDNet/527110a2209e918785075738b81c4ab091664e61/manopt/LICENSE.txt
--------------------------------------------------------------------------------
/manopt/README.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiwu-huang/SPDNet/527110a2209e918785075738b81c4ab091664e61/manopt/README.txt
--------------------------------------------------------------------------------
/manopt/checkinstall/basicexample.m:
--------------------------------------------------------------------------------
1 | function basicexample
2 |
3 | % Verify that Manopt was indeed added to the Matlab path.
4 | if isempty(which('spherefactory'))
5 | error(['You should first add Manopt to the Matlab path.\n' ...
6 | 'Please run importmanopt first.']);
7 | end
8 |
9 | % Generate the problem data.
10 | n = 1000;
11 | A = randn(n);
12 | A = .5*(A+A');
13 |
14 | % Create the problem structure.
15 | manifold = spherefactory(n);
16 | problem.M = manifold;
17 |
18 | % Define the problem cost function and its gradient.
19 | problem.cost = @(x) -x'*(A*x);
20 | problem.grad = @(x) manifold.egrad2rgrad(x, -2*A*x);
21 |
22 | % Numerically check gradient consistency.
23 | checkgradient(problem);
24 |
25 | % Solve.
26 | % The trust-regions algorithm requires the Hessian. Since we do not
27 | % provide it, it will go for a standard approximation of it. The first
28 | % instruction tells Manopt not to issue a warning when this happens.
29 | warning('off', 'manopt:getHessian:approx');
30 | [x xcost info] = trustregions(problem); %#ok
31 |
32 | % Display some statistics.
33 | figure;
34 | semilogy([info.iter], [info.gradnorm], '.-');
35 | xlabel('Iteration #');
36 | ylabel('Gradient norm');
37 | title('Convergence of the trust-regions algorithm on the sphere');
38 |
39 | end
40 |
--------------------------------------------------------------------------------
/manopt/examples/dominant_invariant_subspace.m:
--------------------------------------------------------------------------------
1 | function [X, info] = dominant_invariant_subspace(A, p)
2 | % Returns an orthonormal basis of the dominant invariant p-subspace of A.
3 | %
4 | % function X = dominant_invariant_subspace(A, p)
5 | %
6 | % Input: A real, symmetric matrix A of size nxn and an integer p < n.
7 | % Output: A real, orthonormal matrix X of size nxp such that trace(X'*A*X)
8 | % is maximized. That is, the columns of X form an orthonormal basis
9 | % of a dominant subspace of dimension p of A. These are thus
10 | % eigenvectors associated with the largest eigenvalues of A (in no
11 | % particular order). Sign is important: 2 is deemed a larger
12 | % eigenvalue than -5.
13 | %
14 | % The optimization is performed on the Grassmann manifold, since only the
15 | % space spanned by the columns of X matters. The implementation is short to
16 | % show how Manopt can be used to quickly obtain a prototype. To make the
17 | % implementation more efficient, one might first try to use the caching
18 | % system, that is, use the optional 'store' arguments in the cost, grad and
19 | % hess functions. Furthermore, using egrad2rgrad and ehess2rhess is quick
20 | % and easy, but not always efficient. Having a look at the formulas
21 | % implemented in these functions can help rewrite the code without them,
22 | % possibly more efficiently.
23 |
24 | % This file is part of Manopt and is copyrighted. See the license file.
25 | %
26 | % Main author: Nicolas Boumal, July 5, 2013
27 | % Contributors:
28 | %
29 | % Change log:
30 | %
31 | % NB Dec. 6, 2013:
32 | % We specify a max and initial trust region radius in the options.
33 |
34 | % Generate some random data to test the function
35 | if ~exist('A', 'var') || isempty(A)
36 | A = randn(128);
37 | A = (A+A')/2;
38 | end
39 | if ~exist('p', 'var') || isempty(p)
40 | p = 3;
41 | end
42 |
43 | % Make sure the input matrix is square and symmetric
44 | n = size(A, 1);
45 | assert(isreal(A), 'A must be real.')
46 | assert(size(A, 2) == n, 'A must be square.');
47 | assert(norm(A-A', 'fro') < n*eps, 'A must be symmetric.');
48 | assert(p<=n, 'p must be smaller than n.');
49 |
50 | % Define the cost and its derivatives on the Grassmann manifold
51 | Gr = grassmannfactory(n, p);
52 | problem.M = Gr;
53 | problem.cost = @(X) -trace(X'*A*X);
54 | problem.grad = @(X) -2*Gr.egrad2rgrad(X, A*X);
55 | problem.hess = @(X, H) -2*Gr.ehess2rhess(X, A*X, A*H, H);
56 |
57 | % Execute some checks on the derivatives for early debugging.
58 | % These things can be commented out of course.
59 | % checkgradient(problem);
60 | % pause;
61 | % checkhessian(problem);
62 | % pause;
63 |
64 | % Issue a call to a solver. A random initial guess will be chosen and
65 | % default options are selected except for the ones we specify here.
66 | options.Delta_bar = 8*sqrt(p);
67 | [X, costX, info, options] = trustregions(problem, [], options); %#ok
68 |
69 | fprintf('Options used:\n');
70 | disp(options);
71 |
72 | % For our information, Manopt can also compute the spectrum of the
73 | % Riemannian Hessian on the tangent space at (any) X. Computing the
74 | % spectrum at the solution gives us some idea of the conditioning of
75 | % the problem. If we were to implement a preconditioner for the
76 | % Hessian, this would also inform us on its performance.
77 | %
78 | % Notice that (typically) all eigenvalues of the Hessian at the
79 | % solution are positive, i.e., we find an isolated minimizer. If we
80 | % replace the Grassmann manifold by the Stiefel manifold, hence still
81 | % optimizing over orthonormal matrices but ignoring the invariance
82 | % cost(XQ) = cost(X) for all Q orthogonal, then we see
83 | % dim O(p) = p(p-1)/2 zero eigenvalues in the Hessian spectrum, making
84 | % the optimizer not isolated anymore.
85 | if Gr.dim() < 512
86 | evs = hessianspectrum(problem, X);
87 | stairs(sort(evs));
88 | title(['Eigenvalues of the Hessian of the cost function ' ...
89 | 'at the solution']);
90 | xlabel('Eigenvalue number (sorted)');
91 | ylabel('Value of the eigenvalue');
92 | end
93 |
94 | end
95 |
--------------------------------------------------------------------------------
/manopt/examples/positive_definite_karcher_mean.m:
--------------------------------------------------------------------------------
1 | function X = positive_definite_karcher_mean(A)
2 | % Computes a Karcher mean of a collection of positive definite matrices.
3 | %
4 | % function X = positive_definite_karcher_mean(A)
5 | %
6 | % Input: A 3D matrix A of size nxnxm such that each slice A(:,:,k) is a
7 | % positive definite matrix of size nxn.
8 | %
9 | % Output: A positive definite matrix X of size nxn which is a Karcher mean
10 | % of the m matrices in A, that is, X minimizes the sum of squared
11 | % Riemannian distances to the matrices in A:
12 | % f(X) = sum_k=1^m .5*dist^2(X, A(:, :, k))
13 | % The distance is defined by the natural metric on the set of
14 | % positive definite matrices: dist(X,Y) = norm(logm(X\Y), 'fro').
15 | %
16 | % This simple example is not the best way to compute Karcher means. Its
17 | % purpose it to serve as base code to explore other algorithms. In
18 | % particular, in the presence of large noise, this algorithm seems to not
19 | % be able to reach points with a very small gradient norm. This may be
20 | % caused by insufficient accuracy in the gradient computation.
21 |
22 | % This file is part of Manopt and is copyrighted. See the license file.
23 | %
24 | % Main author: Nicolas Boumal, Sept. 3, 2013
25 | % Contributors:
26 | %
27 | % Change log:
28 | %
29 |
30 | % Generate some random data to test the function if none is given.
31 | if ~exist('A', 'var') || isempty(A)
32 | n = 5;
33 | m = 10;
34 | A = zeros(n, n, m);
35 | ref = diag(max(.1, 1+.1*randn(n, 1)));
36 | for i = 1 : m
37 | noise = 0.01*randn(n);
38 | noise = (noise + noise')/2;
39 | [V D] = eig(ref + noise);
40 | A(:, :, i) = V*diag(max(.01, diag(D)))*V';
41 | end
42 | end
43 |
44 | % Retrieve the size of the problem:
45 | % There are m matrices of size nxn to average.
46 | n = size(A, 1);
47 | m = size(A, 3);
48 | assert(n == size(A, 2), ...
49 | ['The slices of A must be square, i.e., the ' ...
50 | 'first and second dimensions of A must be equal.']);
51 |
52 | % Our search space is the set of positive definite matrices of size n.
53 | % Notice that this is the only place we specify on which manifold we
54 | % wish to compute Karcher means. Replacing this factory for another
55 | % geometry will yield code to compute Karcher means on that other
56 | % manifold, provided that manifold is equipped with a dist function and
57 | % a logarithmic map log.
58 | M = sympositivedefinitefactory(n);
59 |
60 | % Define a problem structure, specifying the manifold M, the cost
61 | % function and its gradient.
62 | problem.M = M;
63 | problem.cost = @cost;
64 | problem.grad = @grad;
65 |
66 | % The functions below make many redundant computations. This
67 | % performance hit can be alleviated by using the caching system. We go
68 | % for a simple implementation here, as a tutorial example.
69 |
70 | % Cost function
71 | function f = cost(X)
72 | f = 0;
73 | for k = 1 : m
74 | f = f + M.dist(X, A(:, :, k))^2;
75 | end
76 | f = f/(2*m);
77 | end
78 |
79 | % Riemannian gradient of the cost function
80 | function g = grad(X)
81 | g = M.zerovec(X);
82 | for k = 1 : m
83 | % Update g in a linear combination of the form
84 | % g = g - [something]/m.
85 | g = M.lincomb(X, 1, g, -1/m, M.log(X, A(:, :, k)));
86 | end
87 | end
88 |
89 | % Execute some checks on the derivatives for early debugging.
90 | % These things can be commented out of course.
91 | % The slopes should agree on part of the plot at least. In this case,
92 | % it is sometimes necessary to inspect the plot visually to make the
93 | % call, but it is indeed correct.
94 | % checkgradient(problem);
95 | % pause;
96 |
97 | % Execute this if you want to force using a proper parallel vector
98 | % transport. This is not necessary. If you omit this, the default
99 | % vector transport is the identity map, which is (of course) cheaper
100 | % and seems to perform well in practice.
101 | % M.transp = M.paralleltransp;
102 |
103 | % Issue a call to a solver. Default options are selected.
104 | % Our initial guess is the first data point.
105 | X = trustregions(problem, A(:, :, 1));
106 |
107 | end
108 |
--------------------------------------------------------------------------------
/manopt/importmanopt.m:
--------------------------------------------------------------------------------
1 | % Add Manopt to the path and make all manopt components available.
2 |
3 | % This file is part of Manopt: www.manopt.org.
4 | % Original author: Nicolas Boumal, Jan. 3, 2013.
5 | % Contributors:
6 | % Change log:
7 | % Aug. 7, 2013 (NB): Changed to work without the import command
8 | % (new structure of the toolbox).
9 | % Aug. 8, 2013 (NB): Changed to use addpath_recursive, home brewed.
10 | % Aug. 22, 2013 (NB): Using genpath instead of homecooked
11 | % addpath_recursive.
12 |
13 | addpath(pwd);
14 |
15 | % Recursively add Manopt directories to the Matlab path.
16 | cd manopt;
17 | addpath(genpath(pwd));
18 | cd ..;
19 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/complexcircle/complexcirclefactory.m:
--------------------------------------------------------------------------------
1 | function M = complexcirclefactory(n)
2 | % Returns a manifold struct to optimize over unit-modulus complex numbers.
3 | %
4 | % function M = complexcirclefactory()
5 | % function M = complexcirclefactory(n)
6 | %
7 | % Description of vectors z in C^n (complex) such that each component z(i)
8 | % has unit modulus. The manifold structure is the Riemannian submanifold
9 | % structure from the embedding space R^2 x ... x R^2, i.e., the complex
10 | % circle is identified with the unit circle in the real plane.
11 | %
12 | % By default, n = 1.
13 | %
14 | % See also spherecomplexfactory
15 |
16 | % This file is part of Manopt: www.manopt.org.
17 | % Original author: Nicolas Boumal, Dec. 30, 2012.
18 | % Contributors:
19 | % Change log:
20 | %
21 | % July 7, 2014 (NB): Added ehess2rhess function.
22 | %
23 |
24 | if ~exist('n', 'var')
25 | n = 1;
26 | end
27 |
28 | M.name = @() sprintf('Complex circle (S^1)^%d', n);
29 |
30 | M.dim = @() n;
31 |
32 | M.inner = @(z, v, w) real(v'*w);
33 |
34 | M.norm = @(x, v) norm(v);
35 |
36 | M.dist = @(x, y) norm(acos(conj(x) .* y));
37 |
38 | M.typicaldist = @() pi*sqrt(n);
39 |
40 | M.proj = @(z, u) u - real( conj(u) .* z ) .* z;
41 |
42 | M.tangent = M.proj;
43 |
44 | % For Riemannian submanifolds, converting a Euclidean gradient into a
45 | % Riemannian gradient amounts to an orthogonal projection.
46 | M.egrad2rgrad = M.proj;
47 |
48 | M.ehess2rhess = @ehess2rhess;
49 | function rhess = ehess2rhess(z, egrad, ehess, zdot)
50 | rhess = M.proj(z, ehess - real(z.*conj(egrad)).*zdot);
51 | end
52 |
53 | M.exp = @exponential;
54 | function y = exponential(z, v, t)
55 | if nargin <= 2
56 | t = 1.0;
57 | end
58 |
59 | y = zeros(n, 1);
60 | tv = t*v;
61 |
62 | nrm_tv = abs(tv);
63 |
64 | % We need to distinguish between very small steps and the others.
65 | % For very small steps, we use a a limit version of the exponential
66 | % (which actually coincides with the retraction), so as to not
67 | % divide by very small numbers.
68 | mask = nrm_tv > 1e-6;
69 | y(mask) = z(mask).*cos(nrm_tv(mask)) + ...
70 | tv(mask).*(sin(nrm_tv(mask))./nrm_tv(mask));
71 | y(~mask) = z(~mask) + tv(~mask);
72 | y(~mask) = y(~mask) ./ abs(y(~mask));
73 |
74 | end
75 |
76 | M.retr = @retraction;
77 | function y = retraction(z, v, t)
78 | if nargin <= 2
79 | t = 1.0;
80 | end
81 | y = z+t*v;
82 | y = y ./ abs(y);
83 | end
84 |
85 | M.log = @logarithm;
86 | function v = logarithm(x1, x2)
87 | v = M.proj(x1, x2 - x1);
88 | di = M.dist(x1, x2);
89 | nv = norm(v);
90 | v = v * (di / nv);
91 | end
92 |
93 | M.hash = @(z) ['z' hashmd5( [real(z(:)) ; imag(z(:))] ) ];
94 |
95 | M.rand = @random;
96 | function z = random()
97 | z = randn(n, 1) + 1i*randn(n, 1);
98 | z = z ./ abs(z);
99 | end
100 |
101 | M.randvec = @randomvec;
102 | function v = randomvec(z)
103 | % i*z(k) is a basis vector of the tangent vector to the k-th circle
104 | v = randn(n, 1) .* (1i*z);
105 | v = v / norm(v);
106 | end
107 |
108 | M.lincomb = @lincomb;
109 |
110 | M.zerovec = @(x) zeros(n, 1);
111 |
112 | M.transp = @(x1, x2, d) M.proj(x2, d);
113 |
114 | M.pairmean = @pairmean;
115 | function z = pairmean(z1, z2)
116 | z = z1+z2;
117 | z = z ./ abs(z);
118 | end
119 |
120 | M.vec = @(x, u_mat) [real(u_mat) ; imag(u_mat)];
121 | M.mat = @(x, u_vec) u_vec(1:n) + 1i*u_vec((n+1):end);
122 | M.vecmatareisometries = @() true;
123 |
124 | end
125 |
126 |
127 | % Linear combination of tangent vectors
128 | function d = lincomb(x, a1, d1, a2, d2) %#ok
129 |
130 | if nargin == 3
131 | d = a1*d1;
132 | elseif nargin == 5
133 | d = a1*d1 + a2*d2;
134 | else
135 | error('Bad use of sphere.lincomb.');
136 | end
137 |
138 | end
139 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/euclidean/euclideanfactory.m:
--------------------------------------------------------------------------------
1 | function M = euclideanfactory(m, n)
2 | % Returns a manifold struct to optimize over m-by-n matrices.
3 | %
4 | % function M = euclideanfactory(m, n)
5 | %
6 | % Returns M, a structure describing the Euclidean space of m-by-n matrices
7 | % equipped with the standard Frobenius distance and associated trace inner
8 | % product as a manifold for Manopt.
9 |
10 | % This file is part of Manopt: www.manopt.org.
11 | % Original author: Nicolas Boumal, Dec. 30, 2012.
12 | % Contributors:
13 | % Change log:
14 | % July 5, 2013 (NB): added egred2rgrad, ehess2rhess, mat, vec, tangent.
15 |
16 |
17 | if ~exist('n', 'var') || isempty(n)
18 | n = 1;
19 | end
20 |
21 | M.name = @() sprintf('Euclidean space R^(%dx%d)', m, n);
22 |
23 | M.dim = @() m*n;
24 |
25 | M.inner = @(x, d1, d2) d1(:).'*d2(:);
26 |
27 | M.norm = @(x, d) norm(d, 'fro');
28 |
29 | M.dist = @(x, y) norm(x-y, 'fro');
30 |
31 | M.typicaldist = @() sqrt(m*n);
32 |
33 | M.proj = @(x, d) d;
34 |
35 | M.egrad2rgrad = @(x, g) g;
36 |
37 | M.ehess2rhess = @(x, eg, eh, d) eh;
38 |
39 | M.tangent = M.proj;
40 |
41 | M.exp = @exp;
42 | function y = exp(x, d, t)
43 | if nargin == 3
44 | y = x + t*d;
45 | else
46 | y = x + d;
47 | end
48 | end
49 |
50 | M.retr = M.exp;
51 |
52 | M.log = @(x, y) y-x;
53 |
54 | M.hash = @(x) ['z' hashmd5(x(:))];
55 |
56 | M.rand = @() randn(m, n);
57 |
58 | M.randvec = @randvec;
59 | function u = randvec(x) %#ok
60 | u = randn(m, n);
61 | u = u / norm(u, 'fro');
62 | end
63 |
64 | M.lincomb = @lincomb;
65 | function v = lincomb(x, a1, d1, a2, d2) %#ok
66 | if nargin == 3
67 | v = a1*d1;
68 | elseif nargin == 5
69 | v = a1*d1 + a2*d2;
70 | else
71 | error('Bad usage of euclidean.lincomb');
72 | end
73 | end
74 |
75 | M.zerovec = @(x) zeros(m, n);
76 |
77 | M.transp = @(x1, x2, d) d;
78 |
79 | M.pairmean = @(x1, x2) .5*(x1+x2);
80 |
81 | M.vec = @(x, u_mat) u_mat(:);
82 | M.mat = @(x, u_vec) reshape(u_vec, [m, n]);
83 | M.vecmatareisometries = @() true;
84 |
85 | end
86 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/euclidean/symmetricfactory.m:
--------------------------------------------------------------------------------
1 | function M = symmetricfactory(n, k)
2 | % Returns a manifold struct to optimize over k symmetric matrices of size n
3 | %
4 | % function M = symmetricfactory(n)
5 | % function M = symmetricfactory(n, k)
6 | %
7 | % Returns M, a structure describing the Euclidean space of n-by-n symmetric
8 | % matrices equipped with the standard Frobenius distance and associated
9 | % trace inner product, as a manifold for Manopt.
10 | % By default, k = 1. If k > 1, points and vectors are stored in 3D matrices
11 | % X of size nxnxk such that each slice X(:, :, i), for i = 1:k, is
12 | % symmetric.
13 |
14 | % This file is part of Manopt: www.manopt.org.
15 | % Original author: Nicolas Boumal, Jan. 22, 2014.
16 | % Contributors:
17 | % Change log:
18 |
19 | if ~exist('k', 'var') || isempty(k)
20 | k = 1;
21 | end
22 |
23 | M.name = @() sprintf('(Symmetric matrices of size %d)^%d', n, k);
24 |
25 | M.dim = @() k*n*(n+1)/2;
26 |
27 | M.inner = @(x, d1, d2) d1(:).'*d2(:);
28 |
29 | M.norm = @(x, d) norm(d(:), 'fro');
30 |
31 | M.dist = @(x, y) norm(x(:)-y(:), 'fro');
32 |
33 | M.typicaldist = @() sqrt(k)*n;
34 |
35 | M.proj = @(x, d) multisym(d);
36 |
37 | M.egrad2rgrad = M.proj;
38 |
39 | %M.ehess2rhess = @(x, eg, eh, d) eh;
40 |
41 | M.tangent = @(x, d) d;
42 |
43 | M.exp = @exp;
44 | function y = exp(x, d, t)
45 | if nargin == 3
46 | y = x + t*d;
47 | else
48 | y = x + d;
49 | end
50 | end
51 |
52 | M.retr = M.exp;
53 |
54 | M.log = @(x, y) y-x;
55 |
56 | M.hash = @(x) ['z' hashmd5(x(:))];
57 |
58 | M.rand = @() multisym(randn(n, n, k));
59 |
60 | M.randvec = @randvec;
61 | function u = randvec(x) %#ok
62 | u = multisym(randn(n, n, k));
63 | u = u / norm(u(:), 'fro');
64 | end
65 |
66 | M.lincomb = @lincomb;
67 | function v = lincomb(x, a1, d1, a2, d2) %#ok
68 | if nargin == 3
69 | v = a1*d1;
70 | elseif nargin == 5
71 | v = a1*d1 + a2*d2;
72 | else
73 | error('Bad usage of euclidean.lincomb');
74 | end
75 | end
76 |
77 | M.zerovec = @(x) zeros(n, n, k);
78 |
79 | M.transp = @(x1, x2, d) d;
80 |
81 | M.pairmean = @(x1, x2) .5*(x1+x2);
82 |
83 | M.vec = @(x, u_mat) u_mat(:);
84 | M.mat = @(x, u_vec) reshape(u_vec, [m, n]);
85 | M.vecmatareisometries = @() true;
86 |
87 | end
88 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/fixedrank/fixedrankMNquotientfactory.m:
--------------------------------------------------------------------------------
1 | function M = fixedrankMNquotientfactory(m, n, k)
2 | % Manifold of m-by-n matrices of rank k with quotient geometry.
3 | %
4 | % function M = fixedrankMNquotientfactory(m, n, k)
5 | %
6 | % This follows the quotient geometry described in the following paper:
7 | % P.-A. Absil, L. Amodei and G. Meyer,
8 | % "Two Newton methods on the manifold of fixed-rank matrices endowed
9 | % with Riemannian quotient geometries", arXiv, 2012.
10 | %
11 | % Paper link: http://arxiv.org/abs/1209.0068
12 | %
13 | % A point X on the manifold is represented as a structure with two
14 | % fields: M and N. The matrix M (mxk) is orthonormal, while the matrix N
15 | % (nxk) is full-rank.
16 | %
17 | % Tangent vectors are represented as a structure with two fields (M, N).
18 |
19 | % This file is part of Manopt: www.manopt.org.
20 | % Original author: Bamdev Mishra, Dec. 30, 2012.
21 | % Contributors:
22 | % Change log:
23 |
24 |
25 | M.name = @() sprintf('MN'' quotient manifold of %dx%d matrices of rank %d', m, n, k);
26 |
27 | M.dim = @() (m+n-k)*k;
28 |
29 | % Choice of the metric is motivated by the symmetry present in the
30 | % space
31 | M.inner = @(X, eta, zeta) eta.M(:).'*zeta.M(:) + eta.N(:).'*zeta.N(:);
32 |
33 | M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
34 |
35 | M.dist = @(x, y) error('fixedrankMNquotientfactory.dist not implemented yet.');
36 |
37 | M.typicaldist = @() 10*k;
38 |
39 | symm = @(X) .5*(X+X');
40 | stiefel_proj = @(M, H) H - M*symm(M'*H);
41 |
42 | M.egrad2rgrad = @egrad2rgrad;
43 | function eta = egrad2rgrad(X, eta)
44 | eta.M = stiefel_proj(X.M, eta.M);
45 | end
46 |
47 | M.ehess2rhess = @ehess2rhess;
48 | function Hess = ehess2rhess(X, egrad, ehess, eta)
49 |
50 | % Directional derivative of the Riemannian gradient
51 | Hess.M = ehess.M - eta.M*symm(X.M'*egrad.M);
52 | Hess.M = stiefel_proj(X.M, Hess.M);
53 |
54 | Hess.N = ehess.N;
55 |
56 | % Projection onto the horizontal space
57 | Hess = M.proj(X, Hess);
58 | end
59 |
60 |
61 | M.proj = @projection;
62 | function etaproj = projection(X, eta)
63 |
64 | % Start by projecting the vector from Rmp x Rnp to the tangent
65 | % space to the total space, that is, eta.M should be in the
66 | % tangent space to Stiefel at X.M and eta.N is arbitrary.
67 | eta.M = stiefel_proj(X.M, eta.M);
68 |
69 | % Now project from the tangent space to the horizontal space, that
70 | % is, take care of the quotient.
71 |
72 | % First solve a Sylvester equation (A symm., B skew-symm.)
73 | A = X.N'*X.N + eye(k);
74 | B = eta.M'*X.M + eta.N'*X.N;
75 | B = B-B';
76 | omega = lyap(A, -B);
77 |
78 | % And project along the vertical space to the horizontal space.
79 | etaproj.M = eta.M + X.M*omega;
80 | etaproj.N = eta.N + X.N*omega;
81 |
82 | end
83 |
84 | M.exp = @exponential;
85 | function Y = exponential(X, eta, t)
86 | if nargin < 3
87 | t = 1.0;
88 | end
89 |
90 | A = t*X.M'*eta.M;
91 | S = t^2*eta.M'*eta.M;
92 | Y.M = [X.M t*eta.M]*expm([A -S ; eye(k) A])*eye(2*k, k)*expm(-A);
93 |
94 | % re-orthonormalize (seems necessary from time to time)
95 | [Q R] = qr(Y.M, 0);
96 | Y.M = Q * diag(sign(diag(R)));
97 |
98 | Y.N = X.N + t*eta.N;
99 |
100 | end
101 |
102 | % Factor M lives on the Stiefel manifold, hence we will reuse its
103 | % random generator.
104 | stiefelm = stiefelfactory(m, k);
105 |
106 | M.retr = @retraction;
107 | function Y = retraction(X, eta, t)
108 | if nargin < 3
109 | t = 1.0;
110 | end
111 |
112 | Y.M = uf(X.M + t*eta.M); % This is a valid retraction
113 | Y.N = X.N + t*eta.N;
114 | end
115 |
116 | M.hash = @(X) ['z' hashmd5([X.M(:) ; X.N(:)])];
117 |
118 | M.rand = @random;
119 | function X = random()
120 | X.M = stiefelm.rand();
121 | X.N = randn(n, k);
122 | end
123 |
124 | M.randvec = @randomvec;
125 | function eta = randomvec(X)
126 | eta.M = randn(m, k);
127 | eta.N = randn(n, k);
128 | eta = projection(X, eta);
129 | nrm = M.norm(X, eta);
130 | eta.M = eta.M / nrm;
131 | eta.N = eta.N / nrm;
132 | end
133 |
134 | M.lincomb = @lincomb;
135 |
136 | M.zerovec = @(X) struct('M', zeros(m, k), 'N', zeros(n, k));
137 |
138 | M.transp = @(x1, x2, d) projection(x2, d);
139 |
140 | end
141 |
142 |
143 | % Linear combination of tangent vectors
144 | function d = lincomb(x, a1, d1, a2, d2) %#ok
145 |
146 | if nargin == 3
147 | d.M = a1*d1.M;
148 | d.N = a1*d1.N;
149 | elseif nargin == 5
150 | d.M = a1*d1.M + a2*d2.M;
151 | d.N = a1*d1.N + a2*d2.N;
152 | else
153 | error('Bad use of fixedrankMNquotientfactory.lincomb.');
154 | end
155 |
156 | end
157 |
158 |
159 | function A = uf(A)
160 | [L, unused, R] = svd(A, 0);
161 | A = L*R';
162 | end
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/fixedrank/fixedrankfactory_2factors.m:
--------------------------------------------------------------------------------
1 | function M = fixedrankfactory_2factors(m, n, k)
2 | % Manifold of m-by-n matrices of rank k with balanced quotient geometry.
3 | %
4 | % function M = fixedrankfactory_2factors(m, n, k)
5 | %
6 | % This follows the balanced quotient geometry described in the following paper:
7 | % G. Meyer, S. Bonnabel and R. Sepulchre,
8 | % "Linear regression under fixed-rank constraints: a Riemannian approach",
9 | % ICML 2011.
10 | %
11 | % Paper link: http://www.icml-2011.org/papers/350_icmlpaper.pdf
12 | %
13 | % A point X on the manifold is represented as a structure with two
14 | % fields: L and R. The matrices L (mxk) and R (nxk) are full column-rank
15 | % matrices such that X = L*R'.
16 | %
17 | % Tangent vectors are represented as a structure with two fields: L, R
18 |
19 | % This file is part of Manopt: www.manopt.org.
20 | % Original author: Bamdev Mishra, Dec. 30, 2012.
21 | % Contributors:
22 | % Change log:
23 | % July 10, 2013 (NB) : added vec, mat, tangent, tangent2ambient
24 |
25 |
26 | M.name = @() sprintf('LR'' quotient manifold of %dx%d matrices of rank %d', m, n, k);
27 |
28 | M.dim = @() (m+n-k)*k;
29 |
30 | % Some precomputations at the point X to be used in the inner product (and
31 | % pretty much everywhere else).
32 | function X = prepare(X)
33 | if ~all(isfield(X,{'LtL','RtR','invRtR','invLtL'}))
34 | L = X.L;
35 | R = X.R;
36 | X.LtL = L'*L;
37 | X.RtR = R'*R;
38 | X.invLtL = inv(X.LtL);
39 | X.invRtR = inv(X.RtR);
40 | end
41 | end
42 |
43 | % Choice of the metric is motivated by the symmetry present in the space
44 | M.inner = @iproduct;
45 | function ip = iproduct(X, eta, zeta)
46 | X = prepare(X);
47 | ip = trace(X.invLtL*(eta.L'*zeta.L)) + trace( X.invRtR*(eta.R'*zeta.R));
48 | end
49 |
50 | M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
51 |
52 | M.dist = @(x, y) error('fixedrankfactory_2factors.dist not implemented yet.');
53 |
54 | M.typicaldist = @() 10*k;
55 |
56 | symm = @(M) .5*(M+M');
57 |
58 | M.egrad2rgrad = @egrad2rgrad;
59 | function eta = egrad2rgrad(X, eta)
60 | X = prepare(X);
61 | eta.L = eta.L*X.LtL;
62 | eta.R = eta.R*X.RtR;
63 | end
64 |
65 | M.ehess2rhess = @ehess2rhess;
66 | function Hess = ehess2rhess(X, egrad, ehess, eta)
67 | X = prepare(X);
68 |
69 | % Riemannian gradient
70 | rgrad = egrad2rgrad(X, egrad);
71 |
72 | % Directional derivative of the Riemannian gradient
73 | Hess.L = ehess.L*X.LtL + 2*egrad.L*symm(eta.L'*X.L);
74 | Hess.R = ehess.R*X.RtR + 2*egrad.R*symm(eta.R'*X.R);
75 |
76 | % We need a correction term for the non-constant metric
77 | Hess.L = Hess.L - rgrad.L*((X.invLtL)*symm(X.L'*eta.L)) - eta.L*(X.invLtL*symm(X.L'*rgrad.L)) + X.L*(X.invLtL*symm(eta.L'*rgrad.L));
78 | Hess.R = Hess.R - rgrad.R*((X.invRtR)*symm(X.R'*eta.R)) - eta.R*(X.invRtR*symm(X.R'*rgrad.R)) + X.R*(X.invRtR*symm(eta.R'*rgrad.R));
79 |
80 | % Projection onto the horizontal space
81 | Hess = M.proj(X, Hess);
82 | end
83 |
84 | M.proj = @projection;
85 | % Projection of the vector eta onto the horizontal space
86 | function etaproj = projection(X, eta)
87 | X = prepare(X);
88 |
89 | SS = (X.LtL)*(X.RtR);
90 | AS = (X.LtL)*(X.R'*eta.R) - (eta.L'*X.L)*(X.RtR);
91 | Omega = lyap(SS, SS,-AS);
92 | etaproj.L = eta.L + X.L*Omega';
93 | etaproj.R = eta.R - X.R*Omega;
94 | end
95 |
96 | M.tangent = M.proj;
97 | M.tangent2ambient = @(X, eta) eta;
98 |
99 | M.retr = @retraction;
100 | function Y = retraction(X, eta, t)
101 | if nargin < 3
102 | t = 1.0;
103 | end
104 |
105 | Y.L = X.L + t*eta.L;
106 | Y.R = X.R + t*eta.R;
107 |
108 | % Numerical conditioning step: A simpler version.
109 | % We need to ensure that L and R do not have very relative
110 | % skewed norms.
111 |
112 | scaling = norm(X.L, 'fro')/norm(X.R, 'fro');
113 | scaling = sqrt(scaling);
114 | Y.L = Y.L / scaling;
115 | Y.R = Y.R * scaling;
116 |
117 | % These are reused in the computation of the gradient and Hessian
118 | Y = prepare(Y);
119 | end
120 |
121 | M.exp = @exponential;
122 | function Y = exponential(X, eta, t)
123 | if nargin < 3
124 | t = 1.0;
125 | end
126 |
127 | Y = retraction(X, eta, t);
128 | warning('manopt:fixedrankfactory_2factors:exp', ...
129 | ['Exponential for fixed rank ' ...
130 | 'manifold not implemented yet. Used retraction instead.']);
131 | end
132 |
133 | M.hash = @(X) ['z' hashmd5([X.L(:) ; X.R(:)])];
134 |
135 | M.rand = @random;
136 | function X = random()
137 | % A random point on the total space
138 | X.L = randn(m, k);
139 | X.R = randn(n, k);
140 | X = prepare(X);
141 | end
142 |
143 | M.randvec = @randomvec;
144 | function eta = randomvec(X)
145 | % A random vector in the horizontal space
146 | eta.L = randn(m, k);
147 | eta.R = randn(n, k);
148 | eta = projection(X, eta);
149 | nrm = M.norm(X, eta);
150 | eta.L = eta.L / nrm;
151 | eta.R = eta.R / nrm;
152 | end
153 |
154 | M.lincomb = @lincomb;
155 |
156 | M.zerovec = @(X) struct('L', zeros(m, k),'R', zeros(n, k));
157 |
158 | M.transp = @(x1, x2, d) projection(x2, d);
159 |
160 | % vec and mat are not isometries, because of the unusual inner metric.
161 | M.vec = @(X, U) [U.L(:) ; U.R(:)];
162 | M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ...
163 | 'R', reshape(u((m*k+1):end), n, k));
164 | M.vecmatareisometries = @() false;
165 |
166 | end
167 |
168 | % Linear combination of tangent vectors
169 | function d = lincomb(x, a1, d1, a2, d2) %#ok
170 |
171 | if nargin == 3
172 | d.L = a1*d1.L;
173 | d.R = a1*d1.R;
174 | elseif nargin == 5
175 | d.L = a1*d1.L + a2*d2.L;
176 | d.R = a1*d1.R + a2*d2.R;
177 | else
178 | error('Bad use of fixedrankfactory_2factors.lincomb.');
179 | end
180 |
181 | end
182 |
183 |
184 |
185 |
186 |
187 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/fixedrank/fixedrankfactory_2factors_preconditioned.m:
--------------------------------------------------------------------------------
1 | function M = fixedrankfactory_2factors_preconditioned(m, n, k)
2 | % Manifold of m-by-n matrices of rank k with new balanced quotient geometry
3 | %
4 | % function M = fixedrankfactory_2factors_preconditioned(m, n, k)
5 | %
6 | % This follows the quotient geometry described in the following paper:
7 | % B. Mishra, K. Adithya Apuroop and R. Sepulchre,
8 | % "A Riemannian geometry for low-rank matrix completion",
9 | % arXiv, 2012.
10 | %
11 | % Paper link: http://arxiv.org/abs/1211.1550
12 | %
13 | % This geoemtry is tuned to least square problems such as low-rank matrix
14 | % completion.
15 | %
16 | % A point X on the manifold is represented as a structure with two
17 | % fields: L and R. The matrices L (mxk) and R (nxk) are full column-rank
18 | % matrices.
19 | %
20 | % Tangent vectors are represented as a structure with two fields: L, R
21 |
22 | % This file is part of Manopt: www.manopt.org.
23 | % Original author: Bamdev Mishra, Dec. 30, 2012.
24 | % Contributors:
25 | % Change log:
26 |
27 |
28 |
29 | M.name = @() sprintf('LR''(tuned for least square problems) quotient manifold of %dx%d matrices of rank %d', m, n, k);
30 |
31 | M.dim = @() (m+n-k)*k;
32 |
33 |
34 | % Some precomputations at the point X to be used in the inner product (and
35 | % pretty much everywhere else).
36 | function X = prepare(X)
37 | if ~all(isfield(X,{'LtL','RtR','invRtR','invLtL'}))
38 | L = X.L;
39 | R = X.R;
40 | X.LtL = L'*L;
41 | X.RtR = R'*R;
42 | X.invLtL = inv(X.LtL);
43 | X.invRtR = inv(X.RtR);
44 | end
45 | end
46 |
47 |
48 | % The choice of metric is motivated by symmetry and tuned to least square
49 | % objective function
50 | M.inner = @iproduct;
51 | function ip = iproduct(X, eta, zeta)
52 | X = prepare(X);
53 |
54 | ip = trace(X.RtR*(eta.L'*zeta.L)) + trace(X.LtL*(eta.R'*zeta.R));
55 | end
56 |
57 | M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
58 |
59 | M.dist = @(x, y) error('fixedrankfactory_2factors_preconditioned.dist not implemented yet.');
60 |
61 | M.typicaldist = @() 10*k;
62 |
63 | symm = @(M) .5*(M+M');
64 |
65 | M.egrad2rgrad = @egrad2rgrad;
66 | function eta = egrad2rgrad(X, eta)
67 | X = prepare(X);
68 |
69 | eta.L = eta.L*X.invRtR;
70 | eta.R = eta.R*X.invLtL;
71 | end
72 |
73 | M.ehess2rhess = @ehess2rhess;
74 | function Hess = ehess2rhess(X, egrad, ehess, eta)
75 | X = prepare(X);
76 |
77 | % Riemannian gradient
78 | rgrad = egrad2rgrad(X, egrad);
79 |
80 | % Directional derivative of the Riemannian gradient
81 | Hess.L = ehess.L*X.invRtR - 2*egrad.L*(X.invRtR * symm(eta.R'*X.R) * X.invRtR);
82 | Hess.R = ehess.R*X.invLtL - 2*egrad.R*(X.invLtL * symm(eta.L'*X.L) * X.invLtL);
83 |
84 | % We still need a correction factor for the non-constant metric
85 | Hess.L = Hess.L + rgrad.L*(symm(eta.R'*X.R)*X.invRtR) + eta.L*(symm(rgrad.R'*X.R)*X.invRtR) - X.L*(symm(eta.R'*rgrad.R)*X.invRtR);
86 | Hess.R = Hess.R + rgrad.R*(symm(eta.L'*X.L)*X.invLtL) + eta.R*(symm(rgrad.L'*X.L)*X.invLtL) - X.R*(symm(eta.L'*rgrad.L)*X.invLtL);
87 |
88 | % Project on the horizontal space
89 | Hess = M.proj(X, Hess);
90 |
91 | end
92 |
93 | M.proj = @projection;
94 | function etaproj = projection(X, eta)
95 | X = prepare(X);
96 |
97 | Lambda = (eta.R'*X.R)*X.invRtR - X.invLtL*(X.L'*eta.L);
98 | Lambda = Lambda/2;
99 | etaproj.L = eta.L + X.L*Lambda;
100 | etaproj.R = eta.R - X.R*Lambda';
101 | end
102 |
103 | M.tangent = M.proj;
104 | M.tangent2ambient = @(X, eta) eta;
105 |
106 |
107 |
108 | M.retr = @retraction;
109 | function Y = retraction(X, eta, t)
110 | if nargin < 3
111 | t = 1.0;
112 | end
113 | Y.L = X.L + t*eta.L;
114 | Y.R = X.R + t*eta.R;
115 |
116 | % Numerical conditioning step: A simpler version.
117 | % We need to ensure that L and R are do not have very relative
118 | % skewed norms.
119 |
120 | scaling = norm(X.L, 'fro')/norm(X.R, 'fro');
121 | scaling = sqrt(scaling);
122 | Y.L = Y.L / scaling;
123 | Y.R = Y.R * scaling;
124 |
125 | % These are reused in the computation of the gradient and Hessian
126 | Y = prepare(Y);
127 | end
128 |
129 |
130 | M.exp = @exponential;
131 | function Y = exponential(X, eta, t)
132 | if nargin < 3
133 | t = 1.0;
134 | end
135 |
136 | Y = retraction(X, eta, t);
137 | warning('manopt:fixedrankfactory_2factors_preconditioned:exp', ...
138 | ['Exponential for fixed rank ' ...
139 | 'manifold not implemented yet. Used retraction instead.']);
140 | end
141 |
142 | M.hash = @(X) ['z' hashmd5([X.L(:) ; X.R(:)])];
143 |
144 | M.rand = @random;
145 |
146 | function X = random()
147 | X.L = randn(m, k);
148 | X.R = randn(n, k);
149 | end
150 |
151 | M.randvec = @randomvec;
152 | function eta = randomvec(X)
153 | eta.L = randn(m, k);
154 | eta.R = randn(n, k);
155 | eta = projection(X, eta);
156 | nrm = M.norm(X, eta);
157 | eta.L = eta.L / nrm;
158 | eta.R = eta.R / nrm;
159 | end
160 |
161 | M.lincomb = @lincomb;
162 |
163 | M.zerovec = @(X) struct('L', zeros(m, k),'R', zeros(n, k));
164 |
165 | M.transp = @(x1, x2, d) projection(x2, d);
166 |
167 | % vec and mat are not isometries, because of the unusual inner metric.
168 | M.vec = @(X, U) [U.L(:) ; U.R(:)];
169 | M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ...
170 | 'R', reshape(u((m*k+1):end), n, k));
171 | M.vecmatareisometries = @() false;
172 |
173 | end
174 |
175 | % Linear combination of tangent vectors
176 | function d = lincomb(x, a1, d1, a2, d2) %#ok
177 |
178 | if nargin == 3
179 | d.L = a1*d1.L;
180 | d.R = a1*d1.R;
181 | elseif nargin == 5
182 | d.L = a1*d1.L + a2*d2.L;
183 | d.R = a1*d1.R + a2*d2.R;
184 | else
185 | error('Bad use of fixedrankfactory_2factors_preconditioned.lincomb.');
186 | end
187 |
188 | end
189 |
190 |
191 |
192 |
193 |
194 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/fixedrank/fixedrankfactory_3factors.m:
--------------------------------------------------------------------------------
1 | function M = fixedrankfactory_3factors(m, n, k)
2 | % Manifold of m-by-n matrices of rank k with polar quotient geometry.
3 | %
4 | % function M = fixedrankfactory_3factors(m, n, k)
5 | %
6 | % Follows the polar quotient geometry described in the following paper:
7 | % G. Meyer, S. Bonnabel and R. Sepulchre,
8 | % "Linear regression under fixed-rank constraints: a Riemannian approach",
9 | % ICML 2011.
10 | %
11 | % Paper link: http://www.icml-2011.org/papers/350_icmlpaper.pdf
12 | %
13 | % Additional reference is
14 | %
15 | % B. Mishra, R. Meyer, S. Bonnabel and R. Sepulchre
16 | % "Fixed-rank matrix factorizations and Riemannian low-rank optimization",
17 | % arXiv, 2012.
18 | %
19 | % Paper link: http://arxiv.org/abs/1209.0430
20 | %
21 | % A point X on the manifold is represented as a structure with three
22 | % fields: L, S and R. The matrices L (mxk) and R (nxk) are orthonormal,
23 | % while the matrix S (kxk) is a symmetric positive definite full rank
24 | % matrix.
25 | %
26 | % Tangent vectors are represented as a structure with three fields: L, S
27 | % and R.
28 |
29 | % This file is part of Manopt: www.manopt.org.
30 | % Original author: Bamdev Mishra, Dec. 30, 2012.
31 | % Contributors:
32 | % Change log:
33 |
34 | M.name = @() sprintf('LSR'' quotient manifold of %dx%d matrices of rank %d', m, n, k);
35 |
36 | M.dim = @() (m+n-k)*k;
37 |
38 | % Choice of the metric on the orthnormal space is motivated by the symmetry present in the
39 | % space. The metric on the positive definite space is its natural metric.
40 | M.inner = @(X, eta, zeta) eta.L(:).'*zeta.L(:) + eta.R(:).'*zeta.R(:) ...
41 | + trace( (X.S\eta.S) * (X.S\zeta.S) );
42 |
43 | M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
44 |
45 | M.dist = @(x, y) error('fixedrankfactory_3factors.dist not implemented yet.');
46 |
47 | M.typicaldist = @() 10*k;
48 |
49 | skew = @(X) .5*(X-X');
50 | symm = @(X) .5*(X+X');
51 | stiefel_proj = @(L, H) H - L*symm(L'*H);
52 |
53 | M.egrad2rgrad = @egrad2rgrad;
54 | function eta = egrad2rgrad(X, eta)
55 | eta.L = stiefel_proj(X.L, eta.L);
56 | eta.S = X.S*symm(eta.S)*X.S;
57 | eta.R = stiefel_proj(X.R, eta.R);
58 | end
59 |
60 |
61 | M.ehess2rhess = @ehess2rhess;
62 | function Hess = ehess2rhess(X, egrad, ehess, eta)
63 |
64 | % Riemannian gradient for the factor S
65 | rgrad.S = X.S*symm(egrad.S)*X.S;
66 |
67 | % Directional derivatives of the Riemannian gradient
68 | Hess.L = ehess.L - eta.L*symm(X.L'*egrad.L);
69 | Hess.L = stiefel_proj(X.L, Hess.L);
70 |
71 | Hess.R = ehess.R - eta.R*symm(X.R'*egrad.R);
72 | Hess.R = stiefel_proj(X.R, Hess.R);
73 |
74 | Hess.S = X.S*symm(ehess.S)*X.S + 2*symm(eta.S*symm(egrad.S)*X.S);
75 |
76 | % Correction factor for the non-constant metric on the factor S
77 | Hess.S = Hess.S - symm(eta.S*(X.S\rgrad.S));
78 |
79 | % Projection onto the horizontal space
80 | Hess = M.proj(X, Hess);
81 | end
82 |
83 |
84 | M.proj = @projection;
85 | function etaproj = projection(X, eta)
86 | % First, projection onto the tangent space of the total sapce
87 | eta.L = stiefel_proj(X.L, eta.L);
88 | eta.R = stiefel_proj(X.R, eta.R);
89 | eta.S = symm(eta.S);
90 |
91 | % Then, projection onto the horizontal space
92 | SS = X.S*X.S;
93 | AS = X.S*(skew(X.L'*eta.L) + skew(X.R'*eta.R) - 2*skew(X.S\eta.S))*X.S;
94 | omega = lyap(SS, -AS);
95 |
96 | etaproj.L = eta.L - X.L*omega;
97 | etaproj.S = eta.S - (X.S*omega - omega*X.S);
98 | etaproj.R = eta.R - X.R*omega;
99 | end
100 |
101 | M.tangent = M.proj;
102 | M.tangent2ambient = @(X, eta) eta;
103 |
104 | M.retr = @retraction;
105 | function Y = retraction(X, eta, t)
106 | if nargin < 3
107 | t = 1.0;
108 | end
109 |
110 | L = chol(X.S);
111 | Y.S = L'*expm(L'\(t*eta.S)/L)*L;
112 | Y.L = uf(X.L + t*eta.L);
113 | Y.R = uf(X.R + t*eta.R);
114 | end
115 |
116 | M.exp = @exponential;
117 | function Y = exponential(X, eta, t)
118 | if nargin < 3
119 | t = 1.0;
120 | end
121 | Y = retraction(X, eta, t);
122 | warning('manopt:fixedrankfactory_3factors:exp', ...
123 | ['Exponential for fixed rank ' ...
124 | 'manifold not implemented yet. Lsed retraction instead.']);
125 | end
126 |
127 | M.hash = @(X) ['z' hashmd5([X.L(:) ; X.S(:) ; X.R(:)])];
128 |
129 | M.rand = @random;
130 | % Factors L and R live on Stiefel manifolds, hence we will reuse
131 | % their random generator.
132 | stiefelm = stiefelfactory(m, k);
133 | stiefeln = stiefelfactory(n, k);
134 | function X = random()
135 | X.L = stiefelm.rand();
136 | X.R = stiefeln.rand();
137 | X.S = diag(1+rand(k, 1));
138 | end
139 |
140 | M.randvec = @randomvec;
141 | function eta = randomvec(X)
142 | % A random vector on the horizontal space
143 | eta.L = randn(m, k);
144 | eta.R = randn(n, k);
145 | eta.S = randn(k, k);
146 | eta = projection(X, eta);
147 | nrm = M.norm(X, eta);
148 | eta.L = eta.L / nrm;
149 | eta.R = eta.R / nrm;
150 | eta.S = eta.S / nrm;
151 | end
152 |
153 | M.lincomb = @lincomb;
154 |
155 | M.zerovec = @(X) struct('L', zeros(m, k), 'S', zeros(k, k), ...
156 | 'R', zeros(n, k));
157 |
158 | M.transp = @(x1, x2, d) projection(x2, d);
159 |
160 | % vec and mat are not isometries, because of the unusual inner metric.
161 | M.vec = @(X, U) [U.L(:) ; U.S(:); U.R(:)];
162 | M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ...
163 | 'S', reshape(u((m*k+1): m*k + k*k), k, k), ...
164 | 'R', reshape(u((m*k+ k*k + 1):end), n, k));
165 | M.vecmatareisometries = @() false;
166 |
167 | end
168 |
169 | % Linear combination of tangent vectors
170 | function d = lincomb(x, a1, d1, a2, d2) %#ok
171 |
172 | if nargin == 3
173 | d.L = a1*d1.L;
174 | d.R = a1*d1.R;
175 | d.S = a1*d1.S;
176 | elseif nargin == 5
177 | d.L = a1*d1.L + a2*d2.L;
178 | d.R = a1*d1.R + a2*d2.R;
179 | d.S = a1*d1.S + a2*d2.S;
180 | else
181 | error('Bad use of fixedrankfactory_3factors.lincomb.');
182 | end
183 |
184 | end
185 |
186 | function A = uf(A)
187 | [L, unused, R] = svd(A, 0); %#ok
188 | A = L*R';
189 | end
190 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/rotations/randrot.m:
--------------------------------------------------------------------------------
1 | function R = randrot(n, N)
2 | % Generates uniformly random rotation matrices.
3 | %
4 | % function R = randrot(n, N)
5 | %
6 | % R is a n-by-n-by-N matrix such that each slice R(:, :, i) is an
7 | % orthogonal matrix of size n of determinant +1 (i.e., a matrix in SO(n)).
8 | % By default, N = 1.
9 | % Complexity: N times O(n^3).
10 | % Theory in Diaconis and Shahshahani 1987 for the uniformity on O(n);
11 | % With details in Mezzadri 2007,
12 | % "How to generate random matrices from the classical compact groups."
13 | % To ensure matrices in SO(n), we permute the two first columns when
14 | % the determinant is -1.
15 | %
16 | % See also: randskew
17 |
18 | % This file is part of Manopt: www.manopt.org.
19 | % Original author: Nicolas Boumal, Sept. 25, 2012.
20 | % Contributors:
21 | % Change log:
22 |
23 | if nargin < 2
24 | N = 1;
25 | end
26 |
27 | if n == 1
28 | R = ones(1, 1, N);
29 | return;
30 | end
31 |
32 | R = zeros(n, n, N);
33 |
34 | for i = 1 : N
35 |
36 | % Generated as such, Q is uniformly distributed over O(n), the set
37 | % of orthogonal matrices.
38 | A = randn(n);
39 | [Q, RR] = qr(A);
40 | Q = Q * diag(sign(diag(RR))); %% Mezzadri 2007
41 |
42 | % If Q is in O(n) but not in SO(n), we permute the two first
43 | % columns of Q such that det(new Q) = -det(Q), hence the new Q will
44 | % be in SO(n), uniformly distributed.
45 | if det(Q) < 0
46 | Q(:, [1 2]) = Q(:, [2 1]);
47 | end
48 |
49 | R(:, :, i) = Q;
50 |
51 | end
52 |
53 | end
54 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/rotations/randskew.m:
--------------------------------------------------------------------------------
1 | function S = randskew(n, N)
2 | % Generates random skew symmetric matrices with normal entries.
3 | %
4 | % function S = randskew(n, N)
5 | %
6 | % S is an n-by-n-by-N matrix where each slice S(:, :, i) for i = 1..N is a
7 | % random skew-symmetric matrix with upper triangular entries distributed
8 | % independently following a normal distribution (Gaussian, zero mean, unit
9 | % variance).
10 | %
11 | % See also: randrot
12 |
13 | % This file is part of Manopt: www.manopt.org.
14 | % Original author: Nicolas Boumal, Sept. 25, 2012.
15 | % Contributors:
16 | % Change log:
17 |
18 |
19 | if nargin < 2
20 | N = 1;
21 | end
22 |
23 | % Subindices of the (strictly) upper triangular entries of an n-by-n
24 | % matrix
25 | [I J] = find(triu(ones(n), 1));
26 |
27 | K = repmat(1:N, n*(n-1)/2, 1);
28 |
29 | % Indices of the strictly upper triangular entries of all N slices of
30 | % an n-by-n-by-N matrix
31 | L = sub2ind([n n N], repmat(I, N, 1), repmat(J, N, 1), K(:));
32 |
33 | % Allocate memory for N random skew matrices of size n-by-n and
34 | % populate each upper triangular entry with a random number following a
35 | % normal distribution and copy them with opposite sign on the
36 | % corresponding lower triangular side.
37 | S = zeros(n, n, N);
38 | S(L) = randn(size(L));
39 | S = S-multitransp(S);
40 |
41 | end
42 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/rotations/rotationsfactory.m:
--------------------------------------------------------------------------------
1 | function M = rotationsfactory(n, k)
2 | % Returns a manifold structure to optimize over rotation matrices.
3 | %
4 | % function M = rotationsfactory(n)
5 | % function M = rotationsfactory(n, k)
6 | %
7 | % Special orthogonal group (the manifold of rotations): deals with matrices
8 | % R of size n x n x k (or n x n if k = 1, which is the default) such that
9 | % each n x n matrix is orthogonal, with determinant 1, i.e., X'*X = eye(n)
10 | % if k = 1, or X(:, :, i)' * X(:, :, i) = eye(n) for i = 1 : k if k > 1.
11 | %
12 | % This is a description of SO(n)^k with the induced metric from the
13 | % embedding space (R^nxn)^k, i.e., this manifold is a Riemannian
14 | % submanifold of (R^nxn)^k endowed with the usual trace inner product.
15 | %
16 | % Tangent vectors are represented in the Lie algebra, i.e., as skew
17 | % symmetric matrices. Use the function M.tangent2ambient(X, H) to switch
18 | % from the Lie algebra representation to the embedding space
19 | % representation.
20 | %
21 | % By default, k = 1.
22 | %
23 | % See also: stiefelfactory
24 |
25 | % This file is part of Manopt: www.manopt.org.
26 | % Original author: Nicolas Boumal, Dec. 30, 2012.
27 | % Contributors:
28 | % Change log:
29 | % Jan. 31, 2013, NB : added egrad2rgrad and ehess2rhess
30 |
31 |
32 | if ~exist('k', 'var') || isempty(k)
33 | k = 1;
34 | end
35 |
36 | if k == 1
37 | M.name = @() sprintf('Rotations manifold SO(%d)', n);
38 | elseif k > 1
39 | M.name = @() sprintf('Product rotations manifold SO(%d)^%d', n, k);
40 | else
41 | error('k must be an integer no less than 1.');
42 | end
43 |
44 | M.dim = @() k*nchoosek(n, 2);
45 |
46 | M.inner = @(x, d1, d2) d1(:).'*d2(:);
47 |
48 | M.norm = @(x, d) norm(d(:));
49 |
50 | M.typicaldist = @() pi*sqrt(n*k);
51 |
52 | M.proj = @(X, H) multiskew(multiprod(multitransp(X), H));
53 |
54 | M.tangent = @(X, H) multiskew(H);
55 |
56 | M.tangent2ambient = @(X, U) multiprod(X, U);
57 |
58 | M.egrad2rgrad = M.proj;
59 |
60 | M.ehess2rhess = @ehess2rhess;
61 | function Rhess = ehess2rhess(X, Egrad, Ehess, H)
62 | % Reminder : H contains skew-symmeric matrices. The actual
63 | % direction that the point X is moved along is X*H.
64 | Xt = multitransp(X);
65 | XtEgrad = multiprod(Xt, Egrad);
66 | symXtEgrad = multisym(XtEgrad);
67 | XtEhess = multiprod(Xt, Ehess);
68 | Rhess = multiskew( XtEhess - multiprod(H, symXtEgrad) );
69 | end
70 |
71 | M.retr = @retraction;
72 | function Y = retraction(X, U, t)
73 | if nargin == 3
74 | tU = t*U;
75 | else
76 | tU = U;
77 | end
78 | Y = X + multiprod(X, tU);
79 | for i = 1 : k
80 | [Q R] = qr(Y(:, :, i));
81 | % The instruction with R ensures we are not flipping signs
82 | % of some columns, which should never happen in modern Matlab
83 | % versions but may be an issue with older versions.
84 | Y(:, :, i) = Q * diag(sign(sign(diag(R))+.5));
85 | % This is guaranteed to always yield orthogonal matrices with
86 | % determinant +1. Simply look at the eigenvalues of a skew
87 | % symmetric matrix, than at those of identity plus that matrix,
88 | % and compute their product for the determinant: it's stricly
89 | % positive in all cases.
90 | end
91 | end
92 |
93 | M.exp = @exponential;
94 | function Y = exponential(X, U, t)
95 | if nargin == 3
96 | exptU = t*U;
97 | else
98 | exptU = U;
99 | end
100 | for i = 1 : k
101 | exptU(:, :, i) = expm(exptU(:, :, i));
102 | end
103 | Y = multiprod(X, exptU);
104 | end
105 |
106 | M.log = @logarithm;
107 | function U = logarithm(X, Y)
108 | U = multiprod(multitransp(X), Y);
109 | for i = 1 : k
110 | % The result of logm should be real in theory, but it is
111 | % numerically useful to force it.
112 | U(:, :, i) = real(logm(U(:, :, i)));
113 | end
114 | % Ensure the tangent vector is in the Lie algebra.
115 | U = multiskew(U);
116 | end
117 |
118 | M.hash = @(X) ['z' hashmd5(X(:))];
119 |
120 | M.rand = @() randrot(n, k);
121 |
122 | M.randvec = @randomvec;
123 | function U = randomvec(X) %#ok
124 | U = randskew(n, k);
125 | nrmU = sqrt(U(:).'*U(:));
126 | U = U / nrmU;
127 | end
128 |
129 | M.lincomb = @lincomb;
130 |
131 | M.zerovec = @(x) zeros(n, n, k);
132 |
133 | M.transp = @(x1, x2, d) d;
134 |
135 | M.pairmean = @pairmean;
136 | function Y = pairmean(X1, X2)
137 | V = M.log(X1, X2);
138 | Y = M.exp(X1, .5*V);
139 | end
140 |
141 | M.dist = @(x, y) M.norm(x, M.log(x, y));
142 |
143 | M.vec = @(x, u_mat) u_mat(:);
144 | M.mat = @(x, u_vec) reshape(u_vec, [n, n, k]);
145 | M.vecmatareisometries = @() true;
146 |
147 | end
148 |
149 | % Linear combination of tangent vectors
150 | function d = lincomb(x, a1, d1, a2, d2) %#ok
151 |
152 | if nargin == 3
153 | d = a1*d1;
154 | elseif nargin == 5
155 | d = a1*d1 + a2*d2;
156 | else
157 | error('Bad use of rotations.lincomb.');
158 | end
159 |
160 | end
161 |
162 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/sphere/spherecomplexfactory.m:
--------------------------------------------------------------------------------
1 | function M = spherecomplexfactory(n, m)
2 | % Returns a manifold struct to optimize over unit-norm complex matrices.
3 | %
4 | % function M = spherecomplexfactory(n)
5 | % function M = spherecomplexfactory(n, m)
6 | %
7 | % Manifold of n-by-m complex matrices of unit Frobenius norm.
8 | % By default, m = 1, which corresponds to the unit sphere in C^n. The
9 | % metric is such that the sphere is a Riemannian submanifold of the space
10 | % of 2nx2m real matrices with the usual trace inner product, i.e., the
11 | % usual metric.
12 | %
13 | % See also: spherefactory
14 |
15 | % This file is part of Manopt: www.manopt.org.
16 | % Original author: Nicolas Boumal, Dec. 30, 2012.
17 | % Contributors:
18 | % Change log:
19 |
20 |
21 | if ~exist('m', 'var')
22 | m = 1;
23 | end
24 |
25 | if m == 1
26 | M.name = @() sprintf('Complex sphere S^%d', n-1);
27 | else
28 | M.name = @() sprintf('Unit F-norm %dx%d complex matrices', n, m);
29 | end
30 |
31 | M.dim = @() 2*(n*m)-1;
32 |
33 | M.inner = @(x, d1, d2) real(d1(:)'*d2(:));
34 |
35 | M.norm = @(x, d) norm(d, 'fro');
36 |
37 | M.dist = @(x, y) acos(real(x(:)'*y(:)));
38 |
39 | M.typicaldist = @() pi;
40 |
41 | M.proj = @(x, d) reshape(d(:) - x(:)*(real(x(:)'*d(:))), n, m);
42 |
43 | % For Riemannian submanifolds, converting a Euclidean gradient into a
44 | % Riemannian gradient amounts to an orthogonal projection.
45 | M.egrad2rgrad = M.proj;
46 |
47 | M.tangent = M.proj;
48 |
49 | M.exp = @exponential;
50 |
51 | M.retr = @retraction;
52 |
53 | M.log = @logarithm;
54 | function v = logarithm(x1, x2)
55 | error('The logarithmic map is not yet implemented for this manifold.');
56 | end
57 |
58 | M.hash = @(x) ['z' hashmd5([real(x(:)) ; imag(x(:))])];
59 |
60 | M.rand = @() random(n, m);
61 |
62 | M.randvec = @(x) randomvec(n, m, x);
63 |
64 | M.lincomb = @lincomb;
65 |
66 | M.zerovec = @(x) zeros(n, m);
67 |
68 | M.transp = @(x1, x2, d) M.proj(x2, d);
69 |
70 | M.pairmean = @pairmean;
71 | function y = pairmean(x1, x2)
72 | y = x1+x2;
73 | y = y / norm(y, 'fro');
74 | end
75 |
76 | end
77 |
78 | % Exponential on the sphere
79 | function y = exponential(x, d, t)
80 |
81 | if nargin == 2
82 | t = 1;
83 | end
84 |
85 | td = t*d;
86 |
87 | nrm_td = norm(td, 'fro');
88 |
89 | if nrm_td > 1e-6
90 | y = x*cos(nrm_td) + td*(sin(nrm_td)/nrm_td);
91 | else
92 | % If the step is too small, to avoid dividing by nrm_td, we choose
93 | % to approximate with this retraction-like step.
94 | y = x + td;
95 | y = y / norm(y, 'fro');
96 | end
97 |
98 | end
99 |
100 | % Retraction on the sphere
101 | function y = retraction(x, d, t)
102 |
103 | if nargin == 2
104 | t = 1;
105 | end
106 |
107 | y = x+t*d;
108 | y = y/norm(y, 'fro');
109 |
110 | end
111 |
112 | % Uniform random sampling on the sphere.
113 | function x = random(n, m)
114 |
115 | x = randn(n, m) + 1i*randn(n, m);
116 | x = x/norm(x, 'fro');
117 |
118 | end
119 |
120 | % Random normalized tangent vector at x.
121 | function d = randomvec(n, m, x)
122 |
123 | d = randn(n, m) + 1i*randn(n, m);
124 | d = reshape(d(:) - x(:)*(real(x(:)'*d(:))), n, m);
125 | d = d / norm(d, 'fro');
126 |
127 | end
128 |
129 | % Linear combination of tangent vectors
130 | function d = lincomb(x, a1, d1, a2, d2) %#ok
131 |
132 | if nargin == 3
133 | d = a1*d1;
134 | elseif nargin == 5
135 | d = a1*d1 + a2*d2;
136 | else
137 | error('Bad use of spherecomplex.lincomb.');
138 | end
139 |
140 | end
141 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/sphere/spherefactory.m:
--------------------------------------------------------------------------------
1 | function M = spherefactory(n, m)
2 | % Returns a manifold struct to optimize over unit-norm vectors or matrices.
3 | %
4 | % function M = spherefactory(n)
5 | % function M = spherefactory(n, m)
6 | %
7 | % Manifold of n-by-m real matrices of unit Frobenius norm.
8 | % By default, m = 1, which corresponds to the unit sphere in R^n. The
9 | % metric is such that the sphere is a Riemannian submanifold of the space
10 | % of nxm matrices with the usual trace inner product, i.e., the usual
11 | % metric.
12 | %
13 | % See also: obliquefactory spherecomplexfactory
14 |
15 | % This file is part of Manopt: www.manopt.org.
16 | % Original author: Nicolas Boumal, Dec. 30, 2012.
17 | % Contributors:
18 | % Change log:
19 |
20 |
21 | if ~exist('m', 'var')
22 | m = 1;
23 | end
24 |
25 | if m == 1
26 | M.name = @() sprintf('Sphere S^%d', n-1);
27 | else
28 | M.name = @() sprintf('Unit F-norm %dx%d matrices', n, m);
29 | end
30 |
31 | M.dim = @() n*m-1;
32 |
33 | M.inner = @(x, d1, d2) d1(:).'*d2(:);
34 |
35 | M.norm = @(x, d) norm(d, 'fro');
36 |
37 | M.dist = @(x, y) real(acos(x(:).'*y(:)));
38 |
39 | M.typicaldist = @() pi;
40 |
41 | M.proj = @(x, d) d - x*(x(:).'*d(:));
42 |
43 | M.tangent = M.proj;
44 |
45 | % For Riemannian submanifolds, converting a Euclidean gradient into a
46 | % Riemannian gradient amounts to an orthogonal projection.
47 | M.egrad2rgrad = M.proj;
48 |
49 | M.ehess2rhess = @ehess2rhess;
50 | function rhess = ehess2rhess(x, egrad, ehess, u)
51 | rhess = M.proj(x, ehess) - (x(:)'*egrad(:))*u;
52 | end
53 |
54 | M.exp = @exponential;
55 |
56 | M.retr = @retraction;
57 |
58 | M.log = @logarithm;
59 | function v = logarithm(x1, x2)
60 | v = M.proj(x1, x2 - x1);
61 | di = M.dist(x1, x2);
62 | nv = norm(v, 'fro');
63 | v = v * (di / nv);
64 | end
65 |
66 | M.hash = @(x) ['z' hashmd5(x(:))];
67 |
68 | M.rand = @() random(n, m);
69 |
70 | M.randvec = @(x) randomvec(n, m, x);
71 |
72 | M.lincomb = @lincomb;
73 |
74 | M.zerovec = @(x) zeros(n, m);
75 |
76 | M.transp = @(x1, x2, d) M.proj(x2, d);
77 |
78 | M.pairmean = @pairmean;
79 | function y = pairmean(x1, x2)
80 | y = x1+x2;
81 | y = y / norm(y, 'fro');
82 | end
83 |
84 | M.vec = @(x, u_mat) u_mat(:);
85 | M.mat = @(x, u_vec) reshape(u_vec, [n, m]);
86 | M.vecmatareisometries = @() true;
87 |
88 | end
89 |
90 | % Exponential on the sphere
91 | function y = exponential(x, d, t)
92 |
93 | if nargin == 2
94 | t = 1;
95 | end
96 |
97 | td = t*d;
98 |
99 | nrm_td = norm(td, 'fro');
100 |
101 | if nrm_td > 1e-6
102 | y = x*cos(nrm_td) + td*(sin(nrm_td)/nrm_td);
103 | else
104 | % if the step is too small, to avoid dividing by nrm_td, we choose
105 | % to approximate with this retraction-like step.
106 | y = x + td;
107 | y = y / norm(y, 'fro');
108 | end
109 |
110 | end
111 |
112 | % Retraction on the sphere
113 | function y = retraction(x, d, t)
114 |
115 | if nargin == 2
116 | t = 1;
117 | end
118 |
119 | y = x + t*d;
120 | y = y / norm(y, 'fro');
121 |
122 | end
123 |
124 | % Uniform random sampling on the sphere.
125 | function x = random(n, m)
126 |
127 | x = randn(n, m);
128 | x = x/norm(x, 'fro');
129 |
130 | end
131 |
132 | % Random normalized tangent vector at x.
133 | function d = randomvec(n, m, x)
134 |
135 | d = randn(n, m);
136 | d = d - x*(x(:).'*d(:));
137 | d = d / norm(d, 'fro');
138 |
139 | end
140 |
141 | % Linear combination of tangent vectors
142 | function d = lincomb(x, a1, d1, a2, d2) %#ok
143 |
144 | if nargin == 3
145 | d = a1*d1;
146 | elseif nargin == 5
147 | d = a1*d1 + a2*d2;
148 | else
149 | error('Bad use of sphere.lincomb.');
150 | end
151 |
152 | end
153 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/stiefel/stiefelfactory.m:
--------------------------------------------------------------------------------
1 | function M = stiefelfactory(n, p, k)
2 | % Returns a manifold structure to optimize over orthonormal matrices.
3 | %
4 | % function M = stiefelfactory(n, p)
5 | % function M = stiefelfactory(n, p, k)
6 | %
7 | % The Stiefel manifold is the set of orthonormal nxp matrices. If k
8 | % is larger than 1, this is the Cartesian product of the Stiefel manifold
9 | % taken k times. The metric is such that the manifold is a Riemannian
10 | % submanifold of R^nxp equipped with the usual trace inner product, that
11 | % is, it is the usual metric.
12 | %
13 | % Points are represented as matrices X of size n x p x k (or n x p if k=1,
14 | % which is the default) such that each n x p matrix is orthonormal,
15 | % i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
16 | % i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
17 | % size as points.
18 | %
19 | % By default, k = 1.
20 | %
21 | % See also: grassmannfactory rotationsfactory
22 |
23 | % This file is part of Manopt: www.manopt.org.
24 | % Original author: Nicolas Boumal, Dec. 30, 2012.
25 | % Contributors:
26 | % Change log:
27 | % July 5, 2013 (NB) : Added ehess2rhess.
28 | % Jan. 27, 2014 (BM) : Bug in ehess2rhess corrected.
29 | % June 24, 2014 (NB) : Added true exponential map and changed the randvec
30 | % function so that it now returns a globally
31 | % normalized vector, not a vector where each
32 | % component is normalized (this only matters if k>1).
33 |
34 |
35 | if ~exist('k', 'var') || isempty(k)
36 | k = 1;
37 | end
38 |
39 | if k == 1
40 | M.name = @() sprintf('Stiefel manifold St(%d, %d)', n, p);
41 | elseif k > 1
42 | M.name = @() sprintf('Product Stiefel manifold St(%d, %d)^%d', n, p, k);
43 | else
44 | error('k must be an integer no less than 1.');
45 | end
46 |
47 | M.dim = @() k*(n*p - .5*p*(p+1));
48 |
49 | M.inner = @(x, d1, d2) d1(:).'*d2(:);
50 |
51 | M.norm = @(x, d) norm(d(:));
52 |
53 | M.dist = @(x, y) error('stiefel.dist not implemented yet.');
54 |
55 | M.typicaldist = @() sqrt(p*k);
56 |
57 | M.proj = @projection;
58 | function Up = projection(X, U)
59 |
60 | XtU = multiprod(multitransp(X), U);
61 | symXtU = multisym(XtU);
62 | Up = U - multiprod(X, symXtU);
63 |
64 | % The code above is equivalent to, but much faster than, the code below.
65 | %
66 | % Up = zeros(size(U));
67 | % function A = sym(A), A = .5*(A+A'); end
68 | % for i = 1 : k
69 | % Xi = X(:, :, i);
70 | % Ui = U(:, :, i);
71 | % Up(:, :, i) = Ui - Xi*sym(Xi'*Ui);
72 | % end
73 |
74 | end
75 |
76 | M.tangent = M.proj;
77 |
78 | % For Riemannian submanifolds, converting a Euclidean gradient into a
79 | % Riemannian gradient amounts to an orthogonal projection.
80 | M.egrad2rgrad = M.proj;
81 |
82 | M.ehess2rhess = @ehess2rhess;
83 | function rhess = ehess2rhess(X, egrad, ehess, H)
84 | XtG = multiprod(multitransp(X), egrad);
85 | symXtG = multisym(XtG);
86 | HsymXtG = multiprod(H, symXtG);
87 | rhess = projection(X, ehess - HsymXtG);
88 | end
89 |
90 | M.retr = @retraction;
91 | function Y = retraction(X, U, t)
92 | if nargin < 3
93 | t = 1.0;
94 | end
95 | Y = X + t*U;
96 | for i = 1 : k
97 | [Q, R] = qr(Y(:, :, i), 0);
98 | % The instruction with R assures we are not flipping signs
99 | % of some columns, which should never happen in modern Matlab
100 | % versions but may be an issue with older versions.
101 | Y(:, :, i) = Q * diag(sign(sign(diag(R))+.5));
102 | end
103 | end
104 |
105 | M.exp = @exponential;
106 | function Y = exponential(X, U, t)
107 | if nargin == 2
108 | t = 1;
109 | end
110 | tU = t*U;
111 | Y = zeros(size(X));
112 | for i = 1 : k
113 | % From a formula by Ross Lippert, Example 5.4.2 in AMS08.
114 | Xi = X(:, :, i);
115 | Ui = tU(:, :, i);
116 | Y(:, :, i) = [Xi Ui] * ...
117 | expm([Xi'*Ui , -Ui'*Ui ; eye(p) , Xi'*Ui]) * ...
118 | [ expm(-Xi'*Ui) ; zeros(p) ];
119 | end
120 |
121 | end
122 |
123 | M.hash = @(X) ['z' hashmd5(X(:))];
124 |
125 | M.rand = @random;
126 | function X = random()
127 | X = zeros(n, p, k);
128 | for i = 1 : k
129 | [Q, unused] = qr(randn(n, p), 0); %#ok
130 | X(:, :, i) = Q;
131 | end
132 | end
133 |
134 | M.randvec = @randomvec;
135 | function U = randomvec(X)
136 | U = projection(X, randn(n, p, k));
137 | U = U / norm(U(:));
138 | end
139 |
140 | M.lincomb = @lincomb;
141 |
142 | M.zerovec = @(x) zeros(n, p, k);
143 |
144 | M.transp = @(x1, x2, d) projection(x2, d);
145 |
146 | M.vec = @(x, u_mat) u_mat(:);
147 | M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
148 | M.vecmatareisometries = @() true;
149 |
150 | end
151 |
152 | % Linear combination of tangent vectors
153 | function d = lincomb(x, a1, d1, a2, d2) %#ok
154 |
155 | if nargin == 3
156 | d = a1*d1;
157 | elseif nargin == 5
158 | d = a1*d1 + a2*d2;
159 | else
160 | error('Bad use of stiefel.lincomb.');
161 | end
162 |
163 | end
164 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/symfixedrank/spectrahedronfactory.m:
--------------------------------------------------------------------------------
1 | function M = spectrahedronfactory(n, k)
2 | % Manifold of n-by-n symmetric positive semidefinite natrices of rank k
3 | % with trace (sum of diagonal elements) being 1.
4 | %
5 | % function M = spectrahedronfactory(n, k)
6 | %
7 | % The goemetry is based on the paper,
8 | % M. Journee, P.-A. Absil, F. Bach and R. Sepulchre,
9 | % "Low-Rank Optinization on the Cone of Positive Semidefinite Matrices",
10 | % SIOPT, 2010.
11 | %
12 | % Paper link: http://www.di.ens.fr/~fbach/journee2010_sdp.pdf
13 | %
14 | % A point X on the manifold is parameterized as YY^T where Y is a matrix of
15 | % size nxk. The matrix Y (nxk) is a full colunn-rank natrix. Hence, we deal
16 | % directly with Y. The trace constraint on X translates to the Frobenius
17 | % norm constrain on Y, i.e., trace(X) = || Y ||^2.
18 |
19 | % This file is part of Manopt: www.nanopt.org.
20 | % Original author: Bamdev Mishra, July 11, 2013.
21 | % Contributors:
22 | % Change log:
23 |
24 |
25 |
26 | M.name = @() sprintf('YY'' quotient manifold of %dx%d PSD matrices of rank %d with trace 1 ', n, k);
27 |
28 | M.dim = @() (k*n - 1) - k*(k-1)/2; % Extra -1 is because of the trace constraint that
29 |
30 | % Euclidean metric on the total space
31 | M.inner = @(Y, eta, zeta) trace(eta'*zeta);
32 |
33 | M.norm = @(Y, eta) sqrt(M.inner(Y, eta, eta));
34 |
35 | M.dist = @(Y, Z) error('spectrahedronfactory.dist not implemented yet.');
36 |
37 | M.typicaldist = @() 10*k;
38 |
39 | M.proj = @projection;
40 | function etaproj = projection(Y, eta)
41 | % Projection onto the tangent space, i.e., on the tangent space of
42 | % ||Y|| = 1
43 |
44 | eta = eta - trace(eta'*Y)*Y;
45 |
46 | % Projection onto the horizontal space
47 | YtY = Y'*Y;
48 | SS = YtY;
49 | AS = Y'*eta - eta'*Y;
50 | Omega = lyap(SS, -AS);
51 | etaproj = eta - Y*Omega;
52 | end
53 |
54 | M.tangent = M.proj;
55 | M.tangent2ambient = @(Y, eta) eta;
56 |
57 | M.retr = @retraction;
58 | function Ynew = retraction(Y, eta, t)
59 | if nargin < 3
60 | t = 1.0;
61 | end
62 | Ynew = Y + t*eta;
63 | Ynew = Ynew/norm(Ynew,'fro');
64 | end
65 |
66 |
67 | M.egrad2rgrad = @(Y, eta) eta - trace(eta'*Y)*Y;
68 |
69 | M.ehess2rhess = @ehess2rhess;
70 | function Hess = ehess2rhess(Y, egrad, ehess, eta)
71 |
72 | % Directional derivative of the Riemannian gradient
73 | Hess = ehess - trace(egrad'*Y)*eta - (trace(ehess'*Y) + trace(egrad'*eta))*Y;
74 | Hess = Hess - trace(Hess'*Y)*Y;
75 |
76 | % Project on the horizontal space
77 | Hess = M.proj(Y, Hess);
78 |
79 | end
80 |
81 | M.exp = @exponential;
82 | function Ynew = exponential(Y, eta, t)
83 | if nargin < 3
84 | t = 1.0;
85 | end
86 |
87 | Ynew = retraction(Y, eta, t);
88 | warning('manopt:spectrahedronfactory:exp', ...
89 | ['Exponential for fixed rank spectrahedron ' ...
90 | 'manifold not implenented yet. Used retraction instead.']);
91 | end
92 |
93 | % Notice that the hash of two equivalent points will be different...
94 | M.hash = @(Y) ['z' hashmd5(Y(:))];
95 |
96 | M.rand = @random;
97 |
98 | function Y = random()
99 | Y = randn(n, k);
100 | Y = Y/norm(Y,'fro');
101 | end
102 |
103 | M.randvec = @randomvec;
104 | function eta = randomvec(Y)
105 | eta = randn(n, k);
106 | eta = projection(Y, eta);
107 | nrm = M.norm(Y, eta);
108 | eta = eta / nrm;
109 | end
110 |
111 | M.lincomb = @lincomb;
112 |
113 | M.zerovec = @(Y) zeros(n, k);
114 |
115 | M.transp = @(Y1, Y2, d) projection(Y2, d);
116 |
117 | M.vec = @(Y, u_mat) u_mat(:);
118 | M.mat = @(Y, u_vec) reshape(u_vec, [n, k]);
119 | M.vecmatareisometries = @() true;
120 |
121 | end
122 |
123 |
124 | % Linear conbination of tangent vectors
125 | function d = lincomb(Y, a1, d1, a2, d2) %#ok
126 |
127 | if nargin == 3
128 | d = a1*d1;
129 | elseif nargin == 5
130 | d = a1*d1 + a2*d2;
131 | else
132 | error('Bad use of spectrahedronfactory.lincomb.');
133 | end
134 |
135 | end
136 |
137 |
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/symfixedrank/symfixedrankYYfactory.m:
--------------------------------------------------------------------------------
1 | function M = symfixedrankYYfactory(n, k)
2 | % Manifold of n-by-n symmetric positive semidefinite matrices of rank k.
3 | %
4 | % function M = symfixedrankYYfactory(n, k)
5 | %
6 | % The geometry is based on the paper,
7 | % M. Journee, P.-A. Absil, F. Bach and R. Sepulchre,
8 | % "Low-Rank Optimization on the Cone of Positive Semidefinite Matrices",
9 | % SIAM Journal on Optimization, 2010.
10 | %
11 | % Paper link: http://www.di.ens.fr/~fbach/journee2010_sdp.pdf
12 | %
13 | % A point X on the manifold is parameterized as YY^T where Y is a matrix of
14 | % size nxk. The matrix Y (nxk) is a full column-rank matrix. Hence, we deal
15 | % directly with Y.
16 | %
17 | % Notice that this manifold is not complete: if optimization leads Y to be
18 | % rank-deficient, the geometry will break down. Hence, this geometry should
19 | % only be used if it is expected that the points of interest will have rank
20 | % exactly k. Reduce k if that is not the case.
21 | %
22 | % An alternative, complete, geometry for positive semidefinite matrices of
23 | % rank k is described in Bonnabel and Sepulchre 2009, "Riemannian Metric
24 | % and Geometric Mean for Positive Semidefinite Matrices of Fixed Rank",
25 | % SIAM Journal on Matrix Analysis and Applications.
26 |
27 | % This file is part of Manopt: www.manopt.org.
28 | % Original author: Bamdev Mishra, Dec. 30, 2012.
29 | % Contributors:
30 | % Change log:
31 | % July 10, 2013 (NB)
32 | % Added vec, mat, tangent, tangent2ambient ;
33 | % Correction for the dimension of the manifold.
34 |
35 |
36 | M.name = @() sprintf('YY'' quotient manifold of %dx%d PSD matrices of rank %d', n, k);
37 |
38 | M.dim = @() k*n - k*(k-1)/2;
39 |
40 | % Euclidean metric on the total space
41 | M.inner = @(Y, eta, zeta) trace(eta'*zeta);
42 |
43 | M.norm = @(Y, eta) sqrt(M.inner(Y, eta, eta));
44 |
45 | M.dist = @(Y, Z) error('symfixedrankYYfactory.dist not implemented yet.');
46 |
47 | M.typicaldist = @() 10*k;
48 |
49 | M.proj = @projection;
50 | function etaproj = projection(Y, eta)
51 | % Projection onto the horizontal space
52 | YtY = Y'*Y;
53 | SS = YtY;
54 | AS = Y'*eta - eta'*Y;
55 | Omega = lyap(SS, -AS);
56 | etaproj = eta - Y*Omega;
57 | end
58 |
59 | M.tangent = M.proj;
60 | M.tangent2ambient = @(Y, eta) eta;
61 |
62 | M.retr = @retraction;
63 | function Ynew = retraction(Y, eta, t)
64 | if nargin < 3
65 | t = 1.0;
66 | end
67 | Ynew = Y + t*eta;
68 | end
69 |
70 |
71 | M.egrad2rgrad = @(Y, eta) eta;
72 | M.ehess2rhess = @(Y, egrad, ehess, U) M.proj(Y, ehess);
73 |
74 | M.exp = @exponential;
75 | function Ynew = exponential(Y, eta, t)
76 | if nargin < 3
77 | t = 1.0;
78 | end
79 |
80 | Ynew = retraction(Y, eta, t);
81 | warning('manopt:symfixedrankYYfactory:exp', ...
82 | ['Exponential for symmetric, fixed-rank ' ...
83 | 'manifold not implemented yet. Used retraction instead.']);
84 | end
85 |
86 | % Notice that the hash of two equivalent points will be different...
87 | M.hash = @(Y) ['z' hashmd5(Y(:))];
88 |
89 | M.rand = @random;
90 |
91 | function Y = random()
92 | Y = randn(n, k);
93 | end
94 |
95 | M.randvec = @randomvec;
96 | function eta = randomvec(Y)
97 | eta = randn(n, k);
98 | eta = projection(Y, eta);
99 | nrm = M.norm(Y, eta);
100 | eta = eta / nrm;
101 | end
102 |
103 | M.lincomb = @lincomb;
104 |
105 | M.zerovec = @(Y) zeros(n, k);
106 |
107 | M.transp = @(Y1, Y2, d) projection(Y2, d);
108 |
109 | M.vec = @(Y, u_mat) u_mat(:);
110 | M.mat = @(Y, u_vec) reshape(u_vec, [n, k]);
111 | M.vecmatareisometries = @() true;
112 |
113 | end
114 |
115 |
116 | % Linear conbination of tangent vectors
117 | function d = lincomb(Y, a1, d1, a2, d2) %#ok
118 |
119 | if nargin == 3
120 | d = a1*d1;
121 | elseif nargin == 5
122 | d = a1*d1 + a2*d2;
123 | else
124 | error('Bad use of symfixedrankYYfactory.lincomb.');
125 | end
126 |
127 | end
128 |
129 |
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/manopt/manopt/manifolds/symfixedrank/sympositivedefinitefactory.m:
--------------------------------------------------------------------------------
1 | function M = sympositivedefinitefactory(n)
2 | % Manifold of n-by-n symmetric positive definite matrices with
3 | % the bi-invariant geometry.
4 | %
5 | % function M = sympositivedefinitefactory(n)
6 | %
7 | % A point X on the manifold is represented as a symmetric positive definite
8 | % matrix X (nxn).
9 | %
10 | % The following material is referenced from Chapter 6 of the book:
11 | % Rajendra Bhatia, "Positive definite matrices",
12 | % Princeton University Press, 2007.
13 |
14 | % This file is part of Manopt: www.manopt.org.
15 | % Original author: Bamdev Mishra, August 29, 2013.
16 | % Contributors: Nicolas Boumal
17 | % Change log:
18 | %
19 | % March 5, 2014 (NB)
20 | % There were a number of mistakes in the code owing to the tacit
21 | % assumption that if X and eta are symmetric, then X\eta is
22 | % symmetric too, which is not the case. See discussion on the Manopt
23 | % forum started on Jan. 19, 2014. Functions norm, dist, exp and log
24 | % were modified accordingly. Furthermore, they only require matrix
25 | % inversion (as well as matrix log or matrix exp), not matrix square
26 | % roots or their inverse.
27 | %
28 | % July 28, 2014 (NB)
29 | % The dim() function returned n*(n-1)/2 instead of n*(n+1)/2.
30 | % Implemented proper parallel transport from Sra and Hosseini (not
31 | % used by default).
32 | % Also added symmetrization in exp and log (to be sure).
33 |
34 | symm = @(X) .5*(X+X');
35 |
36 | M.name = @() sprintf('Symmetric positive definite geometry of %dx%d matrices', n, n);
37 |
38 | M.dim = @() n*(n+1)/2;
39 |
40 | % Choice of the metric on the orthnormal space is motivated by the
41 | % symmetry present in the space. The metric on the positive definite
42 | % cone is its natural bi-invariant metric.
43 | M.inner = @(X, eta, zeta) trace( (X\eta) * (X\zeta) );
44 |
45 | % Notice that X\eta is *not* symmetric in general.
46 | M.norm = @(X, eta) sqrt(trace((X\eta)^2));
47 |
48 | % Same here: X\Y is not symmetric in general. There should be no need
49 | % to take the real part, but rounding errors may cause a small
50 | % imaginary part to appear, so we discard it.
51 | M.dist = @(X, Y) sqrt(real(trace((logm(X\Y))^2)));
52 |
53 |
54 | M.typicaldist = @() sqrt(n*(n+1)/2);
55 |
56 |
57 | M.egrad2rgrad = @egrad2rgrad;
58 | function eta = egrad2rgrad(X, eta)
59 | eta = X*symm(eta)*X;
60 | end
61 |
62 |
63 | M.ehess2rhess = @ehess2rhess;
64 | function Hess = ehess2rhess(X, egrad, ehess, eta)
65 | % Directional derivatives of the Riemannian gradient
66 | Hess = X*symm(ehess)*X + 2*symm(eta*symm(egrad)*X);
67 |
68 | % Correction factor for the non-constant metric
69 | Hess = Hess - symm(eta*symm(egrad)*X);
70 | end
71 |
72 |
73 | M.proj = @(X, eta) symm(eta);
74 |
75 | M.tangent = M.proj;
76 | M.tangent2ambient = @(X, eta) eta;
77 |
78 | M.retr = @exponential;
79 |
80 | M.exp = @exponential;
81 | function Y = exponential(X, eta, t)
82 | if nargin < 3
83 | t = 1.0;
84 | end
85 | % The symm() and real() calls are mathematically not necessary but
86 | % are numerically necessary.
87 | Y = symm(X*real(expm(X\(t*eta))));
88 | end
89 |
90 | M.log = @logarithm;
91 | function H = logarithm(X, Y)
92 | % Same remark regarding the calls to symm() and real().
93 | H = symm(X*real(logm(X\Y)));
94 | end
95 |
96 | M.hash = @(X) ['z' hashmd5(X(:))];
97 |
98 | % Generate a random symmetric positive definite matrix following a
99 | % certain distribution. The particular choice of a distribution is of
100 | % course arbitrary, and specific applications might require different
101 | % ones.
102 | M.rand = @random;
103 | function X = random()
104 | D = diag(1+rand(n, 1));
105 | [Q, R] = qr(randn(n)); %#ok
106 | X = Q*D*Q';
107 | end
108 |
109 | % Generate a uniformly random unit-norm tangent vector at X.
110 | M.randvec = @randomvec;
111 | function eta = randomvec(X)
112 | eta = symm(randn(n));
113 | nrm = M.norm(X, eta);
114 | eta = eta / nrm;
115 | end
116 |
117 | M.lincomb = @lincomb;
118 |
119 | M.zerovec = @(X) zeros(n);
120 |
121 | % Poor man's vector transport: exploit the fact that all tangent spaces
122 | % are the set of symmetric matrices, so that the identity is a sort of
123 | % vector transport. It may perform poorly if the origin and target (X1
124 | % and X2) are far apart though. This should not be the case for typical
125 | % optimization algorithms, which perform small steps.
126 | M.transp = @(X1, X2, eta) eta;
127 |
128 | % For reference, a proper vector transport is given here, following
129 | % work by Sra and Hosseini (2014), "Conic geometric optimisation on the
130 | % manifold of positive definite matrices",
131 | % http://arxiv.org/abs/1312.1039
132 | % This will not be used by default. To force the use of this transport,
133 | % call "M.transp = M.paralleltransp;" on your M returned by the present
134 | % factory.
135 | M.paralleltransp = @parallel_transport;
136 | function zeta = parallel_transport(X, Y, eta)
137 | E = sqrtm((Y/X));
138 | zeta = E*eta*E';
139 | end
140 |
141 | % vec and mat are not isometries, because of the unusual inner metric.
142 | M.vec = @(X, U) U(:);
143 | M.mat = @(X, u) reshape(u, n, n);
144 | M.vecmatareisometries = @() false;
145 |
146 | end
147 |
148 | % Linear combination of tangent vectors
149 | function d = lincomb(X, a1, d1, a2, d2) %#ok
150 | if nargin == 3
151 | d = a1*d1;
152 | elseif nargin == 5
153 | d = a1*d1 + a2*d2;
154 | else
155 | error('Bad use of sympositivedefinitefactory.lincomb.');
156 | end
157 | end
158 |
159 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/applyStatsfun.m:
--------------------------------------------------------------------------------
1 | function stats = applyStatsfun(problem, x, storedb, options, stats)
2 | % Apply the statsfun function to a stats structure (for solvers).
3 | %
4 | % function stats = applyStatsfun(problem, x, storedb, options, stats)
5 | %
6 | % Applies the options.statsfun user supplied function to the stats
7 | % structure, if it was provided, with the appropriate inputs, and returns
8 | % the (possibly) modified stats structure.
9 | %
10 | % See also:
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, April 3, 2013.
14 | % Contributors:
15 | % Change log:
16 |
17 | if isfield(options, 'statsfun')
18 |
19 | is_octave = exist('OCTAVE_VERSION', 'builtin');
20 | if ~is_octave
21 | narg = nargin(options.statsfun);
22 | else
23 | narg = 4;
24 | end
25 |
26 | switch narg
27 | case 3
28 | stats = options.statsfun(problem, x, stats);
29 | case 4
30 | store = getStore(problem, x, storedb);
31 | stats = options.statsfun(problem, x, stats, store);
32 | otherwise
33 | warning('manopt:statsfun', ...
34 | 'statsfun unused: wrong number of inputs');
35 | end
36 | end
37 |
38 | end
39 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/canGetCost.m:
--------------------------------------------------------------------------------
1 | function candoit = canGetCost(problem)
2 | % Checks whether the cost function can be computed for a problem structure.
3 | %
4 | % function candoit = canGetCost(problem)
5 | %
6 | % Returns true if the cost function can be computed given the problem
7 | % description, false otherwise.
8 | %
9 | % See also: getCost canGetDirectionalDerivative canGetGradient canGetHessian
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, Dec. 30, 2012.
13 | % Contributors:
14 | % Change log:
15 |
16 |
17 | candoit = isfield(problem, 'cost') || isfield(problem, 'costgrad');
18 |
19 | end
20 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/canGetDirectionalDerivative.m:
--------------------------------------------------------------------------------
1 | function candoit = canGetDirectionalDerivative(problem)
2 | % Checks whether dir. derivatives can be computed for a problem structure.
3 | %
4 | % function candoit = canGetDirectionalDerivative(problem)
5 | %
6 | % Returns true if the directional derivatives of the cost function can be
7 | % computed given the problem description, false otherwise.
8 | %
9 | % See also: canGetCost canGetGradient canGetHessian
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, Dec. 30, 2012.
13 | % Contributors:
14 | % Change log:
15 |
16 | candoit = isfield(problem, 'diff') || canGetGradient(problem);
17 |
18 | end
19 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/canGetEuclideanGradient.m:
--------------------------------------------------------------------------------
1 | function candoit = canGetEuclideanGradient(problem)
2 | % Checks whether the Euclidean gradient can be computed for a problem.
3 | %
4 | % function candoit = canGetEuclideanGradient(problem)
5 | %
6 | % Returns true if the Euclidean gradient can be computed given the problem
7 | % description, false otherwise.
8 | %
9 | % See also: canGetGradient getEuclideanGradient
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, Dec. 30, 2012.
13 | % Contributors:
14 | % Change log:
15 |
16 |
17 | candoit = isfield(problem, 'egrad');
18 |
19 | end
20 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/canGetGradient.m:
--------------------------------------------------------------------------------
1 | function candoit = canGetGradient(problem)
2 | % Checks whether the gradient can be computed for a problem structure.
3 | %
4 | % function candoit = canGetGradient(problem)
5 | %
6 | % Returns true if the gradient of the cost function can be computed given
7 | % the problem description, false otherwise.
8 | %
9 | % See also: canGetCost canGetDirectionalDerivative canGetHessian
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, Dec. 30, 2012.
13 | % Contributors:
14 | % Change log:
15 |
16 | candoit = isfield(problem, 'grad') || isfield(problem, 'costgrad') || ...
17 | canGetEuclideanGradient(problem);
18 |
19 | end
20 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/canGetHessian.m:
--------------------------------------------------------------------------------
1 | function candoit = canGetHessian(problem)
2 | % Checks whether the Hessian can be computed for a problem structure.
3 | %
4 | % function candoit = canGetHessian(problem)
5 | %
6 | % Returns true if the Hessian of the cost function can be computed given
7 | % the problem description, false otherwise.
8 | %
9 | % See also: canGetCost canGetDirectionalDerivative canGetGradient
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, Dec. 30, 2012.
13 | % Contributors:
14 | % Change log:
15 |
16 | candoit = isfield(problem, 'hess') || ...
17 | (isfield(problem, 'ehess') && canGetEuclideanGradient(problem));
18 |
19 | if ~candoit && ...
20 | (isfield(problem, 'ehess') && ~canGetEuclideanGradient(problem))
21 | warning('manopt:canGetHessian', ...
22 | ['If the Hessian is supplied as a Euclidean Hessian (ehess), ' ...
23 | 'then the Euclidean gradient must also be supplied (egrad).']);
24 | end
25 |
26 | end
27 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/canGetLinesearch.m:
--------------------------------------------------------------------------------
1 | function candoit = canGetLinesearch(problem)
2 | % Checks whether the problem structure can give a line-search a hint.
3 | %
4 | % function candoit = canGetLinesearch(problem)
5 | %
6 | % Returns true if the the problem description includes a mechanism to give
7 | % line-search algorithms a hint as to "how far to look", false otherwise.
8 | %
9 | % See also: getLinesearch
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, July 17, 2014.
13 | % Contributors:
14 | % Change log:
15 |
16 |
17 | candoit = isfield(problem, 'linesearch');
18 |
19 | end
20 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/canGetPrecon.m:
--------------------------------------------------------------------------------
1 | function candoit = canGetPrecon(problem)
2 | % Checks whether a preconditioner was specified in the problem description.
3 | %
4 | % function candoit = canGetPrecon(problem)
5 | %
6 | % Returns true if a preconditioner was specified, false otherwise. Notice
7 | % that even if this function returns false, it is still possible to call
8 | % getPrecon, as the default preconditioner is simply the identity operator.
9 | % This check function is mostly useful to tell whether that default
10 | % preconditioner will be in use or not.
11 | %
12 | % See also: canGetDirectionalDerivative canGetGradient canGetHessian
13 |
14 | % This file is part of Manopt: www.manopt.org.
15 | % Original author: Nicolas Boumal, July 3, 2013.
16 | % Contributors:
17 | % Change log:
18 |
19 | candoit = isfield(problem, 'precon');
20 |
21 | end
22 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getApproxHessian.m:
--------------------------------------------------------------------------------
1 | function [approxhess, storedb] = getApproxHessian(problem, x, d, storedb)
2 | % Computes an approximation of the Hessian of the cost fun. at x along d.
3 | %
4 | % function [approxhess, storedb] = getApproxHessian(problem, x, d, storedb)
5 | %
6 | % Returns an approximation of the Hessian at x along d of the cost function
7 | % described in the problem structure. The cache database storedb is passed
8 | % along, possibly modified and returned in the process.
9 | %
10 | % If no approximate Hessian was furnished, this call is redirected to
11 | % getHessianFD.
12 | %
13 | % See also: getHessianFD
14 |
15 | % This file is part of Manopt: www.manopt.org.
16 | % Original author: Nicolas Boumal, Dec. 30, 2012.
17 | % Contributors:
18 | % Change log:
19 |
20 |
21 | if isfield(problem, 'approxhess')
22 | %% Compute the approximate Hessian using approxhess.
23 |
24 | is_octave = exist('OCTAVE_VERSION', 'builtin');
25 | if ~is_octave
26 | narg = nargin(problem.approxhess);
27 | else
28 | narg = 3;
29 | end
30 |
31 | % Check whether the approximate Hessian function wants to deal with
32 | % the store structure or not.
33 | switch narg
34 | case 2
35 | approxhess = problem.approxhess(x, d);
36 | case 3
37 | % Obtain, pass along, and save the store structure
38 | % associated to this point.
39 | store = getStore(problem, x, storedb);
40 | [approxhess store] = problem.approxhess(x, d, store);
41 | storedb = setStore(problem, x, storedb, store);
42 | otherwise
43 | up = MException('manopt:getApproxHessian:badapproxhess', ...
44 | 'approxhess should accept 2 or 3 inputs.');
45 | throw(up);
46 | end
47 |
48 | else
49 | %% Try to fall back to a standard FD approximation.
50 |
51 | [approxhess, storedb] = getHessianFD(problem, x, d, storedb);
52 |
53 | end
54 |
55 | end
56 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getCost.m:
--------------------------------------------------------------------------------
1 | function [cost, storedb] = getCost(problem, x, storedb)
2 | % Computes the cost function at x.
3 | %
4 | % function [cost, storedb] = getCost(problem, x, storedb)
5 | %
6 | % Returns the value at x of the cost function described in the problem
7 | % structure. The cache database storedb is passed along, possibly modified
8 | % and returned in the process.
9 | %
10 | % See also: canGetCost
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, Dec. 30, 2012.
14 | % Contributors:
15 | % Change log:
16 |
17 |
18 | if isfield(problem, 'cost')
19 | %% Compute the cost function using cost.
20 |
21 | is_octave = exist('OCTAVE_VERSION', 'builtin');
22 | if ~is_octave
23 | narg = nargin(problem.cost);
24 | else
25 | narg = 2;
26 | end
27 |
28 | % Check whether the cost function wants to deal with the store
29 | % structure or not.
30 | switch narg
31 | case 1
32 | cost = problem.cost(x);
33 | case 2
34 | % Obtain, pass along, and save the store structure
35 | % associated to this point.
36 | store = getStore(problem, x, storedb);
37 | [cost, store] = problem.cost(x, store);
38 | storedb = setStore(problem, x, storedb, store);
39 | otherwise
40 | up = MException('manopt:getCost:badcost', ...
41 | 'cost should accept 1 or 2 inputs.');
42 | throw(up);
43 | end
44 |
45 | elseif isfield(problem, 'costgrad')
46 | %% Compute the cost function using costgrad.
47 |
48 | is_octave = exist('OCTAVE_VERSION', 'builtin');
49 | if ~is_octave
50 | narg = nargin(problem.costgrad);
51 | else
52 | narg = 2;
53 | end
54 |
55 | % Check whether the costgrad function wants to deal with the store
56 | % structure or not.
57 | switch narg
58 | case 1
59 | cost = problem.costgrad(x);
60 | case 2
61 | % Obtain, pass along, and save the store structure
62 | % associated to this point.
63 | store = getStore(problem, x, storedb);
64 | [cost, grad, store] = problem.costgrad(x, store); %#ok
65 | storedb = setStore(problem, x, storedb, store);
66 | otherwise
67 | up = MException('manopt:getCost:badcostgrad', ...
68 | 'costgrad should accept 1 or 2 inputs.');
69 | throw(up);
70 | end
71 |
72 | else
73 | %% Abandon computing the cost function.
74 |
75 | up = MException('manopt:getCost:fail', ...
76 | ['The problem description is not explicit enough to ' ...
77 | 'compute the cost.']);
78 | throw(up);
79 |
80 | end
81 |
82 | end
83 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getCostGrad.m:
--------------------------------------------------------------------------------
1 | function [cost, grad, storedb] = getCostGrad(problem, x, storedb)
2 | % Computes the cost function and the gradient at x in one call if possible.
3 | %
4 | % function [cost, storedb] = getCostGrad(problem, x, storedb)
5 | %
6 | % Returns the value at x of the cost function described in the problem
7 | % structure, as well as the gradient at x. The cache database storedb is
8 | % passed along, possibly modified and returned in the process.
9 | %
10 | % See also: canGetCost canGetGradient getCost getGradient
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, Dec. 30, 2012.
14 | % Contributors:
15 | % Change log:
16 |
17 |
18 | if isfield(problem, 'costgrad')
19 | %% Compute the cost/grad pair using costgrad.
20 |
21 | is_octave = exist('OCTAVE_VERSION', 'builtin');
22 | if ~is_octave
23 | narg = nargin(problem.costgrad);
24 | else
25 | narg = 2;
26 | end
27 |
28 | % Check whether the costgrad function wants to deal with the store
29 | % structure or not.
30 | switch narg
31 | case 1
32 | [cost, grad] = problem.costgrad(x);
33 | case 2
34 | % Obtain, pass along, and save the store structure
35 | % associated to this point.
36 | store = getStore(problem, x, storedb);
37 | [cost, grad, store] = problem.costgrad(x, store);
38 | storedb = setStore(problem, x, storedb, store);
39 | otherwise
40 | up = MException('manopt:getCostGrad:badcostgrad', ...
41 | 'costgrad should accept 1 or 2 inputs.');
42 | throw(up);
43 | end
44 |
45 | else
46 | %% Revert to calling getCost and getGradient separately
47 |
48 | [cost, storedb] = getCost(problem, x, storedb);
49 | [grad, storedb] = getGradient(problem, x, storedb);
50 |
51 | end
52 |
53 | end
54 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getDirectionalDerivative.m:
--------------------------------------------------------------------------------
1 | function [diff, storedb] = getDirectionalDerivative(problem, x, d, storedb)
2 | % Computes the directional derivative of the cost function at x along d.
3 | %
4 | % function [diff, storedb] = getDirectionalDerivative(problem, x, d, storedb)
5 | %
6 | % Returns the derivative at x along d of the cost function described in the
7 | % problem structure. The cache database storedb is passed along, possibly
8 | % modified and returned in the process.
9 | %
10 | % See also: getGradient canGetDirectionalDerivative
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, Dec. 30, 2012.
14 | % Contributors:
15 | % Change log:
16 |
17 |
18 | if isfield(problem, 'diff')
19 | %% Compute the directional derivative using diff.
20 |
21 | is_octave = exist('OCTAVE_VERSION', 'builtin');
22 | if ~is_octave
23 | narg = nargin(problem.diff);
24 | else
25 | narg = 3;
26 | end
27 |
28 | % Check whether the diff function wants to deal with the store
29 | % structure or not.
30 | switch narg
31 | case 2
32 | diff = problem.diff(x, d);
33 | case 3
34 | % Obtain, pass along, and save the store structure
35 | % associated to this point.
36 | store = getStore(problem, x, storedb);
37 | [diff store] = problem.diff(x, d, store);
38 | storedb = setStore(problem, x, storedb, store);
39 | otherwise
40 | up = MException('manopt:getDirectionalDerivative:baddiff', ...
41 | 'diff should accept 2 or 3 inputs.');
42 | throw(up);
43 | end
44 |
45 | elseif canGetGradient(problem)
46 | %% Compute the directional derivative using the gradient.
47 |
48 | % Compute the gradient at x, then compute its inner product with d.
49 | [grad, storedb] = getGradient(problem, x, storedb);
50 | diff = problem.M.inner(x, grad, d);
51 |
52 | else
53 | %% Abandon computing the directional derivative.
54 |
55 | up = MException('manopt:getDirectionalDerivative:fail', ...
56 | ['The problem description is not explicit enough to ' ...
57 | 'compute the directional derivatives of f.']);
58 | throw(up);
59 |
60 | end
61 |
62 | end
63 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getEuclideanGradient.m:
--------------------------------------------------------------------------------
1 | function [egrad, storedb] = getEuclideanGradient(problem, x, storedb)
2 | % Computes the Euclidean gradient of the cost function at x.
3 | %
4 | % function [egrad, storedb] = getEuclideanGradient(problem, x, storedb)
5 | %
6 | % Returns the Euclidean gradient at x of the cost function described in the
7 | % problem structure. The cache database storedb is passed along, possibly
8 | % modified and returned in the process.
9 | %
10 | % Because computing the Hessian based on the Euclidean Hessian will require
11 | % the Euclidean gradient every time, to avoid overly redundant
12 | % computations, if the egrad function does not use the store caching
13 | % capabilites, we implement an automatic caching functionality. This means
14 | % that even if the user does not use the store parameter, the hashing
15 | % function will be used, and this can translate in a performance hit for
16 | % small problems. For problems with expensive cost functions, this should
17 | % be a bonus though. Writing egrad to accept the optional store parameter
18 | % (as input and output) will disable automatic caching, but activate user
19 | % controlled caching, which means hashing will be computed in all cases.
20 | %
21 | % If you absolutely do not want hashing to be used (and hence do not want
22 | % caching to be used), you can define grad instead of egrad, without store
23 | % support, and call problem.M.egrad2rgrad manually.
24 | %
25 | % See also: getGradient canGetGradient canGetEuclideanGradient
26 |
27 | % This file is part of Manopt: www.manopt.org.
28 | % Original author: Nicolas Boumal, July 9, 2013.
29 | % Contributors:
30 | % Change log:
31 |
32 |
33 | if isfield(problem, 'egrad')
34 | %% Compute the Euclidean gradient using egrad.
35 |
36 | is_octave = exist('OCTAVE_VERSION', 'builtin');
37 | if ~is_octave
38 | narg = nargin(problem.egrad);
39 | else
40 | narg = 2;
41 | end
42 |
43 | % Check whether the egrad function wants to deal with the store
44 | % structure or not.
45 | switch narg
46 | case 1
47 | % If it does not want to deal with the store structure,
48 | % then we do some caching of our own. There is a small
49 | % performance hit for this is some cases, but we expect
50 | % that this is most often the preferred choice.
51 | store = getStore(problem, x, storedb);
52 | if ~isfield(store, 'egrad__')
53 | store.egrad__ = problem.egrad(x);
54 | storedb = setStore(problem, x, storedb, store);
55 | end
56 | egrad = store.egrad__;
57 | case 2
58 | % Obtain, pass along, and save the store structure
59 | % associated to this point. If the user deals with the
60 | % store structure, then we don't do any automatic caching:
61 | % the user is in control.
62 | store = getStore(problem, x, storedb);
63 | [egrad, store] = problem.egrad(x, store);
64 | storedb = setStore(problem, x, storedb, store);
65 | otherwise
66 | up = MException('manopt:getEuclideanGradient:badegrad', ...
67 | 'egrad should accept 1 or 2 inputs.');
68 | throw(up);
69 | end
70 |
71 | else
72 | %% Abandon computing the Euclidean gradient
73 |
74 | up = MException('manopt:getEuclideanGradient:fail', ...
75 | ['The problem description is not explicit enough to ' ...
76 | 'compute the Euclidean gradient of the cost.']);
77 | throw(up);
78 |
79 | end
80 |
81 | end
82 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getGlobalDefaults.m:
--------------------------------------------------------------------------------
1 | function opts = getGlobalDefaults()
2 | % Returns a structure with default option values for Manopt.
3 | %
4 | % function opts = getGlobalDefaults()
5 | %
6 | % Returns a structure opts containing the global default options such as
7 | % verbosity level etc. Typically, global defaults are overwritten by solver
8 | % defaults, which are in turn overwritten by user-specified options.
9 | % See the online Manopt documentation for details on options.
10 | %
11 | % See also: mergeOptions
12 |
13 | % This file is part of Manopt: www.manopt.org.
14 | % Original author: Nicolas Boumal, Dec. 30, 2012.
15 | % Contributors:
16 | % Change log:
17 |
18 |
19 | % Verbosity level: 0 is no output at all. The higher the verbosity, the
20 | % more info is printed / displayed during solver execution.
21 | opts.verbosity = 3;
22 |
23 | % If debug is set to true, additional computations may be performed and
24 | % debugging information is outputed during solver execution.
25 | opts.debug = false;
26 |
27 | % Maximum number of store structures to store. If set to 0, caching
28 | % capabilities are not disabled, but the cache will be emptied at each
29 | % iteration of iterative solvers (more specifically: every time the
30 | % solver calls the purgeStoredb tool).
31 | opts.storedepth = 20;
32 |
33 | % Maximum amount of time a solver may execute, in seconds.
34 | opts.maxtime = inf;
35 |
36 | end
37 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getGradient.m:
--------------------------------------------------------------------------------
1 | function [grad, storedb] = getGradient(problem, x, storedb)
2 | % Computes the gradient of the cost function at x.
3 | %
4 | % function [grad, storedb] = getGradient(problem, x, storedb)
5 | %
6 | % Returns the gradient at x of the cost function described in the problem
7 | % structure. The cache database storedb is passed along, possibly modified
8 | % and returned in the process.
9 | %
10 | % See also: getDirectionalDerivative canGetGradient
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, Dec. 30, 2012.
14 | % Contributors:
15 | % Change log:
16 |
17 |
18 | if isfield(problem, 'grad')
19 | %% Compute the gradient using grad.
20 |
21 | is_octave = exist('OCTAVE_VERSION', 'builtin');
22 | if ~is_octave
23 | narg = nargin(problem.cost);
24 | else
25 | narg = 2;
26 | end
27 |
28 | % Check whether the gradient function wants to deal with the store
29 | % structure or not.
30 | switch narg
31 | case 1
32 | grad = problem.grad(x);
33 | case 2
34 | % Obtain, pass along, and save the store structure
35 | % associated to this point.
36 | store = getStore(problem, x, storedb);
37 | [grad store] = problem.grad(x, store);
38 | storedb = setStore(problem, x, storedb, store);
39 | otherwise
40 | up = MException('manopt:getGradient:badgrad', ...
41 | 'grad should accept 1 or 2 inputs.');
42 | throw(up);
43 | end
44 |
45 | elseif isfield(problem, 'costgrad')
46 | %% Compute the gradient using costgrad.
47 |
48 | is_octave = exist('OCTAVE_VERSION', 'builtin');
49 | if ~is_octave
50 | narg = nargin(problem.costgrad);
51 | else
52 | narg = 2;
53 | end
54 |
55 | % Check whether the costgrad function wants to deal with the store
56 | % structure or not.
57 | switch narg
58 | case 1
59 | [unused, grad] = problem.costgrad(x); %#ok
60 | case 2
61 | % Obtain, pass along, and save the store structure
62 | % associated to this point.
63 | store = getStore(problem, x, storedb);
64 | [unused, grad, store] = problem.costgrad(x, store); %#ok
65 | storedb = setStore(problem, x, storedb, store);
66 | otherwise
67 | up = MException('manopt:getGradient:badcostgrad', ...
68 | 'costgrad should accept 1 or 2 inputs.');
69 | throw(up);
70 | end
71 |
72 | elseif canGetEuclideanGradient(problem)
73 | %% Compute the gradient using the Euclidean gradient.
74 |
75 | [egrad, storedb] = getEuclideanGradient(problem, x, storedb);
76 | grad = problem.M.egrad2rgrad(x, egrad);
77 |
78 | else
79 | %% Abandon computing the gradient.
80 |
81 | up = MException('manopt:getGradient:fail', ...
82 | ['The problem description is not explicit enough to ' ...
83 | 'compute the gradient of the cost.']);
84 | throw(up);
85 |
86 | end
87 |
88 | end
89 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getHessian.m:
--------------------------------------------------------------------------------
1 | function [hess, storedb] = getHessian(problem, x, d, storedb)
2 | % Computes the Hessian of the cost function at x along d.
3 | %
4 | % function [hess, storedb] = getHessian(problem, x, d, storedb)
5 | %
6 | % Returns the Hessian at x along d of the cost function described in the
7 | % problem structure. The cache database storedb is passed along, possibly
8 | % modified and returned in the process.
9 | %
10 | % If an exact Hessian is not provided, an approximate Hessian is returned
11 | % if possible, without warning. If not possible, an exception will be
12 | % thrown. To check whether an exact Hessian is available or not (typically
13 | % to issue a warning if not), use canGetHessian.
14 | %
15 | % See also: getPrecon getApproxHessian canGetHessian
16 |
17 | % This file is part of Manopt: www.manopt.org.
18 | % Original author: Nicolas Boumal, Dec. 30, 2012.
19 | % Contributors:
20 | % Change log:
21 |
22 | if isfield(problem, 'hess')
23 | %% Compute the Hessian using hess.
24 |
25 | is_octave = exist('OCTAVE_VERSION', 'builtin');
26 | if ~is_octave
27 | narg = nargin(problem.hess);
28 | else
29 | narg = 3;
30 | end
31 |
32 | % Check whether the hess function wants to deal with the store
33 | % structure or not.
34 | switch narg
35 | case 2
36 | hess = problem.hess(x, d);
37 | case 3
38 | % Obtain, pass along, and save the store structure
39 | % associated to this point.
40 | store = getStore(problem, x, storedb);
41 | [hess, store] = problem.hess(x, d, store);
42 | storedb = setStore(problem, x, storedb, store);
43 | otherwise
44 | up = MException('manopt:getHessian:badhess', ...
45 | 'hess should accept 2 or 3 inputs.');
46 | throw(up);
47 | end
48 |
49 | elseif isfield(problem, 'ehess') && canGetEuclideanGradient(problem)
50 | %% Compute the Hessian using ehess.
51 |
52 | % We will need the Euclidean gradient for the conversion from the
53 | % Euclidean Hessian to the Riemannian Hessian.
54 | [egrad, storedb] = getEuclideanGradient(problem, x, storedb);
55 |
56 | is_octave = exist('OCTAVE_VERSION', 'builtin');
57 | if ~is_octave
58 | narg = nargin(problem.ehess);
59 | else
60 | narg = 3;
61 | end
62 |
63 | % Check whether the ehess function wants to deal with the store
64 | % structure or not.
65 | switch narg
66 | case 2
67 | ehess = problem.ehess(x, d);
68 | case 3
69 | % Obtain, pass along, and save the store structure
70 | % associated to this point.
71 | store = getStore(problem, x, storedb);
72 | [ehess, store] = problem.ehess(x, d, store);
73 | storedb = setStore(problem, x, storedb, store);
74 | otherwise
75 | up = MException('manopt:getHessian:badehess', ...
76 | 'ehess should accept 2 or 3 inputs.');
77 | throw(up);
78 | end
79 |
80 | % Convert to the Riemannian Hessian
81 | hess = problem.M.ehess2rhess(x, egrad, ehess, d);
82 |
83 | else
84 | %% Attempt the computation of an approximation of the Hessian.
85 |
86 | [hess, storedb] = getApproxHessian(problem, x, d, storedb);
87 |
88 | end
89 |
90 | end
91 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getHessianFD.m:
--------------------------------------------------------------------------------
1 | function [hessfd, storedb] = getHessianFD(problem, x, d, storedb)
2 | % Computes an approx. of the Hessian w/ finite differences of the gradient.
3 | %
4 | % function [hessfd, storedb] = getHessianFD(problem, x, d, storedb)
5 | %
6 | % Return a finite difference approximation of the Hessian at x along d of
7 | % the cost function described in the problem structure. The cache database
8 | % storedb is passed along, possibly modified and returned in the process.
9 | % The finite difference is based on computations of the gradient.
10 | %
11 | % If the gradient cannot be computed, an exception is thrown.
12 |
13 | % This file is part of Manopt: www.manopt.org.
14 | % Original author: Nicolas Boumal, Dec. 30, 2012.
15 | % Contributors:
16 | % Change log:
17 |
18 |
19 | if ~canGetGradient(problem)
20 | up = MException('manopt:getHessianFD:nogradient', ...
21 | 'getHessianFD requires the gradient to be computable.');
22 | throw(up);
23 | end
24 |
25 | % First, check whether the step d is not too small
26 | if problem.M.norm(x, d) < eps
27 | hessfd = problem.M.zerovec(x);
28 | return;
29 | end
30 |
31 | % Parameter: how far do we look?
32 | % TODO: this parameter should be tunable by the users.
33 | epsilon = 1e-4;
34 |
35 | % TODO: to make this approximation of the Hessian radially linear, that
36 | % is, to have that HessianFD(x, a*d) = a*HessianFD(x, d) for all
37 | % points x, tangent vectors d and real scalars a, we need to pay
38 | % special attention to the case of a < 0. This requires a notion of
39 | % "sign" for vectors d.
40 | % If vectors are uniquely represented by a matrix (which is the case
41 | % for Riemannian submanifolds of the space of matrices), than this
42 | % function is appropriate:
43 | % sg = sign(d(find(d(:), 1, 'first')));
44 | % But it is more difficult to build such a function in general. For
45 | % now, we ignore this difficulty and let the sign always be +1. This
46 | % hardly impacts the actual performances, but may be an obstacle for
47 | % theoretical analysis.
48 | sg = 1;
49 | norm_d = problem.M.norm(x, d);
50 | c = epsilon*sg/norm_d;
51 |
52 | % Compute the gradient at the current point.
53 | [grad0 storedb] = getGradient(problem, x, storedb);
54 |
55 | % Compute a point a little further along d and the gradient there.
56 | x1 = problem.M.retr(x, d, c);
57 | [grad1 storedb] = getGradient(problem, x1, storedb);
58 |
59 | % Transport grad1 back from x1 to x.
60 | grad1 = problem.M.transp(x1, x, grad1);
61 |
62 | % Return the finite difference of them.
63 | hessfd = problem.M.lincomb(x, 1/c, grad1, -1/c, grad0);
64 |
65 | end
66 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getLinesearch.m:
--------------------------------------------------------------------------------
1 | function [t, storedb] = getLinesearch(problem, x, d, storedb)
2 | % Returns a hint for line-search algorithms.
3 | %
4 | % function [t, storedb] = getLinesearch(problem, x, d, storedb)
5 | %
6 | % For a line-search problem at x along the tangent direction d, computes
7 | % and returns t such that retracting t*d at x yields a good point around
8 | % where to look for a line-search solution. That is: t is a hint as to "how
9 | % far to look" along the line.
10 | %
11 | % The cache database storedb is passed along, possibly modified and
12 | % returned in the process.
13 | %
14 | % See also: canGetLinesearch
15 |
16 | % This file is part of Manopt: www.manopt.org.
17 | % Original author: Nicolas Boumal, July 17, 2014.
18 | % Contributors:
19 | % Change log:
20 |
21 |
22 | if isfield(problem, 'linesearch')
23 | %% Compute the line-search hint function using linesearch.
24 |
25 | is_octave = exist('OCTAVE_VERSION', 'builtin');
26 | if ~is_octave
27 | narg = nargin(problem.linesearch);
28 | else
29 | narg = 3;
30 | end
31 |
32 | % Check whether the linesearch function wants to deal with the
33 | % store structure or not.
34 | switch narg
35 | case 2
36 | t = problem.linesearch(x, d);
37 | case 3
38 | % Obtain, pass along, and save the store structure
39 | % associated to this point.
40 | store = getStore(problem, x, storedb);
41 | [t, store] = problem.linesearch(x, d, store);
42 | storedb = setStore(problem, x, storedb, store);
43 | otherwise
44 | up = MException('manopt:getLinesearch:badfun', ...
45 | 'linesearch should accept 2 or 3 inputs.');
46 | throw(up);
47 | end
48 |
49 | else
50 | %% Abandon computing the line-search function.
51 |
52 | up = MException('manopt:getLinesearch:fail', ...
53 | ['The problem description is not explicit enough to ' ...
54 | 'compute a line-search hint.']);
55 | throw(up);
56 |
57 | end
58 |
59 | end
60 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getPrecon.m:
--------------------------------------------------------------------------------
1 | function [Pd, storedb] = getPrecon(problem, x, d, storedb)
2 | % Applies the preconditioner for the Hessian of the cost at x along d.
3 | %
4 | % function [Pd, storedb] = getPrecon(problem, x, storedb)
5 | %
6 | % Returns as Pd the result of applying the Hessian preconditioner to the
7 | % tangent vector d at point x. If no preconditioner is specified, Pd = d
8 | % (identity).
9 | %
10 | % See also: getHessian
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, Dec. 30, 2012.
14 | % Contributors:
15 | % Change log:
16 |
17 |
18 | if isfield(problem, 'precon')
19 | %% Compute the preconditioning using precon.
20 |
21 | is_octave = exist('OCTAVE_VERSION', 'builtin');
22 | if ~is_octave
23 | narg = nargin(problem.precon);
24 | else
25 | narg = 3;
26 | end
27 |
28 | % Check whether the precon function wants to deal with the store
29 | % structure or not.
30 | switch narg
31 | case 2
32 | Pd = problem.precon(x, d);
33 | case 3
34 | % Obtain, pass along, and save the store structure
35 | % associated to this point.
36 | store = getStore(problem, x, storedb);
37 | [Pd store] = problem.precon(x, d, store);
38 | storedb = setStore(problem, x, storedb, store);
39 | otherwise
40 | up = MException('manopt:getPrecon:badprecon', ...
41 | 'precon should accept 2 or 3 inputs.');
42 | throw(up);
43 | end
44 |
45 | else
46 | %% No preconditioner provided, so just use the identity.
47 |
48 | Pd = d;
49 |
50 | end
51 |
52 | end
53 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/getStore.m:
--------------------------------------------------------------------------------
1 | function store = getStore(problem, x, storedb)
2 | % Extracts a store struct. pertaining to a point from the storedb database.
3 | %
4 | % function store = getStore(problem, x, storedb)
5 | %
6 | % Queries the storedb database of structures (itself a structure) and
7 | % returns the store structure corresponding to the point x. If there is no
8 | % record for the point x, returns an empty structure.
9 | %
10 | % See also: setStore purgeStoredb
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, Dec. 30, 2012.
14 | % Contributors:
15 | % Change log:
16 |
17 | % Construct the fieldname (key) associated to the queried point x.
18 | key = problem.M.hash(x);
19 |
20 | % If there is a value stored for this key, return it.
21 | % Otherwise, return an empty structure.
22 | if isfield(storedb, key)
23 | store = storedb.(key);
24 | else
25 | store = struct();
26 | end
27 |
28 | end
29 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/hashmd5.m:
--------------------------------------------------------------------------------
1 | function h = hashmd5(inp)
2 | % Computes the MD5 hash of input data.
3 | %
4 | % function h = hashmd5(inp)
5 | %
6 | % Returns a string containing the MD5 hash of the input variable. The input
7 | % variable may be of any class that can be typecast to uint8 format, which
8 | % is fairly non-restrictive.
9 |
10 | % This file is part of Manopt: www.manopt.org.
11 | % This code is a stripped version of more general hashing code by
12 | % Michael Kleder, Nov 2005.
13 | % Change log:
14 | % Aug. 8, 2013 (NB) : Made x a static (persistent) variable, in the hope
15 | % it will speed it up. Furthermore, the function is
16 | % now Octave compatible.
17 |
18 | is_octave = exist('OCTAVE_VERSION', 'builtin');
19 |
20 | persistent x;
21 | if isempty(x) && ~is_octave
22 | x = java.security.MessageDigest.getInstance('MD5');
23 | end
24 |
25 | inp=inp(:);
26 | % Convert strings and logicals into uint8 format
27 | if ischar(inp) || islogical(inp)
28 | inp=uint8(inp);
29 | else % Convert everything else into uint8 format without loss of data
30 | inp=typecast(inp,'uint8');
31 | end
32 |
33 | % Create hash
34 | if ~is_octave
35 | x.update(inp);
36 | h = typecast(x.digest, 'uint8');
37 | h = dec2hex(h)';
38 | % Remote possibility: all hash bytes < 128, so pad:
39 | if(size(h,1))==1
40 | h = [repmat('0',[1 size(h,2)]);h];
41 | end
42 | h = lower(h(:)');
43 | else
44 | h = md5sum(char(inp'), true);
45 | end
46 |
47 | end
48 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/mergeOptions.m:
--------------------------------------------------------------------------------
1 | function opts = mergeOptions(opts1, opts2)
2 | % Merges two options structures with one having precedence over the other.
3 | %
4 | % function opts = mergeOptions(opts1, opts2)
5 | %
6 | % input: opts1 and opts2 are two structures.
7 | % output: opts is a structure containing all fields of opts1 and opts2.
8 | % Whenever a field is present in both opts1 and opts2, it is the value in
9 | % opts2 that is kept.
10 | %
11 | % The typical usage is to have opts1 contain default options and opts2
12 | % contain user-specified options that overwrite the defaults.
13 | %
14 | % See also: getGlobalDefaults
15 |
16 | % This file is part of Manopt: www.manopt.org.
17 | % Original author: Nicolas Boumal, Dec. 30, 2012.
18 | % Contributors:
19 | % Change log:
20 |
21 |
22 | if isempty(opts1)
23 | opts1 = struct();
24 | end
25 | if isempty(opts2)
26 | opts2 = struct();
27 | end
28 |
29 | opts = opts1;
30 | fields = fieldnames(opts2);
31 | for i = 1 : length(fields)
32 | opts.(fields{i}) = opts2.(fields{i});
33 | end
34 |
35 | end
36 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/purgeStoredb.m:
--------------------------------------------------------------------------------
1 | function storedb = purgeStoredb(storedb, storedepth)
2 | % Makes sure the storedb database does not exceed some maximum size.
3 | %
4 | % function storedb = purgeStoredb(storedb, storedepth)
5 | %
6 | % Trim the store database storedb such that it contains at most storedepth
7 | % elements (store structures). The 'lastset__' field of the store
8 | % structures is used to delete the oldest elements first.
9 |
10 | % This file is part of Manopt: www.manopt.org.
11 | % Original author: Nicolas Boumal, Dec. 30, 2012.
12 | % Contributors:
13 | % Change log:
14 | % Dec. 6, 2013, NB:
15 | % Now using a persistent uint32 counter instead of cputime to track
16 | % the most recently modified stores.
17 |
18 |
19 | if storedepth <= 0
20 | storedb = struct();
21 | return;
22 | end
23 |
24 | % Get list of field names (keys).
25 | keys = fieldnames(storedb);
26 | nkeys = length(keys);
27 |
28 | % If we need to remove some of the elements in the database.
29 | if nkeys > storedepth
30 |
31 | % Get the last-set counter of each element: a higher number means
32 | % it was modified more recently.
33 | lastset = zeros(nkeys, 1, 'uint32');
34 | for i = 1 : nkeys
35 | store = storedb.(keys{i});
36 | lastset(i) = store.lastset__;
37 | end
38 |
39 | % Sort the counters and determine the threshold above which the
40 | % field needs to be removed.
41 | sortlastset = sort(lastset, 1, 'descend');
42 | minlastset = sortlastset(storedepth);
43 |
44 | % Remove all fields that are too old.
45 | storedb = rmfield(storedb, keys(lastset < minlastset));
46 |
47 | end
48 |
49 | end
50 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/setStore.m:
--------------------------------------------------------------------------------
1 | function storedb = setStore(problem, x, storedb, store)
2 | % Updates the store struct. pertaining to a point in the storedb database.
3 | %
4 | % function storedb = setStore(problem, x, storedb, store)
5 | %
6 | % Updates the storedb database of structures such that the structure
7 | % corresponding to the point x will be replaced by store. If there was no
8 | % record for the point x, it is created and set to store. The updated
9 | % storedb database is returned. The lastset__ field of the store structure
10 | % keeps track of which stores were updated latest.
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, Dec. 30, 2012.
14 | % Contributors:
15 | % Change log:
16 | % Dec. 6, 2013, NB:
17 | % Now using a persistent uint32 counter instead of cputime to track
18 | % the most recently modified stores.
19 |
20 | % This persistent counter is used to keep track of the order in which
21 | % store structures are updated. This is used by purgeStoredb to erase
22 | % the least recently useful store structures first when garbage
23 | % collecting.
24 | persistent counter;
25 | if isempty(counter)
26 | counter = uint32(0);
27 | end
28 |
29 | assert(nargout == 1, ...
30 | 'The output of setStore should replace your storedb.');
31 |
32 | % Construct the fieldname (key) associated to the current point x.
33 | key = problem.M.hash(x);
34 |
35 | % Set the value associated to that key to store.
36 | storedb.(key) = store;
37 |
38 | % Add / update a last-set flag.
39 | storedb.(key).lastset__ = counter;
40 | counter = counter + 1;
41 |
42 | end
43 |
--------------------------------------------------------------------------------
/manopt/manopt/privatetools/stoppingcriterion.m:
--------------------------------------------------------------------------------
1 | function [stop reason] = stoppingcriterion(problem, x, options, info, last)
2 | % Checks for standard stopping criteria, as a helper to solvers.
3 | %
4 | % function [stop reason] = stoppingcriterion(problem, x, options, info, last)
5 | %
6 | % Executes standard stopping criterion checks, based on what is defined in
7 | % the info(last) stats structure and in the options structure.
8 | %
9 | % The returned number 'stop' is 0 if none of the stopping criteria
10 | % triggered, and a (strictly) positive integer otherwise. The integer
11 | % identifies which criterion triggered:
12 | % 0 : Nothing triggered;
13 | % 1 : Cost tolerance reached;
14 | % 2 : Gradient norm tolerance reached;
15 | % 3 : Max time exceeded;
16 | % 4 : Max iteration count reached;
17 | % 5 : Maximum number of cost evaluations reached;
18 | % 6 : User defined stopfun criterion triggered.
19 | %
20 | % The output 'reason' is a string describing the triggered event.
21 |
22 | % This file is part of Manopt: www.manopt.org.
23 | % Original author: Nicolas Boumal, Dec. 30, 2012.
24 | % Contributors:
25 | % Change log:
26 |
27 |
28 | stop = 0;
29 | reason = '';
30 |
31 | stats = info(last);
32 |
33 | % Target cost attained
34 | if isfield(stats, 'cost') && isfield(options, 'tolcost') && ...
35 | stats.cost <= options.tolcost
36 | reason = 'Cost tolerance reached. See options.tolcost.';
37 | stop = 1;
38 | return;
39 | end
40 |
41 | % Target gradient norm attained
42 | if isfield(stats, 'gradnorm') && isfield(options, 'tolgradnorm') && ...
43 | stats.gradnorm < options.tolgradnorm
44 | reason = 'Gradient norm tolerance reached. See options.tolgradnorm.';
45 | stop = 2;
46 | return;
47 | end
48 |
49 | % Alloted time exceeded
50 | if isfield(stats, 'time') && isfield(options, 'maxtime') && ...
51 | stats.time >= options.maxtime
52 | reason = 'Max time exceeded. See options.maxtime.';
53 | stop = 3;
54 | return;
55 | end
56 |
57 | % Alloted iteration count exceeded
58 | if isfield(stats, 'iter') && isfield(options, 'maxiter') && ...
59 | stats.iter >= options.maxiter
60 | reason = 'Max iteration count reached. See options.maxiter.';
61 | stop = 4;
62 | return;
63 | end
64 |
65 | % Alloted function evaluation count exceeded
66 | if isfield(stats, 'costevals') && isfield(options, 'maxcostevals') && ...
67 | stats.costevals >= options.maxcostevals
68 | reason = 'Maximum number of cost evaluations reached. See options.maxcostevals.';
69 | stop = 5;
70 | end
71 |
72 | % Check whether the possibly user defined stopping criterion
73 | % triggers or not.
74 | if isfield(options, 'stopfun')
75 | userstop = options.stopfun(problem, x, info, last);
76 | if userstop
77 | reason = 'User defined stopfun criterion triggered. See options.stopfun.';
78 | stop = 6;
79 | return;
80 | end
81 | end
82 |
83 | end
84 |
--------------------------------------------------------------------------------
/manopt/manopt/solvers/linesearch/linesearch_adaptive.m:
--------------------------------------------------------------------------------
1 | function [stepsize newx storedb lsmem lsstats] = ...
2 | linesearch_adaptive(problem, x, d, f0, df0, options, storedb, lsmem)
3 | % Adaptive line search algorithm (step size selection) for descent methods.
4 | %
5 | % function [stepsize newx storedb lsmem lsstats] =
6 | % linesearch_adaptive(problem, x, d, f0, df0, options, storedb, lsmem)
7 | %
8 | % Adaptive linesearch algorithm for descent methods, based on a simple
9 | % backtracking method. On average, this line search intends to do only one
10 | % or two cost evaluations.
11 | %
12 | % Contrary to linesearch.m, this function is not invariant under rescaling
13 | % of the search direction d. Nevertheless, it sometimes performs better.
14 | %
15 | % Inputs/Outputs : see help for linesearch
16 | %
17 | % See also: steepestdescent conjugategradients linesearch
18 |
19 | % This file is part of Manopt: www.manopt.org.
20 | % Original author: Bamdev Mishra, Dec. 30, 2012.
21 | % Contributors: Nicolas Boumal
22 | % Change log:
23 | %
24 | % Sept. 13, 2013 (NB) :
25 | % The automatic direction reversal feature was removed (it triggered
26 | % when df0 > 0). Direction reversal is a decision that needs to be
27 | % made by the solver, so it can know about it.
28 | %
29 | % Nov. 7, 2013 (NB) :
30 | % The whole function has been recoded to mimick more closely the new
31 | % version of linesearch.m. The parameters are available through the
32 | % options structure passed to the solver and have the same names and
33 | % same meaning as for the base linesearch. The information is logged
34 | % more reliably.
35 |
36 |
37 | % Backtracking default parameters. These can be overwritten in the
38 | % options structure which is passed to the solver.
39 | default_options.ls_contraction_factor = .5;
40 | default_options.ls_suff_decr = .5;
41 | default_options.ls_max_steps = 10;
42 | default_options.ls_initial_stepsize = 1;
43 | options = mergeOptions(default_options, options);
44 |
45 | contraction_factor = options.ls_contraction_factor;
46 | suff_decr = options.ls_suff_decr;
47 | max_ls_steps = options.ls_max_steps;
48 | initial_stepsize = options.ls_initial_stepsize;
49 |
50 | % Compute the norm of the search direction.
51 | norm_d = problem.M.norm(x, d);
52 |
53 | % If this is not the first iteration, then lsmem should have been
54 | % filled with a suggestion for the initial step.
55 | if isstruct(lsmem) && isfield(lsmem, 'init_alpha')
56 | % Pick initial step size based on where we were last time,
57 | alpha = lsmem.init_alpha;
58 |
59 | % Otherwise, fall back to a user supplied suggestion.
60 | else
61 | alpha = initial_stepsize / norm_d;
62 | end
63 |
64 | % Make the chosen step and compute the cost there.
65 | newx = problem.M.retr(x, d, alpha);
66 | [newf storedb] = getCost(problem, newx, storedb);
67 | cost_evaluations = 1;
68 |
69 | % Backtrack while the Armijo criterion is not satisfied
70 | while newf > f0 + suff_decr*alpha*df0
71 |
72 | % Reduce the step size,
73 | alpha = contraction_factor * alpha;
74 |
75 | % and look closer down the line
76 | newx = problem.M.retr(x, d, alpha);
77 | [newf storedb] = getCost(problem, newx, storedb);
78 | cost_evaluations = cost_evaluations + 1;
79 |
80 | % Make sure we don't run out of budget
81 | if cost_evaluations >= max_ls_steps
82 | break;
83 | end
84 |
85 | end
86 |
87 | % If we got here without obtaining a decrease, we reject the step.
88 | if newf > f0
89 | alpha = 0;
90 | newx = x;
91 | newf = f0; %#ok
92 | end
93 |
94 | % As seen outside this function, stepsize is the size of the vector we
95 | % retract to make the step from x to newx. Since the step is alpha*d:
96 | stepsize = alpha * norm_d;
97 |
98 | % Fill lsmem with a suggestion for what the next initial step size
99 | % trial should be. On average we intend to do only one extra cost
100 | % evaluation. Notice how the suggestion is not about stepsize but about
101 | % alpha. This is the reason why this line search is not invariant under
102 | % rescaling of the search direction d.
103 | switch cost_evaluations
104 | case 1
105 | % If things go well, push your luck.
106 | lsmem.init_alpha = 2 * alpha;
107 | case 2
108 | % If things go smoothly, try to keep pace.
109 | lsmem.init_alpha = alpha;
110 | otherwise
111 | % If you backtracked a lot, the new stepsize is probably quite
112 | % small: try to recover.
113 | lsmem.init_alpha = 2 * alpha;
114 | end
115 |
116 | % Save some statistics also, for possible analysis.
117 | lsstats.costevals = cost_evaluations;
118 | lsstats.stepsize = stepsize;
119 | lsstats.alpha = alpha;
120 |
121 | end
122 |
--------------------------------------------------------------------------------
/manopt/manopt/solvers/linesearch/linesearch_hint.m:
--------------------------------------------------------------------------------
1 | function [stepsize, newx, storedb, lsmem, lsstats] = ...
2 | linesearch_hint(problem, x, d, f0, df0, options, storedb, lsmem)
3 | % Armijo line-search based on the line-search hint in the problem structure.
4 | %
5 | % function [stepsize, newx, storedb, lsmem, lsstats] =
6 | % linesearch_hint(problem, x, d, f0, df0, options, storedb, lsmem)
7 | %
8 | % Base line-search algorithm for descent methods, based on a simple
9 | % backtracking method. The search direction provided has to be a descent
10 | % direction, as indicated by a negative df0 = directional derivative of f
11 | % at x along d.
12 | %
13 | % The algorithm obtains an initial step size candidate from the problem
14 | % structure, typically through the problem.linesearch function. If that
15 | % step does not fulfill the Armijo sufficient decrease criterion, that step
16 | % size is reduced geometrically until a satisfactory step size is obtained
17 | % or until a failure criterion triggers.
18 | %
19 | % Below, the step will be constructed as alpha*d, and the step size is the
20 | % norm of that vector, thus: stepsize = alpha*norm_d. The step is executed
21 | % by retracting the vector alpha*d from the current point x, giving newx.
22 | %
23 | % Inputs
24 | %
25 | % problem : structure holding the description of the optimization problem
26 | % x : current point on the manifold problem.M
27 | % d : tangent vector at x (descent direction)
28 | % f0 : cost value at x
29 | % df0 : directional derivative at x along d
30 | % options : options structure (see in code for usage)
31 | % storedb : store database structure for caching purposes
32 | % lsmem : not used
33 | %
34 | % Outputs
35 | %
36 | % stepsize : norm of the vector retracted to reach newx from x.
37 | % newx : next iterate suggested by the line-search algorithm, such that
38 | % the retraction at x of the vector alpha*d reaches newx.
39 | % storedb : the (possibly updated) store database structure.
40 | % lsmem : not used.
41 | % lsstats : statistics about the line-search procedure (stepsize, number
42 | % of trials etc).
43 | %
44 | % See also: steepestdescent conjugategradients linesearch
45 |
46 | % This file is part of Manopt: www.manopt.org.
47 | % Original author: Nicolas Boumal, July 17, 2014.
48 | % Contributors:
49 | % Change log:
50 | %
51 |
52 |
53 | % Backtracking default parameters. These can be overwritten in the
54 | % options structure which is passed to the solver.
55 | default_options.ls_contraction_factor = .5;
56 | default_options.ls_suff_decr = 1e-4;
57 | default_options.ls_max_steps = 25;
58 |
59 | options = mergeOptions(default_options, options);
60 |
61 | contraction_factor = options.ls_contraction_factor;
62 | suff_decr = options.ls_suff_decr;
63 | max_ls_steps = options.ls_max_steps;
64 |
65 | % Obtain an initial guess at alpha from the problem structure. It is
66 | % assumed that the present line-search is only called when the problem
67 | % structure provides enough information for the call here to work.
68 | [alpha, storedb] = getLinesearch(problem, x, d, storedb);
69 |
70 | % Make the chosen step and compute the cost there.
71 | newx = problem.M.retr(x, d, alpha);
72 | [newf, storedb] = getCost(problem, newx, storedb);
73 | cost_evaluations = 1;
74 |
75 | % Backtrack while the Armijo criterion is not satisfied
76 | while newf > f0 + suff_decr*alpha*df0
77 |
78 | % Reduce the step size,
79 | alpha = contraction_factor * alpha;
80 |
81 | % and look closer down the line
82 | newx = problem.M.retr(x, d, alpha);
83 | [newf, storedb] = getCost(problem, newx, storedb);
84 | cost_evaluations = cost_evaluations + 1;
85 |
86 | % Make sure we don't run out of budget
87 | if cost_evaluations >= max_ls_steps
88 | break;
89 | end
90 |
91 | end
92 |
93 | % If we got here without obtaining a decrease, we reject the step.
94 | if newf > f0
95 | alpha = 0;
96 | newx = x;
97 | newf = f0; %#ok
98 | end
99 |
100 | % As seen outside this function, stepsize is the size of the vector we
101 | % retract to make the step from x to newx. Since the step is alpha*d:
102 | norm_d = problem.M.norm(x, d);
103 | stepsize = alpha * norm_d;
104 |
105 | % Save some statistics also, for possible analysis.
106 | lsstats.costevals = cost_evaluations;
107 | lsstats.stepsize = stepsize;
108 | lsstats.alpha = alpha;
109 |
110 | end
111 |
--------------------------------------------------------------------------------
/manopt/manopt/solvers/neldermead/centroid.m:
--------------------------------------------------------------------------------
1 | function y = centroid(M, x)
2 | % Attempts the computation of a centroid of a set of points on amanifold.
3 | %
4 | % function y = centroid(M, x)
5 | %
6 | % M is a structure representing a manifold. x is a cell of points on that
7 | % manifold.
8 |
9 | % This file is part of Manopt: www.manopt.org.
10 | % Original author: Nicolas Boumal, Dec. 30, 2012.
11 | % Contributors:
12 | % Change log:
13 |
14 |
15 | % For now, just apply a few steps of gradient descent for Karcher means
16 |
17 | n = numel(x);
18 |
19 | problem.M = M;
20 |
21 | problem.cost = @cost;
22 | function val = cost(y)
23 | val = 0;
24 | for i = 1 : n
25 | val = val + M.dist(y, x{i})^2;
26 | end
27 | val = val/2;
28 | end
29 |
30 | problem.grad = @grad;
31 | function g = grad(y)
32 | g = M.zerovec(y);
33 | for i = 1 : n
34 | g = M.lincomb(y, 1, g, -1, M.log(y, x{i}));
35 | end
36 | end
37 |
38 | % This line can be uncommented to check that the gradient is indeed
39 | % correct. This should always be the case if the dist and the log
40 | % functions in the manifold are correct.
41 | % checkgradient(problem);
42 |
43 | query = warning('query', 'manopt:getHessian:approx');
44 | warning('off', 'manopt:getHessian:approx')
45 | options.verbosity = 0;
46 | options.maxiter = 15;
47 | y = trustregions(problem, x{randi(n)}, options);
48 | warning(query.state, 'manopt:getHessian:approx')
49 |
50 | end
51 |
--------------------------------------------------------------------------------
/manopt/manopt/solvers/trustregions/license for original GenRTR code.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2007,2012 Christopher G. Baker, Pierre-Antoine Absil, Kyle A. Gallivan
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 | * Redistributions of source code must retain the above copyright
7 | notice, this list of conditions and the following disclaimer.
8 | * Redistributions in binary form must reproduce the above copyright
9 | notice, this list of conditions and the following disclaimer in the
10 | documentation and/or other materials provided with the distribution.
11 | * Neither the names of the contributors nor of their affiliated
12 | institutions may be used to endorse or promote products
13 | derived from this software without specific prior written permission.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY
19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
26 | For questions, please contact Chris Baker (chris@cgbaker.net)
27 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/checkdiff.m:
--------------------------------------------------------------------------------
1 | function checkdiff(problem, x, d)
2 | % Checks the consistency of the cost function and directional derivatives.
3 | %
4 | % function checkdiff(problem)
5 | % function checkdiff(problem, x)
6 | % function checkdiff(problem, x, d)
7 | %
8 | % checkdiff performs a numerical test to check that the directional
9 | % derivatives defined in the problem structure agree up to first order with
10 | % the cost function at some point x, along some direction d. The test is
11 | % based on a truncated Taylor series (see online Manopt documentation).
12 | %
13 | % Both x and d are optional and will be sampled at random if omitted.
14 | %
15 | % See also: checkgradient checkhessian
16 |
17 | % This file is part of Manopt: www.manopt.org.
18 | % Original author: Nicolas Boumal, Dec. 30, 2012.
19 | % Contributors:
20 | % Change log:
21 |
22 |
23 | % Verify that the problem description is sufficient.
24 | if ~canGetCost(problem)
25 | error('It seems no cost was provided.');
26 | end
27 | if ~canGetDirectionalDerivative(problem)
28 | error('It seems no directional derivatives were provided.');
29 | end
30 |
31 | dbstore = struct();
32 |
33 | x_isprovided = exist('x', 'var') && ~isempty(x);
34 | d_isprovided = exist('d', 'var') && ~isempty(d);
35 |
36 | if ~x_isprovided && d_isprovided
37 | error('If d is provided, x must be too, since d is tangent at x.');
38 | end
39 |
40 | % If x and / or d are not specified, pick them at random.
41 | if ~x_isprovided
42 | x = problem.M.rand();
43 | end
44 | if ~d_isprovided
45 | d = problem.M.randvec(x);
46 | end
47 |
48 | % Compute the value f0 at f and directional derivative at x along d.
49 | f0 = getCost(problem, x, dbstore);
50 | df0 = getDirectionalDerivative(problem, x, d, dbstore);
51 |
52 | % Compute the value of f at points on the geodesic (or approximation of
53 | % it) originating from x, along direction d, for stepsizes in a large
54 | % range given by h.
55 | h = logspace(-8, 0, 51);
56 | value = zeros(size(h));
57 | for i = 1 : length(h)
58 | y = problem.M.exp(x, d, h(i));
59 | value(i) = getCost(problem, y, dbstore);
60 | end
61 |
62 | % Compute the linear approximation of the cost function using f0 and
63 | % df0 at the same points.
64 | model = polyval([df0 f0], h);
65 |
66 | % Compute the approximation error
67 | err = abs(model - value);
68 |
69 | % And plot it.
70 | loglog(h, err);
71 | title(sprintf(['Directional derivative check.\nThe slope of the '...
72 | 'continuous line should match that of the dashed '...
73 | '(reference) line\nover at least a few orders of '...
74 | 'magnitude for h.']));
75 | xlabel('h');
76 | ylabel('Approximation error');
77 |
78 | line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ...
79 | 'color', 'k', 'LineStyle', '--', ...
80 | 'YLimInclude', 'off', 'XLimInclude', 'off');
81 |
82 |
83 | % In a numerically reasonable neighborhood, the error should decrease
84 | % as the square of the stepsize, i.e., in loglog scale, the error
85 | % should have a slope of 2.
86 | window_len = 10;
87 | [range poly] = identify_linear_piece(log10(h), log10(err), window_len);
88 | hold on;
89 | loglog(h(range), 10.^polyval(poly, log10(h(range))), ...
90 | 'r-', 'LineWidth', 3);
91 | hold off;
92 |
93 | fprintf('The slope should be 2. It appears to be: %g.\n', poly(1));
94 | fprintf(['If it is far from 2, then directional derivatives ' ...
95 | 'might be erroneous.\n']);
96 |
97 | end
98 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/checkdiffSingle.m:
--------------------------------------------------------------------------------
1 | function checkdiff(problem, x, d)
2 | % Checks the consistency of the cost function and directional derivatives.
3 | %
4 | % function checkdiff(problem)
5 | % function checkdiff(problem, x)
6 | % function checkdiff(problem, x, d)
7 | %
8 | % checkdiff performs a numerical test to check that the directional
9 | % derivatives defined in the problem structure agree up to first order with
10 | % the cost function at some point x, along some direction d. The test is
11 | % based on a truncated Taylor series (see online Manopt documentation).
12 | %
13 | % Both x and d are optional and will be sampled at random if omitted.
14 | %
15 | % See also: checkgradient checkhessian
16 |
17 | % This file is part of Manopt: www.manopt.org.
18 | % Original author: Nicolas Boumal, Dec. 30, 2012.
19 | % Contributors:
20 | % Change log:
21 |
22 |
23 | % Verify that the problem description is sufficient.
24 | if ~canGetCost(problem)
25 | error('It seems no cost was provided.');
26 | end
27 | if ~canGetDirectionalDerivative(problem)
28 | error('It seems no directional derivatives were provided.');
29 | end
30 |
31 | dbstore = struct();
32 |
33 | x_isprovided = exist('x', 'var') && ~isempty(x);
34 | d_isprovided = exist('d', 'var') && ~isempty(d);
35 |
36 | if ~x_isprovided && d_isprovided
37 | error('If d is provided, x must be too, since d is tangent at x.');
38 | end
39 |
40 | % If x and / or d are not specified, pick them at random.
41 | if ~x_isprovided
42 | x = problem.M.rand();
43 | end
44 | if ~d_isprovided
45 | d = problem.M.randvec(x);
46 | end
47 |
48 | % Compute the value f0 at f and directional derivative at x along d.
49 | f0 = getCost(problem, x, dbstore);
50 | df0 = getDirectionalDerivative(problem, x, d, dbstore);
51 |
52 | % Compute the value of f at points on the geodesic (or approximation of
53 | % it) originating from x, along direction d, for stepsizes in a large
54 | % range given by h.
55 | h = logspace(-8, 0, 51);
56 | value = zeros(size(h));
57 | for i = 1 : length(h)
58 | y = problem.M.exp(x, d, h(i));
59 | value(i) = getCost(problem, y, dbstore);
60 | end
61 |
62 | % Compute the linear approximation of the cost function using f0 and
63 | % df0 at the same points.
64 | model = polyval([df0 f0], h);
65 |
66 | % Compute the approximation error
67 | err = abs(model - value);
68 |
69 | % And plot it.
70 | loglog(h, err);
71 | title(sprintf(['Directional derivative check.\nThe slope of the '...
72 | 'continuous line should match that of the dashed '...
73 | '(reference) line\nover at least a few orders of '...
74 | 'magnitude for h.']));
75 | xlabel('h');
76 | ylabel('Approximation error');
77 |
78 | line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ...
79 | 'color', 'k', 'LineStyle', '--', ...
80 | 'YLimInclude', 'off', 'XLimInclude', 'off');
81 |
82 |
83 | % In a numerically reasonable neighborhood, the error should decrease
84 | % as the square of the stepsize, i.e., in loglog scale, the error
85 | % should have a slope of 2.
86 | window_len = 10;
87 | [range poly] = identify_linear_piece(log10(h), log10(err), window_len);
88 | hold on;
89 | loglog(h(range), 10.^polyval(poly, log10(h(range))), ...
90 | 'r-', 'LineWidth', 3);
91 | hold off;
92 |
93 | fprintf('The slope should be 2. It appears to be: %g.\n', poly(1));
94 | fprintf(['If it is far from 2, then directional derivatives ' ...
95 | 'might be erroneous.\n']);
96 |
97 | end
98 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/checkdiffSinglePrecision.m:
--------------------------------------------------------------------------------
1 | function checkdiffSinglePrecision(problem, x, d)
2 | % Checks the consistency of the cost function and directional derivatives.
3 | %
4 | % function checkdiff(problem)
5 | % function checkdiff(problem, x)
6 | % function checkdiff(problem, x, d)
7 | %
8 | % checkdiff performs a numerical test to check that the directional
9 | % derivatives defined in the problem structure agree up to first order with
10 | % the cost function at some point x, along some direction d. The test is
11 | % based on a truncated Taylor series (see online Manopt documentation).
12 | %
13 | % Both x and d are optional and will be sampled at random if omitted.
14 | %
15 | % See also: checkgradient checkhessian
16 |
17 | % This file is part of Manopt: www.manopt.org.
18 | % Original author: Nicolas Boumal, Dec. 30, 2012.
19 | % Contributors:
20 | % Change log:
21 |
22 |
23 | % Verify that the problem description is sufficient.
24 | if ~canGetCost(problem)
25 | error('It seems no cost was provided.');
26 | end
27 | if ~canGetDirectionalDerivative(problem)
28 | error('It seems no directional derivatives were provided.');
29 | end
30 |
31 | dbstore = struct();
32 |
33 | x_isprovided = exist('x', 'var') && ~isempty(x);
34 | d_isprovided = exist('d', 'var') && ~isempty(d);
35 |
36 | if ~x_isprovided && d_isprovided
37 | error('If d is provided, x must be too, since d is tangent at x.');
38 | end
39 |
40 | % If x and / or d are not specified, pick them at random.
41 | if ~x_isprovided
42 | x = problem.M.rand();
43 | end
44 | if ~d_isprovided
45 | d = problem.M.randvec(x);
46 | end
47 |
48 | % Compute the value f0 at f and directional derivative at x along d.
49 | f0 = getCost(problem, x, dbstore);
50 | df0 = getDirectionalDerivative(problem, x, d, dbstore);
51 |
52 | % Compute the value of f at points on the geodesic (or approximation of
53 | % it) originating from x, along direction d, for stepsizes in a large
54 | % range given by h.
55 | h = single(logspace(-8, 0, 51));
56 | value = single(zeros(size(h)));
57 | for i = 1 : length(h)
58 | y = problem.M.exp(x, d, h(i));
59 | y.encoder = single(y.encoder);
60 | y.decoder = single(y.decoder);
61 | value(i) = getCost(problem, y, dbstore);
62 | end
63 |
64 | % Compute the linear approximation of the cost function using f0 and
65 | % df0 at the same points.
66 | model = polyval([df0 f0], h);
67 |
68 | % Compute the approximation error
69 | err = abs(model - value);
70 |
71 | % And plot it.
72 | loglog(h, err);
73 | title(sprintf(['Directional derivative check.\nThe slope of the '...
74 | 'continuous line should match that of the dashed '...
75 | '(reference) line\nover at least a few orders of '...
76 | 'magnitude for h.']));
77 | xlabel('h');
78 | ylabel('Approximation error');
79 |
80 | line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ...
81 | 'color', 'k', 'LineStyle', '--', ...
82 | 'YLimInclude', 'off', 'XLimInclude', 'off');
83 |
84 |
85 | % In a numerically reasonable neighborhood, the error should decrease
86 | % as the square of the stepsize, i.e., in loglog scale, the error
87 | % should have a slope of 2.
88 | window_len = 10;
89 | [range poly] = identify_linear_piece_SinglePrecision(log10(h), log10(err), window_len);
90 | hold on;
91 | loglog(h(range), 10.^polyval(poly, log10(h(range))), ...
92 | 'r-', 'LineWidth', 3);
93 | hold off;
94 |
95 | fprintf('The slope should be 2. It appears to be: %g.\n', poly(1));
96 | fprintf(['If it is far from 2, then directional derivatives ' ...
97 | 'might be erroneous.\n']);
98 |
99 | end
100 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/checkgradient.m:
--------------------------------------------------------------------------------
1 | function checkgradient(problem, x, d)
2 | % Checks the consistency of the cost function and the gradient.
3 | %
4 | % function checkgradient(problem)
5 | % function checkgradient(problem, x)
6 | % function checkgradient(problem, x, d)
7 | %
8 | % checkgradient performs a numerical test to check that the gradient
9 | % defined in the problem structure agrees up to first order with the cost
10 | % function at some point x, along some direction d. The test is based on a
11 | % truncated Taylor series (see online Manopt documentation).
12 | %
13 | % It is also tested that the gradient is indeed a tangent vector.
14 | %
15 | % Both x and d are optional and will be sampled at random if omitted.
16 | %
17 | % See also: checkdiff checkhessian
18 |
19 | % This file is part of Manopt: www.manopt.org.
20 | % Original author: Nicolas Boumal, Dec. 30, 2012.
21 | % Contributors:
22 | % Change log:
23 |
24 |
25 | % Verify that the problem description is sufficient.
26 | if ~canGetCost(problem)
27 | error('It seems no cost was provided.');
28 | end
29 | if ~canGetGradient(problem)
30 | error('It seems no gradient provided.');
31 | end
32 |
33 | dbstore = struct();
34 |
35 | x_isprovided = exist('x', 'var') && ~isempty(x);
36 | d_isprovided = exist('d', 'var') && ~isempty(d);
37 |
38 | if ~x_isprovided && d_isprovided
39 | error('If d is provided, x must be too, since d is tangent at x.');
40 | end
41 |
42 | % If x and / or d are not specified, pick them at random.
43 | if ~x_isprovided
44 | x = problem.M.rand();
45 | end
46 | if ~d_isprovided
47 | d = problem.M.randvec(x);
48 | end
49 |
50 | %% Check that the gradient yields a first order model of the cost.
51 |
52 | % By removing the 'directional derivative' function, it should be so
53 | % (?) that the checkdiff function will use the gradient to compute
54 | % directional derivatives.
55 | if isfield(problem, 'diff')
56 | problem = rmfield(problem, 'diff');
57 | end
58 | checkdiff(problem, x, d);
59 | title(sprintf(['Gradient check.\nThe slope of the continuous line ' ...
60 | 'should match that of the dashed (reference) line\n' ...
61 | 'over at least a few orders of magnitude for h.']));
62 | xlabel('h');
63 | ylabel('Approximation error');
64 |
65 | %% Try to check that the gradient is a tangent vector.
66 | if isfield(problem.M, 'tangent')
67 | grad = getGradient(problem, x, dbstore);
68 | pgrad = problem.M.tangent(x, grad);
69 | residual = problem.M.lincomb(x, 1, grad, -1, pgrad);
70 | err = problem.M.norm(x, residual);
71 | fprintf('The residual should be 0, or very close. Residual: %g.\n', err);
72 | fprintf('If it is far from 0, then the gradient is not in the tangent space.\n');
73 | else
74 | fprintf(['Unfortunately, Manopt was unable to verify that the '...
75 | 'gradient is indeed a tangent vector.\nPlease verify ' ...
76 | 'this manually or implement the ''tangent'' function ' ...
77 | 'in your manifold structure.']);
78 | end
79 |
80 | end
81 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/checkgradientSinglePrecision.m:
--------------------------------------------------------------------------------
1 | function checkgradientSinglePrecision(problem, x, d)
2 | % Checks the consistency of the cost function and the gradient.
3 | %
4 | % function checkgradient(problem)
5 | % function checkgradient(problem, x)
6 | % function checkgradient(problem, x, d)
7 | %
8 | % checkgradient performs a numerical test to check that the gradient
9 | % defined in the problem structure agrees up to first order with the cost
10 | % function at some point x, along some direction d. The test is based on a
11 | % truncated Taylor series (see online Manopt documentation).
12 | %
13 | % It is also tested that the gradient is indeed a tangent vector.
14 | %
15 | % Both x and d are optional and will be sampled at random if omitted.
16 | %
17 | % See also: checkdiff checkhessian
18 |
19 | % This file is part of Manopt: www.manopt.org.
20 | % Original author: Nicolas Boumal, Dec. 30, 2012.
21 | % Contributors:
22 | % Change log:
23 |
24 |
25 | % Verify that the problem description is sufficient.
26 | if ~canGetCost(problem)
27 | error('It seems no cost was provided.');
28 | end
29 | if ~canGetGradient(problem)
30 | error('It seems no gradient provided.');
31 | end
32 |
33 | dbstore = struct();
34 |
35 | x_isprovided = exist('x', 'var') && ~isempty(x);
36 | d_isprovided = exist('d', 'var') && ~isempty(d);
37 |
38 | if ~x_isprovided && d_isprovided
39 | error('If d is provided, x must be too, since d is tangent at x.');
40 | end
41 |
42 | % If x and / or d are not specified, pick them at random.
43 | if ~x_isprovided
44 | x = problem.M.rand();
45 | end
46 | if ~d_isprovided
47 | d = problem.M.randvec(x);
48 | end
49 |
50 | x.encoder = single(x.encoder);
51 | x.decoder = single(x.decoder);
52 | d.encoder = single(d.encoder);
53 | d.decoder = single(d.decoder);
54 | %% Check that the gradient yields a first order model of the cost.
55 |
56 | % By removing the 'directional derivative' function, it should be so
57 | % (?) that the checkdiff function will use the gradient to compute
58 | % directional derivatives.
59 | if isfield(problem, 'diff')
60 | problem = rmfield(problem, 'diff');
61 | end
62 | checkdiffSinglePrecision(problem, x, d);
63 | title(sprintf(['Gradient check.\nThe slope of the continuous line ' ...
64 | 'should match that of the dashed (reference) line\n' ...
65 | 'over at least a few orders of magnitude for h.']));
66 | xlabel('h');
67 | ylabel('Approximation error');
68 |
69 | %% Try to check that the gradient is a tangent vector.
70 | if isfield(problem.M, 'tangent')
71 | grad = getGradient(problem, x, dbstore);
72 | pgrad = problem.M.tangent(x, grad);
73 | residual = problem.M.lincomb(x, 1, grad, -1, pgrad);
74 | err = problem.M.norm(x, residual);
75 | fprintf('The residual should be 0, or very close. Residual: %g.\n', err);
76 | fprintf('If it is far from 0, then the gradient is not in the tangent space.\n');
77 | else
78 | fprintf(['Unfortunately, Manopt was unable to verify that the '...
79 | 'gradient is indeed a tangent vector.\nPlease verify ' ...
80 | 'this manually or implement the ''tangent'' function ' ...
81 | 'in your manifold structure.']);
82 | end
83 |
84 | end
85 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/checkhessian.m:
--------------------------------------------------------------------------------
1 | function checkhessian(problem, x, d)
2 | % Checks the consistency of the cost function and the Hessian.
3 | %
4 | % function checkhessian(problem)
5 | % function checkhessian(problem, x)
6 | % function checkhessian(problem, x, d)
7 | %
8 | % checkhessian performs a numerical test to check that the directional
9 | % derivatives and Hessian defined in the problem structure agree up to
10 | % second order with the cost function at some point x, along some direction
11 | % d. The test is based on a truncated Taylor series (see online Manopt
12 | % documentation).
13 | %
14 | % It is also tested that the Hessian along some direction is indeed a
15 | % tangent vector and that the Hessian operator is symmetric w.r.t. the
16 | % Riemannian metric.
17 | %
18 | % Both x and d are optional and will be sampled at random if omitted.
19 | %
20 | % See also: checkdiff checkgradient
21 |
22 | % This file is part of Manopt: www.manopt.org.
23 | % Original author: Nicolas Boumal, Dec. 30, 2012.
24 | % Contributors:
25 | % Change log:
26 |
27 |
28 | % Verify that the problem description is sufficient.
29 | if ~canGetCost(problem)
30 | error('It seems no cost was provided.');
31 | end
32 | if ~canGetGradient(problem)
33 | error('It seems no gradient provided.');
34 | end
35 | if ~canGetHessian(problem)
36 | error('It seems no Hessian was provided.');
37 | end
38 |
39 | dbstore = struct();
40 |
41 | x_isprovided = exist('x', 'var') && ~isempty(x);
42 | d_isprovided = exist('d', 'var') && ~isempty(d);
43 |
44 | if ~x_isprovided && d_isprovided
45 | error('If d is provided, x must be too, since d is tangent at x.');
46 | end
47 |
48 | % If x and / or d are not specified, pick them at random.
49 | if ~x_isprovided
50 | x = problem.M.rand();
51 | end
52 | if ~d_isprovided
53 | d = problem.M.randvec(x);
54 | end
55 |
56 | %% Check that the directional derivative and the Hessian at x along d
57 | %% yield a second order model of the cost function.
58 |
59 | % Compute the value f0 at f, directional derivative df0 at x along d,
60 | % and Hessian along [d, d].
61 | f0 = getCost(problem, x, dbstore);
62 | df0 = getDirectionalDerivative(problem, x, d, dbstore);
63 | d2f0 = problem.M.inner(x, d, getHessian(problem, x, d, dbstore));
64 |
65 | % Compute the value of f at points on the geodesic (or approximation of
66 | % it) originating from x, along direction d, for stepsizes in a large
67 | % range given by h.
68 | h = logspace(-8, 0, 51);
69 | value = zeros(size(h));
70 | for i = 1 : length(h)
71 | y = problem.M.exp(x, d, h(i));
72 | value(i) = getCost(problem, y, dbstore);
73 | end
74 |
75 | % Compute the quadratic approximation of the cost function using f0,
76 | % df0 and d2f0 at the same points.
77 | model = polyval([.5*d2f0 df0 f0], h);
78 |
79 | % Compute the approximation error
80 | err = abs(model - value);
81 |
82 | % And plot it.
83 | loglog(h, err);
84 | title(sprintf(['Hessian check.\nThe slope of the continuous line ' ...
85 | 'should match that of the dashed (reference) line\n' ...
86 | 'over at least a few orders of magnitude for h.']));
87 | xlabel('h');
88 | ylabel('Approximation error');
89 |
90 | line('xdata', [1e-8 1e0], 'ydata', [1e-16 1e8], ...
91 | 'color', 'k', 'LineStyle', '--', ...
92 | 'YLimInclude', 'off', 'XLimInclude', 'off');
93 |
94 | % In a numerically reasonable neighborhood, the error should decrease
95 | % as the cube of the stepsize, i.e., in loglog scale, the error
96 | % should have a slope of 3.
97 | window_len = 10;
98 | [range poly] = identify_linear_piece(log10(h), log10(err), window_len);
99 | hold on;
100 | loglog(h(range), 10.^polyval(poly, log10(h(range))), ...
101 | 'r-', 'LineWidth', 3);
102 | hold off;
103 |
104 | fprintf('The slope should be 3. It appears to be: %g.\n', poly(1));
105 | fprintf(['If it is far from 3, then directional derivatives or ' ...
106 | 'the Hessian might be erroneous.\n']);
107 |
108 |
109 | %% Check that the Hessian at x along direction d is a tangent vector.
110 | if isfield(problem.M, 'tangent')
111 | hess = getHessian(problem, x, d, dbstore);
112 | phess = problem.M.tangent(x, hess);
113 | residual = problem.M.lincomb(x, 1, hess, -1, phess);
114 | err = problem.M.norm(x, residual);
115 | fprintf('The residual should be zero, or very close. ');
116 | fprintf('Residual: %g.\n', err);
117 | fprintf(['If it is far from 0, then the Hessian is not in the ' ...
118 | 'tangent plane.\n']);
119 | else
120 | fprintf(['Unfortunately, Manopt was un able to verify that the '...
121 | 'Hessian is indeed a tangent vector.\nPlease verify ' ...
122 | 'this manually.']);
123 | end
124 |
125 | %% Check that the Hessian at x is symmetric.
126 | d1 = problem.M.randvec(x);
127 | d2 = problem.M.randvec(x);
128 | h1 = getHessian(problem, x, d1, dbstore);
129 | h2 = getHessian(problem, x, d2, dbstore);
130 | v1 = problem.M.inner(x, d1, h2);
131 | v2 = problem.M.inner(x, h1, d2);
132 | value = v1-v2;
133 | fprintf([' - should be zero, or very close.' ...
134 | '\n\tValue: %g - %g = %g.\n'], v1, v2, value);
135 | fprintf('If it is far from 0, then the Hessian is not symmetric.\n');
136 |
137 | end
138 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/diagsum.m:
--------------------------------------------------------------------------------
1 | function [tracedtensor] = diagsum(tensor1, d1, d2)
2 | % C = DIAGSUM(A, d1, d2) Performs the trace
3 | % C(i[1],...,i[d1-1],i[d1+1],...,i[d2-1],i[d2+1],...i[n]) =
4 | % A(i[1],...,i[d1-1],k,i[d1+1],...,i[d2-1],k,i[d2+1],...,i[n])
5 | % (Sum on k).
6 | %
7 | % C = DIAGSUM(A, d1, d2) traces A along the diagonal formed by dimensions d1
8 | % and d2. If the lengths of these dimensions are not equal, DIAGSUM traces
9 | % until the end of the shortest of dimensions d1 and d2 is reached. This is
10 | % an analogue of the built in TRACE function.
11 | %
12 | % Wynton Moore, January 2006
13 |
14 |
15 | dim1=size(tensor1);
16 | numdims=length(dim1);
17 |
18 |
19 | %check inputs
20 | if d1==d2
21 | tracedtensor=squeeze(sum(tensor1,d1));
22 | elseif numdims==2
23 | tracedtensor=trace(tensor1);
24 | elseif dim1(d1)==1 && dim1(d2)==1
25 | tracedtensor=squeeze(tensor1);
26 | else
27 |
28 |
29 | %determine correct permutation
30 | swapd1=d1;swapd2=d2;
31 |
32 | if d1~=numdims-1 && d1~=numdims && d2~=numdims-1
33 | swapd1=numdims-1;
34 | elseif d1~=numdims-1 && d1~=numdims && d2~=numdims
35 | swapd1=numdims;
36 | end
37 | if d2~=numdims-1 && d2~=numdims && swapd1~=numdims-1
38 | swapd2=numdims-1;
39 | elseif d2~=numdims-1 && d2~=numdims && swapd1~=numdims
40 | swapd2=numdims;
41 | end
42 |
43 |
44 | %prepare for construction of selector tensor
45 | temp1=eye(numdims);
46 | permmatrix=temp1;
47 | permmatrix(:,d1)=temp1(:,swapd1);
48 | permmatrix(:,swapd1)=temp1(:,d1);
49 | permmatrix(:,d2)=temp1(:,swapd2);
50 | permmatrix(:,swapd2)=temp1(:,d2);
51 |
52 | selectordim=dim1*permmatrix;
53 | permvector=(1:numdims)*permmatrix;
54 |
55 |
56 | %construct selector tensor
57 | if numdims>3
58 | selector = ipermute(outer(ones(selectordim(1:numdims-2)), ...
59 | eye(selectordim(numdims-1), ...
60 | selectordim(numdims)), ...
61 | 0), ...
62 | permvector);
63 | else
64 | %when numdims=3, the above line gives ndims(selector)=4. This
65 | %routine avoids that error. When used with GMDMP, numdims will be
66 | %at least 4, so this routine will be unnecessary.
67 | selector2=eye(selectordim(numdims-1), selectordim(numdims));
68 | selector=zeros(selectordim);
69 | for j=1:selectordim(1)
70 | selector(j, :, :)=selector2;
71 | end
72 | selector=ipermute(selector, permvector);
73 | end
74 |
75 |
76 | %perform trace, discard resulting singleton dimensions
77 | tracedtensor=sum(sum(tensor1.*selector, d1), d2);
78 | tracedtensor=squeeze(tracedtensor);
79 |
80 |
81 | end
82 |
83 |
84 | %correction for abberation in squeeze function:
85 | %size(squeeze(rand(1,1,2)))=[2 1]
86 | nontracedimensions=dim1;
87 | nontracedimensions(d1)=[];
88 | if d2>d1
89 | nontracedimensions(d2-1)=[];
90 | else
91 | nontracedimensions(d2)=[];
92 | end
93 | tracedsize=size(tracedtensor);
94 | % Next line modified, Nicolas Boumal, April 30, 2012, such that diagsum(A,
95 | % 1, 2) would compute the trace of A, a 2D matrix.
96 | if length(tracedsize)==2 && tracedsize(2)==1 && ...
97 | (isempty(nontracedimensions) || tracedsize(1)~=nontracedimensions(1))
98 |
99 | tracedtensor=tracedtensor.';
100 |
101 | end
102 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/hessianspectrum.m:
--------------------------------------------------------------------------------
1 | function lambdas = hessianspectrum(problem, x, sqrtprec)
2 | % Returns the eigenvalues of the (preconditioned) Hessian at x.
3 | %
4 | % function lambdas = hessianspectrum(problem, x)
5 | % function lambdas = hessianspectrum(problem, x, sqrtprecon)
6 | %
7 | % If the Hessian is defined in the problem structure and if no
8 | % preconditioner is defined, returns the eigenvalues of the Hessian
9 | % operator (which needs to be symmetric but not necessarily definite) on
10 | % the tangent space at x. There are problem.M.dim() eigenvalues.
11 | %
12 | % If a preconditioner is defined, the eigenvalues of the composition is
13 | % computed: precon o Hessian. Remember that the preconditioner has to be
14 | % symmetric, positive definite, and is supposed to approximate the inverse
15 | % of the Hessian.
16 | %
17 | % Even though the Hessian and the preconditioner are both symmetric, their
18 | % composition is not symmetric, which can slow down the call to 'eigs'
19 | % substantially. If possible, you may specify the square root of the
20 | % preconditioner as an optional input sqrtprecon. This operator on the
21 | % tangent space at x must also be symmetric, positive definite, and such
22 | % that sqrtprecon o sqrtprecon = precon. Then the spectrum of the symmetric
23 | % operator sqrtprecon o hess o sqrtprecon is computed: it is the same as
24 | % the spectrum of precon o hess, but is generally faster to compute.
25 | % The operator sqrtprecon(x, u[, store]) accepts as input: a point x,
26 | % a tangent vector u and (optional) a store structure.
27 | %
28 | % The input and the output of the Hessian and of the preconditioner are
29 | % projected on the tangent space to avoid undesired contributions of the
30 | % ambient space.
31 | %
32 | % Requires the manifold description in problem.M to have these functions:
33 | %
34 | % u_vec = vec(x, u_mat) :
35 | % Returns a column vector representation of the normal (usually
36 | % matrix) representation of the tangent vector u_mat. vec must be an
37 | % isometry between the tangent space (with its Riemannian metric) and
38 | % a subspace of R^n where n = length(u_vec), with the 2-norm on R^n.
39 | % In other words: it is an orthogonal projector.
40 | %
41 | % u_mat = mat(x, u_vec) :
42 | % The inverse of vec (its adjoint).
43 | %
44 | % u_mat_clean = tangent(x, u_mat) :
45 | % Subtracts from the tangent vector u_mat any component that would
46 | % make it "not really tangent", by projection.
47 | %
48 | % answer = vecmatareisometries() :
49 | % Returns true if the linear maps encoded by vec and mat are
50 | % isometries, false otherwise. It is better if the answer is yes.
51 | %
52 |
53 | % This file is part of Manopt: www.manopt.org.
54 | % Original author: Nicolas Boumal, July 3, 2013.
55 | % Contributors:
56 | % Change log:
57 |
58 |
59 | if ~canGetHessian(problem)
60 | warning('manopt:hessianspectrum:nohessian', ...
61 | ['The Hessian appears to be unavailable.\n' ...
62 | 'Will try to use an approximate Hessian instead.\n'...
63 | 'Since this approximation may not be linear or '...
64 | 'symmetric,\nthe computation might fail and the '...
65 | 'results (if any)\nmight make no sense.']);
66 | end
67 |
68 | vec = @(u_mat) problem.M.vec(x, u_mat);
69 | mat = @(u_vec) problem.M.mat(x, u_vec);
70 | tgt = @(u_mat) problem.M.tangent(x, u_mat);
71 |
72 | % n: size of a vectorized tangent vector
73 | % dim: dimension of the tangent space
74 | % necessarily, n >= dim.
75 | % The vectorized operators we build below will have at least n - dim
76 | % zero eigenvalues.
77 | n = length(vec(problem.M.zerovec(x)));
78 | dim = problem.M.dim();
79 |
80 | % The store structure is not updated by the getHessian call because the
81 | % eigs function will not take care of it. This might be worked around,
82 | % but for now we simply obtain the store structure built from calling
83 | % the cost and gradient at x and pass that one for every Hessian call.
84 | % This will typically be enough, seen as the Hessian is not supposed to
85 | % store anything new.
86 | storedb = struct();
87 | if canGetGradient(problem)
88 | [unused1, unused2, storedb] = getCostGrad(problem, x, struct()); %#ok
89 | end
90 |
91 | hess = @(u_mat) tgt(getHessian(problem, x, tgt(u_mat), storedb));
92 | hess_vec = @(u_vec) vec(hess(mat(u_vec)));
93 |
94 | % Regardless of preconditioning, we can only have a symmetric
95 | % eigenvalue problem if the vec/mat pair of the manifold is an
96 | % isometry:
97 | vec_mat_are_isometries = problem.M.vecmatareisometries();
98 |
99 | if ~exist('sqrtprec', 'var') || isempty(sqrtprec)
100 |
101 | if ~canGetPrecon(problem)
102 |
103 | % There is no preconditinoer : just deal with the (symmetric)
104 | % Hessian.
105 |
106 | eigs_opts.issym = vec_mat_are_isometries;
107 | eigs_opts.isreal = true;
108 | lambdas = eigs(hess_vec, n, dim, 'LM', eigs_opts);
109 |
110 | else
111 |
112 | % There is a preconditioner, but we don't have its square root:
113 | % deal with the non-symmetric composition prec o hess.
114 |
115 | prec = @(u_mat) tgt(getPrecon(problem, x, tgt(u_mat), storedb));
116 | prec_vec = @(u_vec) vec(prec(mat(u_vec)));
117 | % prec_inv_vec = @(u_vec) pcg(prec_vec, u_vec);
118 |
119 | eigs_opts.issym = false;
120 | eigs_opts.isreal = true;
121 | lambdas = eigs(@(u_vec) prec_vec(hess_vec(u_vec)), ...
122 | n, dim, 'LM', eigs_opts);
123 |
124 | end
125 |
126 | else
127 |
128 | % There is a preconditioner, and we have its square root: deal with
129 | % the symmetric composition sqrtprec o hess o sqrtprec.
130 | % Need to check also whether sqrtprec uses the store cache or not.
131 |
132 | is_octave = exist('OCTAVE_VERSION', 'builtin');
133 | if ~is_octave
134 | narg = nargin(sqrtprec);
135 | else
136 | narg = 3;
137 | end
138 |
139 | switch narg
140 | case 2
141 | sqrtprec_vec = @(u_vec) vec(tgt(sqrtprec(x, tgt(mat(u_vec)))));
142 | case 3
143 | store = getStore(problem, x, storedb);
144 | sqrtprec_vec = @(u_vec) vec(tgt(sqrtprec(x, tgt(mat(u_vec)), store)));
145 | otherwise
146 | error('sqrtprec must accept 2 or 3 inputs: x, u, (optional: store)');
147 | end
148 |
149 | eigs_opts.issym = vec_mat_are_isometries;
150 | eigs_opts.isreal = true;
151 | lambdas = eigs(@(u_vec) ...
152 | sqrtprec_vec(hess_vec(sqrtprec_vec(u_vec))), ...
153 | n, dim, 'LM', eigs_opts);
154 |
155 | end
156 |
157 | end
158 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/identify_linear_piece.m:
--------------------------------------------------------------------------------
1 | function [range, poly] = identify_linear_piece(x, y, window_length)
2 | % Identify a segment of the curve (x, y) that appears to be linear.
3 | %
4 | % function [range poly] = identify_linear_piece(x, y, window_length)
5 | %
6 | % This function attempts to identify a contiguous segment of the curve
7 | % defined by the vectors x and y that appears to be linear. A line is fit
8 | % through the data over all windows of length window_length and the best
9 | % fit is retained. The output specifies the range of indices such that
10 | % x(range) is the portion over which (x, y) is the most linear and the
11 | % output poly specifies a first order polynomial that best fits (x, y) over
12 | % that range, following the usual matlab convention for polynomials
13 | % (highest degree coefficients first).
14 | %
15 | % See also: checkdiff checkgradient checkhessian
16 |
17 | % This file is part of Manopt: www.manopt.org.
18 | % Original author: Nicolas Boumal, July 8, 2013.
19 | % Contributors:
20 | % Change log:
21 |
22 | residues = zeros(length(x)-window_length, 1);
23 | polys = zeros(2, length(residues));
24 | for i = 1 : length(residues)
25 | range = i:i+window_length;
26 | [poly, meta] = polyfit(x(range), y(range), 1);
27 | residues(i) = meta.normr;
28 | polys(:, i) = poly';
29 | end
30 | [unused, best] = min(residues); %#ok
31 | range = best:best+window_length;
32 | poly = polys(:, best)';
33 |
34 | end
35 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/identify_linear_piece_SinglePrecision.m:
--------------------------------------------------------------------------------
1 | function [range, poly] = identify_linear_piece_SinglePrecision(x, y, window_length)
2 | % Identify a segment of the curve (x, y) that appears to be linear.
3 | %
4 | % function [range poly] = identify_linear_piece(x, y, window_length)
5 | %
6 | % This function attempts to identify a contiguous segment of the curve
7 | % defined by the vectors x and y that appears to be linear. A line is fit
8 | % through the data over all windows of length window_length and the best
9 | % fit is retained. The output specifies the range of indices such that
10 | % x(range) is the portion over which (x, y) is the most linear and the
11 | % output poly specifies a first order polynomial that best fits (x, y) over
12 | % that range, following the usual matlab convention for polynomials
13 | % (highest degree coefficients first).
14 | %
15 | % See also: checkdiff checkgradient checkhessian
16 |
17 | % This file is part of Manopt: www.manopt.org.
18 | % Original author: Nicolas Boumal, July 8, 2013.
19 | % Contributors:
20 | % Change log:
21 |
22 | residues = single(zeros(length(x)-window_length, 1));
23 | polys = single(zeros(2, length(residues)));
24 | for i = 1 : length(residues)
25 | range = i:i+window_length;
26 | [poly, meta] = polyfit(x(range), y(range), 1);
27 | residues(i) = meta.normr;
28 | polys(:, i) = poly';
29 | end
30 | [unused, best] = min(residues); %#ok
31 | range = best:best+window_length;
32 | poly = polys(:, best)';
33 |
34 | end
35 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/multiprod.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiwu-huang/SPDNet/527110a2209e918785075738b81c4ab091664e61/manopt/manopt/tools/multiprod.m
--------------------------------------------------------------------------------
/manopt/manopt/tools/multiprodmultitransp_license.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2009, Paolo de Leva
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are
6 | met:
7 |
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 | * Redistributions in binary form must reproduce the above copyright
11 | notice, this list of conditions and the following disclaimer in
12 | the documentation and/or other materials provided with the distribution
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
18 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24 | POSSIBILITY OF SUCH DAMAGE.
25 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/multiscale.m:
--------------------------------------------------------------------------------
1 | function A = multiscale(scale, A)
2 | % Multiplies the 2D slices in a 3D matrix by individual scalars.
3 | %
4 | % function A = multiscale(scale, A)
5 | %
6 | % Given a vector scale of length N and a 3-dimensional matrix A of size
7 | % n-by-m-by-N, returns a matrix A of same size such that
8 | % A(:, :, k) := scale(k) * A(:, :, k);
9 | %
10 | % See also: multiprod multitransp multitrace
11 |
12 | % This file is part of Manopt: www.manopt.org.
13 | % Original author: Nicolas Boumal, Dec. 30, 2012.
14 | % Contributors:
15 | % Change log:
16 |
17 |
18 | assert(ndims(A) <= 3, ...
19 | ['multiscale is only well defined for matrix arrays of 3 ' ...
20 | 'or less dimensions.']);
21 | [n m N] = size(A);
22 | assert(numel(scale) == N, ...
23 | ['scale must be a vector whose length equals the third ' ...
24 | 'dimension of A, that is, the number of 2D matrix slices ' ...
25 | 'in the 3D matrix A.']);
26 |
27 | scale = scale(:);
28 | A = reshape(bsxfun(@times, reshape(A, n*m, N), scale'), n, m, N);
29 |
30 | end
31 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/multiskew.m:
--------------------------------------------------------------------------------
1 | function Y = multiskew(X)
2 | % Returns the skew-symmetric parts of the matrices in the 3D matrix X.
3 | %
4 | % function Y = multiskew(X)
5 | %
6 | % Y is a 3D matrix the same size as X. Each slice Y(:, :, i) is the
7 | % skew-symmetric part of the slice X(:, :, i).
8 | %
9 | % See also: multiprod multitransp multiscale multisym
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, Jan. 31, 2013.
13 | % Contributors:
14 | % Change log:
15 |
16 | Y = .5*(X - multitransp(X));
17 |
18 | end
19 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/multisym.m:
--------------------------------------------------------------------------------
1 | function Y = multisym(X)
2 | % Returns the symmetric parts of the matrices in the 3D matrix X
3 | %
4 | % function Y = multisym(X)
5 | %
6 | % Y is a 3D matrix the same size as X. Each slice Y(:, :, i) is the
7 | % symmetric part of the slice X(:, :, i).
8 | %
9 | % See also: multiprod multitransp multiscale multiskew
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, Jan. 31, 2013.
13 | % Contributors:
14 | % Change log:
15 |
16 | Y = .5*(X + multitransp(X));
17 |
18 | end
19 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/multitrace.m:
--------------------------------------------------------------------------------
1 | function tr = multitrace(A)
2 | % Computes the traces of the 2D slices in a 3D matrix.
3 | %
4 | % function tr = multitrace(A)
5 | %
6 | % For a 3-dimensional matrix A of size n-by-n-by-N, returns a column vector
7 | % tr of length N such that tr(k) = trace(A(:, :, k));
8 | %
9 | % See also: multiprod multitransp multiscale
10 |
11 | % This file is part of Manopt: www.manopt.org.
12 | % Original author: Nicolas Boumal, Dec. 30, 2012.
13 | % Contributors:
14 | % Change log:
15 |
16 |
17 | assert(ndims(A) <= 3, ...
18 | ['multitrace is only well defined for matrix arrays of 3 ' ...
19 | 'or less dimensions.']);
20 |
21 | tr = diagsum(A, 1, 2);
22 |
23 | end
24 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/multitransp.m:
--------------------------------------------------------------------------------
1 | function b = multitransp(a, dim)
2 | %MULTITRANSP Transposing arrays of matrices.
3 | % B = MULTITRANSP(A) is equivalent to B = MULTITRANSP(A, DIM), where
4 | % DIM = 1.
5 | %
6 | % B = MULTITRANSP(A, DIM) is equivalent to
7 | % B = PERMUTE(A, [1:DIM-1, DIM+1, DIM, DIM+2:NDIMS(A)]), where A is an
8 | % array containing N P-by-Q matrices along its dimensions DIM and DIM+1,
9 | % and B is an array containing the Q-by-P transpose (.') of those N
10 | % matrices along the same dimensions. N = NUMEL(A) / (P*Q), i.e. N is
11 | % equal to the number of elements in A divided by the number of elements
12 | % in each matrix.
13 | %
14 | % MULTITRANSP, PERMUTE and IPERMUTE are a generalization of TRANSPOSE
15 | % (.') for N-D arrays.
16 | %
17 | % Example:
18 | % A 5-by-9-by-3-by-2 array may be considered to be a block array
19 | % containing ten 9-by-3 matrices along dimensions 2 and 3. In this
20 | % case, its size is so indicated: 5-by-(9-by-3)-by-2 or 5x(9x3)x2.
21 | % If A is ................ a 5x(9x3)x2 array of 9x3 matrices,
22 | % C = MULTITRANSP(A, 2) is a 5x(3x9)x2 array of 3x9 matrices.
23 | %
24 | % See also PERMUTE, IPERMUTE, MULTIPROD, MULTITRACE, MULTISCALE.
25 |
26 | % $ Version: 1.0 $
27 | % CODE by: Paolo de Leva (IUSM, Rome, IT) 2005 Sep 9
28 | % COMMENTS by: Code author 2006 Nov 21
29 | % OUTPUT tested by: Code author 2005 Sep 13
30 | % -------------------------------------------------------------------------
31 |
32 | % Setting DIM if not supplied.
33 | if nargin == 1, dim = 1; end
34 |
35 | % Transposing
36 | order = [1:dim-1, dim+1, dim, dim+2:ndims(a)];
37 | b = permute(a, order);
38 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/plotprofile.m:
--------------------------------------------------------------------------------
1 | function cost = plotprofile(problem, x, d, t)
2 | % Plot the cost function along a geodesic or a retraction path.
3 | %
4 | % function plotprofile(problem, x, d, t)
5 | % function costs = plotprofile(problem, x, d, t)
6 | %
7 | % Plot profile evaluates the cost function along a geodesic gamma(t) such
8 | % that gamma(0) = x and the derivative of gamma at 0 is the direction d.
9 | % The input t is a vector specifying for which values of t we must evaluate
10 | % f(gamma(t)) (it may include negative values).
11 | %
12 | % If the function is called with an output, the plot is not drawn and the
13 | % values of the cost are returned for the instants t.
14 |
15 | % This file is part of Manopt: www.manopt.org.
16 | % Original author: Nicolas Boumal, Jan. 9, 2013.
17 | % Contributors:
18 | % Change log:
19 |
20 | % Verify that the problem description is sufficient.
21 | if ~canGetCost(problem)
22 | error('It seems no cost was provided.');
23 | end
24 |
25 | linesearch_fun = @(t) getCost(problem, problem.M.exp(x, d, t), struct());
26 |
27 | cost = zeros(size(t));
28 | for i = 1 : numel(t)
29 | cost(i) = linesearch_fun(t(i));
30 | end
31 |
32 | if nargout == 0
33 | plot(t, cost);
34 | xlabel('t');
35 | ylabel('f(Exp_x(t*d))');
36 | end
37 |
38 | end
39 |
--------------------------------------------------------------------------------
/manopt/manopt/tools/powermanifold.m:
--------------------------------------------------------------------------------
1 | function Mn = powermanifold(M, n)
2 | % Returns a structure describing a power manifold M^n = M x M x ... x M.
3 | %
4 | % function Mn = powermanifold(M, n)
5 | %
6 | % Input: a manifold structure M and an integer n >= 1.
7 | %
8 | % Output: a manifold structure Mn representing M x ... x M (n copies of M)
9 | % with the metric of M extended element-wise. Points and vectors are stored
10 | % as cells of size nx1.
11 | %
12 | % This code is for prototyping uses. The structures returned are often
13 | % inefficient representations of power manifolds owing to their use of
14 | % for-loops, but they should allow to rapidly try out an idea.
15 | %
16 | % Example (an inefficient representation of the oblique manifold (3, 10)):
17 | % Mn = powermanifold(spherefactory(3), 10)
18 | % disp(Mn.name());
19 | % x = Mn.rand()
20 | %
21 | % See also: productmanifold
22 |
23 | % This file is part of Manopt: www.manopt.org.
24 | % Original author: Nicolas Boumal, Dec. 30, 2012.
25 | % Contributors:
26 | % Change log:
27 | % NB, July 4, 2013: Added support for vec, mat, tangent.
28 | % Added support for egrad2rgrad and ehess2rhess.
29 |
30 |
31 | assert(n >= 1, 'n must be an integer larger than or equal to 1.');
32 |
33 | Mn.name = @() sprintf('[%s]^%d', M.name(), n);
34 |
35 | Mn.dim = @() n*M.dim();
36 |
37 | Mn.inner = @inner;
38 | function val = inner(x, u, v)
39 | val = 0;
40 | for i = 1 : n
41 | val = val + M.inner(x{i}, u{i}, v{i});
42 | end
43 | end
44 |
45 | Mn.norm = @(x, d) sqrt(Mn.inner(x, d, d));
46 |
47 | Mn.dist = @dist;
48 | function d = dist(x, y)
49 | sqd = 0;
50 | for i = 1 : n
51 | sqd = sqd + M.dist(x{i}, y{i})^2;
52 | end
53 | d = sqrt(sqd);
54 | end
55 |
56 | Mn.typicaldist = @typicaldist;
57 | function d = typicaldist()
58 | sqd = 0;
59 | for i = 1 : n
60 | sqd = sqd + M.typicaldist()^2;
61 | end
62 | d = sqrt(sqd);
63 | end
64 |
65 | Mn.proj = @proj;
66 | function u = proj(x, u)
67 | for i = 1 : n
68 | u{i} = M.proj(x{i}, u{i});
69 | end
70 | end
71 |
72 | Mn.tangent = @tangent;
73 | function u = tangent(x, u)
74 | for i = 1 : n
75 | u{i} = M.tangent(x{i}, u{i});
76 | end
77 | end
78 |
79 | if isfield(M, 'tangent2ambient')
80 | Mn.tangent2ambient = @tangent2ambient;
81 | else
82 | Mn.tangent2ambient = @(x, u) u;
83 | end
84 | function u = tangent2ambient(x, u)
85 | for i = 1 : n
86 | u{i} = M.tangent2ambient(x{i}, u{i});
87 | end
88 | end
89 |
90 | Mn.egrad2rgrad = @egrad2rgrad;
91 | function g = egrad2rgrad(x, g)
92 | for i = 1 : n
93 | g{i} = M.egrad2rgrad(x{i}, g{i});
94 | end
95 | end
96 |
97 | Mn.ehess2rhess = @ehess2rhess;
98 | function h = ehess2rhess(x, eg, eh, h)
99 | for i = 1 : n
100 | h{i} = M.ehess2rhess(x{i}, eg{i}, eh{i}, h{i});
101 | end
102 | end
103 |
104 | Mn.exp = @expo;
105 | function x = expo(x, u, t)
106 | if nargin < 3
107 | t = 1.0;
108 | end
109 | for i = 1 : n
110 | x{i} = M.exp(x{i}, u{i}, t);
111 | end
112 | end
113 |
114 | Mn.retr = @retr;
115 | function x = retr(x, u, t)
116 | if nargin < 3
117 | t = 1.0;
118 | end
119 | for i = 1 : n
120 | x{i} = M.retr(x{i}, u{i}, t);
121 | end
122 | end
123 |
124 | if isfield(M, 'log')
125 | Mn.log = @loga;
126 | end
127 | function u = loga(x, y)
128 | u = cell(n, 1);
129 | for i = 1 : n
130 | u{i} = M.log(x{i}, y{i});
131 | end
132 | end
133 |
134 | Mn.hash = @hash;
135 | function str = hash(x)
136 | str = '';
137 | for i = 1 : n
138 | str = [str M.hash(x{i})]; %#ok
139 | end
140 | str = ['z' hashmd5(str)];
141 | end
142 |
143 | Mn.lincomb = @lincomb;
144 | function x = lincomb(x, a1, u1, a2, u2)
145 | if nargin == 3
146 | for i = 1 : n
147 | x{i} = M.lincomb(x{i}, a1, u1{i});
148 | end
149 | elseif nargin == 5
150 | for i = 1 : n
151 | x{i} = M.lincomb(x{i}, a1, u1{i}, a2, u2{i});
152 | end
153 | else
154 | error('Bad usage of powermanifold.lincomb');
155 | end
156 | end
157 |
158 | Mn.rand = @rand;
159 | function x = rand()
160 | x = cell(n, 1);
161 | for i = 1 : n
162 | x{i} = M.rand();
163 | end
164 | end
165 |
166 | Mn.randvec = @randvec;
167 | function u = randvec(x)
168 | u = cell(n, 1);
169 | for i = 1 : n
170 | u{i} = M.randvec(x{i});
171 | end
172 | u = Mn.lincomb(x, 1/sqrt(n), u);
173 | end
174 |
175 | Mn.zerovec = @zerovec;
176 | function u = zerovec(x)
177 | u = cell(n, 1);
178 | for i = 1 : n
179 | u{i} = M.zerovec(x{i});
180 | end
181 | end
182 |
183 | if isfield(M, 'transp')
184 | Mn.transp = @transp;
185 | end
186 | function u = transp(x1, x2, u)
187 | for i = 1 : n
188 | u{i} = M.transp(x1{i}, x2{i}, u{i});
189 | end
190 | end
191 |
192 | if isfield(M, 'pairmean')
193 | Mn.pairmean = @pairmean;
194 | end
195 | function y = pairmean(x1, x2)
196 | y = cell(n, 1);
197 | for i = 1 : n
198 | y{i} = M.pairmean(x1{i}, x2{i});
199 | end
200 | end
201 |
202 | % Compute the length of a vectorized tangent vector of M at x, assuming
203 | % this length is independent of the point x (that should be fine).
204 | if isfield(M, 'vec')
205 | rand_x = M.rand();
206 | zero_u = M.zerovec(rand_x);
207 | len_vec = length(M.vec(rand_x, zero_u));
208 |
209 | Mn.vec = @vec;
210 |
211 | if isfield(M, 'mat')
212 | Mn.mat = @mat;
213 | end
214 |
215 | end
216 |
217 | function u_vec = vec(x, u_mat)
218 | u_vec = zeros(len_vec, n);
219 | for i = 1 : n
220 | u_vec(:, i) = M.vec(x{i}, u_mat{i});
221 | end
222 | u_vec = u_vec(:);
223 | end
224 |
225 | function u_mat = mat(x, u_vec)
226 | u_mat = cell(n, 1);
227 | u_vec = reshape(u_vec, len_vec, n);
228 | for i = 1 : n
229 | u_mat{i} = M.mat(x{i}, u_vec(:, i));
230 | end
231 | end
232 |
233 | if isfield(M, 'vecmatareisometries')
234 | Mn.vecmatareisometries = M.vecmatareisometries;
235 | else
236 | Mn.vecmatareisometries = @() false;
237 | end
238 |
239 | end
240 |
--------------------------------------------------------------------------------
/manopt/manopt_version.m:
--------------------------------------------------------------------------------
1 | function [version released] = manopt_version()
2 | % Returns the version of the Manopt package you are running, as a vector.
3 | %
4 | % function [version released] = manopt_version()
5 | %
6 | % version(1) is the primary version number.
7 | % released is the date this version was released, in the same format as the
8 | % date() function in Matlab.
9 |
10 | version = [1, 0, 7];
11 | released = '12-Aug-2014';
12 |
--------------------------------------------------------------------------------
/spdnet/vl_mybfc.m:
--------------------------------------------------------------------------------
1 | function [Y, Y_w] = vl_mybfc(X, W, dzdy)
2 | %[DZDX, DZDF, DZDB] = VL_MYBFC(X, F, B, DZDY)
3 | %BiMap layer
4 |
5 | Y = cell(length(X),1);
6 | for ix = 1 : length(X)
7 | Y{ix} = W'*X{ix}*W;
8 | end
9 | if nargin == 3
10 | [dim_ori, dim_tar] = size(W);
11 | Y_w = zeros(dim_ori,dim_tar);
12 | for ix = 1 : length(X)
13 | if iscell(dzdy)==1
14 | d_t = dzdy{ix};
15 | else
16 | d_t = dzdy(:,ix);
17 | d_t = reshape(d_t,[dim_tar dim_tar]);
18 | end
19 | Y{ix} = W*d_t*W';
20 | Y_w = Y_w+2*X{ix}*W*d_t;
21 | end
22 | end
23 |
--------------------------------------------------------------------------------
/spdnet/vl_myfc.m:
--------------------------------------------------------------------------------
1 | function [Y, Y_w] = vl_myfc(X, W, dzdy)
2 | %[DZDX, DZDF, DZDB] = vl_myconv(X, F, B, DZDY)
3 | %regular fully connected layer
4 |
5 | X_t = zeros(size(X{1},1)^2,length(X));
6 | for ix = 1 : length(X)
7 | x_t = X{ix};
8 | X_t(:,ix) = x_t(:);
9 | end
10 | if nargin < 3
11 | Y = W'*X_t;
12 | else
13 | Y = W * dzdy;
14 | Y_w = X_t * dzdy';
15 | end
--------------------------------------------------------------------------------
/spdnet/vl_myforbackward.m:
--------------------------------------------------------------------------------
1 | function res = vl_myforbackward(net, x, dzdy, res, varargin)
2 | % vl_myforbackward evaluates a simple SPDNet
3 |
4 | opts.res = [] ;
5 | opts.conserveMemory = false ;
6 | opts.sync = false ;
7 | opts.disableDropout = false ;
8 | opts.freezeDropout = false ;
9 | opts.accumulate = false ;
10 | opts.cudnn = true ;
11 | opts.skipForward = false;
12 | opts.backPropDepth = +inf ;
13 | opts.epsilon = 1e-4;
14 |
15 | % opts = vl_argparse(opts, varargin);
16 |
17 | n = numel(net.layers) ;
18 |
19 | if (nargin <= 2) || isempty(dzdy)
20 | doder = false ;
21 | else
22 | doder = true ;
23 | end
24 |
25 | if opts.cudnn
26 | cudnn = {'CuDNN'} ;
27 | else
28 | cudnn = {'NoCuDNN'} ;
29 | end
30 |
31 | gpuMode = isa(x, 'gpuArray') ;
32 |
33 | if nargin <= 3 || isempty(res)
34 | res = struct(...
35 | 'x', cell(1,n+1), ...
36 | 'dzdx', cell(1,n+1), ...
37 | 'dzdw', cell(1,n+1), ...
38 | 'aux', cell(1,n+1), ...
39 | 'time', num2cell(zeros(1,n+1)), ...
40 | 'backwardTime', num2cell(zeros(1,n+1))) ;
41 | end
42 | if ~opts.skipForward
43 | res(1).x = x ;
44 | end
45 |
46 |
47 | % -------------------------------------------------------------------------
48 | % Forward pass
49 | % -------------------------------------------------------------------------
50 |
51 | for i=1:n
52 | if opts.skipForward, break; end;
53 | l = net.layers{i} ;
54 | res(i).time = tic ;
55 | switch l.type
56 | case 'bfc'
57 | res(i+1).x = vl_mybfc(res(i).x, l.weight) ;
58 | case 'fc'
59 | res(i+1).x = vl_myfc(res(i).x, l.weight) ;
60 | case 'rec'
61 | res(i+1).x = vl_myrec(res(i).x, opts.epsilon) ;
62 | case 'log'
63 | res(i+1).x = vl_mylog(res(i).x) ;
64 | case 'softmaxloss'
65 | res(i+1).x = vl_mysoftmaxloss(res(i).x, l.class) ;
66 |
67 | case 'custom'
68 | res(i+1) = l.forward(l, res(i), res(i+1)) ;
69 | otherwise
70 | error('Unknown layer type %s', l.type) ;
71 | end
72 | % optionally forget intermediate results
73 | forget = opts.conserveMemory ;
74 | forget = forget & (~doder || strcmp(l.type, 'relu')) ;
75 | forget = forget & ~(strcmp(l.type, 'loss') || strcmp(l.type, 'softmaxloss')) ;
76 | forget = forget & (~isfield(l, 'rememberOutput') || ~l.rememberOutput) ;
77 | if forget
78 | res(i).x = [] ;
79 | end
80 | if gpuMode & opts.sync
81 | % This should make things slower, but on MATLAB 2014a it is necessary
82 | % for any decent performance.
83 | wait(gpuDevice) ;
84 | end
85 | res(i).time = toc(res(i).time) ;
86 | end
87 |
88 | % -------------------------------------------------------------------------
89 | % Backward pass
90 | % -------------------------------------------------------------------------
91 |
92 | if doder
93 | res(n+1).dzdx = dzdy ;
94 | for i=n:-1:max(1, n-opts.backPropDepth+1)
95 | l = net.layers{i} ;
96 | res(i).backwardTime = tic ;
97 | switch l.type
98 | case 'bfc'
99 | [res(i).dzdx, res(i).dzdw] = ...
100 | vl_mybfc(res(i).x, l.weight, res(i+1).dzdx) ;
101 | case 'fc'
102 | [res(i).dzdx, res(i).dzdw] = ...
103 | vl_myfc(res(i).x, l.weight, res(i+1).dzdx) ;
104 | case 'rec'
105 | res(i).dzdx = vl_myrec(res(i).x, opts.epsilon, res(i+1).dzdx) ;
106 | case 'log'
107 | res(i).dzdx = vl_mylog(res(i).x, res(i+1).dzdx) ;
108 | case 'softmaxloss'
109 | res(i).dzdx = vl_mysoftmaxloss(res(i).x, l.class, res(i+1).dzdx) ;
110 | case 'custom'
111 | res(i) = l.backward(l, res(i), res(i+1)) ;
112 | end
113 | if opts.conserveMemory
114 | res(i+1).dzdx = [] ;
115 | end
116 | if gpuMode & opts.sync
117 | wait(gpuDevice) ;
118 | end
119 | res(i).backwardTime = toc(res(i).backwardTime) ;
120 | end
121 | end
122 |
123 |
--------------------------------------------------------------------------------
/spdnet/vl_mylog.m:
--------------------------------------------------------------------------------
1 | function Y = vl_mylog(X, dzdy)
2 | %Y = VL_MYLOG(X, DZDY)
3 | %LogEig layer
4 |
5 | Us = cell(length(X),1);
6 | Ss = cell(length(X),1);
7 | Vs = cell(length(X),1);
8 |
9 | for ix = 1 : length(X)
10 | [Us{ix},Ss{ix},Vs{ix}] = svd(X{ix});
11 | end
12 |
13 |
14 | D = size(Ss{1},2);
15 | Y = cell(length(X),1);
16 |
17 | if nargin < 2
18 | for ix = 1:length(X)
19 | Y{ix} = Us{ix}*diag(log(diag(Ss{ix})))*Us{ix}';
20 | end
21 | else
22 | for ix = 1:length(X)
23 | U = Us{ix}; S = Ss{ix}; V = Vs{ix};
24 | diagS = diag(S);
25 | ind =diagS >(D*eps(max(diagS)));
26 | Dmin = (min(find(ind,1,'last'),D));
27 |
28 | S = S(:,ind); U = U(:,ind);
29 | dLdC = double(reshape(dzdy(:,ix),[D D])); dLdC = symmetric(dLdC);
30 |
31 |
32 | dLdV = 2*dLdC*U*diagLog(S,0);
33 | dLdS = diagInv(S)*(U'*dLdC*U);
34 | if sum(ind) == 1 % diag behaves badly when there is only 1d
35 | K = 1./(S(1)*ones(1,Dmin)-(S(1)*ones(1,Dmin))');
36 | K(eye(size(K,1))>0)=0;
37 | else
38 | K = 1./(diag(S)*ones(1,Dmin)-(diag(S)*ones(1,Dmin))');
39 | K(eye(size(K,1))>0)=0;
40 | K(find(isinf(K)==1))=0;
41 | end
42 | if all(diagS==1)
43 | dzdx = zeros(D,D);
44 | else
45 | dzdx = U*(symmetric(K'.*(U'*dLdV))+dDiag(dLdS))*U';
46 |
47 | end
48 | Y{ix} = dzdx; %warning('no normalization');
49 | end
50 | end
51 |
--------------------------------------------------------------------------------
/spdnet/vl_myrec.m:
--------------------------------------------------------------------------------
1 | function Y = vl_myrec(X, epsilon, dzdy)
2 | % Y = VL_MYREC (X, EPSILON, DZDY)
3 | % ReEig layer
4 |
5 | Us = cell(length(X),1);
6 | Ss = cell(length(X),1);
7 | Vs = cell(length(X),1);
8 |
9 | for ix = 1 : length(X)
10 | [Us{ix},Ss{ix},Vs{ix}] = svd(X{ix});
11 | end
12 |
13 | D = size(Ss{1},2);
14 | Y = cell(length(X),1);
15 |
16 | if nargin < 3
17 | for ix = 1:length(X)
18 | [max_S, ~]=max_eig(Ss{ix},epsilon);
19 | Y{ix} = Us{ix}*max_S*Us{ix}';
20 | end
21 | else
22 | for ix = 1:length(X)
23 | U = Us{ix}; S = Ss{ix}; V = Vs{ix};
24 |
25 | Dmin = D;
26 |
27 | dLdC = double(dzdy{ix}); dLdC = symmetric(dLdC);
28 |
29 | [max_S, max_I]=max_eig(Ss{ix},epsilon);
30 | dLdV = 2*dLdC*U*max_S;
31 | dLdS = (diag(not(max_I)))*U'*dLdC*U;
32 |
33 |
34 | K = 1./(diag(S)*ones(1,Dmin)-(diag(S)*ones(1,Dmin))');
35 | K(eye(size(K,1))>0)=0;
36 | K(find(isinf(K)==1))=0;
37 |
38 | dzdx = U*(symmetric(K'.*(U'*dLdV))+dDiag(dLdS))*U';
39 |
40 | Y{ix} = dzdx; %warning('no normalization');
41 | end
42 | end
43 |
--------------------------------------------------------------------------------
/spdnet/vl_mysoftmaxloss.m:
--------------------------------------------------------------------------------
1 | function Y = vl_mysoftmaxloss(X,c,dzdy)
2 | % Softmax layer
3 |
4 | % class c = 0 skips a spatial location
5 | mass = single(c > 0) ;
6 | mass = mass';
7 |
8 | % convert to indexes
9 | c_ = c - 1 ;
10 | for ic = 1 : length(c)
11 | c_(ic) = c(ic)+(ic-1)*size(X,1);
12 | end
13 |
14 | % compute softmaxloss
15 | Xmax = max(X,[],1) ;
16 | ex = exp(bsxfun(@minus, X, Xmax)) ;
17 |
18 | % s = bsxfun(@minus, X, Xmax);
19 | % ex = exp(s) ;
20 | % y = ex./repmat(sum(ex,1),[size(X,1) 1]);
21 |
22 | %n = sz(1)*sz(2) ;
23 | if nargin < 3
24 | t = Xmax + log(sum(ex,1)) - reshape(X(c_), [1 size(X,2)]) ;
25 | Y = sum(sum(mass .* t,1)) ;
26 | else
27 | Y = bsxfun(@rdivide, ex, sum(ex,1)) ;
28 | Y(c_) = Y(c_) - 1;
29 | Y = bsxfun(@times, Y, bsxfun(@times, mass, dzdy)) ;
30 | end
--------------------------------------------------------------------------------
/spdnet_afew.m:
--------------------------------------------------------------------------------
1 | function [net, info] = spdnet_afew(varargin)
2 | %set up the path
3 | confPath;
4 | %parameter setting
5 | opts.dataDir = fullfile('./data/afew') ;
6 | opts.imdbPathtrain = fullfile(opts.dataDir, 'spddb_afew_train_spd400_int_histeq.mat');
7 | opts.batchSize = 30 ;
8 | opts.test.batchSize = 1;
9 | opts.numEpochs = 500 ;
10 | opts.gpus = [] ;
11 | opts.learningRate = 0.01*ones(1,opts.numEpochs);
12 | opts.weightDecay = 0.0005 ;
13 | opts.continue = 1;
14 | %spdnet initialization
15 | net = spdnet_init_afew() ;
16 | %loading metadata
17 | load(opts.imdbPathtrain) ;
18 | %spdnet training
19 | [net, info] = spdnet_train_afew(net, spd_train, opts);
20 |
21 |
22 |
--------------------------------------------------------------------------------
/spdnet_init_afew.m:
--------------------------------------------------------------------------------
1 | function net = spdnet_init_afew(varargin)
2 | % spdnet_init initializes a spdnet
3 |
4 | rng('default');
5 | rng(0) ;
6 |
7 | opts.layernum = 3;
8 |
9 | Winit = cell(opts.layernum,1);
10 | opts.datadim = [400, 200, 100, 50];
11 |
12 |
13 | for iw = 1 : opts.layernum
14 | A = rand(opts.datadim(iw));
15 | [U1, S1, V1] = svd(A * A');
16 | Winit{iw} = U1(:,1:opts.datadim(iw+1));
17 | end
18 |
19 | f=1/100 ;
20 | classNum = 7;
21 | fdim = size(Winit{iw},2)*size(Winit{iw},2);
22 | theta = f*randn(fdim, classNum, 'single');
23 | Winit{end+1} = theta;
24 |
25 | net.layers = {} ;
26 | net.layers{end+1} = struct('type', 'bfc',...
27 | 'weight', Winit{1}) ;
28 | net.layers{end+1} = struct('type', 'rec') ;
29 | net.layers{end+1} = struct('type', 'bfc',...
30 | 'weight', Winit{2}) ;
31 | net.layers{end+1} = struct('type', 'rec') ;
32 | net.layers{end+1} = struct('type', 'bfc',...
33 | 'weight', Winit{3}) ;
34 | net.layers{end+1} = struct('type', 'log') ;
35 | net.layers{end+1} = struct('type', 'fc', ...
36 | 'weight', Winit{end}) ;
37 | net.layers{end+1} = struct('type', 'softmaxloss') ;
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/spdnet_train_afew.m:
--------------------------------------------------------------------------------
1 | function [net, info] = spdnet_train_afew(net, spd_train, opts)
2 |
3 | opts.errorLabels = {'top1e'};
4 | opts.train = find(spd_train.spd.set==1) ;
5 | opts.val = find(spd_train.spd.set==2) ;
6 |
7 | for epoch=1:opts.numEpochs
8 | learningRate = opts.learningRate(epoch);
9 |
10 | % fast-forward to last checkpoint
11 | modelPath = @(ep) fullfile(opts.dataDir, sprintf('net-epoch-%d.mat', ep));
12 | modelFigPath = fullfile(opts.dataDir, 'net-train.pdf') ;
13 | if opts.continue
14 | if exist(modelPath(epoch),'file')
15 | if epoch == opts.numEpochs
16 | load(modelPath(epoch), 'net', 'info') ;
17 | end
18 | continue ;
19 | end
20 | if epoch > 1
21 | fprintf('resuming by loading epoch %d\n', epoch-1) ;
22 | load(modelPath(epoch-1), 'net', 'info') ;
23 | end
24 | end
25 |
26 | train = opts.train(randperm(length(opts.train))) ; % shuffle
27 | val = opts.val;
28 |
29 | [net,stats.train] = process_epoch(opts, epoch, spd_train, train, learningRate, net) ;
30 | [net,stats.val] = process_epoch(opts, epoch, spd_train, val, 0, net) ;
31 |
32 |
33 | % save
34 | evaluateMode = 0;
35 |
36 | if evaluateMode, sets = {'train'} ; else sets = {'train', 'val'} ; end
37 |
38 | for f = sets
39 | f = char(f) ;
40 | n = numel(eval(f)) ; %
41 | info.(f).objective(epoch) = stats.(f)(2) / n ;
42 | info.(f).error(:,epoch) = stats.(f)(3:end) / n ;
43 | end
44 | if ~evaluateMode, save(modelPath(epoch), 'net', 'info') ; end
45 |
46 | figure(1) ; clf ;
47 | hasError = 1 ;
48 | subplot(1,1+hasError,1) ;
49 | if ~evaluateMode
50 | semilogy(1:epoch, info.train.objective, '.-', 'linewidth', 2) ;
51 | hold on ;
52 | end
53 | semilogy(1:epoch, info.val.objective, '.--') ;
54 | xlabel('training epoch') ; ylabel('energy') ;
55 | grid on ;
56 | h=legend(sets) ;
57 | set(h,'color','none');
58 | title('objective') ;
59 | if hasError
60 | subplot(1,2,2) ; leg = {} ;
61 | if ~evaluateMode
62 | plot(1:epoch, info.train.error', '.-', 'linewidth', 2) ;
63 | hold on ;
64 | leg = horzcat(leg, strcat('train ', opts.errorLabels)) ;
65 | end
66 | plot(1:epoch, info.val.error', '.--') ;
67 | leg = horzcat(leg, strcat('val ', opts.errorLabels)) ;
68 | set(legend(leg{:}),'color','none') ;
69 | grid on ;
70 | xlabel('training epoch') ; ylabel('error') ;
71 | title('error') ;
72 | end
73 | drawnow ;
74 | print(1, modelFigPath, '-dpdf') ;
75 | end
76 |
77 |
78 |
79 | function [net,stats] = process_epoch(opts, epoch, spd_train, trainInd, learningRate, net)
80 |
81 | training = learningRate > 0 ;
82 | if training, mode = 'training' ; else mode = 'validation' ; end
83 |
84 | stats = [0 ; 0 ; 0] ;
85 | numGpus = numel(opts.gpus) ;
86 | if numGpus >= 1
87 | one = gpuArray(single(1)) ;
88 | else
89 | one = single(1) ;
90 | end
91 |
92 | batchSize = opts.batchSize;
93 | errors = 0;
94 | numDone = 0 ;
95 |
96 |
97 | for ib = 1 : batchSize : length(trainInd)
98 | fprintf('%s: epoch %02d: batch %3d/%3d:', mode, epoch, ib,length(trainInd)) ;
99 | batchTime = tic ;
100 | res = [];
101 | if (ib+batchSize> length(trainInd))
102 | batchSize_r = length(trainInd)-ib+1;
103 | else
104 | batchSize_r = batchSize;
105 | end
106 | spd_data = cell(batchSize_r,1);
107 | spd_label = zeros(batchSize_r,1);
108 | for ib_r = 1 : batchSize_r
109 | spdPath = [spd_train.spdDir '\' spd_train.spd.name{trainInd(ib+ib_r-1)}];
110 | load(spdPath); spd_data{ib_r} = Y1;
111 | spd_label(ib_r) = spd_train.spd.label(trainInd(ib+ib_r-1));
112 |
113 | end
114 | net.layers{end}.class = spd_label ;
115 |
116 | %forward/backward spdnet
117 | if training, dzdy = one; else dzdy = [] ; end
118 | res = vl_myforbackward(net, spd_data, dzdy, res) ;
119 |
120 | %accumulating graidents
121 | if numGpus <= 1
122 | [net,res] = accumulate_gradients(opts, learningRate, batchSize_r, net, res) ;
123 | else
124 | if isempty(mmap)
125 | mmap = map_gradients(opts.memoryMapFile, net, res, numGpus) ;
126 | end
127 | write_gradients(mmap, net, res) ;
128 | labBarrier() ;
129 | [net,res] = accumulate_gradients(opts, learningRate, batchSize_r, net, res, mmap) ;
130 | end
131 |
132 | % accumulate training errors
133 | predictions = gather(res(end-1).x) ;
134 | [~,pre_label] = sort(predictions, 'descend') ;
135 | error = sum(~bsxfun(@eq, pre_label(1,:)', spd_label)) ;
136 |
137 | numDone = numDone + batchSize_r ;
138 | errors = errors+error;
139 | batchTime = toc(batchTime) ;
140 | speed = batchSize/batchTime ;
141 | stats = stats+[batchTime ; res(end).x ; error]; % works even when stats=[]
142 |
143 | fprintf(' %.2f s (%.1f data/s)', batchTime, speed) ;
144 |
145 | fprintf(' error: %.5f', stats(3)/numDone) ;
146 | fprintf(' obj: %.5f', stats(2)/numDone) ;
147 |
148 | fprintf(' [%d/%d]', numDone, batchSize_r);
149 | fprintf('\n') ;
150 |
151 | end
152 |
153 |
154 | % -------------------------------------------------------------------------
155 | function [net,res] = accumulate_gradients(opts, lr, batchSize, net, res, mmap)
156 | % -------------------------------------------------------------------------
157 | for l=numel(net.layers):-1:1
158 | if isempty(res(l).dzdw)==0
159 | if ~isfield(net.layers{l}, 'learningRate')
160 | net.layers{l}.learningRate = 1 ;
161 | end
162 | if ~isfield(net.layers{l}, 'weightDecay')
163 | net.layers{l}.weightDecay = 1;
164 | end
165 | thisLR = lr * net.layers{l}.learningRate ;
166 |
167 | if isfield(net.layers{l}, 'weight')
168 | if strcmp(net.layers{l}.type,'bfc')==1
169 | W1=net.layers{l}.weight;
170 | W1grad = (1/batchSize)*res(l).dzdw;
171 | %gradient update on Stiefel manifolds
172 | problemW1.M = stiefelfactory(size(W1,1), size(W1,2));
173 | W1Rgrad = (problemW1.M.egrad2rgrad(W1, W1grad));
174 | net.layers{l}.weight = (problemW1.M.retr(W1, -thisLR*W1Rgrad)); %%!!!NOTE
175 | else
176 | net.layers{l}.weight = net.layers{l}.weight - thisLR * (1/batchSize)* res(l).dzdw ;
177 | end
178 |
179 | end
180 | end
181 | end
182 |
183 |
--------------------------------------------------------------------------------
/utils/dDiag.m:
--------------------------------------------------------------------------------
1 | function M = dDiag(M)
2 | % double diag function i.e. return a matrix with all the elements except
3 | % for the diagonal
4 |
5 | if isa(M,'gpuArray')
6 | I = eye(size(M),'single','gpuArray');
7 | else
8 | I = eye(size(M));
9 | end
10 |
11 | M = I.*M;
--------------------------------------------------------------------------------
/utils/diagInv.m:
--------------------------------------------------------------------------------
1 | function invX = diagInv(X)
2 | % compute the inverse of a diagonal matrix
3 | % no verification for speed
4 | % FIXME add some checks for zeros on the diagonal
5 |
6 | diagX = diag(X);
7 | invX = diag(1./diagX);
8 | end
--------------------------------------------------------------------------------
/utils/diagLog.m:
--------------------------------------------------------------------------------
1 | function L = diagLog(D,c)
2 | % compute log of a diagonal matrix add constant displacement c if necessary
3 | if ~exist('c','var'), c = 0; end
4 | [M,N] = size(D);
5 | if isa(D,'gpuArray')
6 | L = zeros(size(D),'single','gpuArray');
7 | else
8 | L = zeros(size(D));
9 | end
10 |
11 | m = min(M,N);
12 | L(1:m,1:m) = diag(log(diag(D)+c));
13 | end
--------------------------------------------------------------------------------
/utils/max_eig.m:
--------------------------------------------------------------------------------
1 | function [L, L_I]=max_eig(D,c)
2 | if ~exist('c','var'), c = 0; end
3 | [M,N] = size(D);
4 | if isa(D,'gpuArray')
5 | L = zeros(size(D),'single','gpuArray');
6 | else
7 | L = zeros(size(D));
8 | end
9 |
10 | m = min(M,N);
11 | h1=diag(D);
12 | % Ir1 = h1 < c;
13 | % %reLU
14 | % r1 = h1;
15 | % r1(Ir1) = 0;
16 |
17 | L_I = h1 < c;
18 | %reLU
19 | h1(L_I) = c;
20 | L(1:m,1:m) = diag(h1);
--------------------------------------------------------------------------------
/utils/symmetric.m:
--------------------------------------------------------------------------------
1 | function ssX = symmetric(X)
2 | % symmetrized tensor
3 | % this is a version of symmetric that does symmetric of matrices on the
4 | % first 2 dimensions for each of the matrices given by the 3rd dimension
5 | %
6 | % (c) 2015 catalin ionescu -- catalin.ionescu@ins.uni-bonn.de
7 |
8 | ssX = .5*(X+permute(X,[2 1 3]));
9 | end
--------------------------------------------------------------------------------