├── .gitignore ├── CMakeLists.txt ├── README.md ├── cmake └── FindGlog.cmake ├── include ├── adaptive_linesearch.hpp ├── common.hpp ├── conjugate_gradient.hpp ├── cost_func.hpp ├── euclidean.hpp ├── loss.hpp ├── lrucache.hpp ├── manifold.hpp ├── minimizer.hpp ├── problem.hpp ├── product_manifold.hpp ├── rotation.hpp ├── sphere.hpp ├── tcg.hpp └── trust_region.hpp └── src ├── common.cc ├── loss.cc ├── main.cc ├── minimizer.cc ├── rayleigh_quotient_test.cc ├── tcg.cc ├── tcg_test.cc └── tr_test.cc /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | build 3 | bazel-* 4 | CMakeLists.txt.user -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(manopt) 4 | 5 | #set(CMAKE_BUILD_TYPE "RelWithDebInfo") 6 | set(CMAKE_BUILD_TYPE "Release") 7 | set (CMAKE_CXX_STANDARD 14) 8 | 9 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake") 10 | 11 | # find_package(Eigen3 REQUIRED) 12 | find_package(Glog REQUIRED) 13 | 14 | include_directories(include/) 15 | include_directories(/usr/include/eigen3/) 16 | 17 | add_executable(manopt_test src/main.cc) 18 | target_link_libraries(manopt_test glog) 19 | 20 | # add_executable(manopt_tcg_test src/tcg_test.cc src/common.cc src/minimizer.cc) 21 | add_executable(manopt_tr_test src/tr_test.cc src/tcg.cc src/common.cc src/minimizer.cc) 22 | target_link_libraries(manopt_tr_test glog) 23 | 24 | add_executable(rayleigh_quotient_test src/rayleigh_quotient_test.cc src/tcg.cc src/common.cc src/minimizer.cc) 25 | target_link_libraries(rayleigh_quotient_test glog) 26 | 27 | # add_executable(balm_test examples/balm.cc examples/PCRegistration/plane.cpp examples/PCRegistration/create_points.cpp examples/SE3/SO3.cpp examples/SE3/SE3.cpp examples/PCRegistration/arun.cpp examples/PCRegistration/plane_registration.cpp src/tcg.cc src/common.cc src/minimizer.cc) 28 | # target_link_libraries(balm_test glog) 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # manopt_cpp(c++ solver for optimization on manifolds) 2 | 3 | 流形上的求解器 4 | 5 | manopt_cpp是[manopt](https://www.manopt.org/) 的C++版本,用来求解流形上的优化问题。 6 | 7 | manopt_cpp使用了模板静态存储流形数据结构,避免了频繁申请内存带来的效率降低。 8 | 9 | 10 | 支持的流形 11 | 12 | 1)欧式空间 13 | 2)球面 14 | 3)旋转矩阵 15 | 4)积流形 16 | 17 | 支持的求解器 18 | 19 | 目前支持两种求解器,Trust-regions (RTR)和Conjugate-gradient。 20 | 21 | 22 | 求解最大特征值问题 $\max\limits_{x\in\mathbb{R}^n, x \neq 0} \frac{x^\top A x}{x^\top x}.$ 23 | 24 | 问题定义 25 | 26 | ```c++ 27 | class RQCostFunction : public GradientCostFunction { 28 | public: 29 | using Scalar = typename MType::Scalar; 30 | using MPoint = typename MType::MPoint; 31 | using TVector = typename MType::TVector; 32 | using MPtr = typename MType::Ptr; 33 | 34 | RQCostFunction(const MPtr& manifold_, const Eigen::MatrixXd& A_) 35 | : manifold(manifold_), A(A_) {} 36 | 37 | Scalar cost(const MPoint& x) const override { 38 | Eigen::MatrixXd v = -x.transpose() * A * x; 39 | return v(0, 0); 40 | } 41 | 42 | TVector gradient(const MPoint& x) const override { 43 | TVector grad = -2 * A * x; 44 | return manifold->proj(x, grad); 45 | } 46 | 47 | private: 48 | MPtr manifold; 49 | Eigen::MatrixXd A; 50 | }; 51 | ``` 52 | 53 | 求解 54 | ```c++ 55 | MType::Ptr M = std::make_shared(); 56 | typedef MType::MPoint MPoint; 57 | 58 | Eigen::MatrixXd B = Eigen::MatrixXd::Random(N, N); 59 | Eigen::MatrixXd A = 0.5*(B.transpose() + B); 60 | 61 | Problem::Ptr problem = std::make_shared>(); 62 | problem->setManifold(M); 63 | 64 | std::shared_ptr> func = std::make_shared(M, A); 65 | problem->setGradientCostFunction(func); 66 | 67 | MPoint x0 = M->rand(); 68 | 69 | TrustRegion tr(problem); 70 | Summary summary; 71 | double start = wallTimeInSeconds(); 72 | tr.solve(x0, &summary); 73 | std::cout << wallTimeInSeconds() - start << std::endl; 74 | std::cout << summary.fullReport() << std::endl; 75 | ``` 76 | -------------------------------------------------------------------------------- /cmake/FindGlog.cmake: -------------------------------------------------------------------------------- 1 | # Ceres Solver - A fast non-linear least squares minimizer 2 | # Copyright 2015 Google Inc. All rights reserved. 3 | # http://ceres-solver.org/ 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are met: 7 | # 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of Google Inc. nor the names of its contributors may be 14 | # used to endorse or promote products derived from this software without 15 | # specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | # POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Author: alexs.mac@gmail.com (Alex Stewart) 30 | # 31 | 32 | # FindGlog.cmake - Find Google glog logging library. 33 | # 34 | # This module defines the following variables: 35 | # 36 | # GLOG_FOUND: TRUE iff glog is found. 37 | # GLOG_INCLUDE_DIRS: Include directories for glog. 38 | # GLOG_LIBRARIES: Libraries required to link glog. 39 | # FOUND_INSTALLED_GLOG_CMAKE_CONFIGURATION: True iff the version of glog found 40 | # was built & installed / exported 41 | # as a CMake package. 42 | # 43 | # The following variables control the behaviour of this module: 44 | # 45 | # GLOG_PREFER_EXPORTED_GLOG_CMAKE_CONFIGURATION: TRUE/FALSE, iff TRUE then 46 | # then prefer using an exported CMake configuration 47 | # generated by glog > 0.3.4 over searching for the 48 | # glog components manually. Otherwise (FALSE) 49 | # ignore any exported glog CMake configurations and 50 | # always perform a manual search for the components. 51 | # Default: TRUE iff user does not define this variable 52 | # before we are called, and does NOT specify either 53 | # GLOG_INCLUDE_DIR_HINTS or GLOG_LIBRARY_DIR_HINTS 54 | # otherwise FALSE. 55 | # GLOG_INCLUDE_DIR_HINTS: List of additional directories in which to 56 | # search for glog includes, e.g: /timbuktu/include. 57 | # GLOG_LIBRARY_DIR_HINTS: List of additional directories in which to 58 | # search for glog libraries, e.g: /timbuktu/lib. 59 | # 60 | # The following variables are also defined by this module, but in line with 61 | # CMake recommended FindPackage() module style should NOT be referenced directly 62 | # by callers (use the plural variables detailed above instead). These variables 63 | # do however affect the behaviour of the module via FIND_[PATH/LIBRARY]() which 64 | # are NOT re-called (i.e. search for library is not repeated) if these variables 65 | # are set with valid values _in the CMake cache_. This means that if these 66 | # variables are set directly in the cache, either by the user in the CMake GUI, 67 | # or by the user passing -DVAR=VALUE directives to CMake when called (which 68 | # explicitly defines a cache variable), then they will be used verbatim, 69 | # bypassing the HINTS variables and other hard-coded search locations. 70 | # 71 | # GLOG_INCLUDE_DIR: Include directory for glog, not including the 72 | # include directory of any dependencies. 73 | # GLOG_LIBRARY: glog library, not including the libraries of any 74 | # dependencies. 75 | 76 | # Reset CALLERS_CMAKE_FIND_LIBRARY_PREFIXES to its value when 77 | # FindGlog was invoked. 78 | macro(GLOG_RESET_FIND_LIBRARY_PREFIX) 79 | if (MSVC AND CALLERS_CMAKE_FIND_LIBRARY_PREFIXES) 80 | set(CMAKE_FIND_LIBRARY_PREFIXES "${CALLERS_CMAKE_FIND_LIBRARY_PREFIXES}") 81 | endif() 82 | endmacro(GLOG_RESET_FIND_LIBRARY_PREFIX) 83 | 84 | # Called if we failed to find glog or any of it's required dependencies, 85 | # unsets all public (designed to be used externally) variables and reports 86 | # error message at priority depending upon [REQUIRED/QUIET/] argument. 87 | macro(GLOG_REPORT_NOT_FOUND REASON_MSG) 88 | unset(GLOG_FOUND) 89 | unset(GLOG_INCLUDE_DIRS) 90 | unset(GLOG_LIBRARIES) 91 | # Make results of search visible in the CMake GUI if glog has not 92 | # been found so that user does not have to toggle to advanced view. 93 | mark_as_advanced(CLEAR GLOG_INCLUDE_DIR 94 | GLOG_LIBRARY) 95 | 96 | glog_reset_find_library_prefix() 97 | 98 | # Note _FIND_[REQUIRED/QUIETLY] variables defined by FindPackage() 99 | # use the camelcase library name, not uppercase. 100 | if (Glog_FIND_QUIETLY) 101 | message(STATUS "Failed to find glog - " ${REASON_MSG} ${ARGN}) 102 | elseif (Glog_FIND_REQUIRED) 103 | message(FATAL_ERROR "Failed to find glog - " ${REASON_MSG} ${ARGN}) 104 | else() 105 | # Neither QUIETLY nor REQUIRED, use no priority which emits a message 106 | # but continues configuration and allows generation. 107 | message("-- Failed to find glog - " ${REASON_MSG} ${ARGN}) 108 | endif () 109 | return() 110 | endmacro(GLOG_REPORT_NOT_FOUND) 111 | 112 | # Protect against any alternative find_package scripts for this library having 113 | # been called previously (in a client project) which set GLOG_FOUND, but not 114 | # the other variables we require / set here which could cause the search logic 115 | # here to fail. 116 | unset(GLOG_FOUND) 117 | 118 | # ----------------------------------------------------------------- 119 | # By default, if the user has expressed no preference for using an exported 120 | # glog CMake configuration over performing a search for the installed 121 | # components, and has not specified any hints for the search locations, then 122 | # prefer a glog exported configuration if available. 123 | if (NOT DEFINED GLOG_PREFER_EXPORTED_GLOG_CMAKE_CONFIGURATION 124 | AND NOT GLOG_INCLUDE_DIR_HINTS 125 | AND NOT GLOG_LIBRARY_DIR_HINTS) 126 | message(STATUS "No preference for use of exported glog CMake configuration " 127 | "set, and no hints for include/library directories provided. " 128 | "Defaulting to preferring an installed/exported glog CMake configuration " 129 | "if available.") 130 | set(GLOG_PREFER_EXPORTED_GLOG_CMAKE_CONFIGURATION TRUE) 131 | endif() 132 | 133 | if (GLOG_PREFER_EXPORTED_GLOG_CMAKE_CONFIGURATION) 134 | # Try to find an exported CMake configuration for glog, as generated by 135 | # glog versions > 0.3.4 136 | # 137 | # We search twice, s/t we can invert the ordering of precedence used by 138 | # find_package() for exported package build directories, and installed 139 | # packages (found via CMAKE_SYSTEM_PREFIX_PATH), listed as items 6) and 7) 140 | # respectively in [1]. 141 | # 142 | # By default, exported build directories are (in theory) detected first, and 143 | # this is usually the case on Windows. However, on OS X & Linux, the install 144 | # path (/usr/local) is typically present in the PATH environment variable 145 | # which is checked in item 4) in [1] (i.e. before both of the above, unless 146 | # NO_SYSTEM_ENVIRONMENT_PATH is passed). As such on those OSs installed 147 | # packages are usually detected in preference to exported package build 148 | # directories. 149 | # 150 | # To ensure a more consistent response across all OSs, and as users usually 151 | # want to prefer an installed version of a package over a locally built one 152 | # where both exist (esp. as the exported build directory might be removed 153 | # after installation), we first search with NO_CMAKE_PACKAGE_REGISTRY which 154 | # means any build directories exported by the user are ignored, and thus 155 | # installed directories are preferred. If this fails to find the package 156 | # we then research again, but without NO_CMAKE_PACKAGE_REGISTRY, so any 157 | # exported build directories will now be detected. 158 | # 159 | # To prevent confusion on Windows, we also pass NO_CMAKE_BUILDS_PATH (which 160 | # is item 5) in [1]), to not preferentially use projects that were built 161 | # recently with the CMake GUI to ensure that we always prefer an installed 162 | # version if available. 163 | # 164 | # NOTE: We use the NAMES option as glog erroneously uses 'google-glog' as its 165 | # project name when built with CMake, but exports itself as just 'glog'. 166 | # On Linux/OS X this does not break detection as the project name is 167 | # not used as part of the install path for the CMake package files, 168 | # e.g. /usr/local/lib/cmake/glog, where the suffix is hardcoded 169 | # in glog's CMakeLists. However, on Windows the project name *is* 170 | # part of the install prefix: C:/Program Files/google-glog/[include,lib]. 171 | # However, by default CMake checks: 172 | # C:/Program Files/ which does not 173 | # exist and thus detection fails. Thus we use the NAMES to force the 174 | # search to use both google-glog & glog. 175 | # 176 | # [1] http://www.cmake.org/cmake/help/v2.8.11/cmake.html#command:find_package 177 | find_package(glog QUIET 178 | NAMES google-glog glog 179 | NO_MODULE 180 | NO_CMAKE_PACKAGE_REGISTRY 181 | NO_CMAKE_BUILDS_PATH) 182 | if (glog_FOUND) 183 | message(STATUS "Found installed version of glog: ${glog_DIR}") 184 | else() 185 | # Failed to find an installed version of glog, repeat search allowing 186 | # exported build directories. 187 | message(STATUS "Failed to find installed glog CMake configuration, " 188 | "searching for glog build directories exported with CMake.") 189 | # Again pass NO_CMAKE_BUILDS_PATH, as we know that glog is exported and 190 | # do not want to treat projects built with the CMake GUI preferentially. 191 | find_package(glog QUIET 192 | NAMES google-glog glog 193 | NO_MODULE 194 | NO_CMAKE_BUILDS_PATH) 195 | if (glog_FOUND) 196 | message(STATUS "Found exported glog build directory: ${glog_DIR}") 197 | endif(glog_FOUND) 198 | endif(glog_FOUND) 199 | 200 | set(FOUND_INSTALLED_GLOG_CMAKE_CONFIGURATION ${glog_FOUND}) 201 | 202 | if (FOUND_INSTALLED_GLOG_CMAKE_CONFIGURATION) 203 | message(STATUS "Detected glog version: ${glog_VERSION}") 204 | set(GLOG_FOUND ${glog_FOUND}) 205 | # glog wraps the include directories into the exported glog::glog target. 206 | set(GLOG_INCLUDE_DIR "") 207 | set(GLOG_LIBRARY glog::glog) 208 | else (FOUND_INSTALLED_GLOG_CMAKE_CONFIGURATION) 209 | message(STATUS "Failed to find an installed/exported CMake configuration " 210 | "for glog, will perform search for installed glog components.") 211 | endif (FOUND_INSTALLED_GLOG_CMAKE_CONFIGURATION) 212 | endif(GLOG_PREFER_EXPORTED_GLOG_CMAKE_CONFIGURATION) 213 | 214 | if (NOT GLOG_FOUND) 215 | # Either failed to find an exported glog CMake configuration, or user 216 | # told us not to use one. Perform a manual search for all glog components. 217 | 218 | # Handle possible presence of lib prefix for libraries on MSVC, see 219 | # also GLOG_RESET_FIND_LIBRARY_PREFIX(). 220 | if (MSVC) 221 | # Preserve the caller's original values for CMAKE_FIND_LIBRARY_PREFIXES 222 | # s/t we can set it back before returning. 223 | set(CALLERS_CMAKE_FIND_LIBRARY_PREFIXES "${CMAKE_FIND_LIBRARY_PREFIXES}") 224 | # The empty string in this list is important, it represents the case when 225 | # the libraries have no prefix (shared libraries / DLLs). 226 | set(CMAKE_FIND_LIBRARY_PREFIXES "lib" "" "${CMAKE_FIND_LIBRARY_PREFIXES}") 227 | endif (MSVC) 228 | 229 | # Search user-installed locations first, so that we prefer user installs 230 | # to system installs where both exist. 231 | list(APPEND GLOG_CHECK_INCLUDE_DIRS 232 | /usr/local/include 233 | /usr/local/homebrew/include # Mac OS X 234 | /opt/local/var/macports/software # Mac OS X. 235 | /opt/local/include 236 | /usr/include) 237 | # Windows (for C:/Program Files prefix). 238 | list(APPEND GLOG_CHECK_PATH_SUFFIXES 239 | glog/include 240 | glog/Include 241 | Glog/include 242 | Glog/Include 243 | google-glog/include # CMake installs with project name prefix. 244 | google-glog/Include) 245 | 246 | list(APPEND GLOG_CHECK_LIBRARY_DIRS 247 | /usr/local/lib 248 | /usr/local/homebrew/lib # Mac OS X. 249 | /opt/local/lib 250 | /usr/lib) 251 | # Windows (for C:/Program Files prefix). 252 | list(APPEND GLOG_CHECK_LIBRARY_SUFFIXES 253 | glog/lib 254 | glog/Lib 255 | Glog/lib 256 | Glog/Lib 257 | google-glog/lib # CMake installs with project name prefix. 258 | google-glog/Lib) 259 | 260 | # Search supplied hint directories first if supplied. 261 | find_path(GLOG_INCLUDE_DIR 262 | NAMES glog/logging.h 263 | HINTS ${GLOG_INCLUDE_DIR_HINTS} 264 | PATHS ${GLOG_CHECK_INCLUDE_DIRS} 265 | PATH_SUFFIXES ${GLOG_CHECK_PATH_SUFFIXES}) 266 | if (NOT GLOG_INCLUDE_DIR OR 267 | NOT EXISTS ${GLOG_INCLUDE_DIR}) 268 | glog_report_not_found( 269 | "Could not find glog include directory, set GLOG_INCLUDE_DIR " 270 | "to directory containing glog/logging.h") 271 | endif (NOT GLOG_INCLUDE_DIR OR 272 | NOT EXISTS ${GLOG_INCLUDE_DIR}) 273 | 274 | find_library(GLOG_LIBRARY NAMES glog 275 | HINTS ${GLOG_LIBRARY_DIR_HINTS} 276 | PATHS ${GLOG_CHECK_LIBRARY_DIRS} 277 | PATH_SUFFIXES ${GLOG_CHECK_LIBRARY_SUFFIXES}) 278 | if (NOT GLOG_LIBRARY OR 279 | NOT EXISTS ${GLOG_LIBRARY}) 280 | glog_report_not_found( 281 | "Could not find glog library, set GLOG_LIBRARY " 282 | "to full path to libglog.") 283 | endif (NOT GLOG_LIBRARY OR 284 | NOT EXISTS ${GLOG_LIBRARY}) 285 | 286 | # Mark internally as found, then verify. GLOG_REPORT_NOT_FOUND() unsets 287 | # if called. 288 | set(GLOG_FOUND TRUE) 289 | 290 | # Glog does not seem to provide any record of the version in its 291 | # source tree, thus cannot extract version. 292 | 293 | # Catch case when caller has set GLOG_INCLUDE_DIR in the cache / GUI and 294 | # thus FIND_[PATH/LIBRARY] are not called, but specified locations are 295 | # invalid, otherwise we would report the library as found. 296 | if (GLOG_INCLUDE_DIR AND 297 | NOT EXISTS ${GLOG_INCLUDE_DIR}/glog/logging.h) 298 | glog_report_not_found( 299 | "Caller defined GLOG_INCLUDE_DIR:" 300 | " ${GLOG_INCLUDE_DIR} does not contain glog/logging.h header.") 301 | endif (GLOG_INCLUDE_DIR AND 302 | NOT EXISTS ${GLOG_INCLUDE_DIR}/glog/logging.h) 303 | # TODO: This regex for glog library is pretty primitive, we use lowercase 304 | # for comparison to handle Windows using CamelCase library names, could 305 | # this check be better? 306 | string(TOLOWER "${GLOG_LIBRARY}" LOWERCASE_GLOG_LIBRARY) 307 | if (GLOG_LIBRARY AND 308 | NOT "${LOWERCASE_GLOG_LIBRARY}" MATCHES ".*glog[^/]*") 309 | glog_report_not_found( 310 | "Caller defined GLOG_LIBRARY: " 311 | "${GLOG_LIBRARY} does not match glog.") 312 | endif (GLOG_LIBRARY AND 313 | NOT "${LOWERCASE_GLOG_LIBRARY}" MATCHES ".*glog[^/]*") 314 | 315 | glog_reset_find_library_prefix() 316 | 317 | endif(NOT GLOG_FOUND) 318 | 319 | # Set standard CMake FindPackage variables if found. 320 | if (GLOG_FOUND) 321 | set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR}) 322 | set(GLOG_LIBRARIES ${GLOG_LIBRARY}) 323 | endif (GLOG_FOUND) 324 | 325 | # If we are using an exported CMake glog target, the include directories are 326 | # wrapped into the target itself, and do not have to be (and are not) 327 | # separately specified. In which case, we should not add GLOG_INCLUDE_DIRS 328 | # to the list of required variables in order that glog be reported as found. 329 | if (FOUND_INSTALLED_GLOG_CMAKE_CONFIGURATION) 330 | set(GLOG_REQUIRED_VARIABLES GLOG_LIBRARIES) 331 | else() 332 | set(GLOG_REQUIRED_VARIABLES GLOG_INCLUDE_DIRS GLOG_LIBRARIES) 333 | endif() 334 | 335 | # Handle REQUIRED / QUIET optional arguments. 336 | include(FindPackageHandleStandardArgs) 337 | find_package_handle_standard_args(Glog DEFAULT_MSG 338 | ${GLOG_REQUIRED_VARIABLES}) 339 | 340 | # Only mark internal variables as advanced if we found glog, otherwise 341 | # leave them visible in the standard GUI for the user to set manually. 342 | if (GLOG_FOUND) 343 | mark_as_advanced(FORCE GLOG_INCLUDE_DIR 344 | GLOG_LIBRARY 345 | glog_DIR) # Autogenerated by find_package(glog) 346 | endif (GLOG_FOUND) 347 | -------------------------------------------------------------------------------- /include/adaptive_linesearch.hpp: -------------------------------------------------------------------------------- 1 | #ifndef ADAPTIVE_LINESEARCH_H_ 2 | #define ADAPTIVE_LINESEARCH_H_ 3 | 4 | #include "common.hpp" 5 | #include "manifold.hpp" 6 | 7 | namespace manopt 8 | { 9 | 10 | template 11 | class AdaptiveLineSearch 12 | { 13 | public: 14 | using MPoint = typename MType::MPoint; 15 | using TVector = typename MType::TVector; 16 | using Scalar = typename MType::Scalar; 17 | 18 | struct Options 19 | { 20 | double ls_contraction_factor = .5; 21 | double ls_suff_decr = .5; 22 | int ls_max_steps = 10; 23 | int ls_initial_stepsize = 1; 24 | }; 25 | 26 | AdaptiveLineSearch(typename Problem::Ptr problem_) : problem(problem_), init_alpha(-1) 27 | { 28 | } 29 | 30 | void search(const MPoint &x, const TVector &d, const Scalar &f0, const Scalar &df0, int &stepsize, MPoint &new_x) 31 | { 32 | double contraction_factor = options.ls_contraction_factor; 33 | double suff_decr = options.ls_suff_decr; 34 | int max_ls_steps = options.ls_max_steps; 35 | int initial_stepsize = options.ls_initial_stepsize; 36 | 37 | Scalar norm_d = problem->M()->norm(x, d); 38 | 39 | // 初始化alpha 40 | double alpha = initial_stepsize / norm_d; 41 | if (init_alpha > 0) 42 | { 43 | alpha = init_alpha; 44 | } 45 | 46 | TVector alpha_d = problem->M()->lincomb(x, alpha, d); 47 | MPoint newx = problem->M()->retr(x, alpha_d); 48 | 49 | Scalar newf = problem->getCost(newx); 50 | int cost_evaluations = 1; 51 | 52 | // 检查满足 Armijo 条件 53 | while (newf > f0 + suff_decr * alpha * df0) 54 | { 55 | alpha = contraction_factor * alpha; 56 | alpha_d = problem->M()->lincomb(x, alpha, d); 57 | newx = problem->M()->retr(x, alpha_d); 58 | newf = problem->getCost(newx); 59 | cost_evaluations++; 60 | 61 | // std::cout << "alpha = " << alpha << std::endl; 62 | 63 | if (cost_evaluations > max_ls_steps) 64 | { 65 | break; 66 | } 67 | } 68 | 69 | // std::cout << "cost_evaluations = " << cost_evaluations << std::endl; 70 | 71 | if (newf > f0) 72 | { 73 | alpha = 0; 74 | newx = x; 75 | } 76 | 77 | stepsize = alpha * norm_d; 78 | new_x = newx; 79 | 80 | // 更新初始init_alpha 81 | if (cost_evaluations == 1) 82 | { 83 | init_alpha = 2 * alpha; 84 | } 85 | else if (cost_evaluations == 2) 86 | { 87 | init_alpha = alpha; 88 | } 89 | else 90 | { 91 | init_alpha = 2 * alpha; 92 | } 93 | } 94 | 95 | private: 96 | typename Problem::Ptr problem; 97 | Options options; 98 | double init_alpha; 99 | }; 100 | 101 | } // namespace manopt 102 | 103 | #endif // ADAPTIVE_LINESEARCH_H_ 104 | -------------------------------------------------------------------------------- /include/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COMMON_H_ 2 | #define COMMON_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace manopt { 9 | 10 | double wallTimeInSeconds(); 11 | 12 | std::string stringPrintf(const char* format, ...); 13 | 14 | } // namespace manopt 15 | 16 | #endif // COMMON_H_ 17 | -------------------------------------------------------------------------------- /include/conjugate_gradient.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CONJUGATE_GRADIENT_H_ 2 | #define CONJUGATE_GRADIENT_H_ 3 | #include 4 | #include 5 | 6 | #include "adaptive_linesearch.hpp" 7 | #include "common.hpp" 8 | #include "manifold.hpp" 9 | #include "minimizer.hpp" 10 | 11 | namespace manopt 12 | { 13 | 14 | template 15 | class ConjugateGradient 16 | { 17 | public: 18 | using MPoint = typename MType::MPoint; 19 | using TVector = typename MType::TVector; 20 | using Scalar = typename MType::Scalar; 21 | 22 | enum BetaType 23 | { 24 | S_D, 25 | F_R, 26 | P_R, 27 | H_S, 28 | H_Z, 29 | L_S 30 | }; 31 | 32 | struct Options 33 | { 34 | double minstepsize = 1e-10; 35 | int maxiter = 1000; 36 | double tolgradnorm = 1e-6; 37 | int storedepth = 20; 38 | BetaType beta_type = H_Z; 39 | double orth_value = 1e10; 40 | double gradnorm_tol = 1e-5; 41 | double cost_change_tol = 1e-6; 42 | int max_loop = 200; 43 | double max_time = 20; 44 | }; 45 | 46 | ConjugateGradient(typename Problem::Ptr problem_) 47 | : problem(problem_), options() {} 48 | 49 | bool checkStoppingCriterionReached() 50 | { 51 | //检查梯度 52 | if (iteration_summary.gradient_norm < 53 | options.gradnorm_tol) 54 | { 55 | solver_summary->message = stringPrintf("Gradient tolerance reached. Gradient max norm: %lf <= %lf.", 56 | iteration_summary.gradient_norm, 57 | options.gradnorm_tol); 58 | solver_summary->termination_type = CONVERGENCE; 59 | return true; 60 | } 61 | 62 | //检查cost_change 63 | // if (iteration_summary.cost_change < options.cost_change_tol) 64 | // { 65 | // solver_summary->message = stringPrintf("Cost change tolerance reached. Cost change: %lf <= %lf.", 66 | // iteration_summary.cost_change, 67 | // std::abs(iteration_summary.cost) * options.cost_change_tol); 68 | // solver_summary->termination_type = CONVERGENCE; 69 | // return true; 70 | // } 71 | 72 | //检查执行时间 73 | const double total_solver_time = wallTimeInSeconds() - 74 | start_time_in_secs; 75 | if (total_solver_time > options.max_time) 76 | { 77 | solver_summary->message = stringPrintf("Maximum solver time reached. Total solver time: %lf >= %lf.", total_solver_time, 78 | options.max_time); 79 | solver_summary->termination_type = NO_CONVERGENCE; 80 | return true; 81 | } 82 | 83 | //检查迭代次数 84 | if (iteration_summary.iteration >= options.max_loop) 85 | { 86 | solver_summary->message = stringPrintf("Maximum number of iterations reached. Number of iterations: %d.", 87 | iteration_summary.iteration); 88 | solver_summary->termination_type = NO_CONVERGENCE; 89 | 90 | return true; 91 | } 92 | 93 | return false; 94 | } 95 | 96 | TVector getPreconditioner(MPoint &x, const TVector &grad) { return grad; } 97 | 98 | void solve(MPoint &x, Summary *summary) 99 | { 100 | //记录开始时间 101 | start_time_in_secs = wallTimeInSeconds(); 102 | solver_summary = summary; 103 | 104 | AdaptiveLineSearch linesearch(problem); 105 | 106 | Scalar cost = problem->getCost(x); 107 | TVector grad = problem->getGradient(x); 108 | 109 | Scalar norm_grad = problem->M()->norm(x, grad); 110 | 111 | TVector pgrad = getPreconditioner(x, grad); 112 | Scalar grad_pgrad = problem->M()->inner(x, grad, pgrad); 113 | TVector desc_dir = problem->M()->lincomb(x, -1, pgrad); 114 | 115 | int stepsize; 116 | MPoint new_x; 117 | 118 | iteration_summary.gradient_norm = norm_grad; 119 | while (!checkStoppingCriterionReached()) 120 | { 121 | Scalar df0 = problem->M()->inner(x, grad, desc_dir); 122 | 123 | // std::cout << "df0 = " << df0 << std::endl; 124 | // std::cout << "iter = " << iteration_summary.iteration << ", cost = " << cost << ", gradnorm = " << norm_grad << std::endl; 125 | 126 | if (df0 > 0) 127 | { 128 | throw new std::runtime_error("got an ascent direction"); 129 | } 130 | 131 | // std::cout << "x = " << x.transpose() << std::endl; 132 | // std::cout << "desc_dir = " << desc_dir.transpose() << std::endl; 133 | // std::cout << "cost = " << cost << std::endl; 134 | // std::cout << "df0 = " << df0 << std::endl; 135 | // std::cout << "=========" << std::endl; 136 | linesearch.search(x, desc_dir, cost, df0, stepsize, new_x); 137 | 138 | 139 | // std::cout << "stepsize = " << stepsize << std::endl; 140 | // std::cout << "new_x = " << new_x.transpose() << std::endl; 141 | 142 | Scalar new_cost = problem->getCost(new_x); 143 | TVector new_grad = problem->getGradient(new_x); 144 | 145 | Scalar new_norm_grad = problem->M()->norm(new_x, new_grad); 146 | TVector new_pgrad = getPreconditioner(new_x, new_grad); 147 | Scalar new_grad_pgrad = problem->M()->inner(new_x, new_grad, new_pgrad); 148 | 149 | // std::cout << "new_norm_grad = " << new_norm_grad << std::endl; 150 | // std::cout << "new_pgrad = " << new_pgrad.transpose() << std::endl; 151 | // std::cout << "new_grad_pgrad = " << new_grad_pgrad << std::endl; 152 | 153 | // if (i ++ > 10) { 154 | // break; 155 | // } 156 | 157 | double beta; 158 | if (options.beta_type == S_D) 159 | { 160 | beta = 0; 161 | desc_dir = problem->M()->lincomb(new_x, -1, new_pgrad); 162 | } 163 | else 164 | { 165 | // std::cout << "x = " << x.transpose() << ", new_x = " << new_x.transpose() << std::endl; 166 | TVector old_grad = problem->M()->transp(x, new_x, grad); 167 | Scalar orth_grads = 168 | problem->M()->inner(new_x, old_grad, new_pgrad) / new_grad_pgrad; 169 | 170 | if (std::abs(orth_grads) >= options.orth_value) 171 | { 172 | beta = 0; 173 | desc_dir = problem->M()->lincomb(x, -1, new_pgrad); 174 | } 175 | else 176 | { 177 | desc_dir = problem->M()->transp(x, new_x, desc_dir); 178 | // std::cout << "desc_dir = " << desc_dir.transpose() << std::endl; 179 | 180 | switch (options.beta_type) 181 | { 182 | case F_R: 183 | beta = new_grad_pgrad / grad_pgrad; 184 | 185 | break; 186 | case P_R: 187 | { 188 | TVector diff = 189 | problem->M()->lincomb(new_x, 1, new_grad, -1, old_grad); 190 | Scalar ip_diff = problem->M()->inner(new_x, new_pgrad, diff); 191 | beta = ip_diff / grad_pgrad; 192 | beta = std::max(0, beta); 193 | } 194 | break; 195 | 196 | case H_S: 197 | { 198 | TVector diff = 199 | problem->M()->lincomb(new_x, 1, new_grad, -1, old_grad); 200 | Scalar ip_diff = problem->M()->inner(new_x, new_pgrad, diff); 201 | beta = ip_diff / problem->M()->inner(new_x, diff, desc_dir); 202 | // std::cout << "beta = " << beta << std::endl; 203 | beta = std::max(0, beta); 204 | } 205 | break; 206 | 207 | case H_Z: 208 | { 209 | TVector diff = 210 | problem->M()->lincomb(new_x, 1, new_grad, -1, old_grad); 211 | TVector p_old_grad = problem->M()->transp(x, new_x, pgrad); 212 | TVector p_diff = problem->M()->lincomb(new_x, 1, new_pgrad, -1, p_old_grad); 213 | Scalar deno = problem->M()->inner(new_x, diff, desc_dir); 214 | Scalar numo = problem->M()->inner(new_x, diff, new_pgrad); 215 | numo = numo - 2 * problem->M()->inner(new_x, diff, p_diff) * problem->M()->inner(new_x, desc_dir, new_grad) / deno; 216 | beta = numo / deno; 217 | 218 | // Robustness(see Hager - Zhang paper mentioned above) 219 | Scalar desc_dir_norm = problem->M()->norm(new_x, desc_dir); 220 | Scalar eta_HZ = -1 / (desc_dir_norm * std::min(0.01, norm_grad)); 221 | beta = std::max(beta, eta_HZ); 222 | } 223 | break; 224 | 225 | case L_S: 226 | { 227 | TVector diff = 228 | problem->M()->lincomb(new_x, 1, new_grad, -1, old_grad); 229 | Scalar ip_diff = problem->M()->inner(new_x, new_pgrad, diff); 230 | Scalar denom = -problem->M()->inner(x, grad, desc_dir); 231 | Scalar betaLS = ip_diff / denom; 232 | Scalar betaCD = new_grad_pgrad / denom; 233 | beta = std::max(0, std::min(betaLS, betaCD)); 234 | } 235 | 236 | break; 237 | default: 238 | throw new std::runtime_error("Unknown options.beta_type."); 239 | } 240 | 241 | desc_dir = problem->M()->lincomb(new_x, -1, new_pgrad, beta, desc_dir); 242 | } 243 | } 244 | 245 | iteration_summary.cost = new_cost; 246 | iteration_summary.cost_change = cost - new_cost; 247 | iteration_summary.gradient_norm = new_norm_grad; 248 | iteration_summary.iteration += 1; 249 | 250 | x = new_x; 251 | cost = new_cost; 252 | grad = new_grad; 253 | pgrad = new_pgrad; 254 | norm_grad = new_norm_grad; 255 | grad_pgrad = new_grad_pgrad; 256 | } 257 | } 258 | 259 | private: 260 | IterationSummary iteration_summary; 261 | typename Problem::Ptr problem; 262 | Options options; 263 | 264 | Summary *solver_summary; 265 | double start_time_in_secs; 266 | }; 267 | 268 | } // namespace manopt 269 | 270 | #endif // TRUST_REGION_H_ 271 | -------------------------------------------------------------------------------- /include/cost_func.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COST_FUNCTION_H_ 2 | #define COST_FUNCTION_H_ 3 | 4 | #include "manifold.hpp" 5 | 6 | namespace manopt { 7 | 8 | template 9 | class GradientCostFunction { 10 | public: 11 | using Scalar = typename MType::Scalar; 12 | using MPoint = typename MType::MPoint; 13 | using TVector = typename MType::TVector; 14 | 15 | typedef std::shared_ptr> Ptr; 16 | public: 17 | virtual Scalar cost(const MPoint& x) const = 0; 18 | 19 | virtual TVector gradient(const MPoint& x) const = 0; 20 | 21 | virtual void iterationCallback(const MPoint& x){} 22 | }; 23 | 24 | } // namespace manopt 25 | 26 | #endif // COST_FUNCTION_H_ 27 | -------------------------------------------------------------------------------- /include/euclidean.hpp: -------------------------------------------------------------------------------- 1 | #ifndef EUCLIDEAN_H_ 2 | #define EUCLIDEAN_H_ 3 | 4 | #include "manifold.hpp" 5 | #include 6 | 7 | namespace manopt { 8 | 9 | template 10 | class Euclidean; 11 | 12 | namespace internal { 13 | template 14 | struct traits> { 15 | enum { 16 | Rows = N, 17 | Cols = (k==Dynamic ? Dynamic : k) 18 | }; 19 | 20 | using Scalar = Scalar_; 21 | using MPoint = Eigen::Matrix; 22 | using TVector = MPoint; 23 | using EPoint = MPoint; 24 | using ETVector = MPoint; 25 | }; 26 | } 27 | 28 | template 29 | class Euclidean : public MatrixManifold> { 30 | using Base = MatrixManifold>; 31 | public: 32 | using Scalar = typename Base::Scalar; 33 | using MPoint = typename Base::MPoint; 34 | using MSPoint = Eigen::Matrix; 35 | using TVector = typename Base::TVector; 36 | using ETVector = typename Base::ETVector; 37 | using VectorX = typename Eigen::Matrix; 38 | 39 | Euclidean(int sz = k) : sz_(sz) { 40 | CHECK(sz != Dynamic); 41 | } 42 | 43 | int dim() const override { 44 | return N*sz_; 45 | } 46 | 47 | MPoint rand() const override { 48 | return MPoint::Random(N, sz_); 49 | } 50 | 51 | TVector randvec(const MPoint& x) const override { 52 | TVector U = TVector::Random(N, sz_); 53 | return U / U.norm(); 54 | } 55 | 56 | TVector zerovec(const MPoint& x) const override { 57 | return TVector::Zero(N, sz_); 58 | } 59 | 60 | MSPoint at(const MPoint& x, int idx) const { 61 | CHECK(idx < sz_); 62 | return x.col(idx); 63 | } 64 | 65 | void set(MPoint& x, int idx, const MSPoint& v) const { 66 | CHECK(idx < sz_); 67 | x.col(idx) = v; 68 | } 69 | 70 | void accumulate(MPoint& x, int idx, const MSPoint& v) const { 71 | CHECK(idx < sz_); 72 | x.col(idx) += v; 73 | } 74 | 75 | Scalar inner(const MPoint& x, const TVector& d1, const TVector& d2) const override { 76 | Eigen::Map v1(d1.data(), d1.size()); 77 | Eigen::Map v2(d2.data(), d2.size()); 78 | return v1.dot(v2); 79 | } 80 | 81 | Scalar norm(const MPoint& x, const TVector& d) const override { 82 | Eigen::Map v(d.data(), d.size()); 83 | return v.norm(); 84 | } 85 | 86 | TVector proj(const MPoint& x, const ETVector& H) const override { 87 | return H; 88 | } 89 | 90 | TVector tangent(const MPoint& x, const ETVector& H) const override { 91 | return H; 92 | } 93 | 94 | MPoint retr(const MPoint& x, const TVector& U, Scalar t = 1) const override { 95 | return x+t*U; 96 | } 97 | 98 | TVector transp(const MPoint& x1, const MPoint& x, const TVector& U) const override { 99 | return U; 100 | } 101 | 102 | TVector lincomb(const MPoint&x, Scalar a, const TVector& d) const override { 103 | return a*d; 104 | } 105 | 106 | TVector lincomb(const MPoint&x, Scalar a1, const TVector& d1, Scalar a2, const TVector& d2) const override { 107 | return a1*d1+a2*d2; 108 | } 109 | 110 | Scalar typicaldist() const override { 111 | return sqrt(dim()); 112 | } 113 | 114 | private: 115 | int sz_; 116 | }; 117 | 118 | template 119 | using EuclideanX = Euclidean; 120 | 121 | } // namespace manopt 122 | 123 | #endif // EUCLIDEAN_H_ 124 | -------------------------------------------------------------------------------- /include/loss.hpp: -------------------------------------------------------------------------------- 1 | #ifndef LOSS_H_ 2 | #define LOSS_H_ 3 | 4 | namespace manopt { 5 | 6 | class LOSS { 7 | public: 8 | virtual ~LOSS() {} 9 | 10 | virtual double v(double v) const = 0; 11 | 12 | virtual double j(double v) const = 0; 13 | }; 14 | 15 | class TrivialLOSS : public LOSS { 16 | public: 17 | explicit TrivialLOSS() {} 18 | 19 | double v(double v) const override; 20 | 21 | double j(double v) const override; 22 | }; 23 | 24 | class HuberLOSS : public LOSS { 25 | public: 26 | explicit HuberLOSS(double a) : a_(a), b_(a * a) {} 27 | 28 | double v(double v) const override; 29 | 30 | double j(double v) const override; 31 | 32 | private: 33 | const double a_; 34 | // b = a^2. 35 | const double b_; 36 | }; 37 | 38 | class CauchyLOSS : public LOSS { 39 | public: 40 | explicit CauchyLOSS(double a) : b_(a * a), c_(1 / b_) {} 41 | 42 | double v(double v) const override; 43 | 44 | double j(double v) const override; 45 | 46 | private: 47 | // b = a^2. 48 | const double b_; 49 | // c = 1 / a^2. 50 | const double c_; 51 | }; 52 | 53 | } // namespace manopt 54 | 55 | #endif // LOSS_H_ 56 | -------------------------------------------------------------------------------- /include/lrucache.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * File: lrucache.hpp 3 | * Author: Alexander Ponomarev 4 | * 5 | * Created on June 20, 2013, 5:09 PM 6 | */ 7 | 8 | #ifndef _LRUCACHE_HPP_INCLUDED_ 9 | #define _LRUCACHE_HPP_INCLUDED_ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace cache { 18 | 19 | template 20 | class lru_cache { 21 | public: 22 | typedef typename std::pair key_value_pair_t; 23 | typedef typename std::list::iterator list_iterator_t; 24 | 25 | lru_cache(size_t max_size) : 26 | _max_size(max_size) { 27 | } 28 | 29 | void put(const key_t& key, const value_t& value) { 30 | auto it = _cache_items_map.find(key); 31 | _cache_items_list.push_front(key_value_pair_t(key, value)); 32 | if (it != _cache_items_map.end()) { 33 | _cache_items_list.erase(it->second); 34 | _cache_items_map.erase(it); 35 | } 36 | _cache_items_map[key] = _cache_items_list.begin(); 37 | 38 | if (_cache_items_map.size() > _max_size) { 39 | auto last = _cache_items_list.end(); 40 | last--; 41 | _cache_items_map.erase(last->first); 42 | _cache_items_list.pop_back(); 43 | } 44 | } 45 | 46 | const value_t& get(const key_t& key) { 47 | auto it = _cache_items_map.find(key); 48 | if (it == _cache_items_map.end()) { 49 | throw std::range_error("There is no such key in cache"); 50 | } else { 51 | _cache_items_list.splice(_cache_items_list.begin(), _cache_items_list, it->second); 52 | return it->second->second; 53 | } 54 | } 55 | 56 | void print() { 57 | for (auto elem : _cache_items_list) { 58 | std::cout << elem.first << std::endl; 59 | } 60 | } 61 | 62 | bool exists(const key_t& key) const { 63 | return _cache_items_map.find(key) != _cache_items_map.end(); 64 | } 65 | 66 | size_t size() const { 67 | return _cache_items_map.size(); 68 | } 69 | 70 | private: 71 | std::list _cache_items_list; 72 | std::unordered_map _cache_items_map; 73 | size_t _max_size; 74 | }; 75 | 76 | } // namespace cache 77 | 78 | #endif /* _LRUCACHE_HPP_INCLUDED_ */ 79 | 80 | -------------------------------------------------------------------------------- /include/manifold.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MANIFOLD_H_ 2 | #define MANIFOLD_H_ 3 | 4 | #include 5 | 6 | namespace manopt { 7 | 8 | const int Dynamic = -1; 9 | 10 | namespace internal { 11 | template 12 | struct traits; 13 | } 14 | 15 | template 16 | class AbstractManifold { 17 | public: 18 | using Scalar = typename manopt::internal::traits::Scalar; 19 | using MPoint = typename manopt::internal::traits::MPoint; 20 | using TVector = typename manopt::internal::traits::TVector; 21 | using EPoint = typename manopt::internal::traits::EPoint; 22 | using ETVector = typename manopt::internal::traits::ETVector; 23 | 24 | typedef std::shared_ptr> Ptr; 25 | 26 | virtual int dim() const = 0; 27 | 28 | //流形上不好定义zero,所以rand也用来初始化。 29 | virtual MPoint rand() const = 0; 30 | 31 | virtual TVector randvec(const MPoint& x) const = 0; 32 | 33 | virtual TVector zerovec(const MPoint& x) const = 0; 34 | 35 | virtual Scalar inner(const MPoint& x, const TVector& d1, const TVector& d2) const = 0; 36 | 37 | virtual Scalar norm(const MPoint& x, const TVector& d) const = 0; 38 | 39 | virtual TVector proj(const MPoint& x, const ETVector& H) const = 0; 40 | 41 | //H是embed上的表示 42 | virtual TVector tangent(const MPoint& x, const ETVector& H) const = 0; 43 | 44 | virtual MPoint retr(const MPoint& x, const TVector& U, Scalar t = 1) const = 0; 45 | 46 | //把x1点处的切向量U移动到x点。 47 | virtual MPoint transp(const MPoint& x1, const MPoint& x, const TVector& U) const = 0; 48 | 49 | virtual TVector lincomb(const MPoint&x, Scalar a, const TVector& d) const = 0; 50 | 51 | virtual TVector lincomb(const MPoint&x, Scalar a1, const TVector& d1, Scalar a2, const TVector& d2) const = 0; 52 | 53 | // returns the typical distance on the Manifold M, 54 | // which is for example the longest distance in a unit cell or injectivity radius. 55 | virtual Scalar typicaldist() const = 0; 56 | }; 57 | 58 | template 59 | class MatrixManifold : public AbstractManifold { 60 | public: 61 | enum { 62 | Rows = internal::traits::Rows, 63 | Cols = internal::traits::Cols, 64 | }; 65 | }; 66 | 67 | } // namespace manopt 68 | 69 | #endif // MANIFOLD_H_ 70 | -------------------------------------------------------------------------------- /include/minimizer.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MINIMIZER_H_ 2 | #define MINIMIZER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include "tcg.hpp" 8 | 9 | namespace manopt 10 | { 11 | 12 | enum TCGStopReason; 13 | 14 | enum TerminationType 15 | { 16 | // Minimizer terminated because one of the convergence criterion set 17 | // by the user was satisfied. 18 | CONVERGENCE, 19 | 20 | // The solver ran for maximum number of iterations or maximum amount 21 | // of time specified by the user, but none of the convergence 22 | // criterion specified by the user were met. The user's parameter 23 | // blocks will be updated with the solution found so far. 24 | NO_CONVERGENCE, 25 | 26 | // The minimizer terminated because of an error. The user's 27 | // parameter blocks will not be updated. 28 | FAILURE, 29 | }; 30 | 31 | struct IterationSummary 32 | { 33 | IterationSummary() 34 | : iteration(0), 35 | step_is_accept(false), 36 | cost(0.0), 37 | cost_change(0.0), 38 | gradient_norm(0.0), 39 | trust_region_ratio(0.0), 40 | trust_region_radius(0.0), 41 | consecutive_TRplus(0), 42 | consecutive_TRminus(0), 43 | used_cauchy(false), 44 | tcg_iterations(0), 45 | tcg_time(0.0), 46 | tcg_reason(), 47 | total_time(0.0) {} 48 | 49 | // 迭代次数 50 | int iteration; 51 | 52 | // step是否成功 53 | bool step_is_accept; 54 | 55 | // Cost 56 | double cost; 57 | 58 | // old_cost - new_cost 59 | double cost_change; 60 | 61 | // 2-norm of the gradient vector. 62 | double gradient_norm; 63 | 64 | //用来判断二阶近似是否准确 65 | double trust_region_ratio; 66 | 67 | // 信赖域半径 68 | double trust_region_radius; 69 | 70 | //信赖半径增加次数 71 | int consecutive_TRplus; 72 | 73 | //信赖半径减小次数 74 | int consecutive_TRminus; 75 | 76 | //是否使用柯西点 77 | bool used_cauchy; 78 | 79 | // TCG迭代次数 80 | int tcg_iterations; 81 | 82 | // TCG时间 83 | double tcg_time; 84 | 85 | // TCG停止理由 86 | enum TCGStopReason tcg_reason; 87 | 88 | // Solve总时间 89 | double total_time; 90 | }; 91 | 92 | struct Summary 93 | { 94 | Summary(); 95 | 96 | std::string fullReport() const; 97 | 98 | bool isSolutionUsable() const; 99 | 100 | std::string message = "ceres::Solve was not called."; 101 | 102 | TerminationType termination_type = FAILURE; 103 | 104 | std::vector iterations; 105 | }; 106 | 107 | } // namespace manopt 108 | 109 | #endif // MINIMIZER_H_ 110 | -------------------------------------------------------------------------------- /include/problem.hpp: -------------------------------------------------------------------------------- 1 | #ifndef PROBLEM_H_ 2 | #define PROBLEM_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "manifold.hpp" 11 | #include "cost_func.hpp" 12 | #include 13 | 14 | namespace manopt 15 | { 16 | template 17 | class Problem 18 | { 19 | public: 20 | using Scalar = typename MType::Scalar; 21 | using MPoint = typename MType::MPoint; 22 | using TVector = typename MType::TVector; 23 | using MPtr = typename MType::Ptr; 24 | 25 | typedef std::shared_ptr> Ptr; 26 | 27 | private: 28 | MPtr manifold; 29 | typename GradientCostFunction::Ptr functor_; 30 | GradientCostFunction *func_ptr_; 31 | 32 | public: 33 | Problem() : func_ptr_(nullptr){}; 34 | Problem(Problem &&); 35 | Problem &operator=(Problem &&){}; 36 | 37 | Problem(const Problem &) = delete; 38 | Problem &operator=(const Problem &) = delete; 39 | 40 | ~Problem(){}; 41 | 42 | void setManifold(MPtr manifold_) 43 | { 44 | manifold = manifold_; 45 | } 46 | 47 | MPtr M() const 48 | { 49 | return manifold; 50 | } 51 | 52 | // 53 | TVector getGradient(const MPoint &x) 54 | { 55 | CHECK(functor_.get() != nullptr || func_ptr_ != nullptr) << "must set setGradientCostFunction[Ptr] first!"; 56 | return func_ptr_ != nullptr ? func_ptr_->gradient(x) : functor_->gradient(x); 57 | } 58 | 59 | //实现数值Hessian 60 | TVector getHessian(const MPoint &x, const TVector &d) 61 | { 62 | Scalar norm_d = M()->norm(x, d); 63 | Scalar eps = (Scalar)std::pow(2, -14); 64 | if (norm_d < std::numeric_limits::epsilon()) 65 | { 66 | std::cout << "!!!norm_d is too small, return zerovec!!!" << std::endl; 67 | return M()->zerovec(x); 68 | } 69 | 70 | Scalar c = eps / norm_d; 71 | TVector grad = getGradient(x); 72 | 73 | MPoint x1 = M()->retr(x, d, c); 74 | TVector grad1 = getGradient(x1); 75 | 76 | grad1 = M()->transp(x1, x, grad1); 77 | return M()->lincomb(x, 1 / c, grad1, -1 / c, grad); 78 | } 79 | 80 | void setGradientCostFunction(typename GradientCostFunction::Ptr func) 81 | { 82 | functor_ = func; 83 | } 84 | 85 | void setGradientCostFunctionPtr(GradientCostFunction *func_ptr) 86 | { 87 | func_ptr_ = func_ptr; 88 | } 89 | 90 | Scalar getCost(const MPoint &x) 91 | { 92 | CHECK(functor_.get() != nullptr || func_ptr_ != nullptr) << "must set setGradientCostFunction[Ptr] first!"; 93 | return func_ptr_ != nullptr ? func_ptr_->cost(x) : functor_->cost(x); 94 | } 95 | 96 | std::pair getCostAndGrad(const MPoint &x) 97 | { 98 | return std::make_pair(getCost(x), getGradient(x)); 99 | } 100 | 101 | void iterationCallback(const MPoint &x) 102 | { 103 | CHECK(functor_.get() != nullptr || func_ptr_ != nullptr) << "must set setGradientCostFunction[Ptr] first!"; 104 | return func_ptr_ != nullptr ? func_ptr_->iterationCallback(x) : functor_->iterationCallback(x); 105 | } 106 | }; 107 | 108 | } // namespace manopt 109 | 110 | #endif // PROBLEM_H_ 111 | -------------------------------------------------------------------------------- /include/product_manifold.hpp: -------------------------------------------------------------------------------- 1 | #ifndef PRODUCT_EUCLIDEAN_H_ 2 | #define PRODUCT_EUCLIDEAN_H_ 3 | 4 | #include "manifold.hpp" 5 | #include 6 | #include 7 | #include 8 | 9 | namespace manopt { 10 | 11 | template 12 | class ProductManifold; 13 | 14 | namespace internal { 15 | template 16 | struct traits> { 17 | using Scalar = typename T1::Scalar; 18 | using MPoint = std::pair; 19 | using TVector = std::pair; 20 | using EPoint = std::pair; 21 | using ETVector = std::pair; 22 | }; 23 | } 24 | 25 | template 26 | class ProductManifold : public AbstractManifold> { 27 | using Base = AbstractManifold>; 28 | public: 29 | using Scalar = typename Base::Scalar; 30 | using MPoint = typename Base::MPoint; 31 | using TVector = typename Base::TVector; 32 | using ETVector = typename Base::ETVector; 33 | using SubType1 = T1; 34 | using SubType2 = T2; 35 | 36 | typedef std::shared_ptr> Ptr; 37 | 38 | ProductManifold(T1 t1_, T2 t2_) : 39 | t1(t1_), t2(t2_) 40 | {} 41 | 42 | int dim() const override { 43 | return t1.dim()+t2.dim(); 44 | } 45 | 46 | MPoint rand() const override { 47 | return std::make_pair(t1.rand(), t2.rand()); 48 | } 49 | 50 | TVector randvec(const MPoint& x) const override { 51 | return std::make_pair(t1.randvec(x.first), t2.randvec(x.second)); 52 | } 53 | 54 | TVector zerovec(const MPoint& x) const override { 55 | return std::make_pair(t1.zerovec(x.first), t2.zerovec(x.second)); 56 | } 57 | 58 | Scalar inner(const MPoint& x, const TVector& d1, const TVector& d2) const override { 59 | return t1.inner(x.first, d1.first, d2.first) + t2.inner(x.second, d1.second, d2.second); 60 | } 61 | 62 | Scalar norm(const MPoint& x, const TVector& d) const override { 63 | return sqrt(inner(x, d, d)); 64 | } 65 | 66 | TVector proj(const MPoint& x, const ETVector& H) const override { 67 | return std::make_pair(t1.proj(x.first, H.first), t2.proj(x.second, H.second)); 68 | } 69 | 70 | TVector tangent(const MPoint& x, const ETVector& H) const override { 71 | return std::make_pair(t1.tangent(x.first, H.first), t2.tangent(x.second, H.second)); 72 | } 73 | 74 | MPoint retr(const MPoint& x, const TVector& U, Scalar t = 1) const override { 75 | return std::make_pair(t1.retr(x.first, U.first, t), t2.retr(x.second, U.second, t)); 76 | } 77 | 78 | TVector transp(const MPoint& x1, const MPoint& x, const TVector& U) const override { 79 | return std::make_pair(t1.transp(x1.first, x.first, U.first), t2.transp(x1.second, x.second, U.second)); 80 | } 81 | 82 | TVector lincomb(const MPoint&x, Scalar a, const TVector& d) const override { 83 | return std::make_pair(t1.lincomb(x.first, a, d.first), 84 | t2.lincomb(x.second, a, d.second)); 85 | } 86 | 87 | TVector lincomb(const MPoint&x, Scalar a1, const TVector& d1, Scalar a2, const TVector& d2) const override { 88 | return std::make_pair(t1.lincomb(x.first, a1, d1.first, a2, d2.first), 89 | t2.lincomb(x.second, a1, d1.second, a2, d2.second)); 90 | } 91 | 92 | Scalar typicaldist() const override { 93 | return sqrt(std::pow(t1.typicaldist(), 2) + std::pow(t2.typicaldist(), 2)); 94 | } 95 | 96 | //private: => public 97 | T1 t1; 98 | T2 t2; 99 | }; 100 | 101 | } // namespace manopt 102 | 103 | 104 | #endif // PRODUCT_EUCLIDEAN_H_ 105 | -------------------------------------------------------------------------------- /include/rotation.hpp: -------------------------------------------------------------------------------- 1 | #ifndef ROTATION_H_ 2 | #define ROTATION_H_ 3 | 4 | #include "manifold.hpp" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace manopt { 11 | 12 | template 13 | class Rotation; 14 | 15 | namespace internal { 16 | template 17 | struct traits> { 18 | enum { 19 | Rows = N, 20 | Cols = (k==Dynamic ? Dynamic : N*k) 21 | }; 22 | 23 | using Scalar = Scalar_; 24 | using MPoint = Eigen::Matrix; 25 | using TVector = MPoint; 26 | using EPoint = MPoint; 27 | using ETVector = MPoint; 28 | }; 29 | } 30 | 31 | template 32 | class Rotation : public MatrixManifold> { 33 | using Base = MatrixManifold>; 34 | public: 35 | using Scalar = typename Base::Scalar; 36 | using MPoint = typename Base::MPoint; 37 | using MSPoint = Eigen::Matrix; 38 | using TVector = typename Base::TVector; 39 | using ETVector = typename Base::ETVector; 40 | using VectorX = typename Eigen::Matrix; 41 | 42 | typedef std::shared_ptr> Ptr; 43 | 44 | Rotation(int sz = k) : sz_(sz) { 45 | CHECK(sz != Dynamic); 46 | } 47 | 48 | int dim() const override { 49 | return sz_*N*(N-1)/2; 50 | } 51 | 52 | //Generated as such, Q is uniformly distributed over O(n), 53 | //the group of orthogonal matrices; see Mezzadri 2007 54 | MPoint rand() const override { 55 | MPoint Q = MPoint::Random(N, cols()); 56 | for (int i=0; i qr(rand); 59 | MSPoint QS = qr.householderQ(); 60 | if (QS.determinant() < 0) { 61 | const VectorX col0 = QS.col(0); 62 | QS.col(0) = QS.col(1); 63 | QS.col(1) = col0; 64 | } 65 | Q.block(0, i*N, N, N) = QS; 66 | } 67 | 68 | return Q; 69 | } 70 | 71 | TVector randvec(const MPoint& x) const { 72 | TVector U(N, cols()); 73 | for (int i=0; i v1(d1.data(), d1.size()); 101 | Eigen::Map v2(d2.data(), d2.size()); 102 | return v1.dot(v2); 103 | } 104 | 105 | Scalar norm(const MPoint& x, const TVector& d) const override { 106 | return d.norm(); 107 | } 108 | 109 | TVector proj(const MPoint& x, const ETVector& H) const override { 110 | TVector U(N, cols()); 111 | for (int i=0; i qr(x.block(0, i*N, N, N)+t*xu); 132 | MSPoint QS = qr.householderQ(); 133 | MSPoint R = qr.matrixQR(); 134 | for (int j=0; j3) { 173 | throw new std::runtime_error("not supported!"); 174 | } 175 | 176 | // if (N > 3) { 177 | // int n = 3; 178 | // for (int i=3; i 197 | using RotationX = Rotation; 198 | 199 | } // namespace manopt 200 | 201 | #endif // ROTATION_H_ 202 | -------------------------------------------------------------------------------- /include/sphere.hpp: -------------------------------------------------------------------------------- 1 | #ifndef SPHERE_H_ 2 | #define SPHERE_H_ 3 | 4 | #include "manifold.hpp" 5 | #include 6 | #include 7 | 8 | namespace manopt { 9 | 10 | template 11 | class Sphere; 12 | 13 | namespace internal { 14 | template 15 | struct traits> { 16 | using Scalar = Scalar_; 17 | using MPoint = Eigen::Matrix; 18 | using TVector = MPoint; 19 | using EPoint = MPoint; 20 | using ETVector = MPoint; 21 | }; 22 | } 23 | 24 | template 25 | class Sphere : public AbstractManifold> { 26 | using Base = AbstractManifold>; 27 | public: 28 | using Scalar = typename Base::Scalar; 29 | using MPoint = typename Base::MPoint; 30 | using TVector = typename Base::TVector; 31 | using ETVector = typename Base::ETVector; 32 | using VectorX = typename Eigen::Matrix; 33 | 34 | typedef std::shared_ptr> Ptr; 35 | 36 | int dim() const override { 37 | return N-1; 38 | } 39 | 40 | MPoint rand() const override { 41 | MPoint x = MPoint::Random(); 42 | return x / x.norm(); 43 | } 44 | 45 | TVector randvec(const MPoint& x) const { 46 | TVector d = proj(x, TVector::Random()); 47 | return d / d.norm(); 48 | } 49 | 50 | TVector zerovec(const MPoint& x) const override { 51 | return TVector::Zero(); 52 | } 53 | 54 | Scalar inner(const MPoint& x, const TVector& d1, const TVector& d2) const override { 55 | Eigen::Map v1(d1.data(), d1.size()); 56 | Eigen::Map v2(d2.data(), d2.size()); 57 | return v1.dot(v2); 58 | } 59 | 60 | Scalar norm(const MPoint& x, const TVector& d) const override { 61 | Eigen::Map v(d.data(), d.size()); 62 | return v.norm(); 63 | } 64 | 65 | TVector proj(const MPoint& x, const ETVector& H) const override { 66 | Eigen::Map x1(x.data(), x.size()); 67 | Eigen::Map d1(H.data(), H.size()); 68 | 69 | return H - x1.dot(d1)*x; 70 | } 71 | 72 | TVector tangent(const MPoint& x, const ETVector& H) const override { 73 | return proj(x, H); 74 | } 75 | 76 | MPoint retr(const MPoint& x, const TVector& U, Scalar t = 1) const override { 77 | MPoint y = x + t*U; 78 | return y / y.norm(); 79 | } 80 | 81 | TVector transp(const MPoint& x1, const MPoint& x, const TVector& U) const override { 82 | return proj(x, U); 83 | } 84 | 85 | TVector lincomb(const MPoint&x, Scalar a, const TVector& d) const override { 86 | return a*d; 87 | } 88 | 89 | TVector lincomb(const MPoint&x, Scalar a1, const TVector& d1, Scalar a2, const TVector& d2) const override { 90 | return a1*d1+a2*d2; 91 | } 92 | 93 | Scalar typicaldist() const override { 94 | return M_PI; 95 | } 96 | 97 | }; 98 | 99 | } // namespace manopt 100 | 101 | #endif // SPHERE_H_ 102 | -------------------------------------------------------------------------------- /include/tcg.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TCG_H_ 2 | #define TCG_H_ 3 | #include 4 | #include 5 | #include 6 | #include "manifold.hpp" 7 | #include "problem.hpp" 8 | #include 9 | 10 | namespace manopt { 11 | 12 | enum TCGStopReason { 13 | NEGATIVE_CURVATURE, 14 | EXCEEDED_TRUST_REGION, 15 | LINEAR_CONVERGENCE, 16 | SUPERLINEAR_CONVERGENCE, 17 | MAXIMUM_INNER_REACHED, 18 | MODEL_INCREASED 19 | }; 20 | 21 | extern std::array TCGStopReasonStrings; 22 | 23 | template 24 | class TruncatedConjugateGradient { 25 | public: 26 | using MPoint = typename MType::MPoint; 27 | using TVector = typename MType::TVector; 28 | using Scalar = typename MType::Scalar; 29 | 30 | struct Result { 31 | TVector eta; 32 | TVector Heta; 33 | int loopCount; 34 | TCGStopReason stopReason; 35 | }; 36 | 37 | TruncatedConjugateGradient(typename Problem::Ptr problem_, 38 | double theta_, 39 | double kappa_, 40 | bool use_rand_, 41 | int min_loop_, 42 | int max_loop_, 43 | bool debug_) : 44 | problem(problem_), 45 | theta(theta_), 46 | kappa(kappa_), 47 | use_rand(use_rand_), 48 | min_loop(min_loop_), 49 | max_loop(max_loop_), 50 | debug(debug_) { 51 | CHECK(max_loop > min_loop); 52 | } 53 | 54 | // x:点所在流形上位置 55 | // grad: x处梯度 56 | // eta0: 是优化方向初值,非随机方法为zerovec(0)。 57 | // tr_radius: 信赖域半径 58 | // tCG求解子问题 59 | // ```math 60 | // \operatorname*{arg\,min}_{η ∈ T_xM} 61 | // m_x(η) \quad\text{where}  62 | // m_x(η) = F(x) + ⟨\operatorname{grad}F(x),η⟩_x + \frac{1}{2}⟨\operatorname{Hess}F(x)[η],η⟩_x, 63 | // ``` 64 | // 每次循环主要更新以下变量。 65 | // eta(优化变量的迭代值) 66 | // r(残差的迭代值) 67 | // z(残差preconditioner后的迭代值) 68 | // mdelta(每次优化方向) 69 | // 参考 70 | //https://manoptjl.org/v0.1/solvers/truncatedConjugateGradient/ 71 | // 72 | //TODO加上preconditioner的实现 73 | Result solve(const MPoint& x, const TVector& grad, const TVector& eta0, double tr_radius) { 74 | // 75 | 76 | TVector eta = eta0, Heta, r, z; 77 | Scalar e_Pe, model_value; 78 | 79 | if (use_rand) { 80 | Heta = problem->getHessian(x, eta); 81 | r = problem->M()->lincomb(x, 1, grad, 1, Heta); 82 | e_Pe = problem->M()->inner(x, eta, eta); 83 | z = r; 84 | model_value = problem->M()->inner(x, eta, grad) + 0.5 * problem->M()->inner(x, eta, Heta); 85 | } else { 86 | Heta = problem->M()->zerovec(x); 87 | r = grad; 88 | e_Pe = 0; 89 | z = r; /* get_preconditioner(p, x, r) not support yet */ 90 | model_value = 0; 91 | } 92 | 93 | //初始化残差 94 | Scalar r_r = problem->M()->inner(x, r, r); 95 | Scalar norm_r = sqrt(r_r); 96 | Scalar norm_r0 = norm_r; 97 | 98 | Scalar z_r = problem->M()->inner(x, z, r); 99 | Scalar d_Pd = z_r; 100 | TVector mdelta = z; 101 | Scalar e_Pd = use_rand ? -problem->M()->inner(x, eta, mdelta) : 0; 102 | 103 | for (int j=0; jgetHessian(x, mdelta); 106 | Scalar d_Hd = problem->M()->inner(x, mdelta, Hmdelta); 107 | Scalar alpha = z_r / d_Hd; 108 | 109 | //e_Pe_new为了方便计算|eta|^2_P。 110 | Scalar e_Pe_new = e_Pe + 2 * alpha * e_Pd + alpha * alpha * d_Pd; 111 | 112 | #if 0 113 | std::cout << "e_Pe_new = " << e_Pe_new 114 | << ",\t|eta|^2 = " << problem->M()->inner(x, 115 | problem->M()->lincomb(x, 1, eta, -alpha, mdelta), 116 | problem->M()->lincomb(x, 1, eta, -alpha, mdelta)) 117 | << std::endl; 118 | #endif 119 | 120 | if (debug) { 121 | printf("DBG: (r,r) : % 3.6e\n", r_r); 122 | printf("DBG: (d,Hd) : % 3.6e\n", d_Hd); 123 | printf("DBG: alpha : % 3.6e\n", alpha); 124 | } 125 | 126 | //如果极小值点处的d_Hd应该>=0, 127 | //e_Pe_new >= tr_radius*tr_radius表示超出指定的信赖域半径 128 | if (d_Hd <= 0 || e_Pe_new >= tr_radius*tr_radius) { 129 | //求解tau满足 | eta+ tau*mdelta |_P = Delta,这样返回的|eta|正好为信赖域半径 130 | Scalar tau = (-e_Pd + sqrt(e_Pd*e_Pd + d_Pd * (tr_radius*tr_radius - e_Pe))) / d_Pd; 131 | if (debug) { 132 | printf("DBG: tau : % 3.6e\n", tau); 133 | } 134 | eta = problem->M()->lincomb(x, 1, eta, -tau, mdelta); 135 | Heta = problem->M()->lincomb(x, 1, Heta, -tau, Hmdelta); 136 | 137 | if (d_Hd <= 0) { 138 | return Result{eta, Heta, j+1, NEGATIVE_CURVATURE}; 139 | } else { 140 | return Result{eta, Heta, j+1, EXCEEDED_TRUST_REGION}; 141 | } 142 | } 143 | 144 | e_Pe = e_Pe_new; 145 | 146 | TVector new_eta = problem->M()->lincomb(x, 1, eta, -alpha, mdelta); 147 | //应该是problem->getHessian(x, new_eta) 148 | //TODO 检查new_Hη是否能很好近似 149 | TVector new_Heta = problem->M()->lincomb(x, 1, Heta, -alpha, Hmdelta); 150 | 151 | // No negative curvature and eta - alpha * (mdelta) inside TR: accept it. 152 | Scalar new_model_value = problem->M()->inner(x, new_eta, grad) 153 | + 0.5 * problem->M()->inner(x, new_eta, new_Heta); 154 | if (new_model_value > model_value) { 155 | return Result{eta, Heta, j+1, MODEL_INCREASED}; 156 | } 157 | 158 | eta = new_eta; 159 | Heta = new_Heta; 160 | model_value = new_model_value; 161 | r = problem->M()->lincomb(x, 1, r, -alpha, Hmdelta); 162 | 163 | r_r = problem->M()->inner(x, r, r); 164 | norm_r = sqrt(r_r); 165 | 166 | //如果norm_r已经很小了 167 | if (j+1 >= min_loop 168 | && norm_r <= norm_r0*std::min(std::pow(norm_r0, theta), kappa)) { 169 | if (kappa < std::pow(norm_r0, theta)) { 170 | return Result{eta, Heta, j+1, LINEAR_CONVERGENCE}; //linear convergence 171 | } else { 172 | return Result{eta, Heta, j+1, SUPERLINEAR_CONVERGENCE}; //superlinear convergence 173 | } 174 | } 175 | 176 | // Precondition the r. 177 | z = use_rand ? r : r /* get_preconditioner(p, x, r) not support yet */; 178 | Scalar zold_rold = z_r; 179 | z_r = problem->M()->inner(x, z, r); 180 | //# Compute new search direction. 181 | Scalar beta = z_r / zold_rold; 182 | mdelta = problem->M()->lincomb(x, 1, z, beta, mdelta); 183 | //投影δ,由于存在误差可能会偏离切平面 184 | mdelta = problem->M()->tangent(x, mdelta); 185 | //方便计算是否超出信赖域 186 | e_Pd = beta * (alpha * d_Pd + e_Pd); 187 | d_Pd = z_r + beta * beta * d_Pd; 188 | } 189 | 190 | return Result{eta, Heta, max_loop, MAXIMUM_INNER_REACHED}; 191 | } 192 | 193 | private: 194 | typename Problem::Ptr problem; 195 | 196 | double theta; 197 | double kappa; 198 | bool use_rand; 199 | int min_loop; 200 | int max_loop; 201 | bool debug; 202 | }; 203 | 204 | } // namespace manopt 205 | 206 | 207 | #endif // TCG 208 | -------------------------------------------------------------------------------- /include/trust_region.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TRUST_REGION_H_ 2 | #define TRUST_REGION_H_ 3 | #include 4 | #include 5 | #include "common.hpp" 6 | #include "manifold.hpp" 7 | #include "minimizer.hpp" 8 | #include "tcg.hpp" 9 | 10 | namespace manopt { 11 | 12 | template 13 | class TrustRegion { 14 | public: 15 | using MPoint = typename MType::MPoint; 16 | using TVector = typename MType::TVector; 17 | using Scalar = typename MType::Scalar; 18 | using tCGResult = typename TruncatedConjugateGradient::Result; 19 | 20 | struct Options { 21 | int min_loop = 3; 22 | int max_loop = 60; 23 | int min_inner = 1; 24 | int max_inner = 1; 25 | double max_time = 20; 26 | double gradnorm_tol = 1e-5; 27 | double cost_change_tol = 1e-6; 28 | double kappa = 0.1; 29 | double theta = 1.0; 30 | double rho_prime = 0.1; 31 | double Delta_bar; //初始化为 M::typicaldist() 32 | double Delta0; 33 | bool use_rand = false; 34 | double rho_regularization = 1e3; 35 | bool verbosity = false; 36 | bool debug = false; 37 | 38 | Options(double db, int max_inner_) : 39 | max_inner(max_inner_), 40 | Delta_bar(db), 41 | Delta0(db/8) { } 42 | }; 43 | 44 | TrustRegion(typename Problem::Ptr problem_, double db=-1) : 45 | problem(problem_), 46 | options(db < 0 ? problem_->M()->typicaldist() : db, problem_->M()->dim()), 47 | tCG(problem_, options.theta, options.kappa, options.use_rand, 48 | options.min_inner, options.max_inner, options.debug){ 49 | } 50 | 51 | bool checkStoppingCriterionReached() { 52 | //检查梯度 53 | if (iteration_summary.step_is_accept && iteration_summary.gradient_norm < options.gradnorm_tol) { 54 | solver_summary->message = stringPrintf("Gradient tolerance reached. " 55 | "Gradient max norm: %lf <= %lf", 56 | iteration_summary.gradient_norm, 57 | options.gradnorm_tol); 58 | solver_summary->termination_type = CONVERGENCE; 59 | return true; 60 | } 61 | 62 | //检查cost_change 63 | if (iteration_summary.step_is_accept && 64 | iteration_summary.cost_change < std::abs(iteration_summary.cost)*options.cost_change_tol) { 65 | solver_summary->message = stringPrintf("Cost change tolerance reached. " 66 | "Cost change: %lf <= %lf", 67 | iteration_summary.cost_change, 68 | std::abs(iteration_summary.cost)*options.cost_change_tol); 69 | solver_summary->termination_type = CONVERGENCE; 70 | return true; 71 | } 72 | 73 | //检查执行时间 74 | const double total_solver_time = wallTimeInSeconds() - start_time_in_secs; 75 | if (total_solver_time > options.max_time) { 76 | solver_summary->message = stringPrintf("Maximum solver time reached. " 77 | "Total solver time: %lf >= %lf.", 78 | total_solver_time, 79 | options.max_time); 80 | solver_summary->termination_type = NO_CONVERGENCE; 81 | return true; 82 | } 83 | 84 | //检查迭代次数 85 | if (iteration_summary.iteration >= options.max_loop) { 86 | solver_summary->message = stringPrintf("Maximum number of iterations reached. " 87 | "Number of iterations: %d.", 88 | iteration_summary.iteration); 89 | solver_summary->termination_type = NO_CONVERGENCE; 90 | 91 | return true; 92 | } 93 | 94 | return false; 95 | } 96 | 97 | void solve(MPoint& x, Summary* summary) { 98 | //记录开始时间 99 | start_time_in_secs = wallTimeInSeconds(); 100 | solver_summary = summary; 101 | 102 | Scalar fx = problem->getCost(x); 103 | TVector fgradx = problem->getGradient(x); 104 | Scalar norm_grad = problem->M()->norm(x, fgradx); 105 | 106 | iteration_summary.cost = fx; 107 | //iteration_summary.cost_change = 0.0; 108 | iteration_summary.gradient_norm = norm_grad; 109 | iteration_summary.trust_region_radius = options.Delta0; 110 | iteration_summary.consecutive_TRplus = 0; 111 | iteration_summary.consecutive_TRminus = 0; 112 | 113 | solver_summary->iterations.clear(); 114 | solver_summary->iterations.push_back(iteration_summary); 115 | 116 | 117 | //输出LOG 118 | if (options.verbosity) { 119 | std::string output = "iter acc/REJ cost |gradient| tr_ratio tr_radius tcg_iter tcg_time total_time tcg_reason\n"; 120 | output += stringPrintf("% 4d % 3.2e % 3.2e ", 121 | 0, 122 | fx, 123 | norm_grad 124 | ); 125 | std::cout << output << std::endl; 126 | } 127 | 128 | TVector eta; 129 | TVector Heta; 130 | 131 | while(!checkStoppingCriterionReached()) { 132 | 133 | double Delta = iteration_summary.trust_region_radius; 134 | 135 | //随机eta初始值 136 | if (options.use_rand) { 137 | eta = problem->M()->lincomb(x, 1e-6, problem->M()->randvec(x)); 138 | while (problem->M()->norm(x, eta) > Delta) { 139 | eta = problem->M()->lincomb(x, 1.220703125e-4 /*sqrt(sqrt(eps))*/, eta); 140 | } 141 | } else { 142 | eta = problem->M()->zerovec(x); 143 | } 144 | 145 | double tCG_begin = wallTimeInSeconds(); 146 | tCGResult res = tCG.solve(x, fgradx, eta, Delta); 147 | double tCG_elapsed = wallTimeInSeconds()-tCG_begin; 148 | eta = res.eta; 149 | Heta = res.Heta; 150 | 151 | iteration_summary.tcg_iterations = res.loopCount; 152 | iteration_summary.tcg_reason = res.stopReason; 153 | 154 | //检查eta和Heta跟柯西点结果比较~ 155 | if (options.use_rand) { 156 | iteration_summary.used_cauchy = false; 157 | TVector Hg = problem->getHessian(x, fgradx); 158 | Scalar g_Hg = problem->M()->inner(x, fgradx, Hg); 159 | double tau_c; 160 | //g_Hg小于0,使用更小的stepsize。 161 | if (g_Hg <= 0) { 162 | tau_c = 1; 163 | } else { 164 | tau_c = std::min(std::pow(norm_grad, 3)/(Delta*g_Hg), 1); 165 | } 166 | 167 | TVector eta_c = problem->M()->lincomb(x, -tau_c * Delta / norm_grad, fgradx); 168 | TVector Heta_c = problem->M()->lincomb(x, -tau_c * Delta / norm_grad, Hg); 169 | 170 | Scalar mdle = fx + problem->M()->inner(x, fgradx, eta) + 0.5*problem->M()->inner(x, Heta, eta); 171 | Scalar mdlec = fx + problem->M()->inner(x, fgradx, eta_c) + 0.5*problem->M()->inner(x, Heta_c, eta_c); 172 | if (mdlec < mdle) { 173 | eta = eta_c; 174 | Heta = Heta_c; 175 | iteration_summary.used_cauchy = true; 176 | std::cout << "used_cauchy" << std::endl; 177 | } 178 | } 179 | 180 | //下一个点x候选 181 | MPoint x_prop = problem->M()->retr(x, eta); 182 | Scalar fx_prop = problem->getCost(x_prop); 183 | 184 | //计算信赖半径 185 | //rho的分子rhonum 186 | Scalar rhonum = fx - fx_prop; 187 | //方便计算rho的分母rhoden 188 | TVector vecrho = problem->M()->lincomb(x, 1, fgradx, .5, Heta); 189 | Scalar rhoden = -problem->M()->inner(x, eta, vecrho); 190 | 191 | //这个是在x在接近收敛时分子分母都比较小,所以会有扰动。 192 | //为了去掉扰动的影响都会加一个比较小的数 193 | Scalar rho_reg = std::max(1., std::abs(fx)) * 2.220446049250313e-16 /*eps*/ * options.rho_regularization; 194 | rhonum = rhonum + rho_reg; 195 | rhoden = rhoden + rho_reg; 196 | 197 | if (options.debug > 0) { 198 | printf("DBG: rhonum : % 3.6e\n", rhonum); 199 | printf("DBG: rhoden : % 3.6e\n", rhoden); 200 | } 201 | 202 | bool model_decreased = (rhoden >= 0); //TODO: 是否应该为>>> bool model_decreased = (rhonum >= 0) 203 | double rho = rhonum / rhoden; 204 | iteration_summary.trust_region_ratio = rho; 205 | 206 | if (options.debug > 0) { 207 | printf("DBG: new f(x) : % 3.6e\n", fx_prop); 208 | printf("DBG: used rho : % 3.6e\n", rho); 209 | } 210 | 211 | //根据rho调整信赖域 212 | if (rho < 0.25 || !model_decreased || std::isnan(rho)) { 213 | if (std::isnan(rho)) { 214 | std::cout << "WARNING: rho is nan" << std::endl; 215 | } 216 | iteration_summary.trust_region_radius /= 4; 217 | iteration_summary.consecutive_TRplus = 0; 218 | iteration_summary.consecutive_TRminus += 1; 219 | } else if (rho > 0.75 && (res.stopReason == NEGATIVE_CURVATURE || res.stopReason == EXCEEDED_TRUST_REGION)) { 220 | iteration_summary.trust_region_radius = std::min(2*iteration_summary.trust_region_radius, 221 | options.Delta_bar); 222 | iteration_summary.consecutive_TRplus += 1; 223 | iteration_summary.consecutive_TRminus = 0; 224 | } else { 225 | iteration_summary.consecutive_TRplus = 0; 226 | iteration_summary.consecutive_TRminus = 0; 227 | } 228 | 229 | //确定"接受"还是"拒绝" 230 | if (model_decreased && rho > options.rho_prime) { 231 | 232 | if (1) { 233 | problem->iterationCallback(x_prop); 234 | fx_prop = problem->getCost(x_prop); 235 | } 236 | 237 | // 238 | x = x_prop; 239 | fx = fx_prop; 240 | fgradx = problem->getGradient(x); 241 | norm_grad = problem->M()->norm(x, fgradx); 242 | 243 | iteration_summary.cost = fx; 244 | CHECK(rhonum >= 0); 245 | iteration_summary.cost_change = rhonum; 246 | iteration_summary.gradient_norm = norm_grad; 247 | iteration_summary.step_is_accept = true; 248 | } else { 249 | iteration_summary.cost_change = 0; 250 | iteration_summary.step_is_accept = false; 251 | } 252 | 253 | iteration_summary.tcg_time = tCG_elapsed; 254 | iteration_summary.total_time = wallTimeInSeconds() - start_time_in_secs; 255 | iteration_summary.iteration += 1; 256 | solver_summary->iterations.push_back(iteration_summary); 257 | 258 | if (options.verbosity) { 259 | std::string output= 260 | stringPrintf("% 4d %s % 3.2e % 3.2e % 3.2e % 3.2e % 4d % 3.2e % 3.2e %s", 261 | iteration_summary.iteration, 262 | (iteration_summary.step_is_accept ? "acc" : "REJ"), 263 | iteration_summary.cost, 264 | iteration_summary.gradient_norm, 265 | iteration_summary.trust_region_ratio, 266 | iteration_summary.trust_region_radius, 267 | iteration_summary.tcg_iterations, 268 | iteration_summary.tcg_time, 269 | iteration_summary.total_time, 270 | TCGStopReasonStrings[iteration_summary.tcg_reason].c_str() 271 | ); 272 | 273 | std::cout << output << std::endl; 274 | } 275 | } 276 | 277 | } 278 | 279 | private: 280 | IterationSummary iteration_summary; 281 | typename Problem::Ptr problem; 282 | Options options; 283 | TruncatedConjugateGradient tCG; 284 | 285 | Summary* solver_summary; 286 | double start_time_in_secs; 287 | 288 | }; 289 | 290 | } // namespace manopt 291 | 292 | 293 | #endif // TRUST_REGION_H_ 294 | -------------------------------------------------------------------------------- /src/common.cc: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | 4 | namespace manopt { 5 | 6 | double wallTimeInSeconds() { 7 | timeval time_val; 8 | gettimeofday(&time_val, nullptr); 9 | return (time_val.tv_sec + time_val.tv_usec * 1e-6); 10 | } 11 | 12 | void stringAppendV(std::string* dst, const char* format, va_list ap) { 13 | // First try with a small fixed size buffer 14 | char space[1024]; 15 | 16 | // It's possible for methods that use a va_list to invalidate 17 | // the data in it upon use. The fix is to make a copy 18 | // of the structure before using it and use that copy instead. 19 | va_list backup_ap; 20 | va_copy(backup_ap, ap); 21 | int result = vsnprintf(space, sizeof(space), format, backup_ap); 22 | va_end(backup_ap); 23 | 24 | if (result < (int)sizeof(space)) { 25 | if (result >= 0) { 26 | // Normal case -- everything fit. 27 | dst->append(space, result); 28 | return; 29 | } 30 | 31 | if (result < 0) { 32 | // Just an error. 33 | return; 34 | } 35 | } 36 | 37 | // Increase the buffer size to the size requested by vsnprintf, 38 | // plus one for the closing \0. 39 | int length = result + 1; 40 | char* buf = new char[length]; 41 | 42 | // Restore the va_list before we use it again 43 | va_copy(backup_ap, ap); 44 | result = vsnprintf(buf, length, format, backup_ap); 45 | va_end(backup_ap); 46 | 47 | if (result >= 0 && result < length) { 48 | // It fit 49 | dst->append(buf, result); 50 | } 51 | delete[] buf; 52 | } 53 | 54 | std::string stringPrintf(const char* format, ...) { 55 | va_list ap; 56 | va_start(ap, format); 57 | std::string result; 58 | stringAppendV(&result, format, ap); 59 | va_end(ap); 60 | return result; 61 | } 62 | 63 | } // namespace manopt 64 | -------------------------------------------------------------------------------- /src/loss.cc: -------------------------------------------------------------------------------- 1 | #include "loss.hpp" 2 | #include 3 | #include 4 | #include 5 | 6 | namespace manopt { 7 | 8 | double TrivialLOSS::v(double v) const { 9 | return v; 10 | } 11 | 12 | double TrivialLOSS::j(double v) const { 13 | return 1.0; 14 | } 15 | 16 | double HuberLOSS::v(double v) const { 17 | if (v > b_) { 18 | return 2.0 * a_ * sqrt(v) - b_; 19 | } else { 20 | return v; 21 | } 22 | } 23 | 24 | double HuberLOSS::j(double v) const { 25 | if (v > b_) { 26 | return std::max(std::numeric_limits::min(), a_ / sqrt(v)); 27 | } else { 28 | return 1.0; 29 | } 30 | } 31 | 32 | double CauchyLOSS::v(double v) const { 33 | return b_*log(1.0 + v * c_); 34 | } 35 | 36 | double CauchyLOSS::j(double v) const { 37 | const double inv = 1.0 / (1.0 + v * c_); 38 | return std::max(std::numeric_limits::min(), inv); 39 | } 40 | 41 | } // namespace manopt -------------------------------------------------------------------------------- /src/main.cc: -------------------------------------------------------------------------------- 1 | #include "rotation.hpp" 2 | #include "euclidean.hpp" 3 | #include "product_manifold.hpp" 4 | #include 5 | #include 6 | 7 | using namespace manopt; 8 | 9 | 10 | int main() { 11 | int k = 3; 12 | RotationX R(k); 13 | EuclideanX t(k); 14 | ProductManifold M(R, t); 15 | 16 | using MType = decltype(M); 17 | MType::MPoint x = M.rand(); 18 | std::cout << x.first << std::endl; 19 | std::cout << x.second << std::endl; 20 | 21 | return 0; 22 | } -------------------------------------------------------------------------------- /src/minimizer.cc: -------------------------------------------------------------------------------- 1 | #include "tcg.hpp" 2 | #include "common.hpp" 3 | #include "minimizer.hpp" 4 | 5 | namespace manopt { 6 | 7 | Summary::Summary() { 8 | } 9 | 10 | std::string Summary::fullReport() const { 11 | std::string output; 12 | for (uint i=0; i 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | using namespace manopt; 13 | int constexpr N = 1000; 14 | using MType = Sphere; 15 | 16 | class RQCostFunction : public GradientCostFunction { 17 | public: 18 | using Scalar = typename MType::Scalar; 19 | using MPoint = typename MType::MPoint; 20 | using TVector = typename MType::TVector; 21 | using MPtr = typename MType::Ptr; 22 | 23 | RQCostFunction(const MPtr& manifold_, const Eigen::MatrixXd& A_) 24 | : manifold(manifold_), A(A_) {} 25 | 26 | Scalar cost(const MPoint& x) const override { 27 | Eigen::MatrixXd v = -x.transpose() * A * x; 28 | return v(0, 0); 29 | } 30 | 31 | TVector gradient(const MPoint& x) const override { 32 | TVector grad = -2 * A * x; 33 | return manifold->proj(x, grad); 34 | } 35 | 36 | private: 37 | MPtr manifold; 38 | Eigen::MatrixXd A; 39 | }; 40 | 41 | int main() { 42 | MType::Ptr M = std::make_shared(); 43 | typedef MType::MPoint MPoint; 44 | 45 | for (int i = 0; i < 10; i++) { 46 | srand(i); 47 | std::cout << " i = " << i << std::endl; 48 | Eigen::MatrixXd B = Eigen::MatrixXd::Random(N, N); 49 | Eigen::MatrixXd A = 0.5 * (B.transpose() + B); 50 | 51 | // A << 0.5377, 1.3480, -1.3462, 1.3480, 0.3188, -0.4825, -1.3462, -0.4825, 52 | // 3.5784; 53 | 54 | // std::cout << "A = " << A << std::endl; 55 | Problem::Ptr problem = std::make_shared>(); 56 | problem->setManifold(M); 57 | 58 | std::shared_ptr> func = 59 | std::make_shared(M, A); 60 | problem->setGradientCostFunction(func); 61 | 62 | MPoint x0 = M->rand(); 63 | // std::cout << x0 << std::endl; 64 | 65 | // std::cout << "x_old = " << x0 << std::endl; 66 | ConjugateGradient cg(problem); 67 | TrustRegion tr(problem); 68 | 69 | Summary summary; 70 | double start = wallTimeInSeconds(); 71 | MPoint x1 = x0; 72 | tr.solve(x1, &summary); 73 | std::cout << wallTimeInSeconds() - start << std::endl; 74 | start = wallTimeInSeconds(); 75 | MPoint x2 = x0; 76 | cg.solve(x2, &summary); 77 | std::cout << wallTimeInSeconds() - start << std::endl; 78 | if ((x2-x1).norm() > 1e-3) { 79 | throw std::runtime_error("> 1e-3"); 80 | } 81 | 82 | // if (!summary.isSolutionUsable()) { 83 | // break; 84 | // } 85 | } 86 | 87 | 88 | return 0; 89 | } -------------------------------------------------------------------------------- /src/tcg.cc: -------------------------------------------------------------------------------- 1 | #include "tcg.hpp" 2 | 3 | namespace manopt { 4 | 5 | std::array TCGStopReasonStrings = { 6 | "negative curvature", 7 | "exceeded trust region", 8 | "reached target residual-kappa (linear)", 9 | "reached target residual-theta (superlinear)", 10 | "maximum inner iterations", 11 | "model increased" 12 | }; 13 | 14 | } // namespace manopt -------------------------------------------------------------------------------- /src/tcg_test.cc: -------------------------------------------------------------------------------- 1 | #include "rotation.hpp" 2 | #include "problem.hpp" 3 | #include "cost_func.hpp" 4 | #include "tcg.hpp" 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace manopt; 10 | 11 | using MType = Rotation; 12 | 13 | class ICPCostFunction : public GradientCostFunction { 14 | public: 15 | using Scalar = typename MType::Scalar; 16 | using MPoint = typename MType::MPoint; 17 | using TVector = typename MType::TVector; 18 | using MPtr = typename MType::Ptr; 19 | 20 | ICPCostFunction(const MPtr& manifold_, 21 | const std::vector& xx_, 22 | const std::vector& yy_) : 23 | manifold(manifold_), xx(xx_), yy(yy_) {} 24 | 25 | Scalar Cost(const MPoint& Rx) { 26 | Scalar err = 0; 27 | for (int i=0; ihat(bv); 45 | grad += 2*manifold->inner(Rx, base, K)*base; 46 | } 47 | return -grad; 48 | } 49 | 50 | private: 51 | MPtr manifold; 52 | std::vector xx; 53 | std::vector yy; 54 | 55 | }; 56 | 57 | int main() { 58 | MType::Ptr M = std::make_shared(); 59 | typedef MType::MPoint MPoint; 60 | typedef MType::TVector TVector; 61 | 62 | Problem::Ptr problem = std::make_shared>(); 63 | problem->SetManifold(M); 64 | 65 | std::vector xx; 66 | std::vector yy; 67 | 68 | auto x = M->rand(); 69 | int N=10; 70 | //std::cout << std::fixed << std::setprecision(10) << "rot = " << x << std::endl; 71 | for (int i=0; i> func = std::make_shared(M, xx , yy); 79 | problem->SetGradientCostFunction(func); 80 | 81 | // // auto* minimizer = Minimizer::Create(MinimizerType::TRUST_REGION); 82 | 83 | // // MType x=M.rand(); 84 | // // minimizer.Minimize(M, problem, x); 85 | 86 | TruncatedConjugateGradient tCG(problem, 1, 0.1, false, 1, 3); 87 | MPoint xp; 88 | // xp << -0.556901566919582, 0.569921486505996, 0.604193796675560, 89 | // 0.119423032618585, 0.774822686777889, -0.620796217149530, 90 | // -0.821948163876581, -0.273567730618556, -0.499561720666789; 91 | 92 | xp << -0.610781863867384, 0.754507303162354, 0.240133804878356, 93 | 0.215023383875814, 0.449933313925957, -0.866790030749201, 94 | -0.762043607123162, -0.477785247254745, -0.437047821580708; 95 | 96 | TVector grad; 97 | // grad << 0, 37.2910299761626, -168.4016299926723, 98 | // -37.2910299761626, 0, 636.6724794533905, 99 | // 168.4016299926723, -636.6724794533905, 0; 100 | grad << 0, 93.1050313081506, -63.9490650111049, 101 | -93.1050313081506, 0, 407.9876545622247, 102 | 63.9490650111049, -407.9876545622247, 0; 103 | 104 | TVector eta = TVector::Zero(); 105 | 106 | tCG.Solve(xp, grad, eta, 1.360349523175663); 107 | // , TVector grad, TVector eta, double tr_radius 108 | 109 | return 0; 110 | } -------------------------------------------------------------------------------- /src/tr_test.cc: -------------------------------------------------------------------------------- 1 | #include "rotation.hpp" 2 | #include "problem.hpp" 3 | #include "cost_func.hpp" 4 | #include "tcg.hpp" 5 | #include "trust_region.hpp" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace manopt; 12 | 13 | using MType = RotationX; 14 | 15 | class ICPCostFunction : public GradientCostFunction { 16 | public: 17 | using Scalar = typename MType::Scalar; 18 | using MPoint = typename MType::MPoint; 19 | using TVector = typename MType::TVector; 20 | using MPtr = typename MType::Ptr; 21 | 22 | ICPCostFunction(const MPtr& manifold_, 23 | const std::vector& xx_, 24 | const std::vector& yy_) : 25 | manifold(manifold_), xx(xx_), yy(yy_) {} 26 | 27 | Scalar cost(const MPoint& Rx) const override { 28 | Scalar err = 0; 29 | for (uint i=0; izerovec(Rx); 38 | TVector K = manifold->zerovec(Rx); 39 | for (uint i=0; ihat(bv); 47 | //grad += 2*manifold->inner(Rx, base, K)*base; 48 | grad += 2*(base*K).trace()*base; 49 | } 50 | return grad; 51 | } 52 | 53 | private: 54 | MPtr manifold; 55 | std::vector xx; 56 | std::vector yy; 57 | 58 | }; 59 | 60 | int main() { 61 | MType::Ptr M = std::make_shared(1); 62 | typedef MType::MPoint MPoint; 63 | 64 | Problem::Ptr problem = std::make_shared>(); 65 | problem->setManifold(M); 66 | 67 | std::vector xx; 68 | std::vector yy; 69 | 70 | auto x = M->rand(); 71 | int N=10; 72 | srand (2); 73 | //std::cout << std::fixed << std::setprecision(10) << "rot = " << x << std::endl; 74 | for (int i=0; i> func = std::make_shared(M, xx , yy); 82 | problem->setGradientCostFunction(func); 83 | 84 | MPoint xp = M->rand(); 85 | 86 | xp << -0.556901566919582, 0.569921486505996, 0.604193796675560, 87 | 0.119423032618585, 0.774822686777889, -0.620796217149530, 88 | -0.821948163876581, -0.273567730618556, -0.499561720666789; 89 | 90 | TrustRegion tr(problem); 91 | Summary summary; 92 | tr.solve(xp, &summary); 93 | std::cout << xp << std::endl; 94 | std::cout << summary.fullReport() << std::endl; 95 | 96 | return 0; 97 | } --------------------------------------------------------------------------------