├── .clang-format ├── .clang-format-check.sh ├── .clang-format-common.sh ├── .clang-format-fix.sh ├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── doc ├── CMakeLists.txt └── Doxyfile.in ├── include └── DDMPC │ ├── Dataset.h │ ├── MathUtils.h │ ├── StateEq.h │ ├── TorchUtils.h │ └── Training.h ├── msg ├── Dataset.msg └── StandardScaler.msg ├── package.xml ├── samples └── CMakeLists.txt ├── src ├── CMakeLists.txt ├── Dataset.cpp ├── StateEq.cpp └── Training.cpp ├── srv ├── GenerateDataset.srv └── RunSimOnce.srv └── tests ├── CMakeLists.txt ├── data └── .gitignore ├── scripts ├── plotTestMpcPushWalk.py ├── sampleClientSimTestMpcCart.py ├── simTestMpcCart.py └── simTestMpcCartMujoco.py ├── src ├── TestDataset.cpp ├── TestMathUtils.cpp ├── TestMpcCart.cpp ├── TestMpcCartWalk.cpp ├── TestMpcOscillator.cpp ├── TestMpcPushWalk.cpp ├── TestStateEq.cpp ├── TestTorchUtils.cpp └── TestTraining.cpp └── test ├── TestMpcCart.test └── TestMpcCartWalk.test /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | AccessModifierOffset: -2 4 | AlignAfterOpenBracket: Align 5 | AlignConsecutiveAssignments: false 6 | AlignConsecutiveDeclarations: false 7 | AlignEscapedNewlinesLeft: true 8 | AlignOperands: true 9 | AlignTrailingComments: false 10 | AllowAllParametersOfDeclarationOnNextLine: false 11 | AllowShortBlocksOnASingleLine: false 12 | AllowShortCaseLabelsOnASingleLine: false 13 | AllowShortFunctionsOnASingleLine: Empty 14 | AllowShortIfStatementsOnASingleLine: true 15 | AllowShortLoopsOnASingleLine: true 16 | AlwaysBreakAfterDefinitionReturnType: None 17 | AlwaysBreakAfterReturnType: None 18 | AlwaysBreakBeforeMultilineStrings: false 19 | AlwaysBreakTemplateDeclarations: true 20 | BinPackArguments: true 21 | BinPackParameters: false 22 | BreakBeforeBinaryOperators: NonAssignment 23 | BreakBeforeBraces: Allman 24 | BreakBeforeInheritanceComma: false 25 | BreakBeforeTernaryOperators: true 26 | BreakConstructorInitializersBeforeComma: false 27 | BreakConstructorInitializers: BeforeColon 28 | BreakAfterJavaFieldAnnotations: false 29 | BreakStringLiterals: true 30 | ColumnLimit: 120 31 | CommentPragmas: '^ IWYU pragma:' 32 | CompactNamespaces: false 33 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 34 | ConstructorInitializerIndentWidth: 0 35 | ContinuationIndentWidth: 4 36 | Cpp11BracedListStyle: true 37 | DerivePointerAlignment: false 38 | DisableFormat: false 39 | ExperimentalAutoDetectBinPacking: false 40 | FixNamespaceComments: true 41 | ForEachMacros: 42 | - foreach 43 | - Q_FOREACH 44 | - BOOST_FOREACH 45 | IncludeBlocks: Preserve 46 | IncludeCategories: 47 | - Regex: '^( $tmpfile 11 | if ! [[ -z `diff $tmpfile $f` ]]; then 12 | echo "Wrong formatting in $f" 13 | out=1 14 | fi 15 | done 16 | rm -f $tmpfile 17 | if [[ $out -eq 1 ]]; then 18 | echo "You can run ./.clang-format-fix.sh to fix the issues locally, then commit/push again" 19 | fi 20 | exit $out 21 | -------------------------------------------------------------------------------- /.clang-format-common.sh: -------------------------------------------------------------------------------- 1 | # This script is meant to be sourced from other scripts 2 | 3 | # Check for clang-format, prefer 10 if available 4 | if [[ -x "$(command -v clang-format-10)" ]]; then 5 | clang_format=clang-format-10 6 | elif [[ -x "$(command -v clang-format)" ]]; then 7 | clang_format=clang-format 8 | else 9 | echo "clang-format or clang-format-10 must be installed" 10 | exit 1 11 | fi 12 | 13 | # Find all source files in the project minus those that are auto-generated or we do not maintain 14 | src_files=`find include src tests samples -name '*.cpp' -or -name '*.h' -or -name '*.hpp'` 15 | -------------------------------------------------------------------------------- /.clang-format-fix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | readonly this_dir=`cd $(dirname $0); pwd` 4 | cd $this_dir 5 | source .clang-format-common.sh 6 | 7 | for f in ${src_files}; do 8 | $clang_format -style=file -i $f 9 | done 10 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI of DataDrivenMPC 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' 7 | pull_request: 8 | branches: 9 | - '**' 10 | schedule: 11 | - cron: '0 0 * * 0' 12 | 13 | jobs: 14 | 15 | clang-format: 16 | runs-on: ubuntu-20.04 17 | steps: 18 | - name: Checkout repository code 19 | uses: actions/checkout@v2 20 | - name: Install clang-format-10 21 | run: | 22 | sudo apt-get -qq update 23 | sudo apt-get -qq install clang-format-10 24 | - name: Run clang-format-check 25 | run: | 26 | ./.clang-format-check.sh 27 | 28 | build-and-test: 29 | strategy: 30 | fail-fast: false 31 | matrix: 32 | os: [ubuntu-20.04] 33 | build-type: [Debug, RelWithDebInfo] 34 | runs-on: ${{ matrix.os }} 35 | steps: 36 | - name: Set ROS version 37 | run: | 38 | if [ "${{ matrix.os }}" == "ubuntu-20.04" ] 39 | then 40 | echo "ROS_DISTRO=noetic" >> $GITHUB_ENV 41 | echo "PYTHON_PACKAGE_PREFIX=python3" >> $GITHUB_ENV 42 | else # if [ "${{ matrix.os }}" == "ubuntu-18.04" ] 43 | echo "ROS_DISTRO=melodic" >> $GITHUB_ENV 44 | echo "PYTHON_PACKAGE_PREFIX=python" >> $GITHUB_ENV 45 | fi 46 | - name: Install Python dependencies # Only for tests 47 | run: | 48 | set -e 49 | set -x 50 | sudo apt-get update -qq 51 | sudo apt install ${PYTHON_PACKAGE_PREFIX}-tk 52 | python -m pip install --upgrade pip 53 | # Due to numpy conflicts, it must be run before apt numpy is installed. 54 | pip install numpy pybullet matplotlib 55 | - name: Install ROS 56 | run: | 57 | set -e 58 | set -x 59 | sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' 60 | wget http://packages.ros.org/ros.key -O - | sudo apt-key add - 61 | sudo apt-get update -qq 62 | sudo apt-get install -qq ros-${ROS_DISTRO}-ros-base ${PYTHON_PACKAGE_PREFIX}-catkin-tools ${PYTHON_PACKAGE_PREFIX}-rosdep doxygen graphviz 63 | - name: Setup catkin workspace 64 | run: | 65 | set -e 66 | set -x 67 | mkdir -p ${GITHUB_WORKSPACE}/catkin_ws/src/ 68 | cd ${GITHUB_WORKSPACE}/catkin_ws 69 | set +x 70 | . /opt/ros/${ROS_DISTRO}/setup.bash 71 | set -x 72 | catkin init 73 | catkin build --limit-status-rate 0.1 74 | - name: Checkout repository code 75 | uses: actions/checkout@v2 76 | with: 77 | submodules: recursive 78 | path: catkin_ws/src/DataDrivenMPC 79 | - name: Checkout NMPC 80 | uses: actions/checkout@v2 81 | with: 82 | repository: isri-aist/NMPC 83 | submodules: recursive 84 | path: catkin_ws/src/NMPC 85 | - name: Download libtorch 86 | run: | 87 | set -e 88 | set -x 89 | cd ${GITHUB_WORKSPACE} 90 | wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.11.0%2Bcpu.zip 91 | unzip libtorch-cxx11-abi-shared-with-deps-1.11.0+cpu.zip 92 | - name: Rosdep install 93 | run: | 94 | set -e 95 | set -x 96 | cd ${GITHUB_WORKSPACE}/catkin_ws 97 | set +x 98 | . devel/setup.bash 99 | set -x 100 | sudo rosdep init 101 | rosdep update 102 | rosdep install -y -r --from-paths src --ignore-src 103 | - name: Catkin build 104 | run: | 105 | set -e 106 | set -x 107 | cd ${GITHUB_WORKSPACE}/catkin_ws 108 | set +x 109 | . devel/setup.bash 110 | set -x 111 | catkin build --limit-status-rate 0.1 -DCMAKE_BUILD_TYPE=${{ matrix.build-type }} -DINSTALL_DOCUMENTATION=ON \ 112 | -DLIBTORCH_PATH=${GITHUB_WORKSPACE}/libtorch 113 | - name: Run tests 114 | run: | 115 | set -e 116 | set -x 117 | cd ${GITHUB_WORKSPACE}/catkin_ws 118 | set +x 119 | . devel/setup.bash 120 | set -x 121 | catkin build --limit-status-rate 0.1 --catkin-make-args run_tests -- data_driven_mpc --no-deps 122 | catkin_test_results --verbose --all build 123 | - name: Upload documentation 124 | # Only run for one configuration and on origin master branch 125 | if: matrix.os == 'ubuntu-20.04' && matrix.build-type == 'RelWithDebInfo' && github.repository_owner == 'isri-aist' && github.ref == 'refs/heads/master' 126 | run: | 127 | set -e 128 | set -x 129 | cd ${GITHUB_WORKSPACE}/catkin_ws/src/DataDrivenMPC 130 | git config --global user.name "Masaki Murooka" 131 | git config --global user.email "m-murooka@aist.go.jp" 132 | git remote set-url origin "https://mmurooka:${{ secrets.CI_TOKEN }}@github.com/isri-aist/DataDrivenMPC" 133 | git fetch --depth=1 origin gh-pages:gh-pages 134 | git clean -dfx 135 | git checkout --quiet gh-pages 136 | rm -rf doxygen/ 137 | cp -r ${GITHUB_WORKSPACE}/catkin_ws/build/data_driven_mpc/doc/html/ doxygen 138 | git add doxygen 139 | git_status=`git status -s` 140 | if test -n "$git_status"; then 141 | git commit --quiet -m "Update Doxygen HTML files from commit ${{ github.sha }}" 142 | git push origin gh-pages 143 | else 144 | echo "Github pages documentation is already up-to-date." 145 | fi 146 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.pyc 3 | 4 | GPATH 5 | GRTAGS 6 | GSYMS 7 | GTAGS 8 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1) 2 | project(data_driven_mpc) 3 | 4 | add_compile_options(-std=c++17) 5 | 6 | find_package(catkin REQUIRED COMPONENTS 7 | message_generation 8 | roscpp 9 | rospy 10 | rosbag 11 | nmpc_ddp 12 | ) 13 | 14 | # Eigen 15 | find_package(Eigen3 REQUIRED) 16 | include_directories(${EIGEN3_INCLUDE_DIR}) 17 | 18 | # Torch 19 | message("-- LIBTORCH_PATH: ${LIBTORCH_PATH}") 20 | list(APPEND CMAKE_PREFIX_PATH ${LIBTORCH_PATH}) 21 | find_package(Torch REQUIRED) 22 | 23 | add_message_files( 24 | FILES 25 | Dataset.msg 26 | StandardScaler.msg 27 | ) 28 | 29 | add_service_files( 30 | FILES 31 | RunSimOnce.srv 32 | GenerateDataset.srv 33 | ) 34 | 35 | generate_messages( 36 | DEPENDENCIES 37 | ) 38 | 39 | catkin_package( 40 | CATKIN_DEPENDS 41 | roscpp 42 | rospy 43 | rosbag 44 | nmpc_ddp 45 | DEPENDS EIGEN3 46 | INCLUDE_DIRS include 47 | LIBRARIES DDMPC 48 | ) 49 | 50 | add_subdirectory(src) 51 | 52 | add_subdirectory(samples) 53 | 54 | if(CATKIN_ENABLE_TESTING) 55 | add_subdirectory(tests) 56 | endif() 57 | 58 | OPTION(INSTALL_DOCUMENTATION "Generate and install the documentation" OFF) 59 | if(INSTALL_DOCUMENTATION) 60 | add_subdirectory(doc) 61 | endif() 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022, AIST-CNRS JRL 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [DataDrivenMPC](https://github.com/isri-aist/DataDrivenMPC) 2 | Model predictive control based on data-driven model 3 | 4 | [![CI](https://github.com/isri-aist/DataDrivenMPC/actions/workflows/ci.yaml/badge.svg)](https://github.com/isri-aist/DataDrivenMPC/actions/workflows/ci.yaml) 5 | [![Documentation](https://img.shields.io/badge/doxygen-online-brightgreen?logo=read-the-docs&style=flat)](https://isri-aist.github.io/DataDrivenMPC/) 6 | 7 | ## Install 8 | 9 | ### Requirements 10 | - Compiler supporting C++17 11 | - Tested on `Ubuntu 20.04 / ROS Noetic` and `Ubuntu 18.04 / ROS Melodic` 12 | 13 | ### Dependencies 14 | This package depends on 15 | - [libtorch (PyTorch C++ Frontend)](https://pytorch.org/cppdocs/installing.html) 16 | - [NMPC](https://github.com/isri-aist/NMPC) 17 | 18 | Some tests depend on 19 | - [PyBullet](https://pybullet.org) 20 | 21 | ### Installation procedure 22 | It is assumed that ROS is installed. 23 | 24 | 1. Follow [the official instructions](https://pytorch.org/cppdocs/installing.html) to download and extract the zip file of libtorch. 25 | 26 | 2. Setup catkin workspace. 27 | ```bash 28 | $ mkdir -p ~/ros/ws_ddmpc/src 29 | $ cd ~/ros/ws_ddmpc 30 | $ wstool init src 31 | $ wstool set -t src isri-aist/NMPC git@github.com:isri-aist/NMPC.git --git -y 32 | $ wstool set -t src isri-aist/DataDrivenMPC git@github.com:isri-aist/DataDrivenMPC.git --git -y 33 | $ wstool update -t src 34 | ``` 35 | 36 | 3. Install dependent packages. 37 | ```bash 38 | $ source /opt/ros/${ROS_DISTRO}/setup.bash 39 | $ rosdep install -y -r --from-paths src --ignore-src 40 | ``` 41 | 42 | 4. Build a package. 43 | ```bash 44 | $ catkin build data_driven_mpc -DCMAKE_BUILD_TYPE=RelWithDebInfo -DLIBTORCH_PATH= --catkin-make-args all tests 45 | ``` 46 | `` is the path to the directory named libtorch that was extracted in step 1. 47 | 48 | ## Examples 49 | Make sure that it is built with `--catkin-make-args tests` option. 50 | 51 | ### [MPC for Van der Pol oscillator](tests/src/TestMpcOscillator.cpp) 52 | Control the [Van der Pol oscillator](https://web.casadi.org/docs/#a-simple-test-problem) by the learned state equation. 53 | ```bash 54 | $ rosrun data_driven_mpc TestMpcOscillator 55 | ``` 56 | 57 | ### [MPC for walking with pushing](tests/src/TestMpcPushWalk.cpp) 58 | Control the CoM motion of robot and object by combining the known CoM-ZMP model and the learned object dynamics model. 59 | ```bash 60 | $ rosrun data_driven_mpc TestMpcPushWalk --gtest_filter=*.RunMPC 61 | $ rosrun data_driven_mpc plotTestMpcPushWalk.py 62 | ``` 63 | 64 | ### [MPC for cart pushing](tests/src/TestMpcCart.cpp) 65 | Control the position and angle of the one wheel cart on the PyBullet dynamics simulator. 66 | ```bash 67 | # 3-second simulation 68 | $ rostest data_driven_mpc TestMpcCart.test enable_gui:=true --text 69 | # Endless simulation 70 | $ rostest data_driven_mpc TestMpcCart.test no_exit:=true enable_gui:=true --text 71 | ``` 72 | 73 | ### [MPC for walking with cart pushing](tests/src/TestMpcCartWalk.cpp) 74 | Control the robot CoM and the position and angle of the one wheel cart on the PyBullet dynamics simulator. 75 | The robot CoM-ZMP model is known. The object dynamics model is learned. 76 | ```bash 77 | $ rostest data_driven_mpc TestMpcCartWalk.test enable_gui:=true --text 78 | ``` 79 | -------------------------------------------------------------------------------- /doc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(Doxygen REQUIRED) 2 | 3 | if(DOXYGEN_FOUND) 4 | set(DOXYFILE_PATH ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile) 5 | 6 | configure_file(Doxyfile.in ${DOXYFILE_PATH}) 7 | 8 | add_custom_target(DataDrivenMPC_doc ALL 9 | ${DOXYGEN_EXECUTABLE} ${DOXYFILE_PATH} 10 | DEPENDS ${DOXYFILE_PATH} 11 | COMMENT "Generating Doxygen documentation" 12 | ) 13 | endif() 14 | -------------------------------------------------------------------------------- /include/DDMPC/Dataset.h: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include 8 | 9 | namespace DDMPC 10 | { 11 | /** \brief Class of single data for state equation. */ 12 | class Data 13 | { 14 | public: 15 | /** \brief Constructor. 16 | \param state single data tensor of current state 17 | \param input single data tensor of current input 18 | \param next_state single data tensor of next state 19 | */ 20 | Data(const torch::Tensor & state, const torch::Tensor & input, const torch::Tensor & next_state) 21 | : state_(state), input_(input), next_state_(next_state) 22 | { 23 | } 24 | 25 | public: 26 | //! Single data tensor of current state 27 | torch::Tensor state_; 28 | 29 | //! Single data tensor of current input 30 | torch::Tensor input_; 31 | 32 | //! Single data tensor of next state 33 | torch::Tensor next_state_; 34 | }; 35 | 36 | /** \brief Single data example of state equation. */ 37 | using Example = torch::data::Example; 38 | 39 | /** \brief Class of dataset for state equation. */ 40 | class Dataset : public torch::data::datasets::Dataset 41 | { 42 | public: 43 | /** \brief Constructor. 44 | \param state all data tensor of current state 45 | \param input all data tensor of current input 46 | \param next_state all data tensor of next state 47 | */ 48 | explicit Dataset(const torch::Tensor & state, const torch::Tensor & input, const torch::Tensor & next_state) 49 | : state_(std::move(state)), input_(std::move(input)), next_state_(std::move(next_state)) 50 | { 51 | // Store size because tensors are passed to data loader with std::move 52 | size_ = state_.size(0); 53 | } 54 | 55 | /** \brief Returns a single data example. */ 56 | inline Example get(size_t index) override 57 | { 58 | return Data(state_[index], input_[index], next_state_[index]); 59 | } 60 | 61 | /** \brief Returns dataset size. */ 62 | inline torch::optional size() const override 63 | { 64 | return size_; 65 | } 66 | 67 | protected: 68 | //! All data tensor of current state 69 | torch::Tensor state_; 70 | 71 | //! All data tensor of current input 72 | torch::Tensor input_; 73 | 74 | //! All data tensor of next state 75 | torch::Tensor next_state_; 76 | 77 | //! Dataset size 78 | size_t size_; 79 | }; 80 | 81 | /** \brief Make dataset. 82 | \param state all data tensor of current state 83 | \param input all data tensor of current input 84 | \param next_state all data tensor of next state 85 | \param[out] train_dataset training dataset 86 | \param[out] test_dataset test dataset 87 | */ 88 | void makeDataset(const torch::Tensor & state, 89 | const torch::Tensor & input, 90 | const torch::Tensor & next_state, 91 | std::shared_ptr & train_dataset, 92 | std::shared_ptr & test_dataset); 93 | 94 | /** \brief Make batch tensors. 95 | \param batch batch from data loader 96 | \param device device to which tensors belong 97 | \param[out] b_state batch tensor of current state 98 | \param[out] b_next_state batch tensor of next state 99 | \param[out] b_input batch tensor of input 100 | */ 101 | void makeBatchTensor(const std::vector & batch, 102 | const torch::Device & device, 103 | torch::Tensor & b_state, 104 | torch::Tensor & b_input, 105 | torch::Tensor & b_next_state); 106 | } // namespace DDMPC 107 | -------------------------------------------------------------------------------- /include/DDMPC/MathUtils.h: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include 8 | 9 | namespace DDMPC 10 | { 11 | /** \brief Class of standardization (i.e., mean removal and variance scaling). 12 | \tparam Scalar scalar type 13 | \tparam DataDim data dimension 14 | */ 15 | template 16 | class StandardScaler 17 | { 18 | public: 19 | /** \brief Type of matrix. */ 20 | using Matrix = Eigen::Matrix; 21 | 22 | /** \brief Type of column vector. */ 23 | using Vector = Eigen::Matrix; 24 | 25 | /** \brief Type of row vector. */ 26 | using RowVector = Eigen::Matrix; 27 | 28 | public: 29 | /** \brief Constructor. 30 | \param data_all all data to calculate standardization coefficients 31 | */ 32 | StandardScaler(const Matrix & data_all) 33 | { 34 | mean_vec_ = calcMean(data_all); 35 | stddev_vec_ = calcStddev(data_all, mean_vec_).cwiseMax(1e-6); // Set minimum to avoid zero devision 36 | } 37 | 38 | /** \brief Constructor. 39 | \param msg ROS message 40 | */ 41 | StandardScaler(const data_driven_mpc::StandardScaler & msg) 42 | { 43 | mean_vec_ = RowVector(msg.mean_vec.data()); 44 | stddev_vec_ = RowVector(msg.stddev_vec.data()); 45 | } 46 | 47 | /** \brief Convert to ROS message. */ 48 | data_driven_mpc::StandardScaler toMsg() const 49 | { 50 | data_driven_mpc::StandardScaler msg; 51 | msg.mean_vec.resize(DataDim); 52 | RowVector::Map(&msg.mean_vec[0]) = mean_vec_; 53 | msg.stddev_vec.resize(DataDim); 54 | RowVector::Map(&msg.stddev_vec[0]) = stddev_vec_; 55 | return msg; 56 | } 57 | 58 | /** \brief Apply standardization. 59 | \param data data to apply standardization 60 | */ 61 | Matrix apply(const Matrix & data) const 62 | { 63 | return (data.rowwise() - mean_vec_).array().rowwise() / stddev_vec_.array(); 64 | } 65 | 66 | /** \brief Apply standardization. 67 | \param data single data to apply standardization 68 | */ 69 | Vector applyOne(const Vector & data) const 70 | { 71 | return (data - mean_vec_.transpose()).array() / stddev_vec_.transpose().array(); 72 | } 73 | 74 | /** \brief Apply inverse standardization. 75 | \param data data to apply inverse standardization 76 | */ 77 | Matrix applyInv(const Matrix & data) const 78 | { 79 | return (data.array().rowwise() * stddev_vec_.array()).matrix().rowwise() + mean_vec_; 80 | } 81 | 82 | /** \brief Apply inverse standardization. 83 | \param data single data to apply inverse standardization 84 | */ 85 | Vector applyOneInv(const Vector & data) const 86 | { 87 | return data.cwiseProduct(stddev_vec_.transpose()) + mean_vec_.transpose(); 88 | } 89 | 90 | /** \brief Calculate mean. */ 91 | static RowVector calcMean(const Matrix & data_all) 92 | { 93 | return data_all.colwise().mean(); 94 | } 95 | 96 | /** \brief Calculate standard deviation. */ 97 | static RowVector calcStddev(const Matrix & data_all, const RowVector & mean) 98 | { 99 | return ((data_all.rowwise() - mean).cwiseAbs2().colwise().sum() / (data_all.rows() - 1)).cwiseSqrt(); 100 | } 101 | 102 | public: 103 | //! Mean of data 104 | RowVector mean_vec_; 105 | 106 | //! Standard deviation of data 107 | RowVector stddev_vec_; 108 | }; 109 | } // namespace DDMPC 110 | -------------------------------------------------------------------------------- /include/DDMPC/StateEq.h: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | namespace DDMPC 9 | { 10 | /** \brief Class of state equation based on neural network model. */ 11 | class StateEq 12 | { 13 | public: 14 | /** \brief Class of neural network model for state equation. */ 15 | class Model : public torch::nn::Module 16 | { 17 | public: 18 | /** \brief Constructor. 19 | \param state_dim state dimension 20 | \param input_dim input dimension 21 | \param middle_layer_dim middle layer dimension 22 | */ 23 | Model(int state_dim, int input_dim, int middle_layer_dim); 24 | 25 | /** \brief Forward model. 26 | \param x current state 27 | \param u current input 28 | \param enable_auto_grad whether to enable automatic gradient (default true) 29 | \param requires_grad_x whether to require gradient w.r.t. current state (default false) 30 | \param requires_grad_u whether to require gradient w.r.t. current input (default false) 31 | \returns next state 32 | 33 | The required gradients are stored in the member variables grad_x_ and grad_u_. 34 | */ 35 | torch::Tensor forward(torch::Tensor & x, 36 | torch::Tensor & u, 37 | bool enable_auto_grad = true, 38 | bool requires_grad_x = false, 39 | bool requires_grad_u = false); 40 | 41 | public: 42 | //! Whether to enable debug print 43 | bool debug_ = true; 44 | 45 | //! State dimension 46 | const int state_dim_; 47 | 48 | //! Input dimension 49 | const int input_dim_; 50 | 51 | //! Linear layers 52 | torch::nn::Linear linear1_ = nullptr; 53 | torch::nn::Linear linear2_ = nullptr; 54 | torch::nn::Linear linear3_ = nullptr; 55 | 56 | //! Gradient w.r.t. current state 57 | torch::Tensor grad_x_; 58 | 59 | //! Gradient w.r.t. current input 60 | torch::Tensor grad_u_; 61 | }; 62 | 63 | /** \brief Model pointer. 64 | 65 | See "Module Ownership" section of https://pytorch.org/tutorials/advanced/cpp_frontend.html for details 66 | */ 67 | TORCH_MODULE_IMPL(ModelPtr, Model); 68 | 69 | public: 70 | /** \brief Constructor. 71 | \param state_dim state dimension 72 | \param input_dim input dimension 73 | \param middle_layer_dim middle layer dimension 74 | */ 75 | StateEq(int state_dim, int input_dim, int middle_layer_dim = 32) 76 | : model_ptr_(ModelPtr(state_dim, input_dim, middle_layer_dim)) 77 | { 78 | } 79 | 80 | /** \brief Calculate next state. 81 | \param x current state 82 | \param u current input 83 | \returns next state 84 | */ 85 | Eigen::VectorXd eval(const Eigen::VectorXd & x, const Eigen::VectorXd & u); 86 | 87 | /** \brief Calculate next state. 88 | \param x current state 89 | \param u current input 90 | \param[out] grad_x gradient w.r.t. x (not calculated when the matrix size is zero) 91 | \param[out] grad_u gradient w.r.t. u (not calculated when the matrix size is zero) 92 | \returns next state 93 | */ 94 | Eigen::VectorXd eval(const Eigen::VectorXd & x, 95 | const Eigen::VectorXd & u, 96 | Eigen::Ref grad_x, 97 | Eigen::Ref grad_u); 98 | 99 | /** \brief Get state dimension. */ 100 | inline int stateDim() const 101 | { 102 | return model_ptr_->state_dim_; 103 | } 104 | 105 | /** \brief Get input dimension. */ 106 | inline int inputDim() const 107 | { 108 | return model_ptr_->input_dim_; 109 | } 110 | 111 | public: 112 | //! Model pointer 113 | ModelPtr model_ptr_ = nullptr; 114 | }; 115 | } // namespace DDMPC 116 | -------------------------------------------------------------------------------- /include/DDMPC/TorchUtils.h: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | namespace DDMPC 9 | { 10 | /** \brief Row major version of Eige::MatrixXf. */ 11 | using MatrixXfRowMajor = Eigen::Matrix; 12 | /** \brief Row major version of Eige::MatrixXd. */ 13 | using MatrixXdRowMajor = Eigen::Matrix; 14 | 15 | /** \brief Convert to torch::Tensor. 16 | \param mat input (Eigen::MatrixXf) 17 | 18 | Even if the matrix of colum major is passed as an argument, it is automatically converted to row major. 19 | */ 20 | inline torch::Tensor toTorchTensor(const MatrixXfRowMajor & mat) 21 | { 22 | return torch::from_blob(const_cast(mat.data()), {mat.rows(), mat.cols()}).clone(); 23 | } 24 | 25 | /** \brief Convert to Eigen::MatrixXf. 26 | \param tensor input (torch::Tensor) 27 | */ 28 | inline Eigen::MatrixXf toEigenMatrix(const torch::Tensor & tensor) 29 | { 30 | assert(tensor.dim() == 2); 31 | 32 | float * tensor_data_ptr = const_cast(tensor.data_ptr()); 33 | return Eigen::MatrixXf(Eigen::Map(tensor_data_ptr, tensor.size(0), tensor.size(1))); 34 | } 35 | } // namespace DDMPC 36 | -------------------------------------------------------------------------------- /include/DDMPC/Training.h: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | namespace DDMPC 9 | { 10 | /** \brief Class to train a model. */ 11 | class Training 12 | { 13 | public: 14 | /** \brief Constructor. 15 | \param device_type type of device (torch::DeviceType::CUDA or torch::DeviceType::CPU) 16 | */ 17 | Training(torch::DeviceType device_type = torch::DeviceType::CPU); 18 | 19 | /** \brief Train a model. 20 | \param state_eq state equation 21 | \param train_dataset training dataset 22 | \param test_dataset test dataset 23 | \param model_path path to save model parameters 24 | \param batch_size batch size for train and test 25 | \param num_epoch nubmer of epoch for learning 26 | \param learning_rate learning rate 27 | */ 28 | void run(const std::shared_ptr & state_eq, 29 | const std::shared_ptr & train_dataset, 30 | const std::shared_ptr & test_dataset, 31 | const std::string & model_path, 32 | int batch_size = 64, 33 | int num_epoch = 100, 34 | double learning_rate = 1e-3) const; 35 | 36 | /** \brief Load a model from file. 37 | \param state_eq state equation 38 | \param model_path path to load model parameters 39 | */ 40 | void load(const std::shared_ptr & state_eq, const std::string & model_path) const; 41 | 42 | public: 43 | //! Whether to enable debug print 44 | bool debug_ = true; 45 | 46 | //! Device on which to place the model parameters 47 | std::shared_ptr device_; 48 | }; 49 | } // namespace DDMPC 50 | -------------------------------------------------------------------------------- /msg/Dataset.msg: -------------------------------------------------------------------------------- 1 | # Number of data in dataset 2 | int32 dataset_size 3 | 4 | # Discretization period of the state equation [sec] 5 | float64 dt 6 | 7 | # State dimension 8 | int32 state_dim 9 | 10 | # Input dimension 11 | int32 input_dim 12 | 13 | # All states 14 | float64[] state_all 15 | 16 | # All inputs 17 | float64[] input_all 18 | 19 | # All next states 20 | float64[] next_state_all 21 | -------------------------------------------------------------------------------- /msg/StandardScaler.msg: -------------------------------------------------------------------------------- 1 | # Mean of data 2 | float64[] mean_vec 3 | 4 | # Standard deviation of data 5 | float64[] stddev_vec 6 | -------------------------------------------------------------------------------- /package.xml: -------------------------------------------------------------------------------- 1 | 2 | data_driven_mpc 3 | 0.1.0 4 | 5 | Model predictive control based on data-driven model 6 | 7 | Masaki Murooka 8 | BSD 9 | 10 | http://ros.org/wiki/data_driven_mpc 11 | Masaki Murooka 12 | 13 | catkin 14 | 15 | roscpp 16 | rospy 17 | rosbag 18 | nmpc_ddp 19 | 20 | eigen 21 | 22 | rosunit 23 | rostest 24 | 25 | doxygen 26 | graphviz 27 | 28 | -------------------------------------------------------------------------------- /samples/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isri-aist/DataDrivenMPC/d95b019f1944ab4bad2f52e9a46f10ef9dbfbc06/samples/CMakeLists.txt -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(DDMPC 2 | StateEq.cpp 3 | Dataset.cpp 4 | Training.cpp 5 | ) 6 | target_include_directories(DDMPC PUBLIC 7 | ${PROJECT_SOURCE_DIR}/include 8 | ${catkin_INCLUDE_DIRS} 9 | ) 10 | target_link_libraries(DDMPC PUBLIC 11 | ${TORCH_LIBRARIES} ${catkin_LIBRARIES} 12 | ) 13 | -------------------------------------------------------------------------------- /src/Dataset.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | using namespace DDMPC; 6 | 7 | void DDMPC::makeDataset(const torch::Tensor & state, 8 | const torch::Tensor & input, 9 | const torch::Tensor & next_state, 10 | std::shared_ptr & train_dataset, 11 | std::shared_ptr & test_dataset) 12 | { 13 | // Make dataset for train and test 14 | int n_all = state.size(0); 15 | int n_train = 0.7 * n_all; 16 | int n_test = n_all - n_train; 17 | torch::Tensor perm = torch::randperm(n_all, torch::kInt64); 18 | torch::Tensor train_perm = perm.index({at::indexing::Slice(0, n_train)}); 19 | torch::Tensor test_perm = perm.index({at::indexing::Slice(n_train, n_train + n_test)}); 20 | train_dataset = 21 | std::make_shared(state.index({train_perm}), input.index({train_perm}), next_state.index({train_perm})); 22 | test_dataset = 23 | std::make_shared(state.index({test_perm}), input.index({test_perm}), next_state.index({test_perm})); 24 | 25 | // Print debug information 26 | constexpr bool debug = true; 27 | if(debug) 28 | { 29 | std::cout << "Construct Dataset" << std::endl; 30 | std::cout << " - train_dataset size: " << train_dataset->size().value() << std::endl 31 | << " - test_dataset size: " << test_dataset->size().value() << std::endl; 32 | } 33 | } 34 | 35 | void DDMPC::makeBatchTensor(const std::vector & batch, 36 | const torch::Device & device, 37 | torch::Tensor & b_state, 38 | torch::Tensor & b_input, 39 | torch::Tensor & b_next_state) 40 | { 41 | int batch_size = batch.size(); 42 | 43 | // Allocate batch tensors 44 | { 45 | const auto & data = static_cast(batch[0]); 46 | b_state = torch::empty({batch_size, data.state_.size(0)}); 47 | b_input = torch::empty({batch_size, data.input_.size(0)}); 48 | b_next_state = torch::empty({batch_size, data.next_state_.size(0)}); 49 | } 50 | 51 | // Set batch tensors 52 | for(int i = 0; i < batch_size; i++) 53 | { 54 | const auto & data = static_cast(batch[i]); 55 | b_state.index({i}) = data.state_; 56 | b_input.index({i}) = data.input_; 57 | b_next_state.index({i}) = data.next_state_; 58 | } 59 | 60 | // Send to device 61 | b_state = b_state.to(device); 62 | b_input = b_input.to(device); 63 | b_next_state = b_next_state.to(device); 64 | } 65 | -------------------------------------------------------------------------------- /src/StateEq.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | #include 5 | 6 | using namespace DDMPC; 7 | 8 | StateEq::Model::Model(int state_dim, int input_dim, int middle_layer_dim) : state_dim_(state_dim), input_dim_(input_dim) 9 | { 10 | // Instantiate layers 11 | linear1_ = register_module("linear1", torch::nn::Linear(state_dim_ + input_dim_, middle_layer_dim)); 12 | linear2_ = register_module("linear2", torch::nn::Linear(middle_layer_dim, middle_layer_dim)); 13 | linear3_ = register_module("linear3", torch::nn::Linear(middle_layer_dim, state_dim_)); 14 | 15 | // Print debug information 16 | if(debug_) 17 | { 18 | std::cout << "Construct NN Module" << std::endl; 19 | std::cout << " - state_dim: " << state_dim_ << std::endl; 20 | std::cout << " - input_dim: " << input_dim_ << std::endl; 21 | std::cout << " - layer dims: " << state_dim_ + input_dim_ << " -> " << middle_layer_dim << " -> " 22 | << middle_layer_dim << " -> " << state_dim_ << std::endl; 23 | } 24 | 25 | // Workaround to avoid torch error 26 | // See https://github.com/pytorch/pytorch/issues/35736#issuecomment-688078143 27 | torch::cuda::is_available(); 28 | } 29 | 30 | torch::Tensor StateEq::Model::forward(torch::Tensor & x, 31 | torch::Tensor & u, 32 | bool enable_auto_grad, 33 | bool requires_grad_x, 34 | bool requires_grad_u) 35 | { 36 | // Check dimensions 37 | assert(x.size(1) == state_dim_); 38 | assert(u.size(1) == input_dim_); 39 | assert(x.size(0) == u.size(0)); 40 | 41 | // Setup gradient 42 | bool requires_grad = (requires_grad_x || requires_grad_u); 43 | if(requires_grad) 44 | { 45 | enable_auto_grad = true; 46 | } 47 | torch::Tensor x_repeated; 48 | torch::Tensor u_repeated; 49 | if(requires_grad) 50 | { 51 | x_repeated = x.repeat({state_dim_, 1}); 52 | u_repeated = u.repeat({state_dim_, 1}); 53 | } 54 | if(requires_grad_x) 55 | { 56 | x_repeated.set_requires_grad(true); 57 | assert(!x_repeated.mutable_grad().defined()); 58 | } 59 | if(requires_grad_u) 60 | { 61 | u_repeated.set_requires_grad(true); 62 | assert(!u_repeated.mutable_grad().defined()); 63 | } 64 | if(requires_grad && x.size(0) > 1) 65 | { 66 | throw std::runtime_error("batch size should be 1 when requiring gradient. batch size: " 67 | + std::to_string(x.size(0))); 68 | } 69 | 70 | std::unique_ptr no_grad; 71 | if(!enable_auto_grad) 72 | { 73 | // Not calculate gradient 74 | no_grad = std::make_unique(); 75 | } 76 | 77 | // Calculate network output 78 | torch::Tensor xu = requires_grad ? torch::cat({x_repeated, u_repeated}, 1) : torch::cat({x, u}, 1); 79 | xu = torch::relu(linear1_(xu)); 80 | xu = torch::relu(linear2_(xu)); 81 | torch::Tensor next_x = linear3_(xu); 82 | 83 | // Calculate gradient 84 | if(requires_grad) 85 | { 86 | // See https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa for Jacobian calculation 87 | next_x.backward(torch::eye(state_dim_)); 88 | if(requires_grad_x) 89 | { 90 | grad_x_ = x_repeated.grad(); 91 | } 92 | if(requires_grad_u) 93 | { 94 | grad_u_ = u_repeated.grad(); 95 | } 96 | } 97 | 98 | return requires_grad ? next_x.index({0}).view({1, -1}) : next_x; 99 | } 100 | 101 | Eigen::VectorXd StateEq::eval(const Eigen::VectorXd & x, const Eigen::VectorXd & u) 102 | { 103 | // Check dimensions 104 | assert(x.size() == stateDim()); 105 | assert(u.size() == inputDim()); 106 | 107 | // Set tensor 108 | torch::Tensor x_tensor = toTorchTensor(x.transpose().cast()); 109 | torch::Tensor u_tensor = toTorchTensor(u.transpose().cast()); 110 | 111 | // Forward network 112 | torch::Tensor next_x_tensor = model_ptr_->forward(x_tensor, u_tensor, false); 113 | 114 | // Set output variables 115 | return toEigenMatrix(next_x_tensor).transpose().cast(); 116 | } 117 | 118 | Eigen::VectorXd StateEq::eval(const Eigen::VectorXd & x, 119 | const Eigen::VectorXd & u, 120 | Eigen::Ref grad_x, 121 | Eigen::Ref grad_u) 122 | { 123 | // Check dimensions 124 | assert(x.size() == stateDim()); 125 | assert(u.size() == inputDim()); 126 | assert(grad_x.size() == 0 || (grad_x.rows() == stateDim() && grad_x.cols() == stateDim())); 127 | assert(grad_u.size() == 0 || (grad_u.rows() == stateDim() && grad_u.cols() == inputDim())); 128 | 129 | // Set tensor 130 | torch::Tensor x_tensor = toTorchTensor(x.transpose().cast()); 131 | torch::Tensor u_tensor = toTorchTensor(u.transpose().cast()); 132 | bool requires_grad_x = grad_x.size() > 0; 133 | bool requires_grad_u = grad_u.size() > 0; 134 | 135 | // Forward network 136 | torch::Tensor next_x_tensor = model_ptr_->forward(x_tensor, u_tensor, false, requires_grad_x, requires_grad_u); 137 | 138 | // Set output variables 139 | if(requires_grad_x) 140 | { 141 | grad_x = toEigenMatrix(model_ptr_->grad_x_).cast(); 142 | } 143 | if(requires_grad_u) 144 | { 145 | grad_u = toEigenMatrix(model_ptr_->grad_u_).cast(); 146 | } 147 | return toEigenMatrix(next_x_tensor).transpose().cast(); 148 | } 149 | -------------------------------------------------------------------------------- /src/Training.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | #include 5 | 6 | using namespace DDMPC; 7 | 8 | Training::Training(torch::DeviceType device_type) 9 | { 10 | // Set device 11 | if(!torch::cuda::is_available() && device_type == torch::DeviceType::CUDA) 12 | { 13 | std::cout << "CUDA is unavailable! Overwrite with CPU" << std::endl; 14 | device_type = torch::DeviceType::CPU; 15 | } 16 | device_ = std::make_shared(device_type); 17 | } 18 | 19 | void Training::run(const std::shared_ptr & state_eq, 20 | const std::shared_ptr & train_dataset, 21 | const std::shared_ptr & test_dataset, 22 | const std::string & model_path, 23 | int batch_size, 24 | int num_epoch, 25 | double learning_rate) const 26 | { 27 | // Make data loader for train and test 28 | auto train_data_loader = torch::data::make_data_loader( 29 | std::move(*train_dataset), torch::data::DataLoaderOptions(batch_size).workers(2)); 30 | auto test_data_loader = torch::data::make_data_loader( 31 | std::move(*test_dataset), torch::data::DataLoaderOptions(batch_size).workers(2)); 32 | if(debug_) 33 | { 34 | std::cout << "Construct DataLoader" << std::endl; 35 | std::cout << " - batch_size: " << batch_size << std::endl; 36 | } 37 | 38 | // Setup model 39 | auto model_ptr = state_eq->model_ptr_; 40 | model_ptr->to(*device_); 41 | 42 | // Make optimizer 43 | torch::optim::Adam optimizer(model_ptr->parameters(), torch::optim::AdamOptions(learning_rate)); 44 | 45 | // Learn 46 | if(debug_) 47 | { 48 | std::cout << "Start training" << std::endl; 49 | std::cout << " - device_type: " << device_->type() << std::endl; 50 | } 51 | 52 | std::ofstream ofs("/tmp/DataDrivenMPCTraining.txt"); 53 | ofs << "epoch train_loss test_loss" << std::endl; 54 | 55 | float test_loss_ave_min = std::numeric_limits::max(); 56 | for(int i_epoch = 0; i_epoch < num_epoch; i_epoch++) 57 | { // For each epoch 58 | // Train for one epoch 59 | float train_loss_ave = 0.0; 60 | for(const std::vector & batch : *train_data_loader) 61 | { // For each batch 62 | // Make batch tensor 63 | torch::Tensor b_state, b_input, b_next_state_gt; 64 | makeBatchTensor(batch, *device_, b_state, b_input, b_next_state_gt); 65 | 66 | // Forward, calculate loss, and optimize 67 | model_ptr->zero_grad(); 68 | torch::Tensor b_next_state_pred = model_ptr->forward(b_state, b_input); 69 | torch::Tensor loss = torch::nn::functional::mse_loss(b_next_state_pred, b_next_state_gt, 70 | torch::nn::functional::MSELossFuncOptions(torch::kSum)); 71 | loss.backward(); 72 | optimizer.step(); 73 | train_loss_ave += loss.to(torch::DeviceType::CPU).item(); 74 | } 75 | train_loss_ave /= static_cast(train_dataset->size().value()); 76 | 77 | // Test for one epoch 78 | float test_loss_ave = 0.0; 79 | for(const std::vector & batch : *test_data_loader) 80 | { // For each batch 81 | // Not calculate gradient 82 | torch::NoGradGuard no_grad; 83 | 84 | // Make batch tensor 85 | torch::Tensor b_state, b_input, b_next_state_gt; 86 | makeBatchTensor(batch, *device_, b_state, b_input, b_next_state_gt); 87 | 88 | // Forward and calculate loss 89 | torch::Tensor b_next_state_pred = model_ptr->forward(b_state, b_input); 90 | torch::Tensor loss = torch::nn::functional::mse_loss(b_next_state_pred, b_next_state_gt, 91 | torch::nn::functional::MSELossFuncOptions(torch::kSum)); 92 | test_loss_ave += loss.to(torch::DeviceType::CPU).item(); 93 | } 94 | test_loss_ave /= static_cast(test_dataset->size().value()); 95 | 96 | // Print result 97 | std::cout << "[" << i_epoch << "/" << num_epoch << "] train_loss: " << train_loss_ave 98 | << ", test_loss: " << test_loss_ave << std::endl; 99 | ofs << i_epoch << " " << train_loss_ave << " " << test_loss_ave << std::endl; 100 | 101 | // Save model parameters only if the model has the best performance 102 | if(test_loss_ave < test_loss_ave_min) 103 | { 104 | test_loss_ave_min = test_loss_ave; 105 | std::cout << "Best performance. Save model to " << model_path << std::endl; 106 | torch::save(model_ptr, model_path); 107 | } 108 | } 109 | } 110 | 111 | void Training::load(const std::shared_ptr & state_eq, const std::string & model_path) const 112 | { 113 | // Setup model 114 | auto model_ptr = state_eq->model_ptr_; 115 | model_ptr->to(*device_); 116 | 117 | // Load model parameters 118 | if(debug_) 119 | { 120 | std::cout << "Load model from " << model_path << std::endl; 121 | } 122 | torch::load(model_ptr, model_path); 123 | } 124 | -------------------------------------------------------------------------------- /srv/GenerateDataset.srv: -------------------------------------------------------------------------------- 1 | # File name to save dataset 2 | string filename 3 | 4 | # Number of data in dataset 5 | int32 dataset_size 6 | 7 | # Discretization period of the state equation [sec] 8 | float64 dt 9 | 10 | # Min/max of state random samples 11 | float64[] state_min 12 | float64[] state_max 13 | 14 | # Min/max of input random samples 15 | float64[] input_min 16 | float64[] input_max 17 | 18 | --- 19 | -------------------------------------------------------------------------------- /srv/RunSimOnce.srv: -------------------------------------------------------------------------------- 1 | # Duration to proceed the simulation [sec] 2 | # Specify zero to set/get the state without proceeding the simulation. 3 | float64 dt 4 | 5 | # Initial state 6 | # Leave empty to start from the internal state. 7 | float64[] state 8 | 9 | # Control input 10 | float64[] input 11 | 12 | # Additional data 13 | float64[] additional_data 14 | 15 | --- 16 | 17 | # Final state 18 | float64[] state 19 | 20 | # Additional data 21 | float64[] additional_data 22 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(rostest REQUIRED) 2 | 3 | set(DDMPC_gtest_list 4 | TestMathUtils 5 | TestTorchUtils 6 | TestStateEq 7 | TestDataset 8 | TestTraining 9 | TestMpcOscillator 10 | TestMpcPushWalk 11 | ) 12 | 13 | set(DDMPC_rostest_list 14 | TestMpcCart 15 | TestMpcCartWalk 16 | ) 17 | 18 | foreach(NAME IN LISTS DDMPC_gtest_list) 19 | catkin_add_gtest(${NAME} src/${NAME}.cpp) 20 | target_link_libraries(${NAME} DDMPC) 21 | endforeach() 22 | 23 | foreach(NAME IN LISTS DDMPC_rostest_list) 24 | add_rostest_gtest(${NAME} test/${NAME}.test src/${NAME}.cpp) 25 | target_link_libraries(${NAME} DDMPC) 26 | endforeach() 27 | -------------------------------------------------------------------------------- /tests/data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /tests/scripts/plotTestMpcPushWalk.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class PlotTestMpcPushWalk(object): 8 | def __init__(self, result_file_path): 9 | self.result_data_list = np.genfromtxt(result_file_path, dtype=None, delimiter=None, names=True) 10 | print("[PlotTestMpcPushWalk] Load {}".format(result_file_path)) 11 | 12 | fig = plt.figure() 13 | plt.rcParams["font.size"] = 16 14 | 15 | ax = fig.add_subplot(311) 16 | ax.plot(self.result_data_list["time"], self.result_data_list["robot_com_pos"], 17 | color="green", label="planned robot pos") 18 | ax.plot(self.result_data_list["time"], self.result_data_list["obj_com_pos"], 19 | color="coral", label="planned obj pos") 20 | ax.plot(self.result_data_list["time"], self.result_data_list["robot_zmp"], 21 | color="red", label="planned robot zmp") 22 | ax.plot(self.result_data_list["time"], self.result_data_list["ref_obj_com_pos"], 23 | color="cyan", linestyle="dashed", label="ref obj pos", zorder=-1) 24 | ax.plot(self.result_data_list["time"], self.result_data_list["ref_robot_zmp"], 25 | color="blue", linestyle="dashed", label="ref robot zmp", zorder=-1) 26 | ax.set_xlabel("time [s]") 27 | ax.set_ylabel("pos [m]") 28 | ax.grid() 29 | ax.legend(loc="upper left") 30 | 31 | ax = fig.add_subplot(312) 32 | ax.plot(self.result_data_list["time"], self.result_data_list["obj_force"], 33 | color="green", label="obj force") 34 | ax.set_xlabel("time [s]") 35 | ax.set_ylabel("force [N]") 36 | ax.grid() 37 | ax.legend(loc="upper right") 38 | 39 | mass = 50.0 # [kg] 40 | damper_forces = mass * (self.result_data_list["obj_com_vel"][1:] - self.result_data_list["obj_com_vel"][:-1]) / \ 41 | (self.result_data_list["time"][1:] - self.result_data_list["time"][:-1]) \ 42 | - self.result_data_list["obj_force"][:-1] 43 | ax = fig.add_subplot(313) 44 | ax.scatter(self.result_data_list["obj_com_vel"][:-1], damper_forces, 45 | color="green", label="damper force") 46 | ax.set_xlabel("vel [m/s]") 47 | ax.set_ylabel("force [N]") 48 | ax.grid() 49 | ax.legend(loc="upper right") 50 | 51 | plt.show() 52 | 53 | 54 | if __name__ == "__main__": 55 | result_file_path = "/tmp/TestMpcPushWalkResult-Linear.txt" 56 | 57 | import sys 58 | if len(sys.argv) >= 2: 59 | result_file_path = sys.argv[1] 60 | 61 | plot = PlotTestMpcPushWalk(result_file_path) 62 | -------------------------------------------------------------------------------- /tests/scripts/sampleClientSimTestMpcCart.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import time 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | import rospy 8 | from data_driven_mpc.srv import * 9 | 10 | 11 | # Setup ROS 12 | rospy.init_node("sample_client_sim_test_mpc_cart") 13 | run_sim_once_cli = rospy.ServiceProxy("/run_sim_once", RunSimOnce) 14 | generate_dataset_cli = rospy.ServiceProxy("/generate_dataset", GenerateDataset) 15 | 16 | # Generate dataset 17 | rospy.wait_for_service("/generate_dataset") 18 | 19 | req = GenerateDatasetRequest() 20 | req.filename = "/tmp/DataDrivenMPCDataset.bag" 21 | req.dataset_size = 1000 22 | req.dt = 0.05 # [sec] 23 | req.state_max = np.array([1.0, 1.0, np.deg2rad(30), np.deg2rad(60)]) 24 | req.state_min = -1 * req.state_max 25 | req.input_max = np.array([100.0, 100.0]) 26 | req.input_min = -1 * req.input_max 27 | generate_dataset_cli(req) 28 | 29 | # Set initial state 30 | rospy.wait_for_service("/run_sim_once") 31 | 32 | req = RunSimOnceRequest() 33 | req.dt = 0.0 34 | req.state = [0.0, 1.0, np.deg2rad(-10.0), 0.0] 35 | req.input = [0.0, 0.0] 36 | res = run_sim_once_cli(req) 37 | state = np.array(res.state) 38 | 39 | # Setup variables 40 | time_list = [] 41 | state_list = [] 42 | force_list = [] 43 | 44 | t = 0.0 # [sec] 45 | end_t = 3.0 # [sec] 46 | dt = 0.05 # [sec] 47 | while True: 48 | # Calculate manipulation force 49 | _, _, theta, theta_dot = state 50 | manip_force_z = -100.0 * theta -20.0 * theta_dot # [N] 51 | 52 | # Save 53 | time_list.append(t) 54 | state_list.append(state.tolist()) 55 | force_list.append(manip_force_z) 56 | 57 | # Check terminal condition 58 | if t >= end_t: 59 | break 60 | 61 | # Run simulation step 62 | req = RunSimOnceRequest() 63 | req.dt = dt 64 | req.state = [] 65 | req.input = [0.0, manip_force_z] 66 | res = run_sim_once_cli(req) 67 | state = np.array(res.state) 68 | 69 | # Sleep and increment time 70 | time.sleep(dt) 71 | t += dt 72 | 73 | # Plot result 74 | time_list = np.array(time_list) 75 | state_list = np.array(state_list) 76 | force_list = np.array(force_list) 77 | 78 | fig = plt.figure() 79 | 80 | data_list = [state_list[:, 0], np.rad2deg(state_list[:, 2]), force_list] 81 | label_list = ["pos [m]", "angle [deg]", "force [N]"] 82 | for i in range(len(data_list)): 83 | ax = fig.add_subplot(3, 1, i + 1) 84 | ax.plot(time_list, data_list[i], linestyle="-", marker='o') 85 | ax.set_xlabel("time [s]") 86 | ax.set_ylabel(label_list[i]) 87 | ax.grid() 88 | 89 | plt.show() 90 | -------------------------------------------------------------------------------- /tests/scripts/simTestMpcCart.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import time 4 | import numpy as np 5 | import pybullet 6 | import pybullet_data 7 | 8 | import rospy 9 | import rosbag 10 | from data_driven_mpc.msg import * 11 | from data_driven_mpc.srv import * 12 | 13 | 14 | class SimTestMpcCart(object): 15 | def __init__(self, enable_gui): 16 | # Instantiate simulator 17 | if enable_gui: 18 | pybullet.connect(pybullet.GUI) 19 | else: 20 | pybullet.connect(pybullet.DIRECT) 21 | 22 | # Set simulation parameters 23 | self.dt = 0.005 # [sec] 24 | pybullet.setTimeStep(self.dt) 25 | pybullet.setGravity(0, 0, -9.8) # [m/s^2] 26 | 27 | # Set debug parameters 28 | pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_GUI, 0) 29 | 30 | # Setup models 31 | pybullet.setAdditionalSearchPath(pybullet_data.getDataPath()) 32 | 33 | ## Setup floor 34 | pybullet.loadURDF("plane100.urdf") 35 | 36 | ## Setup cart 37 | self.box_half_scale = np.array([0.35, 0.25, 0.15]) # [m] 38 | box_col_shape_idx = pybullet.createCollisionShape(pybullet.GEOM_BOX, 39 | halfExtents=self.box_half_scale) 40 | self.cylinder_radius = 0.1 # [m] 41 | cylinder_height = 0.1 # [m] 42 | cylinder_col_shape_idx = pybullet.createCollisionShape(pybullet.GEOM_CYLINDER, 43 | radius=self.cylinder_radius, 44 | height=cylinder_height) 45 | box_mass = rospy.get_param("~box_mass", 8.0) # [kg] 46 | self.box_com_offset = np.array([-0.02, 0.0, -0.1]) # [m] 47 | cylinder_mass = 2.0 # [kg] 48 | self.cart_body_uid = pybullet.createMultiBody(baseMass=box_mass, 49 | baseCollisionShapeIndex=box_col_shape_idx, 50 | baseVisualShapeIndex=-1, 51 | basePosition=[0.0, 0.0, 2 * self.cylinder_radius + self.box_half_scale[2]], # [m] 52 | baseOrientation=[0.0, 0.0, 0.0, 1.0], 53 | baseInertialFramePosition=self.box_com_offset, 54 | baseInertialFrameOrientation=[0.0, 0.0, 0.0, 1.0], 55 | linkMasses=[cylinder_mass], 56 | linkCollisionShapeIndices=[cylinder_col_shape_idx], 57 | linkVisualShapeIndices=[-1], 58 | linkPositions=[[0.0, 0.0, -1 * (self.cylinder_radius + self.box_half_scale[2])]], # [m] 59 | linkOrientations=[pybullet.getQuaternionFromEuler([np.pi/2, 0.0, 0.0])], 60 | linkInertialFramePositions=[[0.0, 0.0, 0.0]], # [m] 61 | linkInertialFrameOrientations=[[0.0, 0.0, 0.0, 1.0]], 62 | linkParentIndices=[0], 63 | linkJointTypes=[pybullet.JOINT_FIXED], 64 | linkJointAxis=[[0.0, 0.0, 1.0]]) 65 | pybullet.changeVisualShape(objectUniqueId=self.cart_body_uid, 66 | linkIndex=-1, 67 | rgbaColor=[0.0, 1.0, 0.0, 0.8]) 68 | pybullet.changeVisualShape(objectUniqueId=self.cart_body_uid, 69 | linkIndex=0, 70 | rgbaColor=[0.1, 0.1, 0.1, 0.8]) 71 | 72 | # Set dynamics parameters 73 | pybullet.changeDynamics(bodyUniqueId=self.cart_body_uid, 74 | linkIndex=0, 75 | lateralFriction=rospy.get_param("~lateral_friction", 0.05)) 76 | 77 | # Setup variables 78 | self.force_line_uid = -1 79 | 80 | # Setup ROS 81 | run_sim_once_srv = rospy.Service("/run_sim_once", RunSimOnce, self.runSimOnceCallback) 82 | generate_dataset_srv = rospy.Service("/generate_dataset", GenerateDataset, self.generateDatasetCallback) 83 | 84 | def runOnce(self, manip_force=None): 85 | """"Run simulation step once. 86 | 87 | Args: 88 | manip_force manipulation force in world frame 89 | """ 90 | if manip_force is not None: 91 | # Apply manipulation force 92 | box_link_pos, box_link_rot = pybullet.getBasePositionAndOrientation(bodyUniqueId=self.cart_body_uid) 93 | box_link_pos = np.array(box_link_pos) 94 | box_link_rot = np.array(pybullet.getMatrixFromQuaternion(box_link_rot)).reshape((3, 3)) 95 | manip_pos_local = np.array([-1 * self.box_half_scale[0], 0.0, self.box_half_scale[2]]) - self.box_com_offset 96 | manip_pos = box_link_pos + box_link_rot.dot(manip_pos_local) 97 | pybullet.applyExternalForce(objectUniqueId=self.cart_body_uid, 98 | linkIndex=0, 99 | forceObj=manip_force, 100 | posObj=manip_pos, 101 | flags=pybullet.WORLD_FRAME) 102 | 103 | # Visualize external force 104 | force_scale = 0.01 105 | self.force_line_uid = pybullet.addUserDebugLine(lineFromXYZ=manip_pos, 106 | lineToXYZ=manip_pos + force_scale * manip_force, 107 | lineColorRGB=[1, 0, 0], 108 | lineWidth=5.0, 109 | replaceItemUniqueId=self.force_line_uid) 110 | else: 111 | # Delete external force 112 | if self.force_line_uid != -1: 113 | pybullet.removeUserDebugItem(self.force_line_uid) 114 | self.force_line_uid = -1 115 | 116 | # Process simulation step 117 | pybullet.stepSimulation() 118 | 119 | def getState(self): 120 | """"Get state [p, p_dot, theta, theta_dot].""" 121 | cylinder_link_state = pybullet.getLinkState(bodyUniqueId=self.cart_body_uid, linkIndex=0, computeLinkVelocity=True) 122 | p = cylinder_link_state[4][0] # [m] 123 | p_dot = cylinder_link_state[6][0] # [m/s] 124 | theta = pybullet.getEulerFromQuaternion( 125 | pybullet.getBasePositionAndOrientation(bodyUniqueId=self.cart_body_uid)[1])[1] # [rad] 126 | theta_dot = pybullet.getBaseVelocity(bodyUniqueId=self.cart_body_uid)[1][1] # [rad/s] 127 | return np.array([p, p_dot, theta, theta_dot]) 128 | 129 | def setState(self, state): 130 | """Set state [p, p_dot, theta, theta_dot].""" 131 | p, p_dot, theta, theta_dot = state 132 | local_pos_from_cylinder_to_box = np.array( 133 | [self.box_com_offset[0], 0.0, self.cylinder_radius + self.box_half_scale[2] + self.box_com_offset[2]]) 134 | global_pos_from_cylinder_to_box = np.array( 135 | pybullet.getMatrixFromQuaternion(pybullet.getQuaternionFromEuler( 136 | [0.0, theta, 0.0]))).reshape((3, 3)).dot(local_pos_from_cylinder_to_box) 137 | box_pos = np.array([p, 0.0, self.cylinder_radius]) + global_pos_from_cylinder_to_box 138 | box_rot = pybullet.getQuaternionFromEuler([0.0, theta, 0.0]) 139 | pybullet.resetBasePositionAndOrientation(bodyUniqueId=self.cart_body_uid, 140 | posObj=box_pos, 141 | ornObj=box_rot) 142 | linear_vel = np.array([p_dot, 0.0, 0.0]) + np.cross(np.array([0.0, theta_dot, 0.0]), global_pos_from_cylinder_to_box) 143 | angular_vel = np.array([0.0, theta_dot, 0.0]) 144 | pybullet.resetBaseVelocity(objectUniqueId=self.cart_body_uid, 145 | linearVelocity=linear_vel, 146 | angularVelocity=angular_vel) 147 | 148 | def runSimOnceCallback(self, req): 149 | """ROS service callback to run simulation step once.""" 150 | assert len(req.state) == 0 or len(req.state) == 4, \ 151 | "req.state dimension is invalid {} != 0 or 4".format(len(req.state)) 152 | assert len(req.input) == 2, \ 153 | "req.input dimension is invalid {} != 2".format(len(req.input)) 154 | assert len(req.additional_data) == 0, \ 155 | "req.additional_data dimension is invalid {} != 0".format(len(req.additional_data)) 156 | 157 | if len(req.state) > 0: 158 | self.setState(np.array(req.state)) 159 | 160 | manip_force = np.array([req.input[0], 0.0, req.input[1]]) 161 | for i in range(int(req.dt / self.dt)): 162 | self.runOnce(manip_force) 163 | 164 | res = RunSimOnceResponse() 165 | res.state = self.getState() 166 | return res 167 | 168 | def generateDatasetCallback(self, req): 169 | """ROS service callback to generate dataset.""" 170 | state_min = np.array(req.state_min) 171 | state_max = np.array(req.state_max) 172 | input_min = np.array(req.input_min) 173 | input_max = np.array(req.input_max) 174 | 175 | state_all = [] 176 | input_all = [] 177 | next_state_all = [] 178 | for i in range(req.dataset_size): 179 | state = state_min + np.random.rand(len(state_min)) * (state_max - state_min) 180 | input = input_min + np.random.rand(len(input_min)) * (input_max - input_min) 181 | manip_force = np.array([input[0], 0.0, input[1]]) 182 | self.setState(state) 183 | for i in range(int(req.dt / self.dt)): 184 | self.runOnce(manip_force) 185 | next_state = self.getState() 186 | state_all.append(state) 187 | input_all.append(input) 188 | next_state_all.append(next_state) 189 | 190 | msg = Dataset() 191 | msg.dataset_size = req.dataset_size 192 | msg.dt = req.dt 193 | msg.state_dim = len(state_min) 194 | msg.input_dim = len(input_min) 195 | msg.state_all = np.array(state_all).flatten() 196 | msg.input_all = np.array(input_all).flatten() 197 | msg.next_state_all = np.array(next_state_all).flatten() 198 | bag = rosbag.Bag(req.filename, "w") 199 | bag.write("/dataset", msg) 200 | bag.close() 201 | print("[SimTestMpcCart] Save dataset to {}".format(req.filename)) 202 | 203 | res = GenerateDatasetResponse() 204 | return res 205 | 206 | 207 | def demo(): 208 | sim = SimTestMpcCart(True) 209 | sim.setState([0.3, 1.0, np.deg2rad(-10.0), 0.0]) 210 | 211 | t = 0.0 # [sec] 212 | while pybullet.isConnected(): 213 | # Calculate manipulation force 214 | _, _, theta, theta_dot = sim.getState() 215 | manip_force_z = -500.0 * theta -100.0 * theta_dot # [N] 216 | manip_force = np.array([0.0, 0.0, manip_force_z]) 217 | 218 | # Run simulation step 219 | sim.runOnce(manip_force) 220 | 221 | # Sleep and increment time 222 | time.sleep(sim.dt) 223 | t += sim.dt 224 | 225 | 226 | if __name__ == "__main__": 227 | # demo() 228 | 229 | rospy.init_node("sim_test_mpc_cart") 230 | sim = SimTestMpcCart(enable_gui=rospy.get_param("~enable_gui", True)) 231 | rospy.spin() 232 | -------------------------------------------------------------------------------- /tests/scripts/simTestMpcCartMujoco.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import time 4 | import numpy as np 5 | import mujoco 6 | import mujoco.viewer 7 | 8 | import rospy 9 | import rosbag 10 | from data_driven_mpc.msg import * 11 | from data_driven_mpc.srv import * 12 | 13 | 14 | class SimTestMpcCart(object): 15 | def __init__(self, enable_gui): 16 | # TODO: Support enable_gui=False 17 | 18 | # Setup xml 19 | self.box_half_scale = np.array([0.35, 0.25, 0.15]) # [m] 20 | box_mass = rospy.get_param("~box_mass", 8.0) # [kg] 21 | box_com_offset = np.array([-0.02, 0.0, -0.1]) # [m] 22 | self.cylinder_radius = 0.1 # [m] 23 | cylinder_height = 0.1 # [m] 24 | cylinder_mass = 2.0 # [kg] 25 | lateral_friction = rospy.get_param("~lateral_friction", 0.05) 26 | 27 | xml_str = """ 28 | 29 | 52 | """.format(cart_pos_z=2.0*self.cylinder_radius+self.box_half_scale[2], 53 | box_half_scale=self.box_half_scale, 54 | box_mass=box_mass, 55 | box_com_offset=box_com_offset, 56 | cylinder_radius=self.cylinder_radius, 57 | cylinder_half_height=0.5*cylinder_height, 58 | cylinder_mass=cylinder_mass, 59 | sliding_friction=lateral_friction, 60 | cylinder_pos_z=-1*(self.cylinder_radius+self.box_half_scale[2])) 61 | # with open("/tmp/simTestMpcCartMujoco.xml", "w") as f: 62 | # f.write(xml_str) 63 | 64 | # Instantiate simulator 65 | self.model = mujoco.MjModel.from_xml_string(xml_str) 66 | self.data = mujoco.MjData(self.model) 67 | self.viewer = mujoco.viewer.launch_passive(self.model, self.data) 68 | self.viewer.cam.azimuth = 120.0 69 | self.viewer.cam.distance = 10.0 70 | 71 | # Set simulation parameters 72 | self.dt = 0.005 # [sec] 73 | self.model.opt.timestep = self.dt 74 | 75 | # Setup ROS 76 | run_sim_once_srv = rospy.Service("/run_sim_once", RunSimOnce, self.runSimOnceCallback) 77 | generate_dataset_srv = rospy.Service("/generate_dataset", GenerateDataset, self.generateDatasetCallback) 78 | 79 | def runOnce(self, manip_force=None): 80 | """"Run simulation step once. 81 | 82 | Args: 83 | manip_force manipulation force in world frame 84 | """ 85 | if manip_force is not None: 86 | # Apply manipulation force 87 | box_link_pos = self.data.geom("box").xpos 88 | box_link_quat = np.zeros(4) 89 | mujoco.mju_mat2Quat(box_link_quat, self.data.geom("box").xmat) 90 | manip_pos_local = np.array([-1 * self.box_half_scale[0], 0.0, self.box_half_scale[2]]) 91 | manip_pos = np.zeros(3) 92 | mujoco.mju_trnVecPose(manip_pos, box_link_pos, box_link_quat, manip_pos_local) 93 | manip_moment = np.cross(manip_pos - self.data.body("cart").xipos, manip_force) 94 | self.data.body("cart").xfrc_applied += np.concatenate([manip_force, manip_moment]) 95 | 96 | # TODO: Visualize manipulation force 97 | 98 | # Process simulation step 99 | mujoco.mj_step(self.model, self.data) 100 | self.viewer.sync() 101 | 102 | def getState(self): 103 | """"Get state [p, p_dot, theta, theta_dot].""" 104 | p = self.data.geom("cylinder").xpos[0] # [m] 105 | box_quat = np.zeros(4) 106 | mujoco.mju_mat2Quat(box_quat, self.data.geom("box").xmat) 107 | theta = np.arcsin(2 * (box_quat[0] * box_quat[2] - box_quat[1] * box_quat[3])) # [rad] 108 | local_pos_from_cylinder_to_box = np.array([0.0, 0.0, self.cylinder_radius + self.box_half_scale[2]]) 109 | global_pos_from_cylinder_to_box = np.zeros(3) 110 | mujoco.mju_rotVecQuat(global_pos_from_cylinder_to_box, local_pos_from_cylinder_to_box, box_quat) 111 | vel = np.zeros(6) 112 | mujoco.mj_objectVelocity(self.model, self.data, mujoco.mjtObj.mjOBJ_XBODY, self.model.body("cart").id, vel, 0) 113 | vel[3:6] += np.cross(vel[0:3], -1 * global_pos_from_cylinder_to_box) 114 | p_dot = vel[3] # [m/s] 115 | theta_dot = vel[1] # [rad/s] 116 | return np.array([p, p_dot, theta, theta_dot]) 117 | 118 | def setState(self, state): 119 | """Set state [p, p_dot, theta, theta_dot].""" 120 | p, p_dot, theta, theta_dot = state 121 | box_quat = np.zeros(4) 122 | mujoco.mju_axisAngle2Quat(box_quat, np.array([0.0, 1.0, 0.0]), theta) 123 | local_pos_from_cylinder_to_box = np.array([0.0, 0.0, self.cylinder_radius + self.box_half_scale[2]]) 124 | global_pos_from_cylinder_to_box = np.zeros(3) 125 | mujoco.mju_rotVecQuat(global_pos_from_cylinder_to_box, local_pos_from_cylinder_to_box, box_quat) 126 | box_pos = np.array([p, 0.0, self.cylinder_radius]) + global_pos_from_cylinder_to_box 127 | self.data.qpos = np.concatenate([box_pos, box_quat]) 128 | linear_vel = np.array([p_dot, 0.0, 0.0]) + np.cross(np.array([0.0, theta_dot, 0.0]), global_pos_from_cylinder_to_box) 129 | angular_vel = np.array([0.0, theta_dot, 0.0]) 130 | self.data.qvel = np.concatenate([linear_vel, angular_vel]) 131 | 132 | def runSimOnceCallback(self, req): 133 | """ROS service callback to run simulation step once.""" 134 | assert len(req.state) == 0 or len(req.state) == 4, \ 135 | "req.state dimension is invalid {} != 0 or 4".format(len(req.state)) 136 | assert len(req.input) == 2, \ 137 | "req.input dimension is invalid {} != 2".format(len(req.input)) 138 | assert len(req.additional_data) == 0, \ 139 | "req.additional_data dimension is invalid {} != 0".format(len(req.additional_data)) 140 | 141 | if len(req.state) > 0: 142 | self.setState(np.array(req.state)) 143 | 144 | manip_force = np.array([req.input[0], 0.0, req.input[1]]) 145 | for i in range(int(req.dt / self.dt)): 146 | self.runOnce(manip_force) 147 | 148 | res = RunSimOnceResponse() 149 | res.state = self.getState() 150 | return res 151 | 152 | def generateDatasetCallback(self, req): 153 | """ROS service callback to generate dataset.""" 154 | state_min = np.array(req.state_min) 155 | state_max = np.array(req.state_max) 156 | input_min = np.array(req.input_min) 157 | input_max = np.array(req.input_max) 158 | 159 | state_all = [] 160 | input_all = [] 161 | next_state_all = [] 162 | for i in range(req.dataset_size): 163 | state = state_min + np.random.rand(len(state_min)) * (state_max - state_min) 164 | input = input_min + np.random.rand(len(input_min)) * (input_max - input_min) 165 | manip_force = np.array([input[0], 0.0, input[1]]) 166 | self.setState(state) 167 | for i in range(int(req.dt / self.dt)): 168 | self.runOnce(manip_force) 169 | next_state = self.getState() 170 | state_all.append(state) 171 | input_all.append(input) 172 | next_state_all.append(next_state) 173 | 174 | msg = Dataset() 175 | msg.dataset_size = req.dataset_size 176 | msg.dt = req.dt 177 | msg.state_dim = len(state_min) 178 | msg.input_dim = len(input_min) 179 | msg.state_all = np.array(state_all).flatten() 180 | msg.input_all = np.array(input_all).flatten() 181 | msg.next_state_all = np.array(next_state_all).flatten() 182 | bag = rosbag.Bag(req.filename, "w") 183 | bag.write("/dataset", msg) 184 | bag.close() 185 | print("[SimTestMpcCart] Save dataset to {}".format(req.filename)) 186 | 187 | res = GenerateDatasetResponse() 188 | return res 189 | 190 | 191 | def demo(): 192 | sim = SimTestMpcCart(True) 193 | sim.setState([0.3, 1.0, np.deg2rad(-10.0), 0.0]) 194 | 195 | t = 0.0 # [sec] 196 | while sim.viewer.is_running(): 197 | # Calculate manipulation force 198 | _, _, theta, theta_dot = sim.getState() 199 | manip_force_z = -500.0 * theta -100.0 * theta_dot # [N] 200 | manip_force = np.array([0.0, 0.0, manip_force_z]) 201 | 202 | # Run simulation step 203 | sim.runOnce(manip_force) 204 | 205 | # Sleep and increment time 206 | time.sleep(sim.dt) 207 | t += sim.dt 208 | 209 | 210 | if __name__ == "__main__": 211 | # demo() 212 | 213 | rospy.init_node("sim_test_mpc_cart") 214 | sim = SimTestMpcCart(enable_gui=rospy.get_param("~enable_gui", True)) 215 | rospy.spin() 216 | -------------------------------------------------------------------------------- /tests/src/TestDataset.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | TEST(TestDataset, Test1) 9 | { 10 | int dataset_size = 100; 11 | int state_dim = 2; 12 | int input_dim = 1; 13 | torch::Tensor state_all = DDMPC::toTorchTensor(Eigen::MatrixXf::Random(dataset_size, state_dim)); 14 | torch::Tensor input_all = DDMPC::toTorchTensor(Eigen::MatrixXf::Random(dataset_size, input_dim)); 15 | torch::Tensor next_state_all = DDMPC::toTorchTensor(Eigen::MatrixXf::Random(dataset_size, state_dim)); 16 | 17 | DDMPC::Dataset dataset(state_all, input_all, next_state_all); 18 | 19 | for(int i = 0; i < dataset_size; i++) 20 | { 21 | const DDMPC::Data & data = static_cast(dataset.get(i)); 22 | EXPECT_LT((data.state_ - state_all[i]).norm().item(), 1e-8); 23 | EXPECT_LT((data.input_ - input_all[i]).norm().item(), 1e-8); 24 | EXPECT_LT((data.next_state_ - next_state_all[i]).norm().item(), 1e-8); 25 | } 26 | } 27 | 28 | int main(int argc, char ** argv) 29 | { 30 | testing::InitGoogleTest(&argc, argv); 31 | return RUN_ALL_TESTS(); 32 | } 33 | -------------------------------------------------------------------------------- /tests/src/TestMathUtils.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | 7 | template 8 | void testStandardScaler() 9 | { 10 | // Generate data 11 | Eigen::RowVector3d scale(10.0, 20.0, 30.0); 12 | Eigen::RowVector3d offset(0.0, 100.0, -200.0); 13 | Eigen::MatrixX3d data_all = (Eigen::MatrixX3d::Random(10000, 3) * scale.asDiagonal()).rowwise() + offset; 14 | Eigen::MatrixX3d train_data = data_all.topRows(static_cast(0.7 * data_all.rows())); 15 | Eigen::MatrixX3d test_data = data_all.bottomRows(static_cast(0.3 * data_all.rows())); 16 | 17 | // Apply standardization 18 | DDMPC::StandardScaler standard_scaler(train_data); 19 | Eigen::MatrixX3d standardized_train_data = standard_scaler.apply(train_data); 20 | Eigen::MatrixX3d standardized_test_data = standard_scaler.apply(test_data); 21 | 22 | // Check standardized train data 23 | { 24 | Eigen::RowVector3d mean = DDMPC::StandardScaler::calcMean(standardized_train_data); 25 | Eigen::RowVector3d stddev = DDMPC::StandardScaler::calcStddev(standardized_train_data, mean); 26 | EXPECT_LT(mean.norm(), 1e-10); 27 | EXPECT_LT((stddev.array() - 1.0).matrix().norm(), 1e-10); 28 | } 29 | 30 | // Check standardized test data 31 | { 32 | Eigen::RowVector3d mean = DDMPC::StandardScaler::calcMean(standardized_test_data); 33 | Eigen::RowVector3d stddev = DDMPC::StandardScaler::calcStddev(standardized_test_data, mean); 34 | EXPECT_LT(mean.norm(), 0.1); 35 | EXPECT_LT((stddev.array() - 1.0).matrix().norm(), 0.1); 36 | } 37 | 38 | // Check inverse standardization 39 | { 40 | Eigen::MatrixX3d restored_test_data = standard_scaler.applyInv(standardized_test_data); 41 | EXPECT_LT((test_data - restored_test_data).norm(), 1e-10); 42 | 43 | Eigen::Vector3d single_test_data = test_data.row(0).transpose(); 44 | EXPECT_LT((single_test_data - standard_scaler.applyOneInv(standard_scaler.applyOne(single_test_data))).norm(), 45 | 1e-10); 46 | EXPECT_LT((single_test_data - standard_scaler.applyOne(standard_scaler.applyOneInv(single_test_data))).norm(), 47 | 1e-10); 48 | } 49 | } 50 | 51 | TEST(TestMathUtils, StandardScalerFixedSize) 52 | { 53 | testStandardScaler<3>(); 54 | } 55 | 56 | TEST(TestMathUtils, StandardScalerDynamicSize) 57 | { 58 | testStandardScaler(); 59 | } 60 | 61 | int main(int argc, char ** argv) 62 | { 63 | testing::InitGoogleTest(&argc, argv); 64 | return RUN_ALL_TESTS(); 65 | } 66 | -------------------------------------------------------------------------------- /tests/src/TestMpcCart.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | /** \brief DDP problem based on data-driven state equation. 24 | 25 | State consists of [p, p_dot, theta, theta_dot]. 26 | Input consists of [fx, fz]. 27 | p is the cart position. theta is the cart angle. 28 | fx and fz are the manipulation force in the X and Z directions. 29 | */ 30 | class DDPProblem : public nmpc_ddp::DDPProblem<4, 2> 31 | { 32 | public: 33 | struct WeightParam 34 | { 35 | StateDimVector running_state; 36 | InputDimVector running_input; 37 | StateDimVector terminal_state; 38 | 39 | WeightParam(const StateDimVector & _running_state = StateDimVector(1e2, 1e-2, 1e2, 1e-2), 40 | const InputDimVector & _running_input = InputDimVector::Constant(1e-2), 41 | const StateDimVector & _terminal_state = StateDimVector(1e2, 1e0, 1e2, 1e0)) 42 | : running_state(_running_state), running_input(_running_input), terminal_state(_terminal_state) 43 | { 44 | } 45 | }; 46 | 47 | public: 48 | DDPProblem(double dt, 49 | const std::shared_ptr & state_eq, 50 | const WeightParam & weight_param = WeightParam()) 51 | : nmpc_ddp::DDPProblem<4, 2>(dt), state_eq_(state_eq), weight_param_(weight_param) 52 | { 53 | } 54 | 55 | virtual StateDimVector stateEq(double t, const StateDimVector & x, const InputDimVector & u) const override 56 | { 57 | return next_state_standard_scaler_->applyOneInv( 58 | state_eq_->eval(state_standard_scaler_->applyOne(x), input_standard_scaler_->applyOne(u))); 59 | } 60 | 61 | virtual double runningCost(double t, const StateDimVector & x, const InputDimVector & u) const override 62 | { 63 | return 0.5 * weight_param_.running_state.dot(x.cwiseAbs2()) + 0.5 * weight_param_.running_input.dot(u.cwiseAbs2()); 64 | } 65 | 66 | virtual double terminalCost(double t, const StateDimVector & x) const override 67 | { 68 | return 0.5 * weight_param_.terminal_state.dot(x.cwiseAbs2()); 69 | } 70 | 71 | virtual void calcStateEqDeriv(double t, 72 | const StateDimVector & x, 73 | const InputDimVector & u, 74 | Eigen::Ref state_eq_deriv_x, 75 | Eigen::Ref state_eq_deriv_u) const override 76 | { 77 | state_eq_->eval(state_standard_scaler_->applyOne(x), input_standard_scaler_->applyOne(u), state_eq_deriv_x, 78 | state_eq_deriv_u); 79 | state_eq_deriv_x.array().colwise() *= next_state_standard_scaler_->stddev_vec_.transpose().array(); 80 | state_eq_deriv_x.array().rowwise() /= state_standard_scaler_->stddev_vec_.array(); 81 | state_eq_deriv_u.array().colwise() *= next_state_standard_scaler_->stddev_vec_.transpose().array(); 82 | state_eq_deriv_u.array().rowwise() /= input_standard_scaler_->stddev_vec_.array(); 83 | } 84 | 85 | virtual void calcStateEqDeriv(double t, 86 | const StateDimVector & x, 87 | const InputDimVector & u, 88 | Eigen::Ref state_eq_deriv_x, 89 | Eigen::Ref state_eq_deriv_u, 90 | std::vector & state_eq_deriv_xx, 91 | std::vector & state_eq_deriv_uu, 92 | std::vector & state_eq_deriv_xu) const override 93 | { 94 | throw std::runtime_error("Second-order derivatives of state equation are not implemented."); 95 | } 96 | 97 | virtual void calcRunningCostDeriv(double t, 98 | const StateDimVector & x, 99 | const InputDimVector & u, 100 | Eigen::Ref running_cost_deriv_x, 101 | Eigen::Ref running_cost_deriv_u) const override 102 | { 103 | running_cost_deriv_x = weight_param_.running_state.cwiseProduct(x); 104 | running_cost_deriv_u = weight_param_.running_input.cwiseProduct(u); 105 | } 106 | 107 | virtual void calcRunningCostDeriv(double t, 108 | const StateDimVector & x, 109 | const InputDimVector & u, 110 | Eigen::Ref running_cost_deriv_x, 111 | Eigen::Ref running_cost_deriv_u, 112 | Eigen::Ref running_cost_deriv_xx, 113 | Eigen::Ref running_cost_deriv_uu, 114 | Eigen::Ref running_cost_deriv_xu) const override 115 | { 116 | calcRunningCostDeriv(t, x, u, running_cost_deriv_x, running_cost_deriv_u); 117 | 118 | running_cost_deriv_xx.setZero(); 119 | running_cost_deriv_xx.diagonal() = weight_param_.running_state; 120 | running_cost_deriv_uu.setZero(); 121 | running_cost_deriv_uu.diagonal() = weight_param_.running_input; 122 | running_cost_deriv_xu.setZero(); 123 | } 124 | 125 | virtual void calcTerminalCostDeriv(double t, 126 | const StateDimVector & x, 127 | Eigen::Ref terminal_cost_deriv_x) const override 128 | { 129 | terminal_cost_deriv_x = weight_param_.terminal_state.cwiseProduct(x); 130 | } 131 | 132 | virtual void calcTerminalCostDeriv(double t, 133 | const StateDimVector & x, 134 | Eigen::Ref terminal_cost_deriv_x, 135 | Eigen::Ref terminal_cost_deriv_xx) const override 136 | { 137 | calcTerminalCostDeriv(t, x, terminal_cost_deriv_x); 138 | 139 | terminal_cost_deriv_xx.setZero(); 140 | terminal_cost_deriv_xx.diagonal() = weight_param_.terminal_state; 141 | } 142 | 143 | void setStandardScaler(const std::shared_ptr> & state_standard_scaler, 144 | const std::shared_ptr> & input_standard_scaler, 145 | const std::shared_ptr> & next_state_standard_scaler) 146 | 147 | { 148 | state_standard_scaler_ = state_standard_scaler; 149 | input_standard_scaler_ = input_standard_scaler; 150 | next_state_standard_scaler_ = next_state_standard_scaler; 151 | } 152 | 153 | protected: 154 | std::shared_ptr state_eq_; 155 | 156 | WeightParam weight_param_; 157 | 158 | std::shared_ptr> state_standard_scaler_; 159 | std::shared_ptr> input_standard_scaler_; 160 | std::shared_ptr> next_state_standard_scaler_; 161 | }; 162 | 163 | namespace Eigen 164 | { 165 | using MatrixXdRowMajor = Eigen::Matrix; 166 | } 167 | 168 | TEST(TestMpcCart, Test1) 169 | { 170 | ros::NodeHandle nh; 171 | ros::NodeHandle pnh("~"); 172 | ros::ServiceClient generate_dataset_cli = nh.serviceClient("/generate_dataset"); 173 | ros::ServiceClient run_sim_once_cli = nh.serviceClient("/run_sim_once"); 174 | ASSERT_TRUE(generate_dataset_cli.waitForExistence(ros::Duration(10.0))) 175 | << "[TestMpcCart] Failed to wait for ROS service to generate dataset." << std::endl; 176 | ASSERT_TRUE(run_sim_once_cli.waitForExistence(ros::Duration(10.0))) 177 | << "[TestMpcCart] Failed to wait for ROS service to run simulation once." << std::endl; 178 | 179 | //// 1. Train state equation //// 180 | double horizon_dt = 0.1; // [sec] 181 | 182 | // Instantiate state equation 183 | int state_dim = 4; 184 | int input_dim = 2; 185 | int middle_layer_dim = 32; 186 | auto state_eq = std::make_shared(state_dim, input_dim, middle_layer_dim); 187 | 188 | // Instantiate problem 189 | auto ddp_problem = std::make_shared(horizon_dt, state_eq); 190 | 191 | // Call service to generate dataset 192 | auto start_dataset_time = std::chrono::system_clock::now(); 193 | data_driven_mpc::GenerateDataset generate_dataset_srv; 194 | std::string dataset_filename = ros::package::getPath("data_driven_mpc") + "/tests/data/TestMpcCartDataset.bag"; 195 | int dataset_size = 10000; 196 | DDPProblem::StateDimVector x_max = DDPProblem::StateDimVector(1.0, 0.2, 0.4, 0.5); 197 | DDPProblem::InputDimVector u_max = DDPProblem::InputDimVector(15.0, 15.0); 198 | generate_dataset_srv.request.filename = dataset_filename; 199 | generate_dataset_srv.request.dataset_size = dataset_size; 200 | generate_dataset_srv.request.dt = horizon_dt; 201 | generate_dataset_srv.request.state_max.resize(state_dim); 202 | DDPProblem::StateDimVector::Map(&generate_dataset_srv.request.state_max[0], state_dim) = x_max; 203 | generate_dataset_srv.request.state_min.resize(state_dim); 204 | DDPProblem::StateDimVector::Map(&generate_dataset_srv.request.state_min[0], state_dim) = -1 * x_max; 205 | generate_dataset_srv.request.input_max.resize(input_dim); 206 | DDPProblem::InputDimVector::Map(&generate_dataset_srv.request.input_max[0], input_dim) = u_max; 207 | generate_dataset_srv.request.input_min.resize(input_dim); 208 | DDPProblem::InputDimVector::Map(&generate_dataset_srv.request.input_min[0], input_dim) = -1 * u_max; 209 | ASSERT_TRUE(generate_dataset_cli.call(generate_dataset_srv)) 210 | << "[TestMpcCart] Failed to call ROS service to generate dataset." << std::endl; 211 | 212 | // Load dataset from rosbag 213 | Eigen::MatrixXd state_all; 214 | Eigen::MatrixXd input_all; 215 | Eigen::MatrixXd next_state_all; 216 | rosbag::Bag dataset_bag; 217 | dataset_bag.open(dataset_filename, rosbag::bagmode::Read); 218 | for(rosbag::MessageInstance const msg : 219 | rosbag::View(dataset_bag, rosbag::TopicQuery(std::vector{"/dataset"}))) 220 | { 221 | data_driven_mpc::Dataset::ConstPtr dataset_msg = msg.instantiate(); 222 | state_all = Eigen::Map(dataset_msg->state_all.data(), dataset_size, state_dim); 223 | input_all = Eigen::Map(dataset_msg->input_all.data(), dataset_size, input_dim); 224 | next_state_all = 225 | Eigen::Map(dataset_msg->next_state_all.data(), dataset_size, state_dim); 226 | break; 227 | } 228 | dataset_bag.close(); 229 | 230 | // Instantiate standardization scalar 231 | auto state_standard_scaler = std::make_shared>(state_all); 232 | auto input_standard_scaler = std::make_shared>(input_all); 233 | auto next_state_standard_scaler = std::make_shared>(next_state_all); 234 | ddp_problem->setStandardScaler(state_standard_scaler, input_standard_scaler, next_state_standard_scaler); 235 | 236 | // Instantiate dataset 237 | std::shared_ptr train_dataset; 238 | std::shared_ptr test_dataset; 239 | DDMPC::makeDataset(DDMPC::toTorchTensor(state_standard_scaler->apply(state_all).cast()), 240 | DDMPC::toTorchTensor(input_standard_scaler->apply(input_all).cast()), 241 | DDMPC::toTorchTensor(next_state_standard_scaler->apply(next_state_all).cast()), 242 | train_dataset, test_dataset); 243 | std::cout << "dataset duration: " 244 | << std::chrono::duration_cast>(std::chrono::system_clock::now() 245 | - start_dataset_time) 246 | .count() 247 | << " [s]" << std::endl; 248 | 249 | // Training model 250 | auto start_train_time = std::chrono::system_clock::now(); 251 | DDMPC::Training training; 252 | std::string model_path = ros::package::getPath("data_driven_mpc") + "/tests/data/TestMpcCartModel.pt"; 253 | int batch_size = 256; 254 | int num_epoch = 250; 255 | double learning_rate = 1e-3; 256 | training.run(state_eq, train_dataset, test_dataset, model_path, batch_size, num_epoch, learning_rate); 257 | training.load(state_eq, model_path); 258 | std::cout << "train duration: " 259 | << std::chrono::duration_cast>(std::chrono::system_clock::now() 260 | - start_train_time) 261 | .count() 262 | << " [s]" << std::endl; 263 | 264 | std::cout << "Run the following commands in gnuplot:\n" 265 | << " set key autotitle columnhead\n" 266 | << " set key noenhanced\n" 267 | << " plot \"/tmp/DataDrivenMPCTraining.txt\" u 1:2 w lp, \"\" u 1:3 w lp\n"; 268 | 269 | //// 2. Run MPC //// 270 | double horizon_duration = 2.0; // [sec] 271 | int horizon_steps = static_cast(horizon_duration / horizon_dt); 272 | double end_t = 3.0; // [sec] 273 | 274 | // Instantiate solver 275 | auto ddp_solver = std::make_shared>(ddp_problem); 276 | auto input_limits_func = [&](double t) -> std::array { 277 | std::array limits; 278 | limits[0] = -1 * u_max; 279 | limits[1] = u_max; 280 | return limits; 281 | }; 282 | ddp_solver->setInputLimitsFunc(input_limits_func); 283 | ddp_solver->config().with_input_constraint = true; 284 | ddp_solver->config().horizon_steps = horizon_steps; 285 | ddp_solver->config().max_iter = 5; 286 | 287 | // Initialize MPC 288 | double sim_dt = 0.05; // [sec] 289 | double current_t = 0; 290 | DDPProblem::StateDimVector current_x = DDPProblem::StateDimVector(-0.2, 0.0, 0.2, 0.0); 291 | std::vector current_u_list(horizon_steps, DDPProblem::InputDimVector::Zero()); 292 | 293 | // Run MPC loop 294 | std::string file_path = "/tmp/TestMpcCartResult.txt"; 295 | std::ofstream ofs(file_path); 296 | ofs << "time p p_dot theta theta_dot fx fz ddp_iter computation_time" << std::endl; 297 | bool no_exit = false; 298 | pnh.getParam("no_exit", no_exit); 299 | while(no_exit || current_t < end_t) 300 | { 301 | // Solve 302 | auto start_time = std::chrono::system_clock::now(); 303 | ddp_solver->solve(current_t, current_x, current_u_list); 304 | 305 | // Set input 306 | const auto & input_limits = input_limits_func(current_t); 307 | DDPProblem::InputDimVector current_u = 308 | ddp_solver->controlData().u_list[0].cwiseMax(input_limits[0]).cwiseMin(input_limits[1]); 309 | double duration = 310 | 1e3 311 | * std::chrono::duration_cast>(std::chrono::system_clock::now() - start_time) 312 | .count(); 313 | 314 | // Check 315 | for(int i = 0; i < state_dim; i++) 316 | { 317 | EXPECT_LT(std::abs(current_x[i]), 5 * x_max[i]) << "[TestMpcCart] Violate x[" << i << "] limits." << std::endl; 318 | } 319 | for(int i = 0; i < input_dim; i++) 320 | { 321 | EXPECT_LE(std::abs(current_u[i]), u_max[i]) << "[TestMpcCart] Violate u[" << i << "] limits." << std::endl; 322 | } 323 | 324 | // Dump 325 | ofs << current_t << " " << current_x.transpose() << " " << current_u.transpose() << " " 326 | << ddp_solver->traceDataList().back().iter << " " << duration << std::endl; 327 | 328 | // Update to next step 329 | current_t += sim_dt; 330 | data_driven_mpc::RunSimOnce run_sim_once_srv; 331 | run_sim_once_srv.request.dt = sim_dt; 332 | run_sim_once_srv.request.state.resize(state_dim); 333 | DDPProblem::StateDimVector::Map(&run_sim_once_srv.request.state[0], state_dim) = current_x; 334 | run_sim_once_srv.request.input.resize(input_dim); 335 | DDPProblem::InputDimVector::Map(&run_sim_once_srv.request.input[0], input_dim) = current_u; 336 | ASSERT_TRUE(run_sim_once_cli.call(run_sim_once_srv)) 337 | << "[TestMpcCart] Failed to call ROS service to run simulation once." << std::endl; 338 | current_x = DDPProblem::StateDimVector::Map(&run_sim_once_srv.response.state[0], state_dim); 339 | current_u_list = ddp_solver->controlData().u_list; 340 | } 341 | 342 | // Final check 343 | const DDPProblem::InputDimVector & current_u = ddp_solver->controlData().u_list[0]; 344 | EXPECT_LT(std::abs(current_x[0]), 0.15); 345 | EXPECT_LT(std::abs(current_x[1]), 0.5); 346 | EXPECT_LT(std::abs(current_x[2]), 0.15); 347 | EXPECT_LT(std::abs(current_x[3]), 0.5); 348 | EXPECT_LT(std::abs(current_u[0]), 15.0); 349 | EXPECT_LT(std::abs(current_u[1]), 15.0); 350 | 351 | std::cout << "Run the following commands in gnuplot:\n" 352 | << " set key autotitle columnhead\n" 353 | << " set key noenhanced\n" 354 | << " plot \"" << file_path << "\" u 1:2 w lp, \"\" u 1:4 w lp # state\n" 355 | << " plot \"" << file_path << "\" u 1:6 w lp, \"\" u 1:7 w lp # input\n"; 356 | } 357 | 358 | int main(int argc, char ** argv) 359 | { 360 | // Setup ROS 361 | ros::init(argc, argv, "test_mpc_cart"); 362 | 363 | testing::InitGoogleTest(&argc, argv); 364 | return RUN_ALL_TESTS(); 365 | } 366 | -------------------------------------------------------------------------------- /tests/src/TestMpcCartWalk.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | /** \brief DDP problem based on combination of analytical and data-driven models. 24 | 25 | State consists of [robot_com_pos_x, robot_com_vel_x, robot_com_pos_y, robot_com_vel_y, obj_p, obj_p_dot, obj_theta, 26 | obj_theta_dot]. Input consists of [robot_zmp_x, robot_zmp_y, obj_fx, obj_fz]. obj_p is the cart position. obj_theta 27 | is the cart angle. obj_fx and obj_fz are the manipulation force applied to an object in the X and Z directions. 28 | */ 29 | class DDPProblem : public nmpc_ddp::DDPProblem<8, 4> 30 | { 31 | public: 32 | static constexpr int RobotStateDim = 4; 33 | static constexpr int ObjStateDim = 4; 34 | static constexpr int ObjInputDim = 2; 35 | 36 | public: 37 | using RobotStateDimVector = Eigen::Matrix; 38 | using ObjStateDimVector = Eigen::Matrix; 39 | using ObjInputDimVector = Eigen::Matrix; 40 | 41 | public: 42 | struct WeightParam 43 | { 44 | StateDimVector running_state; 45 | InputDimVector running_input; 46 | StateDimVector terminal_state; 47 | 48 | WeightParam() 49 | { 50 | running_state << 0.0, 1e-4, 0.0, 1e-4, 1e3, 1e-4, 1e2, 1e-4; 51 | running_input << 1e1, 1e1, 1e-1, 1e-1; 52 | terminal_state << 1.0, 1.0, 1.0, 1.0, 1e2, 1.0, 1.0, 1.0; 53 | } 54 | }; 55 | 56 | public: 57 | DDPProblem( 58 | double dt, 59 | const std::shared_ptr & state_eq, 60 | const std::function, Eigen::Ref)> & ref_func = nullptr, 61 | const WeightParam & weight_param = WeightParam()) 62 | : nmpc_ddp::DDPProblem<8, 4>(dt), state_eq_(state_eq), ref_func_(ref_func), weight_param_(weight_param) 63 | { 64 | } 65 | 66 | RobotStateDimVector simulateRobot(const StateDimVector & x, const InputDimVector & u, double dt) const 67 | { 68 | double robot_com_pos_x = x[0]; 69 | double robot_com_vel_x = x[1]; 70 | double robot_com_pos_y = x[2]; 71 | double robot_com_vel_y = x[3]; 72 | double obj_p = x[4]; 73 | double robot_zmp_x = u[0]; 74 | double robot_zmp_y = u[1]; 75 | double obj_fx = u[2]; 76 | double obj_fz = u[3]; 77 | 78 | // See equation (3) in https://hal.archives-ouvertes.fr/hal-03425667/ 79 | double omega2 = gravity_acc_ / robot_com_height_; 80 | double zeta = robot_mass_ * gravity_acc_; 81 | double kappa = 1.0 + obj_fz / zeta; 82 | double gamma_x = (-1 * grasp_pos_z_ * obj_fx + (obj_p + grasp_pos_x_local_) * obj_fz) / zeta; 83 | 84 | RobotStateDimVector x_dot; 85 | x_dot[0] = robot_com_vel_x; 86 | x_dot[1] = omega2 * (robot_com_pos_x - kappa * robot_zmp_x + gamma_x); 87 | x_dot[2] = robot_com_vel_y; 88 | x_dot[3] = omega2 * (robot_com_pos_y - kappa * robot_zmp_y); 89 | 90 | return x.head<4>() + dt * x_dot; 91 | } 92 | 93 | ObjStateDimVector simulateObj(const ObjStateDimVector & x, const ObjInputDimVector & u, double dt) const 94 | { 95 | data_driven_mpc::RunSimOnce run_sim_once_srv; 96 | run_sim_once_srv.request.dt = dt; 97 | run_sim_once_srv.request.state.resize(ObjStateDim); 98 | ObjStateDimVector::Map(&run_sim_once_srv.request.state[0], ObjStateDim) = x; 99 | run_sim_once_srv.request.input.resize(ObjInputDim); 100 | ObjInputDimVector::Map(&run_sim_once_srv.request.input[0], ObjInputDim) = u; 101 | _callRunSimOnce(run_sim_once_srv); 102 | return ObjStateDimVector::Map(&run_sim_once_srv.response.state[0], ObjStateDim); 103 | } 104 | 105 | StateDimVector simulate(const StateDimVector & x, const InputDimVector & u, double dt) const 106 | { 107 | StateDimVector next_x; 108 | next_x.head<4>() = simulateRobot(x, u, dt); 109 | next_x.tail<4>() = simulateObj(x.tail<4>(), u.tail<2>(), dt); 110 | 111 | return next_x; 112 | } 113 | 114 | virtual StateDimVector stateEq(double t, const StateDimVector & x, const InputDimVector & u) const override 115 | { 116 | StateDimVector next_x; 117 | next_x.head<4>() = simulateRobot(x, u, dt_); 118 | next_x.tail<4>() = next_state_standard_scaler_->applyOneInv( 119 | state_eq_->eval(state_standard_scaler_->applyOne(x.tail<4>()), input_standard_scaler_->applyOne(u.tail<2>()))); 120 | 121 | return next_x; 122 | } 123 | 124 | virtual double runningCost(double t, const StateDimVector & x, const InputDimVector & u) const override 125 | { 126 | StateDimVector ref_x; 127 | InputDimVector ref_u; 128 | ref_func_(t, ref_x, ref_u); 129 | 130 | return 0.5 * weight_param_.running_state.dot((x - ref_x).cwiseAbs2()) 131 | + 0.5 * weight_param_.running_input.dot((u - ref_u).cwiseAbs2()); 132 | } 133 | 134 | virtual double terminalCost(double t, const StateDimVector & x) const override 135 | { 136 | StateDimVector ref_x; 137 | InputDimVector ref_u; 138 | ref_func_(t, ref_x, ref_u); 139 | 140 | return 0.5 * weight_param_.terminal_state.dot((x - ref_x).cwiseAbs2()); 141 | } 142 | 143 | virtual void calcStateEqDeriv(double t, 144 | const StateDimVector & x, 145 | const InputDimVector & u, 146 | Eigen::Ref state_eq_deriv_x, 147 | Eigen::Ref state_eq_deriv_u) const override 148 | { 149 | double robot_com_pos_x = x[0]; 150 | double robot_com_vel_x = x[1]; 151 | double robot_com_pos_y = x[2]; 152 | double robot_com_vel_y = x[3]; 153 | double obj_p = x[4]; 154 | double robot_zmp_x = u[0]; 155 | double robot_zmp_y = u[1]; 156 | double obj_fx = u[2]; 157 | double obj_fz = u[3]; 158 | 159 | // See equation (3) in https://hal.archives-ouvertes.fr/hal-03425667/ 160 | double omega2 = gravity_acc_ / robot_com_height_; 161 | double zeta = robot_mass_ * gravity_acc_; 162 | double kappa = 1.0 + obj_fz / zeta; 163 | double gamma_x = (-1 * grasp_pos_z_ * obj_fx + (obj_p + grasp_pos_x_local_) * obj_fz) / zeta; 164 | 165 | state_eq_deriv_x.setZero(); 166 | state_eq_deriv_u.setZero(); 167 | 168 | state_eq_deriv_x(0, 1) = 1; 169 | state_eq_deriv_x(1, 0) = omega2; 170 | state_eq_deriv_x(1, 4) = omega2 * obj_fz / zeta; 171 | state_eq_deriv_x(2, 3) = 1; 172 | state_eq_deriv_x(3, 2) = omega2; 173 | state_eq_deriv_x.topRows<4>() *= dt_; 174 | state_eq_deriv_x.diagonal().head<4>().array() += 1.0; 175 | 176 | state_eq_deriv_u(1, 0) = -1 * omega2 * kappa; 177 | state_eq_deriv_u(1, 2) = -1 * omega2 * grasp_pos_z_ / zeta; 178 | state_eq_deriv_u(1, 3) = (omega2 / zeta) * (-1 * robot_zmp_x + (obj_p + grasp_pos_x_local_)); 179 | state_eq_deriv_u(3, 1) = -1 * omega2 * kappa; 180 | state_eq_deriv_u(3, 3) = -1 * omega2 * robot_zmp_y / zeta; 181 | state_eq_deriv_u.topRows<4>() *= dt_; 182 | 183 | state_eq_->eval(state_standard_scaler_->applyOne(x.tail<4>()), input_standard_scaler_->applyOne(u.tail<2>()), 184 | state_eq_deriv_x.bottomRightCorner<4, 4>(), state_eq_deriv_u.bottomRightCorner<4, 2>()); 185 | state_eq_deriv_x.bottomRightCorner<4, 4>().array().colwise() *= 186 | next_state_standard_scaler_->stddev_vec_.transpose().array(); 187 | state_eq_deriv_x.bottomRightCorner<4, 4>().array().rowwise() /= state_standard_scaler_->stddev_vec_.array(); 188 | state_eq_deriv_u.bottomRightCorner<4, 2>().array().colwise() *= 189 | next_state_standard_scaler_->stddev_vec_.transpose().array(); 190 | state_eq_deriv_u.bottomRightCorner<4, 2>().array().rowwise() /= input_standard_scaler_->stddev_vec_.array(); 191 | } 192 | 193 | virtual void calcStateEqDeriv(double t, 194 | const StateDimVector & x, 195 | const InputDimVector & u, 196 | Eigen::Ref state_eq_deriv_x, 197 | Eigen::Ref state_eq_deriv_u, 198 | std::vector & state_eq_deriv_xx, 199 | std::vector & state_eq_deriv_uu, 200 | std::vector & state_eq_deriv_xu) const override 201 | { 202 | throw std::runtime_error("Second-order derivatives of state equation are not implemented."); 203 | } 204 | 205 | virtual void calcRunningCostDeriv(double t, 206 | const StateDimVector & x, 207 | const InputDimVector & u, 208 | Eigen::Ref running_cost_deriv_x, 209 | Eigen::Ref running_cost_deriv_u) const override 210 | { 211 | StateDimVector ref_x; 212 | InputDimVector ref_u; 213 | ref_func_(t, ref_x, ref_u); 214 | 215 | running_cost_deriv_x = weight_param_.running_state.cwiseProduct(x - ref_x); 216 | running_cost_deriv_u = weight_param_.running_input.cwiseProduct(u - ref_u); 217 | } 218 | 219 | virtual void calcRunningCostDeriv(double t, 220 | const StateDimVector & x, 221 | const InputDimVector & u, 222 | Eigen::Ref running_cost_deriv_x, 223 | Eigen::Ref running_cost_deriv_u, 224 | Eigen::Ref running_cost_deriv_xx, 225 | Eigen::Ref running_cost_deriv_uu, 226 | Eigen::Ref running_cost_deriv_xu) const override 227 | { 228 | calcRunningCostDeriv(t, x, u, running_cost_deriv_x, running_cost_deriv_u); 229 | 230 | running_cost_deriv_xx.setZero(); 231 | running_cost_deriv_xx.diagonal() = weight_param_.running_state; 232 | running_cost_deriv_uu.setZero(); 233 | running_cost_deriv_uu.diagonal() = weight_param_.running_input; 234 | running_cost_deriv_xu.setZero(); 235 | } 236 | 237 | virtual void calcTerminalCostDeriv(double t, 238 | const StateDimVector & x, 239 | Eigen::Ref terminal_cost_deriv_x) const override 240 | { 241 | StateDimVector ref_x; 242 | InputDimVector ref_u; 243 | ref_func_(t, ref_x, ref_u); 244 | 245 | terminal_cost_deriv_x = weight_param_.terminal_state.cwiseProduct(x - ref_x); 246 | } 247 | 248 | virtual void calcTerminalCostDeriv(double t, 249 | const StateDimVector & x, 250 | Eigen::Ref terminal_cost_deriv_x, 251 | Eigen::Ref terminal_cost_deriv_xx) const override 252 | { 253 | calcTerminalCostDeriv(t, x, terminal_cost_deriv_x); 254 | 255 | terminal_cost_deriv_xx.setZero(); 256 | terminal_cost_deriv_xx.diagonal() = weight_param_.terminal_state; 257 | } 258 | 259 | void setStandardScaler(const std::shared_ptr> & state_standard_scaler, 260 | const std::shared_ptr> & input_standard_scaler, 261 | const std::shared_ptr> & next_state_standard_scaler) 262 | 263 | { 264 | state_standard_scaler_ = state_standard_scaler; 265 | input_standard_scaler_ = input_standard_scaler; 266 | next_state_standard_scaler_ = next_state_standard_scaler; 267 | } 268 | 269 | protected: 270 | void _callRunSimOnce(data_driven_mpc::RunSimOnce & run_sim_once_srv) const 271 | { 272 | // This is a wrapper for google-test's ASSERT_*, which can only be used with void functions 273 | // https://google.github.io/googletest/advanced.html#assertion-placement 274 | ASSERT_TRUE(ros::service::call("/run_sim_once", run_sim_once_srv)) 275 | << "[TestMpcCartWalk] Failed to call ROS service to run simulation once." << std::endl; 276 | } 277 | 278 | protected: 279 | std::shared_ptr state_eq_; 280 | 281 | std::function, Eigen::Ref)> ref_func_; 282 | 283 | WeightParam weight_param_; 284 | 285 | std::shared_ptr> state_standard_scaler_; 286 | std::shared_ptr> input_standard_scaler_; 287 | std::shared_ptr> next_state_standard_scaler_; 288 | 289 | double gravity_acc_ = 9.8; // [m/s^2] 290 | double robot_mass_ = 60.0; // [kg] 291 | double robot_com_height_ = 0.8; // [m] 292 | double grasp_pos_x_local_ = -0.35; // [m] 293 | double grasp_pos_z_ = 0.5; // [m] 294 | }; 295 | 296 | namespace Eigen 297 | { 298 | using MatrixXdRowMajor = Eigen::Matrix; 299 | } 300 | 301 | TEST(TestMpcCartWalk, RunMPC) 302 | { 303 | ros::NodeHandle nh; 304 | ros::NodeHandle pnh("~"); 305 | ros::ServiceClient generate_dataset_cli = nh.serviceClient("/generate_dataset"); 306 | ASSERT_TRUE(generate_dataset_cli.waitForExistence(ros::Duration(10.0))) 307 | << "[TestMpcCartWalk] Failed to wait for ROS service to generate dataset." << std::endl; 308 | ASSERT_TRUE(ros::service::waitForService("/run_sim_once", ros::Duration(10.0))) 309 | << "[TestMpcCartWalk] Failed to wait for ROS service to run simulation once." << std::endl; 310 | 311 | //// 1. Train state equation //// 312 | double horizon_dt = 0.1; // [sec] 313 | 314 | // Instantiate state equation 315 | int middle_layer_dim = 32; 316 | auto state_eq = std::make_shared(DDPProblem::ObjStateDim, DDPProblem::ObjInputDim, middle_layer_dim); 317 | 318 | // Instantiate problem 319 | auto ref_func = [&](double t, Eigen::Ref ref_x, 320 | Eigen::Ref ref_u) -> void { 321 | // Add small values to avoid numerical instability at inequality bounds 322 | constexpr double epsilon_t = 1e-6; 323 | t += epsilon_t; 324 | 325 | // Object position 326 | ref_x.setZero(); 327 | if(t < 1.0) // [sec] 328 | { 329 | ref_x[4] = 0.5; // [m] 330 | } 331 | else if(t < 3.0) // [sec] 332 | { 333 | ref_x[4] = 0.3 * (t - 1.0) + 0.5; // [m] 334 | } 335 | else 336 | { 337 | ref_x[4] = 1.1; // [m] 338 | } 339 | 340 | // ZMP position 341 | ref_u.setZero(); 342 | double step_total_duration = 1.0; // [sec] 343 | double step_transit_duration = 0.2; // [sec] 344 | double zmp_x_step = 0.2; // [m] 345 | std::vector zmp_y_list = {0.0, -0.1, 0.1, 0.0}; // [m] 346 | int step_idx = std::clamp(static_cast(std::floor((t - 1.5) / step_total_duration)), -1, 2); 347 | if(step_idx == -1) 348 | { 349 | ref_u[0] = 0.0; // [m] 350 | ref_u[1] = 0.0; // [m] 351 | } 352 | else 353 | { 354 | double step_start_time = 1.5 + step_idx * step_total_duration; 355 | double ratio = std::clamp((t - step_start_time) / step_transit_duration, 0.0, 1.0); 356 | ref_u[0] = zmp_x_step * (static_cast(step_idx) + ratio); 357 | ref_u[1] = (1.0 - ratio) * zmp_y_list[step_idx] + ratio * zmp_y_list[step_idx + 1]; 358 | } 359 | ref_x[0] = ref_u[0]; 360 | ref_x[1] = ref_u[1]; 361 | }; 362 | auto ddp_problem = std::make_shared(horizon_dt, state_eq, ref_func); 363 | 364 | // Call service to generate dataset 365 | auto start_dataset_time = std::chrono::system_clock::now(); 366 | data_driven_mpc::GenerateDataset generate_dataset_srv; 367 | std::string dataset_filename = ros::package::getPath("data_driven_mpc") + "/tests/data/TestMpcCartWalkDataset.bag"; 368 | int dataset_size = 10000; 369 | DDPProblem::ObjStateDimVector x_max = DDPProblem::ObjStateDimVector(2.0, 0.2, 0.4, 0.5); 370 | DDPProblem::ObjStateDimVector x_min = DDPProblem::ObjStateDimVector(0.0, -0.2, -0.4, -0.5); 371 | DDPProblem::ObjInputDimVector u_max = DDPProblem::ObjInputDimVector(15.0, 15.0); 372 | generate_dataset_srv.request.filename = dataset_filename; 373 | generate_dataset_srv.request.dataset_size = dataset_size; 374 | generate_dataset_srv.request.dt = horizon_dt; 375 | generate_dataset_srv.request.state_max.resize(DDPProblem::ObjStateDim); 376 | DDPProblem::ObjStateDimVector::Map(&generate_dataset_srv.request.state_max[0], DDPProblem::ObjStateDim) = x_max; 377 | generate_dataset_srv.request.state_min.resize(DDPProblem::ObjStateDim); 378 | DDPProblem::ObjStateDimVector::Map(&generate_dataset_srv.request.state_min[0], DDPProblem::ObjStateDim) = x_min; 379 | generate_dataset_srv.request.input_max.resize(DDPProblem::ObjInputDim); 380 | DDPProblem::ObjInputDimVector::Map(&generate_dataset_srv.request.input_max[0], DDPProblem::ObjInputDim) = u_max; 381 | generate_dataset_srv.request.input_min.resize(DDPProblem::ObjInputDim); 382 | DDPProblem::ObjInputDimVector::Map(&generate_dataset_srv.request.input_min[0], DDPProblem::ObjInputDim) = -1 * u_max; 383 | ASSERT_TRUE(generate_dataset_cli.call(generate_dataset_srv)) 384 | << "[TestMpcCartWalk] Failed to call ROS service to generate dataset." << std::endl; 385 | 386 | // Load dataset from rosbag 387 | Eigen::MatrixXd state_all; 388 | Eigen::MatrixXd input_all; 389 | Eigen::MatrixXd next_state_all; 390 | rosbag::Bag dataset_bag; 391 | dataset_bag.open(dataset_filename, rosbag::bagmode::Read); 392 | for(rosbag::MessageInstance const msg : 393 | rosbag::View(dataset_bag, rosbag::TopicQuery(std::vector{"/dataset"}))) 394 | { 395 | data_driven_mpc::Dataset::ConstPtr dataset_msg = msg.instantiate(); 396 | state_all = 397 | Eigen::Map(dataset_msg->state_all.data(), dataset_size, DDPProblem::ObjStateDim); 398 | input_all = 399 | Eigen::Map(dataset_msg->input_all.data(), dataset_size, DDPProblem::ObjInputDim); 400 | next_state_all = Eigen::Map(dataset_msg->next_state_all.data(), dataset_size, 401 | DDPProblem::ObjStateDim); 402 | break; 403 | } 404 | dataset_bag.close(); 405 | 406 | // Instantiate standardization scalar 407 | auto state_standard_scaler = std::make_shared>(state_all); 408 | auto input_standard_scaler = std::make_shared>(input_all); 409 | auto next_state_standard_scaler = 410 | std::make_shared>(next_state_all); 411 | ddp_problem->setStandardScaler(state_standard_scaler, input_standard_scaler, next_state_standard_scaler); 412 | 413 | // Instantiate dataset 414 | std::shared_ptr train_dataset; 415 | std::shared_ptr test_dataset; 416 | DDMPC::makeDataset(DDMPC::toTorchTensor(state_standard_scaler->apply(state_all).cast()), 417 | DDMPC::toTorchTensor(input_standard_scaler->apply(input_all).cast()), 418 | DDMPC::toTorchTensor(next_state_standard_scaler->apply(next_state_all).cast()), 419 | train_dataset, test_dataset); 420 | std::cout << "dataset duration: " 421 | << std::chrono::duration_cast>(std::chrono::system_clock::now() 422 | - start_dataset_time) 423 | .count() 424 | << " [s]" << std::endl; 425 | 426 | // Training model 427 | auto start_train_time = std::chrono::system_clock::now(); 428 | DDMPC::Training training; 429 | std::string model_path = ros::package::getPath("data_driven_mpc") + "/tests/data/TestMpcCartWalkModel.pt"; 430 | int batch_size = 256; 431 | int num_epoch = 400; 432 | double learning_rate = 1e-3; 433 | training.run(state_eq, train_dataset, test_dataset, model_path, batch_size, num_epoch, learning_rate); 434 | training.load(state_eq, model_path); 435 | std::cout << "train duration: " 436 | << std::chrono::duration_cast>(std::chrono::system_clock::now() 437 | - start_train_time) 438 | .count() 439 | << " [s]" << std::endl; 440 | 441 | std::cout << "Run the following commands in gnuplot:\n" 442 | << " set key autotitle columnhead\n" 443 | << " set key noenhanced\n" 444 | << " plot \"/tmp/DataDrivenMPCTraining.txt\" u 1:2 w lp, \"\" u 1:3 w lp\n"; 445 | 446 | //// 2. Run MPC //// 447 | double horizon_duration = 2.0; // [sec] 448 | int horizon_steps = static_cast(horizon_duration / horizon_dt); 449 | double end_t = 5.0; // [sec] 450 | 451 | // Instantiate solver 452 | auto ddp_solver = std::make_shared>(ddp_problem); 453 | auto input_limits_func = [&](double t) -> std::array { 454 | std::array limits; 455 | limits[0] << Eigen::Vector2d::Constant(-1e10), -1 * u_max; 456 | limits[1] << Eigen::Vector2d::Constant(1e10), u_max; 457 | return limits; 458 | }; 459 | ddp_solver->setInputLimitsFunc(input_limits_func); 460 | ddp_solver->config().with_input_constraint = true; 461 | ddp_solver->config().horizon_steps = horizon_steps; 462 | ddp_solver->config().max_iter = 5; 463 | 464 | // Initialize MPC 465 | double sim_dt = 0.05; // [sec] 466 | double current_t = 0; 467 | DDPProblem::StateDimVector current_x; 468 | current_x << 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0; 469 | std::vector current_u_list(horizon_steps, DDPProblem::InputDimVector::Zero()); 470 | 471 | // Run MPC loop 472 | std::string file_path = "/tmp/TestMpcCartWalkResult.txt"; 473 | std::ofstream ofs(file_path); 474 | ofs << "time robot_com_pos_x robot_com_vel_x robot_com_pos_y robot_com_vel_y obj_p obj_p_dot obj_theta obj_theta_dot " 475 | "robot_zmp_x robot_zmp_y obj_fx obj_fz ref_robot_com_pos_x ref_robot_com_vel_x ref_robot_com_pos_y " 476 | "ref_robot_com_vel_y ref_obj_p ref_obj_p_dot ref_obj_theta ref_obj_theta_dot ref_robot_zmp_x ref_robot_zmp_y " 477 | "ref_obj_fx ref_obj_fz ddp_iter computation_time" 478 | << std::endl; 479 | bool no_exit = false; 480 | pnh.getParam("no_exit", no_exit); 481 | while(no_exit || current_t < end_t) 482 | { 483 | // Solve 484 | auto start_time = std::chrono::system_clock::now(); 485 | ddp_solver->solve(current_t, current_x, current_u_list); 486 | 487 | // Set input 488 | const auto & input_limits = input_limits_func(current_t); 489 | DDPProblem::InputDimVector current_u = 490 | ddp_solver->controlData().u_list[0].cwiseMax(input_limits[0]).cwiseMin(input_limits[1]); 491 | double duration = 492 | 1e3 493 | * std::chrono::duration_cast>(std::chrono::system_clock::now() - start_time) 494 | .count(); 495 | 496 | // Check 497 | DDPProblem::StateDimVector current_ref_x; 498 | DDPProblem::InputDimVector current_ref_u; 499 | ref_func(current_t, current_ref_x, current_ref_u); 500 | for(int i = 0; i < current_x.size(); i++) 501 | { 502 | EXPECT_LT(std::abs(current_x[i] - current_ref_x[i]), 10.0) 503 | << "[TestMpcCartWalk] Violate running check for x[" << i << "]." << std::endl; 504 | } 505 | for(int i = 0; i < current_u.size(); i++) 506 | { 507 | EXPECT_LT(std::abs(current_u[i] - current_ref_u[i]), 100.0) 508 | << "[TestMpcCartWalk] Violate running check for u[" << i << "]." << std::endl; 509 | } 510 | EXPECT_LE(std::abs(current_u[2]), u_max[0]); // [N] 511 | EXPECT_LE(std::abs(current_u[3]), u_max[1]); // [N] 512 | 513 | // Dump 514 | ofs << current_t << " " << current_x.transpose() << " " << current_u.transpose() << " " << current_ref_x.transpose() 515 | << " " << current_ref_u.transpose() << " " << ddp_solver->traceDataList().back().iter << " " << duration 516 | << std::endl; 517 | 518 | // Update to next step 519 | current_t += sim_dt; 520 | current_x = ddp_problem->simulate(current_x, current_u, sim_dt); 521 | current_u_list = ddp_solver->controlData().u_list; 522 | } 523 | 524 | // Final check 525 | const DDPProblem::InputDimVector & current_u = ddp_solver->controlData().u_list[0]; 526 | DDPProblem::StateDimVector current_ref_x; 527 | DDPProblem::InputDimVector current_ref_u; 528 | ref_func(current_t, current_ref_x, current_ref_u); 529 | DDPProblem::StateDimVector x_tolerance; 530 | x_tolerance << 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5; 531 | DDPProblem::InputDimVector u_tolerance; 532 | u_tolerance << 0.5, 0.5, 10.0, 10.0; 533 | for(int i = 0; i < current_x.size(); i++) 534 | { 535 | EXPECT_LT(std::abs(current_x[i] - current_ref_x[i]), x_tolerance[i]) 536 | << "[TestMpcCartWalk] Violate final check for x[" << i << "]." << std::endl; 537 | } 538 | for(int i = 0; i < current_u.size(); i++) 539 | { 540 | EXPECT_LT(std::abs(current_u[i] - current_ref_u[i]), u_tolerance[i]) 541 | << "[TestMpcCartWalk] Violate final check for u[" << i << "]." << std::endl; 542 | } 543 | 544 | std::cout << "Run the following commands in gnuplot:\n" 545 | << " set key autotitle columnhead\n" 546 | << " set key noenhanced\n" 547 | << " plot \"" << file_path 548 | << "\" u 1:2 w lp, \"\" u 1:6 w lp, \"\" u 1:10 w lp, \"\" u 1:18 w l lw 2, \"\" u 1:22 w l lw 2 # pos_x\n" 549 | << " plot \"" << file_path << "\" u 1:4 w lp, \"\" u 1:11 w lp, \"\" u 1:23 w l lw 2 # pos_y\n" 550 | << " plot \"" << file_path << "\" u 1:8 w lp # obj_theta\n" 551 | << " plot \"" << file_path << "\" u 1:12 w lp, \"\" u 1:13 w lp # obj_force\n" 552 | << " plot \"" << file_path << "\" u 1:26 w lp # ddp_iter\n" 553 | << " plot \"" << file_path << "\" u 1:27 w lp # computation_time\n"; 554 | } 555 | 556 | TEST(TestMpcCartWalk, CheckDerivatives) 557 | { 558 | constexpr double deriv_eps = 1e-4; 559 | 560 | double horizon_dt = 0.05; // [sec] 561 | int middle_layer_dim = 32; 562 | auto state_eq = std::make_shared(DDPProblem::ObjStateDim, DDPProblem::ObjInputDim, middle_layer_dim); 563 | auto ddp_problem = std::make_shared(horizon_dt, state_eq); 564 | 565 | int dataset_size = 1000; 566 | Eigen::MatrixXd state_all = 1.0 * Eigen::MatrixXd::Random(dataset_size, DDPProblem::ObjStateDim); 567 | Eigen::MatrixXd input_all = 100.0 * Eigen::MatrixXd::Random(dataset_size, DDPProblem::ObjInputDim); 568 | Eigen::MatrixXd next_state_all = 10.0 * Eigen::MatrixXd::Random(dataset_size, DDPProblem::ObjStateDim); 569 | auto state_standard_scaler = std::make_shared>(state_all); 570 | auto input_standard_scaler = std::make_shared>(input_all); 571 | auto next_state_standard_scaler = 572 | std::make_shared>(next_state_all); 573 | ddp_problem->setStandardScaler(state_standard_scaler, input_standard_scaler, next_state_standard_scaler); 574 | 575 | double t = 0; 576 | DDPProblem::StateDimVector x = 1.0 * DDPProblem::StateDimVector::Random(); 577 | DDPProblem::InputDimVector u = 100.0 * DDPProblem::InputDimVector::Random(); 578 | 579 | DDPProblem::StateStateDimMatrix state_eq_deriv_x_analytical; 580 | DDPProblem::StateInputDimMatrix state_eq_deriv_u_analytical; 581 | ddp_problem->calcStateEqDeriv(t, x, u, state_eq_deriv_x_analytical, state_eq_deriv_u_analytical); 582 | 583 | DDPProblem::StateStateDimMatrix state_eq_deriv_x_numerical; 584 | DDPProblem::StateInputDimMatrix state_eq_deriv_u_numerical; 585 | for(int i = 0; i < ddp_problem->stateDim(); i++) 586 | { 587 | state_eq_deriv_x_numerical.col(i) = 588 | (ddp_problem->stateEq(t, x + deriv_eps * DDPProblem::StateDimVector::Unit(i), u) 589 | - ddp_problem->stateEq(t, x - deriv_eps * DDPProblem::StateDimVector::Unit(i), u)) 590 | / (2 * deriv_eps); 591 | } 592 | for(int i = 0; i < ddp_problem->inputDim(); i++) 593 | { 594 | state_eq_deriv_u_numerical.col(i) = 595 | (ddp_problem->stateEq(t, x, u + deriv_eps * DDPProblem::InputDimVector::Unit(i)) 596 | - ddp_problem->stateEq(t, x, u - deriv_eps * DDPProblem::InputDimVector::Unit(i))) 597 | / (2 * deriv_eps); 598 | } 599 | 600 | EXPECT_LT((state_eq_deriv_x_analytical - state_eq_deriv_x_numerical).norm(), 1e-2) 601 | << "state_eq_deriv_x_analytical:\n" 602 | << state_eq_deriv_x_analytical << std::endl 603 | << "state_eq_deriv_x_numerical:\n" 604 | << state_eq_deriv_x_numerical << std::endl 605 | << "state_eq_deriv_x_error:\n" 606 | << state_eq_deriv_x_analytical - state_eq_deriv_x_numerical << std::endl; 607 | EXPECT_LT((state_eq_deriv_u_analytical - state_eq_deriv_u_numerical).norm(), 1e-2) 608 | << "state_eq_deriv_u_analytical:\n" 609 | << state_eq_deriv_u_analytical << std::endl 610 | << "state_eq_deriv_u_numerical:\n" 611 | << state_eq_deriv_u_numerical << std::endl 612 | << "state_eq_deriv_u_error:\n" 613 | << state_eq_deriv_u_analytical - state_eq_deriv_u_numerical << std::endl; 614 | } 615 | 616 | int main(int argc, char ** argv) 617 | { 618 | // Setup ROS 619 | ros::init(argc, argv, "test_mpc_cart_walk"); 620 | 621 | testing::InitGoogleTest(&argc, argv); 622 | return RUN_ALL_TESTS(); 623 | } 624 | -------------------------------------------------------------------------------- /tests/src/TestMpcOscillator.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | /** \brief DDP problem based on data-driven state equation. */ 15 | class DDPProblem : public nmpc_ddp::DDPProblem<2, 1> 16 | { 17 | public: 18 | struct WeightParam 19 | { 20 | StateDimVector running_state; 21 | InputDimVector running_input; 22 | StateDimVector terminal_state; 23 | 24 | WeightParam(const StateDimVector & _running_state = StateDimVector::Constant(1.0), 25 | const InputDimVector & _running_input = InputDimVector::Constant(1.0), 26 | const StateDimVector & _terminal_state = StateDimVector::Constant(1.0)) 27 | : running_state(_running_state), running_input(_running_input), terminal_state(_terminal_state) 28 | { 29 | } 30 | }; 31 | 32 | public: 33 | DDPProblem(double dt, 34 | const std::shared_ptr & state_eq, 35 | const WeightParam & weight_param = WeightParam()) 36 | : nmpc_ddp::DDPProblem<2, 1>(dt), state_eq_(state_eq), weight_param_(weight_param) 37 | { 38 | } 39 | 40 | virtual StateDimVector stateEq(double t, const StateDimVector & x, const InputDimVector & u) const override 41 | { 42 | return state_eq_->eval(x, u); 43 | } 44 | 45 | virtual double runningCost(double t, const StateDimVector & x, const InputDimVector & u) const override 46 | { 47 | return 0.5 * weight_param_.running_state.dot(x.cwiseAbs2()) + 0.5 * weight_param_.running_input.dot(u.cwiseAbs2()); 48 | } 49 | 50 | virtual double terminalCost(double t, const StateDimVector & x) const override 51 | { 52 | return 0.5 * weight_param_.terminal_state.dot(x.cwiseAbs2()); 53 | } 54 | 55 | virtual void calcStateEqDeriv(double t, 56 | const StateDimVector & x, 57 | const InputDimVector & u, 58 | Eigen::Ref state_eq_deriv_x, 59 | Eigen::Ref state_eq_deriv_u) const override 60 | { 61 | state_eq_->eval(x, u, state_eq_deriv_x, state_eq_deriv_u); 62 | } 63 | 64 | virtual void calcStateEqDeriv(double t, 65 | const StateDimVector & x, 66 | const InputDimVector & u, 67 | Eigen::Ref state_eq_deriv_x, 68 | Eigen::Ref state_eq_deriv_u, 69 | std::vector & state_eq_deriv_xx, 70 | std::vector & state_eq_deriv_uu, 71 | std::vector & state_eq_deriv_xu) const override 72 | { 73 | throw std::runtime_error("Second-order derivatives of state equation are not implemented."); 74 | } 75 | 76 | virtual void calcRunningCostDeriv(double t, 77 | const StateDimVector & x, 78 | const InputDimVector & u, 79 | Eigen::Ref running_cost_deriv_x, 80 | Eigen::Ref running_cost_deriv_u) const override 81 | { 82 | running_cost_deriv_x = weight_param_.running_state.cwiseProduct(x); 83 | running_cost_deriv_u = weight_param_.running_input.cwiseProduct(u); 84 | } 85 | 86 | virtual void calcRunningCostDeriv(double t, 87 | const StateDimVector & x, 88 | const InputDimVector & u, 89 | Eigen::Ref running_cost_deriv_x, 90 | Eigen::Ref running_cost_deriv_u, 91 | Eigen::Ref running_cost_deriv_xx, 92 | Eigen::Ref running_cost_deriv_uu, 93 | Eigen::Ref running_cost_deriv_xu) const override 94 | { 95 | calcRunningCostDeriv(t, x, u, running_cost_deriv_x, running_cost_deriv_u); 96 | 97 | running_cost_deriv_xx.setZero(); 98 | running_cost_deriv_xx.diagonal() = weight_param_.running_state; 99 | running_cost_deriv_uu.setZero(); 100 | running_cost_deriv_uu.diagonal() = weight_param_.running_input; 101 | running_cost_deriv_xu.setZero(); 102 | } 103 | 104 | virtual void calcTerminalCostDeriv(double t, 105 | const StateDimVector & x, 106 | Eigen::Ref terminal_cost_deriv_x) const override 107 | { 108 | terminal_cost_deriv_x = weight_param_.terminal_state.cwiseProduct(x); 109 | } 110 | 111 | virtual void calcTerminalCostDeriv(double t, 112 | const StateDimVector & x, 113 | Eigen::Ref terminal_cost_deriv_x, 114 | Eigen::Ref terminal_cost_deriv_xx) const override 115 | { 116 | calcTerminalCostDeriv(t, x, terminal_cost_deriv_x); 117 | 118 | terminal_cost_deriv_xx.setZero(); 119 | terminal_cost_deriv_xx.diagonal() = weight_param_.terminal_state; 120 | } 121 | 122 | protected: 123 | WeightParam weight_param_; 124 | 125 | std::shared_ptr state_eq_; 126 | }; 127 | 128 | namespace Eigen 129 | { 130 | using Vector1d = Eigen::Matrix; 131 | } 132 | 133 | // Van der Pol oscillator 134 | // https://web.casadi.org/docs/#a-simple-test-problem 135 | Eigen::Vector2d simulate(const Eigen::Vector2d & x, const Eigen::Vector1d & u, double dt) 136 | { 137 | Eigen::Vector2d x_dot; 138 | x_dot[0] = (1.0 - std::pow(x[1], 2)) * x[0] - x[1] + u[0]; 139 | x_dot[1] = x[0]; 140 | return x + dt * x_dot; 141 | } 142 | 143 | TEST(TestMpcOscillator, Test1) 144 | { 145 | //// 1. Train state equation //// 146 | double horizon_dt = 0.03; // [sec] 147 | int state_dim = 2; 148 | int input_dim = 1; 149 | int middle_layer_dim = 16; 150 | auto state_eq = std::make_shared(state_dim, input_dim, middle_layer_dim); 151 | 152 | // Generate dataset 153 | auto start_dataset_time = std::chrono::system_clock::now(); 154 | int dataset_size = 100000; 155 | Eigen::MatrixXd state_all = 2.0 * Eigen::MatrixXd::Random(dataset_size, state_dim); 156 | Eigen::MatrixXd input_all = 2.0 * Eigen::MatrixXd::Random(dataset_size, input_dim); 157 | Eigen::MatrixXd next_state_all(dataset_size, state_dim); 158 | for(int i = 0; i < dataset_size; i++) 159 | { 160 | next_state_all.row(i) = 161 | simulate(state_all.row(i).transpose(), input_all.row(i).transpose(), horizon_dt).transpose(); 162 | } 163 | 164 | // Instantiate dataset 165 | std::shared_ptr train_dataset; 166 | std::shared_ptr test_dataset; 167 | DDMPC::makeDataset(DDMPC::toTorchTensor(state_all.cast()), DDMPC::toTorchTensor(input_all.cast()), 168 | DDMPC::toTorchTensor(next_state_all.cast()), train_dataset, test_dataset); 169 | std::cout << "dataset duration: " 170 | << std::chrono::duration_cast>(std::chrono::system_clock::now() 171 | - start_dataset_time) 172 | .count() 173 | << " [s]" << std::endl; 174 | 175 | // Training model 176 | auto start_train_time = std::chrono::system_clock::now(); 177 | DDMPC::Training training; 178 | std::string model_path = "/tmp/TestMpcOscillatorModel.pt"; 179 | int batch_size = 256; 180 | int num_epoch = 500; 181 | training.run(state_eq, train_dataset, test_dataset, model_path, batch_size, num_epoch); 182 | std::cout << "train duration: " 183 | << std::chrono::duration_cast>(std::chrono::system_clock::now() 184 | - start_train_time) 185 | .count() 186 | << " [s]" << std::endl; 187 | 188 | std::cout << "Run the following commands in gnuplot:\n" 189 | << " set key autotitle columnhead\n" 190 | << " set key noenhanced\n" 191 | << " plot \"/tmp/DataDrivenMPCTraining.txt\" u 1:2 w lp, \"\" u 1:3 w lp\n"; 192 | 193 | //// 2. Run MPC //// 194 | double horizon_duration = 5.0; // [sec] 195 | int horizon_steps = static_cast(horizon_duration / horizon_dt); 196 | double end_t = 10.0; // [sec] 197 | 198 | // Instantiate problem 199 | auto ddp_problem = std::make_shared(horizon_dt, state_eq); 200 | 201 | // Instantiate solver 202 | auto ddp_solver = std::make_shared>(ddp_problem); 203 | auto input_limits_func = [&](double t) -> std::array { 204 | std::array limits; 205 | limits[0].setConstant(input_dim, -1.0); 206 | limits[1].setConstant(input_dim, 1.0); 207 | return limits; 208 | }; 209 | ddp_solver->setInputLimitsFunc(input_limits_func); 210 | ddp_solver->config().with_input_constraint = true; 211 | ddp_solver->config().horizon_steps = horizon_steps; 212 | 213 | // Initialize MPC 214 | double sim_dt = 0.02; // [sec] 215 | double current_t = 0; 216 | DDPProblem::StateDimVector current_x = DDPProblem::StateDimVector(0.0, 1.0); 217 | std::vector current_u_list(horizon_steps, DDPProblem::InputDimVector::Zero()); 218 | 219 | // Run MPC loop 220 | bool first_iter = true; 221 | std::string file_path = "/tmp/TestMpcOscillatorResult.txt"; 222 | std::ofstream ofs(file_path); 223 | ofs << "time x[0] x[1] u[0] ddp_iter computation_time" << std::endl; 224 | while(current_t < end_t) 225 | { 226 | // Solve 227 | auto start_time = std::chrono::system_clock::now(); 228 | ddp_solver->solve(current_t, current_x, current_u_list); 229 | if(first_iter) 230 | { 231 | first_iter = false; 232 | ddp_solver->config().max_iter = 5; 233 | } 234 | 235 | // Set input 236 | const auto & input_limits = input_limits_func(current_t); 237 | DDPProblem::InputDimVector current_u = 238 | ddp_solver->controlData().u_list[0].cwiseMax(input_limits[0]).cwiseMin(input_limits[1]); 239 | double duration = 240 | 1e3 241 | * std::chrono::duration_cast>(std::chrono::system_clock::now() - start_time) 242 | .count(); 243 | 244 | // Check 245 | EXPECT_LT(std::abs(current_x[0]), 2.0); 246 | EXPECT_LT(std::abs(current_x[1]), 2.0); 247 | EXPECT_LE(std::abs(current_u[0]), 1.0); 248 | 249 | // Dump 250 | ofs << current_t << " " << current_x.transpose() << " " << current_u.transpose() << " " 251 | << ddp_solver->traceDataList().back().iter << " " << duration << std::endl; 252 | 253 | // Update to next step 254 | current_t += sim_dt; 255 | current_x = simulate(current_x, current_u, sim_dt); 256 | current_u_list = ddp_solver->controlData().u_list; 257 | current_u_list.erase(current_u_list.begin()); 258 | current_u_list.push_back(current_u_list.back()); 259 | } 260 | 261 | // Final check 262 | const DDPProblem::InputDimVector & current_u = ddp_solver->controlData().u_list[0]; 263 | EXPECT_LT(std::abs(current_x[0]), 0.1); 264 | // Most of the time it reaches a better convergence, but sometimes it's worse 265 | EXPECT_LT(std::abs(current_x[1]), 0.5); 266 | EXPECT_LT(std::abs(current_u[0]), 0.5); 267 | 268 | std::cout << "Run the following commands in gnuplot:\n" 269 | << " set key autotitle columnhead\n" 270 | << " set key noenhanced\n" 271 | << " plot \"" << file_path << "\" u 1:2 w lp, \"\" u 1:3 w lp, \"\" u 1:4 w lp\n"; 272 | } 273 | 274 | int main(int argc, char ** argv) 275 | { 276 | testing::InitGoogleTest(&argc, argv); 277 | return RUN_ALL_TESTS(); 278 | } 279 | -------------------------------------------------------------------------------- /tests/src/TestMpcPushWalk.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace Eigen 16 | { 17 | using Vector1d = Eigen::Matrix; 18 | } 19 | 20 | /** \brief DDP problem based on combination of analytical and data-driven models. 21 | 22 | State consists of [robot_com_pos, robot_com_vel, obj_com_pos, obj_com_vel]. 23 | Input consists of [robot_zmp, obj_force]. 24 | */ 25 | class DDPProblem : public nmpc_ddp::DDPProblem<4, 2> 26 | { 27 | public: 28 | struct WeightParam 29 | { 30 | StateDimVector running_state; 31 | InputDimVector running_input; 32 | StateDimVector terminal_state; 33 | 34 | WeightParam(const StateDimVector & _running_state = StateDimVector::Constant(1.0), 35 | const InputDimVector & _running_input = InputDimVector::Constant(1.0), 36 | const StateDimVector & _terminal_state = StateDimVector::Constant(1.0)) 37 | : running_state(_running_state), running_input(_running_input), terminal_state(_terminal_state) 38 | { 39 | } 40 | }; 41 | 42 | public: 43 | DDPProblem( 44 | double dt, 45 | const std::shared_ptr & state_eq, 46 | const std::function & damper_func = nullptr, 47 | const std::function, Eigen::Ref)> & ref_func = nullptr, 48 | const WeightParam & weight_param = WeightParam()) 49 | : nmpc_ddp::DDPProblem<4, 2>(dt), state_eq_(state_eq), damper_func_(damper_func), ref_func_(ref_func), 50 | weight_param_(weight_param) 51 | { 52 | } 53 | 54 | Eigen::Vector2d simulateRobot(const Eigen::Vector2d & x, const Eigen::Vector2d & u, double dt) const 55 | { 56 | double robot_com_pos = x[0]; 57 | double robot_com_vel = x[1]; 58 | double robot_zmp = u[0]; 59 | double obj_force = u[1]; 60 | 61 | Eigen::Vector2d x_dot; 62 | x_dot[0] = robot_com_vel; 63 | x_dot[1] = gravity_acc_ / robot_com_height_ * (robot_com_pos - robot_zmp) 64 | - obj_grasp_height_ / (robot_mass_ * robot_com_height_) * obj_force; 65 | 66 | return x + dt * x_dot; 67 | } 68 | 69 | Eigen::Vector2d simulateObj(const Eigen::Vector2d & x, const Eigen::Vector1d & u, double dt) const 70 | { 71 | double obj_com_pos = x[0]; 72 | double obj_com_vel = x[1]; 73 | double obj_force = u[0]; 74 | 75 | Eigen::Vector2d x_dot; 76 | x_dot[0] = obj_com_vel; 77 | x_dot[1] = (damper_func_(obj_com_vel) + obj_force) / obj_mass_; 78 | 79 | return x + dt * x_dot; 80 | } 81 | 82 | StateDimVector simulate(const StateDimVector & x, const InputDimVector & u, double dt) const 83 | { 84 | StateDimVector next_x; 85 | next_x.head<2>() = simulateRobot(x.head<2>(), u, dt); 86 | next_x.tail<2>() = simulateObj(x.tail<2>(), u.tail<1>(), dt); 87 | 88 | return next_x; 89 | } 90 | 91 | virtual StateDimVector stateEq(double t, const StateDimVector & x, const InputDimVector & u) const override 92 | { 93 | StateDimVector next_x; 94 | next_x.head<2>() = simulateRobot(x.head<2>(), u, dt_); 95 | next_x.tail<2>() = next_state_standard_scaler_->applyOneInv( 96 | state_eq_->eval(state_standard_scaler_->applyOne(x.tail<2>()), input_standard_scaler_->applyOne(u.tail<1>()))); 97 | 98 | return next_x; 99 | } 100 | 101 | virtual double runningCost(double t, const StateDimVector & x, const InputDimVector & u) const override 102 | { 103 | StateDimVector ref_x; 104 | InputDimVector ref_u; 105 | ref_func_(t, ref_x, ref_u); 106 | 107 | return 0.5 * weight_param_.running_state.dot((x - ref_x).cwiseAbs2()) 108 | + 0.5 * weight_param_.running_input.dot((u - ref_u).cwiseAbs2()); 109 | } 110 | 111 | virtual double terminalCost(double t, const StateDimVector & x) const override 112 | { 113 | StateDimVector ref_x; 114 | InputDimVector ref_u; 115 | ref_func_(t, ref_x, ref_u); 116 | 117 | return 0.5 * weight_param_.terminal_state.dot((x - ref_x).cwiseAbs2()); 118 | } 119 | 120 | virtual void calcStateEqDeriv(double t, 121 | const StateDimVector & x, 122 | const InputDimVector & u, 123 | Eigen::Ref state_eq_deriv_x, 124 | Eigen::Ref state_eq_deriv_u) const override 125 | { 126 | state_eq_deriv_x.setZero(); 127 | state_eq_deriv_u.setZero(); 128 | 129 | state_eq_deriv_x(0, 1) = 1; 130 | state_eq_deriv_x(1, 0) = gravity_acc_ / robot_com_height_; 131 | state_eq_deriv_x.topRows<2>() *= dt_; 132 | state_eq_deriv_x.diagonal().head<2>().array() += 1.0; 133 | 134 | state_eq_deriv_u(1, 0) = -1 * gravity_acc_ / robot_com_height_; 135 | state_eq_deriv_u(1, 1) = -1 * obj_grasp_height_ / (robot_mass_ * robot_com_height_); 136 | state_eq_deriv_u.topRows<2>() *= dt_; 137 | 138 | state_eq_->eval(state_standard_scaler_->applyOne(x.tail<2>()), input_standard_scaler_->applyOne(u.tail<1>()), 139 | state_eq_deriv_x.bottomRightCorner<2, 2>(), state_eq_deriv_u.bottomRightCorner<2, 1>()); 140 | state_eq_deriv_x.bottomRightCorner<2, 2>().array().colwise() *= 141 | next_state_standard_scaler_->stddev_vec_.transpose().array(); 142 | state_eq_deriv_x.bottomRightCorner<2, 2>().array().rowwise() /= state_standard_scaler_->stddev_vec_.array(); 143 | state_eq_deriv_u.bottomRightCorner<2, 1>().array().colwise() *= 144 | next_state_standard_scaler_->stddev_vec_.transpose().array(); 145 | state_eq_deriv_u.bottomRightCorner<2, 1>().array().rowwise() /= input_standard_scaler_->stddev_vec_.array(); 146 | } 147 | 148 | virtual void calcStateEqDeriv(double t, 149 | const StateDimVector & x, 150 | const InputDimVector & u, 151 | Eigen::Ref state_eq_deriv_x, 152 | Eigen::Ref state_eq_deriv_u, 153 | std::vector & state_eq_deriv_xx, 154 | std::vector & state_eq_deriv_uu, 155 | std::vector & state_eq_deriv_xu) const override 156 | { 157 | throw std::runtime_error("Second-order derivatives of state equation are not implemented."); 158 | } 159 | 160 | virtual void calcRunningCostDeriv(double t, 161 | const StateDimVector & x, 162 | const InputDimVector & u, 163 | Eigen::Ref running_cost_deriv_x, 164 | Eigen::Ref running_cost_deriv_u) const override 165 | { 166 | StateDimVector ref_x; 167 | InputDimVector ref_u; 168 | ref_func_(t, ref_x, ref_u); 169 | 170 | running_cost_deriv_x = weight_param_.running_state.cwiseProduct(x - ref_x); 171 | running_cost_deriv_u = weight_param_.running_input.cwiseProduct(u - ref_u); 172 | } 173 | 174 | virtual void calcRunningCostDeriv(double t, 175 | const StateDimVector & x, 176 | const InputDimVector & u, 177 | Eigen::Ref running_cost_deriv_x, 178 | Eigen::Ref running_cost_deriv_u, 179 | Eigen::Ref running_cost_deriv_xx, 180 | Eigen::Ref running_cost_deriv_uu, 181 | Eigen::Ref running_cost_deriv_xu) const override 182 | { 183 | calcRunningCostDeriv(t, x, u, running_cost_deriv_x, running_cost_deriv_u); 184 | 185 | running_cost_deriv_xx.setZero(); 186 | running_cost_deriv_xx.diagonal() = weight_param_.running_state; 187 | running_cost_deriv_uu.setZero(); 188 | running_cost_deriv_uu.diagonal() = weight_param_.running_input; 189 | running_cost_deriv_xu.setZero(); 190 | } 191 | 192 | virtual void calcTerminalCostDeriv(double t, 193 | const StateDimVector & x, 194 | Eigen::Ref terminal_cost_deriv_x) const override 195 | { 196 | StateDimVector ref_x; 197 | InputDimVector ref_u; 198 | ref_func_(t, ref_x, ref_u); 199 | 200 | terminal_cost_deriv_x = weight_param_.terminal_state.cwiseProduct(x - ref_x); 201 | } 202 | 203 | virtual void calcTerminalCostDeriv(double t, 204 | const StateDimVector & x, 205 | Eigen::Ref terminal_cost_deriv_x, 206 | Eigen::Ref terminal_cost_deriv_xx) const override 207 | { 208 | calcTerminalCostDeriv(t, x, terminal_cost_deriv_x); 209 | 210 | terminal_cost_deriv_xx.setZero(); 211 | terminal_cost_deriv_xx.diagonal() = weight_param_.terminal_state; 212 | } 213 | 214 | void setStandardScaler(const std::shared_ptr> & state_standard_scaler, 215 | const std::shared_ptr> & input_standard_scaler, 216 | const std::shared_ptr> & next_state_standard_scaler) 217 | 218 | { 219 | state_standard_scaler_ = state_standard_scaler; 220 | input_standard_scaler_ = input_standard_scaler; 221 | next_state_standard_scaler_ = next_state_standard_scaler; 222 | } 223 | 224 | protected: 225 | std::shared_ptr state_eq_; 226 | 227 | std::function damper_func_; 228 | 229 | std::function, Eigen::Ref)> ref_func_; 230 | 231 | WeightParam weight_param_; 232 | 233 | std::shared_ptr> state_standard_scaler_; 234 | std::shared_ptr> input_standard_scaler_; 235 | std::shared_ptr> next_state_standard_scaler_; 236 | 237 | double gravity_acc_ = 9.8; // [m/s^2] 238 | double robot_mass_ = 100.0; // [kg] 239 | double robot_com_height_ = 1.0; // [m] 240 | double obj_mass_ = 50.0; // [kg] 241 | double obj_grasp_height_ = 1.0; // [m] 242 | }; 243 | 244 | TEST(TestMpcPushWalk, RunMPC) 245 | { 246 | //// 1. Train state equation //// 247 | double horizon_dt = 0.1; // [sec] 248 | 249 | // Instantiate state equation 250 | int obj_state_dim = 2; 251 | int obj_input_dim = 1; 252 | int middle_layer_dim = 8; 253 | auto state_eq = std::make_shared(obj_state_dim, obj_input_dim, middle_layer_dim); 254 | 255 | // Instantiate problem 256 | std::string damper_type_str = "Linear"; 257 | std::function damper_func = nullptr; 258 | if(damper_type_str == "None") 259 | { 260 | damper_func = [&](double vel) -> double { 261 | return 0; // [N] 262 | }; 263 | } 264 | else if(damper_type_str == "Linear") 265 | { 266 | damper_func = [&](double vel) -> double { 267 | return -100 * vel; // [N] 268 | }; 269 | } 270 | else if(damper_type_str == "Square") 271 | { 272 | damper_func = [&](double vel) -> double { 273 | return -500 * std::copysign(std::pow(vel, 2), vel); // [N] 274 | }; 275 | } 276 | else if(damper_type_str == "Sqrt") 277 | { 278 | damper_func = [&](double vel) -> double { 279 | return -20 * std::copysign(std::sqrt(std::abs(vel)), vel); // [N] 280 | }; 281 | } 282 | else if(damper_type_str == "Offset") 283 | { 284 | damper_func = [&](double vel) -> double { 285 | return -25; // [N] 286 | }; 287 | } 288 | else 289 | { 290 | throw std::runtime_error("Invalid damper_type_str: " + damper_type_str); 291 | } 292 | auto ref_func = [&](double t, Eigen::Ref ref_x, 293 | Eigen::Ref ref_u) -> void { 294 | // Add small values to avoid numerical instability at inequality bounds 295 | constexpr double epsilon_t = 1e-6; 296 | t += epsilon_t; 297 | 298 | // Object position 299 | ref_x.setZero(); 300 | if(t < 1.0) // [sec] 301 | { 302 | ref_x[2] = 0.2; // [m] 303 | } 304 | else if(t < 3.0) // [sec] 305 | { 306 | ref_x[2] = 0.3 * (t - 1.0) + 0.2; // [m] 307 | } 308 | else 309 | { 310 | ref_x[2] = 0.8; // [m] 311 | } 312 | 313 | // ZMP position 314 | ref_u.setZero(); 315 | double step_total_duration = 1.0; // [sec] 316 | double step_transit_duration = 0.2; // [sec] 317 | int step_idx = std::clamp(static_cast(std::floor((t - 1.5) / step_total_duration)), -1, 2); 318 | if(step_idx == -1) 319 | { 320 | ref_u[0] = 0.0; // [m] 321 | } 322 | else 323 | { 324 | double step_start_time = 1.5 + step_idx * step_total_duration; 325 | ref_u[0] = 0.2 * (step_idx + std::clamp((t - step_start_time) / step_transit_duration, 0.0, 1.0)); 326 | } 327 | ref_x[0] = ref_u[0]; 328 | }; 329 | DDPProblem::WeightParam weight_param; 330 | weight_param.running_state << 0.0, 1e-6, 1e2, 1e-6; 331 | weight_param.running_input << 1e2, 1e-5; 332 | weight_param.terminal_state << 1.0, 1e-1, 1.0, 1e-1; 333 | auto ddp_problem = std::make_shared(horizon_dt, state_eq, damper_func, ref_func, weight_param); 334 | 335 | // Generate dataset 336 | auto start_dataset_time = std::chrono::system_clock::now(); 337 | int dataset_size = 20000; 338 | Eigen::MatrixXd state_all = 1.0 * Eigen::MatrixXd::Random(dataset_size, obj_state_dim); 339 | Eigen::MatrixXd input_all = 100.0 * Eigen::MatrixXd::Random(dataset_size, obj_input_dim); 340 | Eigen::MatrixXd next_state_all(dataset_size, obj_state_dim); 341 | for(int i = 0; i < dataset_size; i++) 342 | { 343 | next_state_all.row(i) = 344 | ddp_problem->simulateObj(state_all.row(i).transpose(), input_all.row(i).transpose(), horizon_dt).transpose(); 345 | } 346 | 347 | // Instantiate standardization scalar 348 | auto state_standard_scaler = std::make_shared>(state_all); 349 | auto input_standard_scaler = std::make_shared>(input_all); 350 | auto next_state_standard_scaler = std::make_shared>(next_state_all); 351 | ddp_problem->setStandardScaler(state_standard_scaler, input_standard_scaler, next_state_standard_scaler); 352 | 353 | // Instantiate dataset 354 | std::shared_ptr train_dataset; 355 | std::shared_ptr test_dataset; 356 | DDMPC::makeDataset(DDMPC::toTorchTensor(state_standard_scaler->apply(state_all).cast()), 357 | DDMPC::toTorchTensor(input_standard_scaler->apply(input_all).cast()), 358 | DDMPC::toTorchTensor(next_state_standard_scaler->apply(next_state_all).cast()), 359 | train_dataset, test_dataset); 360 | std::cout << "dataset duration: " 361 | << std::chrono::duration_cast>(std::chrono::system_clock::now() 362 | - start_dataset_time) 363 | .count() 364 | << " [s]" << std::endl; 365 | 366 | // Training model 367 | auto start_train_time = std::chrono::system_clock::now(); 368 | DDMPC::Training training; 369 | std::string model_path = "/tmp/TestMpcPushWalkModel.pt"; 370 | int batch_size = 256; 371 | int num_epoch = 200; 372 | double learning_rate = 5e-3; 373 | training.run(state_eq, train_dataset, test_dataset, model_path, batch_size, num_epoch, learning_rate); 374 | std::cout << "train duration: " 375 | << std::chrono::duration_cast>(std::chrono::system_clock::now() 376 | - start_train_time) 377 | .count() 378 | << " [s]" << std::endl; 379 | 380 | std::cout << "Run the following commands in gnuplot:\n" 381 | << " set key autotitle columnhead\n" 382 | << " set key noenhanced\n" 383 | << " plot \"/tmp/DataDrivenMPCTraining.txt\" u 1:2 w lp, \"\" u 1:3 w lp\n"; 384 | 385 | //// 2. Run MPC //// 386 | double horizon_duration = 2.0; // [sec] 387 | int horizon_steps = static_cast(horizon_duration / horizon_dt); 388 | double end_t = 5.0; // [sec] 389 | 390 | // Instantiate solver 391 | auto ddp_solver = std::make_shared>(ddp_problem); 392 | auto input_limits_func = [&](double t) -> std::array { 393 | std::array limits; 394 | limits[0] << -1e10, -50.0; 395 | limits[1] << 1e10, 50.0; 396 | return limits; 397 | }; 398 | ddp_solver->setInputLimitsFunc(input_limits_func); 399 | ddp_solver->config().with_input_constraint = true; 400 | ddp_solver->config().horizon_steps = horizon_steps; 401 | ddp_solver->config().max_iter = 2; 402 | 403 | // Initialize MPC 404 | double sim_dt = 0.05; // [sec] 405 | double current_t = 0; 406 | DDPProblem::StateDimVector current_x = DDPProblem::StateDimVector(0.0, 0.0, 0.2, 0.0); 407 | std::vector current_u_list(horizon_steps, DDPProblem::InputDimVector::Zero()); 408 | 409 | // Run MPC loop 410 | std::string file_path = "/tmp/TestMpcPushWalkResult-" + damper_type_str + ".txt"; 411 | std::ofstream ofs(file_path); 412 | ofs << "time robot_com_pos robot_com_vel obj_com_pos obj_com_vel robot_zmp obj_force ref_obj_com_pos ref_robot_zmp " 413 | "ddp_iter computation_time" 414 | << std::endl; 415 | while(current_t < end_t) 416 | { 417 | // Solve 418 | auto start_time = std::chrono::system_clock::now(); 419 | ddp_solver->solve(current_t, current_x, current_u_list); 420 | 421 | // Set input 422 | const auto & input_limits = input_limits_func(current_t); 423 | DDPProblem::InputDimVector current_u = 424 | ddp_solver->controlData().u_list[0].cwiseMax(input_limits[0]).cwiseMin(input_limits[1]); 425 | double duration = 426 | 1e3 427 | * std::chrono::duration_cast>(std::chrono::system_clock::now() - start_time) 428 | .count(); 429 | 430 | // Check 431 | DDPProblem::StateDimVector current_ref_x; 432 | DDPProblem::InputDimVector current_ref_u; 433 | ref_func(current_t, current_ref_x, current_ref_u); 434 | EXPECT_LT(std::abs(current_x[0] - current_ref_x[0]), 1.0); // [m] 435 | EXPECT_LT(std::abs(current_x[1] - current_ref_x[1]), 1.0); // [m/s] 436 | EXPECT_LT(std::abs(current_x[2] - current_ref_x[2]), 1.0); // [m] 437 | EXPECT_LT(std::abs(current_x[3] - current_ref_x[3]), 1.0); // [m/s] 438 | EXPECT_LT(std::abs(current_u[0] - current_ref_u[0]), 1.0); // [m] 439 | EXPECT_LE(std::abs(current_u[1]), 50.0); // [N] 440 | 441 | // Dump 442 | ofs << current_t << " " << current_x.transpose() << " " << current_u.transpose() << " " << current_ref_x[2] << " " 443 | << current_ref_u[0] << " " << ddp_solver->traceDataList().back().iter << " " << duration << std::endl; 444 | 445 | // Update to next step 446 | current_t += sim_dt; 447 | current_x = ddp_problem->simulate(current_x, current_u, sim_dt); 448 | current_u_list = ddp_solver->controlData().u_list; 449 | } 450 | 451 | // Final check 452 | const DDPProblem::InputDimVector & current_u = ddp_solver->controlData().u_list[0]; 453 | DDPProblem::StateDimVector current_ref_x; 454 | DDPProblem::InputDimVector current_ref_u; 455 | ref_func(current_t, current_ref_x, current_ref_u); 456 | EXPECT_LT(std::abs(current_x[0] - current_ref_x[0]), 0.1); // [m] 457 | EXPECT_LT(std::abs(current_x[1] - current_ref_x[1]), 0.1); // [m/s] 458 | EXPECT_LT(std::abs(current_x[2] - current_ref_x[2]), 0.1); // [m] 459 | EXPECT_LT(std::abs(current_x[3] - current_ref_x[3]), 0.1); // [m/s] 460 | EXPECT_LT(std::abs(current_u[0] - current_ref_u[0]), 0.1); // [m] 461 | EXPECT_LE(std::abs(current_u[1]), 10.0); // [N] 462 | 463 | std::cout << "Run the following commands in gnuplot:\n" 464 | << " set key autotitle columnhead\n" 465 | << " set key noenhanced\n" 466 | << " plot \"" << file_path 467 | << "\" u 1:2 w lp, \"\" u 1:4 w lp, \"\" u 1:6 w lp, \"\" u 1:8 w l lw 2, \"\" u 1:9 w l lw 2\n" 468 | << " plot \"" << file_path << "\" u 1:7 w lp\n"; 469 | } 470 | 471 | TEST(TestMpcPushWalk, CheckDerivatives) 472 | { 473 | constexpr double deriv_eps = 1e-4; 474 | 475 | double horizon_dt = 0.05; // [sec] 476 | int obj_state_dim = 2; 477 | int obj_input_dim = 1; 478 | int middle_layer_dim = 4; 479 | auto state_eq = std::make_shared(obj_state_dim, obj_input_dim, middle_layer_dim); 480 | auto ddp_problem = std::make_shared(horizon_dt, state_eq); 481 | 482 | int dataset_size = 1000; 483 | Eigen::MatrixXd state_all = 1.0 * Eigen::MatrixXd::Random(dataset_size, obj_state_dim); 484 | Eigen::MatrixXd input_all = 100.0 * Eigen::MatrixXd::Random(dataset_size, obj_input_dim); 485 | Eigen::MatrixXd next_state_all = 10.0 * Eigen::MatrixXd::Random(dataset_size, obj_state_dim); 486 | auto state_standard_scaler = std::make_shared>(state_all); 487 | auto input_standard_scaler = std::make_shared>(input_all); 488 | auto next_state_standard_scaler = std::make_shared>(next_state_all); 489 | ddp_problem->setStandardScaler(state_standard_scaler, input_standard_scaler, next_state_standard_scaler); 490 | 491 | double t = 0; 492 | DDPProblem::StateDimVector x(0.1, -0.2, 0.3, -0.4); 493 | DDPProblem::InputDimVector u(10.0, -20.0); 494 | 495 | DDPProblem::StateStateDimMatrix state_eq_deriv_x_analytical; 496 | DDPProblem::StateInputDimMatrix state_eq_deriv_u_analytical; 497 | ddp_problem->calcStateEqDeriv(t, x, u, state_eq_deriv_x_analytical, state_eq_deriv_u_analytical); 498 | 499 | DDPProblem::StateStateDimMatrix state_eq_deriv_x_numerical; 500 | DDPProblem::StateInputDimMatrix state_eq_deriv_u_numerical; 501 | for(int i = 0; i < ddp_problem->stateDim(); i++) 502 | { 503 | state_eq_deriv_x_numerical.col(i) = 504 | (ddp_problem->stateEq(t, x + deriv_eps * DDPProblem::StateDimVector::Unit(i), u) 505 | - ddp_problem->stateEq(t, x - deriv_eps * DDPProblem::StateDimVector::Unit(i), u)) 506 | / (2 * deriv_eps); 507 | } 508 | for(int i = 0; i < ddp_problem->inputDim(); i++) 509 | { 510 | state_eq_deriv_u_numerical.col(i) = 511 | (ddp_problem->stateEq(t, x, u + deriv_eps * DDPProblem::InputDimVector::Unit(i)) 512 | - ddp_problem->stateEq(t, x, u - deriv_eps * DDPProblem::InputDimVector::Unit(i))) 513 | / (2 * deriv_eps); 514 | } 515 | 516 | EXPECT_LT((state_eq_deriv_x_analytical - state_eq_deriv_x_numerical).norm(), 1e-2) 517 | << "state_eq_deriv_x_analytical:\n" 518 | << state_eq_deriv_x_analytical << std::endl 519 | << "state_eq_deriv_x_numerical:\n" 520 | << state_eq_deriv_x_numerical << std::endl 521 | << "state_eq_deriv_x_error:\n" 522 | << state_eq_deriv_x_analytical - state_eq_deriv_x_numerical << std::endl; 523 | EXPECT_LT((state_eq_deriv_u_analytical - state_eq_deriv_u_numerical).norm(), 1e-2) 524 | << "state_eq_deriv_u_analytical:\n" 525 | << state_eq_deriv_u_analytical << std::endl 526 | << "state_eq_deriv_u_numerical:\n" 527 | << state_eq_deriv_u_numerical << std::endl 528 | << "state_eq_deriv_u_error:\n" 529 | << state_eq_deriv_u_analytical - state_eq_deriv_u_numerical << std::endl; 530 | } 531 | 532 | int main(int argc, char ** argv) 533 | { 534 | testing::InitGoogleTest(&argc, argv); 535 | return RUN_ALL_TESTS(); 536 | } 537 | -------------------------------------------------------------------------------- /tests/src/TestStateEq.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | TEST(TestStateEq, Test1) 9 | { 10 | int state_dim = 3; 11 | int input_dim = 2; 12 | DDMPC::StateEq state_eq(state_dim, input_dim); 13 | 14 | Eigen::VectorXd x = Eigen::VectorXd::Random(state_dim); 15 | Eigen::VectorXd u = Eigen::VectorXd::Random(input_dim); 16 | Eigen::MatrixXd grad_x(state_dim, state_dim); 17 | Eigen::MatrixXd grad_u(state_dim, input_dim); 18 | 19 | Eigen::VectorXd next_x = state_eq.eval(x, u, grad_x, grad_u); 20 | EXPECT_FALSE(next_x.array().isNaN().any()); 21 | EXPECT_FALSE(grad_x.array().isNaN().any()); 22 | EXPECT_FALSE(grad_u.array().isNaN().any()); 23 | 24 | Eigen::MatrixXd grad_x_numerical(state_dim, state_dim); 25 | Eigen::MatrixXd grad_u_numerical(state_dim, input_dim); 26 | double eps = 1e-4; 27 | for(int i = 0; i < state_dim; i++) 28 | { 29 | grad_x_numerical.col(i) = (state_eq.eval(x + eps * Eigen::VectorXd::Unit(state_dim, i), u) 30 | - state_eq.eval(x - eps * Eigen::VectorXd::Unit(state_dim, i), u)) 31 | / (2 * eps); 32 | } 33 | for(int i = 0; i < input_dim; i++) 34 | { 35 | grad_u_numerical.col(i) = (state_eq.eval(x, u + eps * Eigen::VectorXd::Unit(input_dim, i)) 36 | - state_eq.eval(x, u - eps * Eigen::VectorXd::Unit(input_dim, i))) 37 | / (2 * eps); 38 | } 39 | EXPECT_LT((grad_x - grad_x_numerical).norm(), 1e-3); 40 | EXPECT_LT((grad_u - grad_u_numerical).norm(), 1e-3); 41 | } 42 | 43 | int main(int argc, char ** argv) 44 | { 45 | testing::InitGoogleTest(&argc, argv); 46 | return RUN_ALL_TESTS(); 47 | } 48 | -------------------------------------------------------------------------------- /tests/src/TestTorchUtils.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | 7 | TEST(TestTorchUtils, Test1) 8 | { 9 | Eigen::MatrixXf mat1 = Eigen::MatrixXf::Random(4, 10); 10 | torch::Tensor tensor1 = DDMPC::toTorchTensor(mat1); 11 | Eigen::MatrixXf mat1_restored = DDMPC::toEigenMatrix(tensor1); 12 | torch::Tensor tensor1_restored = DDMPC::toTorchTensor(mat1_restored); 13 | 14 | EXPECT_LT((mat1 - mat1_restored).norm(), 1e-8); 15 | EXPECT_LT((tensor1 - tensor1_restored).norm().item(), 1e-8); 16 | } 17 | 18 | int main(int argc, char ** argv) 19 | { 20 | testing::InitGoogleTest(&argc, argv); 21 | return RUN_ALL_TESTS(); 22 | } 23 | -------------------------------------------------------------------------------- /tests/src/TestTraining.cpp: -------------------------------------------------------------------------------- 1 | /* Author: Masaki Murooka */ 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | TEST(TestTraining, Test1) 9 | { 10 | // Generate dataset 11 | int state_dim = 2; 12 | int input_dim = 1; 13 | int dataset_size = 1000; 14 | Eigen::MatrixXd state_all = Eigen::MatrixXd::Random(dataset_size, state_dim); 15 | Eigen::MatrixXd input_all = Eigen::MatrixXd::Random(dataset_size, input_dim); 16 | Eigen::MatrixXd next_state_all(dataset_size, state_dim); 17 | next_state_all.col(0) = 1.0 * state_all.col(0) + 2.0 * state_all.col(1) - 1.0 * input_all.col(0); 18 | next_state_all.col(1) = -2.0 * state_all.col(0) + -1.0 * state_all.col(1) + 2.0 * input_all.col(0); 19 | 20 | // Instantiate state equation and dataset 21 | auto state_eq = std::make_shared(state_dim, input_dim); 22 | std::shared_ptr train_dataset; 23 | std::shared_ptr test_dataset; 24 | DDMPC::makeDataset(DDMPC::toTorchTensor(state_all.cast()), DDMPC::toTorchTensor(input_all.cast()), 25 | DDMPC::toTorchTensor(next_state_all.cast()), train_dataset, test_dataset); 26 | 27 | double before_error = (next_state_all.row(0).transpose() - state_eq->eval(state_all.row(0), input_all.row(0))).norm(); 28 | 29 | // Training model 30 | DDMPC::Training training; 31 | std::string model_path = "/tmp/TestTrainingModel.pt"; 32 | training.run(state_eq, train_dataset, test_dataset, model_path); 33 | 34 | double after_error = (next_state_all.row(0).transpose() - state_eq->eval(state_all.row(0), input_all.row(0))).norm(); 35 | 36 | // Check error 37 | EXPECT_LT(after_error, before_error); 38 | EXPECT_LT(after_error, 0.2); 39 | 40 | std::cout << "Run the following commands in gnuplot:\n" 41 | << " set key autotitle columnhead\n" 42 | << " set key noenhanced\n" 43 | << " plot \"/tmp/DataDrivenMPCTraining.txt\" u 1:2 w lp, \"\" u 1:3 w lp\n"; 44 | } 45 | 46 | int main(int argc, char ** argv) 47 | { 48 | testing::InitGoogleTest(&argc, argv); 49 | return RUN_ALL_TESTS(); 50 | } 51 | -------------------------------------------------------------------------------- /tests/test/TestMpcCart.test: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | no_exit: $(arg no_exit) 11 | 12 | 13 | 14 | 16 | 17 | enable_gui: $(arg enable_gui) 18 | 19 | 20 | 22 | 23 | enable_gui: $(arg enable_gui) 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /tests/test/TestMpcCartWalk.test: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | no_exit: $(arg no_exit) 10 | 11 | 12 | 13 | 14 | 15 | enable_gui: $(arg enable_gui) 16 | box_mass: 15.0 17 | lateral_friction: 0.075 18 | 19 | 20 | 21 | --------------------------------------------------------------------------------