├── rl_ws_worlds ├── resource │ └── rl_ws_worlds ├── rl_ws_worlds │ ├── __init__.py │ └── run.py ├── setup.cfg ├── worlds │ ├── greenhouse_object_data.yaml │ ├── greenhouse_location_data.yaml │ ├── greenhouse_random.yaml │ ├── greenhouse_plain.yaml │ ├── greenhouse_battery.yaml │ └── banana.yaml ├── package.xml └── setup.py ├── slides ├── media │ ├── dqn.png │ ├── sb3.png │ ├── greenhouse.png │ ├── grid-world.png │ ├── gymnasium.png │ ├── mdp.drawio.png │ ├── ppo-graph.png │ ├── roscon25.png │ ├── actor-critic.png │ ├── eval-results.png │ ├── onnx-runtime.png │ ├── ros-wg-delib.png │ ├── tensorboard.png │ ├── twitter-post.png │ ├── model-free-rl.png │ ├── sb3-algo-choice.png │ ├── value-iteration.png │ ├── agent-env.drawio.png │ ├── greenhouse-random.png │ ├── optuna-dashboard.png │ ├── policy-iteration.png │ ├── rsl-parallel-sim.png │ ├── agent-env-sym.drawio.png │ ├── curriculum-learning.png │ ├── greenhouse-battery.png │ ├── pytorch-threading-question.png │ ├── tensorboard-reward-compare.png │ └── venus_flytrap_src_wikimedia_commons_Tippitiwichet.jpg ├── README.md └── main.md ├── .gitmodules ├── .gitattributes ├── pyrobosim_ros_gym ├── policies │ ├── BananaPick_DQN_random.pt │ ├── BananaPick_PPO_trained.pt │ ├── BananaPlace_PPO_trained.pt │ ├── GreenhousePlain_DQN_random.pt │ ├── GreenhousePlain_PPO_trained.pt │ ├── BananaPlaceNoSoda_PPO_trained.pt │ ├── GreenhouseBattery_PPO_trained.pt │ ├── GreenhouseRandom_DQN_trained.pt │ ├── GreenhouseRandom_PPO_trained.pt │ └── __init__.py ├── __init__.py ├── config │ ├── banana_env_config.yaml │ └── greenhouse_env_config.yaml ├── start_world.py ├── envs │ ├── __init__.py │ ├── pyrobosim_ros_env.py │ ├── banana.py │ └── greenhouse.py ├── policy_node.py ├── train.py └── eval.py ├── .dockerignore ├── .gitignore ├── rl_interfaces ├── action │ └── ExecutePolicy.action ├── CMakeLists.txt └── package.xml ├── .docker └── Dockerfile ├── LICENSE ├── pyproject.toml ├── .pre-commit-config.yaml ├── .github └── workflows │ └── ci.yml ├── README.md └── test └── test_doc.py /rl_ws_worlds/resource/rl_ws_worlds: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rl_ws_worlds/rl_ws_worlds/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slides/media/dqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/dqn.png -------------------------------------------------------------------------------- /slides/media/sb3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/sb3.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pyrobosim"] 2 | path = pyrobosim 3 | url = https://github.com/sea-bass/pyrobosim.git 4 | -------------------------------------------------------------------------------- /slides/media/greenhouse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/greenhouse.png -------------------------------------------------------------------------------- /slides/media/grid-world.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/grid-world.png -------------------------------------------------------------------------------- /slides/media/gymnasium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/gymnasium.png -------------------------------------------------------------------------------- /slides/media/mdp.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/mdp.drawio.png -------------------------------------------------------------------------------- /slides/media/ppo-graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/ppo-graph.png -------------------------------------------------------------------------------- /slides/media/roscon25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/roscon25.png -------------------------------------------------------------------------------- /slides/media/actor-critic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/actor-critic.png -------------------------------------------------------------------------------- /slides/media/eval-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/eval-results.png -------------------------------------------------------------------------------- /slides/media/onnx-runtime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/onnx-runtime.png -------------------------------------------------------------------------------- /slides/media/ros-wg-delib.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/ros-wg-delib.png -------------------------------------------------------------------------------- /slides/media/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/tensorboard.png -------------------------------------------------------------------------------- /slides/media/twitter-post.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/twitter-post.png -------------------------------------------------------------------------------- /slides/media/model-free-rl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/model-free-rl.png -------------------------------------------------------------------------------- /slides/media/sb3-algo-choice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/sb3-algo-choice.png -------------------------------------------------------------------------------- /slides/media/value-iteration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/value-iteration.png -------------------------------------------------------------------------------- /rl_ws_worlds/setup.cfg: -------------------------------------------------------------------------------- 1 | [develop] 2 | script_dir=$base/lib/rl_ws_worlds 3 | [install] 4 | install_scripts=$base/lib/rl_ws_worlds 5 | -------------------------------------------------------------------------------- /slides/media/agent-env.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/agent-env.drawio.png -------------------------------------------------------------------------------- /slides/media/greenhouse-random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/greenhouse-random.png -------------------------------------------------------------------------------- /slides/media/optuna-dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/optuna-dashboard.png -------------------------------------------------------------------------------- /slides/media/policy-iteration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/policy-iteration.png -------------------------------------------------------------------------------- /slides/media/rsl-parallel-sim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/rsl-parallel-sim.png -------------------------------------------------------------------------------- /slides/media/agent-env-sym.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/agent-env-sym.drawio.png -------------------------------------------------------------------------------- /slides/media/curriculum-learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/curriculum-learning.png -------------------------------------------------------------------------------- /slides/media/greenhouse-battery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/greenhouse-battery.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # SCM syntax highlighting & preventing 3-way merges 2 | pixi.lock merge=binary linguist-language=YAML linguist-generated=true 3 | -------------------------------------------------------------------------------- /slides/media/pytorch-threading-question.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/pytorch-threading-question.png -------------------------------------------------------------------------------- /slides/media/tensorboard-reward-compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/tensorboard-reward-compare.png -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/BananaPick_DQN_random.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/BananaPick_DQN_random.pt -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/BananaPick_PPO_trained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/BananaPick_PPO_trained.pt -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/BananaPlace_PPO_trained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/BananaPlace_PPO_trained.pt -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/GreenhousePlain_DQN_random.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhousePlain_DQN_random.pt -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/GreenhousePlain_PPO_trained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhousePlain_PPO_trained.pt -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/BananaPlaceNoSoda_PPO_trained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/BananaPlaceNoSoda_PPO_trained.pt -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/GreenhouseBattery_PPO_trained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhouseBattery_PPO_trained.pt -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/GreenhouseRandom_DQN_trained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhouseRandom_DQN_trained.pt -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/GreenhouseRandom_PPO_trained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhouseRandom_PPO_trained.pt -------------------------------------------------------------------------------- /slides/media/venus_flytrap_src_wikimedia_commons_Tippitiwichet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/venus_flytrap_src_wikimedia_commons_Tippitiwichet.jpg -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Various types of metadata 2 | .mypy_cache 3 | .pixi 4 | .pytest_cache 5 | .ruff_cache 6 | **/__pycache__ 7 | 8 | # Colcon build files 9 | build/ 10 | install/ 11 | log/ 12 | 13 | # Model checkpoints and logging 14 | train_logs/ 15 | *.pt 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Various types of metadata 2 | .mypy_cache 3 | .pixi 4 | .pytest_cache 5 | .ruff_cache 6 | **/__pycache__ 7 | 8 | # Colcon build files 9 | build/ 10 | install/ 11 | log/ 12 | 13 | # Model checkpoints and logging 14 | train_logs/ 15 | **/*.pt 16 | 17 | # Slides 18 | slides/*.pdf 19 | -------------------------------------------------------------------------------- /rl_interfaces/action/ExecutePolicy.action: -------------------------------------------------------------------------------- 1 | # ROS action definition for executing a policy 2 | 3 | # Goal (empty) 4 | 5 | --- 6 | 7 | # Result 8 | 9 | bool success 10 | uint64 num_steps 11 | float64 cumulative_reward 12 | 13 | --- 14 | 15 | # Feedback 16 | 17 | uint64 num_steps 18 | float64 cumulative_reward 19 | -------------------------------------------------------------------------------- /rl_ws_worlds/worlds/greenhouse_object_data.yaml: -------------------------------------------------------------------------------- 1 | ################### 2 | # Object metadata # 3 | ################### 4 | 5 | plant_good: 6 | footprint: 7 | type: circle 8 | radius: 0.08 9 | color: [0.1, 0.8, 0.1] 10 | 11 | plant_evil: 12 | footprint: 13 | type: circle 14 | radius: 0.08 15 | color: [0.9, 0.3, 0.1] 16 | -------------------------------------------------------------------------------- /.docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | SHELL ["/bin/bash", "-o", "pipefail", "-c"] 3 | 4 | # Install apt dependencies. 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | RUN apt-get update && \ 7 | apt-get upgrade -y && \ 8 | apt-get install -y \ 9 | git curl build-essential && \ 10 | rm -rf /var/lib/apt/lists/* 11 | 12 | # Install pixi 13 | RUN curl -fsSL https://pixi.sh/install.sh | sh 14 | ENV PATH=$PATH:/root/.pixi/bin 15 | 16 | # Get this source 17 | COPY . rl_deliberation 18 | WORKDIR /rl_deliberation 19 | 20 | # Build workspace 21 | RUN pixi run build 22 | -------------------------------------------------------------------------------- /rl_interfaces/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22) 2 | 3 | project(rl_interfaces) 4 | 5 | # Default to C++17 6 | if(NOT CMAKE_CXX_STANDARD) 7 | set(CMAKE_CXX_STANDARD 17) 8 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 9 | endif() 10 | if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") 11 | add_compile_options(-Wall -Wextra -Wpedantic) 12 | endif() 13 | 14 | find_package(ament_cmake REQUIRED) 15 | find_package(rosidl_default_generators REQUIRED) 16 | 17 | rosidl_generate_interfaces(${PROJECT_NAME} 18 | "action/ExecutePolicy.action" 19 | ) 20 | 21 | ament_export_dependencies(rosidl_default_runtime) 22 | ament_package() 23 | -------------------------------------------------------------------------------- /rl_ws_worlds/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | rl_ws_worlds 5 | 1.0.0 6 | Example worlds for exploring reinforcement learning. 7 | Christian Henkel 8 | Sebastian Castro 9 | BSD-3-Clause 10 | 11 | pyrobosim_ros 12 | 13 | ament_copyright 14 | ament_flake8 15 | ament_pep257 16 | python3-pytest 17 | 18 | 19 | ament_python 20 | 21 | 22 | -------------------------------------------------------------------------------- /slides/README.md: -------------------------------------------------------------------------------- 1 | # Slides 2 | 3 | The accompanying slides are built with [Pandoc](https://pandoc.org/) and `pdflatex`. 4 | 5 | Refer to for useful information. 6 | 7 | ## Running pandoc natively 8 | 9 | You must first install these tools: 10 | 11 | ```bash 12 | sudo apt install pandoc texlive-latex-base texlive-latex-extra 13 | ``` 14 | 15 | Then, to build the slides: 16 | 17 | ```bash 18 | pandoc -t beamer main.md -o slides.pdf --listings --slide-level=2 19 | ``` 20 | 21 | ## Running pandoc with Docker (recommended) 22 | 23 | Or use the docker image that is also used in CI: 24 | (Run this command from the `slides/` directory) 25 | 26 | ```bash 27 | docker run \ 28 | --rm \ 29 | --volume "$(pwd):/data" \ 30 | --user $(id -u):$(id -g) \ 31 | pandoc/latex:3.7 -t beamer main.md -o slides.pdf --listings --slide-level=2 32 | ``` 33 | -------------------------------------------------------------------------------- /rl_interfaces/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | rl_interfaces 5 | 1.0.0 6 | Example interfaces for exploring reinforcement learning. 7 | Christian Henkel 8 | Sebastian Castro 9 | BSD-3-Clause 10 | 11 | ament_cmake 12 | rosidl_default_generators 13 | 14 | rosidl_default_runtime 15 | 16 | rosidl_interface_packages 17 | 18 | 19 | ament_cmake 20 | 21 | 22 | -------------------------------------------------------------------------------- /rl_ws_worlds/setup.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | 3 | from setuptools import find_packages, setup 4 | 5 | package_name = "rl_ws_worlds" 6 | 7 | setup( 8 | name=package_name, 9 | version="1.0.0", 10 | packages=find_packages(exclude=["test"]), 11 | data_files=[ 12 | ("share/ament_index/resource_index/packages", ["resource/" + package_name]), 13 | ("share/" + package_name, ["package.xml"]), 14 | ("share/" + package_name + "/worlds", glob("worlds/*.*")), 15 | ], 16 | install_requires=["setuptools"], 17 | zip_safe=True, 18 | maintainer="Christian Henkel", 19 | maintainer_email="christian.henkel2@de.bosch.com", 20 | description="Example worlds for exploring reinforcement learning.", 21 | license="BSD-3-Clause", 22 | entry_points={ 23 | "console_scripts": [ 24 | "run = rl_ws_worlds.run:main", 25 | "is_at_goal = rl_ws_worlds.is_at_goal:main", 26 | ], 27 | }, 28 | ) 29 | -------------------------------------------------------------------------------- /rl_ws_worlds/worlds/greenhouse_location_data.yaml: -------------------------------------------------------------------------------- 1 | ##################### 2 | # Location metadata # 3 | ##################### 4 | 5 | table: 6 | footprint: 7 | type: box 8 | dims: [0.8, 0.8] 9 | height: 0.5 10 | color: [0.3, 0.2, 0.1] 11 | nav_poses: 12 | - position: # left 13 | x: -0.6 14 | y: 0.0 15 | - position: # right 16 | x: 0.6 17 | y: 0.0 18 | rotation_eul: 19 | yaw: 3.14 20 | locations: 21 | - name: "tabletop" 22 | footprint: 23 | type: parent 24 | padding: 0.05 25 | 26 | charger: 27 | footprint: 28 | type: polygon 29 | coords: 30 | - [-0.3, -0.15] 31 | - [0.3, -0.15] 32 | - [0.3, 0.15] 33 | - [-0.3, 0.15] 34 | height: 0.1 35 | locations: 36 | - name: "dock" 37 | footprint: 38 | type: parent 39 | nav_poses: 40 | - position: # below 41 | x: 0.0 42 | y: -0.35 43 | rotation_eul: 44 | yaw: 1.57 45 | - position: # left 46 | x: -0.5 47 | y: 0.0 48 | - position: # above 49 | x: 0.0 50 | y: 0.35 51 | rotation_eul: 52 | yaw: -1.57 53 | - position: # right 54 | x: 0.5 55 | y: 0.0 56 | rotation_eul: 57 | yaw: 3.14 58 | color: [0.4, 0.4, 0] 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2025, ROS Deliberation Community Group 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | import importlib 10 | import os 11 | from typing import Any 12 | 13 | import yaml 14 | 15 | 16 | def get_config(config_path: str) -> dict[str, Any]: 17 | """Helper function to parse the configuration YAML file.""" 18 | if not os.path.isabs(config_path): 19 | default_path = os.path.join( 20 | os.path.dirname(os.path.abspath(__file__)), "config" 21 | ) 22 | config_path = os.path.join(default_path, config_path) 23 | with open(config_path, "r") as file: 24 | config = yaml.safe_load(file) 25 | 26 | # Handle special case of reward function 27 | training_args = config.get("training", {}) 28 | if "reward_fn" in training_args: 29 | module_name, function_name = training_args["reward_fn"].rsplit(".", 1) 30 | module = importlib.import_module(module_name) 31 | training_args["reward_fn"] = getattr(module, function_name) 32 | 33 | # Handle special case of policy_kwargs activation function needing to be a class instance. 34 | for subtype in training_args: 35 | subtype_config = training_args[subtype] 36 | if not isinstance(subtype_config, dict): 37 | continue 38 | policy_kwargs = subtype_config.get("policy_kwargs", {}) 39 | if "activation_fn" in policy_kwargs: 40 | module_name, class_name = policy_kwargs["activation_fn"].rsplit(".", 1) 41 | module = importlib.import_module(module_name) 42 | policy_kwargs["activation_fn"] = getattr(module, class_name) 43 | 44 | return config 45 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/config/banana_env_config.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration for the banana test environment 2 | 3 | training: 4 | # General training options 5 | max_training_steps: 25000 6 | 7 | # Evaluations during training 8 | # Every eval_freq steps, n_eval_episodes episodes will be run. 9 | # If the mean cumulative reward is greater than reward_threshold, 10 | # training will complete. 11 | eval: 12 | n_eval_episodes: 10 13 | eval_freq: 1000 # Set to 2000 for more complicated sub-types 14 | reward_threshold: 9.5 # Mean reward to terminate training 15 | 16 | # Individual algorithm options 17 | DQN: 18 | gamma: 0.99 19 | exploration_initial_eps: 0.75 20 | exploration_final_eps: 0.05 21 | exploration_fraction: 0.25 22 | learning_starts: 100 23 | learning_rate: 0.0001 24 | batch_size: 32 25 | gradient_steps: 10 26 | train_freq: 4 # steps 27 | target_update_interval: 500 28 | policy_kwargs: 29 | activation_fn: torch.nn.ReLU 30 | net_arch: [64, 64] 31 | 32 | PPO: 33 | gamma: 0.99 34 | learning_rate: 0.0003 35 | batch_size: 32 36 | n_steps: 64 37 | policy_kwargs: 38 | activation_fn: torch.nn.ReLU 39 | net_arch: 40 | pi: [64, 64] # actor size 41 | vf: [32, 32] # critic size 42 | 43 | SAC: 44 | gamma: 0.99 45 | learning_rate: 0.0003 46 | batch_size: 32 47 | gradient_steps: 10 48 | train_freq: 4 # steps 49 | target_update_interval: 10 50 | policy_kwargs: 51 | activation_fn: torch.nn.ReLU 52 | net_arch: 53 | pi: [64, 64] # actor size 54 | qf: [32, 32] # critic size (SAC uses qf, not vf) 55 | 56 | A2C: 57 | gamma: 0.99 58 | learning_rate: 0.0007 59 | policy_kwargs: 60 | activation_fn: torch.nn.ReLU 61 | net_arch: 62 | pi: [64, 64] # actor size 63 | vf: [32, 32] # critic size 64 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "rl_deliberation" 3 | authors = [ 4 | {name = "Sebastian Castro", email = "sebas.a.castro@gmail.com"}, 5 | {name = "Christian Henkel", email = "christian.henkel2@de.bosch.com"}, 6 | ] 7 | version = "0.1.0" 8 | requires-python = ">=3.12, <3.13" 9 | dependencies = ["astar>=0.99,<0.100"] 10 | 11 | [tool.pixi.workspace] 12 | preview = ["pixi-build"] 13 | channels = ["robostack-kilted", "conda-forge"] 14 | platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"] 15 | 16 | [tool.pixi.package.build] 17 | backend = { name = "pixi-build-ros", version = "*" } 18 | 19 | [tool.pixi.dependencies] 20 | python = ">=3.12, <3.13" 21 | colcon-common-extensions = ">=0.3.0,<0.4" 22 | compilers = ">=1.11.0,<2" 23 | pre-commit = ">=4.2.0,<5" 24 | rich = ">=14.1.0,<15" 25 | ros-kilted-ros-core = ">=0.12.0,<0.13" 26 | stable-baselines3 = ">=2.6.0,<3" 27 | setuptools = ">=78.1.1,<80.0.0" 28 | scipy = ">=1.16.0,<2" 29 | transforms3d = ">=0.4.2,<0.5" 30 | adjusttext = ">=1.3.0,<2" 31 | matplotlib = ">=3.10.5,<4" 32 | numpy = ">=1.26.4,<2" 33 | pycollada = ">=0.9.2,<0.10" 34 | pyside6 = ">=6.4.0" 35 | pyyaml = ">=6.0.3,<7" 36 | shapely = ">=2.0.1" 37 | trimesh = ">=4.7.1,<5" 38 | tqdm = ">=4.67.1,<5" 39 | tensorboard = "*" 40 | types-pyyaml = ">=6.0.12.20250915,<7" 41 | 42 | [tool.pixi.tasks] 43 | build = "colcon build --symlink-install" 44 | clean = "rm -rf build install log" 45 | lint = "pre-commit run -a" 46 | zenoh = "ros2 run rmw_zenoh_cpp rmw_zenohd" 47 | pyrobosim_demo = "ros2 launch pyrobosim_ros demo.launch.py" 48 | start_world = "python -m pyrobosim_ros_gym.start_world" 49 | train = "python -m pyrobosim_ros_gym.train" 50 | eval = "python -m pyrobosim_ros_gym.eval" 51 | tensorboard = "tensorboard --logdir ./train_logs/" 52 | policy_node = "python -m pyrobosim_ros_gym.policy_node" 53 | 54 | [tool.pixi.activation] 55 | scripts = ["install/setup.sh"] 56 | 57 | [tool.pixi.target.win-64.activation] 58 | scripts = ["install/setup.bat"] 59 | -------------------------------------------------------------------------------- /rl_ws_worlds/rl_ws_worlds/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | """ 10 | Runner for ROS 2 Deliberation workshop worlds. 11 | """ 12 | 13 | import os 14 | import rclpy 15 | import threading 16 | 17 | from pyrobosim.core import WorldYamlLoader 18 | from pyrobosim.gui import start_gui 19 | from pyrobosim_ros.ros_interface import WorldROSWrapper 20 | from ament_index_python.packages import get_package_share_directory 21 | 22 | 23 | def create_ros_node() -> WorldROSWrapper: 24 | """Initializes ROS node""" 25 | rclpy.init() 26 | node = WorldROSWrapper(state_pub_rate=0.1, dynamics_rate=0.01) 27 | node.declare_parameter("world_name", "greenhouse") 28 | node.declare_parameter("headless", False) 29 | 30 | # Set the world file. 31 | world_name = node.get_parameter("world_name").value 32 | node.get_logger().info(f"Starting world '{world_name}'") 33 | world_file = os.path.join( 34 | get_package_share_directory("rl_ws_worlds"), 35 | "worlds", 36 | f"{world_name}.yaml", 37 | ) 38 | world = WorldYamlLoader().from_file(world_file) 39 | node.set_world(world) 40 | 41 | return node 42 | 43 | 44 | def start_node(node: WorldROSWrapper): 45 | headless = node.get_parameter("headless").value 46 | if headless: 47 | # Start ROS node in main thread if there is no GUI. 48 | node.start(wait_for_gui=False) 49 | else: 50 | # Start ROS node in separate thread 51 | ros_thread = threading.Thread(target=lambda: node.start(wait_for_gui=True)) 52 | ros_thread.start() 53 | 54 | # Start GUI in main thread 55 | start_gui(node.world) 56 | 57 | 58 | def main(): 59 | node = create_ros_node() 60 | start_node(node) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/config/greenhouse_env_config.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration for the greenhouse test environment 2 | 3 | training: 4 | # General training options 5 | max_training_steps: 10000 6 | 7 | # Reward function to use when stepping the environment 8 | # reward_fn: pyrobosim_ros_gym.envs.greenhouse.sparse_reward # Try this first 9 | # reward_fn: pyrobosim_ros_gym.envs.greenhouse.dense_reward # Try this second 10 | reward_fn: pyrobosim_ros_gym.envs.greenhouse.full_reward # End with this one 11 | 12 | # Evaluations during training 13 | # Every eval_freq steps, n_eval_episodes episodes will be run. 14 | # If the mean cumulative reward is greater than reward_threshold, 15 | # training will complete. 16 | eval: 17 | n_eval_episodes: 5 18 | eval_freq: 500 19 | reward_threshold: 7.0 20 | 21 | # Individual algorithm options 22 | DQN: 23 | gamma: 0.99 24 | exploration_initial_eps: 0.75 25 | exploration_final_eps: 0.05 26 | exploration_fraction: 0.2 27 | learning_starts: 25 28 | learning_rate: 0.0002 29 | batch_size: 16 30 | gradient_steps: 5 31 | train_freq: 4 # steps 32 | target_update_interval: 10 33 | policy_kwargs: 34 | activation_fn: torch.nn.ReLU 35 | net_arch: [16, 8] 36 | 37 | PPO: 38 | gamma: 0.99 39 | learning_rate: 0.0003 40 | batch_size: 8 41 | n_steps: 8 42 | policy_kwargs: 43 | activation_fn: torch.nn.ReLU 44 | net_arch: 45 | pi: [16, 8] # actor size 46 | vf: [16, 8] # critic size 47 | 48 | SAC: 49 | gamma: 0.99 50 | learning_rate: 0.0003 51 | batch_size: 16 52 | train_freq: 4 # steps 53 | gradient_steps: 5 54 | tau: 0.005 55 | target_update_interval: 10 56 | policy_kwargs: 57 | activation_fn: torch.nn.ReLU 58 | net_arch: 59 | pi: [16, 8] # actor size 60 | qf: [16, 8] # critic size (SAC uses qf, not vf) 61 | 62 | A2C: 63 | gamma: 0.99 64 | learning_rate: 0.0007 65 | policy_kwargs: 66 | activation_fn: torch.nn.ReLU 67 | net_arch: 68 | pi: [16, 8] # actor size 69 | vf: [16, 8] # critic size 70 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/start_world.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | """Loads a world to act as a server for the RL problem.""" 10 | 11 | import argparse 12 | import rclpy 13 | import threading 14 | 15 | from pyrobosim.core import WorldYamlLoader 16 | from pyrobosim.gui import start_gui, WorldCanvasOptions 17 | from pyrobosim_ros.ros_interface import WorldROSWrapper 18 | 19 | from pyrobosim_ros_gym.envs import ( 20 | get_env_class_and_subtype_from_name, 21 | available_envs_w_subtype, 22 | ) 23 | from pyrobosim_ros_gym.envs.greenhouse import GreenhouseEnv 24 | 25 | 26 | def create_ros_node(world_file_path) -> WorldROSWrapper: 27 | """Initializes ROS node""" 28 | rclpy.init() 29 | world = WorldYamlLoader().from_file(world_file_path) 30 | return WorldROSWrapper(world=world, state_pub_rate=0.1, dynamics_rate=0.01) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--headless", action="store_true", help="Enables headless world loading." 37 | ) 38 | parser.add_argument( 39 | "--env", 40 | choices=available_envs_w_subtype(), 41 | help="The environment to use.", 42 | required=True, 43 | ) 44 | args = parser.parse_args() 45 | 46 | env_class, sub_type = get_env_class_and_subtype_from_name(args.env) 47 | node = create_ros_node(env_class.get_world_file_path(sub_type)) 48 | show_room_names = env_class != GreenhouseEnv 49 | 50 | if args.headless: 51 | # Start ROS node in main thread if there is no GUI. 52 | node.start(wait_for_gui=False) 53 | else: 54 | # Start ROS node in separate thread. 55 | ros_thread = threading.Thread(target=lambda: node.start(wait_for_gui=True)) 56 | ros_thread.start() 57 | 58 | # Start GUI in main thread. 59 | options = WorldCanvasOptions(show_room_names=show_room_names) 60 | start_gui(node.world, options=options) 61 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Runs pre-commit hooks and other file format checks. 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v6.0.0 5 | hooks: 6 | - id: check-added-large-files 7 | - id: check-ast 8 | - id: check-builtin-literals 9 | - id: check-case-conflict 10 | - id: check-docstring-first 11 | - id: check-executables-have-shebangs 12 | - id: check-json 13 | - id: check-merge-conflict 14 | - id: check-symlinks 15 | - id: check-toml 16 | - id: check-vcs-permalinks 17 | - id: check-yaml 18 | - id: debug-statements 19 | - id: destroyed-symlinks 20 | - id: detect-private-key 21 | - id: end-of-file-fixer 22 | - id: fix-byte-order-marker 23 | - id: forbid-new-submodules 24 | - id: mixed-line-ending 25 | - id: name-tests-test 26 | - id: requirements-txt-fixer 27 | - id: sort-simple-yaml 28 | - id: trailing-whitespace 29 | exclude: slides/main.md$ 30 | 31 | # Autoformats Python code. 32 | - repo: https://github.com/psf/black.git 33 | rev: 25.9.0 34 | hooks: 35 | - id: black 36 | exclude: ^pyrobosim/ 37 | 38 | # Type checking 39 | - repo: https://github.com/pre-commit/mirrors-mypy 40 | rev: v1.18.2 41 | hooks: 42 | - id: mypy 43 | args: [--ignore-missing-imports] 44 | additional_dependencies: [types-PyYAML] 45 | 46 | # Finds spelling issues in code. 47 | - repo: https://github.com/codespell-project/codespell 48 | rev: v2.4.1 49 | hooks: 50 | - id: codespell 51 | files: ^.*\.(py|md|yaml|yml)$ 52 | 53 | # Finds issues in YAML files. 54 | - repo: https://github.com/adrienverge/yamllint 55 | rev: v1.37.1 56 | hooks: 57 | - id: yamllint 58 | args: 59 | [ 60 | "--no-warnings", 61 | "--config-data", 62 | "{extends: default, rules: {line-length: disable, braces: 63 | {max-spaces-inside: 1}, indentation: disable}}", 64 | ] 65 | types: [text] 66 | files: \.(yml|yaml)$ 67 | 68 | # Checks links in markdown files. 69 | - repo: https://github.com/tcort/markdown-link-check 70 | rev: v3.13.7 71 | hooks: 72 | - id: markdown-link-check 73 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policies/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | import os 10 | 11 | from gymnasium import Space 12 | from stable_baselines3 import DQN, PPO, SAC, A2C 13 | from stable_baselines3.common.base_class import BaseAlgorithm 14 | 15 | 16 | AVAILABLE_POLICIES = {alg.__name__: alg for alg in (DQN, PPO, SAC, A2C)} 17 | 18 | 19 | class ManualPolicy: 20 | """A policy that allows manual keyboard control of the robot.""" 21 | 22 | def __init__(self, action_space: Space): 23 | print("Welcome. You are the agent now!") 24 | self.action_space = action_space 25 | 26 | def predict(self, observation, deterministic): 27 | # print(f"Observation: {observation}") 28 | # print(f"Action space: {self.action_space}") 29 | possible_actions = list(range(self.action_space.n)) 30 | while True: 31 | try: 32 | action = int(input(f"Enter action from {possible_actions}: ")) 33 | if action in possible_actions: 34 | return action, None 35 | else: 36 | print(f"Action {action} not in {possible_actions}.") 37 | except ValueError: 38 | print("Invalid input, please enter an integer.") 39 | 40 | 41 | def model_and_env_type_from_path(model_path: str) -> tuple[BaseAlgorithm, str]: 42 | """ 43 | Loads a model and its corresponding environment type from its file path. 44 | 45 | The models are of the form _[_].pt. 46 | For example, path/to/model/GreenhousePlain_DQN_seed42_2025_10_18_18_02_21.pt. 47 | """ 48 | # Validate the file path to be of the right form 49 | assert os.path.isfile(model_path), f"Model {model_path} must be a valid file." 50 | model_fname = os.path.basename(model_path) 51 | model_name_parts = model_fname.split("_") 52 | assert ( 53 | len(model_name_parts) >= 2 54 | ), f"Model name {model_fname} must be of the form _[_].pt" 55 | env_type = model_name_parts[0] 56 | algorithm = model_name_parts[1] 57 | 58 | # Load the model 59 | if algorithm in AVAILABLE_POLICIES: 60 | model_class = AVAILABLE_POLICIES[algorithm] 61 | model = model_class.load(model_path) 62 | else: 63 | raise RuntimeError(f"Invalid algorithm type: {algorithm}") 64 | 65 | return (model, env_type) 66 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | from typing import Any, Callable, List, Dict 10 | 11 | import rclpy 12 | from rclpy.executors import Executor 13 | 14 | from .banana import BananaEnv 15 | from .greenhouse import GreenhouseEnv 16 | from .pyrobosim_ros_env import PyRoboSimRosEnv 17 | 18 | 19 | ENV_CLASS_FROM_NAME: Dict[str, type[PyRoboSimRosEnv]] = { 20 | "Banana": BananaEnv, 21 | "Greenhouse": GreenhouseEnv, 22 | } 23 | 24 | 25 | def available_envs_w_subtype() -> List[str]: 26 | """Return a list of environment types including subtypes.""" 27 | envs: List[str] = [] 28 | for name, env_class in ENV_CLASS_FROM_NAME.items(): 29 | for sub_type in env_class.sub_types: 30 | envs.append("".join((name, sub_type.name))) 31 | return envs 32 | 33 | 34 | def get_env_class_and_subtype_from_name(req_name: str): 35 | """Return the class of a chosen environment name (ignoring `sub_type`s).""" 36 | for name, env_class in ENV_CLASS_FROM_NAME.items(): 37 | if req_name.startswith(name): 38 | sub_type_str = req_name.replace(name, "") 39 | for st in env_class.sub_types: 40 | if st.name == sub_type_str: 41 | sub_type = st 42 | return env_class, sub_type 43 | raise RuntimeError(f"No environment found for {req_name}.") 44 | 45 | 46 | def get_env_by_name( 47 | env_name: str, 48 | node: rclpy.node.Node, 49 | max_steps_per_episode: int, 50 | realtime: bool, 51 | discrete_actions: bool, 52 | reward_fn: Callable[..., Any], 53 | executor: Executor | None = None, 54 | ) -> PyRoboSimRosEnv: 55 | """ 56 | Instantiate an environment class for a given type and `sub_type`. 57 | 58 | :param env_name: Name of environment, with subtype, e.g. BananaPick. 59 | :param node: Node instance needed for ROS communication. 60 | :param max_steps_per_episode: Limit the steps (when to end the episode). 61 | :param realtime: Whether actions take time. 62 | :param discrete_actions: Choose discrete actions (needed for DQN). 63 | :param reward_fn: The function used to compute the reward at each step. 64 | :param executor: Optional ROS executor. It must be already spinning! 65 | """ 66 | base_class, sub_type = get_env_class_and_subtype_from_name(env_name) 67 | return base_class( 68 | sub_type, 69 | node, 70 | max_steps_per_episode, 71 | realtime, 72 | discrete_actions, 73 | reward_fn, 74 | executor=executor, 75 | ) 76 | -------------------------------------------------------------------------------- /rl_ws_worlds/worlds/greenhouse_random.yaml: -------------------------------------------------------------------------------- 1 | # Problem 1 World 2 | 3 | metadata: 4 | locations: $PWD/greenhouse_location_data.yaml 5 | objects: $PWD/greenhouse_object_data.yaml 6 | 7 | params: 8 | name: greenhouse 9 | inflation_radius: 0.01 10 | object_radius: 0.01 11 | 12 | robots: 13 | - name: robot 14 | radius: 0.1 15 | location: hall 16 | path_executor: 17 | type: constant_velocity 18 | path_planner: 19 | type: astar 20 | grid_resolution: 0.1 21 | grid_inflation_radius: 0.1 22 | 23 | rooms: 24 | - name: hall 25 | footprint: 26 | type: polygon 27 | coords: 28 | - [-3, -3] 29 | - [3, -3] 30 | - [3, 3] 31 | - [-3, 3] 32 | nav_poses: 33 | - position: # left 34 | x: -2.0 35 | y: 0.0 36 | - position: # right 37 | x: 2.0 38 | y: 0.0 39 | rotation_eul: 40 | yaw: 3.14 41 | wall_width: 0.2 42 | color: [0.1, 0.1, 0] 43 | 44 | locations: 45 | - name: table_nw 46 | parent: hall 47 | category: table 48 | pose: 49 | position: 50 | x: -2 51 | y: 2 52 | - name: table_n 53 | parent: hall 54 | category: table 55 | pose: 56 | position: 57 | x: 0 58 | y: 2 59 | - name: table_ne 60 | parent: hall 61 | category: table 62 | pose: 63 | position: 64 | x: 2 65 | y: 2 66 | - name: table_w 67 | parent: hall 68 | category: table 69 | pose: 70 | position: 71 | x: -2 72 | y: 0 73 | - name: table_c 74 | parent: hall 75 | category: table 76 | pose: 77 | position: 78 | x: 0 79 | y: 0 80 | - name: table_e 81 | parent: hall 82 | category: table 83 | pose: 84 | position: 85 | x: 2 86 | y: 0 87 | - name: table_sw 88 | parent: hall 89 | category: table 90 | pose: 91 | position: 92 | x: -2 93 | y: -2 94 | - name: table_s 95 | parent: hall 96 | category: table 97 | pose: 98 | position: 99 | x: 0 100 | y: -2 101 | - name: table_se 102 | parent: hall 103 | category: table 104 | pose: 105 | position: 106 | x: 2 107 | y: -2 108 | 109 | objects: 110 | - name: plant0 111 | parent: table 112 | category: plant_good 113 | - name: plant1 114 | parent: table 115 | category: plant_evil 116 | - name: plant2 117 | parent: table 118 | category: plant_good 119 | - name: plant3 120 | parent: table 121 | category: plant_good 122 | - name: plant4 123 | parent: table 124 | category: plant_good 125 | -------------------------------------------------------------------------------- /rl_ws_worlds/worlds/greenhouse_plain.yaml: -------------------------------------------------------------------------------- 1 | # Problem 1 World 2 | 3 | metadata: 4 | locations: $PWD/greenhouse_location_data.yaml 5 | objects: $PWD/greenhouse_object_data.yaml 6 | 7 | params: 8 | name: greenhouse 9 | inflation_radius: 0.01 10 | object_radius: 0.01 11 | 12 | robots: 13 | - name: robot 14 | radius: 0.1 15 | location: hall 16 | path_executor: 17 | type: constant_velocity 18 | path_planner: 19 | type: astar 20 | grid_resolution: 0.1 21 | grid_inflation_radius: 0.1 22 | 23 | rooms: 24 | - name: hall 25 | footprint: 26 | type: polygon 27 | coords: 28 | - [-3, -3] 29 | - [3, -3] 30 | - [3, 3] 31 | - [-3, 3] 32 | nav_poses: 33 | - position: # left 34 | x: -2.0 35 | y: 0.0 36 | - position: # right 37 | x: 2.0 38 | y: 0.0 39 | rotation_eul: 40 | yaw: 3.14 41 | wall_width: 0.2 42 | color: [0.1, 0.1, 0] 43 | 44 | locations: 45 | - name: table_nw 46 | parent: hall 47 | category: table 48 | pose: 49 | position: 50 | x: -2 51 | y: 2 52 | - name: table_n 53 | parent: hall 54 | category: table 55 | pose: 56 | position: 57 | x: 0 58 | y: 2 59 | - name: table_ne 60 | parent: hall 61 | category: table 62 | pose: 63 | position: 64 | x: 2 65 | y: 2 66 | - name: table_w 67 | parent: hall 68 | category: table 69 | pose: 70 | position: 71 | x: -2 72 | y: 0 73 | - name: table_c 74 | parent: hall 75 | category: table 76 | pose: 77 | position: 78 | x: 0 79 | y: 0 80 | - name: table_e 81 | parent: hall 82 | category: table 83 | pose: 84 | position: 85 | x: 2 86 | y: 0 87 | - name: table_sw 88 | parent: hall 89 | category: table 90 | pose: 91 | position: 92 | x: -2 93 | y: -2 94 | - name: table_s 95 | parent: hall 96 | category: table 97 | pose: 98 | position: 99 | x: 0 100 | y: -2 101 | - name: table_se 102 | parent: hall 103 | category: table 104 | pose: 105 | position: 106 | x: 2 107 | y: -2 108 | 109 | objects: 110 | - name: plant0 111 | parent: table_nw 112 | category: plant_good 113 | - name: plant1 114 | parent: table_ne 115 | category: plant_evil 116 | - name: plant2 117 | parent: table_c 118 | category: plant_good 119 | - name: plant3 120 | parent: table_sw 121 | category: plant_good 122 | - name: plant4 123 | parent: table_se 124 | category: plant_good 125 | -------------------------------------------------------------------------------- /rl_ws_worlds/worlds/greenhouse_battery.yaml: -------------------------------------------------------------------------------- 1 | # Problem 1 World 2 | 3 | metadata: 4 | locations: $PWD/greenhouse_location_data.yaml 5 | objects: $PWD/greenhouse_object_data.yaml 6 | 7 | params: 8 | name: greenhouse 9 | inflation_radius: 0.01 10 | object_radius: 0.01 11 | 12 | robots: 13 | - name: robot 14 | radius: 0.1 15 | location: hall 16 | path_executor: 17 | type: constant_velocity 18 | path_planner: 19 | type: astar 20 | grid_resolution: 0.1 21 | grid_inflation_radius: 0.1 22 | action_execution_options: 23 | close: 24 | battery_usage: 49 # 2 close actions per charge 25 | 26 | rooms: 27 | - name: hall 28 | footprint: 29 | type: polygon 30 | coords: 31 | - [-3, -3] 32 | - [3, -3] 33 | - [3, 3] 34 | - [-3, 3] 35 | nav_poses: 36 | - position: # left 37 | x: -2.0 38 | y: 0.0 39 | - position: # right 40 | x: 2.0 41 | y: 0.0 42 | rotation_eul: 43 | yaw: 3.14 44 | wall_width: 0.2 45 | color: [0.1, 0.1, 0] 46 | 47 | locations: 48 | - name: table_nw 49 | parent: hall 50 | category: table 51 | pose: 52 | position: 53 | x: -2 54 | y: 2 55 | - name: table_n 56 | parent: hall 57 | category: table 58 | pose: 59 | position: 60 | x: 0 61 | y: 2 62 | - name: table_ne 63 | parent: hall 64 | category: table 65 | pose: 66 | position: 67 | x: 2 68 | y: 2 69 | - name: table_w 70 | parent: hall 71 | category: table 72 | pose: 73 | position: 74 | x: -2 75 | y: 0 76 | - name: table_c 77 | parent: hall 78 | category: table 79 | pose: 80 | position: 81 | x: 0 82 | y: 0 83 | - name: table_e 84 | parent: hall 85 | category: table 86 | pose: 87 | position: 88 | x: 2 89 | y: 0 90 | - name: table_sw 91 | parent: hall 92 | category: table 93 | pose: 94 | position: 95 | x: -2 96 | y: -2 97 | - name: table_s 98 | parent: hall 99 | category: table 100 | pose: 101 | position: 102 | x: 0 103 | y: -2 104 | - name: table_se 105 | parent: hall 106 | category: table 107 | pose: 108 | position: 109 | x: 2 110 | y: -2 111 | 112 | - name: charger 113 | category: charger 114 | parent: hall 115 | pose: 116 | position: 117 | x: 1.0 118 | y: -2.8 119 | is_charger: True 120 | 121 | objects: 122 | - name: plant0 123 | parent: table 124 | category: plant_good 125 | - name: plant1 126 | parent: table 127 | category: plant_evil 128 | - name: plant2 129 | parent: table 130 | category: plant_good 131 | - name: plant3 132 | parent: table 133 | category: plant_good 134 | - name: plant4 135 | parent: table 136 | category: plant_good 137 | -------------------------------------------------------------------------------- /rl_ws_worlds/worlds/banana.yaml: -------------------------------------------------------------------------------- 1 | ########################## 2 | # Test world description # 3 | ########################## 4 | 5 | # WORLD PARAMETERS 6 | params: 7 | name: banana 8 | object_radius: 0.01 # Radius around objects 9 | wall_height: 2.0 # Wall height for exporting to Gazebo 10 | 11 | 12 | # METADATA: Describes information about locations and objects 13 | metadata: 14 | locations: 15 | - $DATA/example_location_data_furniture.yaml 16 | - $DATA/example_location_data_accessories.yaml 17 | objects: 18 | - $DATA/example_object_data_food.yaml 19 | - $DATA/example_object_data_drink.yaml 20 | # ROBOTS 21 | robots: 22 | - name: robot 23 | radius: 0.1 24 | color: "#CC00CC" 25 | max_linear_velocity: 3.0 26 | max_angular_velocity: 6.0 27 | max_linear_acceleration: 10.0 28 | max_angular_acceleration: 10.0 29 | path_planner: 30 | type: world_graph 31 | collision_check_step_dist: 0.05 32 | path_executor: 33 | type: constant_velocity 34 | linear_velocity: 3.0 35 | max_angular_velocity: 6.0 36 | dt: 0.1 37 | validate_during_execution: false 38 | location: ["desk", "counter", "table"] 39 | 40 | # ROOMS: Polygonal regions that can contain object locations 41 | rooms: 42 | - name: kitchen 43 | footprint: 44 | type: polygon 45 | coords: 46 | - [-1, -1] 47 | - [1.5, -1] 48 | - [1.5, 1.5] 49 | - [0.5, 1.5] 50 | nav_poses: 51 | - position: 52 | x: 0.75 53 | y: 0.5 54 | wall_width: 0.2 55 | color: "red" 56 | 57 | - name: bedroom 58 | footprint: 59 | type: box 60 | dims: [1.75, 1.5] 61 | pose: 62 | position: 63 | x: 2.625 64 | y: 3.25 65 | wall_width: 0.2 66 | color: "#009900" 67 | 68 | - name: bathroom 69 | footprint: 70 | type: polygon 71 | coords: 72 | - [-1, 1] 73 | - [-1, 3.5] 74 | - [-3, 3.5] 75 | - [-2.5, 1] 76 | wall_width: 0.2 77 | color: [0, 0, 0.6] 78 | 79 | 80 | # HALLWAYS: Connect rooms 81 | hallways: 82 | - room_start: kitchen 83 | room_end: bathroom 84 | width: 0.7 85 | conn_method: auto 86 | is_open: True 87 | is_locked: False 88 | color: "#666666" 89 | 90 | - room_start: bathroom 91 | room_end: bedroom 92 | width: 0.5 93 | conn_method: angle 94 | conn_angle: 0.0 95 | offset: 0.8 96 | is_open: True 97 | is_locked: False 98 | color: "dimgray" 99 | 100 | - room_start: kitchen 101 | room_end: bedroom 102 | width: 0.6 103 | conn_method: points 104 | conn_points: 105 | - [1.0, 0.5] 106 | - [2.5, 0.5] 107 | - [2.5, 3.0] 108 | is_open: True 109 | is_locked: False 110 | 111 | 112 | # LOCATIONS: Can contain objects 113 | locations: 114 | - name: table0 115 | category: table 116 | parent: kitchen 117 | pose: 118 | position: 119 | x: 0.85 120 | y: -0.5 121 | rotation_eul: 122 | yaw: -90.0 123 | angle_units: "degrees" 124 | is_open: True 125 | is_locked: True 126 | 127 | - name: my_desk 128 | category: desk 129 | parent: bedroom 130 | pose: 131 | position: 132 | x: 0.525 133 | y: 0.4 134 | relative_to: bedroom 135 | is_open: True 136 | is_locked: False 137 | 138 | - name: counter0 139 | category: counter 140 | parent: bathroom 141 | pose: 142 | position: 143 | x: -2.45 144 | y: 2.5 145 | rotation_eul: 146 | yaw: 101.2 147 | angle_units: "degrees" 148 | is_open: True 149 | is_locked: True 150 | 151 | 152 | # OBJECTS: Can be picked, placed, and moved by robot 153 | objects: 154 | - category: apple 155 | 156 | - category: apple 157 | 158 | - category: water 159 | 160 | - category: water 161 | 162 | - category: banana 163 | parent: ["desk", "counter"] 164 | 165 | - category: banana 166 | parent: ["desk", "counter"] 167 | 168 | - name: soda 169 | category: coke 170 | parent: ["desk", "counter"] 171 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/policy_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | """Serves a policy as a ROS node with an action server for deployment.""" 10 | 11 | import argparse 12 | import time 13 | 14 | from gymnasium.spaces import Discrete 15 | import rclpy 16 | from rclpy.action import ActionServer, CancelResponse 17 | from rclpy.executors import ExternalShutdownException, MultiThreadedExecutor 18 | from rclpy.node import Node 19 | 20 | from rl_interfaces.action import ExecutePolicy # type: ignore[attr-defined] 21 | 22 | from pyrobosim_ros_gym import get_config 23 | from pyrobosim_ros_gym.envs import get_env_by_name 24 | from pyrobosim_ros_gym.policies import model_and_env_type_from_path 25 | 26 | 27 | class PolicyServerNode(Node): 28 | def __init__(self, args: argparse.Namespace, executor): 29 | super().__init__("policy_node") 30 | 31 | # Load the model and environment 32 | self.model, env_type = model_and_env_type_from_path(args.model) 33 | self.env = get_env_by_name( 34 | env_type, 35 | self, 36 | executor=executor, 37 | max_steps_per_episode=-1, 38 | realtime=True, 39 | discrete_actions=isinstance(self.model.action_space, Discrete), 40 | reward_fn=get_config(args.config)["training"].get("reward_fn"), 41 | ) 42 | 43 | self.action_server = ActionServer( 44 | self, 45 | ExecutePolicy, 46 | "/execute_policy", 47 | execute_callback=self.execute_policy, 48 | cancel_callback=self.cancel_policy, 49 | ) 50 | 51 | self.get_logger().info(f"Started policy node with model '{args.model}'.") 52 | 53 | def cancel_policy(self, goal_handle): 54 | self.get_logger().info("Canceling policy execution...") 55 | return CancelResponse.ACCEPT 56 | 57 | async def execute_policy(self, goal_handle): 58 | self.get_logger().info("Starting policy execution...") 59 | result = ExecutePolicy.Result() 60 | 61 | self.env.initialize() # Resets helper variables 62 | obs = self.env._get_obs() 63 | cumulative_reward = 0.0 64 | while True: 65 | if goal_handle.is_cancel_requested: 66 | goal_handle.canceled() 67 | self.get_logger().info("Policy execution canceled") 68 | return result 69 | 70 | action, _ = self.model.predict(obs, deterministic=True) 71 | obs, reward, terminated, truncated, info = self.env.step(action) 72 | num_steps = self.env.step_number 73 | cumulative_reward += reward 74 | self.get_logger().info( 75 | f"Step {num_steps}: cumulative reward = {cumulative_reward}" 76 | ) 77 | 78 | goal_handle.publish_feedback( 79 | ExecutePolicy.Feedback( 80 | num_steps=num_steps, cumulative_reward=cumulative_reward 81 | ) 82 | ) 83 | 84 | if terminated or truncated: 85 | break 86 | 87 | time.sleep(0.1) # Small sleep between actions 88 | 89 | goal_handle.succeed() 90 | result.success = info["success"] 91 | result.num_steps = num_steps 92 | result.cumulative_reward = cumulative_reward 93 | self.get_logger().info(f"Policy completed in {num_steps} steps.") 94 | self.get_logger().info( 95 | f"success: {result.success}, cumulative reward: {cumulative_reward}" 96 | ) 97 | return result 98 | 99 | 100 | def main(args=None): 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument( 103 | "--model", required=True, help="The name of the model to serve." 104 | ) 105 | parser.add_argument( 106 | "--config", 107 | help="Path to the configuration YAML file.", 108 | required=True, 109 | ) 110 | cli_args = parser.parse_args() 111 | 112 | rclpy.init(args=args) 113 | try: 114 | executor = MultiThreadedExecutor(num_threads=2) 115 | import threading 116 | 117 | threading.Thread(target=executor.spin).start() 118 | policy_server = PolicyServerNode(cli_args, executor) 119 | while True: 120 | time.sleep(0.5) 121 | except (KeyboardInterrupt, ExternalShutdownException): 122 | pass 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | # Run action on certain pull request events 5 | pull_request: 6 | types: [opened, synchronize, reopened, ready_for_review] 7 | 8 | # After pushing to main 9 | push: 10 | branches: [main] 11 | 12 | # Nightly job on default (main) branch 13 | schedule: 14 | - cron: '0 0 * * *' 15 | 16 | # Allow tests to be run manually 17 | workflow_dispatch: 18 | 19 | env: 20 | REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/rl_deliberation 21 | RAW_BRANCH_NAME: ${{ github.head_ref || github.ref_name }} 22 | 23 | jobs: 24 | build_and_test: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - name: Checkout repo 28 | uses: actions/checkout@v4 29 | with: 30 | submodules: true 31 | - name: Setup pixi 32 | uses: prefix-dev/setup-pixi@v0.8.11 33 | with: 34 | cache: true 35 | cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} 36 | - name: Run formatting 37 | run: pixi run lint 38 | - name: Build the environment 39 | run: pixi run build 40 | docker: 41 | runs-on: ubuntu-latest 42 | steps: 43 | - name: Checkout repo 44 | uses: actions/checkout@v4 45 | - run: docker build -t ${{ env.REGISTRY_IMAGE }}:latest -f .docker/Dockerfile . 46 | - name: Push to gh registry with sha and branch 47 | # Do not push on PRs from forks. 48 | if: ${{ !github.event.pull_request.head.repo.fork }} 49 | run: | 50 | echo ${{ secrets.GITHUB_TOKEN }} | docker login ghcr.io -u ${{ github.repository_owner }} --password-stdin 51 | # tag with sha and push 52 | docker tag ${{ env.REGISTRY_IMAGE }}:latest ${{ env.REGISTRY_IMAGE }}:${{ github.sha }} 53 | docker push ${{ env.REGISTRY_IMAGE }}:${{ github.sha }} 54 | # tag with branch name and push 55 | export BRANCH_NAME=$(echo ${{ env.RAW_BRANCH_NAME }} | sed 's/\//-/g') 56 | docker tag ${{ env.REGISTRY_IMAGE }}:latest ${{ env.REGISTRY_IMAGE }}:$BRANCH_NAME 57 | docker push ${{ env.REGISTRY_IMAGE }}:$BRANCH_NAME 58 | echo "TAG_NAME=$BRANCH_NAME" >> $GITHUB_ENV 59 | - name: Push to gh registry with latest if this is main 60 | run: | 61 | echo ${{ secrets.GITHUB_TOKEN }} | docker login ghcr.io -u ${{ github.repository_owner }} --password-stdin 62 | # only from main we actually push to latest 63 | docker push ${{ env.REGISTRY_IMAGE }}:latest 64 | echo "TAG_NAME=latest" >> $GITHUB_ENV 65 | if: github.ref == 'refs/heads/main' 66 | slides: 67 | runs-on: ubuntu-latest 68 | steps: 69 | - name: Checkout repo 70 | uses: actions/checkout@v4 71 | - name: Copy slides to main directory 72 | run: cp -r slides/* . 73 | - name: Run pandoc 74 | uses: docker://pandoc/latex:3.7 75 | with: 76 | args: "-t beamer main.md -o slides.pdf --listings --slide-level=2" 77 | - name: Upload slides 78 | uses: actions/upload-artifact@main 79 | with: 80 | name: slides.pdf 81 | path: slides.pdf 82 | - name: Non-latest slides as release 83 | uses: ncipollo/release-action@v1 84 | with: 85 | artifacts: 'slides.pdf' 86 | token: ${{ secrets.GITHUB_TOKEN }} 87 | commit: ${{ github.sha }} 88 | allowUpdates: true 89 | name: Workshop Slides 90 | tag: slides-pr${{ github.event.pull_request.number }} 91 | prerelease: true 92 | if: github.event_name == 'pull_request' 93 | - name: Latest slides as release 94 | uses: ncipollo/release-action@v1 95 | with: 96 | artifacts: 'slides.pdf' 97 | token: ${{ secrets.GITHUB_TOKEN }} 98 | commit: ${{ github.sha }} 99 | allowUpdates: true 100 | name: Workshop Slides 101 | tag: slides-latest 102 | makeLatest: true 103 | if: github.ref == 'refs/heads/main' 104 | pytest: 105 | runs-on: ubuntu-latest 106 | steps: 107 | - name: Checkout repo 108 | uses: actions/checkout@v4 109 | - name: Install python 110 | uses: actions/setup-python@v6 111 | with: 112 | python-version: '3.11' 113 | - name: Install dependencies 114 | run: pip install pytest pytest-md pytest-emoji sybil docker 115 | - name: Set up Docker 116 | uses: docker/setup-docker-action@v4 117 | - name: Run pytest 118 | uses: pavelzw/pytest-action@v2 119 | with: 120 | verbose: true 121 | emoji: true 122 | job-summary: true 123 | custom-arguments: '-vvs test/' 124 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | """Trains an RL policy.""" 10 | 11 | import argparse 12 | from datetime import datetime 13 | 14 | import rclpy 15 | from rclpy.node import Node 16 | from stable_baselines3 import DQN, PPO, SAC, A2C 17 | from stable_baselines3.common.callbacks import ( 18 | EvalCallback, 19 | StopTrainingOnRewardThreshold, 20 | ) 21 | from stable_baselines3.common.base_class import BaseAlgorithm 22 | 23 | from pyrobosim_ros_gym import get_config 24 | from pyrobosim_ros_gym.envs import get_env_by_name, available_envs_w_subtype 25 | 26 | 27 | def get_args() -> argparse.Namespace: 28 | """Helper function to parse the command-line arguments.""" 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--env", 32 | choices=available_envs_w_subtype(), 33 | help="The environment to use.", 34 | required=True, 35 | ) 36 | parser.add_argument( 37 | "--config", 38 | help="Path to the configuration YAML file.", 39 | required=True, 40 | ) 41 | parser.add_argument( 42 | "--algorithm", 43 | default="DQN", 44 | choices=["DQN", "PPO", "SAC", "A2C"], 45 | help="The algorithm with which to train a model.", 46 | ) 47 | parser.add_argument( 48 | "--discrete-actions", 49 | action="store_true", 50 | help="If true, uses discrete action space. Otherwise, uses continuous action space.", 51 | ) 52 | parser.add_argument("--seed", default=42, type=int, help="The RNG seed to use.") 53 | parser.add_argument( 54 | "--realtime", action="store_true", help="If true, slows down to real time." 55 | ) 56 | parser.add_argument( 57 | "--log", 58 | default=True, 59 | action="store_true", 60 | help="If true, logs data to Tensorboard.", 61 | ) 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | if __name__ == "__main__": 67 | args = get_args() 68 | config = get_config(args.config) 69 | 70 | # Create the environment 71 | rclpy.init() 72 | node = Node("pyrobosim_ros_env") 73 | env = get_env_by_name( 74 | args.env, 75 | node, 76 | max_steps_per_episode=25, 77 | realtime=args.realtime, 78 | discrete_actions=args.discrete_actions, 79 | reward_fn=config["training"].get("reward_fn"), 80 | ) 81 | 82 | # Train a model 83 | log_path = "train_logs" if args.log else None 84 | if args.algorithm == "DQN": 85 | dqn_config = config.get("training", {}).get("DQN", {}) 86 | model: BaseAlgorithm = DQN( 87 | "MlpPolicy", 88 | env=env, 89 | seed=args.seed, 90 | tensorboard_log=log_path, 91 | **dqn_config, 92 | ) 93 | elif args.algorithm == "PPO": 94 | ppo_config = config.get("training", {}).get("PPO", {}) 95 | model = PPO( 96 | "MlpPolicy", 97 | env=env, 98 | seed=args.seed, 99 | tensorboard_log=log_path, 100 | **ppo_config, 101 | ) 102 | elif args.algorithm == "SAC": 103 | sac_config = config.get("training", {}).get("SAC", {}) 104 | model = SAC( 105 | "MlpPolicy", 106 | env=env, 107 | seed=args.seed, 108 | tensorboard_log=log_path, 109 | **sac_config, 110 | ) 111 | elif args.algorithm == "A2C": 112 | a2c_config = config.get("training", {}).get("A2C", {}) 113 | model = A2C( 114 | "MlpPolicy", 115 | env=env, 116 | seed=args.seed, 117 | tensorboard_log=log_path, 118 | **a2c_config, 119 | ) 120 | else: 121 | raise RuntimeError(f"Invalid algorithm type: {args.algorithm}") 122 | print(f"\nTraining with {args.algorithm}...\n") 123 | 124 | # Train the model until it exceeds a specified reward threshold in evals. 125 | training_config = config["training"] 126 | callback_on_best = StopTrainingOnRewardThreshold( 127 | reward_threshold=training_config["eval"]["reward_threshold"], 128 | verbose=1, 129 | ) 130 | eval_callback = EvalCallback( 131 | env, 132 | callback_on_new_best=callback_on_best, 133 | verbose=1, 134 | eval_freq=training_config["eval"]["eval_freq"], 135 | n_eval_episodes=training_config["eval"]["n_eval_episodes"], 136 | ) 137 | 138 | date_str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 139 | log_name = f"{args.env}_{args.algorithm}_seed{args.seed}_{date_str}" 140 | model.learn( 141 | total_timesteps=training_config["max_training_steps"], 142 | progress_bar=True, 143 | tb_log_name=log_name, 144 | log_interval=1, 145 | callback=eval_callback, 146 | ) 147 | 148 | # Save the trained model 149 | model_name = f"{log_name}.pt" 150 | model.save(model_name) 151 | print(f"\nSaved model to {model_name}\n") 152 | 153 | rclpy.shutdown() 154 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/envs/pyrobosim_ros_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | import time 10 | from enum import Enum 11 | from functools import partial 12 | 13 | import gymnasium as gym 14 | import rclpy 15 | from rclpy.action import ActionClient 16 | 17 | from pyrobosim_msgs.action import ExecuteTaskAction 18 | from pyrobosim_msgs.srv import ( 19 | RequestWorldInfo, 20 | RequestWorldState, 21 | ResetWorld, 22 | SetLocationState, 23 | ) 24 | 25 | 26 | class PyRoboSimRosEnv(gym.Env): 27 | """Gym environment wrapping around the PyRoboSim ROS Interface.""" 28 | 29 | sub_types = Enum("sub_types", "DEFINE_IN_SUBCLASS") 30 | 31 | def __init__( 32 | self, 33 | node, 34 | reward_fn, 35 | reset_validation_fn=None, 36 | max_steps_per_episode=50, 37 | realtime=True, 38 | discrete_actions=True, 39 | executor=None, 40 | ): 41 | """ 42 | Instantiates a PyRoboSim ROS environment. 43 | 44 | :param node: The ROS node to use for creating clients. 45 | :param reward_fn: Function that calculates the reward (and possibly other outputs). 46 | :param reset_validation_fn: Function that calculates whether a reset is valid. 47 | If None (default), all resets are valid. 48 | :param max_steps_per_episode: Maximum number of steps before truncating an episode. 49 | If -1, there is no limit to number of steps. 50 | :param realtime: If True, commands PyRoboSim to run actions in real time. 51 | If False, actions run as quickly as possible for faster training. 52 | :param discrete_actions: If True, uses discrete actions, else uses continuous. 53 | :param executor: Optional ROS executor. It must be already spinning! 54 | """ 55 | super().__init__() 56 | self.node = node 57 | self.executor = executor 58 | if self.executor is not None: 59 | self.executor.add_node(self.node) 60 | self.executor.wake() 61 | 62 | self.realtime = realtime 63 | self.max_steps_per_episode = max_steps_per_episode 64 | self.discrete_actions = discrete_actions 65 | 66 | if reward_fn is None: 67 | self.reward_fn = lambda _: 0.0 68 | else: 69 | self.reward_fn = partial(reward_fn, self) 70 | 71 | if reset_validation_fn is None: 72 | self.reset_validation_fn = lambda: True 73 | else: 74 | self.reset_validation_fn = lambda: reset_validation_fn(self) 75 | 76 | self.step_number = 0 77 | self.previous_location = None 78 | self.previous_action_type = None 79 | 80 | self.request_info_client = node.create_client( 81 | RequestWorldInfo, "/request_world_info" 82 | ) 83 | self.request_state_client = node.create_client( 84 | RequestWorldState, "/request_world_state" 85 | ) 86 | self.execute_action_client = ActionClient( 87 | node, ExecuteTaskAction, "/execute_action" 88 | ) 89 | self.reset_world_client = node.create_client(ResetWorld, "reset_world") 90 | self.set_location_state_client = node.create_client( 91 | SetLocationState, "set_location_state" 92 | ) 93 | 94 | self.request_info_client.wait_for_service() 95 | self.request_state_client.wait_for_service() 96 | self.execute_action_client.wait_for_server() 97 | self.reset_world_client.wait_for_service() 98 | self.set_location_state_client.wait_for_service() 99 | 100 | future = self.request_info_client.call_async(RequestWorldInfo.Request()) 101 | self._spin_future(future) 102 | self.world_info = future.result().info 103 | 104 | future = self.request_state_client.call_async(RequestWorldState.Request()) 105 | self._spin_future(future) 106 | self.world_state = future.result().state 107 | 108 | self.all_locations = [] 109 | for loc in self.world_state.locations: 110 | self.all_locations.extend(loc.spawns) 111 | self.num_locations = sum(len(loc.spawns) for loc in self.world_state.locations) 112 | self.loc_to_idx = {loc: idx for idx, loc in enumerate(self.all_locations)} 113 | print(f"{self.all_locations=}") 114 | 115 | self.action_space = self._action_space() 116 | print(f"{self.action_space=}") 117 | 118 | def _spin_future(self, future): 119 | if self.executor is None: 120 | rclpy.spin_until_future_complete(self.node, future) 121 | else: 122 | while not future.done(): 123 | time.sleep(0.1) 124 | 125 | def _action_space(self): 126 | raise NotImplementedError("implement in sub-class") 127 | 128 | def initialize(self): 129 | """Resets helper variables for deployment without doing a full reset.""" 130 | raise NotImplementedError("implement in sub-class") 131 | 132 | def step(self, action): 133 | raise NotImplementedError("implement in sub-class") 134 | 135 | def reset(self, seed=None, options=None): 136 | """Resets the environment with a specified seed and options.""" 137 | print(f"Resetting environment with {seed=}") 138 | super().reset(seed=seed) 139 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | """Evaluates a trained RL policy.""" 10 | 11 | import argparse 12 | 13 | import rclpy 14 | from gymnasium.spaces import Discrete 15 | from rclpy.node import Node 16 | from stable_baselines3.common.base_class import BaseAlgorithm 17 | 18 | from pyrobosim_ros_gym import get_config 19 | from pyrobosim_ros_gym.envs import available_envs_w_subtype, get_env_by_name 20 | from pyrobosim_ros_gym.policies import ManualPolicy, model_and_env_type_from_path 21 | 22 | MANUAL_STR = "manual" 23 | 24 | 25 | def get_args() -> argparse.Namespace: 26 | """Helper function to parse the command-line arguments.""" 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--model", 30 | type=str, 31 | help=f"The path of the model to evaluate. Can be '{MANUAL_STR}' for manual control.", 32 | ) 33 | parser.add_argument( 34 | "--env", 35 | type=str, 36 | help=f"The name of the environment to use if '--model {MANUAL_STR}' is selected.", 37 | choices=available_envs_w_subtype(), 38 | ) 39 | parser.add_argument( 40 | "--config", 41 | help="Path to the configuration YAML file.", 42 | required=True, 43 | ) 44 | parser.add_argument( 45 | "--num-episodes", 46 | default=3, 47 | type=int, 48 | help="The number of episodes to evaluate.", 49 | ) 50 | parser.add_argument("--seed", default=42, type=int, help="The RNG seed to use.") 51 | parser.add_argument( 52 | "--realtime", action="store_true", help="If true, slows down to real time." 53 | ) 54 | args = parser.parse_args() 55 | 56 | # Ensure '--env' is provided if '--model' is 'manual' 57 | if args.model == MANUAL_STR and not args.env: 58 | parser.error(f"--env must be specified when --model is '{MANUAL_STR}'.") 59 | if args.env and args.model is None: 60 | print("--env is specified but --model is not. Defaulting to manual control.") 61 | args.model = MANUAL_STR 62 | return args 63 | 64 | 65 | if __name__ == "__main__": 66 | args = get_args() 67 | config = get_config(args.config) 68 | 69 | rclpy.init() 70 | node = Node("pyrobosim_ros_env") 71 | 72 | # Load the model and environment 73 | model: BaseAlgorithm | ManualPolicy 74 | if args.model == MANUAL_STR: 75 | env = get_env_by_name( 76 | args.env, 77 | node, 78 | max_steps_per_episode=15, 79 | realtime=True, 80 | discrete_actions=True, 81 | reward_fn=config["training"].get("reward_fn"), 82 | ) 83 | model = ManualPolicy(env.action_space) 84 | else: 85 | model, env_type = model_and_env_type_from_path(args.model) 86 | env = get_env_by_name( 87 | env_type, 88 | node, 89 | max_steps_per_episode=15, 90 | realtime=args.realtime, 91 | discrete_actions=isinstance(model.action_space, Discrete), 92 | reward_fn=config["training"].get("reward_fn"), 93 | ) 94 | 95 | # Evaluate the model for some steps 96 | num_successful_episodes = 0 97 | reward_per_episode = [0.0 for _ in range(args.num_episodes)] 98 | custom_metrics_store: dict[str, list[float]] = {} 99 | custom_metrics_episode_mean: dict[str, float] = {} 100 | custom_metrics_per_episode: dict[str, list[float]] = {} 101 | for i_e in range(args.num_episodes): 102 | print(f">>> Starting episode {i_e+1}/{args.num_episodes}") 103 | obs, _ = env.reset(seed=i_e + args.seed) 104 | terminated = False 105 | truncated = False 106 | i_step = 0 107 | while not (terminated or truncated): 108 | print(f"{obs=}") 109 | action, _ = model.predict(obs, deterministic=True) 110 | print(f"{action=}") 111 | obs, reward, terminated, truncated, info = env.step(action) 112 | custom_metrics = info.get("metrics", {}) 113 | 114 | print(f"{reward=}") 115 | reward_per_episode[i_e] += reward 116 | 117 | print(f"{custom_metrics=}") 118 | for k, v in custom_metrics.items(): 119 | if k not in custom_metrics_store: 120 | custom_metrics_store[k] = [] 121 | custom_metrics_store[k].append(v) 122 | 123 | print(f"{terminated=}") 124 | print(f"{truncated=}") 125 | print("." * 10) 126 | if terminated or truncated: 127 | success = info["success"] 128 | if success: 129 | num_successful_episodes += 1 130 | print(f"<<< Episode {i_e+1} finished with {success=}.") 131 | print(f"Total reward: {reward_per_episode[i_e]}") 132 | for k, v in custom_metrics_store.items(): 133 | mean_metric = sum(v) / len(v) if len(v) > 0 else 0.0 134 | custom_metrics_episode_mean[k] = mean_metric 135 | if k not in custom_metrics_per_episode: 136 | custom_metrics_per_episode[k] = [] 137 | custom_metrics_per_episode[k].append(mean_metric) 138 | print(f"Mean {k}: {mean_metric}") 139 | print("=" * 20) 140 | break 141 | 142 | print("Summary:") 143 | success_percent = 100.0 * num_successful_episodes / args.num_episodes 144 | print( 145 | f"Successful episodes: {num_successful_episodes} / {args.num_episodes} " 146 | f"({success_percent:.2f}%)" 147 | ) 148 | print(f"Reward over {args.num_episodes} episodes:") 149 | print(f" Mean: {sum(reward_per_episode)/args.num_episodes}") 150 | print(f" Min: {min(reward_per_episode)}") 151 | print(f" Max: {max(reward_per_episode)}") 152 | for k, v in custom_metrics_per_episode.items(): 153 | print(f"Custom metric '{k}' over {args.num_episodes} episodes:") 154 | print(f" Mean: {sum(v)/args.num_episodes}") 155 | print(f" Min: {min(v)}") 156 | print(f" Max: {max(v)}") 157 | 158 | rclpy.shutdown() 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning for Deliberation in ROS 2 2 | 3 | This repository contains materials for the [ROSCon 2025](https://roscon.ros.org/2025/) workshop on ROS 2 Deliberation Technologies. 4 | 5 | > [!NOTE] 6 | > This was moved here from . 7 | 8 | ## Setup 9 | 10 | This repo uses Pixi and RoboStack along with ROS 2 Kilted. 11 | 12 | First, install dependencies on your system (assuming you are using Linux). 13 | 14 | 15 | 21 | 22 | 23 | ```bash 24 | sudo apt install build-essential curl 25 | ``` 26 | 27 | Then, install Pixi. 28 | 29 | ```bash 30 | curl -fsSL https://pixi.sh/install.sh | sh 31 | ``` 32 | 33 | 39 | 40 | Clone the repo including submodules. 41 | 42 | ```bash 43 | git clone --recursive https://github.com/ros-wg-delib/rl_deliberation.git 44 | ``` 45 | 46 | Build the environment. 47 | 48 | 49 | ```bash 50 | pixi run build 51 | ``` 52 | 53 | To verify your installation, the following should launch a window of PyRoboSim. 54 | 55 | 56 | ```bash 57 | pixi run start_world --env GreenhousePlain 58 | ``` 59 | 60 | To explore the setup, you can also drop into a shell in the Pixi environment. 61 | 62 | 63 | ```bash 64 | pixi shell 65 | ``` 66 | 67 | ## Explore the environment 68 | 69 | There are different environments available. For example, to run the Greenhouse environment: 70 | 71 | 72 | ```bash 73 | pixi run start_world --env GreenhousePlain 74 | ``` 75 | 76 | All the following commands assume that the environment is running. 77 | You can also run the environment in headless mode for training. 78 | 79 | 80 | ```bash 81 | pixi run start_world --env GreenhousePlain --headless 82 | ``` 83 | 84 | But first, we can explore the environment with a random agent. 85 | 86 | ## Evaluating with a random agent 87 | 88 | Assuming the environment is running, execute the evaluation script in another terminal: 89 | 90 | 91 | ```bash 92 | pixi run eval --config greenhouse_env_config.yaml --model pyrobosim_ros_gym/policies/GreenhousePlain_DQN_random.pt --num-episodes 1 --realtime 93 | ``` 94 | 95 | 100 | 101 | In your terminal, you will see multiple sections in the following format: 102 | 103 | ```plaintext 104 | .......... 105 | obs=array([1. , 0.99194384, 0. , 2.7288349, 0. , 3.3768525, 1.], dtype=float32) 106 | action=array(0) 107 | Maximum steps (10) exceeded. Truncated episode. 108 | reward=0.0 109 | custom_metrics={'watered_plant_fraction': 0.0, 'battery_level': 100.0} 110 | terminated=False 111 | truncated=False 112 | .......... 113 | ``` 114 | 115 | This is one step of the environment and the agent's interaction with it. 116 | 117 | - `obs` is the observation from the environment. It is an array with information about the 3 closest plant objects, with a class label (0 or 1), the distance to each object. It also has the robot's battery level and whether its current location is watered at the end. 118 | - `action` is the action taken by the agent. In this simple example, it can choose between 0 = move on and 1 = water plant. 119 | - `reward` is the reward received after taking the action, which is `0.0` in this case, because the agent did not water any plant. 120 | - `custom_metrics` provides additional information about the episode: 121 | - `watered_plant_fraction` indicates the fraction of plants (between 0 and 1) watered thus far in the episode. 122 | - `battery_level` indicates the current battery level of the robot. (This will not decrease for this environment type, but it will later.) 123 | - `terminated` indicates whether the episode reached a terminal state (e.g., the task was completed or failed). 124 | - `truncated` indicates whether the episode ended due to a time limit. 125 | 126 | In the PyRoboSim window, you should also see the robot moving around at every step. 127 | 128 | At the end of the episode, and after all episodes are completed, you will see some more statistics printed in the terminal. 129 | 130 | ```plaintext 131 | .......... 132 | <<< Episode 1 finished with success=False. 133 | Total reward: 0.0 134 | Mean watered_plant_fraction: 0.0 135 | Mean battery_level: 100.0 136 | ==================== 137 | Summary: 138 | Reward over 1 episodes: 139 | Mean: 0.0 140 | Min: 0.0 141 | Max: 0.0 142 | Custom metric 'watered_plant_fraction' over 1 episodes: 143 | Mean: 0.0 144 | Min: 0.0 145 | Max: 0.0 146 | Custom metric 'battery_level' over 1 episodes: 147 | Mean: 100.0 148 | Min: 100.0 149 | Max: 100.0 150 | ``` 151 | 152 | ## Training a model 153 | 154 | While the environment is running (in headless mode if you prefer), you can train a model. 155 | 156 | ### Choose algorithm type 157 | 158 | For example PPO 159 | 160 | 161 | ```bash 162 | pixi run train --env GreenhousePlain --config greenhouse_env_config.yaml --algorithm PPO --log 163 | ``` 164 | 165 | Or DQN. 166 | Note that this needs the `--discrete-actions` flag. 167 | 168 | 169 | ```bash 170 | pixi run train --env GreenhousePlain --config greenhouse_env_config.yaml --algorithm DQN --discrete-actions --log 171 | ``` 172 | 173 | Note that at the end of training, the model name and path will be printed in the terminal: 174 | 175 | ```plaintext 176 | New best mean reward! 177 | 100% ━━━━━━━━━━━━━━━━━━━━━━━━━ 100/100 [ 0:00:35 < 0:00:00 , 2 it/s ] 178 | 179 | Saved model to GreenhousePlain_PPO_.pt 180 | ``` 181 | 182 | Remember this path, as you will need it later. 183 | 184 | ### You may find tensorboard useful 185 | 186 | 187 | ```bash 188 | pixi run tensorboard 189 | ``` 190 | 191 | It should contain one entry named after your recent training run (e.g. `GreenhousePlain_PPO_`). 192 | 193 | ### See your freshly trained policy in action 194 | 195 | To run an evaluation, execute the following code. 196 | 197 | 198 | ```bash 199 | pixi run eval --config greenhouse_env_config.yaml --model GreenhousePlain_PPO_.pt --num-episodes 3 --realtime 200 | ``` 201 | 202 | or to run more episodes as quickly as possible, launch your world with `--headless` and then execute. 203 | 204 | 205 | ```bash 206 | pixi run eval --config greenhouse_env_config.yaml --model GreenhousePlain_PPO_.pt --num-episodes 20 207 | ``` 208 | 209 | You can also see your trained policy in action as a ROS node. 210 | 211 | 212 | ```bash 213 | pixi run policy_node --config greenhouse_env_config.yaml --model GreenhousePlain_PPO_.pt 214 | ``` 215 | 216 | Then, in a separate terminal, you can send a goal. 217 | 218 | 219 | ```bash 220 | pixi shell 221 | ros2 action send_goal /execute_policy rl_interfaces/ExecutePolicy {} 222 | ``` 223 | 224 | Of course, you can also use this same action interface in your own user code! 225 | -------------------------------------------------------------------------------- /test/test_doc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from collections import OrderedDict 4 | from pathlib import Path 5 | import glob 6 | import pytest 7 | from sybil import Sybil 8 | from sybil.parsers.markdown.codeblock import CodeBlockParser 9 | from sybil.parsers.markdown.lexers import ( 10 | FencedCodeBlockLexer, 11 | DirectiveInHTMLCommentLexer, 12 | ) 13 | import docker 14 | 15 | 16 | COMMAND_PREFIX = "$ " 17 | IGNORED_OUTPUT = "..." 18 | DIR_NEW_ENV = "new-env" 19 | DIR_WORKDIR = "workdir" 20 | DIR_SKIP_NEXT = "skip-next" 21 | DIRECTIVE = "directive" 22 | DIR_CODE_BLOCK = "code-block" 23 | BASH = "bash" 24 | ARGUMENTS = "arguments" 25 | LINE_END = "\\" 26 | CWD = "cwd" 27 | EXPECTED_FILES = "expected-files" 28 | 29 | REPO_ROOT_FOLDER = os.path.join(os.path.dirname(__file__), "..") 30 | 31 | 32 | def evaluate_bash_block(example, cwd): 33 | """Executes a command and compares it's output to the provided expected output. 34 | 35 | ```bash 36 | command 37 | ``` 38 | """ 39 | print(f"{example=}") 40 | lines = example.strip().split("\n") 41 | output = [] 42 | output_i = -1 43 | previous_cmd_line = "" 44 | next_is_cmd_continuation = False 45 | for line in lines: 46 | print(f"{line=}") 47 | if line.startswith(COMMAND_PREFIX) or next_is_cmd_continuation: 48 | # this is a command 49 | command = previous_cmd_line + line.replace(COMMAND_PREFIX, "") 50 | if command.endswith(LINE_END): 51 | # this must be merged with the next line 52 | previous_cmd_line = command.replace(LINE_END, "") 53 | next_is_cmd_continuation = True 54 | continue 55 | next_is_cmd_continuation = False 56 | print(f"{command=}") 57 | previous_cmd_line = "" 58 | # output = ( 59 | # subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True, cwd=cwd) 60 | # .strip() 61 | # .decode("ascii") 62 | # ) 63 | output = "" 64 | print(f"{output=}") 65 | output = [x.strip() for x in output.split("\n")] 66 | output_i = 0 67 | else: 68 | # this is expected output 69 | expected_line = line.strip() 70 | if len(expected_line) == 0: 71 | continue 72 | if output_i >= len(output): 73 | # end of captured output 74 | output_i = -1 75 | continue 76 | if output_i == -1: 77 | continue 78 | actual_line = output[output_i] 79 | while len(actual_line) == 0: 80 | # skip empty lines 81 | output_i += 1 82 | if output_i >= len(output): 83 | output_i = -1 84 | continue 85 | actual_line = output[output_i] 86 | if IGNORED_OUTPUT in expected_line: 87 | # skip this line 88 | output_i += 1 89 | continue 90 | assert actual_line == expected_line 91 | output_i += 1 92 | 93 | 94 | def collect_docs(): 95 | """Search for *.md files.""" 96 | assert os.path.exists(REPO_ROOT_FOLDER), f"Path must exist: {REPO_ROOT_FOLDER}" 97 | 98 | pattern = os.path.join(REPO_ROOT_FOLDER, "*.md") 99 | all_md_files = glob.glob(pattern) 100 | print(f"Found {len(all_md_files)} .md files by {pattern}.") 101 | 102 | # Configure Sybil 103 | fenced_lexer = FencedCodeBlockLexer(BASH) 104 | new_env_lexer = DirectiveInHTMLCommentLexer(directive=DIR_NEW_ENV) 105 | workdir_lexer = DirectiveInHTMLCommentLexer(directive=DIR_WORKDIR) 106 | skip_lexer = DirectiveInHTMLCommentLexer(directive=DIR_SKIP_NEXT) 107 | sybil = Sybil( 108 | parsers=[fenced_lexer, new_env_lexer, workdir_lexer, skip_lexer], 109 | filenames=all_md_files, 110 | ) 111 | documents = [] 112 | for f_path in all_md_files: 113 | doc = sybil.parse(Path(f_path)) 114 | rel_path = os.path.relpath(f_path, REPO_ROOT_FOLDER) 115 | if len(list(doc)) > 0: 116 | documents.append([rel_path, list(doc)]) 117 | print(f"Found {len(documents)} .md files with code to test.") 118 | return documents 119 | 120 | 121 | @pytest.mark.parametrize("path, blocks", collect_docs()) 122 | def test_doc_md(path, blocks): 123 | """Testing all code blocks in one *.md file under `path`.""" 124 | print(f"Testing {len(blocks)} code blocks in {path}.") 125 | env_blocks = OrderedDict() 126 | env_workdirs = OrderedDict() 127 | env_arguments = OrderedDict() 128 | current_env = "ENV DEFAULT" 129 | workdir = None 130 | skip_next_block = False 131 | for block in blocks: 132 | if DIRECTIVE in block.region.lexemes: 133 | directive = block.region.lexemes[DIRECTIVE] 134 | if directive == DIR_NEW_ENV: 135 | arguments = block.region.lexemes["arguments"] 136 | assert arguments not in env_blocks.keys() 137 | current_env = f"ENV {block.path}:{block.line}" 138 | if ARGUMENTS in block.region.lexemes: 139 | env_arguments[current_env] = block.region.lexemes[ARGUMENTS] 140 | elif directive == DIR_WORKDIR: 141 | workdir = block.region.lexemes[ARGUMENTS] 142 | elif directive == DIR_SKIP_NEXT: 143 | skip_next_block = True 144 | else: 145 | raise RuntimeError(f"Unsupported directive {directive}.") 146 | else: 147 | if skip_next_block: 148 | skip_next_block = False 149 | continue 150 | language = block.region.lexemes["language"] 151 | source = block.region.lexemes["source"] 152 | assert language == BASH, f"Unsupported language {language}" 153 | if current_env not in env_blocks.keys(): 154 | env_blocks[current_env] = [] 155 | env_workdirs[current_env] = [] 156 | env_blocks[current_env].append(source) 157 | env_workdirs[current_env].append(workdir) 158 | workdir = None 159 | # After preprocessing all the environments, evaluate them. 160 | for env, blocks in env_blocks.items(): 161 | # Get arguments 162 | assert env in env_arguments, "Environment must have arguments." 163 | arguments = env_arguments[env] 164 | print(f"Evaluating environment >{env}< with {arguments=} ...") 165 | client = docker.from_env() 166 | arguments = arguments.split(" ") 167 | assert len(arguments) == 1, f"Expecting one argument. Got {arguments}" 168 | image = arguments[0] 169 | client.images.pull(image) 170 | container = client.containers.run( 171 | image=image, command="sleep 1d", auto_remove=True, detach=True 172 | ) 173 | # Get workdirs 174 | assert env in env_workdirs, "Environment must have working directories." 175 | workdirs = env_workdirs[env] 176 | assert len(blocks) == len(workdirs), "Must have one workdir entry per block." 177 | try: 178 | for block, workdir in zip(blocks, workdirs): 179 | for line in block.split("\n"): 180 | command = f'bash -ic "{line.strip()}"' 181 | if len(line) == 0: 182 | continue 183 | print(f"Executing {command=}") 184 | returncode, output = container.exec_run( 185 | command, tty=True, workdir=workdir 186 | ) 187 | output = output.decode("utf-8") 188 | if returncode == 0: 189 | print("✅") 190 | print("output ...") 191 | print(output) 192 | else: 193 | print(f"❌ {returncode=}") 194 | assert ( 195 | returncode == 0 196 | ), f"Unexpected Returncode: {returncode=}\noutput ...\n{output}" 197 | except Exception as e: 198 | print(e) 199 | assert False, str(e) 200 | finally: 201 | container.stop() 202 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/envs/banana.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | """Utilities for the banana test environment.""" 10 | 11 | from pprint import pprint 12 | import os 13 | from enum import Enum 14 | 15 | import numpy as np 16 | from gymnasium import spaces 17 | 18 | from pyrobosim_msgs.action import ExecuteTaskAction 19 | from pyrobosim_msgs.msg import ExecutionResult, TaskAction 20 | from pyrobosim_msgs.srv import RequestWorldState, ResetWorld 21 | 22 | from .pyrobosim_ros_env import PyRoboSimRosEnv 23 | 24 | 25 | class BananaEnv(PyRoboSimRosEnv): 26 | sub_types = Enum("sub_types", "Pick Place PlaceNoSoda") 27 | 28 | @classmethod 29 | def get_world_file_path(cls, _: sub_types) -> str: 30 | """Get the world file path for a given subtype.""" 31 | return os.path.join("rl_ws_worlds", "worlds", "banana.yaml") 32 | 33 | def __init__( 34 | self, 35 | sub_type: sub_types, 36 | node, 37 | max_steps_per_episode, 38 | realtime, 39 | discrete_actions, 40 | reward_fn=None, 41 | executor=None, 42 | ): 43 | """ 44 | Instantiate Banana environment. 45 | 46 | :param sub_types: Subtype of this environment (e.g. `BananaEnv.sub_types.Pick`). 47 | :param node: Node instance needed for ROS communication. 48 | :param max_steps_per_episode: Limit the steps (when to end the episode). 49 | If -1, there is no limit to number of steps. 50 | :param realtime: Whether actions take time. 51 | :param discrete_actions: Choose discrete actions (needed for DQN). 52 | :param reward_fn: Function that calculates the reward and termination criteria. 53 | The first argument needs to be the environment itself. 54 | The output needs to be a (reward, terminated, info) tuple. 55 | If not specified, uses the default for that environment. 56 | :param executor: Optional ROS executor. It must be already spinning! 57 | """ 58 | if sub_type == BananaEnv.sub_types.Pick: 59 | reward_fn = reward_fn or banana_picked_reward 60 | reset_validation_fn = None 61 | # eval_freq = 1000 62 | elif sub_type == BananaEnv.sub_types.Place: 63 | reward_fn = reward_fn or banana_on_table_reward 64 | reset_validation_fn = None 65 | # eval_freq = 2000 66 | elif sub_type == BananaEnv.sub_types.PlaceNoSoda: 67 | reward_fn = reward_fn or banana_on_table_avoid_soda_reward 68 | reset_validation_fn = avoid_soda_reset_validation 69 | # eval_freq = 2000 70 | else: 71 | raise ValueError(f"Invalid environment: {sub_type}") 72 | 73 | super().__init__( 74 | node, 75 | reward_fn, 76 | reset_validation_fn, 77 | max_steps_per_episode, 78 | realtime, 79 | discrete_actions, 80 | executor=executor, 81 | ) 82 | 83 | self.num_locations = sum(len(loc.spawns) for loc in self.world_state.locations) 84 | self.loc_to_idx = {loc: idx for idx, loc in enumerate(self.all_locations)} 85 | 86 | self.num_object_types = len(self.world_info.object_categories) 87 | self.obj_to_idx = { 88 | obj: idx for idx, obj in enumerate(self.world_info.object_categories) 89 | } 90 | 91 | # Observation space is defined by: 92 | # Previous action 93 | # Location of robot 94 | # Type of object robot is holding (if any) 95 | # Whether there is at least one of a specific object type at each location 96 | self.obs_size = ( 97 | self.num_locations # Number of locations robot can be in 98 | + self.num_object_types # Object types robot is holding 99 | + ( 100 | self.num_locations * self.num_object_types 101 | ) # Number of object categories per location 102 | ) 103 | 104 | self.observation_space = spaces.Box( 105 | low=-np.ones(self.obs_size, dtype=np.float32), 106 | high=np.ones(self.obs_size, dtype=np.float32), 107 | ) 108 | print(f"{self.observation_space=}") 109 | 110 | self.initialize() 111 | 112 | def _action_space(self): 113 | # Action space is defined by: 114 | # Move: To all possible object spawns 115 | # Pick: All possible object categories 116 | # Place: The current manipulated object 117 | idx = 0 118 | self.integer_to_action = {} 119 | for loc in self.all_locations: 120 | self.integer_to_action[idx] = TaskAction( 121 | type="navigate", target_location=loc 122 | ) 123 | idx += 1 124 | for obj_category in self.world_info.object_categories: 125 | self.integer_to_action[idx] = TaskAction(type="pick", object=obj_category) 126 | idx += 1 127 | self.integer_to_action[idx] = TaskAction(type="place") 128 | self.num_actions = len(self.integer_to_action) 129 | print("self.integer_to_action=") 130 | pprint(self.integer_to_action) 131 | 132 | if self.discrete_actions: 133 | return spaces.Discrete(self.num_actions) 134 | else: 135 | return spaces.Box( 136 | low=np.zeros(self.num_actions, dtype=np.float32), 137 | high=np.ones(self.num_actions, dtype=np.float32), 138 | ) 139 | 140 | def step(self, action): 141 | """Steps the environment with a specific action.""" 142 | self.previous_location = self.world_state.robots[0].last_visited_location 143 | 144 | goal = ExecuteTaskAction.Goal() 145 | if self.discrete_actions: 146 | goal.action = self.integer_to_action[int(action)] 147 | else: 148 | goal.action = self.integer_to_action[np.argmax(action)] 149 | goal.action.robot = "robot" 150 | goal.realtime_factor = 1.0 if self.realtime else -1.0 151 | 152 | goal_future = self.execute_action_client.send_goal_async(goal) 153 | self._spin_future(goal_future) 154 | 155 | result_future = goal_future.result().get_result_async() 156 | self._spin_future(result_future) 157 | 158 | action_result = result_future.result().result 159 | self.step_number += 1 160 | truncated = (self.max_steps_per_episode >= 0) and ( 161 | self.step_number >= self.max_steps_per_episode 162 | ) 163 | if truncated: 164 | print( 165 | f"Maximum steps ({self.max_steps_per_episode}) exceeded. Truncated episode." 166 | ) 167 | 168 | observation = self._get_obs() 169 | reward, terminated, info = self.reward_fn(goal, action_result) 170 | self.previous_action_type = goal.action.type 171 | 172 | info["metrics"] = { 173 | "at_banana_location": float(is_at_banana_location(self)), 174 | "holding_banana": float(is_holding_banana(self)), 175 | } 176 | return observation, reward, terminated, truncated, info 177 | 178 | def initialize(self): 179 | self.step_number = 0 180 | 181 | def reset(self, seed=None, options=None): 182 | super().reset(seed=seed) 183 | self.initialize() 184 | info = {} 185 | 186 | valid_reset = False 187 | num_reset_attempts = 0 188 | while not valid_reset: 189 | future = self.reset_world_client.call_async( 190 | ResetWorld.Request(seed=(seed or -1)) 191 | ) 192 | self._spin_future(future) 193 | 194 | observation = self._get_obs() 195 | 196 | valid_reset = self.reset_validation_fn() 197 | num_reset_attempts += 1 198 | seed = None # subsequent resets need to not use a fixed seed 199 | 200 | print(f"Reset environment in {num_reset_attempts} attempt(s).") 201 | return observation, info 202 | 203 | def _get_obs(self): 204 | """Calculate the observation. All elements are either -1.0 or +1.0.""" 205 | future = self.request_state_client.call_async(RequestWorldState.Request()) 206 | self._spin_future(future) 207 | world_state = future.result().state 208 | robot_state = world_state.robots[0] 209 | 210 | obs = -np.ones(self.obs_size, dtype=np.float32) 211 | 212 | # Robot's current location 213 | if robot_state.last_visited_location in self.loc_to_idx: 214 | loc_idx = self.loc_to_idx[robot_state.last_visited_location] 215 | obs[loc_idx] = 1.0 216 | 217 | # Object categories per location (including currently held object, if any) 218 | for obj in world_state.objects: 219 | obj_idx = self.obj_to_idx[obj.category] 220 | if obj.name == robot_state.manipulated_object: 221 | obs[self.num_locations + obj_idx] = 1.0 222 | else: 223 | loc_idx = self.loc_to_idx[obj.parent] 224 | obs[ 225 | self.num_locations 226 | + self.num_object_types 227 | + (loc_idx * self.num_object_types) 228 | + obj_idx 229 | ] = 1.0 230 | 231 | self.world_state = world_state 232 | return obs 233 | 234 | 235 | def is_at_banana_location(env): 236 | robot_state = env.world_state.robots[0] 237 | for obj in env.world_state.objects: 238 | if obj.category == "banana": 239 | if obj.parent == robot_state.last_visited_location: 240 | return True 241 | return False 242 | 243 | 244 | def is_holding_banana(env): 245 | robot_state = env.world_state.robots[0] 246 | for obj in env.world_state.objects: 247 | if obj.category == "banana" and obj.name == robot_state.manipulated_object: 248 | return True 249 | return False 250 | 251 | 252 | def banana_picked_reward(env, goal, action_result): 253 | """ 254 | Checks whether the robot has picked a banana. 255 | 256 | :param: The environment. 257 | :goal: The ROS action goal sent to the robot. 258 | :action_result: The result of the above goal. 259 | : 260 | :return: A tuple of (reward, terminated, info) 261 | """ 262 | # Calculate reward 263 | reward = 0.0 264 | terminated = False 265 | info = {"success": False} 266 | 267 | # Discourage repeating the same navigation action or failing to pick/place. 268 | if (goal.action.type == "navigate") and ( 269 | goal.action.target_location == env.previous_location 270 | ): 271 | reward -= 1.0 272 | if action_result.execution_result.status != ExecutionResult.SUCCESS: 273 | reward -= 1.0 274 | # Discourage repeat action types. 275 | if goal.action.type == env.previous_action_type: 276 | reward -= 0.5 277 | # Robot gets positive reward based on holding a banana, 278 | # and negative reward for being in locations without bananas. 279 | at_banana_location = is_at_banana_location(env) 280 | if is_holding_banana(env): 281 | reward += 10.0 282 | terminated = True 283 | at_banana_location = True 284 | info["success"] = True 285 | print(f"🍌 Picked banana. Episode succeeded in {env.step_number} steps!") 286 | 287 | # Reward shaping: Penalty if the robot is not at a location containing a banana. 288 | if not terminated and not at_banana_location: 289 | reward -= 0.5 290 | return reward, terminated, info 291 | 292 | 293 | def banana_on_table_reward(env, goal, action_result): 294 | """ 295 | Checks whether the robot has placed a banana on the table. 296 | 297 | :param: The environment. 298 | :goal: The ROS action goal sent to the robot. 299 | :action_result: The result of the above goal. 300 | : 301 | :return: A tuple of (reward, terminated, info) 302 | """ 303 | # Calculate reward 304 | reward = 0.0 305 | terminated = False 306 | info = {"success": False} 307 | 308 | robot_state = env.world_state.robots[0] 309 | # Discourage repeating the same navigation action or failing to pick/place. 310 | if (goal.action.type == "navigate") and ( 311 | goal.action.target_location == env.previous_location 312 | ): 313 | reward -= 1.0 314 | if action_result.execution_result.status != ExecutionResult.SUCCESS: 315 | reward -= 1.0 316 | # Discourage repeat action types. 317 | if goal.action.type == env.previous_action_type: 318 | reward -= 0.5 319 | # Robot gets positive reward based on a banana being at the table. 320 | at_banana_location = False 321 | holding_banana = False 322 | for obj in env.world_state.objects: 323 | if obj.category == "banana": 324 | if obj.parent == robot_state.last_visited_location: 325 | at_banana_location = True 326 | if obj.parent == "table0_tabletop": 327 | print( 328 | f"🍌 Placed banana on the table. " 329 | f"Episode succeeded in {env.step_number} steps!" 330 | ) 331 | reward += 10.0 332 | terminated = True 333 | info["success"] = True 334 | break 335 | 336 | if obj.category == "banana" and obj.name == robot_state.manipulated_object: 337 | holding_banana = True 338 | 339 | # Reward shaping: Adjust the reward related to how close the robot is to completing the task. 340 | if not terminated: 341 | if holding_banana and env.previous_location != "table0_tabletop": 342 | reward += 0.25 343 | elif holding_banana: 344 | reward += 0.1 345 | elif at_banana_location: 346 | reward -= 0.1 347 | else: 348 | reward -= 0.25 349 | 350 | return reward, terminated, info 351 | 352 | 353 | def banana_on_table_avoid_soda_reward(env, goal, action_result): 354 | """ 355 | Checks whether the robot has placed a banana on the table without touching the soda. 356 | 357 | :param: The environment. 358 | :goal: The ROS action goal sent to the robot. 359 | :action_result: The result of the above goal. 360 | : 361 | :return: A tuple of (reward, terminated, info) 362 | """ 363 | # Start with the same reward as the no-soda case. 364 | reward, terminated, info = banana_on_table_reward(env, goal, action_result) 365 | 366 | # Robot gets additional negative reward being near a soda, 367 | # and fails if it tries to pick or place at a location with a soda. 368 | robot_state = env.world_state.robots[0] 369 | for obj in env.world_state.objects: 370 | if obj.category == "coke" and obj.parent == robot_state.last_visited_location: 371 | if goal.action.type == "navigate": 372 | reward -= 2.5 373 | else: 374 | print( 375 | "🔥 Tried to pick and place near a soda. " 376 | f"Episode failed in {env.step_number} steps!" 377 | ) 378 | reward -= 25.0 379 | terminated = True 380 | info["success"] = False 381 | 382 | return reward, terminated, info 383 | 384 | 385 | def avoid_soda_reset_validation(env): 386 | """ 387 | Checks whether an environment has been properly reset to avoid soda. 388 | 389 | Specifically, we are checking that: 390 | - There is at least one banana not next to a soda 391 | - The robot is not at a starting location where there is a soda 392 | 393 | :param: The environment. 394 | :return: True if valid, else False. 395 | """ 396 | soda_location = None 397 | for obj in env.world_state.objects: 398 | if obj.category == "coke": 399 | soda_location = obj.parent 400 | 401 | valid_banana_locations = False 402 | for obj in env.world_state.objects: 403 | if obj.category == "banana" and obj.parent != soda_location: 404 | valid_banana_locations = True 405 | 406 | robot_location = env.world_state.robots[0].last_visited_location 407 | valid_robot_location = robot_location != soda_location 408 | 409 | return valid_banana_locations and valid_robot_location 410 | -------------------------------------------------------------------------------- /pyrobosim_ros_gym/envs/greenhouse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD 3-Clause License. 7 | # See the LICENSE file in the project root for license information. 8 | 9 | """Utilities for the greenhouse test environment.""" 10 | 11 | from enum import Enum 12 | import os 13 | 14 | import numpy as np 15 | from geometry_msgs.msg import Point 16 | from gymnasium import spaces 17 | 18 | from pyrobosim_msgs.action import ExecuteTaskAction 19 | from pyrobosim_msgs.msg import TaskAction, WorldState 20 | from pyrobosim_msgs.srv import RequestWorldState, ResetWorld 21 | 22 | from .pyrobosim_ros_env import PyRoboSimRosEnv 23 | 24 | 25 | def _dist(a: Point, b: Point) -> float: 26 | """Calculate distance between two (geometry_msgs.msg) Points.""" 27 | return float(np.linalg.norm([a.x - b.x, a.y - b.y, a.z - b.z])) 28 | 29 | 30 | class GreenhouseEnv(PyRoboSimRosEnv): 31 | sub_types = Enum("sub_types", "Plain Battery Random") 32 | 33 | @classmethod 34 | def get_world_file_path(cls, sub_type: sub_types) -> str: 35 | """Get the world file path for a given subtype.""" 36 | if sub_type == GreenhouseEnv.sub_types.Plain: 37 | return os.path.join("rl_ws_worlds", "worlds", "greenhouse_plain.yaml") 38 | elif sub_type == GreenhouseEnv.sub_types.Battery: 39 | return os.path.join("rl_ws_worlds", "worlds", "greenhouse_battery.yaml") 40 | elif sub_type == GreenhouseEnv.sub_types.Random: 41 | return os.path.join("rl_ws_worlds", "worlds", "greenhouse_random.yaml") 42 | else: 43 | raise ValueError(f"Invalid environment: {sub_type}") 44 | 45 | def __init__( 46 | self, 47 | sub_type: sub_types, 48 | node, 49 | max_steps_per_episode, 50 | realtime, 51 | discrete_actions, 52 | reward_fn, 53 | executor=None, 54 | ): 55 | """ 56 | Instantiate Greenhouse environment. 57 | 58 | :param sub_type: Subtype of this environment, e.g. `GreenhouseEnv.sub_types.Deterministic`. 59 | :param node: Node instance needed for ROS communication. 60 | :param max_steps_per_episode: Limit the steps (when to end the episode). 61 | If -1, there is no limit to number of steps. 62 | :param realtime: Whether actions take time. 63 | :param discrete_actions: Choose discrete actions (needed for DQN). 64 | :param reward_fn: Function that calculates the reward and termination criteria. 65 | The first argument needs to be the environment itself. 66 | The output needs to be a (reward, terminated) tuple. 67 | :param executor: Optional ROS executor. It must be already spinning! 68 | """ 69 | if sub_type == GreenhouseEnv.sub_types.Plain: 70 | # All plants are in their places 71 | pass 72 | elif sub_type == GreenhouseEnv.sub_types.Random: 73 | # Plants are randomly across tables 74 | pass 75 | elif sub_type == GreenhouseEnv.sub_types.Battery: 76 | # Battery (= water) is limited 77 | pass 78 | else: 79 | raise ValueError(f"Invalid environment: {sub_type}") 80 | self.sub_type = sub_type 81 | 82 | super().__init__( 83 | node, 84 | reward_fn, 85 | None, # reset_validation_fn 86 | max_steps_per_episode, 87 | realtime, 88 | discrete_actions, 89 | executor=executor, 90 | ) 91 | 92 | # Observation space is defined by: 93 | self.max_n_objects = 3 94 | self.max_dist = 10 95 | # array of n objects with a class and distance each, 96 | # plus current location watered and (optionally) battery level. 97 | self.obs_size = 2 * self.max_n_objects + 1 98 | if self.sub_type == GreenhouseEnv.sub_types.Battery: 99 | self.obs_size += 1 100 | 101 | low = np.zeros(self.obs_size, dtype=np.float32) 102 | high = np.ones(self.obs_size, dtype=np.float32) # max class = 1 103 | high[1 : self.max_n_objects * 2 : 2] = self.max_dist 104 | self.observation_space = spaces.Box(low=low, high=high) 105 | print(f"{self.observation_space=}") 106 | 107 | self.plants = [obj.name for obj in self.world_state.objects] 108 | # print(f"{self.plants=}") 109 | self.good_plants = [ 110 | obj.name for obj in self.world_state.objects if obj.category == "plant_good" 111 | ] 112 | # print(f"{self.good_plants=}") 113 | 114 | self.waypoints = [ 115 | "table_c", 116 | "table_ne", 117 | "table_e", 118 | "table_se", 119 | "table_s", 120 | "table_sw", 121 | "table_w", 122 | "table_nw", 123 | "table_n", 124 | ] 125 | self.initialize() 126 | 127 | def _action_space(self): 128 | if self.sub_type == GreenhouseEnv.sub_types.Battery: 129 | self.num_actions = 3 # stay ducked, water plant, or go charge 130 | else: 131 | self.num_actions = 2 # stay ducked or water plant 132 | 133 | if self.discrete_actions: 134 | return spaces.Discrete(self.num_actions) 135 | else: 136 | return spaces.Box( 137 | low=np.zeros(self.num_actions, dtype=np.float32), 138 | high=np.ones(self.num_actions, dtype=np.float32), 139 | ) 140 | 141 | def step(self, action): 142 | info = {} 143 | truncated = (self.max_steps_per_episode >= 0) and ( 144 | self.step_number >= self.max_steps_per_episode 145 | ) 146 | if truncated: 147 | print( 148 | f"Maximum steps ({self.max_steps_per_episode}) exceeded. " 149 | f"Truncated episode with watered fraction {self.watered_plant_fraction()}." 150 | ) 151 | 152 | # print(f"{'*'*10}") 153 | # print(f"{action=}") 154 | if self.discrete_actions: 155 | action = float(action) 156 | else: 157 | action = np.argmax(action) 158 | 159 | # Execute the current actions before calculating reward 160 | self.previous_battery_level = self.battery_level() 161 | if action == 1: # water a plant 162 | self.mark_table(self.get_current_location()) 163 | elif action == 2: # charge 164 | self.go_to_loc("charger") 165 | 166 | future = self.request_state_client.call_async(RequestWorldState.Request()) 167 | self._spin_future(future) 168 | self.world_state = future.result().state 169 | 170 | self.step_number += 1 171 | reward, terminated = self.reward_fn(action) 172 | # print(f"{reward=}") 173 | 174 | # Execute the remainder of the actions after calculating reward 175 | if not terminated: 176 | if action == 2: # charge 177 | self.go_to_loc(self.get_current_location()) 178 | else: 179 | self.go_to_loc(self.get_next_location()) 180 | 181 | # Update self.world_state and observation after finishing the action 182 | observation = self._get_obs() 183 | # print(f"{observation=}") 184 | 185 | info = { 186 | "success": self.watered_plant_fraction() == 1.0, 187 | "metrics": { 188 | "watered_plant_fraction": float(self.watered_plant_fraction()), 189 | "battery_level": float(self.battery_level()), 190 | }, 191 | } 192 | 193 | return observation, reward, terminated, truncated, info 194 | 195 | def mark_table(self, loc): 196 | close_goal = ExecuteTaskAction.Goal() 197 | close_goal.action = TaskAction() 198 | close_goal.action.robot = "robot" 199 | close_goal.action.type = "close" 200 | close_goal.action.target_location = loc 201 | 202 | goal_future = self.execute_action_client.send_goal_async(close_goal) 203 | self._spin_future(goal_future) 204 | 205 | result_future = goal_future.result().get_result_async() 206 | self._spin_future(result_future) 207 | 208 | def watered_plant_fraction(self): 209 | n_watered = 0 210 | for w in self.watered.values(): 211 | if w: 212 | n_watered += 1 213 | return n_watered / len(self.watered) 214 | 215 | def battery_level(self): 216 | return self.world_state.robots[0].battery_level 217 | 218 | def _get_plants_by_distance(self, world_state: WorldState): 219 | robot_state = world_state.robots[0] 220 | robot_pos = robot_state.pose.position 221 | # print(robot_pos) 222 | 223 | plants_by_distance = {} 224 | for obj in world_state.objects: 225 | pos = obj.pose.position 226 | dist = _dist(robot_pos, pos) 227 | dist = min(dist, self.max_dist) 228 | plants_by_distance[dist] = obj 229 | 230 | return plants_by_distance 231 | 232 | def initialize(self): 233 | self.step_number = 0 234 | self.waypoint_i = -1 235 | self.watered = {plant: False for plant in self.good_plants} 236 | self.previous_battery_level = self.battery_level() 237 | self.go_to_loc(self.get_next_location()) 238 | 239 | def reset(self, seed=None, options=None): 240 | super().reset(seed) 241 | 242 | valid_reset = False 243 | num_reset_attempts = 0 244 | while not valid_reset: 245 | future = self.reset_world_client.call_async( 246 | ResetWorld.Request(seed=(seed or -1)) 247 | ) 248 | self._spin_future(future) 249 | 250 | # Validate that there are no two plants in the same location. 251 | observation = self._get_obs() 252 | valid_reset = True 253 | parent_locs = set() 254 | for obj in self.world_state.objects: 255 | if obj.parent not in parent_locs: 256 | parent_locs.add(obj.parent) 257 | else: 258 | valid_reset = False 259 | break 260 | 261 | num_reset_attempts += 1 262 | seed = None # subsequent resets need to not use a fixed seed 263 | 264 | self.initialize() 265 | print(f"Reset environment in {num_reset_attempts} attempt(s).") 266 | return observation, {} 267 | 268 | def _get_obs(self): 269 | """Calculate the observations""" 270 | future = self.request_state_client.call_async(RequestWorldState.Request()) 271 | self._spin_future(future) 272 | world_state = future.result().state 273 | plants_by_distance = self._get_plants_by_distance(world_state) 274 | 275 | obs = np.zeros(self.obs_size, dtype=np.float32) 276 | start_idx = 0 277 | 278 | for _ in range(self.max_n_objects): 279 | closest_d = min(plants_by_distance.keys()) 280 | plant = plants_by_distance.pop(closest_d) 281 | plant_class = 0 if plant.category == "plant_good" else 1 282 | obs[start_idx] = plant_class 283 | obs[start_idx + 1] = closest_d 284 | start_idx += 2 285 | 286 | cur_loc = world_state.robots[0].last_visited_location 287 | for loc in world_state.locations: 288 | if cur_loc == loc.name or cur_loc in loc.spawns: 289 | if not loc.is_open: 290 | obs[start_idx] = 1.0 # closed = watered 291 | break 292 | 293 | if self.sub_type == self.sub_type.Battery: 294 | obs[start_idx + 1] = self.battery_level() / 100.0 295 | 296 | self.world_state = world_state 297 | return obs 298 | 299 | def get_next_location(self): 300 | self.waypoint_i = (self.waypoint_i + 1) % len(self.waypoints) 301 | return self.get_current_location() 302 | 303 | def get_current_location(self): 304 | return self.waypoints[self.waypoint_i] 305 | 306 | def go_to_loc(self, loc: str): 307 | nav_goal = ExecuteTaskAction.Goal() 308 | nav_goal.action = TaskAction(type="navigate", target_location=loc) 309 | nav_goal.action.robot = "robot" 310 | nav_goal.realtime_factor = 1.0 if self.realtime else -1.0 311 | 312 | goal_future = self.execute_action_client.send_goal_async(nav_goal) 313 | self._spin_future(goal_future) 314 | 315 | result_future = goal_future.result().get_result_async() 316 | self._spin_future(result_future) 317 | 318 | 319 | def sparse_reward(env, action): 320 | """ 321 | The most basic Greenhouse environment reward function, which provides 322 | positive reward if all good plants are watered and negative reward if 323 | an evil plant is watered. 324 | """ 325 | reward = 0.0 326 | plants_by_distance = env._get_plants_by_distance(env.world_state) 327 | robot_location = env.world_state.robots[0].last_visited_location 328 | 329 | if action == 1: # move up to water 330 | for plant in plants_by_distance.values(): 331 | if plant.parent != robot_location: 332 | continue 333 | if plant.category == "plant_evil": 334 | print( 335 | "🌶️ Tried to water an evil plant. " 336 | f"Terminated in {env.step_number} steps " 337 | f"with watered fraction {env.watered_plant_fraction()}." 338 | ) 339 | return -5.0, True 340 | 341 | terminated = all(env.watered.values()) 342 | if terminated: 343 | print(f"💧 Watered all good plants! Succeeded in {env.step_number} steps.") 344 | reward += 8.0 345 | return reward, terminated 346 | 347 | 348 | def dense_reward(env, action): 349 | """ 350 | A simple Greenhouse environment reward function that provides reward each time 351 | a good plant is watered. 352 | """ 353 | reward = 0.0 354 | plants_by_distance = env._get_plants_by_distance(env.world_state) 355 | robot_location = env.world_state.robots[0].last_visited_location 356 | 357 | if action == 1: # move up to water 358 | for plant in plants_by_distance.values(): 359 | if plant.parent != robot_location: 360 | continue 361 | if plant.category == "plant_good": 362 | if not env.watered[plant.name]: 363 | env.watered[plant.name] = True 364 | reward += 2.0 365 | elif plant.category == "plant_evil": 366 | print( 367 | "🌶️ Tried to water an evil plant. " 368 | f"Terminated in {env.step_number} steps " 369 | f"with watered fraction {env.watered_plant_fraction()}." 370 | ) 371 | return -5.0, True 372 | else: 373 | raise RuntimeError(f"Unknown category {plant.category}") 374 | 375 | terminated = all(env.watered.values()) 376 | if terminated: 377 | print(f"💧 Watered all good plants! Succeeded in {env.step_number} steps.") 378 | return reward, terminated 379 | 380 | 381 | def full_reward(env, action): 382 | """Full (solution) reward function for the Greenhouse environment.""" 383 | reward = 0.0 384 | plants_by_distance = env._get_plants_by_distance(env.world_state) 385 | robot_location = env.world_state.robots[0].last_visited_location 386 | 387 | if env.battery_level() <= 0.0: 388 | print( 389 | "🪫 Ran out of battery. " 390 | f"Terminated in {env.step_number} steps " 391 | f"with watered fraction {env.watered_plant_fraction()}." 392 | ) 393 | return -5.0, True 394 | 395 | if action == 0: # stay ducked 396 | # Robot gets a penalty if it decides to ignore a waterable plant. 397 | for plant in env.world_state.objects: 398 | if (plant.category == "plant_good") and (plant.parent == robot_location): 399 | for location in env.world_state.locations: 400 | if robot_location in location.spawns and location.is_open: 401 | # print("\tPassed over a waterable plant") 402 | reward -= 0.25 403 | break 404 | return reward, False 405 | 406 | elif action == 1: # move up to water 407 | for plant in plants_by_distance.values(): 408 | if plant.parent != robot_location: 409 | continue 410 | if plant.category == "plant_good": 411 | if not env.watered[plant.name]: 412 | env.watered[plant.name] = True 413 | reward += 2.0 414 | elif plant.category == "plant_evil": 415 | print( 416 | "🌶️ Tried to water an evil plant. " 417 | f"Terminated in {env.step_number} steps " 418 | f"with watered fraction {env.watered_plant_fraction()}." 419 | ) 420 | return -5.0, True 421 | else: 422 | raise RuntimeError(f"Unknown category {plant.category}") 423 | if reward == 0.0: # nothing watered, wasted water 424 | # print("\tWasted water") 425 | reward = -0.5 426 | 427 | elif action == 2: # charging 428 | # Reward shaping to get the robot to visit the charger when its 429 | # battery is low, but not when it is high. 430 | if env.previous_battery_level <= 5.0: 431 | # print(f"\tCharged when battery low ({self.previous_battery_level}) :)") 432 | reward += 0.5 433 | else: 434 | # print(f"\tCharged when battery high ({self.previous_battery_level}) :(") 435 | reward -= 1.0 436 | 437 | terminated = all(env.watered.values()) 438 | if terminated: 439 | print(f"💧 Watered all good plants! Succeeded in {env.step_number} steps.") 440 | 441 | return reward, terminated 442 | -------------------------------------------------------------------------------- /slides/main.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 3 | - Reinforcement Learning for Deliberation in ROS 2 4 | author: 5 | - Christian Henkel 6 | - Sebastian Castro 7 | theme: 8 | - Bergen 9 | date: 10 | - ROSCon 2025 / October 27, 2025 11 | aspectratio: 169 12 | fontsize: 10pt 13 | colorlinks: true 14 | indent: false 15 | header-includes: 16 | - \usepackage{listings} 17 | - \usepackage{xcolor} 18 | - \lstset{ 19 | basicstyle=\ttfamily\small, 20 | backgroundcolor=\color{gray!10}, 21 | keywordstyle=\color{blue}, 22 | stringstyle=\color{orange}, 23 | commentstyle=\color{gray}, 24 | showstringspaces=false 25 | } 26 | - \hypersetup{urlcolor=blue} 27 | - \urlstyle{tt} 28 | - \setbeamerfont{footline}{size=\normalsize} 29 | - \setbeamertemplate{navigation symbols}{} 30 | - \setbeamertemplate{footline}{\vspace{2pt} \hspace{2pt} \includegraphics[width=1.85cm]{media/ros-wg-delib.png} \includegraphics[width=1.85cm]{media/roscon25.png} \hspace*{5pt} \insertsection \hfill \insertframenumber{} / \inserttotalframenumber \hspace*{5pt} \vspace{2pt}} 31 | --- 32 | 33 | # Introduction 34 | 35 | ## Agenda 36 | 37 | | __Time__ | __Topic__ | 38 | |----------------|---------------------------------------------------------| 39 | | 13:00 - 13:30 | Introduction / Software Setup | 40 | | 13:30 - 14:00 | (Very) Quick intro to Reinforcement Learning | 41 | | 14:00 - 15:00 | Training and evaluating RL agents | 42 | | 15:00 - 15:30 | [Coffee break / leave a longer training running] | 43 | | 15:30 - 16:15 | Evaluating trained agents and running in ROS nodes | 44 | | 16:15 - 17:00 | Discussion: ROS 2, RL, and Deliberation | 45 | 46 | ## Software Setup 47 | 48 | 1. Clone the repository 49 | 50 | ```bash 51 | git clone --recursive \ 52 | https://github.com/ros-wg-delib/rl_deliberation.git 53 | ``` 54 | 55 | 2. Install Pixi: 56 | 57 | ```bash 58 | curl -fsSL https://pixi.sh/install.sh | sh 59 | ``` 60 | 61 | (or – recommend for autocompletion!) 62 | 63 | 3. Build the project: 64 | 65 | ```bash 66 | pixi run build 67 | ``` 68 | 69 | 4. Run an example: 70 | 71 | ```plain 72 | pixi run start_world --env GreenhousePlain 73 | ``` 74 | 75 | ## Learning Goals 76 | 77 | By the end of this workshop, you will be able to: 78 | 79 | - Recognize robotics problems that can be solved with reinforcement learning. 80 | - Understand the basic reinforcement learning concepts and terminology. 81 | - Observe the effects of changing algorithms, hyperparameters, and reward functions on training. 82 | 83 | ## What is Reinforcement Learning (RL)? 84 | 85 | ::: columns 86 | 87 | :::: column 88 | 89 | ### Basic model 90 | 91 | - Given an __agent__ and an __environment__. 92 | - Subject to the __state__ of the environment, 93 | - the agent takes an __action__. 94 | - the environment responds with a new __state__ and a __reward__. 95 | 96 | :::: 97 | 98 | :::: column 99 | ![Agent Environment Interaction](media/agent-env.drawio.png) 100 | 101 | See also [Sutton and Barto, Reinforcement Learning: An Introduction](http://incompleteideas.net/book/RLbook2020.pdf) 102 | :::: 103 | 104 | ::: 105 | 106 | ## What is Reinforcement Learning (RL)? 107 | 108 | ### Notation 109 | 110 | ::: columns 111 | 112 | :::: column 113 | 114 | - Discrete time steps $t = 0, 1, 2, \dots$ 115 | - The environment is in a __state $S_t$__ 116 | - Agent performs an __action $A_t$__ 117 | - Environment responds with a new state $S_{t+1}$ and a reward $R_{t+1}$ 118 | - Based on that $S_t$ the agent selects the __next action__ $A_{t+1}$ 119 | 120 | :::: 121 | 122 | :::: column 123 | ![Agent Environment Interaction](media/agent-env-sym.drawio.png) 124 | :::: 125 | 126 | ::: 127 | 128 | ## RL Software in this Workshop 129 | 130 | ::: columns 131 | 132 | :::: column 133 | 134 | ![Gymnasium \tiny](media/gymnasium.png) 135 | 136 | \footnotesize Represents environments for RL 137 | 138 | :::: 139 | 140 | :::: column 141 | 142 | ![Stable Baselines3 (SB3)](media/sb3.png) 143 | 144 | \footnotesize RL algorithm implementations in PyTorch 145 | 146 | :::: 147 | 148 | ::: 149 | 150 | ## Exercise 1: You are the agent 151 | 152 | ::: columns 153 | 154 | :::: {.column width=40%} 155 | 156 | Start the environment: 157 | 158 | ```plain 159 | pixi run start_world \ 160 | --env GreenhousePlain 161 | ``` 162 | 163 | ![Greenhouse environment](media/greenhouse.png){height=100px} 164 | 165 | __Welcome__ 166 | You are a robot that has to water plants in a greenhouse 167 | 168 | :::: 169 | 170 | :::: {.column width=60%} 171 | 172 | Then, in another terminal, run: 173 | 174 | ```plain 175 | pixi run eval --model manual \ 176 | --env GreenhousePlain \ 177 | --config \ 178 | greenhouse_env_config.yaml 179 | ``` 180 | 181 | ```plain 182 | Enter action from [0, 1]: 183 | ``` 184 | 185 | On this prompt, you can choose: 186 | 187 | - __0__: Move forward without watering, or 188 | - __1__: Water the plant and move on. 189 | 190 | --- 191 | 192 | But be __careful__: If you try to water the evil plant _(red)_, you will be eaten. 193 | 194 | ![Evil Plant \tiny flickr/Tippitiwichet](media/venus_flytrap_src_wikimedia_commons_Tippitiwichet.jpg){width=80px} 195 | 196 | :::: 197 | 198 | ::: 199 | 200 | # Concepts 201 | 202 | ## Environment = MDP 203 | 204 | ### MDP 205 | 206 | We assume the environment to follow a __Markov Decision Process (MDP)__ model. 207 | An MDP is defined as $< \mathcal{S}, \mathcal{A}, \mathcal{P}, \mathcal{R}>$. 208 | 209 | - $s \in \mathcal{S}$ states and $a \in \mathcal{A}$ actions as above. 210 | - $\mathcal{P}$ State Transition Probability: $P(s'|s, a)$. 211 | - For an action $a$ taken in state $s$, what is the probability of reaching state $s'$? 212 | - $\mathcal{R}$ Reward Function: $R(s, a)$. 213 | - We will use this to motivate the agent to learn desired behavior. 214 | 215 | ![(Pessimistic) Example MDP](media/mdp.drawio.png){height=90px} 216 | 217 | ## Markovian 218 | 219 | ### Markov Property 220 | 221 | Implicit to the MDP formulation is the __Markov Property__: 222 | 223 | The future state $S_{t+1}$ depends only on the current state $S_t$ and action $A_t$, 224 | not on the sequence of events that preceded it. 225 | 226 | ### Practical implication 227 | 228 | Not a direct limitation for practical use, but something to be aware of. 229 | 230 | - E.g., if history matters, include it in the state representation. 231 | - However, this will not make learning easier. 232 | 233 | ## Agent = Policy 234 | 235 | ### Policy 236 | 237 | The agent's behavior is defined by a __policy__ $\pi$. 238 | A policy is a mapping from states to actions: $\pi: \mathcal{S} \rightarrow \mathcal{A}$. 239 | 240 | ### Reminder 241 | 242 | We are trying to optimize the __cumulative reward__ (or __return__) over time: 243 | 244 | $$ 245 | G_t = R_0 + R_1 + R_2 + \dots 246 | $$ 247 | 248 | In practice, we use a __discount factor__ $\gamma \in [0, 1]$ to prioritize immediate rewards: 249 | 250 | $$ 251 | G_t = R_0 + \gamma R_1 + \gamma^2 R_2 + \dots 252 | $$ 253 | $$ 254 | G_t = \sum_{k=0}^{\infty} \gamma^k R_{t+k} 255 | $$ 256 | 257 | ## Learning 258 | 259 | ### Bellman Equation 260 | 261 | This is probably the __most fundamental equation in RL__. 262 | It estimates $v_{\pi}(s)$, known as the __state value function__ when using policy $\pi$: 263 | 264 | $$v_{\pi}(s) = \mathbb{E}_{\pi} [G_t | S_t = s]$$ 265 | $$ = \mathbb{E}_{\pi} [R_{t+1} + \gamma v_{\pi}(S_{t+1}) | S_t = s]$$ 266 | $$ = \sum_{a} \pi(a|s) \sum_{s', r} p(s', r | s, a) [r + \gamma v_{\pi}(s')]$$ 267 | 268 | ### Fundamental optimization goal 269 | 270 | So, we can formulate the problem to find an optimal policy $\pi^*$ as an optimization problem: 271 | 272 | $$\pi^* = \arg\max_{\pi} v_{\pi}(s), \quad \forall s \in \mathcal{S}$$ 273 | 274 | # Methods 275 | 276 | ## Temporal Differencing 277 | 278 | The Bellman equation gives rise to __temporal differencing (TD)__ for training a policy. 279 | 280 | $$v_{\pi}(S_t) \leftarrow (1 - \alpha) v_{\pi}(S_t) + \alpha (R_{t+1} + \gamma v_{\pi}(S_{t+1}))$$ 281 | 282 | where 283 | 284 | - $v_{\pi}(S_t)$ is the expected value of state $S_t$ 285 | - $R_{t+1} + \gamma v_{\pi}(S_{t+1})$ is the actual reward obtained at $S_t$ plus the expected value of the next state $S_{t+1}$ 286 | - $R_{t+1} + \gamma v_{\pi}(S_{t+1}) - v_{\pi}(S_t)$ is the __TD error__. 287 | - $\alpha$ is the __learning rate__. 288 | 289 | \small (a variant using the __state-action value function__ $Q_{\pi}(s, a)$ 290 | \small is known as __Q-learning__.) 291 | 292 | ## Tabular Reinforcement Learning 293 | 294 | RL began with known MDPs + discrete states/actions, so $v_{\pi}(s)$ or $q_{\pi}(s,a)$ are __tables__. 295 | 296 | ![Grid world \tiny ([Silver, 2015](https://davidstarsilver.wordpress.com/wp-content/uploads/2025/04/lecture-3-planning-by-dynamic-programming-.pdf))](media/grid-world.png){width=150px} 297 | 298 | Can use __dynamic programming__ to iterate through the entire environment and converge on an optimal policy. 299 | 300 | ::: columns 301 | 302 | :::: column 303 | 304 | ![Value iteration \tiny ([Silver, 2015](https://davidstarsilver.wordpress.com/wp-content/uploads/2025/04/lecture-3-planning-by-dynamic-programming-.pdf))](media/value-iteration.png){width=80px} 305 | 306 | :::: 307 | 308 | :::: column 309 | 310 | ![Policy iteration \tiny ([Silver, 2015](https://davidstarsilver.wordpress.com/wp-content/uploads/2025/04/lecture-3-planning-by-dynamic-programming-.pdf))](media/policy-iteration.png){width=80px} 311 | 312 | :::: 313 | 314 | ::: 315 | 316 | ## Model-Free Reinforcement Learning 317 | 318 | If the state-action space is too large, need to perform __rollouts__ to gain experience. 319 | 320 | Key: Balancing __exploitation__ and __exploration__! 321 | 322 | ![Model-free RL methods \tiny ([Silver, 2015](https://davidstarsilver.wordpress.com/wp-content/uploads/2025/04/lecture-4-model-free-prediction-.pdf))](media/model-free-rl.png){width=430px} 323 | 324 | ## Deep Reinforcement Learning 325 | 326 | When the observation space is too large (or worse, continuous), tabular methods no longer work. 327 | 328 | Need a different function approximator -- _...why not a neural network?_ 329 | 330 | ![Deep Q-Network \tiny ([Mnih et al., 2015](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf))](media/dqn.png){width=250px} 331 | 332 | __Off-policy__: Can train on old experiences from a _replay buffer_. 333 | 334 | ## Actor-Critic / Policy Gradient Methods 335 | 336 | DQN only works for discrete actions, so what about continuous actions? 337 | 338 | ::: columns 339 | 340 | :::: {.column width=60%} 341 | 342 | - __Critic__ approximates value function. Trained via TD learning. 343 | 344 | - __Actor__ outputs actions (i.e., the policy). Trained via __policy gradient__, backpropagated from critic loss. 345 | 346 | --- 347 | 348 | - Initial methods were __on-policy__ -- can only train on the latest version of the policy with current experiences. 349 | Example: Proximal Policy Optimization (PPO) ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)). 350 | 351 | - Other approaches train actor and critic at different time scales to allow off-policy. 352 | Example: Soft Actor-Critic (SAC) ([Haarnoja et al., 2018](https://arxiv.org/abs/1801.01290)). 353 | 354 | :::: 355 | 356 | :::: {.column width=40%} 357 | 358 | ![Actor-Critic methods \tiny ([Sutton + Barto, 2020](http://incompleteideas.net/book/the-book-2nd.html))](media/actor-critic.png){width=160px} 359 | 360 | :::: 361 | ::: 362 | 363 | ## Concept Comparison 364 | 365 | ::: columns 366 | :::: {.column width=45%} 367 | 368 | ### Exploitation vs. Exploration 369 | 370 | __Exploitation__: Based on the current policy, select the action that maximizes expected reward. 371 | 372 | __Exploration__: Select actions that may not maximize immediate reward, but could lead to better long-term outcomes. 373 | 374 | ### On-policy vs. Off-policy 375 | 376 | __On-policy__: Learn the value of the policy being carried out by the agent. 377 | 378 | __Off-policy__: Learn the value of an optimal policy independently of the agent's actions. 379 | :::: 380 | 381 | :::: {.column width=55%} 382 | Benefits of __on-policy__ methods: 383 | 384 | $-$ Collect data that is relevant under the current policy. 385 | 386 | $-$ More stable learning. 387 | 388 | --- 389 | 390 | Benefits of __off-policy__ methods: 391 | 392 | $-$ Better sample efficiency. 393 | 394 | $-$ Relevant for real-world robotics (gathering data is expensive). 395 | 396 | :::: 397 | 398 | ::: 399 | 400 | ## Available Algorithms 401 | 402 | ### DQN 403 | 404 | __Deep Q Network__ 405 | 406 | ::: columns 407 | :::: {.column width=50%} 408 | \small 409 | Learns a Q-function $Q(s, a)$. Introduced _experience replay_ and _target networks_. 410 | :::: 411 | :::: {.column width=25%} 412 | \small 413 | Off-policy 414 | Discrete actions 415 | :::: 416 | :::: {.column width=25%} 417 | \small 418 | [Mnih et al., 2013](https://arxiv.org/abs/1312.5602) 419 | [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html) 420 | :::: 421 | ::: 422 | 423 | ### A2C 424 | 425 | __Advantage Actor-Critic__ 426 | 427 | ::: columns 428 | :::: {.column width=50%} 429 | \small 430 | $A(s, a) = Q(s, a) - V(s)$ _advantage function_ to reduce variance. 431 | :::: 432 | :::: {.column width=25%} 433 | \small 434 | On-policy 435 | Any action space 436 | :::: 437 | :::: {.column width=25%} 438 | \small 439 | [Mnih et al., 2016](https://arxiv.org/abs/1602.01783) 440 | [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html) 441 | :::: 442 | ::: 443 | 444 | ### PPO 445 | 446 | __Proximal Policy Optimization__ 447 | 448 | ::: columns 449 | :::: {.column width=50%} 450 | \small 451 | Optimize policy directly. Uses a _clipped surrogate objective_ for stability. 452 | :::: 453 | :::: {.column width=25%} 454 | \small 455 | On-policy 456 | Any action space 457 | :::: 458 | :::: {.column width=25%} 459 | \small 460 | [Schulman et al., 2017](https://arxiv.org/abs/1707.06347) [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) 461 | :::: 462 | ::: 463 | 464 | ### SAC 465 | 466 | __Soft Actor-Critic__ 467 | 468 | ::: columns 469 | :::: {.column width=50%} 470 | \small 471 | Separate Actor & Critic NNs. Exploration by additional _entropy_ term. 472 | :::: 473 | :::: {.column width=25%} 474 | \small 475 | Off-policy 476 | Cont. actions 477 | :::: 478 | :::: {.column width=25%} 479 | \small 480 | [Haarnoja et al., 2018](https://arxiv.org/abs/1801.01290) 481 | [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/sac.html) 482 | ::::: 483 | ::: 484 | 485 | ## Which Algorithm to Choose? 486 | 487 | From the lead developer of SB3 (Antonin Raffin): 488 | 489 | ![Algorithm choice flowchart](media/sb3-algo-choice.png){width=300px} 490 | 491 | # Exercises 492 | 493 | ## Exercise 2: Run with a Random Agent 494 | 495 | ```plain 496 | pixi run start_world --env GreenhousePlain 497 | 498 | pixi run eval --realtime \ 499 | --config greenhouse_env_config.yaml --model \ 500 | pyrobosim_ros_gym/policies/GreenhousePlain_DQN_random.pt 501 | ``` 502 | 503 | ![Greenhouse environment](media/greenhouse.png){height=150px} 504 | 505 | ## Exercise 3: Training Your First Agent 506 | 507 | Start the world. 508 | 509 | ```plain 510 | pixi run start_world --env GreenhousePlain 511 | ``` 512 | 513 | Kick off training. 514 | 515 | ```plain 516 | pixi run train --config greenhouse_env_config.yaml \ 517 | --env GreenhousePlain --algorithm DQN \ 518 | --discrete-actions --realtime 519 | ``` 520 | 521 | The `--config` file points to `pyrobosim_ros_gym/config/greenhouse_env_config.yaml`, which lets you easily set up different algorithms and training parameters. 522 | 523 | ## Exercise 3: Training Your First Agent (For Real...) 524 | 525 | ... this is going to take a while. 526 | Let's speed things up. 527 | 528 | Run simulation headless, i.e., without the GUI. 529 | 530 | ```plain 531 | pixi run start_world --env GreenhousePlain --headless 532 | ``` 533 | 534 | Without "realtime", run actions as fast as possible 535 | 536 | ```plain 537 | pixi run train --config greenhouse_env_config.yaml \ 538 | --env GreenhousePlain --algorithm DQN --discrete-actions 539 | ``` 540 | 541 | __NOTE:__ Seeding the training run is important for reproducibility! 542 | 543 | We are running with `--seed 42` by default, but you can change it. 544 | 545 | ## Exercise 3: Visualizing Training Progress 546 | 547 | SB3 has visualization support for [TensorBoard](https://www.tensorflow.org/tensorboard). 548 | By adding the `--log` argument, a log file will be written to the `train_logs` folder. 549 | 550 | ```plain 551 | pixi run train --config greenhouse_env_config.yaml \ 552 | --env GreenhousePlain --algorithm DQN \ 553 | --discrete-actions --log 554 | ``` 555 | 556 | Open TensorBoard and follow the URL (usually `http://localhost:6006/`). 557 | 558 | ```plain 559 | pixi run tensorboard 560 | ``` 561 | 562 | ![TensorBoard](media/tensorboard.png){width=200px} 563 | 564 | ## Exercise 3: Reward Engineering 565 | 566 | Open the file `pyrobosim_ros_gym/config/greenhouse_env_config.yaml`. 567 | 568 | There, you will see 3 options for the `training.reward_fn` parameter. 569 | 570 | Train models with each and compare the effects of __reward shaping__ on results. 571 | 572 | \small 573 | * `sparse_reward` : -5 if evil plant is watered, +8 if _all_ good plants are watered. 574 | * `dense_reward` : -5 if evil plant is watered, +2 for _each_ good plant that is watered. 575 | * `full_reward` : Same as above, but adds small penalties for wasting water / passing over a good plant. 576 | 577 | ![Comparing reward functions](media/tensorboard-reward-compare.png){height=120px} 578 | 579 | ## Exercise 3: Evaluating Your Trained Agent 580 | 581 | Once you have your trained models, you can evaluate them against the simulator. 582 | 583 | ```plain 584 | pixi run eval --config greenhouse_env_config.yaml \ 585 | --model .pt --num-episodes 10 586 | ``` 587 | 588 | By default, this will run just like training (as quickly as possible). 589 | 590 | You can add the `--realtime` flag to slow things down to "real-time" so you can visually inspect the results. 591 | 592 | ![Example evaluation results](media/eval-results.png){width=240px} 593 | 594 | ## Exercise 4: Train More Complicated Environments 595 | 596 | \small Training the `GreenhousePlain` environment is easy because the environment is _deterministic_; the plants are always in the same locations. 597 | 598 | For harder environments, you may want to switch algorithms (e.g., `PPO` or `SAC`). 599 | 600 | ::: columns 601 | 602 | :::: column 603 | 604 | ![`GreenhouseRandom` environment](media/greenhouse-random.png){width=120px} 605 | 606 | Plants are now spawned in random locations -- but only one per table. 607 | 608 | :::: 609 | 610 | ::: column 611 | 612 | ![`GreenhouseBattery` environment](media/greenhouse-battery.png){width=120px} 613 | 614 | Watering costs 49% battery -- must recharge after watering twice. 615 | 616 | Charging is a new action (id `3`). 617 | 618 | :::: 619 | 620 | ::: 621 | 622 | __Challenge__: Evaluate your policy on the `GreenhouseRandom` environment! 623 | 624 | ## Application: Deploying a Trained Policy as a ROS Node 625 | 626 | 1. Start an environment of your choice. 627 | 628 | ```plain 629 | pixi run start_world --env GreenhouseRandom 630 | ``` 631 | 632 | 2. Start the node with an appropriate model. 633 | 634 | ```plain 635 | pixi run policy_node --model .pt \ 636 | --config greenhouse_env_config.yaml 637 | ``` 638 | 639 | 3. Open an interactive shell. 640 | 641 | ```plain 642 | pixi shell 643 | ``` 644 | 645 | 4. In the shell, send an action goal to run the policy to completion! 646 | 647 | ```plain 648 | ros2 action send_goal /execute_policy \ 649 | rl_interfaces/ExecutePolicy {} 650 | ``` 651 | 652 | # Discussion 653 | 654 | ## When to use RL? 655 | 656 | Arguably, our simple greenhouse problem did not need RL. 657 | 658 | ... but it was nice and educational... right? 659 | 660 | ### General rules 661 | 662 | - If easy to model, __engineer it by hand__ (e.g., controllers, behavior trees). 663 | - If difficult to model, but you can provide the answer (e.g., labels or demonstrations), consider __supervised learning__. 664 | - If difficult to model, and you cannot easily provide an answer, consider __reinforcement learning__. 665 | 666 | ## Scaling up Learning 667 | 668 | ### Parallel simulation 669 | 670 | ::: columns 671 | 672 | :::: column 673 | 674 | - Simulations can be parallelized using multiple CPUs / GPUs. 675 | 676 | - SB3 defaults to using __vectorized environments__. 677 | 678 | - Other tools for parallel RL include [NVIDIA Isaac Lab](https://github.com/isaac-sim/IsaacLab) and [mjlab](https://github.com/mujocolab/mjlab). 679 | 680 | :::: 681 | 682 | :::: column 683 | 684 | ![[Rudin et al., 2021](https://leggedrobotics.github.io/legged_gym/)](media/rsl-parallel-sim.png){height=90px} 685 | 686 | :::: 687 | 688 | ::: 689 | 690 | ## Curriculum learning 691 | 692 | ::: columns 693 | 694 | :::: column 695 | 696 | - __Reward shaping__ by itself is very important to speed up learning. 697 | 698 | - Rather than solving the hardest problem from scratch, introduce a __curriculum__ of progressively harder tasks. 699 | 700 | :::: 701 | 702 | :::: column 703 | 704 | ![[Bengio et al., 2009](https://qmro.qmul.ac.uk/xmlui/bitstream/handle/123456789/15972/Bengio%2C%202009%20Curriculum%20Learning.pdf?sequence=1&isAllowed=y)](media/curriculum-learning.png){height=100px} 705 | 706 | :::: 707 | 708 | ::: 709 | 710 | ## RL Experimentation 711 | 712 | ### RNG Seeds 713 | 714 | ::: columns 715 | 716 | :::: column 717 | 718 | - It is possible to get a "lucky" (or "unlucky") seed when training. 719 | 720 | - Best (and expected) practice is to run multiple experiments and report intervals. 721 | 722 | :::: 723 | 724 | :::: column 725 | 726 | ![[Schulman et al., 2017](https://arxiv.org/abs/1707.06347)](media/ppo-graph.png){height=90px} 727 | 728 | :::: 729 | 730 | ::: 731 | 732 | ## Hyperparameter Tuning 733 | 734 | ::: columns 735 | 736 | :::: column 737 | 738 | ### Hyperparameter 739 | 740 | - Any user-specified parameter for ML training. 741 | 742 | - e.g., learning rate, batch/network size, reward weights, ... 743 | 744 | - Consider using automated tools (e.g., [Optuna](https://github.com/optuna/optuna)) to help you tune hyperparameters. 745 | 746 | :::: 747 | 748 | :::: column 749 | 750 | ![[Optuna Dashboard](https://github.com/optuna/optuna-dashboard)](media/optuna-dashboard.png){height=90px} 751 | 752 | :::: 753 | 754 | ::: 755 | 756 | ## Deploying Policies to ROS 757 | 758 | ::: columns 759 | 760 | :::: {.column width=40%} 761 | 762 | ### Python 763 | 764 | - Can directly put PyTorch / Tensorflow / etc. models in a `rclpy` node. 765 | - Be careful with threading and CPU/GPU synchronization issues! 766 | 767 | ### C++ 768 | 769 | - If you need performance, consider using C++ for inference. 770 | - Facilitated by tools like [ONNX Runtime](https://onnxruntime.ai/inference). 771 | - Can also put your policy inside a `ros2_control` controller for real-time capabilities. 772 | 773 | :::: 774 | 775 | :::: {.column width=60%} 776 | 777 | ![[Question on Pytorch](https://discuss.pytorch.org/t/are-there-any-reasons-why-running-gpu-inference-in-a-thread-would-be-slower/204519)](media/pytorch-threading-question.png){height=120px} 778 | 779 | ![[ONNX Runtime](https://onnxruntime.ai/inference)](media/onnx-runtime.png){height=100px} 780 | 781 | :::: 782 | 783 | ::: 784 | 785 | ## RL for Deliberation 786 | 787 | ### Background 788 | 789 | __State of the art RL works for fast, low-level control policies (e.g., locomotion)__ 790 | 791 | - Requires sim-to-real training because on-robot RL is hard and/or unsafe. 792 | - Alternatives: fine-tune pretrained policies or train _residual_ policies. 793 | 794 | ### Deliberation 795 | 796 | __How does this change for deliberation applications?__ 797 | 798 | - Facilitates on-robot RL: train high-level decision making, add a safety layer below. 799 | - Hierarchical RL dates back to the __options framework__ ([Sutton et al., 1998](http://incompleteideas.net/609%20dropbox/other%20readings%20and%20resources/Options.pdf)). 800 | - What kinds of high-level decisions can/should be learned? 801 | - What should this "safety layer" below look like? 802 | 803 | ## Further Resources 804 | 805 | ### RL Theory 806 | 807 | - \small Sutton & Barto Textbook: 808 | - \small David Silver Lectures: 809 | - \small Stable Baselines3 docs: 810 | 811 | ### ROS Deliberation 812 | 813 | 814 | 815 | Join our mailing list and ~monthly meetings! 816 | 817 | ![Happy RL journey! :)](media/twitter-post.png){height=100px} 818 | --------------------------------------------------------------------------------