├── .flake8 ├── .github └── LICENSE_HEADER.txt ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTORS.md ├── LICENSE ├── README.md ├── config └── dummy_config.yaml ├── licenses └── dependencies │ ├── black-license.txt │ ├── codespell-license.txt │ ├── flake8-license.txt │ ├── isort-license.txt │ ├── numpy_license.txt │ ├── onnx-license.txt │ ├── pre-commit-hooks-license.txt │ ├── pre-commit-license.txt │ ├── pyright-license.txt │ ├── pyupgrade-license.txt │ └── torch_license.txt ├── pyproject.toml ├── rsl_rl ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── distillation.py │ └── ppo.py ├── env │ ├── __init__.py │ └── vec_env.py ├── modules │ ├── __init__.py │ ├── actor_critic.py │ ├── actor_critic_recurrent.py │ ├── normalizer.py │ ├── rnd.py │ ├── student_teacher.py │ └── student_teacher_recurrent.py ├── networks │ ├── __init__.py │ └── memory.py ├── runners │ ├── __init__.py │ └── on_policy_runner.py ├── storage │ ├── __init__.py │ └── rollout_storage.py └── utils │ ├── __init__.py │ ├── neptune_utils.py │ ├── utils.py │ └── wandb_utils.py └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | show-source=True 3 | statistics=True 4 | per-file-ignores=*/__init__.py:F401 5 | # E402: Module level import not at top of file 6 | # E501: Line too long 7 | # W503: Line break before binary operator 8 | # E203: Whitespace before ':' -> conflicts with black 9 | # D401: First line should be in imperative mood 10 | # R504: Unnecessary variable assignment before return statement. 11 | # R505: Unnecessary elif after return statement 12 | # SIM102: Use a single if-statement instead of nested if-statements 13 | # SIM117: Merge with statements for context managers that have same scope. 14 | ignore=E402,E501,W503,E203,D401,R504,R505,SIM102,SIM117 15 | max-line-length = 120 16 | max-complexity = 18 17 | exclude=_*,.vscode,.git,docs/** 18 | # docstrings 19 | docstring-convention=google 20 | # annotations 21 | suppress-none-returning=True 22 | allow-star-arg-any=True 23 | -------------------------------------------------------------------------------- /.github/LICENSE_HEADER.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | All rights reserved. 3 | 4 | SPDX-License-Identifier: BSD-3-Clause 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDEs 2 | .idea 3 | 4 | # builds 5 | *.egg-info 6 | build/* 7 | dist/* 8 | 9 | # cache 10 | __pycache__ 11 | .pytest_cache 12 | 13 | # vs code 14 | .vscode 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/python/black 3 | rev: 23.10.1 4 | hooks: 5 | - id: black 6 | args: ["--line-length", "120", "--preview"] 7 | - repo: https://github.com/pycqa/flake8 8 | rev: 6.1.0 9 | hooks: 10 | - id: flake8 11 | additional_dependencies: [flake8-simplify, flake8-return] 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v4.5.0 14 | hooks: 15 | - id: trailing-whitespace 16 | - id: check-symlinks 17 | - id: destroyed-symlinks 18 | - id: check-yaml 19 | - id: check-merge-conflict 20 | - id: check-case-conflict 21 | - id: check-executables-have-shebangs 22 | - id: check-toml 23 | - id: end-of-file-fixer 24 | - id: check-shebang-scripts-are-executable 25 | - id: detect-private-key 26 | - id: debug-statements 27 | - repo: https://github.com/pycqa/isort 28 | rev: 5.12.0 29 | hooks: 30 | - id: isort 31 | name: isort (python) 32 | args: ["--profile", "black", "--filter-files"] 33 | - repo: https://github.com/asottile/pyupgrade 34 | rev: v3.15.0 35 | hooks: 36 | - id: pyupgrade 37 | args: ["--py37-plus"] 38 | - repo: https://github.com/codespell-project/codespell 39 | rev: v2.2.6 40 | hooks: 41 | - id: codespell 42 | additional_dependencies: 43 | - tomli 44 | - repo: https://github.com/Lucas-C/pre-commit-hooks 45 | rev: v1.5.1 46 | hooks: 47 | - id: insert-license 48 | files: \.py$ 49 | args: 50 | # - --remove-header # Remove existing license headers. Useful when updating license. 51 | - --license-filepath 52 | - .github/LICENSE_HEADER.txt 53 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | # RSL-RL Maintainers and Contributors 2 | 3 | This is the official list of developers and contributors. 4 | 5 | To see the full list of contributors, see the revision history in the source control. 6 | 7 | Names should be added to this file as: individual names or organizations. 8 | 9 | Email addresses are tracked elsewhere to avoid spam. 10 | 11 | Please keep the lists sorted alphabetically. 12 | 13 | ## Maintainers 14 | 15 | * Robotic Systems Lab, ETH Zurich 16 | * NVIDIA Corporation 17 | 18 | --- 19 | 20 | * Mayank Mittal 21 | * Clemens Schwarke 22 | 23 | ## Authors 24 | 25 | * David Hoeller 26 | * Nikita Rudin 27 | 28 | ## Contributors 29 | 30 | * Bikram Pandit 31 | * Eric Vollenweider 32 | * Fabian Jenelten 33 | * Lorenzo Terenzi 34 | * Marko Bjelonic 35 | * Matthijs van der Boon 36 | * Özhan Özen 37 | * Pascal Roth 38 | * Zhang Chong 39 | * Ziqi Fan 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025, ETH Zurich 2 | Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its contributors 16 | may be used to endorse or promote products derived from this software without 17 | specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 26 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | See licenses/dependencies for license information of dependencies of this package. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RSL RL 2 | 3 | A fast and simple implementation of RL algorithms, designed to run fully on GPU. 4 | This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac Gym. 5 | 6 | Environment repositories using the framework: 7 | 8 | * **`Isaac Lab`** (built on top of NVIDIA Isaac Sim): https://github.com/isaac-sim/IsaacLab 9 | * **`Legged-Gym`** (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/ 10 | 11 | The main branch supports **PPO** and **Student-Teacher Distillation** with additional features from our research. These include: 12 | 13 | * [Random Network Distillation (RND)](https://proceedings.mlr.press/v229/schwarke23a.html) - Encourages exploration by adding 14 | a curiosity driven intrinsic reward. 15 | * [Symmetry-based Augmentation](https://arxiv.org/abs/2403.04359) - Makes the learned behaviors more symmetrical. 16 | 17 | We welcome contributions from the community. Please check our contribution guidelines for more 18 | information. 19 | 20 | **Maintainer**: Mayank Mittal and Clemens Schwarke
21 | **Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA
22 | **Contact**: cschwarke@ethz.ch 23 | 24 | > **Note:** The `algorithms` branch supports additional algorithms (SAC, DDPG, DSAC, and more). However, it isn't currently actively maintained. 25 | 26 | 27 | ## Setup 28 | 29 | The package can be installed via PyPI with: 30 | 31 | ```bash 32 | pip install rsl-rl-lib 33 | ``` 34 | 35 | or by cloning this repository and installing it with: 36 | 37 | ```bash 38 | git clone https://github.com/leggedrobotics/rsl_rl 39 | cd rsl_rl 40 | pip install -e . 41 | ``` 42 | 43 | The package supports the following logging frameworks which can be configured through `logger`: 44 | 45 | * Tensorboard: https://www.tensorflow.org/tensorboard/ 46 | * Weights & Biases: https://wandb.ai/site 47 | * Neptune: https://docs.neptune.ai/ 48 | 49 | For a demo configuration of PPO, please check the [dummy_config.yaml](config/dummy_config.yaml) file. 50 | 51 | 52 | ## Contribution Guidelines 53 | 54 | For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) for docstrings. Please make sure that your code is well-documented and follows the guidelines. 55 | 56 | We use the following tools for maintaining code quality: 57 | 58 | - [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase. 59 | - [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter. 60 | - [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker. 61 | 62 | Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal: 63 | 64 | ```bash 65 | # for installation (only once) 66 | pre-commit install 67 | # for running 68 | pre-commit run --all-files 69 | ``` 70 | 71 | ## Citing 72 | 73 | **We are working on writing a white paper for this library.** Until then, please cite the following work 74 | if you use this library for your research: 75 | 76 | ```text 77 | @InProceedings{rudin2022learning, 78 | title = {Learning to Walk in Minutes Using Massively Parallel Deep Reinforcement Learning}, 79 | author = {Rudin, Nikita and Hoeller, David and Reist, Philipp and Hutter, Marco}, 80 | booktitle = {Proceedings of the 5th Conference on Robot Learning}, 81 | pages = {91--100}, 82 | year = {2022}, 83 | volume = {164}, 84 | series = {Proceedings of Machine Learning Research}, 85 | publisher = {PMLR}, 86 | url = {https://proceedings.mlr.press/v164/rudin22a.html}, 87 | } 88 | ``` 89 | 90 | If you use the library with curiosity-driven exploration (random network distillation), please cite: 91 | 92 | ```text 93 | @InProceedings{schwarke2023curiosity, 94 | title = {Curiosity-Driven Learning of Joint Locomotion and Manipulation Tasks}, 95 | author = {Schwarke, Clemens and Klemm, Victor and Boon, Matthijs van der and Bjelonic, Marko and Hutter, Marco}, 96 | booktitle = {Proceedings of The 7th Conference on Robot Learning}, 97 | pages = {2594--2610}, 98 | year = {2023}, 99 | volume = {229}, 100 | series = {Proceedings of Machine Learning Research}, 101 | publisher = {PMLR}, 102 | url = {https://proceedings.mlr.press/v229/schwarke23a.html}, 103 | } 104 | ``` 105 | 106 | If you use the library with symmetry augmentation, please cite: 107 | 108 | ```text 109 | @InProceedings{mittal2024symmetry, 110 | author={Mittal, Mayank and Rudin, Nikita and Klemm, Victor and Allshire, Arthur and Hutter, Marco}, 111 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)}, 112 | title={Symmetry Considerations for Learning Task Symmetric Robot Policies}, 113 | year={2024}, 114 | pages={7433-7439}, 115 | doi={10.1109/ICRA57147.2024.10611493} 116 | } 117 | ``` 118 | -------------------------------------------------------------------------------- /config/dummy_config.yaml: -------------------------------------------------------------------------------- 1 | algorithm: 2 | class_name: PPO 3 | # training parameters 4 | # -- advantage normalization 5 | normalize_advantage_per_mini_batch: false 6 | # -- value function 7 | value_loss_coef: 1.0 8 | clip_param: 0.2 9 | use_clipped_value_loss: true 10 | # -- surrogate loss 11 | desired_kl: 0.01 12 | entropy_coef: 0.01 13 | gamma: 0.99 14 | lam: 0.95 15 | max_grad_norm: 1.0 16 | # -- training 17 | learning_rate: 0.001 18 | num_learning_epochs: 5 19 | num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches 20 | schedule: adaptive # adaptive, fixed 21 | 22 | # -- Random Network Distillation 23 | rnd_cfg: 24 | weight: 0.0 # initial weight of the RND reward 25 | 26 | # note: This is a dictionary with a required key called "mode". 27 | # Please check the RND module for more information. 28 | weight_schedule: null 29 | 30 | reward_normalization: false # whether to normalize RND reward 31 | state_normalization: true # whether to normalize RND state observations 32 | 33 | # -- Learning parameters 34 | learning_rate: 0.001 # learning rate for RND 35 | 36 | # -- Network parameters 37 | # note: if -1, then the network will use dimensions of the observation 38 | num_outputs: 1 # number of outputs of RND network 39 | predictor_hidden_dims: [-1] # hidden dimensions of predictor network 40 | target_hidden_dims: [-1] # hidden dimensions of target network 41 | 42 | # -- Symmetry Augmentation 43 | symmetry_cfg: 44 | use_data_augmentation: true # this adds symmetric trajectories to the batch 45 | use_mirror_loss: false # this adds symmetry loss term to the loss function 46 | 47 | # string containing the module and function name to import. 48 | # Example: "legged_gym.envs.locomotion.anymal_c.symmetry:get_symmetric_states" 49 | # 50 | # .. code-block:: python 51 | # 52 | # @torch.no_grad() 53 | # def get_symmetric_states( 54 | # obs: Optional[torch.Tensor] = None, actions: Optional[torch.Tensor] = None, cfg: "BaseEnvCfg" = None, obs_type: str = "policy" 55 | # ) -> Tuple[torch.Tensor, torch.Tensor]: 56 | # 57 | data_augmentation_func: null 58 | 59 | # coefficient for symmetry loss term 60 | # if 0, then no symmetry loss is used 61 | mirror_loss_coeff: 0.0 62 | 63 | policy: 64 | class_name: ActorCritic 65 | # for MLP i.e. `ActorCritic` 66 | activation: elu 67 | actor_hidden_dims: [128, 128, 128] 68 | critic_hidden_dims: [128, 128, 128] 69 | init_noise_std: 1.0 70 | noise_std_type: "scalar" # 'scalar' or 'log' 71 | 72 | # only needed for `ActorCriticRecurrent` 73 | # rnn_type: 'lstm' 74 | # rnn_hidden_dim: 512 75 | # rnn_num_layers: 1 76 | 77 | runner: 78 | num_steps_per_env: 24 # number of steps per environment per iteration 79 | max_iterations: 1500 # number of policy updates 80 | empirical_normalization: false 81 | # -- logging parameters 82 | save_interval: 50 # check for potential saves every `save_interval` iterations 83 | experiment_name: walking_experiment 84 | run_name: "" 85 | # -- logging writer 86 | logger: tensorboard # tensorboard, neptune, wandb 87 | neptune_project: legged_gym 88 | wandb_project: legged_gym 89 | # -- load and resuming 90 | load_run: -1 # -1 means load latest run 91 | resume_path: null # updated from load_run and checkpoint 92 | checkpoint: -1 # -1 means load latest checkpoint 93 | 94 | runner_class_name: OnPolicyRunner 95 | seed: 1 96 | -------------------------------------------------------------------------------- /licenses/dependencies/black-license.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Łukasz Langa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/dependencies/codespell-license.txt: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /licenses/dependencies/flake8-license.txt: -------------------------------------------------------------------------------- 1 | == Flake8 License (MIT) == 2 | 3 | Copyright (C) 2011-2013 Tarek Ziade 4 | Copyright (C) 2012-2016 Ian Cordasco 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | this software and associated documentation files (the "Software"), to deal in 8 | the Software without restriction, including without limitation the rights to 9 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 10 | of the Software, and to permit persons to whom the Software is furnished to do 11 | so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /licenses/dependencies/isort-license.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013 Timothy Edmund Crosley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/dependencies/numpy_license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2005-2021, NumPy Developers. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of the NumPy Developers nor the names of any 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /licenses/dependencies/onnx-license.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /licenses/dependencies/pre-commit-hooks-license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 pre-commit dev team: Anthony Sottile, Ken Struys 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /licenses/dependencies/pre-commit-license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 pre-commit dev team: Anthony Sottile, Ken Struys 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /licenses/dependencies/pyright-license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Robert Craigie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | =============================================================================== 25 | 26 | MIT License 27 | 28 | Pyright - A static type checker for the Python language 29 | Copyright (c) Microsoft Corporation. All rights reserved. 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in all 39 | copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 | SOFTWARE 48 | -------------------------------------------------------------------------------- /licenses/dependencies/pyupgrade-license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Anthony Sottile 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /licenses/dependencies/torch_license.txt: -------------------------------------------------------------------------------- 1 | From PyTorch: 2 | 3 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 4 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 5 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 6 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 7 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 8 | Copyright (c) 2011-2013 NYU (Clement Farabet) 9 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 10 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 11 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 12 | 13 | From Caffe2: 14 | 15 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 16 | 17 | All contributions by Facebook: 18 | Copyright (c) 2016 Facebook Inc. 19 | 20 | All contributions by Google: 21 | Copyright (c) 2015 Google Inc. 22 | All rights reserved. 23 | 24 | All contributions by Yangqing Jia: 25 | Copyright (c) 2015 Yangqing Jia 26 | All rights reserved. 27 | 28 | All contributions by Kakao Brain: 29 | Copyright 2019-2020 Kakao Brain 30 | 31 | All contributions from Caffe: 32 | Copyright(c) 2013, 2014, 2015, the respective contributors 33 | All rights reserved. 34 | 35 | All other contributions: 36 | Copyright(c) 2015, 2016 the respective contributors 37 | All rights reserved. 38 | 39 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 40 | copyright over their contributions to Caffe2. The project versioning records 41 | all such contribution and copyright details. If a contributor wants to further 42 | mark their specific copyright on a particular contribution, they should 43 | indicate their copyright solely in the commit message of the change when it is 44 | committed. 45 | 46 | All rights reserved. 47 | 48 | Redistribution and use in source and binary forms, with or without 49 | modification, are permitted provided that the following conditions are met: 50 | 51 | 1. Redistributions of source code must retain the above copyright 52 | notice, this list of conditions and the following disclaimer. 53 | 54 | 2. Redistributions in binary form must reproduce the above copyright 55 | notice, this list of conditions and the following disclaimer in the 56 | documentation and/or other materials provided with the distribution. 57 | 58 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 59 | and IDIAP Research Institute nor the names of its contributors may be 60 | used to endorse or promote products derived from this software without 61 | specific prior written permission. 62 | 63 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 64 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 65 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 66 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 67 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 68 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 69 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 70 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 71 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 72 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 73 | POSSIBILITY OF SUCH DAMAGE. 74 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rsl-rl-lib" 7 | version = "2.3.3" 8 | keywords = ["reinforcement-learning", "isaac", "leggedrobotics", "rl-pytorch"] 9 | maintainers = [ 10 | { name="Clemens Schwarke", email="cschwarke@ethz.ch" }, 11 | { name="Mayank Mittal", email="mittalma@ethz.ch" }, 12 | ] 13 | authors = [ 14 | { name="Clemens Schwarke", email="cschwarke@ethz.ch" }, 15 | { name="Mayank Mittal", email="mittalma@ethz.ch" }, 16 | { name="Nikita Rudin", email="rudinn@ethz.ch" }, 17 | { name="David Hoeller", email="holler.david78@gmail.com" }, 18 | ] 19 | description = "Fast and simple RL algorithms implemented in PyTorch" 20 | readme = { file = "README.md", content-type = "text/markdown"} 21 | license = { text = "BSD-3-Clause" } 22 | 23 | requires-python = ">=3.8" 24 | classifiers = [ 25 | "Programming Language :: Python :: 3", 26 | "Operating System :: OS Independent", 27 | ] 28 | dependencies = [ 29 | "torch>=1.10.0", 30 | "torchvision>=0.5.0", 31 | "numpy>=1.16.4", 32 | "GitPython", 33 | "onnx", 34 | ] 35 | 36 | [project.urls] 37 | Homepage = "https://github.com/leggedrobotics/rsl_rl" 38 | Issues = "https://github.com/leggedrobotics/rsl_rl/issues" 39 | 40 | [tool.setuptools.packages.find] 41 | where = ["."] 42 | include = ["rsl_rl*"] 43 | 44 | [tool.setuptools.package-data] 45 | "rsl_rl" = ["config/*", "licenses/*"] 46 | 47 | [tool.isort] 48 | 49 | py_version = 37 50 | line_length = 120 51 | group_by_package = true 52 | 53 | # Files to skip 54 | skip_glob = [".vscode/*"] 55 | 56 | # Order of imports 57 | sections = [ 58 | "FUTURE", 59 | "STDLIB", 60 | "THIRDPARTY", 61 | "FIRSTPARTY", 62 | "LOCALFOLDER", 63 | ] 64 | 65 | # Extra standard libraries considered as part of python (permissive licenses) 66 | extra_standard_library = [ 67 | "numpy", 68 | "torch", 69 | "tensordict", 70 | "warp", 71 | "typing_extensions", 72 | "git", 73 | ] 74 | # Imports from this repository 75 | known_first_party = "rsl_rl" 76 | 77 | [tool.pyright] 78 | 79 | include = ["rsl_rl"] 80 | 81 | typeCheckingMode = "basic" 82 | pythonVersion = "3.7" 83 | pythonPlatform = "Linux" 84 | enableTypeIgnoreComments = true 85 | 86 | # This is required as the CI pre-commit does not download the module (i.e. numpy, torch, prettytable) 87 | # Therefore, we have to ignore missing imports 88 | reportMissingImports = "none" 89 | # This is required to ignore for type checks of modules with stubs missing. 90 | reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers 91 | 92 | reportGeneralTypeIssues = "none" # -> raises 218 errors (usage of literal MISSING in dataclasses) 93 | reportOptionalMemberAccess = "warning" # -> raises 8 errors 94 | reportPrivateUsage = "warning" 95 | -------------------------------------------------------------------------------- /rsl_rl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Main module for the rsl_rl package.""" 7 | -------------------------------------------------------------------------------- /rsl_rl/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Implementation of different RL agents.""" 7 | 8 | from .distillation import Distillation 9 | from .ppo import PPO 10 | 11 | __all__ = ["PPO", "Distillation"] 12 | -------------------------------------------------------------------------------- /rsl_rl/algorithms/distillation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | # torch 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | 11 | # rsl-rl 12 | from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent 13 | from rsl_rl.storage import RolloutStorage 14 | 15 | 16 | class Distillation: 17 | """Distillation algorithm for training a student model to mimic a teacher model.""" 18 | 19 | policy: StudentTeacher | StudentTeacherRecurrent 20 | """The student teacher model.""" 21 | 22 | def __init__( 23 | self, 24 | policy, 25 | num_learning_epochs=1, 26 | gradient_length=15, 27 | learning_rate=1e-3, 28 | max_grad_norm=None, 29 | loss_type="mse", 30 | device="cpu", 31 | # Distributed training parameters 32 | multi_gpu_cfg: dict | None = None, 33 | ): 34 | # device-related parameters 35 | self.device = device 36 | self.is_multi_gpu = multi_gpu_cfg is not None 37 | # Multi-GPU parameters 38 | if multi_gpu_cfg is not None: 39 | self.gpu_global_rank = multi_gpu_cfg["global_rank"] 40 | self.gpu_world_size = multi_gpu_cfg["world_size"] 41 | else: 42 | self.gpu_global_rank = 0 43 | self.gpu_world_size = 1 44 | 45 | self.rnd = None # TODO: remove when runner has a proper base class 46 | 47 | # distillation components 48 | self.policy = policy 49 | self.policy.to(self.device) 50 | self.storage = None # initialized later 51 | self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate) 52 | self.transition = RolloutStorage.Transition() 53 | self.last_hidden_states = None 54 | 55 | # distillation parameters 56 | self.num_learning_epochs = num_learning_epochs 57 | self.gradient_length = gradient_length 58 | self.learning_rate = learning_rate 59 | self.max_grad_norm = max_grad_norm 60 | 61 | # initialize the loss function 62 | if loss_type == "mse": 63 | self.loss_fn = nn.functional.mse_loss 64 | elif loss_type == "huber": 65 | self.loss_fn = nn.functional.huber_loss 66 | else: 67 | raise ValueError(f"Unknown loss type: {loss_type}. Supported types are: mse, huber") 68 | 69 | self.num_updates = 0 70 | 71 | def init_storage( 72 | self, training_type, num_envs, num_transitions_per_env, student_obs_shape, teacher_obs_shape, actions_shape 73 | ): 74 | # create rollout storage 75 | self.storage = RolloutStorage( 76 | training_type, 77 | num_envs, 78 | num_transitions_per_env, 79 | student_obs_shape, 80 | teacher_obs_shape, 81 | actions_shape, 82 | None, 83 | self.device, 84 | ) 85 | 86 | def act(self, obs, teacher_obs): 87 | # compute the actions 88 | self.transition.actions = self.policy.act(obs).detach() 89 | self.transition.privileged_actions = self.policy.evaluate(teacher_obs).detach() 90 | # record the observations 91 | self.transition.observations = obs 92 | self.transition.privileged_observations = teacher_obs 93 | return self.transition.actions 94 | 95 | def process_env_step(self, rewards, dones, infos): 96 | # record the rewards and dones 97 | self.transition.rewards = rewards 98 | self.transition.dones = dones 99 | # record the transition 100 | self.storage.add_transitions(self.transition) 101 | self.transition.clear() 102 | self.policy.reset(dones) 103 | 104 | def update(self): 105 | self.num_updates += 1 106 | mean_behavior_loss = 0 107 | loss = 0 108 | cnt = 0 109 | 110 | for epoch in range(self.num_learning_epochs): 111 | self.policy.reset(hidden_states=self.last_hidden_states) 112 | self.policy.detach_hidden_states() 113 | for obs, _, _, privileged_actions, dones in self.storage.generator(): 114 | 115 | # inference the student for gradient computation 116 | actions = self.policy.act_inference(obs) 117 | 118 | # behavior cloning loss 119 | behavior_loss = self.loss_fn(actions, privileged_actions) 120 | 121 | # total loss 122 | loss = loss + behavior_loss 123 | mean_behavior_loss += behavior_loss.item() 124 | cnt += 1 125 | 126 | # gradient step 127 | if cnt % self.gradient_length == 0: 128 | self.optimizer.zero_grad() 129 | loss.backward() 130 | if self.is_multi_gpu: 131 | self.reduce_parameters() 132 | if self.max_grad_norm: 133 | nn.utils.clip_grad_norm_(self.policy.student.parameters(), self.max_grad_norm) 134 | self.optimizer.step() 135 | self.policy.detach_hidden_states() 136 | loss = 0 137 | 138 | # reset dones 139 | self.policy.reset(dones.view(-1)) 140 | self.policy.detach_hidden_states(dones.view(-1)) 141 | 142 | mean_behavior_loss /= cnt 143 | self.storage.clear() 144 | self.last_hidden_states = self.policy.get_hidden_states() 145 | self.policy.detach_hidden_states() 146 | 147 | # construct the loss dictionary 148 | loss_dict = {"behavior": mean_behavior_loss} 149 | 150 | return loss_dict 151 | 152 | """ 153 | Helper functions 154 | """ 155 | 156 | def broadcast_parameters(self): 157 | """Broadcast model parameters to all GPUs.""" 158 | # obtain the model parameters on current GPU 159 | model_params = [self.policy.state_dict()] 160 | # broadcast the model parameters 161 | torch.distributed.broadcast_object_list(model_params, src=0) 162 | # load the model parameters on all GPUs from source GPU 163 | self.policy.load_state_dict(model_params[0]) 164 | 165 | def reduce_parameters(self): 166 | """Collect gradients from all GPUs and average them. 167 | 168 | This function is called after the backward pass to synchronize the gradients across all GPUs. 169 | """ 170 | # Create a tensor to store the gradients 171 | grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None] 172 | all_grads = torch.cat(grads) 173 | # Average the gradients across all GPUs 174 | torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM) 175 | all_grads /= self.gpu_world_size 176 | # Update the gradients for all parameters with the reduced gradients 177 | offset = 0 178 | for param in self.policy.parameters(): 179 | if param.grad is not None: 180 | numel = param.numel() 181 | # copy data back from shared buffer 182 | param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data)) 183 | # update the offset for the next parameter 184 | offset += numel 185 | -------------------------------------------------------------------------------- /rsl_rl/algorithms/ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from itertools import chain 12 | 13 | from rsl_rl.modules import ActorCritic 14 | from rsl_rl.modules.rnd import RandomNetworkDistillation 15 | from rsl_rl.storage import RolloutStorage 16 | from rsl_rl.utils import string_to_callable 17 | 18 | 19 | class PPO: 20 | """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347).""" 21 | 22 | policy: ActorCritic 23 | """The actor critic module.""" 24 | 25 | def __init__( 26 | self, 27 | policy, 28 | num_learning_epochs=1, 29 | num_mini_batches=1, 30 | clip_param=0.2, 31 | gamma=0.998, 32 | lam=0.95, 33 | value_loss_coef=1.0, 34 | entropy_coef=0.0, 35 | learning_rate=1e-3, 36 | max_grad_norm=1.0, 37 | use_clipped_value_loss=True, 38 | schedule="fixed", 39 | desired_kl=0.01, 40 | device="cpu", 41 | normalize_advantage_per_mini_batch=False, 42 | # RND parameters 43 | rnd_cfg: dict | None = None, 44 | # Symmetry parameters 45 | symmetry_cfg: dict | None = None, 46 | # Distributed training parameters 47 | multi_gpu_cfg: dict | None = None, 48 | ): 49 | # device-related parameters 50 | self.device = device 51 | self.is_multi_gpu = multi_gpu_cfg is not None 52 | # Multi-GPU parameters 53 | if multi_gpu_cfg is not None: 54 | self.gpu_global_rank = multi_gpu_cfg["global_rank"] 55 | self.gpu_world_size = multi_gpu_cfg["world_size"] 56 | else: 57 | self.gpu_global_rank = 0 58 | self.gpu_world_size = 1 59 | 60 | # RND components 61 | if rnd_cfg is not None: 62 | # Extract learning rate and remove it from the original dict 63 | learning_rate = rnd_cfg.pop("learning_rate", 1e-3) 64 | # Create RND module 65 | self.rnd = RandomNetworkDistillation(device=self.device, **rnd_cfg) 66 | # Create RND optimizer 67 | params = self.rnd.predictor.parameters() 68 | self.rnd_optimizer = optim.Adam(params, lr=learning_rate) 69 | else: 70 | self.rnd = None 71 | self.rnd_optimizer = None 72 | 73 | # Symmetry components 74 | if symmetry_cfg is not None: 75 | # Check if symmetry is enabled 76 | use_symmetry = symmetry_cfg["use_data_augmentation"] or symmetry_cfg["use_mirror_loss"] 77 | # Print that we are not using symmetry 78 | if not use_symmetry: 79 | print("Symmetry not used for learning. We will use it for logging instead.") 80 | # If function is a string then resolve it to a function 81 | if isinstance(symmetry_cfg["data_augmentation_func"], str): 82 | symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"]) 83 | # Check valid configuration 84 | if symmetry_cfg["use_data_augmentation"] and not callable(symmetry_cfg["data_augmentation_func"]): 85 | raise ValueError( 86 | "Data augmentation enabled but the function is not callable:" 87 | f" {symmetry_cfg['data_augmentation_func']}" 88 | ) 89 | # Store symmetry configuration 90 | self.symmetry = symmetry_cfg 91 | else: 92 | self.symmetry = None 93 | 94 | # PPO components 95 | self.policy = policy 96 | self.policy.to(self.device) 97 | # Create optimizer 98 | self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate) 99 | # Create rollout storage 100 | self.storage: RolloutStorage = None # type: ignore 101 | self.transition = RolloutStorage.Transition() 102 | 103 | # PPO parameters 104 | self.clip_param = clip_param 105 | self.num_learning_epochs = num_learning_epochs 106 | self.num_mini_batches = num_mini_batches 107 | self.value_loss_coef = value_loss_coef 108 | self.entropy_coef = entropy_coef 109 | self.gamma = gamma 110 | self.lam = lam 111 | self.max_grad_norm = max_grad_norm 112 | self.use_clipped_value_loss = use_clipped_value_loss 113 | self.desired_kl = desired_kl 114 | self.schedule = schedule 115 | self.learning_rate = learning_rate 116 | self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch 117 | 118 | def init_storage( 119 | self, training_type, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, actions_shape 120 | ): 121 | # create memory for RND as well :) 122 | if self.rnd: 123 | rnd_state_shape = [self.rnd.num_states] 124 | else: 125 | rnd_state_shape = None 126 | # create rollout storage 127 | self.storage = RolloutStorage( 128 | training_type, 129 | num_envs, 130 | num_transitions_per_env, 131 | actor_obs_shape, 132 | critic_obs_shape, 133 | actions_shape, 134 | rnd_state_shape, 135 | self.device, 136 | ) 137 | 138 | def act(self, obs, critic_obs): 139 | if self.policy.is_recurrent: 140 | self.transition.hidden_states = self.policy.get_hidden_states() 141 | # compute the actions and values 142 | self.transition.actions = self.policy.act(obs).detach() 143 | self.transition.values = self.policy.evaluate(critic_obs).detach() 144 | self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach() 145 | self.transition.action_mean = self.policy.action_mean.detach() 146 | self.transition.action_sigma = self.policy.action_std.detach() 147 | # need to record obs and critic_obs before env.step() 148 | self.transition.observations = obs 149 | self.transition.privileged_observations = critic_obs 150 | return self.transition.actions 151 | 152 | def process_env_step(self, rewards, dones, infos): 153 | # Record the rewards and dones 154 | # Note: we clone here because later on we bootstrap the rewards based on timeouts 155 | self.transition.rewards = rewards.clone() 156 | self.transition.dones = dones 157 | 158 | # Compute the intrinsic rewards and add to extrinsic rewards 159 | if self.rnd: 160 | # Obtain curiosity gates / observations from infos 161 | rnd_state = infos["observations"]["rnd_state"] 162 | # Compute the intrinsic rewards 163 | # note: rnd_state is the gated_state after normalization if normalization is used 164 | self.intrinsic_rewards, rnd_state = self.rnd.get_intrinsic_reward(rnd_state) 165 | # Add intrinsic rewards to extrinsic rewards 166 | self.transition.rewards += self.intrinsic_rewards 167 | # Record the curiosity gates 168 | self.transition.rnd_state = rnd_state.clone() 169 | 170 | # Bootstrapping on time outs 171 | if "time_outs" in infos: 172 | self.transition.rewards += self.gamma * torch.squeeze( 173 | self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1 174 | ) 175 | 176 | # record the transition 177 | self.storage.add_transitions(self.transition) 178 | self.transition.clear() 179 | self.policy.reset(dones) 180 | 181 | def compute_returns(self, last_critic_obs): 182 | # compute value for the last step 183 | last_values = self.policy.evaluate(last_critic_obs).detach() 184 | self.storage.compute_returns( 185 | last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch 186 | ) 187 | 188 | def update(self): # noqa: C901 189 | mean_value_loss = 0 190 | mean_surrogate_loss = 0 191 | mean_entropy = 0 192 | # -- RND loss 193 | if self.rnd: 194 | mean_rnd_loss = 0 195 | else: 196 | mean_rnd_loss = None 197 | # -- Symmetry loss 198 | if self.symmetry: 199 | mean_symmetry_loss = 0 200 | else: 201 | mean_symmetry_loss = None 202 | 203 | # generator for mini batches 204 | if self.policy.is_recurrent: 205 | generator = self.storage.recurrent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) 206 | else: 207 | generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) 208 | 209 | # iterate over batches 210 | for ( 211 | obs_batch, 212 | critic_obs_batch, 213 | actions_batch, 214 | target_values_batch, 215 | advantages_batch, 216 | returns_batch, 217 | old_actions_log_prob_batch, 218 | old_mu_batch, 219 | old_sigma_batch, 220 | hid_states_batch, 221 | masks_batch, 222 | rnd_state_batch, 223 | ) in generator: 224 | 225 | # number of augmentations per sample 226 | # we start with 1 and increase it if we use symmetry augmentation 227 | num_aug = 1 228 | # original batch size 229 | original_batch_size = obs_batch.shape[0] 230 | 231 | # check if we should normalize advantages per mini batch 232 | if self.normalize_advantage_per_mini_batch: 233 | with torch.no_grad(): 234 | advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8) 235 | 236 | # Perform symmetric augmentation 237 | if self.symmetry and self.symmetry["use_data_augmentation"]: 238 | # augmentation using symmetry 239 | data_augmentation_func = self.symmetry["data_augmentation_func"] 240 | # returned shape: [batch_size * num_aug, ...] 241 | obs_batch, actions_batch = data_augmentation_func( 242 | obs=obs_batch, actions=actions_batch, env=self.symmetry["_env"], obs_type="policy" 243 | ) 244 | critic_obs_batch, _ = data_augmentation_func( 245 | obs=critic_obs_batch, actions=None, env=self.symmetry["_env"], obs_type="critic" 246 | ) 247 | # compute number of augmentations per sample 248 | num_aug = int(obs_batch.shape[0] / original_batch_size) 249 | # repeat the rest of the batch 250 | # -- actor 251 | old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1) 252 | # -- critic 253 | target_values_batch = target_values_batch.repeat(num_aug, 1) 254 | advantages_batch = advantages_batch.repeat(num_aug, 1) 255 | returns_batch = returns_batch.repeat(num_aug, 1) 256 | 257 | # Recompute actions log prob and entropy for current batch of transitions 258 | # Note: we need to do this because we updated the policy with the new parameters 259 | # -- actor 260 | self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) 261 | actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch) 262 | # -- critic 263 | value_batch = self.policy.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]) 264 | # -- entropy 265 | # we only keep the entropy of the first augmentation (the original one) 266 | mu_batch = self.policy.action_mean[:original_batch_size] 267 | sigma_batch = self.policy.action_std[:original_batch_size] 268 | entropy_batch = self.policy.entropy[:original_batch_size] 269 | 270 | # KL 271 | if self.desired_kl is not None and self.schedule == "adaptive": 272 | with torch.inference_mode(): 273 | kl = torch.sum( 274 | torch.log(sigma_batch / old_sigma_batch + 1.0e-5) 275 | + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) 276 | / (2.0 * torch.square(sigma_batch)) 277 | - 0.5, 278 | axis=-1, 279 | ) 280 | kl_mean = torch.mean(kl) 281 | 282 | # Reduce the KL divergence across all GPUs 283 | if self.is_multi_gpu: 284 | torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM) 285 | kl_mean /= self.gpu_world_size 286 | 287 | # Update the learning rate 288 | # Perform this adaptation only on the main process 289 | # TODO: Is this needed? If KL-divergence is the "same" across all GPUs, 290 | # then the learning rate should be the same across all GPUs. 291 | if self.gpu_global_rank == 0: 292 | if kl_mean > self.desired_kl * 2.0: 293 | self.learning_rate = max(1e-5, self.learning_rate / 1.5) 294 | elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: 295 | self.learning_rate = min(1e-2, self.learning_rate * 1.5) 296 | 297 | # Update the learning rate for all GPUs 298 | if self.is_multi_gpu: 299 | lr_tensor = torch.tensor(self.learning_rate, device=self.device) 300 | torch.distributed.broadcast(lr_tensor, src=0) 301 | self.learning_rate = lr_tensor.item() 302 | 303 | # Update the learning rate for all parameter groups 304 | for param_group in self.optimizer.param_groups: 305 | param_group["lr"] = self.learning_rate 306 | 307 | # Surrogate loss 308 | ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) 309 | surrogate = -torch.squeeze(advantages_batch) * ratio 310 | surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp( 311 | ratio, 1.0 - self.clip_param, 1.0 + self.clip_param 312 | ) 313 | surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() 314 | 315 | # Value function loss 316 | if self.use_clipped_value_loss: 317 | value_clipped = target_values_batch + (value_batch - target_values_batch).clamp( 318 | -self.clip_param, self.clip_param 319 | ) 320 | value_losses = (value_batch - returns_batch).pow(2) 321 | value_losses_clipped = (value_clipped - returns_batch).pow(2) 322 | value_loss = torch.max(value_losses, value_losses_clipped).mean() 323 | else: 324 | value_loss = (returns_batch - value_batch).pow(2).mean() 325 | 326 | loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean() 327 | 328 | # Symmetry loss 329 | if self.symmetry: 330 | # obtain the symmetric actions 331 | # if we did augmentation before then we don't need to augment again 332 | if not self.symmetry["use_data_augmentation"]: 333 | data_augmentation_func = self.symmetry["data_augmentation_func"] 334 | obs_batch, _ = data_augmentation_func( 335 | obs=obs_batch, actions=None, env=self.symmetry["_env"], obs_type="policy" 336 | ) 337 | # compute number of augmentations per sample 338 | num_aug = int(obs_batch.shape[0] / original_batch_size) 339 | 340 | # actions predicted by the actor for symmetrically-augmented observations 341 | mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone()) 342 | 343 | # compute the symmetrically augmented actions 344 | # note: we are assuming the first augmentation is the original one. 345 | # We do not use the action_batch from earlier since that action was sampled from the distribution. 346 | # However, the symmetry loss is computed using the mean of the distribution. 347 | action_mean_orig = mean_actions_batch[:original_batch_size] 348 | _, actions_mean_symm_batch = data_augmentation_func( 349 | obs=None, actions=action_mean_orig, env=self.symmetry["_env"], obs_type="policy" 350 | ) 351 | 352 | # compute the loss (we skip the first augmentation as it is the original one) 353 | mse_loss = torch.nn.MSELoss() 354 | symmetry_loss = mse_loss( 355 | mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:] 356 | ) 357 | # add the loss to the total loss 358 | if self.symmetry["use_mirror_loss"]: 359 | loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss 360 | else: 361 | symmetry_loss = symmetry_loss.detach() 362 | 363 | # Random Network Distillation loss 364 | if self.rnd: 365 | # predict the embedding and the target 366 | predicted_embedding = self.rnd.predictor(rnd_state_batch) 367 | target_embedding = self.rnd.target(rnd_state_batch).detach() 368 | # compute the loss as the mean squared error 369 | mseloss = torch.nn.MSELoss() 370 | rnd_loss = mseloss(predicted_embedding, target_embedding) 371 | 372 | # Compute the gradients 373 | # -- For PPO 374 | self.optimizer.zero_grad() 375 | loss.backward() 376 | # -- For RND 377 | if self.rnd: 378 | self.rnd_optimizer.zero_grad() # type: ignore 379 | rnd_loss.backward() 380 | 381 | # Collect gradients from all GPUs 382 | if self.is_multi_gpu: 383 | self.reduce_parameters() 384 | 385 | # Apply the gradients 386 | # -- For PPO 387 | nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) 388 | self.optimizer.step() 389 | # -- For RND 390 | if self.rnd_optimizer: 391 | self.rnd_optimizer.step() 392 | 393 | # Store the losses 394 | mean_value_loss += value_loss.item() 395 | mean_surrogate_loss += surrogate_loss.item() 396 | mean_entropy += entropy_batch.mean().item() 397 | # -- RND loss 398 | if mean_rnd_loss is not None: 399 | mean_rnd_loss += rnd_loss.item() 400 | # -- Symmetry loss 401 | if mean_symmetry_loss is not None: 402 | mean_symmetry_loss += symmetry_loss.item() 403 | 404 | # -- For PPO 405 | num_updates = self.num_learning_epochs * self.num_mini_batches 406 | mean_value_loss /= num_updates 407 | mean_surrogate_loss /= num_updates 408 | mean_entropy /= num_updates 409 | # -- For RND 410 | if mean_rnd_loss is not None: 411 | mean_rnd_loss /= num_updates 412 | # -- For Symmetry 413 | if mean_symmetry_loss is not None: 414 | mean_symmetry_loss /= num_updates 415 | # -- Clear the storage 416 | self.storage.clear() 417 | 418 | # construct the loss dictionary 419 | loss_dict = { 420 | "value_function": mean_value_loss, 421 | "surrogate": mean_surrogate_loss, 422 | "entropy": mean_entropy, 423 | } 424 | if self.rnd: 425 | loss_dict["rnd"] = mean_rnd_loss 426 | if self.symmetry: 427 | loss_dict["symmetry"] = mean_symmetry_loss 428 | 429 | return loss_dict 430 | 431 | """ 432 | Helper functions 433 | """ 434 | 435 | def broadcast_parameters(self): 436 | """Broadcast model parameters to all GPUs.""" 437 | # obtain the model parameters on current GPU 438 | model_params = [self.policy.state_dict()] 439 | if self.rnd: 440 | model_params.append(self.rnd.predictor.state_dict()) 441 | # broadcast the model parameters 442 | torch.distributed.broadcast_object_list(model_params, src=0) 443 | # load the model parameters on all GPUs from source GPU 444 | self.policy.load_state_dict(model_params[0]) 445 | if self.rnd: 446 | self.rnd.predictor.load_state_dict(model_params[1]) 447 | 448 | def reduce_parameters(self): 449 | """Collect gradients from all GPUs and average them. 450 | 451 | This function is called after the backward pass to synchronize the gradients across all GPUs. 452 | """ 453 | # Create a tensor to store the gradients 454 | grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None] 455 | if self.rnd: 456 | grads += [param.grad.view(-1) for param in self.rnd.parameters() if param.grad is not None] 457 | all_grads = torch.cat(grads) 458 | 459 | # Average the gradients across all GPUs 460 | torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM) 461 | all_grads /= self.gpu_world_size 462 | 463 | # Get all parameters 464 | all_params = self.policy.parameters() 465 | if self.rnd: 466 | all_params = chain(all_params, self.rnd.parameters()) 467 | 468 | # Update the gradients for all parameters with the reduced gradients 469 | offset = 0 470 | for param in all_params: 471 | if param.grad is not None: 472 | numel = param.numel() 473 | # copy data back from shared buffer 474 | param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data)) 475 | # update the offset for the next parameter 476 | offset += numel 477 | -------------------------------------------------------------------------------- /rsl_rl/env/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Submodule defining the environment definitions.""" 7 | 8 | from .vec_env import VecEnv 9 | 10 | __all__ = ["VecEnv"] 11 | -------------------------------------------------------------------------------- /rsl_rl/env/vec_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | from abc import ABC, abstractmethod 10 | 11 | 12 | class VecEnv(ABC): 13 | """Abstract class for vectorized environment. 14 | 15 | The vectorized environment is a collection of environments that are synchronized. This means that 16 | the same action is applied to all environments and the same observation is returned from all environments. 17 | 18 | All extra observations must be provided as a dictionary to "extras" in the step() method. Based on the 19 | configuration, the extra observations are used for different purposes. The following keys are used by the 20 | environment: 21 | 22 | - "observations" (dict[str, dict[str, torch.Tensor]]): 23 | Additional observations that are not used by the actor networks. The keys are the names of the observations 24 | and the values are the observations themselves. The following are reserved keys for the observations: 25 | 26 | - "critic": The observation is used as input to the critic network. Useful for asymmetric observation spaces. 27 | - "rnd_state": The observation is used as input to the RND network. Useful for random network distillation. 28 | 29 | - "time_outs" (torch.Tensor): Timeouts for the environments. These correspond to terminations that happen due to time limits and 30 | not due to the environment reaching a terminal state. This is useful for environments that have a fixed 31 | episode length. 32 | 33 | - "log" (dict[str, float | torch.Tensor]): Additional information for logging and debugging purposes. 34 | The key should be a string and start with "/" for namespacing. The value can be a scalar or a tensor. 35 | If it is a tensor, the mean of the tensor is used for logging. 36 | 37 | .. deprecated:: 2.0.0 38 | 39 | Use "log" in the extra information dictionary instead of the "episode" key. 40 | 41 | """ 42 | 43 | num_envs: int 44 | """Number of environments.""" 45 | 46 | num_actions: int 47 | """Number of actions.""" 48 | 49 | max_episode_length: int | torch.Tensor 50 | """Maximum episode length. 51 | 52 | The maximum episode length can be a scalar or a tensor. If it is a scalar, it is the same for all environments. 53 | If it is a tensor, it is the maximum episode length for each environment. This is useful for dynamic episode 54 | lengths. 55 | """ 56 | 57 | episode_length_buf: torch.Tensor 58 | """Buffer for current episode lengths.""" 59 | 60 | device: torch.device 61 | """Device to use.""" 62 | 63 | cfg: dict | object 64 | """Configuration object.""" 65 | 66 | """ 67 | Operations. 68 | """ 69 | 70 | @abstractmethod 71 | def get_observations(self) -> tuple[torch.Tensor, dict]: 72 | """Return the current observations. 73 | 74 | Returns: 75 | Tuple containing the observations and extras. 76 | """ 77 | raise NotImplementedError 78 | 79 | @abstractmethod 80 | def reset(self) -> tuple[torch.Tensor, dict]: 81 | """Reset all environment instances. 82 | 83 | Returns: 84 | Tuple containing the observations and extras. 85 | """ 86 | raise NotImplementedError 87 | 88 | @abstractmethod 89 | def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: 90 | """Apply input action on the environment. 91 | 92 | The extra information is a dictionary. It includes metrics such as the episode reward, episode length, 93 | etc. Additional information can be stored in the dictionary such as observations for the critic network, etc. 94 | 95 | Args: 96 | actions: Input actions to apply. Shape: (num_envs, num_actions) 97 | 98 | Returns: 99 | A tuple containing the observations, rewards, dones and extra information (metrics). 100 | """ 101 | raise NotImplementedError 102 | -------------------------------------------------------------------------------- /rsl_rl/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Definitions for neural-network components for RL-agents.""" 7 | 8 | from .actor_critic import ActorCritic 9 | from .actor_critic_recurrent import ActorCriticRecurrent 10 | from .normalizer import EmpiricalNormalization 11 | from .rnd import RandomNetworkDistillation 12 | from .student_teacher import StudentTeacher 13 | from .student_teacher_recurrent import StudentTeacherRecurrent 14 | 15 | __all__ = [ 16 | "ActorCritic", 17 | "ActorCriticRecurrent", 18 | "EmpiricalNormalization", 19 | "RandomNetworkDistillation", 20 | "StudentTeacher", 21 | "StudentTeacherRecurrent", 22 | ] 23 | -------------------------------------------------------------------------------- /rsl_rl/modules/actor_critic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.distributions import Normal 11 | 12 | from rsl_rl.utils import resolve_nn_activation 13 | 14 | 15 | class ActorCritic(nn.Module): 16 | is_recurrent = False 17 | 18 | def __init__( 19 | self, 20 | num_actor_obs, 21 | num_critic_obs, 22 | num_actions, 23 | actor_hidden_dims=[256, 256, 256], 24 | critic_hidden_dims=[256, 256, 256], 25 | activation="elu", 26 | init_noise_std=1.0, 27 | noise_std_type: str = "scalar", 28 | **kwargs, 29 | ): 30 | if kwargs: 31 | print( 32 | "ActorCritic.__init__ got unexpected arguments, which will be ignored: " 33 | + str([key for key in kwargs.keys()]) 34 | ) 35 | super().__init__() 36 | activation = resolve_nn_activation(activation) 37 | 38 | mlp_input_dim_a = num_actor_obs 39 | mlp_input_dim_c = num_critic_obs 40 | # Policy 41 | actor_layers = [] 42 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 43 | actor_layers.append(activation) 44 | for layer_index in range(len(actor_hidden_dims)): 45 | if layer_index == len(actor_hidden_dims) - 1: 46 | actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions)) 47 | else: 48 | actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1])) 49 | actor_layers.append(activation) 50 | self.actor = nn.Sequential(*actor_layers) 51 | 52 | # Value function 53 | critic_layers = [] 54 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 55 | critic_layers.append(activation) 56 | for layer_index in range(len(critic_hidden_dims)): 57 | if layer_index == len(critic_hidden_dims) - 1: 58 | critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1)) 59 | else: 60 | critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1])) 61 | critic_layers.append(activation) 62 | self.critic = nn.Sequential(*critic_layers) 63 | 64 | print(f"Actor MLP: {self.actor}") 65 | print(f"Critic MLP: {self.critic}") 66 | 67 | # Action noise 68 | self.noise_std_type = noise_std_type 69 | if self.noise_std_type == "scalar": 70 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 71 | elif self.noise_std_type == "log": 72 | self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) 73 | else: 74 | raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") 75 | 76 | # Action distribution (populated in update_distribution) 77 | self.distribution = None 78 | # disable args validation for speedup 79 | Normal.set_default_validate_args(False) 80 | 81 | @staticmethod 82 | # not used at the moment 83 | def init_weights(sequential, scales): 84 | [ 85 | torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) 86 | for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear)) 87 | ] 88 | 89 | def reset(self, dones=None): 90 | pass 91 | 92 | def forward(self): 93 | raise NotImplementedError 94 | 95 | @property 96 | def action_mean(self): 97 | return self.distribution.mean 98 | 99 | @property 100 | def action_std(self): 101 | return self.distribution.stddev 102 | 103 | @property 104 | def entropy(self): 105 | return self.distribution.entropy().sum(dim=-1) 106 | 107 | def update_distribution(self, observations): 108 | # compute mean 109 | mean = self.actor(observations) 110 | # compute standard deviation 111 | if self.noise_std_type == "scalar": 112 | std = self.std.expand_as(mean) 113 | elif self.noise_std_type == "log": 114 | std = torch.exp(self.log_std).expand_as(mean) 115 | else: 116 | raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") 117 | # create distribution 118 | self.distribution = Normal(mean, std) 119 | 120 | def act(self, observations, **kwargs): 121 | self.update_distribution(observations) 122 | return self.distribution.sample() 123 | 124 | def get_actions_log_prob(self, actions): 125 | return self.distribution.log_prob(actions).sum(dim=-1) 126 | 127 | def act_inference(self, observations): 128 | actions_mean = self.actor(observations) 129 | return actions_mean 130 | 131 | def evaluate(self, critic_observations, **kwargs): 132 | value = self.critic(critic_observations) 133 | return value 134 | 135 | def load_state_dict(self, state_dict, strict=True): 136 | """Load the parameters of the actor-critic model. 137 | 138 | Args: 139 | state_dict (dict): State dictionary of the model. 140 | strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this 141 | module's state_dict() function. 142 | 143 | Returns: 144 | bool: Whether this training resumes a previous training. This flag is used by the `load()` function of 145 | `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation). 146 | """ 147 | 148 | super().load_state_dict(state_dict, strict=strict) 149 | return True 150 | -------------------------------------------------------------------------------- /rsl_rl/modules/actor_critic_recurrent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import warnings 9 | 10 | from rsl_rl.modules import ActorCritic 11 | from rsl_rl.networks import Memory 12 | from rsl_rl.utils import resolve_nn_activation 13 | 14 | 15 | class ActorCriticRecurrent(ActorCritic): 16 | is_recurrent = True 17 | 18 | def __init__( 19 | self, 20 | num_actor_obs, 21 | num_critic_obs, 22 | num_actions, 23 | actor_hidden_dims=[256, 256, 256], 24 | critic_hidden_dims=[256, 256, 256], 25 | activation="elu", 26 | rnn_type="lstm", 27 | rnn_hidden_dim=256, 28 | rnn_num_layers=1, 29 | init_noise_std=1.0, 30 | **kwargs, 31 | ): 32 | if "rnn_hidden_size" in kwargs: 33 | warnings.warn( 34 | "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. " 35 | "Please use `rnn_hidden_dim` instead.", 36 | DeprecationWarning, 37 | ) 38 | if rnn_hidden_dim == 256: # Only override if the new argument is at its default 39 | rnn_hidden_dim = kwargs.pop("rnn_hidden_size") 40 | if kwargs: 41 | print( 42 | "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()), 43 | ) 44 | 45 | super().__init__( 46 | num_actor_obs=rnn_hidden_dim, 47 | num_critic_obs=rnn_hidden_dim, 48 | num_actions=num_actions, 49 | actor_hidden_dims=actor_hidden_dims, 50 | critic_hidden_dims=critic_hidden_dims, 51 | activation=activation, 52 | init_noise_std=init_noise_std, 53 | ) 54 | 55 | activation = resolve_nn_activation(activation) 56 | 57 | self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim) 58 | self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim) 59 | 60 | print(f"Actor RNN: {self.memory_a}") 61 | print(f"Critic RNN: {self.memory_c}") 62 | 63 | def reset(self, dones=None): 64 | self.memory_a.reset(dones) 65 | self.memory_c.reset(dones) 66 | 67 | def act(self, observations, masks=None, hidden_states=None): 68 | input_a = self.memory_a(observations, masks, hidden_states) 69 | return super().act(input_a.squeeze(0)) 70 | 71 | def act_inference(self, observations): 72 | input_a = self.memory_a(observations) 73 | return super().act_inference(input_a.squeeze(0)) 74 | 75 | def evaluate(self, critic_observations, masks=None, hidden_states=None): 76 | input_c = self.memory_c(critic_observations, masks, hidden_states) 77 | return super().evaluate(input_c.squeeze(0)) 78 | 79 | def get_hidden_states(self): 80 | return self.memory_a.hidden_states, self.memory_c.hidden_states 81 | -------------------------------------------------------------------------------- /rsl_rl/modules/normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | # Copyright (c) 2020 Preferred Networks, Inc. 7 | 8 | from __future__ import annotations 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class EmpiricalNormalization(nn.Module): 15 | """Normalize mean and variance of values based on empirical values.""" 16 | 17 | def __init__(self, shape, eps=1e-2, until=None): 18 | """Initialize EmpiricalNormalization module. 19 | 20 | Args: 21 | shape (int or tuple of int): Shape of input values except batch axis. 22 | eps (float): Small value for stability. 23 | until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes 24 | exceeds it. 25 | """ 26 | super().__init__() 27 | self.eps = eps 28 | self.until = until 29 | self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0)) 30 | self.register_buffer("_var", torch.ones(shape).unsqueeze(0)) 31 | self.register_buffer("_std", torch.ones(shape).unsqueeze(0)) 32 | self.register_buffer("count", torch.tensor(0, dtype=torch.long)) 33 | 34 | @property 35 | def mean(self): 36 | return self._mean.squeeze(0).clone() 37 | 38 | @property 39 | def std(self): 40 | return self._std.squeeze(0).clone() 41 | 42 | def forward(self, x): 43 | """Normalize mean and variance of values based on empirical values. 44 | 45 | Args: 46 | x (ndarray or Variable): Input values 47 | 48 | Returns: 49 | ndarray or Variable: Normalized output values 50 | """ 51 | 52 | if self.training: 53 | self.update(x) 54 | return (x - self._mean) / (self._std + self.eps) 55 | 56 | @torch.jit.unused 57 | def update(self, x): 58 | """Learn input values without computing the output values of them""" 59 | 60 | if self.until is not None and self.count >= self.until: 61 | return 62 | 63 | count_x = x.shape[0] 64 | self.count += count_x 65 | rate = count_x / self.count 66 | 67 | var_x = torch.var(x, dim=0, unbiased=False, keepdim=True) 68 | mean_x = torch.mean(x, dim=0, keepdim=True) 69 | delta_mean = mean_x - self._mean 70 | self._mean += rate * delta_mean 71 | self._var += rate * (var_x - self._var + delta_mean * (mean_x - self._mean)) 72 | self._std = torch.sqrt(self._var) 73 | 74 | @torch.jit.unused 75 | def inverse(self, y): 76 | return y * (self._std + self.eps) + self._mean 77 | 78 | 79 | class EmpiricalDiscountedVariationNormalization(nn.Module): 80 | """Reward normalization from Pathak's large scale study on PPO. 81 | 82 | Reward normalization. Since the reward function is non-stationary, it is useful to normalize 83 | the scale of the rewards so that the value function can learn quickly. We did this by dividing 84 | the rewards by a running estimate of the standard deviation of the sum of discounted rewards. 85 | """ 86 | 87 | def __init__(self, shape, eps=1e-2, gamma=0.99, until=None): 88 | super().__init__() 89 | 90 | self.emp_norm = EmpiricalNormalization(shape, eps, until) 91 | self.disc_avg = DiscountedAverage(gamma) 92 | 93 | def forward(self, rew): 94 | if self.training: 95 | # update discounected rewards 96 | avg = self.disc_avg.update(rew) 97 | 98 | # update moments from discounted rewards 99 | self.emp_norm.update(avg) 100 | 101 | if self.emp_norm._std > 0: 102 | return rew / self.emp_norm._std 103 | else: 104 | return rew 105 | 106 | 107 | class DiscountedAverage: 108 | r"""Discounted average of rewards. 109 | 110 | The discounted average is defined as: 111 | 112 | .. math:: 113 | 114 | \bar{R}_t = \gamma \bar{R}_{t-1} + r_t 115 | 116 | Args: 117 | gamma (float): Discount factor. 118 | """ 119 | 120 | def __init__(self, gamma): 121 | self.avg = None 122 | self.gamma = gamma 123 | 124 | def update(self, rew: torch.Tensor) -> torch.Tensor: 125 | if self.avg is None: 126 | self.avg = rew 127 | else: 128 | self.avg = self.avg * self.gamma + rew 129 | return self.avg 130 | -------------------------------------------------------------------------------- /rsl_rl/modules/rnd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from rsl_rl.modules.normalizer import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization 12 | from rsl_rl.utils import resolve_nn_activation 13 | 14 | 15 | class RandomNetworkDistillation(nn.Module): 16 | """Implementation of Random Network Distillation (RND) [1] 17 | 18 | References: 19 | .. [1] Burda, Yuri, et al. "Exploration by random network distillation." arXiv preprint arXiv:1810.12894 (2018). 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_states: int, 25 | num_outputs: int, 26 | predictor_hidden_dims: list[int], 27 | target_hidden_dims: list[int], 28 | activation: str = "elu", 29 | weight: float = 0.0, 30 | state_normalization: bool = False, 31 | reward_normalization: bool = False, 32 | device: str = "cpu", 33 | weight_schedule: dict | None = None, 34 | ): 35 | """Initialize the RND module. 36 | 37 | - If :attr:`state_normalization` is True, then the input state is normalized using an Empirical Normalization layer. 38 | - If :attr:`reward_normalization` is True, then the intrinsic reward is normalized using an Empirical Discounted 39 | Variation Normalization layer. 40 | 41 | .. note:: 42 | If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states 43 | is used as the hidden dimension. 44 | 45 | Args: 46 | num_states: Number of states/inputs to the predictor and target networks. 47 | num_outputs: Number of outputs (embedding size) of the predictor and target networks. 48 | predictor_hidden_dims: List of hidden dimensions of the predictor network. 49 | target_hidden_dims: List of hidden dimensions of the target network. 50 | activation: Activation function. Defaults to "elu". 51 | weight: Scaling factor of the intrinsic reward. Defaults to 0.0. 52 | state_normalization: Whether to normalize the input state. Defaults to False. 53 | reward_normalization: Whether to normalize the intrinsic reward. Defaults to False. 54 | device: Device to use. Defaults to "cpu". 55 | weight_schedule: The type of schedule to use for the RND weight parameter. 56 | Defaults to None, in which case the weight parameter is constant. 57 | It is a dictionary with the following keys: 58 | 59 | - "mode": The type of schedule to use for the RND weight parameter. 60 | - "constant": Constant weight schedule. 61 | - "step": Step weight schedule. 62 | - "linear": Linear weight schedule. 63 | 64 | For the "step" weight schedule, the following parameters are required: 65 | 66 | - "final_step": The step at which the weight parameter is set to the final value. 67 | - "final_value": The final value of the weight parameter. 68 | 69 | For the "linear" weight schedule, the following parameters are required: 70 | - "initial_step": The step at which the weight parameter is set to the initial value. 71 | - "final_step": The step at which the weight parameter is set to the final value. 72 | - "final_value": The final value of the weight parameter. 73 | """ 74 | # initialize parent class 75 | super().__init__() 76 | 77 | # Store parameters 78 | self.num_states = num_states 79 | self.num_outputs = num_outputs 80 | self.initial_weight = weight 81 | self.device = device 82 | self.state_normalization = state_normalization 83 | self.reward_normalization = reward_normalization 84 | 85 | # Normalization of input gates 86 | if state_normalization: 87 | self.state_normalizer = EmpiricalNormalization(shape=[self.num_states], until=1.0e8).to(self.device) 88 | else: 89 | self.state_normalizer = torch.nn.Identity() 90 | # Normalization of intrinsic reward 91 | if reward_normalization: 92 | self.reward_normalizer = EmpiricalDiscountedVariationNormalization(shape=[], until=1.0e8).to(self.device) 93 | else: 94 | self.reward_normalizer = torch.nn.Identity() 95 | 96 | # counter for the number of updates 97 | self.update_counter = 0 98 | 99 | # resolve weight schedule 100 | if weight_schedule is not None: 101 | self.weight_scheduler_params = weight_schedule 102 | self.weight_scheduler = getattr(self, f"_{weight_schedule['mode']}_weight_schedule") 103 | else: 104 | self.weight_scheduler = None 105 | # Create network architecture 106 | self.predictor = self._build_mlp(num_states, predictor_hidden_dims, num_outputs, activation).to(self.device) 107 | self.target = self._build_mlp(num_states, target_hidden_dims, num_outputs, activation).to(self.device) 108 | 109 | # make target network not trainable 110 | self.target.eval() 111 | 112 | def get_intrinsic_reward(self, rnd_state) -> tuple[torch.Tensor, torch.Tensor]: 113 | # note: the counter is updated number of env steps per learning iteration 114 | self.update_counter += 1 115 | # Normalize rnd state 116 | rnd_state = self.state_normalizer(rnd_state) 117 | # Obtain the embedding of the rnd state from the target and predictor networks 118 | target_embedding = self.target(rnd_state).detach() 119 | predictor_embedding = self.predictor(rnd_state).detach() 120 | # Compute the intrinsic reward as the distance between the embeddings 121 | intrinsic_reward = torch.linalg.norm(target_embedding - predictor_embedding, dim=1) 122 | # Normalize intrinsic reward 123 | intrinsic_reward = self.reward_normalizer(intrinsic_reward) 124 | 125 | # Check the weight schedule 126 | if self.weight_scheduler is not None: 127 | self.weight = self.weight_scheduler(step=self.update_counter, **self.weight_scheduler_params) 128 | else: 129 | self.weight = self.initial_weight 130 | # Scale intrinsic reward 131 | intrinsic_reward *= self.weight 132 | 133 | return intrinsic_reward, rnd_state 134 | 135 | def forward(self, *args, **kwargs): 136 | raise RuntimeError("Forward method is not implemented. Use get_intrinsic_reward instead.") 137 | 138 | def train(self, mode: bool = True): 139 | # sets module into training mode 140 | self.predictor.train(mode) 141 | if self.state_normalization: 142 | self.state_normalizer.train(mode) 143 | if self.reward_normalization: 144 | self.reward_normalizer.train(mode) 145 | return self 146 | 147 | def eval(self): 148 | return self.train(False) 149 | 150 | """ 151 | Private Methods 152 | """ 153 | 154 | @staticmethod 155 | def _build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"): 156 | """Builds target and predictor networks""" 157 | 158 | network_layers = [] 159 | # resolve hidden dimensions 160 | # if dims is -1 then we use the number of observations 161 | hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims] 162 | # resolve activation function 163 | activation = resolve_nn_activation(activation_name) 164 | # first layer 165 | network_layers.append(nn.Linear(input_dims, hidden_dims[0])) 166 | network_layers.append(activation) 167 | # subsequent layers 168 | for layer_index in range(len(hidden_dims)): 169 | if layer_index == len(hidden_dims) - 1: 170 | # last layer 171 | network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims)) 172 | else: 173 | # hidden layers 174 | network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1])) 175 | network_layers.append(activation) 176 | return nn.Sequential(*network_layers) 177 | 178 | """ 179 | Different weight schedules. 180 | """ 181 | 182 | def _constant_weight_schedule(self, step: int, **kwargs): 183 | return self.initial_weight 184 | 185 | def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs): 186 | return self.initial_weight if step < final_step else final_value 187 | 188 | def _linear_weight_schedule(self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs): 189 | if step < initial_step: 190 | return self.initial_weight 191 | elif step > final_step: 192 | return final_value 193 | else: 194 | return self.initial_weight + (final_value - self.initial_weight) * (step - initial_step) / ( 195 | final_step - initial_step 196 | ) 197 | -------------------------------------------------------------------------------- /rsl_rl/modules/student_teacher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.distributions import Normal 11 | 12 | from rsl_rl.utils import resolve_nn_activation 13 | 14 | 15 | class StudentTeacher(nn.Module): 16 | is_recurrent = False 17 | 18 | def __init__( 19 | self, 20 | num_student_obs, 21 | num_teacher_obs, 22 | num_actions, 23 | student_hidden_dims=[256, 256, 256], 24 | teacher_hidden_dims=[256, 256, 256], 25 | activation="elu", 26 | init_noise_std=0.1, 27 | **kwargs, 28 | ): 29 | if kwargs: 30 | print( 31 | "StudentTeacher.__init__ got unexpected arguments, which will be ignored: " 32 | + str([key for key in kwargs.keys()]) 33 | ) 34 | super().__init__() 35 | activation = resolve_nn_activation(activation) 36 | self.loaded_teacher = False # indicates if teacher has been loaded 37 | 38 | mlp_input_dim_s = num_student_obs 39 | mlp_input_dim_t = num_teacher_obs 40 | 41 | # student 42 | student_layers = [] 43 | student_layers.append(nn.Linear(mlp_input_dim_s, student_hidden_dims[0])) 44 | student_layers.append(activation) 45 | for layer_index in range(len(student_hidden_dims)): 46 | if layer_index == len(student_hidden_dims) - 1: 47 | student_layers.append(nn.Linear(student_hidden_dims[layer_index], num_actions)) 48 | else: 49 | student_layers.append(nn.Linear(student_hidden_dims[layer_index], student_hidden_dims[layer_index + 1])) 50 | student_layers.append(activation) 51 | self.student = nn.Sequential(*student_layers) 52 | 53 | # teacher 54 | teacher_layers = [] 55 | teacher_layers.append(nn.Linear(mlp_input_dim_t, teacher_hidden_dims[0])) 56 | teacher_layers.append(activation) 57 | for layer_index in range(len(teacher_hidden_dims)): 58 | if layer_index == len(teacher_hidden_dims) - 1: 59 | teacher_layers.append(nn.Linear(teacher_hidden_dims[layer_index], num_actions)) 60 | else: 61 | teacher_layers.append(nn.Linear(teacher_hidden_dims[layer_index], teacher_hidden_dims[layer_index + 1])) 62 | teacher_layers.append(activation) 63 | self.teacher = nn.Sequential(*teacher_layers) 64 | self.teacher.eval() 65 | 66 | print(f"Student MLP: {self.student}") 67 | print(f"Teacher MLP: {self.teacher}") 68 | 69 | # action noise 70 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 71 | self.distribution = None 72 | # disable args validation for speedup 73 | Normal.set_default_validate_args = False 74 | 75 | def reset(self, dones=None, hidden_states=None): 76 | pass 77 | 78 | def forward(self): 79 | raise NotImplementedError 80 | 81 | @property 82 | def action_mean(self): 83 | return self.distribution.mean 84 | 85 | @property 86 | def action_std(self): 87 | return self.distribution.stddev 88 | 89 | @property 90 | def entropy(self): 91 | return self.distribution.entropy().sum(dim=-1) 92 | 93 | def update_distribution(self, observations): 94 | mean = self.student(observations) 95 | std = self.std.expand_as(mean) 96 | self.distribution = Normal(mean, std) 97 | 98 | def act(self, observations): 99 | self.update_distribution(observations) 100 | return self.distribution.sample() 101 | 102 | def act_inference(self, observations): 103 | actions_mean = self.student(observations) 104 | return actions_mean 105 | 106 | def evaluate(self, teacher_observations): 107 | with torch.no_grad(): 108 | actions = self.teacher(teacher_observations) 109 | return actions 110 | 111 | def load_state_dict(self, state_dict, strict=True): 112 | """Load the parameters of the student and teacher networks. 113 | 114 | Args: 115 | state_dict (dict): State dictionary of the model. 116 | strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this 117 | module's state_dict() function. 118 | 119 | Returns: 120 | bool: Whether this training resumes a previous training. This flag is used by the `load()` function of 121 | `OnPolicyRunner` to determine how to load further parameters. 122 | """ 123 | 124 | # check if state_dict contains teacher and student or just teacher parameters 125 | if any("actor" in key for key in state_dict.keys()): # loading parameters from rl training 126 | # rename keys to match teacher and remove critic parameters 127 | teacher_state_dict = {} 128 | for key, value in state_dict.items(): 129 | if "actor." in key: 130 | teacher_state_dict[key.replace("actor.", "")] = value 131 | self.teacher.load_state_dict(teacher_state_dict, strict=strict) 132 | # also load recurrent memory if teacher is recurrent 133 | if self.is_recurrent and self.teacher_recurrent: 134 | raise NotImplementedError("Loading recurrent memory for the teacher is not implemented yet") # TODO 135 | # set flag for successfully loading the parameters 136 | self.loaded_teacher = True 137 | self.teacher.eval() 138 | return False 139 | elif any("student" in key for key in state_dict.keys()): # loading parameters from distillation training 140 | super().load_state_dict(state_dict, strict=strict) 141 | # set flag for successfully loading the parameters 142 | self.loaded_teacher = True 143 | self.teacher.eval() 144 | return True 145 | else: 146 | raise ValueError("state_dict does not contain student or teacher parameters") 147 | 148 | def get_hidden_states(self): 149 | return None 150 | 151 | def detach_hidden_states(self, dones=None): 152 | pass 153 | -------------------------------------------------------------------------------- /rsl_rl/modules/student_teacher_recurrent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import warnings 9 | 10 | from rsl_rl.modules import StudentTeacher 11 | from rsl_rl.networks import Memory 12 | from rsl_rl.utils import resolve_nn_activation 13 | 14 | 15 | class StudentTeacherRecurrent(StudentTeacher): 16 | is_recurrent = True 17 | 18 | def __init__( 19 | self, 20 | num_student_obs, 21 | num_teacher_obs, 22 | num_actions, 23 | student_hidden_dims=[256, 256, 256], 24 | teacher_hidden_dims=[256, 256, 256], 25 | activation="elu", 26 | rnn_type="lstm", 27 | rnn_hidden_dim=256, 28 | rnn_num_layers=1, 29 | init_noise_std=0.1, 30 | teacher_recurrent=False, 31 | **kwargs, 32 | ): 33 | if "rnn_hidden_size" in kwargs: 34 | warnings.warn( 35 | "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. " 36 | "Please use `rnn_hidden_dim` instead.", 37 | DeprecationWarning, 38 | ) 39 | if rnn_hidden_dim == 256: # Only override if the new argument is at its default 40 | rnn_hidden_dim = kwargs.pop("rnn_hidden_size") 41 | if kwargs: 42 | print( 43 | "StudentTeacherRecurrent.__init__ got unexpected arguments, which will be ignored: " 44 | + str(kwargs.keys()), 45 | ) 46 | 47 | self.teacher_recurrent = teacher_recurrent 48 | 49 | super().__init__( 50 | num_student_obs=rnn_hidden_dim, 51 | num_teacher_obs=rnn_hidden_dim if teacher_recurrent else num_teacher_obs, 52 | num_actions=num_actions, 53 | student_hidden_dims=student_hidden_dims, 54 | teacher_hidden_dims=teacher_hidden_dims, 55 | activation=activation, 56 | init_noise_std=init_noise_std, 57 | ) 58 | 59 | activation = resolve_nn_activation(activation) 60 | 61 | self.memory_s = Memory(num_student_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim) 62 | if self.teacher_recurrent: 63 | self.memory_t = Memory( 64 | num_teacher_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim 65 | ) 66 | 67 | print(f"Student RNN: {self.memory_s}") 68 | if self.teacher_recurrent: 69 | print(f"Teacher RNN: {self.memory_t}") 70 | 71 | def reset(self, dones=None, hidden_states=None): 72 | if hidden_states is None: 73 | hidden_states = (None, None) 74 | self.memory_s.reset(dones, hidden_states[0]) 75 | if self.teacher_recurrent: 76 | self.memory_t.reset(dones, hidden_states[1]) 77 | 78 | def act(self, observations): 79 | input_s = self.memory_s(observations) 80 | return super().act(input_s.squeeze(0)) 81 | 82 | def act_inference(self, observations): 83 | input_s = self.memory_s(observations) 84 | return super().act_inference(input_s.squeeze(0)) 85 | 86 | def evaluate(self, teacher_observations): 87 | if self.teacher_recurrent: 88 | teacher_observations = self.memory_t(teacher_observations) 89 | return super().evaluate(teacher_observations.squeeze(0)) 90 | 91 | def get_hidden_states(self): 92 | if self.teacher_recurrent: 93 | return self.memory_s.hidden_states, self.memory_t.hidden_states 94 | else: 95 | return self.memory_s.hidden_states, None 96 | 97 | def detach_hidden_states(self, dones=None): 98 | self.memory_s.detach_hidden_states(dones) 99 | if self.teacher_recurrent: 100 | self.memory_t.detach_hidden_states(dones) 101 | -------------------------------------------------------------------------------- /rsl_rl/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Definitions for neural networks.""" 7 | 8 | from .memory import Memory 9 | 10 | __all__ = ["Memory"] 11 | -------------------------------------------------------------------------------- /rsl_rl/networks/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from rsl_rl.utils import unpad_trajectories 12 | 13 | 14 | class Memory(torch.nn.Module): 15 | def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256): 16 | super().__init__() 17 | # RNN 18 | rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM 19 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 20 | self.hidden_states = None 21 | 22 | def forward(self, input, masks=None, hidden_states=None): 23 | batch_mode = masks is not None 24 | if batch_mode: 25 | # batch mode: needs saved hidden states 26 | if hidden_states is None: 27 | raise ValueError("Hidden states not passed to memory module during policy update") 28 | out, _ = self.rnn(input, hidden_states) 29 | out = unpad_trajectories(out, masks) 30 | else: 31 | # inference/distillation mode: uses hidden states of last step 32 | out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) 33 | return out 34 | 35 | def reset(self, dones=None, hidden_states=None): 36 | if dones is None: # reset all hidden states 37 | if hidden_states is None: 38 | self.hidden_states = None 39 | else: 40 | self.hidden_states = hidden_states 41 | elif self.hidden_states is not None: # reset hidden states of done environments 42 | if hidden_states is None: 43 | if isinstance(self.hidden_states, tuple): # tuple in case of LSTM 44 | for hidden_state in self.hidden_states: 45 | hidden_state[..., dones == 1, :] = 0.0 46 | else: 47 | self.hidden_states[..., dones == 1, :] = 0.0 48 | else: 49 | NotImplementedError( 50 | "Resetting hidden states of done environments with custom hidden states is not implemented" 51 | ) 52 | 53 | def detach_hidden_states(self, dones=None): 54 | if self.hidden_states is not None: 55 | if dones is None: # detach all hidden states 56 | if isinstance(self.hidden_states, tuple): # tuple in case of LSTM 57 | self.hidden_states = tuple(hidden_state.detach() for hidden_state in self.hidden_states) 58 | else: 59 | self.hidden_states = self.hidden_states.detach() 60 | else: # detach hidden states of done environments 61 | if isinstance(self.hidden_states, tuple): # tuple in case of LSTM 62 | for hidden_state in self.hidden_states: 63 | hidden_state[..., dones == 1, :] = hidden_state[..., dones == 1, :].detach() 64 | else: 65 | self.hidden_states[..., dones == 1, :] = self.hidden_states[..., dones == 1, :].detach() 66 | -------------------------------------------------------------------------------- /rsl_rl/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Implementation of runners for environment-agent interaction.""" 7 | 8 | from .on_policy_runner import OnPolicyRunner 9 | 10 | __all__ = ["OnPolicyRunner"] 11 | -------------------------------------------------------------------------------- /rsl_rl/runners/on_policy_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import os 9 | import statistics 10 | import time 11 | import torch 12 | from collections import deque 13 | 14 | import rsl_rl 15 | from rsl_rl.algorithms import PPO, Distillation 16 | from rsl_rl.env import VecEnv 17 | from rsl_rl.modules import ( 18 | ActorCritic, 19 | ActorCriticRecurrent, 20 | EmpiricalNormalization, 21 | StudentTeacher, 22 | StudentTeacherRecurrent, 23 | ) 24 | from rsl_rl.utils import store_code_state 25 | 26 | 27 | class OnPolicyRunner: 28 | """On-policy runner for training and evaluation.""" 29 | 30 | def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device="cpu"): 31 | self.cfg = train_cfg 32 | self.alg_cfg = train_cfg["algorithm"] 33 | self.policy_cfg = train_cfg["policy"] 34 | self.device = device 35 | self.env = env 36 | 37 | # check if multi-gpu is enabled 38 | self._configure_multi_gpu() 39 | 40 | # resolve training type depending on the algorithm 41 | if self.alg_cfg["class_name"] == "PPO": 42 | self.training_type = "rl" 43 | elif self.alg_cfg["class_name"] == "Distillation": 44 | self.training_type = "distillation" 45 | else: 46 | raise ValueError(f"Training type not found for algorithm {self.alg_cfg['class_name']}.") 47 | 48 | # resolve dimensions of observations 49 | obs, extras = self.env.get_observations() 50 | num_obs = obs.shape[1] 51 | 52 | # resolve type of privileged observations 53 | if self.training_type == "rl": 54 | if "critic" in extras["observations"]: 55 | self.privileged_obs_type = "critic" # actor-critic reinforcement learnig, e.g., PPO 56 | else: 57 | self.privileged_obs_type = None 58 | if self.training_type == "distillation": 59 | if "teacher" in extras["observations"]: 60 | self.privileged_obs_type = "teacher" # policy distillation 61 | else: 62 | self.privileged_obs_type = None 63 | 64 | # resolve dimensions of privileged observations 65 | if self.privileged_obs_type is not None: 66 | num_privileged_obs = extras["observations"][self.privileged_obs_type].shape[1] 67 | else: 68 | num_privileged_obs = num_obs 69 | 70 | # evaluate the policy class 71 | policy_class = eval(self.policy_cfg.pop("class_name")) 72 | policy: ActorCritic | ActorCriticRecurrent | StudentTeacher | StudentTeacherRecurrent = policy_class( 73 | num_obs, num_privileged_obs, self.env.num_actions, **self.policy_cfg 74 | ).to(self.device) 75 | 76 | # resolve dimension of rnd gated state 77 | if "rnd_cfg" in self.alg_cfg and self.alg_cfg["rnd_cfg"] is not None: 78 | # check if rnd gated state is present 79 | rnd_state = extras["observations"].get("rnd_state") 80 | if rnd_state is None: 81 | raise ValueError("Observations for the key 'rnd_state' not found in infos['observations'].") 82 | # get dimension of rnd gated state 83 | num_rnd_state = rnd_state.shape[1] 84 | # add rnd gated state to config 85 | self.alg_cfg["rnd_cfg"]["num_states"] = num_rnd_state 86 | # scale down the rnd weight with timestep (similar to how rewards are scaled down in legged_gym envs) 87 | self.alg_cfg["rnd_cfg"]["weight"] *= env.unwrapped.step_dt 88 | 89 | # if using symmetry then pass the environment config object 90 | if "symmetry_cfg" in self.alg_cfg and self.alg_cfg["symmetry_cfg"] is not None: 91 | # this is used by the symmetry function for handling different observation terms 92 | self.alg_cfg["symmetry_cfg"]["_env"] = env 93 | 94 | # initialize algorithm 95 | alg_class = eval(self.alg_cfg.pop("class_name")) 96 | self.alg: PPO | Distillation = alg_class( 97 | policy, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg 98 | ) 99 | 100 | # store training configuration 101 | self.num_steps_per_env = self.cfg["num_steps_per_env"] 102 | self.save_interval = self.cfg["save_interval"] 103 | self.empirical_normalization = self.cfg["empirical_normalization"] 104 | if self.empirical_normalization: 105 | self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device) 106 | self.privileged_obs_normalizer = EmpiricalNormalization(shape=[num_privileged_obs], until=1.0e8).to( 107 | self.device 108 | ) 109 | else: 110 | self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization 111 | self.privileged_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization 112 | 113 | # init storage and model 114 | self.alg.init_storage( 115 | self.training_type, 116 | self.env.num_envs, 117 | self.num_steps_per_env, 118 | [num_obs], 119 | [num_privileged_obs], 120 | [self.env.num_actions], 121 | ) 122 | 123 | # Decide whether to disable logging 124 | # We only log from the process with rank 0 (main process) 125 | self.disable_logs = self.is_distributed and self.gpu_global_rank != 0 126 | # Logging 127 | self.log_dir = log_dir 128 | self.writer = None 129 | self.tot_timesteps = 0 130 | self.tot_time = 0 131 | self.current_learning_iteration = 0 132 | self.git_status_repos = [rsl_rl.__file__] 133 | 134 | def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): # noqa: C901 135 | # initialize writer 136 | if self.log_dir is not None and self.writer is None and not self.disable_logs: 137 | # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard. 138 | self.logger_type = self.cfg.get("logger", "tensorboard") 139 | self.logger_type = self.logger_type.lower() 140 | 141 | if self.logger_type == "neptune": 142 | from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter 143 | 144 | self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) 145 | self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg) 146 | elif self.logger_type == "wandb": 147 | from rsl_rl.utils.wandb_utils import WandbSummaryWriter 148 | 149 | self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) 150 | self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg) 151 | elif self.logger_type == "tensorboard": 152 | from torch.utils.tensorboard import SummaryWriter 153 | 154 | self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) 155 | else: 156 | raise ValueError("Logger type not found. Please choose 'neptune', 'wandb' or 'tensorboard'.") 157 | 158 | # check if teacher is loaded 159 | if self.training_type == "distillation" and not self.alg.policy.loaded_teacher: 160 | raise ValueError("Teacher model parameters not loaded. Please load a teacher model to distill.") 161 | 162 | # randomize initial episode lengths (for exploration) 163 | if init_at_random_ep_len: 164 | self.env.episode_length_buf = torch.randint_like( 165 | self.env.episode_length_buf, high=int(self.env.max_episode_length) 166 | ) 167 | 168 | # start learning 169 | obs, extras = self.env.get_observations() 170 | privileged_obs = extras["observations"].get(self.privileged_obs_type, obs) 171 | obs, privileged_obs = obs.to(self.device), privileged_obs.to(self.device) 172 | self.train_mode() # switch to train mode (for dropout for example) 173 | 174 | # Book keeping 175 | ep_infos = [] 176 | rewbuffer = deque(maxlen=100) 177 | lenbuffer = deque(maxlen=100) 178 | cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) 179 | cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) 180 | 181 | # create buffers for logging extrinsic and intrinsic rewards 182 | if self.alg.rnd: 183 | erewbuffer = deque(maxlen=100) 184 | irewbuffer = deque(maxlen=100) 185 | cur_ereward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) 186 | cur_ireward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) 187 | 188 | # Ensure all parameters are in-synced 189 | if self.is_distributed: 190 | print(f"Synchronizing parameters for rank {self.gpu_global_rank}...") 191 | self.alg.broadcast_parameters() 192 | # TODO: Do we need to synchronize empirical normalizers? 193 | # Right now: No, because they all should converge to the same values "asymptotically". 194 | 195 | # Start training 196 | start_iter = self.current_learning_iteration 197 | tot_iter = start_iter + num_learning_iterations 198 | for it in range(start_iter, tot_iter): 199 | start = time.time() 200 | # Rollout 201 | with torch.inference_mode(): 202 | for _ in range(self.num_steps_per_env): 203 | # Sample actions 204 | actions = self.alg.act(obs, privileged_obs) 205 | # Step the environment 206 | obs, rewards, dones, infos = self.env.step(actions.to(self.env.device)) 207 | # Move to device 208 | obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device)) 209 | # perform normalization 210 | obs = self.obs_normalizer(obs) 211 | if self.privileged_obs_type is not None: 212 | privileged_obs = self.privileged_obs_normalizer( 213 | infos["observations"][self.privileged_obs_type].to(self.device) 214 | ) 215 | else: 216 | privileged_obs = obs 217 | 218 | # process the step 219 | self.alg.process_env_step(rewards, dones, infos) 220 | 221 | # Extract intrinsic rewards (only for logging) 222 | intrinsic_rewards = self.alg.intrinsic_rewards if self.alg.rnd else None 223 | 224 | # book keeping 225 | if self.log_dir is not None: 226 | if "episode" in infos: 227 | ep_infos.append(infos["episode"]) 228 | elif "log" in infos: 229 | ep_infos.append(infos["log"]) 230 | # Update rewards 231 | if self.alg.rnd: 232 | cur_ereward_sum += rewards 233 | cur_ireward_sum += intrinsic_rewards # type: ignore 234 | cur_reward_sum += rewards + intrinsic_rewards 235 | else: 236 | cur_reward_sum += rewards 237 | # Update episode length 238 | cur_episode_length += 1 239 | # Clear data for completed episodes 240 | # -- common 241 | new_ids = (dones > 0).nonzero(as_tuple=False) 242 | rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) 243 | lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) 244 | cur_reward_sum[new_ids] = 0 245 | cur_episode_length[new_ids] = 0 246 | # -- intrinsic and extrinsic rewards 247 | if self.alg.rnd: 248 | erewbuffer.extend(cur_ereward_sum[new_ids][:, 0].cpu().numpy().tolist()) 249 | irewbuffer.extend(cur_ireward_sum[new_ids][:, 0].cpu().numpy().tolist()) 250 | cur_ereward_sum[new_ids] = 0 251 | cur_ireward_sum[new_ids] = 0 252 | 253 | stop = time.time() 254 | collection_time = stop - start 255 | start = stop 256 | 257 | # compute returns 258 | if self.training_type == "rl": 259 | self.alg.compute_returns(privileged_obs) 260 | 261 | # update policy 262 | loss_dict = self.alg.update() 263 | 264 | stop = time.time() 265 | learn_time = stop - start 266 | self.current_learning_iteration = it 267 | # log info 268 | if self.log_dir is not None and not self.disable_logs: 269 | # Log information 270 | self.log(locals()) 271 | # Save model 272 | if it % self.save_interval == 0: 273 | self.save(os.path.join(self.log_dir, f"model_{it}.pt")) 274 | 275 | # Clear episode infos 276 | ep_infos.clear() 277 | # Save code state 278 | if it == start_iter and not self.disable_logs: 279 | # obtain all the diff files 280 | git_file_paths = store_code_state(self.log_dir, self.git_status_repos) 281 | # if possible store them to wandb 282 | if self.logger_type in ["wandb", "neptune"] and git_file_paths: 283 | for path in git_file_paths: 284 | self.writer.save_file(path) 285 | 286 | # Save the final model after training 287 | if self.log_dir is not None and not self.disable_logs: 288 | self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt")) 289 | 290 | def log(self, locs: dict, width: int = 80, pad: int = 35): 291 | # Compute the collection size 292 | collection_size = self.num_steps_per_env * self.env.num_envs * self.gpu_world_size 293 | # Update total time-steps and time 294 | self.tot_timesteps += collection_size 295 | self.tot_time += locs["collection_time"] + locs["learn_time"] 296 | iteration_time = locs["collection_time"] + locs["learn_time"] 297 | 298 | # -- Episode info 299 | ep_string = "" 300 | if locs["ep_infos"]: 301 | for key in locs["ep_infos"][0]: 302 | infotensor = torch.tensor([], device=self.device) 303 | for ep_info in locs["ep_infos"]: 304 | # handle scalar and zero dimensional tensor infos 305 | if key not in ep_info: 306 | continue 307 | if not isinstance(ep_info[key], torch.Tensor): 308 | ep_info[key] = torch.Tensor([ep_info[key]]) 309 | if len(ep_info[key].shape) == 0: 310 | ep_info[key] = ep_info[key].unsqueeze(0) 311 | infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) 312 | value = torch.mean(infotensor) 313 | # log to logger and terminal 314 | if "/" in key: 315 | self.writer.add_scalar(key, value, locs["it"]) 316 | ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n""" 317 | else: 318 | self.writer.add_scalar("Episode/" + key, value, locs["it"]) 319 | ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" 320 | 321 | mean_std = self.alg.policy.action_std.mean() 322 | fps = int(collection_size / (locs["collection_time"] + locs["learn_time"])) 323 | 324 | # -- Losses 325 | for key, value in locs["loss_dict"].items(): 326 | self.writer.add_scalar(f"Loss/{key}", value, locs["it"]) 327 | self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"]) 328 | 329 | # -- Policy 330 | self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"]) 331 | 332 | # -- Performance 333 | self.writer.add_scalar("Perf/total_fps", fps, locs["it"]) 334 | self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"]) 335 | self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"]) 336 | 337 | # -- Training 338 | if len(locs["rewbuffer"]) > 0: 339 | # separate logging for intrinsic and extrinsic rewards 340 | if self.alg.rnd: 341 | self.writer.add_scalar("Rnd/mean_extrinsic_reward", statistics.mean(locs["erewbuffer"]), locs["it"]) 342 | self.writer.add_scalar("Rnd/mean_intrinsic_reward", statistics.mean(locs["irewbuffer"]), locs["it"]) 343 | self.writer.add_scalar("Rnd/weight", self.alg.rnd.weight, locs["it"]) 344 | # everything else 345 | self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"]) 346 | self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"]) 347 | if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging 348 | self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time) 349 | self.writer.add_scalar( 350 | "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time 351 | ) 352 | 353 | str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " 354 | 355 | if len(locs["rewbuffer"]) > 0: 356 | log_string = ( 357 | f"""{'#' * width}\n""" 358 | f"""{str.center(width, ' ')}\n\n""" 359 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ 360 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 361 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" 362 | ) 363 | # -- Losses 364 | for key, value in locs["loss_dict"].items(): 365 | log_string += f"""{f'Mean {key} loss:':>{pad}} {value:.4f}\n""" 366 | # -- Rewards 367 | if self.alg.rnd: 368 | log_string += ( 369 | f"""{'Mean extrinsic reward:':>{pad}} {statistics.mean(locs['erewbuffer']):.2f}\n""" 370 | f"""{'Mean intrinsic reward:':>{pad}} {statistics.mean(locs['irewbuffer']):.2f}\n""" 371 | ) 372 | log_string += f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" 373 | # -- episode info 374 | log_string += f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""" 375 | else: 376 | log_string = ( 377 | f"""{'#' * width}\n""" 378 | f"""{str.center(width, ' ')}\n\n""" 379 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ 380 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 381 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" 382 | ) 383 | for key, value in locs["loss_dict"].items(): 384 | log_string += f"""{f'{key}:':>{pad}} {value:.4f}\n""" 385 | 386 | log_string += ep_string 387 | log_string += ( 388 | f"""{'-' * width}\n""" 389 | f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" 390 | f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" 391 | f"""{'Time elapsed:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n""" 392 | f"""{'ETA:':>{pad}} {time.strftime( 393 | "%H:%M:%S", 394 | time.gmtime( 395 | self.tot_time / (locs['it'] - locs['start_iter'] + 1) 396 | * (locs['start_iter'] + locs['num_learning_iterations'] - locs['it']) 397 | ) 398 | )}\n""" 399 | ) 400 | print(log_string) 401 | 402 | def save(self, path: str, infos=None): 403 | # -- Save model 404 | saved_dict = { 405 | "model_state_dict": self.alg.policy.state_dict(), 406 | "optimizer_state_dict": self.alg.optimizer.state_dict(), 407 | "iter": self.current_learning_iteration, 408 | "infos": infos, 409 | } 410 | # -- Save RND model if used 411 | if self.alg.rnd: 412 | saved_dict["rnd_state_dict"] = self.alg.rnd.state_dict() 413 | saved_dict["rnd_optimizer_state_dict"] = self.alg.rnd_optimizer.state_dict() 414 | # -- Save observation normalizer if used 415 | if self.empirical_normalization: 416 | saved_dict["obs_norm_state_dict"] = self.obs_normalizer.state_dict() 417 | saved_dict["privileged_obs_norm_state_dict"] = self.privileged_obs_normalizer.state_dict() 418 | 419 | # save model 420 | torch.save(saved_dict, path) 421 | 422 | # upload model to external logging service 423 | if self.logger_type in ["neptune", "wandb"] and not self.disable_logs: 424 | self.writer.save_model(path, self.current_learning_iteration) 425 | 426 | def load(self, path: str, load_optimizer: bool = True): 427 | loaded_dict = torch.load(path, weights_only=False) 428 | # -- Load model 429 | resumed_training = self.alg.policy.load_state_dict(loaded_dict["model_state_dict"]) 430 | # -- Load RND model if used 431 | if self.alg.rnd: 432 | self.alg.rnd.load_state_dict(loaded_dict["rnd_state_dict"]) 433 | # -- Load observation normalizer if used 434 | if self.empirical_normalization: 435 | if resumed_training: 436 | # if a previous training is resumed, the actor/student normalizer is loaded for the actor/student 437 | # and the critic/teacher normalizer is loaded for the critic/teacher 438 | self.obs_normalizer.load_state_dict(loaded_dict["obs_norm_state_dict"]) 439 | self.privileged_obs_normalizer.load_state_dict(loaded_dict["privileged_obs_norm_state_dict"]) 440 | else: 441 | # if the training is not resumed but a model is loaded, this run must be distillation training following 442 | # an rl training. Thus the actor normalizer is loaded for the teacher model. The student's normalizer 443 | # is not loaded, as the observation space could differ from the previous rl training. 444 | self.privileged_obs_normalizer.load_state_dict(loaded_dict["obs_norm_state_dict"]) 445 | # -- load optimizer if used 446 | if load_optimizer and resumed_training: 447 | # -- algorithm optimizer 448 | self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) 449 | # -- RND optimizer if used 450 | if self.alg.rnd: 451 | self.alg.rnd_optimizer.load_state_dict(loaded_dict["rnd_optimizer_state_dict"]) 452 | # -- load current learning iteration 453 | if resumed_training: 454 | self.current_learning_iteration = loaded_dict["iter"] 455 | return loaded_dict["infos"] 456 | 457 | def get_inference_policy(self, device=None): 458 | self.eval_mode() # switch to evaluation mode (dropout for example) 459 | if device is not None: 460 | self.alg.policy.to(device) 461 | policy = self.alg.policy.act_inference 462 | if self.cfg["empirical_normalization"]: 463 | if device is not None: 464 | self.obs_normalizer.to(device) 465 | policy = lambda x: self.alg.policy.act_inference(self.obs_normalizer(x)) # noqa: E731 466 | return policy 467 | 468 | def train_mode(self): 469 | # -- PPO 470 | self.alg.policy.train() 471 | # -- RND 472 | if self.alg.rnd: 473 | self.alg.rnd.train() 474 | # -- Normalization 475 | if self.empirical_normalization: 476 | self.obs_normalizer.train() 477 | self.privileged_obs_normalizer.train() 478 | 479 | def eval_mode(self): 480 | # -- PPO 481 | self.alg.policy.eval() 482 | # -- RND 483 | if self.alg.rnd: 484 | self.alg.rnd.eval() 485 | # -- Normalization 486 | if self.empirical_normalization: 487 | self.obs_normalizer.eval() 488 | self.privileged_obs_normalizer.eval() 489 | 490 | def add_git_repo_to_log(self, repo_file_path): 491 | self.git_status_repos.append(repo_file_path) 492 | 493 | """ 494 | Helper functions. 495 | """ 496 | 497 | def _configure_multi_gpu(self): 498 | """Configure multi-gpu training.""" 499 | # check if distributed training is enabled 500 | self.gpu_world_size = int(os.getenv("WORLD_SIZE", "1")) 501 | self.is_distributed = self.gpu_world_size > 1 502 | 503 | # if not distributed training, set local and global rank to 0 and return 504 | if not self.is_distributed: 505 | self.gpu_local_rank = 0 506 | self.gpu_global_rank = 0 507 | self.multi_gpu_cfg = None 508 | return 509 | 510 | # get rank and world size 511 | self.gpu_local_rank = int(os.getenv("LOCAL_RANK", "0")) 512 | self.gpu_global_rank = int(os.getenv("RANK", "0")) 513 | 514 | # make a configuration dictionary 515 | self.multi_gpu_cfg = { 516 | "global_rank": self.gpu_global_rank, # rank of the main process 517 | "local_rank": self.gpu_local_rank, # rank of the current process 518 | "world_size": self.gpu_world_size, # total number of processes 519 | } 520 | 521 | # check if user has device specified for local rank 522 | if self.device != f"cuda:{self.gpu_local_rank}": 523 | raise ValueError( 524 | f"Device '{self.device}' does not match expected device for local rank '{self.gpu_local_rank}'." 525 | ) 526 | # validate multi-gpu configuration 527 | if self.gpu_local_rank >= self.gpu_world_size: 528 | raise ValueError( 529 | f"Local rank '{self.gpu_local_rank}' is greater than or equal to world size '{self.gpu_world_size}'." 530 | ) 531 | if self.gpu_global_rank >= self.gpu_world_size: 532 | raise ValueError( 533 | f"Global rank '{self.gpu_global_rank}' is greater than or equal to world size '{self.gpu_world_size}'." 534 | ) 535 | 536 | # initialize torch distributed 537 | torch.distributed.init_process_group(backend="nccl", rank=self.gpu_global_rank, world_size=self.gpu_world_size) 538 | # set device to the local rank 539 | torch.cuda.set_device(self.gpu_local_rank) 540 | -------------------------------------------------------------------------------- /rsl_rl/storage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Implementation of transitions storage for RL-agent.""" 7 | 8 | from .rollout_storage import RolloutStorage 9 | 10 | __all__ = ["RolloutStorage"] 11 | -------------------------------------------------------------------------------- /rsl_rl/storage/rollout_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import torch 9 | 10 | from rsl_rl.utils import split_and_pad_trajectories 11 | 12 | 13 | class RolloutStorage: 14 | class Transition: 15 | def __init__(self): 16 | self.observations = None 17 | self.privileged_observations = None 18 | self.actions = None 19 | self.privileged_actions = None 20 | self.rewards = None 21 | self.dones = None 22 | self.values = None 23 | self.actions_log_prob = None 24 | self.action_mean = None 25 | self.action_sigma = None 26 | self.hidden_states = None 27 | self.rnd_state = None 28 | 29 | def clear(self): 30 | self.__init__() 31 | 32 | def __init__( 33 | self, 34 | training_type, 35 | num_envs, 36 | num_transitions_per_env, 37 | obs_shape, 38 | privileged_obs_shape, 39 | actions_shape, 40 | rnd_state_shape=None, 41 | device="cpu", 42 | ): 43 | # store inputs 44 | self.training_type = training_type 45 | self.device = device 46 | self.num_transitions_per_env = num_transitions_per_env 47 | self.num_envs = num_envs 48 | self.obs_shape = obs_shape 49 | self.privileged_obs_shape = privileged_obs_shape 50 | self.rnd_state_shape = rnd_state_shape 51 | self.actions_shape = actions_shape 52 | 53 | # Core 54 | self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device) 55 | if privileged_obs_shape is not None: 56 | self.privileged_observations = torch.zeros( 57 | num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device 58 | ) 59 | else: 60 | self.privileged_observations = None 61 | self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 62 | self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 63 | self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() 64 | 65 | # for distillation 66 | if training_type == "distillation": 67 | self.privileged_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 68 | 69 | # for reinforcement learning 70 | if training_type == "rl": 71 | self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 72 | self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 73 | self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 74 | self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 75 | self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 76 | self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 77 | 78 | # For RND 79 | if rnd_state_shape is not None: 80 | self.rnd_state = torch.zeros(num_transitions_per_env, num_envs, *rnd_state_shape, device=self.device) 81 | 82 | # For RNN networks 83 | self.saved_hidden_states_a = None 84 | self.saved_hidden_states_c = None 85 | 86 | # counter for the number of transitions stored 87 | self.step = 0 88 | 89 | def add_transitions(self, transition: Transition): 90 | # check if the transition is valid 91 | if self.step >= self.num_transitions_per_env: 92 | raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.") 93 | 94 | # Core 95 | self.observations[self.step].copy_(transition.observations) 96 | if self.privileged_observations is not None: 97 | self.privileged_observations[self.step].copy_(transition.privileged_observations) 98 | self.actions[self.step].copy_(transition.actions) 99 | self.rewards[self.step].copy_(transition.rewards.view(-1, 1)) 100 | self.dones[self.step].copy_(transition.dones.view(-1, 1)) 101 | 102 | # for distillation 103 | if self.training_type == "distillation": 104 | self.privileged_actions[self.step].copy_(transition.privileged_actions) 105 | 106 | # for reinforcement learning 107 | if self.training_type == "rl": 108 | self.values[self.step].copy_(transition.values) 109 | self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) 110 | self.mu[self.step].copy_(transition.action_mean) 111 | self.sigma[self.step].copy_(transition.action_sigma) 112 | 113 | # For RND 114 | if self.rnd_state_shape is not None: 115 | self.rnd_state[self.step].copy_(transition.rnd_state) 116 | 117 | # For RNN networks 118 | self._save_hidden_states(transition.hidden_states) 119 | 120 | # increment the counter 121 | self.step += 1 122 | 123 | def _save_hidden_states(self, hidden_states): 124 | if hidden_states is None or hidden_states == (None, None): 125 | return 126 | # make a tuple out of GRU hidden state sto match the LSTM format 127 | hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) 128 | hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) 129 | # initialize if needed 130 | if self.saved_hidden_states_a is None: 131 | self.saved_hidden_states_a = [ 132 | torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a)) 133 | ] 134 | self.saved_hidden_states_c = [ 135 | torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c)) 136 | ] 137 | # copy the states 138 | for i in range(len(hid_a)): 139 | self.saved_hidden_states_a[i][self.step].copy_(hid_a[i]) 140 | self.saved_hidden_states_c[i][self.step].copy_(hid_c[i]) 141 | 142 | def clear(self): 143 | self.step = 0 144 | 145 | def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True): 146 | advantage = 0 147 | for step in reversed(range(self.num_transitions_per_env)): 148 | # if we are at the last step, bootstrap the return value 149 | if step == self.num_transitions_per_env - 1: 150 | next_values = last_values 151 | else: 152 | next_values = self.values[step + 1] 153 | # 1 if we are not in a terminal state, 0 otherwise 154 | next_is_not_terminal = 1.0 - self.dones[step].float() 155 | # TD error: r_t + gamma * V(s_{t+1}) - V(s_t) 156 | delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] 157 | # Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1}) 158 | advantage = delta + next_is_not_terminal * gamma * lam * advantage 159 | # Return: R_t = A(s_t, a_t) + V(s_t) 160 | self.returns[step] = advantage + self.values[step] 161 | 162 | # Compute the advantages 163 | self.advantages = self.returns - self.values 164 | # Normalize the advantages if flag is set 165 | # This is to prevent double normalization (i.e. if per minibatch normalization is used) 166 | if normalize_advantage: 167 | self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) 168 | 169 | # for distillation 170 | def generator(self): 171 | if self.training_type != "distillation": 172 | raise ValueError("This function is only available for distillation training.") 173 | 174 | for i in range(self.num_transitions_per_env): 175 | if self.privileged_observations is not None: 176 | privileged_observations = self.privileged_observations[i] 177 | else: 178 | privileged_observations = self.observations[i] 179 | yield self.observations[i], privileged_observations, self.actions[i], self.privileged_actions[ 180 | i 181 | ], self.dones[i] 182 | 183 | # for reinforcement learning with feedforward networks 184 | def mini_batch_generator(self, num_mini_batches, num_epochs=8): 185 | if self.training_type != "rl": 186 | raise ValueError("This function is only available for reinforcement learning training.") 187 | batch_size = self.num_envs * self.num_transitions_per_env 188 | mini_batch_size = batch_size // num_mini_batches 189 | indices = torch.randperm(num_mini_batches * mini_batch_size, requires_grad=False, device=self.device) 190 | 191 | # Core 192 | observations = self.observations.flatten(0, 1) 193 | if self.privileged_observations is not None: 194 | privileged_observations = self.privileged_observations.flatten(0, 1) 195 | else: 196 | privileged_observations = observations 197 | 198 | actions = self.actions.flatten(0, 1) 199 | values = self.values.flatten(0, 1) 200 | returns = self.returns.flatten(0, 1) 201 | 202 | # For PPO 203 | old_actions_log_prob = self.actions_log_prob.flatten(0, 1) 204 | advantages = self.advantages.flatten(0, 1) 205 | old_mu = self.mu.flatten(0, 1) 206 | old_sigma = self.sigma.flatten(0, 1) 207 | 208 | # For RND 209 | if self.rnd_state_shape is not None: 210 | rnd_state = self.rnd_state.flatten(0, 1) 211 | 212 | for epoch in range(num_epochs): 213 | for i in range(num_mini_batches): 214 | # Select the indices for the mini-batch 215 | start = i * mini_batch_size 216 | end = (i + 1) * mini_batch_size 217 | batch_idx = indices[start:end] 218 | 219 | # Create the mini-batch 220 | # -- Core 221 | obs_batch = observations[batch_idx] 222 | privileged_observations_batch = privileged_observations[batch_idx] 223 | actions_batch = actions[batch_idx] 224 | 225 | # -- For PPO 226 | target_values_batch = values[batch_idx] 227 | returns_batch = returns[batch_idx] 228 | old_actions_log_prob_batch = old_actions_log_prob[batch_idx] 229 | advantages_batch = advantages[batch_idx] 230 | old_mu_batch = old_mu[batch_idx] 231 | old_sigma_batch = old_sigma[batch_idx] 232 | 233 | # -- For RND 234 | if self.rnd_state_shape is not None: 235 | rnd_state_batch = rnd_state[batch_idx] 236 | else: 237 | rnd_state_batch = None 238 | 239 | # yield the mini-batch 240 | yield obs_batch, privileged_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( 241 | None, 242 | None, 243 | ), None, rnd_state_batch 244 | 245 | # for reinfrocement learning with recurrent networks 246 | def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8): 247 | if self.training_type != "rl": 248 | raise ValueError("This function is only available for reinforcement learning training.") 249 | padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) 250 | if self.privileged_observations is not None: 251 | padded_privileged_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones) 252 | else: 253 | padded_privileged_obs_trajectories = padded_obs_trajectories 254 | 255 | if self.rnd_state_shape is not None: 256 | padded_rnd_state_trajectories, _ = split_and_pad_trajectories(self.rnd_state, self.dones) 257 | else: 258 | padded_rnd_state_trajectories = None 259 | 260 | mini_batch_size = self.num_envs // num_mini_batches 261 | for ep in range(num_epochs): 262 | first_traj = 0 263 | for i in range(num_mini_batches): 264 | start = i * mini_batch_size 265 | stop = (i + 1) * mini_batch_size 266 | 267 | dones = self.dones.squeeze(-1) 268 | last_was_done = torch.zeros_like(dones, dtype=torch.bool) 269 | last_was_done[1:] = dones[:-1] 270 | last_was_done[0] = True 271 | trajectories_batch_size = torch.sum(last_was_done[:, start:stop]) 272 | last_traj = first_traj + trajectories_batch_size 273 | 274 | masks_batch = trajectory_masks[:, first_traj:last_traj] 275 | obs_batch = padded_obs_trajectories[:, first_traj:last_traj] 276 | privileged_obs_batch = padded_privileged_obs_trajectories[:, first_traj:last_traj] 277 | 278 | if padded_rnd_state_trajectories is not None: 279 | rnd_state_batch = padded_rnd_state_trajectories[:, first_traj:last_traj] 280 | else: 281 | rnd_state_batch = None 282 | 283 | actions_batch = self.actions[:, start:stop] 284 | old_mu_batch = self.mu[:, start:stop] 285 | old_sigma_batch = self.sigma[:, start:stop] 286 | returns_batch = self.returns[:, start:stop] 287 | advantages_batch = self.advantages[:, start:stop] 288 | values_batch = self.values[:, start:stop] 289 | old_actions_log_prob_batch = self.actions_log_prob[:, start:stop] 290 | 291 | # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim]) 292 | # then take only time steps after dones (flattens num envs and time dimensions), 293 | # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim] 294 | last_was_done = last_was_done.permute(1, 0) 295 | hid_a_batch = [ 296 | saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] 297 | .transpose(1, 0) 298 | .contiguous() 299 | for saved_hidden_states in self.saved_hidden_states_a 300 | ] 301 | hid_c_batch = [ 302 | saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] 303 | .transpose(1, 0) 304 | .contiguous() 305 | for saved_hidden_states in self.saved_hidden_states_c 306 | ] 307 | # remove the tuple for GRU 308 | hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch 309 | hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch 310 | 311 | yield obs_batch, privileged_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( 312 | hid_a_batch, 313 | hid_c_batch, 314 | ), masks_batch, rnd_state_batch 315 | 316 | first_traj = last_traj 317 | -------------------------------------------------------------------------------- /rsl_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | """Helper functions.""" 7 | 8 | from .utils import ( 9 | resolve_nn_activation, 10 | split_and_pad_trajectories, 11 | store_code_state, 12 | string_to_callable, 13 | unpad_trajectories, 14 | ) 15 | -------------------------------------------------------------------------------- /rsl_rl/utils/neptune_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import os 9 | from dataclasses import asdict 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | try: 13 | import neptune 14 | except ModuleNotFoundError: 15 | raise ModuleNotFoundError("neptune-client is required to log to Neptune.") 16 | 17 | 18 | class NeptuneLogger: 19 | def __init__(self, project, token): 20 | self.run = neptune.init_run(project=project, api_token=token) 21 | 22 | def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): 23 | self.run["runner_cfg"] = runner_cfg 24 | self.run["policy_cfg"] = policy_cfg 25 | self.run["alg_cfg"] = alg_cfg 26 | self.run["env_cfg"] = asdict(env_cfg) 27 | 28 | 29 | class NeptuneSummaryWriter(SummaryWriter): 30 | """Summary writer for Neptune.""" 31 | 32 | def __init__(self, log_dir: str, flush_secs: int, cfg): 33 | super().__init__(log_dir, flush_secs) 34 | 35 | try: 36 | project = cfg["neptune_project"] 37 | except KeyError: 38 | raise KeyError("Please specify neptune_project in the runner config, e.g. legged_gym.") 39 | 40 | try: 41 | token = os.environ["NEPTUNE_API_TOKEN"] 42 | except KeyError: 43 | raise KeyError( 44 | "Neptune api token not found. Please run or add to ~/.bashrc: export NEPTUNE_API_TOKEN=YOUR_API_TOKEN" 45 | ) 46 | 47 | try: 48 | entity = os.environ["NEPTUNE_USERNAME"] 49 | except KeyError: 50 | raise KeyError( 51 | "Neptune username not found. Please run or add to ~/.bashrc: export NEPTUNE_USERNAME=YOUR_USERNAME" 52 | ) 53 | 54 | neptune_project = entity + "/" + project 55 | 56 | self.neptune_logger = NeptuneLogger(neptune_project, token) 57 | 58 | self.name_map = { 59 | "Train/mean_reward/time": "Train/mean_reward_time", 60 | "Train/mean_episode_length/time": "Train/mean_episode_length_time", 61 | } 62 | 63 | run_name = os.path.split(log_dir)[-1] 64 | 65 | self.neptune_logger.run["log_dir"].log(run_name) 66 | 67 | def _map_path(self, path): 68 | if path in self.name_map: 69 | return self.name_map[path] 70 | else: 71 | return path 72 | 73 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False): 74 | super().add_scalar( 75 | tag, 76 | scalar_value, 77 | global_step=global_step, 78 | walltime=walltime, 79 | new_style=new_style, 80 | ) 81 | self.neptune_logger.run[self._map_path(tag)].log(scalar_value, step=global_step) 82 | 83 | def stop(self): 84 | self.neptune_logger.run.stop() 85 | 86 | def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): 87 | self.neptune_logger.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg) 88 | 89 | def save_model(self, model_path, iter): 90 | self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path) 91 | 92 | def save_file(self, path, iter=None): 93 | name = path.rsplit("/", 1)[-1].split(".")[0] 94 | self.neptune_logger.run["git_diff/" + name].upload(path) 95 | -------------------------------------------------------------------------------- /rsl_rl/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import git 9 | import importlib 10 | import os 11 | import pathlib 12 | import torch 13 | from typing import Callable 14 | 15 | 16 | def resolve_nn_activation(act_name: str) -> torch.nn.Module: 17 | if act_name == "elu": 18 | return torch.nn.ELU() 19 | elif act_name == "selu": 20 | return torch.nn.SELU() 21 | elif act_name == "relu": 22 | return torch.nn.ReLU() 23 | elif act_name == "crelu": 24 | return torch.nn.CELU() 25 | elif act_name == "lrelu": 26 | return torch.nn.LeakyReLU() 27 | elif act_name == "tanh": 28 | return torch.nn.Tanh() 29 | elif act_name == "sigmoid": 30 | return torch.nn.Sigmoid() 31 | elif act_name == "identity": 32 | return torch.nn.Identity() 33 | else: 34 | raise ValueError(f"Invalid activation function '{act_name}'.") 35 | 36 | 37 | def split_and_pad_trajectories(tensor, dones): 38 | """Splits trajectories at done indices. Then concatenates them and pads with zeros up to the length og the longest trajectory. 39 | Returns masks corresponding to valid parts of the trajectories 40 | Example: 41 | Input: [ [a1, a2, a3, a4 | a5, a6], 42 | [b1, b2 | b3, b4, b5 | b6] 43 | ] 44 | 45 | Output:[ [a1, a2, a3, a4], | [ [True, True, True, True], 46 | [a5, a6, 0, 0], | [True, True, False, False], 47 | [b1, b2, 0, 0], | [True, True, False, False], 48 | [b3, b4, b5, 0], | [True, True, True, False], 49 | [b6, 0, 0, 0] | [True, False, False, False], 50 | ] | ] 51 | 52 | Assumes that the inputy has the following dimension order: [time, number of envs, additional dimensions] 53 | """ 54 | dones = dones.clone() 55 | dones[-1] = 1 56 | # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping 57 | flat_dones = dones.transpose(1, 0).reshape(-1, 1) 58 | 59 | # Get length of trajectory by counting the number of successive not done elements 60 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0])) 61 | trajectory_lengths = done_indices[1:] - done_indices[:-1] 62 | trajectory_lengths_list = trajectory_lengths.tolist() 63 | # Extract the individual trajectories 64 | trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list) 65 | # add at least one full length trajectory 66 | trajectories = trajectories + (torch.zeros(tensor.shape[0], *tensor.shape[2:], device=tensor.device),) 67 | # pad the trajectories to the length of the longest trajectory 68 | padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) 69 | # remove the added tensor 70 | padded_trajectories = padded_trajectories[:, :-1] 71 | 72 | trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) 73 | return padded_trajectories, trajectory_masks 74 | 75 | 76 | def unpad_trajectories(trajectories, masks): 77 | """Does the inverse operation of split_and_pad_trajectories()""" 78 | # Need to transpose before and after the masking to have proper reshaping 79 | return ( 80 | trajectories.transpose(1, 0)[masks.transpose(1, 0)] 81 | .view(-1, trajectories.shape[0], trajectories.shape[-1]) 82 | .transpose(1, 0) 83 | ) 84 | 85 | 86 | def store_code_state(logdir, repositories) -> list: 87 | git_log_dir = os.path.join(logdir, "git") 88 | os.makedirs(git_log_dir, exist_ok=True) 89 | file_paths = [] 90 | for repository_file_path in repositories: 91 | try: 92 | repo = git.Repo(repository_file_path, search_parent_directories=True) 93 | t = repo.head.commit.tree 94 | except Exception: 95 | print(f"Could not find git repository in {repository_file_path}. Skipping.") 96 | # skip if not a git repository 97 | continue 98 | # get the name of the repository 99 | repo_name = pathlib.Path(repo.working_dir).name 100 | diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff") 101 | # check if the diff file already exists 102 | if os.path.isfile(diff_file_name): 103 | continue 104 | # write the diff file 105 | print(f"Storing git diff for '{repo_name}' in: {diff_file_name}") 106 | with open(diff_file_name, "x", encoding="utf-8") as f: 107 | content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" 108 | f.write(content) 109 | # add the file path to the list of files to be uploaded 110 | file_paths.append(diff_file_name) 111 | return file_paths 112 | 113 | 114 | def string_to_callable(name: str) -> Callable: 115 | """Resolves the module and function names to return the function. 116 | 117 | Args: 118 | name (str): The function name. The format should be 'module:attribute_name'. 119 | 120 | Raises: 121 | ValueError: When the resolved attribute is not a function. 122 | ValueError: When unable to resolve the attribute. 123 | 124 | Returns: 125 | Callable: The function loaded from the module. 126 | """ 127 | try: 128 | mod_name, attr_name = name.split(":") 129 | mod = importlib.import_module(mod_name) 130 | callable_object = getattr(mod, attr_name) 131 | # check if attribute is callable 132 | if callable(callable_object): 133 | return callable_object 134 | else: 135 | raise ValueError(f"The imported object is not callable: '{name}'") 136 | except AttributeError as e: 137 | msg = ( 138 | "We could not interpret the entry as a callable object. The format of input should be" 139 | f" 'module:attribute_name'\nWhile processing input '{name}', received the error:\n {e}." 140 | ) 141 | raise ValueError(msg) 142 | -------------------------------------------------------------------------------- /rsl_rl/utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import os 9 | from dataclasses import asdict 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | try: 13 | import wandb 14 | except ModuleNotFoundError: 15 | raise ModuleNotFoundError("Wandb is required to log to Weights and Biases.") 16 | 17 | 18 | class WandbSummaryWriter(SummaryWriter): 19 | """Summary writer for Weights and Biases.""" 20 | 21 | def __init__(self, log_dir: str, flush_secs: int, cfg): 22 | super().__init__(log_dir, flush_secs) 23 | 24 | # Get the run name 25 | run_name = os.path.split(log_dir)[-1] 26 | 27 | try: 28 | project = cfg["wandb_project"] 29 | except KeyError: 30 | raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") 31 | 32 | try: 33 | entity = os.environ["WANDB_USERNAME"] 34 | except KeyError: 35 | entity = None 36 | 37 | # Initialize wandb 38 | wandb.init(project=project, entity=entity, name=run_name) 39 | 40 | # Add log directory to wandb 41 | wandb.config.update({"log_dir": log_dir}) 42 | 43 | self.name_map = { 44 | "Train/mean_reward/time": "Train/mean_reward_time", 45 | "Train/mean_episode_length/time": "Train/mean_episode_length_time", 46 | } 47 | 48 | def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): 49 | wandb.config.update({"runner_cfg": runner_cfg}) 50 | wandb.config.update({"policy_cfg": policy_cfg}) 51 | wandb.config.update({"alg_cfg": alg_cfg}) 52 | try: 53 | wandb.config.update({"env_cfg": env_cfg.to_dict()}) 54 | except Exception: 55 | wandb.config.update({"env_cfg": asdict(env_cfg)}) 56 | 57 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False): 58 | super().add_scalar( 59 | tag, 60 | scalar_value, 61 | global_step=global_step, 62 | walltime=walltime, 63 | new_style=new_style, 64 | ) 65 | wandb.log({self._map_path(tag): scalar_value}, step=global_step) 66 | 67 | def stop(self): 68 | wandb.finish() 69 | 70 | def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): 71 | self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg) 72 | 73 | def save_model(self, model_path, iter): 74 | wandb.save(model_path, base_path=os.path.dirname(model_path)) 75 | 76 | def save_file(self, path, iter=None): 77 | wandb.save(path, base_path=os.path.dirname(path)) 78 | 79 | """ 80 | Private methods. 81 | """ 82 | 83 | def _map_path(self, path): 84 | if path in self.name_map: 85 | return self.name_map[path] 86 | else: 87 | return path 88 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from setuptools import setup 7 | 8 | setup() 9 | --------------------------------------------------------------------------------