├── .clang-format ├── .gitignore ├── .gitmodules ├── AddToCeres.cmake ├── CMakeLists.txt ├── LICENSES ├── BINDINGS_LICENSE └── CERES_LICENSE ├── README.md ├── cmake ├── FindCeres.cmake ├── FindGlog.cmake └── pytorch_stuff.cmake ├── docs ├── debugging_log.md └── pytorch_stuff.md ├── examples ├── __init__.py ├── ceres_hello_world_analytic_diff.py ├── ceres_hello_world_example.py ├── ceres_rosenbrock_autodiff.py ├── ceres_rosenbrock_example.py ├── ceres_simple_bundle_adjuster.py ├── hello_world_python_autodiff.py ├── manual_setting_sys_path.py ├── pose_graph_slam_example.py ├── pytorch_torchscript_example.py └── utilities.py ├── pyproject.toml ├── python_bindings ├── ceres_examples_module.cpp ├── custom_cpp_cost_functions.cpp ├── python_module.cpp ├── pytorch_cost_function.cpp ├── pytorch_cost_function.h └── pytorch_module.cpp ├── python_tests ├── __init__.py ├── debug_functions.py ├── loss_function_test.py └── test_python_defined_cost_function.py ├── setup.cfg ├── setup.py └── tests └── pytorch_test.cpp /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | Standard: Cpp11 3 | BinPackArguments: false 4 | BinPackParameters: false 5 | PointerAlignment: Left 6 | DerivePointerAlignment: false 7 | CommentPragmas: NOLINT.* 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | 3 | #Specific folders: 4 | x86_64_windows_msvc/* 5 | x86_64_linux_gnu/* 6 | bin/* 7 | build/* 8 | cmake-build-debug/* 9 | 10 | 11 | # C++ files 12 | *.a 13 | *.so 14 | *.o 15 | 16 | # python files 17 | *.pyc 18 | *.npz 19 | *.npy 20 | *.pytest_cache 21 | *.p 22 | 23 | # Editor files 24 | 25 | #QT 26 | *.files 27 | *.includes 28 | *.creator 29 | 30 | # Visual studio 31 | *.sln 32 | *.VC.db 33 | scripts/\.vs/scripts/v15/\.suo 34 | \.vs/nl/v15/ 35 | \.vs/ 36 | 37 | # CLion 38 | .idea 39 | .idea/* 40 | 41 | 42 | #Other filetypes 43 | *.xml 44 | *.iml 45 | *.orig 46 | *.jpg 47 | *.png 48 | *.autosave 49 | *.zip 50 | *.dir 51 | *.pyproj 52 | *.csv 53 | *.pt 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pybind11"] 2 | path = pybind11 3 | url = https://github.com/pybind/pybind11 4 | -------------------------------------------------------------------------------- /AddToCeres.cmake: -------------------------------------------------------------------------------- 1 | 2 | 3 | include_directories(${PROJECT_SOURCE_DIR}/ceres_python_bindings/pybind11/include) 4 | include_directories(${PROJECT_SOURCE_DIR}/internal) 5 | 6 | set(PYBIND11_CPP_STANDARD -std=c++11) 7 | add_subdirectory(${PROJECT_SOURCE_DIR}/ceres_python_bindings/pybind11) 8 | pybind11_add_module(PyCeres ${PROJECT_SOURCE_DIR}/ceres_python_bindings/python_bindings/python_module.cpp 9 | ${PROJECT_SOURCE_DIR}/ceres_python_bindings/python_bindings/ceres_examples_module.cpp 10 | ${PROJECT_SOURCE_DIR}/ceres_python_bindings/python_bindings/custom_cpp_cost_functions.cpp) 11 | 12 | 13 | target_link_libraries(PyCeres PRIVATE ceres) 14 | 15 | message(STATUS "Python Bindings for Ceres(PyCeres) have been added") 16 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.02) 2 | project(PyCeres) 3 | 4 | set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") 5 | 6 | find_package(Ceres REQUIRED) 7 | find_package(Glog REQUIRED) 8 | find_package(LAPACK) 9 | 10 | include_directories(${CERES_INCLUDE_DIR}) 11 | 12 | # Since we are linking to ceres then we have to access the compile flags used to build it. 13 | add_definitions("-DCERES_IS_LINKED") 14 | 15 | add_subdirectory(pybind11) 16 | pybind11_add_module(PyCeres python_bindings/python_module.cpp 17 | python_bindings/ceres_examples_module.cpp 18 | python_bindings/custom_cpp_cost_functions.cpp 19 | # python_bindings/pytorch_cost_function.h 20 | # python_bindings/pytorch_cost_function.cpp 21 | ) 22 | 23 | find_package(Eigen3 3.3 CONFIG REQUIRED 24 | HINTS ${HOMEBREW_INCLUDE_DIR_HINTS}) 25 | if (EIGEN3_FOUND) 26 | message("-- Found Eigen version ${EIGEN3_VERSION_STRING}: ${EIGEN3_INCLUDE_DIRS}") 27 | if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64.*|AARCH64.*)" AND 28 | EIGEN3_VERSION_STRING VERSION_LESS 3.3.4) 29 | # As per issue #289: https://github.com/ceres-solver/ceres-solver/issues/289 30 | # the bundle_adjustment_test will fail for Eigen < 3.3.4 on aarch64. 31 | message(FATAL_ERROR "-- Ceres requires Eigen version >= 3.3.4 on aarch64. " 32 | "Detected version of Eigen is: ${EIGEN3_VERSION_STRING}.") 33 | endif() 34 | include_directories(${EIGEN3_INCLUDE_DIR}) 35 | endif() 36 | 37 | target_link_libraries(PyCeres PRIVATE ${CERES_LIBRARY} ${GLOG_LIBRARY} ${LAPACK_LIBRARIES}) 38 | 39 | ############# 40 | ## Install ## 41 | ############# 42 | 43 | install( 44 | TARGETS PyCeres 45 | EXPORT PyCeresTargets 46 | LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} 47 | ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} 48 | RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} 49 | PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) 50 | 51 | # Install the exported targets 52 | install( 53 | EXPORT PyCeresTargets 54 | FILE PyCeresTargets.cmake 55 | NAMESPACE PyCeres:: 56 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/PyCeres) 57 | 58 | 59 | set(PyCeres_BUILDDIR "${CMAKE_BINARY_DIR}/PyCeres") 60 | # For the python package we need an init file 61 | file( 62 | GENERATE 63 | OUTPUT "${PyCeres_BUILDDIR}/__init__.py" 64 | CONTENT "from PyCeres.PyCeres import *\n") 65 | 66 | # Install the __init__.py file 67 | install( 68 | FILES "${PyCeres_BUILDDIR}/__init__.py" 69 | DESTINATION ${CMAKE_INSTALL_PREFIX}) 70 | 71 | #install( 72 | # TARGETS pybind11_bindings 73 | # COMPONENT bindings 74 | # LIBRARY DESTINATION ${MYMATH_INSTALL_PREFIX} 75 | # ARCHIVE DESTINATION ${MYMATH_INSTALL_PREFIX} 76 | # RUNTIME DESTINATION ${MYMATH_INSTALL_PREFIX}) 77 | 78 | -------------------------------------------------------------------------------- /LICENSES/BINDINGS_LICENSE: -------------------------------------------------------------------------------- 1 | Ceres Solver Python Bindings 2 | Copyright Nikolaus Mitchell. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | * Neither the name of the copyright holder nor the names of its contributors 13 | may be used to endorse or promote products derived from this software 14 | without specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 20 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 | POSSIBILITY OF SUCH DAMAGE. 27 | 28 | Author: nikolausmitchell@gmail.com (Nikolaus Mitchell) 29 | -------------------------------------------------------------------------------- /LICENSES/CERES_LICENSE: -------------------------------------------------------------------------------- 1 | Ceres Solver - A fast non-linear least squares minimizer 2 | Copyright 2015 Google Inc. All rights reserved. 3 | http://ceres-solver.org/ 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | * Neither the name of Google Inc. nor the names of its contributors may be 14 | used to endorse or promote products derived from this software without 15 | specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ceres Python Wrapper 2 | 3 | This project uses pybind11 to wrap Ceres with a python interface. 4 | 5 | ## Build Setup 6 | There are different ways to build this library. The safest way is build it along with the Ceres library. 7 | 8 | ### Recommended: Build Alongside Ceres 9 | 10 | Clone the repository at https://github.com/Edwinem/ceres_python_bindings into 11 | your *ceres-solver* folder. 12 | 13 | Initialize and download the pybind11 submodule 14 | ```shell 15 | git clone https://github.com/Edwinem/ceres_python_bindings 16 | cd ceres_python_bindings 17 | git submodule init 18 | git submodule update 19 | ``` 20 | 21 | If you cloned it somewhere else then you must now copy and paste the 22 | *ceres_python_bindings* directory to your *ceres-solver* directory. 23 | 24 | Your ceres directory should now look something like this. 25 | ``` 26 | ceres-solver/ 27 | │ 28 | ├── CMakeLists.txt 29 | ├── include 30 | ├── ... 31 | │ 32 | ├── ceres_python_bindings/ - THIS REPOSITORY 33 | │ ├── pybind11 34 | │ ├── python_bindings 35 | │ ├── ... 36 | │ └── AddToCeres.cmake - file to include in Ceres CMakeLists.txt 37 | ``` 38 | 39 | Open up your *ceres-solver/CMakeLists.txt* and add the following to the end 40 | of the file. 41 | 42 | ``` 43 | include(ceres_python_bindings/AddToCeres.cmake) 44 | ``` 45 | 46 | If everything was successful then when you call *cmake* in your build folder at the 47 | end it should output 48 | 49 | ``` 50 | -- Python Bindings for Ceres(PyCeres) have been added 51 | ``` 52 | 53 | Build Ceres as you would normally. To specifically build the bindings you should 54 | call _make PyCeres_ . 55 | 56 | ### Build separately and link to Ceres 57 | 58 | Note that these methods assume that you have built and installed the Ceres library. Either through `sudo apt-get ` 59 | or by doing `make install`. 60 | 61 | * You might have to modify the CMakeLists.txt to link to extra libraries such as suitesparse depending on how 62 | your Ceres library was built. 63 | 64 | #### Normal Cmake 65 | 66 | Clone the project and initialize the submodules. Call cmake as you would normally. 67 | 68 | ```shell 69 | cd ceres_python_bindings 70 | git submodule init 71 | git submodule update 72 | mkdir build 73 | cd build 74 | cmake .. 75 | make 76 | ``` 77 | 78 | #### Python setup.py 79 | 80 | This uses cmake-build-extension to call the cmake commands with python's setuptools. 81 | 82 | Activate your python virtual env. Within the *ceres_python_bindings* folder run `pip install .`. This will 83 | call the `setup.py` file and install PyCeres to your virtual environment. 84 | 85 | If this fails then your best bet is to use the normal cmake method and debug from there. 86 | 87 | ## How to import PyCeres 88 | 89 | ### Built with setuptools 90 | If you used the `setup.py` with *pip* then the library should have been installed to your virtualenv, and you 91 | can simply install it with 92 | 93 | ```python 94 | import PyCeres 95 | ``` 96 | 97 | ### Built with cmake 98 | 99 | Somewhere a file called **PyCeres.so** should have been built. It should be in your build directory. 100 | It probably more likely looks something like this **PyCeres.cpython-36m-x86_64-linux-gnu.so**. 101 | Mark down the location of this file. This location is what you have to add to 102 | python **sys.path** in order to use the library. An example of how to do this can 103 | be seen below. 104 | 105 | ```python 106 | pyceres_location="..." 107 | import sys 108 | sys.path.insert(0, pyceres_location) 109 | ``` 110 | 111 | After this you can now run 112 | ```python 113 | import PyCeres 114 | ``` 115 | to utilize the library. 116 | 117 | Another option is to copy and paste the **PyCeres.so** file to your virtualenv/lib folder, which allows you to skip 118 | the sys path modifications. 119 | 120 | ## How to use PyCeres 121 | 122 | You should peruse some of the examples listed below. It works almost exactly 123 | like Ceres in C++. The only care you have to take is that the parameters you 124 | pass to the AddResidualBlock() function is a numpy array. 125 | 126 | ### Basic HelloWorld 127 | 128 | Code for this example can be found in *examples/ceres_hello_world_example.py* 129 | 130 | This example is the same as the hello world example from Ceres. 131 | 132 | ```python 133 | import PyCeres # Import the Python Bindings 134 | import numpy as np 135 | 136 | # The variable to solve for with its initial value. 137 | initial_x = 5.0 138 | x = np.array([initial_x]) # Requires the variable to be in a numpy array 139 | 140 | # Here we create the problem as in normal Ceres 141 | problem = PyCeres.Problem() 142 | 143 | # Creates the CostFunction. This example uses a C++ wrapped function which 144 | # returns the Autodiffed cost function used in the C++ example 145 | cost_function = PyCeres.CreateHelloWorldCostFunction() 146 | 147 | # Add the costfunction and the parameter numpy array to the problem 148 | problem.AddResidualBlock(cost_function, None, x) 149 | 150 | # Setup the solver options as in normal ceres 151 | options = PyCeres.SolverOptions() 152 | # Ceres enums live in PyCeres and require the enum Type 153 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_QR 154 | options.minimizer_progress_to_stdout = True 155 | summary = PyCeres.Summary() 156 | # Solve as you would normally 157 | PyCeres.Solve(options, problem, summary) 158 | print(summary.BriefReport() + " \n") 159 | print("x : " + str(initial_x) + " -> " + str(x) + "\n") 160 | ``` 161 | 162 | ### CostFunction in Python 163 | 164 | This library allows you to create your own custom CostFunction in Python to be 165 | used with the Ceres Solver. 166 | 167 | An custom CostFunction in Python can be seen here. 168 | ```python 169 | # function f(x) = 10 - x. 170 | # Comes from ceres/examples/helloworld_analytic_diff.cc 171 | class QuadraticCostFunction(PyCeres.CostFunction): 172 | def __init__(self): 173 | # MUST BE CALLED. Initializes the Ceres::CostFunction class 174 | super().__init__() 175 | 176 | # MUST BE CALLED. Sets the size of the residuals and parameters 177 | self.set_num_residuals(1) 178 | self.set_parameter_block_sizes([1]) 179 | 180 | # The CostFunction::Evaluate(...) virtual function implementation 181 | def Evaluate(self,parameters, residuals, jacobians): 182 | x=parameters[0][0] 183 | 184 | residuals[0] = 10 - x 185 | 186 | if (jacobians!=None): # check for Null 187 | jacobians[0][0] = -1 188 | 189 | return True 190 | ``` 191 | 192 | Some things to be aware of for a custom CostFunction 193 | 194 | * residuals is a numpy array 195 | * parameters,jacobians are lists of numpy arrays ([arr1,arr2,...]) 196 | * Indexing works similar to Ceres C++. parameters[i] is the ith parameter block 197 | * You must always index into the list first. Even if it only has 1 value. 198 | * You must call the base constructor with super. 199 | 200 | ### CostFunction defined in C++ 201 | It is possible to define your custom CostFunction in C++ and utilize it within 202 | the python framework. In order to do this we provide a file **python_bindings/custom_cpp_cost_functions.cpp**. 203 | which provides a place to write your own wrapper code. The easiest way to do this 204 | is create an initialization function that creates your custom CostFunction class 205 | and returns a ceres::CostFunction* to it. That function should then be wrapped in 206 | the *void add_custom_cost_functions(py::module& m)* function. 207 | 208 | It should end up looking something like this. 209 | ```cpp 210 | 211 | #include 212 | 213 | // Create a function that initiliazes your CostFunction and returns a ceres::CostFunction* 214 | 215 | ceres::CostFunction* CreateCustomCostFunction(arg1,arg2,...){ 216 | return new CustomCostFunction(arg1,arg2,...); 217 | } 218 | 219 | // In file custom_cpp_cost_function add the following line 220 | 221 | void add_custom_cost_functions(py::module &m) { 222 | // .... 223 | m.def("CreateCustomCostFunction",&CreateCustomCostFunction); 224 | } 225 | 226 | ``` 227 | 228 | We provide a basic example of this in *custom_cpp_cost_functions.cpp*. 229 | Note you are responsible for ensuring that all the dependencies and includes are 230 | set correctly for your library. 231 | 232 | ### Running examples 233 | We provide a couple examples of how to use the library under *./python_tests*. 234 | They all assume the wrappers were built alongside Ceres for the PyCeres library. 235 | If you did not do this then you need to set the *PYCERES_LOCATION* environment 236 | variable. 237 | 238 | You need the following python libs to run these examples. 239 | 240 | **Required:** 241 | 242 | * numpy 243 | 244 | **Optional:** 245 | 246 | * pytest 247 | * jax 248 | 249 | 250 | ## Experimental PyTorch functionality 251 | 252 | 253 | 254 | ## Warnings: 255 | 256 | * Remember Ceres was designed with a C++ memory model. So you have to be careful 257 | when using it from Python. The main problem is that Python does not really have 258 | the concept of giving away ownership of memory. So it may try to delete something 259 | that Ceres still believes is valid. 260 | * I think for most stuff I setup the proper procedures that this doesn't happen ( 261 | e.g Ceres::Problem by default has Ownership turned off, cost_function can't be deleted 262 | until Problem is ,...) . But you never know what I missed. 263 | * Careful with wrapping AutodiffCostfunction. It takes over the memory of a cost 264 | functor which can cause errors. 265 | * Python has **GIL**. Therefore, cost functions written in Python have a fundamental 266 | slowdown, and can't be truly multithreaded. 267 | 268 | 269 | ## TODOs 270 | - [ ] The wrapper code that wraps the Evaluate pointers(residuals,parameters,..) 271 | needs a lot of improvement and optimization. We really need this to be a zero copy 272 | operation. 273 | - [ ] Wrap all the variables for Summary and other classes 274 | - [ ] LocalParameterizations and Lossfunctions need to be properly wrapped 275 | - [ ] Callbacks need to be wrapped 276 | - [ ] Investigate how to wrap a basic python function for evaluate rather than 277 | go through the CostFunction( something like in the C api). 278 | - [ ] Add docstrings for all the wrapped stuff 279 | - [X] Add a place for users to define their CostFunctions in C++ 280 | - [ ] Evaluate speed of Python Cost Function vs C++ 281 | - [ ] Clean up AddResidualBlock() and set up the correct error checks 282 | - [ ] Figure out how google log should work with Python 283 | - [ ] Figure out if Jax or PyTorch could somehow be integrated so that we use 284 | their tensor/numpy arrays. 285 | 286 | ## Status 287 | Custom Cost functions work 288 | 289 | ## LICENSE 290 | Same as Ceres New BSD. 291 | 292 | ## Credit 293 | This is just a wrapper over the hard work of the main 294 | [Ceres](http://ceres-solver.org/) project. All the examples derive from ceres-solver/examples 295 | -------------------------------------------------------------------------------- /cmake/FindCeres.cmake: -------------------------------------------------------------------------------- 1 | # - Find Ceres library 2 | # Find the native Ceres includes and library 3 | # This module defines 4 | # CERES_INCLUDE_DIRS, where to find ceres.h, Set when 5 | # CERES_INCLUDE_DIR is found. 6 | # CERES_LIBRARIES, libraries to link against to use Ceres. 7 | # CERES_ROOT_DIR, The base directory to search for Ceres. 8 | # This can also be an environment variable. 9 | # CERES_FOUND, If false, do not try to use Ceres. 10 | # 11 | # also defined, but not for general use are 12 | # CERES_LIBRARY, where to find the Ceres library. 13 | 14 | # If CERES_ROOT_DIR was defined in the environment, use it. 15 | IF(NOT CERES_ROOT_DIR AND NOT $ENV{CERES_ROOT_DIR} STREQUAL "") 16 | SET(CERES_ROOT_DIR $ENV{CERES_ROOT_DIR}) 17 | ENDIF() 18 | 19 | SET(_ceres_SEARCH_DIRS 20 | ${CERES_ROOT_DIR} 21 | /usr/local 22 | /sw # Fink 23 | /opt/local # DarwinPorts 24 | /opt/csw # Blastwave 25 | /opt/lib/ceres 26 | ) 27 | 28 | FIND_PATH(CERES_INCLUDE_DIR 29 | NAMES 30 | ceres/ceres.h 31 | HINTS 32 | ${_ceres_SEARCH_DIRS} 33 | PATH_SUFFIXES 34 | include 35 | ) 36 | 37 | FIND_PATH(CERES_CMAKE_DIR 38 | NAMES 39 | cmake/CeresConfig.cmake 40 | HINTS 41 | ${_ceres_SEARCH_DIRS} 42 | PATH_SUFFIXES 43 | lib64 lib 44 | ) 45 | 46 | FIND_LIBRARY(CERES_LIBRARY 47 | NAMES 48 | ceres 49 | HINTS 50 | ${_ceres_SEARCH_DIRS} 51 | PATH_SUFFIXES 52 | lib64 lib 53 | ) 54 | 55 | # handle the QUIETLY and REQUIRED arguments and set CERES_FOUND to TRUE if 56 | # all listed variables are TRUE 57 | INCLUDE(FindPackageHandleStandardArgs) 58 | FIND_PACKAGE_HANDLE_STANDARD_ARGS(Ceres DEFAULT_MSG 59 | CERES_LIBRARY CERES_INCLUDE_DIR) 60 | 61 | IF(CERES_FOUND) 62 | SET(CERES_LIBRARIES ${CERES_LIBRARY}) 63 | SET(CERES_INCLUDE_DIRS ${CERES_INCLUDE_DIR}) 64 | ENDIF(CERES_FOUND) 65 | 66 | MARK_AS_ADVANCED( 67 | CERES_INCLUDE_DIR 68 | CERES_LIBRARY 69 | ) -------------------------------------------------------------------------------- /cmake/FindGlog.cmake: -------------------------------------------------------------------------------- 1 | # Ceres Solver - A fast non-linear least squares minimizer 2 | # Copyright 2013 Google Inc. All rights reserved. 3 | # http://code.google.com/p/ceres-solver/ 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are met: 7 | # 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of Google Inc. nor the names of its contributors may be 14 | # used to endorse or promote products derived from this software without 15 | # specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | # POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Author: alexs.mac@gmail.com (Alex Stewart) 30 | # 31 | 32 | # FindGlog.cmake - Find Google glog logging library. 33 | # 34 | # This module defines the following variables: 35 | # 36 | # GLOG_FOUND: TRUE iff glog is found. 37 | # GLOG_INCLUDE_DIRS: Include directories for glog. 38 | # GLOG_LIBRARIES: Libraries required to link glog. 39 | # 40 | # The following variables control the behaviour of this module: 41 | # 42 | # GLOG_INCLUDE_DIR_HINTS: List of additional directories in which to 43 | # search for glog includes, e.g: /timbuktu/include. 44 | # GLOG_LIBRARY_DIR_HINTS: List of additional directories in which to 45 | # search for glog libraries, e.g: /timbuktu/lib. 46 | # 47 | # The following variables are also defined by this module, but in line with 48 | # CMake recommended FindPackage() module style should NOT be referenced directly 49 | # by callers (use the plural variables detailed above instead). These variables 50 | # do however affect the behaviour of the module via FIND_[PATH/LIBRARY]() which 51 | # are NOT re-called (i.e. search for library is not repeated) if these variables 52 | # are set with valid values _in the CMake cache_. This means that if these 53 | # variables are set directly in the cache, either by the user in the CMake GUI, 54 | # or by the user passing -DVAR=VALUE directives to CMake when called (which 55 | # explicitly defines a cache variable), then they will be used verbatim, 56 | # bypassing the HINTS variables and other hard-coded search locations. 57 | # 58 | # GLOG_INCLUDE_DIR: Include directory for glog, not including the 59 | # include directory of any dependencies. 60 | # GLOG_LIBRARY: glog library, not including the libraries of any 61 | # dependencies. 62 | 63 | # Called if we failed to find glog or any of it's required dependencies, 64 | # unsets all public (designed to be used externally) variables and reports 65 | # error message at priority depending upon [REQUIRED/QUIET/] argument. 66 | MACRO(GLOG_REPORT_NOT_FOUND REASON_MSG) 67 | UNSET(GLOG_FOUND) 68 | UNSET(GLOG_INCLUDE_DIRS) 69 | UNSET(GLOG_LIBRARIES) 70 | # Make results of search visible in the CMake GUI if glog has not 71 | # been found so that user does not have to toggle to advanced view. 72 | MARK_AS_ADVANCED(CLEAR GLOG_INCLUDE_DIR 73 | GLOG_LIBRARY) 74 | # Note _FIND_[REQUIRED/QUIETLY] variables defined by FindPackage() 75 | # use the camelcase library name, not uppercase. 76 | IF (Glog_FIND_QUIETLY) 77 | MESSAGE(STATUS "Failed to find glog - " ${REASON_MSG} ${ARGN}) 78 | ELSEIF (Glog_FIND_REQUIRED) 79 | MESSAGE(FATAL_ERROR "Failed to find glog - " ${REASON_MSG} ${ARGN}) 80 | ELSE() 81 | # Neither QUIETLY nor REQUIRED, use no priority which emits a message 82 | # but continues configuration and allows generation. 83 | MESSAGE("-- Failed to find glog - " ${REASON_MSG} ${ARGN}) 84 | ENDIF () 85 | ENDMACRO(GLOG_REPORT_NOT_FOUND) 86 | 87 | # Search user-installed locations first, so that we prefer user installs 88 | # to system installs where both exist. 89 | # 90 | # TODO: Add standard Windows search locations for glog. 91 | LIST(APPEND GLOG_CHECK_INCLUDE_DIRS 92 | /usr/local/include 93 | /usr/local/homebrew/include # Mac OS X 94 | /opt/local/var/macports/software # Mac OS X. 95 | /opt/local/include 96 | /usr/include) 97 | LIST(APPEND GLOG_CHECK_LIBRARY_DIRS 98 | /usr/local/lib 99 | /usr/local/homebrew/lib # Mac OS X. 100 | /opt/local/lib 101 | /usr/lib) 102 | 103 | # Search supplied hint directories first if supplied. 104 | FIND_PATH(GLOG_INCLUDE_DIR 105 | NAMES glog/logging.h 106 | PATHS ${GLOG_INCLUDE_DIR_HINTS} 107 | ${GLOG_CHECK_INCLUDE_DIRS}) 108 | IF (NOT GLOG_INCLUDE_DIR OR 109 | NOT EXISTS ${GLOG_INCLUDE_DIR}) 110 | GLOG_REPORT_NOT_FOUND( 111 | "Could not find glog include directory, set GLOG_INCLUDE_DIR " 112 | "to directory containing glog/logging.h") 113 | ENDIF (NOT GLOG_INCLUDE_DIR OR 114 | NOT EXISTS ${GLOG_INCLUDE_DIR}) 115 | 116 | FIND_LIBRARY(GLOG_LIBRARY NAMES glog 117 | PATHS ${GLOG_LIBRARY_DIR_HINTS} 118 | ${GLOG_CHECK_LIBRARY_DIRS}) 119 | IF (NOT GLOG_LIBRARY OR 120 | NOT EXISTS ${GLOG_LIBRARY}) 121 | GLOG_REPORT_NOT_FOUND( 122 | "Could not find glog library, set GLOG_LIBRARY " 123 | "to full path to libglog.") 124 | ENDIF (NOT GLOG_LIBRARY OR 125 | NOT EXISTS ${GLOG_LIBRARY}) 126 | 127 | # Mark internally as found, then verify. GLOG_REPORT_NOT_FOUND() unsets 128 | # if called. 129 | SET(GLOG_FOUND TRUE) 130 | 131 | # Glog does not seem to provide any record of the version in its 132 | # source tree, thus cannot extract version. 133 | 134 | # Catch case when caller has set GLOG_INCLUDE_DIR in the cache / GUI and 135 | # thus FIND_[PATH/LIBRARY] are not called, but specified locations are 136 | # invalid, otherwise we would report the library as found. 137 | IF (GLOG_INCLUDE_DIR AND 138 | NOT EXISTS ${GLOG_INCLUDE_DIR}/glog/logging.h) 139 | GLOG_REPORT_NOT_FOUND( 140 | "Caller defined GLOG_INCLUDE_DIR:" 141 | " ${GLOG_INCLUDE_DIR} does not contain glog/logging.h header.") 142 | ENDIF (GLOG_INCLUDE_DIR AND 143 | NOT EXISTS ${GLOG_INCLUDE_DIR}/glog/logging.h) 144 | # TODO: This regex for glog library is pretty primitive, we use lowercase 145 | # for comparison to handle Windows using CamelCase library names, could 146 | # this check be better? 147 | STRING(TOLOWER "${GLOG_LIBRARY}" LOWERCASE_GLOG_LIBRARY) 148 | IF (GLOG_LIBRARY AND 149 | NOT "${LOWERCASE_GLOG_LIBRARY}" MATCHES ".*glog[^/]*") 150 | GLOG_REPORT_NOT_FOUND( 151 | "Caller defined GLOG_LIBRARY: " 152 | "${GLOG_LIBRARY} does not match glog.") 153 | ENDIF (GLOG_LIBRARY AND 154 | NOT "${LOWERCASE_GLOG_LIBRARY}" MATCHES ".*glog[^/]*") 155 | 156 | # Set standard CMake FindPackage variables if found. 157 | IF (GLOG_FOUND) 158 | SET(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR}) 159 | SET(GLOG_LIBRARIES ${GLOG_LIBRARY}) 160 | ENDIF (GLOG_FOUND) 161 | 162 | # Handle REQUIRED / QUIET optional arguments. 163 | INCLUDE(FindPackageHandleStandardArgs) 164 | FIND_PACKAGE_HANDLE_STANDARD_ARGS(Glog DEFAULT_MSG 165 | GLOG_INCLUDE_DIRS GLOG_LIBRARIES) 166 | 167 | # Only mark internal variables as advanced if we found glog, otherwise 168 | # leave them visible in the standard GUI for the user to set manually. 169 | IF (GLOG_FOUND) 170 | MARK_AS_ADVANCED(FORCE GLOG_INCLUDE_DIR 171 | GLOG_LIBRARY) 172 | ENDIF (GLOG_FOUND) -------------------------------------------------------------------------------- /cmake/pytorch_stuff.cmake: -------------------------------------------------------------------------------- 1 | # Placeholder for old Pytorch stuff 2 | 3 | list(APPEND PYTORCH_FILES "") 4 | option(WITH_PYTORCH "Enables PyTorch defined Cost Functions" OFF) 5 | 6 | if(${WITH_PYTORCH}) 7 | #PyTorch by default is build with old C++ ABI. So we use that option here 8 | add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) 9 | list(APPEND CMAKE_PREFIX_PATH "$ENV{HOME}/programming/python/env/lib/python3.6/site-packages/torch/share/cmake/Torch") 10 | find_package(Torch REQUIRED) 11 | include_directories(${TORCH_INCLUDE_DIRS}) 12 | 13 | add_definitions("-DWITH_PYTORCH") 14 | endif(${WITH_PYTORCH}) 15 | if(${WITH_PYTORCH}) 16 | list(APPEND PYTORCH_FILES python_bindings/pytorch_cost_function.h 17 | python_bindings/pytorch_cost_function.cpp 18 | python_bindings/pytorch_module.cpp) 19 | endif(${WITH_PYTORCH}) 20 | if(${WITH_PYTORCH}) 21 | target_link_libraries(PyCeres PRIVATE "${TORCH_LIBRARIES}") 22 | add_executable(torchscript tests/pytorch_test.cpp) 23 | target_link_libraries(torchscript "${TORCH_LIBRARIES}") 24 | set_property(TARGET torchscript PROPERTY CXX_STANDARD 14) 25 | endif(${WITH_PYTORCH}) -------------------------------------------------------------------------------- /docs/debugging_log.md: -------------------------------------------------------------------------------- 1 | # Debugging log 2 | 3 | Contains some notes on problems encoutered trying to wrap Ceres. Could be 4 | usefull for people in the future. 5 | 6 | * AddResidualBlock returns ResidualBlock*. If you use AddResidualBlock and don't 7 | capture the result in a local variable then python would capture it and promptly 8 | delete it. Thus deleting the ResidualBlock in your problem. Fix was to add 9 | py::return_value_policy::reference so python doesn't manage that memory. 10 | * AutodiffCostFunction takes ownership of the CostFunctor. So if the 11 | CostFunctor is created in Python then a double free will happen as the python 12 | garbage collector will delete the AutodiffCostFunction(deletes CostFunctor) and 13 | the CostFunctor. 14 | * Python manages the memory for the cost functions. This means it can delete the 15 | cost function before the Problem even uses it. In order to avoid this you have 16 | to make the relationship clear that cost function scope is dependent on the 17 | problem. This is done with the py::keep_alive<> command for AddResidualBlock 18 | * End user must call super().__init__() on custom CostFunctions define in 19 | Python. If this is not done then the Base Class CostFunction is never 20 | initialized. 21 | * Seems like as soon as you start touching python stuff like the py::array you 22 | need to ensure that you have the GIL. The trampoline classes CostFunction would 23 | crash unless the gil was put as the first line. (Ahh the bug starting occuring 24 | because I put release GIL in the Solve functions. Before I had this the GIL was 25 | held and therefore there was no crash) 26 | -------------------------------------------------------------------------------- /docs/pytorch_stuff.md: -------------------------------------------------------------------------------- 1 | **WARNING THIS IS CURRENTLY EXTREMELY EXPERIMENTAL** 2 | 3 | In order to bypass the fundamental slowness of Python (due to GIL and other factors). This 4 | library optionally provides the capability to utilize PyTorch's TorchScript. 5 | This allows you to define a CostFunction in Python, but bypass having to touch 6 | it when solving the Ceres::Problem. 7 | 8 | Right now the only the standalone version of this bindings support it. Lots of 9 | the paths are hardcoded. So you will have to change them to. 10 | 11 | To enable this functionality you must do the following things. 12 | 13 | - Enable the option in cmake by turning on _WITH_PYTORCH_ 14 | - Build Ceres and GLOG that you link to with _-D_GLIBCXX_USE_CXX11_ABI=0_ 15 | - The default PyTorch libs that you download from pip and other package managers 16 | is built with the old C++ ABI. 17 | 18 | Note this will break normal functionality as all Python instantiations now requires 19 | a *import Torch* before you import *PyCeres*. 20 | 21 | Currently the TorchScript is passed by serialized files. -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Edwinem/ceres_python_bindings/d809e7651890c6e4b78c5c40c2f47d6c4aba8525/examples/__init__.py -------------------------------------------------------------------------------- /examples/ceres_hello_world_analytic_diff.py: -------------------------------------------------------------------------------- 1 | import PyCeres # Import the Python Bindings 2 | import numpy as np 3 | import pytest 4 | 5 | 6 | # A CostFunction implementing analytically derivatives for the 7 | # function f(x) = 10 - x. 8 | # Comes from ceres/examples/helloworld_analytic_diff.cc 9 | class QuadraticCostFunction(PyCeres.CostFunction): 10 | def __init__(self): 11 | super().__init__() 12 | self.set_num_residuals(1) 13 | self.set_parameter_block_sizes([1]) 14 | 15 | def Evaluate(self, parameters, residuals, jacobians): 16 | x = parameters[0][0] 17 | 18 | residuals[0] = 10 - x 19 | 20 | if (jacobians != None): 21 | jacobians[0][0] = -1 22 | 23 | return True 24 | 25 | 26 | def RunQuadraticFunction(): 27 | cost_function = QuadraticCostFunction() 28 | 29 | data = [0.5] 30 | np_data = np.array(data) 31 | 32 | print(np_data) 33 | 34 | problem = PyCeres.Problem() 35 | 36 | problem.AddResidualBlock(cost_function, None, np_data) 37 | options = PyCeres.SolverOptions() 38 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_QR 39 | options.minimizer_progress_to_stdout = True 40 | summary = PyCeres.Summary() 41 | PyCeres.Solve(options, problem, summary) 42 | print(summary.BriefReport()) 43 | print("x : " + str(0.5) + " -> " + str(np_data[0])) 44 | 45 | 46 | RunQuadraticFunction() 47 | -------------------------------------------------------------------------------- /examples/ceres_hello_world_example.py: -------------------------------------------------------------------------------- 1 | """ Contains Ceres HelloWorld Example in Python 2 | 3 | This file contains the Ceres HelloWorld Example except it uses Python Bindings. 4 | 5 | """ 6 | 7 | import PyCeres # Import the Python Bindings 8 | 9 | import numpy as np 10 | 11 | # The variable to solve for with its initial value. 12 | initial_x = 5.0 13 | x = np.array([initial_x]) 14 | 15 | # Build the Problem 16 | problem = PyCeres.Problem() 17 | 18 | # Set up the only cost function (also known as residual). This uses a helper function written in C++ as Autodiff 19 | # cant be used in Python. It returns a CostFunction* 20 | cost_function = PyCeres.CreateHelloWorldCostFunction() 21 | 22 | problem.AddResidualBlock(cost_function, None, x) 23 | 24 | options = PyCeres.SolverOptions() 25 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_QR # Ceres enums live in PyCeres and require the enum Type 26 | options.minimizer_progress_to_stdout = True 27 | summary = PyCeres.Summary() 28 | PyCeres.Solve(options, problem, summary) 29 | print(summary.BriefReport() + " \n") 30 | print("x : " + str(initial_x) + " -> " + str(x) + "\n") 31 | -------------------------------------------------------------------------------- /examples/ceres_rosenbrock_autodiff.py: -------------------------------------------------------------------------------- 1 | """ Contains Ceres Rosenbrock example in Python 2 | """ 3 | 4 | import PyCeres # Import the Python Bindings 5 | import numpy as np 6 | 7 | from jax import grad 8 | 9 | 10 | # f(x,y) = (1-x)^2 + 100(y - x^2)^2; 11 | def CalcCost(x, y): 12 | return (1.0 - x) * (1.0 - x) + 100.0 * (y - x * x) * (y - x * x) 13 | 14 | 15 | class Rosenbrock(PyCeres.FirstOrderFunction): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def Evaluate(self, parameters, cost, gradient): 20 | x = parameters[0] 21 | y = parameters[1] 22 | cost[0] = CalcCost(x, y) 23 | if not (gradient is None): 24 | gradient[0], gradient[1] = grad(CalcCost, (0, 1))(x, y) 25 | return True 26 | 27 | def NumParameters(self): 28 | return 2 29 | 30 | 31 | parameters = [-1.2, 1.0] 32 | 33 | np_params = np.array(parameters) 34 | 35 | options = PyCeres.GradientProblemOptions() 36 | options.minimizer_progress_to_stdout = True 37 | 38 | summary = PyCeres.GradientProblemSummary() 39 | problem = PyCeres.GradientProblem(Rosenbrock()) 40 | PyCeres.Solve(options, problem, np_params, summary) 41 | 42 | print(summary.FullReport() + "\n") 43 | print("Initial x: " + str(-1.2) + " y: " + str(1.0)) 44 | print("Final x: " + str(np_params[0]) + " y: " + str(np_params[1])) 45 | -------------------------------------------------------------------------------- /examples/ceres_rosenbrock_example.py: -------------------------------------------------------------------------------- 1 | """ Contains Ceres Rosenbrock example in Python 2 | """ 3 | 4 | import numpy as np 5 | 6 | import PyCeres # Import the Python Bindings 7 | 8 | 9 | # f(x,y) = (1-x)^2 + 100(y - x^2)^2; 10 | class Rosenbrock(PyCeres.FirstOrderFunction): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def Evaluate(self, parameters, cost, gradient): 15 | x = parameters[0] 16 | y = parameters[1] 17 | 18 | cost[0] = (1.0 - x) * (1.0 - x) + 100.0 * (y - x * x) * (y - x * x) 19 | if not (gradient is None): 20 | gradient[0] = -2.0 * (1.0 - x) - 200.0 * (y - x * x) * 2.0 * x 21 | gradient[1] = 200.0 * (y - x * x) 22 | return True 23 | 24 | def NumParameters(self): 25 | return 2 26 | 27 | 28 | parameters = [-1.2, 1.0] 29 | 30 | np_params = np.array(parameters) 31 | 32 | options = PyCeres.GradientProblemOptions() 33 | options.minimizer_progress_to_stdout = True 34 | 35 | summary = PyCeres.GradientProblemSummary() 36 | problem = PyCeres.GradientProblem(Rosenbrock()) 37 | PyCeres.Solve(options, problem, np_params, summary) 38 | 39 | print(summary.FullReport() + "\n") 40 | print("Initial x: " + str(-1.2) + " y: " + str(1.0)) 41 | print("Final x: " + str(np_params[0]) + " y: " + str(np_params[1])) 42 | -------------------------------------------------------------------------------- /examples/ceres_simple_bundle_adjuster.py: -------------------------------------------------------------------------------- 1 | """ Contains Ceres Simple Bundle Adjustment in Python 2 | 3 | """ 4 | 5 | import numpy as np 6 | import argparse 7 | import PyCeres 8 | 9 | parser = argparse.ArgumentParser(description='Solves a Bundle Adjustment problem') 10 | parser.add_argument('file', help='File from http://grail.cs.washington.edu/projects/bal') 11 | args = parser.parse_args() 12 | 13 | if len(sys.argv) == 1: 14 | sys.exit("No file provided") 15 | 16 | file = args.file 17 | 18 | bal_problem = PyCeres.BALProblem() 19 | 20 | bal_problem.LoadFile(file) 21 | 22 | problem = PyCeres.Problem() 23 | 24 | observations = bal_problem.observations() 25 | cameras = bal_problem.cameras() 26 | points = bal_problem.points() 27 | 28 | numpy_points = np.array(points) 29 | numpy_points = np.reshape(numpy_points, (-1, 3)) 30 | numpy_cameras = np.array(cameras) 31 | numpy_cameras = np.reshape(numpy_cameras, (-1, 9)) 32 | print(numpy_points[0]) 33 | 34 | for i in range(0, bal_problem.num_observations()): 35 | cost_function = PyCeres.CreateSnavelyCostFunction(observations[2 * i + 0], observations[2 * i + 1]) 36 | cam_index = bal_problem.camera_index(i) 37 | point_index = bal_problem.point_index(i) 38 | loss = PyCeres.HuberLoss(0.1) 39 | problem.AddResidualBlock(cost_function, loss, numpy_cameras[cam_index], numpy_points[point_index]) 40 | 41 | options = PyCeres.SolverOptions() 42 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_SCHUR 43 | options.minimizer_progress_to_stdout = True 44 | 45 | summary = PyCeres.Summary() 46 | PyCeres.Solve(options, problem, summary) 47 | print(summary.FullReport()) 48 | 49 | # Compare with CPP version 50 | 51 | print(" Running C++ version now ") 52 | PyCeres.SolveBALProblemWithCPP(bal_problem) 53 | cpp_points = bal_problem.points() 54 | cpp_points = np.array(cpp_points) 55 | cpp_points = np.reshape(cpp_points, (-1, 3)) 56 | print(" For point 1 Python has a value of " + str(numpy_points[0]) + " \n") 57 | print(" Cpp solved for point 1 a value of " + str(cpp_points[0])) 58 | -------------------------------------------------------------------------------- /examples/hello_world_python_autodiff.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | Example of using a basic cost function with autodiff provided by Jax 4 | 5 | ''' 6 | import PyCeres # Import the Python Bindings 7 | import numpy as np 8 | from jax import grad 9 | 10 | 11 | def residual_calc(param_input): 12 | return 10 - param_input 13 | 14 | 15 | # function f(x) = 10 - x. 16 | class HelloWorldAutoDiff(PyCeres.CostFunction): 17 | def __init__(self): 18 | super().__init__() 19 | self.set_num_residuals(1) 20 | self.set_parameter_block_sizes([1]) 21 | 22 | def Evaluate(self, parameters, residuals, jacobians): 23 | x = parameters[0][0] 24 | print("Param is {}".format(x)) 25 | residuals[0] = residual_calc(x) 26 | print(residuals) 27 | 28 | if (jacobians != None): 29 | jacobians[0][0] = grad(residual_calc)(x) 30 | 31 | return True 32 | 33 | 34 | def RunHelloWorldAutoDiff(): 35 | cost_function = HelloWorldAutoDiff() 36 | 37 | data = [0.5] 38 | np_data = np.array(data) 39 | 40 | print(np_data) 41 | 42 | problem = PyCeres.Problem() 43 | 44 | problem.AddResidualBlock(cost_function, None, np_data) 45 | options = PyCeres.SolverOptions() 46 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_QR 47 | options.minimizer_progress_to_stdout = True 48 | summary = PyCeres.Summary() 49 | PyCeres.Solve(options, problem, summary) 50 | print(summary.BriefReport()) 51 | print("x : " + str(0.5) + " -> " + str(np_data[0])) 52 | 53 | 54 | RunHelloWorldAutoDiff() 55 | -------------------------------------------------------------------------------- /examples/manual_setting_sys_path.py: -------------------------------------------------------------------------------- 1 | """ Contains Ceres HelloWorld Example in Python 2 | 3 | This file contains the Ceres HelloWorld Example but the import path is manually set. 4 | 5 | """ 6 | 7 | from utilities import find_and_import_pyceres 8 | 9 | # Function which tries to import the library 10 | find_and_import_pyceres() 11 | import PyCeres 12 | 13 | import numpy as np 14 | 15 | # The variable to solve for with its initial value. 16 | initial_x = 5.0 17 | x = np.array([initial_x]) 18 | 19 | # Build the Problem 20 | problem = PyCeres.Problem() 21 | 22 | # Set up the only cost function (also known as residual). This uses a helper function written in C++ as Autodiff 23 | # cant be used in Python. It returns a CostFunction* 24 | cost_function = PyCeres.CreateHelloWorldCostFunction() 25 | 26 | problem.AddResidualBlock(cost_function, None, x) 27 | 28 | options = PyCeres.SolverOptions() 29 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_QR # Ceres enums live in PyCeres and require the enum Type 30 | options.minimizer_progress_to_stdout = True 31 | summary = PyCeres.Summary() 32 | PyCeres.Solve(options, problem, summary) 33 | print(summary.BriefReport() + " \n") 34 | print("x : " + str(initial_x) + " -> " + str(x) + "\n") 35 | -------------------------------------------------------------------------------- /examples/pose_graph_slam_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solves a g2o pose graph dataset using Ceres 3 | 4 | """ 5 | import PyCeres 6 | import numpy as np 7 | from jax import grad 8 | 9 | 10 | def residual_calc(): 11 | return 10 - param_input 12 | 13 | # function f(x) = 10 - x. 14 | class PoseResidual(PyCeres.CostFunction): 15 | def __init__(self,dx,dy,dtheta): 16 | super().__init__() 17 | self.set_num_residuals(3) 18 | self.set_parameter_block_sizes([1,1,1,1,1,1]) 19 | self.dx=dx 20 | self.dy=dy 21 | self.dtheta=dtheta 22 | self.sqrt_info=np.identity(3)*0.05 23 | 24 | def Evaluate(self,parameters, residuals, jacobians): 25 | xa=parameters[0][0] 26 | ya=parameters[1][0] 27 | yaw_a=parameters[2][0] 28 | xb=parameters[3][0] 29 | yb=parameters[4][0] 30 | yaw_b=parameters[5][0] 31 | 32 | 33 | residuals[0] = residual_calc(x) 34 | 35 | if (jacobians!=None): 36 | jacobians[0][0] = grad(residual_calc)(x) 37 | 38 | return True 39 | 40 | 41 | 42 | def RunHelloWorldAutoDiff(): 43 | cost_function = HelloWorldAutoDiff() 44 | 45 | data = [0.5] 46 | np_data = np.array(data) 47 | 48 | print(np_data) 49 | 50 | problem = PyCeres.Problem() 51 | 52 | problem.AddResidualBlock(cost_function, None, np_data) 53 | options = PyCeres.SolverOptions() 54 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_QR 55 | options.minimizer_progress_to_stdout = True 56 | summary = PyCeres.Summary() 57 | PyCeres.Solve(options, problem, summary) 58 | print(summary.BriefReport()) 59 | print ("x : " + str(0.5) + " -> " + str(np_data[0])) 60 | 61 | RunHelloWorldAutoDiff() -------------------------------------------------------------------------------- /examples/pytorch_torchscript_example.py: -------------------------------------------------------------------------------- 1 | 2 | import torch # Torch must be imported before PyCeres 3 | import PyCeres # Import the Python Bindings 4 | 5 | class ExampleTorchModule(torch.nn.Module): 6 | def __init__(self): 7 | super(ExampleTorchModule, self).__init__() 8 | 9 | def forward(self, input): 10 | residual=10-input 11 | return residual 12 | 13 | module=ExampleTorchModule() 14 | torchscript = torch.jit.script(module) 15 | 16 | 17 | # Currently we pass torchscript modules as files 18 | filename="example_torch_module.pt" 19 | torchscript.save(filename) 20 | # Create a PyTorchCostFunction. From a torchscript file. Additionally residual size and parameter block sizes must 21 | # be passed. 22 | torch_cost=PyCeres.CreateTorchCostFunction(filename,1,[1]) 23 | 24 | # Create the data in a PyTorch tensor 25 | data = [0.5] 26 | tensor=torch.tensor(data,dtype=torch.float64) 27 | tensor_vec=[tensor] # Data must be passed as a list of tensors 28 | 29 | # Create problem and options as usual 30 | problem = PyCeres.Problem() 31 | res=problem.AddResidualBlock(torch_cost,None,tensor_vec) 32 | 33 | options = PyCeres.SolverOptions() 34 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_QR 35 | options.minimizer_progress_to_stdout = True 36 | summary = PyCeres.Summary() 37 | PyCeres.Solve(options, problem, summary) 38 | print(summary.BriefReport()) 39 | print("x : " + str(0.5) + " -> " + str(tensor[0])) -------------------------------------------------------------------------------- /examples/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | def find_and_import_pyceres(folder_path = None): 6 | """Function which tries to import PyCeres given the multiple possible install options. 7 | 8 | """ 9 | # Try importing it as if it was an installed library 10 | try: 11 | import PyCeres 12 | return 13 | except: 14 | pass 15 | 16 | # If it fails then try importing it from a given path 17 | if folder_path is not None: 18 | if not os.path.isdir(folder_path): 19 | raise ValueError("%s is not a folder, and therefore can't contain the PyCeres library.",folder_path) 20 | sys.path.insert(0, folder_path) 21 | try: 22 | import PyCeres 23 | return 24 | except: 25 | print("Will try to import from environment variable") 26 | # Try other importing options if the given path fails. 27 | pyceres_location="" # Folder where the PyCeres lib is created 28 | # try to import from environment variable 29 | if os.getenv('PYCERES_LOCATION'): 30 | pyceres_location=os.getenv('PYCERES_LOCATION') 31 | else: 32 | pyceres_location="../../build/lib" # If the environment variable is not set 33 | # then it will assume this directory. Only will work if built with Ceres and 34 | # through the normal mkdir build, cd build, cmake .. procedure 35 | sys.path.insert(0, pyceres_location) 36 | import PyCeres 37 | return 38 | 39 | 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "wheel", 4 | "setuptools>=45", 5 | "setuptools_scm[toml]>=6.0", 6 | "cmake_build_extension", 7 | "numpy", 8 | "pybind11", 9 | ] 10 | build-backend = "setuptools.build_meta" 11 | 12 | #[build-system] 13 | #requires = ["setuptools", "wheel", "scikit-build", "cmake", "ninja"] 14 | #build-backend = "setuptools.build_meta" 15 | # 16 | #[tool.pytest.ini_options] 17 | #testpaths = ["tests"] 18 | -------------------------------------------------------------------------------- /python_bindings/ceres_examples_module.cpp: -------------------------------------------------------------------------------- 1 | // Ceres Solver - A fast non-linear least squares minimizer 2 | // Copyright 2015 Google Inc. All rights reserved. 3 | // http://ceres-solver.org/ 4 | // 5 | // Redistribution and use in source and binary forms, with or without 6 | // modification, are permitted provided that the following conditions are met: 7 | // 8 | // * Redistributions of source code must retain the above copyright notice, 9 | // this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above copyright notice, 11 | // this list of conditions and the following disclaimer in the documentation 12 | // and/or other materials provided with the distribution. 13 | // * Neither the name of Google Inc. nor the names of its contributors may be 14 | // used to endorse or promote products derived from this software without 15 | // specific prior written permission. 16 | // 17 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 | // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | // POSSIBILITY OF SUCH DAMAGE. 28 | // 29 | // Author: keir@google.com (Keir Mierle) 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | #include "ceres/ceres.h" 36 | #include "ceres/rotation.h" 37 | 38 | namespace py = pybind11; 39 | 40 | struct HelloWorldCostFunctor { 41 | template 42 | bool operator()(const T* const x, T* residual) const { 43 | residual[0] = T(10.0) - x[0]; 44 | return true; 45 | } 46 | }; 47 | 48 | ceres::CostFunction* CreateHelloWorldCostFunction() { 49 | return new ceres::AutoDiffCostFunction( 50 | new HelloWorldCostFunctor); 51 | } 52 | 53 | // Read a Bundle Adjustment in the Large dataset. 54 | class BALProblem { 55 | public: 56 | ~BALProblem() { 57 | delete[] point_index_; 58 | delete[] camera_index_; 59 | delete[] observations_; 60 | delete[] parameters_; 61 | } 62 | int num_observations() const { return num_observations_; } 63 | const double* observations() const { return observations_; } 64 | double* mutable_cameras() { return parameters_; } 65 | double* mutable_points() { return parameters_ + 9 * num_cameras_; } 66 | double* mutable_camera_for_observation(int i) { 67 | return mutable_cameras() + camera_index_[i] * 9; 68 | } 69 | double* mutable_point_for_observation(int i) { 70 | return mutable_points() + point_index_[i] * 3; 71 | } 72 | bool LoadFile(const char* filename) { 73 | FILE* fptr = fopen(filename, "r"); 74 | if (fptr == NULL) { 75 | return false; 76 | }; 77 | FscanfOrDie(fptr, "%d", &num_cameras_); 78 | FscanfOrDie(fptr, "%d", &num_points_); 79 | FscanfOrDie(fptr, "%d", &num_observations_); 80 | point_index_ = new int[num_observations_]; 81 | camera_index_ = new int[num_observations_]; 82 | observations_ = new double[2 * num_observations_]; 83 | num_parameters_ = 9 * num_cameras_ + 3 * num_points_; 84 | parameters_ = new double[num_parameters_]; 85 | for (int i = 0; i < num_observations_; ++i) { 86 | FscanfOrDie(fptr, "%d", camera_index_ + i); 87 | FscanfOrDie(fptr, "%d", point_index_ + i); 88 | for (int j = 0; j < 2; ++j) { 89 | FscanfOrDie(fptr, "%lf", observations_ + 2 * i + j); 90 | } 91 | } 92 | for (int i = 0; i < num_parameters_; ++i) { 93 | FscanfOrDie(fptr, "%lf", parameters_ + i); 94 | } 95 | return true; 96 | } 97 | 98 | template 99 | void FscanfOrDie(FILE* fptr, const char* format, T* value) { 100 | int num_scanned = fscanf(fptr, format, value); 101 | if (num_scanned != 1) { 102 | LOG(FATAL) << "Invalid UW data file."; 103 | } 104 | } 105 | int num_cameras_; 106 | int num_points_; 107 | int num_observations_; 108 | int num_parameters_; 109 | int* point_index_; 110 | int* camera_index_; 111 | double* observations_; 112 | double* parameters_; 113 | }; 114 | 115 | // Templated pinhole camera model for used with Ceres. The camera is 116 | // parameterized using 9 parameters: 3 for rotation, 3 for translation, 1 for 117 | // focal length and 2 for radial distortion. The principal point is not modeled 118 | // (i.e. it is assumed be located at the image center). 119 | struct SnavelyReprojectionError { 120 | SnavelyReprojectionError(double observed_x, double observed_y) 121 | : observed_x(observed_x), observed_y(observed_y) {} 122 | template 123 | bool operator()(const T* const camera, 124 | const T* const point, 125 | T* residuals) const { 126 | // camera[0,1,2] are the angle-axis rotation. 127 | T p[3]; 128 | ceres::AngleAxisRotatePoint(camera, point, p); 129 | // camera[3,4,5] are the translation. 130 | p[0] += camera[3]; 131 | p[1] += camera[4]; 132 | p[2] += camera[5]; 133 | // Compute the center of distortion. The sign change comes from 134 | // the camera model that Noah Snavely's Bundler assumes, whereby 135 | // the camera coordinate system has a negative z axis. 136 | T xp = -p[0] / p[2]; 137 | T yp = -p[1] / p[2]; 138 | // Apply second and fourth order radial distortion. 139 | const T& l1 = camera[7]; 140 | const T& l2 = camera[8]; 141 | T r2 = xp * xp + yp * yp; 142 | T distortion = 1.0 + r2 * (l1 + l2 * r2); 143 | // Compute final projected point position. 144 | const T& focal = camera[6]; 145 | T predicted_x = focal * distortion * xp; 146 | T predicted_y = focal * distortion * yp; 147 | // The error is the difference between the predicted and observed position. 148 | residuals[0] = predicted_x - observed_x; 149 | residuals[1] = predicted_y - observed_y; 150 | return true; 151 | } 152 | // Factory to hide the construction of the CostFunction object from 153 | // the client code. 154 | static ceres::CostFunction* Create(const double observed_x, 155 | const double observed_y) { 156 | return (new ceres::AutoDiffCostFunction( 157 | new SnavelyReprojectionError(observed_x, observed_y))); 158 | } 159 | double observed_x; 160 | double observed_y; 161 | }; 162 | 163 | void SolveBALProblemWithCPP(BALProblem* bal_problem) { 164 | const double* observations = bal_problem->observations(); 165 | ceres::Problem problem; 166 | for (int i = 0; i < bal_problem->num_observations(); ++i) { 167 | ceres::CostFunction* cost_function = SnavelyReprojectionError::Create( 168 | observations[2 * i + 0], observations[2 * i + 1]); 169 | problem.AddResidualBlock(cost_function, 170 | NULL /* squared loss */, 171 | bal_problem->mutable_camera_for_observation(i), 172 | bal_problem->mutable_point_for_observation(i)); 173 | } 174 | ceres::Solver::Options options; 175 | options.linear_solver_type = ceres::DENSE_SCHUR; 176 | // options.minimizer_progress_to_stdout = true; 177 | ceres::Solver::Summary summary; 178 | ceres::Solve(options, &problem, &summary); 179 | } 180 | 181 | void add_pybinded_ceres_examples(py::module& m) { 182 | m.def("CreateHelloWorldCostFunction", &CreateHelloWorldCostFunction); 183 | 184 | py::class_ bal(m, "BALProblem"); 185 | bal.def(py::init<>()); 186 | bal.def("num_observations", &BALProblem::num_observations); 187 | bal.def("LoadFile", &BALProblem::LoadFile); 188 | 189 | bal.def("observations", [](BALProblem& myself) { 190 | std::vector double_data; 191 | for (int idx = 0; idx < myself.num_observations(); ++idx) { 192 | double_data.push_back(myself.observations()[2 * idx + 0]); 193 | double_data.push_back(myself.observations()[2 * idx + 1]); 194 | } 195 | return double_data; 196 | }); 197 | bal.def("cameras", [](BALProblem& myself) { 198 | std::vector double_data; 199 | for (int idx = 0; idx < myself.num_cameras_; ++idx) { 200 | double_data.push_back(myself.mutable_cameras()[9 * idx + 0]); 201 | double_data.push_back(myself.mutable_cameras()[9 * idx + 1]); 202 | double_data.push_back(myself.mutable_cameras()[9 * idx + 2]); 203 | double_data.push_back(myself.mutable_cameras()[9 * idx + 3]); 204 | double_data.push_back(myself.mutable_cameras()[9 * idx + 4]); 205 | double_data.push_back(myself.mutable_cameras()[9 * idx + 5]); 206 | double_data.push_back(myself.mutable_cameras()[9 * idx + 6]); 207 | double_data.push_back(myself.mutable_cameras()[9 * idx + 7]); 208 | double_data.push_back(myself.mutable_cameras()[9 * idx + 8]); 209 | } 210 | return double_data; 211 | }); 212 | bal.def("points", [](BALProblem& myself) { 213 | std::vector double_data; 214 | for (int idx = 0; idx < myself.num_points_; ++idx) { 215 | double_data.push_back(myself.mutable_points()[3 * idx + 0]); 216 | double_data.push_back(myself.mutable_points()[3 * idx + 1]); 217 | double_data.push_back(myself.mutable_points()[3 * idx + 2]); 218 | } 219 | return double_data; 220 | }); 221 | bal.def("mutable_cameras_for_observation", [](BALProblem& myself, int i) { 222 | std::vector double_data; 223 | double* ptr = myself.mutable_cameras() + myself.camera_index_[i] * 9; 224 | for (int idx = 0; idx < 9; ++idx) { 225 | double_data.push_back(ptr[idx]); 226 | } 227 | return double_data; 228 | }); 229 | bal.def("mutable_point_for_observation", [](BALProblem& myself, int i) { 230 | std::vector double_data; 231 | double* ptr = myself.mutable_points() + myself.point_index_[i] * 3; 232 | for (int idx = 0; idx < 3; ++idx) { 233 | double_data.push_back(ptr[idx]); 234 | } 235 | return double_data; 236 | }); 237 | 238 | bal.def("camera_index", 239 | [](BALProblem& myself, int i) { return myself.camera_index_[i]; }); 240 | 241 | bal.def("point_index", 242 | [](BALProblem& myself, int i) { return myself.point_index_[i]; }); 243 | 244 | m.def("CreateSnavelyCostFunction", &SnavelyReprojectionError::Create); 245 | m.def("SolveBALProblemWithCPP", &SolveBALProblemWithCPP); 246 | } 247 | -------------------------------------------------------------------------------- /python_bindings/custom_cpp_cost_functions.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace py = pybind11; 5 | 6 | struct ExampleFunctor { 7 | template 8 | bool operator()(const T* const x, T* residual) const { 9 | residual[0] = T(10.0) - x[0]; 10 | return true; 11 | } 12 | 13 | static ceres::CostFunction* Create() { 14 | return new ceres::AutoDiffCostFunction( 15 | new ExampleFunctor); 16 | } 17 | }; 18 | 19 | void add_custom_cost_functions(py::module& m) { 20 | // Use pybind11 code to wrap your own cost function which is defined in C++s 21 | 22 | // Here is an example 23 | m.def("CreateCustomExampleCostFunction", &ExampleFunctor::Create); 24 | } 25 | -------------------------------------------------------------------------------- /python_bindings/python_module.cpp: -------------------------------------------------------------------------------- 1 | // Ceres Solver Python Bindings 2 | // Copyright Nikolaus Mitchell. All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are met: 6 | // 7 | // * Redistributions of source code must retain the above copyright notice, 8 | // this list of conditions and the following disclaimer. 9 | // * Redistributions in binary form must reproduce the above copyright notice, 10 | // this list of conditions and the following disclaimer in the documentation 11 | // and/or other materials provided with the distribution. 12 | // * Neither the name of the copyright holder nor the names of its contributors 13 | // may be used to endorse or promote products derived from this software 14 | // without specific prior written permission. 15 | // 16 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 20 | // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 | // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 | // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 | // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 | // POSSIBILITY OF SUCH DAMAGE. 27 | // 28 | // Author: nikolausmitchell@gmail.com (Nikolaus Mitchell) 29 | 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | 37 | #include 38 | #include 39 | 40 | // If the library is built by linking to ceres then we need to 41 | // access the defined compiler options(USE_SUITESPARSE,threading model 42 | // ...) 43 | #ifdef CERES_IS_LINKED 44 | #include 45 | #endif 46 | 47 | namespace py = pybind11; 48 | 49 | // Used for overloaded functions in C++11 50 | template 51 | using overload_cast_ = pybind11::detail::overload_cast_impl; 52 | 53 | // Forward decls for additionally modules 54 | void add_pybinded_ceres_examples(py::module& m); 55 | void add_custom_cost_functions(py::module& m); 56 | void add_torch_functionality(py::module& m); 57 | 58 | // Function to create a ceres Problem with the default options that Ceres does 59 | // NOT take ownership. Needed since Python expects to own the memory. 60 | ceres::Problem CreatePythonProblem() { 61 | ceres::Problem::Options o; 62 | o.local_parameterization_ownership = ceres::Ownership::DO_NOT_TAKE_OWNERSHIP; 63 | o.loss_function_ownership = ceres::Ownership::DO_NOT_TAKE_OWNERSHIP; 64 | o.cost_function_ownership = ceres::Ownership::DO_NOT_TAKE_OWNERSHIP; 65 | return ceres::Problem(o); 66 | } 67 | 68 | // Function to create Problem::Options with DO_NOT_TAKE_OWNERSHIP 69 | // This is cause we want Python to manage our memory not Ceres 70 | ceres::Problem::Options CreateNoOwnershipOption() { 71 | ceres::Problem::Options o; 72 | o.local_parameterization_ownership = ceres::Ownership::DO_NOT_TAKE_OWNERSHIP; 73 | o.loss_function_ownership = ceres::Ownership::DO_NOT_TAKE_OWNERSHIP; 74 | o.cost_function_ownership = ceres::Ownership::DO_NOT_TAKE_OWNERSHIP; 75 | return o; 76 | } 77 | 78 | // Class which we can use to create a ceres::CostFunction in python. 79 | // This allows use to create python based cost functions. 80 | class PyCostFunction : public ceres::CostFunction { 81 | public: 82 | /* Inherit the constructors */ 83 | using ceres::CostFunction::CostFunction; 84 | using ceres::CostFunction::set_num_residuals; 85 | 86 | bool Evaluate(double const* const* parameters, 87 | double* residuals, 88 | double** jacobians) const override { 89 | pybind11::gil_scoped_acquire gil; 90 | 91 | // Resize the vectors passed to python to the proper size. And set the 92 | // pointer values 93 | if (!cached_flag) { 94 | parameters_vec.reserve(this->parameter_block_sizes().size()); 95 | jacobians_vec.reserve(this->parameter_block_sizes().size()); 96 | residuals_wrap = py::array_t(num_residuals(), residuals, dummy); 97 | for (size_t idx = 0; idx < parameter_block_sizes().size(); ++idx) { 98 | parameters_vec.emplace_back(py::array_t( 99 | this->parameter_block_sizes()[idx], parameters[idx], dummy)); 100 | jacobians_vec.emplace_back(py::array_t( 101 | this->parameter_block_sizes()[idx] * num_residuals(), 102 | jacobians[idx], 103 | dummy)); 104 | } 105 | cached_flag = true; 106 | } 107 | 108 | // Check if the pointers have changed and if they have then change them 109 | auto info = residuals_wrap.request(true); 110 | if (info.ptr != residuals) { 111 | residuals_wrap = py::array_t(num_residuals(), residuals, dummy); 112 | } 113 | info = parameters_vec[0].request(true); 114 | if (info.ptr != parameters) { 115 | for (size_t idx = 0; idx < parameters_vec.size(); ++idx) { 116 | parameters_vec[idx] = py::array_t( 117 | this->parameter_block_sizes()[idx], parameters[idx], dummy); 118 | } 119 | } 120 | if (jacobians) { 121 | info = jacobians_vec[0].request(true); 122 | if (info.ptr != jacobians) { 123 | for (size_t idx = 0; idx < jacobians_vec.size(); ++idx) { 124 | jacobians_vec[idx] = py::array_t( 125 | this->parameter_block_sizes()[idx] * num_residuals(), 126 | jacobians[idx], 127 | dummy); 128 | } 129 | } 130 | } 131 | 132 | pybind11::function overload = pybind11::get_overload( 133 | static_cast(this), "Evaluate"); 134 | if (overload) { 135 | if (jacobians) { 136 | auto o = overload.operator()( 137 | parameters_vec, residuals_wrap, jacobians_vec); 138 | return pybind11::detail::cast_safe(std::move(o)); 139 | } else { 140 | auto o = overload.operator()( 141 | parameters_vec, residuals_wrap, nullptr); 142 | return pybind11::detail::cast_safe(std::move(o)); 143 | } 144 | } 145 | pybind11::pybind11_fail("Tried to call pure virtual function \"" PYBIND11_STRINGIFY( 146 | Ceres::CostFunction) "::" "Evaluate \""); 147 | } 148 | 149 | private: 150 | // Vectors used to pass double pointers to python as pybind does not wrap 151 | // double pointers(**) like Ceres uses. 152 | // Mutable so they can be modified by the const function. 153 | mutable std::vector> parameters_vec; 154 | mutable std::vector> jacobians_vec; 155 | mutable bool cached_flag = false; // Flag used to determine if the vectors 156 | // need to be resized 157 | mutable py::array_t 158 | residuals_wrap; // Buffer to contain the residuals 159 | // pointer 160 | mutable py::str dummy; // Dummy variable for pybind11 so it doesn't make a 161 | // copy 162 | }; 163 | 164 | // Trampoline class so that we can create a LossFunction in Python. 165 | class PyLossFunction : public ceres::LossFunction { 166 | public: 167 | /* Inherit the constructors */ 168 | using ceres::LossFunction::LossFunction; 169 | 170 | void Evaluate(double sq_norm, double out[3]) const override {} 171 | }; 172 | 173 | // 174 | class PyLocalParameterization : public ceres::LocalParameterization { 175 | /* Inherit the constructors */ 176 | using ceres::LocalParameterization::LocalParameterization; 177 | 178 | bool Plus(const double* x, const double* delta, double* x_plus_delta) const { 179 | assert(false); 180 | return true; 181 | } 182 | bool ComputeJacobian(const double* x, double* jacobian) const { 183 | assert(false); 184 | return true; 185 | } 186 | 187 | bool MultiplyByJacobian(const double* x, 188 | const int num_rows, 189 | const double* global_matrix, 190 | double* local_matrix) const { 191 | assert(false); 192 | return true; 193 | } 194 | 195 | // Size of x. 196 | int GlobalSize() const override { 197 | PYBIND11_OVERLOAD_PURE( 198 | int, /* Return type */ 199 | ceres::LocalParameterization, /* Parent class */ 200 | GlobalSize, /* Name of function in C++ (must match Python name) */ 201 | ); 202 | } 203 | 204 | // Size of delta. 205 | int LocalSize() const override { 206 | PYBIND11_OVERLOAD_PURE( 207 | int, /* Return type */ 208 | ceres::LocalParameterization, /* Parent class */ 209 | LocalSize, /* Name of function in C++ (must match Python name) */ 210 | ); 211 | } 212 | }; 213 | 214 | // Trampoline class so we can create an EvaluationCallback in Python. 215 | class PyEvaluationCallBack : public ceres::EvaluationCallback { 216 | public: 217 | /* Inherit the constructors */ 218 | using ceres::EvaluationCallback::EvaluationCallback; 219 | 220 | void PrepareForEvaluation(bool evaluate_jacobians, 221 | bool new_evaluation_point) override { 222 | PYBIND11_OVERLOAD_PURE(void, /* Return type */ 223 | ceres::EvaluationCallback, /* Parent class */ 224 | PrepareForEvaluation, /* Name of function in C++ 225 | (must match Python name) */ 226 | evaluate_jacobians, 227 | new_evaluation_point /* Argument(s) */ 228 | ); 229 | } 230 | }; 231 | 232 | class PyFirstOrderFunction : public ceres::FirstOrderFunction { 233 | public: 234 | /* Inherit the constructors */ 235 | using ceres::FirstOrderFunction::FirstOrderFunction; 236 | 237 | int NumParameters() const override { 238 | pybind11::gil_scoped_acquire gil; 239 | pybind11::function overload = pybind11::get_overload( 240 | static_cast(this), "NumParameters"); 241 | if (overload) { 242 | auto o = overload(); 243 | return pybind11::detail::cast_safe(std::move(o)); 244 | } 245 | 246 | pybind11::pybind11_fail("Tried to call pure virtual function \"" PYBIND11_STRINGIFY( 247 | ceres::FirstOrderFunction) "::" "NumParameters \""); 248 | } 249 | 250 | bool Evaluate(const double* const parameters, 251 | double* cost, 252 | double* gradient) const override { 253 | pybind11::gil_scoped_acquire gil; 254 | if (!cached_flag) { 255 | parameters_wrap = py::array_t(NumParameters(), parameters, dummy); 256 | gradient_wrap = py::array_t(NumParameters(), gradient, dummy); 257 | cost_wrap = py::array_t(1, cost, dummy); 258 | cached_flag = true; 259 | } 260 | 261 | // Check if the pointers have change and if they have then change them 262 | auto info = cost_wrap.request(true); 263 | if (info.ptr != cost) { 264 | cost_wrap = py::array_t(1, cost, dummy); 265 | } 266 | info = parameters_wrap.request(true); 267 | if (info.ptr != parameters) { 268 | parameters_wrap = py::array_t(NumParameters(), parameters, dummy); 269 | } 270 | if (gradient) { 271 | info = gradient_wrap.request(true); 272 | if (info.ptr != gradient) { 273 | gradient_wrap = py::array_t(NumParameters(), gradient, dummy); 274 | } 275 | } 276 | pybind11::function overload = pybind11::get_overload( 277 | static_cast(this), "Evaluate"); 278 | if (overload) { 279 | if (gradient) { 280 | auto o = overload.operator()( 281 | parameters_wrap, cost_wrap, gradient_wrap); 282 | return pybind11::detail::cast_safe(std::move(o)); 283 | } else { 284 | auto o = overload.operator()( 285 | parameters_wrap, cost_wrap, nullptr); 286 | return pybind11::detail::cast_safe(std::move(o)); 287 | } 288 | } 289 | pybind11::pybind11_fail("Tried to call pure virtual function \"" PYBIND11_STRINGIFY( 290 | ceres::FirstOrderFunction) "::" "Evaluate \""); 291 | } 292 | 293 | private: 294 | // Numpy arrays to pass to python that wrap the pointers 295 | // Mutable so they can be modified by the const function. 296 | mutable py::array_t parameters_wrap; 297 | mutable py::array_t gradient_wrap; 298 | mutable bool cached_flag = false; // Flag used to determine if the vectors 299 | // need to be resized 300 | mutable py::array_t cost_wrap; // Buffer to contain the cost ptr 301 | mutable py::str dummy; // Dummy variable for pybind11 so it doesn't make a 302 | // copy 303 | }; 304 | 305 | class PyIterationCallback : public ceres::IterationCallback { 306 | public: 307 | /* Inherit the constructors */ 308 | using ceres::IterationCallback::IterationCallback; 309 | 310 | ceres::CallbackReturnType operator()( 311 | const ceres::IterationSummary& summary) override { 312 | PYBIND11_OVERLOAD_PURE( 313 | ceres::CallbackReturnType, /* Return type */ 314 | ceres::IterationCallback, /* Parent class */ 315 | operator(), /* Name of function in C++ (must match Python name) */ 316 | summary /* Argument(s) */ 317 | ); 318 | } 319 | }; 320 | 321 | // Hacky Wrapper for ceres::FirstOrderFunction. 322 | // Essentially the problem is that GradientProblem takes ownership of the 323 | // passed in function. In order to stop a double delete from happening we 324 | // instead use this class. It wraps the ceres::FirstOrderFunction* pointer. 325 | // This function is then passed to GradientProblem. GradientProblem will then 326 | // delete this class instead of ceres::FirstOrderFunction. Python is free to 327 | // delete the FirstOrderFunction* without worrying about a double delete. 328 | class FirstOrderFunctionWrapper : public ceres::FirstOrderFunction { 329 | public: 330 | explicit FirstOrderFunctionWrapper(FirstOrderFunction* real_function) 331 | : function_(real_function) {} 332 | bool Evaluate(const double* const parameters, 333 | double* cost, 334 | double* gradient) const override { 335 | return function_->Evaluate(parameters, cost, gradient); 336 | } 337 | int NumParameters() const override { return function_->NumParameters(); } 338 | 339 | private: 340 | FirstOrderFunction* function_; 341 | }; 342 | 343 | // Same as FirstOrderFunctionWrapper 344 | class CostFunctionWrapper : public ceres::CostFunction { 345 | explicit CostFunctionWrapper(ceres::CostFunction* real_cost_function) 346 | : cost_function_(real_cost_function) { 347 | this->set_num_residuals(cost_function_->num_residuals()); 348 | *(this->mutable_parameter_block_sizes()) = 349 | cost_function_->parameter_block_sizes(); 350 | } 351 | 352 | bool Evaluate(double const* const* parameters, 353 | double* residuals, 354 | double** jacobians) const override { 355 | return cost_function_->Evaluate(parameters, residuals, jacobians); 356 | } 357 | 358 | private: 359 | CostFunction* cost_function_; 360 | }; 361 | 362 | // Parses a numpy array and extracts the pointer to the first element. 363 | // Requires that the numpy array be either an array or a row/column vector 364 | double* ParseNumpyData(py::array_t& np_buf) { 365 | py::buffer_info info = np_buf.request(); 366 | // This is essentially just all error checking. As it will always be the info 367 | // ptr 368 | if (info.ndim > 2) { 369 | std::string error_msg( 370 | "Number of dimensions must be <=2. This function" 371 | "only allows either an array or row/column vector(2D matrix) " + 372 | std::to_string(info.ndim)); 373 | throw std::runtime_error(error_msg); 374 | } 375 | if (info.ndim == 2) { 376 | // Row or Column Vector. Represents 1 parameter 377 | if (info.shape[0] == 1 || info.shape[1] == 1) { 378 | } else { 379 | std::string error_msg( 380 | "Matrix is not a row or column vector and instead has size " + 381 | std::to_string(info.shape[0]) + "x" + std::to_string(info.shape[1])); 382 | throw std::runtime_error(error_msg); 383 | } 384 | if (info.itemsize != 8) { 385 | std::string error_msg("Numpy vector must be of type double "); 386 | throw std::runtime_error(error_msg); 387 | } 388 | } 389 | return (double*)info.ptr; 390 | } 391 | 392 | PYBIND11_MODULE(PyCeres, m) { 393 | m.doc() = "Ceres wrappers"; // optional module docstring' 394 | 395 | py::enum_(m, "Ownership") 396 | .value("DO_NOT_TAKE_OWNERSHIP", ceres::Ownership::DO_NOT_TAKE_OWNERSHIP) 397 | .value("TAKE_OWNERSHIP", ceres::Ownership::TAKE_OWNERSHIP) 398 | .export_values(); 399 | 400 | py::enum_(m, "MinimizerType") 401 | .value("LINE_SEARCH", ceres::MinimizerType::LINE_SEARCH) 402 | .value("TRUST_REGION", ceres::MinimizerType::TRUST_REGION); 403 | 404 | py::enum_(m, "LineSearchType") 405 | .value("ARMIJO", ceres::LineSearchType::ARMIJO) 406 | .value("WOLFE", ceres::LineSearchType::WOLFE); 407 | 408 | py::enum_(m, "LineSearchDirectionType") 409 | .value("BFGS", ceres::LineSearchDirectionType::BFGS) 410 | .value("LBFGS", ceres::LineSearchDirectionType::LBFGS) 411 | .value("NONLINEAR_CONJUGATE_GRADIENT", 412 | ceres::LineSearchDirectionType::NONLINEAR_CONJUGATE_GRADIENT) 413 | .value("STEEPEST_DESCENT", 414 | ceres::LineSearchDirectionType::STEEPEST_DESCENT); 415 | 416 | py::enum_(m, 417 | "LineSearchInterpolationType") 418 | .value("BISECTION", ceres::LineSearchInterpolationType::BISECTION) 419 | .value("CUBIC", ceres::LineSearchInterpolationType::CUBIC) 420 | .value("QUADRATIC", ceres::LineSearchInterpolationType::QUADRATIC); 421 | 422 | py::enum_( 423 | m, "NonlinearConjugateGradientType") 424 | .value("FLETCHER_REEVES", 425 | ceres::NonlinearConjugateGradientType::FLETCHER_REEVES) 426 | .value("HESTENES_STIEFEL", 427 | ceres::NonlinearConjugateGradientType::HESTENES_STIEFEL) 428 | .value("POLAK_RIBIERE", 429 | ceres::NonlinearConjugateGradientType::POLAK_RIBIERE); 430 | 431 | py::enum_(m, "LinearSolverType") 432 | .value("DENSE_NORMAL_CHOLESKY", 433 | ceres::LinearSolverType::DENSE_NORMAL_CHOLESKY) 434 | .value("DENSE_QR", ceres::LinearSolverType::DENSE_QR) 435 | .value("SPARSE_NORMAL_CHOLESKY", 436 | ceres::LinearSolverType::SPARSE_NORMAL_CHOLESKY) 437 | .value("DENSE_SCHUR", ceres::LinearSolverType::DENSE_SCHUR) 438 | .value("SPARSE_SCHUR", ceres::LinearSolverType::SPARSE_SCHUR) 439 | .value("ITERATIVE_SCHUR", ceres::LinearSolverType::ITERATIVE_SCHUR) 440 | .value("CGNR", ceres::LinearSolverType::CGNR); 441 | 442 | py::enum_(m, "DoglegType") 443 | .value("TRADITIONAL_DOGLEG", ceres::DoglegType::TRADITIONAL_DOGLEG) 444 | .value("SUBSPACE_DOGLEG", ceres::DoglegType::SUBSPACE_DOGLEG); 445 | 446 | py::enum_(m, "TrustRegionStrategyType") 447 | .value("LEVENBERG_MARQUARDT", 448 | ceres::TrustRegionStrategyType::LEVENBERG_MARQUARDT) 449 | .value("DOGLEG", ceres::TrustRegionStrategyType::DOGLEG); 450 | 451 | py::enum_(m, "PreconditionerType") 452 | .value("IDENTITY", ceres::PreconditionerType::IDENTITY) 453 | .value("JACOBI", ceres::PreconditionerType::JACOBI) 454 | .value("SCHUR_JACOBI", ceres::PreconditionerType::SCHUR_JACOBI) 455 | .value("CLUSTER_JACOBI", ceres::PreconditionerType::CLUSTER_JACOBI) 456 | .value("CLUSTER_TRIDIAGONAL", 457 | ceres::PreconditionerType::CLUSTER_TRIDIAGONAL) 458 | .value("SUBSET", ceres::PreconditionerType::SUBSET); 459 | 460 | py::enum_(m, "VisibilityClusteringType") 461 | .value("CANONICAL_VIEWS", 462 | ceres::VisibilityClusteringType::CANONICAL_VIEWS) 463 | .value("SINGLE_LINKAGE", ceres::VisibilityClusteringType::SINGLE_LINKAGE); 464 | 465 | py::enum_( 466 | m, "DenseLinearAlgebraLibraryType") 467 | .value("EIGEN", ceres::DenseLinearAlgebraLibraryType::EIGEN) 468 | .value("LAPACK", ceres::DenseLinearAlgebraLibraryType::LAPACK); 469 | 470 | py::enum_( 471 | m, "SparseLinearAlgebraLibraryType") 472 | .value("SUITE_SPARSE", 473 | ceres::SparseLinearAlgebraLibraryType::SUITE_SPARSE) 474 | .value("CX_SPARSE", ceres::SparseLinearAlgebraLibraryType::CX_SPARSE) 475 | .value("EIGEN_SPARSE", 476 | ceres::SparseLinearAlgebraLibraryType::EIGEN_SPARSE) 477 | .value("ACCELERATE_SPARSE", 478 | ceres::SparseLinearAlgebraLibraryType::ACCELERATE_SPARSE) 479 | .value("NO_SPARSE", ceres::SparseLinearAlgebraLibraryType::NO_SPARSE); 480 | 481 | py::enum_(m, "LoggingType") 482 | .value("SILENT", ceres::LoggingType::SILENT) 483 | .value("PER_MINIMIZER_ITERATION", 484 | ceres::LoggingType::PER_MINIMIZER_ITERATION); 485 | 486 | py::enum_(m, "CovarianceAlgorithmType") 487 | .value("DENSE_SVD", ceres::CovarianceAlgorithmType::DENSE_SVD) 488 | .value("SPARSE_QR", ceres::CovarianceAlgorithmType::SPARSE_QR); 489 | 490 | py::enum_(m, "CallbackReturnType") 491 | .value("SOLVER_CONTINUE", ceres::CallbackReturnType::SOLVER_CONTINUE) 492 | .value("SOLVER_ABORT", ceres::CallbackReturnType::SOLVER_ABORT) 493 | .value("SOLVER_TERMINATE_SUCCESSFULLY", 494 | ceres::CallbackReturnType::SOLVER_TERMINATE_SUCCESSFULLY); 495 | 496 | py::enum_(m, "DumpFormatType") 497 | .value("CONSOLE", ceres::DumpFormatType::CONSOLE) 498 | .value("TEXTFILE", ceres::DumpFormatType::TEXTFILE); 499 | 500 | using options = ceres::Problem::Options; 501 | py::class_ option(m, "ProblemOptions"); 502 | option.def(py::init(&CreateNoOwnershipOption)); // Ensures default is that 503 | // Python manages memory 504 | option.def_readwrite("cost_function_ownership", 505 | &options::cost_function_ownership); 506 | option.def_readwrite("loss_function_ownership", 507 | &options::loss_function_ownership); 508 | option.def_readwrite("local_parameterization_ownership", 509 | &options::local_parameterization_ownership); 510 | option.def_readwrite("enable_fast_removal", &options::enable_fast_removal); 511 | option.def_readwrite("disable_all_safety_checks", 512 | &options::disable_all_safety_checks); 513 | py::class_(m, "EvaluateOptions") 514 | .def(py::init<>()) 515 | // Doesn't make sense to wrap this as you can't see the pointers in python 516 | //.def_readwrite("parameter_blocks",&ceres::Problem::EvaluateOptions) 517 | .def_readwrite("apply_loss_function", 518 | &ceres::Problem::EvaluateOptions::apply_loss_function) 519 | .def_readwrite("num_threads", 520 | &ceres::Problem::EvaluateOptions::num_threads); 521 | 522 | // Wrapper around ceres ResidualBlockID. In Ceres a ResidualBlockId is 523 | // actually just a pointer to internal::ResidualBlock. However, since Ceres 524 | // uses a forward declaration we don't actually have the type definition. 525 | // (Ceres doesn't make it part of its public API). Since pybind11 needs a type 526 | // we use this class instead which simply holds the pointer. 527 | struct ResidualBlockIDWrapper { 528 | ResidualBlockIDWrapper(const ceres::ResidualBlockId& id) : id(id) {} 529 | const ceres::ResidualBlockId id; 530 | }; 531 | 532 | py::class_ residual_block_wrapper(m, "ResidualBlock"); 533 | 534 | py::class_ problem(m, "Problem"); 535 | problem.def(py::init(&CreatePythonProblem)); 536 | problem.def(py::init()); 537 | problem.def("NumParameterBlocks", &ceres::Problem::NumParameterBlocks); 538 | problem.def("NumParameters", &ceres::Problem::NumParameters); 539 | problem.def("NumResidualBlocks", &ceres::Problem::NumResidualBlocks); 540 | problem.def("NumResiduals", &ceres::Problem::NumResiduals); 541 | problem.def("ParameterBlockSize", &ceres::Problem::ParameterBlockSize); 542 | problem.def("SetParameterBlockConstant", 543 | [](ceres::Problem& myself, py::array_t& np_arr) { 544 | py::buffer_info info = np_arr.request(); 545 | myself.SetParameterBlockConstant((double*)info.ptr); 546 | }); 547 | problem.def("SetParameterBlockVariable", 548 | [](ceres::Problem& myself, py::array_t& np_arr) { 549 | py::buffer_info info = np_arr.request(); 550 | myself.SetParameterBlockVariable((double*)info.ptr); 551 | }); 552 | problem.def("IsParameterBlockConstant", 553 | [](ceres::Problem& myself, py::array_t& np_arr) { 554 | py::buffer_info info = np_arr.request(); 555 | myself.IsParameterBlockConstant((double*)info.ptr); 556 | }); 557 | problem.def("SetParameterLowerBound", 558 | [](ceres::Problem& myself, 559 | py::array_t& np_arr, 560 | int index, 561 | double lower_bound) { 562 | py::buffer_info info = np_arr.request(); 563 | myself.SetParameterLowerBound( 564 | (double*)info.ptr, index, lower_bound); 565 | }); 566 | problem.def("SetParameterUpperBound", 567 | [](ceres::Problem& myself, 568 | py::array_t& np_arr, 569 | int index, 570 | double upper_bound) { 571 | py::buffer_info info = np_arr.request(); 572 | myself.SetParameterUpperBound( 573 | (double*)info.ptr, index, upper_bound); 574 | }); 575 | problem.def( 576 | "GetParameterLowerBound", 577 | [](ceres::Problem& myself, py::array_t& np_arr, int index) { 578 | py::buffer_info info = np_arr.request(); 579 | return myself.GetParameterLowerBound((double*)info.ptr, index); 580 | }); 581 | problem.def( 582 | "GetParameterUpperBound", 583 | [](ceres::Problem& myself, py::array_t& np_arr, int index) { 584 | py::buffer_info info = np_arr.request(); 585 | return myself.GetParameterUpperBound((double*)info.ptr, index); 586 | }); 587 | problem.def("GetParameterization", 588 | [](ceres::Problem& myself, py::array_t& np_arr) { 589 | py::buffer_info info = np_arr.request(); 590 | myself.GetParameterization((double*)info.ptr); 591 | }); 592 | problem.def("SetParameterization", 593 | [](ceres::Problem& myself, 594 | py::array_t& np_arr, 595 | ceres::LocalParameterization* local_parameterization) { 596 | py::buffer_info info = np_arr.request(); 597 | myself.SetParameterization((double*)info.ptr, 598 | local_parameterization); 599 | }); 600 | problem.def("ParameterBlockSize", 601 | [](ceres::Problem& myself, py::array_t& np_arr) { 602 | py::buffer_info info = np_arr.request(); 603 | return myself.ParameterBlockSize((double*)info.ptr); 604 | }); 605 | problem.def("HasParameterBlock", 606 | [](ceres::Problem& myself, py::array_t& np_arr) { 607 | py::buffer_info info = np_arr.request(); 608 | return myself.HasParameterBlock((double*)info.ptr); 609 | }); 610 | problem.def( 611 | "AddResidualBlock", 612 | [](ceres::Problem& myself, 613 | ceres::CostFunction* cost, 614 | ceres::LossFunction* loss, 615 | py::array_t& values) { 616 | // Should we even do this error checking? 617 | double* pointer = ParseNumpyData(values); 618 | return ResidualBlockIDWrapper( 619 | myself.AddResidualBlock(cost, loss, pointer)); 620 | }, 621 | py::keep_alive<1, 2>(), // CostFunction 622 | py::keep_alive<1, 3>()); // LossFunction 623 | 624 | problem.def( 625 | "AddResidualBlock", 626 | [](ceres::Problem& myself, 627 | ceres::CostFunction* cost, 628 | ceres::LossFunction* loss, 629 | py::array_t& values1, 630 | py::array_t& values2) { 631 | double* pointer1 = ParseNumpyData(values1); 632 | double* pointer2 = ParseNumpyData(values2); 633 | return ResidualBlockIDWrapper( 634 | myself.AddResidualBlock(cost, loss, pointer1, pointer2)); 635 | }, 636 | py::keep_alive<1, 2>(), // Cost Function 637 | py::keep_alive<1, 3>()); // Loss Function 638 | problem.def( 639 | "AddResidualBlock", 640 | [](ceres::Problem& myself, 641 | ceres::CostFunction* cost, 642 | ceres::LossFunction* loss, 643 | py::array_t& values1, 644 | py::array_t& values2, 645 | py::array_t& values3) { 646 | double* pointer1 = ParseNumpyData(values1); 647 | double* pointer2 = ParseNumpyData(values2); 648 | double* pointer3 = ParseNumpyData(values3); 649 | return ResidualBlockIDWrapper( 650 | myself.AddResidualBlock(cost, loss, pointer1, pointer2, pointer3)); 651 | }, 652 | py::keep_alive<1, 2>(), // Cost Function 653 | py::keep_alive<1, 3>()); // Loss Function 654 | problem.def( 655 | "AddResidualBlock", 656 | [](ceres::Problem& myself, 657 | ceres::CostFunction* cost, 658 | ceres::LossFunction* loss, 659 | py::array_t& values1, 660 | py::array_t& values2, 661 | py::array_t& values3, 662 | py::array_t& values4) { 663 | double* pointer1 = ParseNumpyData(values1); 664 | double* pointer2 = ParseNumpyData(values2); 665 | double* pointer3 = ParseNumpyData(values3); 666 | double* pointer4 = ParseNumpyData(values4); 667 | return ResidualBlockIDWrapper(myself.AddResidualBlock( 668 | cost, loss, pointer1, pointer2, pointer3, pointer4)); 669 | }, 670 | py::keep_alive<1, 2>(), // Cost Function 671 | py::keep_alive<1, 3>()); // Loss Function 672 | 673 | problem.def( 674 | "AddResidualBlock", 675 | [](ceres::Problem& myself, 676 | ceres::CostFunction* cost, 677 | ceres::LossFunction* loss, 678 | py::array_t& values1, 679 | py::array_t& values2, 680 | py::array_t& values3, 681 | py::array_t& values4, 682 | py::array_t& values5) { 683 | double* pointer1 = ParseNumpyData(values1); 684 | double* pointer2 = ParseNumpyData(values2); 685 | double* pointer3 = ParseNumpyData(values3); 686 | double* pointer4 = ParseNumpyData(values4); 687 | double* pointer5 = ParseNumpyData(values5); 688 | return ResidualBlockIDWrapper(myself.AddResidualBlock( 689 | cost, loss, pointer1, pointer2, pointer3, pointer4, pointer5)); 690 | }, 691 | py::keep_alive<1, 2>(), // Cost Function 692 | py::keep_alive<1, 3>()); // Loss Function 693 | 694 | problem.def( 695 | "AddResidualBlock", 696 | [](ceres::Problem& myself, 697 | ceres::CostFunction* cost, 698 | ceres::LossFunction* loss, 699 | std::vector>& values) { 700 | std::vector pointer_values; 701 | for (int idx = 0; idx < values.size(); ++idx) { 702 | pointer_values.push_back(ParseNumpyData(values[idx])); 703 | } 704 | return ResidualBlockIDWrapper( 705 | myself.AddResidualBlock(cost, loss, pointer_values)); 706 | }, 707 | py::keep_alive<1, 2>(), // Cost Function 708 | py::keep_alive<1, 3>()); // Loss Function 709 | 710 | problem.def( 711 | "AddParameterBlock", 712 | [](ceres::Problem& myself, py::array_t& values, int size) { 713 | double* pointer = ParseNumpyData(values); 714 | myself.AddParameterBlock(pointer, size); 715 | }); 716 | 717 | problem.def( 718 | "AddParameterBlock", 719 | [](ceres::Problem& myself, 720 | py::array_t& values, 721 | int size, 722 | ceres::LocalParameterization* local_parameterization) { 723 | double* pointer = ParseNumpyData(values); 724 | myself.AddParameterBlock(pointer, size, local_parameterization); 725 | }, 726 | py::keep_alive<1, 4>() // LocalParameterization 727 | ); 728 | 729 | problem.def("RemoveParameterBlock", 730 | [](ceres::Problem& myself, py::array_t& values) { 731 | double* pointer = ParseNumpyData(values); 732 | myself.RemoveParameterBlock(pointer); 733 | }); 734 | 735 | problem.def( 736 | "RemoveResidualBlock", 737 | [](ceres::Problem& myself, ResidualBlockIDWrapper& residual_block_id) { 738 | myself.RemoveResidualBlock(residual_block_id.id); 739 | }); 740 | 741 | py::class_ solver_options(m, "SolverOptions"); 742 | using s_options = ceres::Solver::Options; 743 | solver_options.def(py::init<>()); 744 | solver_options.def("IsValid", &s_options::IsValid); 745 | solver_options.def_readwrite("minimizer_type", &s_options::minimizer_type); 746 | solver_options.def_readwrite("line_search_direction_type", 747 | &s_options::line_search_direction_type); 748 | solver_options.def_readwrite("line_search_type", 749 | &s_options::line_search_type); 750 | solver_options.def_readwrite("nonlinear_conjugate_gradient_type", 751 | &s_options::nonlinear_conjugate_gradient_type); 752 | solver_options.def_readwrite("max_lbfgs_rank", &s_options::max_lbfgs_rank); 753 | solver_options.def_readwrite( 754 | "use_approximate_eigenvalue_bfgs_scaling", 755 | &s_options::use_approximate_eigenvalue_bfgs_scaling); 756 | solver_options.def_readwrite("line_search_interpolation_type", 757 | &s_options::line_search_interpolation_type); 758 | solver_options.def_readwrite("min_line_search_step_size", 759 | &s_options::min_line_search_step_size); 760 | solver_options.def_readwrite( 761 | "line_search_sufficient_function_decrease", 762 | &s_options::line_search_sufficient_function_decrease); 763 | solver_options.def_readwrite("max_line_search_step_contraction", 764 | &s_options::max_line_search_step_contraction); 765 | solver_options.def_readwrite("min_line_search_step_contraction", 766 | &s_options::min_line_search_step_contraction); 767 | solver_options.def_readwrite( 768 | "max_num_line_search_step_size_iterations", 769 | &s_options::max_num_line_search_step_size_iterations); 770 | solver_options.def_readwrite( 771 | "max_num_line_search_direction_restarts", 772 | &s_options::max_num_line_search_direction_restarts); 773 | solver_options.def_readwrite( 774 | "line_search_sufficient_curvature_decrease", 775 | &s_options::line_search_sufficient_curvature_decrease); 776 | solver_options.def_readwrite("max_line_search_step_expansion", 777 | &s_options::max_line_search_step_expansion); 778 | solver_options.def_readwrite("trust_region_strategy_type", 779 | &s_options::trust_region_strategy_type); 780 | solver_options.def_readwrite("dogleg_type", &s_options::dogleg_type); 781 | solver_options.def_readwrite("use_nonmonotonic_steps", 782 | &s_options::use_nonmonotonic_steps); 783 | solver_options.def_readwrite("max_consecutive_nonmonotonic_steps", 784 | &s_options::max_consecutive_nonmonotonic_steps); 785 | solver_options.def_readwrite("max_num_iterations", 786 | &s_options::max_num_iterations); 787 | solver_options.def_readwrite("max_solver_time_in_seconds", 788 | &s_options::max_solver_time_in_seconds); 789 | solver_options.def_readwrite("num_threads", &s_options::num_threads); 790 | solver_options.def_readwrite("initial_trust_region_radius", 791 | &s_options::initial_trust_region_radius); 792 | solver_options.def_readwrite("max_trust_region_radius", 793 | &s_options::max_trust_region_radius); 794 | solver_options.def_readwrite("min_trust_region_radius", 795 | &s_options::min_trust_region_radius); 796 | solver_options.def_readwrite("min_relative_decrease", 797 | &s_options::min_relative_decrease); 798 | solver_options.def_readwrite("min_lm_diagonal", &s_options::min_lm_diagonal); 799 | solver_options.def_readwrite("max_lm_diagonal", &s_options::max_lm_diagonal); 800 | solver_options.def_readwrite("max_num_consecutive_invalid_steps", 801 | &s_options::max_num_consecutive_invalid_steps); 802 | solver_options.def_readwrite("function_tolerance", 803 | &s_options::function_tolerance); 804 | solver_options.def_readwrite("gradient_tolerance", 805 | &s_options::gradient_tolerance); 806 | solver_options.def_readwrite("parameter_tolerance", 807 | &s_options::parameter_tolerance); 808 | solver_options.def_readwrite("linear_solver_type", 809 | &s_options::linear_solver_type); 810 | solver_options.def_readwrite("preconditioner_type", 811 | &s_options::preconditioner_type); 812 | solver_options.def_readwrite("visibility_clustering_type", 813 | &s_options::visibility_clustering_type); 814 | solver_options.def_readwrite("dense_linear_algebra_library_type", 815 | &s_options::dense_linear_algebra_library_type); 816 | solver_options.def_readwrite("sparse_linear_algebra_library_type", 817 | &s_options::sparse_linear_algebra_library_type); 818 | solver_options.def_readwrite("use_explicit_schur_complement", 819 | &s_options::use_explicit_schur_complement); 820 | solver_options.def_readwrite("use_postordering", 821 | &s_options::use_postordering); 822 | solver_options.def_readwrite("dynamic_sparsity", 823 | &s_options::dynamic_sparsity); 824 | solver_options.def_readwrite("use_mixed_precision_solves", 825 | &s_options::use_mixed_precision_solves); 826 | solver_options.def_readwrite("max_num_refinement_iterations", 827 | &s_options::max_num_refinement_iterations); 828 | solver_options.def_readwrite("use_inner_iterations", 829 | &s_options::use_inner_iterations); 830 | solver_options.def_readwrite("minimizer_progress_to_stdout", 831 | &s_options::minimizer_progress_to_stdout); 832 | 833 | py::class_( 834 | m, "CostFunction") 835 | .def(py::init<>()) 836 | .def("num_residuals", &ceres::CostFunction::num_residuals) 837 | .def("num_parameter_blocks", 838 | [](ceres::CostFunction& myself) { 839 | return myself.parameter_block_sizes().size(); 840 | }) 841 | .def("parameter_block_sizes", 842 | &ceres::CostFunction::parameter_block_sizes, 843 | py::return_value_policy::reference) 844 | .def("set_num_residuals", &PyCostFunction::set_num_residuals) 845 | .def("set_parameter_block_sizes", 846 | [](ceres::CostFunction& myself, std::vector& sizes) { 847 | for (auto s : sizes) { 848 | const_cast&>(myself.parameter_block_sizes()) 849 | .push_back(s); 850 | } 851 | }); 852 | 853 | py::class_(m, "LossFunction") 854 | .def(py::init<>()); 855 | 856 | py::class_(m, "TrivialLoss") 857 | .def(py::init<>()); 858 | 859 | py::class_(m, "HuberLoss") 860 | .def(py::init()); 861 | 862 | py::class_(m, "SoftLOneLoss") 863 | .def(py::init()); 864 | 865 | py::class_(m, "CauchyLoss") 866 | .def(py::init()); 867 | 868 | py::class_ solver_summary(m, "Summary"); 869 | using s_summary = ceres::Solver::Summary; 870 | solver_summary.def(py::init<>()); 871 | solver_summary.def("BriefReport", &ceres::Solver::Summary::BriefReport); 872 | solver_summary.def("FullReport", &ceres::Solver::Summary::FullReport); 873 | solver_summary.def("IsSolutionUsable", 874 | &ceres::Solver::Summary::IsSolutionUsable); 875 | solver_summary.def_readwrite("initial_cost", 876 | &ceres::Solver::Summary::initial_cost); 877 | solver_summary.def_readwrite("final_cost", 878 | &ceres::Solver::Summary::final_cost); 879 | solver_summary.def_readwrite("minimizer_type", &s_summary::minimizer_type); 880 | solver_summary.def_readwrite("line_search_direction_type", 881 | &s_summary::line_search_direction_type); 882 | solver_summary.def_readwrite("line_search_type", 883 | &s_summary::line_search_type); 884 | solver_summary.def_readwrite("nonlinear_conjugate_gradient_type", 885 | &s_summary::nonlinear_conjugate_gradient_type); 886 | solver_summary.def_readwrite("max_lbfgs_rank", &s_summary::max_lbfgs_rank); 887 | solver_summary.def_readwrite("line_search_interpolation_type", 888 | &s_summary::line_search_interpolation_type); 889 | solver_summary.def_readwrite("trust_region_strategy_type", 890 | &s_summary::trust_region_strategy_type); 891 | solver_summary.def_readwrite("dogleg_type", &s_summary::dogleg_type); 892 | solver_summary.def_readwrite("visibility_clustering_type", 893 | &s_summary::visibility_clustering_type); 894 | solver_summary.def_readwrite("dense_linear_algebra_library_type", 895 | &s_summary::dense_linear_algebra_library_type); 896 | solver_summary.def_readwrite("sparse_linear_algebra_library_type", 897 | &s_summary::sparse_linear_algebra_library_type); 898 | solver_summary.def_readwrite("termination_type", 899 | &s_summary::termination_type); 900 | solver_summary.def_readwrite("message", &s_summary::message); 901 | solver_summary.def_readwrite("initial_cost", &s_summary::initial_cost); 902 | solver_summary.def_readwrite("final_cost", &s_summary::final_cost); 903 | solver_summary.def_readwrite("fixed_cost", &s_summary::fixed_cost); 904 | solver_summary.def_readwrite("iterations", &s_summary::iterations); 905 | solver_summary.def_readwrite("num_successful_steps", 906 | &s_summary::num_successful_steps); 907 | solver_summary.def_readwrite("num_unsuccessful_steps", 908 | &s_summary::num_unsuccessful_steps); 909 | solver_summary.def_readwrite("num_inner_iteration_steps", 910 | &s_summary::num_inner_iteration_steps); 911 | solver_summary.def_readwrite("num_line_search_steps", 912 | &s_summary::num_line_search_steps); 913 | solver_summary.def_readwrite("preprocessor_time_in_seconds", 914 | &s_summary::preprocessor_time_in_seconds); 915 | solver_summary.def_readwrite("minimizer_time_in_seconds", 916 | &s_summary::minimizer_time_in_seconds); 917 | solver_summary.def_readwrite("postprocessor_time_in_seconds", 918 | &s_summary::postprocessor_time_in_seconds); 919 | solver_summary.def_readwrite("total_time_in_seconds", 920 | &s_summary::total_time_in_seconds); 921 | solver_summary.def_readwrite("linear_solver_time_in_seconds", 922 | &s_summary::linear_solver_time_in_seconds); 923 | solver_summary.def_readwrite("num_linear_solves", 924 | &s_summary::num_linear_solves); 925 | solver_summary.def_readwrite("residual_evaluation_time_in_seconds", 926 | &s_summary::residual_evaluation_time_in_seconds); 927 | 928 | solver_summary.def_readwrite("num_residual_evaluations", 929 | &s_summary::num_residual_evaluations); 930 | 931 | solver_summary.def_readwrite("jacobian_evaluation_time_in_seconds", 932 | &s_summary::jacobian_evaluation_time_in_seconds); 933 | 934 | solver_summary.def_readwrite("num_jacobian_evaluations", 935 | &s_summary::num_jacobian_evaluations); 936 | solver_summary.def_readwrite("inner_iteration_time_in_seconds", 937 | &s_summary::inner_iteration_time_in_seconds); 938 | solver_summary.def_readwrite( 939 | "line_search_cost_evaluation_time_in_seconds", 940 | &s_summary::line_search_cost_evaluation_time_in_seconds); 941 | solver_summary.def_readwrite( 942 | "line_search_gradient_evaluation_time_in_seconds", 943 | &s_summary::line_search_gradient_evaluation_time_in_seconds); 944 | solver_summary.def_readwrite( 945 | "line_search_polynomial_minimization_time_in_seconds", 946 | &s_summary::line_search_polynomial_minimization_time_in_seconds); 947 | solver_summary.def_readwrite("line_search_total_time_in_seconds", 948 | &s_summary::line_search_total_time_in_seconds); 949 | solver_summary.def_readwrite("num_parameter_blocks", 950 | &s_summary::num_parameter_blocks); 951 | solver_summary.def_readwrite("num_parameters", &s_summary::num_parameters); 952 | solver_summary.def_readwrite("num_effective_parameters", 953 | &s_summary::num_effective_parameters); 954 | solver_summary.def_readwrite("num_residual_blocks", 955 | &s_summary::num_residual_blocks); 956 | solver_summary.def_readwrite("num_residuals", &s_summary::num_residuals); 957 | solver_summary.def_readwrite("num_parameter_blocks_reduced", 958 | &s_summary::num_parameter_blocks_reduced); 959 | solver_summary.def_readwrite("num_parameters_reduced", 960 | &s_summary::num_parameters_reduced); 961 | solver_summary.def_readwrite("num_effective_parameters_reduced", 962 | &s_summary::num_effective_parameters_reduced); 963 | solver_summary.def_readwrite("num_residual_blocks_reduced", 964 | &s_summary::num_residual_blocks_reduced); 965 | solver_summary.def_readwrite("num_residuals_reduced", 966 | &s_summary::num_residuals_reduced); 967 | solver_summary.def_readwrite("is_constrained", &s_summary::is_constrained); 968 | solver_summary.def_readwrite("num_threads_given", 969 | &s_summary::num_threads_given); 970 | solver_summary.def_readwrite("num_threads_used", 971 | &s_summary::num_threads_used); 972 | solver_summary.def_readwrite("linear_solver_ordering_given", 973 | &s_summary::linear_solver_ordering_given); 974 | solver_summary.def_readwrite("linear_solver_ordering_used", 975 | &s_summary::linear_solver_ordering_used); 976 | solver_summary.def_readwrite("schur_structure_given", 977 | &s_summary::schur_structure_given); 978 | solver_summary.def_readwrite("schur_structure_used", 979 | &s_summary::schur_structure_used); 980 | solver_summary.def_readwrite("inner_iterations_given", 981 | &s_summary::inner_iterations_given); 982 | solver_summary.def_readwrite("inner_iterations_used", 983 | &s_summary::inner_iterations_used); 984 | solver_summary.def_readwrite("inner_iteration_ordering_given", 985 | &s_summary::inner_iteration_ordering_given); 986 | solver_summary.def_readwrite("inner_iteration_ordering_used", 987 | &s_summary::inner_iteration_ordering_used); 988 | 989 | py::class_ iteration_summary(m, "IterationSummary"); 990 | using it_sum = ceres::IterationSummary; 991 | iteration_summary.def(py::init<>()); 992 | iteration_summary.def_readonly("iteration", &it_sum::iteration); 993 | iteration_summary.def_readonly("step_is_valid", &it_sum::step_is_valid); 994 | iteration_summary.def_readonly("step_is_nonmonotonic", 995 | &it_sum::step_is_nonmonotonic); 996 | iteration_summary.def_readonly("step_is_succesful", 997 | &it_sum::step_is_successful); 998 | iteration_summary.def_readonly("cost", &it_sum::cost); 999 | iteration_summary.def_readonly("cost_change", &it_sum::cost_change); 1000 | iteration_summary.def_readonly("gradient_max_norm", 1001 | &it_sum::gradient_max_norm); 1002 | iteration_summary.def_readonly("gradient_norm", &it_sum::gradient_norm); 1003 | iteration_summary.def_readonly("step_norm", &it_sum::step_norm); 1004 | iteration_summary.def_readonly("relative_decrease", 1005 | &it_sum::relative_decrease); 1006 | iteration_summary.def_readonly("trust_region_radius", 1007 | &it_sum::trust_region_radius); 1008 | iteration_summary.def_readonly("eta", &it_sum::eta); 1009 | iteration_summary.def_readonly("step_size", &it_sum::step_size); 1010 | iteration_summary.def_readonly("line_search_function_evaluations", 1011 | &it_sum::line_search_function_evaluations); 1012 | iteration_summary.def_readonly("line_search_gradient_evaluations", 1013 | &it_sum::line_search_gradient_evaluations); 1014 | iteration_summary.def_readonly("line_search_iterations", 1015 | &it_sum::line_search_iterations); 1016 | iteration_summary.def_readonly("linear_solver_iterations", 1017 | &it_sum::linear_solver_iterations); 1018 | iteration_summary.def_readonly("iteration_time_in_seconds", 1019 | &it_sum::iteration_time_in_seconds); 1020 | iteration_summary.def_readonly("step_solver_time_in_seconds", 1021 | &it_sum::step_solver_time_in_seconds); 1022 | iteration_summary.def_readonly("cumulative_time_in_seconds", 1023 | &it_sum::cumulative_time_in_seconds); 1024 | 1025 | py::class_(m, "IterationCallback") 1027 | .def(py::init<>()); 1028 | 1029 | py::class_(m, "EvaluationCallback") 1031 | .def(py::init<>()); 1032 | 1033 | py::class_(m, "FirstOrderFunction") 1035 | .def(py::init<>()); 1036 | 1037 | py::class_( 1039 | m, "LocalParameterization") 1040 | .def(py::init<>()) 1041 | .def("GlobalSize", &ceres::LocalParameterization::GlobalSize) 1042 | .def("LocalSize", &ceres::LocalParameterization::LocalSize); 1043 | 1044 | py::class_( 1045 | m, "IdentityParameterization") 1046 | .def(py::init()); 1047 | py::class_( 1048 | m, "QuaternionParameterization") 1049 | .def(py::init<>()); 1050 | py::class_(m, 1052 | "HomogeneousVectorParameterization") 1053 | .def(py::init()); 1054 | py::class_(m, "EigenQuaternionParameterization") 1056 | .def(py::init<>()); 1057 | py::class_( 1058 | m, "SubsetParameterization") 1059 | .def(py::init&>()); 1060 | 1061 | py::class_ grad_problem(m, "GradientProblem"); 1062 | grad_problem.def(py::init([](ceres::FirstOrderFunction* func) { 1063 | ceres::FirstOrderFunction* wrap = 1064 | new FirstOrderFunctionWrapper(func); 1065 | return ceres::GradientProblem(wrap); 1066 | }), 1067 | py::keep_alive<1, 2>() // FirstOrderFunction 1068 | ); 1069 | 1070 | grad_problem.def("NumParameters", &ceres::GradientProblem::NumParameters); 1071 | 1072 | py::class_ grad_options( 1073 | m, "GradientProblemOptions"); 1074 | using g_options = ceres::GradientProblemSolver::Options; 1075 | grad_options.def(py::init<>()); 1076 | grad_options.def("IsValid", &g_options::IsValid); 1077 | grad_options.def_readwrite("line_search_direction_type", 1078 | &g_options::line_search_direction_type); 1079 | grad_options.def_readwrite("line_search_type", &g_options::line_search_type); 1080 | grad_options.def_readwrite("nonlinear_conjugate_gradient_type", 1081 | &g_options::nonlinear_conjugate_gradient_type); 1082 | grad_options.def_readwrite("max_lbfgs_rank", &g_options::max_lbfgs_rank); 1083 | grad_options.def_readwrite( 1084 | "use_approximate_eigenvalue_bfgs_scaling", 1085 | &g_options::use_approximate_eigenvalue_bfgs_scaling); 1086 | grad_options.def_readwrite("line_search_interpolation_type", 1087 | &g_options::line_search_interpolation_type); 1088 | grad_options.def_readwrite("min_line_search_step_size", 1089 | &g_options::min_line_search_step_size); 1090 | grad_options.def_readwrite( 1091 | "line_search_sufficient_function_decrease", 1092 | &g_options::line_search_sufficient_function_decrease); 1093 | grad_options.def_readwrite("max_line_search_step_contraction", 1094 | &g_options::max_line_search_step_contraction); 1095 | grad_options.def_readwrite("min_line_search_step_contraction", 1096 | &g_options::min_line_search_step_contraction); 1097 | grad_options.def_readwrite( 1098 | "max_num_line_search_step_size_iterations", 1099 | &g_options::max_num_line_search_step_size_iterations); 1100 | grad_options.def_readwrite( 1101 | "max_num_line_search_direction_restarts", 1102 | &g_options::max_num_line_search_direction_restarts); 1103 | grad_options.def_readwrite( 1104 | "line_search_sufficient_curvature_decrease", 1105 | &g_options::line_search_sufficient_curvature_decrease); 1106 | grad_options.def_readwrite("max_line_search_step_expansion", 1107 | &g_options::max_line_search_step_expansion); 1108 | grad_options.def_readwrite("max_num_iterations", 1109 | &g_options::max_num_iterations); 1110 | grad_options.def_readwrite("max_solver_time_in_seconds", 1111 | &g_options::max_solver_time_in_seconds); 1112 | grad_options.def_readwrite("function_tolerance", 1113 | &g_options::function_tolerance); 1114 | grad_options.def_readwrite("gradient_tolerance", 1115 | &g_options::gradient_tolerance); 1116 | grad_options.def_readwrite("parameter_tolerance", 1117 | &g_options::parameter_tolerance); 1118 | grad_options.def_readwrite("minimizer_progress_to_stdout", 1119 | &g_options::minimizer_progress_to_stdout); 1120 | 1121 | py::class_ grad_summary( 1122 | m, "GradientProblemSummary"); 1123 | using g_sum = ceres::GradientProblemSolver::Summary; 1124 | grad_summary.def(py::init<>()); 1125 | grad_summary.def("BriefReport", 1126 | &ceres::GradientProblemSolver::Summary::BriefReport); 1127 | grad_summary.def("FullReport", 1128 | &ceres::GradientProblemSolver::Summary::FullReport); 1129 | grad_summary.def("IsSolutionUsable", 1130 | &ceres::GradientProblemSolver::Summary::IsSolutionUsable); 1131 | grad_summary.def_readwrite( 1132 | "initial_cost", &ceres::GradientProblemSolver::Summary::initial_cost); 1133 | grad_summary.def_readwrite( 1134 | "final_cost", &ceres::GradientProblemSolver::Summary::final_cost); 1135 | 1136 | // GradientProblem Solve 1137 | m.def("Solve", 1138 | [](const ceres::GradientProblemSolver::Options& options, 1139 | const ceres::GradientProblem& problem, 1140 | py::array_t& np_params, 1141 | ceres::GradientProblemSolver::Summary* summary) { 1142 | double* param_ptr = ParseNumpyData(np_params); 1143 | py::gil_scoped_release release; 1144 | ceres::Solve(options, problem, param_ptr, summary); 1145 | }); 1146 | 1147 | // The main Solve function 1148 | m.def("Solve", 1149 | overload_cast_()(&ceres::Solve), 1152 | py::call_guard()); 1153 | 1154 | py::class_ crs_mat(m, "CRSMatrix"); 1155 | crs_mat.def(py::init<>()); 1156 | crs_mat.def_readwrite("num_cols", &ceres::CRSMatrix::num_cols); 1157 | crs_mat.def_readwrite("num_rows", &ceres::CRSMatrix::num_rows); 1158 | crs_mat.def_readwrite("cols", &ceres::CRSMatrix::cols); 1159 | crs_mat.def_readwrite("rows", &ceres::CRSMatrix::rows); 1160 | crs_mat.def_readwrite("values", &ceres::CRSMatrix::values); 1161 | 1162 | py::class_ numdiff_options(m, 1163 | "NumericDiffOptions"); 1164 | numdiff_options.def(py::init<>()); 1165 | numdiff_options.def_readwrite("relative_step_size", 1166 | &ceres::NumericDiffOptions::relative_step_size); 1167 | numdiff_options.def_readwrite( 1168 | "ridders_relative_initial_step_size", 1169 | &ceres::NumericDiffOptions::ridders_relative_initial_step_size); 1170 | numdiff_options.def_readwrite( 1171 | "max_num_ridders_extrapolations", 1172 | &ceres::NumericDiffOptions::max_num_ridders_extrapolations); 1173 | numdiff_options.def_readwrite("ridders_epsilon", 1174 | &ceres::NumericDiffOptions::ridders_epsilon); 1175 | numdiff_options.def_readwrite( 1176 | "ridders_step_shrink_factor", 1177 | &ceres::NumericDiffOptions::ridders_step_shrink_factor); 1178 | 1179 | py::class_ probe_results( 1180 | m, "ProbeResults"); 1181 | probe_results.def(py::init<>()); 1182 | probe_results.def_readwrite( 1183 | "return_value", &ceres::GradientChecker::ProbeResults::return_value); 1184 | probe_results.def_readwrite("residuals", 1185 | &ceres::GradientChecker::ProbeResults::residuals); 1186 | probe_results.def_readwrite("jacobians", 1187 | &ceres::GradientChecker::ProbeResults::jacobians); 1188 | probe_results.def_readwrite( 1189 | "local_jacobians", 1190 | &ceres::GradientChecker::ProbeResults::local_jacobians); 1191 | probe_results.def_readwrite( 1192 | "numeric_jacobians", 1193 | &ceres::GradientChecker::ProbeResults::numeric_jacobians); 1194 | probe_results.def_readwrite( 1195 | "local_numeric_jacobians", 1196 | &ceres::GradientChecker::ProbeResults::local_numeric_jacobians); 1197 | probe_results.def_readwrite( 1198 | "maximum_relative_error", 1199 | &ceres::GradientChecker::ProbeResults::maximum_relative_error); 1200 | probe_results.def_readwrite("error_log", 1201 | &ceres::GradientChecker::ProbeResults::error_log); 1202 | 1203 | py::class_ gradient_checker(m, "GradientChecker"); 1204 | gradient_checker.def( 1205 | py::init*, 1207 | const ceres::NumericDiffOptions>()); 1208 | gradient_checker.def( 1209 | "Probe", 1210 | [](ceres::GradientChecker& myself, 1211 | std::vector>& parameters, 1212 | double relative_precision, 1213 | ceres::GradientChecker::ProbeResults* results) { 1214 | std::vector param_pointers; 1215 | for (auto& p : parameters) { 1216 | param_pointers.push_back(ParseNumpyData(p)); 1217 | } 1218 | return myself.Probe(param_pointers.data(), relative_precision, results); 1219 | }); 1220 | 1221 | py::class_ normal_prior(m, "NormalPrior"); 1222 | normal_prior.def(py::init()); 1223 | 1224 | py::class_(m, "Context") 1225 | .def(py::init<>()) 1226 | .def("Create", &ceres::Context::Create); 1227 | 1228 | py::class_ cov_opt(m, "CovarianceOptions"); 1229 | using c_opt = ceres::Covariance::Options; 1230 | cov_opt.def_readwrite("sparse_linear_algebra_library_type", 1231 | &c_opt::sparse_linear_algebra_library_type); 1232 | cov_opt.def_readwrite("algorithm_type", &c_opt::algorithm_type); 1233 | cov_opt.def_readwrite("min_reciprocal_condition_number", 1234 | &c_opt::min_reciprocal_condition_number); 1235 | cov_opt.def_readwrite("null_space_rank", &c_opt::null_space_rank); 1236 | cov_opt.def_readwrite("num_threads", &c_opt::num_threads); 1237 | cov_opt.def_readwrite("apply_loss_function", &c_opt::apply_loss_function); 1238 | 1239 | py::class_ cov(m, "Covariance"); 1240 | cov.def(py::init()); 1241 | // cov.def("Compute",overload_cast_>&, 1243 | // ceres::Problem*>()(&ceres::Covariance::Compute)); 1244 | // cov.def("Compute",overload_cast_&, 1245 | // ceres::Problem*>()(&ceres::Covariance::Compute)); 1246 | 1247 | py::class_ cond_cost( 1248 | m, "ConditionedCostFunction"); 1249 | cond_cost.def( 1250 | py::init([](ceres::CostFunction* wrapped_cost_function, 1251 | const std::vector& conditioners) { 1252 | return new ceres::ConditionedCostFunction( 1253 | wrapped_cost_function, conditioners, ceres::DO_NOT_TAKE_OWNERSHIP); 1254 | })); 1255 | 1256 | add_pybinded_ceres_examples(m); 1257 | add_custom_cost_functions(m); 1258 | 1259 | #ifdef WITH_PYTORCH 1260 | add_torch_functionality(m); 1261 | #endif 1262 | 1263 | // Untested 1264 | 1265 | // Things below this line are wrapped ,but are rarely used even in C++ ceres. 1266 | // and thus are not tested. 1267 | 1268 | // py::class_(m, "ScaledLoss") 1269 | // .def(py::init(), 1270 | // py::arg("ownership") = ceres::Ownership::DO_NOT_TAKE_OWNERSHIP); 1271 | } 1272 | -------------------------------------------------------------------------------- /python_bindings/pytorch_cost_function.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "pytorch_cost_function.h" 4 | 5 | PyTorchCostFunction::PyTorchCostFunction( 6 | int num_residuals, const std::vector& param_sizes) { 7 | set_num_residuals(num_residuals); 8 | *mutable_parameter_block_sizes() = param_sizes; 9 | inputs.resize(param_sizes.size()); 10 | } 11 | 12 | bool PyTorchCostFunction::Evaluate(double const* const* parameters, 13 | double* residuals, 14 | double** jacobians) const { 15 | for (size_t i = 0; i < inputs.size(); ++i) { 16 | torch::ArrayRef ref(parameters[i], parameter_block_sizes()[i]); 17 | // if jacobians exist then check if the specific jacobian is set. else false 18 | bool require_grad = jacobians ? jacobians[i] != nullptr : false; 19 | inputs[i] = torch::tensor(ref, torch::dtype(torch::kFloat64)) 20 | .set_requires_grad(require_grad); 21 | } 22 | 23 | torch::Tensor residual = module.forward(inputs).toTensor(); 24 | for (int i = 0; i < num_residuals(); i++) { 25 | residuals[i] = residual[i].item(); 26 | } 27 | 28 | if (jacobians) { 29 | residual.backward(); 30 | for (int i = 0; i < parameter_block_sizes().size(); ++i) { 31 | int param_size = parameter_block_sizes()[i]; 32 | auto& grad = inputs[i].toTensor().grad(); 33 | if (jacobians[i] && grad.defined()) { 34 | for (int j = 0; j < param_size; j++) { 35 | for (int k = 0; k < num_residuals(); k++) { 36 | jacobians[k][k * param_size + j] = 37 | grad[k * param_size + j].item(); 38 | } 39 | } 40 | } 41 | } 42 | } 43 | 44 | return true; 45 | } 46 | -------------------------------------------------------------------------------- /python_bindings/pytorch_cost_function.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef WITH_PYTORCH 4 | 5 | #include 6 | #include 7 | 8 | class PyTorchCostFunction : public ceres::CostFunction { 9 | public: 10 | PyTorchCostFunction(int num_residuals, 11 | const std::vector& param_sizes); 12 | 13 | virtual bool Evaluate(double const* const* parameters, 14 | double* residuals, 15 | double** jacobians) const override; 16 | 17 | public: 18 | mutable torch::jit::script::Module module; 19 | mutable std::vector inputs; 20 | }; 21 | 22 | #endif // WITH_PYTORCH 23 | -------------------------------------------------------------------------------- /python_bindings/pytorch_module.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // forward decl needed for residual block return 5 | #include 6 | 7 | #ifdef WITH_PYTORCH 8 | 9 | #include 10 | #include // One-stop header. 11 | 12 | #include "pytorch_cost_function.h" 13 | 14 | namespace py = pybind11; 15 | 16 | void DoPtr(torch::jit::script::Module* m) { 17 | std::cout << "ptr works" << std::endl; 18 | } 19 | 20 | void DoCopy(torch::jit::script::Module m) { 21 | std::cout << "copy works " << std::endl; 22 | } 23 | 24 | void TestTensor(torch::Tensor a) { 25 | std::cout << "Tensor Success " << a.data_ptr() << std::endl; 26 | double* ptr = a.data_ptr(); 27 | std::cout << "Double ptr " << ptr << " after ptr " << std::endl; 28 | } 29 | 30 | void TestLiveRun(const std::string& code) { 31 | std::cout << "jit compiling code " << std::endl; 32 | 33 | auto m = torch::jit::compile(code); 34 | std::vector inputs; 35 | double param = 1; 36 | torch::ArrayRef ref(¶m, 1); 37 | inputs.push_back(torch::tensor(ref, torch::dtype(torch::kFloat64)) 38 | .set_requires_grad(true)); 39 | // inputs.push_back(torch::ones({1}).set_requires_grad(true)); 40 | 41 | // Execute the model and turn its output into a tensor. 42 | torch::jit::IValue out = m->run_method("forward", inputs[0]); 43 | 44 | auto output = out.toTensor(); 45 | 46 | std::cout << out << std::endl; 47 | 48 | output.backward(); 49 | 50 | auto& grad = inputs[0].toTensor().grad(); 51 | 52 | std::cout << grad << std::endl; 53 | } 54 | 55 | ceres::CostFunction* CreateTorchCostFunction( 56 | const std::string& filepath, 57 | int num_residuals, 58 | const std::vector& param_sizes) { 59 | PyTorchCostFunction* cost_func = 60 | new PyTorchCostFunction(num_residuals, param_sizes); 61 | try { 62 | cost_func->module = torch::jit::load(filepath); 63 | } catch (const c10::Error& e) { 64 | std::cerr << "error loading the model\n"; 65 | delete cost_func; 66 | throw std::runtime_error("Serialized model does not exist"); 67 | } 68 | 69 | return cost_func; 70 | } 71 | 72 | // Adds pytorch functionality to the module. 73 | void add_torch_functionality(py::module& m) { 74 | m.def("DoPtr", &DoPtr); 75 | m.def("DoCopy", &DoCopy); 76 | m.def("TestTensor", &TestTensor); 77 | m.def("LiveRun", &TestLiveRun); 78 | 79 | py::class_ problem = 80 | (py::class_)m.attr("Problem"); 81 | 82 | problem.def("AddParameterBlockPythonFunc", 83 | [](ceres::Problem& myself, 84 | const std::string& func_src_code, 85 | const std::vector& sizes) { 86 | std::cout << func_src_code << std::endl; 87 | }); 88 | 89 | problem.def("AddParameterBlockTest", 90 | [](ceres::Problem& myself, torch::jit::script::Module* module) { 91 | std::cout << "Problem suceess" << std::endl; 92 | }); 93 | 94 | problem.def( 95 | "AddResidualBlock", 96 | [](ceres::Problem& myself, 97 | ceres::CostFunction* cost, 98 | ceres::LossFunction* loss, 99 | std::vector& tensors) { 100 | std::vector pointers; 101 | for (int i = 0; i < tensors.size(); ++i) { 102 | auto& t = tensors[i]; 103 | pointers.push_back(t.data_ptr()); 104 | } 105 | return myself.AddResidualBlock(cost, loss, pointers); 106 | }, 107 | py::keep_alive<1, 2>(), // Cost Function 108 | py::keep_alive<1, 3>(), // Loss Function 109 | py::return_value_policy::reference); 110 | 111 | m.def("CreateTorchCostFunction", &CreateTorchCostFunction); 112 | } 113 | 114 | #endif 115 | -------------------------------------------------------------------------------- /python_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Edwinem/ceres_python_bindings/d809e7651890c6e4b78c5c40c2f47d6c4aba8525/python_tests/__init__.py -------------------------------------------------------------------------------- /python_tests/debug_functions.py: -------------------------------------------------------------------------------- 1 | # Ceres Solver Python Bindings 2 | # Copyright Nikolaus Mitchell. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright notice, 8 | # this list of conditions and the following disclaimer. 9 | # * Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # * Neither the name of the copyright holder nor the names of its contributors may be 13 | # used to endorse or promote products derived from this software without 14 | # specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 20 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 | # POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | # Author: nikolausmitchell@gmail.com (Nikolaus Mitchell) 29 | 30 | 31 | ''' 32 | Functions used to help with debugging 33 | ''' 34 | 35 | 36 | def print_numpy_address(np_data): 37 | print(hex(np_data.__array_interface__['data'][0])) 38 | 39 | def GetPIDAndPause(): 40 | import os 41 | 42 | print(os.getpid()) 43 | 44 | input("Enter to continue ...") -------------------------------------------------------------------------------- /python_tests/loss_function_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Edwinem/ceres_python_bindings/d809e7651890c6e4b78c5c40c2f47d6c4aba8525/python_tests/loss_function_test.py -------------------------------------------------------------------------------- /python_tests/test_python_defined_cost_function.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | pyceres_location="" # Folder where the PyCeres lib is created 4 | if os.getenv('PYCERES_LOCATION'): 5 | pyceres_location=os.getenv('PYCERES_LOCATION') 6 | else: 7 | pyceres_location="../../build/lib" # If the environment variable is not set 8 | # then it will assume this directory. Only will work if built with Ceres and 9 | # through the normal mkdir build, cd build, cmake .. procedure 10 | import sys 11 | sys.path.insert(0, pyceres_location) 12 | 13 | import PyCeres # Import the Python Bindings 14 | import numpy as np 15 | import pytest 16 | 17 | class PythonCostFunc(PyCeres.CostFunction): 18 | def __init__(self): 19 | super().__init__() 20 | self.set_num_residuals(2) 21 | self.set_parameter_block_sizes([3]) 22 | 23 | def Evaluate(self,parameters, residuals, jacobians): 24 | x=parameters[0][0] 25 | y=parameters[0][1] 26 | z=parameters[0][2] 27 | 28 | residuals[0]=x+2*y+4*z 29 | residuals[1]=y*z 30 | if jacobians!=None: 31 | jacobian=jacobians[0] 32 | jacobian[0 * 2 + 0] = 1 33 | jacobian[0 * 2 + 1] = 0 34 | 35 | jacobian[1 * 2 + 0] = 2 36 | jacobian[1 * 2 + 1] = z 37 | 38 | jacobian[2 * 2 + 0] = 4 39 | jacobian[2 * 2 + 1] = y 40 | return True 41 | 42 | 43 | def RunBasicProblem(): 44 | cost_function = PythonCostFunc() 45 | 46 | data = [0.76026643, -30.01799744, 0.55192142] 47 | np_data = np.array(data) 48 | 49 | print(np_data) 50 | 51 | problem = PyCeres.Problem() 52 | 53 | problem.AddResidualBlock(cost_function, None, np_data) 54 | options = PyCeres.SolverOptions() 55 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_QR 56 | options.minimizer_progress_to_stdout = True 57 | summary = PyCeres.Summary() 58 | PyCeres.Solve(options, problem, summary) 59 | return summary.final_cost 60 | 61 | def test_cost(): 62 | cost=RunBasicProblem() 63 | assert pytest.approx(0.0, 1e-10) == cost 64 | 65 | 66 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = PyCeres 3 | description = Python Bindings for Ceres library 4 | long_description = file: README.md; charset=UTF-8 5 | long_description_content_type = text/markdown 6 | author = Edwinem 7 | author_email = 735010+Edwinem@users.noreply.github.com 8 | license = BSD 9 | platforms = any 10 | url = https://github.com/Edwinem/ceres_python_bindings 11 | project_urls = 12 | Source = https://github.com/Edwinem/ceres_python_bindings 13 | Tracker = https://github.com/Edwinem/ceres_python_bindings/issues 14 | keywords = nonlinear least squares optimization 15 | classifiers = 16 | Development Status :: Production/Stable 17 | Operating System :: OS Independent 18 | Operating System :: POSIX :: Linux 19 | Operating System :: MacOS 20 | Operating System :: Microsoft :: Windows 21 | Intended Audience :: Science/Research 22 | Intended Audience :: Developers 23 | Intended Audience :: Education 24 | Programming Language :: C++ 25 | Programming Language :: Python :: 3 :: Only 26 | Programming Language :: Python :: 3 27 | Programming Language :: Python :: 3.6 28 | Programming Language :: Python :: 3.7 29 | Programming Language :: Python :: 3.8 30 | Programming Language :: Python :: 3.9 31 | License :: OSI Approved :: BSD License 32 | 33 | [options] 34 | zip_safe = False 35 | python_requires = >=3.6 36 | 37 | [options.extras_require] 38 | testing = 39 | pytest 40 | numpy 41 | all = 42 | %(testing)s 43 | 44 | [tool:pytest] 45 | testpaths = test/python -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from cmake_build_extension import BuildExtension, CMakeExtension 4 | from setuptools import setup 5 | from pathlib import Path 6 | 7 | setup( 8 | ext_modules=[ 9 | CMakeExtension( 10 | name="CMakeProject", 11 | install_prefix="PyCeres", 12 | cmake_depends_on=["pybind11"], 13 | disable_editable=False, 14 | cmake_configure_options=[ 15 | # This option points CMake to the right Python interpreter, and helps 16 | # the logic of FindPython3.cmake to find the active version 17 | f"-DPython3_ROOT_DIR={Path(sys.prefix)}"] 18 | ) 19 | ], 20 | cmdclass=dict(build_ext=BuildExtension), 21 | ) -------------------------------------------------------------------------------- /tests/pytorch_test.cpp: -------------------------------------------------------------------------------- 1 | #include // One-stop header. 2 | 3 | #include 4 | #include 5 | 6 | 7 | int main(int argc, const char* argv[]) { 8 | 9 | 10 | 11 | torch::jit::script::Module module; 12 | try { 13 | // Deserialize the ScriptModule from a file using torch::jit::load(). 14 | module = torch::jit::load("../python_tests/example_torch_module.pt"); 15 | } 16 | catch (const c10::Error& e) { 17 | std::cerr << "error loading the model\n"; 18 | return -1; 19 | } 20 | 21 | std::cout << "ok\n"; 22 | 23 | std::vector inputs; 24 | double param=1; 25 | torch::ArrayRef ref(¶m, 1); 26 | inputs.push_back(torch::tensor(ref, torch::dtype(torch::kFloat64)).set_requires_grad(true)); 27 | //inputs.push_back(torch::ones({1}).set_requires_grad(true)); 28 | 29 | 30 | 31 | // Execute the model and turn its output into a tensor. 32 | at::Tensor output = module.forward(inputs).toTensor(); 33 | 34 | 35 | std::cout << output << std::endl; 36 | 37 | output.backward(); 38 | 39 | auto &grad = inputs[0].toTensor().grad(); 40 | 41 | std::cout << grad << std::endl; 42 | 43 | } --------------------------------------------------------------------------------