├── amp_rsl_rl ├── __init__.py ├── algorithms │ ├── __init__.py │ └── amp_ppo.py ├── storage │ ├── __init__.py │ └── replay_buffer.py ├── runners │ ├── __init__.py │ └── amp_on_policy_runner.py ├── networks │ ├── __init__.py │ ├── ac_moe.py │ └── discriminator.py └── utils │ ├── __init__.py │ ├── exporter.py │ └── motion_loader.py ├── .github └── workflows │ └── publish-pypi.yml ├── LICENSE ├── pyproject.toml ├── benchmarking ├── benchmark_replay_buffer.py └── benchmark_download_and_loader.py ├── example └── load_amp_dataset.py ├── README.md └── .gitignore /amp_rsl_rl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | 7 | """Main module for the amp_rsl_rl package.""" 8 | -------------------------------------------------------------------------------- /amp_rsl_rl/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Implementation of different RL agents using AMP.""" 7 | 8 | from .amp_ppo import AMP_PPO 9 | 10 | __all__ = ["AMP_PPO"] 11 | -------------------------------------------------------------------------------- /amp_rsl_rl/storage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | 7 | """Implementation of replay buffer for storing and sampling data.""" 8 | 9 | from .replay_buffer import ReplayBuffer 10 | 11 | __all__ = ["ReplayBuffer"] 12 | -------------------------------------------------------------------------------- /amp_rsl_rl/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Implementation of runners for environment-agent interaction.""" 7 | 8 | from .amp_on_policy_runner import AMPOnPolicyRunner 9 | 10 | __all__ = ["AMPOnPolicyRunner"] 11 | -------------------------------------------------------------------------------- /amp_rsl_rl/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | 7 | """Implementation of the network for the AMP algorithm.""" 8 | 9 | from .discriminator import Discriminator 10 | from .ac_moe import ActorMoE, ActorCriticMoE 11 | 12 | __all__ = ["Discriminator", "ActorCriticMoE", "ActorMoE"] 13 | -------------------------------------------------------------------------------- /amp_rsl_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | 7 | """Utilities for amp""" 8 | 9 | from .motion_loader import AMPLoader, download_amp_dataset_from_hf 10 | from .exporter import export_policy_as_onnx 11 | 12 | __all__ = [ 13 | "AMPLoader", 14 | "download_amp_dataset_from_hf", 15 | "export_policy_as_onnx", 16 | ] 17 | -------------------------------------------------------------------------------- /.github/workflows/publish-pypi.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Publish to PyPI on release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | 13 | - name: 🐍 Set up Python 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: '3.x' 17 | 18 | - name: 📦 Install build dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install build 22 | 23 | - name: 🏗️ Build package 24 | run: python -m build 25 | 26 | - name: 🚀 Publish to PyPI 27 | uses: pypa/gh-action-pypi-publish@release/v1 28 | with: 29 | password: ${{ secrets.PYPI_API_TOKEN }} 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2025, Artificial and Mechanical Intelligence 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=61.0", 4 | "setuptools_scm>=7.0", 5 | "wheel" # Added for better binary distribution support 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [project] 10 | name = "amp-rsl-rl" 11 | description = "Adversarial Motion Prior (AMP) reinforcement learning extension for PPO based on RSL-RL." 12 | authors = [ 13 | { name = "Giulio Romualdi", email = "giulio.romualdi@iit.it" }, 14 | { name = "Giuseppe L'Erario", email = "giuseppe.lerario@iit.it" } 15 | ] 16 | maintainers = [ # Added separate maintainers section 17 | { name = "Giulio Romualdi", email = "giulio.romualdi@iit.it" }, 18 | { name = "Giuseppe L'Erario", email = "giuseppe.lerario@iit.it" } 19 | ] 20 | license = "BSD-3-Clause" 21 | readme = "README.md" 22 | requires-python = ">=3.8" 23 | classifiers = [ # Added standard PyPI classifiers 24 | "Development Status :: 4 - Beta", 25 | "Intended Audience :: Science/Research", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 31 | ] 32 | keywords = ["reinforcement-learning", "robotics", "motion-priors", "ppo"] 33 | dependencies = [ 34 | "numpy>=1.21.0", 35 | "scipy>=1.7.0", 36 | "rsl-rl-lib>=3.0.0", 37 | "torch>=2.6.0", 38 | "tensordict>=0.7.0", 39 | ] 40 | dynamic = ["version"] 41 | 42 | [project.optional-dependencies] 43 | examples = ["huggingface_hub"] 44 | 45 | [project.urls] 46 | Homepage = "https://github.com/ami-iit/amp-rsl-rl" 47 | Repository = "https://github.com/ami-iit/amp-rsl-rl" 48 | BugTracker = "https://github.com/ami-iit/amp-rsl-rl/issues" 49 | Changelog = "https://github.com/ami-iit/amp-rsl-rl/releases" 50 | 51 | [tool.setuptools_scm] 52 | local_scheme = "dirty-tag" 53 | 54 | [tool.setuptools] 55 | packages = ["amp_rsl_rl"] 56 | -------------------------------------------------------------------------------- /benchmarking/benchmark_replay_buffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # benchmark_replay_buffer.py 3 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 4 | # All rights reserved. 5 | # 6 | # SPDX-License-Identifier: BSD-3-Clause 7 | 8 | import time 9 | import torch 10 | from amp_rsl_rl.storage.replay_buffer import ReplayBuffer 11 | 12 | # ============================================= 13 | # CONFIGURATION 14 | # ============================================= 15 | device_str = "cuda" if torch.cuda.is_available() else "cpu" 16 | obs_dim = 60 # dimension of each state vector 17 | buffer_size = 200_000 # capacity of the circular buffer 18 | insert_batch = 4096 # how many transitions per insert() call 19 | num_inserts = 50 # how many insert() calls to benchmark 20 | mini_batch_size = 1024 # size of each sampled mini-batch 21 | num_mini_batches = 20 # how many mini-batches to sample 22 | 23 | 24 | def main(): 25 | device = torch.device(device_str) 26 | print(f"\n[ReplayBuffer Benchmark] Device: {device}\n") 27 | 28 | # 1) Initialize buffer 29 | buf = ReplayBuffer(obs_dim, buffer_size, device) 30 | 31 | # 2) Prepare dummy data 32 | dummy_states = torch.randn(insert_batch, obs_dim, device=device) 33 | dummy_next = torch.randn(insert_batch, obs_dim, device=device) 34 | 35 | # Warm up (GPU kernels, caches, etc.) 36 | for _ in range(5): 37 | buf.insert(dummy_states, dummy_next) 38 | for _ in buf.feed_forward_generator(1, mini_batch_size): 39 | pass 40 | 41 | # 3) Benchmark insert() 42 | if torch.cuda.is_available(): 43 | torch.cuda.synchronize() 44 | t0 = time.perf_counter() 45 | for _ in range(num_inserts): 46 | buf.insert(dummy_states, dummy_next) 47 | if torch.cuda.is_available(): 48 | torch.cuda.synchronize() 49 | t1 = time.perf_counter() 50 | 51 | total_inserted = insert_batch * num_inserts 52 | insert_rate = total_inserted / (t1 - t0) 53 | print( 54 | f"[insert] Inserted {total_inserted} samples in {(t1 - t0)*1e3:.1f} ms → " 55 | f"{insert_rate:,.0f} samples/s" 56 | ) 57 | 58 | # Ensure there's enough data to sample 59 | assert len(buf) >= mini_batch_size * num_mini_batches, ( 60 | f"Need at least {mini_batch_size * num_mini_batches} samples, " 61 | f"but buffer has only {len(buf)}" 62 | ) 63 | 64 | # 4) Benchmark sampling 65 | torch.cuda.synchronize() 66 | t2 = time.perf_counter() 67 | sampled = 0 68 | for states, next_states in buf.feed_forward_generator( 69 | num_mini_batches, mini_batch_size 70 | ): 71 | sampled += states.size(0) 72 | torch.cuda.synchronize() 73 | t3 = time.perf_counter() 74 | 75 | sample_rate = sampled / (t3 - t2) 76 | print( 77 | f"[sample] Sampled {sampled} samples in {(t3 - t2)*1e3:.1f} ms → " 78 | f"{sample_rate:,.0f} samples/s" 79 | ) 80 | 81 | # 5) Combined insert + sample 82 | torch.cuda.synchronize() 83 | t4 = time.perf_counter() 84 | ops = 0 85 | for _ in range(num_inserts): 86 | buf.insert(dummy_states, dummy_next) 87 | for states, _ in buf.feed_forward_generator(1, mini_batch_size): 88 | ops += states.size(0) 89 | torch.cuda.synchronize() 90 | t5 = time.perf_counter() 91 | 92 | combined_rate = (total_inserted + ops) / (t5 - t4) 93 | print( 94 | f"[combined] insert+sample of {total_inserted + ops} ops in {(t5 - t4)*1e3:.1f} ms → " 95 | f"{combined_rate:,.0f} ops/s\n" 96 | ) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /example/load_amp_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | # Import required libraries 7 | from pathlib import Path # For path manipulation 8 | import tempfile # For creating temporary directories 9 | from amp_rsl_rl.utils import ( 10 | AMPLoader, 11 | download_amp_dataset_from_hf, 12 | ) # Core AMP utilities 13 | import torch # PyTorch for tensor operations 14 | 15 | # ============================================= 16 | # CONFIGURATION SECTION 17 | # ============================================= 18 | # Define the dataset source and files to download 19 | repo_id = "ami-iit/amp-dataset" # Hugging Face repository ID 20 | robot_folder = "ergocub" # Subfolder containing robot-specific datasets 21 | 22 | # List of motion dataset files to download 23 | files = [ 24 | "ergocub_stand_still.npy", # Standing still motion 25 | "ergocub_walk_left0.npy", # Walking left motion 26 | "ergocub_walk.npy", # Straight walking motion 27 | "ergocub_walk_right2.npy", # Walking right motion 28 | ] 29 | 30 | # ============================================= 31 | # DATASET DOWNLOAD AND LOADING 32 | # ============================================= 33 | # Create a temporary directory to store downloaded datasets 34 | # This ensures clean up after we're done 35 | with tempfile.TemporaryDirectory() as tmpdirname: 36 | local_dir = Path(tmpdirname) # Convert to Path object for easier handling 37 | 38 | # Download datasets from Hugging Face Hub 39 | # Returns the base names of the downloaded files (without .npy extension) 40 | dataset_names = download_amp_dataset_from_hf( 41 | local_dir, # Where to save the files 42 | robot_folder=robot_folder, # Which robot dataset to use 43 | files=files, # Which specific motion files to download 44 | repo_id=repo_id, # Repository ID on Hugging Face Hub 45 | ) 46 | 47 | # ============================================= 48 | # DATASET PROCESSING WITH AMPLoader 49 | # ============================================= 50 | # Initialize the AMPLoader to process and manage the motion data 51 | loader = AMPLoader( 52 | device="cpu", # Use CPU for processing (change to "cuda" for GPU) 53 | dataset_path_root=local_dir, # Path to downloaded datasets 54 | dataset_names=dataset_names, # Names of the loaded datasets 55 | dataset_weights=[1.0] * len(dataset_names), # Equal weights for all motions 56 | simulation_dt=1 / 60.0, # Simulation timestep (60Hz) 57 | slow_down_factor=1, # Don't slow down the motions 58 | expected_joint_names=None, # Use default joint ordering 59 | ) 60 | 61 | # ============================================= 62 | # EXAMPLE USAGE 63 | # ============================================= 64 | # Get the first motion sequence from the loader 65 | motion = loader.motion_data[0] 66 | 67 | # Print basic information about the loaded motion 68 | print("Loaded dataset with", len(motion), "frames.") 69 | 70 | # Get and print a sample observation (first frame) 71 | sample_obs = motion.get_amp_dataset_obs(torch.tensor([0])) # Get frame 0 72 | print("Sample AMP observation:", sample_obs) 73 | 74 | # The motion data contains: 75 | # - Joint positions and velocities 76 | # - Base linear/angular velocities (local and world frames) 77 | # - Base orientation (quaternion) 78 | 79 | # Typical usage patterns: 80 | # 1. For training: Use loader.feed_forward_generator() to get batches 81 | # 2. For reset: Use loader.get_state_for_reset() to initialize robot states 82 | # 3. For observation: Use motion.get_amp_dataset_obs() to get specific frames 83 | 84 | # The temporary directory is automatically deleted when the 'with' block ends 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMP-RSL-RL 2 | 3 | AMP-RSL-RL is a reinforcement learning library that extends the Proximal Policy Optimization (PPO) implementation of [RSL-RL](https://github.com/leggedrobotics/rsl_rl) to incorporate Adversarial Motion Priors (AMP). This framework enables humanoid agents to learn motor skills from motion capture data using adversarial imitation learning techniques. 4 | 5 | --- 6 | 7 | ## 📦 Installation 8 | 9 | The repository is available on PyPI under the package name **amp-rl-rsl**. You can install it directly using pip: 10 | 11 | ```bash 12 | pip install amp-rsl-rl 13 | ``` 14 | 15 | Alternatively, if you prefer to clone the repository and install it locally, follow these steps: 16 | 17 | 1. Clone the repository: 18 | ```bash 19 | git clone https://github.com/gbionics/amp_rsl_rl.git 20 | cd amp_rsl_rl 21 | ``` 22 | 23 | 2. Install the package: 24 | ```bash 25 | pip install . 26 | ``` 27 | 28 | For editable/development mode: 29 | 30 | ```bash 31 | pip install -e . 32 | ``` 33 | 34 | If you want to run the examples, please install with: 35 | 36 | ```bash 37 | pip install .[examples] 38 | ``` 39 | 40 | The required dependencies include: 41 | 42 | - `numpy` 43 | - `scipy` 44 | - `torch` 45 | - `rsl-rl-lib` 46 | 47 | These will be automatically installed via pip. 48 | 49 | --- 50 | 51 | ## 📂 Project Structure 52 | 53 | ``` 54 | amp_rsl_rl/ 55 | │ 56 | ├── algorithms/ # AMP and PPO implementations 57 | ├── networks/ # Neural networks for policy and discriminator 58 | ├── runners/ # Training and evaluation routines 59 | ├── storage/ # Replay buffer for experience collection 60 | ├── utils/ # Dataset loaders and motion tools 61 | ``` 62 | 63 | --- 64 | 65 | ## 📁 Dataset Structure 66 | 67 | The AMP-RSL-RL framework expects motion capture datasets in `.npy` format. Each `.npy` file must contain a Python dictionary with the following keys: 68 | 69 | - **`joints_list`**: `List[str]` 70 | A list of joint names. These should correspond to the joint order expected by the agent. 71 | 72 | - **`joint_positions`**: `List[np.ndarray]` 73 | A list where each element is a NumPy array representing the joint positions at a frame. All arrays should have the same shape `(N,)`, where `N` is the number of joints. 74 | 75 | - **`root_position`**: `List[np.ndarray]` 76 | A list of 3D vectors representing the position of the base (root) of the agent in world coordinates for each frame. 77 | 78 | - **`root_quaternion`**: `List[np.ndarray]` 79 | A list of unit quaternions in **`xyzw`** format (SciPy convention), representing the base orientation of the agent for each frame. 80 | 81 | - **`fps`**: `float` 82 | The number of frames per second in the original dataset. This is used to resample the data to match the simulator's timestep. 83 | 84 | ### Example 85 | 86 | Here’s an example of how the structure might look when loaded in Python: 87 | 88 | ```python 89 | { 90 | "joints_list": ["hip", "knee", "ankle"], 91 | "joint_positions": [np.array([0.1, -0.2, 0.3]), np.array([0.11, -0.21, 0.31]), ...], 92 | "root_position": [np.array([0.0, 0.0, 1.0]), np.array([0.01, 0.0, 1.0]), ...], 93 | "root_quaternion": [np.array([0.0, 0.0, 0.0, 1.0]), np.array([0.0, 0.0, 0.1, 0.99]), ...], 94 | "fps": 120.0 95 | } 96 | ``` 97 | 98 | All lists must have the same number of entries (i.e. one per frame). The dataset should represent smooth motion captured over time. 99 | 100 | --- 101 | 102 | ## 📚 Supported Dataset 103 | 104 | For a ready-to-use motion capture dataset, you can use the [AMP Dataset on Hugging Face](https://huggingface.co/datasets/ami-iit/amp-dataset). This dataset is curated to work seamlessly with the AMP-RSL-RL framework. 105 | 106 | --- 107 | 108 | ## 🧑‍💻 Authors 109 | 110 | - **Giulio Romualdi** – [@GiulioRomualdi](https://github.com/GiulioRomualdi) 111 | - **Giuseppe L'Erario** – [@Giulero](https://github.com/Giulero) 112 | 113 | --- 114 | 115 | ## 📄 License 116 | 117 | BSD 3-Clause License © 2025 Istituto Italiano di Tecnologia 118 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /benchmarking/benchmark_download_and_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """ 7 | benchmark_download_and_loader.py: Download AMP datasets and benchmark loader performance with fixed config. 8 | """ 9 | 10 | from pathlib import Path 11 | import tempfile 12 | import time 13 | import torch 14 | 15 | from amp_rsl_rl.utils import AMPLoader, download_amp_dataset_from_hf 16 | 17 | # ============================================= 18 | # CONFIGURATION (hard-coded) 19 | # ============================================= 20 | repo_id = "ami-iit/amp-dataset" 21 | robot_folder = "ergocub" 22 | files = [ 23 | "ergocub_stand_still.npy", 24 | "ergocub_walk_left0.npy", 25 | "ergocub_walk.npy", 26 | "ergocub_walk_right2.npy", 27 | ] 28 | dataset_weights = [1.0] * len(files) 29 | device_str = "cuda" # or "cpu" 30 | simulation_dt = 1.0 / 60.0 31 | slow_down_factor = 1 32 | num_samples = 20 33 | batch_size = 32768 34 | 35 | joint_names = [ 36 | "l_ankle_pitch", 37 | "l_ankle_roll", 38 | "l_knee", 39 | "l_hip_yaw", 40 | "l_hip_roll", 41 | "l_hip_pitch", 42 | "r_ankle_pitch", 43 | "r_ankle_roll", 44 | "r_knee", 45 | "r_hip_yaw", 46 | "r_hip_roll", 47 | "r_hip_pitch", 48 | "torso_yaw", 49 | "torso_roll", 50 | "torso_pitch", 51 | ] 52 | 53 | 54 | def main(): 55 | device = torch.device(device_str) 56 | print(f"Using device: {device}") 57 | 58 | # Download into a temporary directory 59 | with tempfile.TemporaryDirectory() as tmpdirname: 60 | local_dir = Path(tmpdirname) 61 | print(f"Downloading {len(files)} files to {local_dir}...") 62 | dataset_names = download_amp_dataset_from_hf( 63 | destination_dir=local_dir, 64 | robot_folder=robot_folder, 65 | files=files, 66 | repo_id=repo_id, 67 | ) 68 | print("Downloaded datasets:", dataset_names) 69 | 70 | # Build datasets dictionary using returned dataset_names (without .npy extension) 71 | datasets = {name: weight for name, weight in zip(dataset_names, dataset_weights)} 72 | 73 | # Initialize loader and measure time 74 | t0 = time.perf_counter() 75 | loader = AMPLoader( 76 | device=device, 77 | dataset_path_root=local_dir, 78 | datasets=datasets, 79 | simulation_dt=simulation_dt, 80 | slow_down_factor=slow_down_factor, 81 | expected_joint_names=joint_names, 82 | ) 83 | if device.type == "cuda": 84 | torch.cuda.synchronize() 85 | t1 = time.perf_counter() 86 | print(f"Loader initialization took {t1 - t0:.3f} seconds") 87 | 88 | # Sampling performance 89 | total_batches = max(1, num_samples // batch_size) 90 | print(f"Sampling {total_batches} batches of size {batch_size}...") 91 | sampler = loader.feed_forward_generator(total_batches, batch_size) 92 | # Warm-up (esp. for CUDA) 93 | for _ in range(2): 94 | try: 95 | next(sampler) 96 | except StopIteration: 97 | break 98 | sampler = loader.feed_forward_generator(total_batches, batch_size) 99 | 100 | if device.type == "cuda": 101 | torch.cuda.synchronize() 102 | t2 = time.perf_counter() 103 | frames = 0 104 | for obs, next_obs in sampler: 105 | frames += obs.size(0) 106 | if device.type == "cuda": 107 | torch.cuda.synchronize() 108 | t3 = time.perf_counter() 109 | fps = frames / (t3 - t2) 110 | print( 111 | f"Sampled {frames} frames in {(t3 - t2) * 1000:.3f} milliseconds → {fps:.1f} frames/s" 112 | ) 113 | 114 | # Reset-state sampling performance 115 | print(f"Sampling {total_batches} reset-state batches of size {batch_size}...") 116 | if device.type == "cuda": 117 | torch.cuda.synchronize() 118 | t4 = time.perf_counter() 119 | states = 0 120 | for _ in range(total_batches): 121 | loader.get_state_for_reset(batch_size) 122 | states += batch_size 123 | if device.type == "cuda": 124 | torch.cuda.synchronize() 125 | t5 = time.perf_counter() 126 | sps = states / (t5 - t4) 127 | print( 128 | f"Sampled {states} states in {(t5 - t4) * 1000 :.3f} milliseconds → {sps:.1f} states/s" 129 | ) 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /amp_rsl_rl/storage/replay_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | import torch 7 | from typing import Generator, Tuple, Union 8 | 9 | 10 | class ReplayBuffer: 11 | """ 12 | Fixed-size circular buffer to store state and next-state experience tuples. 13 | 14 | Attributes: 15 | states (Tensor): Buffer of current states. 16 | next_states (Tensor): Buffer of next states. 17 | buffer_size (int): Maximum number of elements in the buffer. 18 | device (str or torch.device): Device where tensors are stored. 19 | step (int): Current write index. 20 | num_samples (int): Total number of inserted samples (up to buffer_size). 21 | """ 22 | 23 | def __init__( 24 | self, 25 | obs_dim: int, 26 | buffer_size: int, 27 | device: Union[str, torch.device] = "cpu", 28 | ) -> None: 29 | """ 30 | Initialize a ReplayBuffer object. 31 | 32 | Args: 33 | obs_dim (int): Dimension of the observation space. 34 | buffer_size (int): Maximum number of transitions to store. 35 | device (str or torch.device): Torch device where buffers are allocated ('cpu' or 'cuda'). 36 | """ 37 | self.device = torch.device(device) 38 | self.buffer_size = buffer_size 39 | 40 | # Pre-allocate buffers on the target device 41 | self.states = torch.zeros( 42 | (buffer_size, obs_dim), dtype=torch.float32, device=self.device 43 | ) 44 | self.next_states = torch.zeros( 45 | (buffer_size, obs_dim), dtype=torch.float32, device=self.device 46 | ) 47 | 48 | self.step = 0 49 | self.num_samples = 0 50 | 51 | def insert( 52 | self, 53 | states: torch.Tensor, 54 | next_states: torch.Tensor, 55 | ) -> None: 56 | """ 57 | Add a batch of states and next_states to the buffer. 58 | 59 | Args: 60 | states (Tensor): Batch of current states (batch_size, obs_dim). 61 | next_states (Tensor): Batch of next states (batch_size, obs_dim). 62 | """ 63 | # Move incoming data to buffer device if necessary 64 | states = states.to(self.device) 65 | next_states = next_states.to(self.device) 66 | 67 | batch_size = states.shape[0] 68 | end = self.step + batch_size 69 | 70 | if end <= self.buffer_size: 71 | self.states[self.step : end] = states 72 | self.next_states[self.step : end] = next_states 73 | else: 74 | # Wrap around 75 | first_part = self.buffer_size - self.step 76 | self.states[self.step :] = states[:first_part] 77 | self.next_states[self.step :] = next_states[:first_part] 78 | remainder = batch_size - first_part 79 | self.states[:remainder] = states[first_part:] 80 | self.next_states[:remainder] = next_states[first_part:] 81 | 82 | # Update pointers 83 | self.step = end % self.buffer_size 84 | self.num_samples = min(self.buffer_size, self.num_samples + batch_size) 85 | 86 | def feed_forward_generator( 87 | self, 88 | num_mini_batch: int, 89 | mini_batch_size: int, 90 | allow_replacement: bool = True, 91 | ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: 92 | """ 93 | Yield `num_mini_batch` mini‑batches of (state, next_state) tuples from the buffer, 94 | each of length `mini_batch_size`. 95 | 96 | If the total number of requested samples is larger than the number of 97 | items currently stored (`len(self)`), the method will 98 | 99 | * raise an error when `allow_replacement=False`; 100 | * silently sample **with replacement** when `allow_replacement=True` 101 | (the default). 102 | 103 | Args 104 | ---- 105 | num_mini_batch : int 106 | mini_batch_size : int 107 | allow_replacement : bool, optional 108 | Whether to allow sampling with replacement when the request 109 | exceeds the number of stored transitions. 110 | """ 111 | total = num_mini_batch * mini_batch_size 112 | 113 | # Sampling with replacement might yield duplicate samples, which can affect training dynamics 114 | if total > self.num_samples: 115 | if not allow_replacement: 116 | raise ValueError( 117 | f"Not enough samples in buffer: requested {total}, " 118 | f"but have {self.num_samples}" 119 | ) 120 | # Permute‑then‑modulo 121 | cycles = (total + self.num_samples - 1) // self.num_samples 122 | big_size = self.num_samples * cycles 123 | big_perm = torch.randperm(big_size, device=self.device) 124 | indices = big_perm[:total] % self.num_samples 125 | else: 126 | # Sample WITHOUT replacement 127 | indices = torch.randperm(self.num_samples, device=self.device)[:total] 128 | 129 | # Yield the mini‑batches 130 | for i in range(num_mini_batch): 131 | batch_idx = indices[i * mini_batch_size : (i + 1) * mini_batch_size] 132 | yield self.states[batch_idx], self.next_states[batch_idx] 133 | 134 | def __len__(self) -> int: 135 | """ 136 | Return the number of valid samples currently stored in the buffer. 137 | """ 138 | return self.num_samples 139 | -------------------------------------------------------------------------------- /amp_rsl_rl/utils/exporter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, The Isaac Lab Project Developers. 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | # Code taken from https://github.com/isaac-sim/IsaacLab/blob/5716d5600a1a0e45345bc01342a70bd81fac7889/source/isaaclab_rl/isaaclab_rl/rsl_rl/exporter.py 7 | 8 | import copy 9 | import os 10 | import torch 11 | from amp_rsl_rl.networks import ActorMoE 12 | 13 | 14 | def export_policy_as_onnx( 15 | actor_critic: object, 16 | path: str, 17 | normalizer: object | None = None, 18 | filename="policy.onnx", 19 | verbose=False, 20 | ): 21 | """Export policy into a Torch ONNX file. 22 | 23 | Args: 24 | actor_critic: The actor-critic torch module. 25 | normalizer: The empirical normalizer module. If None, Identity is used. 26 | path: The path to the saving directory. 27 | filename: The name of exported ONNX file. Defaults to "policy.onnx". 28 | verbose: Whether to print the model summary. Defaults to False. 29 | """ 30 | if not os.path.exists(path): 31 | os.makedirs(path, exist_ok=True) 32 | policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose) 33 | policy_exporter.export(path, filename) 34 | 35 | 36 | """ 37 | Helper Classes - Private. 38 | """ 39 | 40 | 41 | class _TorchPolicyExporter(torch.nn.Module): 42 | """Exporter of actor-critic into JIT file.""" 43 | 44 | def __init__(self, actor_critic, normalizer=None): 45 | super().__init__() 46 | self.actor = copy.deepcopy(actor_critic.actor) 47 | self.is_recurrent = actor_critic.is_recurrent 48 | if self.is_recurrent: 49 | self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) 50 | self.rnn.cpu() 51 | self.register_buffer( 52 | "hidden_state", 53 | torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size), 54 | ) 55 | self.register_buffer( 56 | "cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) 57 | ) 58 | self.forward = self.forward_lstm 59 | self.reset = self.reset_memory 60 | # copy normalizer if exists 61 | if normalizer: 62 | self.normalizer = copy.deepcopy(normalizer) 63 | else: 64 | self.normalizer = torch.nn.Identity() 65 | 66 | def forward_lstm(self, x): 67 | x = self.normalizer(x) 68 | x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state)) 69 | self.hidden_state[:] = h 70 | self.cell_state[:] = c 71 | x = x.squeeze(0) 72 | return self.actor(x) 73 | 74 | def forward(self, x): 75 | return self.actor(self.normalizer(x)) 76 | 77 | @torch.jit.export 78 | def reset(self): 79 | pass 80 | 81 | def reset_memory(self): 82 | self.hidden_state[:] = 0.0 83 | self.cell_state[:] = 0.0 84 | 85 | def export(self, path, filename): 86 | os.makedirs(path, exist_ok=True) 87 | path = os.path.join(path, filename) 88 | self.to("cpu") 89 | traced_script_module = torch.jit.script(self) 90 | traced_script_module.save(path) 91 | 92 | 93 | class _OnnxPolicyExporter(torch.nn.Module): 94 | """Exporter of actor-critic into ONNX file.""" 95 | 96 | def __init__(self, actor_critic, normalizer=None, verbose=False): 97 | super().__init__() 98 | self.verbose = verbose 99 | self.actor = copy.deepcopy(actor_critic.actor) 100 | self.is_recurrent = actor_critic.is_recurrent 101 | if self.is_recurrent: 102 | self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) 103 | self.rnn.cpu() 104 | self.forward = self.forward_lstm 105 | # copy normalizer if exists 106 | if normalizer: 107 | self.normalizer = copy.deepcopy(normalizer) 108 | else: 109 | self.normalizer = torch.nn.Identity() 110 | 111 | def forward_lstm(self, x_in, h_in, c_in): 112 | x_in = self.normalizer(x_in) 113 | x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in)) 114 | x = x.squeeze(0) 115 | return self.actor(x), h, c 116 | 117 | def forward(self, x): 118 | return self.actor(self.normalizer(x)) 119 | 120 | def export(self, path, filename): 121 | self.to("cpu") 122 | if self.is_recurrent: 123 | obs = torch.zeros(1, self.rnn.input_size) 124 | h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) 125 | c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) 126 | actions, h_out, c_out = self(obs, h_in, c_in) 127 | torch.onnx.export( 128 | self, 129 | (obs, h_in, c_in), 130 | os.path.join(path, filename), 131 | export_params=True, 132 | opset_version=11, 133 | verbose=self.verbose, 134 | input_names=["obs", "h_in", "c_in"], 135 | output_names=["actions", "h_out", "c_out"], 136 | dynamic_axes={}, 137 | ) 138 | else: 139 | obs = ( 140 | torch.zeros(1, self.actor.obs_dim) 141 | if isinstance(self.actor, ActorMoE) 142 | else torch.zeros(1, self.actor[0].in_features) 143 | ) 144 | torch.onnx.export( 145 | self, 146 | obs, 147 | os.path.join(path, filename), 148 | export_params=True, 149 | opset_version=11, 150 | verbose=self.verbose, 151 | input_names=["obs"], 152 | output_names=["actions"], 153 | dynamic_axes={}, 154 | ) 155 | -------------------------------------------------------------------------------- /amp_rsl_rl/networks/ac_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.distributions import Normal 11 | from rsl_rl.networks import EmpiricalNormalization 12 | from rsl_rl.utils import resolve_nn_activation 13 | 14 | 15 | class MLP_net(nn.Sequential): 16 | def __init__(self, in_dim, hidden_dims, out_dim, act): 17 | layers = [nn.Linear(in_dim, hidden_dims[0]), act] 18 | for i in range(len(hidden_dims)): 19 | if i == len(hidden_dims) - 1: 20 | layers.append(nn.Linear(hidden_dims[i], out_dim)) 21 | else: 22 | layers.extend([nn.Linear(hidden_dims[i], hidden_dims[i + 1]), act]) 23 | super().__init__(*layers) 24 | 25 | 26 | class ActorMoE(nn.Module): 27 | """ 28 | Mixture-of-Experts actor: ⎡expert_1(x) … expert_K(x)⎤·softmax(gate(x)) 29 | """ 30 | 31 | def __init__( 32 | self, 33 | obs_dim: int, 34 | act_dim: int, 35 | hidden_dims, 36 | num_experts: int = 4, 37 | gate_hidden_dims: list[int] | None = None, 38 | activation="elu", 39 | ): 40 | super().__init__() 41 | self.obs_dim = obs_dim 42 | self.act_dim = act_dim 43 | self.num_experts = num_experts 44 | act = resolve_nn_activation(activation) 45 | 46 | # experts 47 | self.experts = nn.ModuleList( 48 | [MLP_net(obs_dim, hidden_dims, act_dim, act) for _ in range(num_experts)] 49 | ) 50 | 51 | # gating network 52 | gate_layers = [] 53 | last_dim = obs_dim 54 | gate_hidden_dims = gate_hidden_dims or [] 55 | for h in gate_hidden_dims: 56 | gate_layers += [nn.Linear(last_dim, h), act] 57 | last_dim = h 58 | gate_layers.append(nn.Linear(last_dim, num_experts)) 59 | self.gate = nn.Sequential(*gate_layers) 60 | self.softmax = nn.Softmax(dim=-1) # kept separate for ONNX clarity 61 | 62 | def forward(self, x: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Args: 65 | x: [batch, obs_dim] 66 | Returns: 67 | mean action: [batch, act_dim] 68 | """ 69 | expert_out = torch.stack([e(x) for e in self.experts], dim=-1) 70 | gate_logits = self.gate(x) # [batch, K] 71 | weights = self.softmax(gate_logits).unsqueeze(1) # [batch, 1, K] 72 | return (expert_out * weights).sum(-1) # weighted sum -> [batch, A] 73 | 74 | 75 | class ActorCriticMoE(nn.Module): 76 | """Actor-critic module powered by a Mixture-of-Experts policy network. 77 | 78 | The API mirrors :class:`rsl_rl.modules.ActorCritic` so the class can be 79 | referenced via the standard policy ``class_name`` in configuration files. 80 | Observations are provided as TensorDict (or dict-like) containers and are 81 | grouped via ``obs_groups`` exactly like the upstream implementation. 82 | """ 83 | 84 | is_recurrent = False 85 | 86 | def __init__( 87 | self, 88 | obs, 89 | obs_groups, 90 | num_actions: int, 91 | actor_hidden_dims=[256, 256, 256], 92 | critic_hidden_dims=[256, 256, 256], 93 | num_experts: int = 4, 94 | activation: str = "elu", 95 | init_noise_std: float = 1.0, 96 | noise_std_type: str = "scalar", 97 | actor_obs_normalization: bool = False, 98 | critic_obs_normalization: bool = False, 99 | **kwargs, 100 | ): 101 | if kwargs: 102 | print( 103 | ( 104 | "ActorCriticMoE.__init__ ignored unexpected arguments: " 105 | + str(list(kwargs.keys())) 106 | ) 107 | ) 108 | super().__init__() 109 | 110 | self.obs_groups = obs_groups 111 | 112 | num_actor_obs = 0 113 | for obs_group in obs_groups["policy"]: 114 | assert ( 115 | len(obs[obs_group].shape) == 2 116 | ), "ActorCriticMoE only supports 1D flattened observations." 117 | num_actor_obs += obs[obs_group].shape[-1] 118 | 119 | num_critic_obs = 0 120 | for obs_group in obs_groups["critic"]: 121 | assert ( 122 | len(obs[obs_group].shape) == 2 123 | ), "ActorCriticMoE only supports 1D flattened observations." 124 | num_critic_obs += obs[obs_group].shape[-1] 125 | 126 | act = resolve_nn_activation(activation) 127 | 128 | self.actor = ActorMoE( 129 | obs_dim=num_actor_obs, 130 | act_dim=num_actions, 131 | hidden_dims=actor_hidden_dims, 132 | num_experts=num_experts, 133 | gate_hidden_dims=actor_hidden_dims[:-1], 134 | activation=activation, 135 | ) 136 | self.critic = MLP_net(num_critic_obs, critic_hidden_dims, 1, act) 137 | 138 | self.actor_obs_normalization = actor_obs_normalization 139 | if actor_obs_normalization: 140 | self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) 141 | else: 142 | self.actor_obs_normalizer = nn.Identity() 143 | 144 | self.critic_obs_normalization = critic_obs_normalization 145 | if critic_obs_normalization: 146 | self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) 147 | else: 148 | self.critic_obs_normalizer = nn.Identity() 149 | 150 | self.noise_std_type = noise_std_type 151 | if self.noise_std_type == "scalar": 152 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 153 | elif self.noise_std_type == "log": 154 | self.log_std = nn.Parameter( 155 | torch.log(init_noise_std * torch.ones(num_actions)) 156 | ) 157 | else: 158 | raise ValueError("noise_std_type must be 'scalar' or 'log'") 159 | 160 | self.distribution = None 161 | Normal.set_default_validate_args(False) 162 | 163 | print(f"Actor (MoE) structure:\n{self.actor}") 164 | print(f"Critic MLP structure:\n{self.critic}") 165 | 166 | def reset(self, dones=None): # noqa: D401 167 | pass 168 | 169 | def forward(self): 170 | raise NotImplementedError 171 | 172 | @property 173 | def action_mean(self): 174 | return self.distribution.mean 175 | 176 | @property 177 | def action_std(self): 178 | return self.distribution.stddev 179 | 180 | @property 181 | def entropy(self): 182 | return self.distribution.entropy().sum(dim=-1) 183 | 184 | def update_distribution(self, observations): 185 | mean = self.actor(observations) 186 | if self.noise_std_type == "scalar": 187 | std = self.std.expand_as(mean) 188 | else: # "log" 189 | std = torch.exp(self.log_std).expand_as(mean) 190 | self.distribution = Normal(mean, std) 191 | 192 | def act(self, obs, **kwargs): 193 | actor_obs = self.get_actor_obs(obs) 194 | actor_obs = self.actor_obs_normalizer(actor_obs) 195 | self.update_distribution(actor_obs) 196 | return self.distribution.sample() 197 | 198 | def get_actions_log_prob(self, actions): 199 | return self.distribution.log_prob(actions).sum(dim=-1) 200 | 201 | def act_inference(self, obs): 202 | actor_obs = self.get_actor_obs(obs) 203 | actor_obs = self.actor_obs_normalizer(actor_obs) 204 | return self.actor(actor_obs) 205 | 206 | def evaluate(self, obs, **kwargs): 207 | critic_obs = self.get_critic_obs(obs) 208 | critic_obs = self.critic_obs_normalizer(critic_obs) 209 | return self.critic(critic_obs) 210 | 211 | def get_actor_obs(self, obs): 212 | obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]] 213 | return torch.cat(obs_list, dim=-1) 214 | 215 | def get_critic_obs(self, obs): 216 | obs_list = [obs[obs_group] for obs_group in self.obs_groups["critic"]] 217 | return torch.cat(obs_list, dim=-1) 218 | 219 | def update_normalization(self, obs): 220 | if self.actor_obs_normalization: 221 | actor_obs = self.get_actor_obs(obs) 222 | self.actor_obs_normalizer.update(actor_obs) 223 | if self.critic_obs_normalization: 224 | critic_obs = self.get_critic_obs(obs) 225 | self.critic_obs_normalizer.update(critic_obs) 226 | 227 | # unchanged load_state_dict so checkpoints from the old class still load 228 | def load_state_dict(self, state_dict, strict=True): 229 | super().load_state_dict(state_dict, strict=strict) 230 | return True 231 | -------------------------------------------------------------------------------- /amp_rsl_rl/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import autograd 9 | from torch.nn import functional as F 10 | 11 | from rsl_rl.networks import EmpiricalNormalization 12 | 13 | 14 | class Discriminator(nn.Module): 15 | """Discriminator implements the discriminator network for the AMP algorithm. 16 | 17 | This network is trained to distinguish between expert and policy-generated data. 18 | It also provides reward signals for the policy through adversarial learning. 19 | 20 | Args: 21 | input_dim (int): Dimension of the concatenated input state (state + next state). 22 | hidden_layer_sizes (list): List of hidden layer sizes. 23 | reward_scale (float): Scale factor for the computed reward. 24 | reward_clamp_epsilon (float): Numerical epsilon used when clamping rewards. 25 | device (str | torch.device): Device to run the model on. 26 | loss_type (str): Type of loss function to use ('BCEWithLogits' or 'Wasserstein'). 27 | eta_wgan (float): Scaling factor for the Wasserstein loss (if used). 28 | use_minibatch_std (bool): Whether to use minibatch standard deviation in the network 29 | empirical_normalization (bool): Whether to normalize AMP observations empirically before scoring. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | input_dim: int, 35 | hidden_layer_sizes: list[int], 36 | reward_scale: float, 37 | reward_clamp_epsilon: float = 1.0e-4, 38 | device: str | torch.device = "cpu", 39 | loss_type: str = "BCEWithLogits", 40 | eta_wgan: float = 0.3, 41 | use_minibatch_std: bool = True, 42 | empirical_normalization: bool = False, 43 | ): 44 | super().__init__() 45 | 46 | self.device = torch.device(device) 47 | self.input_dim = input_dim 48 | self.reward_scale = reward_scale 49 | self.reward_clamp_epsilon = reward_clamp_epsilon 50 | layers = [] 51 | curr_in_dim = input_dim 52 | 53 | for hidden_dim in hidden_layer_sizes: 54 | layers.append(nn.Linear(curr_in_dim, hidden_dim)) 55 | layers.append(nn.ReLU()) 56 | curr_in_dim = hidden_dim 57 | 58 | self.trunk = nn.Sequential(*layers) 59 | final_in_dim = hidden_layer_sizes[-1] + (1 if use_minibatch_std else 0) 60 | self.linear = nn.Linear(final_in_dim, 1) 61 | 62 | self.empirical_normalization = empirical_normalization 63 | amp_obs_dim = input_dim // 2 64 | if empirical_normalization: 65 | self.amp_normalizer = EmpiricalNormalization(shape=[amp_obs_dim]) 66 | else: 67 | self.amp_normalizer = nn.Identity() 68 | 69 | self.to(self.device) 70 | self.train() 71 | self.use_minibatch_std = use_minibatch_std 72 | self.loss_type = loss_type if loss_type is not None else "BCEWithLogits" 73 | if self.loss_type == "BCEWithLogits": 74 | self.loss_fun = torch.nn.BCEWithLogitsLoss() 75 | elif self.loss_type == "Wasserstein": 76 | self.loss_fun = None 77 | self.eta_wgan = eta_wgan 78 | print("The Wasserstein-like loss is experimental") 79 | else: 80 | raise ValueError( 81 | f"Unsupported loss type: {self.loss_type}. Supported types are 'BCEWithLogits' and 'Wasserstein'." 82 | ) 83 | 84 | def forward(self, x: torch.Tensor) -> torch.Tensor: 85 | """Forward pass through the discriminator. 86 | 87 | Args: 88 | x (Tensor): Input tensor (batch_size, input_dim). 89 | 90 | Returns: 91 | Tensor: Discriminator output logits/scores. 92 | """ 93 | 94 | # Normalize AMP observations. If not enabled the normalizer is identity. 95 | # split state and next_state and apply normalization 96 | state, next_state = torch.split(x, self.input_dim // 2, dim=-1) 97 | state = self.amp_normalizer(state) 98 | next_state = self.amp_normalizer(next_state) 99 | x = torch.cat([state, next_state], dim=-1) 100 | 101 | h = self.trunk(x) 102 | if self.use_minibatch_std: 103 | s = self._minibatch_std_scalar(h) 104 | h = torch.cat([h, s], dim=-1) 105 | return self.linear(h) 106 | 107 | def _minibatch_std_scalar(self, h: torch.Tensor) -> torch.Tensor: 108 | """Mean over feature-wise std across the batch; shape (B,1).""" 109 | if h.shape[0] <= 1: 110 | return h.new_zeros((h.shape[0], 1)) 111 | s = h.float().std(dim=0, unbiased=False).mean() 112 | return s.expand(h.shape[0], 1).to(h.dtype) 113 | 114 | def predict_reward( 115 | self, 116 | state: torch.Tensor, 117 | next_state: torch.Tensor, 118 | ) -> torch.Tensor: 119 | """Predicts reward based on discriminator output using a log-style formulation. 120 | 121 | Args: 122 | state (Tensor): Current state tensor. 123 | next_state (Tensor): Next state tensor. 124 | 125 | Returns: 126 | Tensor: Computed adversarial reward. 127 | """ 128 | with torch.no_grad(): 129 | 130 | # No need to normalize here as normalization is done in forward() 131 | discriminator_logit = self.forward(torch.cat([state, next_state], dim=-1)) 132 | 133 | if self.loss_type == "Wasserstein": 134 | discriminator_logit = torch.tanh(self.eta_wgan * discriminator_logit) 135 | return self.reward_scale * torch.exp(discriminator_logit).squeeze() 136 | # softplus(logit) == -log(1 - sigmoid(logit)) 137 | reward = F.softplus(discriminator_logit) 138 | reward = self.reward_scale * reward 139 | return reward.squeeze() 140 | 141 | def policy_loss(self, discriminator_output: torch.Tensor) -> torch.Tensor: 142 | """ 143 | Computes the loss for the discriminator when classifying policy-generated transitions. 144 | Uses binary cross-entropy loss where the target label for policy transitions is 0. 145 | 146 | Parameters 147 | ---------- 148 | discriminator_output : torch.Tensor 149 | The raw logits output from the discriminator for policy data. 150 | 151 | Returns 152 | ------- 153 | torch.Tensor 154 | The computed policy loss. 155 | """ 156 | expected = torch.zeros_like(discriminator_output, device=self.device) 157 | return self.loss_fun(discriminator_output, expected) 158 | 159 | def expert_loss(self, discriminator_output: torch.Tensor) -> torch.Tensor: 160 | """ 161 | Computes the loss for the discriminator when classifying expert transitions. 162 | Uses binary cross-entropy loss where the target label for expert transitions is 1. 163 | 164 | Parameters 165 | ---------- 166 | discriminator_output : torch.Tensor 167 | The raw logits output from the discriminator for expert data. 168 | 169 | Returns 170 | ------- 171 | torch.Tensor 172 | The computed expert loss. 173 | """ 174 | expected = torch.ones_like(discriminator_output, device=self.device) 175 | return self.loss_fun(discriminator_output, expected) 176 | 177 | def update_normalization(self, *batches: torch.Tensor) -> None: 178 | """Update empirical statistics using provided AMP batches.""" 179 | if not self.empirical_normalization: 180 | return 181 | with torch.no_grad(): 182 | for batch in batches: 183 | self.amp_normalizer.update(batch) 184 | 185 | def compute_loss( 186 | self, 187 | policy_d, 188 | expert_d, 189 | sample_amp_expert, 190 | sample_amp_policy, 191 | lambda_: float = 10, 192 | ): 193 | 194 | # Compute gradient penalty to stabilize discriminator training. 195 | sample_amp_expert = tuple(self.amp_normalizer(s) for s in sample_amp_expert) 196 | sample_amp_policy = tuple(self.amp_normalizer(s) for s in sample_amp_policy) 197 | grad_pen_loss = self.compute_grad_pen( 198 | expert_states=sample_amp_expert, 199 | policy_states=sample_amp_policy, 200 | lambda_=lambda_, 201 | ) 202 | if self.loss_type == "BCEWithLogits": 203 | expert_loss = self.loss_fun(expert_d, torch.ones_like(expert_d)) 204 | policy_loss = self.loss_fun(policy_d, torch.zeros_like(policy_d)) 205 | # AMP loss is the average of expert and policy losses. 206 | amp_loss = 0.5 * (expert_loss + policy_loss) 207 | elif self.loss_type == "Wasserstein": 208 | amp_loss = self.wgan_loss(policy_d=policy_d, expert_d=expert_d) 209 | return amp_loss, grad_pen_loss 210 | 211 | def compute_grad_pen( 212 | self, 213 | expert_states: tuple[torch.Tensor, torch.Tensor], 214 | policy_states: tuple[torch.Tensor, torch.Tensor], 215 | lambda_: float = 10, 216 | ) -> torch.Tensor: 217 | """Computes the gradient penalty used to regularize the discriminator. 218 | 219 | Args: 220 | expert_states (tuple[Tensor, Tensor]): A tuple containing batches of expert states and expert next states. 221 | policy_states (tuple[Tensor, Tensor]): A tuple containing batches of policy states and policy next states. 222 | lambda_ (float): Penalty coefficient. 223 | 224 | Returns: 225 | Tensor: Gradient penalty value. 226 | """ 227 | expert = torch.cat(expert_states, -1) 228 | 229 | if self.loss_type == "Wasserstein": 230 | policy = torch.cat(policy_states, -1) 231 | alpha = torch.rand(expert.size(0), 1, device=expert.device) 232 | alpha = alpha.expand_as(expert) 233 | data = alpha * expert + (1 - alpha) * policy 234 | data = data.detach().requires_grad_(True) 235 | h = self.trunk(data) 236 | if self.use_minibatch_std: 237 | with torch.no_grad(): 238 | s = self._minibatch_std_scalar(h) 239 | h = torch.cat([h, s], dim=-1) 240 | scores = self.linear(h) 241 | grad = autograd.grad( 242 | outputs=scores, 243 | inputs=data, 244 | grad_outputs=torch.ones_like(scores), 245 | create_graph=True, 246 | retain_graph=True, 247 | only_inputs=True, 248 | )[0] 249 | return lambda_ * (grad.norm(2, dim=1) - 1.0).pow(2).mean() 250 | elif self.loss_type == "BCEWithLogits": 251 | # R1 regularizer on REAL: 0.5 * lambda * ||∇_x D(x_real)||^2 252 | data = expert.detach().requires_grad_(True) 253 | # Compute D(x_real) with minibatch-std DETACHED, 254 | # so gradients are w.r.t. the sample itself, not the batch statistics. 255 | h = self.trunk(data) 256 | if self.use_minibatch_std: 257 | with torch.no_grad(): 258 | s = self._minibatch_std_scalar(h) 259 | h = torch.cat([h, s], dim=-1) 260 | scores = self.linear(h) 261 | 262 | grad = autograd.grad( 263 | outputs=scores.sum(), 264 | inputs=data, 265 | create_graph=True, 266 | retain_graph=True, 267 | only_inputs=True, 268 | )[0] 269 | return 0.5 * lambda_ * (grad.pow(2).sum(dim=1)).mean() 270 | 271 | else: 272 | raise ValueError( 273 | f"Unsupported loss type: {self.loss_type}. Supported types are 'BCEWithLogits' and 'Wasserstein'." 274 | ) 275 | 276 | def wgan_loss(self, policy_d, expert_d): 277 | """ 278 | This loss function computes a modified Wasserstein loss for the discriminator. 279 | The original Wasserstein loss is D(policy) - D(expert), but here we apply a tanh 280 | transformation to the discriminator outputs scaled by eta_wgan. This helps in stabilizing the training. 281 | Args: 282 | policy_d (Tensor): Discriminator output for policy data. 283 | expert_d (Tensor): Discriminator output for expert data. 284 | """ 285 | policy_d = torch.tanh(self.eta_wgan * policy_d) 286 | expert_d = torch.tanh(self.eta_wgan * expert_d) 287 | return policy_d.mean() - expert_d.mean() 288 | -------------------------------------------------------------------------------- /amp_rsl_rl/utils/motion_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from pathlib import Path 7 | from typing import List, Union, Tuple, Generator, Dict 8 | from dataclasses import dataclass 9 | 10 | import torch 11 | import numpy as np 12 | from scipy.spatial.transform import Rotation, Slerp 13 | from scipy.interpolate import interp1d 14 | 15 | 16 | def download_amp_dataset_from_hf( 17 | destination_dir: Path, 18 | robot_folder: str, 19 | files: list, 20 | repo_id: str = "ami-iit/amp-dataset", 21 | ) -> list: 22 | """ 23 | Downloads AMP dataset files from Hugging Face and saves them to `destination_dir`. 24 | Ensures real file copies (not symlinks or hard links). 25 | 26 | Args: 27 | destination_dir (Path): Local directory to save the files. 28 | robot_folder (str): Folder in the Hugging Face dataset repo to pull from. 29 | files (list): List of filenames to download. 30 | repo_id (str): Hugging Face repository ID. Default is "ami-iit/amp-dataset". 31 | 32 | Returns: 33 | List[str]: List of dataset names (without .npy extension). 34 | """ 35 | from huggingface_hub import hf_hub_download 36 | 37 | destination_dir.mkdir(parents=True, exist_ok=True) 38 | dataset_names = [] 39 | 40 | for file in files: 41 | file_path = hf_hub_download( 42 | repo_id=repo_id, 43 | filename=f"{robot_folder}/{file}", 44 | repo_type="dataset", 45 | local_files_only=False, 46 | ) 47 | local_copy = destination_dir / file 48 | # Deep copy to avoid symlinks 49 | with open(file_path, "rb") as src_file, open(local_copy, "wb") as dst_file: 50 | dst_file.write(src_file.read()) 51 | dataset_names.append(file.replace(".npy", "")) 52 | 53 | return dataset_names 54 | 55 | 56 | @dataclass 57 | class MotionData: 58 | """ 59 | Data class representing motion data for humanoid agents. 60 | 61 | This class stores joint positions and velocities, base velocities (both in local 62 | and mixed/world frames), and base orientation (as quaternion). It offers utilities 63 | for preparing data in AMP-compatible format, as well as environment reset states. 64 | 65 | Attributes: 66 | - joint_positions: shape (T, N) 67 | - joint_velocities: shape (T, N) 68 | - base_lin_velocities_mixed: linear velocity in world frame 69 | - base_ang_velocities_mixed: (currently zeros) 70 | - base_lin_velocities_local: linear velocity in local (body) frame 71 | - base_ang_velocities_local: (currently zeros) 72 | - base_quat: orientation quaternion as torch.Tensor in wxyz order 73 | 74 | Notes: 75 | - The quaternion is expected in the dataset as `xyzw` format (SciPy default), 76 | and it is converted internally to `wxyz` format to be compatible with IsaacLab conventions. 77 | - All data is converted to torch.Tensor on the specified device during initialization. 78 | """ 79 | 80 | joint_positions: Union[torch.Tensor, np.ndarray] 81 | joint_velocities: Union[torch.Tensor, np.ndarray] 82 | base_lin_velocities_mixed: Union[torch.Tensor, np.ndarray] 83 | base_ang_velocities_mixed: Union[torch.Tensor, np.ndarray] 84 | base_lin_velocities_local: Union[torch.Tensor, np.ndarray] 85 | base_ang_velocities_local: Union[torch.Tensor, np.ndarray] 86 | base_quat: Union[Rotation, torch.Tensor] 87 | device: torch.device = torch.device("cpu") 88 | 89 | def __post_init__(self) -> None: 90 | # Convert numpy arrays (or SciPy Rotations) to torch tensors 91 | def to_tensor(x): 92 | return torch.tensor(x, device=self.device, dtype=torch.float32) 93 | 94 | if isinstance(self.joint_positions, np.ndarray): 95 | self.joint_positions = to_tensor(self.joint_positions) 96 | if isinstance(self.joint_velocities, np.ndarray): 97 | self.joint_velocities = to_tensor(self.joint_velocities) 98 | if isinstance(self.base_lin_velocities_mixed, np.ndarray): 99 | self.base_lin_velocities_mixed = to_tensor(self.base_lin_velocities_mixed) 100 | if isinstance(self.base_ang_velocities_mixed, np.ndarray): 101 | self.base_ang_velocities_mixed = to_tensor(self.base_ang_velocities_mixed) 102 | if isinstance(self.base_lin_velocities_local, np.ndarray): 103 | self.base_lin_velocities_local = to_tensor(self.base_lin_velocities_local) 104 | if isinstance(self.base_ang_velocities_local, np.ndarray): 105 | self.base_ang_velocities_local = to_tensor(self.base_ang_velocities_local) 106 | if isinstance(self.base_quat, Rotation): 107 | quat_xyzw = self.base_quat.as_quat() # (T,4) xyzw 108 | # convert to wxyz 109 | self.base_quat = torch.tensor( 110 | quat_xyzw[:, [3, 0, 1, 2]], 111 | device=self.device, 112 | dtype=torch.float32, 113 | ) 114 | 115 | def __len__(self) -> int: 116 | return self.joint_positions.shape[0] 117 | 118 | def get_amp_dataset_obs(self, indices: torch.Tensor) -> torch.Tensor: 119 | """ 120 | Returns the AMP observation tensor for given indices. 121 | 122 | Args: 123 | indices: indices of samples to retrieve 124 | 125 | Returns: 126 | Concatenated observation tensor 127 | """ 128 | return torch.cat( 129 | ( 130 | self.joint_positions[indices], 131 | self.joint_velocities[indices], 132 | self.base_lin_velocities_local[indices], 133 | self.base_ang_velocities_local[indices], 134 | ), 135 | dim=1, 136 | ) 137 | 138 | def get_state_for_reset(self, indices: torch.Tensor) -> Tuple[torch.Tensor, ...]: 139 | """ 140 | Returns the full state needed for environment reset. 141 | 142 | Args: 143 | indices: indices of samples to retrieve 144 | 145 | Returns: 146 | Tuple of (quat, joint_positions, joint_velocities, base_lin_velocities, base_ang_velocities) 147 | """ 148 | return ( 149 | self.base_quat[indices], 150 | self.joint_positions[indices], 151 | self.joint_velocities[indices], 152 | self.base_lin_velocities_local[indices], 153 | self.base_ang_velocities_local[indices], 154 | ) 155 | 156 | def get_random_sample_for_reset(self, items: int = 1) -> Tuple[torch.Tensor, ...]: 157 | indices = torch.randint(0, len(self), (items,), device=self.device) 158 | return self.get_state_for_reset(indices) 159 | 160 | 161 | class AMPLoader: 162 | """ 163 | Loader and processor for humanoid motion capture datasets in AMP format. 164 | 165 | Responsibilities: 166 | - Loading .npy files containing motion data 167 | - Building a unified joint ordering across all datasets 168 | - Resampling trajectories to match the simulator's timestep 169 | - Computing derived quantities (velocities, local-frame motion) 170 | - Returning torch-friendly MotionData instances 171 | 172 | Dataset format: 173 | Each .npy contains a dict with keys: 174 | - "joints_list": List[str] 175 | - "joint_positions": List[np.ndarray] 176 | - "root_position": List[np.ndarray] 177 | - "root_quaternion": List[np.ndarray] (xyzw) 178 | - "fps": float (frames/sec) 179 | 180 | Args: 181 | device: Target torch device ('cpu' or 'cuda') 182 | dataset_path_root: Directory containing the .npy motion files 183 | datasets: Dictionary mapping dataset names (without extension) to sampling weights (floats) 184 | simulation_dt: Timestep used by the simulator 185 | slow_down_factor: Integer factor to slow down original data 186 | expected_joint_names: (Optional) override for joint ordering 187 | """ 188 | 189 | def __init__( 190 | self, 191 | device: str, 192 | dataset_path_root: Path, 193 | datasets: Dict[str, float], 194 | simulation_dt: float, 195 | slow_down_factor: int, 196 | expected_joint_names: Union[List[str], None] = None, 197 | ) -> None: 198 | self.device = device 199 | if isinstance(dataset_path_root, str): 200 | dataset_path_root = Path(dataset_path_root) 201 | 202 | # ─── Parse dataset names and weights ─── 203 | dataset_names = list(datasets.keys()) 204 | dataset_weights = list(datasets.values()) 205 | 206 | # ─── Build union of all joint names if not provided ─── 207 | if expected_joint_names is None: 208 | joint_union: List[str] = [] 209 | seen = set() 210 | for name in dataset_names: 211 | p = dataset_path_root / f"{name}.npy" 212 | info = np.load(str(p), allow_pickle=True).item() 213 | for j in info["joints_list"]: 214 | if j not in seen: 215 | seen.add(j) 216 | joint_union.append(j) 217 | expected_joint_names = joint_union 218 | # ───────────────────────────────────────────────────────── 219 | 220 | # Load and process each dataset into MotionData 221 | self.motion_data: List[MotionData] = [] 222 | for dataset_name in dataset_names: 223 | dataset_path = dataset_path_root / f"{dataset_name}.npy" 224 | md = self.load_data( 225 | dataset_path, 226 | simulation_dt, 227 | slow_down_factor, 228 | expected_joint_names, 229 | ) 230 | self.motion_data.append(md) 231 | 232 | # Normalize dataset-level sampling weights 233 | weights = torch.tensor(dataset_weights, dtype=torch.float32, device=self.device) 234 | self.dataset_weights = weights / weights.sum() 235 | 236 | # Precompute flat buffers for fast sampling 237 | obs_list, next_obs_list, reset_states = [], [], [] 238 | for data, w in zip(self.motion_data, self.dataset_weights): 239 | T = len(data) 240 | idx = torch.arange(T, device=self.device) 241 | obs = data.get_amp_dataset_obs(idx) 242 | next_idx = torch.clamp(idx + 1, max=T - 1) 243 | next_obs = data.get_amp_dataset_obs(next_idx) 244 | 245 | obs_list.append(obs) 246 | next_obs_list.append(next_obs) 247 | 248 | quat, jp, jv, blv, bav = data.get_state_for_reset(idx) 249 | reset_states.append(torch.cat([quat, jp, jv, blv, bav], dim=1)) 250 | 251 | self.all_obs = torch.cat(obs_list, dim=0) 252 | self.all_next_obs = torch.cat(next_obs_list, dim=0) 253 | self.all_states = torch.cat(reset_states, dim=0) 254 | 255 | # Build per-frame sampling weights: weight_i / length_i 256 | lengths = [len(d) for d in self.motion_data] 257 | per_frame = torch.cat( 258 | [ 259 | torch.full((L,), w / L, device=self.device) 260 | for w, L in zip(self.dataset_weights, lengths) 261 | ] 262 | ) 263 | self.per_frame_weights = per_frame / per_frame.sum() 264 | 265 | def _resample_data_Rn( 266 | self, 267 | data: List[np.ndarray], 268 | original_keyframes, 269 | target_keyframes, 270 | ) -> np.ndarray: 271 | f = interp1d(original_keyframes, data, axis=0) 272 | return f(target_keyframes) 273 | 274 | def _resample_data_SO3( 275 | self, 276 | raw_quaternions: List[np.ndarray], 277 | original_keyframes, 278 | target_keyframes, 279 | ) -> Rotation: 280 | 281 | # the quaternion is expected in the dataset as `xyzw` format (SciPy default) 282 | tmp = Rotation.from_quat(raw_quaternions) 283 | slerp = Slerp(original_keyframes, tmp) 284 | return slerp(target_keyframes) 285 | 286 | def _compute_ang_vel( 287 | self, 288 | data: List[Rotation], 289 | dt: float, 290 | local: bool = False, 291 | ) -> np.ndarray: 292 | R_prev = data[:-1] 293 | R_next = data[1:] 294 | 295 | if local: 296 | # Exp = R_i⁻¹ · R_{i+1} 297 | rel = R_prev.inv() * R_next 298 | else: 299 | # Exp = R_{i+1} · R_i⁻¹ 300 | rel = R_next * R_prev.inv() 301 | 302 | # Log-map to rotation vectors and divide by Δt 303 | rotvec = rel.as_rotvec() / dt 304 | 305 | return np.vstack((rotvec, rotvec[-1])) 306 | 307 | def _compute_raw_derivative(self, data: np.ndarray, dt: float) -> np.ndarray: 308 | d = (data[1:] - data[:-1]) / dt 309 | return np.vstack([d, d[-1:]]) 310 | 311 | def load_data( 312 | self, 313 | dataset_path: Path, 314 | simulation_dt: float, 315 | slow_down_factor: int = 1, 316 | expected_joint_names: Union[List[str], None] = None, 317 | ) -> MotionData: 318 | """ 319 | Loads and processes one motion dataset. 320 | 321 | Returns: 322 | MotionData instance 323 | """ 324 | data = np.load(str(dataset_path), allow_pickle=True).item() 325 | dataset_joint_names = data["joints_list"] 326 | 327 | # build index map for expected_joint_names 328 | idx_map: List[Union[int, None]] = [] 329 | for j in expected_joint_names: 330 | if j in dataset_joint_names: 331 | idx_map.append(dataset_joint_names.index(j)) 332 | else: 333 | idx_map.append(None) 334 | 335 | # reorder & fill joint positions 336 | jp_list: List[np.ndarray] = [] 337 | for frame in data["joint_positions"]: 338 | arr = np.zeros((len(idx_map),), dtype=frame.dtype) 339 | for i, src_idx in enumerate(idx_map): 340 | if src_idx is not None: 341 | arr[i] = frame[src_idx] 342 | jp_list.append(arr) 343 | 344 | dt = 1.0 / data["fps"] / float(slow_down_factor) 345 | T = len(jp_list) 346 | t_orig = np.linspace(0, T * dt, T) 347 | T_new = int(T * dt / simulation_dt) 348 | t_new = np.linspace(0, T * dt, T_new) 349 | 350 | resampled_joint_positions = self._resample_data_Rn(jp_list, t_orig, t_new) 351 | resampled_joint_velocities = self._compute_raw_derivative( 352 | resampled_joint_positions, simulation_dt 353 | ) 354 | 355 | resampled_base_positions = self._resample_data_Rn( 356 | data["root_position"], t_orig, t_new 357 | ) 358 | resampled_base_orientations = self._resample_data_SO3( 359 | data["root_quaternion"], t_orig, t_new 360 | ) 361 | 362 | resampled_base_lin_vel_mixed = self._compute_raw_derivative( 363 | resampled_base_positions, simulation_dt 364 | ) 365 | 366 | resampled_base_ang_vel_mixed = self._compute_ang_vel( 367 | resampled_base_orientations, simulation_dt, local=False 368 | ) 369 | 370 | resampled_base_lin_vel_local = np.stack( 371 | [ 372 | R.as_matrix().T @ v 373 | for R, v in zip( 374 | resampled_base_orientations, resampled_base_lin_vel_mixed 375 | ) 376 | ] 377 | ) 378 | resampled_base_ang_vel_local = self._compute_ang_vel( 379 | resampled_base_orientations, simulation_dt, local=True 380 | ) 381 | 382 | return MotionData( 383 | joint_positions=resampled_joint_positions, 384 | joint_velocities=resampled_joint_velocities, 385 | base_lin_velocities_mixed=resampled_base_lin_vel_mixed, 386 | base_ang_velocities_mixed=resampled_base_ang_vel_mixed, 387 | base_lin_velocities_local=resampled_base_lin_vel_local, 388 | base_ang_velocities_local=resampled_base_ang_vel_local, 389 | base_quat=resampled_base_orientations, 390 | device=self.device, 391 | ) 392 | 393 | def feed_forward_generator( 394 | self, num_mini_batch: int, mini_batch_size: int 395 | ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: 396 | """ 397 | Yields mini-batches of (state, next_state) pairs for training, 398 | sampled directly from precomputed buffers. 399 | 400 | Args: 401 | num_mini_batch: Number of mini-batches to yield 402 | mini_batch_size: Size of each mini-batch 403 | Yields: 404 | Tuple of (state, next_state) tensors 405 | """ 406 | for _ in range(num_mini_batch): 407 | idx = torch.multinomial( 408 | self.per_frame_weights, mini_batch_size, replacement=True 409 | ) 410 | yield self.all_obs[idx], self.all_next_obs[idx] 411 | 412 | def get_state_for_reset(self, number_of_samples: int) -> Tuple[torch.Tensor, ...]: 413 | """ 414 | Randomly samples full states for environment resets, 415 | sampled directly from the precomputed state buffer. 416 | 417 | Args: 418 | number_of_samples: Number of samples to retrieve 419 | Returns: 420 | Tuple of (quat, joint_positions, joint_velocities, base_lin_velocities, base_ang_velocities) 421 | """ 422 | idx = torch.multinomial( 423 | self.per_frame_weights, number_of_samples, replacement=True 424 | ) 425 | full = self.all_states[idx] 426 | joint_dim = self.motion_data[0].joint_positions.shape[1] 427 | 428 | # The dimensions of the full state are: 429 | # - 4 (quat) + joint_dim (joint_positions) + joint_dim (joint_velocities) 430 | # + 3 (base_lin_velocities) + 3 (base_ang_velocities) 431 | # = 4 + joint_dim + joint_dim + 3 + 3 432 | dims = [4, joint_dim, joint_dim, 3, 3] 433 | return torch.split(full, dims, dim=1) 434 | -------------------------------------------------------------------------------- /amp_rsl_rl/algorithms/amp_ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Optional, Tuple, Dict, Any 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from tensordict import TensorDict 15 | 16 | # External modules providing the actor-critic model, storage utilities, and AMP components. 17 | from rsl_rl.modules import ActorCritic 18 | from rsl_rl.storage import RolloutStorage 19 | 20 | from amp_rsl_rl.storage import ReplayBuffer 21 | from amp_rsl_rl.networks import Discriminator 22 | from amp_rsl_rl.utils import AMPLoader 23 | 24 | 25 | class AMP_PPO: 26 | """ 27 | AMP_PPO implements Adversarial Motion Priors (AMP) combined with Proximal Policy Optimization (PPO). 28 | 29 | The algorithm mirrors the structure of upstream ``PPO`` from ``rsl_rl`` but augments each update 30 | with a discriminator trained on expert trajectories. Observations feed into the policy as 31 | TensorDicts, allowing the actor and critic to consume different observation groups. 32 | 33 | Parameters 34 | ---------- 35 | actor_critic : ActorCritic 36 | Policy network providing ``act``/``evaluate``/``update_normalization`` APIs. 37 | discriminator : Discriminator 38 | AMP discriminator distinguishing expert vs policy motion pairs. 39 | amp_data : AMPLoader 40 | Data loader that provides batches of expert motion data. 41 | num_learning_epochs : int, default=1 42 | Number of passes over the rollout buffer per update. 43 | num_mini_batches : int, default=1 44 | Number of mini-batches to divide each epoch's data into. 45 | clip_param : float, default=0.2 46 | PPO clipping parameter that bounds the policy update step. 47 | gamma : float, default=0.998 48 | Discount factor. 49 | lam : float, default=0.95 50 | Lambda parameter for Generalized Advantage Estimation (GAE). 51 | value_loss_coef : float, default=1.0 52 | Coefficient for the value function loss term in the PPO loss. 53 | entropy_coef : float, default=0.0 54 | Coefficient for the entropy regularization term (encouraging exploration). 55 | learning_rate : float, default=1e-3 56 | Initial learning rate. 57 | max_grad_norm : float, default=1.0 58 | Maximum gradient norm for clipping gradients during backpropagation. 59 | use_clipped_value_loss : bool, default=True 60 | Enables the clipped value loss variant of PPO. 61 | schedule : str, default="fixed" 62 | Either ``"fixed"`` or ``"adaptive"`` (based on KL). 63 | desired_kl : float, default=0.01 64 | Target KL divergence when using the adaptive schedule. 65 | amp_replay_buffer_size : int, default=100_000 66 | Size of the replay buffer storing policy-generated AMP samples. 67 | use_smooth_ratio_clipping : bool, default=False 68 | Enables smooth ratio clipping instead of hard clamping. 69 | device : str, default="cpu" 70 | Torch device used by the module. 71 | """ 72 | 73 | actor_critic: ActorCritic 74 | 75 | def __init__( 76 | self, 77 | actor_critic: ActorCritic, 78 | discriminator: Discriminator, 79 | amp_data: AMPLoader, 80 | num_learning_epochs: int = 1, 81 | num_mini_batches: int = 1, 82 | clip_param: float = 0.2, 83 | gamma: float = 0.998, 84 | lam: float = 0.95, 85 | value_loss_coef: float = 1.0, 86 | entropy_coef: float = 0.0, 87 | learning_rate: float = 1e-3, 88 | max_grad_norm: float = 1.0, 89 | use_clipped_value_loss: bool = True, 90 | schedule: str = "fixed", 91 | desired_kl: float = 0.01, 92 | amp_replay_buffer_size: int = 100000, 93 | use_smooth_ratio_clipping: bool = False, 94 | device: str = "cpu", 95 | ) -> None: 96 | # Set device and learning hyperparameters 97 | self.device: str = device 98 | self.desired_kl: float = desired_kl 99 | self.schedule: str = schedule 100 | self.learning_rate: float = learning_rate 101 | 102 | # Set up the discriminator and move it to the appropriate device. 103 | self.discriminator: Discriminator = discriminator.to(self.device) 104 | self.amp_transition: RolloutStorage.Transition = RolloutStorage.Transition() 105 | # Determine observation dimension used in the replay buffer. 106 | # The discriminator expects concatenated observations, so the replay buffer uses half the dimension. 107 | obs_dim: int = self.discriminator.input_dim // 2 108 | self.amp_storage: ReplayBuffer = ReplayBuffer( 109 | obs_dim=obs_dim, buffer_size=amp_replay_buffer_size, device=device 110 | ) 111 | self.amp_data: AMPLoader = amp_data 112 | 113 | # Set up the actor-critic (policy) and move it to the device. 114 | self.actor_critic = actor_critic 115 | self.actor_critic.to(self.device) 116 | self.storage: Optional[RolloutStorage] = ( 117 | None # Will be initialized later once environment parameters are known 118 | ) 119 | 120 | # Create optimizer for both the actor-critic and the discriminator. 121 | # Note: Weight decay is set differently for discriminator trunk and head. 122 | params = [ 123 | {"params": self.actor_critic.parameters(), "name": "actor_critic"}, 124 | { 125 | "params": self.discriminator.trunk.parameters(), 126 | "weight_decay": 10e-4, 127 | "name": "amp_trunk", 128 | }, 129 | { 130 | "params": self.discriminator.linear.parameters(), 131 | "weight_decay": 10e-2, 132 | "name": "amp_head", 133 | }, 134 | ] 135 | self.optimizer: optim.Adam = optim.Adam(params, lr=learning_rate) 136 | self.transition: RolloutStorage.Transition = RolloutStorage.Transition() 137 | # PPO-specific parameters 138 | self.clip_param: float = clip_param 139 | self.num_learning_epochs: int = num_learning_epochs 140 | self.num_mini_batches: int = num_mini_batches 141 | self.value_loss_coef: float = value_loss_coef 142 | self.entropy_coef: float = entropy_coef 143 | self.gamma: float = gamma 144 | self.lam: float = lam 145 | self.max_grad_norm: float = max_grad_norm 146 | self.use_clipped_value_loss: bool = use_clipped_value_loss 147 | self.use_smooth_ratio_clipping: bool = use_smooth_ratio_clipping 148 | 149 | def init_storage( 150 | self, 151 | num_envs: int, 152 | num_transitions_per_env: int, 153 | observations: TensorDict, 154 | action_shape: Tuple[int, ...], 155 | ) -> None: 156 | """Initialize rollout storage for TensorDict observations. 157 | 158 | Parameters 159 | ---------- 160 | num_envs : int 161 | Number of parallel environments. 162 | num_transitions_per_env : int 163 | Horizon (per environment) stored inside the rollout buffer. 164 | observations : TensorDict 165 | Prototype observation structure used to determine buffer shapes. 166 | action_shape : Tuple[int, ...] 167 | Shape of the action vector output by the policy. 168 | """ 169 | self.storage = RolloutStorage( 170 | training_type="rl", 171 | num_envs=num_envs, 172 | num_transitions_per_env=num_transitions_per_env, 173 | obs=observations, 174 | actions_shape=action_shape, 175 | device=self.device, 176 | ) 177 | 178 | def test_mode(self) -> None: 179 | """ 180 | Sets the actor-critic model to evaluation mode. 181 | """ 182 | self.actor_critic.eval() 183 | 184 | def train_mode(self) -> None: 185 | """ 186 | Sets the actor-critic model to training mode. 187 | """ 188 | self.actor_critic.train() 189 | 190 | def act(self, obs: TensorDict) -> torch.Tensor: 191 | """Select an action and value estimate for the current observation. 192 | 193 | Parameters 194 | ---------- 195 | obs : TensorDict 196 | Batched observation TensorDict provided by the environment. 197 | 198 | Returns 199 | ------- 200 | torch.Tensor 201 | Detached action tensor sampled from the actor-critic policy. 202 | """ 203 | if self.actor_critic.is_recurrent: 204 | self.transition.hidden_states = self.actor_critic.get_hidden_states() 205 | 206 | self.transition.actions = self.actor_critic.act(obs).detach() 207 | self.transition.values = self.actor_critic.evaluate(obs).detach() 208 | self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob( 209 | self.transition.actions 210 | ).detach() 211 | self.transition.action_mean = self.actor_critic.action_mean.detach() 212 | self.transition.action_sigma = self.actor_critic.action_std.detach() 213 | self.transition.observations = obs 214 | return self.transition.actions 215 | 216 | def act_amp(self, amp_obs: torch.Tensor) -> None: 217 | """Store the latest AMP policy observation for later replay insertion. 218 | 219 | Parameters 220 | ---------- 221 | amp_obs : torch.Tensor 222 | Concatenated AMP observation representing the current policy state. 223 | """ 224 | self.amp_transition.observations = amp_obs 225 | 226 | def process_env_step( 227 | self, 228 | obs: TensorDict, 229 | rewards: torch.Tensor, 230 | dones: torch.Tensor, 231 | extras: Dict[str, Any], 232 | ) -> None: 233 | """Record the outcome of an environment step and update normalizers. 234 | 235 | Parameters 236 | ---------- 237 | obs : TensorDict 238 | Observation returned by the environment after stepping. 239 | rewards : torch.Tensor 240 | Reward tensor (batch x 1) after mixing task/style components. 241 | dones : torch.Tensor 242 | Episode termination flags. 243 | extras : dict[str, Any] 244 | Additional metadata from the environment (e.g. ``time_outs``). 245 | """ 246 | self.actor_critic.update_normalization(obs) 247 | 248 | self.transition.rewards = rewards.clone() 249 | self.transition.dones = dones 250 | 251 | if "time_outs" in extras: 252 | self.transition.rewards += self.gamma * torch.squeeze( 253 | self.transition.values 254 | * extras["time_outs"].unsqueeze(1).to(self.device), 255 | 1, 256 | ) 257 | 258 | self.storage.add_transitions(self.transition) 259 | self.transition.clear() 260 | self.actor_critic.reset(dones) 261 | 262 | def process_amp_step(self, amp_obs: torch.Tensor) -> None: 263 | """Insert a policy-generated AMP transition into the replay buffer. 264 | 265 | Parameters 266 | ---------- 267 | amp_obs : torch.Tensor 268 | Next AMP observation paired with the previously stored policy state. 269 | """ 270 | self.amp_storage.insert(self.amp_transition.observations, amp_obs) 271 | self.amp_transition.clear() 272 | 273 | def compute_returns(self, obs: TensorDict) -> None: 274 | """Compute and store GAE-lambda returns from the final observation. 275 | 276 | Parameters 277 | ---------- 278 | obs : TensorDict 279 | Last observation gathered after rollout completion. 280 | """ 281 | 282 | last_values = self.actor_critic.evaluate(obs).detach() 283 | self.storage.compute_returns(last_values, self.gamma, self.lam) 284 | 285 | def update( 286 | self, 287 | ) -> Tuple[float, float, float, float, float, float, float, float, float]: 288 | """ 289 | Performs a single update step for both the actor-critic (PPO) and the AMP discriminator. 290 | It iterates over mini-batches of data, computes surrogate, value, AMP and gradient penalty losses, 291 | performs adaptive learning rate scheduling (if enabled), and updates model parameters. 292 | 293 | Returns 294 | ------- 295 | tuple 296 | A tuple containing mean losses and statistics: 297 | (mean_value_loss, mean_surrogate_loss, mean_amp_loss, mean_grad_pen_loss, 298 | mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert, 299 | mean_kl_divergence) 300 | """ 301 | # Initialize mean loss and accuracy statistics. 302 | mean_value_loss: float = 0.0 303 | mean_surrogate_loss: float = 0.0 304 | mean_amp_loss: float = 0.0 305 | mean_grad_pen_loss: float = 0.0 306 | mean_policy_pred: float = 0.0 307 | mean_expert_pred: float = 0.0 308 | mean_accuracy_policy: float = 0.0 309 | mean_accuracy_expert: float = 0.0 310 | mean_accuracy_policy_elem: float = 0.0 311 | mean_accuracy_expert_elem: float = 0.0 312 | mean_kl_divergence: float = 0.0 313 | 314 | # Create data generators for mini-batch sampling. 315 | if self.actor_critic.is_recurrent: 316 | generator = self.storage.recurrent_mini_batch_generator( 317 | self.num_mini_batches, self.num_learning_epochs 318 | ) 319 | else: 320 | generator = self.storage.mini_batch_generator( 321 | self.num_mini_batches, self.num_learning_epochs 322 | ) 323 | 324 | # Generator for policy-generated AMP transitions. 325 | amp_policy_generator = self.amp_storage.feed_forward_generator( 326 | num_mini_batch=self.num_learning_epochs * self.num_mini_batches, 327 | mini_batch_size=self.storage.num_envs 328 | * self.storage.num_transitions_per_env 329 | // self.num_mini_batches, 330 | allow_replacement=True, 331 | ) 332 | 333 | # Generator for expert AMP data. 334 | amp_expert_generator = self.amp_data.feed_forward_generator( 335 | self.num_learning_epochs * self.num_mini_batches, 336 | self.storage.num_envs 337 | * self.storage.num_transitions_per_env 338 | // self.num_mini_batches, 339 | ) 340 | 341 | # Loop over mini-batches from the environment transitions and AMP data. 342 | for sample, sample_amp_policy, sample_amp_expert in zip( 343 | generator, amp_policy_generator, amp_expert_generator 344 | ): 345 | # Unpack the mini-batch sample from the environment. 346 | ( 347 | obs_batch, 348 | actions_batch, 349 | target_values_batch, 350 | advantages_batch, 351 | returns_batch, 352 | old_actions_log_prob_batch, 353 | old_mu_batch, 354 | old_sigma_batch, 355 | hidden_states_batch, 356 | masks_batch, 357 | ) = sample 358 | 359 | hidden_state_actor, hidden_state_critic = (None, None) 360 | if hidden_states_batch is not None: 361 | hidden_state_actor, hidden_state_critic = hidden_states_batch 362 | 363 | # Forward pass through the actor to get current policy outputs. 364 | self.actor_critic.act( 365 | obs_batch, masks=masks_batch, hidden_states=hidden_state_actor 366 | ) 367 | actions_log_prob_batch = self.actor_critic.get_actions_log_prob( 368 | actions_batch 369 | ) 370 | value_batch = self.actor_critic.evaluate( 371 | obs_batch, masks=masks_batch, hidden_states=hidden_state_critic 372 | ) 373 | mu_batch = self.actor_critic.action_mean 374 | sigma_batch = self.actor_critic.action_std 375 | entropy_batch = self.actor_critic.entropy 376 | 377 | # Adaptive learning rate adjustment based on KL divergence if schedule is "adaptive". 378 | if self.desired_kl is not None and self.schedule == "adaptive": 379 | with torch.inference_mode(): 380 | kl = torch.sum( 381 | torch.log(sigma_batch / old_sigma_batch + 1.0e-5) 382 | + ( 383 | torch.square(old_sigma_batch) 384 | + torch.square(old_mu_batch - mu_batch) 385 | ) 386 | / (2.0 * torch.square(sigma_batch)) 387 | - 0.5, 388 | axis=-1, 389 | ) 390 | kl_mean = torch.mean(kl) 391 | mean_kl_divergence += kl_mean.item() 392 | 393 | if kl_mean > self.desired_kl * 2.0: 394 | self.learning_rate = max(1e-5, self.learning_rate / 1.5) 395 | elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: 396 | self.learning_rate = min(1e-2, self.learning_rate * 1.5) 397 | 398 | for param_group in self.optimizer.param_groups: 399 | param_group["lr"] = self.learning_rate 400 | 401 | # Compute the PPO surrogate loss. 402 | ratio = torch.exp( 403 | actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch) 404 | ) 405 | 406 | min_ = 1.0 - self.clip_param 407 | max_ = 1.0 + self.clip_param 408 | # Smooth clipping for the ratio if enabled. 409 | if self.use_smooth_ratio_clipping: 410 | clipped_ratio = ( 411 | 1 412 | / (1 + torch.exp((-(ratio - min_) / (max_ - min_) + 0.5) * 4)) 413 | * (max_ - min_) 414 | + min_ 415 | ) 416 | else: 417 | clipped_ratio = torch.clamp(ratio, min_, max_) 418 | 419 | surrogate = -torch.squeeze(advantages_batch) * ratio 420 | surrogate_clipped = -torch.squeeze(advantages_batch) * clipped_ratio 421 | surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() 422 | 423 | # Compute the value function loss. 424 | if self.use_clipped_value_loss: 425 | value_clipped = target_values_batch + ( 426 | value_batch - target_values_batch 427 | ).clamp(-self.clip_param, self.clip_param) 428 | value_losses = (value_batch - returns_batch).pow(2) 429 | value_losses_clipped = (value_clipped - returns_batch).pow(2) 430 | value_loss = torch.max(value_losses, value_losses_clipped).mean() 431 | else: 432 | value_loss = (returns_batch - value_batch).pow(2).mean() 433 | 434 | # Combine surrogate loss, value loss and entropy regularization to form PPO loss. 435 | ppo_loss = ( 436 | surrogate_loss 437 | + self.value_loss_coef * value_loss 438 | - self.entropy_coef * entropy_batch.mean() 439 | ) 440 | 441 | # Process AMP loss by unpacking policy and expert AMP samples. 442 | policy_state, policy_next_state = sample_amp_policy 443 | expert_state, expert_next_state = sample_amp_expert 444 | 445 | # Ensure everything is on the right device (AMPLoader may yield CPU tensors) 446 | policy_state = policy_state.to(self.device) 447 | policy_next_state = policy_next_state.to(self.device) 448 | expert_state = expert_state.to(self.device) 449 | expert_next_state = expert_next_state.to(self.device) 450 | 451 | # Keep raw tensors for normalizer updates 452 | policy_state_raw = policy_state.detach().clone() 453 | policy_next_state_raw = policy_next_state.detach().clone() 454 | expert_state_raw = expert_state.detach().clone() 455 | expert_next_state_raw = expert_next_state.detach().clone() 456 | 457 | # Concatenate policy and expert AMP observations for the discriminator input. 458 | B_pol = policy_state.size(0) 459 | discriminator_input = torch.cat( 460 | ( 461 | torch.cat([policy_state, policy_next_state], dim=-1), 462 | torch.cat([expert_state, expert_next_state], dim=-1), 463 | ), 464 | dim=0, 465 | ) 466 | discriminator_output = self.discriminator(discriminator_input) 467 | policy_d, expert_d = ( 468 | discriminator_output[:B_pol], 469 | discriminator_output[B_pol:], 470 | ) 471 | 472 | # Compute discriminator losses 473 | amp_loss, grad_pen_loss = self.discriminator.compute_loss( 474 | policy_d=policy_d, 475 | expert_d=expert_d, 476 | sample_amp_expert=(expert_state, expert_next_state), 477 | sample_amp_policy=(policy_state, policy_next_state), 478 | lambda_=10, 479 | ) 480 | 481 | # The final loss combines the PPO loss with AMP losses. 482 | loss = ppo_loss + (amp_loss + grad_pen_loss) 483 | 484 | # Backpropagation and optimizer step. 485 | self.optimizer.zero_grad() 486 | loss.backward() 487 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm) 488 | self.optimizer.step() 489 | 490 | # Update the normalizer with RAW (unnormalized) observations under no_grad 491 | self.discriminator.update_normalization( 492 | expert_state_raw, 493 | expert_next_state_raw, 494 | policy_state_raw, 495 | policy_next_state_raw, 496 | ) 497 | 498 | # Compute probabilities from the discriminator logits. 499 | policy_d_prob = torch.sigmoid(policy_d) 500 | expert_d_prob = torch.sigmoid(expert_d) 501 | 502 | # Update running statistics. 503 | mean_value_loss += value_loss.item() 504 | mean_surrogate_loss += surrogate_loss.item() 505 | mean_amp_loss += amp_loss.item() 506 | mean_grad_pen_loss += grad_pen_loss.item() 507 | mean_policy_pred += policy_d_prob.mean().item() 508 | mean_expert_pred += expert_d_prob.mean().item() 509 | 510 | # Calculate the accuracy of the discriminator. 511 | mean_accuracy_policy += torch.sum( 512 | torch.round(policy_d_prob) == torch.zeros_like(policy_d_prob) 513 | ).item() 514 | mean_accuracy_expert += torch.sum( 515 | torch.round(expert_d_prob) == torch.ones_like(expert_d_prob) 516 | ).item() 517 | 518 | # Record the total number of elements processed. 519 | mean_accuracy_expert_elem += expert_d_prob.numel() 520 | mean_accuracy_policy_elem += policy_d_prob.numel() 521 | 522 | # Average the statistics over all mini-batches. 523 | num_updates = self.num_learning_epochs * self.num_mini_batches 524 | mean_value_loss /= num_updates 525 | mean_surrogate_loss /= num_updates 526 | mean_amp_loss /= num_updates 527 | mean_grad_pen_loss /= num_updates 528 | mean_policy_pred /= num_updates 529 | mean_expert_pred /= num_updates 530 | mean_accuracy_policy /= max(1, mean_accuracy_policy_elem) 531 | mean_accuracy_expert /= max(1, mean_accuracy_expert_elem) 532 | mean_kl_divergence /= num_updates 533 | 534 | # Clear the storage for the next update cycle. 535 | self.storage.clear() 536 | 537 | return ( 538 | mean_value_loss, 539 | mean_surrogate_loss, 540 | mean_amp_loss, 541 | mean_grad_pen_loss, 542 | mean_policy_pred, 543 | mean_expert_pred, 544 | mean_accuracy_policy, 545 | mean_accuracy_expert, 546 | mean_kl_divergence, 547 | ) 548 | -------------------------------------------------------------------------------- /amp_rsl_rl/runners/amp_on_policy_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Istituto Italiano di Tecnologia 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | 7 | from __future__ import annotations 8 | 9 | import os 10 | import statistics 11 | import time 12 | from collections import deque 13 | 14 | import torch 15 | from torch.utils.tensorboard import SummaryWriter as TensorboardSummaryWriter 16 | 17 | import rsl_rl 18 | from rsl_rl.env import VecEnv 19 | from rsl_rl.modules import ActorCritic, ActorCriticRecurrent 20 | from rsl_rl.utils import resolve_obs_groups, store_code_state 21 | 22 | import amp_rsl_rl 23 | from amp_rsl_rl.utils import AMPLoader 24 | from amp_rsl_rl.algorithms import AMP_PPO 25 | from amp_rsl_rl.networks import Discriminator, ActorCriticMoE 26 | from amp_rsl_rl.utils import export_policy_as_onnx 27 | 28 | 29 | class AMPOnPolicyRunner: 30 | """ 31 | AMPOnPolicyRunner is a high-level orchestrator that manages the training and evaluation 32 | of a policy using Adversarial Motion Priors (AMP) combined with on-policy reinforcement learning (PPO). 33 | 34 | It brings together multiple components: 35 | - Environment (`VecEnv`) 36 | - Policy (`ActorCritic`, `ActorCriticRecurrent`) 37 | - Discriminator (Discriminator) 38 | - Expert dataset (AMPLoader) 39 | - Reward combination (task + style) 40 | - Logging and checkpointing 41 | 42 | --- 43 | 🔧 Configuration 44 | ---------------- 45 | The class expects a `train_cfg` dictionary structured with keys: 46 | - "obs_groups": optional mapping describing which observation tensors belong to "policy" and "critic" inputs. 47 | - "policy": configuration for the policy network, including `"class_name"` 48 | - "algorithm": configuration for PPO/AMP_PPO, including `"class_name"` 49 | - "discriminator": configuration for the AMP discriminator 50 | - "dataset": dictionary forwarded to `AMPLoader`, containing at least: 51 | * "amp_data_path": folder with the `.npy` expert datasets 52 | * "datasets": mapping of dataset name -> sampling weight (floats) 53 | * "slow_down_factor": slowdown applied to real motion data to match sim dynamics 54 | - "num_steps_per_env": rollout horizon per environment 55 | - "save_interval": frequency (in iterations) for model checkpointing 56 | - "empirical_normalization": (deprecated) legacy flag mirrored to `policy.actor_obs_normalization` 57 | - "logger": one of "tensorboard", "wandb", or "neptune" 58 | 59 | --- 60 | 📦 Dataset format 61 | ------------------ 62 | The expert motion datasets loaded via `AMPLoader` must be `.npy` files with a dictionary containing: 63 | 64 | - `"joints_list"`: List[str] — ordered list of joint names 65 | - `"joint_positions"`: List[np.ndarray] — joint configurations per timestep (1D arrays) 66 | - `"root_position"`: List[np.ndarray] — base position in world coordinates 67 | - `"root_quaternion"`: List[np.ndarray] — base orientation in **`xyzw`** format (SciPy default) 68 | - `"fps"`: float — original dataset frame rate 69 | 70 | Internally: 71 | - Quaternions are interpolated via SLERP and converted to **`wxyz`** format before being used by the model (to match Isaac Gym convention). 72 | - Velocities are estimated with finite differences. 73 | - All data is converted to torch tensors and placed on the desired device. 74 | 75 | --- 76 | 🎓 AMP Reward 77 | ------------- 78 | During each training step, the runner collects AMP-specific observations and computes 79 | a discriminator-based "style reward" from the expert dataset. This is combined 80 | with the environment reward as: 81 | 82 | `reward = 0.5 * task_reward + 0.5 * style_reward` 83 | 84 | This can be later generalized into a weighted or learned reward mixing policy. 85 | 86 | --- 87 | 🔁 Training loop 88 | ---------------- 89 | The `learn()` method performs: 90 | - `rollout`: collects TensorDict observations via `self.alg.act()` and `env.step()` 91 | - `style_reward`: computed from discriminator via `predict_reward(...)` 92 | - `storage update`: via `process_env_step()` and `process_amp_step()` 93 | - `return computation`: via `compute_returns()` using the latest observation TensorDict 94 | - `update`: performs backpropagation with `self.alg.update()` 95 | - Logging via TensorBoard/WandB/Neptune 96 | 97 | --- 98 | 💾 Saving and ONNX export 99 | -------------------------- 100 | At each `save_interval`, the runner: 101 | - Saves the full state (`model`, `optimizer`, `discriminator`, etc.) 102 | - Optionally exports the policy as an ONNX model for deployment 103 | - Uploads checkpoints to logging services if enabled 104 | 105 | --- 106 | 📤 Inference policy 107 | ------------------- 108 | `get_inference_policy()` returns a callable that takes an observation and returns an action. 109 | If empirical normalization is enabled, observations are normalized before inference. 110 | 111 | --- 112 | 🛠️ Additional tools 113 | ------------------- 114 | - Git integration via `store_code_state()` to track code changes 115 | - Logging of learning statistics, reward breakdown, discriminator metrics 116 | - Compatible with multi-task setups via dataset weights 117 | 118 | --- 119 | 📚 Notes 120 | -------- 121 | - This runner assumes an AMP-compatible VecEnv, providing `observations["amp"]` 122 | - AMP uses both current and next state to train the discriminator 123 | - Logging behavior is separated from core logic (WandB, Neptune, TensorBoard) 124 | - The Discriminator and AMP_PPO must follow expected APIs 125 | 126 | """ 127 | 128 | def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): 129 | self.cfg = train_cfg 130 | self.alg_cfg = train_cfg["algorithm"] 131 | self.policy_cfg = train_cfg["policy"] 132 | self.discriminator_cfg = train_cfg["discriminator"] 133 | self.dataset_cfg = train_cfg["dataset"] 134 | self.device = device 135 | self.env = env 136 | 137 | observations = self.env.get_observations() 138 | default_sets = ["critic"] 139 | self.cfg["obs_groups"] = resolve_obs_groups( 140 | observations, self.cfg.get("obs_groups"), default_sets 141 | ) 142 | 143 | actor_critic_class = eval(self.policy_cfg.pop("class_name")) # ActorCritic 144 | actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticMoE = ( 145 | actor_critic_class( 146 | observations, 147 | self.cfg["obs_groups"], 148 | self.env.num_actions, 149 | **self.policy_cfg, 150 | ).to(self.device) 151 | ) 152 | # NOTE: to use this we need to configure the observations in the env coherently with amp observation. Tested with Manager Based envs in Isaaclab 153 | amp_joint_names = self.env.cfg.observations.amp.joint_pos.params[ 154 | "asset_cfg" 155 | ].joint_names 156 | 157 | # Initialize all the ingredients required for AMP (discriminator, dataset loader) 158 | num_amp_obs = observations["amp"].shape[1] 159 | amp_data = AMPLoader( 160 | device=self.device, 161 | dataset_path_root=self.dataset_cfg["amp_data_path"], 162 | datasets=self.dataset_cfg["datasets"], 163 | simulation_dt=self.env.cfg.sim.dt * self.env.cfg.decimation, 164 | slow_down_factor=self.dataset_cfg["slow_down_factor"], 165 | expected_joint_names=amp_joint_names, 166 | ) 167 | 168 | self.discriminator = Discriminator( 169 | input_dim=num_amp_obs 170 | * 2, # the discriminator takes in the concatenation of the current and next observation 171 | hidden_layer_sizes=self.discriminator_cfg["hidden_dims"], 172 | reward_scale=self.discriminator_cfg["reward_scale"], 173 | device=self.device, 174 | loss_type=self.discriminator_cfg["loss_type"], 175 | empirical_normalization=self.discriminator_cfg["empirical_normalization"], 176 | ).to(self.device) 177 | 178 | # Initialize the PPO algorithm 179 | alg_class = eval(self.alg_cfg.pop("class_name")) # AMP_PPO 180 | # This removes from alg_cfg fields that are not in AMP_PPO but are introduced in rsl_rl 2.2.3 PPO 181 | # normalize_advantage_per_mini_batch=False, 182 | # rnd_cfg: dict | None = None, 183 | # symmetry_cfg: dict | None = None, 184 | # multi_gpu_cfg: dict | None = None, 185 | for key in list(self.alg_cfg.keys()): 186 | if key not in AMP_PPO.__init__.__code__.co_varnames: 187 | self.alg_cfg.pop(key) 188 | 189 | self.alg: AMP_PPO = alg_class( 190 | actor_critic=actor_critic, 191 | discriminator=self.discriminator, 192 | amp_data=amp_data, 193 | device=self.device, 194 | **self.alg_cfg, 195 | ) 196 | self.num_steps_per_env = self.cfg["num_steps_per_env"] 197 | self.save_interval = self.cfg["save_interval"] 198 | # init storage and model 199 | obs_template = observations.clone().detach().to(self.device) 200 | self.alg.init_storage( 201 | self.env.num_envs, 202 | self.num_steps_per_env, 203 | obs_template, 204 | (self.env.num_actions,), 205 | ) 206 | 207 | # Log 208 | self.log_dir = log_dir 209 | self.writer = None 210 | self.logger_type = None 211 | self.tot_timesteps = 0 212 | self.tot_time = 0 213 | self.current_learning_iteration = 0 214 | self.git_status_repos = [rsl_rl.__file__, amp_rsl_rl.__file__] 215 | 216 | def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): 217 | # initialize writer 218 | if self.log_dir is not None and self.writer is None: 219 | # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard. 220 | self.logger_type = self.cfg.get("logger", "tensorboard") 221 | self.logger_type = self.logger_type.lower() 222 | 223 | if self.logger_type == "neptune": 224 | from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter 225 | 226 | self.writer = NeptuneSummaryWriter( 227 | log_dir=self.log_dir, flush_secs=10, cfg=self.cfg 228 | ) 229 | self.writer.log_config( 230 | self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg 231 | ) 232 | elif self.logger_type == "wandb": 233 | from rsl_rl.utils.wandb_utils import WandbSummaryWriter 234 | import wandb 235 | 236 | # Update the run name with a sequence number. This function is useful to 237 | # replicate the same behaviour of rsl-rl-lib before v2.3.0 238 | def update_run_name_with_sequence(prefix: str) -> None: 239 | # Retrieve the current wandb run details (project and entity) 240 | project = wandb.run.project 241 | entity = wandb.run.entity 242 | 243 | # Use wandb's API to list all runs in your project 244 | api = wandb.Api() 245 | runs = api.runs(f"{entity}/{project}") 246 | 247 | max_num = 0 248 | # Iterate through runs to extract the numeric suffix after the prefix. 249 | for run in runs: 250 | if run.name.startswith(prefix): 251 | # Extract the numeric part from the run name. 252 | numeric_suffix = run.name[ 253 | len(prefix) : 254 | ] # e.g., from "prefix564", get "564" 255 | try: 256 | run_num = int(numeric_suffix) 257 | if run_num > max_num: 258 | max_num = run_num 259 | except ValueError: 260 | continue 261 | 262 | # Increment to get the new run number 263 | new_num = max_num + 1 264 | new_run_name = f"{prefix}{new_num}" 265 | 266 | # Update the wandb run's name 267 | wandb.run.name = new_run_name 268 | print("Updated run name to:", wandb.run.name) 269 | 270 | self.writer = WandbSummaryWriter( 271 | log_dir=self.log_dir, flush_secs=10, cfg=self.cfg 272 | ) 273 | update_run_name_with_sequence(prefix=self.cfg["wandb_project"]) 274 | 275 | wandb.gym.monitor() 276 | self.writer.log_config( 277 | self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg 278 | ) 279 | elif self.logger_type == "tensorboard": 280 | self.writer = TensorboardSummaryWriter( 281 | log_dir=self.log_dir, flush_secs=10 282 | ) 283 | else: 284 | raise AssertionError("logger type not found") 285 | 286 | if init_at_random_ep_len: 287 | self.env.episode_length_buf = torch.randint_like( 288 | self.env.episode_length_buf, high=int(self.env.max_episode_length) 289 | ) 290 | obs = self.env.get_observations().to(self.device) 291 | amp_obs = obs["amp"].clone() 292 | self.train_mode() # switch to train mode (for dropout for example) 293 | 294 | ep_infos = [] 295 | rewbuffer = deque(maxlen=100) 296 | lenbuffer = deque(maxlen=100) 297 | cur_reward_sum = torch.zeros( 298 | self.env.num_envs, dtype=torch.float, device=self.device 299 | ) 300 | cur_episode_length = torch.zeros( 301 | self.env.num_envs, dtype=torch.float, device=self.device 302 | ) 303 | 304 | start_iter = self.current_learning_iteration 305 | tot_iter = start_iter + num_learning_iterations 306 | for it in range(start_iter, tot_iter): 307 | start = time.time() 308 | # Rollout 309 | 310 | mean_style_reward_log = 0 311 | mean_task_reward_log = 0 312 | 313 | with torch.inference_mode(): 314 | for _ in range(self.num_steps_per_env): 315 | actions = self.alg.act(obs) 316 | self.alg.act_amp(amp_obs) 317 | obs, rewards, dones, extras = self.env.step( 318 | actions.to(self.env.device) 319 | ) 320 | obs = obs.to(self.device) 321 | rewards = rewards.to(self.device) 322 | dones = dones.to(self.device) 323 | 324 | next_amp_obs = obs["amp"].clone() 325 | style_rewards = self.discriminator.predict_reward( 326 | amp_obs, next_amp_obs 327 | ) 328 | 329 | mean_task_reward_log += rewards.mean().item() 330 | mean_style_reward_log += style_rewards.mean().item() 331 | 332 | rewards = 0.5 * rewards + 0.5 * style_rewards 333 | 334 | self.alg.process_env_step(obs, rewards, dones, extras) 335 | self.alg.process_amp_step(next_amp_obs) 336 | 337 | amp_obs = next_amp_obs 338 | 339 | if self.log_dir is not None: 340 | if "episode" in extras: 341 | ep_infos.append(extras["episode"]) 342 | elif "log" in extras: 343 | ep_infos.append(extras["log"]) 344 | cur_reward_sum += rewards 345 | cur_episode_length += 1 346 | new_ids = torch.nonzero(dones, as_tuple=False) 347 | if new_ids.numel() > 0: 348 | env_indices = new_ids.view(-1) 349 | rewbuffer.extend(cur_reward_sum[env_indices].cpu().tolist()) 350 | lenbuffer.extend( 351 | cur_episode_length[env_indices].cpu().tolist() 352 | ) 353 | cur_reward_sum[env_indices] = 0 354 | cur_episode_length[env_indices] = 0 355 | 356 | stop = time.time() 357 | collection_time = stop - start 358 | 359 | # Learning step 360 | start = stop 361 | self.alg.compute_returns(obs) 362 | 363 | mean_style_reward_log /= self.num_steps_per_env 364 | mean_task_reward_log /= self.num_steps_per_env 365 | 366 | ( 367 | mean_value_loss, 368 | mean_surrogate_loss, 369 | mean_amp_loss, 370 | mean_grad_pen_loss, 371 | mean_policy_pred, 372 | mean_expert_pred, 373 | mean_accuracy_policy, 374 | mean_accuracy_expert, 375 | mean_kl_divergence, 376 | ) = self.alg.update() 377 | stop = time.time() 378 | learn_time = stop - start 379 | self.current_learning_iteration = it 380 | if self.log_dir is not None: 381 | self.log(locals()) 382 | if it % self.save_interval == 0: 383 | self.save(os.path.join(self.log_dir, f"model_{it}.pt"), save_onnx=True) 384 | ep_infos.clear() 385 | if it == start_iter: 386 | # obtain all the diff files 387 | git_file_paths = store_code_state(self.log_dir, self.git_status_repos) 388 | # if possible store them to wandb 389 | if self.logger_type in ["wandb", "neptune"] and git_file_paths: 390 | for path in git_file_paths: 391 | self.writer.save_file(path) 392 | 393 | self.save( 394 | os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"), 395 | save_onnx=True, 396 | ) 397 | 398 | def log(self, locs: dict, width: int = 80, pad: int = 35): 399 | self.tot_timesteps += self.num_steps_per_env * self.env.num_envs 400 | self.tot_time += locs["collection_time"] + locs["learn_time"] 401 | iteration_time = locs["collection_time"] + locs["learn_time"] 402 | 403 | ep_string = "" 404 | if locs["ep_infos"]: 405 | for key in locs["ep_infos"][0]: 406 | infotensor = torch.tensor([], device=self.device) 407 | for ep_info in locs["ep_infos"]: 408 | # handle scalar and zero dimensional tensor infos 409 | if key not in ep_info: 410 | continue 411 | if not isinstance(ep_info[key], torch.Tensor): 412 | ep_info[key] = torch.Tensor([ep_info[key]]) 413 | if len(ep_info[key].shape) == 0: 414 | ep_info[key] = ep_info[key].unsqueeze(0) 415 | infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) 416 | value = torch.mean(infotensor) 417 | # log to logger and terminal 418 | if "/" in key: 419 | self.writer.add_scalar(key, value, locs["it"]) 420 | ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n""" 421 | else: 422 | self.writer.add_scalar("Episode/" + key, value, locs["it"]) 423 | ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" 424 | if getattr(self.alg.actor_critic, "noise_std_type", "scalar") == "log": 425 | mean_std_value = torch.exp(self.alg.actor_critic.log_std).mean() 426 | else: 427 | mean_std_value = self.alg.actor_critic.std.mean() 428 | fps = int( 429 | self.num_steps_per_env 430 | * self.env.num_envs 431 | / (locs["collection_time"] + locs["learn_time"]) 432 | ) 433 | 434 | self.writer.add_scalar( 435 | "Loss/value_function", locs["mean_value_loss"], locs["it"] 436 | ) 437 | self.writer.add_scalar( 438 | "Loss/surrogate", locs["mean_surrogate_loss"], locs["it"] 439 | ) 440 | 441 | # Adding logging due to AMP 442 | self.writer.add_scalar("Loss/amp_loss", locs["mean_amp_loss"], locs["it"]) 443 | self.writer.add_scalar( 444 | "Loss/grad_pen_loss", locs["mean_grad_pen_loss"], locs["it"] 445 | ) 446 | self.writer.add_scalar("Loss/policy_pred", locs["mean_policy_pred"], locs["it"]) 447 | self.writer.add_scalar("Loss/expert_pred", locs["mean_expert_pred"], locs["it"]) 448 | self.writer.add_scalar( 449 | "Loss/accuracy_policy", locs["mean_accuracy_policy"], locs["it"] 450 | ) 451 | self.writer.add_scalar( 452 | "Loss/accuracy_expert", locs["mean_accuracy_expert"], locs["it"] 453 | ) 454 | 455 | self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"]) 456 | self.writer.add_scalar( 457 | "Loss/mean_kl_divergence", locs["mean_kl_divergence"], locs["it"] 458 | ) 459 | self.writer.add_scalar( 460 | "Policy/mean_noise_std", mean_std_value.item(), locs["it"] 461 | ) 462 | self.writer.add_scalar("Perf/total_fps", fps, locs["it"]) 463 | self.writer.add_scalar( 464 | "Perf/collection time", locs["collection_time"], locs["it"] 465 | ) 466 | self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"]) 467 | if len(locs["rewbuffer"]) > 0: 468 | self.writer.add_scalar( 469 | "Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"] 470 | ) 471 | self.writer.add_scalar( 472 | "Train/mean_episode_length", 473 | statistics.mean(locs["lenbuffer"]), 474 | locs["it"], 475 | ) 476 | self.writer.add_scalar( 477 | "Train/mean_style_reward", locs["mean_style_reward_log"], locs["it"] 478 | ) 479 | self.writer.add_scalar( 480 | "Train/mean_task_reward", locs["mean_task_reward_log"], locs["it"] 481 | ) 482 | if ( 483 | self.logger_type != "wandb" 484 | ): # wandb does not support non-integer x-axis logging 485 | self.writer.add_scalar( 486 | "Train/mean_reward/time", 487 | statistics.mean(locs["rewbuffer"]), 488 | self.tot_time, 489 | ) 490 | self.writer.add_scalar( 491 | "Train/mean_episode_length/time", 492 | statistics.mean(locs["lenbuffer"]), 493 | self.tot_time, 494 | ) 495 | 496 | str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " 497 | 498 | if len(locs["rewbuffer"]) > 0: 499 | log_string = ( 500 | f"""{'#' * width}\n""" 501 | f"""{str.center(width, ' ')}\n\n""" 502 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ 503 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 504 | f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" 505 | f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" 506 | f"""{'Mean action noise std:':>{pad}} {mean_std_value.item():.2f}\n""" 507 | f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" 508 | f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""" 509 | ) 510 | # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" 511 | # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") 512 | else: 513 | log_string = ( 514 | f"""{'#' * width}\n""" 515 | f"""{str.center(width, ' ')}\n\n""" 516 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ 517 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 518 | f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" 519 | f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" 520 | f"""{'Mean action noise std:':>{pad}} {mean_std_value.item():.2f}\n""" 521 | ) 522 | # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" 523 | # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") 524 | 525 | log_string += ep_string 526 | 527 | # make the eta in H:M:S 528 | eta_seconds = ( 529 | self.tot_time 530 | / (locs["it"] + 1) 531 | * (locs["num_learning_iterations"] - locs["it"]) 532 | ) 533 | 534 | # Convert seconds to H:M:S 535 | eta_h, rem = divmod(eta_seconds, 3600) 536 | eta_m, eta_s = divmod(rem, 60) 537 | 538 | log_string += ( 539 | f"""{'-' * width}\n""" 540 | f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" 541 | f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" 542 | f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n""" 543 | f"""{'ETA:':>{pad}} {int(eta_h)}h {int(eta_m)}m {int(eta_s)}s\n""" 544 | ) 545 | print(log_string) 546 | 547 | def save(self, path, infos=None, save_onnx=False): 548 | saved_dict = { 549 | "model_state_dict": self.alg.actor_critic.state_dict(), 550 | "optimizer_state_dict": self.alg.optimizer.state_dict(), 551 | "discriminator_state_dict": self.alg.discriminator.state_dict(), 552 | "iter": self.current_learning_iteration, 553 | "infos": infos, 554 | } 555 | torch.save(saved_dict, path) 556 | 557 | # Upload model to external logging service 558 | if self.logger_type in ["neptune", "wandb"]: 559 | self.writer.save_model(path, self.current_learning_iteration) 560 | 561 | if save_onnx: 562 | # Save the model in ONNX format 563 | # extract the folder path 564 | onnx_folder = os.path.dirname(path) 565 | 566 | # extract the iteration number from the path. The path is expected to be in the format 567 | # model_{iteration}.pt 568 | iteration = int(os.path.basename(path).split("_")[1].split(".")[0]) 569 | onnx_model_name = f"policy_{iteration}.onnx" 570 | 571 | export_policy_as_onnx( 572 | self.alg.actor_critic, 573 | normalizer=self.alg.actor_critic.actor_obs_normalizer, 574 | path=onnx_folder, 575 | filename=onnx_model_name, 576 | ) 577 | 578 | if self.logger_type in ["neptune", "wandb"]: 579 | self.writer.save_model( 580 | os.path.join(onnx_folder, onnx_model_name), 581 | self.current_learning_iteration, 582 | ) 583 | 584 | def load(self, path, load_optimizer=True, weights_only=False): 585 | loaded_dict = torch.load( 586 | path, map_location=self.device, weights_only=weights_only 587 | ) 588 | self.alg.actor_critic.load_state_dict(loaded_dict["model_state_dict"]) 589 | discriminator_state = loaded_dict["discriminator_state_dict"] 590 | self.alg.discriminator.load_state_dict(discriminator_state, strict=False) 591 | 592 | amp_normalizer_module = loaded_dict.get("amp_normalizer") 593 | if amp_normalizer_module is not None and getattr( 594 | self.alg.discriminator, "empirical_normalization", False 595 | ): 596 | # Old checkpoints stored the empirical normalizer separately; hydrate it if present. 597 | self.alg.discriminator.amp_normalizer.load_state_dict( 598 | amp_normalizer_module.state_dict() 599 | ) 600 | if load_optimizer: 601 | self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) 602 | self.current_learning_iteration = loaded_dict["iter"] 603 | return loaded_dict["infos"] 604 | 605 | def get_inference_policy(self, device=None): 606 | self.eval_mode() # switch to evaluation mode (dropout for example) 607 | if device is not None: 608 | self.alg.actor_critic.to(device) 609 | return self.alg.actor_critic.act_inference 610 | 611 | def train_mode(self): 612 | self.alg.actor_critic.train() 613 | self.alg.discriminator.train() 614 | 615 | def eval_mode(self): 616 | self.alg.actor_critic.eval() 617 | self.alg.discriminator.eval() 618 | 619 | def add_git_repo_to_log(self, repo_file_path): 620 | self.git_status_repos.append(repo_file_path) 621 | --------------------------------------------------------------------------------