├── rl_ws_worlds
├── resource
│ └── rl_ws_worlds
├── rl_ws_worlds
│ ├── __init__.py
│ └── run.py
├── setup.cfg
├── worlds
│ ├── greenhouse_object_data.yaml
│ ├── greenhouse_location_data.yaml
│ ├── greenhouse_random.yaml
│ ├── greenhouse_plain.yaml
│ ├── greenhouse_battery.yaml
│ └── banana.yaml
├── package.xml
└── setup.py
├── slides
├── media
│ ├── dqn.png
│ ├── sb3.png
│ ├── greenhouse.png
│ ├── grid-world.png
│ ├── gymnasium.png
│ ├── mdp.drawio.png
│ ├── ppo-graph.png
│ ├── roscon25.png
│ ├── actor-critic.png
│ ├── eval-results.png
│ ├── onnx-runtime.png
│ ├── ros-wg-delib.png
│ ├── tensorboard.png
│ ├── twitter-post.png
│ ├── model-free-rl.png
│ ├── sb3-algo-choice.png
│ ├── value-iteration.png
│ ├── agent-env.drawio.png
│ ├── greenhouse-random.png
│ ├── optuna-dashboard.png
│ ├── policy-iteration.png
│ ├── rsl-parallel-sim.png
│ ├── agent-env-sym.drawio.png
│ ├── curriculum-learning.png
│ ├── greenhouse-battery.png
│ ├── pytorch-threading-question.png
│ ├── tensorboard-reward-compare.png
│ └── venus_flytrap_src_wikimedia_commons_Tippitiwichet.jpg
├── README.md
└── main.md
├── .gitmodules
├── .gitattributes
├── pyrobosim_ros_gym
├── policies
│ ├── BananaPick_DQN_random.pt
│ ├── BananaPick_PPO_trained.pt
│ ├── BananaPlace_PPO_trained.pt
│ ├── GreenhousePlain_DQN_random.pt
│ ├── GreenhousePlain_PPO_trained.pt
│ ├── BananaPlaceNoSoda_PPO_trained.pt
│ ├── GreenhouseBattery_PPO_trained.pt
│ ├── GreenhouseRandom_DQN_trained.pt
│ ├── GreenhouseRandom_PPO_trained.pt
│ └── __init__.py
├── __init__.py
├── config
│ ├── banana_env_config.yaml
│ └── greenhouse_env_config.yaml
├── start_world.py
├── envs
│ ├── __init__.py
│ ├── pyrobosim_ros_env.py
│ ├── banana.py
│ └── greenhouse.py
├── policy_node.py
├── train.py
└── eval.py
├── .dockerignore
├── .gitignore
├── rl_interfaces
├── action
│ └── ExecutePolicy.action
├── CMakeLists.txt
└── package.xml
├── .docker
└── Dockerfile
├── LICENSE
├── pyproject.toml
├── .pre-commit-config.yaml
├── .github
└── workflows
│ └── ci.yml
├── README.md
└── test
└── test_doc.py
/rl_ws_worlds/resource/rl_ws_worlds:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rl_ws_worlds/rl_ws_worlds/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slides/media/dqn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/dqn.png
--------------------------------------------------------------------------------
/slides/media/sb3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/sb3.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "pyrobosim"]
2 | path = pyrobosim
3 | url = https://github.com/sea-bass/pyrobosim.git
4 |
--------------------------------------------------------------------------------
/slides/media/greenhouse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/greenhouse.png
--------------------------------------------------------------------------------
/slides/media/grid-world.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/grid-world.png
--------------------------------------------------------------------------------
/slides/media/gymnasium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/gymnasium.png
--------------------------------------------------------------------------------
/slides/media/mdp.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/mdp.drawio.png
--------------------------------------------------------------------------------
/slides/media/ppo-graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/ppo-graph.png
--------------------------------------------------------------------------------
/slides/media/roscon25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/roscon25.png
--------------------------------------------------------------------------------
/slides/media/actor-critic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/actor-critic.png
--------------------------------------------------------------------------------
/slides/media/eval-results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/eval-results.png
--------------------------------------------------------------------------------
/slides/media/onnx-runtime.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/onnx-runtime.png
--------------------------------------------------------------------------------
/slides/media/ros-wg-delib.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/ros-wg-delib.png
--------------------------------------------------------------------------------
/slides/media/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/tensorboard.png
--------------------------------------------------------------------------------
/slides/media/twitter-post.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/twitter-post.png
--------------------------------------------------------------------------------
/slides/media/model-free-rl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/model-free-rl.png
--------------------------------------------------------------------------------
/slides/media/sb3-algo-choice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/sb3-algo-choice.png
--------------------------------------------------------------------------------
/slides/media/value-iteration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/value-iteration.png
--------------------------------------------------------------------------------
/rl_ws_worlds/setup.cfg:
--------------------------------------------------------------------------------
1 | [develop]
2 | script_dir=$base/lib/rl_ws_worlds
3 | [install]
4 | install_scripts=$base/lib/rl_ws_worlds
5 |
--------------------------------------------------------------------------------
/slides/media/agent-env.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/agent-env.drawio.png
--------------------------------------------------------------------------------
/slides/media/greenhouse-random.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/greenhouse-random.png
--------------------------------------------------------------------------------
/slides/media/optuna-dashboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/optuna-dashboard.png
--------------------------------------------------------------------------------
/slides/media/policy-iteration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/policy-iteration.png
--------------------------------------------------------------------------------
/slides/media/rsl-parallel-sim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/rsl-parallel-sim.png
--------------------------------------------------------------------------------
/slides/media/agent-env-sym.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/agent-env-sym.drawio.png
--------------------------------------------------------------------------------
/slides/media/curriculum-learning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/curriculum-learning.png
--------------------------------------------------------------------------------
/slides/media/greenhouse-battery.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/greenhouse-battery.png
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # SCM syntax highlighting & preventing 3-way merges
2 | pixi.lock merge=binary linguist-language=YAML linguist-generated=true
3 |
--------------------------------------------------------------------------------
/slides/media/pytorch-threading-question.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/pytorch-threading-question.png
--------------------------------------------------------------------------------
/slides/media/tensorboard-reward-compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/tensorboard-reward-compare.png
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/BananaPick_DQN_random.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/BananaPick_DQN_random.pt
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/BananaPick_PPO_trained.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/BananaPick_PPO_trained.pt
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/BananaPlace_PPO_trained.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/BananaPlace_PPO_trained.pt
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/GreenhousePlain_DQN_random.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhousePlain_DQN_random.pt
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/GreenhousePlain_PPO_trained.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhousePlain_PPO_trained.pt
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/BananaPlaceNoSoda_PPO_trained.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/BananaPlaceNoSoda_PPO_trained.pt
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/GreenhouseBattery_PPO_trained.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhouseBattery_PPO_trained.pt
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/GreenhouseRandom_DQN_trained.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhouseRandom_DQN_trained.pt
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/GreenhouseRandom_PPO_trained.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/pyrobosim_ros_gym/policies/GreenhouseRandom_PPO_trained.pt
--------------------------------------------------------------------------------
/slides/media/venus_flytrap_src_wikimedia_commons_Tippitiwichet.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ros-wg-delib/rl_deliberation/HEAD/slides/media/venus_flytrap_src_wikimedia_commons_Tippitiwichet.jpg
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Various types of metadata
2 | .mypy_cache
3 | .pixi
4 | .pytest_cache
5 | .ruff_cache
6 | **/__pycache__
7 |
8 | # Colcon build files
9 | build/
10 | install/
11 | log/
12 |
13 | # Model checkpoints and logging
14 | train_logs/
15 | *.pt
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Various types of metadata
2 | .mypy_cache
3 | .pixi
4 | .pytest_cache
5 | .ruff_cache
6 | **/__pycache__
7 |
8 | # Colcon build files
9 | build/
10 | install/
11 | log/
12 |
13 | # Model checkpoints and logging
14 | train_logs/
15 | **/*.pt
16 |
17 | # Slides
18 | slides/*.pdf
19 |
--------------------------------------------------------------------------------
/rl_interfaces/action/ExecutePolicy.action:
--------------------------------------------------------------------------------
1 | # ROS action definition for executing a policy
2 |
3 | # Goal (empty)
4 |
5 | ---
6 |
7 | # Result
8 |
9 | bool success
10 | uint64 num_steps
11 | float64 cumulative_reward
12 |
13 | ---
14 |
15 | # Feedback
16 |
17 | uint64 num_steps
18 | float64 cumulative_reward
19 |
--------------------------------------------------------------------------------
/rl_ws_worlds/worlds/greenhouse_object_data.yaml:
--------------------------------------------------------------------------------
1 | ###################
2 | # Object metadata #
3 | ###################
4 |
5 | plant_good:
6 | footprint:
7 | type: circle
8 | radius: 0.08
9 | color: [0.1, 0.8, 0.1]
10 |
11 | plant_evil:
12 | footprint:
13 | type: circle
14 | radius: 0.08
15 | color: [0.9, 0.3, 0.1]
16 |
--------------------------------------------------------------------------------
/.docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:latest
2 | SHELL ["/bin/bash", "-o", "pipefail", "-c"]
3 |
4 | # Install apt dependencies.
5 | ARG DEBIAN_FRONTEND=noninteractive
6 | RUN apt-get update && \
7 | apt-get upgrade -y && \
8 | apt-get install -y \
9 | git curl build-essential && \
10 | rm -rf /var/lib/apt/lists/*
11 |
12 | # Install pixi
13 | RUN curl -fsSL https://pixi.sh/install.sh | sh
14 | ENV PATH=$PATH:/root/.pixi/bin
15 |
16 | # Get this source
17 | COPY . rl_deliberation
18 | WORKDIR /rl_deliberation
19 |
20 | # Build workspace
21 | RUN pixi run build
22 |
--------------------------------------------------------------------------------
/rl_interfaces/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.22)
2 |
3 | project(rl_interfaces)
4 |
5 | # Default to C++17
6 | if(NOT CMAKE_CXX_STANDARD)
7 | set(CMAKE_CXX_STANDARD 17)
8 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
9 | endif()
10 | if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
11 | add_compile_options(-Wall -Wextra -Wpedantic)
12 | endif()
13 |
14 | find_package(ament_cmake REQUIRED)
15 | find_package(rosidl_default_generators REQUIRED)
16 |
17 | rosidl_generate_interfaces(${PROJECT_NAME}
18 | "action/ExecutePolicy.action"
19 | )
20 |
21 | ament_export_dependencies(rosidl_default_runtime)
22 | ament_package()
23 |
--------------------------------------------------------------------------------
/rl_ws_worlds/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | rl_ws_worlds
5 | 1.0.0
6 | Example worlds for exploring reinforcement learning.
7 | Christian Henkel
8 | Sebastian Castro
9 | BSD-3-Clause
10 |
11 | pyrobosim_ros
12 |
13 | ament_copyright
14 | ament_flake8
15 | ament_pep257
16 | python3-pytest
17 |
18 |
19 | ament_python
20 |
21 |
22 |
--------------------------------------------------------------------------------
/slides/README.md:
--------------------------------------------------------------------------------
1 | # Slides
2 |
3 | The accompanying slides are built with [Pandoc](https://pandoc.org/) and `pdflatex`.
4 |
5 | Refer to for useful information.
6 |
7 | ## Running pandoc natively
8 |
9 | You must first install these tools:
10 |
11 | ```bash
12 | sudo apt install pandoc texlive-latex-base texlive-latex-extra
13 | ```
14 |
15 | Then, to build the slides:
16 |
17 | ```bash
18 | pandoc -t beamer main.md -o slides.pdf --listings --slide-level=2
19 | ```
20 |
21 | ## Running pandoc with Docker (recommended)
22 |
23 | Or use the docker image that is also used in CI:
24 | (Run this command from the `slides/` directory)
25 |
26 | ```bash
27 | docker run \
28 | --rm \
29 | --volume "$(pwd):/data" \
30 | --user $(id -u):$(id -g) \
31 | pandoc/latex:3.7 -t beamer main.md -o slides.pdf --listings --slide-level=2
32 | ```
33 |
--------------------------------------------------------------------------------
/rl_interfaces/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | rl_interfaces
5 | 1.0.0
6 | Example interfaces for exploring reinforcement learning.
7 | Christian Henkel
8 | Sebastian Castro
9 | BSD-3-Clause
10 |
11 | ament_cmake
12 | rosidl_default_generators
13 |
14 | rosidl_default_runtime
15 |
16 | rosidl_interface_packages
17 |
18 |
19 | ament_cmake
20 |
21 |
22 |
--------------------------------------------------------------------------------
/rl_ws_worlds/setup.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 |
3 | from setuptools import find_packages, setup
4 |
5 | package_name = "rl_ws_worlds"
6 |
7 | setup(
8 | name=package_name,
9 | version="1.0.0",
10 | packages=find_packages(exclude=["test"]),
11 | data_files=[
12 | ("share/ament_index/resource_index/packages", ["resource/" + package_name]),
13 | ("share/" + package_name, ["package.xml"]),
14 | ("share/" + package_name + "/worlds", glob("worlds/*.*")),
15 | ],
16 | install_requires=["setuptools"],
17 | zip_safe=True,
18 | maintainer="Christian Henkel",
19 | maintainer_email="christian.henkel2@de.bosch.com",
20 | description="Example worlds for exploring reinforcement learning.",
21 | license="BSD-3-Clause",
22 | entry_points={
23 | "console_scripts": [
24 | "run = rl_ws_worlds.run:main",
25 | "is_at_goal = rl_ws_worlds.is_at_goal:main",
26 | ],
27 | },
28 | )
29 |
--------------------------------------------------------------------------------
/rl_ws_worlds/worlds/greenhouse_location_data.yaml:
--------------------------------------------------------------------------------
1 | #####################
2 | # Location metadata #
3 | #####################
4 |
5 | table:
6 | footprint:
7 | type: box
8 | dims: [0.8, 0.8]
9 | height: 0.5
10 | color: [0.3, 0.2, 0.1]
11 | nav_poses:
12 | - position: # left
13 | x: -0.6
14 | y: 0.0
15 | - position: # right
16 | x: 0.6
17 | y: 0.0
18 | rotation_eul:
19 | yaw: 3.14
20 | locations:
21 | - name: "tabletop"
22 | footprint:
23 | type: parent
24 | padding: 0.05
25 |
26 | charger:
27 | footprint:
28 | type: polygon
29 | coords:
30 | - [-0.3, -0.15]
31 | - [0.3, -0.15]
32 | - [0.3, 0.15]
33 | - [-0.3, 0.15]
34 | height: 0.1
35 | locations:
36 | - name: "dock"
37 | footprint:
38 | type: parent
39 | nav_poses:
40 | - position: # below
41 | x: 0.0
42 | y: -0.35
43 | rotation_eul:
44 | yaw: 1.57
45 | - position: # left
46 | x: -0.5
47 | y: 0.0
48 | - position: # above
49 | x: 0.0
50 | y: 0.35
51 | rotation_eul:
52 | yaw: -1.57
53 | - position: # right
54 | x: 0.5
55 | y: 0.0
56 | rotation_eul:
57 | yaw: 3.14
58 | color: [0.4, 0.4, 0]
59 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2025, ROS Deliberation Community Group
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | import importlib
10 | import os
11 | from typing import Any
12 |
13 | import yaml
14 |
15 |
16 | def get_config(config_path: str) -> dict[str, Any]:
17 | """Helper function to parse the configuration YAML file."""
18 | if not os.path.isabs(config_path):
19 | default_path = os.path.join(
20 | os.path.dirname(os.path.abspath(__file__)), "config"
21 | )
22 | config_path = os.path.join(default_path, config_path)
23 | with open(config_path, "r") as file:
24 | config = yaml.safe_load(file)
25 |
26 | # Handle special case of reward function
27 | training_args = config.get("training", {})
28 | if "reward_fn" in training_args:
29 | module_name, function_name = training_args["reward_fn"].rsplit(".", 1)
30 | module = importlib.import_module(module_name)
31 | training_args["reward_fn"] = getattr(module, function_name)
32 |
33 | # Handle special case of policy_kwargs activation function needing to be a class instance.
34 | for subtype in training_args:
35 | subtype_config = training_args[subtype]
36 | if not isinstance(subtype_config, dict):
37 | continue
38 | policy_kwargs = subtype_config.get("policy_kwargs", {})
39 | if "activation_fn" in policy_kwargs:
40 | module_name, class_name = policy_kwargs["activation_fn"].rsplit(".", 1)
41 | module = importlib.import_module(module_name)
42 | policy_kwargs["activation_fn"] = getattr(module, class_name)
43 |
44 | return config
45 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/config/banana_env_config.yaml:
--------------------------------------------------------------------------------
1 | # Training configuration for the banana test environment
2 |
3 | training:
4 | # General training options
5 | max_training_steps: 25000
6 |
7 | # Evaluations during training
8 | # Every eval_freq steps, n_eval_episodes episodes will be run.
9 | # If the mean cumulative reward is greater than reward_threshold,
10 | # training will complete.
11 | eval:
12 | n_eval_episodes: 10
13 | eval_freq: 1000 # Set to 2000 for more complicated sub-types
14 | reward_threshold: 9.5 # Mean reward to terminate training
15 |
16 | # Individual algorithm options
17 | DQN:
18 | gamma: 0.99
19 | exploration_initial_eps: 0.75
20 | exploration_final_eps: 0.05
21 | exploration_fraction: 0.25
22 | learning_starts: 100
23 | learning_rate: 0.0001
24 | batch_size: 32
25 | gradient_steps: 10
26 | train_freq: 4 # steps
27 | target_update_interval: 500
28 | policy_kwargs:
29 | activation_fn: torch.nn.ReLU
30 | net_arch: [64, 64]
31 |
32 | PPO:
33 | gamma: 0.99
34 | learning_rate: 0.0003
35 | batch_size: 32
36 | n_steps: 64
37 | policy_kwargs:
38 | activation_fn: torch.nn.ReLU
39 | net_arch:
40 | pi: [64, 64] # actor size
41 | vf: [32, 32] # critic size
42 |
43 | SAC:
44 | gamma: 0.99
45 | learning_rate: 0.0003
46 | batch_size: 32
47 | gradient_steps: 10
48 | train_freq: 4 # steps
49 | target_update_interval: 10
50 | policy_kwargs:
51 | activation_fn: torch.nn.ReLU
52 | net_arch:
53 | pi: [64, 64] # actor size
54 | qf: [32, 32] # critic size (SAC uses qf, not vf)
55 |
56 | A2C:
57 | gamma: 0.99
58 | learning_rate: 0.0007
59 | policy_kwargs:
60 | activation_fn: torch.nn.ReLU
61 | net_arch:
62 | pi: [64, 64] # actor size
63 | vf: [32, 32] # critic size
64 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "rl_deliberation"
3 | authors = [
4 | {name = "Sebastian Castro", email = "sebas.a.castro@gmail.com"},
5 | {name = "Christian Henkel", email = "christian.henkel2@de.bosch.com"},
6 | ]
7 | version = "0.1.0"
8 | requires-python = ">=3.12, <3.13"
9 | dependencies = ["astar>=0.99,<0.100"]
10 |
11 | [tool.pixi.workspace]
12 | preview = ["pixi-build"]
13 | channels = ["robostack-kilted", "conda-forge"]
14 | platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"]
15 |
16 | [tool.pixi.package.build]
17 | backend = { name = "pixi-build-ros", version = "*" }
18 |
19 | [tool.pixi.dependencies]
20 | python = ">=3.12, <3.13"
21 | colcon-common-extensions = ">=0.3.0,<0.4"
22 | compilers = ">=1.11.0,<2"
23 | pre-commit = ">=4.2.0,<5"
24 | rich = ">=14.1.0,<15"
25 | ros-kilted-ros-core = ">=0.12.0,<0.13"
26 | stable-baselines3 = ">=2.6.0,<3"
27 | setuptools = ">=78.1.1,<80.0.0"
28 | scipy = ">=1.16.0,<2"
29 | transforms3d = ">=0.4.2,<0.5"
30 | adjusttext = ">=1.3.0,<2"
31 | matplotlib = ">=3.10.5,<4"
32 | numpy = ">=1.26.4,<2"
33 | pycollada = ">=0.9.2,<0.10"
34 | pyside6 = ">=6.4.0"
35 | pyyaml = ">=6.0.3,<7"
36 | shapely = ">=2.0.1"
37 | trimesh = ">=4.7.1,<5"
38 | tqdm = ">=4.67.1,<5"
39 | tensorboard = "*"
40 | types-pyyaml = ">=6.0.12.20250915,<7"
41 |
42 | [tool.pixi.tasks]
43 | build = "colcon build --symlink-install"
44 | clean = "rm -rf build install log"
45 | lint = "pre-commit run -a"
46 | zenoh = "ros2 run rmw_zenoh_cpp rmw_zenohd"
47 | pyrobosim_demo = "ros2 launch pyrobosim_ros demo.launch.py"
48 | start_world = "python -m pyrobosim_ros_gym.start_world"
49 | train = "python -m pyrobosim_ros_gym.train"
50 | eval = "python -m pyrobosim_ros_gym.eval"
51 | tensorboard = "tensorboard --logdir ./train_logs/"
52 | policy_node = "python -m pyrobosim_ros_gym.policy_node"
53 |
54 | [tool.pixi.activation]
55 | scripts = ["install/setup.sh"]
56 |
57 | [tool.pixi.target.win-64.activation]
58 | scripts = ["install/setup.bat"]
59 |
--------------------------------------------------------------------------------
/rl_ws_worlds/rl_ws_worlds/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | """
10 | Runner for ROS 2 Deliberation workshop worlds.
11 | """
12 |
13 | import os
14 | import rclpy
15 | import threading
16 |
17 | from pyrobosim.core import WorldYamlLoader
18 | from pyrobosim.gui import start_gui
19 | from pyrobosim_ros.ros_interface import WorldROSWrapper
20 | from ament_index_python.packages import get_package_share_directory
21 |
22 |
23 | def create_ros_node() -> WorldROSWrapper:
24 | """Initializes ROS node"""
25 | rclpy.init()
26 | node = WorldROSWrapper(state_pub_rate=0.1, dynamics_rate=0.01)
27 | node.declare_parameter("world_name", "greenhouse")
28 | node.declare_parameter("headless", False)
29 |
30 | # Set the world file.
31 | world_name = node.get_parameter("world_name").value
32 | node.get_logger().info(f"Starting world '{world_name}'")
33 | world_file = os.path.join(
34 | get_package_share_directory("rl_ws_worlds"),
35 | "worlds",
36 | f"{world_name}.yaml",
37 | )
38 | world = WorldYamlLoader().from_file(world_file)
39 | node.set_world(world)
40 |
41 | return node
42 |
43 |
44 | def start_node(node: WorldROSWrapper):
45 | headless = node.get_parameter("headless").value
46 | if headless:
47 | # Start ROS node in main thread if there is no GUI.
48 | node.start(wait_for_gui=False)
49 | else:
50 | # Start ROS node in separate thread
51 | ros_thread = threading.Thread(target=lambda: node.start(wait_for_gui=True))
52 | ros_thread.start()
53 |
54 | # Start GUI in main thread
55 | start_gui(node.world)
56 |
57 |
58 | def main():
59 | node = create_ros_node()
60 | start_node(node)
61 |
62 |
63 | if __name__ == "__main__":
64 | main()
65 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/config/greenhouse_env_config.yaml:
--------------------------------------------------------------------------------
1 | # Training configuration for the greenhouse test environment
2 |
3 | training:
4 | # General training options
5 | max_training_steps: 10000
6 |
7 | # Reward function to use when stepping the environment
8 | # reward_fn: pyrobosim_ros_gym.envs.greenhouse.sparse_reward # Try this first
9 | # reward_fn: pyrobosim_ros_gym.envs.greenhouse.dense_reward # Try this second
10 | reward_fn: pyrobosim_ros_gym.envs.greenhouse.full_reward # End with this one
11 |
12 | # Evaluations during training
13 | # Every eval_freq steps, n_eval_episodes episodes will be run.
14 | # If the mean cumulative reward is greater than reward_threshold,
15 | # training will complete.
16 | eval:
17 | n_eval_episodes: 5
18 | eval_freq: 500
19 | reward_threshold: 7.0
20 |
21 | # Individual algorithm options
22 | DQN:
23 | gamma: 0.99
24 | exploration_initial_eps: 0.75
25 | exploration_final_eps: 0.05
26 | exploration_fraction: 0.2
27 | learning_starts: 25
28 | learning_rate: 0.0002
29 | batch_size: 16
30 | gradient_steps: 5
31 | train_freq: 4 # steps
32 | target_update_interval: 10
33 | policy_kwargs:
34 | activation_fn: torch.nn.ReLU
35 | net_arch: [16, 8]
36 |
37 | PPO:
38 | gamma: 0.99
39 | learning_rate: 0.0003
40 | batch_size: 8
41 | n_steps: 8
42 | policy_kwargs:
43 | activation_fn: torch.nn.ReLU
44 | net_arch:
45 | pi: [16, 8] # actor size
46 | vf: [16, 8] # critic size
47 |
48 | SAC:
49 | gamma: 0.99
50 | learning_rate: 0.0003
51 | batch_size: 16
52 | train_freq: 4 # steps
53 | gradient_steps: 5
54 | tau: 0.005
55 | target_update_interval: 10
56 | policy_kwargs:
57 | activation_fn: torch.nn.ReLU
58 | net_arch:
59 | pi: [16, 8] # actor size
60 | qf: [16, 8] # critic size (SAC uses qf, not vf)
61 |
62 | A2C:
63 | gamma: 0.99
64 | learning_rate: 0.0007
65 | policy_kwargs:
66 | activation_fn: torch.nn.ReLU
67 | net_arch:
68 | pi: [16, 8] # actor size
69 | vf: [16, 8] # critic size
70 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/start_world.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | """Loads a world to act as a server for the RL problem."""
10 |
11 | import argparse
12 | import rclpy
13 | import threading
14 |
15 | from pyrobosim.core import WorldYamlLoader
16 | from pyrobosim.gui import start_gui, WorldCanvasOptions
17 | from pyrobosim_ros.ros_interface import WorldROSWrapper
18 |
19 | from pyrobosim_ros_gym.envs import (
20 | get_env_class_and_subtype_from_name,
21 | available_envs_w_subtype,
22 | )
23 | from pyrobosim_ros_gym.envs.greenhouse import GreenhouseEnv
24 |
25 |
26 | def create_ros_node(world_file_path) -> WorldROSWrapper:
27 | """Initializes ROS node"""
28 | rclpy.init()
29 | world = WorldYamlLoader().from_file(world_file_path)
30 | return WorldROSWrapper(world=world, state_pub_rate=0.1, dynamics_rate=0.01)
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--headless", action="store_true", help="Enables headless world loading."
37 | )
38 | parser.add_argument(
39 | "--env",
40 | choices=available_envs_w_subtype(),
41 | help="The environment to use.",
42 | required=True,
43 | )
44 | args = parser.parse_args()
45 |
46 | env_class, sub_type = get_env_class_and_subtype_from_name(args.env)
47 | node = create_ros_node(env_class.get_world_file_path(sub_type))
48 | show_room_names = env_class != GreenhouseEnv
49 |
50 | if args.headless:
51 | # Start ROS node in main thread if there is no GUI.
52 | node.start(wait_for_gui=False)
53 | else:
54 | # Start ROS node in separate thread.
55 | ros_thread = threading.Thread(target=lambda: node.start(wait_for_gui=True))
56 | ros_thread.start()
57 |
58 | # Start GUI in main thread.
59 | options = WorldCanvasOptions(show_room_names=show_room_names)
60 | start_gui(node.world, options=options)
61 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # Runs pre-commit hooks and other file format checks.
3 | - repo: https://github.com/pre-commit/pre-commit-hooks
4 | rev: v6.0.0
5 | hooks:
6 | - id: check-added-large-files
7 | - id: check-ast
8 | - id: check-builtin-literals
9 | - id: check-case-conflict
10 | - id: check-docstring-first
11 | - id: check-executables-have-shebangs
12 | - id: check-json
13 | - id: check-merge-conflict
14 | - id: check-symlinks
15 | - id: check-toml
16 | - id: check-vcs-permalinks
17 | - id: check-yaml
18 | - id: debug-statements
19 | - id: destroyed-symlinks
20 | - id: detect-private-key
21 | - id: end-of-file-fixer
22 | - id: fix-byte-order-marker
23 | - id: forbid-new-submodules
24 | - id: mixed-line-ending
25 | - id: name-tests-test
26 | - id: requirements-txt-fixer
27 | - id: sort-simple-yaml
28 | - id: trailing-whitespace
29 | exclude: slides/main.md$
30 |
31 | # Autoformats Python code.
32 | - repo: https://github.com/psf/black.git
33 | rev: 25.9.0
34 | hooks:
35 | - id: black
36 | exclude: ^pyrobosim/
37 |
38 | # Type checking
39 | - repo: https://github.com/pre-commit/mirrors-mypy
40 | rev: v1.18.2
41 | hooks:
42 | - id: mypy
43 | args: [--ignore-missing-imports]
44 | additional_dependencies: [types-PyYAML]
45 |
46 | # Finds spelling issues in code.
47 | - repo: https://github.com/codespell-project/codespell
48 | rev: v2.4.1
49 | hooks:
50 | - id: codespell
51 | files: ^.*\.(py|md|yaml|yml)$
52 |
53 | # Finds issues in YAML files.
54 | - repo: https://github.com/adrienverge/yamllint
55 | rev: v1.37.1
56 | hooks:
57 | - id: yamllint
58 | args:
59 | [
60 | "--no-warnings",
61 | "--config-data",
62 | "{extends: default, rules: {line-length: disable, braces:
63 | {max-spaces-inside: 1}, indentation: disable}}",
64 | ]
65 | types: [text]
66 | files: \.(yml|yaml)$
67 |
68 | # Checks links in markdown files.
69 | - repo: https://github.com/tcort/markdown-link-check
70 | rev: v3.13.7
71 | hooks:
72 | - id: markdown-link-check
73 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policies/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | import os
10 |
11 | from gymnasium import Space
12 | from stable_baselines3 import DQN, PPO, SAC, A2C
13 | from stable_baselines3.common.base_class import BaseAlgorithm
14 |
15 |
16 | AVAILABLE_POLICIES = {alg.__name__: alg for alg in (DQN, PPO, SAC, A2C)}
17 |
18 |
19 | class ManualPolicy:
20 | """A policy that allows manual keyboard control of the robot."""
21 |
22 | def __init__(self, action_space: Space):
23 | print("Welcome. You are the agent now!")
24 | self.action_space = action_space
25 |
26 | def predict(self, observation, deterministic):
27 | # print(f"Observation: {observation}")
28 | # print(f"Action space: {self.action_space}")
29 | possible_actions = list(range(self.action_space.n))
30 | while True:
31 | try:
32 | action = int(input(f"Enter action from {possible_actions}: "))
33 | if action in possible_actions:
34 | return action, None
35 | else:
36 | print(f"Action {action} not in {possible_actions}.")
37 | except ValueError:
38 | print("Invalid input, please enter an integer.")
39 |
40 |
41 | def model_and_env_type_from_path(model_path: str) -> tuple[BaseAlgorithm, str]:
42 | """
43 | Loads a model and its corresponding environment type from its file path.
44 |
45 | The models are of the form _[_].pt.
46 | For example, path/to/model/GreenhousePlain_DQN_seed42_2025_10_18_18_02_21.pt.
47 | """
48 | # Validate the file path to be of the right form
49 | assert os.path.isfile(model_path), f"Model {model_path} must be a valid file."
50 | model_fname = os.path.basename(model_path)
51 | model_name_parts = model_fname.split("_")
52 | assert (
53 | len(model_name_parts) >= 2
54 | ), f"Model name {model_fname} must be of the form _[_].pt"
55 | env_type = model_name_parts[0]
56 | algorithm = model_name_parts[1]
57 |
58 | # Load the model
59 | if algorithm in AVAILABLE_POLICIES:
60 | model_class = AVAILABLE_POLICIES[algorithm]
61 | model = model_class.load(model_path)
62 | else:
63 | raise RuntimeError(f"Invalid algorithm type: {algorithm}")
64 |
65 | return (model, env_type)
66 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | from typing import Any, Callable, List, Dict
10 |
11 | import rclpy
12 | from rclpy.executors import Executor
13 |
14 | from .banana import BananaEnv
15 | from .greenhouse import GreenhouseEnv
16 | from .pyrobosim_ros_env import PyRoboSimRosEnv
17 |
18 |
19 | ENV_CLASS_FROM_NAME: Dict[str, type[PyRoboSimRosEnv]] = {
20 | "Banana": BananaEnv,
21 | "Greenhouse": GreenhouseEnv,
22 | }
23 |
24 |
25 | def available_envs_w_subtype() -> List[str]:
26 | """Return a list of environment types including subtypes."""
27 | envs: List[str] = []
28 | for name, env_class in ENV_CLASS_FROM_NAME.items():
29 | for sub_type in env_class.sub_types:
30 | envs.append("".join((name, sub_type.name)))
31 | return envs
32 |
33 |
34 | def get_env_class_and_subtype_from_name(req_name: str):
35 | """Return the class of a chosen environment name (ignoring `sub_type`s)."""
36 | for name, env_class in ENV_CLASS_FROM_NAME.items():
37 | if req_name.startswith(name):
38 | sub_type_str = req_name.replace(name, "")
39 | for st in env_class.sub_types:
40 | if st.name == sub_type_str:
41 | sub_type = st
42 | return env_class, sub_type
43 | raise RuntimeError(f"No environment found for {req_name}.")
44 |
45 |
46 | def get_env_by_name(
47 | env_name: str,
48 | node: rclpy.node.Node,
49 | max_steps_per_episode: int,
50 | realtime: bool,
51 | discrete_actions: bool,
52 | reward_fn: Callable[..., Any],
53 | executor: Executor | None = None,
54 | ) -> PyRoboSimRosEnv:
55 | """
56 | Instantiate an environment class for a given type and `sub_type`.
57 |
58 | :param env_name: Name of environment, with subtype, e.g. BananaPick.
59 | :param node: Node instance needed for ROS communication.
60 | :param max_steps_per_episode: Limit the steps (when to end the episode).
61 | :param realtime: Whether actions take time.
62 | :param discrete_actions: Choose discrete actions (needed for DQN).
63 | :param reward_fn: The function used to compute the reward at each step.
64 | :param executor: Optional ROS executor. It must be already spinning!
65 | """
66 | base_class, sub_type = get_env_class_and_subtype_from_name(env_name)
67 | return base_class(
68 | sub_type,
69 | node,
70 | max_steps_per_episode,
71 | realtime,
72 | discrete_actions,
73 | reward_fn,
74 | executor=executor,
75 | )
76 |
--------------------------------------------------------------------------------
/rl_ws_worlds/worlds/greenhouse_random.yaml:
--------------------------------------------------------------------------------
1 | # Problem 1 World
2 |
3 | metadata:
4 | locations: $PWD/greenhouse_location_data.yaml
5 | objects: $PWD/greenhouse_object_data.yaml
6 |
7 | params:
8 | name: greenhouse
9 | inflation_radius: 0.01
10 | object_radius: 0.01
11 |
12 | robots:
13 | - name: robot
14 | radius: 0.1
15 | location: hall
16 | path_executor:
17 | type: constant_velocity
18 | path_planner:
19 | type: astar
20 | grid_resolution: 0.1
21 | grid_inflation_radius: 0.1
22 |
23 | rooms:
24 | - name: hall
25 | footprint:
26 | type: polygon
27 | coords:
28 | - [-3, -3]
29 | - [3, -3]
30 | - [3, 3]
31 | - [-3, 3]
32 | nav_poses:
33 | - position: # left
34 | x: -2.0
35 | y: 0.0
36 | - position: # right
37 | x: 2.0
38 | y: 0.0
39 | rotation_eul:
40 | yaw: 3.14
41 | wall_width: 0.2
42 | color: [0.1, 0.1, 0]
43 |
44 | locations:
45 | - name: table_nw
46 | parent: hall
47 | category: table
48 | pose:
49 | position:
50 | x: -2
51 | y: 2
52 | - name: table_n
53 | parent: hall
54 | category: table
55 | pose:
56 | position:
57 | x: 0
58 | y: 2
59 | - name: table_ne
60 | parent: hall
61 | category: table
62 | pose:
63 | position:
64 | x: 2
65 | y: 2
66 | - name: table_w
67 | parent: hall
68 | category: table
69 | pose:
70 | position:
71 | x: -2
72 | y: 0
73 | - name: table_c
74 | parent: hall
75 | category: table
76 | pose:
77 | position:
78 | x: 0
79 | y: 0
80 | - name: table_e
81 | parent: hall
82 | category: table
83 | pose:
84 | position:
85 | x: 2
86 | y: 0
87 | - name: table_sw
88 | parent: hall
89 | category: table
90 | pose:
91 | position:
92 | x: -2
93 | y: -2
94 | - name: table_s
95 | parent: hall
96 | category: table
97 | pose:
98 | position:
99 | x: 0
100 | y: -2
101 | - name: table_se
102 | parent: hall
103 | category: table
104 | pose:
105 | position:
106 | x: 2
107 | y: -2
108 |
109 | objects:
110 | - name: plant0
111 | parent: table
112 | category: plant_good
113 | - name: plant1
114 | parent: table
115 | category: plant_evil
116 | - name: plant2
117 | parent: table
118 | category: plant_good
119 | - name: plant3
120 | parent: table
121 | category: plant_good
122 | - name: plant4
123 | parent: table
124 | category: plant_good
125 |
--------------------------------------------------------------------------------
/rl_ws_worlds/worlds/greenhouse_plain.yaml:
--------------------------------------------------------------------------------
1 | # Problem 1 World
2 |
3 | metadata:
4 | locations: $PWD/greenhouse_location_data.yaml
5 | objects: $PWD/greenhouse_object_data.yaml
6 |
7 | params:
8 | name: greenhouse
9 | inflation_radius: 0.01
10 | object_radius: 0.01
11 |
12 | robots:
13 | - name: robot
14 | radius: 0.1
15 | location: hall
16 | path_executor:
17 | type: constant_velocity
18 | path_planner:
19 | type: astar
20 | grid_resolution: 0.1
21 | grid_inflation_radius: 0.1
22 |
23 | rooms:
24 | - name: hall
25 | footprint:
26 | type: polygon
27 | coords:
28 | - [-3, -3]
29 | - [3, -3]
30 | - [3, 3]
31 | - [-3, 3]
32 | nav_poses:
33 | - position: # left
34 | x: -2.0
35 | y: 0.0
36 | - position: # right
37 | x: 2.0
38 | y: 0.0
39 | rotation_eul:
40 | yaw: 3.14
41 | wall_width: 0.2
42 | color: [0.1, 0.1, 0]
43 |
44 | locations:
45 | - name: table_nw
46 | parent: hall
47 | category: table
48 | pose:
49 | position:
50 | x: -2
51 | y: 2
52 | - name: table_n
53 | parent: hall
54 | category: table
55 | pose:
56 | position:
57 | x: 0
58 | y: 2
59 | - name: table_ne
60 | parent: hall
61 | category: table
62 | pose:
63 | position:
64 | x: 2
65 | y: 2
66 | - name: table_w
67 | parent: hall
68 | category: table
69 | pose:
70 | position:
71 | x: -2
72 | y: 0
73 | - name: table_c
74 | parent: hall
75 | category: table
76 | pose:
77 | position:
78 | x: 0
79 | y: 0
80 | - name: table_e
81 | parent: hall
82 | category: table
83 | pose:
84 | position:
85 | x: 2
86 | y: 0
87 | - name: table_sw
88 | parent: hall
89 | category: table
90 | pose:
91 | position:
92 | x: -2
93 | y: -2
94 | - name: table_s
95 | parent: hall
96 | category: table
97 | pose:
98 | position:
99 | x: 0
100 | y: -2
101 | - name: table_se
102 | parent: hall
103 | category: table
104 | pose:
105 | position:
106 | x: 2
107 | y: -2
108 |
109 | objects:
110 | - name: plant0
111 | parent: table_nw
112 | category: plant_good
113 | - name: plant1
114 | parent: table_ne
115 | category: plant_evil
116 | - name: plant2
117 | parent: table_c
118 | category: plant_good
119 | - name: plant3
120 | parent: table_sw
121 | category: plant_good
122 | - name: plant4
123 | parent: table_se
124 | category: plant_good
125 |
--------------------------------------------------------------------------------
/rl_ws_worlds/worlds/greenhouse_battery.yaml:
--------------------------------------------------------------------------------
1 | # Problem 1 World
2 |
3 | metadata:
4 | locations: $PWD/greenhouse_location_data.yaml
5 | objects: $PWD/greenhouse_object_data.yaml
6 |
7 | params:
8 | name: greenhouse
9 | inflation_radius: 0.01
10 | object_radius: 0.01
11 |
12 | robots:
13 | - name: robot
14 | radius: 0.1
15 | location: hall
16 | path_executor:
17 | type: constant_velocity
18 | path_planner:
19 | type: astar
20 | grid_resolution: 0.1
21 | grid_inflation_radius: 0.1
22 | action_execution_options:
23 | close:
24 | battery_usage: 49 # 2 close actions per charge
25 |
26 | rooms:
27 | - name: hall
28 | footprint:
29 | type: polygon
30 | coords:
31 | - [-3, -3]
32 | - [3, -3]
33 | - [3, 3]
34 | - [-3, 3]
35 | nav_poses:
36 | - position: # left
37 | x: -2.0
38 | y: 0.0
39 | - position: # right
40 | x: 2.0
41 | y: 0.0
42 | rotation_eul:
43 | yaw: 3.14
44 | wall_width: 0.2
45 | color: [0.1, 0.1, 0]
46 |
47 | locations:
48 | - name: table_nw
49 | parent: hall
50 | category: table
51 | pose:
52 | position:
53 | x: -2
54 | y: 2
55 | - name: table_n
56 | parent: hall
57 | category: table
58 | pose:
59 | position:
60 | x: 0
61 | y: 2
62 | - name: table_ne
63 | parent: hall
64 | category: table
65 | pose:
66 | position:
67 | x: 2
68 | y: 2
69 | - name: table_w
70 | parent: hall
71 | category: table
72 | pose:
73 | position:
74 | x: -2
75 | y: 0
76 | - name: table_c
77 | parent: hall
78 | category: table
79 | pose:
80 | position:
81 | x: 0
82 | y: 0
83 | - name: table_e
84 | parent: hall
85 | category: table
86 | pose:
87 | position:
88 | x: 2
89 | y: 0
90 | - name: table_sw
91 | parent: hall
92 | category: table
93 | pose:
94 | position:
95 | x: -2
96 | y: -2
97 | - name: table_s
98 | parent: hall
99 | category: table
100 | pose:
101 | position:
102 | x: 0
103 | y: -2
104 | - name: table_se
105 | parent: hall
106 | category: table
107 | pose:
108 | position:
109 | x: 2
110 | y: -2
111 |
112 | - name: charger
113 | category: charger
114 | parent: hall
115 | pose:
116 | position:
117 | x: 1.0
118 | y: -2.8
119 | is_charger: True
120 |
121 | objects:
122 | - name: plant0
123 | parent: table
124 | category: plant_good
125 | - name: plant1
126 | parent: table
127 | category: plant_evil
128 | - name: plant2
129 | parent: table
130 | category: plant_good
131 | - name: plant3
132 | parent: table
133 | category: plant_good
134 | - name: plant4
135 | parent: table
136 | category: plant_good
137 |
--------------------------------------------------------------------------------
/rl_ws_worlds/worlds/banana.yaml:
--------------------------------------------------------------------------------
1 | ##########################
2 | # Test world description #
3 | ##########################
4 |
5 | # WORLD PARAMETERS
6 | params:
7 | name: banana
8 | object_radius: 0.01 # Radius around objects
9 | wall_height: 2.0 # Wall height for exporting to Gazebo
10 |
11 |
12 | # METADATA: Describes information about locations and objects
13 | metadata:
14 | locations:
15 | - $DATA/example_location_data_furniture.yaml
16 | - $DATA/example_location_data_accessories.yaml
17 | objects:
18 | - $DATA/example_object_data_food.yaml
19 | - $DATA/example_object_data_drink.yaml
20 | # ROBOTS
21 | robots:
22 | - name: robot
23 | radius: 0.1
24 | color: "#CC00CC"
25 | max_linear_velocity: 3.0
26 | max_angular_velocity: 6.0
27 | max_linear_acceleration: 10.0
28 | max_angular_acceleration: 10.0
29 | path_planner:
30 | type: world_graph
31 | collision_check_step_dist: 0.05
32 | path_executor:
33 | type: constant_velocity
34 | linear_velocity: 3.0
35 | max_angular_velocity: 6.0
36 | dt: 0.1
37 | validate_during_execution: false
38 | location: ["desk", "counter", "table"]
39 |
40 | # ROOMS: Polygonal regions that can contain object locations
41 | rooms:
42 | - name: kitchen
43 | footprint:
44 | type: polygon
45 | coords:
46 | - [-1, -1]
47 | - [1.5, -1]
48 | - [1.5, 1.5]
49 | - [0.5, 1.5]
50 | nav_poses:
51 | - position:
52 | x: 0.75
53 | y: 0.5
54 | wall_width: 0.2
55 | color: "red"
56 |
57 | - name: bedroom
58 | footprint:
59 | type: box
60 | dims: [1.75, 1.5]
61 | pose:
62 | position:
63 | x: 2.625
64 | y: 3.25
65 | wall_width: 0.2
66 | color: "#009900"
67 |
68 | - name: bathroom
69 | footprint:
70 | type: polygon
71 | coords:
72 | - [-1, 1]
73 | - [-1, 3.5]
74 | - [-3, 3.5]
75 | - [-2.5, 1]
76 | wall_width: 0.2
77 | color: [0, 0, 0.6]
78 |
79 |
80 | # HALLWAYS: Connect rooms
81 | hallways:
82 | - room_start: kitchen
83 | room_end: bathroom
84 | width: 0.7
85 | conn_method: auto
86 | is_open: True
87 | is_locked: False
88 | color: "#666666"
89 |
90 | - room_start: bathroom
91 | room_end: bedroom
92 | width: 0.5
93 | conn_method: angle
94 | conn_angle: 0.0
95 | offset: 0.8
96 | is_open: True
97 | is_locked: False
98 | color: "dimgray"
99 |
100 | - room_start: kitchen
101 | room_end: bedroom
102 | width: 0.6
103 | conn_method: points
104 | conn_points:
105 | - [1.0, 0.5]
106 | - [2.5, 0.5]
107 | - [2.5, 3.0]
108 | is_open: True
109 | is_locked: False
110 |
111 |
112 | # LOCATIONS: Can contain objects
113 | locations:
114 | - name: table0
115 | category: table
116 | parent: kitchen
117 | pose:
118 | position:
119 | x: 0.85
120 | y: -0.5
121 | rotation_eul:
122 | yaw: -90.0
123 | angle_units: "degrees"
124 | is_open: True
125 | is_locked: True
126 |
127 | - name: my_desk
128 | category: desk
129 | parent: bedroom
130 | pose:
131 | position:
132 | x: 0.525
133 | y: 0.4
134 | relative_to: bedroom
135 | is_open: True
136 | is_locked: False
137 |
138 | - name: counter0
139 | category: counter
140 | parent: bathroom
141 | pose:
142 | position:
143 | x: -2.45
144 | y: 2.5
145 | rotation_eul:
146 | yaw: 101.2
147 | angle_units: "degrees"
148 | is_open: True
149 | is_locked: True
150 |
151 |
152 | # OBJECTS: Can be picked, placed, and moved by robot
153 | objects:
154 | - category: apple
155 |
156 | - category: apple
157 |
158 | - category: water
159 |
160 | - category: water
161 |
162 | - category: banana
163 | parent: ["desk", "counter"]
164 |
165 | - category: banana
166 | parent: ["desk", "counter"]
167 |
168 | - name: soda
169 | category: coke
170 | parent: ["desk", "counter"]
171 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/policy_node.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | """Serves a policy as a ROS node with an action server for deployment."""
10 |
11 | import argparse
12 | import time
13 |
14 | from gymnasium.spaces import Discrete
15 | import rclpy
16 | from rclpy.action import ActionServer, CancelResponse
17 | from rclpy.executors import ExternalShutdownException, MultiThreadedExecutor
18 | from rclpy.node import Node
19 |
20 | from rl_interfaces.action import ExecutePolicy # type: ignore[attr-defined]
21 |
22 | from pyrobosim_ros_gym import get_config
23 | from pyrobosim_ros_gym.envs import get_env_by_name
24 | from pyrobosim_ros_gym.policies import model_and_env_type_from_path
25 |
26 |
27 | class PolicyServerNode(Node):
28 | def __init__(self, args: argparse.Namespace, executor):
29 | super().__init__("policy_node")
30 |
31 | # Load the model and environment
32 | self.model, env_type = model_and_env_type_from_path(args.model)
33 | self.env = get_env_by_name(
34 | env_type,
35 | self,
36 | executor=executor,
37 | max_steps_per_episode=-1,
38 | realtime=True,
39 | discrete_actions=isinstance(self.model.action_space, Discrete),
40 | reward_fn=get_config(args.config)["training"].get("reward_fn"),
41 | )
42 |
43 | self.action_server = ActionServer(
44 | self,
45 | ExecutePolicy,
46 | "/execute_policy",
47 | execute_callback=self.execute_policy,
48 | cancel_callback=self.cancel_policy,
49 | )
50 |
51 | self.get_logger().info(f"Started policy node with model '{args.model}'.")
52 |
53 | def cancel_policy(self, goal_handle):
54 | self.get_logger().info("Canceling policy execution...")
55 | return CancelResponse.ACCEPT
56 |
57 | async def execute_policy(self, goal_handle):
58 | self.get_logger().info("Starting policy execution...")
59 | result = ExecutePolicy.Result()
60 |
61 | self.env.initialize() # Resets helper variables
62 | obs = self.env._get_obs()
63 | cumulative_reward = 0.0
64 | while True:
65 | if goal_handle.is_cancel_requested:
66 | goal_handle.canceled()
67 | self.get_logger().info("Policy execution canceled")
68 | return result
69 |
70 | action, _ = self.model.predict(obs, deterministic=True)
71 | obs, reward, terminated, truncated, info = self.env.step(action)
72 | num_steps = self.env.step_number
73 | cumulative_reward += reward
74 | self.get_logger().info(
75 | f"Step {num_steps}: cumulative reward = {cumulative_reward}"
76 | )
77 |
78 | goal_handle.publish_feedback(
79 | ExecutePolicy.Feedback(
80 | num_steps=num_steps, cumulative_reward=cumulative_reward
81 | )
82 | )
83 |
84 | if terminated or truncated:
85 | break
86 |
87 | time.sleep(0.1) # Small sleep between actions
88 |
89 | goal_handle.succeed()
90 | result.success = info["success"]
91 | result.num_steps = num_steps
92 | result.cumulative_reward = cumulative_reward
93 | self.get_logger().info(f"Policy completed in {num_steps} steps.")
94 | self.get_logger().info(
95 | f"success: {result.success}, cumulative reward: {cumulative_reward}"
96 | )
97 | return result
98 |
99 |
100 | def main(args=None):
101 | parser = argparse.ArgumentParser()
102 | parser.add_argument(
103 | "--model", required=True, help="The name of the model to serve."
104 | )
105 | parser.add_argument(
106 | "--config",
107 | help="Path to the configuration YAML file.",
108 | required=True,
109 | )
110 | cli_args = parser.parse_args()
111 |
112 | rclpy.init(args=args)
113 | try:
114 | executor = MultiThreadedExecutor(num_threads=2)
115 | import threading
116 |
117 | threading.Thread(target=executor.spin).start()
118 | policy_server = PolicyServerNode(cli_args, executor)
119 | while True:
120 | time.sleep(0.5)
121 | except (KeyboardInterrupt, ExternalShutdownException):
122 | pass
123 |
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 |
3 | on:
4 | # Run action on certain pull request events
5 | pull_request:
6 | types: [opened, synchronize, reopened, ready_for_review]
7 |
8 | # After pushing to main
9 | push:
10 | branches: [main]
11 |
12 | # Nightly job on default (main) branch
13 | schedule:
14 | - cron: '0 0 * * *'
15 |
16 | # Allow tests to be run manually
17 | workflow_dispatch:
18 |
19 | env:
20 | REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/rl_deliberation
21 | RAW_BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
22 |
23 | jobs:
24 | build_and_test:
25 | runs-on: ubuntu-latest
26 | steps:
27 | - name: Checkout repo
28 | uses: actions/checkout@v4
29 | with:
30 | submodules: true
31 | - name: Setup pixi
32 | uses: prefix-dev/setup-pixi@v0.8.11
33 | with:
34 | cache: true
35 | cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
36 | - name: Run formatting
37 | run: pixi run lint
38 | - name: Build the environment
39 | run: pixi run build
40 | docker:
41 | runs-on: ubuntu-latest
42 | steps:
43 | - name: Checkout repo
44 | uses: actions/checkout@v4
45 | - run: docker build -t ${{ env.REGISTRY_IMAGE }}:latest -f .docker/Dockerfile .
46 | - name: Push to gh registry with sha and branch
47 | # Do not push on PRs from forks.
48 | if: ${{ !github.event.pull_request.head.repo.fork }}
49 | run: |
50 | echo ${{ secrets.GITHUB_TOKEN }} | docker login ghcr.io -u ${{ github.repository_owner }} --password-stdin
51 | # tag with sha and push
52 | docker tag ${{ env.REGISTRY_IMAGE }}:latest ${{ env.REGISTRY_IMAGE }}:${{ github.sha }}
53 | docker push ${{ env.REGISTRY_IMAGE }}:${{ github.sha }}
54 | # tag with branch name and push
55 | export BRANCH_NAME=$(echo ${{ env.RAW_BRANCH_NAME }} | sed 's/\//-/g')
56 | docker tag ${{ env.REGISTRY_IMAGE }}:latest ${{ env.REGISTRY_IMAGE }}:$BRANCH_NAME
57 | docker push ${{ env.REGISTRY_IMAGE }}:$BRANCH_NAME
58 | echo "TAG_NAME=$BRANCH_NAME" >> $GITHUB_ENV
59 | - name: Push to gh registry with latest if this is main
60 | run: |
61 | echo ${{ secrets.GITHUB_TOKEN }} | docker login ghcr.io -u ${{ github.repository_owner }} --password-stdin
62 | # only from main we actually push to latest
63 | docker push ${{ env.REGISTRY_IMAGE }}:latest
64 | echo "TAG_NAME=latest" >> $GITHUB_ENV
65 | if: github.ref == 'refs/heads/main'
66 | slides:
67 | runs-on: ubuntu-latest
68 | steps:
69 | - name: Checkout repo
70 | uses: actions/checkout@v4
71 | - name: Copy slides to main directory
72 | run: cp -r slides/* .
73 | - name: Run pandoc
74 | uses: docker://pandoc/latex:3.7
75 | with:
76 | args: "-t beamer main.md -o slides.pdf --listings --slide-level=2"
77 | - name: Upload slides
78 | uses: actions/upload-artifact@main
79 | with:
80 | name: slides.pdf
81 | path: slides.pdf
82 | - name: Non-latest slides as release
83 | uses: ncipollo/release-action@v1
84 | with:
85 | artifacts: 'slides.pdf'
86 | token: ${{ secrets.GITHUB_TOKEN }}
87 | commit: ${{ github.sha }}
88 | allowUpdates: true
89 | name: Workshop Slides
90 | tag: slides-pr${{ github.event.pull_request.number }}
91 | prerelease: true
92 | if: github.event_name == 'pull_request'
93 | - name: Latest slides as release
94 | uses: ncipollo/release-action@v1
95 | with:
96 | artifacts: 'slides.pdf'
97 | token: ${{ secrets.GITHUB_TOKEN }}
98 | commit: ${{ github.sha }}
99 | allowUpdates: true
100 | name: Workshop Slides
101 | tag: slides-latest
102 | makeLatest: true
103 | if: github.ref == 'refs/heads/main'
104 | pytest:
105 | runs-on: ubuntu-latest
106 | steps:
107 | - name: Checkout repo
108 | uses: actions/checkout@v4
109 | - name: Install python
110 | uses: actions/setup-python@v6
111 | with:
112 | python-version: '3.11'
113 | - name: Install dependencies
114 | run: pip install pytest pytest-md pytest-emoji sybil docker
115 | - name: Set up Docker
116 | uses: docker/setup-docker-action@v4
117 | - name: Run pytest
118 | uses: pavelzw/pytest-action@v2
119 | with:
120 | verbose: true
121 | emoji: true
122 | job-summary: true
123 | custom-arguments: '-vvs test/'
124 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | """Trains an RL policy."""
10 |
11 | import argparse
12 | from datetime import datetime
13 |
14 | import rclpy
15 | from rclpy.node import Node
16 | from stable_baselines3 import DQN, PPO, SAC, A2C
17 | from stable_baselines3.common.callbacks import (
18 | EvalCallback,
19 | StopTrainingOnRewardThreshold,
20 | )
21 | from stable_baselines3.common.base_class import BaseAlgorithm
22 |
23 | from pyrobosim_ros_gym import get_config
24 | from pyrobosim_ros_gym.envs import get_env_by_name, available_envs_w_subtype
25 |
26 |
27 | def get_args() -> argparse.Namespace:
28 | """Helper function to parse the command-line arguments."""
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument(
31 | "--env",
32 | choices=available_envs_w_subtype(),
33 | help="The environment to use.",
34 | required=True,
35 | )
36 | parser.add_argument(
37 | "--config",
38 | help="Path to the configuration YAML file.",
39 | required=True,
40 | )
41 | parser.add_argument(
42 | "--algorithm",
43 | default="DQN",
44 | choices=["DQN", "PPO", "SAC", "A2C"],
45 | help="The algorithm with which to train a model.",
46 | )
47 | parser.add_argument(
48 | "--discrete-actions",
49 | action="store_true",
50 | help="If true, uses discrete action space. Otherwise, uses continuous action space.",
51 | )
52 | parser.add_argument("--seed", default=42, type=int, help="The RNG seed to use.")
53 | parser.add_argument(
54 | "--realtime", action="store_true", help="If true, slows down to real time."
55 | )
56 | parser.add_argument(
57 | "--log",
58 | default=True,
59 | action="store_true",
60 | help="If true, logs data to Tensorboard.",
61 | )
62 | args = parser.parse_args()
63 | return args
64 |
65 |
66 | if __name__ == "__main__":
67 | args = get_args()
68 | config = get_config(args.config)
69 |
70 | # Create the environment
71 | rclpy.init()
72 | node = Node("pyrobosim_ros_env")
73 | env = get_env_by_name(
74 | args.env,
75 | node,
76 | max_steps_per_episode=25,
77 | realtime=args.realtime,
78 | discrete_actions=args.discrete_actions,
79 | reward_fn=config["training"].get("reward_fn"),
80 | )
81 |
82 | # Train a model
83 | log_path = "train_logs" if args.log else None
84 | if args.algorithm == "DQN":
85 | dqn_config = config.get("training", {}).get("DQN", {})
86 | model: BaseAlgorithm = DQN(
87 | "MlpPolicy",
88 | env=env,
89 | seed=args.seed,
90 | tensorboard_log=log_path,
91 | **dqn_config,
92 | )
93 | elif args.algorithm == "PPO":
94 | ppo_config = config.get("training", {}).get("PPO", {})
95 | model = PPO(
96 | "MlpPolicy",
97 | env=env,
98 | seed=args.seed,
99 | tensorboard_log=log_path,
100 | **ppo_config,
101 | )
102 | elif args.algorithm == "SAC":
103 | sac_config = config.get("training", {}).get("SAC", {})
104 | model = SAC(
105 | "MlpPolicy",
106 | env=env,
107 | seed=args.seed,
108 | tensorboard_log=log_path,
109 | **sac_config,
110 | )
111 | elif args.algorithm == "A2C":
112 | a2c_config = config.get("training", {}).get("A2C", {})
113 | model = A2C(
114 | "MlpPolicy",
115 | env=env,
116 | seed=args.seed,
117 | tensorboard_log=log_path,
118 | **a2c_config,
119 | )
120 | else:
121 | raise RuntimeError(f"Invalid algorithm type: {args.algorithm}")
122 | print(f"\nTraining with {args.algorithm}...\n")
123 |
124 | # Train the model until it exceeds a specified reward threshold in evals.
125 | training_config = config["training"]
126 | callback_on_best = StopTrainingOnRewardThreshold(
127 | reward_threshold=training_config["eval"]["reward_threshold"],
128 | verbose=1,
129 | )
130 | eval_callback = EvalCallback(
131 | env,
132 | callback_on_new_best=callback_on_best,
133 | verbose=1,
134 | eval_freq=training_config["eval"]["eval_freq"],
135 | n_eval_episodes=training_config["eval"]["n_eval_episodes"],
136 | )
137 |
138 | date_str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
139 | log_name = f"{args.env}_{args.algorithm}_seed{args.seed}_{date_str}"
140 | model.learn(
141 | total_timesteps=training_config["max_training_steps"],
142 | progress_bar=True,
143 | tb_log_name=log_name,
144 | log_interval=1,
145 | callback=eval_callback,
146 | )
147 |
148 | # Save the trained model
149 | model_name = f"{log_name}.pt"
150 | model.save(model_name)
151 | print(f"\nSaved model to {model_name}\n")
152 |
153 | rclpy.shutdown()
154 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/envs/pyrobosim_ros_env.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | import time
10 | from enum import Enum
11 | from functools import partial
12 |
13 | import gymnasium as gym
14 | import rclpy
15 | from rclpy.action import ActionClient
16 |
17 | from pyrobosim_msgs.action import ExecuteTaskAction
18 | from pyrobosim_msgs.srv import (
19 | RequestWorldInfo,
20 | RequestWorldState,
21 | ResetWorld,
22 | SetLocationState,
23 | )
24 |
25 |
26 | class PyRoboSimRosEnv(gym.Env):
27 | """Gym environment wrapping around the PyRoboSim ROS Interface."""
28 |
29 | sub_types = Enum("sub_types", "DEFINE_IN_SUBCLASS")
30 |
31 | def __init__(
32 | self,
33 | node,
34 | reward_fn,
35 | reset_validation_fn=None,
36 | max_steps_per_episode=50,
37 | realtime=True,
38 | discrete_actions=True,
39 | executor=None,
40 | ):
41 | """
42 | Instantiates a PyRoboSim ROS environment.
43 |
44 | :param node: The ROS node to use for creating clients.
45 | :param reward_fn: Function that calculates the reward (and possibly other outputs).
46 | :param reset_validation_fn: Function that calculates whether a reset is valid.
47 | If None (default), all resets are valid.
48 | :param max_steps_per_episode: Maximum number of steps before truncating an episode.
49 | If -1, there is no limit to number of steps.
50 | :param realtime: If True, commands PyRoboSim to run actions in real time.
51 | If False, actions run as quickly as possible for faster training.
52 | :param discrete_actions: If True, uses discrete actions, else uses continuous.
53 | :param executor: Optional ROS executor. It must be already spinning!
54 | """
55 | super().__init__()
56 | self.node = node
57 | self.executor = executor
58 | if self.executor is not None:
59 | self.executor.add_node(self.node)
60 | self.executor.wake()
61 |
62 | self.realtime = realtime
63 | self.max_steps_per_episode = max_steps_per_episode
64 | self.discrete_actions = discrete_actions
65 |
66 | if reward_fn is None:
67 | self.reward_fn = lambda _: 0.0
68 | else:
69 | self.reward_fn = partial(reward_fn, self)
70 |
71 | if reset_validation_fn is None:
72 | self.reset_validation_fn = lambda: True
73 | else:
74 | self.reset_validation_fn = lambda: reset_validation_fn(self)
75 |
76 | self.step_number = 0
77 | self.previous_location = None
78 | self.previous_action_type = None
79 |
80 | self.request_info_client = node.create_client(
81 | RequestWorldInfo, "/request_world_info"
82 | )
83 | self.request_state_client = node.create_client(
84 | RequestWorldState, "/request_world_state"
85 | )
86 | self.execute_action_client = ActionClient(
87 | node, ExecuteTaskAction, "/execute_action"
88 | )
89 | self.reset_world_client = node.create_client(ResetWorld, "reset_world")
90 | self.set_location_state_client = node.create_client(
91 | SetLocationState, "set_location_state"
92 | )
93 |
94 | self.request_info_client.wait_for_service()
95 | self.request_state_client.wait_for_service()
96 | self.execute_action_client.wait_for_server()
97 | self.reset_world_client.wait_for_service()
98 | self.set_location_state_client.wait_for_service()
99 |
100 | future = self.request_info_client.call_async(RequestWorldInfo.Request())
101 | self._spin_future(future)
102 | self.world_info = future.result().info
103 |
104 | future = self.request_state_client.call_async(RequestWorldState.Request())
105 | self._spin_future(future)
106 | self.world_state = future.result().state
107 |
108 | self.all_locations = []
109 | for loc in self.world_state.locations:
110 | self.all_locations.extend(loc.spawns)
111 | self.num_locations = sum(len(loc.spawns) for loc in self.world_state.locations)
112 | self.loc_to_idx = {loc: idx for idx, loc in enumerate(self.all_locations)}
113 | print(f"{self.all_locations=}")
114 |
115 | self.action_space = self._action_space()
116 | print(f"{self.action_space=}")
117 |
118 | def _spin_future(self, future):
119 | if self.executor is None:
120 | rclpy.spin_until_future_complete(self.node, future)
121 | else:
122 | while not future.done():
123 | time.sleep(0.1)
124 |
125 | def _action_space(self):
126 | raise NotImplementedError("implement in sub-class")
127 |
128 | def initialize(self):
129 | """Resets helper variables for deployment without doing a full reset."""
130 | raise NotImplementedError("implement in sub-class")
131 |
132 | def step(self, action):
133 | raise NotImplementedError("implement in sub-class")
134 |
135 | def reset(self, seed=None, options=None):
136 | """Resets the environment with a specified seed and options."""
137 | print(f"Resetting environment with {seed=}")
138 | super().reset(seed=seed)
139 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | """Evaluates a trained RL policy."""
10 |
11 | import argparse
12 |
13 | import rclpy
14 | from gymnasium.spaces import Discrete
15 | from rclpy.node import Node
16 | from stable_baselines3.common.base_class import BaseAlgorithm
17 |
18 | from pyrobosim_ros_gym import get_config
19 | from pyrobosim_ros_gym.envs import available_envs_w_subtype, get_env_by_name
20 | from pyrobosim_ros_gym.policies import ManualPolicy, model_and_env_type_from_path
21 |
22 | MANUAL_STR = "manual"
23 |
24 |
25 | def get_args() -> argparse.Namespace:
26 | """Helper function to parse the command-line arguments."""
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument(
29 | "--model",
30 | type=str,
31 | help=f"The path of the model to evaluate. Can be '{MANUAL_STR}' for manual control.",
32 | )
33 | parser.add_argument(
34 | "--env",
35 | type=str,
36 | help=f"The name of the environment to use if '--model {MANUAL_STR}' is selected.",
37 | choices=available_envs_w_subtype(),
38 | )
39 | parser.add_argument(
40 | "--config",
41 | help="Path to the configuration YAML file.",
42 | required=True,
43 | )
44 | parser.add_argument(
45 | "--num-episodes",
46 | default=3,
47 | type=int,
48 | help="The number of episodes to evaluate.",
49 | )
50 | parser.add_argument("--seed", default=42, type=int, help="The RNG seed to use.")
51 | parser.add_argument(
52 | "--realtime", action="store_true", help="If true, slows down to real time."
53 | )
54 | args = parser.parse_args()
55 |
56 | # Ensure '--env' is provided if '--model' is 'manual'
57 | if args.model == MANUAL_STR and not args.env:
58 | parser.error(f"--env must be specified when --model is '{MANUAL_STR}'.")
59 | if args.env and args.model is None:
60 | print("--env is specified but --model is not. Defaulting to manual control.")
61 | args.model = MANUAL_STR
62 | return args
63 |
64 |
65 | if __name__ == "__main__":
66 | args = get_args()
67 | config = get_config(args.config)
68 |
69 | rclpy.init()
70 | node = Node("pyrobosim_ros_env")
71 |
72 | # Load the model and environment
73 | model: BaseAlgorithm | ManualPolicy
74 | if args.model == MANUAL_STR:
75 | env = get_env_by_name(
76 | args.env,
77 | node,
78 | max_steps_per_episode=15,
79 | realtime=True,
80 | discrete_actions=True,
81 | reward_fn=config["training"].get("reward_fn"),
82 | )
83 | model = ManualPolicy(env.action_space)
84 | else:
85 | model, env_type = model_and_env_type_from_path(args.model)
86 | env = get_env_by_name(
87 | env_type,
88 | node,
89 | max_steps_per_episode=15,
90 | realtime=args.realtime,
91 | discrete_actions=isinstance(model.action_space, Discrete),
92 | reward_fn=config["training"].get("reward_fn"),
93 | )
94 |
95 | # Evaluate the model for some steps
96 | num_successful_episodes = 0
97 | reward_per_episode = [0.0 for _ in range(args.num_episodes)]
98 | custom_metrics_store: dict[str, list[float]] = {}
99 | custom_metrics_episode_mean: dict[str, float] = {}
100 | custom_metrics_per_episode: dict[str, list[float]] = {}
101 | for i_e in range(args.num_episodes):
102 | print(f">>> Starting episode {i_e+1}/{args.num_episodes}")
103 | obs, _ = env.reset(seed=i_e + args.seed)
104 | terminated = False
105 | truncated = False
106 | i_step = 0
107 | while not (terminated or truncated):
108 | print(f"{obs=}")
109 | action, _ = model.predict(obs, deterministic=True)
110 | print(f"{action=}")
111 | obs, reward, terminated, truncated, info = env.step(action)
112 | custom_metrics = info.get("metrics", {})
113 |
114 | print(f"{reward=}")
115 | reward_per_episode[i_e] += reward
116 |
117 | print(f"{custom_metrics=}")
118 | for k, v in custom_metrics.items():
119 | if k not in custom_metrics_store:
120 | custom_metrics_store[k] = []
121 | custom_metrics_store[k].append(v)
122 |
123 | print(f"{terminated=}")
124 | print(f"{truncated=}")
125 | print("." * 10)
126 | if terminated or truncated:
127 | success = info["success"]
128 | if success:
129 | num_successful_episodes += 1
130 | print(f"<<< Episode {i_e+1} finished with {success=}.")
131 | print(f"Total reward: {reward_per_episode[i_e]}")
132 | for k, v in custom_metrics_store.items():
133 | mean_metric = sum(v) / len(v) if len(v) > 0 else 0.0
134 | custom_metrics_episode_mean[k] = mean_metric
135 | if k not in custom_metrics_per_episode:
136 | custom_metrics_per_episode[k] = []
137 | custom_metrics_per_episode[k].append(mean_metric)
138 | print(f"Mean {k}: {mean_metric}")
139 | print("=" * 20)
140 | break
141 |
142 | print("Summary:")
143 | success_percent = 100.0 * num_successful_episodes / args.num_episodes
144 | print(
145 | f"Successful episodes: {num_successful_episodes} / {args.num_episodes} "
146 | f"({success_percent:.2f}%)"
147 | )
148 | print(f"Reward over {args.num_episodes} episodes:")
149 | print(f" Mean: {sum(reward_per_episode)/args.num_episodes}")
150 | print(f" Min: {min(reward_per_episode)}")
151 | print(f" Max: {max(reward_per_episode)}")
152 | for k, v in custom_metrics_per_episode.items():
153 | print(f"Custom metric '{k}' over {args.num_episodes} episodes:")
154 | print(f" Mean: {sum(v)/args.num_episodes}")
155 | print(f" Min: {min(v)}")
156 | print(f" Max: {max(v)}")
157 |
158 | rclpy.shutdown()
159 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Reinforcement Learning for Deliberation in ROS 2
2 |
3 | This repository contains materials for the [ROSCon 2025](https://roscon.ros.org/2025/) workshop on ROS 2 Deliberation Technologies.
4 |
5 | > [!NOTE]
6 | > This was moved here from .
7 |
8 | ## Setup
9 |
10 | This repo uses Pixi and RoboStack along with ROS 2 Kilted.
11 |
12 | First, install dependencies on your system (assuming you are using Linux).
13 |
14 |
15 |
21 |
22 |
23 | ```bash
24 | sudo apt install build-essential curl
25 | ```
26 |
27 | Then, install Pixi.
28 |
29 | ```bash
30 | curl -fsSL https://pixi.sh/install.sh | sh
31 | ```
32 |
33 |
39 |
40 | Clone the repo including submodules.
41 |
42 | ```bash
43 | git clone --recursive https://github.com/ros-wg-delib/rl_deliberation.git
44 | ```
45 |
46 | Build the environment.
47 |
48 |
49 | ```bash
50 | pixi run build
51 | ```
52 |
53 | To verify your installation, the following should launch a window of PyRoboSim.
54 |
55 |
56 | ```bash
57 | pixi run start_world --env GreenhousePlain
58 | ```
59 |
60 | To explore the setup, you can also drop into a shell in the Pixi environment.
61 |
62 |
63 | ```bash
64 | pixi shell
65 | ```
66 |
67 | ## Explore the environment
68 |
69 | There are different environments available. For example, to run the Greenhouse environment:
70 |
71 |
72 | ```bash
73 | pixi run start_world --env GreenhousePlain
74 | ```
75 |
76 | All the following commands assume that the environment is running.
77 | You can also run the environment in headless mode for training.
78 |
79 |
80 | ```bash
81 | pixi run start_world --env GreenhousePlain --headless
82 | ```
83 |
84 | But first, we can explore the environment with a random agent.
85 |
86 | ## Evaluating with a random agent
87 |
88 | Assuming the environment is running, execute the evaluation script in another terminal:
89 |
90 |
91 | ```bash
92 | pixi run eval --config greenhouse_env_config.yaml --model pyrobosim_ros_gym/policies/GreenhousePlain_DQN_random.pt --num-episodes 1 --realtime
93 | ```
94 |
95 |
100 |
101 | In your terminal, you will see multiple sections in the following format:
102 |
103 | ```plaintext
104 | ..........
105 | obs=array([1. , 0.99194384, 0. , 2.7288349, 0. , 3.3768525, 1.], dtype=float32)
106 | action=array(0)
107 | Maximum steps (10) exceeded. Truncated episode.
108 | reward=0.0
109 | custom_metrics={'watered_plant_fraction': 0.0, 'battery_level': 100.0}
110 | terminated=False
111 | truncated=False
112 | ..........
113 | ```
114 |
115 | This is one step of the environment and the agent's interaction with it.
116 |
117 | - `obs` is the observation from the environment. It is an array with information about the 3 closest plant objects, with a class label (0 or 1), the distance to each object. It also has the robot's battery level and whether its current location is watered at the end.
118 | - `action` is the action taken by the agent. In this simple example, it can choose between 0 = move on and 1 = water plant.
119 | - `reward` is the reward received after taking the action, which is `0.0` in this case, because the agent did not water any plant.
120 | - `custom_metrics` provides additional information about the episode:
121 | - `watered_plant_fraction` indicates the fraction of plants (between 0 and 1) watered thus far in the episode.
122 | - `battery_level` indicates the current battery level of the robot. (This will not decrease for this environment type, but it will later.)
123 | - `terminated` indicates whether the episode reached a terminal state (e.g., the task was completed or failed).
124 | - `truncated` indicates whether the episode ended due to a time limit.
125 |
126 | In the PyRoboSim window, you should also see the robot moving around at every step.
127 |
128 | At the end of the episode, and after all episodes are completed, you will see some more statistics printed in the terminal.
129 |
130 | ```plaintext
131 | ..........
132 | <<< Episode 1 finished with success=False.
133 | Total reward: 0.0
134 | Mean watered_plant_fraction: 0.0
135 | Mean battery_level: 100.0
136 | ====================
137 | Summary:
138 | Reward over 1 episodes:
139 | Mean: 0.0
140 | Min: 0.0
141 | Max: 0.0
142 | Custom metric 'watered_plant_fraction' over 1 episodes:
143 | Mean: 0.0
144 | Min: 0.0
145 | Max: 0.0
146 | Custom metric 'battery_level' over 1 episodes:
147 | Mean: 100.0
148 | Min: 100.0
149 | Max: 100.0
150 | ```
151 |
152 | ## Training a model
153 |
154 | While the environment is running (in headless mode if you prefer), you can train a model.
155 |
156 | ### Choose algorithm type
157 |
158 | For example PPO
159 |
160 |
161 | ```bash
162 | pixi run train --env GreenhousePlain --config greenhouse_env_config.yaml --algorithm PPO --log
163 | ```
164 |
165 | Or DQN.
166 | Note that this needs the `--discrete-actions` flag.
167 |
168 |
169 | ```bash
170 | pixi run train --env GreenhousePlain --config greenhouse_env_config.yaml --algorithm DQN --discrete-actions --log
171 | ```
172 |
173 | Note that at the end of training, the model name and path will be printed in the terminal:
174 |
175 | ```plaintext
176 | New best mean reward!
177 | 100% ━━━━━━━━━━━━━━━━━━━━━━━━━ 100/100 [ 0:00:35 < 0:00:00 , 2 it/s ]
178 |
179 | Saved model to GreenhousePlain_PPO_.pt
180 | ```
181 |
182 | Remember this path, as you will need it later.
183 |
184 | ### You may find tensorboard useful
185 |
186 |
187 | ```bash
188 | pixi run tensorboard
189 | ```
190 |
191 | It should contain one entry named after your recent training run (e.g. `GreenhousePlain_PPO_`).
192 |
193 | ### See your freshly trained policy in action
194 |
195 | To run an evaluation, execute the following code.
196 |
197 |
198 | ```bash
199 | pixi run eval --config greenhouse_env_config.yaml --model GreenhousePlain_PPO_.pt --num-episodes 3 --realtime
200 | ```
201 |
202 | or to run more episodes as quickly as possible, launch your world with `--headless` and then execute.
203 |
204 |
205 | ```bash
206 | pixi run eval --config greenhouse_env_config.yaml --model GreenhousePlain_PPO_.pt --num-episodes 20
207 | ```
208 |
209 | You can also see your trained policy in action as a ROS node.
210 |
211 |
212 | ```bash
213 | pixi run policy_node --config greenhouse_env_config.yaml --model GreenhousePlain_PPO_.pt
214 | ```
215 |
216 | Then, in a separate terminal, you can send a goal.
217 |
218 |
219 | ```bash
220 | pixi shell
221 | ros2 action send_goal /execute_policy rl_interfaces/ExecutePolicy {}
222 | ```
223 |
224 | Of course, you can also use this same action interface in your own user code!
225 |
--------------------------------------------------------------------------------
/test/test_doc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from collections import OrderedDict
4 | from pathlib import Path
5 | import glob
6 | import pytest
7 | from sybil import Sybil
8 | from sybil.parsers.markdown.codeblock import CodeBlockParser
9 | from sybil.parsers.markdown.lexers import (
10 | FencedCodeBlockLexer,
11 | DirectiveInHTMLCommentLexer,
12 | )
13 | import docker
14 |
15 |
16 | COMMAND_PREFIX = "$ "
17 | IGNORED_OUTPUT = "..."
18 | DIR_NEW_ENV = "new-env"
19 | DIR_WORKDIR = "workdir"
20 | DIR_SKIP_NEXT = "skip-next"
21 | DIRECTIVE = "directive"
22 | DIR_CODE_BLOCK = "code-block"
23 | BASH = "bash"
24 | ARGUMENTS = "arguments"
25 | LINE_END = "\\"
26 | CWD = "cwd"
27 | EXPECTED_FILES = "expected-files"
28 |
29 | REPO_ROOT_FOLDER = os.path.join(os.path.dirname(__file__), "..")
30 |
31 |
32 | def evaluate_bash_block(example, cwd):
33 | """Executes a command and compares it's output to the provided expected output.
34 |
35 | ```bash
36 | command
37 | ```
38 | """
39 | print(f"{example=}")
40 | lines = example.strip().split("\n")
41 | output = []
42 | output_i = -1
43 | previous_cmd_line = ""
44 | next_is_cmd_continuation = False
45 | for line in lines:
46 | print(f"{line=}")
47 | if line.startswith(COMMAND_PREFIX) or next_is_cmd_continuation:
48 | # this is a command
49 | command = previous_cmd_line + line.replace(COMMAND_PREFIX, "")
50 | if command.endswith(LINE_END):
51 | # this must be merged with the next line
52 | previous_cmd_line = command.replace(LINE_END, "")
53 | next_is_cmd_continuation = True
54 | continue
55 | next_is_cmd_continuation = False
56 | print(f"{command=}")
57 | previous_cmd_line = ""
58 | # output = (
59 | # subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True, cwd=cwd)
60 | # .strip()
61 | # .decode("ascii")
62 | # )
63 | output = ""
64 | print(f"{output=}")
65 | output = [x.strip() for x in output.split("\n")]
66 | output_i = 0
67 | else:
68 | # this is expected output
69 | expected_line = line.strip()
70 | if len(expected_line) == 0:
71 | continue
72 | if output_i >= len(output):
73 | # end of captured output
74 | output_i = -1
75 | continue
76 | if output_i == -1:
77 | continue
78 | actual_line = output[output_i]
79 | while len(actual_line) == 0:
80 | # skip empty lines
81 | output_i += 1
82 | if output_i >= len(output):
83 | output_i = -1
84 | continue
85 | actual_line = output[output_i]
86 | if IGNORED_OUTPUT in expected_line:
87 | # skip this line
88 | output_i += 1
89 | continue
90 | assert actual_line == expected_line
91 | output_i += 1
92 |
93 |
94 | def collect_docs():
95 | """Search for *.md files."""
96 | assert os.path.exists(REPO_ROOT_FOLDER), f"Path must exist: {REPO_ROOT_FOLDER}"
97 |
98 | pattern = os.path.join(REPO_ROOT_FOLDER, "*.md")
99 | all_md_files = glob.glob(pattern)
100 | print(f"Found {len(all_md_files)} .md files by {pattern}.")
101 |
102 | # Configure Sybil
103 | fenced_lexer = FencedCodeBlockLexer(BASH)
104 | new_env_lexer = DirectiveInHTMLCommentLexer(directive=DIR_NEW_ENV)
105 | workdir_lexer = DirectiveInHTMLCommentLexer(directive=DIR_WORKDIR)
106 | skip_lexer = DirectiveInHTMLCommentLexer(directive=DIR_SKIP_NEXT)
107 | sybil = Sybil(
108 | parsers=[fenced_lexer, new_env_lexer, workdir_lexer, skip_lexer],
109 | filenames=all_md_files,
110 | )
111 | documents = []
112 | for f_path in all_md_files:
113 | doc = sybil.parse(Path(f_path))
114 | rel_path = os.path.relpath(f_path, REPO_ROOT_FOLDER)
115 | if len(list(doc)) > 0:
116 | documents.append([rel_path, list(doc)])
117 | print(f"Found {len(documents)} .md files with code to test.")
118 | return documents
119 |
120 |
121 | @pytest.mark.parametrize("path, blocks", collect_docs())
122 | def test_doc_md(path, blocks):
123 | """Testing all code blocks in one *.md file under `path`."""
124 | print(f"Testing {len(blocks)} code blocks in {path}.")
125 | env_blocks = OrderedDict()
126 | env_workdirs = OrderedDict()
127 | env_arguments = OrderedDict()
128 | current_env = "ENV DEFAULT"
129 | workdir = None
130 | skip_next_block = False
131 | for block in blocks:
132 | if DIRECTIVE in block.region.lexemes:
133 | directive = block.region.lexemes[DIRECTIVE]
134 | if directive == DIR_NEW_ENV:
135 | arguments = block.region.lexemes["arguments"]
136 | assert arguments not in env_blocks.keys()
137 | current_env = f"ENV {block.path}:{block.line}"
138 | if ARGUMENTS in block.region.lexemes:
139 | env_arguments[current_env] = block.region.lexemes[ARGUMENTS]
140 | elif directive == DIR_WORKDIR:
141 | workdir = block.region.lexemes[ARGUMENTS]
142 | elif directive == DIR_SKIP_NEXT:
143 | skip_next_block = True
144 | else:
145 | raise RuntimeError(f"Unsupported directive {directive}.")
146 | else:
147 | if skip_next_block:
148 | skip_next_block = False
149 | continue
150 | language = block.region.lexemes["language"]
151 | source = block.region.lexemes["source"]
152 | assert language == BASH, f"Unsupported language {language}"
153 | if current_env not in env_blocks.keys():
154 | env_blocks[current_env] = []
155 | env_workdirs[current_env] = []
156 | env_blocks[current_env].append(source)
157 | env_workdirs[current_env].append(workdir)
158 | workdir = None
159 | # After preprocessing all the environments, evaluate them.
160 | for env, blocks in env_blocks.items():
161 | # Get arguments
162 | assert env in env_arguments, "Environment must have arguments."
163 | arguments = env_arguments[env]
164 | print(f"Evaluating environment >{env}< with {arguments=} ...")
165 | client = docker.from_env()
166 | arguments = arguments.split(" ")
167 | assert len(arguments) == 1, f"Expecting one argument. Got {arguments}"
168 | image = arguments[0]
169 | client.images.pull(image)
170 | container = client.containers.run(
171 | image=image, command="sleep 1d", auto_remove=True, detach=True
172 | )
173 | # Get workdirs
174 | assert env in env_workdirs, "Environment must have working directories."
175 | workdirs = env_workdirs[env]
176 | assert len(blocks) == len(workdirs), "Must have one workdir entry per block."
177 | try:
178 | for block, workdir in zip(blocks, workdirs):
179 | for line in block.split("\n"):
180 | command = f'bash -ic "{line.strip()}"'
181 | if len(line) == 0:
182 | continue
183 | print(f"Executing {command=}")
184 | returncode, output = container.exec_run(
185 | command, tty=True, workdir=workdir
186 | )
187 | output = output.decode("utf-8")
188 | if returncode == 0:
189 | print("✅")
190 | print("output ...")
191 | print(output)
192 | else:
193 | print(f"❌ {returncode=}")
194 | assert (
195 | returncode == 0
196 | ), f"Unexpected Returncode: {returncode=}\noutput ...\n{output}"
197 | except Exception as e:
198 | print(e)
199 | assert False, str(e)
200 | finally:
201 | container.stop()
202 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/envs/banana.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | """Utilities for the banana test environment."""
10 |
11 | from pprint import pprint
12 | import os
13 | from enum import Enum
14 |
15 | import numpy as np
16 | from gymnasium import spaces
17 |
18 | from pyrobosim_msgs.action import ExecuteTaskAction
19 | from pyrobosim_msgs.msg import ExecutionResult, TaskAction
20 | from pyrobosim_msgs.srv import RequestWorldState, ResetWorld
21 |
22 | from .pyrobosim_ros_env import PyRoboSimRosEnv
23 |
24 |
25 | class BananaEnv(PyRoboSimRosEnv):
26 | sub_types = Enum("sub_types", "Pick Place PlaceNoSoda")
27 |
28 | @classmethod
29 | def get_world_file_path(cls, _: sub_types) -> str:
30 | """Get the world file path for a given subtype."""
31 | return os.path.join("rl_ws_worlds", "worlds", "banana.yaml")
32 |
33 | def __init__(
34 | self,
35 | sub_type: sub_types,
36 | node,
37 | max_steps_per_episode,
38 | realtime,
39 | discrete_actions,
40 | reward_fn=None,
41 | executor=None,
42 | ):
43 | """
44 | Instantiate Banana environment.
45 |
46 | :param sub_types: Subtype of this environment (e.g. `BananaEnv.sub_types.Pick`).
47 | :param node: Node instance needed for ROS communication.
48 | :param max_steps_per_episode: Limit the steps (when to end the episode).
49 | If -1, there is no limit to number of steps.
50 | :param realtime: Whether actions take time.
51 | :param discrete_actions: Choose discrete actions (needed for DQN).
52 | :param reward_fn: Function that calculates the reward and termination criteria.
53 | The first argument needs to be the environment itself.
54 | The output needs to be a (reward, terminated, info) tuple.
55 | If not specified, uses the default for that environment.
56 | :param executor: Optional ROS executor. It must be already spinning!
57 | """
58 | if sub_type == BananaEnv.sub_types.Pick:
59 | reward_fn = reward_fn or banana_picked_reward
60 | reset_validation_fn = None
61 | # eval_freq = 1000
62 | elif sub_type == BananaEnv.sub_types.Place:
63 | reward_fn = reward_fn or banana_on_table_reward
64 | reset_validation_fn = None
65 | # eval_freq = 2000
66 | elif sub_type == BananaEnv.sub_types.PlaceNoSoda:
67 | reward_fn = reward_fn or banana_on_table_avoid_soda_reward
68 | reset_validation_fn = avoid_soda_reset_validation
69 | # eval_freq = 2000
70 | else:
71 | raise ValueError(f"Invalid environment: {sub_type}")
72 |
73 | super().__init__(
74 | node,
75 | reward_fn,
76 | reset_validation_fn,
77 | max_steps_per_episode,
78 | realtime,
79 | discrete_actions,
80 | executor=executor,
81 | )
82 |
83 | self.num_locations = sum(len(loc.spawns) for loc in self.world_state.locations)
84 | self.loc_to_idx = {loc: idx for idx, loc in enumerate(self.all_locations)}
85 |
86 | self.num_object_types = len(self.world_info.object_categories)
87 | self.obj_to_idx = {
88 | obj: idx for idx, obj in enumerate(self.world_info.object_categories)
89 | }
90 |
91 | # Observation space is defined by:
92 | # Previous action
93 | # Location of robot
94 | # Type of object robot is holding (if any)
95 | # Whether there is at least one of a specific object type at each location
96 | self.obs_size = (
97 | self.num_locations # Number of locations robot can be in
98 | + self.num_object_types # Object types robot is holding
99 | + (
100 | self.num_locations * self.num_object_types
101 | ) # Number of object categories per location
102 | )
103 |
104 | self.observation_space = spaces.Box(
105 | low=-np.ones(self.obs_size, dtype=np.float32),
106 | high=np.ones(self.obs_size, dtype=np.float32),
107 | )
108 | print(f"{self.observation_space=}")
109 |
110 | self.initialize()
111 |
112 | def _action_space(self):
113 | # Action space is defined by:
114 | # Move: To all possible object spawns
115 | # Pick: All possible object categories
116 | # Place: The current manipulated object
117 | idx = 0
118 | self.integer_to_action = {}
119 | for loc in self.all_locations:
120 | self.integer_to_action[idx] = TaskAction(
121 | type="navigate", target_location=loc
122 | )
123 | idx += 1
124 | for obj_category in self.world_info.object_categories:
125 | self.integer_to_action[idx] = TaskAction(type="pick", object=obj_category)
126 | idx += 1
127 | self.integer_to_action[idx] = TaskAction(type="place")
128 | self.num_actions = len(self.integer_to_action)
129 | print("self.integer_to_action=")
130 | pprint(self.integer_to_action)
131 |
132 | if self.discrete_actions:
133 | return spaces.Discrete(self.num_actions)
134 | else:
135 | return spaces.Box(
136 | low=np.zeros(self.num_actions, dtype=np.float32),
137 | high=np.ones(self.num_actions, dtype=np.float32),
138 | )
139 |
140 | def step(self, action):
141 | """Steps the environment with a specific action."""
142 | self.previous_location = self.world_state.robots[0].last_visited_location
143 |
144 | goal = ExecuteTaskAction.Goal()
145 | if self.discrete_actions:
146 | goal.action = self.integer_to_action[int(action)]
147 | else:
148 | goal.action = self.integer_to_action[np.argmax(action)]
149 | goal.action.robot = "robot"
150 | goal.realtime_factor = 1.0 if self.realtime else -1.0
151 |
152 | goal_future = self.execute_action_client.send_goal_async(goal)
153 | self._spin_future(goal_future)
154 |
155 | result_future = goal_future.result().get_result_async()
156 | self._spin_future(result_future)
157 |
158 | action_result = result_future.result().result
159 | self.step_number += 1
160 | truncated = (self.max_steps_per_episode >= 0) and (
161 | self.step_number >= self.max_steps_per_episode
162 | )
163 | if truncated:
164 | print(
165 | f"Maximum steps ({self.max_steps_per_episode}) exceeded. Truncated episode."
166 | )
167 |
168 | observation = self._get_obs()
169 | reward, terminated, info = self.reward_fn(goal, action_result)
170 | self.previous_action_type = goal.action.type
171 |
172 | info["metrics"] = {
173 | "at_banana_location": float(is_at_banana_location(self)),
174 | "holding_banana": float(is_holding_banana(self)),
175 | }
176 | return observation, reward, terminated, truncated, info
177 |
178 | def initialize(self):
179 | self.step_number = 0
180 |
181 | def reset(self, seed=None, options=None):
182 | super().reset(seed=seed)
183 | self.initialize()
184 | info = {}
185 |
186 | valid_reset = False
187 | num_reset_attempts = 0
188 | while not valid_reset:
189 | future = self.reset_world_client.call_async(
190 | ResetWorld.Request(seed=(seed or -1))
191 | )
192 | self._spin_future(future)
193 |
194 | observation = self._get_obs()
195 |
196 | valid_reset = self.reset_validation_fn()
197 | num_reset_attempts += 1
198 | seed = None # subsequent resets need to not use a fixed seed
199 |
200 | print(f"Reset environment in {num_reset_attempts} attempt(s).")
201 | return observation, info
202 |
203 | def _get_obs(self):
204 | """Calculate the observation. All elements are either -1.0 or +1.0."""
205 | future = self.request_state_client.call_async(RequestWorldState.Request())
206 | self._spin_future(future)
207 | world_state = future.result().state
208 | robot_state = world_state.robots[0]
209 |
210 | obs = -np.ones(self.obs_size, dtype=np.float32)
211 |
212 | # Robot's current location
213 | if robot_state.last_visited_location in self.loc_to_idx:
214 | loc_idx = self.loc_to_idx[robot_state.last_visited_location]
215 | obs[loc_idx] = 1.0
216 |
217 | # Object categories per location (including currently held object, if any)
218 | for obj in world_state.objects:
219 | obj_idx = self.obj_to_idx[obj.category]
220 | if obj.name == robot_state.manipulated_object:
221 | obs[self.num_locations + obj_idx] = 1.0
222 | else:
223 | loc_idx = self.loc_to_idx[obj.parent]
224 | obs[
225 | self.num_locations
226 | + self.num_object_types
227 | + (loc_idx * self.num_object_types)
228 | + obj_idx
229 | ] = 1.0
230 |
231 | self.world_state = world_state
232 | return obs
233 |
234 |
235 | def is_at_banana_location(env):
236 | robot_state = env.world_state.robots[0]
237 | for obj in env.world_state.objects:
238 | if obj.category == "banana":
239 | if obj.parent == robot_state.last_visited_location:
240 | return True
241 | return False
242 |
243 |
244 | def is_holding_banana(env):
245 | robot_state = env.world_state.robots[0]
246 | for obj in env.world_state.objects:
247 | if obj.category == "banana" and obj.name == robot_state.manipulated_object:
248 | return True
249 | return False
250 |
251 |
252 | def banana_picked_reward(env, goal, action_result):
253 | """
254 | Checks whether the robot has picked a banana.
255 |
256 | :param: The environment.
257 | :goal: The ROS action goal sent to the robot.
258 | :action_result: The result of the above goal.
259 | :
260 | :return: A tuple of (reward, terminated, info)
261 | """
262 | # Calculate reward
263 | reward = 0.0
264 | terminated = False
265 | info = {"success": False}
266 |
267 | # Discourage repeating the same navigation action or failing to pick/place.
268 | if (goal.action.type == "navigate") and (
269 | goal.action.target_location == env.previous_location
270 | ):
271 | reward -= 1.0
272 | if action_result.execution_result.status != ExecutionResult.SUCCESS:
273 | reward -= 1.0
274 | # Discourage repeat action types.
275 | if goal.action.type == env.previous_action_type:
276 | reward -= 0.5
277 | # Robot gets positive reward based on holding a banana,
278 | # and negative reward for being in locations without bananas.
279 | at_banana_location = is_at_banana_location(env)
280 | if is_holding_banana(env):
281 | reward += 10.0
282 | terminated = True
283 | at_banana_location = True
284 | info["success"] = True
285 | print(f"🍌 Picked banana. Episode succeeded in {env.step_number} steps!")
286 |
287 | # Reward shaping: Penalty if the robot is not at a location containing a banana.
288 | if not terminated and not at_banana_location:
289 | reward -= 0.5
290 | return reward, terminated, info
291 |
292 |
293 | def banana_on_table_reward(env, goal, action_result):
294 | """
295 | Checks whether the robot has placed a banana on the table.
296 |
297 | :param: The environment.
298 | :goal: The ROS action goal sent to the robot.
299 | :action_result: The result of the above goal.
300 | :
301 | :return: A tuple of (reward, terminated, info)
302 | """
303 | # Calculate reward
304 | reward = 0.0
305 | terminated = False
306 | info = {"success": False}
307 |
308 | robot_state = env.world_state.robots[0]
309 | # Discourage repeating the same navigation action or failing to pick/place.
310 | if (goal.action.type == "navigate") and (
311 | goal.action.target_location == env.previous_location
312 | ):
313 | reward -= 1.0
314 | if action_result.execution_result.status != ExecutionResult.SUCCESS:
315 | reward -= 1.0
316 | # Discourage repeat action types.
317 | if goal.action.type == env.previous_action_type:
318 | reward -= 0.5
319 | # Robot gets positive reward based on a banana being at the table.
320 | at_banana_location = False
321 | holding_banana = False
322 | for obj in env.world_state.objects:
323 | if obj.category == "banana":
324 | if obj.parent == robot_state.last_visited_location:
325 | at_banana_location = True
326 | if obj.parent == "table0_tabletop":
327 | print(
328 | f"🍌 Placed banana on the table. "
329 | f"Episode succeeded in {env.step_number} steps!"
330 | )
331 | reward += 10.0
332 | terminated = True
333 | info["success"] = True
334 | break
335 |
336 | if obj.category == "banana" and obj.name == robot_state.manipulated_object:
337 | holding_banana = True
338 |
339 | # Reward shaping: Adjust the reward related to how close the robot is to completing the task.
340 | if not terminated:
341 | if holding_banana and env.previous_location != "table0_tabletop":
342 | reward += 0.25
343 | elif holding_banana:
344 | reward += 0.1
345 | elif at_banana_location:
346 | reward -= 0.1
347 | else:
348 | reward -= 0.25
349 |
350 | return reward, terminated, info
351 |
352 |
353 | def banana_on_table_avoid_soda_reward(env, goal, action_result):
354 | """
355 | Checks whether the robot has placed a banana on the table without touching the soda.
356 |
357 | :param: The environment.
358 | :goal: The ROS action goal sent to the robot.
359 | :action_result: The result of the above goal.
360 | :
361 | :return: A tuple of (reward, terminated, info)
362 | """
363 | # Start with the same reward as the no-soda case.
364 | reward, terminated, info = banana_on_table_reward(env, goal, action_result)
365 |
366 | # Robot gets additional negative reward being near a soda,
367 | # and fails if it tries to pick or place at a location with a soda.
368 | robot_state = env.world_state.robots[0]
369 | for obj in env.world_state.objects:
370 | if obj.category == "coke" and obj.parent == robot_state.last_visited_location:
371 | if goal.action.type == "navigate":
372 | reward -= 2.5
373 | else:
374 | print(
375 | "🔥 Tried to pick and place near a soda. "
376 | f"Episode failed in {env.step_number} steps!"
377 | )
378 | reward -= 25.0
379 | terminated = True
380 | info["success"] = False
381 |
382 | return reward, terminated, info
383 |
384 |
385 | def avoid_soda_reset_validation(env):
386 | """
387 | Checks whether an environment has been properly reset to avoid soda.
388 |
389 | Specifically, we are checking that:
390 | - There is at least one banana not next to a soda
391 | - The robot is not at a starting location where there is a soda
392 |
393 | :param: The environment.
394 | :return: True if valid, else False.
395 | """
396 | soda_location = None
397 | for obj in env.world_state.objects:
398 | if obj.category == "coke":
399 | soda_location = obj.parent
400 |
401 | valid_banana_locations = False
402 | for obj in env.world_state.objects:
403 | if obj.category == "banana" and obj.parent != soda_location:
404 | valid_banana_locations = True
405 |
406 | robot_location = env.world_state.robots[0].last_visited_location
407 | valid_robot_location = robot_location != soda_location
408 |
409 | return valid_banana_locations and valid_robot_location
410 |
--------------------------------------------------------------------------------
/pyrobosim_ros_gym/envs/greenhouse.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2025, Sebastian Castro, Christian Henkel
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the BSD 3-Clause License.
7 | # See the LICENSE file in the project root for license information.
8 |
9 | """Utilities for the greenhouse test environment."""
10 |
11 | from enum import Enum
12 | import os
13 |
14 | import numpy as np
15 | from geometry_msgs.msg import Point
16 | from gymnasium import spaces
17 |
18 | from pyrobosim_msgs.action import ExecuteTaskAction
19 | from pyrobosim_msgs.msg import TaskAction, WorldState
20 | from pyrobosim_msgs.srv import RequestWorldState, ResetWorld
21 |
22 | from .pyrobosim_ros_env import PyRoboSimRosEnv
23 |
24 |
25 | def _dist(a: Point, b: Point) -> float:
26 | """Calculate distance between two (geometry_msgs.msg) Points."""
27 | return float(np.linalg.norm([a.x - b.x, a.y - b.y, a.z - b.z]))
28 |
29 |
30 | class GreenhouseEnv(PyRoboSimRosEnv):
31 | sub_types = Enum("sub_types", "Plain Battery Random")
32 |
33 | @classmethod
34 | def get_world_file_path(cls, sub_type: sub_types) -> str:
35 | """Get the world file path for a given subtype."""
36 | if sub_type == GreenhouseEnv.sub_types.Plain:
37 | return os.path.join("rl_ws_worlds", "worlds", "greenhouse_plain.yaml")
38 | elif sub_type == GreenhouseEnv.sub_types.Battery:
39 | return os.path.join("rl_ws_worlds", "worlds", "greenhouse_battery.yaml")
40 | elif sub_type == GreenhouseEnv.sub_types.Random:
41 | return os.path.join("rl_ws_worlds", "worlds", "greenhouse_random.yaml")
42 | else:
43 | raise ValueError(f"Invalid environment: {sub_type}")
44 |
45 | def __init__(
46 | self,
47 | sub_type: sub_types,
48 | node,
49 | max_steps_per_episode,
50 | realtime,
51 | discrete_actions,
52 | reward_fn,
53 | executor=None,
54 | ):
55 | """
56 | Instantiate Greenhouse environment.
57 |
58 | :param sub_type: Subtype of this environment, e.g. `GreenhouseEnv.sub_types.Deterministic`.
59 | :param node: Node instance needed for ROS communication.
60 | :param max_steps_per_episode: Limit the steps (when to end the episode).
61 | If -1, there is no limit to number of steps.
62 | :param realtime: Whether actions take time.
63 | :param discrete_actions: Choose discrete actions (needed for DQN).
64 | :param reward_fn: Function that calculates the reward and termination criteria.
65 | The first argument needs to be the environment itself.
66 | The output needs to be a (reward, terminated) tuple.
67 | :param executor: Optional ROS executor. It must be already spinning!
68 | """
69 | if sub_type == GreenhouseEnv.sub_types.Plain:
70 | # All plants are in their places
71 | pass
72 | elif sub_type == GreenhouseEnv.sub_types.Random:
73 | # Plants are randomly across tables
74 | pass
75 | elif sub_type == GreenhouseEnv.sub_types.Battery:
76 | # Battery (= water) is limited
77 | pass
78 | else:
79 | raise ValueError(f"Invalid environment: {sub_type}")
80 | self.sub_type = sub_type
81 |
82 | super().__init__(
83 | node,
84 | reward_fn,
85 | None, # reset_validation_fn
86 | max_steps_per_episode,
87 | realtime,
88 | discrete_actions,
89 | executor=executor,
90 | )
91 |
92 | # Observation space is defined by:
93 | self.max_n_objects = 3
94 | self.max_dist = 10
95 | # array of n objects with a class and distance each,
96 | # plus current location watered and (optionally) battery level.
97 | self.obs_size = 2 * self.max_n_objects + 1
98 | if self.sub_type == GreenhouseEnv.sub_types.Battery:
99 | self.obs_size += 1
100 |
101 | low = np.zeros(self.obs_size, dtype=np.float32)
102 | high = np.ones(self.obs_size, dtype=np.float32) # max class = 1
103 | high[1 : self.max_n_objects * 2 : 2] = self.max_dist
104 | self.observation_space = spaces.Box(low=low, high=high)
105 | print(f"{self.observation_space=}")
106 |
107 | self.plants = [obj.name for obj in self.world_state.objects]
108 | # print(f"{self.plants=}")
109 | self.good_plants = [
110 | obj.name for obj in self.world_state.objects if obj.category == "plant_good"
111 | ]
112 | # print(f"{self.good_plants=}")
113 |
114 | self.waypoints = [
115 | "table_c",
116 | "table_ne",
117 | "table_e",
118 | "table_se",
119 | "table_s",
120 | "table_sw",
121 | "table_w",
122 | "table_nw",
123 | "table_n",
124 | ]
125 | self.initialize()
126 |
127 | def _action_space(self):
128 | if self.sub_type == GreenhouseEnv.sub_types.Battery:
129 | self.num_actions = 3 # stay ducked, water plant, or go charge
130 | else:
131 | self.num_actions = 2 # stay ducked or water plant
132 |
133 | if self.discrete_actions:
134 | return spaces.Discrete(self.num_actions)
135 | else:
136 | return spaces.Box(
137 | low=np.zeros(self.num_actions, dtype=np.float32),
138 | high=np.ones(self.num_actions, dtype=np.float32),
139 | )
140 |
141 | def step(self, action):
142 | info = {}
143 | truncated = (self.max_steps_per_episode >= 0) and (
144 | self.step_number >= self.max_steps_per_episode
145 | )
146 | if truncated:
147 | print(
148 | f"Maximum steps ({self.max_steps_per_episode}) exceeded. "
149 | f"Truncated episode with watered fraction {self.watered_plant_fraction()}."
150 | )
151 |
152 | # print(f"{'*'*10}")
153 | # print(f"{action=}")
154 | if self.discrete_actions:
155 | action = float(action)
156 | else:
157 | action = np.argmax(action)
158 |
159 | # Execute the current actions before calculating reward
160 | self.previous_battery_level = self.battery_level()
161 | if action == 1: # water a plant
162 | self.mark_table(self.get_current_location())
163 | elif action == 2: # charge
164 | self.go_to_loc("charger")
165 |
166 | future = self.request_state_client.call_async(RequestWorldState.Request())
167 | self._spin_future(future)
168 | self.world_state = future.result().state
169 |
170 | self.step_number += 1
171 | reward, terminated = self.reward_fn(action)
172 | # print(f"{reward=}")
173 |
174 | # Execute the remainder of the actions after calculating reward
175 | if not terminated:
176 | if action == 2: # charge
177 | self.go_to_loc(self.get_current_location())
178 | else:
179 | self.go_to_loc(self.get_next_location())
180 |
181 | # Update self.world_state and observation after finishing the action
182 | observation = self._get_obs()
183 | # print(f"{observation=}")
184 |
185 | info = {
186 | "success": self.watered_plant_fraction() == 1.0,
187 | "metrics": {
188 | "watered_plant_fraction": float(self.watered_plant_fraction()),
189 | "battery_level": float(self.battery_level()),
190 | },
191 | }
192 |
193 | return observation, reward, terminated, truncated, info
194 |
195 | def mark_table(self, loc):
196 | close_goal = ExecuteTaskAction.Goal()
197 | close_goal.action = TaskAction()
198 | close_goal.action.robot = "robot"
199 | close_goal.action.type = "close"
200 | close_goal.action.target_location = loc
201 |
202 | goal_future = self.execute_action_client.send_goal_async(close_goal)
203 | self._spin_future(goal_future)
204 |
205 | result_future = goal_future.result().get_result_async()
206 | self._spin_future(result_future)
207 |
208 | def watered_plant_fraction(self):
209 | n_watered = 0
210 | for w in self.watered.values():
211 | if w:
212 | n_watered += 1
213 | return n_watered / len(self.watered)
214 |
215 | def battery_level(self):
216 | return self.world_state.robots[0].battery_level
217 |
218 | def _get_plants_by_distance(self, world_state: WorldState):
219 | robot_state = world_state.robots[0]
220 | robot_pos = robot_state.pose.position
221 | # print(robot_pos)
222 |
223 | plants_by_distance = {}
224 | for obj in world_state.objects:
225 | pos = obj.pose.position
226 | dist = _dist(robot_pos, pos)
227 | dist = min(dist, self.max_dist)
228 | plants_by_distance[dist] = obj
229 |
230 | return plants_by_distance
231 |
232 | def initialize(self):
233 | self.step_number = 0
234 | self.waypoint_i = -1
235 | self.watered = {plant: False for plant in self.good_plants}
236 | self.previous_battery_level = self.battery_level()
237 | self.go_to_loc(self.get_next_location())
238 |
239 | def reset(self, seed=None, options=None):
240 | super().reset(seed)
241 |
242 | valid_reset = False
243 | num_reset_attempts = 0
244 | while not valid_reset:
245 | future = self.reset_world_client.call_async(
246 | ResetWorld.Request(seed=(seed or -1))
247 | )
248 | self._spin_future(future)
249 |
250 | # Validate that there are no two plants in the same location.
251 | observation = self._get_obs()
252 | valid_reset = True
253 | parent_locs = set()
254 | for obj in self.world_state.objects:
255 | if obj.parent not in parent_locs:
256 | parent_locs.add(obj.parent)
257 | else:
258 | valid_reset = False
259 | break
260 |
261 | num_reset_attempts += 1
262 | seed = None # subsequent resets need to not use a fixed seed
263 |
264 | self.initialize()
265 | print(f"Reset environment in {num_reset_attempts} attempt(s).")
266 | return observation, {}
267 |
268 | def _get_obs(self):
269 | """Calculate the observations"""
270 | future = self.request_state_client.call_async(RequestWorldState.Request())
271 | self._spin_future(future)
272 | world_state = future.result().state
273 | plants_by_distance = self._get_plants_by_distance(world_state)
274 |
275 | obs = np.zeros(self.obs_size, dtype=np.float32)
276 | start_idx = 0
277 |
278 | for _ in range(self.max_n_objects):
279 | closest_d = min(plants_by_distance.keys())
280 | plant = plants_by_distance.pop(closest_d)
281 | plant_class = 0 if plant.category == "plant_good" else 1
282 | obs[start_idx] = plant_class
283 | obs[start_idx + 1] = closest_d
284 | start_idx += 2
285 |
286 | cur_loc = world_state.robots[0].last_visited_location
287 | for loc in world_state.locations:
288 | if cur_loc == loc.name or cur_loc in loc.spawns:
289 | if not loc.is_open:
290 | obs[start_idx] = 1.0 # closed = watered
291 | break
292 |
293 | if self.sub_type == self.sub_type.Battery:
294 | obs[start_idx + 1] = self.battery_level() / 100.0
295 |
296 | self.world_state = world_state
297 | return obs
298 |
299 | def get_next_location(self):
300 | self.waypoint_i = (self.waypoint_i + 1) % len(self.waypoints)
301 | return self.get_current_location()
302 |
303 | def get_current_location(self):
304 | return self.waypoints[self.waypoint_i]
305 |
306 | def go_to_loc(self, loc: str):
307 | nav_goal = ExecuteTaskAction.Goal()
308 | nav_goal.action = TaskAction(type="navigate", target_location=loc)
309 | nav_goal.action.robot = "robot"
310 | nav_goal.realtime_factor = 1.0 if self.realtime else -1.0
311 |
312 | goal_future = self.execute_action_client.send_goal_async(nav_goal)
313 | self._spin_future(goal_future)
314 |
315 | result_future = goal_future.result().get_result_async()
316 | self._spin_future(result_future)
317 |
318 |
319 | def sparse_reward(env, action):
320 | """
321 | The most basic Greenhouse environment reward function, which provides
322 | positive reward if all good plants are watered and negative reward if
323 | an evil plant is watered.
324 | """
325 | reward = 0.0
326 | plants_by_distance = env._get_plants_by_distance(env.world_state)
327 | robot_location = env.world_state.robots[0].last_visited_location
328 |
329 | if action == 1: # move up to water
330 | for plant in plants_by_distance.values():
331 | if plant.parent != robot_location:
332 | continue
333 | if plant.category == "plant_evil":
334 | print(
335 | "🌶️ Tried to water an evil plant. "
336 | f"Terminated in {env.step_number} steps "
337 | f"with watered fraction {env.watered_plant_fraction()}."
338 | )
339 | return -5.0, True
340 |
341 | terminated = all(env.watered.values())
342 | if terminated:
343 | print(f"💧 Watered all good plants! Succeeded in {env.step_number} steps.")
344 | reward += 8.0
345 | return reward, terminated
346 |
347 |
348 | def dense_reward(env, action):
349 | """
350 | A simple Greenhouse environment reward function that provides reward each time
351 | a good plant is watered.
352 | """
353 | reward = 0.0
354 | plants_by_distance = env._get_plants_by_distance(env.world_state)
355 | robot_location = env.world_state.robots[0].last_visited_location
356 |
357 | if action == 1: # move up to water
358 | for plant in plants_by_distance.values():
359 | if plant.parent != robot_location:
360 | continue
361 | if plant.category == "plant_good":
362 | if not env.watered[plant.name]:
363 | env.watered[plant.name] = True
364 | reward += 2.0
365 | elif plant.category == "plant_evil":
366 | print(
367 | "🌶️ Tried to water an evil plant. "
368 | f"Terminated in {env.step_number} steps "
369 | f"with watered fraction {env.watered_plant_fraction()}."
370 | )
371 | return -5.0, True
372 | else:
373 | raise RuntimeError(f"Unknown category {plant.category}")
374 |
375 | terminated = all(env.watered.values())
376 | if terminated:
377 | print(f"💧 Watered all good plants! Succeeded in {env.step_number} steps.")
378 | return reward, terminated
379 |
380 |
381 | def full_reward(env, action):
382 | """Full (solution) reward function for the Greenhouse environment."""
383 | reward = 0.0
384 | plants_by_distance = env._get_plants_by_distance(env.world_state)
385 | robot_location = env.world_state.robots[0].last_visited_location
386 |
387 | if env.battery_level() <= 0.0:
388 | print(
389 | "🪫 Ran out of battery. "
390 | f"Terminated in {env.step_number} steps "
391 | f"with watered fraction {env.watered_plant_fraction()}."
392 | )
393 | return -5.0, True
394 |
395 | if action == 0: # stay ducked
396 | # Robot gets a penalty if it decides to ignore a waterable plant.
397 | for plant in env.world_state.objects:
398 | if (plant.category == "plant_good") and (plant.parent == robot_location):
399 | for location in env.world_state.locations:
400 | if robot_location in location.spawns and location.is_open:
401 | # print("\tPassed over a waterable plant")
402 | reward -= 0.25
403 | break
404 | return reward, False
405 |
406 | elif action == 1: # move up to water
407 | for plant in plants_by_distance.values():
408 | if plant.parent != robot_location:
409 | continue
410 | if plant.category == "plant_good":
411 | if not env.watered[plant.name]:
412 | env.watered[plant.name] = True
413 | reward += 2.0
414 | elif plant.category == "plant_evil":
415 | print(
416 | "🌶️ Tried to water an evil plant. "
417 | f"Terminated in {env.step_number} steps "
418 | f"with watered fraction {env.watered_plant_fraction()}."
419 | )
420 | return -5.0, True
421 | else:
422 | raise RuntimeError(f"Unknown category {plant.category}")
423 | if reward == 0.0: # nothing watered, wasted water
424 | # print("\tWasted water")
425 | reward = -0.5
426 |
427 | elif action == 2: # charging
428 | # Reward shaping to get the robot to visit the charger when its
429 | # battery is low, but not when it is high.
430 | if env.previous_battery_level <= 5.0:
431 | # print(f"\tCharged when battery low ({self.previous_battery_level}) :)")
432 | reward += 0.5
433 | else:
434 | # print(f"\tCharged when battery high ({self.previous_battery_level}) :(")
435 | reward -= 1.0
436 |
437 | terminated = all(env.watered.values())
438 | if terminated:
439 | print(f"💧 Watered all good plants! Succeeded in {env.step_number} steps.")
440 |
441 | return reward, terminated
442 |
--------------------------------------------------------------------------------
/slides/main.md:
--------------------------------------------------------------------------------
1 | ---
2 | title:
3 | - Reinforcement Learning for Deliberation in ROS 2
4 | author:
5 | - Christian Henkel
6 | - Sebastian Castro
7 | theme:
8 | - Bergen
9 | date:
10 | - ROSCon 2025 / October 27, 2025
11 | aspectratio: 169
12 | fontsize: 10pt
13 | colorlinks: true
14 | indent: false
15 | header-includes:
16 | - \usepackage{listings}
17 | - \usepackage{xcolor}
18 | - \lstset{
19 | basicstyle=\ttfamily\small,
20 | backgroundcolor=\color{gray!10},
21 | keywordstyle=\color{blue},
22 | stringstyle=\color{orange},
23 | commentstyle=\color{gray},
24 | showstringspaces=false
25 | }
26 | - \hypersetup{urlcolor=blue}
27 | - \urlstyle{tt}
28 | - \setbeamerfont{footline}{size=\normalsize}
29 | - \setbeamertemplate{navigation symbols}{}
30 | - \setbeamertemplate{footline}{\vspace{2pt} \hspace{2pt} \includegraphics[width=1.85cm]{media/ros-wg-delib.png} \includegraphics[width=1.85cm]{media/roscon25.png} \hspace*{5pt} \insertsection \hfill \insertframenumber{} / \inserttotalframenumber \hspace*{5pt} \vspace{2pt}}
31 | ---
32 |
33 | # Introduction
34 |
35 | ## Agenda
36 |
37 | | __Time__ | __Topic__ |
38 | |----------------|---------------------------------------------------------|
39 | | 13:00 - 13:30 | Introduction / Software Setup |
40 | | 13:30 - 14:00 | (Very) Quick intro to Reinforcement Learning |
41 | | 14:00 - 15:00 | Training and evaluating RL agents |
42 | | 15:00 - 15:30 | [Coffee break / leave a longer training running] |
43 | | 15:30 - 16:15 | Evaluating trained agents and running in ROS nodes |
44 | | 16:15 - 17:00 | Discussion: ROS 2, RL, and Deliberation |
45 |
46 | ## Software Setup
47 |
48 | 1. Clone the repository
49 |
50 | ```bash
51 | git clone --recursive \
52 | https://github.com/ros-wg-delib/rl_deliberation.git
53 | ```
54 |
55 | 2. Install Pixi:
56 |
57 | ```bash
58 | curl -fsSL https://pixi.sh/install.sh | sh
59 | ```
60 |
61 | (or – recommend for autocompletion!)
62 |
63 | 3. Build the project:
64 |
65 | ```bash
66 | pixi run build
67 | ```
68 |
69 | 4. Run an example:
70 |
71 | ```plain
72 | pixi run start_world --env GreenhousePlain
73 | ```
74 |
75 | ## Learning Goals
76 |
77 | By the end of this workshop, you will be able to:
78 |
79 | - Recognize robotics problems that can be solved with reinforcement learning.
80 | - Understand the basic reinforcement learning concepts and terminology.
81 | - Observe the effects of changing algorithms, hyperparameters, and reward functions on training.
82 |
83 | ## What is Reinforcement Learning (RL)?
84 |
85 | ::: columns
86 |
87 | :::: column
88 |
89 | ### Basic model
90 |
91 | - Given an __agent__ and an __environment__.
92 | - Subject to the __state__ of the environment,
93 | - the agent takes an __action__.
94 | - the environment responds with a new __state__ and a __reward__.
95 |
96 | ::::
97 |
98 | :::: column
99 | 
100 |
101 | See also [Sutton and Barto, Reinforcement Learning: An Introduction](http://incompleteideas.net/book/RLbook2020.pdf)
102 | ::::
103 |
104 | :::
105 |
106 | ## What is Reinforcement Learning (RL)?
107 |
108 | ### Notation
109 |
110 | ::: columns
111 |
112 | :::: column
113 |
114 | - Discrete time steps $t = 0, 1, 2, \dots$
115 | - The environment is in a __state $S_t$__
116 | - Agent performs an __action $A_t$__
117 | - Environment responds with a new state $S_{t+1}$ and a reward $R_{t+1}$
118 | - Based on that $S_t$ the agent selects the __next action__ $A_{t+1}$
119 |
120 | ::::
121 |
122 | :::: column
123 | 
124 | ::::
125 |
126 | :::
127 |
128 | ## RL Software in this Workshop
129 |
130 | ::: columns
131 |
132 | :::: column
133 |
134 | 
135 |
136 | \footnotesize Represents environments for RL
137 |
138 | ::::
139 |
140 | :::: column
141 |
142 | 
143 |
144 | \footnotesize RL algorithm implementations in PyTorch
145 |
146 | ::::
147 |
148 | :::
149 |
150 | ## Exercise 1: You are the agent
151 |
152 | ::: columns
153 |
154 | :::: {.column width=40%}
155 |
156 | Start the environment:
157 |
158 | ```plain
159 | pixi run start_world \
160 | --env GreenhousePlain
161 | ```
162 |
163 | {height=100px}
164 |
165 | __Welcome__
166 | You are a robot that has to water plants in a greenhouse
167 |
168 | ::::
169 |
170 | :::: {.column width=60%}
171 |
172 | Then, in another terminal, run:
173 |
174 | ```plain
175 | pixi run eval --model manual \
176 | --env GreenhousePlain \
177 | --config \
178 | greenhouse_env_config.yaml
179 | ```
180 |
181 | ```plain
182 | Enter action from [0, 1]:
183 | ```
184 |
185 | On this prompt, you can choose:
186 |
187 | - __0__: Move forward without watering, or
188 | - __1__: Water the plant and move on.
189 |
190 | ---
191 |
192 | But be __careful__: If you try to water the evil plant _(red)_, you will be eaten.
193 |
194 | {width=80px}
195 |
196 | ::::
197 |
198 | :::
199 |
200 | # Concepts
201 |
202 | ## Environment = MDP
203 |
204 | ### MDP
205 |
206 | We assume the environment to follow a __Markov Decision Process (MDP)__ model.
207 | An MDP is defined as $< \mathcal{S}, \mathcal{A}, \mathcal{P}, \mathcal{R}>$.
208 |
209 | - $s \in \mathcal{S}$ states and $a \in \mathcal{A}$ actions as above.
210 | - $\mathcal{P}$ State Transition Probability: $P(s'|s, a)$.
211 | - For an action $a$ taken in state $s$, what is the probability of reaching state $s'$?
212 | - $\mathcal{R}$ Reward Function: $R(s, a)$.
213 | - We will use this to motivate the agent to learn desired behavior.
214 |
215 | {height=90px}
216 |
217 | ## Markovian
218 |
219 | ### Markov Property
220 |
221 | Implicit to the MDP formulation is the __Markov Property__:
222 |
223 | The future state $S_{t+1}$ depends only on the current state $S_t$ and action $A_t$,
224 | not on the sequence of events that preceded it.
225 |
226 | ### Practical implication
227 |
228 | Not a direct limitation for practical use, but something to be aware of.
229 |
230 | - E.g., if history matters, include it in the state representation.
231 | - However, this will not make learning easier.
232 |
233 | ## Agent = Policy
234 |
235 | ### Policy
236 |
237 | The agent's behavior is defined by a __policy__ $\pi$.
238 | A policy is a mapping from states to actions: $\pi: \mathcal{S} \rightarrow \mathcal{A}$.
239 |
240 | ### Reminder
241 |
242 | We are trying to optimize the __cumulative reward__ (or __return__) over time:
243 |
244 | $$
245 | G_t = R_0 + R_1 + R_2 + \dots
246 | $$
247 |
248 | In practice, we use a __discount factor__ $\gamma \in [0, 1]$ to prioritize immediate rewards:
249 |
250 | $$
251 | G_t = R_0 + \gamma R_1 + \gamma^2 R_2 + \dots
252 | $$
253 | $$
254 | G_t = \sum_{k=0}^{\infty} \gamma^k R_{t+k}
255 | $$
256 |
257 | ## Learning
258 |
259 | ### Bellman Equation
260 |
261 | This is probably the __most fundamental equation in RL__.
262 | It estimates $v_{\pi}(s)$, known as the __state value function__ when using policy $\pi$:
263 |
264 | $$v_{\pi}(s) = \mathbb{E}_{\pi} [G_t | S_t = s]$$
265 | $$ = \mathbb{E}_{\pi} [R_{t+1} + \gamma v_{\pi}(S_{t+1}) | S_t = s]$$
266 | $$ = \sum_{a} \pi(a|s) \sum_{s', r} p(s', r | s, a) [r + \gamma v_{\pi}(s')]$$
267 |
268 | ### Fundamental optimization goal
269 |
270 | So, we can formulate the problem to find an optimal policy $\pi^*$ as an optimization problem:
271 |
272 | $$\pi^* = \arg\max_{\pi} v_{\pi}(s), \quad \forall s \in \mathcal{S}$$
273 |
274 | # Methods
275 |
276 | ## Temporal Differencing
277 |
278 | The Bellman equation gives rise to __temporal differencing (TD)__ for training a policy.
279 |
280 | $$v_{\pi}(S_t) \leftarrow (1 - \alpha) v_{\pi}(S_t) + \alpha (R_{t+1} + \gamma v_{\pi}(S_{t+1}))$$
281 |
282 | where
283 |
284 | - $v_{\pi}(S_t)$ is the expected value of state $S_t$
285 | - $R_{t+1} + \gamma v_{\pi}(S_{t+1})$ is the actual reward obtained at $S_t$ plus the expected value of the next state $S_{t+1}$
286 | - $R_{t+1} + \gamma v_{\pi}(S_{t+1}) - v_{\pi}(S_t)$ is the __TD error__.
287 | - $\alpha$ is the __learning rate__.
288 |
289 | \small (a variant using the __state-action value function__ $Q_{\pi}(s, a)$
290 | \small is known as __Q-learning__.)
291 |
292 | ## Tabular Reinforcement Learning
293 |
294 | RL began with known MDPs + discrete states/actions, so $v_{\pi}(s)$ or $q_{\pi}(s,a)$ are __tables__.
295 |
296 | )](media/grid-world.png){width=150px}
297 |
298 | Can use __dynamic programming__ to iterate through the entire environment and converge on an optimal policy.
299 |
300 | ::: columns
301 |
302 | :::: column
303 |
304 | )](media/value-iteration.png){width=80px}
305 |
306 | ::::
307 |
308 | :::: column
309 |
310 | )](media/policy-iteration.png){width=80px}
311 |
312 | ::::
313 |
314 | :::
315 |
316 | ## Model-Free Reinforcement Learning
317 |
318 | If the state-action space is too large, need to perform __rollouts__ to gain experience.
319 |
320 | Key: Balancing __exploitation__ and __exploration__!
321 |
322 | )](media/model-free-rl.png){width=430px}
323 |
324 | ## Deep Reinforcement Learning
325 |
326 | When the observation space is too large (or worse, continuous), tabular methods no longer work.
327 |
328 | Need a different function approximator -- _...why not a neural network?_
329 |
330 | )](media/dqn.png){width=250px}
331 |
332 | __Off-policy__: Can train on old experiences from a _replay buffer_.
333 |
334 | ## Actor-Critic / Policy Gradient Methods
335 |
336 | DQN only works for discrete actions, so what about continuous actions?
337 |
338 | ::: columns
339 |
340 | :::: {.column width=60%}
341 |
342 | - __Critic__ approximates value function. Trained via TD learning.
343 |
344 | - __Actor__ outputs actions (i.e., the policy). Trained via __policy gradient__, backpropagated from critic loss.
345 |
346 | ---
347 |
348 | - Initial methods were __on-policy__ -- can only train on the latest version of the policy with current experiences.
349 | Example: Proximal Policy Optimization (PPO) ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)).
350 |
351 | - Other approaches train actor and critic at different time scales to allow off-policy.
352 | Example: Soft Actor-Critic (SAC) ([Haarnoja et al., 2018](https://arxiv.org/abs/1801.01290)).
353 |
354 | ::::
355 |
356 | :::: {.column width=40%}
357 |
358 | )](media/actor-critic.png){width=160px}
359 |
360 | ::::
361 | :::
362 |
363 | ## Concept Comparison
364 |
365 | ::: columns
366 | :::: {.column width=45%}
367 |
368 | ### Exploitation vs. Exploration
369 |
370 | __Exploitation__: Based on the current policy, select the action that maximizes expected reward.
371 |
372 | __Exploration__: Select actions that may not maximize immediate reward, but could lead to better long-term outcomes.
373 |
374 | ### On-policy vs. Off-policy
375 |
376 | __On-policy__: Learn the value of the policy being carried out by the agent.
377 |
378 | __Off-policy__: Learn the value of an optimal policy independently of the agent's actions.
379 | ::::
380 |
381 | :::: {.column width=55%}
382 | Benefits of __on-policy__ methods:
383 |
384 | $-$ Collect data that is relevant under the current policy.
385 |
386 | $-$ More stable learning.
387 |
388 | ---
389 |
390 | Benefits of __off-policy__ methods:
391 |
392 | $-$ Better sample efficiency.
393 |
394 | $-$ Relevant for real-world robotics (gathering data is expensive).
395 |
396 | ::::
397 |
398 | :::
399 |
400 | ## Available Algorithms
401 |
402 | ### DQN
403 |
404 | __Deep Q Network__
405 |
406 | ::: columns
407 | :::: {.column width=50%}
408 | \small
409 | Learns a Q-function $Q(s, a)$. Introduced _experience replay_ and _target networks_.
410 | ::::
411 | :::: {.column width=25%}
412 | \small
413 | Off-policy
414 | Discrete actions
415 | ::::
416 | :::: {.column width=25%}
417 | \small
418 | [Mnih et al., 2013](https://arxiv.org/abs/1312.5602)
419 | [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html)
420 | ::::
421 | :::
422 |
423 | ### A2C
424 |
425 | __Advantage Actor-Critic__
426 |
427 | ::: columns
428 | :::: {.column width=50%}
429 | \small
430 | $A(s, a) = Q(s, a) - V(s)$ _advantage function_ to reduce variance.
431 | ::::
432 | :::: {.column width=25%}
433 | \small
434 | On-policy
435 | Any action space
436 | ::::
437 | :::: {.column width=25%}
438 | \small
439 | [Mnih et al., 2016](https://arxiv.org/abs/1602.01783)
440 | [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html)
441 | ::::
442 | :::
443 |
444 | ### PPO
445 |
446 | __Proximal Policy Optimization__
447 |
448 | ::: columns
449 | :::: {.column width=50%}
450 | \small
451 | Optimize policy directly. Uses a _clipped surrogate objective_ for stability.
452 | ::::
453 | :::: {.column width=25%}
454 | \small
455 | On-policy
456 | Any action space
457 | ::::
458 | :::: {.column width=25%}
459 | \small
460 | [Schulman et al., 2017](https://arxiv.org/abs/1707.06347) [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html)
461 | ::::
462 | :::
463 |
464 | ### SAC
465 |
466 | __Soft Actor-Critic__
467 |
468 | ::: columns
469 | :::: {.column width=50%}
470 | \small
471 | Separate Actor & Critic NNs. Exploration by additional _entropy_ term.
472 | ::::
473 | :::: {.column width=25%}
474 | \small
475 | Off-policy
476 | Cont. actions
477 | ::::
478 | :::: {.column width=25%}
479 | \small
480 | [Haarnoja et al., 2018](https://arxiv.org/abs/1801.01290)
481 | [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/sac.html)
482 | :::::
483 | :::
484 |
485 | ## Which Algorithm to Choose?
486 |
487 | From the lead developer of SB3 (Antonin Raffin):
488 |
489 | {width=300px}
490 |
491 | # Exercises
492 |
493 | ## Exercise 2: Run with a Random Agent
494 |
495 | ```plain
496 | pixi run start_world --env GreenhousePlain
497 |
498 | pixi run eval --realtime \
499 | --config greenhouse_env_config.yaml --model \
500 | pyrobosim_ros_gym/policies/GreenhousePlain_DQN_random.pt
501 | ```
502 |
503 | {height=150px}
504 |
505 | ## Exercise 3: Training Your First Agent
506 |
507 | Start the world.
508 |
509 | ```plain
510 | pixi run start_world --env GreenhousePlain
511 | ```
512 |
513 | Kick off training.
514 |
515 | ```plain
516 | pixi run train --config greenhouse_env_config.yaml \
517 | --env GreenhousePlain --algorithm DQN \
518 | --discrete-actions --realtime
519 | ```
520 |
521 | The `--config` file points to `pyrobosim_ros_gym/config/greenhouse_env_config.yaml`, which lets you easily set up different algorithms and training parameters.
522 |
523 | ## Exercise 3: Training Your First Agent (For Real...)
524 |
525 | ... this is going to take a while.
526 | Let's speed things up.
527 |
528 | Run simulation headless, i.e., without the GUI.
529 |
530 | ```plain
531 | pixi run start_world --env GreenhousePlain --headless
532 | ```
533 |
534 | Without "realtime", run actions as fast as possible
535 |
536 | ```plain
537 | pixi run train --config greenhouse_env_config.yaml \
538 | --env GreenhousePlain --algorithm DQN --discrete-actions
539 | ```
540 |
541 | __NOTE:__ Seeding the training run is important for reproducibility!
542 |
543 | We are running with `--seed 42` by default, but you can change it.
544 |
545 | ## Exercise 3: Visualizing Training Progress
546 |
547 | SB3 has visualization support for [TensorBoard](https://www.tensorflow.org/tensorboard).
548 | By adding the `--log` argument, a log file will be written to the `train_logs` folder.
549 |
550 | ```plain
551 | pixi run train --config greenhouse_env_config.yaml \
552 | --env GreenhousePlain --algorithm DQN \
553 | --discrete-actions --log
554 | ```
555 |
556 | Open TensorBoard and follow the URL (usually `http://localhost:6006/`).
557 |
558 | ```plain
559 | pixi run tensorboard
560 | ```
561 |
562 | {width=200px}
563 |
564 | ## Exercise 3: Reward Engineering
565 |
566 | Open the file `pyrobosim_ros_gym/config/greenhouse_env_config.yaml`.
567 |
568 | There, you will see 3 options for the `training.reward_fn` parameter.
569 |
570 | Train models with each and compare the effects of __reward shaping__ on results.
571 |
572 | \small
573 | * `sparse_reward` : -5 if evil plant is watered, +8 if _all_ good plants are watered.
574 | * `dense_reward` : -5 if evil plant is watered, +2 for _each_ good plant that is watered.
575 | * `full_reward` : Same as above, but adds small penalties for wasting water / passing over a good plant.
576 |
577 | {height=120px}
578 |
579 | ## Exercise 3: Evaluating Your Trained Agent
580 |
581 | Once you have your trained models, you can evaluate them against the simulator.
582 |
583 | ```plain
584 | pixi run eval --config greenhouse_env_config.yaml \
585 | --model .pt --num-episodes 10
586 | ```
587 |
588 | By default, this will run just like training (as quickly as possible).
589 |
590 | You can add the `--realtime` flag to slow things down to "real-time" so you can visually inspect the results.
591 |
592 | {width=240px}
593 |
594 | ## Exercise 4: Train More Complicated Environments
595 |
596 | \small Training the `GreenhousePlain` environment is easy because the environment is _deterministic_; the plants are always in the same locations.
597 |
598 | For harder environments, you may want to switch algorithms (e.g., `PPO` or `SAC`).
599 |
600 | ::: columns
601 |
602 | :::: column
603 |
604 | {width=120px}
605 |
606 | Plants are now spawned in random locations -- but only one per table.
607 |
608 | ::::
609 |
610 | ::: column
611 |
612 | {width=120px}
613 |
614 | Watering costs 49% battery -- must recharge after watering twice.
615 |
616 | Charging is a new action (id `3`).
617 |
618 | ::::
619 |
620 | :::
621 |
622 | __Challenge__: Evaluate your policy on the `GreenhouseRandom` environment!
623 |
624 | ## Application: Deploying a Trained Policy as a ROS Node
625 |
626 | 1. Start an environment of your choice.
627 |
628 | ```plain
629 | pixi run start_world --env GreenhouseRandom
630 | ```
631 |
632 | 2. Start the node with an appropriate model.
633 |
634 | ```plain
635 | pixi run policy_node --model .pt \
636 | --config greenhouse_env_config.yaml
637 | ```
638 |
639 | 3. Open an interactive shell.
640 |
641 | ```plain
642 | pixi shell
643 | ```
644 |
645 | 4. In the shell, send an action goal to run the policy to completion!
646 |
647 | ```plain
648 | ros2 action send_goal /execute_policy \
649 | rl_interfaces/ExecutePolicy {}
650 | ```
651 |
652 | # Discussion
653 |
654 | ## When to use RL?
655 |
656 | Arguably, our simple greenhouse problem did not need RL.
657 |
658 | ... but it was nice and educational... right?
659 |
660 | ### General rules
661 |
662 | - If easy to model, __engineer it by hand__ (e.g., controllers, behavior trees).
663 | - If difficult to model, but you can provide the answer (e.g., labels or demonstrations), consider __supervised learning__.
664 | - If difficult to model, and you cannot easily provide an answer, consider __reinforcement learning__.
665 |
666 | ## Scaling up Learning
667 |
668 | ### Parallel simulation
669 |
670 | ::: columns
671 |
672 | :::: column
673 |
674 | - Simulations can be parallelized using multiple CPUs / GPUs.
675 |
676 | - SB3 defaults to using __vectorized environments__.
677 |
678 | - Other tools for parallel RL include [NVIDIA Isaac Lab](https://github.com/isaac-sim/IsaacLab) and [mjlab](https://github.com/mujocolab/mjlab).
679 |
680 | ::::
681 |
682 | :::: column
683 |
684 | ](media/rsl-parallel-sim.png){height=90px}
685 |
686 | ::::
687 |
688 | :::
689 |
690 | ## Curriculum learning
691 |
692 | ::: columns
693 |
694 | :::: column
695 |
696 | - __Reward shaping__ by itself is very important to speed up learning.
697 |
698 | - Rather than solving the hardest problem from scratch, introduce a __curriculum__ of progressively harder tasks.
699 |
700 | ::::
701 |
702 | :::: column
703 |
704 | ](media/curriculum-learning.png){height=100px}
705 |
706 | ::::
707 |
708 | :::
709 |
710 | ## RL Experimentation
711 |
712 | ### RNG Seeds
713 |
714 | ::: columns
715 |
716 | :::: column
717 |
718 | - It is possible to get a "lucky" (or "unlucky") seed when training.
719 |
720 | - Best (and expected) practice is to run multiple experiments and report intervals.
721 |
722 | ::::
723 |
724 | :::: column
725 |
726 | ](media/ppo-graph.png){height=90px}
727 |
728 | ::::
729 |
730 | :::
731 |
732 | ## Hyperparameter Tuning
733 |
734 | ::: columns
735 |
736 | :::: column
737 |
738 | ### Hyperparameter
739 |
740 | - Any user-specified parameter for ML training.
741 |
742 | - e.g., learning rate, batch/network size, reward weights, ...
743 |
744 | - Consider using automated tools (e.g., [Optuna](https://github.com/optuna/optuna)) to help you tune hyperparameters.
745 |
746 | ::::
747 |
748 | :::: column
749 |
750 | ](media/optuna-dashboard.png){height=90px}
751 |
752 | ::::
753 |
754 | :::
755 |
756 | ## Deploying Policies to ROS
757 |
758 | ::: columns
759 |
760 | :::: {.column width=40%}
761 |
762 | ### Python
763 |
764 | - Can directly put PyTorch / Tensorflow / etc. models in a `rclpy` node.
765 | - Be careful with threading and CPU/GPU synchronization issues!
766 |
767 | ### C++
768 |
769 | - If you need performance, consider using C++ for inference.
770 | - Facilitated by tools like [ONNX Runtime](https://onnxruntime.ai/inference).
771 | - Can also put your policy inside a `ros2_control` controller for real-time capabilities.
772 |
773 | ::::
774 |
775 | :::: {.column width=60%}
776 |
777 | ](media/pytorch-threading-question.png){height=120px}
778 |
779 | ](media/onnx-runtime.png){height=100px}
780 |
781 | ::::
782 |
783 | :::
784 |
785 | ## RL for Deliberation
786 |
787 | ### Background
788 |
789 | __State of the art RL works for fast, low-level control policies (e.g., locomotion)__
790 |
791 | - Requires sim-to-real training because on-robot RL is hard and/or unsafe.
792 | - Alternatives: fine-tune pretrained policies or train _residual_ policies.
793 |
794 | ### Deliberation
795 |
796 | __How does this change for deliberation applications?__
797 |
798 | - Facilitates on-robot RL: train high-level decision making, add a safety layer below.
799 | - Hierarchical RL dates back to the __options framework__ ([Sutton et al., 1998](http://incompleteideas.net/609%20dropbox/other%20readings%20and%20resources/Options.pdf)).
800 | - What kinds of high-level decisions can/should be learned?
801 | - What should this "safety layer" below look like?
802 |
803 | ## Further Resources
804 |
805 | ### RL Theory
806 |
807 | - \small Sutton & Barto Textbook:
808 | - \small David Silver Lectures:
809 | - \small Stable Baselines3 docs:
810 |
811 | ### ROS Deliberation
812 |
813 |
814 |
815 | Join our mailing list and ~monthly meetings!
816 |
817 | {height=100px}
818 |
--------------------------------------------------------------------------------