├── docs ├── source │ ├── rst_files │ │ ├── readme.rst │ │ ├── readme-sdk.rst │ │ ├── readme-docs.rst │ │ ├── readme-test.rst │ │ ├── readme-docker.rst │ │ ├── protobuf.rst │ │ ├── api-sdk.rst │ │ ├── api-common.rst │ │ └── api-simulator.rst │ ├── index.rst │ └── conf.py ├── make-html.sh ├── Makefile ├── make.bat ├── .readthedocs.yaml └── README.md ├── sdk ├── assets │ └── exported-python-sdk │ │ ├── requirements.txt │ │ ├── check.sh │ │ ├── __init__.py │ │ ├── README.md │ │ ├── trajectory_tracker.py │ │ └── check_export_symbols.py ├── docs │ └── DRLTT-SDK-mainpage.md ├── drltt-sdk │ ├── common │ │ ├── common.h │ │ ├── math.h │ │ ├── CMakeLists.txt │ │ ├── protobuf_operators_test.cpp │ │ ├── io.cpp │ │ ├── protobuf_operators.cpp │ │ ├── io.h │ │ ├── geometry.h │ │ └── protobuf_operators.h │ ├── CMakeLists.txt │ ├── managers │ │ ├── CMakeLists.txt │ │ ├── observation_manager.h │ │ └── observation_manager.cpp │ ├── dynamics_models │ │ ├── bicycle_model_test.cpp │ │ ├── CMakeLists.txt │ │ ├── base_dynamics_model.cpp │ │ ├── bicycle_model.h │ │ ├── base_dynamics_model.h │ │ └── bicycle_model.cpp │ ├── inference │ │ ├── CMakeLists.txt │ │ ├── policy_inference_test.cpp │ │ ├── policy_inference.cpp │ │ └── policy_inference.h │ ├── environments │ │ ├── CMakeLists.txt │ │ ├── trajectory_tracking.h │ │ └── trajectory_tracking.cpp │ └── trajectory_tracker │ │ ├── CMakeLists.txt │ │ ├── trajectory_tracker.cpp │ │ ├── trajectory_tracker_pybind_export.cpp │ │ └── trajectory_tracker.h ├── format-cpp-code.sh ├── Doxyfile-cpp ├── compile-source.sh ├── export-py-sdk.sh ├── CMakeLists.txt ├── compile-in-docker.sh └── .clang-format ├── submodules ├── README.md └── setup-submodules.sh ├── cicd ├── stop-all-gitlab-runners.sh ├── test-cpp-ci.sh ├── start-gitlab-runner.sh └── README.md ├── drltt ├── __init__.py ├── simulator │ ├── observation │ │ ├── __init__.py │ │ └── observation_manager.py │ ├── trajectory_tracker │ │ ├── __init__.py │ │ ├── trajectory_tracker_test.py │ │ └── trajectory_tracker.py │ ├── trajectory │ │ ├── __init__.py │ │ ├── random_walk_test.py │ │ ├── reference_line_test.py │ │ └── random_walk.py │ ├── visualization │ │ ├── __init__.py │ │ ├── visualize_trajectory_tracking_episode_test.py │ │ ├── utils.py │ │ └── visualize_trajectory_tracking_episode.py │ ├── dynamics_models │ │ ├── __init__.py │ │ ├── bicycle_model_test.py │ │ ├── dynamics_model_manager.py │ │ └── base_dynamics_model.py │ ├── environments │ │ ├── __init__.py │ │ ├── trajectory_tracking_env_test.py │ │ └── env_interface.py │ ├── __init__.py │ ├── common.py │ └── rl_learning │ │ ├── __init__.py │ │ ├── sb3_utils.py │ │ ├── sb3_learner_test.py │ │ └── sb3_learner.py └── common │ ├── __init__.py │ ├── proto │ └── proto_def │ │ ├── drltt_proto │ │ ├── environment │ │ │ ├── environment.proto │ │ │ └── trajectory_tracking.proto │ │ ├── dynamics_model │ │ │ ├── observation.proto │ │ │ ├── action.proto │ │ │ ├── state.proto │ │ │ ├── basics.proto │ │ │ └── hyper_parameter.proto │ │ ├── sdk │ │ │ └── exported_policy_test_case.proto │ │ └── trajectory │ │ │ └── trajectory.proto │ │ ├── CMakeLists.txt │ │ └── compile_proto.sh │ ├── gym_helper.py │ ├── future.py │ ├── geometry.py │ ├── registry.py │ └── io.py ├── format-python-code.sh ├── requirements ├── pypi-doc.txt └── pypi.txt ├── configs └── trajectory_tracking │ ├── config-track-tiny.yaml │ ├── config-track.yaml │ ├── config-track-var-reflen.yaml │ ├── test_samples │ ├── config-track-test-sample-dummy.yaml │ ├── config-track-test-sample-fast.yaml │ └── config-track-test-sample.yaml │ └── config-track-base.yaml ├── test ├── test-doc.sh ├── test-python.sh ├── test.sh ├── test-cpp.sh └── README.md ├── .gitmodules ├── format-code.sh ├── setup.sh ├── setup-minimum.sh ├── scripts ├── train_eval_trace-track.sh ├── train_eval_trace-track_tiny.sh ├── train_eval_trace-track_var_reflen.sh ├── tests │ ├── train_eval_trace-track_test_sample.sh │ ├── train_eval_trace-track_test_sample_dummy.sh │ └── train_eval_trace-track_test_sample_fast.sh └── eval │ ├── eval-track_tiny.sh │ ├── eval-track.sh │ └── eval-track_var_reflen.sh ├── install-setup-protoc-gen-doc.sh ├── .gitlab-ci.yml ├── install-setup-doxygen.sh ├── LICENSE ├── install-setup-protoc.sh ├── docker ├── README.md └── Dockerfile.cicd ├── .gitignore └── tools └── main.py /docs/source/rst_files/readme.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../../../README.md 2 | -------------------------------------------------------------------------------- /sdk/assets/exported-python-sdk/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | protobuf 3 | -------------------------------------------------------------------------------- /docs/source/rst_files/readme-sdk.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../../../sdk/README.md 2 | -------------------------------------------------------------------------------- /docs/source/rst_files/readme-docs.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../../../docs/README.md 2 | -------------------------------------------------------------------------------- /docs/source/rst_files/readme-test.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../../../test/README.md 2 | -------------------------------------------------------------------------------- /docs/source/rst_files/readme-docker.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../../../docker/README.md 2 | -------------------------------------------------------------------------------- /sdk/docs/DRLTT-SDK-mainpage.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | Welcome to the documentation for DRLTT. 4 | -------------------------------------------------------------------------------- /submodules/README.md: -------------------------------------------------------------------------------- 1 | # Submodules 2 | 3 | This directory contains asset repositories and third party repositories. 4 | -------------------------------------------------------------------------------- /cicd/stop-all-gitlab-runners.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker container stop $(docker container ls -q --filter name=drltt-cicd*) 4 | -------------------------------------------------------------------------------- /drltt/__init__.py: -------------------------------------------------------------------------------- 1 | from . import common 2 | from . import simulator 3 | 4 | __all__ = [ 5 | 'common', 6 | 'simulator', 7 | ] 8 | -------------------------------------------------------------------------------- /drltt/simulator/observation/__init__.py: -------------------------------------------------------------------------------- 1 | from .observation_manager import ObservationManager 2 | 3 | __all__ = [ 4 | 'ObservationManager', 5 | ] 6 | -------------------------------------------------------------------------------- /drltt/simulator/trajectory_tracker/__init__.py: -------------------------------------------------------------------------------- 1 | from .trajectory_tracker import TrajectoryTracker 2 | 3 | __all__ = [ 4 | 'TrajectoryTracker', 5 | ] 6 | -------------------------------------------------------------------------------- /format-python-code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo FORMATTING PYTHON CODE... 4 | 5 | black ${BLACK_ARGS} --config ./configs/code_formatting/pyproject.toml ./ 6 | -------------------------------------------------------------------------------- /requirements/pypi-doc.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-rtd-theme 3 | myst-parser # markdown, substitute of `recommonmark` 4 | m2r2 5 | breathe # doxygen within sphinx 6 | -------------------------------------------------------------------------------- /configs/trajectory_tracking/config-track-tiny.yaml: -------------------------------------------------------------------------------- 1 | 2 | algorithm: 3 | policy_kwargs: 4 | net_arch: 5 | pi: [128, 32] 6 | qf: [1024, 512, 256, 128] 7 | -------------------------------------------------------------------------------- /configs/trajectory_tracking/config-track.yaml: -------------------------------------------------------------------------------- 1 | 2 | algorithm: 3 | policy_kwargs: 4 | net_arch: 5 | pi: [256, 256, 128, 128, 64, 64] 6 | qf: [1024, 512, 256, 128] 7 | -------------------------------------------------------------------------------- /docs/source/rst_files/protobuf.rst: -------------------------------------------------------------------------------- 1 | Protobuf Definition 2 | ============================ 3 | 4 | .. raw:: html 5 | :file: ../../../drltt/common/proto/proto_doc_gen/index.html 6 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "common/geometry.h" 4 | #include "common/io.h" 5 | #include "common/math.h" 6 | #include "common/protobuf_operators.h" 7 | -------------------------------------------------------------------------------- /drltt/simulator/trajectory/__init__.py: -------------------------------------------------------------------------------- 1 | from .random_walk import random_walk 2 | from .reference_line import ReferenceLineManager 3 | 4 | __all__ = [ 5 | 'random_walk', 6 | 'ReferenceLineManager', 7 | ] 8 | -------------------------------------------------------------------------------- /configs/trajectory_tracking/config-track-var-reflen.yaml: -------------------------------------------------------------------------------- 1 | # reference line of variable length 2 | environment: 3 | tracking_length_lb: 20 4 | tracking_length_ub: 200 5 | 6 | learning: 7 | total_timesteps: 3_000_000 8 | -------------------------------------------------------------------------------- /test/test-doc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | log_dir=./test-log 4 | mkdir -p $log_dir 5 | 6 | ( 7 | source setup.sh 8 | ./docs/make-html.sh 9 | retval=$? 10 | ) 2>&1 | tee ./${log_dir}/doc-test.log 11 | 12 | exit $retval 13 | -------------------------------------------------------------------------------- /docs/source/rst_files/api-sdk.rst: -------------------------------------------------------------------------------- 1 | APIs: SDK 2 | ============== 3 | 4 | .. doxygenclass:: drltt::TrajectoryTracker 5 | :members: 6 | 7 | .. doxygenclass:: drltt::TrajectoryTracking 8 | :members: 9 | 10 | .. doxygenclass:: drltt::TorchJITModulePolicy 11 | :members: 12 | -------------------------------------------------------------------------------- /sdk/assets/exported-python-sdk/check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install -r ./requirements.txt 4 | 5 | export LD_LIBRARY_PATH=./lib:${LD_LIBRARY_PATH} 6 | export PYTHONPATH=$PWD:$PYTHONPATH 7 | export PYTHONPATH=./proto_gen_py:$PYTHONPATH 8 | 9 | python ./check_export_symbols.py 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/drltt-assets"] 2 | path = submodules/drltt-assets 3 | url = https://github.com/MARMOTatZJU/drltt-assets.git 4 | [submodule "submodules/waymax-visualization"] 5 | path = submodules/waymax-visualization 6 | url = https://github.com/MARMOTatZJU/waymax-visualization.git 7 | -------------------------------------------------------------------------------- /drltt/simulator/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from drltt.common import Registry 2 | 3 | VISUALIZATION_FUNCTIONS = Registry() 4 | 5 | from .visualize_trajectory_tracking_episode import visualize_trajectory_tracking_episode 6 | 7 | __all__ = [ 8 | 'visualize_trajectory_tracking_episode', 9 | ] 10 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(drltt-sdk) 2 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 3 | 4 | add_subdirectory(common) 5 | add_subdirectory(dynamics_models) 6 | add_subdirectory(inference) 7 | add_subdirectory(managers) 8 | add_subdirectory(environments) 9 | add_subdirectory(trajectory_tracker) 10 | -------------------------------------------------------------------------------- /drltt/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import Registry, build_object_within_registry_from_config 2 | from .geometry import normalize_angle 3 | from .io import load_config_from_yaml, GLOBAL_DEBUG_INFO 4 | 5 | __all__ = [ 6 | 'Registry', 7 | 'build_object_within_registry_from_config', 8 | 'load_config_from_yaml', 9 | ] 10 | -------------------------------------------------------------------------------- /sdk/assets/exported-python-sdk/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | PACKAGE_DIR = os.path.dirname(__file__) 5 | USR_LIB_DIR = f'{PACKAGE_DIR}/lib' 6 | SDK_LIB_DIR = PACKAGE_DIR 7 | sys.path.append(PACKAGE_DIR) 8 | 9 | 10 | from .trajectory_tracker import TrajectoryTracker 11 | 12 | __all__ = [ 13 | 'TrajectoryTracker', 14 | ] 15 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/environment/environment.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "drltt_proto/environment/trajectory_tracking.proto"; 4 | 5 | package drltt_proto; 6 | 7 | // Environment 8 | message Environment { 9 | // Trajectory tracking. 10 | optional TrajectoryTrackingEnvironment trajectory_tracking = 1; 11 | } 12 | -------------------------------------------------------------------------------- /test/test-python.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | log_dir=./test-log 4 | mkdir -p $log_dir 5 | 6 | ( 7 | echo "SETTING UP PYTHON TESTING ENVIRONMENT" 8 | source ./scripts/tests/train_eval_trace-track_test_sample_dummy.sh 9 | echo "TEST PYTHON CODE" 10 | pytest 11 | retval=$? 12 | ) 2>&1 | tee ./${log_dir}/python-test.log 13 | 14 | exit $retval 15 | -------------------------------------------------------------------------------- /sdk/assets/exported-python-sdk/README.md: -------------------------------------------------------------------------------- 1 | # DRLTT standalone SDK 2 | 3 | Exported DRLTT Python SDK package. 4 | 5 | No dependency except for `Python=3.8`. Multiple supported Python versions planned in the future. 6 | 7 | Documentation: https://drl-based-trajectory-tracking.readthedocs.io/ 8 | 9 | NOTE: `LD_LIBRARY_PATH` must include `./lib` before process starts. 10 | -------------------------------------------------------------------------------- /format-code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./format-python-code.sh 4 | python_check_result=$? 5 | 6 | pushd ./sdk 7 | ./format-cpp-code.sh 8 | cpp_check_result=$? 9 | popd 10 | 11 | if [ $python_check_result -ne 0 ] || [ $cpp_check_result -ne 0 ]; then 12 | echo "Code formatting failed" 13 | exit 1 14 | fi 15 | 16 | echo "Code formatting succeeded" 17 | exit 0 18 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/managers/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(drltt-sdk) 2 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 3 | 4 | 5 | file(GLOB_RECURSE _SRCS "*.[hc]pp") 6 | list(FILTER _SRCS EXCLUDE REGEX "_test.[hc]pp$") 7 | add_library(managers STATIC ${_SRCS}) 8 | target_link_libraries(managers 9 | ${TORCH_LIBRARIES} 10 | common 11 | dynamics_models 12 | ) 13 | 14 | -------------------------------------------------------------------------------- /submodules/setup-submodules.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | drltt_proto_gen_py_dir_default="${DRLTT_REPO_DIR}/submodules/drltt-assets/proto_gen_py/" 4 | export PYTHONPATH=:${drltt_proto_gen_py_dir_default}:${PYTHONPATH} 5 | 6 | waymax_viz_dir=submodules/waymax-visualization 7 | export PYTHONPATH=:${waymax_viz_dir}:${PYTHONPATH} 8 | 9 | echo "DRLTT submodule setup finished." 10 | -------------------------------------------------------------------------------- /docs/make-html.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | 4 | pushd docs 5 | rm -rf build 6 | make html SPHINXOPTS="-W" 7 | make_sphinx_ret_val=$? 8 | if [ $make_sphinx_ret_val -eq 0 ];then 9 | echo "Built the Sphinx documentation successfully." 10 | else 11 | echo "Sphinx documentation building failed!!!" 12 | fi 13 | popd 14 | 15 | exit $make_sphinx_ret_val 16 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/dynamics_model/observation.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package drltt_proto; 4 | 5 | 6 | message Observation { 7 | // Obervation of bicycle model. 8 | optional BicycleModelObservation bicycle_model = 1; 9 | } 10 | 11 | 12 | message BicycleModelObservation { 13 | // Vectorized feature. 14 | repeated float feature = 1; 15 | } 16 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/dynamics_model/action.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package drltt_proto; 4 | 5 | 6 | message Action { 7 | optional BicycleModelAction bicycle_model = 1; 8 | } 9 | 10 | 11 | message BicycleModelAction { 12 | // Acceleration in [m/s/s]. 13 | optional float a = 1; 14 | // Steering angle in [rad]. 15 | optional float s = 2; 16 | } 17 | -------------------------------------------------------------------------------- /drltt/simulator/dynamics_models/__init__.py: -------------------------------------------------------------------------------- 1 | from drltt.common import Registry 2 | 3 | DYNAMICS_MODELS = Registry() 4 | 5 | from .base_dynamics_model import BaseDynamicsModel 6 | from .dynamics_model_manager import DynamicsModelManager 7 | from .bicycle_model import BicycleModel 8 | 9 | 10 | __all__ = [ 11 | 'BaseDynamicsModel', 12 | 'BicycleModel', 13 | 'DynamicsModelManager', 14 | ] 15 | -------------------------------------------------------------------------------- /drltt/simulator/environments/__init__.py: -------------------------------------------------------------------------------- 1 | from drltt.common import Registry 2 | 3 | ENVIRONMENTS = Registry() 4 | 5 | from .env_interface import CustomizedEnvInterface 6 | from .trajectory_tracking_env import TrajectoryTrackingEnv 7 | from .env_interface import ExtendedGymEnv 8 | 9 | __all__ = [ 10 | 'TrajectoryTrackingEnv', 11 | 'CustomizedEnvInterface', 12 | 'ExtendedGymEnv', 13 | ] 14 | -------------------------------------------------------------------------------- /requirements/pypi.txt: -------------------------------------------------------------------------------- 1 | # common dependencies 2 | pytest 3 | numpy 4 | pandas 5 | numpy 6 | scipy 7 | matplotlib 8 | protobuf==3.20.1 9 | black==23.11.0 10 | PyYAML 11 | frozendict 12 | 13 | # deep reinforcement learning (DRL)-related 14 | torch 15 | torchvision 16 | gym>=0.26.2 # SB3 specifies a version of 0.21.0, which is not compatible with up-to-date version of SB3 17 | stable-baselines3[extra] 18 | -------------------------------------------------------------------------------- /drltt/simulator/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import DTYPE, EPSILON, TEST_CONFIG_PATHS, TEST_CHECKPOINT_DIR 2 | from .trajectory_tracker.trajectory_tracker import TrajectoryTracker 3 | from .environments.trajectory_tracking_env import TrajectoryTrackingEnv 4 | 5 | __all__ = [ 6 | 'DTYPE', 7 | 'EPSILON', 8 | 'TEST_CONFIG_PATHS', 9 | 'TEST_CHECKPOINT_DIR', 10 | 'TrajectoryTrackingEnv', 11 | 'TrajectoryTracker', 12 | ] 13 | -------------------------------------------------------------------------------- /test/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | work_dir=./test-log 4 | if [[ -d $work_dir ]];then 5 | bak_work_dir=${work_dir}-bak 6 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 7 | mv ${work_dir} ${bak_work_dir} 8 | fi 9 | mkdir -p $work_dir 10 | 11 | test_script_dir=$(dirname $0) 12 | 13 | ./${test_script_dir}/test-python.sh "$@" 14 | ./${test_script_dir}/test-cpp.sh "$@" 15 | ./${test_script_dir}/test-doc.sh "$@" 16 | -------------------------------------------------------------------------------- /cicd/test-cpp-ci.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | log_dir=./test-log 4 | mkdir -p $log_dir 5 | 6 | ( 7 | ./scripts/tests/train_eval_trace-track_test_sample.sh 8 | cd sdk 9 | export BUILD_DIR=$(realpath ./build) 10 | export PROTO_GEN_DIR=$(realpath ./proto_gen) 11 | export LIBTORCH_DIR=/libtorch 12 | export REPO_ROOT_DIR="$(git rev-parse --show-toplevel)" 13 | ./compile-source.sh 14 | ) 2>&1 | tee ./${log_dir}/cpp-test-ci.log 15 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/dynamics_model/state.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package drltt_proto; 4 | 5 | import "drltt_proto/dynamics_model/basics.proto"; 6 | 7 | message State { 8 | // State of bicycle model. 9 | optional BicycleModelState bicycle_model = 1; 10 | } 11 | 12 | message BicycleModelState { 13 | // Sody state. 14 | BodyState body_state = 1; 15 | // Velocity in [m/s]. 16 | optional float v = 2; 17 | } 18 | -------------------------------------------------------------------------------- /drltt/simulator/common.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO) 4 | 5 | import numpy as np 6 | 7 | DTYPE = np.float32 8 | EPSILON = 1e-6 9 | 10 | TEST_CONFIG_PATHS = ( 11 | 'configs/trajectory_tracking/config-track-base.yaml', 12 | 'configs/trajectory_tracking/config-track-tiny.yaml', 13 | 'configs/trajectory_tracking/test_samples/config-track-test-sample.yaml', 14 | ) 15 | 16 | TEST_CHECKPOINT_DIR = 'work_dir/track-test/checkpoint' 17 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/dynamics_model/basics.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package drltt_proto; 4 | 5 | // State coordinates in fixed space frame. 6 | message BodyState { 7 | // X-coordinate in [m]. 8 | optional float x = 1; 9 | // Y-coordinate in [m]. 10 | optional float y = 2; 11 | // Orientation/heading in [rad]. 12 | optional float r = 3; 13 | } 14 | 15 | // TODO: move to higher level 16 | message DebugInfo { 17 | repeated float data = 1; 18 | } 19 | -------------------------------------------------------------------------------- /configs/trajectory_tracking/test_samples/config-track-test-sample-dummy.yaml: -------------------------------------------------------------------------------- 1 | # Overriding config file for DUMMY TEST. 2 | # No consistency ensured. ONLY FOR PYTHON TEST. 3 | 4 | environment: 5 | tracking_length_lb: 45 6 | tracking_length_ub: 55 7 | 8 | algorithm: 9 | learning_starts: 128 10 | batch_size: 128 11 | learning_rate: 1.0e-3 12 | tau: 0.01 13 | 14 | learning: 15 | total_timesteps: 256 16 | log_interval: 10 17 | 18 | evaluation: 19 | eval_config: 20 | n_episodes: 1 21 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -exo pipefail 4 | echo "DRLTT setup starts." 5 | 6 | source ./setup-minimum.sh 7 | 8 | source ./install-setup-protoc.sh 9 | source ./install-setup-doxygen.sh 10 | source ./install-setup-protoc-gen-doc.sh 11 | 12 | # compile protobuf and generate documentation of protobuf 13 | bash ${DRLTT_PROTO_DIR}/proto_def/compile_proto.sh 14 | 15 | # generate documentation for cpp 16 | (cd ./sdk ; doxygen Doxyfile-cpp) 17 | 18 | echo "DRLTT setup finished." 19 | 20 | set +exo pipefail 21 | -------------------------------------------------------------------------------- /drltt/common/gym_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym.spaces import Space 4 | 5 | 6 | def scale_action(action: np.ndarray, action_space: Space) -> np.ndarray: 7 | """Scale action into range [-1, +1]. 8 | 9 | Args: 10 | action: Action to be sclaled. 11 | action_space: Reference action space. 12 | 13 | Returns: 14 | np.ndarray: Scaled action. 15 | """ 16 | scaled_action = 2 * (action - action_space.low) / (action_space.high - action_space.low) - 1 17 | return scaled_action 18 | -------------------------------------------------------------------------------- /setup-minimum.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # THIS SCRIPT NEEDS TO BE RUN AT THE ROOT DIR. OF THE REPO. 3 | 4 | export DRLTT_REPO_DIR=$(realpath ${PWD}) 5 | export DRLTT_PY_LIB_DIR=${DRLTT_REPO_DIR}/drltt 6 | export DRLTT_PROTO_DIR=${DRLTT_PY_LIB_DIR}/common/proto 7 | export DRLTT_PROTO_GEN_PY_DIR=${DRLTT_PROTO_DIR}/proto_gen_py 8 | 9 | source submodules/setup-submodules.sh 10 | 11 | export PYTHONPATH=${DRLTT_REPO_DIR}:${PYTHONPATH} 12 | export PYTHONPATH=:${DRLTT_PROTO_GEN_PY_DIR}:${PYTHONPATH} 13 | 14 | echo "set PYTHONPATH: $PYTHONPATH" 15 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/math.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | const float EPSILON = 1e-6; 6 | 7 | /** 8 | * @brief Clip an object. Comparison operator need to be implemented for this 9 | * object 10 | * 11 | * @tparam T Object type. 12 | * @param n Object to be clipped. 13 | * @param lower Lower bound. 14 | * @param upper Upper bound. 15 | * @return T Clipped object. 16 | */ 17 | template 18 | T clip(const T& n, const T& lower, const T& upper) { 19 | return std::max(lower, std::min(n, upper)); 20 | } 21 | -------------------------------------------------------------------------------- /drltt/simulator/rl_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from drltt.common import Registry 2 | 3 | METRICS = Registry() 4 | 5 | from .sb3_learner import compute_bicycle_model_metrics, train_with_sb3, eval_with_sb3, roll_out_one_episode 6 | from .sb3_export import export_sb3_jit_module, test_sb3_jit_module 7 | 8 | __all__ = [ 9 | 'train_with_sb3', 10 | 'eval_with_sb3', 11 | 'compute_bicycle_model_metrics', 12 | 'export_sb3_jit_module', 13 | 'test_sb3_jit_module', 14 | 'OnnxableActorCriticPolicy', 15 | 'roll_out_one_episode', 16 | ] 17 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/sdk/exported_policy_test_case.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package drltt_proto; 4 | 5 | // Float-point tensor. 6 | message TensorFP{ 7 | // Tensor's shape. 8 | repeated int32 shape = 1; 9 | // Tensor's underlying data. 10 | repeated float data = 2; 11 | } 12 | 13 | message ExportedPolicyTestCases{ 14 | // Observation tensor, shape=(sample_number, observation_dim) 15 | optional TensorFP observations = 1; 16 | // Action tensor, shape=(sample_number, action_dim) 17 | optional TensorFP actions = 2; 18 | } 19 | -------------------------------------------------------------------------------- /sdk/format-cpp-code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "FORMATTING CPP CODE..." 4 | 5 | if [[ ! -x $(command -v clang-format) ]];then 6 | echo "clang-format not found." 7 | exit -1 8 | fi 9 | 10 | clang_format_result=0 11 | for f in $(find . -regex '.*\.\(cpp\|hpp\|cu\|cuh\|c\|h\)' \ 12 | -not -path "./build/*" \ 13 | -not -path "./proto_gen/*" \ 14 | ) 15 | do 16 | clang-format --verbose --style=file --Werror ${CLANG_FORMAT_ARGS} -i $f 17 | if [ $? -ne 0 ]; then 18 | clang_format_result=$(($clang_format_result + 1)) 19 | fi 20 | done 21 | 22 | exit ${clang_format_result} 23 | -------------------------------------------------------------------------------- /configs/trajectory_tracking/test_samples/config-track-test-sample-fast.yaml: -------------------------------------------------------------------------------- 1 | # Overriding config file for FAST TEST. 2 | # Can be used during debugging of tests. 3 | # Quicker than preparation with FORMAL TEST config while ensuring enough consistency of testing results. 4 | 5 | environment: 6 | tracking_length_lb: 45 7 | tracking_length_ub: 55 8 | 9 | algorithm: 10 | learning_starts: 4096 11 | batch_size: 4096 12 | learning_rate: 1.0e-4 13 | tau: 0.001 14 | 15 | learning: 16 | total_timesteps: 16384 17 | log_interval: 10 18 | 19 | evaluation: 20 | eval_config: 21 | n_episodes: 20 22 | -------------------------------------------------------------------------------- /configs/trajectory_tracking/test_samples/config-track-test-sample.yaml: -------------------------------------------------------------------------------- 1 | # Overriding config file for FORMAL TEST. 2 | # Each commit/pull request/merge request need to pass tests 3 | # associated with this config 4 | # Long preparation time but ensure consistent testing results. 5 | 6 | environment: 7 | tracking_length_lb: 45 8 | tracking_length_ub: 55 9 | 10 | algorithm: 11 | learning_starts: 4096 12 | batch_size: 4096 13 | learning_rate: 1.0e-4 14 | tau: 0.0001 15 | 16 | learning: 17 | total_timesteps: 65536 18 | log_interval: 10 19 | 20 | evaluation: 21 | eval_config: 22 | n_episodes: 20 23 | -------------------------------------------------------------------------------- /sdk/Doxyfile-cpp: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/google/google-api-cpp-client/blob/master/doxygen.config 2 | DOXYFILE_ENCODING = UTF-8 3 | PROJECT_NAME = "DRLTT C++ SDK" 4 | OUTPUT_DIRECTORY = ./doxygen_output 5 | INPUT = ./drltt-sdk/ 6 | RECURSIVE = YES 7 | INPUT += ./docs/DRLTT-SDK-mainpage.md 8 | USE_MDFILE_AS_MAINPAGE = ./docs/DRLTT-SDK-mainpage.md 9 | 10 | GENERATE_XML = YES 11 | XML_OUTPUT = xml 12 | 13 | WARN_FORMAT = "$file:$line: $text" 14 | WARN_LOGFILE = "doxygen.warnings.log" 15 | -------------------------------------------------------------------------------- /scripts/train_eval_trace-track.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | work_dir=work_dir/track 4 | if [[ -d $work_dir ]];then 5 | bak_work_dir=${work_dir}-bak 6 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 7 | mv ${work_dir} ${bak_work_dir} 8 | fi 9 | mkdir -p $work_dir 10 | 11 | script_path=$0 12 | cp $script_path $work_dir/ 13 | 14 | python tools/main.py \ 15 | --config-files \ 16 | configs/trajectory_tracking/config-track-base.yaml \ 17 | configs/trajectory_tracking/config-track.yaml \ 18 | --checkpoint-dir $work_dir/checkpoint \ 19 | --num-test-cases 1024 \ 20 | --train --eval --trace 21 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(Protobuf REQUIRED) 2 | include_directories(${PROTOBUF_INCLUDE_DIRS}) 3 | 4 | # TODO remove redundant set of this part 5 | set(PROTOBUF_GENERATE_CPP_APPEND_PATH FALSE) 6 | set(PROTOBUF_IMPORT_DIRS ${CMAKE_CURRENT_SOURCE_DIR} PROTOC_OUT_DIR ${PROTO_GENERATE_PATH}) 7 | file(GLOB_RECURSE PROTO_DEF "${CMAKE_CURRENT_SOURCE_DIR}/*.proto") 8 | protobuf_generate_cpp(PROTO_SRC PROTO_HEADER ${PROTO_DEF} ) 9 | 10 | # compile protobuf-generated source 11 | include_directories(${PROTO_GENERATE_PATH}) 12 | add_library(proto STATIC ${PROTO_HEADER} ${PROTO_SRC}) 13 | target_link_libraries(proto PRIVATE ${Protobuf_LIBRARIES}) 14 | -------------------------------------------------------------------------------- /scripts/train_eval_trace-track_tiny.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | work_dir=work_dir/track-tiny 4 | if [[ -d $work_dir ]];then 5 | bak_work_dir=${work_dir}-bak 6 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 7 | mv ${work_dir} ${bak_work_dir} 8 | fi 9 | mkdir -p $work_dir 10 | 11 | script_path=$0 12 | cp $script_path $work_dir/ 13 | 14 | python tools/main.py \ 15 | --config-files \ 16 | configs/trajectory_tracking/config-track-base.yaml \ 17 | configs/trajectory_tracking/config-track-tiny.yaml \ 18 | --checkpoint-dir $work_dir/checkpoint \ 19 | --num-test-cases 1024 \ 20 | --train --eval --trace 21 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/dynamics_models/bicycle_model_test.cpp: -------------------------------------------------------------------------------- 1 | #include "bicycle_model.h" 2 | #include 3 | 4 | using namespace drltt; 5 | 6 | TEST(DynamicsModelTest, BicycleModelTest) { 7 | drltt_proto::HyperParameter hyper_parameter; 8 | hyper_parameter.mutable_bicycle_model()->set_front_overhang(0.9); 9 | hyper_parameter.mutable_bicycle_model()->set_rear_overhang(0.9); 10 | hyper_parameter.mutable_bicycle_model()->set_wheelbase(2.7); 11 | hyper_parameter.mutable_bicycle_model()->set_width(1.8); 12 | BicycleModel dynamics_model(hyper_parameter); 13 | drltt_proto::Action action; 14 | float delta_t = 0.1; 15 | dynamics_model.Step(action, delta_t); 16 | } 17 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/inference/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(drltt-sdk) 2 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 3 | 4 | 5 | file(GLOB_RECURSE _SRCS "*.[hc]pp") 6 | list(FILTER _SRCS EXCLUDE REGEX "_test.[hc]pp$") 7 | add_library(inference STATIC ${_SRCS}) 8 | target_link_libraries(inference 9 | ${TORCH_LIBRARIES} 10 | common 11 | ) 12 | 13 | # gtest 14 | if(BUILD_TESTS) 15 | file(GLOB_RECURSE _TEST_SRCS "*test.[hc]pp") 16 | add_executable( 17 | inference_test 18 | ${_TEST_SRCS} 19 | ) 20 | target_link_libraries( 21 | inference_test 22 | GTest::gtest_main 23 | inference 24 | ) 25 | gtest_discover_tests(inference_test) 26 | endif(BUILD_TESTS) 27 | -------------------------------------------------------------------------------- /drltt/common/future.py: -------------------------------------------------------------------------------- 1 | """Components ensuring compatibility""" 2 | 3 | from typing import Callable 4 | 5 | 6 | def override(func: Callable) -> Callable: 7 | """Try to import `typing.override` which is supported only for Python>=3.12 8 | 9 | Args: 10 | func: Function desired to be decorated by `typing.override` 11 | 12 | Returns: 13 | Callable: Decorated function if Python>=3.12. Undecorated function otherwise. 14 | """ 15 | try: 16 | from typing import override 17 | 18 | return override(func) 19 | except: 20 | return func 21 | 22 | 23 | try: 24 | from typing import Self 25 | except: 26 | from typing import Any 27 | 28 | Self = Any 29 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/environments/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(drltt-sdk) 2 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 3 | 4 | file(GLOB_RECURSE _SRCS "*.[hc]pp") 5 | list(FILTER _SRCS EXCLUDE REGEX "_test.[hc]pp$") 6 | add_library(environments STATIC ${_SRCS}) 7 | target_link_libraries(environments 8 | common 9 | dynamics_models 10 | inference 11 | managers 12 | ) 13 | 14 | # gtest 15 | if(BUILD_TESTS) 16 | file(GLOB_RECURSE _TEST_SRCS "*test.[hc]pp") 17 | add_executable( 18 | environments_test 19 | ${_TEST_SRCS} 20 | ) 21 | target_link_libraries( 22 | environments_test 23 | environments 24 | GTest::gtest_main 25 | ) 26 | gtest_discover_tests(environments_test) 27 | endif(BUILD_TESTS) 28 | -------------------------------------------------------------------------------- /scripts/train_eval_trace-track_var_reflen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | work_dir=work_dir/track-var-reflen 4 | if [[ -d $work_dir ]];then 5 | bak_work_dir=${work_dir}-bak 6 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 7 | mv ${work_dir} ${bak_work_dir} 8 | fi 9 | mkdir -p $work_dir 10 | 11 | script_path=$0 12 | cp $script_path $work_dir/ 13 | 14 | python tools/main.py \ 15 | --config-files \ 16 | configs/trajectory_tracking/config-track-base.yaml \ 17 | configs/trajectory_tracking/config-track.yaml \ 18 | configs/trajectory_tracking/config-track-var-reflen.yaml \ 19 | --checkpoint-dir $work_dir/checkpoint \ 20 | --num-test-cases 1024 \ 21 | --train --eval --trace 22 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/dynamics_models/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(drltt-sdk) 2 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 3 | 4 | 5 | file(GLOB_RECURSE _SRCS "*.[hc]pp") 6 | list(FILTER _SRCS EXCLUDE REGEX "_test.[hc]pp$") 7 | add_library(dynamics_models STATIC ${_SRCS}) 8 | target_link_libraries(dynamics_models 9 | ${PROTOBUF_LIBRARY} 10 | common 11 | proto 12 | ) 13 | 14 | # gtest 15 | if(BUILD_TESTS) 16 | file(GLOB_RECURSE _TEST_SRCS "*test.[hc]pp") 17 | add_executable( 18 | dynamics_models_test 19 | ${_TEST_SRCS} 20 | ) 21 | target_link_libraries( 22 | dynamics_models_test 23 | GTest::gtest_main 24 | dynamics_models 25 | ) 26 | gtest_discover_tests(dynamics_models_test) 27 | endif(BUILD_TESTS) 28 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(drltt_sdk) 2 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 3 | # TODO: reorg include paths, simplify them 4 | 5 | file(GLOB_RECURSE _SRCS "*.[hc]pp") 6 | list(FILTER _SRCS EXCLUDE REGEX "_test.[hc]pp$") 7 | add_library(common STATIC ${_SRCS}) 8 | target_link_libraries( 9 | common 10 | ${PROTOBUF_LIBRARY} 11 | ${TORCH_LIBRARIES} 12 | proto 13 | ) 14 | 15 | # gtest 16 | if(BUILD_TESTS) 17 | file(GLOB_RECURSE _TEST_SRCS "*test.[hc]pp") 18 | add_executable( 19 | common_test 20 | ${_TEST_SRCS} 21 | ) 22 | target_link_libraries( 23 | common_test 24 | GTest::gtest_main 25 | common 26 | ) 27 | gtest_discover_tests(common_test) 28 | endif(BUILD_TESTS) 29 | -------------------------------------------------------------------------------- /scripts/tests/train_eval_trace-track_test_sample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | work_dir=work_dir/track-test 4 | if [[ -d $work_dir ]];then 5 | bak_work_dir=${work_dir}-bak 6 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 7 | mv ${work_dir} ${bak_work_dir} 8 | fi 9 | mkdir -p $work_dir 10 | 11 | script_path=$0 12 | cp $script_path $work_dir/ 13 | 14 | python tools/main.py \ 15 | --config-files \ 16 | configs/trajectory_tracking/config-track-base.yaml \ 17 | configs/trajectory_tracking/config-track-tiny.yaml \ 18 | configs/trajectory_tracking/test_samples/config-track-test-sample.yaml \ 19 | --checkpoint-dir $work_dir/checkpoint \ 20 | --num-test-cases 1024 \ 21 | --train --eval --trace 22 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/trajectory_tracker/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(drltt-sdk) 2 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 3 | 4 | file(GLOB_RECURSE _SRCS "*.[hc]pp") 5 | list(FILTER _SRCS EXCLUDE REGEX "_pybind_export.[hc]pp$") 6 | add_library(trajectory_tracker STATIC ${_SRCS}) 7 | target_link_libraries(trajectory_tracker 8 | ${TORCH_LIBRARIES} 9 | environments 10 | ) 11 | 12 | # pybind export 13 | file(GLOB_RECURSE _EXPORT_SRCS "*_pybind_export.[hc]pp") 14 | pybind11_add_module( 15 | trajectory_tracker_pybind_export 16 | ${_EXPORT_SRCS} 17 | ) 18 | target_include_directories(trajectory_tracker_pybind_export PRIVATE ${PYBIND11_INCLUDE_DIRS}) 19 | target_link_libraries(trajectory_tracker_pybind_export 20 | PRIVATE 21 | trajectory_tracker 22 | ) 23 | -------------------------------------------------------------------------------- /scripts/eval/eval-track_tiny.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | origin_work_dir=submodules/drltt-assets/checkpoints/track-tiny 4 | work_dir=work_dir/eval/$(basename $origin_work_dir) 5 | if [[ -d $work_dir ]];then 6 | bak_work_dir=${work_dir}-bak 7 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 8 | mv ${work_dir} ${bak_work_dir} 9 | fi 10 | mkdir -p $work_dir 11 | cp -r $origin_work_dir/* $work_dir 12 | 13 | script_path=$0 14 | cp $script_path $work_dir/ 15 | 16 | python tools/main.py \ 17 | --config-files \ 18 | configs/trajectory_tracking/config-track-base.yaml \ 19 | configs/trajectory_tracking/config-track-tiny.yaml \ 20 | --checkpoint-dir $work_dir/checkpoint \ 21 | --num-test-cases 1024 \ 22 | --eval 23 | -------------------------------------------------------------------------------- /scripts/tests/train_eval_trace-track_test_sample_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | work_dir=work_dir/track-test 4 | if [[ -d $work_dir ]];then 5 | bak_work_dir=${work_dir}-bak 6 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 7 | mv ${work_dir} ${bak_work_dir} 8 | fi 9 | mkdir -p $work_dir 10 | 11 | script_path=$0 12 | cp $script_path $work_dir/ 13 | 14 | python tools/main.py \ 15 | --config-files \ 16 | configs/trajectory_tracking/config-track-base.yaml \ 17 | configs/trajectory_tracking/config-track-tiny.yaml \ 18 | configs/trajectory_tracking/test_samples/config-track-test-sample-dummy.yaml \ 19 | --checkpoint-dir $work_dir/checkpoint \ 20 | --num-test-cases 1 \ 21 | --train --eval --trace 22 | -------------------------------------------------------------------------------- /scripts/tests/train_eval_trace-track_test_sample_fast.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | work_dir=work_dir/track-test 4 | if [[ -d $work_dir ]];then 5 | bak_work_dir=${work_dir}-bak 6 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 7 | mv ${work_dir} ${bak_work_dir} 8 | fi 9 | mkdir -p $work_dir 10 | 11 | script_path=$0 12 | cp $script_path $work_dir/ 13 | 14 | python tools/main.py \ 15 | --config-files \ 16 | configs/trajectory_tracking/config-track-base.yaml \ 17 | configs/trajectory_tracking/config-track-tiny.yaml \ 18 | configs/trajectory_tracking/test_samples/config-track-test-sample-fast.yaml \ 19 | --checkpoint-dir $work_dir/checkpoint \ 20 | --num-test-cases 1024 \ 21 | --train --eval --trace 22 | -------------------------------------------------------------------------------- /test/test-cpp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | log_dir=./test-log 4 | mkdir -p $log_dir 5 | 6 | ( 7 | 8 | echo "SETTING UP CPP TESTING ENVIRONMENT" 9 | if [[ $1 == "fast" ]]; then 10 | gen_data_script=./scripts/tests/train_eval_trace-track_test_sample_fast.sh 11 | shift 12 | else 13 | gen_data_script=./scripts/tests/train_eval_trace-track_test_sample.sh 14 | fi 15 | 16 | if [[ ! ( $1 == "test" && $2 == "reuse-checkpoint" ) ]]; then 17 | # generate checkpoint for sdk test 18 | source $gen_data_script 19 | else 20 | source setup.sh 21 | fi 22 | echo "TEST CPP CODE" 23 | pushd sdk 24 | bash ./compile-in-docker.sh "$@" 25 | retval=$? 26 | popd 27 | ) 2>&1 | tee ./${log_dir}/cpp-test.log 28 | 29 | exit $retval 30 | -------------------------------------------------------------------------------- /docs/source/rst_files/api-common.rst: -------------------------------------------------------------------------------- 1 | APIs: drltt.common 2 | ================================= 3 | 4 | 5 | drltt.common.registry 6 | --------------------------------- 7 | .. automodule:: drltt.common.registry 8 | :members: 9 | :special-members: 10 | :private-members: 11 | 12 | drltt.common.io 13 | --------------------------------- 14 | .. automodule:: drltt.common.io 15 | :members: 16 | :special-members: 17 | :private-members: 18 | 19 | drltt.common.gym_helper 20 | --------------------------------- 21 | .. automodule:: drltt.common.gym_helper 22 | :members: 23 | :special-members: 24 | :private-members: 25 | 26 | drltt.common.geometry 27 | --------------------------------- 28 | .. automodule:: drltt.common.geometry 29 | :members: 30 | :special-members: 31 | :private-members: 32 | -------------------------------------------------------------------------------- /scripts/eval/eval-track.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | origin_work_dir=submodules/drltt-assets/checkpoints/track 4 | work_dir=work_dir/eval/$(basename $origin_work_dir) 5 | if [[ -d $work_dir ]];then 6 | bak_work_dir=${work_dir}-bak 7 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 8 | mv ${work_dir} ${bak_work_dir} 9 | fi 10 | mkdir -p $work_dir 11 | cp -r $origin_work_dir/* $work_dir 12 | 13 | script_path=$0 14 | cp $script_path $work_dir/ 15 | 16 | python tools/main.py \ 17 | --config-files \ 18 | configs/trajectory_tracking/config-track-base.yaml \ 19 | configs/trajectory_tracking/config-track.yaml \ 20 | configs/trajectory_tracking/config-track-eval.yaml \ 21 | --checkpoint-dir $work_dir/checkpoint \ 22 | --num-test-cases 1024 \ 23 | --eval 24 | -------------------------------------------------------------------------------- /drltt/simulator/trajectory/random_walk_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from drltt.simulator.dynamics_models import BicycleModel 4 | from drltt.simulator.trajectory.random_walk import random_walk 5 | 6 | 7 | def test_random_walk(): 8 | dynamics_model = BicycleModel( 9 | front_overhang=0.9, 10 | rear_overhang=0.9, 11 | wheelbase=2.7, 12 | width=1.8, 13 | action_space_lb=[-3.0, -0.5235987755983], 14 | action_space_ub=[+3.0, +0.5235987755983], 15 | ) 16 | dynamics_model.set_state(np.array((0.0, 0.0, 0.0, 10.0))) 17 | reference_line, trajectory = random_walk(dynamics_model, step_interval=0.1, walk_length=60) 18 | 19 | assert len(reference_line.waypoints) == 60 20 | assert len(trajectory.waypoints) == 60 21 | 22 | 23 | if __name__ == '__main__': 24 | test_random_walk() 25 | -------------------------------------------------------------------------------- /scripts/eval/eval-track_var_reflen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source setup.sh 3 | origin_work_dir=submodules/drltt-assets/checkpoints/track-var-reflen 4 | work_dir=work_dir/eval/$(basename $origin_work_dir) 5 | if [[ -d $work_dir ]];then 6 | bak_work_dir=${work_dir}-bak 7 | if [[ -d ${bak_work_dir} ]];then rm -rf ${bak_work_dir};fi 8 | mv ${work_dir} ${bak_work_dir} 9 | fi 10 | mkdir -p $work_dir 11 | cp -r $origin_work_dir/* $work_dir 12 | 13 | script_path=$0 14 | cp $script_path $work_dir/ 15 | 16 | python tools/main.py \ 17 | --config-files \ 18 | configs/trajectory_tracking/config-track-base.yaml \ 19 | configs/trajectory_tracking/config-track.yaml \ 20 | configs/trajectory_tracking/config-track-var-reflen.yaml \ 21 | --checkpoint-dir $work_dir/checkpoint \ 22 | --num-test-cases 1024 \ 23 | --eval 24 | -------------------------------------------------------------------------------- /drltt/simulator/dynamics_models/bicycle_model_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from drltt.simulator.dynamics_models import BicycleModel 4 | 5 | 6 | def test_bicycle_model(): 7 | dynamics_model = BicycleModel( 8 | front_overhang=0.9, 9 | rear_overhang=0.9, 10 | wheelbase=2.7, 11 | width=1.8, 12 | action_space_lb=[-3.0, -0.5235987755983], 13 | action_space_ub=[+3.0, +0.5235987755983], 14 | max_lat_acc=2.0, 15 | ) 16 | dynamics_model.set_state(np.array((0.0, 0.0, 0.0, 5.0))) 17 | assert dynamics_model.cog_relative_position_between_axles > 0.0 18 | assert dynamics_model.cog_relative_position_between_axles < 1.0 19 | assert dynamics_model.max_steer > 0.0 20 | assert dynamics_model.max_steer < np.pi 21 | 22 | 23 | if __name__ == '__main__': 24 | test_bicycle_model() 25 | -------------------------------------------------------------------------------- /install-setup-protoc-gen-doc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PROTO_GEN_DOC_VERSION="1.5.1" 4 | 5 | home_bin_dir=$HOME/.local/bin 6 | export PATH=${home_bin_dir}:$PATH 7 | 8 | if [[ ! -x $(command -v protoc-gen-doc) ]];then 9 | mkdir -p ${home_bin_dir} 10 | protoc_doc_gen_binary=${home_bin_dir}/protoc-gen-doc 11 | if [[ ! -f ${protoc_doc_gen_binary} ]];then 12 | tmp_dir=/tmp/drltt-$(openssl rand -hex 6) 13 | mkdir ${tmp_dir} 14 | pushd ${tmp_dir} 15 | PROTO_GEN_DOC_FILENAME=protoc-gen-doc_${PROTO_GEN_DOC_VERSION}_linux_amd64.tar.gz && \ 16 | curl -OL https://github.com/pseudomuto/protoc-gen-doc/releases/download/v${PROTO_GEN_DOC_VERSION}/${PROTO_GEN_DOC_FILENAME} && \ 17 | tar -xvf ${PROTO_GEN_DOC_FILENAME} -C $home_bin_dir/ protoc-gen-doc && \ 18 | popd 19 | rm -rf ${tmp_dir} 20 | fi 21 | fi 22 | 23 | echo "protoc-gen-doc installed and setup at $(which protoc-gen-doc)" 24 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/trajectory_tracker/trajectory_tracker.cpp: -------------------------------------------------------------------------------- 1 | #include "trajectory_tracker.h" 2 | 3 | namespace drltt { 4 | TrajectoryTracker::TrajectoryTracker(const std::string& load_path, 5 | int dynamics_model_index) { 6 | _env.LoadPolicy(load_path + "./traced_policy.pt"); 7 | _env.LoadEnvData(load_path + "./env_data.bin"); 8 | _env.set_dynamics_model_hyper_parameter(dynamics_model_index); 9 | } 10 | 11 | bool TrajectoryTracker::set_reference_line( 12 | const REFERENCE_LINE& reference_line) { 13 | return _env.set_reference_line(reference_line); 14 | } 15 | 16 | bool TrajectoryTracker::set_dynamics_model_initial_state( 17 | const STATE& init_state) { 18 | return _env.set_dynamics_model_initial_state(init_state); 19 | } 20 | 21 | TRAJECTORY TrajectoryTracker::TrackReferenceLine() { 22 | _env.RollOut(); 23 | return _env.get_tracked_trajectory(); 24 | } 25 | 26 | } // namespace drltt 27 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /cicd/start-gitlab-runner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | runner_image_name="drltt:cicd" 4 | executor_image_name="drltt:runtime" 5 | 6 | gitlab_url=$1 7 | token=$2 8 | 9 | if [[ -z ${token} ]];then 10 | echo "No gitlab runner token provided." 11 | exit 1 12 | fi 13 | 14 | register_cmd="gitlab-runner register \ 15 | --non-interactive \ 16 | --url ${gitlab_url} \ 17 | --token ${token} \ 18 | --executor docker \ 19 | --docker-image "${executor_image_name}" \ 20 | --docker-pull-policy if-not-present \ 21 | --name test-runner \ 22 | " 23 | 24 | docker_container_cmd="( 25 | ${register_cmd}; 26 | gitlab-runner start; 27 | sleep infinity; 28 | )" 29 | 30 | docker_arg_suffix=${runner_image_name} 31 | 32 | docker_container_name=drltt-cicd-$(date +%s) 33 | docker run --name ${docker_container_name} --entrypoint bash -e "ACCEPT_EULA=Y" --rm --network=host \ 34 | -v /var/run/docker.sock:/var/run/docker.sock \ 35 | ${docker_arg_suffix} \ 36 | -c "${docker_container_cmd}" 37 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/trajectory/trajectory.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "drltt_proto/dynamics_model/state.proto"; 4 | import "drltt_proto/dynamics_model/action.proto"; 5 | 6 | package drltt_proto; 7 | 8 | // TODO: add step_interval to all structure. 9 | 10 | // Reference line's waypoint. 11 | // TODO: move to body state 12 | message ReferenceLineWaypoint { 13 | // X-coordinate in [m] 14 | optional float x = 1; 15 | // Y-coordinate in [m] 16 | optional float y = 2; 17 | } 18 | 19 | // Reference line 20 | message ReferenceLine { 21 | // Sequence of waypoints. 22 | repeated ReferenceLineWaypoint waypoints = 1; 23 | } 24 | 25 | // Trajectory's waypoint. 26 | message TrajectoryWaypoint { 27 | // Sequence of states. 28 | optional State state = 1; 29 | // Sequence of actions. 30 | optional Action action = 2; 31 | } 32 | 33 | // Trajectory 34 | message Trajectory { 35 | // Sequence of waypoint 36 | repeated TrajectoryWaypoint waypoints = 1; 37 | } 38 | -------------------------------------------------------------------------------- /drltt/simulator/environments/trajectory_tracking_env_test.py: -------------------------------------------------------------------------------- 1 | from gym import Env 2 | 3 | from drltt.common import build_object_within_registry_from_config 4 | from drltt.common.io import load_and_override_configs 5 | from drltt.simulator import TEST_CONFIG_PATHS 6 | from drltt.simulator.environments import ENVIRONMENTS 7 | 8 | 9 | def test_trajectory_tracking_env(): 10 | config = load_and_override_configs(TEST_CONFIG_PATHS) 11 | 12 | env_config = config['environment'] 13 | 14 | env: Env = build_object_within_registry_from_config(ENVIRONMENTS, env_config) 15 | 16 | # observation, extra_info = env.reset() 17 | observation = env.reset() 18 | action = env.dynamics_model_manager.get_sampled_dynamics_model().get_action_space().sample() 19 | # observation, scalar_reward, terminated, truncated, extra_info = env.step(action) 20 | observation, scalar_reward, terminated, extra_info = env.step(action) 21 | 22 | env.get_dynamics_model_info() 23 | 24 | 25 | if __name__ == '__main__': 26 | test_trajectory_tracking_env() 27 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/protobuf_operators_test.cpp: -------------------------------------------------------------------------------- 1 | #include "protobuf_operators.h" 2 | #include 3 | 4 | TEST(BodyStateOperatorTest, BodyStateAdditionTest) { 5 | drltt_proto::BodyState state; 6 | state.set_x(0.); 7 | state.set_y(0.); 8 | state.set_r(0.); 9 | drltt_proto::BodyState state2; 10 | state2.set_x(1.); 11 | state2.set_y(1.); 12 | state2.set_r(1.); 13 | drltt_proto::BodyState state3 = state + state2; 14 | EXPECT_EQ(state3.x(), 1.0); 15 | EXPECT_EQ(state3.y(), 1.0); 16 | EXPECT_EQ(state3.r(), 1.0); 17 | } 18 | 19 | TEST(BodyStateOperatorTest, BodyStateMultiplicationTest) { 20 | drltt_proto::BodyState state; 21 | state.set_x(1.); 22 | state.set_y(1.); 23 | state.set_r(1.); 24 | float scalar = 3.0; 25 | drltt_proto::BodyState state2 = state * scalar; 26 | EXPECT_EQ(state2.x(), 3.0); 27 | EXPECT_EQ(state2.y(), 3.0); 28 | EXPECT_EQ(state2.r(), 3.0); 29 | drltt_proto::BodyState state3 = scalar * state; 30 | EXPECT_EQ(state3.x(), 3.0); 31 | EXPECT_EQ(state3.y(), 3.0); 32 | EXPECT_EQ(state3.r(), 3.0); 33 | } 34 | -------------------------------------------------------------------------------- /cicd/README.md: -------------------------------------------------------------------------------- 1 | # DRLTT CI/CD 2 | 3 | DRLTT uses [gitlab-cicd](https://docs.gitlab.com/ee/ci/pipelines/) to build CI/CD. 4 | 5 | ## Build Docker image for CI/CD 6 | 7 | Build Docker images `drltt:cicd` and `drltt:runtime` following the instructions in [Docker instructions](../docker/README.md). 8 | 9 | ## Set-up runner 10 | 11 | Go to `https://${GITLAB_URL}/${USER}/${REPO_NAME}/-/runners/new`, create a new runner (select `Linux`/`Run untagged jobs`), and copy the runner token ${RUNNER_TOKEN}. 12 | 13 | On the runner machine, run `start-gitlab-runner.sh`, and pass the runner token to it. 14 | 15 | ```bash 16 | ./start-gitlab-runner.sh ${GITLAB_URL} ${RUNNER_TOKEN} 17 | ``` 18 | 19 | Cleanup: To stop and clear all runners, run `stop-all-gitlab-runners.sh`. 20 | 21 | ```bash 22 | ./stop-all-gitlab-runners.sh 23 | ``` 24 | 25 | ## Configure Pipelines/Jobs/etc. 26 | 27 | See `.gitlab-ci.yml` for details. 28 | 29 | ## Check CI/CD results 30 | 31 | Go to `https://${GITLAB_URL}/${USER}/${REPO_NAME}/-/pipelines` to see the CI/CD and check out artifacts (generated files). 32 | -------------------------------------------------------------------------------- /docs/source/rst_files/api-simulator.rst: -------------------------------------------------------------------------------- 1 | APIs: drltt.simulator 2 | ================================= 3 | 4 | drltt.simulator.environments 5 | --------------------------------- 6 | .. automodule:: drltt.simulator.environments 7 | :members: 8 | :special-members: 9 | :private-members: 10 | 11 | drltt.simulator.dynamics_models 12 | --------------------------------- 13 | .. automodule:: drltt.simulator.dynamics_models 14 | :members: 15 | :special-members: 16 | :private-members: 17 | 18 | drltt.simulator.rl_learning 19 | --------------------------------- 20 | .. automodule:: drltt.simulator.rl_learning 21 | :members: 22 | :special-members: 23 | :private-members: 24 | 25 | drltt.simulator.observation 26 | --------------------------------- 27 | .. automodule:: drltt.simulator.observation 28 | :members: 29 | :special-members: 30 | :private-members: 31 | 32 | drltt.simulator.trajectory 33 | --------------------------------- 34 | .. automodule:: drltt.simulator.trajectory 35 | :members: 36 | :special-members: 37 | :private-members: 38 | 39 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | stages: 2 | - test 3 | 4 | drltt-code-check-job: 5 | stage: test 6 | script: 7 | - export BLACK_ARGS=" --check " 8 | - export CLANG_FORMAT_ARGS=" --dry-run " 9 | - ./format-code.sh 10 | interruptible: true 11 | 12 | drltt-python-test-job: 13 | stage: test 14 | script: 15 | - ./test/test-python.sh 16 | artifacts: 17 | paths: 18 | - "test-log/" 19 | expire_in: 1 week 20 | interruptible: true 21 | 22 | drltt-cpp-test-job: 23 | stage: test 24 | script: 25 | - ./cicd/test-cpp-ci.sh 26 | artifacts: 27 | paths: 28 | - "test-log/" 29 | expire_in: 1 week 30 | interruptible: true 31 | timeout: "10h 00m" 32 | 33 | drltt-doc-test-job: 34 | stage: test 35 | script: 36 | - ./test/test-doc.sh 37 | artifacts: 38 | paths: 39 | - "test-log/" 40 | - "docs/build/" 41 | expire_in: 1 week 42 | interruptible: true 43 | 44 | variables: 45 | GIT_SUBMODULE_STRATEGY: recursive 46 | GIT_SUBMODULE_UPDATE_FLAGS: --jobs 4 47 | GIT_DEPTH: 1 48 | GIT_SUBMODULE_DEPTH: 1 49 | GET_SOURCES_ATTEMPTS: 10 50 | -------------------------------------------------------------------------------- /install-setup-doxygen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOXYGEN_VERSION="1.10.0" 4 | 5 | home_bin_dir=$HOME/.local/bin 6 | export PATH=${home_bin_dir}:$PATH 7 | 8 | if [[ ! -x $(command -v doxygen) ]];then 9 | mkdir -p ${home_bin_dir} 10 | doxygen_binary=${home_bin_dir}/doxygen 11 | if [[ ! -f ${doxygen_binary} ]];then 12 | tmp_dir=/tmp/drltt-$(openssl rand -hex 6) 13 | mkdir ${tmp_dir} 14 | pushd ${tmp_dir} 15 | DOXYGEN_TARBALL_NAME=doxygen-${DOXYGEN_VERSION} 16 | DOXYGEN_TARBALL_FILENAME=${DOXYGEN_TARBALL_NAME}.linux.bin.tar.gz 17 | DOXYGEN_RELEASE_NAME=Release_$(echo $DOXYGEN_VERSION | sed -r 's/\./_/g') 18 | DOXYGEN_URL="https://github.com/doxygen/doxygen/releases/download/${DOXYGEN_RELEASE_NAME}/${DOXYGEN_TARBALL_FILENAME}" 19 | curl -OL ${DOXYGEN_URL} 20 | tar -xzvf ${DOXYGEN_TARBALL_FILENAME} && \ 21 | ( \ 22 | cd ${DOXYGEN_TARBALL_NAME}; \ 23 | mv ./bin/* $home_bin_dir/ ; \ 24 | ) 25 | popd 26 | rm -rf ${tmp_dir} 27 | fi 28 | fi 29 | 30 | echo "doxygen installed and setup at $(which doxygen)" 31 | -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: "2" 3 | 4 | # build process 5 | # https://docs.readthedocs.io/en/stable/builds.html 6 | # https://docs.readthedocs.io/en/stable/build-customization.html 7 | 8 | # https://docs.readthedocs.io/en/stable/config-file/v2.html#build 9 | build: 10 | os: "ubuntu-22.04" 11 | tools: 12 | python: "3.12" 13 | apt_packages: 14 | - "cmake" 15 | - "protobuf-compiler" 16 | - "doxygen" 17 | jobs: 18 | pre_build: 19 | - bash ./setup.sh 20 | 21 | # https://docs.readthedocs.io/en/stable/config-file/v2.html#python-install 22 | python: 23 | install: 24 | - requirements: requirements/pypi.txt 25 | - requirements: requirements/pypi-doc.txt 26 | - requirements: submodules/waymax-visualization/requirements.txt 27 | 28 | # https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx 29 | sphinx: 30 | configuration: docs/source/conf.py 31 | fail_on_warning: True 32 | 33 | # https://docs.readthedocs.io/en/stable/config-file/v2.html#submodules 34 | submodules: 35 | include: all 36 | recursive: True 37 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/compile_proto.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Start compiling protobuf." 4 | 5 | src_dir=$(dirname $0) 6 | python_output_dir=${src_dir}/../proto_gen_py/ 7 | rm -rf ${python_output_dir} 8 | mkdir -p ${python_output_dir} 9 | cpp_output_dir=${src_dir}/../proto_gen_cpp/ 10 | rm -rf ${cpp_output_dir} 11 | mkdir -p ${cpp_output_dir} 12 | doc_output_dir=${src_dir}/../proto_doc_gen/ 13 | rm -rf ${doc_output_dir} 14 | mkdir -p ${doc_output_dir} 15 | 16 | # NOTE: if build cpp target within docker, protobuf also need to be compiled within docker. 17 | # otherwise, cpp compiler will not find protobuf include file. 18 | # TODO: verify if `--experimental_allow_proto3_optional` is necessary 19 | protoc -I ${src_dir} \ 20 | --python_out ${python_output_dir} \ 21 | --cpp_out ${cpp_output_dir} \ 22 | --doc_out ${doc_output_dir} \ 23 | --experimental_allow_proto3_optional \ 24 | $(find ${src_dir} -name *.proto) 25 | 26 | # create `__init__.py` files for compiled Python package 27 | find ${python_output_dir} -mindepth 1 -type d -exec touch {}/__init__.py \; 28 | 29 | echo "Protobuf compiled." 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2024] [Yinda Xu] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /drltt/simulator/environments/env_interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from abc import abstractmethod 3 | 4 | import numpy as np 5 | from gym import Env 6 | from gym.spaces import Space 7 | 8 | from drltt_proto.environment.environment_pb2 import Environment 9 | 10 | 11 | class CustomizedEnvInterface: 12 | """Customized interface for extending gym for DRLTT.""" 13 | 14 | env_info: Environment 15 | 16 | @abstractmethod 17 | def export_environment_data(self) -> Environment: 18 | """Export environment data. 19 | 20 | Return: 21 | Environment: Environment data in proto structure. 22 | """ 23 | env_data = Environment() 24 | env_data.CopyFrom(self.env_info) 25 | 26 | return env_data 27 | 28 | @abstractmethod 29 | def get_state(self) -> np.ndarray: 30 | """Get the underlying state. 31 | 32 | Returns: 33 | np.ndarray: Vectorized underlying state. 34 | """ 35 | raise NotImplementedError 36 | 37 | 38 | class ExtendedGymEnv(Env, CustomizedEnvInterface): 39 | """The type object for the extended gym environmnet.""" 40 | 41 | pass 42 | -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | # DRLTT TEST 2 | 3 | ## DRLTT Python Test 4 | 5 | Python testing is done with *pytest*. To launch the Python testing, run `./test/test-python.sh`: 6 | 7 | .. literalinclude:: ../../../test/test-python.sh 8 | :language: bash 9 | 10 | ## DRLTT CPP Test 11 | 12 | CPP testing is performed through *gtest* immediately after building. To launch the CPP testing, run `./test/test-cpp.sh`: 13 | 14 | .. literalinclude:: ../../../test/test-cpp.sh 15 | :language: bash 16 | 17 | Please refer to *DRTLL SDK* for details. 18 | 19 | #### Accelerating CPP testing 20 | 21 | To skip SDK exporting (e.g. while debugging the test running), run: 22 | 23 | ```bash 24 | ./test-cpp.sh test 25 | ``` 26 | 27 | To skip both SDK exporting and checkpoint generation (e.g. while debugging the test building), run: 28 | 29 | ```bash 30 | ./test-cpp.sh test reuse-checkpoint 31 | ``` 32 | 33 | To use a sample config with a shorter time for test data generation (a dummy training), run: 34 | 35 | ```bash 36 | ./test-cpp.sh fast test 37 | ``` 38 | 39 | TODO: refactor argument parsing logic in test scripts. 40 | 41 | ## DRLTT Documentation Test 42 | 43 | To test the Documentation, run `./test/test-doc.sh`: 44 | -------------------------------------------------------------------------------- /install-setup-protoc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # - release v21.12, 2022/12/14 4 | # - https://github.com/protocolbuffers/protobuf/releases/tag/v21.12 5 | # - tag v3.21.12, 2022/12/13 6 | # - https://github.com/protocolbuffers/protobuf/releases/tag/v3.21.12 7 | # - installation instruction 8 | # - https://google.github.io/proto-lens/installing-protoc.html 9 | 10 | protobuf_release_version="21.12" 11 | proto_release_filename=protoc-${protobuf_release_version}-linux-x86_64.zip 12 | 13 | home_bin_dir=$HOME/.local/bin 14 | export PATH=${home_bin_dir}:$PATH 15 | 16 | if [[ ! -x $(command -v protoc) ]];then 17 | mkdir -p ${home_bin_dir} 18 | protoc_binary=${home_bin_dir}/protoc 19 | if [[ ! -f ${protoc_binary} ]];then 20 | tmp_dir=/tmp/drltt-$(openssl rand -hex 6) 21 | mkdir ${tmp_dir} 22 | pushd ${tmp_dir} 23 | curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v${protobuf_release_version}/${proto_release_filename} 24 | unzip ${proto_release_filename} -d protobuf-release/ 25 | mv protobuf-release/bin/protoc ${home_bin_dir} 26 | popd 27 | rm -rf ${tmp_dir} 28 | fi 29 | fi 30 | 31 | echo "protoc installed and setup at $(which protoc)" 32 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/dynamics_models/base_dynamics_model.cpp: -------------------------------------------------------------------------------- 1 | #include "base_dynamics_model.h" 2 | 3 | namespace drltt { 4 | 5 | BaseDynamicsModel::BaseDynamicsModel( 6 | const drltt_proto::HyperParameter& hyper_parameter) { 7 | _hyper_parameter.CopyFrom(hyper_parameter); 8 | this->parse_hyper_parameter(); 9 | } 10 | 11 | BaseDynamicsModel::BaseDynamicsModel( 12 | const drltt_proto::HyperParameter& hyper_parameter, 13 | const drltt_proto::State& init_state) 14 | : BaseDynamicsModel(hyper_parameter) { 15 | Reset(init_state); 16 | } 17 | 18 | void BaseDynamicsModel::Reset(const drltt_proto::State& state) { 19 | _state.CopyFrom(state); 20 | } 21 | 22 | drltt_proto::State BaseDynamicsModel::get_state() const { 23 | return _state; 24 | } 25 | 26 | bool BaseDynamicsModel::set_state(const drltt_proto::State& state) { 27 | _state.CopyFrom(state); 28 | return true; 29 | } 30 | 31 | drltt_proto::HyperParameter BaseDynamicsModel::get_hyper_parameter() const { 32 | return _hyper_parameter; 33 | } 34 | 35 | bool BaseDynamicsModel::set_hyper_parameter( 36 | const drltt_proto::HyperParameter& hyper_parameter) { 37 | _hyper_parameter.CopyFrom(hyper_parameter); 38 | return true; 39 | } 40 | 41 | } // namespace drltt 42 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/managers/observation_manager.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "drltt-sdk/common/common.h" 5 | #include "drltt-sdk/dynamics_models/base_dynamics_model.h" 6 | #include "drltt_proto/dynamics_model/state.pb.h" 7 | #include "drltt_proto/trajectory/trajectory.pb.h" 8 | 9 | namespace drltt { 10 | 11 | class ObservationManager { 12 | public: 13 | ObservationManager() = default; 14 | ~ObservationManager() {}; 15 | bool Reset(drltt_proto::ReferenceLine* reference_line_ptr, 16 | BaseDynamicsModel* dynamics_model_ptr); 17 | // TODO: remove index and window 18 | bool get_observation(const drltt_proto::BodyState& body_state, 19 | int start_index, int tracking_length, 20 | int n_observation_steps, 21 | std::vector* observation); 22 | 23 | private: 24 | bool get_reference_line_observation(const drltt_proto::BodyState& body_state, 25 | int start_index, int tracking_length, 26 | int n_observation_steps, 27 | std::vector* observation); 28 | const drltt_proto::ReferenceLine* _reference_line_ptr; 29 | const BaseDynamicsModel* _dynamics_model_ptr; 30 | }; 31 | 32 | } // namespace drltt 33 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/io.cpp: -------------------------------------------------------------------------------- 1 | #include "io.h" 2 | 3 | namespace drltt { 4 | 5 | drltt_proto::DebugInfo global_debug_info = drltt_proto::DebugInfo(); 6 | 7 | torch::Tensor parse_tensor_proto_to_torch_tensor( 8 | const drltt_proto::TensorFP& tensor_proto) { 9 | std::vector shape_vec(tensor_proto.shape().begin(), 10 | tensor_proto.shape().end()); 11 | // TODO: verify type 12 | std::vector data_vec(tensor_proto.data().begin(), 13 | tensor_proto.data().end()); 14 | 15 | // TODO: remove copy by using RVO and std::move 16 | // NOTE: only data_vec.data() pointer copied. data need to be copied 17 | // otherwise. 18 | torch::Tensor parsed_tensor = 19 | torch::from_blob(data_vec.data(), shape_vec, torch::kFloat32); 20 | 21 | // TODO: remove copy by using RVO and std::move 22 | return parsed_tensor.clone(); 23 | } 24 | 25 | bool convert_tensor_to_vector(const torch::Tensor& tensor, 26 | std::vector* vector) { 27 | auto flattened_tensor = tensor.view({tensor.numel()}); 28 | vector->reserve(flattened_tensor.numel()); 29 | vector->assign(flattened_tensor.data_ptr(), 30 | flattened_tensor.data_ptr() + flattened_tensor.numel()); 31 | return true; 32 | } 33 | 34 | } // namespace drltt 35 | -------------------------------------------------------------------------------- /sdk/compile-source.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # NOTE: This script runs in docker container. 4 | # NOTE: Current directory is `${SDK_ROOT_DIR}/build`. 5 | 6 | set -exo pipefail 7 | 8 | # shared libraries 9 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$USR_LIB_DIR 10 | 11 | # TODO figure out root cause and resolve it in more formal way 12 | # https://stackoverflow.com/questions/19901934/libpthread-so-0-error-adding-symbols-dso-missing-from-command-line 13 | export LDFLAGS="-Wl,--copy-dt-needed-entries" 14 | 15 | # Check version and location of Protobuf compiler 16 | echo "Using protoc $(protoc --version) at $(which protoc)" 17 | 18 | # Clear exisiting compiled files 19 | rm -rf ${BUILD_DIR} 20 | mkdir -p ${BUILD_DIR} 21 | rm -rf ${PROTO_GEN_DIR} 22 | mkdir -p ${PROTO_GEN_DIR} 23 | 24 | # Configure and build 25 | pushd ${BUILD_DIR} 26 | cmake .. \ 27 | -DBUILD_TESTS=ON \ 28 | -DREPO_ROOT_DIR=${REPO_ROOT_DIR} \ 29 | -DMACRO_CHECKPOINT_DIR=${CHECKPOINT_DIR} \ 30 | -DLIBTORCH_DIR=${LIBTORCH_DIR} \ 31 | && make -j$(nproc --all) 2>&1 | tee ./build.log 32 | cmake_ret_val=$? 33 | if [ ${cmake_ret_val} != 0 ]; then 34 | echo "cmake failed with exit ${cmake_ret_val}" 35 | exit ${cmake_ret_val} 36 | fi 37 | ctest -VV --rerun-failed --output-on-failure 2>&1 | tee ./test.log 38 | popd 39 | set +exo pipefail 40 | -------------------------------------------------------------------------------- /drltt/simulator/rl_learning/sb3_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Callable 2 | from copy import deepcopy 3 | 4 | import numpy as np 5 | 6 | from drltt.simulator.environments.env_interface import ExtendedGymEnv 7 | 8 | 9 | def roll_out_one_episode( 10 | environment: ExtendedGymEnv, 11 | policy_func: Callable, 12 | **kwargs, 13 | ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: 14 | """Roll out one episode and return a trajectory. 15 | 16 | Args: 17 | environment: The associated environment. 18 | policy_func: The policy function, observation -> action. 19 | 20 | Returns: 21 | List[np.ndarray]: States. 22 | List[np.ndarray]: Actions. 23 | List[np.ndarray]: Observations. 24 | """ 25 | states = list() 26 | actions = list() 27 | observations = list() 28 | 29 | obs = environment.reset(**kwargs) 30 | state = environment.get_state() 31 | 32 | done = False 33 | while not done: 34 | action = policy_func(obs) 35 | # collect data 36 | states.append(deepcopy(state)) 37 | actions.append(deepcopy(action)) 38 | observations.append(deepcopy(obs)) 39 | # step environment 40 | obs, reward, done, info = environment.step(action) 41 | state = environment.get_state() 42 | 43 | return states, actions, observations 44 | -------------------------------------------------------------------------------- /sdk/export-py-sdk.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # NOTE: current directory is `./sdk/build` 4 | 5 | echo "Exporting Python SDK at ${export_dir}..." 6 | 7 | project_name=${PROJECT_NAME} 8 | package_name=$(echo ${project_name}-py | sed -r 's/-/_/g') 9 | export_dir=${BUILD_DIR}/${package_name} 10 | mkdir -p $export_dir 11 | 12 | # export dependency shared library. 13 | # TODO: consider a more elegant way, like packaging 14 | libtorch_lib_dir=${LIBTORCH_DIR}/lib 15 | cp -r ${USR_LIB_DIR} $export_dir/ 16 | echo "User lib size: $(du -sh $USR_LIB_DIR)" 17 | if [[ ! -v PY_SDK_NO_LIBTORCH_EXPORTED ]];then 18 | cp -r ${libtorch_lib_dir} $export_dir/ 19 | echo "libtorch lib size: $(du -sh $libtorch_lib_dir)" 20 | fi 21 | echo "Total exported lib size: $(du -sh $export_dir/lib)" 22 | 23 | # export sdk shared library 24 | sdk_so=$(ls ${BUILD_DIR}/${project_name}/trajectory_tracker/trajectory_tracker_*.so|head -n 1) 25 | cp $sdk_so $export_dir/ 26 | pushd $export_dir/ 27 | ln -sf $(basename $sdk_so) export_symbols.so 28 | popd 29 | 30 | # export assets 31 | cp -r ${CHECKPOINT_DIR} $export_dir/ 32 | cp -r assets/exported-python-sdk/* $export_dir/ 33 | cp -r /proto/proto_gen_py $export_dir/ 34 | 35 | # package into tarball 36 | # TODO: move to a more formal packaging way 37 | pushd ${BUILD_DIR} 38 | tar -czf ./${package_name}.tar.gz ${package_name}/ 39 | echo "library packed: ${package_name}.tar.gz" 40 | popd 41 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/protobuf_operators.cpp: -------------------------------------------------------------------------------- 1 | #include "protobuf_operators.h" 2 | 3 | drltt_proto::BodyState operator+(drltt_proto::BodyState lhs, 4 | const drltt_proto::BodyState& rhs) { 5 | lhs.set_x(lhs.x() + rhs.x()); 6 | lhs.set_y(lhs.y() + rhs.y()); 7 | lhs.set_r(normalize_angle(lhs.r() + rhs.r())); 8 | 9 | return lhs; 10 | } 11 | 12 | drltt_proto::BodyState operator*(drltt_proto::BodyState lhs, float rhs) { 13 | lhs.set_x(lhs.x() * rhs); 14 | lhs.set_y(lhs.y() * rhs); 15 | lhs.set_r(lhs.r() * rhs); 16 | 17 | return lhs; 18 | } 19 | 20 | drltt_proto::BodyState operator*(float lhs, drltt_proto::BodyState rhs) { 21 | return rhs * lhs; 22 | } 23 | 24 | drltt_proto::BicycleModelState operator+( 25 | drltt_proto::BicycleModelState lhs, 26 | const drltt_proto::BicycleModelState& rhs) { 27 | lhs.mutable_body_state()->CopyFrom(lhs.body_state() + rhs.body_state()); 28 | lhs.set_v(lhs.v() + rhs.v()); 29 | 30 | return lhs; 31 | } 32 | 33 | drltt_proto::BicycleModelState operator*(drltt_proto::BicycleModelState lhs, 34 | float rhs) { 35 | lhs.mutable_body_state()->CopyFrom(lhs.body_state() * rhs); 36 | lhs.set_v(lhs.v() * rhs); 37 | 38 | return lhs; 39 | } 40 | 41 | drltt_proto::BicycleModelState operator*(float lhs, 42 | drltt_proto::BicycleModelState rhs) { 43 | return rhs * lhs; 44 | } 45 | -------------------------------------------------------------------------------- /drltt/simulator/visualization/visualize_trajectory_tracking_episode_test.py: -------------------------------------------------------------------------------- 1 | from gym import Env 2 | 3 | from drltt.common import build_object_within_registry_from_config 4 | from drltt.common.io import load_and_override_configs, generate_random_string 5 | from drltt.simulator import TEST_CONFIG_PATHS 6 | from drltt.simulator.environments import ENVIRONMENTS 7 | from drltt.simulator.rl_learning.sb3_learner import build_sb3_algorithm_from_config 8 | from drltt.simulator.rl_learning.sb3_utils import roll_out_one_episode 9 | from drltt.simulator.visualization.visualize_trajectory_tracking_episode import visualize_trajectory_tracking_episode 10 | 11 | from drltt_proto.environment.environment_pb2 import Environment 12 | 13 | 14 | def test_visualize_trajectory_tracking_episode(): 15 | config = load_and_override_configs(TEST_CONFIG_PATHS) 16 | 17 | env_config = config['environment'] 18 | environment: Env = build_object_within_registry_from_config(ENVIRONMENTS, env_config) 19 | algorithm_config = config['algorithm'] 20 | algorithm = build_sb3_algorithm_from_config(environment, algorithm_config) 21 | 22 | roll_out_one_episode(environment, lambda obs: algorithm.predict(obs)[0]) 23 | env_data: Environment = environment.export_environment_data() 24 | 25 | viz_prefix = f'/tmp/drltt-pytest-{generate_random_string(6)}' 26 | visualization_function = visualize_trajectory_tracking_episode 27 | visualization_function(env_data, viz_prefix) 28 | 29 | 30 | if __name__ == '__main__': 31 | test_visualize_trajectory_tracking_episode() 32 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/dynamics_model/hyper_parameter.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package drltt_proto; 4 | 5 | // TODO: create a subpacakge for dynamics model 6 | 7 | // TODO: add name field: LongVehicle/Truck/LongVehicle 8 | message HyperParameter { 9 | // The class name of the dynamics model. 10 | optional string type= 1; 11 | // The name of dynamics model. E.g. short vehicle, long vehicle, etc. 12 | optional string name= 2; 13 | // Bicycle model's hyper paramter. 14 | optional BicycleModelHyperParameter bicycle_model = 3; 15 | } 16 | 17 | message BicycleModelHyperParameter { 18 | // Vehicle length in [m]. 19 | optional float length = 1; 20 | // Distance in [m] between vehicle front and fron axle. 21 | optional float front_overhang = 2; 22 | // Distance in [m] between vehicle rear and rear axle. 23 | optional float rear_overhang = 3; 24 | // Distance in [m] between front axle and rear axle. 25 | optional float wheelbase = 4; 26 | // Vehicle width in [m]. 27 | optional float width = 5; 28 | // Distance in [m] between front axle and center-of-gravity (CoG). 29 | optional float frontwheel_to_cog = 6; 30 | // Distance in [m] between rear axle and center-of-gravity (CoG). 31 | optional float rearwheel_to_cog = 7; 32 | // Action space's lower bound. 33 | repeated float action_space_lb = 8; 34 | // Action space's upper bound. 35 | repeated float action_space_ub = 9; 36 | // Maximum lateral acceleration. 37 | optional float max_lat_acc = 10; 38 | } 39 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. DRL-based Trajectory Tracking (DRLTT) documentation master file, created by 2 | sphinx-quickstart on Sun Feb 18 09:17:22 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to DRL-based Trajectory Tracking (DRLTT)'s documentation! 7 | ================================================================= 8 | 9 | Code repository: https://github.com/MARMOTatZJU/drl-based-trajectory-tracking 10 | 11 | .. toctree:: 12 | :maxdepth: 3 13 | :caption: Introduction 14 | 15 | rst_files/readme 16 | 17 | .. toctree:: 18 | :maxdepth: 3 19 | :caption: SDK 20 | 21 | rst_files/readme-sdk 22 | 23 | .. toctree:: 24 | :maxdepth: 3 25 | :caption: Test 26 | 27 | rst_files/readme-test 28 | 29 | .. toctree:: 30 | :maxdepth: 3 31 | :caption: Documentation Guide 32 | 33 | rst_files/readme-docs 34 | 35 | .. toctree:: 36 | :maxdepth: 3 37 | :caption: Docker 38 | 39 | rst_files/readme-docker 40 | 41 | 42 | .. Indices and tables 43 | .. ================== 44 | 45 | .. * :ref:`genindex` 46 | .. * :ref:`modindex` 47 | .. * :ref:`search` 48 | 49 | 50 | API 51 | ================== 52 | 53 | .. toctree:: 54 | :maxdepth: 2 55 | :caption: APIs: simulator 56 | 57 | rst_files/api-simulator 58 | 59 | .. toctree:: 60 | :maxdepth: 2 61 | :caption: APIs: common 62 | 63 | rst_files/api-common 64 | 65 | .. toctree:: 66 | :maxdepth: 2 67 | :caption: APIs: SDK 68 | 69 | rst_files/api-sdk 70 | 71 | .. toctree:: 72 | :maxdepth: 2 73 | :caption: Protobuf Definition 74 | 75 | rst_files/protobuf 76 | -------------------------------------------------------------------------------- /drltt/simulator/trajectory_tracker/trajectory_tracker_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from drltt.simulator import TEST_CHECKPOINT_DIR 4 | from drltt.simulator.trajectory_tracker import TrajectoryTracker 5 | 6 | 7 | def test_trajectory_tracker_print(): 8 | trajectory_tracker = TrajectoryTracker(checkpoint_dir=TEST_CHECKPOINT_DIR) 9 | trajectory_tracker.get_dynamics_model_info() 10 | 11 | 12 | def test_trajectory_tracker_random_reference_line(): 13 | trajectory_tracker = TrajectoryTracker(checkpoint_dir=TEST_CHECKPOINT_DIR) 14 | states, actions = trajectory_tracker.track_reference_line() 15 | reference_line = trajectory_tracker.get_reference_line() 16 | 17 | assert len(states) == len(actions) == len(reference_line) 18 | 19 | 20 | def test_trajectory_tracker(): 21 | trajectory_tracker = TrajectoryTracker(checkpoint_dir=TEST_CHECKPOINT_DIR) 22 | 23 | init_v = 5.0 24 | init_r = 0.0 25 | init_state = (0.0, 0.0, init_r, init_v) 26 | tracking_length = 50 27 | step_interval = trajectory_tracker.get_step_interval() 28 | reference_line = [ 29 | ( 30 | np.cos(init_r) * init_v * step_interval * step_index, 31 | np.sin(init_r) * init_v * step_interval * step_index, 32 | ) 33 | for step_index in range(tracking_length) 34 | ] 35 | states, actions = trajectory_tracker.track_reference_line( 36 | init_state=init_state, 37 | reference_line=reference_line, 38 | ) 39 | 40 | assert len(states) == len(reference_line) 41 | assert len(actions) == len(reference_line) 42 | 43 | 44 | if __name__ == '__main__': 45 | test_trajectory_tracker() 46 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/trajectory_tracker/trajectory_tracker_pybind_export.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "trajectory_tracker.h" 5 | 6 | // TODO: export class directly 7 | // reference: https://pybind11.readthedocs.io/en/stable/classes.html 8 | // TODO: add unit test for exported library 9 | namespace py = pybind11; 10 | using namespace drltt; 11 | 12 | bool TrajectoryTrackerSetReferenceLine(TrajectoryTracker& trajectory_tracker, 13 | const REFERENCE_LINE& reference_line) { 14 | return trajectory_tracker.set_reference_line(reference_line); 15 | } 16 | 17 | bool TrajectoryTrackerSetDynamicsModelInitialState( 18 | TrajectoryTracker& trajectory_tracker, const STATE& init_state) { 19 | return trajectory_tracker.set_dynamics_model_initial_state(init_state); 20 | } 21 | TRAJECTORY TrajectoryTrackerTrackReferenceLine( 22 | TrajectoryTracker& trajectory_tracker) { 23 | return trajectory_tracker.TrackReferenceLine(); 24 | } 25 | 26 | // Reference: https://pybind11.readthedocs.io/en/stable/classes.html 27 | PYBIND11_MODULE(export_symbols, m) { 28 | m.doc() = "DRL-based Trajectory Tracking (DRLTT)"; 29 | py::class_(m, "TrajectoryTracker") 30 | .def(py::init<>()) 31 | .def(py::init()); 32 | m.def("trajectory_tracker_set_reference_line", 33 | &TrajectoryTrackerSetReferenceLine); 34 | m.def("trajectory_tracker_set_dynamics_model_initial_state", 35 | &TrajectoryTrackerSetDynamicsModelInitialState); 36 | m.def("trajectory_tracker_track_reference_line", 37 | &TrajectoryTrackerTrackReferenceLine); 38 | } 39 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/inference/policy_inference_test.cpp: -------------------------------------------------------------------------------- 1 | #include "policy_inference.h" 2 | #include 3 | #include "common/io.h" 4 | #include "drltt_proto/sdk/exported_policy_test_case.pb.h" 5 | 6 | using namespace drltt; 7 | 8 | TEST(PolicyInferenceTest, ForwardTest) { 9 | const std::string checkpoint_dir = MACRO_CHECKPOINT_DIR; 10 | const std::string module_path = checkpoint_dir + "/traced_policy.pt"; 11 | const std::string test_cases_path = 12 | checkpoint_dir + "/traced_policy_test_cases.bin"; 13 | 14 | // load test case data 15 | drltt_proto::ExportedPolicyTestCases test_cases_proto; 16 | parse_proto_from_file(test_cases_proto, test_cases_path); 17 | torch::Tensor gt_observations_tensor = 18 | parse_tensor_proto_to_torch_tensor(test_cases_proto.observations()); 19 | torch::Tensor gt_actions_tensor = 20 | parse_tensor_proto_to_torch_tensor(test_cases_proto.actions()); 21 | 22 | // perform inference 23 | TorchJITModulePolicy policy; 24 | policy.Load(module_path); 25 | torch::Tensor jit_actions_tensor = policy.Infer(gt_observations_tensor); 26 | 27 | // check result 28 | const float atol = 1e-5; 29 | const float rtol = 1e-3; 30 | const bool all_close = 31 | torch::allclose(jit_actions_tensor, gt_actions_tensor, atol, rtol); 32 | EXPECT_TRUE(all_close); 33 | torch::Tensor isclose = 34 | torch::isclose(jit_actions_tensor, gt_actions_tensor, atol, rtol); 35 | const float isclose_ratio = 36 | static_cast(torch::sum(isclose).item()) / 37 | static_cast(isclose.numel()); 38 | const float isclose_ratio_thres = 0.95; 39 | EXPECT_GT(isclose_ratio, isclose_ratio_thres); 40 | } 41 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/inference/policy_inference.cpp: -------------------------------------------------------------------------------- 1 | #include "policy_inference.h" 2 | 3 | namespace drltt { 4 | 5 | bool TorchJITModulePolicy::Load(const std::string& jit_module_path) { 6 | _module = torch::jit::load(jit_module_path); 7 | return true; 8 | } 9 | 10 | torch::Tensor TorchJITModulePolicy::Infer( 11 | const torch::Tensor& observations_tensor) { 12 | // TODO: use RVO and std::move to reduce copy 13 | std::vector jit_inputs; 14 | jit_inputs.push_back(observations_tensor); 15 | torch::Tensor jit_actions_tensor = _module.forward(jit_inputs).toTensor(); 16 | 17 | return jit_actions_tensor; 18 | } 19 | 20 | // TODO: use RVO and std::move to reduce copy 21 | std::vector TorchJITModulePolicy::Infer( 22 | const std::vector& observations_vec, 23 | const std::initializer_list& shape_vec) { 24 | std::vector observations_vec_copied(observations_vec.begin(), 25 | observations_vec.end()); 26 | torch::Tensor observation_tensor = 27 | torch::from_blob(observations_vec_copied.data(), shape_vec, 28 | torch::kFloat32) 29 | .view(shape_vec); 30 | torch::Tensor jit_actions_tensor = Infer(observation_tensor); 31 | std::vector actions_vec; 32 | convert_tensor_to_vector(jit_actions_tensor, &actions_vec); 33 | return actions_vec; 34 | } 35 | 36 | // TODO: use RVO and std::move to reduce copy 37 | std::vector TorchJITModulePolicy::Infer( 38 | const std::vector& observations_vec) { 39 | return Infer(observations_vec, 40 | {1, static_cast(observations_vec.size())}); 41 | } 42 | 43 | } // namespace drltt 44 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/io.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "drltt_proto/dynamics_model/basics.pb.h" 8 | #include "drltt_proto/sdk/exported_policy_test_case.pb.h" 9 | 10 | namespace drltt { 11 | 12 | extern drltt_proto::DebugInfo global_debug_info; 13 | 14 | /** 15 | * @brief Parse tensor in proto to torch Tensor. 16 | * TODO: use RVO and std::move to avoid copy 17 | * 18 | * @param tensor_proto Tensor in proto. 19 | * @return torch::Tensor Parsed torch Tensor. 20 | */ 21 | torch::Tensor parse_tensor_proto_to_torch_tensor( 22 | const drltt_proto::TensorFP& tensor_proto); 23 | 24 | /** 25 | * @brief Parse protobuf message from binary file. 26 | * 27 | * @tparam T The type of protobuf message. 28 | * @param proto_msg Protobuf message. 29 | * @param proto_path Path to the binary file. 30 | * @return true Parsing succeeded. 31 | * @return false Parsing failed. 32 | */ 33 | template 34 | bool parse_proto_from_file(T& proto_msg, const std::string& proto_path) { 35 | std::fstream input(proto_path, std::ios::in | std::ios::binary); 36 | if (!input) { 37 | std::cerr << proto_path << " not found!!!" << std::endl; 38 | return false; 39 | } else if (!proto_msg.ParseFromIstream(&input)) { 40 | std::cerr << "Parsing error." << std::endl; 41 | return false; 42 | } 43 | input.close(); 44 | return true; 45 | } 46 | 47 | /** 48 | * @brief Convert tensor to std::vector. 49 | * 50 | * @param tensor Source tensor. 51 | * @param vector Vector to store converted data. 52 | * @return true Conversion succeeded. 53 | * @return false Conversion failed. 54 | */ 55 | bool convert_tensor_to_vector(const torch::Tensor& tensor, 56 | std::vector* vector); 57 | 58 | } // namespace drltt 59 | -------------------------------------------------------------------------------- /drltt/simulator/visualization/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | def get_subjective_brightness(pixels: np.ndarray) -> np.ndarray: 7 | """Get subjective brightness. 8 | Reference: https://computergraphics.stackexchange.com/questions/5085/light-intensity-of-an-rgb-value 9 | 10 | Args: 11 | pixels: Input pixels, shape=[..., 3], channel_order=. 12 | 13 | Returns: 14 | np.ndarray: Brightness in [0.0, 1.0]. 15 | """ 16 | return np.round(0.21 * pixels[..., 0] + 0.72 * pixels[..., 1] + 0.07 * pixels[..., 2]) / 255.0 17 | 18 | 19 | def scale_xy_lim( 20 | xy_lim: Tuple[Tuple[float, float], Tuple[float, float]], ratio: float 21 | ) -> Tuple[Tuple[float, float], Tuple[float, float]]: 22 | """Scale a pair of limits on x/y-axis with a scaling ratio. 23 | 24 | Args: 25 | xy_lim (Tuple[Tuple[float, float], Tuple[float, float]]): Limits on x/y-axis to be scaled, format=. 26 | ratio (float): Scaling ratio. 27 | 28 | Returns: 29 | Tuple[Tuple[float, float], Tuple[float, float]]: Scaled limits on x/y-axis. 30 | """ 31 | x_lim, y_lim = xy_lim 32 | scaled_xlim = scale_axe_lim(x_lim, ratio) 33 | scaled_ylim = scale_axe_lim(y_lim, ratio) 34 | 35 | return (scaled_xlim, scaled_ylim) 36 | 37 | 38 | def scale_axe_lim(axe_lim: Tuple[float, float], ratio: float) -> Tuple[float, float]: 39 | """Scale an axe limit. 40 | 41 | Args: 42 | axe_lim (Tuple[float, float]): Axe limit to be scaled. 43 | ratio (float): Scaling ratio. 44 | 45 | Returns: 46 | Tuple[float, float]: Scaled axe limit. 47 | """ 48 | mid = (axe_lim[0] + axe_lim[1]) / 2 49 | scaled_lb = mid + (axe_lim[0] - mid) * ratio 50 | scaled_ub = mid + (axe_lim[1] - mid) * ratio 51 | 52 | return (scaled_lb, scaled_ub) 53 | -------------------------------------------------------------------------------- /sdk/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22) 2 | project(drltt-sdk) 3 | set(SDK_ROOT_DIR ${PROJECT_SOURCE_DIR}) 4 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 5 | 6 | option(BUILD_TESTS "Build tests." OFF) 7 | add_definitions(-DREPO_ROOT_DIR="$(git rev-parse --show-toplevel)") 8 | add_definitions(-DMACRO_CHECKPOINT_DIR="${REPO_ROOT_DIR}/work_dir/track-test/checkpoint") 9 | add_definitions(-DLIBTORCH_DIR="/libtorch") 10 | 11 | message("Using checkpoint at ${MACRO_CHECKPOINT_DIR}.") 12 | 13 | set(CMAKE_CXX_FLAGS "-O2 -Werror -fPIC -std=c++20 -march=native -ftree-vectorize") 14 | 15 | # TODO resolve it in more formal way 16 | # https://github.com/protocolbuffers/protobuf/issues/14500 17 | # set(CMAKE_MODULE_LINKER_FLAGS "-Wl,--copy-dt-needed-entries") 18 | # set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--copy-dt-needed-entries") 19 | # set(CMAKE_STATIC_LINKER_FLAGS "-Wl,--copy-dt-needed-entries") 20 | 21 | # libtorch 22 | set(LIBTORCH_CMAKE_PREFIX_PATH "${LIBTORCH_DIR}/share/cmake/") 23 | list(APPEND CMAKE_PREFIX_PATH ${LIBTORCH_CMAKE_PREFIX_PATH}) 24 | find_package(Torch REQUIRED) 25 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 26 | include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) 27 | 28 | # gtest 29 | if(BUILD_TESTS) 30 | find_package(GTest REQUIRED) 31 | include_directories(${GTEST_INCLUDE_DIR}) 32 | include(GoogleTest) 33 | enable_testing() 34 | endif(BUILD_TESTS) 35 | 36 | # protobuf 37 | set(PROTO_GENERATE_PATH ${SDK_ROOT_DIR}/proto_gen) 38 | add_subdirectory(${REPO_ROOT_DIR}/drltt/common/proto/proto_def proto_def) 39 | include_directories(${PROTO_GENERATE_PATH}) 40 | 41 | # setup pybind11 42 | execute_process(COMMAND "which python" OUTPUT_VARIABLE PYTHON_PATH) 43 | set(PYTHON_EXECUTABLE ${PYTHON_PATH}) 44 | set(PYBIND11_FINDPYTHON ON) 45 | find_package(pybind11 REQUIRED) 46 | include_directories(${PYBIND11_INCLUDE_DIRS}) 47 | 48 | 49 | add_subdirectory(drltt-sdk) 50 | 51 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/inference/policy_inference.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file policy_inference.h 3 | * @brief Policy inference based on LibTorch. 4 | * 5 | */ 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include "drltt-sdk/common/common.h" 11 | 12 | namespace drltt { 13 | 14 | // TODO: logging 15 | /** 16 | * @brief A policy based on torch JIT module. 17 | * 18 | */ 19 | class TorchJITModulePolicy { 20 | public: 21 | TorchJITModulePolicy() = default; 22 | ~TorchJITModulePolicy() {}; 23 | /** 24 | * @brief Load a JIT module from a specified path. 25 | * 26 | * @param jit_module_path Path to JIT module. 27 | * @return true Module loading succeeded. 28 | * @return false Module loading failed. 29 | */ 30 | bool Load(const std::string& jit_module_path); 31 | /** 32 | * @brief Perform inference. 33 | * 34 | * @param observations_tensor The torch tensor of observations, 35 | * Shape={batch_size, observation_dim}. 36 | * @return torch::Tensor The torch tensor of actions. Shape={batch_size, 37 | * action_dim} 38 | */ 39 | torch::Tensor Infer(const torch::Tensor& observations_tensor); 40 | /** 41 | * @brief Perform inference. 42 | * 43 | * @param observations_vec Vector of observation data. 44 | * @param shape_vec Vector of the shape of observation tensor. 45 | * @return std::vector , size=batch_size*action_dim 46 | */ 47 | std::vector Infer(const std::vector& observations_vec, 48 | const std::initializer_list& shape_vec); 49 | /** 50 | * @brief Perform inference. 51 | * 52 | * @param observations_vec Vector of observation data. Assuming batch_size=1, 53 | * i.e. original shape={1, observation_dim} 54 | * @return std::vector , size=batch_size*action_dim. 55 | */ 56 | std::vector Infer(const std::vector& observations_vec); 57 | 58 | protected: 59 | torch::jit::script::Module _module; 60 | }; 61 | } // namespace drltt 62 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/geometry.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "drltt_proto/dynamics_model/state.pb.h" 6 | 7 | // TODO: import from: https://github.com/ros/angles 8 | 9 | /** 10 | * Normalize a scalar angle to [0, 2*pi). 11 | * Source: http://docs.ros.org/en/indigo/api/angles/html/angles_8h_source.html, 12 | * L68 13 | * @param angle Input angle. 14 | * @return Normalized angle. 15 | */ 16 | static inline double normalize_angle_positive(double angle) { 17 | return fmod(fmod(angle, 2.0 * M_PI) + 2.0 * M_PI, 2.0 * M_PI); 18 | } 19 | 20 | /** 21 | * Normalize a scalar angle to [-pi, pi). 22 | * Source: http://docs.ros.org/en/indigo/api/angles/html/angles_8h_source.html, 23 | * L81 24 | * @param angle Input angle. 25 | * @return Normalized angle. 26 | */ 27 | static inline double normalize_angle(double angle) { 28 | double a = normalize_angle_positive(angle); 29 | if (a >= M_PI) 30 | a -= 2.0 * M_PI; 31 | return a; 32 | } 33 | 34 | /** 35 | * @brief Transfer a SO(2) state from the world frame to the body frame. 36 | * TODO: referenceline waypoint: move to body state 37 | * TODO: resolve c++ function naming issue 38 | * TODO: unit test 39 | * 40 | * @param body_state 41 | * @param state State to be transformed. 42 | */ 43 | static inline void transform_to_local_from_world( 44 | const drltt_proto::BodyState& body_state, drltt_proto::BodyState* state) { 45 | const float x = body_state.x(); 46 | const float y = body_state.y(); 47 | const float r = body_state.r(); 48 | 49 | const float transformed_x = std::cos(r) * state->x() + 50 | std::sin(r) * state->y() - x * std::cos(r) - 51 | y * std::sin(r); 52 | const float transformed_y = -std::sin(r) * state->x() + 53 | std::cos(r) * state->y() + x * std::sin(r) - 54 | y * std::cos(r); 55 | const float transformed_r = state->r() - r; 56 | 57 | state->set_x(transformed_x); 58 | state->set_y(transformed_y); 59 | state->set_r(transformed_r); 60 | } 61 | -------------------------------------------------------------------------------- /drltt/simulator/trajectory/reference_line_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from drltt.simulator.dynamics_models import BicycleModel 4 | from drltt.simulator.trajectory.random_walk import random_walk 5 | from drltt.simulator.trajectory.reference_line import ReferenceLineManager 6 | 7 | 8 | def test_reference_line_manager(): 9 | dynamics_model = BicycleModel( 10 | front_overhang=0.9, 11 | rear_overhang=0.9, 12 | wheelbase=2.7, 13 | width=1.8, 14 | action_space_lb=[-3.0, -0.5235987755983], 15 | action_space_ub=[+3.0, +0.5235987755983], 16 | ) 17 | dynamics_model.set_state(np.array((0.0, 0.0, 0.0, 10.0))) 18 | tracking_length = 60 19 | n_observation_steps = 15 20 | reference_line, trajectory = random_walk(dynamics_model, step_interval=0.1, walk_length=tracking_length) 21 | reference_line_manager = ReferenceLineManager( 22 | n_observation_steps=n_observation_steps, 23 | pad_mode='repeat', 24 | ) 25 | reference_line_manager.set_reference_line(reference_line, tracking_length=tracking_length) 26 | assert len(reference_line_manager.raw_reference_line.waypoints) == 60 27 | 28 | 29 | def test_estimate_init_state_from_reference_line(): 30 | step_interval = 0.1 31 | init_v = 5.0 32 | init_r = np.pi / 3 33 | reference_line_length = 20 34 | 35 | reference_line_arr = np.array([ 36 | ( 37 | np.cos(init_r) * init_v * step_interval * step_index, 38 | np.sin(init_r) * init_v * step_interval * step_index, 39 | ) 40 | for step_index in range(reference_line_length) 41 | ]) 42 | reference_line = ReferenceLineManager.np_array_to_reference_line(reference_line_arr) 43 | estimated_init_state = ReferenceLineManager.estimate_init_state_from_reference_line(reference_line, step_interval) 44 | assert np.isclose(estimated_init_state.bicycle_model.body_state.r, init_r) 45 | assert np.isclose(estimated_init_state.bicycle_model.v, init_v) 46 | 47 | 48 | if __name__ == '__main__': 49 | test_reference_line_manager() 50 | test_estimate_init_state_from_reference_line() 51 | -------------------------------------------------------------------------------- /configs/trajectory_tracking/config-track-base.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | action_space_lb: &ACTION_SPACE_LB [-4.5, -0.5235987755983] 3 | action_space_ub: &ACTION_SPACE_UB [+4.5, +0.5235987755983] 4 | 5 | environment: 6 | type: 'TrajectoryTrackingEnv' 7 | step_interval: 0.1 8 | # TODO: add step index/ total number to observation 9 | # to deal with variance of cumulative reward w.r.t. different trajectory length 10 | tracking_length_lb: 50 11 | tracking_length_ub: 50 12 | reference_line_pad_mode: 'repeat' 13 | init_state_lb: [-1.e-8, -1.e-8, -3.1415926536, 0.1] 14 | init_state_ub: [+1.e-8, +1.e-8, +3.1415926536, 40.0] 15 | n_observation_steps: 15 16 | dynamics_model_configs: 17 | - type: 'BicycleModel' 18 | name: 'ShortVehicle' 19 | front_overhang: 0.9 20 | rear_overhang: 0.9 21 | wheelbase: 2.7 22 | width: 1.8 23 | action_space_lb: *ACTION_SPACE_LB 24 | action_space_ub: *ACTION_SPACE_UB 25 | max_lat_acc: 8.0 26 | - type: 'BicycleModel' 27 | name: 'LongVehicle' 28 | front_overhang: 2.3 29 | rear_overhang: 2.0 30 | wheelbase: 6.1 31 | width: 2.5 32 | action_space_lb: *ACTION_SPACE_LB 33 | action_space_ub: *ACTION_SPACE_UB 34 | max_lat_acc: 4.0 35 | - type: 'BicycleModel' 36 | name: 'MiddleVehicle' 37 | front_overhang: 1.0 38 | rear_overhang: 1.5 39 | wheelbase: 3.4 40 | width: 2.65 41 | action_space_lb: *ACTION_SPACE_LB 42 | action_space_ub: *ACTION_SPACE_UB 43 | max_lat_acc: 2.0 44 | 45 | algorithm: 46 | type: TD3 47 | policy: MlpPolicy 48 | train_freq: [4, "episode"] 49 | learning_rate: 1.0e-4 50 | scaled_action_noise: 51 | mean: [0.0, 0.0] 52 | sigma: [0.9, 0.3] 53 | verbose: 1 54 | learning_starts: 4096 55 | batch_size: 4096 56 | tau: 0.0001 57 | 58 | learning: 59 | total_timesteps: 1_000_000 60 | log_interval: 10 61 | 62 | evaluation: 63 | eval_config: 64 | n_episodes: 200 65 | compute_metrics_name: "compute_bicycle_model_metrics" 66 | visualization_function_name: "visualize_trajectory_tracking_episode" 67 | overriden_environment: {} 68 | -------------------------------------------------------------------------------- /drltt/simulator/trajectory/random_walk.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | from gym.spaces import Space 5 | 6 | from drltt.simulator.dynamics_models import BaseDynamicsModel 7 | 8 | from drltt_proto.trajectory.trajectory_pb2 import ReferenceLineWaypoint, ReferenceLine, TrajectoryWaypoint, Trajectory 9 | 10 | 11 | def random_walk( 12 | dynamics_model: BaseDynamicsModel, 13 | step_interval: float, 14 | walk_length: int, 15 | ) -> Tuple[ReferenceLine, Trajectory]: 16 | """Perform random walk to generate reference line. 17 | 18 | Args: 19 | dynamics_model: Dynamics model for random walk. 20 | step_interval: Time interval of a step. 21 | walk_length: Length of the generated trajectory. 22 | 23 | Returns: 24 | Tuple[ReferenceLine, Trajectory]: Random walk results: 25 | 26 | * ReferenceLine: Generated reference line. 27 | * Trajectory: Generated trajectory. 28 | """ 29 | assert walk_length >= 1, f'Illegal walk_length: {walk_length}' 30 | 31 | action_space: Space = dynamics_model.get_action_space() 32 | 33 | all_states = list() 34 | all_actions = list() 35 | for step_idx in range(walk_length - 1): 36 | state: np.ndarray = dynamics_model.get_state() 37 | action: np.ndarray = action_space.sample() 38 | 39 | all_states.append(state) 40 | all_actions.append(action) 41 | 42 | dynamics_model.step(action, step_interval) 43 | 44 | all_states.append(dynamics_model.get_state()) 45 | all_actions.append(np.zeros_like(all_actions[-1])) 46 | 47 | reference_line = ReferenceLine() 48 | trajectory = Trajectory() 49 | for state, action in zip(all_states, all_actions): 50 | ref_wpt = ReferenceLineWaypoint() 51 | ref_wpt.x = state[0] 52 | ref_wpt.y = state[1] 53 | reference_line.waypoints.append(ref_wpt) 54 | 55 | trj_wpt = TrajectoryWaypoint() 56 | trj_wpt.state.CopyFrom(dynamics_model.serialize_state(state)) 57 | trj_wpt.action.CopyFrom(dynamics_model.serialize_action(action)) 58 | trajectory.waypoints.append(trj_wpt) 59 | 60 | return (reference_line, trajectory) 61 | -------------------------------------------------------------------------------- /drltt/simulator/rl_learning/sb3_learner_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from gym import Env 5 | 6 | from drltt.common import build_object_within_registry_from_config 7 | from drltt.common.io import load_and_override_configs, generate_random_string 8 | from drltt.simulator import TEST_CONFIG_PATHS 9 | from drltt.simulator.environments import ENVIRONMENTS 10 | from drltt.simulator.rl_learning.sb3_learner import build_sb3_algorithm_from_config, train_with_sb3, eval_with_sb3 11 | 12 | 13 | # TODO: use pytest fixture/setup to refactor tests within file, to avoid copy-paste of test codes 14 | # TODO: use unit-test config to override sample config. 15 | 16 | 17 | def test_train_with_sb3(): 18 | config = load_and_override_configs(TEST_CONFIG_PATHS) 19 | 20 | env_config = config['environment'] 21 | environment: Env = build_object_within_registry_from_config(ENVIRONMENTS, env_config) 22 | config['learning']['total_timesteps'] = 64 23 | config['algorithm']['learning_starts'] = 16 24 | config['algorithm']['batch_size'] = 16 25 | test_checkpoint_dir = f'/tmp/drltt-pytest-{generate_random_string(6)}' 26 | test_checkpoint_file_prefix = f'{test_checkpoint_dir}/checkpoint' 27 | algorithm = train_with_sb3( 28 | environment=environment, 29 | algorithm_config=config['algorithm'], 30 | learning_config=config['learning'], 31 | checkpoint_file_prefix=test_checkpoint_file_prefix, 32 | ) 33 | shutil.rmtree(test_checkpoint_dir) 34 | 35 | 36 | def test_eval_with_sb3(): 37 | config = load_and_override_configs(TEST_CONFIG_PATHS) 38 | 39 | env_config = config['environment'] 40 | environment: Env = build_object_within_registry_from_config(ENVIRONMENTS, env_config) 41 | algorithm_config = config['algorithm'] 42 | algorithm = build_sb3_algorithm_from_config(environment, algorithm_config) 43 | 44 | eval_config = config['evaluation']['eval_config'] 45 | eval_config['n_episodes'] = 10 46 | report_dir = f'/tmp/drltt-pytest-{generate_random_string(6)}' 47 | eval_with_sb3(environment, algorithm, report_dir, **eval_config) 48 | if os.path.exists(report_dir): 49 | shutil.rmtree(report_dir, ignore_errors=True) 50 | 51 | 52 | if __name__ == '__main__': 53 | test_train_with_sb3() 54 | test_eval_with_sb3() 55 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/common/protobuf_operators.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "drltt_proto/dynamics_model/basics.pb.h" 4 | #include "drltt_proto/dynamics_model/state.pb.h" 5 | #include "geometry.h" 6 | 7 | /** 8 | * Add two body state together. Perform normalization on heading. 9 | * @param lhs Left hand side body state. 10 | * @param rhs Right hand side body state. 11 | * @return Added body state 12 | */ 13 | drltt_proto::BodyState operator+(drltt_proto::BodyState lhs, 14 | const drltt_proto::BodyState& rhs); 15 | 16 | /** 17 | * Scale a body state with a scalar to represent a difference in body state. No 18 | * normalization performed on heading in this case. 19 | * @param lhs Body state. 20 | * @param rhs Scalar. 21 | * @return Scaled body state. 22 | */ 23 | drltt_proto::BodyState operator*(drltt_proto::BodyState lhs, float rhs); 24 | 25 | /** 26 | * Scale a body state with a scalar to represent a difference in body state. 27 | * NOTE: No normalization performed on heading in this case. 28 | * @param lhs Scalar. 29 | * @param rhs Body state. 30 | * @return Scaled body state. 31 | */ 32 | drltt_proto::BodyState operator*(float lhs, drltt_proto::BodyState rhs); 33 | 34 | /** 35 | * Add two bicycle model state together. Perform normalization on heading. 36 | * @param lhs Left hand side bicycle model state. 37 | * @param rhs Right hand side bicycle model state. 38 | * @return Added bicycle model state 39 | */ 40 | drltt_proto::BicycleModelState operator+( 41 | drltt_proto::BicycleModelState lhs, 42 | const drltt_proto::BicycleModelState& rhs); 43 | 44 | /** 45 | * Scale a bicycle model state with a scalar to represent a difference in 46 | * bicycle model state. 47 | * @param lhs Bicycle model state. 48 | * @param rhs Scalar. 49 | * @return Scaled bicycle model state. 50 | */ 51 | drltt_proto::BicycleModelState operator*(drltt_proto::BicycleModelState lhs, 52 | float rhs); 53 | 54 | /** 55 | * Scale a bicycle model state with a scalar to represent a difference in 56 | * bicycle model state. 57 | * @param lhs Scalar. 58 | * @param rhs Bicycle model state. 59 | * @return Scaled bicycle model state. 60 | */ 61 | drltt_proto::BicycleModelState operator*(float lhs, 62 | drltt_proto::BicycleModelState rhs); 63 | -------------------------------------------------------------------------------- /sdk/assets/exported-python-sdk/trajectory_tracker.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | from enum import Enum 3 | import os 4 | 5 | from . import PACKAGE_DIR, USR_LIB_DIR, SDK_LIB_DIR 6 | 7 | CHECKPOINT_DIR = f'{PACKAGE_DIR}/checkpoint/' 8 | 9 | from export_symbols import ( 10 | TrajectoryTracker as ExportedTrajectoryTracker, 11 | trajectory_tracker_set_reference_line, 12 | trajectory_tracker_set_dynamics_model_initial_state, 13 | trajectory_tracker_track_reference_line, 14 | ) 15 | 16 | 17 | class DynamicsModelType(Enum): 18 | """Vechicle type. 19 | 20 | See `./configs/trajectory_tracking/config-track-base.yaml` for definition 21 | """ 22 | 23 | SHORT_VEHICLE: int = 0 24 | LONG_VEHICLE: int = 1 25 | MIDDLE_VEHICLE: int = 2 26 | 27 | 28 | class TrajectoryTracker: 29 | """DRLTT Trajectory Tracking policy wrapper""" 30 | 31 | def __init__(self, vehicle_type: int = DynamicsModelType.SHORT_VEHICLE.value): 32 | """ 33 | Args: 34 | vehicle_type: Vehicle type. Default is short vehicle. 35 | """ 36 | self._tracker = ExportedTrajectoryTracker(CHECKPOINT_DIR, vehicle_type) 37 | 38 | def track_reference_line( 39 | self, 40 | reference_line: List[Tuple[float, float]], 41 | init_state: Union[Tuple[float, float, float, float], None] = None, 42 | ) -> Tuple[List[Tuple[float, float, float, float]], List[Tuple[float, float]]]: 43 | """Track a reference line with the underlying policy model. 44 | 45 | Nomenclature: 46 | 47 | - x: X-coordinate in [m] within (-inf, +inf) 48 | - y: Y-coordinate in [m] within (-inf, +inf) 49 | - r: heading in [rad] within [-pi, pi), following convention of math lib like `std::atan2` 50 | - v: scalar speed in [m/s] within [0, +inf) 51 | 52 | Args: 53 | reference_line: Reference line, format=List[]. 54 | init_state: Initial state, format=. 55 | 56 | Return: 57 | Tuple[states, action]: The tracked trajectory. All elements have the same length 58 | that is equal to the length of reference line. 59 | 60 | - The first element is a sequence of states. 61 | - The second element is a sequence of actions. 62 | """ 63 | trajectory_tracker_set_reference_line(self._tracker, reference_line) 64 | if init_state is not None: 65 | trajectory_tracker_set_dynamics_model_initial_state(self._tracker, init_state) 66 | rt_states, rt_actions, rt_observations, rt_debug_datas = trajectory_tracker_track_reference_line(self._tracker) 67 | 68 | return rt_states, rt_actions 69 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/managers/observation_manager.cpp: -------------------------------------------------------------------------------- 1 | #include "observation_manager.h" 2 | 3 | namespace drltt { 4 | bool ObservationManager::Reset(drltt_proto::ReferenceLine* reference_line_ptr, 5 | BaseDynamicsModel* dynamics_model_ptr) { 6 | if (reference_line_ptr == nullptr || dynamics_model_ptr == nullptr) { 7 | std::cerr << "nullptr found." << std::endl; 8 | return false; 9 | } 10 | _reference_line_ptr = reference_line_ptr; 11 | _dynamics_model_ptr = dynamics_model_ptr; 12 | return true; 13 | } 14 | 15 | bool ObservationManager::get_observation( 16 | const drltt_proto::BodyState& body_state, int start_index, 17 | int tracking_length, int n_observation_steps, 18 | std::vector* observation) { 19 | get_reference_line_observation(body_state, start_index, tracking_length, 20 | n_observation_steps, observation); 21 | _dynamics_model_ptr->get_state_observation(observation); 22 | _dynamics_model_ptr->get_dynamics_model_observation(observation); 23 | return true; 24 | } 25 | 26 | // TODO unit test with `env_data.bin` 27 | bool ObservationManager::get_reference_line_observation( 28 | const drltt_proto::BodyState& body_state, int start_index, 29 | int tracking_length, int n_observation_steps, 30 | std::vector* observation) { 31 | // slice reference line segment 32 | std::vector observed_waypoint_ptrs; 33 | const int reference_line_length = _reference_line_ptr->waypoints().size(); 34 | for (int index = start_index; index < start_index + n_observation_steps; 35 | ++index) { 36 | if (index >= reference_line_length) { 37 | // padding at the end by repeating. TODO: support different type of 38 | // padding. 39 | const drltt_proto::ReferenceLineWaypoint& waypoint = 40 | _reference_line_ptr->waypoints().at(reference_line_length - 1); 41 | observed_waypoint_ptrs.push_back(&waypoint); 42 | } else { 43 | // normal observation 44 | const drltt_proto::ReferenceLineWaypoint& waypoint = 45 | _reference_line_ptr->waypoints().at(index); 46 | observed_waypoint_ptrs.push_back(&waypoint); 47 | } 48 | } 49 | // transform to body frame 50 | for (const auto& waypoint_ptr : observed_waypoint_ptrs) { 51 | drltt_proto::BodyState point; 52 | point.set_x(waypoint_ptr->x()); 53 | point.set_y(waypoint_ptr->y()); 54 | transform_to_local_from_world(body_state, &point); 55 | observation->push_back(point.x()); 56 | observation->push_back(point.y()); 57 | } 58 | // forward number of steps 59 | const int forward_tracking_length = tracking_length - start_index; 60 | observation->push_back(static_cast(forward_tracking_length)); 61 | 62 | return true; 63 | } 64 | 65 | } // namespace drltt -------------------------------------------------------------------------------- /sdk/drltt-sdk/dynamics_models/bicycle_model.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // TODO: verify if it is necessary 4 | #define _USE_MATH_DEFINES 5 | 6 | #include 7 | #include 8 | #include "base_dynamics_model.h" 9 | 10 | namespace drltt { 11 | 12 | class BicycleModel : public BaseDynamicsModel { 13 | public: 14 | BicycleModel() = default; 15 | BicycleModel(const drltt_proto::HyperParameter& hyper_parameter) 16 | : BaseDynamicsModel(hyper_parameter) {} 17 | BicycleModel(const drltt_proto::HyperParameter& hyper_parameter, 18 | const drltt_proto::State& state) 19 | : BaseDynamicsModel(hyper_parameter, state) {} 20 | void Step(const drltt_proto::Action& action, float delta_t) override; 21 | ~BicycleModel() = default; 22 | 23 | bool get_state_observation(std::vector* observation) const override; 24 | bool get_dynamics_model_observation( 25 | std::vector* observation) const override; 26 | 27 | protected: 28 | void parse_hyper_parameter() override; 29 | /** 30 | * @brief Compute derivative w.r.t. time. 31 | * TODO: use RVO to avoid copy. 32 | * @param state Dynamics model state. 33 | * @param action Dynamics model action. 34 | * @param hyper_parameter Dynamics model hyper parameter. 35 | * @return drltt_proto::State Derivative. Each field of the state represents 36 | * the derivative w.r.t. time of the field. 37 | */ 38 | drltt_proto::State _compute_derivative( 39 | const drltt_proto::State& state, const drltt_proto::Action& action, 40 | const drltt_proto::HyperParameter& hyper_parameter); 41 | /** 42 | * @brief Get the relative position of Center of Gravity (CoG) between axles. 43 | * The front axle represents 1.0, while the rear axle represents 0.0. 44 | * 45 | * @return float The relative position of CoG. 46 | */ 47 | float GetCogRelativePositionBetweenAxles() const; 48 | /** 49 | * @brief Compute variables related to rotations. 50 | * 51 | * @param steering_angle Steering angle. 52 | * @param omega The angle between the speed direction of Center of Gravity 53 | * (CoG) and the vehicle heading. 54 | * @param rotation_radius_inv The inverse of rotation radius of vehicle under 55 | * current steering angle. Return the inverse for reasons of numerical 56 | * stability. 57 | * @return true Computation succeeded. 58 | * @return false Computation failed. 59 | */ 60 | bool ComputeRotationRelatedVariables(float steering_angle, float* omega, 61 | float* rotation_radius_inv) const; 62 | /** 63 | * @brief Get the maximum steering angle. Based on current speed and the 64 | * maximum lateral acceleration. 65 | * 66 | * @return float The maximum steering angle. 67 | */ 68 | float GetMaxSteeringAngle() const; 69 | }; 70 | } // namespace drltt 71 | -------------------------------------------------------------------------------- /sdk/compile-in-docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | project_name=drltt-sdk 4 | image_name=drltt:runtime 5 | 6 | repo_root_dir=$(git rev-parse --show-toplevel) 7 | repo_dirname=$(basename $repo_root_dir) 8 | sdk_root_dir=$(dirname $0) 9 | build_dir=${sdk_dir}/build 10 | 11 | docker_repo_root_dir=/${repo_dirname} 12 | docker_repo_work_dir=${docker_repo_root_dir}/work_dir 13 | docker_checkpoint_dir=${docker_repo_work_dir}/track-test/checkpoint 14 | docker_sdk_root_dir=${docker_repo_root_dir}/sdk # TODO remove hardcode, use relative path 15 | docker_proto_gen_dir=${docker_sdk_root_dir}/proto_gen 16 | docker_build_dir=${docker_sdk_root_dir}/build 17 | 18 | docker_usr_lib_dir=/usr/local/lib 19 | 20 | if [[ -v HOST_LIBTORCH_DIR ]];then 21 | docker_libtorch_dir=/libtorch-host 22 | mount_host_libtorch_in_docker="-v ${HOST_LIBTORCH_DIR}:${docker_libtorch_dir}:ro" 23 | echo "Using libtorch mounted from host: ${HOST_LIBTORCH_DIR}" 24 | else 25 | docker_libtorch_dir=/libtorch 26 | mount_host_libtorch_in_docker="" 27 | echo "Using libtorch preinstalled in docker image" 28 | fi 29 | 30 | if [[ -v HOST_PYTORCH_DIR ]];then 31 | docker_pytorch_dir=/pytorch-host 32 | mount_pytorch_in_docker="-v $PYTORCH_DIR:/${docker_pytorch_dir}:rw" 33 | else 34 | docker_pytorch_dir=/pytorch 35 | mount_pytorch_in_docker="" 36 | fi 37 | 38 | # TODO: reorganize mouted path 39 | 40 | nonsudo_user_arg=" --user $(id -u):$(id -g) " 41 | nonsudo_user_source_cmd=" source /work_dir/.bashrc " 42 | 43 | if [[ $1 == "interactive" ]]; then 44 | docker_arg_suffix=" -it ${image_name} " 45 | docker_container_cmd=" bash " 46 | if [[ $2 == "nonsudo" ]]; then 47 | docker_arg_suffix="${nonsudo_user_arg} ${docker_arg_suffix}" 48 | docker_container_cmd=" ${nonsudo_user_source_cmd} && bash " 49 | fi 50 | else 51 | docker_arg_suffix="${nonsudo_user_arg} ${image_name} " 52 | docker_container_cmd=" ${nonsudo_user_source_cmd} && cd ${docker_sdk_root_dir} && bash ./compile-source.sh" 53 | if [[ ! $1 == "test" ]]; then 54 | docker_container_cmd="${docker_container_cmd} && bash ./export-py-sdk.sh" 55 | fi 56 | fi 57 | 58 | docker run --name drltt-sdk --entrypoint bash -e "ACCEPT_EULA=Y" --rm --network=host \ 59 | -e "PRIVACY_CONSENT=Y" \ 60 | -e "REPO_ROOT_DIR=${docker_repo_root_dir}" \ 61 | -e "PROJECT_NAME=${project_name}" \ 62 | -e "BUILD_DIR=${docker_build_dir}" \ 63 | -e "PROTO_GEN_DIR=${docker_proto_gen_dir}" \ 64 | -e "CHECKPOINT_DIR=${docker_checkpoint_dir}" \ 65 | -e "USR_LIB_DIR=${docker_usr_lib_dir}" \ 66 | -e "LIBTORCH_DIR=${docker_libtorch_dir}" \ 67 | -e "PYTORCH_DIR=${docker_pytorch_dir}" \ 68 | -v ${repo_root_dir}:${docker_repo_root_dir}:rw \ 69 | ${mount_pytorch_in_docker} \ 70 | ${mount_host_libtorch_in_docker} \ 71 | ${docker_arg_suffix} \ 72 | -c "${docker_container_cmd}" 73 | -------------------------------------------------------------------------------- /sdk/.clang-format: -------------------------------------------------------------------------------- 1 | # Google C/C++ Code Style settings 2 | # https://clang.llvm.org/docs/ClangFormatStyleOptions.html 3 | # Author: Kehan Xue, kehan.xue (at) gmail.com 4 | # URL: https://github.com/kehanXue/google-style-clang-format 5 | 6 | Language: Cpp 7 | BasedOnStyle: Google 8 | AccessModifierOffset: -1 9 | AlignAfterOpenBracket: Align 10 | AlignConsecutiveAssignments: None 11 | AlignOperands: Align 12 | AllowAllArgumentsOnNextLine: true 13 | AllowAllConstructorInitializersOnNextLine: true 14 | AllowAllParametersOfDeclarationOnNextLine: false 15 | AllowShortBlocksOnASingleLine: Empty 16 | AllowShortCaseLabelsOnASingleLine: false 17 | AllowShortFunctionsOnASingleLine: Inline 18 | AllowShortIfStatementsOnASingleLine: Never # To avoid conflict, set this "Never" and each "if statement" should include brace when coding 19 | AllowShortLambdasOnASingleLine: Inline 20 | AllowShortLoopsOnASingleLine: false 21 | AlwaysBreakAfterReturnType: None 22 | AlwaysBreakTemplateDeclarations: Yes 23 | BinPackArguments: true 24 | BreakBeforeBraces: Custom 25 | BraceWrapping: 26 | AfterCaseLabel: false 27 | AfterClass: false 28 | AfterStruct: false 29 | AfterControlStatement: Never 30 | AfterEnum: false 31 | AfterFunction: false 32 | AfterNamespace: false 33 | AfterUnion: false 34 | AfterExternBlock: false 35 | BeforeCatch: false 36 | BeforeElse: false 37 | BeforeLambdaBody: false 38 | IndentBraces: false 39 | SplitEmptyFunction: false 40 | SplitEmptyRecord: false 41 | SplitEmptyNamespace: false 42 | BreakBeforeBinaryOperators: None 43 | BreakBeforeTernaryOperators: true 44 | BreakConstructorInitializers: BeforeColon 45 | BreakInheritanceList: BeforeColon 46 | ColumnLimit: 80 47 | CompactNamespaces: false 48 | ContinuationIndentWidth: 4 49 | Cpp11BracedListStyle: true 50 | DerivePointerAlignment: false # Make sure the * or & align on the left 51 | EmptyLineBeforeAccessModifier: LogicalBlock 52 | FixNamespaceComments: true 53 | IncludeBlocks: Preserve 54 | IndentCaseLabels: true 55 | IndentPPDirectives: None 56 | IndentWidth: 2 57 | KeepEmptyLinesAtTheStartOfBlocks: true 58 | MaxEmptyLinesToKeep: 1 59 | NamespaceIndentation: None 60 | ObjCSpaceAfterProperty: false 61 | ObjCSpaceBeforeProtocolList: true 62 | PointerAlignment: Left 63 | ReflowComments: true 64 | # SeparateDefinitionBlocks: Always # Only support since clang-format 14 65 | SpaceAfterCStyleCast: false 66 | SpaceAfterLogicalNot: false 67 | SpaceAfterTemplateKeyword: true 68 | SpaceBeforeAssignmentOperators: true 69 | SpaceBeforeCpp11BracedList: false 70 | SpaceBeforeCtorInitializerColon: true 71 | SpaceBeforeInheritanceColon: true 72 | SpaceBeforeParens: ControlStatements 73 | SpaceBeforeRangeBasedForLoopColon: true 74 | SpaceBeforeSquareBrackets: false 75 | SpaceInEmptyParentheses: false 76 | SpacesBeforeTrailingComments: 2 77 | SpacesInAngles: false 78 | SpacesInCStyleCastParentheses: false 79 | SpacesInContainerLiterals: false 80 | SpacesInParentheses: false 81 | SpacesInSquareBrackets: false 82 | Standard: c++11 83 | TabWidth: 4 84 | UseTab: Never 85 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker 2 | 3 | DRLTT uses docker for SDK compiling, CI/CD, and more. 4 | 5 | Currently supported docker images and their hierarchical relationship: 6 | 7 | ```text 8 | `drltt` 9 | ├── `drltt:cicd` # Docker image for setting up CI/CD 10 | └── `drltt:runtime` # Docker image for run DRLTT, i.e. training/testing/sdk compilation/etc. 11 | ``` 12 | 13 | ## `drltt:cicd`: Build DRLTT CI/CD Docker Image 14 | 15 | ```bash 16 | docker image build --tag drltt:cicd - < ./Dockerfile.cicd 17 | ``` 18 | 19 | ## `drltt:runtime`: Build DRLTT Runtime Docker Image 20 | 21 | ```bash 22 | docker image build --tag drltt:runtime - < ./Dockerfile.runtime 23 | ``` 24 | 25 | ## Useful Tips 26 | 27 | ### Tips 1: Networking Issue 28 | 29 | For network environments within Mainland China, you may consider using a domestic pip source to accelerate this process and setting the timeout to a larger value: 30 | 31 | ```bash 32 | docker image build --tag drltt:runtime --build-arg PIP_ARGS=" -i https://pypi.tuna.tsinghua.edu.cn/simple --timeout 1000 " - < ./Dockerfile.runtime 33 | ``` 34 | 35 | For APT source, you may consider using a domestic apt source to accelerate this process by appending the following part to the `./Dockerfile`: 36 | 37 | ```dockerfile 38 | # Example using TUNA apt source 39 | ARG APT_SOURCE_LIST=/etc/apt/sources.list 40 | RUN \ 41 | mv ${APT_SOURCE_LIST} ${APT_SOURCE_LIST}.bak && \ 42 | touch ${APT_SOURCE_LIST} && \ 43 | printf "deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy main restricted universe multiverse" >> ${APT_SOURCE_LIST} && \ 44 | printf "deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-updates main restricted universe multiverse" >> ${APT_SOURCE_LIST} && \ 45 | printf "deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-backports main restricted universe multiverse" >> ${APT_SOURCE_LIST} && \ 46 | printf "deb http://security.ubuntu.com/ubuntu/ jammy-security main restricted universe multiverse" >> ${APT_SOURCE_LIST} && \ 47 | cat ${APT_SOURCE_LIST} 48 | ``` 49 | 50 | ### Tip 2: Retrying the image building command 51 | 52 | ```bash 53 | ret_val=1;while [ -z "${ret_val}" ] || [ 0 -ne $ret_val ]; do ret_val=$(docker image build --tag drltt:runtime --build-arg PIP_ARGS=" -i https://pypi.tuna.tsinghua.edu.cn/simple --timeout 1000 " - < ./Dockerfile.runtime); done 54 | ``` 55 | 56 | ### Tip 3: Transferring Docker images 57 | 58 | To save time by transferring the Docker images, save with the command: 59 | 60 | ```bash 61 | docker image save drltt:runtime -o ./drltt.runtime.image 62 | ``` 63 | 64 | , and load with the command: 65 | 66 | ```bash 67 | docker image load -i ./drltt.runtime.image 68 | ``` 69 | 70 | ### Tip 4: Clearing Cache 71 | 72 | Tip: To remove unused images/cached, run: 73 | 74 | ```bash 75 | docker system prune 76 | ``` 77 | 78 | ### Tip 5: To debug with docker images 79 | 80 | ```bash 81 | image_name=drltt:runtime 82 | docker run --name ${image_name}-$((RANDOM%100000)) --entrypoint bash --rm --network=host -it ${image_name} -c bash 83 | ``` 84 | -------------------------------------------------------------------------------- /drltt/common/proto/proto_def/drltt_proto/environment/trajectory_tracking.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "drltt_proto/dynamics_model/state.proto"; 4 | import "drltt_proto/dynamics_model/action.proto"; 5 | import "drltt_proto/dynamics_model/observation.proto"; 6 | import "drltt_proto/dynamics_model/hyper_parameter.proto"; 7 | import "drltt_proto/dynamics_model/basics.proto"; 8 | import "drltt_proto/trajectory/trajectory.proto"; 9 | 10 | package drltt_proto; 11 | 12 | message TrajectoryTrackingHyperParameter { 13 | // Interval between each time step in [s] 14 | optional float step_interval = 1; 15 | // Initial state space's lower bound. 16 | repeated float init_state_lb = 2; 17 | // Initial state space's upper bound. 18 | repeated float init_state_ub = 3; 19 | // Number of observable forward waypoints on the reference line. 20 | optional int32 n_observation_steps = 4; 21 | // Collectino of hyper-parameters of dynamics models. 22 | repeated HyperParameter dynamics_models_hyper_parameters = 5; 23 | // Tracking lenght' lower bound. 24 | optional int32 tracking_length_lb = 6; 25 | // Tracking lenght' upper bound. 26 | optional int32 tracking_length_ub = 7; 27 | // Reference line pad mode. 28 | optional string reference_line_pad_mode = 8; 29 | // TODO: consider a better place to store this hparam 30 | optional int32 max_n_episodes = 9; 31 | } 32 | 33 | message DynamicsModelData { 34 | // Dynamic model's class name 35 | optional string type = 1; 36 | // Hyper-parameter. 37 | optional HyperParameter hyper_parameter = 2; 38 | // Sequence of states. 39 | repeated State states = 3; 40 | // Sequence of actions. 41 | repeated Action actions = 4; 42 | // Sequence of observations. 43 | repeated Observation observations = 5; 44 | // Debug information. 45 | repeated DebugInfo debug_infos = 6; 46 | } 47 | 48 | // Data recorded from an episode, i.e. a trajectory. 49 | message TrajectoryTrackingEpisode { 50 | // Last step index. 51 | optional int32 step_index = 1; 52 | // Environment's hyper-parameter. 53 | optional TrajectoryTrackingHyperParameter hyper_parameter = 2; 54 | // Reference line. 55 | optional ReferenceLine reference_line = 3; 56 | // Real tracking length. 57 | optional int32 tracking_length = 4; 58 | // Dynamics model's data, including hyper-parameter and trajectory, i.e. sequence of state-observation-action triplet. 59 | DynamicsModelData dynamics_model = 5; 60 | // Index of selected dynamics model. 61 | optional int32 selected_dynamics_model_index = 6; 62 | // Sequence of rewards. 63 | repeated float rewards = 7; 64 | } 65 | 66 | // Trajectory tracking environment. 67 | message TrajectoryTrackingEnvironment { 68 | // Hyper-parameter of the environment. 69 | optional TrajectoryTrackingHyperParameter hyper_parameter = 1; 70 | // Episode data. 71 | optional TrajectoryTrackingEpisode episode = 2; 72 | repeated TrajectoryTrackingEpisode episodes = 3; 73 | } 74 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/trajectory_tracker/trajectory_tracker.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file trajectory_tracker.h 3 | * @brief Trajectory tracker exported SDK. 4 | * 5 | */ 6 | #pragma once 7 | 8 | #include "drltt-sdk/environments/trajectory_tracking.h" 9 | 10 | /** 11 | * drltt 12 | */ 13 | namespace drltt { 14 | // clang-format off 15 | /** 16 | * @brief DRLTT Trajectory Tracking C++ SDK. 17 | * 18 | * Nomenclature for documentation: 19 | * 20 | * - x: X-coordinate in [m] within (-inf, +inf). 21 | * - y: Y-coordinate in [m] within (-inf, +inf). 22 | * - r: heading in [rad] within [-pi, pi), following convention of math lib like `std::atan2`. 23 | * - v: scalar speed in [m/s] within [0, +inf)。 24 | * - a: acceleration in [m/s/s] within [0, +inf)。 25 | * - s: steering angle in [rad] within [-max_s, +max_s] where `max_s` is the steering limit. 26 | * 27 | * TODO: move this part to the doc. of protobuf. 28 | * 29 | * Predefined type for documentation 30 | * 31 | * - STATE : tuple, state of dynamics model. 32 | * - ACTION : tuple, action of dynamics model. 33 | * - OBSERVATION : vector, vectorized observation feature. 34 | * - REFERENCE_WAYPOINT : tuple, vectorized observation feature. 35 | * - REFERENCE_LINE : vector, reference line for the dynamics model to track. 36 | */ 37 | // clang-format on 38 | class TrajectoryTracker { 39 | 40 | public: 41 | TrajectoryTracker() = default; 42 | ~TrajectoryTracker() {}; 43 | 44 | /** 45 | * @param load_path Path to checkpoint. See "Checkpoint Structure" in 46 | * `./README.md` for detail. 47 | * @param dynamics_model_index The index of dynamics model. The order 48 | * corresponds to the `dynamics_model_configs` within the config YAML in the 49 | * checkpoint folder. 50 | */ 51 | TrajectoryTracker(const std::string& load_path, int dynamics_model_index); 52 | /** 53 | * @brief Set a reference line. 54 | * 55 | * It will estimate an initial state of the dynamics 56 | * model, which may be overwritten by other function later if necessary. 57 | * 58 | * @param reference_line Reference line to be tracked. 59 | * @return Success flag. 60 | */ 61 | bool set_reference_line(const REFERENCE_LINE& reference_line); 62 | /** 63 | * @brief Set the Dynamics Model Initial State object. 64 | * 65 | * @param init_state Initial state. 66 | * @return true Setting succeeded. 67 | * @return false Setting failed. 68 | */ 69 | bool set_dynamics_model_initial_state(const STATE& init_state); 70 | /** 71 | * @brief Perform trajectory tracking. 72 | * 73 | * Roll out a trajectory and return the tracked trajectory. 74 | * 75 | * @return Trajectory containing states, actions, and observations of the 76 | * roll-out. Format=, vector, vector>. 77 | */ 78 | TRAJECTORY TrackReferenceLine(); 79 | 80 | private: 81 | TrajectoryTracking _env; 82 | }; 83 | } // namespace drltt 84 | -------------------------------------------------------------------------------- /sdk/assets/exported-python-sdk/check_export_symbols.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from export_symbols import ( 4 | TrajectoryTracker, 5 | trajectory_tracker_set_reference_line, 6 | trajectory_tracker_set_dynamics_model_initial_state, 7 | trajectory_tracker_track_reference_line, 8 | ) 9 | from drltt_proto.environment.environment_pb2 import Environment 10 | 11 | 12 | def main(): 13 | env_data = Environment() 14 | with open('./checkpoint/env_data.bin', 'rb') as f: 15 | env_data.ParseFromString(f.read()) 16 | 17 | reference_line = [ 18 | (waypoint.x, waypoint.y) 19 | for waypoint in env_data.trajectory_tracking.episode.reference_line.waypoints[ 20 | : env_data.trajectory_tracking.episode.tracking_length 21 | ] 22 | ] 23 | init_state_proto = env_data.trajectory_tracking.episode.dynamics_model.states[0] 24 | init_state = ( 25 | init_state_proto.bicycle_model.body_state.x, 26 | init_state_proto.bicycle_model.body_state.y, 27 | init_state_proto.bicycle_model.body_state.r, 28 | init_state_proto.bicycle_model.v, 29 | ) 30 | tracker = TrajectoryTracker('./checkpoint/', env_data.trajectory_tracking.episode.selected_dynamics_model_index) 31 | trajectory_tracker_set_reference_line(tracker, reference_line) 32 | trajectory_tracker_set_dynamics_model_initial_state(tracker, init_state) 33 | rt_states, rt_actions, rt_observations, rt_debug_datas = trajectory_tracker_track_reference_line(tracker) 34 | 35 | print('state shape: ', np.array(rt_states).shape) 36 | print('action shape: ', np.array(rt_actions).shape) 37 | print('observation shape: ', np.array(rt_observations).shape) 38 | 39 | gt_states = list() 40 | for state in env_data.trajectory_tracking.episode.dynamics_model.states: 41 | gt_state = ( 42 | state.bicycle_model.body_state.x, 43 | state.bicycle_model.body_state.y, 44 | state.bicycle_model.body_state.r, 45 | state.bicycle_model.v, 46 | ) 47 | gt_states.append(gt_state) 48 | 49 | gt_actions = list() 50 | for action in env_data.trajectory_tracking.episode.dynamics_model.actions: 51 | gt_action = ( 52 | action.bicycle_model.a, 53 | action.bicycle_model.s, 54 | ) 55 | gt_actions.append(gt_action) 56 | 57 | gt_observations = list() 58 | for observation in env_data.trajectory_tracking.episode.dynamics_model.observations: 59 | gt_observation = tuple(observation.bicycle_model.feature) 60 | gt_observations.append(gt_observation) 61 | 62 | state_diffs = np.array(gt_states, dtype=np.float32) - np.array(rt_states, dtype=np.float32) 63 | print(f'state max diff: {state_diffs.max()}') 64 | 65 | action_diffs = np.array(gt_actions, dtype=np.float32) - np.array(rt_actions, dtype=np.float32) 66 | print(f'action max diff: {action_diffs.max()}') 67 | 68 | observation_diffs = np.array(gt_observations, dtype=np.float32) - np.array(rt_observations, dtype=np.float32) 69 | print(f'observation max diff: {observation_diffs.max()}') 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/dynamics_models/base_dynamics_model.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "drltt-sdk/common/common.h" 5 | #include "drltt_proto/dynamics_model/action.pb.h" 6 | #include "drltt_proto/dynamics_model/hyper_parameter.pb.h" 7 | #include "drltt_proto/dynamics_model/state.pb.h" 8 | 9 | namespace drltt { 10 | 11 | class BaseDynamicsModel { 12 | public: 13 | BaseDynamicsModel() = default; 14 | /** 15 | * @brief Construct a new Base Dynamics Model object 16 | * 17 | * @param hyper_parameter Hyper-parameter. 18 | */ 19 | BaseDynamicsModel(const drltt_proto::HyperParameter& hyper_parameter); 20 | /** 21 | * @brief Construct a new Base Dynamics Model object 22 | * 23 | * @param hyper_parameter Hyper-parameter. 24 | * @param state Initial state. 25 | */ 26 | BaseDynamicsModel(const drltt_proto::HyperParameter& hyper_parameter, 27 | const drltt_proto::State& state); 28 | /** 29 | * @brief Reset the environment. 30 | * 31 | * @param state Initial state. 32 | */ 33 | void Reset(const drltt_proto::State& state); 34 | /** 35 | * @brief Perform a forward step with the given action, 36 | * 37 | * @param action Action. 38 | * @param delta_t Step interval in [s]. 39 | */ 40 | virtual void Step(const drltt_proto::Action& action, float delta_t) {} 41 | /** 42 | * @brief Get the state object. 43 | * 44 | * @return drltt_proto::State State proto. 45 | */ 46 | drltt_proto::State get_state() const; 47 | /** 48 | * @brief Set the state object 49 | * 50 | * @param state State used for setting. 51 | * @return true Setting succeeded. 52 | * @return false Setting failed. 53 | */ 54 | bool set_state(const drltt_proto::State& state); 55 | // TODO: get_body_state() const 56 | /** 57 | * @brief Get the hyper parameter object. 58 | * 59 | * @return drltt_proto::HyperParameter Hyper-parameter proto. 60 | */ 61 | drltt_proto::HyperParameter get_hyper_parameter() const; 62 | /** 63 | * @brief Set the hyper parameter object 64 | * 65 | * @param hyper_parameter Hyper-parameter used for setting. 66 | * @return true Setting succeeded. 67 | * @return false Setting failed. 68 | */ 69 | bool set_hyper_parameter(const drltt_proto::HyperParameter& hyper_parameter); 70 | 71 | ~BaseDynamicsModel() = default; 72 | 73 | // avoid copy by using RVO and std::move 74 | /** 75 | * @brief Get the state observation object. 76 | * 77 | * @param observation The returned observation. 78 | * @return true 79 | * @return false 80 | */ 81 | virtual bool get_state_observation(std::vector* observation) const { 82 | return false; 83 | } 84 | virtual bool get_dynamics_model_observation( 85 | std::vector* observation) const { 86 | return false; 87 | } 88 | 89 | drltt_proto::DebugInfo get_debug_info() { return _debug_info; } 90 | 91 | protected: 92 | virtual void parse_hyper_parameter() {} 93 | drltt_proto::State _state; 94 | drltt_proto::DebugInfo _debug_info; // TODO: move to trajectory tracking env 95 | drltt_proto::HyperParameter _hyper_parameter; 96 | }; 97 | 98 | } // namespace drltt 99 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | # Configuration file for the Sphinx documentation builder. 3 | # 4 | # For the full list of built-in configuration values, see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Project information ----------------------------------------------------- 8 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 9 | 10 | import os 11 | import sys 12 | import subprocess 13 | import logging 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | conf_py_file = os.path.dirname(__file__) 17 | project_dir = os.path.realpath(f'{conf_py_file}/../../') 18 | proto_dir = f'{project_dir}/common/proto/proto_gen_py' 19 | waymax_viz_dir = f'{project_dir}/submodules/waymax-visualization' # TODO: consider move paths to a config files like YAML 20 | 21 | # Reference: https://stackoverflow.com/questions/10324393/sphinx-build-fail-autodoc-cant-import-find-module 22 | sys.path.append(project_dir) 23 | sys.path.append(proto_dir) 24 | sys.path.append(waymax_viz_dir) 25 | logging.info('sys.path:') 26 | for p in sys.path: 27 | logging.info(p) 28 | 29 | project = 'DRL-based Trajectory Tracking (DRLTT)' 30 | copyright = '2024, Yinda Xu, Lidong Yu' 31 | author = 'Yinda Xu, Lidong Yu' 32 | 33 | # -- General configuration --------------------------------------------------- 34 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 35 | 36 | extensions = [ 37 | 'sphinx.ext.autodoc', 38 | 'sphinx.ext.napoleon', 39 | # 'myst_parser', # substitute of `recommonmark` 40 | 'm2r2', # TODO: figure out why myst_parser does not work. 41 | 'breathe', 42 | ] 43 | 44 | source_suffix = { 45 | '.rst': 'restructuredtext', 46 | '.txt': 'markdown', 47 | '.md': 'markdown', 48 | } 49 | 50 | sdk_name = "drltt-sdk" 51 | breathe_projects = { sdk_name: "../../sdk/doxygen_output/xml/" } 52 | breathe_default_project = sdk_name 53 | 54 | templates_path = ['_templates'] 55 | exclude_patterns = [] 56 | 57 | 58 | 59 | # -- Processing scripts --------------------------------------------------- 60 | 61 | # Integrate Doxygen into Sphinx 62 | # Reference: https://leimao.github.io/blog/CPP-Documentation-Using-Sphinx/ 63 | # subprocess.call('make clean', shell=True) 64 | # logging.info('Compiling CPP documentation with Doxygen...') 65 | # subprocess.call('cd ../../sdk ; doxygen Doxyfile-cpp ', shell=True) 66 | # logging.info('CPP documentation compiled.') 67 | 68 | # protoc-gen-doc 69 | # logging.info('Compiling Protobuf documentation with Doxygen...') 70 | # home_bin_dir = '/home/docs/.local/bin/' 71 | # sys.path.append(home_bin_dir) # TOOD: use $HOME instead 72 | # append_path_prefix = (f'PATH=${{PATH}}:{home_bin_dir}') 73 | # subprocess.call(f'cd ../../common/proto/proto_def ; {append_path_prefix} bash compile_proto.sh ', shell=True) 74 | # logging.info('Protobuf documentation compiled.') 75 | 76 | 77 | 78 | # -- Options for HTML output ------------------------------------------------- 79 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 80 | 81 | import sphinx_rtd_theme 82 | html_theme = 'sphinx_rtd_theme' 83 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 84 | html_static_path = [] 85 | -------------------------------------------------------------------------------- /drltt/simulator/observation/observation_manager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from gym.spaces import Space 4 | 5 | from drltt.simulator.dynamics_models import DynamicsModelManager 6 | from drltt.simulator.trajectory.reference_line import ReferenceLineManager 7 | 8 | from drltt_proto.dynamics_model.basics_pb2 import BodyState 9 | 10 | 11 | class ObservationManager: 12 | """Manager for observation. 13 | 14 | Attributes: 15 | reference_line_manager: handler of underlying reference line manager 16 | dynamics_model_manager: handler of underlying Dynamics model manager 17 | """ 18 | 19 | reference_line_manager: ReferenceLineManager 20 | dynamics_model_manager: DynamicsModelManager 21 | 22 | def __init__( 23 | self, 24 | reference_line_manager: ReferenceLineManager, 25 | dynamics_model_manager: DynamicsModelManager, 26 | ): 27 | """ 28 | Args: 29 | reference_line_manager: Underlying reference line manager 30 | dynamics_model_manager: underlying Dynamics model manager 31 | """ 32 | self.reference_line_manager = reference_line_manager 33 | self.dynamics_model_manager = dynamics_model_manager 34 | 35 | def get_observation_space(self) -> Space: 36 | """Return observation space, usually consisting of multiple sub-observation spaces. 37 | 38 | Returns: 39 | Space: observation space. 40 | """ 41 | reference_line_observation_space = self.reference_line_manager.get_observation_space() 42 | state_observation_space = self.dynamics_model_manager.get_state_observation_space() 43 | synamics_model_observation_space = self.dynamics_model_manager.get_dynamics_model_observation_space() 44 | 45 | observation_space_tuple = gym.spaces.Tuple( 46 | ( 47 | reference_line_observation_space, 48 | state_observation_space, 49 | synamics_model_observation_space, 50 | ) 51 | ) 52 | observation_space = gym.spaces.flatten_space(observation_space_tuple) 53 | 54 | return observation_space 55 | 56 | def get_observation(self, episode_data, body_state: BodyState) -> np.ndarray: 57 | """Return the vectorized observation, which is usually ego-centric. 58 | 59 | Args: 60 | episode_data: Episode data. 61 | body_state: (Serialized) body state of agent/dynamics model. 62 | 63 | Returns: 64 | np.ndarray: Vectorized observation. 65 | """ 66 | reference_line_observation = self.reference_line_manager.get_observation_by_index( 67 | episode_data=episode_data, body_state=body_state 68 | ) 69 | state_observation = self.dynamics_model_manager.get_sampled_dynamics_model().get_state_observation() 70 | dynamics_model_observation = ( 71 | self.dynamics_model_manager.get_sampled_dynamics_model().get_dynamics_model_observation() 72 | ) 73 | 74 | observation = np.concatenate( 75 | ( 76 | reference_line_observation, 77 | state_observation, 78 | dynamics_model_observation, 79 | ), 80 | axis=0, 81 | ) 82 | 83 | return observation 84 | -------------------------------------------------------------------------------- /docker/Dockerfile.cicd: -------------------------------------------------------------------------------- 1 | # Dockerfile for lauching CI/CD environment for DRLTT 2 | 3 | FROM ubuntu:22.04 4 | WORKDIR / 5 | 6 | ARG USER_WORKDIR=/work_dir 7 | ARG USER_BIN=/usr_bin 8 | ARG USER_BASHRC=/work_dir/.bashrc 9 | ARG CURL_ARGS=" -OL --retry-all-errors " 10 | ARG USR_LOCAL_BIN_DIR=/usr/local/bin/ 11 | ARG USR_LOCAL_LIB_DIR=/usr/local/lib/ 12 | ARG APT_SOURCE_LIST=/etc/apt/sources.list 13 | 14 | ENV DEBIAN_FRONTEND noninteractive 15 | ENV SHELL=/bin/bash 16 | ENV LD_LIBRARY_PATH=${USR_LOCAL_BIN_DIR}:${LD_LIBRARY_PATH} 17 | ENV LD_LIBRARY_PATH=${USR_LOCAL_LIB_DIR}:${LD_LIBRARY_PATH} 18 | 19 | SHELL [ "/bin/bash", "-c" ] 20 | 21 | RUN \ 22 | echo "Activating APT links to source code for in-container building" && \ 23 | printf '\n' >> ${APT_SOURCE_LIST} && \ 24 | printf "deb-src http://archive.ubuntu.com/ubuntu/ jammy main restricted universe multiverse" >> ${APT_SOURCE_LIST} && \ 25 | printf '\n' >> ${APT_SOURCE_LIST} && \ 26 | printf "deb-src http://archive.ubuntu.com/ubuntu/ jammy-updates main restricted universe multiverse" >> ${APT_SOURCE_LIST} && \ 27 | printf '\n' >> ${APT_SOURCE_LIST} && \ 28 | printf "deb-src http://archive.ubuntu.com/ubuntu/ jammy-security main restricted universe multiverse" >> ${APT_SOURCE_LIST} && \ 29 | printf '\n' >> ${APT_SOURCE_LIST} && \ 30 | printf "deb-src http://archive.ubuntu.com/ubuntu/ jammy-backports main restricted universe multiverse" >> ${APT_SOURCE_LIST} && \ 31 | printf '\n' >> ${APT_SOURCE_LIST} && \ 32 | apt update --fix-missing && \ 33 | echo "Adding ppa:deadsnakes/ppa, for installing python3.x" && \ 34 | DEBIAN_FRONTEND=noninteractive apt install -y software-properties-common && \ 35 | echo "Add PPA for Python" && \ 36 | add-apt-repository ppa:deadsnakes/ppa && \ 37 | echo "Updating APT list" && \ 38 | apt update --fix-missing && \ 39 | echo "APT list configuration done." 40 | 41 | RUN \ 42 | echo "Installing network tools" && \ 43 | apt install -y \ 44 | ca-certificates \ 45 | libssl-dev \ 46 | curl wget \ 47 | gnutls-bin \ 48 | && \ 49 | echo "Installing compiling tools" && \ 50 | apt install -y \ 51 | cmake \ 52 | build-essential \ 53 | cmake \ 54 | python3 \ 55 | && \ 56 | echo "Installing system utilities" && \ 57 | apt install -y \ 58 | unzip \ 59 | vim \ 60 | && \ 61 | echo "Installing dependency libraries" && \ 62 | apt install -y \ 63 | libblas-dev \ 64 | liblapack-dev \ 65 | libeigen3-dev \ 66 | python3-opencv \ 67 | libopencv-dev \ 68 | libgtest-dev \ 69 | libboost-all-dev \ 70 | libabsl-dev \ 71 | && \ 72 | mkdir $USER_WORKDIR && chmod -R 777 $USER_WORKDIR && \ 73 | mkdir $USER_BIN && chmod -R 777 $USER_BIN && \ 74 | echo "Configuring user's bashrc" && \ 75 | touch $USER_BASHRC && \ 76 | echo "# Add LD_LIBRARY_PATH installed by user." >> $USER_BASHRC && \ 77 | echo 'export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH' >> $USER_BASHRC && \ 78 | echo "source $USER_BASHRC" >> /root/.bashrc && \ 79 | echo "apt package installation and bashrc modification done." 80 | 81 | RUN \ 82 | echo "Installing gitlab-runner..." && \ 83 | ( \ 84 | curl -L --output /usr/local/bin/gitlab-runner https://gitlab-runner-downloads.gitlab.cn/latest/binaries/gitlab-runner-linux-amd64; \ 85 | chmod +x /usr/local/bin/gitlab-runner; \ 86 | useradd --comment 'GitLab Runner' --create-home gitlab-runner --shell /bin/bash; \ 87 | gitlab-runner install --user=gitlab-runner --working-directory=/home/gitlab-runner; \ 88 | # sudo gitlab-runner start; \ 89 | ) && \ 90 | echo "gitlab-runner installed." 91 | 92 | -------------------------------------------------------------------------------- /drltt/common/geometry.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | 5 | 6 | def normalize_angle(angle: Union[float, np.ndarray]) -> Union[float, np.ndarray]: 7 | """Normalize angle to [-pi, pi), compatible with Numpy vectorization. 8 | 9 | Args: 10 | angle: Angle to be normalized. 11 | 12 | Returns: 13 | Normalized angle. 14 | 15 | """ 16 | return (angle + np.pi) % (2 * np.pi) - np.pi 17 | 18 | 19 | # TODO: use screw theory to refactor this part 20 | def transform_points(points: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray: 21 | """Transform 2-D points with transform matrix. 22 | 23 | Args: 24 | points: Points to be transformed, shape=(N, 2). 25 | transform_matrix: Transform matrix which lies in group SO(2), shape=(3, 3). 26 | 27 | Returns: 28 | np.ndarray: Transformed points, shape=(N, 2). 29 | """ 30 | points = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1) 31 | points = transform_matrix @ points.T 32 | points = points.T 33 | points = points[:, :2] 34 | 35 | return points 36 | 37 | 38 | def transform_between_local_and_world(points: np.ndarray, body_state: np.ndarray, trans_dir: str) -> np.ndarray: 39 | """Transform points between the local (body) frame and the world frame. 40 | 41 | Args: 42 | points (np.ndarray): Points to be transformed, shape=(N, 2). 43 | body_state (np.ndarray): Body state, format=. 44 | trans_dir (str): Transform direction. 45 | 'world_to_local': from world frame to body frame. 46 | 'local_to_world': from body frame to world frame. 47 | 48 | Returns: 49 | np.ndarray: Transformed points, shape=(N, 2). 50 | """ 51 | points, body_state = points.copy(), body_state.copy() 52 | x, y, r = body_state[:3] 53 | 54 | translation_matrix = np.array( 55 | ( 56 | (1.0, 0.0, -x), 57 | (0.0, 1.0, -y), 58 | (0.0, 0.0, 1.0), 59 | ) 60 | ) 61 | rotation_matrix = np.array( 62 | ( 63 | (np.cos(r), np.sin(r), 0.0), 64 | (-np.sin(r), np.cos(r), 0.0), 65 | (0.0, 0.0, 1.0), 66 | ) 67 | ) 68 | if trans_dir == 'world_to_local': 69 | transform_matrix = rotation_matrix @ translation_matrix 70 | elif trans_dir == 'local_to_world': 71 | transform_matrix = np.linalg.inv(rotation_matrix @ translation_matrix) 72 | else: 73 | raise ValueError(f'Unknown transform direction: {trans_dir}') 74 | 75 | transformed_points = transform_points(points, transform_matrix) 76 | 77 | return transformed_points 78 | 79 | 80 | def transform_to_local_from_world(points: np.ndarray, body_state: np.ndarray) -> np.ndarray: 81 | """Transform points from the world frame to the local (body) frame. 82 | 83 | Args: 84 | points (np.ndarray): Points in world frame, shape=(N, 2). 85 | body_state (np.ndarray): Body state, format=. 86 | 87 | Returns: 88 | np.ndarray: Points in body frame, shape=(N, 2). 89 | """ 90 | return transform_between_local_and_world( 91 | points, 92 | body_state, 93 | trans_dir='world_to_local', 94 | ) 95 | 96 | 97 | def transform_to_world_from_local(points: np.ndarray, body_state: np.ndarray) -> np.ndarray: 98 | """Transform points from the local (body) frame to the world frame. 99 | 100 | Args: 101 | points (np.ndarray): Points in body frame, shape=(N, 2). 102 | body_state (np.ndarray): Body state, format=. 103 | 104 | Returns: 105 | np.ndarray: Points in world frame, shape=(N, 2). 106 | """ 107 | return transform_between_local_and_world( 108 | points, 109 | body_state, 110 | trans_dir='local_to_world', 111 | ) 112 | -------------------------------------------------------------------------------- /drltt/common/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from copy import deepcopy 3 | import logging 4 | 5 | from frozendict import frozendict 6 | 7 | from .future import override, Self 8 | 9 | 10 | def _register_generic( 11 | module_dict: Dict, 12 | module_name: str, 13 | module: object, 14 | ): 15 | assert module_name not in module_dict, logging.info(module_name, module_dict, 'defined in several script files') 16 | module_dict[module_name] = module 17 | 18 | 19 | class Registry(dict): 20 | """A helper class for registering and managing modules 21 | 22 | >>> # build registry 23 | >>> REGISTRY_NAME = Registry(name='REGISTRY_NAME') 24 | 25 | >>> # register function 26 | >>> @REGISTRY_NAME.register 27 | >>> def foo(): 28 | >>> pass 29 | 30 | >>> # register class 31 | >>> @REGISTRY_NAME.register 32 | >>> class bar(): 33 | >>> pass 34 | 35 | >>> # fetch for class construction within builder 36 | >>> class_instance = REGISTRY_NAME[module_name](*args, **kwargs) 37 | 38 | >>> # fetch for function call 39 | >>> result = REGISTRY_NAME[module_name](*args, **kwargs) 40 | """ 41 | 42 | def __init__(self, name: str = 'Registry', **kwargs): 43 | """ 44 | Args: 45 | name: name of the registry 46 | """ 47 | self.name = name 48 | super(Registry, self).__init__(**kwargs) 49 | 50 | def register(self, module: Any) -> Any: 51 | """Register module (class/function/etc.) into this registry. 52 | 53 | Typically used as decorator, thus return the input module itself. 54 | 55 | Args: 56 | module: Python object that needs to be registered. 57 | 58 | Returns: 59 | Any: The registered module. 60 | """ 61 | name = module.__name__ 62 | _register_generic(self, name, module) 63 | 64 | return module 65 | 66 | @override 67 | def __getitem__(self, name: str) -> Any: 68 | if name not in self: 69 | raise ValueError(f'Object {name} does not exist') 70 | 71 | return super().__getitem__(name) 72 | 73 | def register_from_python_module(self, module: Any) -> Self: 74 | """Register all members from a Python module. 75 | 76 | Args: 77 | module: Python module to be processed. 78 | 79 | Returns: 80 | Self: The yielded registry. 81 | """ 82 | self.name = module.__name__ 83 | for k, v in module.__dict__.items(): 84 | if k.startswith('__'): 85 | continue 86 | self[k] = v 87 | 88 | return self 89 | 90 | 91 | def build_object_within_registry_from_config( 92 | registry: Registry, 93 | config: Dict[str, Any] = frozendict(), # TODO: replace with Python built-in static mapping object 94 | **kwargs, 95 | ) -> Any: 96 | """Builder function to build object within a registry from config. 97 | 98 | Config should be in form of keyword arguments (dict-like). 99 | Support adding additional config items through kwargs. 100 | 101 | NOTE: kwargs will not be deep-copied. 102 | 103 | Args: 104 | registry: registry to retrieve class to be constructed. 105 | config: config function that provide the class name and the corresponding arguments, 106 | which should be arranged in the following format: 107 | 108 | .. code-block:: YAML 109 | 110 | type: TYPENAME 111 | arg1: value1 112 | arg2: value2 113 | ... 114 | 115 | **kwargs: key-word arguments to be passed to the retrieved class function. 116 | 117 | Return: 118 | Any: The built object. 119 | """ 120 | config = deepcopy(config) 121 | config = dict(**config) 122 | kwargs = deepcopy(kwargs) 123 | config.update(kwargs) 124 | class_name = config.pop('type') 125 | obj = registry[class_name](**config) 126 | 127 | return obj 128 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/github/gitignore/blob/main/Python.gitignore 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # working directory 165 | .vscode/ 166 | work_dir/ 167 | work_dirs/ 168 | *.json 169 | *.txt 170 | !CMakeLists.txt 171 | *.jpg 172 | *.png 173 | *.mp4 174 | *.zip 175 | 176 | # protobuf generated code 177 | **/*proto_gen*/**/* 178 | # protobuf generated documentation 179 | **/*proto_doc_gen*/**/* 180 | 181 | # doxygen output 182 | sdk/doxygen_output/ 183 | 184 | # test log 185 | /test-log* 186 | -------------------------------------------------------------------------------- /drltt/simulator/trajectory_tracker/trajectory_tracker.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from drltt.simulator.environments.trajectory_tracking_env import TrajectoryTrackingEnv 7 | from drltt.simulator.rl_learning.sb3_utils import roll_out_one_episode 8 | from drltt.simulator.trajectory.reference_line import ReferenceLineManager 9 | 10 | from drltt_proto.environment.environment_pb2 import Environment 11 | from drltt_proto.trajectory.trajectory_pb2 import ReferenceLine 12 | 13 | 14 | class TrajectoryTracker: 15 | """Trajectory tracking SDK class""" 16 | 17 | def __init__(self, checkpoint_dir: str): 18 | """ 19 | Args: 20 | checkpoint_dir: Path to the checkpoint's directory. 21 | """ 22 | self.policy = torch.jit.load(f'{checkpoint_dir}/traced_policy.pt') 23 | env_info = Environment() 24 | with open(f'{checkpoint_dir}/env_data.bin', 'rb') as f: 25 | env_info.ParseFromString(f.read()) 26 | self.env = TrajectoryTrackingEnv(env_info=env_info) 27 | 28 | def track_reference_line( 29 | self, 30 | init_state: Union[Tuple[float, float, float, float], None] = None, 31 | dynamics_model_name: Union[str, None] = 'ShortVehicle', 32 | reference_line: Union[List[Tuple[float, float]], None] = None, 33 | ) -> Tuple[List[Tuple[float, float, float, float]], List[Tuple[float, float]]]: 34 | """Track a reference line with the underlying policy model. 35 | 36 | Nomenclature: 37 | 38 | - x: X-coordinate in [m] within (-inf, +inf) 39 | - y: Y-coordinate in [m] within (-inf, +inf) 40 | - r: heading in [rad] within [-pi, pi), following convention of math lib like `std::atan2` 41 | - v: scalar speed in [m/s] within [0, +inf) 42 | 43 | TODO: refer this part to definition of prototype and so for avoid redundant documentation. 44 | 45 | Args: 46 | init_state: Initial state, format=. 47 | dynamics_model_name: Name of the dynamics model. 48 | reference_line: Reference line, format=List[]. 49 | 50 | Return: 51 | Tuple[states, action]: The tracked trajectory. All elements have the same length 52 | that is equal to the length of reference line. 53 | 54 | - The first element is a sequence of states. 55 | - The second element is a sequence of actions. 56 | """ 57 | if init_state is not None: 58 | init_state = np.array(init_state) 59 | if reference_line is not None: 60 | reference_line: ReferenceLine = ReferenceLineManager.np_array_to_reference_line(np.array(reference_line)) 61 | 62 | states, actions, observations = roll_out_one_episode( 63 | self.env, 64 | self.policy_func, 65 | init_state=init_state, 66 | dynamics_model_name=dynamics_model_name, 67 | reference_line=reference_line, 68 | ) 69 | return ( 70 | [tuple(state) for state in states], 71 | [tuple(action) for action in actions], 72 | ) 73 | 74 | def policy_func(self, observation: np.ndarray) -> np.ndarray: 75 | """Wrapper of underlying JIT policy in form of func(observation) -> action. 76 | 77 | Including preprocessing and post processing of the tensor. 78 | 79 | Args: 80 | observation: Observation. 81 | 82 | Returns: 83 | np.ndarray: Action. 84 | """ 85 | observation_tensor = torch.from_numpy(observation).reshape(1, -1) 86 | action_tensor = self.policy(observation_tensor) 87 | action = action_tensor.reshape(-1).numpy() 88 | 89 | return action 90 | 91 | def get_step_interval(self) -> float: 92 | """Get the step interval in time. 93 | 94 | Returns: 95 | float: Time step interval. 96 | """ 97 | return self.env.env_info.trajectory_tracking.hyper_parameter.step_interval 98 | 99 | def get_dynamics_model_info(self) -> str: 100 | """Get the string for the information of dynamics models. 101 | 102 | Returns: 103 | str: Pretty string containing information of dynamics models. 104 | """ 105 | return self.env.get_dynamics_model_info() 106 | 107 | def get_reference_line(self) -> List[Tuple[float, float]]: 108 | """Get the current reference line. 109 | 110 | Returns: 111 | List[Tuple[float, float]]: Reference line, shape=[]. 112 | """ 113 | arr = ReferenceLineManager.reference_line_to_np_array(self.env.get_reference_line()) 114 | 115 | return [tuple(waypoint) for waypoint in arr] 116 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/dynamics_models/bicycle_model.cpp: -------------------------------------------------------------------------------- 1 | #include "bicycle_model.h" 2 | 3 | namespace drltt { 4 | 5 | void BicycleModel::Step(const drltt_proto::Action& action, float delta_t) { 6 | 7 | drltt_proto::State derivative = 8 | _compute_derivative(_state, action, _hyper_parameter); 9 | _state.mutable_bicycle_model()->CopyFrom( 10 | _state.bicycle_model() + 11 | derivative.bicycle_model() * 12 | delta_t); // TODO move to canonical implementation `+=` 13 | } 14 | 15 | bool BicycleModel::get_state_observation( 16 | std::vector* observation) const { 17 | observation->push_back(_state.bicycle_model().v()); 18 | observation->push_back(GetMaxSteeringAngle()); 19 | 20 | return true; 21 | } 22 | 23 | bool BicycleModel::get_dynamics_model_observation( 24 | std::vector* observation) const { 25 | observation->push_back(_hyper_parameter.bicycle_model().front_overhang()); 26 | observation->push_back(_hyper_parameter.bicycle_model().wheelbase()); 27 | observation->push_back(_hyper_parameter.bicycle_model().rear_overhang()); 28 | observation->push_back(_hyper_parameter.bicycle_model().width()); 29 | observation->push_back(_hyper_parameter.bicycle_model().length()); 30 | observation->push_back(_hyper_parameter.bicycle_model().max_lat_acc()); 31 | 32 | return true; 33 | } 34 | 35 | void BicycleModel::parse_hyper_parameter() { 36 | float front_overhang = _hyper_parameter.bicycle_model().front_overhang(); 37 | float wheelbase = _hyper_parameter.bicycle_model().wheelbase(); 38 | float rear_overhang = _hyper_parameter.bicycle_model().rear_overhang(); 39 | float length = front_overhang + wheelbase + rear_overhang; 40 | _hyper_parameter.mutable_bicycle_model()->set_length(length); 41 | _hyper_parameter.mutable_bicycle_model()->set_frontwheel_to_cog( 42 | wheelbase + rear_overhang - length / 2); 43 | _hyper_parameter.mutable_bicycle_model()->set_rearwheel_to_cog( 44 | wheelbase + front_overhang - length / 2); 45 | } 46 | 47 | drltt_proto::State BicycleModel::_compute_derivative( 48 | const drltt_proto::State& state, const drltt_proto::Action& action, 49 | const drltt_proto::HyperParameter& hyper_parameter) { 50 | float x = state.bicycle_model().body_state().x(); 51 | float y = state.bicycle_model().body_state().y(); 52 | float r = state.bicycle_model().body_state().r(); 53 | float v = state.bicycle_model().v(); 54 | float a = action.bicycle_model().a(); 55 | float s = action.bicycle_model().s(); 56 | 57 | // clip s 58 | const float max_steer = GetMaxSteeringAngle(); 59 | s = clip(s, -max_steer, +max_steer); 60 | 61 | float omega; 62 | float rotation_radius_inv; 63 | ComputeRotationRelatedVariables(s, &omega, &rotation_radius_inv); 64 | 65 | float dx_dt = v * std::cos(r + omega); 66 | float dy_dt = v * std::sin(r + omega); 67 | float dr_dt = v * rotation_radius_inv; 68 | float dv_dt = a; 69 | 70 | drltt_proto::State derivative; 71 | derivative.mutable_bicycle_model()->mutable_body_state()->set_x(dx_dt); 72 | derivative.mutable_bicycle_model()->mutable_body_state()->set_y(dy_dt); 73 | derivative.mutable_bicycle_model()->mutable_body_state()->set_r(dr_dt); 74 | derivative.mutable_bicycle_model()->set_v(dv_dt); 75 | 76 | return derivative; 77 | } 78 | 79 | float BicycleModel::GetCogRelativePositionBetweenAxles() const { 80 | const float frontwheel_to_cog = 81 | _hyper_parameter.bicycle_model().frontwheel_to_cog(); 82 | const float rearwheel_to_cog = 83 | _hyper_parameter.bicycle_model().rearwheel_to_cog(); 84 | 85 | return rearwheel_to_cog / (rearwheel_to_cog + frontwheel_to_cog); 86 | } 87 | 88 | bool BicycleModel::ComputeRotationRelatedVariables( 89 | float steering_angle, float* omega, float* rotation_radius_inv) const { 90 | const float cog_relative_position_between_axles = 91 | GetCogRelativePositionBetweenAxles(); 92 | *omega = 93 | std::atan(cog_relative_position_between_axles * std::tan(steering_angle)); 94 | *omega = normalize_angle(*omega); 95 | // Return the inverse of the radius to ensure numerical stability. 96 | *rotation_radius_inv = 97 | std::sin(*omega) / _hyper_parameter.bicycle_model().rearwheel_to_cog(); 98 | 99 | return true; 100 | } 101 | 102 | float BicycleModel::GetMaxSteeringAngle() const { 103 | const float max_lat_acc = _hyper_parameter.bicycle_model().max_lat_acc(); 104 | const float v = _state.bicycle_model().v(); 105 | const float rearwheel_to_cog = 106 | _hyper_parameter.bicycle_model().rearwheel_to_cog(); 107 | 108 | float max_s; 109 | 110 | const float asin_arg = 111 | rearwheel_to_cog * max_lat_acc / std::max(v * v, EPSILON); 112 | 113 | if (asin_arg <= 1.0) { 114 | max_s = std::atan(std::tan(std::asin(asin_arg)) / 115 | GetCogRelativePositionBetweenAxles()); 116 | } else { 117 | max_s = M_PI; 118 | } 119 | 120 | return max_s; 121 | } 122 | 123 | } // namespace drltt 124 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/environments/trajectory_tracking.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file trajectory_tracking.h 3 | * @brief Trajectory tracking environment. 4 | * 5 | */ 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "drltt-sdk/common/common.h" 13 | #include "drltt-sdk/dynamics_models/bicycle_model.h" 14 | #include "drltt-sdk/inference/policy_inference.h" 15 | #include "drltt-sdk/managers/observation_manager.h" 16 | #include "drltt_proto/dynamics_model/state.pb.h" 17 | #include "drltt_proto/environment/environment.pb.h" 18 | #include "drltt_proto/trajectory/trajectory.pb.h" 19 | 20 | typedef std::tuple REFERENCE_WAYPOINT; 21 | typedef std::vector REFERENCE_LINE; 22 | typedef std::tuple STATE; 23 | typedef std::tuple ACTION; 24 | typedef std::vector OBSERVATION; 25 | typedef std::vector DEBUG_DATA; 26 | typedef std::tuple, std::vector, 27 | std::vector, std::vector> 28 | TRAJECTORY; 29 | 30 | namespace drltt { 31 | 32 | // TODO: use factory to make it configurable 33 | // TODO: verify if proto can be passed to python through pybind 34 | /** 35 | * @brief Trajectory tracking environment. 36 | * 37 | * Class for managing dynamics models/reference lines/policy model/etc. and 38 | * performing rollouts. 39 | * 40 | */ 41 | class TrajectoryTracking { 42 | public: 43 | TrajectoryTracking() = default; 44 | ~TrajectoryTracking() {} 45 | /** 46 | * @brief Load the underlying policy. 47 | * 48 | * @param policy_path Path to the policy. 49 | * @return true Loading succeeded. 50 | * @return false Loading failed. 51 | */ 52 | bool LoadPolicy(const std::string& policy_path); 53 | /** 54 | * @brief Load the environment data. 55 | * 56 | * @param env_data_path Path to the protobuf binary file of environment data. 57 | * @return true Loading succeeded. 58 | * @return false Loading failed. 59 | */ 60 | bool LoadEnvData(const std::string& env_data_path); 61 | /** 62 | * @brief Set the dynamics model hyper parameter with index. 63 | * TODO: provide function to set dynamics model by name. 64 | * 65 | * @param index The index of dynamicsmo stored in the `env_data`. 66 | * @return true Setting succeeded. 67 | * @return false Setting failed. 68 | */ 69 | bool set_dynamics_model_hyper_parameter(int index); 70 | /** 71 | * @brief Set the reference line object. 72 | * 73 | * @param reference_line Reference line. 74 | * @return true Setting succeeded. 75 | * @return false Setting failed. 76 | */ 77 | bool set_reference_line( 78 | const std::vector& reference_line); 79 | /** 80 | * @brief Set the reference line object. 81 | * 82 | * @param reference_line Reference line. 83 | * @return true Setting succeeded. 84 | * @return false Setting failed. 85 | */ 86 | bool set_reference_line(const drltt_proto::ReferenceLine& reference_line); 87 | /** 88 | * @brief Estimate the initial state. 89 | * 90 | * @param reference_line Reference line. 91 | * @param state The returned initial state. 92 | * @param delta_t Step interval. 93 | * @return true Setting succeeded. 94 | * @return false Setting failed. 95 | */ 96 | static bool EstimateInitialState( 97 | const drltt_proto::ReferenceLine& reference_line, 98 | drltt_proto::State& state, float delta_t); 99 | /** 100 | * @brief Set the dynamics model initial state object. 101 | * 102 | * @param state Initial state to be set to dynamics model. 103 | * @return true Setting succeeded. 104 | * @return false Setting failed. 105 | */ 106 | bool set_dynamics_model_initial_state(STATE state); 107 | /** 108 | * @brief Set the dynamics model initial state object. 109 | * 110 | * @param state Initial state to be set to dynamics model. 111 | * @return true Setting succeeded. 112 | * @return false Setting failed. 113 | */ 114 | bool set_dynamics_model_initial_state(drltt_proto::State state); 115 | /** 116 | * @brief Roll out a trajectory based on underlying policy model and 117 | * environment. 118 | * 119 | * @return true Roll-out succeeded. 120 | * @return false Roll-out failed. 121 | */ 122 | bool RollOut(); 123 | /** 124 | * @brief Get the tracked trajectory object 125 | * 126 | * @return TRAJECTORY Tracked trajectory, format=,, 127 | * vector, vector> 128 | */ 129 | TRAJECTORY get_tracked_trajectory(); 130 | 131 | private: 132 | TorchJITModulePolicy _policy_model; 133 | drltt_proto::ReferenceLine _reference_line; 134 | drltt_proto::Environment _env_data; 135 | BicycleModel _dynamics_model; 136 | ObservationManager _observation_manager; 137 | std::vector _states; 138 | std::vector _actions; 139 | std::vector _observations; 140 | std::vector _debug_infos; 141 | }; 142 | 143 | } // namespace drltt 144 | -------------------------------------------------------------------------------- /tools/main.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import argparse 3 | import os 4 | import sys 5 | from copy import deepcopy 6 | import logging 7 | import shutil 8 | 9 | import gym 10 | from gym import Env 11 | from stable_baselines3.common.base_class import BaseAlgorithm 12 | 13 | from drltt.common import build_object_within_registry_from_config 14 | from drltt.common.io import load_and_override_configs, override_config, save_config_to_yaml 15 | from drltt.simulator.rl_learning.sb3_learner import train_with_sb3, eval_with_sb3 16 | from drltt.simulator.rl_learning.sb3_export import export_sb3_jit_module 17 | from drltt.simulator.environments import ENVIRONMENTS, ExtendedGymEnv 18 | from drltt.simulator.rl_learning.sb3_learner import SB3_MODULES 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | '--config-files', 25 | metavar='N', 26 | type=str, 27 | nargs='+', 28 | help=( 29 | 'Config file(s). If multiple paths provided, the first config is base and will overridden by the rest' 30 | ' respectively.' 31 | ), 32 | ) 33 | parser.add_argument('--checkpoint-dir', type=str) 34 | parser.add_argument('--train', action='store_true', default=False) 35 | parser.add_argument('--eval', action='store_true', default=False) 36 | parser.add_argument('--trace', action='store_true', default=False) 37 | parser.add_argument('--test-case-save-format', type=str, default='protobuf') 38 | parser.add_argument('--num-test-cases', type=int, default=1) 39 | 40 | args = parser.parse_args() 41 | 42 | return args 43 | 44 | 45 | def configure_root_logger(log_dir: str): 46 | """Configure root logger. 47 | 48 | Remove all default handlers and add customized handlers. 49 | 50 | Args: 51 | log_dir: Directory to dump log output. 52 | """ 53 | os.makedirs(log_dir, exist_ok=True) 54 | FORMAT = '%(asctime)s :: %(name)s :: %(levelname)-8s :: %(message)s' 55 | FORMATTER = logging.Formatter(fmt=FORMAT) 56 | 57 | logger = logging.root 58 | logger.handlers.clear() 59 | 60 | stream_handler = logging.StreamHandler(stream=sys.stdout) 61 | stream_handler.setFormatter(FORMATTER) 62 | stream_handler.setLevel(logging.INFO) 63 | logger.addHandler(stream_handler) 64 | 65 | file_handler = logging.FileHandler(filename=f'{log_dir}/log.txt') 66 | file_handler.setFormatter(FORMATTER) 67 | file_handler.setLevel(logging.INFO) 68 | logger.addHandler(file_handler) 69 | 70 | logging.info(f'Logging directory configured at: {log_dir}') 71 | 72 | 73 | def main(args): 74 | configure_root_logger(args.checkpoint_dir) 75 | 76 | config = load_and_override_configs(args.config_files) 77 | env_config = config['environment'] 78 | 79 | # backup config 80 | os.makedirs(args.checkpoint_dir, exist_ok=True) 81 | save_config_to_yaml(config, f'{args.checkpoint_dir}/config.yaml') 82 | for i_config, config_p in enumerate(args.config_files): 83 | config_basename = os.path.basename(config_p) 84 | config_save_path = f'{i_config:02}-{config_basename}' 85 | shutil.copy(config_p, f'{args.checkpoint_dir}/{config_save_path}') 86 | logging.info(f'Config file backed up at: {config_save_path}') 87 | 88 | checkpoint_file_prefix = f'{args.checkpoint_dir}/checkpoint' # without extension 89 | 90 | if args.train: 91 | environment: ExtendedGymEnv = build_object_within_registry_from_config(ENVIRONMENTS, deepcopy(env_config)) 92 | train_with_sb3( 93 | environment=environment, 94 | algorithm_config=deepcopy(config['algorithm']), 95 | learning_config=deepcopy(config['learning']), 96 | checkpoint_file_prefix=checkpoint_file_prefix, 97 | ) 98 | 99 | if args.eval: 100 | eval_config = config['evaluation'] 101 | eval_env_config = override_config(deepcopy(env_config), deepcopy(eval_config['overriden_environment'])) 102 | eval_environment: ExtendedGymEnv = build_object_within_registry_from_config( 103 | ENVIRONMENTS, deepcopy(eval_env_config) 104 | ) 105 | 106 | eval_algorithm: BaseAlgorithm = SB3_MODULES[config['algorithm']['type']].load(checkpoint_file_prefix) 107 | 108 | eval_with_sb3( 109 | eval_environment, 110 | eval_algorithm, 111 | report_dir=args.checkpoint_dir, 112 | **eval_config['eval_config'], 113 | ) 114 | 115 | if args.trace: 116 | trace_algorithm: BaseAlgorithm = SB3_MODULES[config['algorithm']['type']].load(checkpoint_file_prefix) 117 | trace_environment: ExtendedGymEnv = build_object_within_registry_from_config(ENVIRONMENTS, deepcopy(env_config)) 118 | export_sb3_jit_module( 119 | trace_algorithm, 120 | trace_environment, 121 | device='cpu', 122 | export_dir=args.checkpoint_dir, 123 | n_test_cases=args.num_test_cases, 124 | test_case_save_format=args.test_case_save_format, 125 | ) 126 | 127 | 128 | if __name__ == '__main__': 129 | args = parse_args() 130 | main(args) 131 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | ## Sphinx 4 | 5 | ### For initialization 6 | 7 | Initialize requirements of documentation generation: 8 | 9 | ```bash 10 | pip install -r requirements/pypi-doc.txt 11 | ``` 12 | 13 | Initialize Sphinx project: 14 | 15 | ```bash 16 | mkdir docs && cd docs 17 | sphinx-quickstart 18 | 19 | # Tips for the interactive configuration phase 20 | # Choose the option "Separate sources from build" 21 | # Reference: https://stackoverflow.com/questions/65829149/what-does-separate-source-and-build-directories-mean 22 | ``` 23 | 24 | Build HTML: 25 | 26 | ```bash 27 | cd build 28 | make html 29 | ``` 30 | 31 | ### For incremental changes 32 | 33 | As documentation of the current project has already been set up, you can only run `./docs/make-html.sh` instead of the previous steps: 34 | 35 | .. literalinclude:: ../../../docs/make-html.sh 36 | :language: bash 37 | 38 | 39 | ### Start the server and view the documentation pages 40 | 41 | Start the HTTP server on the remote side 42 | 43 | ```bash 44 | cd docs/build/html 45 | python -m http.server 8080 -b localhost 46 | ``` 47 | 48 | Create an SSH tunneling on the local side, which forwards connections/requests from local to remote (server) 49 | 50 | ```bash 51 | ssh -L 8080:localhost:8080 remote-server 52 | ``` 53 | 54 | ## Use RST Files within the Sphinx Documentation 55 | 56 | ### Include Markdown 57 | 58 | For including Markdown files into the Sphinx documentation, this project utilizes `m2r2` which needs to be installed through `pip`: 59 | 60 | ``` 61 | pip install m2r2 62 | ``` 63 | 64 | Then in `.rst` file, use the `.. mdinclude::` directive to include Markdown file. The path needs to be relative to the `.rst` file. 65 | 66 | ``` 67 | .. mdinclude:: ../../../README.md 68 | ``` 69 | 70 | ### Include Code snippets 71 | 72 | Use the `.. literalinclude::` directive to include a code snippet from a script/source file. 73 | 74 | For example, 75 | 76 | ``` 77 | .. literalinclude:: ../../../docs/make-html.sh 78 | :language: bash 79 | ``` 80 | 81 | results in: 82 | 83 | .. literalinclude:: ../../../docs/make-html.sh 84 | :language: bash 85 | 86 | ## Auto-Generation of API Documentations 87 | 88 | The feature of auto-documentation is realized by various Sphinx extensions and third-party tools. 89 | 90 | This project adopts [Google-style Python docstrings](https://google.github.io/styleguide/pyguide.html), [Example Google Style Python Docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) for auto-generation of Python API pages. 91 | 92 | The authors would like to thank [PyTorch](https://pytorch.org/docs/stable/index.html) for being an exemplar of documentation. 93 | 94 | Reference: 95 | 96 | * https://sphinx-rtd-theme.readthedocs.io/en/stable/demo/api.html 97 | 98 | ### Python documentation 99 | 100 | Sphinx can utilize Napoleon for auto-generation of API pages of Python code. 101 | 102 | ```python 103 | import os 104 | import sys 105 | sys.path.insert(0, os.path.abspath(path_to_python_module_dir)) 106 | sys.path.insert(0, os.path.abspath(path_to_python_module2_dir)) 107 | ... 108 | 109 | extensions = [ 110 | ... 111 | 'sphinx.ext.autodoc', # support for auto-doc generation 112 | 'sphinx.ext.napoleon', # support for numpy / google style 113 | ] 114 | ``` 115 | 116 | References: 117 | 118 | * https://sphinxcontrib-napoleon.readthedocs.io/en/latest/ 119 | * https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html 120 | * https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html 121 | * https://sphinx-intro-tutorial.readthedocs.io/en/latest/sphinx_extensions.html 122 | 123 | 124 | 125 | ### Doxygen documentation 126 | 127 | This project utilizes Doxygen for auto-generation of API pages of C++ SDK, and `breathe` for including these pages into the Sphinx documentation. 128 | 129 | `breathe` needs to be installed through `pip`: 130 | 131 | ```bash 132 | pip install breath 133 | ``` 134 | 135 | Firstly, XML files are generated: 136 | 137 | ```bash 138 | cd sdk 139 | doxygen Doxyfile-cpp 140 | ``` 141 | 142 | Then, in `conf.py`, set up `breathe` to include the XML generated in the previous step. 143 | 144 | ```python 145 | extensions = [ 146 | "...", 147 | "breathe", 148 | ] 149 | sdk_name = "..." 150 | breathe_projects = { sdk_name: "../../sdk/doxygen_output/xml/" } 151 | breathe_default_project = sdk_name 152 | ``` 153 | 154 | Reference: 155 | 156 | * https://www.doxygen.nl/manual/starting.html 157 | * https://leimao.github.io/blog/CPP-Documentation-Using-Sphinx/ 158 | * https://github.com/leimao/Sphinx-CPP-TriangleLib 159 | * https://leimao.github.io/blog/CPP-Documentation-Using-Doxygen/ 160 | * https://github.com/leimao/Doxygen-CPP-TriangleLib 161 | 162 | 163 | 164 | ### Protobuf documentation 165 | 166 | This project uses `protoc-gen-doc`, an extension of the Protobuf compiler, to automatically generate API pages for Protobuf definitions. 167 | 168 | To activate this extension, pass `--doc_out` argument to `protoc`. 169 | 170 | ```bash 171 | protoc ... \ 172 | --doc_out ${doc_output_dir} \ 173 | ... 174 | ``` 175 | 176 | References: 177 | 178 | * https://github.com/pseudomuto/protoc-gen-doc 179 | -------------------------------------------------------------------------------- /drltt/common/io.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict, Mapping 2 | import os 3 | import logging 4 | import random 5 | import string 6 | 7 | import yaml 8 | import numpy as np 9 | 10 | from drltt_proto.sdk.exported_policy_test_case_pb2 import TensorFP 11 | from drltt_proto.dynamics_model.basics_pb2 import DebugInfo 12 | 13 | GLOBAL_DEBUG_INFO = DebugInfo() 14 | 15 | 16 | def load_config_from_yaml(config_file: str) -> Dict: 17 | """Load config from yaml and handle exceptions. 18 | 19 | Args: 20 | config_file: Path to configuration file. 21 | """ 22 | if not os.path.exists(config_file): 23 | raise FileNotFoundError(f'{config_file} does not exist') 24 | with open(config_file, 'r') as f: 25 | config = yaml.safe_load(f) 26 | logging.info(f'Loaded config from: {config_file}') 27 | 28 | return config 29 | 30 | 31 | def save_config_to_yaml(config: Dict, config_file: str): 32 | """Save config to YAML file. 33 | 34 | Args: 35 | config: Config object to be saved. 36 | config_file: Path to YAML file. 37 | """ 38 | with open(config_file, 'w') as f: 39 | yaml.dump(config, f) 40 | 41 | logging.info(f'Saved config at: {config_file}') 42 | 43 | 44 | def convert_list_to_tuple_within_dict( 45 | dictionary: Dict, 46 | exceptions: Tuple[str] = tuple(), 47 | ) -> Dict: 48 | """Recursively cast list to tuple within a dict. 49 | 50 | Avoid modification issue (mutable / immutable). Support specified exception key. 51 | 52 | Also deal with some type issues. 53 | e.g. in case of stable-baselines3, https://github.com/DLR-RM/stable-baselines3/blob/v2.2.1/stable_baselines3/common/off_policy_algorithm.py#L157 54 | 55 | Args: 56 | dictionary: Dictionary to be processed. 57 | exceptions: Exceptional keys. 58 | 59 | Returns: 60 | Dict: Processed dictionary. 61 | """ 62 | for k in tuple(dictionary.keys()): 63 | if k in exceptions: 64 | continue 65 | if isinstance(dictionary[k], List): 66 | dictionary[k] = tuple(dictionary[k]) 67 | elif isinstance(dictionary[k], Mapping): 68 | dictionary[k] = convert_list_to_tuple_within_dict( 69 | dictionary[k], 70 | exceptions=exceptions, 71 | ) 72 | else: 73 | continue 74 | 75 | return dictionary 76 | 77 | 78 | def generate_random_string(n: int) -> str: 79 | """Generate a string with characters randomly chosen. 80 | 81 | Args: 82 | n: Desired length of string. 83 | 84 | Returns: 85 | str: Random string. 86 | """ 87 | return ''.join(random.choices(string.ascii_uppercase + string.digits, k=n)) 88 | 89 | 90 | def override_config( 91 | base_config: Dict, 92 | update_config: Dict, 93 | allow_new_key: bool = False, 94 | ) -> Dict: 95 | """Override the value config. 96 | 97 | Args: 98 | base_config: Base config to be processed. 99 | update_config: Incremental config which contains key-value pair for overriding. 100 | allow_new_key: Whether allow creation of new key. 101 | 102 | Returns: 103 | Dict: The overridden config. 104 | """ 105 | for k in tuple(update_config.keys()): 106 | if k not in base_config: 107 | if allow_new_key: 108 | base_config[k] = update_config[k] 109 | else: 110 | continue 111 | if type(base_config[k]) != type(update_config[k]): 112 | continue 113 | if isinstance(update_config[k], Dict): 114 | base_config[k] = override_config( 115 | base_config[k], 116 | update_config[k], 117 | allow_new_key=allow_new_key, 118 | ) 119 | else: 120 | base_config[k] = update_config[k] 121 | 122 | return base_config 123 | 124 | 125 | def load_and_override_configs(config_paths: List[str]) -> Dict: 126 | """Load and override a series of config files. 127 | 128 | Args: 129 | config_paths: The base config and the overriding configs. 130 | 131 | * The first config will serve as base config. 132 | * The rest configs will override the base config, respectively. 133 | 134 | Returns: 135 | Dict: The loaded and overridden config. 136 | """ 137 | config = load_config_from_yaml(config_paths[0]) 138 | 139 | for overriding_cfg_file in config_paths[1:]: 140 | overriding_config = load_config_from_yaml(overriding_cfg_file) 141 | override_config(config, overriding_config, allow_new_key=True) 142 | 143 | config = convert_list_to_tuple_within_dict(config) 144 | 145 | return config 146 | 147 | 148 | def convert_numpy_to_TensorFP(arr: np.ndarray) -> TensorFP: 149 | """Convert numpy array to tensor proto. 150 | 151 | Args: 152 | arr: Numpy array to be converted. 153 | 154 | Returns: 155 | TensorFP: Converted tensor proto. 156 | """ 157 | tensor_proto = TensorFP() 158 | tensor_proto.shape.extend(arr.shape) 159 | tensor_proto.data.extend(arr.reshape(-1)) 160 | 161 | return tensor_proto 162 | 163 | 164 | def convert_TensorFP_to_numpy(tensor_proto: TensorFP) -> np.ndarray: 165 | """Convert tensor proto to numpy array. 166 | 167 | Args: 168 | tensor_proto: Tensor proto. 169 | 170 | Returns: 171 | np.ndarray: Converted numpy array. 172 | """ 173 | arr = np.array(tensor_proto.data).reshape(*tensor_proto.shape) 174 | 175 | return arr 176 | -------------------------------------------------------------------------------- /drltt/simulator/dynamics_models/dynamics_model_manager.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Iterable 2 | 3 | import numpy as np 4 | from gym.spaces import Space 5 | 6 | from drltt.common import build_object_within_registry_from_config 7 | from drltt_proto.dynamics_model.hyper_parameter_pb2 import HyperParameter 8 | from . import BaseDynamicsModel, DYNAMICS_MODELS 9 | 10 | 11 | class DynamicsModelManager: 12 | """Manager for Dynamics Models, supporting random sampling/operating on sampled dynamics model/etc. 13 | 14 | Attributes: 15 | dynamics_models: Collection of dynamics models 16 | sampled_dynamics_model: The currently sampled dynamics model. 17 | sampled_dynamics_model_index: THe index of currently sampled dynamics model. 18 | """ 19 | 20 | dynamics_models: List[BaseDynamicsModel] 21 | sampled_dynamics_model: Union[BaseDynamicsModel, None] = None 22 | sampled_dynamics_model_index: int = -1 23 | 24 | def __init__( 25 | self, hyper_parameters: Iterable[HyperParameter] = tuple(), dynamics_model_configs: Iterable = tuple() 26 | ): 27 | """ 28 | Args: 29 | hyper_parameters: Hyper-parameters of dynamics models to be managed by the manager. 30 | dynamics_model_configs: configurations of dynamics models. 31 | """ 32 | self.dynamics_models = list() 33 | self.sampled_dynamics_model = None 34 | 35 | if len(hyper_parameters) > 0: 36 | # TODO: add test for this branch 37 | for dynamics_hyper_parameter in hyper_parameters: 38 | dynamics_model_type = ( 39 | dynamics_hyper_parameter.type 40 | ) # TODO wrap this to function like build_object_within_registry_from_config 41 | dynamics_model = DYNAMICS_MODELS[dynamics_model_type](hyper_parameter=dynamics_hyper_parameter) 42 | self.dynamics_models.append(dynamics_model) 43 | else: 44 | for dynamics_model_config in dynamics_model_configs: 45 | dynamics_model = build_object_within_registry_from_config(DYNAMICS_MODELS, dynamics_model_config) 46 | self.dynamics_models.append(dynamics_model) 47 | 48 | # TODO: Check if all spaces are identical within the collection 49 | 50 | self.probabilities = (1 / len(self.dynamics_models),) * len(self.dynamics_models) 51 | self.sampled_dynamics_model = self.dynamics_models[0] 52 | 53 | # TODO: unify data structure for dynamics models 54 | self.names_to_indexes_and_dynamics_models = { 55 | dm.get_name(): (dm_idx, dm) for dm_idx, dm in enumerate(self.dynamics_models) 56 | } 57 | 58 | def get_dynamics_model_observation_space(self) -> Space: 59 | """Get the dynamics model observation space. Assuming this space is identical across models within the collection 60 | 61 | Returns: 62 | Space: Dynamics model observation space. 63 | """ 64 | return self.dynamics_models[0].get_dynamics_model_observation_space() 65 | 66 | def get_state_observation_space(self) -> Space: 67 | """Get the state observation space. Assuming this space is identical across models within the collection. 68 | 69 | Returns: 70 | Space: State observation space 71 | """ 72 | return self.dynamics_models[0].get_state_observation_space() 73 | 74 | def get_dynamics_model_info(self) -> str: 75 | """Get the string about of information of dynamics models. 76 | 77 | Returns: 78 | str: The string of information of dynamics models. 79 | """ 80 | dm_infos = list() 81 | for dm in self.dynamics_models: 82 | dm_info = '\n'.join( 83 | ( 84 | dm.get_name(), 85 | str(dm.hyper_parameter), 86 | ) 87 | ) 88 | dm_infos.append(dm_info) 89 | return '\n'.join(dm_infos) 90 | 91 | def select_dynamics_model_by_name(self, name: str) -> Tuple[int, BaseDynamicsModel]: 92 | """Select a dynamics model by name. 93 | 94 | Args: 95 | name (str): The name of dynamics model. 96 | 97 | Returns: 98 | Tuple[int, BaseDynamicsModel]: The index and object of the selected dynamics model. 99 | """ 100 | return self.names_to_indexes_and_dynamics_models[name] 101 | 102 | def sample_dynamics_model(self) -> Tuple[int, BaseDynamicsModel]: 103 | """Randomly sample a dynamics model. 104 | 105 | Returns: 106 | BaseDynamicsModel: sampled dynamics model 107 | """ 108 | self.sampled_dynamics_model_index = np.random.choice(range(len(self.dynamics_models)), p=self.probabilities) 109 | self.sampled_dynamics_model = self.dynamics_models[self.sampled_dynamics_model_index] 110 | return self.sampled_dynamics_model_index, self.get_sampled_dynamics_model() 111 | 112 | def get_sampled_dynamics_model(self) -> BaseDynamicsModel: 113 | """Return the sampled dynamics model. 114 | 115 | Returns: 116 | BaseDynamicsModel: Sampled dynamics model. 117 | """ 118 | return self.sampled_dynamics_model 119 | 120 | def get_all_hyper_parameters(self) -> List[HyperParameter]: 121 | """Get the collection of hyper-parameters of all dynamics models. 122 | 123 | Returns: 124 | List[HyperParameter]: The list of dynamics models. 125 | """ 126 | all_hyper_parameters = list() 127 | for dynamics_model in self.dynamics_models: 128 | hyper_parameter = HyperParameter() 129 | hyper_parameter.CopyFrom(dynamics_model.hyper_parameter) 130 | all_hyper_parameters.append(hyper_parameter) 131 | 132 | return all_hyper_parameters 133 | -------------------------------------------------------------------------------- /drltt/simulator/visualization/visualize_trajectory_tracking_episode.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib 4 | import numpy as np 5 | from jax import numpy as jnp 6 | 7 | from waymax_viz.waymax import datatypes 8 | from waymax_viz.waymax.datatypes import Trajectory, RoadgraphPoints, MapElementIds 9 | from waymax_viz.waymax.visualization import ( 10 | plot_trajectory, 11 | utils as viz_utils, 12 | color, 13 | ) 14 | from waymax_viz.waymax.visualization.viz import ( 15 | _RoadGraphShown, 16 | _RoadGraphDefaultColor, 17 | ) 18 | 19 | from .utils import scale_xy_lim 20 | 21 | from drltt.simulator.visualization import VISUALIZATION_FUNCTIONS 22 | 23 | from drltt_proto.environment.environment_pb2 import Environment 24 | 25 | 26 | @VISUALIZATION_FUNCTIONS.register 27 | def visualize_trajectory_tracking_episode( 28 | env_data: Environment, 29 | viz_prefix: str, 30 | n_steps_per_viz: int = 30, 31 | ): 32 | """Visualize an episode of trajectory tracking and save images. 33 | 34 | TODO: fix non-unified data range issue. 35 | 36 | Args: 37 | episode (TrajectoryTrackingEpisode): Episode data of trajectory tracking. 38 | viz_prefix (str): The prefix of visualization files to be saved. 39 | n_steps_per_viz (int, optional): Number of steps per draw. Defaults to 20. 40 | """ 41 | assert isinstance( 42 | env_data, Environment 43 | ), f'`visualize_trajectory_tracking_episode` requires env_data to be in class `Environment`' 44 | episode = env_data.trajectory_tracking.episode 45 | 46 | traj_len = episode.tracking_length 47 | n_viz = traj_len // n_steps_per_viz 48 | 49 | for viz_idx in range(n_viz): 50 | viz_idx_range = ( 51 | viz_idx * n_steps_per_viz, 52 | traj_len if viz_idx == n_viz - 1 else (viz_idx + 1) * n_steps_per_viz, 53 | ) 54 | 55 | # plot trajectory 56 | dm_states = episode.dynamics_model.states 57 | traj = Trajectory( 58 | x=jnp.array([dm_states[idx].bicycle_model.body_state.x for idx in range(*viz_idx_range)]).reshape(1, -1), 59 | y=jnp.array([dm_states[idx].bicycle_model.body_state.y for idx in range(*viz_idx_range)]).reshape(1, -1), 60 | z=jnp.array([0.0 for idx in range(*viz_idx_range)]).reshape(1, -1), 61 | vel_x=jnp.array([ 62 | np.cos(dm_states[idx].bicycle_model.body_state.r) * dm_states[idx].bicycle_model.v 63 | for idx in range(*viz_idx_range) 64 | ]).reshape(1, -1), 65 | vel_y=jnp.array([ 66 | np.sin(dm_states[idx].bicycle_model.body_state.r) * dm_states[idx].bicycle_model.v 67 | for idx in range(*viz_idx_range) 68 | ]).reshape(1, -1), 69 | yaw=jnp.array([dm_states[idx].bicycle_model.body_state.r for idx in range(*viz_idx_range)]).reshape(1, -1), 70 | timestamp_micros=jnp.array( 71 | [int(episode.hyper_parameter.step_interval * 1e6 * idx) for idx in range(*viz_idx_range)], 72 | dtype=jnp.int32, 73 | ).reshape(1, -1), 74 | valid=jnp.array([True for idx in range(*viz_idx_range)], dtype=bool).reshape(1, -1), 75 | length=jnp.array([ 76 | episode.dynamics_model.hyper_parameter.bicycle_model.length for idx in range(*viz_idx_range) 77 | ]).reshape(1, -1), 78 | width=jnp.array([ 79 | episode.dynamics_model.hyper_parameter.bicycle_model.width for idx in range(*viz_idx_range) 80 | ]).reshape(1, -1), 81 | height=jnp.array([1.4 for idx in range(*viz_idx_range)]).reshape(1, -1), 82 | ) # shape=(#agents, #timesteps) 83 | refline_waypoints = episode.reference_line.waypoints 84 | reference_line = RoadgraphPoints( 85 | x=jnp.array([refline_waypoints[idx].x for idx in range(*viz_idx_range)]).reshape(-1), 86 | y=jnp.array([refline_waypoints[idx].y for idx in range(*viz_idx_range)]).reshape(-1), 87 | z=jnp.array([0.0 for idx in range(*viz_idx_range)]).reshape(-1), 88 | dir_x=jnp.array([ 89 | refline_waypoints[idx].x / math.hypot(refline_waypoints[idx].x, refline_waypoints[idx].y) 90 | for idx in range(*viz_idx_range) 91 | ]).reshape(-1), 92 | dir_y=jnp.array([ 93 | refline_waypoints[idx].y / math.hypot(refline_waypoints[idx].x, refline_waypoints[idx].y) 94 | for idx in range(*viz_idx_range) 95 | ]).reshape(-1), 96 | dir_z=jnp.array([0.0 for idx in range(*viz_idx_range)]).reshape(-1), 97 | types=jnp.array([MapElementIds.STOP_SIGN for idx in range(*viz_idx_range)], dtype=np.int32).reshape(-1), 98 | ids=jnp.array([idx for idx in range(*viz_idx_range)], dtype=np.int32).reshape(-1), 99 | valid=jnp.array([True for idx in range(*viz_idx_range)], dtype=bool).reshape(-1), 100 | ) # shape=(#timesteps) 101 | n_steps_within_current_image = traj.shape[1] 102 | is_controlled = np.array([ 103 | True, 104 | ]).reshape(1) 105 | fig, ax = viz_utils.init_fig_ax() 106 | ax.set_aspect('equal', adjustable='box') 107 | ax.set_box_aspect(1) 108 | for viz_step_idx in range(n_steps_within_current_image): 109 | plot_trajectory(ax, traj, is_controlled, time_idx=viz_step_idx) 110 | plot_roadgraph_points(ax, reference_line) 111 | xy_lim = scale_xy_lim((ax.get_xlim(), ax.get_ylim()), ratio=1.3) 112 | stacked_img = viz_utils.img_from_fig(fig) 113 | viz_utils.save_img_as_png(stacked_img, f'{viz_prefix}-{viz_idx}-stacked.png') 114 | del fig, ax 115 | 116 | # TODO: figure out why manually stacking not work (probably related to detail of `plot_trajectory`) 117 | # all_imgs = list() 118 | # for viz_step_idx in range(n_steps_within_current_image): 119 | # fig, ax = viz_utils.init_fig_ax() 120 | # plot_trajectory(ax, traj, is_controlled, time_idx=viz_step_idx) 121 | # # TODO: plot reference line 122 | # ax.set_xlim(xy_lim[0]) 123 | # ax.set_ylim(xy_lim[1]) 124 | # ax.set_aspect('equal', adjustable='box') 125 | # ax.set_box_aspect(1) 126 | # img = viz_utils.img_from_fig(fig) 127 | # all_imgs.append(img) 128 | # fused_image = np.stack(all_imgs, axis=0).mean(axis=0).astype(all_imgs[0].dtype) 129 | # # adjust intensity 130 | # for index in np.ndindex(fused_image.shape[:2]): 131 | # intensity = get_subjective_brightness(fused_image[index]) 132 | # fused_image[index] = fused_image[index] * np.power(intensity, 4) 133 | # viz_utils.save_img_as_png(fused_image, f'{viz_prefix}-{viz_idx}.png') 134 | 135 | 136 | # fmt: off 137 | def plot_roadgraph_points( 138 | ax: matplotlib.axes.Axes, 139 | rg_pts: datatypes.RoadgraphPoints, 140 | verbose: bool = False, 141 | ) -> None: 142 | """Overwritten function from Waymax. 143 | 144 | Modified items: 145 | - line type. 146 | - overlap level (zorder) 147 | """ 148 | if len(rg_pts.shape) != 1: 149 | raise ValueError(f'Roadgraph should be rank 1, got {len(rg_pts.shape)}') 150 | if rg_pts.valid.sum() == 0: 151 | return 152 | elif verbose: 153 | print(f'Roadgraph points count: {rg_pts.valid.sum()}') 154 | 155 | xy = rg_pts.xy[rg_pts.valid] 156 | rg_type = rg_pts.types[rg_pts.valid] 157 | for curr_type in np.unique(rg_type): 158 | if curr_type in _RoadGraphShown: 159 | p1 = xy[rg_type == curr_type] 160 | rg_color = color.ROAD_GRAPH_COLORS.get(curr_type, _RoadGraphDefaultColor) 161 | ax.plot(p1[:, 0], p1[:, 1], 'o-', color=rg_color, ms=3, zorder=15) 162 | # fmt: on 163 | -------------------------------------------------------------------------------- /drltt/simulator/dynamics_models/base_dynamics_model.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from abc import ABC, abstractmethod 3 | 4 | import numpy as np 5 | import gym 6 | from gym.spaces import Space 7 | 8 | from drltt.simulator import DTYPE 9 | 10 | from drltt_proto.dynamics_model.basics_pb2 import BodyState 11 | from drltt_proto.dynamics_model.hyper_parameter_pb2 import HyperParameter 12 | from drltt_proto.dynamics_model.state_pb2 import State 13 | from drltt_proto.dynamics_model.action_pb2 import Action 14 | from drltt_proto.dynamics_model.observation_pb2 import Observation 15 | from drltt_proto.dynamics_model.basics_pb2 import DebugInfo 16 | 17 | 18 | class BaseDynamicsModel(ABC): 19 | """Base class for dynamics models defining state/action space, dynamics/transition functions, and more. 20 | 21 | Attributes: 22 | hyper_parameter: Hyper-parameter of the dynamics model. 23 | state: Vectorized state of the dynamics model. 24 | """ 25 | 26 | hyper_parameter: HyperParameter 27 | state: State 28 | 29 | def __init__( 30 | self, 31 | init_state: Union[np.ndarray, None] = None, 32 | **kwargs, 33 | ): 34 | """ 35 | Args: 36 | init_state: Initial state to be set. 37 | """ 38 | if init_state is not None: 39 | self.set_state(init_state) 40 | 41 | def get_name(self) -> str: 42 | """Get the dynamics model's name. 43 | 44 | Returns: 45 | str: The dynamics model's name. 46 | """ 47 | return self.hyper_parameter.name 48 | 49 | def get_state(self) -> np.ndarray: 50 | """Get the state in `np.ndarray` (deserialized form). 51 | 52 | Return: 53 | np.ndarray: Returned state. 54 | """ 55 | return self.deserialize_state(self.state) 56 | 57 | def get_state_proto(self) -> State: 58 | """Get the state in proto (serialized form). 59 | 60 | Returns: 61 | State: State in proto. 62 | """ 63 | return self.state 64 | 65 | @abstractmethod 66 | def get_body_state_proto(self) -> BodyState: 67 | """Return agent body's state in proto (serialized form). 68 | 69 | Return: 70 | BodyState: Body state of the dynamics model. 71 | """ 72 | raise NotImplementedError 73 | 74 | def set_state(self, new_state: np.ndarray): 75 | self.state = self.serialize_state(new_state) 76 | 77 | @classmethod 78 | @abstractmethod 79 | def serialize_state(cls, state: np.ndarray) -> State: 80 | """Serialize state to proto. 81 | 82 | Args: 83 | state: Vectorized state. 84 | 85 | Returns: 86 | State: Serialized state. 87 | """ 88 | raise NotImplementedError 89 | 90 | @classmethod 91 | @abstractmethod 92 | def deserialize_state(cls, state: State) -> np.ndarray: 93 | """Deserialize state to np.ndarray. 94 | 95 | Args: 96 | state: Serialized state. 97 | 98 | Returns: 99 | np.ndarray: Vectorized state. 100 | """ 101 | raise NotImplementedError 102 | 103 | @classmethod 104 | @abstractmethod 105 | def serialize_action(cls, action: np.ndarray) -> Action: 106 | """Serialize action to proto. 107 | 108 | Args: 109 | action: Vectorized action. 110 | 111 | Returns: 112 | Action: Serialized action. 113 | """ 114 | raise NotImplementedError 115 | 116 | @classmethod 117 | @abstractmethod 118 | def deserialize_action(cls, action: Action) -> np.ndarray: 119 | """Deserialize action to np.ndarray. 120 | 121 | Args: 122 | action: Serialized action. 123 | 124 | Returns: 125 | np.ndarray: Vectorized action. 126 | """ 127 | raise NotImplementedError 128 | 129 | @classmethod 130 | @abstractmethod 131 | def serialize_observation(cls, observation: np.ndarray) -> Observation: 132 | raise NotImplementedError 133 | 134 | def step(self, action: np.ndarray, delta_t: float): 135 | """Step the model's state forward by a specified time interval. 136 | 137 | Args: 138 | action: Applied action. 139 | delta_t: Time interval. 140 | """ 141 | next_state = self.compute_next_state(action, delta_t) 142 | self.state = next_state 143 | 144 | @abstractmethod 145 | def compute_next_state(self, action: np.ndarray, delta_t: float) -> State: 146 | """Proceed a step forward by a specified time interval **without** update of internal state. 147 | 148 | Args: 149 | action: Applied vectorized action. 150 | delta_t: Time interval. 151 | """ 152 | raise NotImplementedError 153 | 154 | @abstractmethod 155 | def get_dynamics_model_observation(self) -> np.ndarray: 156 | """Get dynamics model observationm usually containing hyper-parameter of the model. 157 | 158 | Returns: 159 | np.ndarray: Vectorized dynamics model observation. 160 | """ 161 | raise NotImplementedError 162 | 163 | def get_dynamics_model_observation_space(self) -> Space: 164 | """Get dynamics model observation space where the dynamics model obervation lies in. 165 | 166 | Returns: 167 | Space: Dynamics model observation space. 168 | """ 169 | observation = self.get_dynamics_model_observation() 170 | obs_size = observation.size 171 | space = gym.spaces.Box( 172 | low=-np.ones((obs_size,), dtype=DTYPE) * np.inf, 173 | high=+np.ones((obs_size,), dtype=DTYPE) * np.inf, 174 | shape=observation.shape, 175 | dtype=DTYPE, 176 | ) 177 | 178 | return space 179 | 180 | @abstractmethod 181 | def get_state_observation(self) -> np.ndarray: 182 | """Return state observation of the dynamics model, which is usually body state-independent/ego centric. 183 | 184 | Returns: 185 | np.ndarray: Vectorised State observation. 186 | """ 187 | raise NotImplementedError 188 | 189 | def get_state_observation_space(self) -> Space: 190 | """Get state observation space where the state observation lies in. 191 | 192 | Returns: 193 | Space: State obervstion space. 194 | """ 195 | observation = self.get_state_observation() 196 | obs_size = observation.size 197 | space = gym.spaces.Box( 198 | low=-np.ones((obs_size,), dtype=DTYPE) * np.inf, 199 | high=+np.ones((obs_size,), dtype=DTYPE) * np.inf, 200 | shape=observation.shape, 201 | dtype=DTYPE, 202 | ) 203 | 204 | return space 205 | 206 | @abstractmethod 207 | def get_state_space(self) -> Space: 208 | """Get state space. 209 | 210 | Returns: 211 | Space: State space. 212 | """ 213 | raise NotImplementedError 214 | 215 | @abstractmethod 216 | def get_action_space(self) -> Space: 217 | """Get action space. 218 | 219 | Returns: 220 | Space: Action space. 221 | """ 222 | raise NotImplementedError 223 | 224 | @abstractmethod 225 | def jacobian(self, state: np.ndarray, action: np.ndarray) -> np.ndarray: 226 | """Compute jacobian by performing linearization at a given point (state-action pair). 227 | 228 | Args: 229 | state: State at the linearization point, shape=(n_dims_state,). 230 | action: Action at the linearization point, shape(n_dims_action,). 231 | 232 | Returns: 233 | np.ndarray: Jacobian matrix, shape=(n_dims_state, n_dims_state + n_dims_action). 234 | """ 235 | raise NotImplementedError 236 | -------------------------------------------------------------------------------- /sdk/drltt-sdk/environments/trajectory_tracking.cpp: -------------------------------------------------------------------------------- 1 | #include "trajectory_tracking.h" 2 | 3 | namespace drltt { 4 | 5 | bool TrajectoryTracking::LoadPolicy(const std::string& policy_path) { 6 | return _policy_model.Load(policy_path); 7 | } 8 | 9 | bool TrajectoryTracking::LoadEnvData(const std::string& env_data_path) { 10 | return parse_proto_from_file(_env_data, env_data_path); 11 | } 12 | 13 | bool TrajectoryTracking::set_dynamics_model_hyper_parameter(int index) { 14 | const google::protobuf::RepeatedPtrField& 15 | all_hparams = _env_data.trajectory_tracking() 16 | .hyper_parameter() 17 | .dynamics_models_hyper_parameters(); 18 | if (index >= 0 && index <= all_hparams.size()) { 19 | _dynamics_model.set_hyper_parameter(all_hparams.at(index)); 20 | return true; 21 | } else { 22 | return false; 23 | } 24 | } 25 | 26 | bool TrajectoryTracking::set_reference_line( 27 | const std::vector& reference_line) { 28 | drltt_proto::ReferenceLine new_reference_line; 29 | for (const auto& it : reference_line) { 30 | drltt_proto::ReferenceLineWaypoint* reference_waypoint_ptr = 31 | new_reference_line.add_waypoints(); 32 | reference_waypoint_ptr->set_x(std::get<0>(it)); 33 | reference_waypoint_ptr->set_y(std::get<1>(it)); 34 | } 35 | return set_reference_line(new_reference_line); 36 | } 37 | 38 | bool TrajectoryTracking::set_reference_line( 39 | const drltt_proto::ReferenceLine& reference_line) { 40 | _reference_line.CopyFrom(reference_line); 41 | 42 | drltt_proto::State init_state; 43 | EstimateInitialState( 44 | _reference_line, init_state, 45 | _env_data.trajectory_tracking().hyper_parameter().step_interval()); 46 | _dynamics_model.set_state(init_state); 47 | 48 | return true; 49 | } 50 | 51 | // TODO: implement test 52 | bool TrajectoryTracking::EstimateInitialState( 53 | const drltt_proto::ReferenceLine& reference_line, drltt_proto::State& state, 54 | float delta_t) { 55 | const int length = reference_line.waypoints().size(); 56 | if (length <= 1) { 57 | std::cerr << "`reference_line` too short to estimate initial state."; 58 | return false; 59 | } 60 | 61 | const int window_size = 5; 62 | const float discount_factor = 1 / std::exp(1.0); 63 | const int real_window_size = std::min(window_size, length); 64 | float total_displacement = 0; 65 | float total_displacement_x = 0; 66 | float total_displacement_y = 0; 67 | float total_coefficient = 0; 68 | 69 | const drltt_proto::ReferenceLineWaypoint init_waypoint = 70 | reference_line.waypoints().at(0); 71 | for (int index = 0; index < real_window_size - 1; ++index) { 72 | const drltt_proto::ReferenceLineWaypoint& current_waypoint = 73 | reference_line.waypoints().at(index); 74 | const drltt_proto::ReferenceLineWaypoint& next_waypoint = 75 | reference_line.waypoints().at(index + 1); 76 | const float coef = std::pow(discount_factor, index); 77 | total_displacement += 78 | (std::hypot(next_waypoint.x() - current_waypoint.x(), 79 | next_waypoint.y() - current_waypoint.y()) * 80 | coef); 81 | total_displacement_x += ((next_waypoint.x() - current_waypoint.x()) * coef); 82 | total_displacement_y += ((next_waypoint.y() - current_waypoint.y()) * coef); 83 | total_coefficient += coef; 84 | } 85 | 86 | const float init_r = std::atan2(total_displacement_y, total_displacement_x); 87 | const float init_v = total_displacement / total_coefficient / delta_t; 88 | 89 | state.mutable_bicycle_model()->mutable_body_state()->set_x(init_waypoint.x()); 90 | state.mutable_bicycle_model()->mutable_body_state()->set_y(init_waypoint.y()); 91 | state.mutable_bicycle_model()->mutable_body_state()->set_r(init_r); 92 | state.mutable_bicycle_model()->set_v(init_v); 93 | 94 | return true; 95 | } 96 | 97 | bool TrajectoryTracking::set_dynamics_model_initial_state(STATE state) { 98 | drltt_proto::State init_state; 99 | init_state.mutable_bicycle_model()->mutable_body_state()->set_x( 100 | std::get<0>(state)); 101 | init_state.mutable_bicycle_model()->mutable_body_state()->set_y( 102 | std::get<1>(state)); 103 | init_state.mutable_bicycle_model()->mutable_body_state()->set_r( 104 | std::get<2>(state)); 105 | init_state.mutable_bicycle_model()->set_v(std::get<3>(state)); 106 | return set_dynamics_model_initial_state(init_state); 107 | } 108 | 109 | bool TrajectoryTracking::set_dynamics_model_initial_state( 110 | drltt_proto::State state) { 111 | _dynamics_model.set_state(state); 112 | return true; 113 | } 114 | 115 | bool TrajectoryTracking::RollOut() { 116 | _states.clear(); 117 | _actions.clear(); 118 | _observation_manager.Reset(&_reference_line, &_dynamics_model); 119 | 120 | const int tracking_length = _reference_line.waypoints().size(); 121 | const float step_interval = 122 | _env_data.trajectory_tracking().hyper_parameter().step_interval(); 123 | const float n_observation_steps = 124 | _env_data.trajectory_tracking().hyper_parameter().n_observation_steps(); 125 | for (int step_index = 0; step_index < tracking_length; ++step_index) { 126 | // state 127 | drltt_proto::State state = _dynamics_model.get_state(); 128 | const drltt_proto::BodyState& body_state = 129 | state.bicycle_model().body_state(); 130 | // observation 131 | std::vector observation_vec; 132 | _observation_manager.get_observation(body_state, step_index, 133 | tracking_length, n_observation_steps, 134 | &observation_vec); 135 | drltt_proto::Observation observation; 136 | for (const auto scalar : observation_vec) { 137 | observation.mutable_bicycle_model()->add_feature(scalar); 138 | } 139 | // action 140 | std::vector action_vec = _policy_model.Infer(observation_vec); 141 | drltt_proto::Action action; 142 | action.mutable_bicycle_model()->set_a(action_vec.at(0)); 143 | action.mutable_bicycle_model()->set_s(action_vec.at(1)); 144 | // record state, action, observation 145 | _states.push_back(state); 146 | _actions.push_back(action); 147 | _observations.push_back(observation); 148 | // debug_info 149 | _debug_infos.push_back(global_debug_info); 150 | global_debug_info.mutable_data()->Clear(); 151 | // step 152 | _dynamics_model.Step(action, step_interval); 153 | } 154 | return true; 155 | } 156 | 157 | TRAJECTORY TrajectoryTracking::get_tracked_trajectory() { 158 | // std::assert(_states.size() == _actions.size()); 159 | const int trajectory_length = _states.size(); 160 | TRAJECTORY trajectory; 161 | for (const auto& tracked_state : _states) { 162 | STATE state; 163 | std::get<0>(state) = tracked_state.bicycle_model().body_state().x(); 164 | std::get<1>(state) = tracked_state.bicycle_model().body_state().y(); 165 | std::get<2>(state) = tracked_state.bicycle_model().body_state().r(); 166 | std::get<3>(state) = tracked_state.bicycle_model().v(); 167 | std::get<0>(trajectory).push_back(state); 168 | } 169 | for (const auto& tracked_action : _actions) { 170 | ACTION action; 171 | std::get<0>(action) = tracked_action.bicycle_model().a(); 172 | std::get<1>(action) = tracked_action.bicycle_model().s(); 173 | std::get<1>(trajectory).push_back(action); 174 | } 175 | for (const auto& tracked_observation : _observations) { 176 | const auto& feature = tracked_observation.bicycle_model().feature(); 177 | OBSERVATION observation(feature.begin(), feature.end()); 178 | std::get<2>(trajectory).push_back(observation); 179 | } 180 | for (const auto& debug_info : _debug_infos) { 181 | DEBUG_DATA debug_data(debug_info.data().begin(), debug_info.data().end()); 182 | std::get<3>(trajectory).push_back(debug_data); 183 | } 184 | return trajectory; 185 | } 186 | 187 | } // namespace drltt 188 | -------------------------------------------------------------------------------- /drltt/simulator/rl_learning/sb3_learner.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict, Union, Any 2 | import os 3 | import logging 4 | from copy import deepcopy 5 | import json 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import gym 10 | import stable_baselines3 11 | from stable_baselines3.common.utils import configure 12 | from stable_baselines3.common.base_class import BaseAlgorithm 13 | from stable_baselines3.common.noise import NormalActionNoise 14 | 15 | from . import METRICS 16 | from .sb3_utils import roll_out_one_episode 17 | 18 | from drltt.common import Registry, build_object_within_registry_from_config 19 | from drltt.common.gym_helper import scale_action 20 | from drltt.simulator.environments import ExtendedGymEnv 21 | from drltt.simulator.visualization import VISUALIZATION_FUNCTIONS 22 | 23 | from drltt_proto.environment.environment_pb2 import Environment 24 | 25 | SB3_MODULES = Registry().register_from_python_module(stable_baselines3) 26 | SB3_LOGGING_FORMAT_STRINGS = ['stdout', 'log', 'csv'] 27 | 28 | 29 | def build_sb3_algorithm_from_config( 30 | environment: gym.Env, 31 | algorithm_config: Dict, 32 | ) -> BaseAlgorithm: 33 | """Build an algorithm from Stable Baselines 3. 34 | 35 | Args: 36 | environment: The associated environment. 37 | algorithm_config: The algorihm config. 38 | 39 | Returns: 40 | BaseAlgorithm: Built algorithm. 41 | """ 42 | # add action noise object 43 | # TODO: check existence of `scaled_action_noise` 44 | action_noise_config = algorithm_config.pop('scaled_action_noise') 45 | algorithm_config['action_noise'] = NormalActionNoise( 46 | mean=np.array(action_noise_config['mean']), 47 | sigma=np.array(action_noise_config['sigma']), 48 | ) 49 | # add environment object 50 | algorithm_config['env'] = environment 51 | 52 | # build algorithm object 53 | algorithm: BaseAlgorithm = build_object_within_registry_from_config(SB3_MODULES, algorithm_config) 54 | logging.info(f'Built algorithm.policy: {algorithm.policy}') 55 | 56 | return algorithm 57 | 58 | 59 | def train_with_sb3( 60 | environment: ExtendedGymEnv, 61 | algorithm_config: Dict, 62 | learning_config: Dict, 63 | checkpoint_file_prefix: str = '', 64 | ) -> Union[BaseAlgorithm, None]: 65 | """RL Training with Stable Baselines3. 66 | 67 | Args: 68 | environment: Training environment. 69 | algorithm_config: Configuration of the algorithm. 70 | learning_config: Configuration of the learning. 71 | checkpoint_file_prefix: File prefix (i.e. path without extension) to save checkpoint file. 72 | 73 | Returns: 74 | Union[BaseAlgorithm, None]: The algorithm object with trained models. 75 | """ 76 | checkpoint_file = f'{checkpoint_file_prefix}.zip' 77 | checkpoint_dir = os.path.dirname(checkpoint_file) 78 | if os.path.exists(checkpoint_file): 79 | logging.warn(f'Training aborted as checkpoint exists: {checkpoint_file}') 80 | return None 81 | 82 | algorithm = build_sb3_algorithm_from_config(environment, algorithm_config) 83 | algorithm.set_logger(configure(f'{checkpoint_dir}/sb3-train', format_strings=SB3_LOGGING_FORMAT_STRINGS)) 84 | algorithm.learn(**learning_config) 85 | 86 | if checkpoint_file_prefix != '': 87 | # save model 88 | os.makedirs(checkpoint_dir, exist_ok=True) 89 | algorithm.save(checkpoint_file_prefix) 90 | logging.info(f'SB3 Algorithm Policy saved at: {checkpoint_file}') 91 | 92 | # save environment data 93 | for _ in range(environment.env_info.trajectory_tracking.hyper_parameter.max_n_episodes + 1): 94 | roll_out_one_episode(environment, lambda obs: algorithm.predict(obs)[0]) 95 | env_data = environment.export_environment_data() 96 | env_data_save_path = f'{checkpoint_dir}/env_data.bin' 97 | with open(env_data_save_path, 'wb') as f: 98 | f.write(env_data.SerializeToString()) 99 | logging.info(f'Environment data saved to {env_data_save_path}') 100 | 101 | return algorithm 102 | 103 | 104 | def eval_with_sb3( 105 | environment: gym.Env, 106 | algorithm: BaseAlgorithm, 107 | report_dir: str, 108 | n_episodes: int, 109 | compute_metrics_name: str, 110 | visualization_function_name: str, 111 | viz_interval: int = 10, 112 | ): 113 | """RL Evaluation with Stable Baselines3. 114 | 115 | Args: 116 | environment: Evaluation environment. 117 | algorithm: The algorithm with models to be evaluated. 118 | report_dir: Directory to export report JSON. 119 | n_episodes: Number of episodes. 120 | compute_metrics_name: Name of `compute_metrics`. 121 | visualization_function_name: Name of `visualization_function`. 122 | viz_interval: Interval of episodes that this function performs visualization. 123 | TODO: set it with argument passed through Shell script. 124 | """ 125 | algorithm.set_logger(configure(f'{report_dir}/sb3-eval', format_strings=SB3_LOGGING_FORMAT_STRINGS)) 126 | all_episodes_metrics = list() 127 | viz_dir = f"{report_dir}/visualization" 128 | os.makedirs(viz_dir, exist_ok=True) 129 | for scenario_idx in range(n_episodes): 130 | logging.info(f'scenario #{scenario_idx}') 131 | roll_out_one_episode(environment, lambda obs: algorithm.predict(obs)[0]) 132 | 133 | compute_metrics = METRICS[compute_metrics_name] 134 | env_data: Environment = environment.export_environment_data() 135 | metrics = compute_metrics(env_data, environment) # format: metric[metric_name][reduce_method] 136 | if scenario_idx % viz_interval == 0: 137 | # TODO: consider moving it to env.render() 138 | viz_prefix = f"{viz_dir}/{scenario_idx}" 139 | visualization_function = VISUALIZATION_FUNCTIONS[visualization_function_name] 140 | visualization_function(env_data, viz_prefix) 141 | all_episodes_metrics.append(metrics) 142 | 143 | df = pd.DataFrame.from_records(all_episodes_metrics) 144 | json_str = df.median().to_json() 145 | 146 | os.makedirs(report_dir, exist_ok=True) 147 | report_file = f'{report_dir}/metrics.json' 148 | with open(report_file, 'w') as f: 149 | json.dump(json.loads(json_str), f, sort_keys=True, indent=2, separators=(',', ': ')) 150 | 151 | logging.info(f'Report file dumped at: {report_file}') 152 | logging.info(json_str) 153 | 154 | 155 | @METRICS.register 156 | def compute_bicycle_model_metrics( 157 | env_data: Environment, 158 | environment: ExtendedGymEnv, 159 | ) -> Dict[str, Any]: 160 | """Compute metrics for the bicycle model for an episode. 161 | 162 | Args: 163 | episode: Data of the episode. 164 | environment: Associated environment. 165 | 166 | Returns: 167 | Dict[str, Any]: Computed metrics. 168 | 169 | - l2_distance_median: median L2 distance 170 | - scaled_action_norm_median 171 | - reward_median 172 | """ 173 | assert isinstance( 174 | env_data, Environment 175 | ), f'`compute_bicycle_model_metrics` requires env_data to be in class `Environment`' 176 | episode = env_data.trajectory_tracking.episode 177 | 178 | dists = list() 179 | scaled_action_norms = list() 180 | rewards = list() 181 | for i_step in range(episode.tracking_length): 182 | reference_waypoint = episode.reference_line.waypoints[i_step] 183 | state = episode.dynamics_model.states[i_step] 184 | action = episode.dynamics_model.actions[i_step] 185 | reward = episode.rewards[i_step] 186 | tracked_position = np.array( 187 | ( 188 | state.bicycle_model.body_state.x, 189 | state.bicycle_model.body_state.y, 190 | ) 191 | ) 192 | reference_position = np.array( 193 | ( 194 | reference_waypoint.x, 195 | reference_waypoint.y, 196 | ) 197 | ) 198 | dist = np.linalg.norm(tracked_position - reference_position) 199 | 200 | action_vec = np.array((action.bicycle_model.a, action.bicycle_model.s)) 201 | scaled_action = scale_action(action_vec, environment.action_space) 202 | scaled_action_norm = np.linalg.norm(scaled_action) 203 | 204 | dists.append(dist) 205 | scaled_action_norms.append(scaled_action_norm) 206 | rewards.append(reward) 207 | 208 | metrics = dict( 209 | l2_distance_median=np.median(dists), 210 | scaled_action_norm_median=np.median(scaled_action_norms), 211 | reward_median=np.median(rewards), 212 | ) 213 | 214 | return metrics 215 | --------------------------------------------------------------------------------