├── .clang-format ├── .github └── workflows │ └── cmake-multi-platform.yml ├── .gitignore ├── CMakeLists.txt ├── CONTRIBUTING.md ├── Doxyfile ├── LICENSE ├── README-zh.md ├── README.md ├── examples ├── CMakeLists.txt ├── diff_machine │ ├── CMakeLists.txt │ └── main.cpp ├── set_initial_values │ ├── CMakeLists.txt │ └── main.cpp ├── solve │ ├── CMakeLists.txt │ └── main.cpp └── solve2 │ ├── CMakeLists.txt │ └── main.cpp ├── scripts ├── all_tests_preheader.cpp ├── combine_src_to_header_only.py ├── combine_tests_to_single.py ├── pre-commit ├── run_me.bat └── util.py ├── single ├── include │ └── tomsolver │ │ └── tomsolver.h └── test │ ├── CMakeLists.txt │ └── all_tests.cpp ├── src └── tomsolver │ ├── config.cpp │ ├── config.h │ ├── diff.cpp │ ├── diff.h │ ├── error_type.cpp │ ├── error_type.h │ ├── functions.h │ ├── linear.cpp │ ├── linear.h │ ├── mat.cpp │ ├── mat.h │ ├── math_operator.cpp │ ├── math_operator.h │ ├── node.cpp │ ├── node.h │ ├── nonlinear.cpp │ ├── nonlinear.h │ ├── parse.cpp │ ├── parse.h │ ├── simplify.cpp │ ├── simplify.h │ ├── subs.cpp │ ├── subs.h │ ├── symmat.cpp │ ├── symmat.h │ ├── tomsolver.h │ ├── vars_table.cpp │ └── vars_table.h └── tests ├── CMakeLists.txt ├── diff_test.cpp ├── functions_test.cpp ├── helper.cpp ├── helper.h ├── linear_test.cpp ├── mat_test.cpp ├── memory_leak_detection.h ├── memory_leak_detection_win.h ├── node_test.cpp ├── parse_test.cpp ├── power_test.cpp ├── random_test.cpp ├── simplify_test.cpp ├── solve_base_test.cpp ├── solve_test.cpp ├── subs_test.cpp ├── symmat_test.cpp └── to_string_test.cpp /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: LLVM 3 | 4 | IndentWidth: 4 5 | 6 | # 访问说明符(public、private等)的偏移 7 | AccessModifierOffset: -4 8 | 9 | # 允许短的块放在同一行(Always 总是将短块合并成一行,Empty 只合并空块) 10 | AllowShortBlocksOnASingleLine: Empty 11 | 12 | # 允许短的函数放在同一行: None, InlineOnly(定义在类中), Empty(空函数), Inline(定义在类中,空函数), All 13 | AllowShortFunctionsOnASingleLine: Empty 14 | 15 | AllowShortLambdasOnASingleLine: Empty 16 | 17 | # 总是在template声明后换行 18 | AlwaysBreakTemplateDeclarations: true 19 | 20 | # 禁用include重排序 21 | SortIncludes: true 22 | --- 23 | Language: Cpp 24 | # Force pointers to the type for C++. 25 | DerivePointerAlignment: false 26 | PointerAlignment: Right 27 | ColumnLimit: 120 28 | --- 29 | Language: Proto 30 | # Don't format .proto files. 31 | DisableFormat: true -------------------------------------------------------------------------------- /.github/workflows/cmake-multi-platform.yml: -------------------------------------------------------------------------------- 1 | # This starter workflow is for a CMake project running on multiple platforms. There is a different starter workflow if you just want a single platform. 2 | # See: https://github.com/actions/starter-workflows/blob/main/ci/cmake-single-platform.yml 3 | name: CMake on multiple platforms 4 | 5 | on: 6 | push: 7 | branches: [ "master", "dev", "action" ] 8 | pull_request: 9 | branches: [ "master", "dev", "action" ] 10 | 11 | jobs: 12 | build: 13 | runs-on: ${{ matrix.os }} 14 | 15 | strategy: 16 | # Set fail-fast to false to ensure that feedback is delivered for all matrix combinations. Consider changing this to true when your workflow is stable. 17 | fail-fast: false 18 | 19 | # Set up a matrix to run the following 3 configurations: 20 | # 1. 21 | # 2. 22 | # 3. 23 | # 24 | # To add more build types (Release, Debug, RelWithDebInfo, etc.) customize the build_type list. 25 | matrix: 26 | os: [ubuntu-latest, windows-latest, macos-latest] 27 | build_type: [Debug, Release] 28 | cpp_compiler: [g++, clang++, cl] 29 | exclude: 30 | - os: ubuntu-latest 31 | cpp_compiler: cl 32 | - os: macos-latest 33 | cpp_compiler: cl 34 | 35 | steps: 36 | - uses: actions/checkout@v3 37 | 38 | - name: Set reusable strings 39 | # Turn repeated input strings (such as the build output directory) into step outputs. These step outputs can be used throughout the workflow file. 40 | id: strings 41 | shell: bash 42 | run: | 43 | echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT" 44 | 45 | - name: Configure CMake 46 | # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. 47 | # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type 48 | run: > 49 | cmake -B ${{ steps.strings.outputs.build-output-dir }} 50 | -DCMAKE_CXX_COMPILER=${{ matrix.cpp_compiler }} 51 | -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} 52 | -S ${{ github.workspace }} 53 | 54 | - name: Build 55 | # Build your program with the given configuration. Note that --config is needed because the default Windows generator is a multi-config generator (Visual Studio generator). 56 | run: cmake --build ${{ steps.strings.outputs.build-output-dir }} --config ${{ matrix.build_type }} 57 | 58 | - name: Test 59 | working-directory: ${{ steps.strings.outputs.build-output-dir }} 60 | # Execute tests defined by the CMake configuration. Note that --build-config is needed because the default Windows generator is a multi-config generator (Visual Studio generator). 61 | # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail 62 | run: ctest --build-config ${{ matrix.build_type }} 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | output/ 3 | build/ 4 | .vs/ 5 | .vscode/ 6 | __pycache__/ 7 | .cache -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # for FetchContent 2 | cmake_minimum_required(VERSION 3.11 FATAL_ERROR) 3 | 4 | project(tomsolver) 5 | 6 | # C++14 7 | set(CMAKE_CXX_STANDARD 14) 8 | set(CMAKE_CXX_EXTENSIONS OFF) 9 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 10 | 11 | if(MSVC) 12 | add_compile_options("$<$:/utf-8>") 13 | else() 14 | add_compile_options(-Wall -Wextra -pedantic -Werror) 15 | endif() 16 | 17 | # ===================================== 18 | option(BUILD_TESTS "Build tests" ON) 19 | message(STATUS "Option: BUILD_TESTS=${BUILD_TESTS}") 20 | 21 | option(BUILD_EXAMPLES "Build examples" ON) 22 | message(STATUS "Option: BUILD_EXAMPLES=${BUILD_EXAMPLES}") 23 | 24 | # ===================================== 25 | if(BUILD_TESTS) 26 | # 添加参数:是否从镜像站下载googletest 27 | option(USE_MIRROR_GTEST_REPO "Use mirror google test repository at gitcode.net" OFF) 28 | 29 | if(USE_MIRROR_GTEST_REPO) 30 | set(GTEST_URL "https://gitcode.net/mirrors/google/googletest.git") 31 | else() 32 | set(GTEST_URL "https://github.com/google/googletest.git") 33 | endif() 34 | 35 | message(STATUS "Fetch googletest from ${GTEST_URL}") 36 | 37 | include(FetchContent) 38 | FetchContent_Declare( 39 | googletest 40 | GIT_REPOSITORY ${GTEST_URL} 41 | GIT_TAG release-1.12.1 42 | ) 43 | 44 | # For Windows: Prevent overriding the parent project's compiler/linker settings 45 | set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) 46 | FetchContent_MakeAvailable(googletest) 47 | 48 | enable_testing() 49 | 50 | add_subdirectory(tests) 51 | add_subdirectory(single/test) 52 | endif() 53 | 54 | # ===================================== 55 | file(GLOB_RECURSE SOURCE_CODE 56 | src/*.cpp 57 | src/*.h 58 | ) 59 | 60 | add_library(tomsolver STATIC ${SOURCE_CODE}) 61 | 62 | target_include_directories(tomsolver PUBLIC 63 | $ 64 | $ 65 | ) 66 | 67 | if(MSVC) 68 | target_compile_options(tomsolver PRIVATE /W4 /WX) 69 | endif() 70 | 71 | # ===================================== 72 | if(BUILD_EXAMPLES) 73 | add_subdirectory(examples) 74 | endif() 75 | 76 | # ===================================== 77 | # Install the library and its headers 78 | install(TARGETS tomsolver 79 | EXPORT tomsolver_targets 80 | ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" 81 | LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" 82 | RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") 83 | 84 | install(DIRECTORY src/ 85 | DESTINATION include/ 86 | FILES_MATCHING PATTERN "*.h" 87 | ) 88 | 89 | # Generate and install *-targets.cmake 90 | install(EXPORT tomsolver_targets 91 | FILE tomsolver-targets.cmake 92 | NAMESPACE tomsolver:: 93 | DESTINATION share/tomsolver) 94 | 95 | # Generate the config file in the current binary dir (this ensures it's not placed directly in source) 96 | file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/tomsolver-config.cmake" 97 | "include(CMakeFindDependencyMacro)\n" 98 | "include(\"\${CMAKE_CURRENT_LIST_DIR}/tomsolver-targets.cmake\")\n" 99 | ) 100 | 101 | # Install the generated config file 102 | install(FILES "${CMAKE_CURRENT_BINARY_DIR}/tomsolver-config.cmake" 103 | DESTINATION share/tomsolver) 104 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Environment 2 | 3 | * all the C++ code must be formated by clang-format, so you should prepare an executive clang-format and configure it well. e.g. 4 | * at Visual Studio, you can get it from this: 5 | https://learn.microsoft.com/en-us/visualstudio/ide/reference/options-text-editor-c-cpp-formatting?view=vs-2022 6 | * at Visual Studio Code, you can install an extension named "clang-format". 7 | * ensure you have a python (version 3.6+), and then 8 | copy ${repo}/scripts/pre-commit to ${repo}/.git/hooks 9 | 10 | Now, you can modify the code freely. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tom Willow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-zh.md: -------------------------------------------------------------------------------- 1 | # tomsolver 2 | 3 | ![workflow](https://github.com/tomwillow/tomsolver/actions/workflows/cmake-multi-platform.yml/badge.svg) 4 | 5 | [中文](https://github.com/tomwillow/tomsolver/blob/master/README-zh.md) [English](https://github.com/tomwillow/tomsolver) 6 | 7 | C++14 极简非线性方程组求解器 8 | 9 | > 让 C++求解非线性方程组像 Matlab fsolve 一样简单 10 | 11 | 地址: https://github.com/tomwillow/tomsolver 12 | 13 | **Contributors:** 14 | 15 | - Tom Willow (https://github.com/tomwillow) 16 | - lizho (https://github.com/lizho) 17 | 18 | # 特点 19 | 20 | - 简单!简单!会用 fsolve 就会用这个! 21 | - 单头文件,直接 include 完事儿! 22 | 23 | # 功能 24 | 25 | - 非线性方程组求解(牛顿-拉夫森法、LM 方法) 26 | - 线性方程组求解(高斯-列主元迭代法、逆矩阵) 27 | - 矩阵、向量运算(矩阵求逆、向量叉乘等) 28 | - “伪”符号运算(对表达式求导、对符号矩阵求雅可比矩阵) 29 | 30 | # 跨平台支持 31 | 32 | 测试环境: 33 | 34 | - Linux: ubuntu 22.04 x86_64 gcc 11.3.0 35 | - Windows: windows10 x64 Visual Studio 2019 36 | 37 | Github Actions 自动测试: 38 | 39 | - Linux-latest gcc Debug&Release 40 | - Linux-latest clang Debug&Release 41 | - Windows-latest msvc Debug&Release 42 | 43 | # 例子 44 | 45 | ```C++ 46 | #include 47 | 48 | using namespace tomsolver; 49 | 50 | int main() { 51 | /* 52 | Translate from Matlab code: 53 | 54 | root2d.m: 55 | function F = root2d(x) 56 | F(1) = exp(-exp(-(x(1)+x(2)))) - x(2)*(1+x(1)^2); 57 | F(2) = x(1)*cos(x(2)) + x(2)*sin(x(1)) - 0.5; 58 | end 59 | 60 | root2d_solve.m: 61 | format long 62 | fun = @root2d; 63 | x0 = [0,0]; 64 | x = fsolve(fun,x0) 65 | 66 | result: 67 | x = 68 | 69 | 0.353246561920553 0.606082026502285 70 | */ 71 | try { 72 | // 创建方程组 73 | SymVec f = { 74 | "exp(-exp(-(x1 + x2))) - x2 * (1 + x1 ^ 2)"_f, 75 | "x1 * cos(x2) + x2 * sin(x1) - 0.5"_f, 76 | }; 77 | 78 | // 求解! 79 | VarsTable ans = Solve(f); 80 | 81 | // 打印结果 82 | std::cout << ans << std::endl; 83 | 84 | // 取得结果 85 | std::cout << "x1 = " << ans["x1"] << std::endl; 86 | std::cout << "x2 = " << ans["x2"] << std::endl; 87 | } catch (const std::runtime_error &err) { 88 | // 如果出错,捕获异常,并打印 89 | std::cerr << err.what() << std::endl; 90 | return -1; 91 | } 92 | 93 | return 0; 94 | } 95 | ``` 96 | 97 | # 用法 98 | 99 | ## 1. header-only用法 100 | 101 | 仅需要包含一个单头文件即可: 102 | `single/include/tomsolver/tomsolver.hpp` 103 | 104 | ## 2. 使用VCPKG 105 | 106 | 你可以使用[vcpkg](https://learn.microsoft.com/zh-cn/vcpkg/get_started/overview)进行统一的包管理。 107 | 108 | * 清单模式(推荐) 109 | 新建 `vcpkg.json`文件: 110 | 111 | ``` 112 | { 113 | "dependencies": [ 114 | "tomsolver" 115 | ] 116 | } 117 | ``` 118 | 119 | 然后在当前目录执行: 120 | `$ vcpkg install` 121 | 详细教程:[https://learn.microsoft.com/zh-cn/vcpkg/get_started/get-started?pivots=shell-powershell]() 122 | * 经典模式 123 | 直接执行: 124 | `$ vcpkg install tomsolver` 125 | 126 | ## 3. 二进制库+头文件用法 127 | 128 | ```bash 129 | $ git clone https://github.com/tomwillow/tomsolver 130 | $ mkdir build 131 | $ cd build 132 | $ cmake ../tomsolver 133 | $ cmake --build . --target INSTALL 134 | ``` 135 | 136 | 然后添加 include 目录,并链接到库文件。 137 | 138 | # 目录结构 139 | 140 | - **src**: 源文件 141 | - **tests**: 单元测试 142 | - **single/include**: header-only 的 tomsolver.h 所在的文件夹 143 | - **single/test**: 所有单元测试整合为一个.cpp 文件,用于测试 tomsolver.h 是否正确 144 | - **scripts**: 用于生成 single 下面的单文件头文件和单文件测试 145 | 146 | ### 例子 147 | 148 | - **examples/solve**: 解非线性方程的例子,演示基本用法和怎么设置统一的初值 149 | - **examples/set_initial_values**: 解非线性方程的例子,演示怎么设置每个变量的初值 150 | - **examples/solve2**: 解非线性方程的例子,演示怎么切换解法和怎么替换方程中的已知量 151 | - **examples/diff_machine**: 求导器,输入一行表达式,输出这个表达式的求导结果 152 | 153 | # 开发计划 154 | 155 | - 增加 doxygen 注释+教程文档(CN+EN) 156 | - 增加 benchmark 测速 157 | - 增加使用 Eigen 库作为内置矩阵库的可选项 158 | - 对标 Matlab fsolve,增加更多非线性方程组解法 159 | - 在 Solve 函数中增加可选的 Config 参数,可以使用非全局的 Config 进行求解 160 | (类似于 Matlab fsolve 的 options) 161 | - 增加对二元/多元函数的支持,例如 pow(x, y) 162 | - 现在的 Simplify 函数还很朴素,把 Simplify 修改得更好 163 | - 增加 LaTeX 格式的公式输出 164 | 165 | # Thanks 166 | 167 | https://github.com/taehwan642 168 | 169 | # 微信交流群 170 | 171 | 如果有问题想要交流或者想参与开发,或者想与作者联系,欢迎加微信 tomwillow。备注 tomsolver 按照指引进群。 172 | 173 | 如果您觉得此项目不错,请赏颗星吧! 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tomsolver 2 | 3 | ![workflow](https://github.com/tomwillow/tomsolver/actions/workflows/cmake-multi-platform.yml/badge.svg) 4 | 5 | [中文](https://github.com/tomwillow/tomsolver/blob/master/README-zh.md) [English](https://github.com/tomwillow/tomsolver) 6 | 7 | Simplest, Well-tested, Non-linear equations solver library by C++14. 8 | 9 | origin: https://github.com/tomwillow/tomsolver 10 | 11 | > Make C++ solve nonlinear equations as easy as Matlab fsolve 12 | 13 | **Contributors:** 14 | 15 | - Tom Willow (https://github.com/tomwillow) 16 | - lizho (https://github.com/lizho) 17 | 18 | # Features 19 | 20 | - Simple! Simple! If you know how to use fsolve, you will use this! 21 | - Single header file, just include it! 22 | 23 | # Functions 24 | 25 | - Solving nonlinear equations (Newton-Raphson method, LM method) 26 | - Solving linear equations (Gaussian-column pivot iteration method, inverse matrix) 27 | - Matrix and vector operations (matrix inversion, vector cross multiplication, etc.) 28 | - "Pseudo" symbolic operations (derivatives of expressions, Jacobian matrices of symbolic matrices) 29 | 30 | # Supported Platforms 31 | 32 | Tested at: 33 | 34 | - Linux: ubuntu 22.04 x86_64 gcc 11.3.0 35 | - Windows: windows10 x64 Visual Studio 2019 36 | 37 | Tested at Github Actions: 38 | 39 | - Linux-latest gcc Debug&Release 40 | - Linux-latest clang Debug&Release 41 | - Windows-latest msvc Debug&Release 42 | 43 | # Example 44 | 45 | ```C++ 46 | #include 47 | 48 | using namespace tomsolver; 49 | 50 | int main() { 51 | /* 52 | Translate from Matlab code: 53 | 54 | root2d.m: 55 | function F = root2d(x) 56 | F(1) = exp(-exp(-(x(1)+x(2)))) - x(2)*(1+x(1)^2); 57 | F(2) = x(1)*cos(x(2)) + x(2)*sin(x(1)) - 0.5; 58 | end 59 | 60 | root2d_solve.m: 61 | format long 62 | fun = @root2d; 63 | x0 = [0,0]; 64 | x = fsolve(fun,x0) 65 | 66 | result: 67 | x = 68 | 69 | 0.353246561920553 0.606082026502285 70 | */ 71 | try { 72 | // create equations from string 73 | SymVec f = { 74 | "exp(-exp(-(x1 + x2))) - x2 * (1 + x1 ^ 2)"_f, 75 | "x1 * cos(x2) + x2 * sin(x1) - 0.5"_f, 76 | }; 77 | 78 | // solve it! 79 | VarsTable ans = Solve(f); 80 | 81 | // print the solution 82 | std::cout << ans << std::endl; 83 | 84 | // get the values of solution 85 | std::cout << "x1 = " << ans["x1"] << std::endl; 86 | std::cout << "x2 = " << ans["x2"] << std::endl; 87 | } catch (const std::runtime_error &err) { 88 | // if any error occurs, exception will be thrown 89 | std::cerr << err.what() << std::endl; 90 | return -1; 91 | } 92 | 93 | return 0; 94 | } 95 | ``` 96 | 97 | # Usage 98 | 99 | ## 1. Header-Only usage 100 | 101 | Just include a single header file: 102 | `single/include/tomsolver/tomsolver.h` 103 | 104 | ## 2. VCPKG 105 | 106 | You could use [vcpkg](https://learn.microsoft.com/en-us/vcpkg/get_started/overview) for unified package usage. 107 | 108 | * Manifest mode(recommended) 109 | Create a `vcpkg.json` file: 110 | 111 | ``` 112 | { 113 | "dependencies": [ 114 | "tomsolver" 115 | ] 116 | } 117 | ``` 118 | 119 | Then run the following command in the same directory: 120 | `$ vcpkg install` 121 | Tutorial: [https://learn.microsoft.com/en-us/vcpkg/get_started/get-started?pivots=shell-bash](https://learn.microsoft.com/en-us/vcpkg/get_started/get-started?pivots=shell-bash) 122 | * Classic mode 123 | `$ vcpkg install tomsolver` 124 | 125 | ## 2. Binary Library + Header Files usage 126 | 127 | ```bash 128 | $ git clone https://github.com/tomwillow/tomsolver 129 | $ mkdir build 130 | $ cd build 131 | $ cmake ../tomsolver 132 | $ cmake --build . --target INSTALL 133 | ``` 134 | 135 | Then add the include directory and link to the library file. 136 | 137 | # Directory Structure 138 | 139 | - **src**: source files 140 | - **tests**: unit tests 141 | - **single/include**: the folder where the header-only tomsolver.h is located 142 | - **single/test**: All unit tests are integrated into one .cpp file to test whether tomsolver.h is correct. 143 | - **scripts**: used to generate single-file header files and single-file tests under single 144 | 145 | ### examples 146 | 147 | - **examples/solve**: Example of solving nonlinear equations, demonstrating basic usage and how to set a unified initial value 148 | - **examples/set_initial_value**: Example of solving nonlinear equations, demostrating how to set every variable's initial value 149 | - **examples/solve2**: Example of solving nonlinear equations, demonstrating how to switch solution methods and replace known quantities in the equation 150 | - **examples/diff_machine**: Derivator, input a line of expression and output the derivation result of this expression 151 | 152 | # Development Plan 153 | 154 | - add doxygen comments + tutorial document (CN+EN) 155 | - add benchmark tests 156 | - add an option to use Eigen library as matrix library 157 | - aim at Matlab fsolve, add more solving methods of nonlinear equations 158 | - add an optional Config parameter in Solve() function 159 | (similar to Matlab fsolve's option) 160 | - add support for binary/multivariate functions, such as pow(x, y) 161 | - the current Simplify function is still very simple, modify Simplify to be better 162 | - add LaTeX format formula output 163 | 164 | # Thanks 165 | 166 | https://github.com/taehwan642 167 | 168 | # WeChat group 169 | 170 | If you have any questions, want to communicate, want to participate in development, or want to contact the authors, you are welcome to add the WeChat _tomwillow_. 171 | 172 | If you think this repository is good, please give it a star! 173 | -------------------------------------------------------------------------------- /examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(solve) 2 | add_subdirectory(set_initial_values) 3 | add_subdirectory(solve2) 4 | add_subdirectory(diff_machine) -------------------------------------------------------------------------------- /examples/diff_machine/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | file(GLOB TEST_CODE 3 | *.cpp 4 | ) 5 | 6 | add_executable(DiffMachine ${TEST_CODE}) 7 | 8 | target_include_directories(DiffMachine PUBLIC 9 | ../../single/include 10 | ) 11 | -------------------------------------------------------------------------------- /examples/diff_machine/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | using std::cerr; 4 | using std::cin; 5 | using std::cout; 6 | using std::endl; 7 | using namespace tomsolver; 8 | 9 | const char usage[] = "=========== diff machine ===========\n" 10 | "example: \n" 11 | ">>x^5\n" 12 | "ans = \n" 13 | " 5*x^4\n" 14 | "\n" 15 | "optional functions:\n" 16 | "sin(x) cos(x) tan(x) asin(x) acos(x) atan(x)\n" 17 | "sqrt(x) log(x) log2(x) log10(x) exp(x)\n" 18 | "====================================\n"; 19 | 20 | int main() { 21 | cout << usage << endl; 22 | 23 | while (1) { 24 | // let user input an expression 25 | cout << ">>"; 26 | std::string expr; 27 | cin >> expr; 28 | 29 | try { 30 | // parse to be expression tree 31 | Node node = Parse(expr); 32 | 33 | // analyze variables 34 | auto varnamesSet = node->GetAllVarNames(); 35 | std::vector varnames(varnamesSet.begin(), varnamesSet.end()); 36 | VarsTable varsTable(varnames, 0); 37 | std::string varname; 38 | 39 | if (varsTable.Vars().size() == 1) { 40 | // only one variable 41 | varname = varsTable.Vars()[0]; 42 | } else if (varsTable.Vars().size() > 1) { 43 | // multiple variables 44 | cout << "more than 1 variable. who do you want to differentiate?" << endl; 45 | cout << ">>"; 46 | cin >> varname; 47 | if (!varsTable.Has(varname)) { 48 | throw std::runtime_error("no variable \"" + varname + "\" in expression: " + expr); 49 | } 50 | } 51 | 52 | Node dnode = Diff(Move(node), varname); 53 | cout << "ans = " << endl; 54 | cout << " " << dnode->ToString() << endl; 55 | } catch (const std::runtime_error &err) { cerr << err.what() << endl; } 56 | } 57 | 58 | return 0; 59 | } -------------------------------------------------------------------------------- /examples/set_initial_values/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | file(GLOB TEST_CODE 3 | *.cpp 4 | ) 5 | 6 | add_executable(Example_SetInitialValues ${TEST_CODE}) 7 | 8 | target_include_directories(Example_SetInitialValues PUBLIC 9 | ../../single/include 10 | ) 11 | -------------------------------------------------------------------------------- /examples/set_initial_values/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | using namespace tomsolver; 4 | 5 | int main() { 6 | // set the locale output to utf-8 7 | std::setlocale(LC_ALL, ".UTF8"); 8 | 9 | try { 10 | // create equations from string 11 | SymVec f = { 12 | "x^2+y^2-25"_f, 13 | "x^2-y^2-7"_f, 14 | }; 15 | 16 | Config::Get().logLevel = LogLevel::TRACE; 17 | 18 | VarsTable initialValues{{"x", 5}, {"y", -5}}; 19 | 20 | // solve it with custom initial values 21 | VarsTable ans = Solve(f, initialValues); 22 | 23 | // print the solution 24 | std::cout << ans << std::endl; 25 | 26 | // get the values of solution 27 | std::cout << "x = " << ans["x"] << std::endl; 28 | std::cout << "y = " << ans["y"] << std::endl; 29 | 30 | // substitute the obtained variables into the equations to verify the solution 31 | // if the result is 0, it indicates that the solution is correct 32 | std::cout << "equations: " << f.Subs(ans).Calc() << std::endl; 33 | } catch (const std::runtime_error &err) { 34 | // if any error occurs, exception will be thrown 35 | std::cerr << err.what() << std::endl; 36 | return -1; 37 | } 38 | 39 | return 0; 40 | } -------------------------------------------------------------------------------- /examples/solve/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | file(GLOB TEST_CODE 3 | *.cpp 4 | ) 5 | 6 | add_executable(Example_Solve ${TEST_CODE}) 7 | 8 | target_include_directories(Example_Solve PUBLIC 9 | ../../single/include 10 | ) 11 | -------------------------------------------------------------------------------- /examples/solve/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | using namespace tomsolver; 4 | 5 | int main() { 6 | /* 7 | Translate from Matlab code: 8 | 9 | root2d.m: 10 | function F = root2d(x) 11 | F(1) = exp(-exp(-(x(1)+x(2)))) - x(2)*(1+x(1)^2); 12 | F(2) = x(1)*cos(x(2)) + x(2)*sin(x(1)) - 0.5; 13 | end 14 | 15 | root2d_solve.m: 16 | format long 17 | fun = @root2d; 18 | x0 = [0,0]; 19 | x = fsolve(fun,x0) 20 | 21 | result: 22 | x = 23 | 24 | 0.353246561920553 0.606082026502285 25 | */ 26 | 27 | // set the locale output to utf-8 28 | std::setlocale(LC_ALL, ".UTF8"); 29 | 30 | try { 31 | // create equations from string 32 | SymVec f = { 33 | "exp(-exp(-(x1 + x2))) - x2 * (1 + x1 ^ 2)"_f, 34 | "x1 * cos(x2) + x2 * sin(x1) - 0.5"_f, 35 | }; 36 | 37 | // set the initial value to be 0.0 38 | Config::Get().initialValue = 0.0; 39 | Config::Get().nonlinearMethod = NonlinearMethod::NEWTON_RAPHSON; // use the Newton-Raphson method 40 | Config::Get().allowIndeterminateEquation = true; 41 | 42 | // solve it! 43 | VarsTable ans = Solve(f); 44 | 45 | // print the solution 46 | std::cout << ans << std::endl; 47 | 48 | // get the values of solution 49 | std::cout << "x1 = " << ans["x1"] << std::endl; 50 | std::cout << "x2 = " << ans["x2"] << std::endl; 51 | 52 | // substitute the obtained variables into the equations to verify the solution 53 | // if the result is 0, it indicates that the solution is correct 54 | std::cout << "equations: " << f.Subs(ans).Calc() << std::endl; 55 | } catch (const std::runtime_error &err) { 56 | // if any error occurs, exception will be thrown 57 | std::cerr << err.what() << std::endl; 58 | return -1; 59 | } 60 | 61 | return 0; 62 | } -------------------------------------------------------------------------------- /examples/solve2/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | file(GLOB TEST_CODE 3 | *.cpp 4 | ) 5 | 6 | add_executable(Example_Solve2 ${TEST_CODE}) 7 | 8 | target_include_directories(Example_Solve2 PUBLIC 9 | ../../single/include 10 | ) 11 | -------------------------------------------------------------------------------- /examples/solve2/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | using namespace tomsolver; 4 | 5 | int main() { 6 | std::setlocale(LC_ALL, ".UTF8"); 7 | 8 | try { 9 | // use the LM method 10 | Config::Get().nonlinearMethod = NonlinearMethod::LM; 11 | 12 | // define the equations 13 | // this is essentially a symbolic vector, and solving it means finding the roots where the vector equals the 14 | // zero vector 15 | SymVec f{ 16 | Parse("a/(b^2)-c/(d^2)"), 17 | Parse("129.56108*b-(a/(b^2)+1/a-2*b/(a^2))"), 18 | Parse("129.56108*d-(d/(c^2)-c/(d^2)-1/a)"), 19 | Parse("5*e-7-(2/3*pi*a^2*b+((sqrt(3)*c^2)/(3*sqrt(c^2/3+d^2))+a-c)^2*pi*d^2/(c^2/3+d^2))"), 20 | }; 21 | 22 | // substitude the symbolic constances "pi", "e" with their numerical values 23 | f.Subs(VarsTable{{"pi", PI}, {"e", std::exp(1.0)}}); 24 | 25 | // print the f (symbolic vector) 26 | cout << f << endl; 27 | 28 | // solve 29 | auto ans = Solve(f); 30 | 31 | // print the solution 32 | cout << ans << endl; 33 | } catch (const std::runtime_error &err) { 34 | // if any error occurs, exception will be thrown 35 | std::cerr << err.what() << std::endl; 36 | return -1; 37 | } 38 | 39 | return 0; 40 | } -------------------------------------------------------------------------------- /scripts/all_tests_preheader.cpp: -------------------------------------------------------------------------------- 1 | #include "tomsolver/tomsolver.h" 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | using namespace tomsolver; 12 | 13 | using std::cout; 14 | using std::deque; 15 | using std::endl; -------------------------------------------------------------------------------- /scripts/combine_src_to_header_only.py: -------------------------------------------------------------------------------- 1 | import os 2 | import util 3 | 4 | 5 | # repository的根目录 6 | root_dir = os.path.abspath(os.path.dirname(__file__) + "/..") 7 | 8 | 9 | if __name__ == "__main__": 10 | # 输出目标的include目录 11 | target_dir = os.path.join(root_dir, "single/include/tomsolver") 12 | output_filename = f"{target_dir}/tomsolver.h" 13 | 14 | srcFilenames = [] 15 | for path in util.findAllFile(f"{root_dir}/src"): 16 | srcFilenames.append(path) 17 | srcFilenames.sort() 18 | 19 | include_dirs = [f"{root_dir}/src/tomsolver"] 20 | 21 | util.combineCodeFile(target_dir, output_filename, srcFilenames, include_dirs) 22 | -------------------------------------------------------------------------------- /scripts/combine_tests_to_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import util 5 | 6 | 7 | # repository的根目录 8 | root_dir = os.path.abspath(os.path.dirname(__file__) + "/..") 9 | 10 | # 测试代码中可能包含的本仓库内的include目录。 11 | # 用于合并测试文件时剔除#include语句 12 | INCLUDE_DIR_PREFIXES = [ 13 | os.path.abspath(f"{root_dir}/src"), 14 | os.path.abspath(f"{root_dir}/tests"), 15 | ] 16 | 17 | class MyTest: 18 | """ 19 | 这个类用于分析xxx_test.cpp文件,提取出TEST(xxx,xxx) {} 内容填入self.contents。 20 | 其他部分忽略不要。 21 | """ 22 | 23 | def __init__(self, filename): 24 | self.filename = filename 25 | with open(filename, "r", encoding="utf-8") as f: 26 | lines_orig = f.readlines() 27 | 28 | self.contents = [] 29 | 30 | inTest = False 31 | for line in lines_orig: 32 | line = line.rstrip() 33 | 34 | if re.match(r"TEST\(\w+,\s+\w+\)\s+{", line): 35 | if inTest: 36 | raise RuntimeError("[ERR] parse error at " + line) 37 | self.contents.append(line) 38 | inTest = True 39 | continue 40 | 41 | if line == "}": 42 | if not inTest: 43 | raise RuntimeError("[ERR] parse error at " + line) 44 | self.contents.append(line) 45 | inTest = False 46 | continue 47 | 48 | if inTest: 49 | self.contents.append(line) 50 | 51 | 52 | class MyTestDepFile: 53 | """ 54 | 这个类用于解析测试cpp的依赖文件。内容填入self.contents。 55 | """ 56 | 57 | def __init__(self, filename): 58 | self.filename = filename 59 | with open(filename, "r", encoding="utf-8") as f: 60 | lines_orig = f.readlines() 61 | 62 | self.contents = [] 63 | 64 | for line in lines_orig: 65 | line = line.rstrip() 66 | 67 | # 忽略 68 | if line == "#pragma once": 69 | continue 70 | 71 | self.contents.append(line) 72 | 73 | 74 | if __name__ == "__main__": 75 | # 输出目标的include目录 76 | target_dir = os.path.join(root_dir, "single/test") 77 | output_filename = f"{target_dir}/all_tests.cpp" 78 | 79 | srcFilenames = [] 80 | for path in util.findAllFile(f"{root_dir}/tests"): 81 | if path.find("_test.cpp") != -1: 82 | srcFilenames.append(path) 83 | srcFilenames.sort() 84 | 85 | # 1. 创建文件夹 86 | if not os.path.isdir(target_dir): 87 | os.makedirs(target_dir) 88 | 89 | # 90 | elements = [] 91 | for path in srcFilenames: 92 | elements.append(MyTest(path)) 93 | 94 | # 95 | contents = [] 96 | contents.extend(MyTestDepFile(f"{root_dir}/tests/memory_leak_detection.h").contents) 97 | contents.extend( 98 | MyTestDepFile(f"{root_dir}/tests/memory_leak_detection_win.h").contents 99 | ) 100 | contents.extend(MyTestDepFile(f"{root_dir}/tests/helper.cpp").contents) 101 | 102 | tempContents = [] 103 | for line in contents: 104 | # 去掉所有#include tomsolver的行 105 | obj = re.match(r"#include\s+[\"<]([\w_./\\]+)[\">]", line) 106 | if obj is None: 107 | tempContents.append(line) 108 | continue 109 | 110 | isInThisRepository = False 111 | pureIncludePath = obj.group(1) 112 | for dir in INCLUDE_DIR_PREFIXES: 113 | fullIncludePath = f"{dir}/{pureIncludePath}" 114 | if os.path.isfile(fullIncludePath): 115 | # if this included header file is in this repository, remove it (not append) 116 | isInThisRepository = True 117 | break 118 | if isInThisRepository: 119 | continue 120 | 121 | tempContents.append(line) 122 | 123 | contents = tempContents 124 | 125 | 126 | # 写入 127 | with open(output_filename, "w", encoding="utf-8") as f: 128 | # 填入预先准备好的头 129 | with open(f"{root_dir}/scripts/all_tests_preheader.cpp", "r") as ff: 130 | headerLines = ff.readlines() 131 | 132 | f.writelines(headerLines) 133 | 134 | f.write("\n") 135 | 136 | f.write("\n".join(contents)) 137 | f.write("\n") 138 | 139 | # 写入所有测试 140 | for ele in elements: 141 | print("writing content of ", ele.filename, "...") 142 | f.write("\n".join(ele.contents)) 143 | f.write("\n\n") 144 | 145 | # 用clang-format处理 146 | ret = os.system(f"clang-format -i {output_filename}") 147 | if ret != 0: 148 | raise RuntimeError("[ERR] fail to run clang-format") 149 | 150 | print("Done.") 151 | -------------------------------------------------------------------------------- /scripts/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # An example hook script to verify what is about to be committed. 4 | # Called by "git commit" with no arguments. The hook should 5 | # exit with non-zero status after issuing an appropriate message if 6 | # it wants to stop the commit. 7 | # 8 | # To enable this hook, rename this file to "pre-commit". 9 | set -e 10 | 11 | echo "[pre-commit] update single header file..." 12 | cd scripts 13 | python combine_src_to_header_only.py 14 | cd .. 15 | 16 | echo "[pre-commit] update single test file..." 17 | cd scripts 18 | python combine_tests_to_single.py 19 | cd .. 20 | 21 | git add single/ -------------------------------------------------------------------------------- /scripts/run_me.bat: -------------------------------------------------------------------------------- 1 | python combine_src_to_header_only.py 2 | python combine_tests_to_single.py 3 | pause -------------------------------------------------------------------------------- /scripts/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import typing 4 | 5 | 6 | def getFullPath(basename: str, include_dirs) -> typing.Union[None, str]: 7 | """ 8 | 传入一个basename,查找include_dirs里面是否有这个文件,并返回全路径。 9 | """ 10 | for dir in include_dirs: 11 | dest = os.path.join(dir, basename) 12 | if os.path.isfile(dest): 13 | return os.path.abspath(dest).replace("\\", "/") 14 | return None 15 | 16 | 17 | def findAllFile(base): 18 | """ 19 | 通过迭代方式遍历文件夹下的文件名。 20 | """ 21 | for root, ds, fs in os.walk(base): 22 | for f in fs: 23 | fullname = os.path.join(root, f) 24 | yield fullname 25 | 26 | 27 | class MyClass: 28 | """ 29 | 这个类用于分析.h或者.cpp文件,读入文件,获取它的include ""的内容和include <>的内容。 30 | 双引号的header视为内部依赖,尖括号的header视为外部依赖。 31 | #pragma once忽略掉。 32 | 其他的内容存入self.contents。 33 | """ 34 | 35 | def __init__(self, filename: str, include_dirs): 36 | self.filename = os.path.abspath(filename).replace("\\", "/") 37 | with open(filename, "r", encoding="utf-8") as f: 38 | lines_orig = f.readlines() 39 | 40 | # 内部依赖项,对于拓扑排序而言就是这个节点的入度 41 | self.depsInner = [] 42 | 43 | # 对外部库的依赖项 44 | self.depsLib = [] 45 | 46 | # #define的内容 47 | self.defines = [] 48 | 49 | self.contents = [] 50 | 51 | for line in lines_orig: 52 | stripedLine = line.strip() 53 | 54 | # 忽略 55 | if stripedLine == "#pragma once": 56 | continue 57 | 58 | innerDep = re.match(r"#include\s+\"([a-z_./\\]+)\"", stripedLine) 59 | if innerDep is not None: 60 | basename = innerDep.group(1) 61 | fullname = getFullPath(basename, include_dirs) 62 | if fullname is None: 63 | raise RuntimeError("can not find full path of ", basename) 64 | self.depsInner.append(fullname) 65 | continue 66 | 67 | libDep = re.match(r"#include\s+<([a-z_./\\]+)>", stripedLine) 68 | if libDep is not None: 69 | basename = libDep.group(1) 70 | self.depsLib.append(basename) 71 | continue 72 | 73 | if stripedLine.find("#define") == 0: 74 | self.defines.append(stripedLine) 75 | continue 76 | 77 | if re.match(r"^(? str: 88 | return f"inner deps: {self.depsInner}, std deps: {self.depsLib}" 89 | 90 | 91 | def combineClasses(elements: typing.List[MyClass], output_filename): 92 | # 3. 拓扑排序 93 | sorted = [] 94 | while len(elements) > 0: 95 | temp = [] 96 | 97 | # 遍历找到入度为0的节点,加入序列 98 | for i in range(len(elements) - 1, -1, -1): 99 | cls = elements[i] 100 | if len(cls.depsInner) > 0: 101 | continue 102 | 103 | # 此时入度为0 104 | temp.append(cls) 105 | elements.remove(cls) 106 | 107 | for t in elements: 108 | try: 109 | index = t.depsInner.index(cls.filename) 110 | except ValueError: 111 | continue 112 | obj = t.depsInner[index] 113 | t.depsInner.remove(obj) 114 | 115 | # 如果本轮没有入度为0的元素,代表有环形依赖 116 | if len(temp) == 0: 117 | print("[FAIL] remain elements: ", elements) 118 | raise RuntimeError("topological sort fail") 119 | 120 | sorted.extend(temp) 121 | 122 | # 4. 提取所有的#define,#include <>的内容 123 | 124 | allDefines = [] 125 | allStdDeps = [] 126 | for ele in sorted: 127 | allDefines.extend(ele.defines) 128 | allStdDeps.extend(ele.depsLib) 129 | allDefines = list(set(allDefines)) 130 | allStdDeps = list(set(allStdDeps)) 131 | allDefines.sort() 132 | allStdDeps.sort() 133 | 134 | # 5. 写入 135 | with open(output_filename, "w", encoding="utf-8 sig") as f: 136 | f.write("#pragma once\n\n") 137 | 138 | # define先于#include写入 139 | for defs in allDefines: 140 | f.write(f"{defs}\n") 141 | 142 | # 写入标准库头文件include 143 | for header in allStdDeps: 144 | f.write(f"#include <{header}>\n") 145 | 146 | # 逐个把正文写入 147 | for ele in sorted: 148 | print("writing content of ", ele.filename, "...") 149 | f.write("\n".join(ele.contents)) 150 | f.write("\n") 151 | 152 | # 6. 用clang-format处理 153 | ret = os.system(f"clang-format -i {output_filename}") 154 | if ret != 0: 155 | raise RuntimeError("[ERR] fail to run clang-format") 156 | 157 | print("Done.") 158 | 159 | 160 | def combineCodeFile(target_dir, output_filename, src_filenames, include_dirs): 161 | # 1. 创建文件夹 162 | if not os.path.isdir(target_dir): 163 | os.makedirs(target_dir) 164 | 165 | # 2. 把源路径下所有文件都分析一遍 166 | elements: typing.List[MyClass] = [] 167 | for path in src_filenames: 168 | elements.append(MyClass(path, include_dirs)) 169 | 170 | combineClasses(elements, output_filename) 171 | -------------------------------------------------------------------------------- /single/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | file(GLOB TEST_CODE 3 | ../include/tomsolver/*.h 4 | *.cpp 5 | ) 6 | 7 | add_executable(TomSolverSingleTest ${TEST_CODE}) 8 | 9 | target_include_directories(TomSolverSingleTest PUBLIC 10 | ../include 11 | ) 12 | 13 | target_link_libraries(TomSolverSingleTest PUBLIC 14 | gtest_main 15 | ) 16 | 17 | include(GoogleTest) 18 | gtest_discover_tests(TomSolverSingleTest) -------------------------------------------------------------------------------- /src/tomsolver/config.cpp: -------------------------------------------------------------------------------- 1 | #include "config.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tomsolver { 9 | 10 | namespace { 11 | 12 | static const std::tuple strategies[] = { 13 | {"%.16e", std::regex{"\\.?0+(?=e)"}}, 14 | {"%.16f", std::regex{"\\.?0+(?=$)"}}, 15 | }; 16 | 17 | } 18 | 19 | std::string ToString(double value) noexcept { 20 | if (value == 0.0) { 21 | return "0"; 22 | } 23 | 24 | char buf[64]; 25 | 26 | // 绝对值过大 或者 绝对值过小,应该使用科学计数法来表示 27 | auto getStrategyIdx = [absValue = std::abs(value)] { 28 | return (absValue >= 1.0e16 || absValue <= 1.0e-16) ? 0 : 1; 29 | }; 30 | 31 | auto &strategy = strategies[getStrategyIdx()]; 32 | auto fmt = std::get<0>(strategy); 33 | auto &re = std::get<1>(strategy); 34 | 35 | snprintf(buf, sizeof(buf), fmt, value); 36 | return std::regex_replace(buf, re, ""); 37 | } 38 | 39 | void Config::Reset() noexcept { 40 | *this = {}; 41 | } 42 | 43 | Config &Config::Get() { 44 | static Config config; 45 | return config; 46 | } 47 | 48 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace tomsolver { 6 | 7 | enum class LogLevel { OFF, FATAL, ERROR, WARN, INFO, DEBUG, TRACE, ALL }; 8 | 9 | enum class NonlinearMethod { NEWTON_RAPHSON, LM }; 10 | 11 | struct Config { 12 | /** 13 | * 指定出现浮点数无效值(inf, -inf, nan)时,是否抛出异常。默认为true。 14 | */ 15 | bool throwOnInvalidValue = true; 16 | 17 | double epsilon = 1.0e-9; 18 | 19 | LogLevel logLevel = LogLevel::WARN; 20 | 21 | /** 22 | * 最大迭代次数限制 23 | */ 24 | int maxIterations = 100; 25 | 26 | /** 27 | * 求解方法 28 | */ 29 | NonlinearMethod nonlinearMethod = NonlinearMethod::NEWTON_RAPHSON; 30 | 31 | /** 32 | * 非线性方程求解时,当没有为VarsTable传初值时,设定的初值 33 | */ 34 | double initialValue = 1.0; 35 | 36 | /** 37 | * 是否允许不定方程存在。 38 | * 例如,当等式数量大于未知数数量时,方程组成为不定方程; 39 | * 如果允许,此时将返回一组特解;如果不允许,将抛出异常。 40 | */ 41 | bool allowIndeterminateEquation = false; 42 | 43 | void Reset() noexcept; 44 | 45 | static Config &Get(); 46 | }; 47 | 48 | std::string ToString(double value) noexcept; 49 | 50 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/diff.cpp: -------------------------------------------------------------------------------- 1 | #include "diff.h" 2 | 3 | #include "functions.h" 4 | #include "simplify.h" 5 | 6 | #include 7 | #include // std::log 8 | 9 | namespace tomsolver { 10 | 11 | namespace internal { 12 | 13 | class DiffFunctions { 14 | public: 15 | struct DiffNode { 16 | NodeImpl &node; 17 | const bool isLeftChild; 18 | 19 | DiffNode(NodeImpl &node) : node(node), isLeftChild(node.parent && node.parent->left.get() == &node) {} 20 | }; 21 | 22 | static void DiffOnce(Node &root, const std::string &varname) { 23 | std::queue q; 24 | 25 | if (root->type == NodeType::OPERATOR) { 26 | DiffOnceOperator(root, q); 27 | } else { 28 | q.emplace(*root); 29 | } 30 | 31 | while (!q.empty()) { 32 | auto& node = q.front().node; 33 | auto isLeftChild = q.front().isLeftChild; 34 | q.pop(); 35 | 36 | switch (node.type) { 37 | case NodeType::VARIABLE: 38 | node.type = NodeType::NUMBER; 39 | node.value = node.varname == varname ? 1 : 0; 40 | node.varname = ""; 41 | break; 42 | 43 | case NodeType::NUMBER: 44 | node.value = 0; 45 | break; 46 | 47 | case NodeType::OPERATOR: { 48 | auto &child = isLeftChild ? node.parent->left : node.parent->right; 49 | DiffOnceOperator(child, q); 50 | break; 51 | } 52 | default: 53 | assert(0 && "inner bug"); 54 | } 55 | } 56 | } 57 | 58 | static void DiffOnceOperator(Node &node, std::queue &q) { 59 | auto parent = node->parent; 60 | 61 | // 调用前提:node是1元操作符 62 | // 如果node的成员是数字,那么整个node变为数字节点,value=0,且返回true 63 | // 例如: sin(1)' = 0 64 | auto CullNumberMember = [&node]() -> bool { 65 | assert(GetOperatorNum(node->op) == 1); 66 | assert(node->left); 67 | if (node->left->type == NodeType::NUMBER) { 68 | node->left = nullptr; 69 | node->type = NodeType::NUMBER; 70 | node->op = MathOperator::MATH_NULL; 71 | node->value = 0.0; 72 | return true; 73 | } 74 | return false; 75 | }; 76 | 77 | // 调用前提:node是2元操作符 78 | // 如果node的成员是数字,那么整个node变为数字节点,value=0,且返回true 79 | // 例如: (2*3)' = 0 80 | auto CullNumberMemberBinary = [&node]() -> bool { 81 | assert(GetOperatorNum(node->op) == 2); 82 | assert(node->left && node->right); 83 | if (node->left->type == NodeType::NUMBER && node->right->type == NodeType::NUMBER) { 84 | node->left = nullptr; 85 | node->right = nullptr; 86 | node->type = NodeType::NUMBER; 87 | node->op = MathOperator::MATH_NULL; 88 | node->value = 0.0; 89 | return true; 90 | } 91 | return false; 92 | }; 93 | 94 | switch (node->op) { 95 | case MathOperator::MATH_NULL: { 96 | assert(0 && "inner bug"); 97 | break; 98 | } 99 | case MathOperator::MATH_POSITIVE: 100 | case MathOperator::MATH_NEGATIVE: { 101 | q.emplace(*node->left); 102 | return; 103 | } 104 | 105 | // 函数 106 | case MathOperator::MATH_SIN: { 107 | if (CullNumberMember()) { 108 | return; 109 | } 110 | 111 | // sin(u)' = cos(u) * u' 112 | node->op = MathOperator::MATH_COS; 113 | auto u2 = Clone(node->left); 114 | q.emplace(*u2); 115 | node = Move(node) * Move(u2); 116 | node->parent = parent; 117 | break; 118 | } 119 | case MathOperator::MATH_COS: { 120 | if (CullNumberMember()) { 121 | return; 122 | } 123 | 124 | // cos(u)' = -sin(u) * u' 125 | node->op = MathOperator::MATH_SIN; 126 | auto u2 = Clone(node->left); 127 | q.emplace(*u2); 128 | node = -Move(node) * Move(u2); 129 | node->parent = parent; 130 | break; 131 | } 132 | case MathOperator::MATH_TAN: { 133 | if (CullNumberMember()) { 134 | return; 135 | } 136 | 137 | // tan'u = 1/(cos(u)^2) * u' 138 | node->op = MathOperator::MATH_COS; 139 | auto &u = node->left; 140 | auto u2 = Clone(u); 141 | q.emplace(*u2); 142 | node = Num(1) / (Move(node) ^ Num(2)) * Move(u2); 143 | node->parent = parent; 144 | return; 145 | } 146 | case MathOperator::MATH_ARCSIN: { 147 | if (CullNumberMember()) { 148 | return; 149 | } 150 | 151 | // asin'u = 1/sqrt(1-u^2) * u' 152 | auto &u = node->left; 153 | auto u2 = Clone(u); 154 | q.emplace(*u2); 155 | node = (Num(1) / sqrt(Num(1) - (Move(u) ^ Num(2)))) * Move(u2); 156 | node->parent = parent; 157 | return; 158 | } 159 | case MathOperator::MATH_ARCCOS: { 160 | if (CullNumberMember()) { 161 | return; 162 | } 163 | 164 | // acos'u = -1/sqrt(1-u^2) * u' 165 | auto &u = node->left; 166 | auto u2 = Clone(u); 167 | q.emplace(*u2); 168 | node = (Num(-1) / sqrt(Num(1) - (Move(u) ^ Num(2)))) * Move(u2); 169 | node->parent = parent; 170 | return; 171 | } 172 | case MathOperator::MATH_ARCTAN: { 173 | if (CullNumberMember()) { 174 | return; 175 | } 176 | 177 | // atan'u = 1/(1+u^2) * u' 178 | auto &u = node->left; 179 | auto u2 = Clone(u); 180 | q.emplace(*u2); 181 | node = (Num(1) / (Num(1) + (Move(u) ^ Num(2)))) * Move(u2); 182 | node->parent = parent; 183 | return; 184 | } 185 | case MathOperator::MATH_SQRT: { 186 | if (CullNumberMember()) { 187 | return; 188 | } 189 | 190 | // sqrt(u)' = 1/(2*sqrt(u)) * u' 191 | auto &u = node->left; 192 | auto u2 = Clone(u); 193 | q.emplace(*u2); 194 | node = Num(1) / (Num(2) * Move(node)) * Move(u2); 195 | node->parent = parent; 196 | return; 197 | } 198 | case MathOperator::MATH_LOG: { 199 | if (CullNumberMember()) { 200 | return; 201 | } 202 | 203 | // ln(u)' = 1/u * u' 204 | auto &u = node->left; 205 | auto u2 = Clone(u); 206 | q.emplace(*u2); 207 | node = (Num(1) / Move(u)) * Move(u2); 208 | node->parent = parent; 209 | return; 210 | } 211 | case MathOperator::MATH_LOG2: { 212 | if (CullNumberMember()) { 213 | return; 214 | } 215 | 216 | // loga(u)' = 1/(u * ln(a)) * u' 217 | auto a = 2.0; 218 | auto &u = node->left; 219 | auto u2 = Clone(u); 220 | q.emplace(*u2); 221 | node = (Num(1) / (Move(u) * Num(std::log(a)))) * Move(u2); 222 | node->parent = parent; 223 | return; 224 | } 225 | case MathOperator::MATH_LOG10: { 226 | if (CullNumberMember()) { 227 | return; 228 | } 229 | 230 | // loga(u)' = 1/(u * ln(a)) * u' 231 | auto a = 10.0; 232 | auto &u = node->left; 233 | auto u2 = Clone(u); 234 | q.emplace(*u2); 235 | node = (Num(1) / (Move(u) * Num(std::log(a)))) * Move(u2); 236 | node->parent = parent; 237 | return; 238 | } 239 | case MathOperator::MATH_EXP: { 240 | if (CullNumberMember()) { 241 | return; 242 | } 243 | 244 | // e^x=e^x 245 | if (node->left->type == NodeType::VARIABLE) 246 | return; 247 | 248 | // (e^u)' = e^u * u' 249 | auto u2 = Clone(node->left); 250 | q.emplace(*u2); 251 | node = Move(node) * Move(u2); 252 | node->parent = parent; 253 | break; 254 | } 255 | 256 | // 二元 257 | case MathOperator::MATH_ADD: 258 | case MathOperator::MATH_SUB: 259 | if (CullNumberMemberBinary()) { 260 | return; 261 | } 262 | // (u + v)' = u' + v' 263 | if (node->left) { 264 | q.emplace(*node->left); 265 | } 266 | if (node->right) { 267 | q.emplace(*node->right); 268 | } 269 | return; 270 | case MathOperator::MATH_MULTIPLY: { 271 | // 两个操作数中有一个是数字 272 | if ( node->left->type == NodeType::NUMBER) { 273 | q.emplace(*node->right); 274 | return; 275 | } 276 | if (node->right->type == NodeType::NUMBER) { 277 | q.emplace(*node->left); 278 | return; 279 | } 280 | 281 | if (CullNumberMemberBinary()) { 282 | return; 283 | } 284 | 285 | // (u*v)' = u' * v + u * v' 286 | auto &u = node->left; 287 | auto &v = node->right; 288 | q.emplace(*u); 289 | auto u2 = Clone(u); 290 | auto v2 = Clone(v); 291 | q.emplace(*v2); 292 | node = Move(node) + Move(u2) * Move(v2); 293 | node->parent = parent; 294 | return; 295 | } 296 | case MathOperator::MATH_DIVIDE: { 297 | // auto leftIsNumber = node->left->type == NodeType::NUMBER; 298 | auto rightIsNumber = node->right->type == NodeType::NUMBER; 299 | 300 | // f(x)/number = f'(x)/number 301 | if (rightIsNumber) { 302 | q.emplace(*node->left); 303 | return; 304 | } 305 | 306 | if (CullNumberMemberBinary()) { 307 | return; 308 | } 309 | 310 | // (u/v)' = (u'v - uv')/(v^2) 311 | auto &u = node->left; 312 | auto &v = node->right; 313 | auto u2 = Clone(u); 314 | auto v2 = Clone(v); 315 | auto v3 = Clone(v); 316 | q.emplace(*u); 317 | q.emplace(*v2); 318 | node = (Move(u) * Move(v) - Move(u2) * Move(v2)) / (Move(v3) ^ Num(2)); 319 | node->parent = parent; 320 | return; 321 | } 322 | case MathOperator::MATH_POWER: { 323 | // 如果两个操作数都是数字 324 | if (CullNumberMemberBinary()) { 325 | return; 326 | } 327 | 328 | auto lChildIsNumber = node->left->type == NodeType::NUMBER; 329 | auto rChildIsNumber = node->right->type == NodeType::NUMBER; 330 | 331 | // (u^a)' = a*u^(a-1) * u' 332 | if (rChildIsNumber) { 333 | auto &a = node->right; 334 | auto aValue = a->value; 335 | auto &u = node->left; 336 | auto u2 = Clone(u); 337 | q.emplace(*u2); 338 | node = std::move(a) * (std::move(u) ^ Num(aValue - 1)) * std::move(u2); 339 | node->parent = parent; 340 | return; 341 | } 342 | 343 | // (a^x)' = a^x * ln(a) when a>0 and a!=1 344 | if (lChildIsNumber) { 345 | auto &a = node->left; 346 | auto aValue = a->value; 347 | auto &u = node->right; 348 | auto u2 = Clone(u); 349 | q.emplace(*u2); 350 | node = (std::move(a) ^ std::move(u)) * log(Num(aValue)) * std::move(u2); 351 | node->parent = parent; 352 | return; 353 | } 354 | 355 | // (u^v)' = ( e^(v*ln(u)) )' = e^(v*ln(u)) * (v*ln(u))' = u^v * (v*ln(u))' 356 | // 左右都不是数字 357 | auto &u = node->left; 358 | auto &v = node->right; 359 | auto vln_u = Clone(v) * log(Clone(u)); 360 | q.emplace(*vln_u); 361 | node = Move(node) * std::move(vln_u); 362 | node->parent = parent; 363 | return; 364 | } 365 | 366 | case MathOperator::MATH_AND: { 367 | throw std::runtime_error("can not apply diff for AND(&) operator"); 368 | return; 369 | } 370 | case MathOperator::MATH_OR: { 371 | throw std::runtime_error("can not apply diff for OR(|) operator"); 372 | return; 373 | } 374 | case MathOperator::MATH_MOD: { 375 | throw std::runtime_error("can not apply diff for MOD(%) operator"); 376 | return; 377 | } 378 | case MathOperator::MATH_LEFT_PARENTHESIS: 379 | case MathOperator::MATH_RIGHT_PARENTHESIS: 380 | assert(0 && "inner bug"); 381 | return; 382 | default: 383 | assert(0 && "inner bug"); 384 | return; 385 | } 386 | } 387 | }; 388 | 389 | } // namespace internal 390 | 391 | Node Diff(const Node &node, const std::string &varname, int i) { 392 | auto node2 = Clone(node); 393 | return Diff(std::move(node2), varname, i); 394 | } 395 | 396 | Node Diff(Node &&node, const std::string &varname, int i) { 397 | assert(i > 0); 398 | auto n = std::move(node); 399 | while (i--) { 400 | internal::DiffFunctions::DiffOnce(n, varname); 401 | } 402 | #ifndef NDEBUG 403 | auto s = n->ToString(); 404 | n->CheckParent(); 405 | #endif 406 | Simplify(n); 407 | #ifndef NDEBUG 408 | n->CheckParent(); 409 | #endif 410 | return n; 411 | } 412 | 413 | } // namespace tomsolver 414 | -------------------------------------------------------------------------------- /src/tomsolver/diff.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "node.h" 4 | 5 | namespace tomsolver { 6 | 7 | /** 8 | * node对varname求导。在node包含多个变量时,是对varname求偏导。 9 | * @exception runtime_error 如果表达式内包含AND(&) OR(|) MOD(%)这类不能求导的运算符,则抛出异常 10 | */ 11 | Node Diff(const Node &node, const std::string &varname, int i = 1); 12 | 13 | /** 14 | * node对varname求导。在node包含多个变量时,是对varname求偏导。 15 | * @exception runtime_error 如果表达式内包含AND(&) OR(|) MOD(%)这类不能求导的运算符,则抛出异常 16 | */ 17 | Node Diff(Node &&node, const std::string &varname, int i = 1); 18 | 19 | } // namespace tomsolver 20 | -------------------------------------------------------------------------------- /src/tomsolver/error_type.cpp: -------------------------------------------------------------------------------- 1 | #include "error_type.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace tomsolver { 7 | 8 | std::string GetErrorInfo(ErrorType err) { 9 | switch (err) { 10 | case ErrorType::ERROR_INVALID_NUMBER: 11 | return u8"invalid number"; 12 | break; 13 | case ErrorType::ERROR_ILLEGALCHAR: 14 | return u8"illegal character"; 15 | break; 16 | case ErrorType::ERROR_INVALID_VARNAME: 17 | return u8"invalid variable name (must start with an underscore \"_\" or a letter)"; 18 | break; 19 | case ErrorType::ERROR_WRONG_EXPRESSION: 20 | return u8"invalid expression"; 21 | break; 22 | case ErrorType::ERROR_EMPTY_INPUT: 23 | return u8"empty input"; 24 | break; 25 | case ErrorType::ERROR_UNDEFINED_VARIABLE: 26 | return u8"undefined variable"; 27 | break; 28 | case ErrorType::ERROR_SUBS_NOT_EQUAL: 29 | return u8"number of substitutions does not match the number of items to be replaced"; 30 | break; 31 | case ErrorType::ERROR_NOT_LINK_VARIABLETABLE: 32 | return u8"not linked variable table"; 33 | break; 34 | case ErrorType::ERROR_OUTOF_DOMAIN: 35 | return u8"out of domain"; 36 | break; 37 | case ErrorType::ERROR_VAR_COUNT_NOT_EQUAL_NUM_COUNT: 38 | return u8"the number of variable is not equal with number count"; 39 | break; 40 | case ErrorType::ERROR_VAR_HAS_BEEN_DEFINED: 41 | return u8"variable redefined"; 42 | break; 43 | case ErrorType::ERROR_INDETERMINATE_EQUATION: 44 | return u8"indeterminate equation"; 45 | break; 46 | case ErrorType::ERROR_SINGULAR_MATRIX: 47 | return u8"singular matrix"; 48 | break; 49 | case ErrorType::ERROR_INFINITY_SOLUTIONS: 50 | return u8"infinite solutions"; 51 | break; 52 | case ErrorType::ERROR_OVER_DETERMINED_EQUATIONS: 53 | return u8"overdetermined equations"; 54 | break; 55 | case ErrorType::SIZE_NOT_MATCH: 56 | return u8"size does not match"; 57 | default: 58 | assert(0); 59 | break; 60 | } 61 | return u8"GetErrorInfo: bug"; 62 | } 63 | 64 | MathError::MathError(ErrorType errorType, const std::string &extInfo) 65 | : std::runtime_error(extInfo), errorType(errorType) { 66 | errInfo = GetErrorInfo(errorType); 67 | if (!extInfo.empty()) { 68 | errInfo += ": \"" + extInfo + "\""; 69 | } 70 | } 71 | 72 | const char *MathError::what() const noexcept { 73 | return errInfo.c_str(); 74 | } 75 | 76 | ErrorType MathError::GetErrorType() const noexcept { 77 | return errorType; 78 | } 79 | 80 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/error_type.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace tomsolver { 7 | 8 | enum class ErrorType { 9 | ERROR_INVALID_NUMBER, // 出现无效的浮点数(inf, -inf, nan) 10 | ERROR_ILLEGALCHAR, // 出现非法字符 11 | ERROR_INVALID_VARNAME, // 无效变量名 12 | ERROR_WRONG_EXPRESSION, // 表达式逻辑不正确 13 | ERROR_EMPTY_INPUT, // 表达式为空 14 | ERROR_UNDEFINED_VARIABLE, // 未定义的变量 15 | ERROR_SUBS_NOT_EQUAL, // 替换与被替换数量不对等 16 | ERROR_NOT_LINK_VARIABLETABLE, // 未链接变量表 17 | ERROR_OUTOF_DOMAIN, // 计算超出定义域 18 | ERROR_VAR_COUNT_NOT_EQUAL_NUM_COUNT, // 定义变量时变量数量与初始值不等 19 | ERROR_VAR_HAS_BEEN_DEFINED, // 变量重定义 20 | ERROR_INDETERMINATE_EQUATION, // 不定方程 21 | ERROR_SINGULAR_MATRIX, // 矩阵奇异 22 | ERROR_INFINITY_SOLUTIONS, // 无穷多解 23 | ERROR_OVER_DETERMINED_EQUATIONS, // 方程组过定义 24 | SIZE_NOT_MATCH // 维数不匹配 25 | }; 26 | 27 | std::string GetErrorInfo(ErrorType err); 28 | 29 | class MathError : public std::runtime_error { 30 | public: 31 | MathError(ErrorType errorType, const std::string &extInfo = {}); 32 | 33 | virtual const char *what() const noexcept override; 34 | 35 | ErrorType GetErrorType() const noexcept; 36 | 37 | private: 38 | ErrorType errorType; 39 | std::string errInfo; 40 | }; 41 | 42 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/functions.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "node.h" 3 | 4 | namespace tomsolver { 5 | 6 | template 7 | Node sin(T &&n) noexcept { 8 | return internal::UnaryOperator(MathOperator::MATH_SIN, std::forward(n)); 9 | } 10 | 11 | template 12 | Node cos(T &&n) noexcept { 13 | return internal::UnaryOperator(MathOperator::MATH_COS, std::forward(n)); 14 | } 15 | 16 | template 17 | Node tan(T &&n) noexcept { 18 | return internal::UnaryOperator(MathOperator::MATH_TAN, std::forward(n)); 19 | } 20 | 21 | template 22 | Node asin(T &&n) noexcept { 23 | return internal::UnaryOperator(MathOperator::MATH_ARCSIN, std::forward(n)); 24 | } 25 | 26 | template 27 | Node acos(T &&n) noexcept { 28 | return internal::UnaryOperator(MathOperator::MATH_ARCCOS, std::forward(n)); 29 | } 30 | 31 | template 32 | Node atan(T &&n) noexcept { 33 | return internal::UnaryOperator(MathOperator::MATH_ARCTAN, std::forward(n)); 34 | } 35 | 36 | template 37 | Node sqrt(T &&n) noexcept { 38 | return internal::UnaryOperator(MathOperator::MATH_SQRT, std::forward(n)); 39 | } 40 | 41 | template 42 | Node log(T &&n) noexcept { 43 | return internal::UnaryOperator(MathOperator::MATH_LOG, std::forward(n)); 44 | } 45 | 46 | template 47 | Node log2(T &&n) noexcept { 48 | return internal::UnaryOperator(MathOperator::MATH_LOG2, std::forward(n)); 49 | } 50 | 51 | template 52 | Node log10(T &&n) noexcept { 53 | return internal::UnaryOperator(MathOperator::MATH_LOG10, std::forward(n)); 54 | } 55 | 56 | template 57 | Node exp(T &&n) noexcept { 58 | return internal::UnaryOperator(MathOperator::MATH_EXP, std::forward(n)); 59 | } 60 | 61 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/linear.cpp: -------------------------------------------------------------------------------- 1 | #include "linear.h" 2 | 3 | #include "config.h" 4 | #include "error_type.h" 5 | 6 | #include 7 | #include 8 | 9 | namespace tomsolver { 10 | 11 | namespace { 12 | template 13 | const T &asConst(T &a) { 14 | return a; 15 | } 16 | } // namespace 17 | 18 | Vec SolveLinear(Mat A, Vec b) { 19 | if (Config::Get().logLevel >= LogLevel::TRACE) { 20 | std::cout << "SolveLinear:Ax=b (x is the wanted)\n"; 21 | std::cout << "A=\n" + A.ToString(); 22 | std::cout << "b=\n" + b.ToString(); 23 | } 24 | 25 | int rows = A.Rows(); // 行数 26 | int cols = rows; // 列数=未知数个数 27 | 28 | int RankA = rows, RankAb = rows; // 初始值 29 | 30 | assert(rows == b.Rows()); // A行数不等于b行数 31 | 32 | Vec ret(rows); 33 | 34 | if (rows > 0) { 35 | cols = A.Cols(); 36 | } 37 | if (cols != rows) // 不是方阵 38 | { 39 | if (rows > cols) { 40 | // 过定义方程组 41 | throw MathError(ErrorType::ERROR_OVER_DETERMINED_EQUATIONS); 42 | } else { 43 | // 不定方程组 44 | ret.Resize(cols); 45 | } 46 | } 47 | 48 | std::vector TrueRowNumber(cols); 49 | 50 | // 列主元消元法 51 | for (auto y = 0, x = 0; y < rows && x < cols; y++, x++) { 52 | // if (A[i].size() != rows) 53 | 54 | // 从当前行(y)到最后一行(rows-1)中,找出x列最大的一行与y行交换 55 | int maxAbsRowIndex = GetMaxAbsRowIndex(A, y, rows - 1, x); 56 | A.SwapRow(y, maxAbsRowIndex); 57 | b.SwapRow(y, maxAbsRowIndex); 58 | 59 | while (std::abs(A.Value(y, x)) < Config::Get().epsilon) // 如果当前值为0 x一直递增到非0 60 | { 61 | x++; 62 | if (x == cols) { 63 | break; 64 | } 65 | 66 | // 交换本行与最大行 67 | maxAbsRowIndex = GetMaxAbsRowIndex(A, y, rows - 1, x); 68 | A.SwapRow(y, maxAbsRowIndex); 69 | b.SwapRow(y, maxAbsRowIndex); 70 | } 71 | 72 | if (x != cols && x > y) { 73 | TrueRowNumber[y] = x; // 补齐方程时 当前行应换到x行 74 | } 75 | 76 | if (x == cols) // 本行全为0 77 | { 78 | RankA = y; 79 | if (std::abs(b[y]) < Config::Get().epsilon) { 80 | RankAb = y; 81 | } 82 | 83 | if (RankA != RankAb) { 84 | // 奇异,且系数矩阵及增广矩阵秩不相等->无解 85 | throw MathError(ErrorType::ERROR_SINGULAR_MATRIX); 86 | } else { 87 | // 跳出for,得到特解 88 | break; 89 | } 90 | } 91 | 92 | // 主对角线化为1 93 | auto ratioY = A.Value(y, x); 94 | // y行第j个->第cols个 95 | std::valarray rowY = asConst(A).Row(y, x) / ratioY; 96 | A.Row(y, x) = rowY; 97 | b[y] /= ratioY; 98 | 99 | // 每行化为0 100 | for (auto row = y + 1; row < rows; row++) // 下1行->最后1行 101 | { 102 | auto ratioRow = A.Value(row, x); 103 | if (std::abs(ratioRow) >= Config::Get().epsilon) { 104 | A.Row(row, x) -= rowY * ratioRow; 105 | b[row] -= b[y] * ratioRow; 106 | } 107 | } 108 | } 109 | 110 | bool bIndeterminateEquation = false; // 设置此变量是因为后面rows将=cols,标记以判断是否为不定方程组 111 | 112 | // 若为不定方程组,空缺行全填0继续运算 113 | if (rows != cols) { 114 | A.Resize(cols, cols); 115 | b.Resize(cols); 116 | rows = cols; 117 | bIndeterminateEquation = true; 118 | 119 | // 调整顺序 120 | for (int i = rows - 1; i >= 0; i--) { 121 | if (TrueRowNumber[i] != 0) { 122 | A.SwapRow(i, TrueRowNumber[i]); 123 | b.SwapRow(i, TrueRowNumber[i]); 124 | } 125 | } 126 | } 127 | 128 | // 后置换得到x 129 | for (int i = rows - 1; i >= 0; i--) // 最后1行->第1行 130 | { 131 | auto vec = asConst(A).Row(i, i + 1) * asConst(ret).Col(0, i + 1); 132 | ret[i] = b[i] - (vec.size() ? vec.sum() : 0); 133 | } 134 | 135 | if (RankA < cols && RankA == RankAb) { 136 | if (bIndeterminateEquation) { 137 | if (!Config::Get().allowIndeterminateEquation) { 138 | throw MathError(ErrorType::ERROR_INDETERMINATE_EQUATION, 139 | "A = " + A.ToString() + "\nb = " + b.ToString()); 140 | } 141 | } else { 142 | throw MathError(ErrorType::ERROR_INFINITY_SOLUTIONS); 143 | } 144 | } 145 | 146 | return ret; 147 | } 148 | 149 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/linear.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mat.h" 4 | 5 | namespace tomsolver { 6 | 7 | /** 8 | * 求解线性方程组Ax = b。传入矩阵A,向量b,返回向量x。 9 | * @exception MathError 奇异矩阵 10 | * @exception MathError 矛盾方程组 11 | * @exception MathError 不定方程(设置Config::Get().allowIndeterminateEquation=true可以允许不定方程组返回一组特解) 12 | * 13 | */ 14 | Vec SolveLinear(Mat A, Vec b); 15 | 16 | } // namespace tomsolver 17 | -------------------------------------------------------------------------------- /src/tomsolver/mat.h: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Original Inverse(), Adjoint(), GetCofactor(), Det() is from https://github.com/taehwan642: 4 | 5 | /////////////////////////////////////////// 6 | MADE BY TAE HWAN KIM, SHIN JAE HO 7 | 김태환, 신재호 제작 8 | If you see this documents, you can learn & understand Faster. 9 | 밑에 자료들을 보시면, 더욱 빠르게 배우고 이해하실 수 있으실겁니다. 10 | https://www.wikihow.com/Find-the-Inverse-of-a-3x3-Matrix 11 | https://www.wikihow.com/Find-the-Determinant-of-a-3X3-Matrix 12 | LAST UPDATE 2020 - 03 - 30 13 | 마지막 업데이트 2020 - 03 - 30 14 | This is my Github Profile. You can use this source whenever you want. 15 | 제 깃허브 페이지입니다. 언제든지 이 소스를 가져다 쓰셔도 됩니다. 16 | https://github.com/taehwan642 17 | Thanks :) 18 | 감사합니다 :) 19 | /////////////////////////////////////////// 20 | 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | namespace tomsolver { 31 | 32 | class Vec; 33 | 34 | class Mat { 35 | public: 36 | explicit Mat(int row, int col, double initValue = 0) noexcept; 37 | 38 | Mat(std::initializer_list> init) noexcept; 39 | 40 | Mat(int row, int col, std::valarray data) noexcept; 41 | 42 | Mat(const Mat &) = default; 43 | Mat(Mat &&) = default; 44 | Mat &operator=(const Mat &) = default; 45 | Mat &operator=(Mat &&) = default; 46 | 47 | std::slice_array Row(int i, int offset = 0); 48 | std::slice_array Col(int j, int offset = 0); 49 | auto Row(int i, int offset = 0) const -> decltype(std::declval>()[(std::slice{})]); 50 | auto Col(int j, int offset = 0) const -> decltype(std::declval>()[(std::slice{})]); 51 | const double &Value(int i, int j) const; 52 | double &Value(int i, int j); 53 | 54 | bool operator==(double m) const noexcept; 55 | bool operator==(const Mat &b) const noexcept; 56 | 57 | // be negative 58 | Mat operator-() noexcept; 59 | 60 | Mat operator+(const Mat &b) const noexcept; 61 | Mat &operator+=(const Mat &b) noexcept; 62 | 63 | Mat operator-(const Mat &b) const noexcept; 64 | 65 | Mat operator*(double m) const noexcept; 66 | Mat operator*(const Mat &b) const noexcept; 67 | 68 | int Rows() const noexcept; 69 | 70 | int Cols() const noexcept; 71 | 72 | /** 73 | * 输出Vec。如果列数不为1,抛出异常。 74 | * @exception runtime_error 列数不为1 75 | */ 76 | Vec ToVec() const; 77 | 78 | Mat &SwapRow(int i, int j) noexcept; 79 | Mat &SwapCol(int i, int j) noexcept; 80 | 81 | std::string ToString() const noexcept; 82 | 83 | void Resize(int newRows, int newCols) noexcept; 84 | 85 | Mat &Zero() noexcept; 86 | 87 | Mat &Ones() noexcept; 88 | 89 | double Norm2() const noexcept; 90 | 91 | double NormInfinity() const noexcept; 92 | 93 | double NormNegInfinity() const noexcept; 94 | 95 | double Min() const noexcept; 96 | 97 | void SetValue(double value) noexcept; 98 | 99 | /** 100 | * 返回矩阵是否正定。 101 | */ 102 | bool PositiveDetermine() const noexcept; 103 | 104 | Mat Transpose() const noexcept; 105 | 106 | /** 107 | * 计算逆矩阵。 108 | * @exception MathError 如果是奇异矩阵,抛出异常 109 | */ 110 | Mat Inverse() const; 111 | 112 | protected: 113 | int rows; 114 | int cols; 115 | std::valarray data; 116 | 117 | friend Mat operator*(double k, const Mat &mat) noexcept; 118 | friend std::ostream &operator<<(std::ostream &out, const Mat &mat) noexcept; 119 | friend Mat EachDivide(const Mat &a, const Mat &b) noexcept; 120 | friend bool IsZero(const Mat &mat) noexcept; 121 | friend bool AllIsLessThan(const Mat &v1, const Mat &v2) noexcept; 122 | friend void GetCofactor(const Mat &A, Mat &temp, int p, int q, int n) noexcept; 123 | friend void Adjoint(const Mat &A, Mat &adj) noexcept; 124 | friend double Det(const Mat &A, int n) noexcept; 125 | }; 126 | 127 | Mat operator*(double k, const Mat &mat) noexcept; 128 | 129 | std::ostream &operator<<(std::ostream &out, const Mat &mat) noexcept; 130 | 131 | Mat EachDivide(const Mat &a, const Mat &b) noexcept; 132 | 133 | bool IsZero(const Mat &mat) noexcept; 134 | 135 | bool AllIsLessThan(const Mat &v1, const Mat &v2) noexcept; 136 | 137 | int GetMaxAbsRowIndex(const Mat &A, int rowStart, int rowEnd, int col) noexcept; 138 | 139 | /** 140 | * 伴随矩阵。 141 | */ 142 | void Adjoint(const Mat &A, Mat &adj) noexcept; 143 | 144 | void GetCofactor(const Mat &A, Mat &temp, int p, int q, int n) noexcept; 145 | 146 | /** 147 | * 计算矩阵的行列式值。 148 | */ 149 | double Det(const Mat &A, int n) noexcept; 150 | 151 | class Vec : public Mat { 152 | public: 153 | explicit Vec(int rows, double initValue = 0) noexcept; 154 | 155 | Vec(std::initializer_list init) noexcept; 156 | 157 | Vec(std::valarray data) noexcept; 158 | 159 | Mat &AsMat() noexcept; 160 | 161 | void Resize(int newRows) noexcept; 162 | 163 | double &operator[](std::size_t i) noexcept; 164 | 165 | double operator[](std::size_t i) const noexcept; 166 | 167 | Vec operator+(const Vec &b) const noexcept; 168 | 169 | // be negative 170 | Vec operator-() noexcept; 171 | 172 | Vec operator-(const Vec &b) const noexcept; 173 | 174 | Vec operator*(double m) const noexcept; 175 | 176 | Vec operator*(const Vec &b) const noexcept; 177 | 178 | Vec operator/(const Vec &b) const noexcept; 179 | 180 | bool operator<(const Vec &b) noexcept; 181 | 182 | friend double Dot(const Vec &a, const Vec &b) noexcept; 183 | friend Vec operator*(double k, const Vec &V); 184 | }; 185 | 186 | /** 187 | * 向量点乘。 188 | */ 189 | double Dot(const Vec &a, const Vec &b) noexcept; 190 | 191 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/math_operator.cpp: -------------------------------------------------------------------------------- 1 | #include "math_operator.h" 2 | 3 | #include "config.h" 4 | #include "error_type.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace tomsolver { 11 | 12 | std::string MathOperatorToStr(MathOperator op) { 13 | switch (op) { 14 | case MathOperator::MATH_NULL: 15 | assert(0); 16 | return ""; 17 | // 一元 18 | case MathOperator::MATH_POSITIVE: 19 | return "+"; 20 | case MathOperator::MATH_NEGATIVE: 21 | return "-"; 22 | // 函数 23 | case MathOperator::MATH_SIN: 24 | return "sin"; 25 | case MathOperator::MATH_COS: 26 | return "cos"; 27 | case MathOperator::MATH_TAN: 28 | return "tan"; 29 | case MathOperator::MATH_ARCSIN: 30 | return "asin"; 31 | case MathOperator::MATH_ARCCOS: 32 | return "acos"; 33 | case MathOperator::MATH_ARCTAN: 34 | return "atan"; 35 | case MathOperator::MATH_SQRT: 36 | return "sqrt"; 37 | case MathOperator::MATH_LOG: 38 | return "log"; 39 | case MathOperator::MATH_LOG2: 40 | return "log2"; 41 | case MathOperator::MATH_LOG10: 42 | return "log10"; 43 | case MathOperator::MATH_EXP: 44 | return "exp"; 45 | // 二元 46 | case MathOperator::MATH_ADD: 47 | return "+"; 48 | case MathOperator::MATH_SUB: 49 | return "-"; 50 | case MathOperator::MATH_MULTIPLY: 51 | return "*"; 52 | case MathOperator::MATH_DIVIDE: 53 | return "/"; 54 | case MathOperator::MATH_POWER: 55 | return "^"; 56 | case MathOperator::MATH_AND: 57 | return "&"; 58 | case MathOperator::MATH_OR: 59 | return "|"; 60 | case MathOperator::MATH_MOD: 61 | return "%"; 62 | case MathOperator::MATH_LEFT_PARENTHESIS: 63 | return "("; 64 | case MathOperator::MATH_RIGHT_PARENTHESIS: 65 | return ")"; 66 | } 67 | assert(0); 68 | return "err"; 69 | } 70 | 71 | int GetOperatorNum(MathOperator op) noexcept { 72 | switch (op) { 73 | case MathOperator::MATH_POSITIVE: // 正负号 74 | case MathOperator::MATH_NEGATIVE: 75 | 76 | case MathOperator::MATH_SIN: 77 | case MathOperator::MATH_COS: 78 | case MathOperator::MATH_TAN: 79 | case MathOperator::MATH_ARCSIN: 80 | case MathOperator::MATH_ARCCOS: 81 | case MathOperator::MATH_ARCTAN: 82 | case MathOperator::MATH_SQRT: 83 | case MathOperator::MATH_LOG: 84 | case MathOperator::MATH_LOG2: 85 | case MathOperator::MATH_LOG10: 86 | case MathOperator::MATH_EXP: 87 | return 1; 88 | 89 | case MathOperator::MATH_ADD: 90 | case MathOperator::MATH_SUB: 91 | case MathOperator::MATH_MULTIPLY: 92 | case MathOperator::MATH_DIVIDE: 93 | case MathOperator::MATH_POWER: //^ 94 | case MathOperator::MATH_AND: //& 95 | case MathOperator::MATH_OR: //| 96 | case MathOperator::MATH_MOD: //% 97 | return 2; 98 | 99 | case MathOperator::MATH_LEFT_PARENTHESIS: 100 | case MathOperator::MATH_RIGHT_PARENTHESIS: 101 | assert(0); 102 | break; 103 | default: 104 | assert(0); 105 | break; 106 | } 107 | assert(0); 108 | return 0; 109 | } 110 | 111 | int Rank(MathOperator op) noexcept { 112 | switch (op) { 113 | case MathOperator::MATH_SIN: 114 | case MathOperator::MATH_COS: 115 | case MathOperator::MATH_TAN: 116 | case MathOperator::MATH_ARCSIN: 117 | case MathOperator::MATH_ARCCOS: 118 | case MathOperator::MATH_ARCTAN: 119 | case MathOperator::MATH_SQRT: 120 | case MathOperator::MATH_LOG: 121 | case MathOperator::MATH_LOG2: 122 | case MathOperator::MATH_LOG10: 123 | case MathOperator::MATH_EXP: 124 | return 15; 125 | 126 | case MathOperator::MATH_POSITIVE: // 除了函数,所有运算符均可将正负号挤出 127 | case MathOperator::MATH_NEGATIVE: 128 | return 14; 129 | 130 | case MathOperator::MATH_MOD: //% 131 | return 13; 132 | 133 | case MathOperator::MATH_AND: //& 134 | case MathOperator::MATH_OR: //| 135 | return 12; 136 | 137 | case MathOperator::MATH_POWER: //^ 138 | return 11; 139 | 140 | case MathOperator::MATH_MULTIPLY: 141 | case MathOperator::MATH_DIVIDE: 142 | return 10; 143 | 144 | case MathOperator::MATH_ADD: 145 | case MathOperator::MATH_SUB: 146 | return 5; 147 | 148 | case MathOperator::MATH_LEFT_PARENTHESIS: // 左右括号优先级小是为了不被其余任何运算符挤出 149 | case MathOperator::MATH_RIGHT_PARENTHESIS: 150 | return 0; 151 | default: 152 | assert(0); 153 | break; 154 | } 155 | assert(0); 156 | return 0; 157 | } 158 | 159 | bool IsLeft2Right(MathOperator eOperator) noexcept { 160 | switch (eOperator) { 161 | case MathOperator::MATH_MOD: //% 162 | case MathOperator::MATH_AND: //& 163 | case MathOperator::MATH_OR: //| 164 | case MathOperator::MATH_MULTIPLY: 165 | case MathOperator::MATH_DIVIDE: 166 | case MathOperator::MATH_ADD: 167 | case MathOperator::MATH_SUB: 168 | return true; 169 | 170 | case MathOperator::MATH_POSITIVE: // 正负号为右结合 171 | case MathOperator::MATH_NEGATIVE: 172 | case MathOperator::MATH_POWER: //^ 173 | return false; 174 | 175 | // 函数和括号不计结合性 176 | case MathOperator::MATH_SIN: 177 | case MathOperator::MATH_COS: 178 | case MathOperator::MATH_TAN: 179 | case MathOperator::MATH_ARCSIN: 180 | case MathOperator::MATH_ARCCOS: 181 | case MathOperator::MATH_ARCTAN: 182 | case MathOperator::MATH_SQRT: 183 | case MathOperator::MATH_LOG: 184 | case MathOperator::MATH_LOG2: 185 | case MathOperator::MATH_LOG10: 186 | case MathOperator::MATH_EXP: 187 | 188 | case MathOperator::MATH_LEFT_PARENTHESIS: 189 | case MathOperator::MATH_RIGHT_PARENTHESIS: 190 | return true; 191 | default: 192 | assert(0); 193 | } 194 | return false; 195 | } 196 | 197 | bool InAssociativeLaws(MathOperator eOperator) noexcept { 198 | switch (eOperator) { 199 | 200 | case MathOperator::MATH_POSITIVE: // 正负号 201 | case MathOperator::MATH_NEGATIVE: 202 | 203 | case MathOperator::MATH_SQRT: 204 | case MathOperator::MATH_SIN: 205 | case MathOperator::MATH_COS: 206 | case MathOperator::MATH_TAN: 207 | case MathOperator::MATH_ARCSIN: 208 | case MathOperator::MATH_ARCCOS: 209 | case MathOperator::MATH_ARCTAN: 210 | case MathOperator::MATH_LOG: 211 | case MathOperator::MATH_LOG2: 212 | case MathOperator::MATH_LOG10: 213 | case MathOperator::MATH_EXP: 214 | 215 | case MathOperator::MATH_MOD: //% 216 | case MathOperator::MATH_AND: //& 217 | case MathOperator::MATH_OR: //| 218 | case MathOperator::MATH_POWER: //^ 219 | case MathOperator::MATH_DIVIDE: 220 | case MathOperator::MATH_SUB: 221 | 222 | case MathOperator::MATH_LEFT_PARENTHESIS: 223 | case MathOperator::MATH_RIGHT_PARENTHESIS: 224 | return false; 225 | 226 | case MathOperator::MATH_ADD: 227 | case MathOperator::MATH_MULTIPLY: 228 | return true; 229 | default: 230 | assert(0); 231 | break; 232 | } 233 | assert(0); 234 | return false; 235 | } 236 | 237 | bool IsFunction(MathOperator op) noexcept { 238 | switch (op) { 239 | case MathOperator::MATH_SIN: 240 | case MathOperator::MATH_COS: 241 | case MathOperator::MATH_TAN: 242 | case MathOperator::MATH_ARCSIN: 243 | case MathOperator::MATH_ARCCOS: 244 | case MathOperator::MATH_ARCTAN: 245 | case MathOperator::MATH_SQRT: 246 | case MathOperator::MATH_LOG: 247 | case MathOperator::MATH_LOG2: 248 | case MathOperator::MATH_LOG10: 249 | case MathOperator::MATH_EXP: 250 | return true; 251 | 252 | case MathOperator::MATH_POSITIVE: 253 | case MathOperator::MATH_NEGATIVE: 254 | case MathOperator::MATH_MOD: //% 255 | case MathOperator::MATH_AND: //& 256 | case MathOperator::MATH_OR: //| 257 | case MathOperator::MATH_POWER: //^ 258 | case MathOperator::MATH_MULTIPLY: 259 | case MathOperator::MATH_DIVIDE: 260 | case MathOperator::MATH_ADD: 261 | case MathOperator::MATH_SUB: 262 | case MathOperator::MATH_LEFT_PARENTHESIS: 263 | case MathOperator::MATH_RIGHT_PARENTHESIS: 264 | return false; 265 | default: 266 | assert(0); 267 | break; 268 | } 269 | assert(0); 270 | return false; 271 | } 272 | 273 | double Calc(MathOperator op, double v1, double v2) { 274 | double ret = std::numeric_limits::quiet_NaN(); 275 | switch (op) { 276 | case MathOperator::MATH_SIN: 277 | ret = std::sin(v1); 278 | break; 279 | case MathOperator::MATH_COS: 280 | ret = std::cos(v1); 281 | break; 282 | case MathOperator::MATH_TAN: 283 | ret = std::tan(v1); 284 | break; 285 | case MathOperator::MATH_ARCSIN: 286 | ret = std::asin(v1); 287 | break; 288 | case MathOperator::MATH_ARCCOS: 289 | ret = std::acos(v1); 290 | break; 291 | case MathOperator::MATH_ARCTAN: 292 | ret = std::atan(v1); 293 | break; 294 | case MathOperator::MATH_SQRT: 295 | ret = std::sqrt(v1); 296 | break; 297 | case MathOperator::MATH_LOG: 298 | ret = std::log(v1); 299 | break; 300 | case MathOperator::MATH_LOG2: 301 | ret = std::log2(v1); 302 | break; 303 | case MathOperator::MATH_LOG10: 304 | ret = std::log10(v1); 305 | break; 306 | case MathOperator::MATH_EXP: 307 | ret = std::exp(v1); 308 | break; 309 | case MathOperator::MATH_POSITIVE: 310 | ret = v1; 311 | break; 312 | case MathOperator::MATH_NEGATIVE: 313 | ret = -v1; 314 | break; 315 | 316 | case MathOperator::MATH_MOD: //% 317 | ret = (int)v1 % (int)v2; 318 | break; 319 | case MathOperator::MATH_AND: //& 320 | ret = (int)v1 & (int)v2; 321 | break; 322 | case MathOperator::MATH_OR: //| 323 | ret = (int)v1 | (int)v2; 324 | break; 325 | 326 | case MathOperator::MATH_POWER: //^ 327 | ret = std::pow(v1, v2); 328 | break; 329 | 330 | case MathOperator::MATH_ADD: 331 | ret = v1 + v2; 332 | break; 333 | case MathOperator::MATH_SUB: 334 | ret = v1 - v2; 335 | break; 336 | case MathOperator::MATH_MULTIPLY: 337 | ret = v1 * v2; 338 | break; 339 | case MathOperator::MATH_DIVIDE: 340 | ret = v1 / v2; 341 | break; 342 | default: 343 | assert(0 && "[Calc] bug."); 344 | break; 345 | } 346 | 347 | if (Config::Get().throwOnInvalidValue == false) { 348 | return ret; 349 | } 350 | 351 | bool isInvalid = (ret == std::numeric_limits::infinity()) || 352 | (ret == -std::numeric_limits::infinity()) || (ret != ret); 353 | if (isInvalid) { 354 | // std::string info; 355 | std::stringstream info; 356 | info << "expression: \""; 357 | switch (GetOperatorNum(op)) { 358 | case 1: 359 | info << MathOperatorToStr(op) << " " << ToString(v1); 360 | break; 361 | case 2: 362 | info << ToString(v1) << " " << MathOperatorToStr(op) << " " << ToString(v2); 363 | break; 364 | default: 365 | assert(0); 366 | } 367 | info << "\""; 368 | throw MathError(ErrorType::ERROR_INVALID_NUMBER, info.str()); 369 | } 370 | 371 | return ret; 372 | } 373 | 374 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/math_operator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #define _USE_MATH_DEFINES 5 | #include 6 | 7 | namespace tomsolver { 8 | 9 | constexpr double PI = M_PI; 10 | 11 | template 12 | T radians(T &&t) noexcept { 13 | return std::forward(t) / 180.0 * PI; 14 | } 15 | 16 | template 17 | T degrees(T &&t) noexcept { 18 | return std::forward(t) * 180.0 / PI; 19 | } 20 | 21 | enum class MathOperator { 22 | MATH_NULL, 23 | // 一元 24 | MATH_POSITIVE, 25 | MATH_NEGATIVE, 26 | 27 | // 函数 28 | MATH_SIN, 29 | MATH_COS, 30 | MATH_TAN, 31 | MATH_ARCSIN, 32 | MATH_ARCCOS, 33 | MATH_ARCTAN, 34 | MATH_SQRT, 35 | MATH_LOG, 36 | MATH_LOG2, 37 | MATH_LOG10, 38 | MATH_EXP, 39 | 40 | // 二元 41 | MATH_ADD, 42 | MATH_SUB, 43 | MATH_MULTIPLY, 44 | MATH_DIVIDE, 45 | MATH_POWER, 46 | MATH_AND, 47 | MATH_OR, 48 | MATH_MOD, 49 | 50 | MATH_LEFT_PARENTHESIS, 51 | MATH_RIGHT_PARENTHESIS 52 | }; 53 | 54 | /** 55 | * 操作符转std::string 56 | */ 57 | std::string MathOperatorToStr(MathOperator op); 58 | 59 | /** 60 | * 取得操作数的数量。 61 | */ 62 | int GetOperatorNum(MathOperator op) noexcept; 63 | 64 | /** 65 | * 返回运算符的优先级 66 | 67 | */ 68 | int Rank(MathOperator op) noexcept; 69 | 70 | /** 71 | * 返回运算符结合性 72 | */ 73 | bool IsLeft2Right(MathOperator eOperator) noexcept; 74 | 75 | /** 76 | * 返回是否满足交换律 77 | */ 78 | bool InAssociativeLaws(MathOperator eOperator) noexcept; 79 | 80 | /** 81 | * 返回是否是函数 82 | */ 83 | bool IsFunction(MathOperator op) noexcept; 84 | 85 | /** 86 | * 是整数 且 为偶数 87 | * FIXME: 超出long long范围的处理 88 | */ 89 | bool IsIntAndEven(double n) noexcept; 90 | 91 | double Calc(MathOperator op, double v1, double v2); 92 | 93 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/node.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "math_operator.h" 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | namespace tomsolver { 16 | 17 | enum class NodeType { NUMBER, OPERATOR, VARIABLE }; 18 | 19 | // 前置声明 20 | namespace internal { 21 | struct NodeImpl; 22 | } 23 | class SymMat; 24 | 25 | /** 26 | * 表达式节点。 27 | */ 28 | using Node = std::unique_ptr; 29 | 30 | namespace internal { 31 | 32 | /** 33 | * 单个节点的实现。通常应该以std::unique_ptr包裹。 34 | */ 35 | struct NodeImpl { 36 | 37 | NodeImpl(NodeType type, MathOperator op, double value, std::string varname) noexcept; 38 | 39 | NodeImpl(const NodeImpl &rhs) noexcept; 40 | NodeImpl &operator=(const NodeImpl &rhs) noexcept; 41 | 42 | NodeImpl(NodeImpl &&rhs) noexcept; 43 | NodeImpl &operator=(NodeImpl &&rhs) noexcept; 44 | 45 | ~NodeImpl(); 46 | 47 | bool Equal(const Node &rhs) const noexcept; 48 | 49 | /** 50 | * 把整个节点以中序遍历的顺序输出为字符串。 51 | * 例如: 52 | * Node n = (Var("a") + Num(1)) * Var("b"); 53 | * 则 54 | * n->ToString() == "(a+1.000000)*b" 55 | */ 56 | std::string ToString() const noexcept; 57 | 58 | /** 59 | * 计算出整个表达式的数值。不改变自身。 60 | * @exception runtime_error 如果有变量存在,则无法计算 61 | * @exception MathError 出现浮点数无效值(inf, -inf, nan) 62 | */ 63 | double Vpa() const; 64 | 65 | /** 66 | * 计算出整个表达式的数值。不改变自身。 67 | * @exception runtime_error 如果有变量存在,则无法计算 68 | * @exception MathError 出现浮点数无效值(inf, -inf, nan) 69 | */ 70 | NodeImpl &Calc(); 71 | 72 | /** 73 | * 返回表达式内出现的所有变量名。 74 | */ 75 | std::set GetAllVarNames() const noexcept; 76 | 77 | /** 78 | * 检查整个节点数的parent指针是否正确。 79 | */ 80 | void CheckParent() const noexcept; 81 | 82 | private: 83 | std::string varname; 84 | double value; 85 | MathOperator op = MathOperator::MATH_NULL; 86 | NodeType type = NodeType::NUMBER; 87 | NodeImpl *parent = nullptr; 88 | Node left, right; 89 | NodeImpl() = default; 90 | 91 | /** 92 | * 本节点如果是OPERATOR,检查操作数数量和left, right指针是否匹配。 93 | */ 94 | void CheckOperatorNum() const noexcept; 95 | 96 | /** 97 | * 节点转string。仅限本节点,不含子节点。 98 | */ 99 | std::string NodeToStr() const noexcept; 100 | 101 | void ToStringRecursively(std::stringstream &output) const noexcept; 102 | 103 | void ToStringNonRecursively(std::stringstream &output) const noexcept; 104 | 105 | /** 106 | * 计算表达式数值。递归实现。 107 | * @exception runtime_error 如果有变量存在,则无法计算 108 | * @exception MathError 不符合定义域, 除0等情况。 109 | */ 110 | double VpaRecursively() const; 111 | 112 | /** 113 | * 计算表达式数值。非递归实现。 114 | * 性能弱于递归实现。但不会导致栈溢出。 115 | * 根据benchmark,生成一组含4000个随机四则运算节点的表达式,生成1000次,Release下测试耗时3000ms。递归实现耗时2500ms。 116 | * 粗略计算,即 1333 ops/ms。 117 | * @exception runtime_error 如果有变量存在,则无法计算 118 | * @exception MathError 不符合定义域, 除0等情况。 119 | */ 120 | double VpaNonRecursively() const; 121 | 122 | /** 123 | * 释放整个节点树,除了自己。 124 | * 实际是二叉树的非递归后序遍历。 125 | */ 126 | void Release() noexcept; 127 | 128 | friend Node Operator(MathOperator op, Node left, Node right) noexcept; 129 | 130 | friend Node CloneRecursively(const Node &rhs) noexcept; 131 | friend Node CloneNonRecursively(const Node &rhs) noexcept; 132 | 133 | friend void CopyOrMoveTo(NodeImpl *parent, Node &child, Node &&n1) noexcept; 134 | friend void CopyOrMoveTo(NodeImpl *parent, Node &child, const Node &n1) noexcept; 135 | 136 | friend std::ostream &operator<<(std::ostream &out, const Node &n) noexcept; 137 | 138 | template 139 | friend Node UnaryOperator(MathOperator op, T &&n) noexcept; 140 | 141 | template 142 | friend Node BinaryOperator(MathOperator op, T1 &&n1, T2 &&n2) noexcept; 143 | 144 | friend class tomsolver::SymMat; 145 | friend class SimplifyFunctions; 146 | friend class DiffFunctions; 147 | friend class SubsFunctions; 148 | friend class ParseFunctions; 149 | }; 150 | 151 | Node CloneRecursively(const Node &rhs) noexcept; 152 | 153 | Node CloneNonRecursively(const Node &rhs) noexcept; 154 | 155 | /** 156 | * 对于一个节点n和另一个节点n1,把n1移动到作为n的子节点。 157 | * 用法:CopyOrMoveTo(n->parent, n->left, std::forward(n1)); 158 | */ 159 | void CopyOrMoveTo(NodeImpl *parent, Node &child, Node &&n1) noexcept; 160 | 161 | /** 162 | * 对于一个节点n和另一个节点n1,把n1整个拷贝一份,把拷贝的副本设为n的子节点。 163 | * 用法:CopyOrMoveTo(n->parent, n->left, std::forward(n1)); 164 | */ 165 | void CopyOrMoveTo(NodeImpl *parent, Node &child, const Node &n1) noexcept; 166 | 167 | /** 168 | * 重载std::ostream的<<操作符以输出一个Node节点。 169 | */ 170 | std::ostream &operator<<(std::ostream &out, const Node &n) noexcept; 171 | 172 | template 173 | Node UnaryOperator(MathOperator op, T &&n) noexcept { 174 | auto ret = std::make_unique(NodeType::OPERATOR, op, 0, ""); 175 | CopyOrMoveTo(ret.get(), ret->left, std::forward(n)); 176 | return ret; 177 | } 178 | 179 | template 180 | Node BinaryOperator(MathOperator op, T1 &&n1, T2 &&n2) noexcept { 181 | auto ret = std::make_unique(NodeType::OPERATOR, op, 0, ""); 182 | CopyOrMoveTo(ret.get(), ret->left, std::forward(n1)); 183 | CopyOrMoveTo(ret.get(), ret->right, std::forward(n2)); 184 | return ret; 185 | } 186 | 187 | /** 188 | * 新建一个运算符节点。 189 | */ 190 | Node Operator(MathOperator op, Node left = nullptr, Node right = nullptr) noexcept; 191 | 192 | } // namespace internal 193 | 194 | Node Clone(const Node &rhs) noexcept; 195 | 196 | /** 197 | * 对节点进行移动。等同于std::move。 198 | */ 199 | Node Move(Node &rhs) noexcept; 200 | 201 | /** 202 | * 新建一个数值节点。 203 | */ 204 | Node Num(double num) noexcept; 205 | 206 | /** 207 | * 新建一个函数节点。 208 | */ 209 | Node Op(MathOperator op); 210 | 211 | /** 212 | * 返回变量名是否有效。(只支持英文数字或者下划线,第一个字符必须是英文或者下划线) 213 | */ 214 | bool VarNameIsLegal(const std::string &varname) noexcept; 215 | 216 | /** 217 | * 新建一个变量节点。 218 | * @exception runtime_error 名字不合法 219 | */ 220 | Node Var(std::string varname); 221 | 222 | template 223 | struct SfinaeNodeImpl : std::false_type {}; 224 | 225 | template <> 226 | struct SfinaeNodeImpl : std::true_type {}; 227 | 228 | template <> 229 | struct SfinaeNodeImpl : std::true_type {}; 230 | 231 | template 232 | using SfinaeNode = std::enable_if_t...>::value, Node>; 233 | 234 | template 235 | SfinaeNode operator+(T1 &&n1, T2 &&n2) noexcept { 236 | return internal::BinaryOperator(MathOperator::MATH_ADD, std::forward(n1), std::forward(n2)); 237 | } 238 | 239 | template 240 | SfinaeNode &operator+=(Node &n1, T &&n2) noexcept { 241 | n1 = internal::BinaryOperator(MathOperator::MATH_ADD, std::move(n1), std::forward(n2)); 242 | return n1; 243 | } 244 | 245 | template 246 | SfinaeNode operator-(T1 &&n1, T2 &&n2) noexcept { 247 | return internal::BinaryOperator(MathOperator::MATH_SUB, std::forward(n1), std::forward(n2)); 248 | } 249 | 250 | template 251 | SfinaeNode operator-(T &&n1) noexcept { 252 | return internal::UnaryOperator(MathOperator::MATH_NEGATIVE, std::forward(n1)); 253 | } 254 | 255 | template 256 | SfinaeNode operator+(T &&n1) noexcept { 257 | return internal::UnaryOperator(MathOperator::MATH_POSITIVE, std::forward(n1)); 258 | } 259 | 260 | template 261 | SfinaeNode &operator-=(Node &n1, T &&n2) noexcept { 262 | n1 = internal::BinaryOperator(MathOperator::MATH_SUB, std::move(n1), std::forward(n2)); 263 | return n1; 264 | } 265 | 266 | template 267 | SfinaeNode operator*(T1 &&n1, T2 &&n2) noexcept { 268 | return internal::BinaryOperator(MathOperator::MATH_MULTIPLY, std::forward(n1), std::forward(n2)); 269 | } 270 | 271 | template 272 | SfinaeNode &operator*=(Node &n1, T &&n2) noexcept { 273 | n1 = internal::BinaryOperator(MathOperator::MATH_MULTIPLY, std::move(n1), std::forward(n2)); 274 | return n1; 275 | } 276 | 277 | template 278 | SfinaeNode operator/(T1 &&n1, T2 &&n2) noexcept { 279 | return internal::BinaryOperator(MathOperator::MATH_DIVIDE, std::forward(n1), std::forward(n2)); 280 | } 281 | 282 | template 283 | SfinaeNode &operator/=(Node &n1, T &&n2) noexcept { 284 | n1 = internal::BinaryOperator(MathOperator::MATH_DIVIDE, std::move(n1), std::forward(n2)); 285 | return n1; 286 | } 287 | 288 | template 289 | SfinaeNode operator^(T1 &&n1, T2 &&n2) noexcept { 290 | return internal::BinaryOperator(MathOperator::MATH_POWER, std::forward(n1), std::forward(n2)); 291 | } 292 | 293 | template 294 | SfinaeNode &operator^=(Node &n1, T &&n2) noexcept { 295 | n1 = internal::BinaryOperator(MathOperator::MATH_POWER, std::move(n1), std::forward(n2)); 296 | return n1; 297 | } 298 | 299 | } // namespace tomsolver 300 | -------------------------------------------------------------------------------- /src/tomsolver/nonlinear.cpp: -------------------------------------------------------------------------------- 1 | #include "nonlinear.h" 2 | 3 | #include "config.h" 4 | #include "error_type.h" 5 | #include "linear.h" 6 | 7 | #include 8 | #include 9 | 10 | using std::cout; 11 | using std::endl; 12 | using std::runtime_error; 13 | 14 | namespace tomsolver { 15 | 16 | double Armijo(const Vec &x, const Vec &d, std::function f, std::function df) { 17 | double alpha = 1; // a > 0 18 | double gamma = 0.4; // 取值范围(0, 0.5)越大越快 19 | double sigma = 0.5; // 取值范围(0, 1)越大越慢 20 | Vec x_new(x); 21 | while (1) { 22 | x_new = x + alpha * d; 23 | 24 | auto l = f(x_new).Norm2(); 25 | auto r = (f(x).AsMat() + gamma * alpha * df(x).Transpose() * d).Norm2(); 26 | if (l <= r) // 检验条件 27 | { 28 | break; 29 | } else 30 | alpha = alpha * sigma; // 缩小alpha,进入下一次循环 31 | } 32 | return alpha; 33 | } 34 | 35 | double FindAlpha(const Vec &x, const Vec &d, std::function f, double uncert) { 36 | double alpha_cur = 0; 37 | 38 | double alpha_new = 1; 39 | 40 | int it = 0; 41 | int maxIter = 100; 42 | 43 | Vec g_cur = f(x + alpha_cur * d); 44 | 45 | while (std::abs(alpha_new - alpha_cur) > alpha_cur * uncert) { 46 | double alpha_old = alpha_cur; 47 | alpha_cur = alpha_new; 48 | Vec g_old = g_cur; 49 | g_cur = f(x + alpha_cur * d); 50 | 51 | if (g_cur < g_old) { 52 | break; 53 | } 54 | 55 | // FIXME: nan occurred 56 | alpha_new = EachDivide((g_cur * alpha_old - g_old * alpha_cur), (g_cur - g_old)).NormNegInfinity(); 57 | 58 | // cout << it<<"\t"< maxIter) { 60 | cout << "FindAlpha: over iterator" << endl; 61 | break; 62 | } 63 | } 64 | return alpha_new; 65 | } 66 | 67 | namespace internal { 68 | void PrintSolveStartInfo(const SymVec &equations, const VarsTable &varsTable) noexcept { 69 | if (Config::Get().logLevel >= LogLevel::INFO) { 70 | cout << "Solve start.\n"; 71 | cout << " Method: Newton Raphson\n"; 72 | cout << "Equations:\n" + equations.ToString(); 73 | cout << "Inital Values:\n" + varsTable.ToString(); 74 | } 75 | } 76 | void PrintJacobian(const SymMat &jaEqs) noexcept { 77 | if (Config::Get().logLevel >= LogLevel::TRACE) { 78 | cout << "Jacobian:\n" + jaEqs.ToString(); 79 | } 80 | } 81 | void PrintAtIterationStart(int it) noexcept { 82 | if (Config::Get().logLevel >= LogLevel::INFO) { 83 | cout << std::string("==========") + "==========" + "\n"; 84 | cout << "iteration times = " + std::to_string(it) + "\n"; 85 | } 86 | } 87 | } // namespace internal 88 | 89 | VarsTable SolveByNewtonRaphson(const SymVec &equations, const VarsTable &varsTable) { 90 | int it = 0; // 迭代计数 91 | VarsTable table = varsTable; 92 | int n = table.VarNums(); // 未知量数量 93 | Vec q(n); // x向量 94 | internal::PrintSolveStartInfo(equations, varsTable); 95 | 96 | SymMat JaEqs = Jacobian(equations, table.Vars()); 97 | internal::PrintJacobian(JaEqs); 98 | 99 | while (1) { 100 | internal::PrintAtIterationStart(it); 101 | 102 | Vec phi = equations.Clone().Subs(table).Calc().ToMat().ToVec(); 103 | if (Config::Get().logLevel >= LogLevel::INFO) { 104 | cout << "phi = \n" + phi.ToString(); 105 | } 106 | 107 | if (phi == 0) { 108 | break; 109 | } 110 | 111 | if (it > Config::Get().maxIterations) { 112 | throw runtime_error("迭代次数超出限制"); 113 | } 114 | 115 | Mat ja = JaEqs.Clone().Subs(table).Calc().ToMat(); 116 | 117 | try { 118 | Vec deltaq = SolveLinear(ja, -phi); 119 | if (Config::Get().logLevel >= LogLevel::TRACE) { 120 | cout << "deltaq = " << deltaq << endl; 121 | } 122 | 123 | q += deltaq; 124 | } catch (const tomsolver::MathError &err) { 125 | if (err.GetErrorType() == ErrorType::ERROR_SINGULAR_MATRIX) { 126 | throw MathError(ErrorType::ERROR_SINGULAR_MATRIX, "tip: consider using different initial values"); 127 | } 128 | throw; 129 | } 130 | 131 | if (Config::Get().logLevel >= LogLevel::TRACE) { 132 | cout << "ja = " << ja << endl; 133 | cout << "q = " << q << endl; 134 | } 135 | 136 | table.SetValues(q); 137 | 138 | ++it; 139 | } 140 | return table; 141 | } 142 | 143 | VarsTable SolveByLM(const SymVec &equations, const VarsTable &varsTable) { 144 | int it = 0; // 迭代计数 145 | VarsTable table = varsTable; 146 | int n = table.VarNums(); // 未知量数量 147 | Vec q = table.Values(); // x向量 148 | internal::PrintSolveStartInfo(equations, varsTable); 149 | 150 | SymMat JaEqs = Jacobian(equations, table.Vars()); 151 | internal::PrintJacobian(JaEqs); 152 | 153 | while (1) { 154 | internal::PrintAtIterationStart(it); 155 | 156 | double mu = 1e-5; // LM方法的λ值 157 | 158 | Vec F = equations.Clone().Subs(table).Calc().ToMat().ToVec(); // 计算F 159 | 160 | if (Config::Get().logLevel >= LogLevel::TRACE) { 161 | cout << "F = " << F << endl; 162 | } 163 | 164 | if (F == 0) { // F值为0,满足方程组求根条件 165 | break; 166 | } 167 | 168 | Vec FNew(n); // 下一轮F 169 | Vec deltaq(n); // Δq 170 | while (1) { 171 | 172 | Mat J = JaEqs.Clone().Subs(table).Calc().ToMat(); // 计算雅可比矩阵 173 | 174 | if (Config::Get().logLevel >= LogLevel::TRACE) { 175 | cout << "J = " << J << endl; 176 | } 177 | 178 | // 说明: 179 | // 标准的LM方法中,d=-(J'*J+λI)^(-1)*J'F,其中J'*J是为了确保矩阵对称正定。有时d会过大,很难收敛。 180 | // 牛顿法的 d=-(J+λI)^(-1)*F 181 | 182 | // 方向向量 183 | Vec d = SolveLinear(J.Transpose() * J + mu * Mat(J.Rows(), J.Cols()).Ones(), 184 | -(J.Transpose() * F).ToVec()); // 得到d 185 | 186 | if (Config::Get().logLevel >= LogLevel::TRACE) { 187 | cout << "d = " << d << endl; 188 | } 189 | 190 | double alpha = Armijo( 191 | q, d, 192 | [&](Vec v) -> Vec { 193 | table.SetValues(v); 194 | return equations.Clone().Subs(table).Calc().ToMat().ToVec(); 195 | }, 196 | [&](Vec v) -> Mat { 197 | table.SetValues(v); 198 | return JaEqs.Clone().Subs(table).Calc().ToMat(); 199 | }); // 进行1维搜索得到alpha 200 | 201 | // double alpha = FindAlpha(q, d, std::bind(SixBarAngPosition, std::placeholders::_1, thetaCDKL, Hhit)); 202 | 203 | // for (size_t i = 0; i < alpha.rows; ++i) 204 | //{ 205 | // if (alpha[i] != alpha[i]) 206 | // alpha[i] = 1.0; 207 | //} 208 | 209 | deltaq = alpha * d; // 计算Δq 210 | 211 | Vec qTemp = q + deltaq; 212 | table.SetValues(qTemp); 213 | 214 | FNew = equations.Clone().Subs(table).Calc().ToMat().ToVec(); // 计算新的F 215 | 216 | if (Config::Get().logLevel >= LogLevel::TRACE) { 217 | cout << "it=" << it << endl; 218 | cout << "\talpha=" << alpha << endl; 219 | cout << "mu=" << mu << endl; 220 | cout << "F.Norm2()=" << F.Norm2() << endl; 221 | cout << "FNew.Norm2()=" << FNew.Norm2() << endl; 222 | cout << "\tF(x k+1).Norm2()\t" << ((FNew.Norm2() < F.Norm2()) ? "<" : ">=") << "\tF(x k).Norm2()\t" 223 | << endl; 224 | } 225 | 226 | if (FNew.Norm2() < F.Norm2()) // 满足下降条件,跳出内层循环 227 | { 228 | break; 229 | } else { 230 | mu *= 10.0; // 扩大λ,使模型倾向梯度下降方向 231 | } 232 | 233 | if (it++ == Config::Get().maxIterations) { 234 | throw runtime_error("迭代次数超出限制"); 235 | } 236 | } 237 | 238 | q += deltaq; // 应用Δq,更新q值 239 | 240 | table.SetValues(q); 241 | 242 | F = FNew; // 更新F 243 | 244 | if (it++ == Config::Get().maxIterations) { 245 | throw runtime_error("迭代次数超出限制"); 246 | } 247 | 248 | if (Config::Get().logLevel >= LogLevel::TRACE) { 249 | cout << std::string(20, '=') << endl; 250 | } 251 | } 252 | 253 | if (Config::Get().logLevel >= LogLevel::TRACE) { 254 | cout << "success" << endl; 255 | } 256 | 257 | return table; 258 | } 259 | 260 | VarsTable Solve(const SymVec &equations, const VarsTable &varsTable) { 261 | switch (Config::Get().nonlinearMethod) { 262 | case NonlinearMethod::NEWTON_RAPHSON: 263 | return SolveByNewtonRaphson(equations, varsTable); 264 | case NonlinearMethod::LM: 265 | return SolveByLM(equations, varsTable); 266 | } 267 | throw runtime_error("invalid config.NonlinearMethod value: " + 268 | std::to_string(static_cast(Config::Get().nonlinearMethod))); 269 | } 270 | 271 | VarsTable Solve(const SymVec &equations) { 272 | auto varNames = equations.GetAllVarNames(); 273 | std::vector vecVarNames(varNames.begin(), varNames.end()); 274 | VarsTable varsTable(std::move(vecVarNames), Config::Get().initialValue); 275 | return Solve(equations, varsTable); 276 | } 277 | 278 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/nonlinear.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mat.h" 4 | #include "symmat.h" 5 | #include "vars_table.h" 6 | 7 | #include 8 | 9 | namespace tomsolver { 10 | 11 | /** 12 | * Armijo方法一维搜索,寻找alpha 13 | */ 14 | double Armijo(const Vec &x, const Vec &d, std::function f, std::function df); 15 | 16 | /** 17 | * 割线法 进行一维搜索,寻找alpha 18 | */ 19 | double FindAlpha(const Vec &x, const Vec &d, std::function f, double uncert = 1.0e-5); 20 | 21 | /** 22 | * 解非线性方程组equations。 23 | * 初值及变量名通过varsTable传入。 24 | * @exception runtime_error 迭代次数超出限制 25 | */ 26 | VarsTable SolveByNewtonRaphson(const SymVec &equations, const VarsTable &varsTable); 27 | 28 | /** 29 | * 解非线性方程组equations。 30 | * 初值及变量名通过varsTable传入。 31 | * @exception runtime_error 迭代次数超出限制 32 | */ 33 | VarsTable SolveByLM(const SymVec &equations, const VarsTable &varsTable); 34 | 35 | /** 36 | * Solve a system of nonlinear equations. 37 | * Initial values and variable names are passed through varsTable. Will not use the Config::Get().initialValue 38 | * @param equations: The system of equations. Essentially, it is a symbolic vector. 39 | * @param varsTable: The table of initial values. 40 | * @throws tomsolver::MathError: If the number of iterations exceeds the limit. 41 | */ 42 | VarsTable Solve(const SymVec &equations, const VarsTable &varsTable); 43 | 44 | /** 45 | * Solve the equations. 46 | * Variable names are obtained by analyzing the equations. Initial values are obtained through Config::Get(). 47 | * @param equations: The system of equations. Essentially, it is a symbolic vector. 48 | * @throws tomsolver::MathError: If the number of iterations exceeds the limit. 49 | */ 50 | VarsTable Solve(const SymVec &equations); 51 | 52 | } // namespace tomsolver 53 | -------------------------------------------------------------------------------- /src/tomsolver/parse.cpp: -------------------------------------------------------------------------------- 1 | #include "parse.h" 2 | 3 | #include "error_type.h" 4 | #include "math_operator.h" 5 | #include "node.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace tomsolver { 17 | 18 | namespace { 19 | 20 | constexpr auto fnv1a(internal::StringView s) { 21 | constexpr uint64_t offsetBasis = 14695981039346656037ul; 22 | constexpr uint64_t prime = 1099511628211ul; 23 | 24 | uint64_t hash = offsetBasis; 25 | 26 | for (auto c : s) { 27 | hash = (hash ^ c) * prime; 28 | } 29 | 30 | return hash; 31 | } 32 | 33 | constexpr auto operator""_fnv1a(const char *s, size_t) { 34 | return fnv1a(s); 35 | } 36 | 37 | /* 是基本运算符()+-* /^&|% */ 38 | bool IsBasicOperator(char c) noexcept { 39 | switch (c) { 40 | case '(': 41 | case ')': 42 | case '+': 43 | case '-': 44 | case '*': 45 | case '/': 46 | case '^': 47 | case '&': 48 | case '|': 49 | case '%': 50 | return true; 51 | } 52 | return false; 53 | } 54 | 55 | /* */ 56 | MathOperator BaseOperatorCharToEnum(char c, bool unary) noexcept { 57 | switch (c) { 58 | case '(': 59 | return MathOperator::MATH_LEFT_PARENTHESIS; 60 | case ')': 61 | return MathOperator::MATH_RIGHT_PARENTHESIS; 62 | case '+': 63 | return unary ? MathOperator::MATH_POSITIVE : MathOperator::MATH_ADD; 64 | case '-': 65 | return unary ? MathOperator::MATH_NEGATIVE : MathOperator::MATH_SUB; 66 | case '*': 67 | return MathOperator::MATH_MULTIPLY; 68 | case '/': 69 | return MathOperator::MATH_DIVIDE; 70 | case '^': 71 | return MathOperator::MATH_POWER; 72 | case '&': 73 | return MathOperator::MATH_AND; 74 | case '|': 75 | return MathOperator::MATH_OR; 76 | case '%': 77 | return MathOperator::MATH_MOD; 78 | default: 79 | assert(0); 80 | } 81 | assert(0); 82 | return MathOperator::MATH_NULL; 83 | } 84 | 85 | MathOperator Str2Function(internal::StringView s) noexcept { 86 | switch (fnv1a(s)) { 87 | case "sin"_fnv1a: 88 | return MathOperator::MATH_SIN; 89 | case "cos"_fnv1a: 90 | return MathOperator::MATH_COS; 91 | case "tan"_fnv1a: 92 | return MathOperator::MATH_TAN; 93 | case "arcsin"_fnv1a: 94 | return MathOperator::MATH_ARCSIN; 95 | case "arccos"_fnv1a: 96 | return MathOperator::MATH_ARCCOS; 97 | case "arctan"_fnv1a: 98 | return MathOperator::MATH_ARCTAN; 99 | case "sqrt"_fnv1a: 100 | return MathOperator::MATH_SQRT; 101 | case "log"_fnv1a: 102 | return MathOperator::MATH_LOG; 103 | case "log2"_fnv1a: 104 | return MathOperator::MATH_LOG2; 105 | case "log10"_fnv1a: 106 | return MathOperator::MATH_LOG10; 107 | case "exp"_fnv1a: 108 | return MathOperator::MATH_EXP; 109 | } 110 | return MathOperator::MATH_NULL; 111 | } 112 | 113 | } // namespace 114 | 115 | const char *SingleParseError::what() const noexcept { 116 | return whatStr.c_str(); 117 | } 118 | 119 | int SingleParseError::GetLine() const noexcept { 120 | return line; 121 | } 122 | 123 | int SingleParseError::GetPos() const noexcept { 124 | return pos; 125 | } 126 | 127 | MultiParseError::MultiParseError(const std::vector &parseErrors) : parseErrors(parseErrors) { 128 | std::stringstream ss; 129 | std::transform(parseErrors.rbegin(), parseErrors.rend(), std::ostream_iterator(ss, "\n"), 130 | [](const auto &err) { 131 | return err.what(); 132 | }); 133 | whatStr = ss.str(); 134 | } 135 | 136 | const char *MultiParseError::what() const noexcept { 137 | return whatStr.c_str(); 138 | } 139 | 140 | namespace internal { 141 | 142 | std::deque ParseFunctions::ParseToTokens(StringView content) { 143 | 144 | if (content.empty()) { 145 | throw SingleParseError(0, 0, "empty input", content); 146 | } 147 | 148 | auto iter = content.begin(), nameIter = iter; 149 | std::deque ret; 150 | 151 | auto tryComfirmToken = [&ret, &iter, &nameIter, &content] { 152 | if (size_t size = std::distance(nameIter, iter)) { 153 | auto exp = StringView{&*nameIter, size}; 154 | ret.emplace_back(0, static_cast(std::distance(content.begin(), nameIter)), false, exp, content); 155 | auto &token = ret.back(); 156 | 157 | auto expStr = exp.toString(); 158 | // 检验是否为浮点数 159 | try { 160 | std::size_t sz; 161 | auto d = std::stod(expStr, &sz); 162 | if (sz == expStr.size()) { 163 | token.node = Num(d); 164 | return; 165 | } 166 | } catch (...) {} 167 | 168 | auto op = Str2Function(exp); 169 | if (op != MathOperator::MATH_NULL) { 170 | token.node = Op(op); 171 | return; 172 | } 173 | 174 | // 变量 175 | // 非运算符、数字、函数 176 | if (!VarNameIsLegal(expStr)) // 变量名首字符需为下划线或字母 177 | { 178 | throw SingleParseError(token.line, token.pos, exp, "Invalid variable name: \"", exp, "\""); 179 | } 180 | 181 | token.node = Var(expStr); 182 | } 183 | }; 184 | 185 | while (iter != content.end()) { 186 | if (IsBasicOperator(*iter)) { 187 | tryComfirmToken(); 188 | auto unaryOp = ret.empty() || (ret.back().node->type == NodeType::OPERATOR && 189 | ret.back().node->op != MathOperator::MATH_RIGHT_PARENTHESIS); 190 | ret.emplace_back(0, static_cast(std::distance(content.begin(), iter)), true, StringView{&*iter, 1}, 191 | content); 192 | ret.back().node = Op(BaseOperatorCharToEnum(*iter, unaryOp)); 193 | nameIter = ++iter; 194 | } else if (isspace(*iter)) { 195 | // 忽略tab (\t) whitespaces (\n, \v, \f, \r) space 196 | tryComfirmToken(); 197 | nameIter = ++iter; 198 | } else { 199 | ++iter; 200 | } 201 | } 202 | 203 | tryComfirmToken(); 204 | 205 | return ret; 206 | } 207 | 208 | std::vector ParseFunctions::InOrderToPostOrder(std::deque &inOrder) { 209 | std::vector postOrder; 210 | int parenthesisBalance = 0; 211 | std::stack tokenStack; 212 | 213 | auto popToken = [&tokenStack] { 214 | auto r = std::move(tokenStack.top()); 215 | tokenStack.pop(); 216 | return r; 217 | }; 218 | 219 | while (!inOrder.empty()) { 220 | auto f = std::move(inOrder.front()); 221 | inOrder.pop_front(); 222 | 223 | switch (f.node->type) { 224 | // 数字直接入栈 225 | case NodeType::NUMBER: 226 | case NodeType::VARIABLE: 227 | postOrder.emplace_back(std::move(f)); 228 | continue; 229 | case NodeType::OPERATOR: 230 | break; 231 | default: 232 | assert(0); 233 | }; 234 | 235 | switch (f.node->op) { 236 | case MathOperator::MATH_LEFT_PARENTHESIS: 237 | parenthesisBalance++; 238 | break; 239 | 240 | case MathOperator::MATH_POSITIVE: 241 | case MathOperator::MATH_NEGATIVE: 242 | break; 243 | 244 | case MathOperator::MATH_RIGHT_PARENTHESIS: 245 | if (parenthesisBalance == 0) { 246 | throw SingleParseError(f.line, f.pos, f.content, "Parenthesis not match: \"", f.s, "\""); 247 | } 248 | for (auto token = popToken(); token.node->op != MathOperator::MATH_LEFT_PARENTHESIS; token = popToken()) { 249 | postOrder.emplace_back(std::move(token)); 250 | } 251 | if (!tokenStack.empty() && IsFunction(tokenStack.top().node->op)) { 252 | postOrder.emplace_back(popToken()); 253 | } 254 | while (!tokenStack.empty() && (tokenStack.top().node->op == MathOperator::MATH_POSITIVE || 255 | tokenStack.top().node->op == MathOperator::MATH_NEGATIVE)) { 256 | postOrder.emplace_back(popToken()); 257 | } 258 | continue; 259 | 260 | default: 261 | // 不是括号也不是正负号 262 | if (!tokenStack.empty()) { 263 | auto compare = 264 | IsLeft2Right(f.node->op) 265 | ? std::function{[cmp = std::less_equal<>{}, rank = Rank(f.node->op)]( 266 | const Token 267 | &token) { // 左结合,则挤出高优先级及同优先级符号 268 | return cmp(rank, Rank(token.node->op)); 269 | }} 270 | : std::function{ 271 | [cmp = std::less<>{}, rank = Rank(f.node->op)]( 272 | const Token &token) { // 右结合,则挤出高优先级,但不挤出同优先级符号 273 | return cmp(rank, Rank(token.node->op)); 274 | }}; 275 | 276 | while (!tokenStack.empty() && compare(tokenStack.top())) { 277 | postOrder.push_back(std::move(tokenStack.top())); // 符号进入post队列 278 | tokenStack.pop(); 279 | } 280 | } 281 | break; 282 | } 283 | 284 | tokenStack.push(std::move(f)); // 高优先级已全部挤出,当前符号入栈 285 | } 286 | 287 | // 剩下的元素全部入栈 288 | while (!tokenStack.empty()) { 289 | auto token = popToken(); 290 | 291 | // 退栈时出现左括号,说明没有找到与之匹配的右括号 292 | if (token.node->op == MathOperator::MATH_LEFT_PARENTHESIS) { 293 | throw SingleParseError(token.line, token.pos, token.content, "Parenthesis not match: \"", token.s, "\""); 294 | } 295 | 296 | postOrder.emplace_back(std::move(token)); 297 | } 298 | 299 | return postOrder; 300 | } 301 | 302 | // 将PostOrder建立为树,并进行表达式有效性检验(确保二元及一元运算符、函数均有操作数) 303 | Node ParseFunctions::BuildExpressionTree(std::vector &postOrder) { 304 | std::stack tokenStack; 305 | auto pushToken = [&tokenStack](Token &token) { 306 | tokenStack.emplace(std::move(token)); 307 | }; 308 | auto popNode = [&tokenStack] { 309 | auto node = std::move(tokenStack.top().node); 310 | tokenStack.pop(); 311 | return node; 312 | }; 313 | // 逐个识别PostOrder序列,构建表达式树 314 | for (auto &token : postOrder) { 315 | switch (token.node->type) { 316 | case NodeType::OPERATOR: 317 | if (GetOperatorNum(token.node->op) == 2) { 318 | if (tokenStack.empty()) { 319 | throw MathError{ErrorType::ERROR_WRONG_EXPRESSION}; 320 | } 321 | 322 | tokenStack.top().node->parent = token.node.get(); 323 | token.node->right = popNode(); 324 | 325 | if (tokenStack.empty()) { 326 | throw MathError{ErrorType::ERROR_WRONG_EXPRESSION}; 327 | } 328 | 329 | tokenStack.top().node->parent = token.node.get(); 330 | token.node->left = popNode(); 331 | 332 | pushToken(token); 333 | continue; 334 | } 335 | 336 | // 一元运算符 337 | assert(GetOperatorNum(token.node->op) == 1); 338 | 339 | if (tokenStack.empty()) { 340 | throw MathError{ErrorType::ERROR_WRONG_EXPRESSION}; 341 | } 342 | 343 | tokenStack.top().node->parent = token.node.get(); 344 | token.node->left = popNode(); 345 | 346 | break; 347 | 348 | case NodeType::NUMBER: 349 | case NodeType::VARIABLE: 350 | default: 351 | break; 352 | } 353 | 354 | pushToken(token); 355 | } 356 | 357 | // 如果现在临时栈里面有超过1个元素,那么除了栈顶,其他的都代表出错 358 | if (tokenStack.size() > 1) { 359 | // 扔掉最顶上的,构造到一半的表达式 360 | tokenStack.pop(); 361 | 362 | std::vector errors; 363 | while (!tokenStack.empty()) { 364 | Token &token = tokenStack.top(); 365 | errors.emplace_back(token.line, token.pos, token.content, "Parse Error at: \"", token.s, "\""); 366 | tokenStack.pop(); 367 | } 368 | throw MultiParseError(errors); 369 | } 370 | 371 | return popNode(); 372 | } 373 | 374 | } // namespace internal 375 | 376 | Node Parse(internal::StringView expression) { 377 | auto tokens = internal::ParseFunctions::ParseToTokens(expression); 378 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 379 | auto node = internal::ParseFunctions::BuildExpressionTree(postOrder); 380 | return node; 381 | } 382 | 383 | Node operator""_f(const char *exp, size_t) { 384 | return Parse(exp); 385 | } 386 | 387 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/parse.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "node.h" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace tomsolver { 11 | 12 | namespace internal { 13 | 14 | class StringView { 15 | public: 16 | constexpr StringView() noexcept = default; 17 | constexpr StringView(const char *str, size_t len) noexcept : str{str}, len{len} {} 18 | constexpr StringView(const char *str) noexcept : str{str} { 19 | while (*str++) { 20 | len++; 21 | } 22 | } 23 | StringView(const std::string &str) noexcept : StringView{str.data(), str.size()} {} 24 | constexpr StringView(const StringView &) noexcept = default; 25 | constexpr StringView &operator=(const StringView &) noexcept = default; 26 | 27 | constexpr auto begin() const noexcept { 28 | return str; 29 | } 30 | constexpr auto end() const noexcept { 31 | return str + len; 32 | } 33 | constexpr auto empty() const noexcept { 34 | return !len; 35 | } 36 | 37 | auto toString() const noexcept { 38 | return std::string{begin(), end()}; 39 | } 40 | 41 | template 42 | friend Stream &operator<<(Stream &s, const internal::StringView &sv) { 43 | s.rdbuf()->sputn(sv.str, sv.len); 44 | return s; 45 | } 46 | 47 | private: 48 | const char *str = nullptr; 49 | size_t len = 0; 50 | }; 51 | 52 | template 53 | void append(Stream &) {} 54 | 55 | template 56 | void append(Stream &s, T &&arg, Ts &&...args) { 57 | s << std::forward(arg); 58 | append(s, std::forward(args)...); 59 | } 60 | 61 | struct Token; 62 | } // namespace internal 63 | 64 | class ParseError : public std::runtime_error { 65 | public: 66 | protected: 67 | ParseError() : std::runtime_error("parse error") {} 68 | }; 69 | 70 | class SingleParseError : public ParseError { 71 | public: 72 | template 73 | SingleParseError(int line, int pos, internal::StringView content, T &&...errInfo) 74 | : line(line), pos(pos), content(content) { 75 | std::stringstream ss; 76 | 77 | ss << "[Parse Error] "; 78 | internal::append(ss, std::forward(errInfo)...); 79 | ss << " at(" << line << ", " << pos << "):\n" 80 | << content << "\n" 81 | << std::string(pos, ' ') << "^---- error position"; 82 | 83 | whatStr = ss.str(); 84 | } 85 | 86 | virtual const char *what() const noexcept override; 87 | 88 | int GetLine() const noexcept; 89 | 90 | int GetPos() const noexcept; 91 | 92 | private: 93 | int line; // the line index 94 | int pos; // the position of character 95 | internal::StringView content; // the whole content of the line 96 | std::string whatStr; // the whole error message 97 | }; 98 | 99 | class MultiParseError : public ParseError { 100 | public: 101 | MultiParseError(const std::vector &parseErrors); 102 | 103 | virtual const char *what() const noexcept override; 104 | 105 | private: 106 | std::vector parseErrors; 107 | std::string whatStr; // the whole error message 108 | }; 109 | 110 | namespace internal { 111 | 112 | struct Token { 113 | Node node; // node 114 | internal::StringView s; // the string of this token 115 | int line; // the line index 116 | int pos; // the position of character 117 | bool isBaseOperator; // if is base operator (single-character operator or parenthesis) 118 | internal::StringView content; // the whole content of the line 119 | Token(int line, int pos, bool isBaseOperator, StringView s, StringView content) 120 | : s(s), line(line), pos(pos), isBaseOperator(isBaseOperator), content(content) {} 121 | }; 122 | 123 | class ParseFunctions { 124 | public: 125 | /** 126 | * 解析表达式字符串为in order记号流。其实就是做词法分析。 127 | * @exception ParseError 128 | */ 129 | static std::deque ParseToTokens(StringView expression); 130 | 131 | /** 132 | * 由in order序列得到post order序列。实质上是把记号流转化为逆波兰表达式。 133 | * @exception ParseError 134 | */ 135 | static std::vector InOrderToPostOrder(std::deque &inOrder); 136 | 137 | /** 138 | * 将PostOrder建立为树,并进行表达式有效性检验(确保二元及一元运算符、函数均有操作数)。 139 | * @exception ParseError 140 | */ 141 | static Node BuildExpressionTree(std::vector &postOrder); 142 | }; 143 | 144 | } // namespace internal 145 | 146 | /** 147 | * 把字符串解析为表达式。 148 | * @exception ParseError 149 | */ 150 | Node Parse(internal::StringView expression); 151 | 152 | Node operator""_f(const char *exp, size_t); 153 | 154 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/simplify.cpp: -------------------------------------------------------------------------------- 1 | #include "simplify.h" 2 | #include "math_operator.h" 3 | #include "node.h" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace tomsolver { 11 | 12 | namespace internal { 13 | 14 | class SimplifyFunctions { 15 | public: 16 | struct SimplifyNode { 17 | NodeImpl &node; 18 | bool isLeftChild; 19 | 20 | SimplifyNode(NodeImpl &node) : node(node), isLeftChild(!node.parent || node.parent->left.get() == &node) {} 21 | }; 22 | 23 | // 对单节点n进行化简。 24 | static void SimplifySingleNode(std::unique_ptr &n) noexcept { 25 | auto parent = n->parent; 26 | switch (GetOperatorNum(n->op)) { 27 | // 对于1元运算符,且儿子是数字的,直接计算出来 28 | case 1: 29 | if (n->left->type == NodeType::NUMBER) { 30 | n->type = NodeType::NUMBER; 31 | n->value = tomsolver::Calc(n->op, n->left->value, 0); 32 | n->op = MathOperator::MATH_NULL; 33 | n->left = nullptr; 34 | } 35 | break; 36 | 37 | // 对于2元运算符 38 | case 2: 39 | // 儿子是数字的,直接计算出来 40 | if (n->left->type == NodeType::NUMBER && n->right->type == NodeType::NUMBER) { 41 | n->type = NodeType::NUMBER; 42 | n->value = tomsolver::Calc(n->op, n->left->value, n->right->value); 43 | n->op = MathOperator::MATH_NULL; 44 | n->left = nullptr; 45 | n->right = nullptr; 46 | return; 47 | } 48 | 49 | assert(n->left && n->right); 50 | bool lChildIs0 = n->left->type == NodeType::NUMBER && n->left->value == 0.0; 51 | bool rChildIs0 = n->right->type == NodeType::NUMBER && n->right->value == 0.0; 52 | bool lChildIs1 = n->left->type == NodeType::NUMBER && n->left->value == 1.0; 53 | bool rChildIs1 = n->right->type == NodeType::NUMBER && n->right->value == 1.0; 54 | 55 | // 任何数乘或被乘0、被0除、0的除0外的任何次方,等于0 56 | if ((n->op == MathOperator::MATH_MULTIPLY && (lChildIs0 || rChildIs0)) || 57 | (n->op == MathOperator::MATH_DIVIDE && lChildIs0) || (n->op == MathOperator::MATH_POWER && lChildIs0)) { 58 | n = Num(0); 59 | n->parent = parent; 60 | return; 61 | } 62 | 63 | // 任何数加或被加0、被减0、乘或被乘1、被1除、开1次方,等于自身 64 | if ((n->op == MathOperator::MATH_ADD && (lChildIs0 || rChildIs0)) || 65 | (n->op == MathOperator::MATH_SUB && rChildIs0) || 66 | (n->op == MathOperator::MATH_MULTIPLY && (lChildIs1 || rChildIs1)) || 67 | (n->op == MathOperator::MATH_DIVIDE && rChildIs1) || (n->op == MathOperator::MATH_POWER && rChildIs1)) { 68 | 69 | if (lChildIs1 || lChildIs0) { 70 | n->left = nullptr; 71 | n = Move(n->right); 72 | n->parent = parent; 73 | } else if (rChildIs1 || rChildIs0) { 74 | n->right = nullptr; 75 | n = Move(n->left); 76 | n->parent = parent; 77 | } 78 | return; 79 | } 80 | } 81 | } 82 | 83 | // 后序遍历。非递归实现。 84 | static void SimplifyWholeNode(Node &node) { 85 | 86 | // 借助一个栈,得到反向的后序遍历序列,结果保存在revertedPostOrder。除了root节点,root节点不保存在revertedPostOrder里,最后单独化简。 87 | std::stack stk; 88 | std::deque revertedPostOrder; 89 | 90 | auto popNode = [&stk] { 91 | auto node = std::move(stk.top()); 92 | stk.pop(); 93 | return node; 94 | }; 95 | 96 | // ==== Part I ==== 97 | 98 | if (node->type != NodeType::OPERATOR) { 99 | return; 100 | } 101 | 102 | stk.push(SimplifyNode(*node.get())); 103 | 104 | while (!stk.empty()) { 105 | auto f = popNode(); 106 | 107 | if (f.node.left && f.node.left->type == NodeType::OPERATOR) { 108 | stk.push(SimplifyNode(*f.node.left.get())); 109 | } 110 | 111 | if (f.node.right && f.node.right->type == NodeType::OPERATOR) { 112 | stk.push(SimplifyNode(*f.node.right.get())); 113 | } 114 | 115 | revertedPostOrder.emplace_back(std::move(f)); 116 | } 117 | 118 | // pop掉root,root最后单独处理 119 | revertedPostOrder.pop_front(); 120 | 121 | // ==== Part II ==== 122 | std::for_each(revertedPostOrder.rbegin(), revertedPostOrder.rend(), [](SimplifyNode &snode) { 123 | SimplifySingleNode(snode.isLeftChild ? snode.node.parent->left : snode.node.parent->right); 124 | }); 125 | 126 | SimplifyFunctions::SimplifySingleNode(node); 127 | 128 | return; 129 | 130 | // if (GetOperateNum(now->eOperator) == 2) { 131 | 132 | // //任何节点的0次方均等于1,除了0的0次方已在前面报错 133 | // if (now->eOperator == MATH_POWER && RChildIs0) { 134 | // //替换掉当前运算符,这个1节点将在回溯时处理 135 | // //新建一个1节点 136 | // TNode *temp; 137 | // temp = new TNode; 138 | // temp->eType = NODE_NUMBER; 139 | // temp->value = 1; 140 | 141 | // // 0节点连接到上面 142 | // if (now != head) { 143 | // if (now->parent->left == now) { 144 | // now->parent->left = temp; 145 | // temp->parent = now->parent; 146 | // } 147 | // if (now->parent->right == now) { 148 | // now->parent->right = temp; 149 | // temp->parent = now->parent; 150 | // } 151 | // } else 152 | // head = temp; 153 | 154 | // DeleteNode(now); 155 | // } 156 | 157 | // // 0-x=-x 158 | // if (now->eOperator == MATH_SUB && LChildIs0) { 159 | // TNode *LChild = now->left; 160 | // TNode *RChild = now->right; 161 | // now->eOperator = MATH_NEGATIVE; 162 | // now->left = RChild; 163 | // now->right = NULL; 164 | 165 | // delete LChild; 166 | // } 167 | 168 | //} 169 | } 170 | }; 171 | 172 | } // namespace internal 173 | 174 | void Simplify(Node &node) noexcept { 175 | internal::SimplifyFunctions::SimplifyWholeNode(node); 176 | return; 177 | } 178 | 179 | } // namespace tomsolver 180 | -------------------------------------------------------------------------------- /src/tomsolver/simplify.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "node.h" 3 | 4 | namespace tomsolver { 5 | 6 | void Simplify(Node &node) noexcept; 7 | 8 | } // namespace tomsolver 9 | -------------------------------------------------------------------------------- /src/tomsolver/subs.cpp: -------------------------------------------------------------------------------- 1 | #include "subs.h" 2 | #include "node.h" 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tomsolver { 9 | 10 | namespace internal { 11 | 12 | class SubsFunctions { 13 | public: 14 | // 前序遍历。非递归实现。 15 | static Node SubsInner(Node node, const std::map &dict) noexcept { 16 | 17 | std::stack> stk; 18 | 19 | auto Replace = [&dict](Node &cur) { 20 | if (cur->type != NodeType::VARIABLE) { 21 | return false; 22 | } 23 | 24 | auto itor = dict.find(cur->varname); 25 | if (itor == dict.end()) { 26 | return false; 27 | } 28 | 29 | auto parent = cur->parent; 30 | cur = Clone(itor->second); 31 | cur->parent = parent; 32 | 33 | return true; 34 | }; 35 | 36 | if (!Replace(node)) { 37 | auto TryReplace = [&stk, &Replace](Node &cur) { 38 | if (cur && !Replace(cur)) { 39 | stk.emplace(*cur); 40 | } 41 | }; 42 | 43 | TryReplace(node->right); 44 | TryReplace(node->left); 45 | 46 | while (!stk.empty()) { 47 | auto &f = stk.top().get(); 48 | stk.pop(); 49 | TryReplace(f.right); 50 | TryReplace(f.left); 51 | } 52 | } 53 | 54 | #ifndef NDEBUG 55 | node->CheckParent(); 56 | #endif 57 | return node; 58 | } 59 | }; 60 | 61 | } // namespace internal 62 | 63 | Node Subs(const Node &node, const std::string &oldVar, const Node &newNode) noexcept { 64 | return Subs(Clone(node), oldVar, newNode); 65 | } 66 | 67 | Node Subs(Node &&node, const std::string &oldVar, const Node &newNode) noexcept { 68 | std::map dict; 69 | dict.insert({oldVar, Clone(newNode)}); 70 | return internal::SubsFunctions::SubsInner(Move(node), dict); 71 | } 72 | 73 | Node Subs(const Node &node, const std::vector &oldVars, const SymVec &newNodes) noexcept { 74 | return Subs(Clone(node), oldVars, newNodes); 75 | } 76 | 77 | Node Subs(Node &&node, const std::vector &oldVars, const SymVec &newNodes) noexcept { 78 | assert(static_cast(oldVars.size()) == newNodes.Rows()); 79 | std::map dict; 80 | for (size_t i = 0; i < oldVars.size(); ++i) { 81 | dict.insert({oldVars[i], Clone(newNodes[i])}); 82 | } 83 | return internal::SubsFunctions::SubsInner(Move(node), dict); 84 | } 85 | 86 | Node Subs(const Node &node, const std::map &dict) noexcept { 87 | return Subs(Clone(node), dict); 88 | } 89 | 90 | Node Subs(Node &&node, const std::map &dict) noexcept { 91 | return internal::SubsFunctions::SubsInner(Move(node), dict); 92 | } 93 | 94 | Node Subs(const Node &node, const std::map &varValues) noexcept { 95 | return Subs(Clone(node), varValues); 96 | } 97 | 98 | Node Subs(Node &&node, const std::map &varValues) noexcept { 99 | std::map dict; 100 | for (auto &item : varValues) { 101 | dict.insert({item.first, Num(item.second)}); 102 | } 103 | return internal::SubsFunctions::SubsInner(Move(node), dict); 104 | } 105 | 106 | Node Subs(const Node &node, const VarsTable &varsTable) noexcept { 107 | return Subs(Clone(node), varsTable); 108 | } 109 | 110 | Node Subs(Node &&node, const VarsTable &varsTable) noexcept { 111 | std::map dict; 112 | for (auto &item : varsTable) { 113 | dict.insert({item.first, Num(item.second)}); 114 | } 115 | return internal::SubsFunctions::SubsInner(Move(node), dict); 116 | } 117 | 118 | } // namespace tomsolver 119 | -------------------------------------------------------------------------------- /src/tomsolver/subs.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "node.h" 4 | #include "symmat.h" 5 | #include "vars_table.h" 6 | 7 | namespace tomsolver { 8 | 9 | /** 10 | * 用newNode节点替换oldVar指定的变量。 11 | */ 12 | Node Subs(const Node &node, const std::string &oldVar, const Node &newNode) noexcept; 13 | 14 | /** 15 | * 用newNode节点替换oldVar指定的变量。 16 | */ 17 | Node Subs(Node &&node, const std::string &oldVar, const Node &newNode) noexcept; 18 | 19 | /** 20 | * 用newNodes节点替换oldVars指定的变量。 21 | */ 22 | Node Subs(const Node &node, const std::vector &oldVars, const SymVec &newNodes) noexcept; 23 | 24 | /** 25 | * 用newNodes节点替换oldVars指定的变量。 26 | */ 27 | Node Subs(Node &&node, const std::vector &oldVars, const SymVec &newNodes) noexcept; 28 | 29 | /** 30 | * 用newNodes节点替换oldVars指定的变量。 31 | */ 32 | Node Subs(const Node &node, const std::map &dict) noexcept; 33 | 34 | /** 35 | * 用newNodes节点替换oldVars指定的变量。 36 | */ 37 | Node Subs(Node &&node, const std::map &dict) noexcept; 38 | 39 | /** 40 | * 用newNodes节点替换oldVars指定的变量。 41 | */ 42 | Node Subs(const Node &node, const std::map &varValues) noexcept; 43 | 44 | /** 45 | * 用newNodes节点替换oldVars指定的变量。 46 | */ 47 | Node Subs(Node &&node, const std::map &varValues) noexcept; 48 | 49 | /** 50 | * 用newNodes节点替换oldVars指定的变量。 51 | */ 52 | Node Subs(const Node &node, const VarsTable &varsTable) noexcept; 53 | 54 | /** 55 | * 用newNodes节点替换oldVars指定的变量。 56 | */ 57 | Node Subs(Node &&node, const VarsTable &varsTable) noexcept; 58 | 59 | } // namespace tomsolver 60 | -------------------------------------------------------------------------------- /src/tomsolver/symmat.cpp: -------------------------------------------------------------------------------- 1 | #include "symmat.h" 2 | 3 | #include "diff.h" 4 | #include "error_type.h" 5 | #include "mat.h" 6 | #include "node.h" 7 | #include "subs.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace tomsolver { 15 | 16 | using DataType = std::valarray; 17 | 18 | SymMat::SymMat(int rows, int cols) noexcept : rows(rows), cols(cols) { 19 | assert(rows > 0 && cols > 0); 20 | data.reset(new DataType(rows * cols)); 21 | } 22 | 23 | SymMat::SymMat(std::initializer_list> init) noexcept { 24 | rows = static_cast(init.size()); 25 | cols = static_cast(std::max(init, [](auto lhs, auto rhs) { 26 | return lhs.size() < rhs.size(); 27 | }).size()); 28 | data.reset(new DataType(rows * cols)); 29 | 30 | auto i = 0; 31 | for (auto val : init) { 32 | auto j = 0; 33 | for (auto &node : val) { 34 | (*data)[i * cols + j++] = std::move(const_cast(node)); 35 | } 36 | i++; 37 | } 38 | } 39 | 40 | SymMat::SymMat(const Mat &rhs) noexcept : SymMat(rhs.Rows(), rhs.Cols()) { 41 | std::generate(std::begin(*data), std::end(*data), [p = std::addressof(rhs.Value(0, 0))]() mutable { 42 | return Num(*p++); 43 | }); 44 | } 45 | 46 | SymMat SymMat::Clone() const noexcept { 47 | SymMat ret(Rows(), Cols()); 48 | std::generate(std::begin(*ret.data), std::end(*ret.data), [iter = std::begin(*data)]() mutable { 49 | return tomsolver::Clone(*iter++); 50 | }); 51 | return ret; 52 | } 53 | 54 | bool SymMat::Empty() const noexcept { 55 | return data->size() == 0; 56 | } 57 | 58 | int SymMat::Rows() const noexcept { 59 | return rows; 60 | } 61 | 62 | int SymMat::Cols() const noexcept { 63 | return cols; 64 | } 65 | 66 | SymVec SymMat::ToSymVec() const { 67 | assert(rows > 0); 68 | if (cols != 1) { 69 | throw std::runtime_error("SymMat::ToSymVec fail. rows is not one"); 70 | } 71 | return ToSymVecOneByOne(); 72 | } 73 | 74 | SymVec SymMat::ToSymVecOneByOne() const noexcept { 75 | SymVec v(rows * cols); 76 | std::generate(std::begin(*v.data), std::end(*v.data), [iter = std::begin(*data)]() mutable { 77 | return tomsolver::Clone(*iter++); 78 | }); 79 | return v; 80 | } 81 | 82 | Mat SymMat::ToMat() const { 83 | std::valarray newData(data->size()); 84 | std::generate(std::begin(newData), std::end(newData), [iter = std::begin(*data)]() mutable { 85 | if ((**iter).type != NodeType::NUMBER) { 86 | throw std::runtime_error("ToMat error: node is not number"); 87 | } 88 | return (**iter++).value; 89 | }); 90 | return {rows, cols, newData}; 91 | } 92 | 93 | SymMat &SymMat::Calc() { 94 | for (auto &node : *data) { 95 | node->Calc(); 96 | } 97 | return *this; 98 | } 99 | 100 | SymMat &SymMat::Subs(const std::map &varValues) noexcept { 101 | for (auto &node : *data) { 102 | node = tomsolver::Subs(std::move(node), varValues); 103 | } 104 | return *this; 105 | } 106 | 107 | SymMat &SymMat::Subs(const VarsTable &varsTable) noexcept { 108 | for (auto &node : *data) { 109 | node = tomsolver::Subs(std::move(node), varsTable); 110 | } 111 | return *this; 112 | } 113 | 114 | std::set SymMat::GetAllVarNames() const noexcept { 115 | std::set ret; 116 | for (auto &node : *data) { 117 | auto names = node->GetAllVarNames(); 118 | ret.insert(names.begin(), names.end()); 119 | } 120 | return ret; 121 | } 122 | 123 | SymMat SymMat::operator-(const SymMat &rhs) const noexcept { 124 | assert(rhs.rows == rows && rhs.cols == cols); 125 | SymMat ret(rows, cols); 126 | std::generate(std::begin(*ret.data), std::end(*ret.data), 127 | [lhsIter = std::begin(*data), rhsIter = std::begin(*rhs.data)]() mutable { 128 | return *lhsIter++ - *rhsIter++; 129 | }); 130 | return ret; 131 | } 132 | 133 | SymMat SymMat::operator*(const SymMat &rhs) const { 134 | if (cols != rhs.rows) { 135 | throw MathError(ErrorType::SIZE_NOT_MATCH); 136 | } 137 | 138 | SymMat ans(rows, rhs.cols); 139 | for (int i = 0; i < Rows(); ++i) { 140 | for (int j = 0; j < rhs.cols; ++j) { 141 | auto sum = Value(i, 0) * rhs.Value(0, j); 142 | for (int k = 1; k < cols; ++k) { 143 | sum += Value(i, k) * rhs.Value(k, j); 144 | } 145 | ans.Value(i, j) = Move(sum); 146 | } 147 | } 148 | return ans; 149 | } 150 | 151 | bool SymMat::operator==(const SymMat &rhs) const noexcept { 152 | if (rhs.rows != rows || rhs.cols != cols) { 153 | return false; 154 | } 155 | return std::equal(std::begin(*data), std::end(*data), std::begin(*rhs.data), [](auto &node1, auto &node2) { 156 | return node1->Equal(node2); 157 | }); 158 | } 159 | 160 | Node &SymMat::Value(int i, int j) noexcept { 161 | return (*data)[i * cols + j]; 162 | } 163 | 164 | const Node &SymMat::Value(int i, int j) const noexcept { 165 | return (*data)[i * cols + j]; 166 | } 167 | 168 | std::string SymMat::ToString() const noexcept { 169 | if (data->size() == 0) { 170 | return "[]\n"; 171 | } 172 | 173 | std::stringstream ss; 174 | ss << "["; 175 | 176 | size_t i = 0; 177 | for (auto &node : *data) { 178 | ss << (i == 0 ? "" : " ") << node->ToString(); 179 | i++; 180 | ss << (i % cols == 0 ? (i == data->size() ? "]\n" : "\n") : ", "); 181 | } 182 | 183 | return ss.str(); 184 | } 185 | 186 | SymVec::SymVec(int rows) noexcept : SymMat(rows, 1) {} 187 | 188 | SymVec::SymVec(std::initializer_list init) noexcept : SymMat(static_cast(init.size()), 1) { 189 | auto i = 0; 190 | for (auto &node : init) { 191 | (*data)[i++] = std::move(const_cast(node)); 192 | } 193 | } 194 | 195 | SymVec SymVec::operator-(const SymVec &rhs) const noexcept { 196 | return SymMat::operator-(rhs).ToSymVec(); 197 | } 198 | 199 | Node &SymVec::operator[](std::size_t index) noexcept { 200 | return (*data)[index]; 201 | } 202 | 203 | const Node &SymVec::operator[](std::size_t index) const noexcept { 204 | return (*data)[index]; 205 | } 206 | 207 | SymMat Jacobian(const SymMat &equations, const std::vector &vars) noexcept { 208 | int rows = equations.rows; 209 | int cols = static_cast(vars.size()); 210 | SymMat ja(rows, cols); 211 | std::generate(std::begin(*ja.data), std::end(*ja.data), 212 | [iter = std::begin(*equations.data), &vars, i = size_t{0}]() mutable { 213 | if (i == vars.size()) { 214 | i = 0; 215 | iter++; 216 | } 217 | return Diff(*iter, vars[i++]); 218 | }); 219 | return ja; 220 | } 221 | 222 | std::ostream &operator<<(std::ostream &out, const SymMat &symMat) noexcept { 223 | return out << symMat.ToString(); 224 | } 225 | 226 | } // namespace tomsolver 227 | -------------------------------------------------------------------------------- /src/tomsolver/symmat.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mat.h" 4 | #include "node.h" 5 | #include "vars_table.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace tomsolver { 13 | 14 | class SymVec; 15 | class SymMat { 16 | public: 17 | /** 18 | * 19 | */ 20 | SymMat(int rows, int cols) noexcept; 21 | 22 | /** 23 | * 使用初始化列表构造。注意列表内的对象将被强行移动至Vec内部。 24 | */ 25 | SymMat(std::initializer_list> init) noexcept; 26 | 27 | /** 28 | * 从数值矩阵构造符号矩阵 29 | */ 30 | SymMat(const Mat &rhs) noexcept; 31 | 32 | SymMat Clone() const noexcept; 33 | 34 | bool Empty() const noexcept; 35 | 36 | int Rows() const noexcept; 37 | 38 | int Cols() const noexcept; 39 | 40 | /** 41 | * 输出Vec。如果列数不为1,抛出异常。 42 | * @exception runtime_error 列数不为1 43 | */ 44 | SymVec ToSymVec() const; 45 | 46 | /** 47 | * 逐个元素转换为符号向量(列向量)。 48 | */ 49 | SymVec ToSymVecOneByOne() const noexcept; 50 | 51 | /** 52 | * 得到数值矩阵。前提条件是矩阵内的元素均为数值节点,否则抛出异常。 53 | * @exception runtime_error 存在非数值节点 54 | */ 55 | Mat ToMat() const; 56 | 57 | /** 58 | * 把矩阵的内的元素均计算为数值节点。 59 | * @exception runtime_error 如果有变量存在,则无法计算 60 | * @exception MathError 出现浮点数无效值(inf, -inf, nan) 61 | */ 62 | SymMat &Calc(); 63 | 64 | SymMat &Subs(const std::map &varValues) noexcept; 65 | 66 | SymMat &Subs(const VarsTable &varsTable) noexcept; 67 | 68 | /** 69 | * 返回符号矩阵内出现的所有变量名。 70 | */ 71 | std::set GetAllVarNames() const noexcept; 72 | 73 | /** 74 | * 如果rhs和自己的维数不匹配会触发assert。 75 | */ 76 | SymMat operator-(const SymMat &rhs) const noexcept; 77 | 78 | /** 79 | * 80 | * @exception MathError 维数不匹配 81 | */ 82 | SymMat operator*(const SymMat &rhs) const; 83 | 84 | /** 85 | * 返回是否相等。 86 | * 目前只能判断表达式树完全一致的情况。 87 | * TODO 改为可以判断等价表达式 88 | */ 89 | bool operator==(const SymMat &rhs) const noexcept; 90 | 91 | Node& Value(int i, int j) noexcept; 92 | const Node& Value(int i, int j) const noexcept; 93 | 94 | std::string ToString() const noexcept; 95 | 96 | protected: 97 | int rows, cols; 98 | std::unique_ptr> data; 99 | 100 | friend SymMat Jacobian(const SymMat &equations, const std::vector &vars) noexcept; 101 | }; 102 | 103 | class SymVec : public SymMat { 104 | public: 105 | /** 106 | * 107 | */ 108 | SymVec(int rows) noexcept; 109 | 110 | /** 111 | * 使用初始化列表构造。注意列表内的对象将被强行移动至Vec内部。 112 | */ 113 | SymVec(std::initializer_list init) noexcept; 114 | 115 | /** 116 | * 如果rhs和自己的维数不匹配会触发assert。 117 | */ 118 | SymVec operator-(const SymVec &rhs) const noexcept; 119 | 120 | Node &operator[](std::size_t index) noexcept; 121 | 122 | const Node &operator[](std::size_t index) const noexcept; 123 | }; 124 | 125 | SymMat Jacobian(const SymMat &equations, const std::vector &vars) noexcept; 126 | 127 | std::ostream &operator<<(std::ostream &out, const SymMat &symMat) noexcept; 128 | 129 | } // namespace tomsolver 130 | -------------------------------------------------------------------------------- /src/tomsolver/tomsolver.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "config.h" 4 | #include "node.h" // error_type.h math_operator.h 5 | #include "functions.h" 6 | #include "simplify.h" 7 | #include "diff.h" 8 | #include "subs.h" // symmat.h vars_table.h 9 | #include "symmat.h" // mat.h vars_table.h 10 | #include "parse.h" 11 | #include "linear.h" 12 | #include "nonlinear.h" -------------------------------------------------------------------------------- /src/tomsolver/vars_table.cpp: -------------------------------------------------------------------------------- 1 | #include "vars_table.h" 2 | 3 | #include "config.h" 4 | 5 | #include 6 | #include 7 | 8 | namespace tomsolver { 9 | 10 | VarsTable::VarsTable(const std::vector &vars, double initValue) 11 | : vars(vars), values(static_cast(vars.size()), initValue) { 12 | for (auto &var : vars) { 13 | table.insert({var, initValue}); 14 | } 15 | assert(vars.size() == table.size() && "vars is not unique"); 16 | } 17 | 18 | VarsTable::VarsTable(std::initializer_list> initList) 19 | : VarsTable({initList.begin(), initList.end()}) { 20 | assert(vars.size() == table.size() && "vars is not unique"); 21 | } 22 | 23 | VarsTable::VarsTable(const std::map &table) noexcept 24 | : vars(table.size()), values(static_cast(table.size())), table(table) { 25 | int i = 0; 26 | for (auto &item : table) { 27 | vars[i] = item.first; 28 | values[i] = item.second; 29 | ++i; 30 | } 31 | } 32 | 33 | int VarsTable::VarNums() const noexcept { 34 | return static_cast(table.size()); 35 | } 36 | 37 | const std::vector &VarsTable::Vars() const noexcept { 38 | return vars; 39 | } 40 | 41 | const Vec &VarsTable::Values() const noexcept { 42 | return values; 43 | } 44 | 45 | void VarsTable::SetValues(const Vec &v) noexcept { 46 | assert(v.Rows() == values.Rows()); 47 | values = v; 48 | for (int i = 0; i < values.Rows(); ++i) { 49 | table[vars[i]] = v[i]; 50 | } 51 | } 52 | 53 | bool VarsTable::Has(const std::string &varname) const noexcept { 54 | return table.find(varname) != table.end(); 55 | } 56 | 57 | std::string VarsTable::ToString() const noexcept { 58 | std::string ret; 59 | for (auto &item : table) { 60 | ret += item.first + " = " + tomsolver::ToString(item.second) + "\n"; 61 | } 62 | return ret; 63 | } 64 | 65 | std::map::const_iterator VarsTable::begin() const noexcept { 66 | return table.begin(); 67 | } 68 | 69 | std::map::const_iterator VarsTable::end() const noexcept { 70 | return table.end(); 71 | } 72 | 73 | std::map::const_iterator VarsTable::cbegin() const noexcept { 74 | return table.cbegin(); 75 | } 76 | 77 | std::map::const_iterator VarsTable::cend() const noexcept { 78 | return table.cend(); 79 | } 80 | 81 | bool VarsTable::operator==(const VarsTable &rhs) const noexcept { 82 | return values.Rows() == rhs.values.Rows() && 83 | std::equal(table.begin(), table.end(), rhs.table.begin(), [](const auto &lhs, const auto &rhs) { 84 | auto &lVar = lhs.first; 85 | auto &lVal = lhs.second; 86 | auto &rVar = rhs.first; 87 | auto &rVal = rhs.second; 88 | return lVar == rVar && std::abs(lVal - rVal) <= Config::Get().epsilon; 89 | }); 90 | } 91 | 92 | double VarsTable::operator[](const std::string &varname) const { 93 | auto it = table.find(varname); 94 | if (it == table.end()) { 95 | throw std::out_of_range("no such variable: " + varname); 96 | } 97 | return it->second; 98 | } 99 | 100 | std::ostream &operator<<(std::ostream &out, const VarsTable &table) noexcept { 101 | out << table.ToString(); 102 | return out; 103 | } 104 | 105 | } // namespace tomsolver -------------------------------------------------------------------------------- /src/tomsolver/vars_table.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mat.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace tomsolver { 10 | 11 | /** 12 | * 变量表。 13 | * 内部保存了多个变量名及其数值的对应关系。 14 | */ 15 | class VarsTable { 16 | public: 17 | /** 18 | * 新建变量表。 19 | * @param vars 变量数组 20 | * @param initValue 初值 21 | */ 22 | VarsTable(const std::vector &vars, double initValue); 23 | 24 | /** 25 | * 新建变量表。 26 | * @param vars 变量数组 27 | * @param initValue 初值 28 | */ 29 | explicit VarsTable(std::initializer_list> initList); 30 | 31 | /** 32 | * 新建变量表。 33 | * @param vars 变量数组 34 | * @param initValue 初值 35 | */ 36 | explicit VarsTable(const std::map &table) noexcept; 37 | 38 | /** 39 | * 变量数量。 40 | */ 41 | int VarNums() const noexcept; 42 | 43 | /** 44 | * 返回std::vector容器包装的变量名数组。 45 | */ 46 | const std::vector &Vars() const noexcept; 47 | 48 | /** 49 | * 返回所有变量名对应的值的数值向量。 50 | */ 51 | const Vec &Values() const noexcept; 52 | 53 | /** 54 | * 设置数值向量。 55 | */ 56 | void SetValues(const Vec &v) noexcept; 57 | 58 | /** 59 | * 返回是否有指定的变量。 60 | */ 61 | bool Has(const std::string &varname) const noexcept; 62 | 63 | std::string ToString() const noexcept; 64 | 65 | std::map::const_iterator begin() const noexcept; 66 | 67 | std::map::const_iterator end() const noexcept; 68 | 69 | std::map::const_iterator cbegin() const noexcept; 70 | 71 | std::map::const_iterator cend() const noexcept; 72 | 73 | bool operator==(const VarsTable &rhs) const noexcept; 74 | 75 | /** 76 | * 根据变量名获取数值。 77 | * @exception out_of_range 如果没有这个变量,抛出异常 78 | */ 79 | double operator[](const std::string &varname) const; 80 | 81 | private: 82 | std::vector vars; 83 | Vec values; 84 | std::map table; 85 | }; 86 | 87 | std::ostream &operator<<(std::ostream &out, const VarsTable &table) noexcept; 88 | 89 | } // namespace tomsolver -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # ===================================== 2 | file(GLOB TEST_CODE 3 | *.h 4 | *.cpp 5 | ) 6 | 7 | add_executable(tomsolver_tests ${TEST_CODE}) 8 | 9 | target_include_directories(tomsolver_tests PUBLIC 10 | ../src 11 | ) 12 | 13 | target_link_libraries(tomsolver_tests PUBLIC 14 | tomsolver 15 | gtest_main 16 | ) 17 | 18 | include(GoogleTest) 19 | gtest_discover_tests(tomsolver_tests) -------------------------------------------------------------------------------- /tests/diff_test.cpp: -------------------------------------------------------------------------------- 1 | #include "memory_leak_detection.h" 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | #include 8 | 9 | #include 10 | 11 | using namespace tomsolver; 12 | 13 | using std::cout; 14 | using std::endl; 15 | 16 | TEST(Diff, Base) { 17 | MemoryLeakDetection mld; 18 | 19 | Node n = Var("a"); 20 | ASSERT_TRUE(Diff(n, "a")->Equal(Num(1))); 21 | ASSERT_TRUE(Diff(n, "b")->Equal(Num(0))); 22 | 23 | ASSERT_TRUE(Diff(Num(1), "a")->Equal(Num(0))); 24 | 25 | // diff(a+b, a) == 1 26 | Node n2 = n + Var("b"); 27 | ASSERT_TRUE(Diff(n2, "a")->Equal(Num(1))); 28 | } 29 | 30 | TEST(Diff, Negative) { 31 | MemoryLeakDetection mld; 32 | 33 | Node n = -Var("a"); 34 | auto dn = Diff(n, "a"); 35 | ASSERT_TRUE(dn->Equal(Num(-1))); 36 | ASSERT_TRUE(Diff(n, "b")->Equal(Num(0))); 37 | 38 | ASSERT_TRUE(Diff(-Num(1), "a")->Equal(Num(0))); 39 | 40 | // diff(-a+ -b, a) == -1 41 | Node n2 = n + -Var("b"); 42 | auto dn2a = Diff(n2, "a"); 43 | ASSERT_TRUE(dn2a->Equal(Num(-1))); 44 | auto dn2b = Diff(n2, "b"); 45 | ASSERT_TRUE(dn2b->Equal(Num(-1))); 46 | 47 | // diff(-a+ +b, a) == -1 48 | Node n3 = n + +Var("b"); 49 | ASSERT_TRUE(Diff(n3, "a")->Equal(Num(-1))); 50 | ASSERT_TRUE(Diff(n3, "b")->Equal(Num(1))); 51 | } 52 | 53 | TEST(Diff, Sin) { 54 | MemoryLeakDetection mld; 55 | 56 | { 57 | // sin'x = cos x 58 | Node n = sin(Var("x")); 59 | Node dn = Diff(n, "x"); 60 | dn->CheckParent(); 61 | cout << dn->ToString() << endl; 62 | ASSERT_TRUE(dn->Equal(cos(Var("x")))); 63 | } 64 | } 65 | 66 | TEST(Diff, Cos) { 67 | MemoryLeakDetection mld; 68 | 69 | { 70 | // cos'x = -sin x 71 | Node n = cos(Var("x")); 72 | Node dn = Diff(n, "x"); 73 | dn->CheckParent(); 74 | cout << dn->ToString() << endl; 75 | ASSERT_TRUE(dn->Equal(-sin(Var("x")))); 76 | } 77 | } 78 | 79 | TEST(Diff, Tan) { 80 | MemoryLeakDetection mld; 81 | 82 | { 83 | // tan'x = 1/(cos(x)^2) 84 | Node n = tan(Var("x")); 85 | Node dn = Diff(n, "x"); 86 | dn->CheckParent(); 87 | cout << dn->ToString() << endl; 88 | ASSERT_TRUE(dn->Equal(Num(1) / (cos(Var("x")) ^ Num(2)))); 89 | } 90 | } 91 | 92 | TEST(Diff, ArcSin) { 93 | MemoryLeakDetection mld; 94 | 95 | { 96 | // asin'x = 1/sqrt(1-x^2) 97 | Node n = asin(Var("x")); 98 | Node dn = Diff(n, "x"); 99 | dn->CheckParent(); 100 | cout << dn->ToString() << endl; 101 | ASSERT_TRUE(dn->Equal(Num(1) / sqrt(Num(1) - (Var("x") ^ Num(2))))); 102 | } 103 | } 104 | 105 | TEST(Diff, ArcCos) { 106 | MemoryLeakDetection mld; 107 | 108 | { 109 | // acos'x = -1/sqrt(1-x^2) 110 | Node n = acos(Var("x")); 111 | Node dn = Diff(n, "x"); 112 | dn->CheckParent(); 113 | cout << dn->ToString() << endl; 114 | ASSERT_TRUE(dn->Equal(Num(-1) / sqrt(Num(1) - (Var("x") ^ Num(2))))); 115 | } 116 | } 117 | 118 | TEST(Diff, ArcTan) { 119 | MemoryLeakDetection mld; 120 | 121 | { 122 | // atan'x = 1/(1+x^2) 123 | Node n = atan(Var("x")); 124 | Node dn = Diff(n, "x"); 125 | dn->CheckParent(); 126 | cout << dn->ToString() << endl; 127 | ASSERT_TRUE(dn->Equal(Num(1) / (Num(1) + (Var("x") ^ Num(2))))); 128 | } 129 | } 130 | 131 | TEST(Diff, Sqrt) { 132 | MemoryLeakDetection mld; 133 | 134 | { 135 | // sqrt(x)' = 1/(2*sqrt(x)) 136 | Node n = sqrt(Var("x")); 137 | Node dn = Diff(n, "x"); 138 | dn->CheckParent(); 139 | cout << dn->ToString() << endl; 140 | ASSERT_TRUE(dn->Equal(Num(1) / (Num(2) * sqrt(Var("x"))))); 141 | } 142 | } 143 | 144 | TEST(Diff, Exp) { 145 | MemoryLeakDetection mld; 146 | 147 | { 148 | // (e^x)' = e^x 149 | Node n = exp(Var("x")); 150 | Node dn = Diff(n, "x"); 151 | dn->CheckParent(); 152 | cout << dn->ToString() << endl; 153 | ASSERT_TRUE(dn->Equal(exp(Var("x")))); 154 | } 155 | 156 | { 157 | // (e^sin(x))' = e^sin(x)*cos(x) 158 | Node n = exp(sin(Var("x"))); 159 | Node dn = Diff(n, "x"); 160 | dn->CheckParent(); 161 | cout << dn->ToString() << endl; 162 | ASSERT_TRUE(dn->Equal(exp(sin(Var("x"))) * cos(Var("x")))); 163 | } 164 | } 165 | 166 | TEST(Diff, Multiply) { 167 | MemoryLeakDetection mld; 168 | 169 | // diff(5*a, a) == 5 170 | ASSERT_TRUE(Diff(Num(5) * Var("a"), "a")->Equal(Num(5))); 171 | 172 | // diff(b*5, b) == 5 173 | ASSERT_TRUE(Diff(Var("b") * Num(5), "b")->Equal(Num(5))); 174 | 175 | { 176 | // diff(a*b, a) == b 177 | Node n = Var("a") * Var("b"); 178 | Node dn = Diff(n, "a"); 179 | dn->CheckParent(); 180 | cout << dn->ToString() << endl; 181 | ASSERT_TRUE(dn->Equal(Var("b"))); 182 | } 183 | 184 | { 185 | // diff(a*b*a, a) == 186 | Node n = Var("a") * Var("b") * Var("a"); 187 | Node dn = Diff(n, "a"); 188 | dn->CheckParent(); 189 | cout << dn->ToString() << endl; 190 | } 191 | } 192 | 193 | TEST(Diff, Divide) { 194 | MemoryLeakDetection mld; 195 | 196 | { 197 | // diff(b/5, b) == 1/5 198 | Node d = Diff(Var("b") / Num(5), "b"); 199 | d->CheckParent(); 200 | ASSERT_TRUE(d->Equal(Num(1.0 / 5.0))); 201 | } 202 | 203 | { 204 | // diff(5/a, a) == -5/a^2 205 | Node d = Diff(Num(5) / Var("a"), "a"); 206 | d->CheckParent(); 207 | ASSERT_TRUE(d->Equal(Num(-5) / (Var("a") ^ Num(2)))); 208 | } 209 | 210 | { 211 | // diff(x^2/sin(x), x) = (2*x*sin(x)-x^2*cos(x))/sin(x)^2 212 | Node n = (Var("x") ^ Num(2)) / sin(Var("x")); 213 | Node dn = Diff(n, "x"); 214 | dn->CheckParent(); 215 | ASSERT_EQ(dn->ToString(), "(2*x*sin(x)-x^2*cos(x))/sin(x)^2"); 216 | } 217 | } 218 | 219 | TEST(Diff, Log) { 220 | MemoryLeakDetection mld; 221 | 222 | { 223 | // log(x)' = 1/x 224 | Node n = log(Var("x")); 225 | Node dn = Diff(n, "x"); 226 | dn->CheckParent(); 227 | cout << dn->ToString() << endl; 228 | ASSERT_TRUE(dn->Equal(Num(1) / Var("x"))); 229 | } 230 | 231 | { 232 | // log(sin(x))' = 1/x * cos(x) 233 | Node n = log(sin(Var("x"))); 234 | Node dn = Diff(n, "x"); 235 | dn->CheckParent(); 236 | cout << dn->ToString() << endl; 237 | ASSERT_TRUE(dn->Equal(Num(1) / sin(Var("x")) * cos(Var("x")))); 238 | } 239 | } 240 | 241 | TEST(Diff, LogChain) { 242 | MemoryLeakDetection mld; 243 | { 244 | // (x*ln(x))' = ln(x)+1 245 | Node n = Var("x") * log(Var("x")); 246 | Node dn = Diff(n, "x"); 247 | dn->CheckParent(); 248 | cout << dn->ToString() << endl; 249 | ASSERT_EQ(dn->ToString(), "log(x)+x*1/x"); // FIXME: 化简 250 | } 251 | } 252 | 253 | TEST(Diff, Log2) { 254 | MemoryLeakDetection mld; 255 | 256 | { 257 | // loga(x)' = 1/(x * ln(a)) 258 | Node n = log2(Var("x")); 259 | Node dn = Diff(n, "x"); 260 | dn->CheckParent(); 261 | cout << dn->ToString() << endl; 262 | ASSERT_TRUE(dn->Equal(Num(1) / (Var("x") * Num(std::log(2))))); 263 | } 264 | } 265 | 266 | TEST(Diff, Log10) { 267 | MemoryLeakDetection mld; 268 | 269 | { 270 | // loga(x)' = 1/(x * ln(a)) 271 | Node n = log10(Var("x")); 272 | Node dn = Diff(n, "x"); 273 | dn->CheckParent(); 274 | cout << dn->ToString() << endl; 275 | ASSERT_TRUE(dn->Equal(Num(1) / (Var("x") * Num(std::log(10.0))))); 276 | } 277 | } 278 | 279 | TEST(Diff, Power) { 280 | MemoryLeakDetection mld; 281 | 282 | { 283 | // (x^a)' = a*x^(a-1) 284 | Node n = Var("x") ^ Num(5); 285 | Node dn = Diff(n, "x"); 286 | dn->CheckParent(); 287 | cout << dn->ToString() << endl; 288 | ASSERT_TRUE(dn->Equal(Num(5) * (Var("x") ^ Num(4)))); 289 | } 290 | 291 | { 292 | // (a^x)' = a^x * ln(a) when a>0 and a!=1 293 | Node n = Num(3) ^ Var("x"); 294 | Node dn = Diff(n, "x"); 295 | dn->CheckParent(); 296 | cout << dn->ToString() << endl; 297 | Node expect = (Num(3) ^ Var("x")) * Num(std::log(3)); 298 | ASSERT_TRUE(dn->Equal(expect)); 299 | } 300 | 301 | { 302 | // (u^v)' = ( e^(v*ln(u)) )' = e^(v*ln(u)) * (v*ln(u))' 303 | 304 | Node n = Var("x") ^ Var("x"); 305 | Node dn = Diff(n, "x"); 306 | dn->CheckParent(); 307 | cout << dn->ToString() << endl; 308 | ASSERT_EQ(dn->ToString(), "x^x*(log(x)+x*1/x)"); // FIXME: 化简 309 | } 310 | 311 | { 312 | // (u^v)' = ( e^(v*ln(u)) )' = e^(v*ln(u)) * (v*ln(u))' 313 | 314 | Node n = sin(Var("x")) ^ cos(Var("x")); 315 | Node dn = Diff(n, "x"); 316 | dn->CheckParent(); 317 | cout << dn->ToString() << endl; 318 | ASSERT_EQ(dn->ToString(), "sin(x)^cos(x)*(-(sin(x))*log(sin(x))+cos(x)*1/sin(x)*cos(x))"); // FIXME: 化简 319 | } 320 | } 321 | 322 | TEST(Diff, Combine) { 323 | MemoryLeakDetection mld; 324 | 325 | { 326 | // diff(sin(a*b+c)*1*a, a) 327 | Node n = sin(Var("a") * Var("b") + Var("c")) * Num(1) * Var("a"); 328 | Node dn = Diff(n, "a"); 329 | dn->CheckParent(); 330 | cout << dn->ToString() << endl; 331 | ASSERT_EQ(dn->ToString(), "cos(a*b+c)*b*a+sin(a*b+c)"); 332 | } 333 | 334 | { 335 | // diff(sin(cos(x)+sin(x)), x) = 336 | Node n = sin(cos(Var("x")) + sin(Var("x"))); 337 | Node dn = Diff(n, "x"); 338 | dn->CheckParent(); 339 | cout << dn->ToString() << endl; 340 | ASSERT_EQ(dn->ToString(), "cos(cos(x)+sin(x))*(-(sin(x))+cos(x))"); 341 | } 342 | } 343 | 344 | TEST(Diff, Combine2) { 345 | MemoryLeakDetection mld; 346 | 347 | { 348 | Node n = "sin(x)/log(x*y)"_f; 349 | Node dn = Diff(n, "y"); 350 | dn->CheckParent(); 351 | cout << dn->ToString() << endl; 352 | 353 | // TODO 进一步化简 354 | } 355 | } -------------------------------------------------------------------------------- /tests/functions_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "memory_leak_detection.h" 7 | 8 | #include 9 | 10 | #include 11 | 12 | #undef max 13 | #undef min 14 | 15 | using namespace tomsolver; 16 | 17 | using std::cout; 18 | using std::endl; 19 | 20 | TEST(Function, Trigonometric) { 21 | MemoryLeakDetection mld; 22 | 23 | int count = 100; 24 | 25 | auto seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); 26 | cout << "seed = " << seed << endl; 27 | std::default_random_engine eng(seed); 28 | std::uniform_real_distribution unifNum; 29 | 30 | for (int i = 0; i < count; ++i) { 31 | double num = unifNum(eng); 32 | ASSERT_DOUBLE_EQ(sin(Num(num))->Vpa(), sin(num)); 33 | ASSERT_DOUBLE_EQ(cos(Num(num))->Vpa(), cos(num)); 34 | ASSERT_DOUBLE_EQ(tan(Num(num))->Vpa(), tan(num)); 35 | ASSERT_DOUBLE_EQ(asin(Num(num))->Vpa(), asin(num)); 36 | ASSERT_DOUBLE_EQ(acos(Num(num))->Vpa(), acos(num)); 37 | ASSERT_DOUBLE_EQ(atan(Num(num))->Vpa(), atan(num)); 38 | ASSERT_DOUBLE_EQ(sqrt(Num(num))->Vpa(), sqrt(num)); 39 | ASSERT_DOUBLE_EQ(log(Num(num))->Vpa(), log(num)); 40 | ASSERT_DOUBLE_EQ(log2(Num(num))->Vpa(), log2(num)); 41 | ASSERT_DOUBLE_EQ(log10(Num(num))->Vpa(), log10(num)); 42 | ASSERT_DOUBLE_EQ(exp(Num(num))->Vpa(), exp(num)); 43 | } 44 | } 45 | 46 | TEST(Function, InvalidNumber) { 47 | MemoryLeakDetection mld; 48 | double inf = std::numeric_limits::infinity(); 49 | double inf2 = -std::numeric_limits::infinity(); 50 | double nan = std::numeric_limits::quiet_NaN(); 51 | double dblMax = std::numeric_limits::max(); 52 | 53 | #define MY_ASSERT_THROW(statement, shouldThrow) \ 54 | if (shouldThrow) { \ 55 | try { \ 56 | statement; \ 57 | FAIL(); \ 58 | } catch (const MathError &err) { \ 59 | std::cerr << "[Expected Exception]" << err.what() << std::endl; \ 60 | } \ 61 | } else { \ 62 | try { \ 63 | statement; \ 64 | } catch (const MathError &err) { \ 65 | std::cerr << "[Unexpected Exception]" << err.what() << std::endl; \ 66 | FAIL(); \ 67 | } \ 68 | } 69 | 70 | auto Test = [&](bool shouldThrow) { 71 | MY_ASSERT_THROW((Num(inf) + Num(1))->Vpa(), shouldThrow); 72 | MY_ASSERT_THROW((Num(inf2) + Num(1))->Vpa(), shouldThrow); 73 | MY_ASSERT_THROW((Num(nan) + Num(1))->Vpa(), shouldThrow); 74 | 75 | MY_ASSERT_THROW((Num(1) / Num(0))->Vpa(), shouldThrow); 76 | 77 | // pow(DBL_DOUELB, 2) 78 | MY_ASSERT_THROW((Num(dblMax) ^ Num(2))->Vpa(), shouldThrow); 79 | 80 | MY_ASSERT_THROW(asin(Num(1.1))->Vpa(), shouldThrow); 81 | MY_ASSERT_THROW(asin(Num(-1.1))->Vpa(), shouldThrow); 82 | 83 | MY_ASSERT_THROW(acos(Num(1.1))->Vpa(), shouldThrow); 84 | MY_ASSERT_THROW(acos(Num(-1.1))->Vpa(), shouldThrow); 85 | 86 | MY_ASSERT_THROW(sqrt(Num(-0.1))->Vpa(), shouldThrow); 87 | 88 | MY_ASSERT_THROW(log(Num(0))->Vpa(), shouldThrow); 89 | MY_ASSERT_THROW(log2(Num(0))->Vpa(), shouldThrow); 90 | MY_ASSERT_THROW(log10(Num(0))->Vpa(), shouldThrow); 91 | }; 92 | 93 | // 默认配置下,无效值应该抛异常 94 | Test(true); 95 | 96 | // 手动关闭无效值检查,不应抛异常 97 | { 98 | Config::Get().throwOnInvalidValue = false; 99 | 100 | Test(false); 101 | 102 | // 恢复配置 103 | Config::Get().Reset(); 104 | } 105 | } 106 | 107 | TEST(Function, ToString) { 108 | MemoryLeakDetection mld; 109 | 110 | Node f = Var("r") * sin(Var("omega") / Num(2.0) + Var("phi")) + Var("c"); 111 | 112 | ASSERT_EQ(f->ToString(), "r*sin(omega/2+phi)+c"); 113 | } -------------------------------------------------------------------------------- /tests/helper.cpp: -------------------------------------------------------------------------------- 1 | #include "helper.h" 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | using std::cout; 11 | using std::endl; 12 | 13 | namespace tomsolver { 14 | 15 | std::pair CreateRandomExpresionTree(int len) { 16 | auto seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); 17 | cout << "seed = " << seed << endl; 18 | std::default_random_engine eng(seed); 19 | 20 | std::vector ops{MathOperator::MATH_POSITIVE, MathOperator::MATH_NEGATIVE, MathOperator::MATH_ADD, 21 | MathOperator::MATH_SUB, MathOperator::MATH_MULTIPLY, MathOperator::MATH_DIVIDE, 22 | MathOperator::MATH_SIN, MathOperator::MATH_COS, MathOperator::MATH_TAN, 23 | MathOperator::MATH_ARCSIN, MathOperator::MATH_ARCCOS, MathOperator::MATH_ARCTAN}; 24 | std::uniform_int_distribution unifOp(0, static_cast(ops.size()) - 1); 25 | std::uniform_real_distribution unifNum(-100.0, 100.0); 26 | double v = unifNum(eng); 27 | auto node = Num(v); 28 | 29 | for (int j = 0; j < len;) { 30 | double num = unifNum(eng); 31 | auto op = ops[unifOp(eng)]; 32 | 33 | bool frontOrBack = unifOp(eng) % 2; 34 | switch (op) { 35 | case MathOperator::MATH_POSITIVE: { 36 | node = +std::move(node); 37 | break; 38 | } 39 | case MathOperator::MATH_NEGATIVE: { 40 | v = -v; 41 | node = -std::move(node); 42 | break; 43 | } 44 | case MathOperator::MATH_ADD: 45 | if (frontOrBack) { 46 | v = num + v; 47 | node = Num(num) + std::move(node); 48 | } else { 49 | v += num; 50 | node += Num(num); 51 | } 52 | break; 53 | case MathOperator::MATH_SUB: 54 | if (frontOrBack) { 55 | v = num - v; 56 | node = Num(num) - std::move(node); 57 | } else { 58 | v -= num; 59 | node -= Num(num); 60 | } 61 | break; 62 | case MathOperator::MATH_MULTIPLY: 63 | if (frontOrBack) { 64 | v = num * v; 65 | node = Num(num) * std::move(node); 66 | } else { 67 | v *= num; 68 | node *= Num(num); 69 | } 70 | break; 71 | case MathOperator::MATH_DIVIDE: 72 | if (frontOrBack) { 73 | if (v == 0) { 74 | continue; 75 | } 76 | v = num / v; 77 | node = Num(num) / std::move(node); 78 | } else { 79 | if (num == 0) { 80 | continue; 81 | } 82 | v /= num; 83 | node /= Num(num); 84 | } 85 | break; 86 | case MathOperator::MATH_SIN: 87 | v = std::sin(v); 88 | node = sin(std::move(node)); 89 | break; 90 | case MathOperator::MATH_COS: 91 | v = std::cos(v); 92 | node = cos(std::move(node)); 93 | break; 94 | case MathOperator::MATH_TAN: 95 | v = std::tan(v); 96 | node = tan(std::move(node)); 97 | break; 98 | case MathOperator::MATH_ARCSIN: 99 | if (v < -1.0 || v > 1.0) { 100 | continue; 101 | } 102 | v = std::asin(v); 103 | node = asin(std::move(node)); 104 | break; 105 | case MathOperator::MATH_ARCCOS: 106 | if (v < -1.0 || v > 1.0) { 107 | continue; 108 | } 109 | v = std::acos(v); 110 | node = acos(std::move(node)); 111 | break; 112 | case MathOperator::MATH_ARCTAN: 113 | v = std::atan(v); 114 | node = atan(std::move(node)); 115 | break; 116 | default: 117 | assert(0); 118 | } 119 | 120 | ++j; 121 | } 122 | return {std::move(node), v}; 123 | } 124 | 125 | } // namespace tomsolver -------------------------------------------------------------------------------- /tests/helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace tomsolver { 5 | 6 | std::pair CreateRandomExpresionTree(int len); 7 | 8 | } // namespace tomsolver -------------------------------------------------------------------------------- /tests/linear_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "memory_leak_detection.h" 4 | 5 | #include 6 | 7 | using namespace tomsolver; 8 | 9 | TEST(Linear, Base) { 10 | MemoryLeakDetection mld; 11 | 12 | Mat A = {{2, 1, -5, 1}, {1, -5, 0, 7}, {0, 2, 1, -1}, {1, 6, -1, -4}}; 13 | Vec b = {13, -9, 6, 0}; 14 | 15 | auto x = SolveLinear(std::move(A), std::move(b)); 16 | 17 | Vec expected = {-66.5555555555555429, 25.6666666666666643, -18.777777777777775, 26.55555555555555}; 18 | 19 | ASSERT_EQ(x, expected); 20 | } -------------------------------------------------------------------------------- /tests/mat_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "memory_leak_detection.h" 5 | 6 | #include 7 | 8 | using namespace tomsolver; 9 | 10 | using std::cout; 11 | using std::endl; 12 | 13 | TEST(Mat, Multiply) { 14 | MemoryLeakDetection mld; 15 | 16 | Mat A = {{1, 2}, {3, 4}}; 17 | Mat B = {{6, 7}, {8, 9}}; 18 | 19 | Mat ret = A * B; 20 | Mat expected = {{22, 25}, {50, 57}}; 21 | ASSERT_EQ(ret, expected); 22 | } 23 | 24 | TEST(Mat, Inverse) { 25 | MemoryLeakDetection mld; 26 | 27 | { 28 | Mat A = {{1, 2}, {3, 4}}; 29 | auto inv = A.Inverse(); 30 | Mat expected = {{-2, 1}, {1.5, -0.5}}; 31 | ASSERT_EQ(inv, expected); 32 | } 33 | { 34 | Mat A = {{1, 2, 3}, {4, 5, 6}, {-2, 7, 8}}; 35 | auto inv = A.Inverse(); 36 | Mat expected = {{-0.083333333333333, 0.208333333333333, -0.125000000000000}, 37 | {-1.833333333333333, 0.583333333333333, 0.250000000000000}, 38 | {1.583333333333333, -0.458333333333333, -0.125000000000000}}; 39 | ASSERT_EQ(inv, expected); 40 | } 41 | { 42 | Mat A = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; 43 | try { 44 | auto inv = A.Inverse(); 45 | FAIL(); 46 | } catch (const MathError &e) { 47 | cout << "[Expected]" << e.what() << endl; 48 | } 49 | } 50 | } 51 | 52 | TEST(Mat, PositiveDetermine) { 53 | MemoryLeakDetection mld; 54 | 55 | { 56 | Mat A = {{1, 1, 1, 1}, {1, 2, 3, 4}, {1, 3, 6, 10}, {1, 4, 10, 20}}; 57 | ASSERT_TRUE(A.PositiveDetermine()); 58 | } 59 | { 60 | Mat A = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; 61 | ASSERT_TRUE(!A.PositiveDetermine()); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /tests/memory_leak_detection.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef WIN32 4 | #include "memory_leak_detection_win.h" 5 | #else 6 | class MemoryLeakDetection final { 7 | public: 8 | MemoryLeakDetection() {} 9 | }; 10 | #endif 11 | -------------------------------------------------------------------------------- /tests/memory_leak_detection_win.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef WIN32 4 | 5 | #include 6 | 7 | #undef max 8 | #undef min 9 | 10 | #define _CRTDBG_MAP_ALLOC // to get more details 11 | #include 12 | #include //for malloc and free 13 | 14 | #include 15 | 16 | #include 17 | #include 18 | 19 | class MemoryLeakDetection final { 20 | public: 21 | MemoryLeakDetection() { 22 | _CrtMemCheckpoint(&sOld); // take a snapshot 23 | } 24 | 25 | ~MemoryLeakDetection() { 26 | _CrtMemCheckpoint(&sNew); // take a snapshot 27 | if (_CrtMemDifference(&sDiff, &sOld, &sNew)) // if there is a difference 28 | { 29 | // OutputDebugString(TEXT("-----------_CrtMemDumpStatistics ---------")); 30 | //_CrtMemDumpStatistics(&sDiff); 31 | // OutputDebugString(TEXT("-----------_CrtMemDumpAllObjectsSince ---------")); 32 | //_CrtMemDumpAllObjectsSince(&sOld); 33 | // OutputDebugString(TEXT("-----------_CrtDumpMemoryLeaks ---------")); 34 | _CrtDumpMemoryLeaks(); 35 | 36 | EXPECT_TRUE(0 && "Memory leak is detected! See debug output for detail."); 37 | } 38 | } 39 | 40 | void SetBreakAlloc(long index) const noexcept { 41 | (index); 42 | _CrtSetBreakAlloc(index); 43 | } 44 | 45 | private: 46 | _CrtMemState sOld; 47 | _CrtMemState sNew; 48 | _CrtMemState sDiff; 49 | }; 50 | 51 | #endif -------------------------------------------------------------------------------- /tests/node_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "memory_leak_detection.h" 7 | 8 | #include 9 | 10 | #include 11 | 12 | using namespace tomsolver; 13 | 14 | using std::cout; 15 | using std::endl; 16 | 17 | TEST(Node, Num) { 18 | MemoryLeakDetection mld; 19 | 20 | auto n = Num(10); 21 | cout << n->ToString() << endl; 22 | ASSERT_EQ(n->ToString(), "10"); 23 | 24 | // 右值+右值 25 | auto n2 = Num(1) + Num(2); 26 | cout << n2->ToString() << endl; 27 | ASSERT_EQ(n2->ToString(), "1+2"); 28 | 29 | // 左值+左值 30 | auto n3 = n + n2; 31 | n3->CheckParent(); 32 | cout << n3->ToString() << endl; 33 | ASSERT_EQ(n3->ToString(), "10+1+2"); 34 | cout << n3->ToString() << endl; 35 | ASSERT_EQ(n3->ToString(), "10+1+2"); 36 | 37 | // 前面的 n n2 不应被释放 38 | ASSERT_EQ(n->ToString(), "10"); 39 | ASSERT_EQ(n2->ToString(), "1+2"); 40 | 41 | // 左值+右值 42 | auto n4 = n + Num(3); 43 | ASSERT_EQ(n4->ToString(), "10+3"); 44 | // 前面的 n 不应被释放 45 | ASSERT_EQ(n->ToString(), "10"); 46 | 47 | // 右值+左值 48 | auto n5 = Num(3) + n; 49 | ASSERT_EQ(n5->ToString(), "3+10"); 50 | // 前面的 n 不应被释放 51 | ASSERT_EQ(n->ToString(), "10"); 52 | 53 | n->CheckParent(); 54 | n2->CheckParent(); 55 | n4->CheckParent(); 56 | n5->CheckParent(); 57 | } 58 | 59 | TEST(Node, Var) { 60 | MemoryLeakDetection mld; 61 | 62 | ASSERT_ANY_THROW(Var("0a")); 63 | 64 | Var("a"); 65 | Var("a0"); 66 | Var("_"); 67 | Var("_a"); 68 | Var("_1"); 69 | 70 | auto expr = Var("a") - Num(1); 71 | cout << expr << endl; 72 | ASSERT_EQ(expr->ToString(), "a-1"); 73 | 74 | expr->CheckParent(); 75 | } 76 | 77 | TEST(Node, Op) { 78 | MemoryLeakDetection mld; 79 | 80 | ASSERT_ANY_THROW(Op(MathOperator::MATH_NULL)); 81 | 82 | ASSERT_NO_THROW(Op(MathOperator::MATH_ADD)); 83 | } 84 | 85 | TEST(Node, Clone) { 86 | MemoryLeakDetection mld; 87 | 88 | Node n = Var("a") + Var("b") * Var("c"); 89 | n->CheckParent(); 90 | 91 | Node n2 = Clone(n); 92 | n2->CheckParent(); 93 | 94 | ASSERT_EQ(n->ToString(), "a+b*c"); 95 | ASSERT_EQ(n2->ToString(), "a+b*c"); 96 | } 97 | 98 | TEST(Node, Move) { 99 | MemoryLeakDetection mld; 100 | 101 | Node n = Var("a") + Var("b") * Var("c"); 102 | Node n2 = Move(n); 103 | 104 | ASSERT_EQ(n, nullptr); 105 | ASSERT_EQ(n2->ToString(), "a+b*c"); 106 | 107 | n2->CheckParent(); 108 | } 109 | 110 | TEST(Node, AddEqual) { 111 | MemoryLeakDetection mld; 112 | 113 | auto n = Num(10); 114 | 115 | n += Num(1); 116 | ASSERT_EQ(n->ToString(), "10+1"); 117 | 118 | auto n2 = Num(20); 119 | n += n2; 120 | ASSERT_EQ(n->ToString(), "10+1+20"); 121 | 122 | // 前面的 n2 不应被释放 123 | ASSERT_EQ(n2->ToString(), "20"); 124 | 125 | n->CheckParent(); 126 | n2->CheckParent(); 127 | } 128 | 129 | TEST(Node, Sub) { 130 | MemoryLeakDetection mld; 131 | 132 | Node n = Num(10) - Num(-10); 133 | ASSERT_EQ(n->ToString(), "10-(-10)"); 134 | } 135 | 136 | TEST(Node, SubEqual) { 137 | MemoryLeakDetection mld; 138 | 139 | auto n = Num(10); 140 | 141 | n -= Num(1); 142 | ASSERT_EQ(n->ToString(), "10-1"); 143 | 144 | auto n2 = Num(20); 145 | n -= n2; 146 | ASSERT_EQ(n->ToString(), "10-1-20"); 147 | 148 | // 前面的 n2 不应被释放 149 | ASSERT_EQ(n2->ToString(), "20"); 150 | 151 | n->CheckParent(); 152 | n2->CheckParent(); 153 | } 154 | 155 | TEST(Node, Negative) { 156 | MemoryLeakDetection mld; 157 | 158 | { 159 | Node n = -Num(10); 160 | ASSERT_EQ(n->ToString(), "-10"); 161 | } 162 | 163 | { 164 | Node n = -Var("x"); 165 | ASSERT_EQ(n->ToString(), "-x"); 166 | } 167 | 168 | { 169 | Node n = +Var("y"); 170 | ASSERT_EQ(n->ToString(), "+y"); 171 | } 172 | 173 | { 174 | Node n = -(Var("x") + Num(2)); 175 | ASSERT_EQ(n->ToString(), "-(x+2)"); 176 | } 177 | 178 | { 179 | Node n = Var("y") + -(Var("x") + Num(2)); 180 | ASSERT_EQ(n->ToString(), "y+-(x+2)"); 181 | } 182 | 183 | { 184 | Node n = Var("y") + +(Var("x") + Num(2)); 185 | ASSERT_EQ(n->ToString(), "y++(x+2)"); 186 | } 187 | 188 | { 189 | Node n = atan(cos(-(+(-Num(87.9117553746407054) / Num(90.5933224572584663))))); 190 | ASSERT_DOUBLE_EQ(n->Vpa(), 0.51426347804323491); 191 | } 192 | } 193 | 194 | TEST(Node, MulEqual) { 195 | MemoryLeakDetection mld; 196 | 197 | auto n = Num(10); 198 | 199 | n *= Num(1); 200 | ASSERT_EQ(n->ToString(), "10*1"); 201 | 202 | auto n2 = Num(20); 203 | n *= n2; 204 | ASSERT_EQ(n->ToString(), "10*1*20"); 205 | 206 | // 前面的 n2 不应被释放 207 | ASSERT_EQ(n2->ToString(), "20"); 208 | 209 | n->CheckParent(); 210 | n2->CheckParent(); 211 | } 212 | 213 | TEST(Node, DivEqual) { 214 | MemoryLeakDetection mld; 215 | 216 | auto n = Num(10); 217 | 218 | n /= Num(1); 219 | ASSERT_EQ(n->ToString(), "10/1"); 220 | 221 | auto n2 = Num(20); 222 | n /= n2; 223 | ASSERT_EQ(n->ToString(), "10/1/20"); 224 | 225 | // 前面的 n2 不应被释放 226 | ASSERT_EQ(n2->ToString(), "20"); 227 | 228 | n->CheckParent(); 229 | n2->CheckParent(); 230 | } 231 | 232 | TEST(Node, Multiply) { 233 | MemoryLeakDetection mld; 234 | 235 | { 236 | auto expr = Var("a") + Var("b") * Var("c"); 237 | cout << expr << endl; 238 | ASSERT_EQ(expr->ToString(), "a+b*c"); 239 | 240 | expr->CheckParent(); 241 | } 242 | 243 | { 244 | auto expr = (Var("a") + Var("b")) * Var("c"); 245 | cout << expr << endl; 246 | ASSERT_EQ(expr->ToString(), "(a+b)*c"); 247 | 248 | expr->CheckParent(); 249 | } 250 | 251 | { 252 | auto expr = Num(1) + Num(2) * Num(3); 253 | cout << expr << " = " << expr->Vpa() << endl; 254 | ASSERT_DOUBLE_EQ(expr->Vpa(), 7.0); 255 | 256 | expr->CheckParent(); 257 | } 258 | 259 | { 260 | auto expr = (Num(1) + Num(2)) * Num(3); 261 | cout << expr << " = " << expr->Vpa() << endl; 262 | ASSERT_DOUBLE_EQ(expr->Vpa(), 9.0); 263 | 264 | expr->CheckParent(); 265 | } 266 | } 267 | 268 | TEST(Node, Divide) { 269 | MemoryLeakDetection mld; 270 | 271 | { 272 | auto expr = Var("a") + Var("b") / Var("c"); 273 | cout << expr << endl; 274 | ASSERT_EQ(expr->ToString(), "a+b/c"); 275 | 276 | expr->CheckParent(); 277 | } 278 | 279 | { 280 | auto expr = Num(1) + Num(2) / Num(4); 281 | cout << expr << " = " << expr->Vpa() << endl; 282 | ASSERT_DOUBLE_EQ(expr->Vpa(), 1.5); 283 | 284 | expr->CheckParent(); 285 | } 286 | 287 | { 288 | auto expr = (Num(1) + Num(2)) / Num(4); 289 | cout << expr << " = " << expr->Vpa() << endl; 290 | ASSERT_DOUBLE_EQ(expr->Vpa(), 0.75); 291 | 292 | expr->CheckParent(); 293 | } 294 | 295 | auto expr = Num(1) / Num(0); 296 | ASSERT_THROW(expr->Vpa(), MathError); 297 | 298 | expr->CheckParent(); 299 | } 300 | 301 | TEST(Node, Equal) { 302 | MemoryLeakDetection mld; 303 | 304 | Node n = Var("a") + Var("b") * Var("c"); 305 | Node n2 = Clone(n); 306 | 307 | ASSERT_TRUE(n->Equal(n)); 308 | ASSERT_TRUE(n->Equal(n2)); 309 | ASSERT_TRUE(n2->Equal(n)); 310 | 311 | ASSERT_TRUE(n->Equal(Var("a") + Var("b") * Var("c"))); 312 | ASSERT_TRUE((Var("a") + Var("b") * Var("c"))->Equal(n)); 313 | } -------------------------------------------------------------------------------- /tests/parse_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "memory_leak_detection.h" 6 | 7 | #include 8 | 9 | #include 10 | 11 | using namespace tomsolver; 12 | 13 | using std::cout; 14 | using std::deque; 15 | using std::endl; 16 | 17 | TEST(Parse, Base) { 18 | MemoryLeakDetection mld; 19 | std::setlocale(LC_ALL, ".UTF8"); 20 | 21 | { 22 | deque tokens = internal::ParseFunctions::ParseToTokens("1+2"); 23 | ASSERT_TRUE(tokens[0].node->Equal(Num(1))); 24 | ASSERT_TRUE(tokens[1].node->Equal(internal::Operator(MathOperator::MATH_ADD))); 25 | ASSERT_TRUE(tokens[2].node->Equal(Num(2))); 26 | } 27 | } 28 | 29 | TEST(Parse, Number) { 30 | MemoryLeakDetection mld; 31 | 32 | { 33 | deque tokens = internal::ParseFunctions::ParseToTokens(".12345"); 34 | ASSERT_TRUE(tokens[0].node->Equal(Num(.12345))); 35 | } 36 | 37 | { 38 | deque tokens = internal::ParseFunctions::ParseToTokens("7891.123"); 39 | ASSERT_TRUE(tokens[0].node->Equal(Num(7891.123))); 40 | } 41 | 42 | { 43 | deque tokens = internal::ParseFunctions::ParseToTokens("1e0"); 44 | ASSERT_TRUE(tokens[0].node->Equal(Num(1e0))); 45 | } 46 | 47 | std::default_random_engine eng( 48 | static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count())); 49 | 50 | std::uniform_real_distribution unif; 51 | 52 | for (int i = 0; i < 100; ++i) { 53 | double d = unif(eng); 54 | std::string expected = tomsolver::ToString(d); 55 | deque tokens = internal::ParseFunctions::ParseToTokens(expected); 56 | ASSERT_EQ(expected, tokens[0].node->ToString()); 57 | } 58 | } 59 | 60 | TEST(Parse, IllegalChar) { 61 | MemoryLeakDetection mld; 62 | 63 | try { 64 | deque tokens = internal::ParseFunctions::ParseToTokens("1#+2"); 65 | FAIL(); 66 | } catch (const SingleParseError &err) { 67 | cout << err.what() << endl; 68 | ASSERT_EQ(err.GetPos(), 0); 69 | } 70 | 71 | try { 72 | deque tokens = 73 | internal::ParseFunctions::ParseToTokens("a*cos(x1) + b*cos(x1-x2) + c*cos(?x1-x2-x3)"); 74 | FAIL(); 75 | } catch (const SingleParseError &err) { 76 | cout << err.what() << endl; 77 | ASSERT_EQ(err.GetPos(), 33); 78 | } 79 | } 80 | 81 | TEST(Parse, PositiveNegative) { 82 | MemoryLeakDetection mld; 83 | 84 | { 85 | deque tokens = internal::ParseFunctions::ParseToTokens("1/+2"); 86 | ASSERT_TRUE(tokens[0].node->Equal(Num(1))); 87 | ASSERT_TRUE(tokens[1].node->Equal(internal::Operator(MathOperator::MATH_DIVIDE))); 88 | ASSERT_TRUE(tokens[2].node->Equal(internal::Operator(MathOperator::MATH_POSITIVE))); 89 | ASSERT_TRUE(tokens[3].node->Equal(Num(2))); 90 | } 91 | 92 | { 93 | deque tokens = internal::ParseFunctions::ParseToTokens("1/-2"); 94 | ASSERT_TRUE(tokens[0].node->Equal(Num(1))); 95 | ASSERT_TRUE(tokens[1].node->Equal(internal::Operator(MathOperator::MATH_DIVIDE))); 96 | ASSERT_TRUE(tokens[2].node->Equal(internal::Operator(MathOperator::MATH_NEGATIVE))); 97 | ASSERT_TRUE(tokens[3].node->Equal(Num(2))); 98 | } 99 | 100 | { 101 | deque tokens = internal::ParseFunctions::ParseToTokens("-1--2"); 102 | ASSERT_TRUE(tokens[0].node->Equal(internal::Operator(MathOperator::MATH_NEGATIVE))); 103 | ASSERT_TRUE(tokens[1].node->Equal(Num(1))); 104 | ASSERT_TRUE(tokens[2].node->Equal(internal::Operator(MathOperator::MATH_SUB))); 105 | ASSERT_TRUE(tokens[3].node->Equal(internal::Operator(MathOperator::MATH_NEGATIVE))); 106 | ASSERT_TRUE(tokens[4].node->Equal(Num(2))); 107 | } 108 | } 109 | 110 | TEST(Parse, PostOrder) { 111 | MemoryLeakDetection mld; 112 | 113 | { 114 | deque tokens = internal::ParseFunctions::ParseToTokens("1*(2-3)"); 115 | 116 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 117 | 118 | ASSERT_TRUE(postOrder[0].node->Equal(Num(1))); 119 | ASSERT_TRUE(postOrder[1].node->Equal(Num(2))); 120 | ASSERT_TRUE(postOrder[2].node->Equal(Num(3))); 121 | ASSERT_TRUE(postOrder[3].node->Equal(internal::Operator(MathOperator::MATH_SUB))); 122 | ASSERT_TRUE(postOrder[4].node->Equal(internal::Operator(MathOperator::MATH_MULTIPLY))); 123 | } 124 | 125 | { 126 | deque tokens = internal::ParseFunctions::ParseToTokens("1*2-3"); 127 | 128 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 129 | 130 | ASSERT_TRUE(postOrder[0].node->Equal(Num(1))); 131 | ASSERT_TRUE(postOrder[1].node->Equal(Num(2))); 132 | ASSERT_TRUE(postOrder[2].node->Equal(internal::Operator(MathOperator::MATH_MULTIPLY))); 133 | ASSERT_TRUE(postOrder[3].node->Equal(Num(3))); 134 | ASSERT_TRUE(postOrder[4].node->Equal(internal::Operator(MathOperator::MATH_SUB))); 135 | } 136 | 137 | { 138 | deque tokens = internal::ParseFunctions::ParseToTokens("1-2-3"); 139 | 140 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 141 | 142 | ASSERT_TRUE(postOrder[0].node->Equal(Num(1))); 143 | ASSERT_TRUE(postOrder[1].node->Equal(Num(2))); 144 | ASSERT_TRUE(postOrder[2].node->Equal(internal::Operator(MathOperator::MATH_SUB))); 145 | ASSERT_TRUE(postOrder[3].node->Equal(Num(3))); 146 | ASSERT_TRUE(postOrder[4].node->Equal(internal::Operator(MathOperator::MATH_SUB))); 147 | } 148 | 149 | { 150 | deque tokens = internal::ParseFunctions::ParseToTokens("2^3^4"); 151 | 152 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 153 | 154 | ASSERT_TRUE(postOrder[0].node->Equal(Num(2))); 155 | ASSERT_TRUE(postOrder[1].node->Equal(Num(3))); 156 | ASSERT_TRUE(postOrder[2].node->Equal(Num(4))); 157 | ASSERT_TRUE(postOrder[3].node->Equal(internal::Operator(MathOperator::MATH_POWER))); 158 | ASSERT_TRUE(postOrder[4].node->Equal(internal::Operator(MathOperator::MATH_POWER))); 159 | } 160 | } 161 | 162 | TEST(Parse, PostOrderError) { 163 | MemoryLeakDetection mld; 164 | 165 | { 166 | deque tokens = internal::ParseFunctions::ParseToTokens("1*2-3)"); 167 | 168 | try { 169 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 170 | FAIL(); 171 | } catch (const ParseError &err) { 172 | cout << err.what() << endl; 173 | } 174 | } 175 | 176 | { 177 | deque tokens = internal::ParseFunctions::ParseToTokens("(1*2-3"); 178 | 179 | try { 180 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 181 | FAIL(); 182 | } catch (const ParseError &err) { 183 | cout << err.what() << endl; 184 | } 185 | } 186 | } 187 | 188 | TEST(Parse, BuildTree) { 189 | MemoryLeakDetection mld; 190 | { 191 | deque tokens = internal::ParseFunctions::ParseToTokens("1*(2-3)"); 192 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 193 | auto node = internal::ParseFunctions::BuildExpressionTree(postOrder); 194 | ASSERT_EQ(node->ToString(), "1*(2-3)"); 195 | node->CheckParent(); 196 | } 197 | 198 | { 199 | deque tokens = internal::ParseFunctions::ParseToTokens("1*2-3"); 200 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 201 | auto node = internal::ParseFunctions::BuildExpressionTree(postOrder); 202 | ASSERT_EQ(node->ToString(), "1*2-3"); 203 | node->CheckParent(); 204 | } 205 | 206 | { 207 | deque tokens = internal::ParseFunctions::ParseToTokens("x^2-y^2-7"); 208 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 209 | auto node = internal::ParseFunctions::BuildExpressionTree(postOrder); 210 | ASSERT_EQ(node->ToString(), "x^2-y^2-7"); 211 | node->CheckParent(); 212 | } 213 | } 214 | 215 | TEST(Parse, Mix) { 216 | MemoryLeakDetection mld; 217 | 218 | { 219 | deque tokens = 220 | internal::ParseFunctions::ParseToTokens("a*cos(x1) + b*cos(x1-x2) + c*cos(x1-x2-x3)"); 221 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 222 | auto node = internal::ParseFunctions::BuildExpressionTree(postOrder); 223 | node->CheckParent(); 224 | 225 | Node expected = Var("a") * cos(Var("x1")) + Var("b") * cos(Var("x1") - Var("x2")) + 226 | Var("c") * cos(Var("x1") - Var("x2") - Var("x3")); 227 | ASSERT_TRUE(node->Equal(expected)); 228 | } 229 | 230 | try { 231 | deque tokens = internal::ParseFunctions::ParseToTokens("x(1)*cos(2)"); 232 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 233 | auto node = internal::ParseFunctions::BuildExpressionTree(postOrder); 234 | FAIL(); 235 | } catch (const ParseError &err) { 236 | cout << err.what() << endl; 237 | } 238 | 239 | try { 240 | deque tokens = 241 | internal::ParseFunctions::ParseToTokens("x(1)*cos(x(2)) + x(2)*sin(x(1)) - 0.5"); 242 | auto postOrder = internal::ParseFunctions::InOrderToPostOrder(tokens); 243 | auto node = internal::ParseFunctions::BuildExpressionTree(postOrder); 244 | FAIL(); 245 | } catch (const ParseError &err) { 246 | cout << err.what() << endl; 247 | } 248 | } -------------------------------------------------------------------------------- /tests/power_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "memory_leak_detection.h" 4 | 5 | #include 6 | 7 | #include 8 | 9 | using namespace tomsolver; 10 | 11 | using std::cout; 12 | using std::endl; 13 | 14 | TEST(Power, Base) { 15 | MemoryLeakDetection mld; 16 | 17 | Node n = Num(2) ^ Num(3); 18 | ASSERT_EQ(n->ToString(), "2^3"); 19 | ASSERT_DOUBLE_EQ(n->Vpa(), 8); 20 | 21 | // CATION! 22 | Node n2 = Num(2) ^ Num(3) ^ Num(2); 23 | ASSERT_EQ(n2->ToString(), "(2^3)^2"); 24 | ASSERT_DOUBLE_EQ(n2->Vpa(), 64); 25 | 26 | // CATION! 27 | Node n3 = Num(2) ^ (Num(3) ^ Num(2)); 28 | ASSERT_EQ(n3->ToString(), "2^(3^2)"); 29 | ASSERT_DOUBLE_EQ(n3->Vpa(), 512); 30 | } 31 | 32 | TEST(Power, Decimal) { 33 | MemoryLeakDetection mld; 34 | 35 | Node n3 = Num(1.1) ^ Num(0.25); 36 | ASSERT_DOUBLE_EQ(n3->Vpa(), std::pow(1.1, 0.25)); 37 | } -------------------------------------------------------------------------------- /tests/random_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "helper.h" 4 | #include "memory_leak_detection.h" 5 | 6 | 7 | #include 8 | 9 | #include 10 | 11 | using namespace tomsolver; 12 | 13 | using std::cout; 14 | using std::endl; 15 | 16 | TEST(Node, Random) { 17 | MemoryLeakDetection mld; 18 | 19 | int maxCount = 10; 20 | 21 | auto seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); 22 | cout << "seed = " << seed << endl; 23 | std::default_random_engine eng(seed); 24 | 25 | std::uniform_int_distribution unifCount(1, maxCount); 26 | for (int i = 0; i < 10; ++i) { 27 | int count = unifCount(eng); 28 | 29 | auto pr = CreateRandomExpresionTree(count); 30 | Node &node = pr.first; 31 | double v = pr.second; 32 | 33 | node->CheckParent(); 34 | 35 | double result = node->Vpa(); 36 | cout << node->ToString() << endl; 37 | cout << "\t result = " << result << endl; 38 | cout << "\t expected = " << v << endl; 39 | ASSERT_DOUBLE_EQ(result, v); 40 | 41 | // clone 42 | Node n2 = Clone(node); 43 | ASSERT_DOUBLE_EQ(result, n2->Vpa()); 44 | n2->CheckParent(); 45 | 46 | cout << endl; 47 | } 48 | } 49 | 50 | TEST(Clone, DoNotStackOverFlow) { 51 | MemoryLeakDetection mld; 52 | 53 | // 构造一个随机的长表达式 54 | auto pr = CreateRandomExpresionTree(10000); 55 | Node &node = pr.first; 56 | 57 | // clone,不应爆栈 58 | Node n2 = Clone(node); 59 | 60 | ASSERT_TRUE(node->Equal(n2)); 61 | } 62 | 63 | TEST(Vpa, DoNotStackOverFlow) { 64 | MemoryLeakDetection mld; 65 | 66 | // 构造一个随机的长表达式 67 | auto pr = CreateRandomExpresionTree(10000); 68 | Node &node = pr.first; 69 | double v = pr.second; 70 | 71 | double result = node->Vpa(); 72 | 73 | cout << "\t result = " << result << endl; 74 | cout << "\t expected = " << v << endl; 75 | ASSERT_DOUBLE_EQ(result, v); 76 | } 77 | 78 | TEST(ToString, DoNotStackOverFlow) { 79 | MemoryLeakDetection mld; 80 | 81 | // 构造一个随机的长表达式 82 | auto pr = CreateRandomExpresionTree(10000); 83 | Node &node = pr.first; 84 | 85 | // 输出表达式字符串,不应爆栈 86 | std::string s = node->ToString(); 87 | } -------------------------------------------------------------------------------- /tests/simplify_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "helper.h" 5 | #include "memory_leak_detection.h" 6 | 7 | #include 8 | 9 | #include 10 | 11 | using namespace tomsolver; 12 | 13 | using std::cout; 14 | using std::endl; 15 | 16 | TEST(Simplify, Base) { 17 | MemoryLeakDetection mld; 18 | 19 | Node n = sin(Num(0)); 20 | Simplify(n); 21 | 22 | ASSERT_EQ(n->ToString(), "0"); 23 | 24 | Node n2 = Num(1) + Num(2) * Num(3); 25 | Simplify(n2); 26 | 27 | ASSERT_EQ(n2->ToString(), "7"); 28 | 29 | ASSERT_TRUE(n2->Equal(Num(7))); 30 | } 31 | 32 | TEST(Simplify, Add) { 33 | MemoryLeakDetection mld; 34 | 35 | { 36 | Node n = Var("x") + Num(0); 37 | Simplify(n); 38 | ASSERT_EQ(n->ToString(), "x"); 39 | n->CheckParent(); 40 | } 41 | 42 | { 43 | Node n = Num(0) + Var("x"); 44 | Simplify(n); 45 | ASSERT_EQ(n->ToString(), "x"); 46 | n->CheckParent(); 47 | } 48 | } 49 | 50 | TEST(Simplify, Multiply) { 51 | MemoryLeakDetection mld; 52 | 53 | { 54 | Node n = Var("x") * Num(1) * Var("y") * Var("z"); 55 | Simplify(n); 56 | ASSERT_EQ(n->ToString(), "x*y*z"); 57 | n->CheckParent(); 58 | } 59 | 60 | { 61 | Node n = cos(Var("x")) * Num(1); 62 | Simplify(n); 63 | ASSERT_EQ(n->ToString(), "cos(x)"); 64 | n->CheckParent(); 65 | } 66 | 67 | { 68 | Node n = Num(1) * Var("x") * Num(0) + Num(0) * Var("y"); 69 | Simplify(n); 70 | ASSERT_EQ(n->ToString(), "0"); 71 | n->CheckParent(); 72 | } 73 | } 74 | 75 | TEST(Simplify, DoNotStackOverFlow) { 76 | MemoryLeakDetection mld; 77 | 78 | // 构造一个随机的长表达式 79 | auto pr = CreateRandomExpresionTree(100000); 80 | Node &node = pr.first; 81 | 82 | Simplify(node); 83 | } -------------------------------------------------------------------------------- /tests/solve_base_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | 8 | #include "memory_leak_detection.h" 9 | 10 | #include 11 | 12 | #include 13 | 14 | using namespace tomsolver; 15 | 16 | using std::cout; 17 | using std::endl; 18 | 19 | TEST(SolveBase, FindAlphaByArmijo) { 20 | MemoryLeakDetection mld; 21 | 22 | auto g = [](const Vec &x) -> Vec { 23 | return {pow(x[0] - 4, 4) + pow(x[1] - 3, 2) + 4 * pow(x[2] + 5, 4)}; 24 | }; 25 | 26 | auto dg = [](const Vec &x) -> Vec { 27 | return {4 * pow(x[0] - 4, 3), 2 * (x[1] - 3), 16 * pow(x[2] + 5, 3)}; 28 | }; 29 | 30 | Vec x{4, 2, -1}; 31 | Vec d = -Vec{0, -2, 1024}; 32 | double alpha = Armijo(x, d, g, dg); 33 | cout << alpha << endl; 34 | 35 | // FIXME: not match got results 36 | // double expected = 0.003866; 37 | } 38 | 39 | TEST(SolveBase, FindAlpha) { 40 | 41 | auto g = [](const Vec &x) -> Vec { 42 | return {pow(x[0] - 4, 4), pow(x[1] - 3, 2), 4 * pow(x[2] + 5, 4)}; 43 | }; 44 | 45 | Vec x{4, 2, -1}; 46 | Vec d = -Vec{0, -2, 1024}; 47 | double alpha = FindAlpha(x, d, g); 48 | cout << alpha << endl; 49 | 50 | // FIXME: not match got results 51 | // double expected = 0.003866; 52 | } 53 | 54 | TEST(SolveBase, Base) { 55 | // the example of this test is from: https://zhuanlan.zhihu.com/p/136889381 56 | 57 | MemoryLeakDetection mld; 58 | 59 | std::setlocale(LC_ALL, ".UTF8"); 60 | 61 | /* 62 | 以一个平面三轴机器人为例,运动学方程为 63 | a = 0.425; b = 0.39243; c=0.109; 64 | y = [ a*cos(x(1)) + b*cos(x(1)-x(2)) + c*cos(x(1)-x(2)-x(3)), 65 | a*sin(x(1)) + b*sin(x(1)-x(2)) + c*sin(x(1)-x(2)-x(3)), 66 | x(1)-x(2)-x(3) ]; 67 | */ 68 | Node f1 = Var("a") * cos(Var("x1")) + Var("b") * cos(Var("x1") - Var("x2")) + 69 | Var("c") * cos(Var("x1") - Var("x2") - Var("x3")); 70 | Node f2 = Var("a") * sin(Var("x1")) + Var("b") * sin(Var("x1") - Var("x2")) + 71 | Var("c") * sin(Var("x1") - Var("x2") - Var("x3")); 72 | Node f3 = Var("x1") - Var("x2") - Var("x3"); 73 | 74 | SymVec f{Clone(f1), Clone(f2), Clone(f3)}; 75 | 76 | // 目标位置为:[0.5 0.4 0] 77 | SymVec b{Num(0.5), Num(0.4), Num(0)}; 78 | SymVec equations = f - b; 79 | equations.Subs(VarsTable{{"a", 0.425}, {"b", 0.39243}, {"c", 0.109}}); 80 | 81 | // 初值表 82 | VarsTable varsTable{{"x1", 1}, {"x2", 1}, {"x3", 1}}; 83 | 84 | // 期望值 85 | VarsTable expected{{"x1", 1.5722855035930956}, {"x2", 1.6360330989069252}, {"x3", -0.0637475947386077}}; 86 | 87 | // Newton-Raphson方法 88 | { 89 | VarsTable got = SolveByNewtonRaphson(equations, varsTable); 90 | cout << got << endl; 91 | 92 | ASSERT_EQ(got, expected); 93 | } 94 | 95 | // LM方法 96 | { 97 | VarsTable got = SolveByLM(equations, varsTable); 98 | cout << got << endl; 99 | 100 | ASSERT_EQ(got, expected); 101 | } 102 | } 103 | 104 | TEST(SolveBase, IndeterminateEquation) { 105 | MemoryLeakDetection mld; 106 | 107 | std::setlocale(LC_ALL, ".UTF8"); 108 | 109 | SymVec f = { 110 | "cos(x1) + cos(x1-x2) + cos(x1-x2-x3) - 1"_f, 111 | "sin(x1) + sin(x1-x2) + sin(x1-x2-x3) + 2"_f, 112 | }; 113 | 114 | // 不定方程,应该抛出异常 115 | try { 116 | VarsTable got = Solve(f); 117 | FAIL(); 118 | } catch (const MathError &e) { 119 | cout << e.what() << endl; 120 | } 121 | 122 | // 设置为允许不定方程 123 | Config::Get().allowIndeterminateEquation = true; 124 | 125 | // 结束时恢复设置 126 | std::shared_ptr defer(nullptr, [](auto) { 127 | Config::Get().Reset(); 128 | }); 129 | 130 | VarsTable got = Solve(f); 131 | cout << got << endl; 132 | } -------------------------------------------------------------------------------- /tests/solve_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "memory_leak_detection.h" 7 | 8 | #include 9 | 10 | #include 11 | 12 | using namespace tomsolver; 13 | 14 | using std::cout; 15 | using std::endl; 16 | 17 | TEST(Solve, Base) { 18 | // the example of this test is from: https://zhuanlan.zhihu.com/p/136889381 19 | 20 | MemoryLeakDetection mld; 21 | 22 | std::setlocale(LC_ALL, ".UTF8"); 23 | 24 | /* 25 | 以一个平面三轴机器人为例,运动学方程为 26 | a = 0.425; b = 0.39243; c=0.109; 27 | y = [ a*cos(x(1)) + b*cos(x(1)-x(2)) + c*cos(x(1)-x(2)-x(3)), 28 | a*sin(x(1)) + b*sin(x(1)-x(2)) + c*sin(x(1)-x(2)-x(3)), 29 | x(1)-x(2)-x(3) ]; 30 | */ 31 | SymVec f = { 32 | "a*cos(x1) + b*cos(x1-x2) + c*cos(x1-x2-x3)"_f, 33 | "a*sin(x1) + b*sin(x1-x2) + c*sin(x1-x2-x3)"_f, 34 | "x1-x2-x3"_f, 35 | }; 36 | 37 | // 目标位置为:[0.5 0.4 0] 38 | SymVec b{Num(0.5), Num(0.4), Num(0)}; 39 | SymVec equations = f - b; 40 | equations.Subs(VarsTable{{"a", 0.425}, {"b", 0.39243}, {"c", 0.109}}); 41 | 42 | // 期望值 43 | VarsTable expected{{"x1", 1.5722855035930956}, {"x2", 1.6360330989069252}, {"x3", -0.0637475947386077}}; 44 | 45 | // Newton-Raphson方法 46 | { 47 | Config::Get().nonlinearMethod = NonlinearMethod::NEWTON_RAPHSON; 48 | 49 | // 结束时恢复设置 50 | std::shared_ptr defer(nullptr, [&](...) { 51 | Config::Get().Reset(); 52 | }); 53 | 54 | VarsTable got = Solve(equations); 55 | cout << got << endl; 56 | 57 | ASSERT_EQ(got, expected); 58 | } 59 | 60 | // LM方法 61 | { 62 | Config::Get().nonlinearMethod = NonlinearMethod::LM; 63 | 64 | // 结束时恢复设置 65 | std::shared_ptr defer(nullptr, [&](...) { 66 | Config::Get().Reset(); 67 | }); 68 | 69 | VarsTable got = Solve(equations); 70 | cout << got << endl; 71 | 72 | ASSERT_EQ(got, expected); 73 | } 74 | } 75 | 76 | TEST(Solve, Case1) { 77 | MemoryLeakDetection mld; 78 | 79 | std::setlocale(LC_ALL, ".UTF8"); 80 | 81 | /* 82 | 83 | Matlab code: 84 | from: https://ww2.mathworks.cn/help/optim/ug/fsolve.html 85 | 86 | root2d.m: 87 | function F = root2d(x) 88 | F(1) = exp(-exp(-(x(1)+x(2)))) - x(2)*(1+x(1)^2); 89 | F(2) = x(1)*cos(x(2)) + x(2)*sin(x(1)) - 0.5; 90 | end 91 | 92 | root2d_solve.m: 93 | format long 94 | fun = @root2d; 95 | x0 = [0,0]; 96 | x = fsolve(fun,x0) 97 | 98 | result: 99 | x = 100 | 101 | 0.353246561920553 0.606082026502285 102 | 103 | 104 | */ 105 | 106 | // 设置初值为0.0 107 | Config::Get().initialValue = 0.0; 108 | 109 | // 设置容差为1.0e-6,因为Matlab的默认容差是1.0e-6 110 | Config::Get().epsilon = 1.0e-6; 111 | 112 | // 结束时恢复设置 113 | std::shared_ptr defer(nullptr, [&](...) { 114 | Config::Get().Reset(); 115 | }); 116 | 117 | // 构造方程组 118 | SymVec f = { 119 | "exp(-exp(-(x1 + x2))) - x2 * (1 + x1 ^ 2)"_f, 120 | "x1 * cos(x2) + x2 * sin(x1) - 0.5"_f, 121 | }; 122 | 123 | // 求解,结果保存到ans 124 | VarsTable ans = Solve(f); 125 | 126 | // 打印出ans 127 | cout << ans << endl; 128 | 129 | // 单独获取变量的值 130 | cout << "x1 = " << ans["x1"] << endl; 131 | cout << "x2 = " << ans["x2"] << endl; 132 | 133 | ASSERT_EQ(ans, VarsTable({{"x1", 0.353246561920553}, {"x2", 0.606082026502285}})); 134 | } 135 | 136 | TEST(Solve, Case2) { 137 | MemoryLeakDetection mld; 138 | 139 | std::setlocale(LC_ALL, ".UTF8"); 140 | 141 | /* 142 | 143 | Translate from Matlab code: 144 | 145 | fun = @(x)x*x*x - [1,2;3,4]; 146 | x0 = ones(2); 147 | format long; 148 | fsolve(fun,x0) 149 | 150 | */ 151 | Config::Get().nonlinearMethod = NonlinearMethod::LM; 152 | 153 | // 结束时恢复设置 154 | std::shared_ptr defer(nullptr, [](auto) { 155 | Config::Get().Reset(); 156 | }); 157 | 158 | // 构造符号矩阵: [a b; c d] 159 | SymMat X({{Var("a"), Var("b")}, {Var("c"), Var("d")}}); 160 | 161 | Mat B{{1, 2}, {3, 4}}; 162 | 163 | // 计算出矩阵X*X*X-B 164 | auto FMat = X * X * X - B; 165 | 166 | // 提取矩阵的每一个元素,构成4个方程组成的符号向量 167 | auto F = FMat.ToSymVecOneByOne(); 168 | 169 | cout << F << endl; 170 | 171 | // 把符号向量F作为方程组进行求解 172 | VarsTable ans = Solve(F); 173 | 174 | cout << ans << endl; 175 | 176 | VarsTable expected{ 177 | {"a", -0.129148906397607}, {"b", 0.8602157139938529}, {"c", 1.2903235709907794}, {"d", 1.1611746645931726}}; 178 | 179 | ASSERT_EQ(ans, expected); 180 | } 181 | 182 | TEST(Solve, Case3) { 183 | MemoryLeakDetection mld; 184 | 185 | std::setlocale(LC_ALL, ".UTF8"); 186 | 187 | Config::Get().nonlinearMethod = NonlinearMethod::LM; 188 | 189 | // 结束时恢复设置 190 | std::shared_ptr defer(nullptr, [&](...) { 191 | Config::Get().Reset(); 192 | }); 193 | 194 | SymVec f{ 195 | Parse("a/(b^2)-c/(d^2)"), 196 | Parse("129.56108*b-(a/(b^2)+1/a-2*b/(a^2))"), 197 | Parse("129.56108*d-(d/(c^2)-c/(d^2)-1/a)"), 198 | Parse("5*exp(1)-7-(2/3*pi*a^2*b+((sqrt(3)*c^2)/(3*sqrt(c^2/3+d^2))+a-c)^2*pi*d^2/(c^2/3+d^2))"), 199 | }; 200 | 201 | f.Subs(VarsTable{{"pi", PI}}); 202 | 203 | cout << f << endl; 204 | 205 | auto ans = Solve(f); 206 | 207 | cout << ans << endl; 208 | } 209 | 210 | TEST(Solve, Case4) { 211 | MemoryLeakDetection mld; 212 | 213 | std::setlocale(LC_ALL, ".UTF8"); 214 | 215 | SymVec f{ 216 | Parse("x^2+y^2-25"), 217 | Parse("x^2-y^2-7"), 218 | }; 219 | 220 | cout << f << endl; 221 | 222 | VarsTable initialValues{{"x", 4.1}, {"y", 3.1}}; 223 | 224 | auto ans = Solve(f, initialValues); 225 | 226 | cout << ans << endl; 227 | } 228 | -------------------------------------------------------------------------------- /tests/subs_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "memory_leak_detection.h" 5 | 6 | #include 7 | 8 | #include 9 | 10 | using namespace tomsolver; 11 | 12 | TEST(Subs, Base) { 13 | MemoryLeakDetection mld; 14 | 15 | Node n = Var("x"); 16 | ASSERT_EQ(Subs(n, "x", Var("y"))->ToString(), "y"); 17 | 18 | ASSERT_EQ(Subs(n, "x", Num(100))->ToString(), "100"); 19 | 20 | ASSERT_DOUBLE_EQ(Subs(std::move(n), "x", Num(99))->Vpa(), 99.0); 21 | ASSERT_EQ(n, nullptr); 22 | } 23 | 24 | TEST(Subs, Combine) { 25 | MemoryLeakDetection mld; 26 | 27 | { 28 | // x*y+sin(x) 29 | Node n = Var("x") * Var("y") + sin(Var("x")); 30 | 31 | n = Subs(std::move(n), "x", Var("x") + Num(1)); 32 | 33 | ASSERT_EQ(n->ToString(), "(x+1)*y+sin(x+1)"); 34 | } 35 | 36 | { 37 | // r*sin(x+y) 38 | Node n = Var("r") * sin(Var("x") + Var("y")); 39 | 40 | // -> 100*sin(360deg+30deg) == 50 41 | n = Subs(std::move(n), "x", Num(radians(360.0))); 42 | n = Subs(std::move(n), "y", Num(radians(30.0))); 43 | n = Subs(std::move(n), "r", Num(100)); 44 | 45 | ASSERT_DOUBLE_EQ(n->Vpa(), 50.0); 46 | } 47 | } 48 | 49 | TEST(Subs, Multiple) { 50 | MemoryLeakDetection mld; 51 | 52 | { 53 | // x*y+sin(x) 54 | Node n = Var("x") * Var("y") + sin(Var("x")); 55 | 56 | // 交换x y 57 | n = Subs(std::move(n), {"x", "y"}, {Var("y"), Var("x")}); 58 | ASSERT_EQ(n->ToString(), "y*x+sin(y)"); 59 | 60 | // x -> cos(y) 61 | n = Subs(std::move(n), {"x"}, {cos(Var("y"))}); 62 | ASSERT_EQ(n->ToString(), "y*cos(y)+sin(y)"); 63 | } 64 | } -------------------------------------------------------------------------------- /tests/symmat_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "memory_leak_detection.h" 5 | 6 | #include 7 | 8 | #include 9 | 10 | using namespace tomsolver; 11 | 12 | using std::cout; 13 | using std::endl; 14 | 15 | TEST(SymMat, Base) { 16 | MemoryLeakDetection mld; 17 | 18 | SymVec a{Var("a"), Var("b"), Var("c")}; 19 | 20 | cout << a.ToString() << endl; 21 | 22 | Node x = Var("x"); 23 | Node y = Var("y"); 24 | Node f1 = (sin(x) ^ Num(2)) + x * y + y - Num(3); 25 | Node f2 = Num(4) * x + (y ^ Num(2)); 26 | 27 | SymVec f{std::move(f1), std::move(f2)}; 28 | 29 | cout << f.ToString() << endl; 30 | } 31 | 32 | TEST(SymMat, Multiply) { 33 | MemoryLeakDetection mld; 34 | 35 | SymMat X = {{Var("a"), Var("b")}, {Var("c"), Var("d")}}; 36 | 37 | SymMat ret = X * X; 38 | 39 | SymMat expected = {{Var("a") * Var("a") + Var("b") * Var("c"), Var("a") * Var("b") + Var("b") * Var("d")}, 40 | {Var("c") * Var("a") + Var("d") * Var("c"), Var("c") * Var("b") + Var("d") * Var("d")}}; 41 | 42 | cout << ret << endl; 43 | 44 | ASSERT_EQ(ret, expected); 45 | } -------------------------------------------------------------------------------- /tests/to_string_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "memory_leak_detection.h" 4 | 5 | #include 6 | 7 | #ifdef WIN32 8 | #undef max 9 | #undef min 10 | #endif 11 | 12 | using namespace tomsolver; 13 | 14 | using std::cout; 15 | using std::endl; 16 | 17 | TEST(ToString, Base) { 18 | MemoryLeakDetection mld; 19 | 20 | ASSERT_EQ(ToString(0.0), "0"); 21 | ASSERT_EQ(ToString(1.0), "1"); 22 | ASSERT_EQ(ToString(0.1), "0.1"); 23 | ASSERT_EQ(ToString(0.12), "0.12"); 24 | ASSERT_EQ(ToString(0.123456789123450), "0.12345678912345"); 25 | ASSERT_EQ(ToString(1234567890.0), "1234567890"); 26 | 27 | // 15位 28 | ASSERT_EQ(ToString(123456789012345), "123456789012345"); 29 | ASSERT_EQ(ToString(-123456789012345), "-123456789012345"); 30 | 31 | // 16位 32 | ASSERT_EQ(ToString(1234567890123456), "1234567890123456"); 33 | ASSERT_EQ(ToString(-1234567890123456), "-1234567890123456"); 34 | 35 | ASSERT_EQ(ToString(1.0e0), "1"); 36 | ASSERT_EQ(ToString(1e0), "1"); 37 | ASSERT_EQ(ToString(1e1), "10"); 38 | ASSERT_EQ(ToString(1e15), "1000000000000000"); 39 | ASSERT_EQ(ToString(1e16), "1e+16"); 40 | ASSERT_EQ(ToString(1.0e16), "1e+16"); 41 | ASSERT_EQ(ToString(1e-16), "9.9999999999999998e-17"); 42 | ASSERT_EQ(ToString(1.0e-16), "9.9999999999999998e-17"); 43 | 44 | ASSERT_EQ(ToString(std::numeric_limits::min()), "2.2250738585072014e-308"); 45 | ASSERT_EQ(ToString(std::numeric_limits::max()), "1.7976931348623157e+308"); 46 | ASSERT_EQ(ToString(std::numeric_limits::denorm_min()), "4.9406564584124654e-324"); 47 | ASSERT_EQ(ToString(std::numeric_limits::lowest()), "-1.7976931348623157e+308"); 48 | } --------------------------------------------------------------------------------