├── .flake8
├── .github
└── LICENSE_HEADER.txt
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTORS.md
├── LICENSE
├── README.md
├── config
└── dummy_config.yaml
├── licenses
└── dependencies
│ ├── black-license.txt
│ ├── codespell-license.txt
│ ├── flake8-license.txt
│ ├── isort-license.txt
│ ├── numpy_license.txt
│ ├── onnx-license.txt
│ ├── pre-commit-hooks-license.txt
│ ├── pre-commit-license.txt
│ ├── pyright-license.txt
│ ├── pyupgrade-license.txt
│ └── torch_license.txt
├── pyproject.toml
├── rsl_rl
├── __init__.py
├── algorithms
│ ├── __init__.py
│ ├── distillation.py
│ └── ppo.py
├── env
│ ├── __init__.py
│ └── vec_env.py
├── modules
│ ├── __init__.py
│ ├── actor_critic.py
│ ├── actor_critic_recurrent.py
│ ├── normalizer.py
│ ├── rnd.py
│ ├── student_teacher.py
│ └── student_teacher_recurrent.py
├── networks
│ ├── __init__.py
│ └── memory.py
├── runners
│ ├── __init__.py
│ └── on_policy_runner.py
├── storage
│ ├── __init__.py
│ └── rollout_storage.py
└── utils
│ ├── __init__.py
│ ├── neptune_utils.py
│ ├── utils.py
│ └── wandb_utils.py
└── setup.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | show-source=True
3 | statistics=True
4 | per-file-ignores=*/__init__.py:F401
5 | # E402: Module level import not at top of file
6 | # E501: Line too long
7 | # W503: Line break before binary operator
8 | # E203: Whitespace before ':' -> conflicts with black
9 | # D401: First line should be in imperative mood
10 | # R504: Unnecessary variable assignment before return statement.
11 | # R505: Unnecessary elif after return statement
12 | # SIM102: Use a single if-statement instead of nested if-statements
13 | # SIM117: Merge with statements for context managers that have same scope.
14 | ignore=E402,E501,W503,E203,D401,R504,R505,SIM102,SIM117
15 | max-line-length = 120
16 | max-complexity = 18
17 | exclude=_*,.vscode,.git,docs/**
18 | # docstrings
19 | docstring-convention=google
20 | # annotations
21 | suppress-none-returning=True
22 | allow-star-arg-any=True
23 |
--------------------------------------------------------------------------------
/.github/LICENSE_HEADER.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | All rights reserved.
3 |
4 | SPDX-License-Identifier: BSD-3-Clause
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # IDEs
2 | .idea
3 |
4 | # builds
5 | *.egg-info
6 | build/*
7 | dist/*
8 |
9 | # cache
10 | __pycache__
11 | .pytest_cache
12 |
13 | # vs code
14 | .vscode
15 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/python/black
3 | rev: 23.10.1
4 | hooks:
5 | - id: black
6 | args: ["--line-length", "120", "--preview"]
7 | - repo: https://github.com/pycqa/flake8
8 | rev: 6.1.0
9 | hooks:
10 | - id: flake8
11 | additional_dependencies: [flake8-simplify, flake8-return]
12 | - repo: https://github.com/pre-commit/pre-commit-hooks
13 | rev: v4.5.0
14 | hooks:
15 | - id: trailing-whitespace
16 | - id: check-symlinks
17 | - id: destroyed-symlinks
18 | - id: check-yaml
19 | - id: check-merge-conflict
20 | - id: check-case-conflict
21 | - id: check-executables-have-shebangs
22 | - id: check-toml
23 | - id: end-of-file-fixer
24 | - id: check-shebang-scripts-are-executable
25 | - id: detect-private-key
26 | - id: debug-statements
27 | - repo: https://github.com/pycqa/isort
28 | rev: 5.12.0
29 | hooks:
30 | - id: isort
31 | name: isort (python)
32 | args: ["--profile", "black", "--filter-files"]
33 | - repo: https://github.com/asottile/pyupgrade
34 | rev: v3.15.0
35 | hooks:
36 | - id: pyupgrade
37 | args: ["--py37-plus"]
38 | - repo: https://github.com/codespell-project/codespell
39 | rev: v2.2.6
40 | hooks:
41 | - id: codespell
42 | additional_dependencies:
43 | - tomli
44 | - repo: https://github.com/Lucas-C/pre-commit-hooks
45 | rev: v1.5.1
46 | hooks:
47 | - id: insert-license
48 | files: \.py$
49 | args:
50 | # - --remove-header # Remove existing license headers. Useful when updating license.
51 | - --license-filepath
52 | - .github/LICENSE_HEADER.txt
53 |
--------------------------------------------------------------------------------
/CONTRIBUTORS.md:
--------------------------------------------------------------------------------
1 | # RSL-RL Maintainers and Contributors
2 |
3 | This is the official list of developers and contributors.
4 |
5 | To see the full list of contributors, see the revision history in the source control.
6 |
7 | Names should be added to this file as: individual names or organizations.
8 |
9 | Email addresses are tracked elsewhere to avoid spam.
10 |
11 | Please keep the lists sorted alphabetically.
12 |
13 | ## Maintainers
14 |
15 | * Robotic Systems Lab, ETH Zurich
16 | * NVIDIA Corporation
17 |
18 | ---
19 |
20 | * Mayank Mittal
21 | * Clemens Schwarke
22 |
23 | ## Authors
24 |
25 | * David Hoeller
26 | * Nikita Rudin
27 |
28 | ## Contributors
29 |
30 | * Bikram Pandit
31 | * Eric Vollenweider
32 | * Fabian Jenelten
33 | * Lorenzo Terenzi
34 | * Marko Bjelonic
35 | * Matthijs van der Boon
36 | * Özhan Özen
37 | * Pascal Roth
38 | * Zhang Chong
39 | * Ziqi Fan
40 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2025, ETH Zurich
2 | Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification,
6 | are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice,
9 | this list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its contributors
16 | may be used to endorse or promote products derived from this software without
17 | specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
26 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
30 | See licenses/dependencies for license information of dependencies of this package.
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RSL RL
2 |
3 | A fast and simple implementation of RL algorithms, designed to run fully on GPU.
4 | This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac Gym.
5 |
6 | Environment repositories using the framework:
7 |
8 | * **`Isaac Lab`** (built on top of NVIDIA Isaac Sim): https://github.com/isaac-sim/IsaacLab
9 | * **`Legged-Gym`** (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/
10 |
11 | The main branch supports **PPO** and **Student-Teacher Distillation** with additional features from our research. These include:
12 |
13 | * [Random Network Distillation (RND)](https://proceedings.mlr.press/v229/schwarke23a.html) - Encourages exploration by adding
14 | a curiosity driven intrinsic reward.
15 | * [Symmetry-based Augmentation](https://arxiv.org/abs/2403.04359) - Makes the learned behaviors more symmetrical.
16 |
17 | We welcome contributions from the community. Please check our contribution guidelines for more
18 | information.
19 |
20 | **Maintainer**: Mayank Mittal and Clemens Schwarke
21 | **Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA
22 | **Contact**: cschwarke@ethz.ch
23 |
24 | > **Note:** The `algorithms` branch supports additional algorithms (SAC, DDPG, DSAC, and more). However, it isn't currently actively maintained.
25 |
26 |
27 | ## Setup
28 |
29 | The package can be installed via PyPI with:
30 |
31 | ```bash
32 | pip install rsl-rl-lib
33 | ```
34 |
35 | or by cloning this repository and installing it with:
36 |
37 | ```bash
38 | git clone https://github.com/leggedrobotics/rsl_rl
39 | cd rsl_rl
40 | pip install -e .
41 | ```
42 |
43 | The package supports the following logging frameworks which can be configured through `logger`:
44 |
45 | * Tensorboard: https://www.tensorflow.org/tensorboard/
46 | * Weights & Biases: https://wandb.ai/site
47 | * Neptune: https://docs.neptune.ai/
48 |
49 | For a demo configuration of PPO, please check the [dummy_config.yaml](config/dummy_config.yaml) file.
50 |
51 |
52 | ## Contribution Guidelines
53 |
54 | For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) for docstrings. Please make sure that your code is well-documented and follows the guidelines.
55 |
56 | We use the following tools for maintaining code quality:
57 |
58 | - [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase.
59 | - [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter.
60 | - [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker.
61 |
62 | Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal:
63 |
64 | ```bash
65 | # for installation (only once)
66 | pre-commit install
67 | # for running
68 | pre-commit run --all-files
69 | ```
70 |
71 | ## Citing
72 |
73 | **We are working on writing a white paper for this library.** Until then, please cite the following work
74 | if you use this library for your research:
75 |
76 | ```text
77 | @InProceedings{rudin2022learning,
78 | title = {Learning to Walk in Minutes Using Massively Parallel Deep Reinforcement Learning},
79 | author = {Rudin, Nikita and Hoeller, David and Reist, Philipp and Hutter, Marco},
80 | booktitle = {Proceedings of the 5th Conference on Robot Learning},
81 | pages = {91--100},
82 | year = {2022},
83 | volume = {164},
84 | series = {Proceedings of Machine Learning Research},
85 | publisher = {PMLR},
86 | url = {https://proceedings.mlr.press/v164/rudin22a.html},
87 | }
88 | ```
89 |
90 | If you use the library with curiosity-driven exploration (random network distillation), please cite:
91 |
92 | ```text
93 | @InProceedings{schwarke2023curiosity,
94 | title = {Curiosity-Driven Learning of Joint Locomotion and Manipulation Tasks},
95 | author = {Schwarke, Clemens and Klemm, Victor and Boon, Matthijs van der and Bjelonic, Marko and Hutter, Marco},
96 | booktitle = {Proceedings of The 7th Conference on Robot Learning},
97 | pages = {2594--2610},
98 | year = {2023},
99 | volume = {229},
100 | series = {Proceedings of Machine Learning Research},
101 | publisher = {PMLR},
102 | url = {https://proceedings.mlr.press/v229/schwarke23a.html},
103 | }
104 | ```
105 |
106 | If you use the library with symmetry augmentation, please cite:
107 |
108 | ```text
109 | @InProceedings{mittal2024symmetry,
110 | author={Mittal, Mayank and Rudin, Nikita and Klemm, Victor and Allshire, Arthur and Hutter, Marco},
111 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)},
112 | title={Symmetry Considerations for Learning Task Symmetric Robot Policies},
113 | year={2024},
114 | pages={7433-7439},
115 | doi={10.1109/ICRA57147.2024.10611493}
116 | }
117 | ```
118 |
--------------------------------------------------------------------------------
/config/dummy_config.yaml:
--------------------------------------------------------------------------------
1 | algorithm:
2 | class_name: PPO
3 | # training parameters
4 | # -- advantage normalization
5 | normalize_advantage_per_mini_batch: false
6 | # -- value function
7 | value_loss_coef: 1.0
8 | clip_param: 0.2
9 | use_clipped_value_loss: true
10 | # -- surrogate loss
11 | desired_kl: 0.01
12 | entropy_coef: 0.01
13 | gamma: 0.99
14 | lam: 0.95
15 | max_grad_norm: 1.0
16 | # -- training
17 | learning_rate: 0.001
18 | num_learning_epochs: 5
19 | num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches
20 | schedule: adaptive # adaptive, fixed
21 |
22 | # -- Random Network Distillation
23 | rnd_cfg:
24 | weight: 0.0 # initial weight of the RND reward
25 |
26 | # note: This is a dictionary with a required key called "mode".
27 | # Please check the RND module for more information.
28 | weight_schedule: null
29 |
30 | reward_normalization: false # whether to normalize RND reward
31 | state_normalization: true # whether to normalize RND state observations
32 |
33 | # -- Learning parameters
34 | learning_rate: 0.001 # learning rate for RND
35 |
36 | # -- Network parameters
37 | # note: if -1, then the network will use dimensions of the observation
38 | num_outputs: 1 # number of outputs of RND network
39 | predictor_hidden_dims: [-1] # hidden dimensions of predictor network
40 | target_hidden_dims: [-1] # hidden dimensions of target network
41 |
42 | # -- Symmetry Augmentation
43 | symmetry_cfg:
44 | use_data_augmentation: true # this adds symmetric trajectories to the batch
45 | use_mirror_loss: false # this adds symmetry loss term to the loss function
46 |
47 | # string containing the module and function name to import.
48 | # Example: "legged_gym.envs.locomotion.anymal_c.symmetry:get_symmetric_states"
49 | #
50 | # .. code-block:: python
51 | #
52 | # @torch.no_grad()
53 | # def get_symmetric_states(
54 | # obs: Optional[torch.Tensor] = None, actions: Optional[torch.Tensor] = None, cfg: "BaseEnvCfg" = None, obs_type: str = "policy"
55 | # ) -> Tuple[torch.Tensor, torch.Tensor]:
56 | #
57 | data_augmentation_func: null
58 |
59 | # coefficient for symmetry loss term
60 | # if 0, then no symmetry loss is used
61 | mirror_loss_coeff: 0.0
62 |
63 | policy:
64 | class_name: ActorCritic
65 | # for MLP i.e. `ActorCritic`
66 | activation: elu
67 | actor_hidden_dims: [128, 128, 128]
68 | critic_hidden_dims: [128, 128, 128]
69 | init_noise_std: 1.0
70 | noise_std_type: "scalar" # 'scalar' or 'log'
71 |
72 | # only needed for `ActorCriticRecurrent`
73 | # rnn_type: 'lstm'
74 | # rnn_hidden_dim: 512
75 | # rnn_num_layers: 1
76 |
77 | runner:
78 | num_steps_per_env: 24 # number of steps per environment per iteration
79 | max_iterations: 1500 # number of policy updates
80 | empirical_normalization: false
81 | # -- logging parameters
82 | save_interval: 50 # check for potential saves every `save_interval` iterations
83 | experiment_name: walking_experiment
84 | run_name: ""
85 | # -- logging writer
86 | logger: tensorboard # tensorboard, neptune, wandb
87 | neptune_project: legged_gym
88 | wandb_project: legged_gym
89 | # -- load and resuming
90 | load_run: -1 # -1 means load latest run
91 | resume_path: null # updated from load_run and checkpoint
92 | checkpoint: -1 # -1 means load latest checkpoint
93 |
94 | runner_class_name: OnPolicyRunner
95 | seed: 1
96 |
--------------------------------------------------------------------------------
/licenses/dependencies/black-license.txt:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2018 Łukasz Langa
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/licenses/dependencies/codespell-license.txt:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 2, June 1991
3 |
4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
6 | Everyone is permitted to copy and distribute verbatim copies
7 | of this license document, but changing it is not allowed.
8 |
9 | Preamble
10 |
11 | The licenses for most software are designed to take away your
12 | freedom to share and change it. By contrast, the GNU General Public
13 | License is intended to guarantee your freedom to share and change free
14 | software--to make sure the software is free for all its users. This
15 | General Public License applies to most of the Free Software
16 | Foundation's software and to any other program whose authors commit to
17 | using it. (Some other Free Software Foundation software is covered by
18 | the GNU Lesser General Public License instead.) You can apply it to
19 | your programs, too.
20 |
21 | When we speak of free software, we are referring to freedom, not
22 | price. Our General Public Licenses are designed to make sure that you
23 | have the freedom to distribute copies of free software (and charge for
24 | this service if you wish), that you receive source code or can get it
25 | if you want it, that you can change the software or use pieces of it
26 | in new free programs; and that you know you can do these things.
27 |
28 | To protect your rights, we need to make restrictions that forbid
29 | anyone to deny you these rights or to ask you to surrender the rights.
30 | These restrictions translate to certain responsibilities for you if you
31 | distribute copies of the software, or if you modify it.
32 |
33 | For example, if you distribute copies of such a program, whether
34 | gratis or for a fee, you must give the recipients all the rights that
35 | you have. You must make sure that they, too, receive or can get the
36 | source code. And you must show them these terms so they know their
37 | rights.
38 |
39 | We protect your rights with two steps: (1) copyright the software, and
40 | (2) offer you this license which gives you legal permission to copy,
41 | distribute and/or modify the software.
42 |
43 | Also, for each author's protection and ours, we want to make certain
44 | that everyone understands that there is no warranty for this free
45 | software. If the software is modified by someone else and passed on, we
46 | want its recipients to know that what they have is not the original, so
47 | that any problems introduced by others will not reflect on the original
48 | authors' reputations.
49 |
50 | Finally, any free program is threatened constantly by software
51 | patents. We wish to avoid the danger that redistributors of a free
52 | program will individually obtain patent licenses, in effect making the
53 | program proprietary. To prevent this, we have made it clear that any
54 | patent must be licensed for everyone's free use or not licensed at all.
55 |
56 | The precise terms and conditions for copying, distribution and
57 | modification follow.
58 |
59 | GNU GENERAL PUBLIC LICENSE
60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
61 |
62 | 0. This License applies to any program or other work which contains
63 | a notice placed by the copyright holder saying it may be distributed
64 | under the terms of this General Public License. The "Program", below,
65 | refers to any such program or work, and a "work based on the Program"
66 | means either the Program or any derivative work under copyright law:
67 | that is to say, a work containing the Program or a portion of it,
68 | either verbatim or with modifications and/or translated into another
69 | language. (Hereinafter, translation is included without limitation in
70 | the term "modification".) Each licensee is addressed as "you".
71 |
72 | Activities other than copying, distribution and modification are not
73 | covered by this License; they are outside its scope. The act of
74 | running the Program is not restricted, and the output from the Program
75 | is covered only if its contents constitute a work based on the
76 | Program (independent of having been made by running the Program).
77 | Whether that is true depends on what the Program does.
78 |
79 | 1. You may copy and distribute verbatim copies of the Program's
80 | source code as you receive it, in any medium, provided that you
81 | conspicuously and appropriately publish on each copy an appropriate
82 | copyright notice and disclaimer of warranty; keep intact all the
83 | notices that refer to this License and to the absence of any warranty;
84 | and give any other recipients of the Program a copy of this License
85 | along with the Program.
86 |
87 | You may charge a fee for the physical act of transferring a copy, and
88 | you may at your option offer warranty protection in exchange for a fee.
89 |
90 | 2. You may modify your copy or copies of the Program or any portion
91 | of it, thus forming a work based on the Program, and copy and
92 | distribute such modifications or work under the terms of Section 1
93 | above, provided that you also meet all of these conditions:
94 |
95 | a) You must cause the modified files to carry prominent notices
96 | stating that you changed the files and the date of any change.
97 |
98 | b) You must cause any work that you distribute or publish, that in
99 | whole or in part contains or is derived from the Program or any
100 | part thereof, to be licensed as a whole at no charge to all third
101 | parties under the terms of this License.
102 |
103 | c) If the modified program normally reads commands interactively
104 | when run, you must cause it, when started running for such
105 | interactive use in the most ordinary way, to print or display an
106 | announcement including an appropriate copyright notice and a
107 | notice that there is no warranty (or else, saying that you provide
108 | a warranty) and that users may redistribute the program under
109 | these conditions, and telling the user how to view a copy of this
110 | License. (Exception: if the Program itself is interactive but
111 | does not normally print such an announcement, your work based on
112 | the Program is not required to print an announcement.)
113 |
114 | These requirements apply to the modified work as a whole. If
115 | identifiable sections of that work are not derived from the Program,
116 | and can be reasonably considered independent and separate works in
117 | themselves, then this License, and its terms, do not apply to those
118 | sections when you distribute them as separate works. But when you
119 | distribute the same sections as part of a whole which is a work based
120 | on the Program, the distribution of the whole must be on the terms of
121 | this License, whose permissions for other licensees extend to the
122 | entire whole, and thus to each and every part regardless of who wrote it.
123 |
124 | Thus, it is not the intent of this section to claim rights or contest
125 | your rights to work written entirely by you; rather, the intent is to
126 | exercise the right to control the distribution of derivative or
127 | collective works based on the Program.
128 |
129 | In addition, mere aggregation of another work not based on the Program
130 | with the Program (or with a work based on the Program) on a volume of
131 | a storage or distribution medium does not bring the other work under
132 | the scope of this License.
133 |
134 | 3. You may copy and distribute the Program (or a work based on it,
135 | under Section 2) in object code or executable form under the terms of
136 | Sections 1 and 2 above provided that you also do one of the following:
137 |
138 | a) Accompany it with the complete corresponding machine-readable
139 | source code, which must be distributed under the terms of Sections
140 | 1 and 2 above on a medium customarily used for software interchange; or,
141 |
142 | b) Accompany it with a written offer, valid for at least three
143 | years, to give any third party, for a charge no more than your
144 | cost of physically performing source distribution, a complete
145 | machine-readable copy of the corresponding source code, to be
146 | distributed under the terms of Sections 1 and 2 above on a medium
147 | customarily used for software interchange; or,
148 |
149 | c) Accompany it with the information you received as to the offer
150 | to distribute corresponding source code. (This alternative is
151 | allowed only for noncommercial distribution and only if you
152 | received the program in object code or executable form with such
153 | an offer, in accord with Subsection b above.)
154 |
155 | The source code for a work means the preferred form of the work for
156 | making modifications to it. For an executable work, complete source
157 | code means all the source code for all modules it contains, plus any
158 | associated interface definition files, plus the scripts used to
159 | control compilation and installation of the executable. However, as a
160 | special exception, the source code distributed need not include
161 | anything that is normally distributed (in either source or binary
162 | form) with the major components (compiler, kernel, and so on) of the
163 | operating system on which the executable runs, unless that component
164 | itself accompanies the executable.
165 |
166 | If distribution of executable or object code is made by offering
167 | access to copy from a designated place, then offering equivalent
168 | access to copy the source code from the same place counts as
169 | distribution of the source code, even though third parties are not
170 | compelled to copy the source along with the object code.
171 |
172 | 4. You may not copy, modify, sublicense, or distribute the Program
173 | except as expressly provided under this License. Any attempt
174 | otherwise to copy, modify, sublicense or distribute the Program is
175 | void, and will automatically terminate your rights under this License.
176 | However, parties who have received copies, or rights, from you under
177 | this License will not have their licenses terminated so long as such
178 | parties remain in full compliance.
179 |
180 | 5. You are not required to accept this License, since you have not
181 | signed it. However, nothing else grants you permission to modify or
182 | distribute the Program or its derivative works. These actions are
183 | prohibited by law if you do not accept this License. Therefore, by
184 | modifying or distributing the Program (or any work based on the
185 | Program), you indicate your acceptance of this License to do so, and
186 | all its terms and conditions for copying, distributing or modifying
187 | the Program or works based on it.
188 |
189 | 6. Each time you redistribute the Program (or any work based on the
190 | Program), the recipient automatically receives a license from the
191 | original licensor to copy, distribute or modify the Program subject to
192 | these terms and conditions. You may not impose any further
193 | restrictions on the recipients' exercise of the rights granted herein.
194 | You are not responsible for enforcing compliance by third parties to
195 | this License.
196 |
197 | 7. If, as a consequence of a court judgment or allegation of patent
198 | infringement or for any other reason (not limited to patent issues),
199 | conditions are imposed on you (whether by court order, agreement or
200 | otherwise) that contradict the conditions of this License, they do not
201 | excuse you from the conditions of this License. If you cannot
202 | distribute so as to satisfy simultaneously your obligations under this
203 | License and any other pertinent obligations, then as a consequence you
204 | may not distribute the Program at all. For example, if a patent
205 | license would not permit royalty-free redistribution of the Program by
206 | all those who receive copies directly or indirectly through you, then
207 | the only way you could satisfy both it and this License would be to
208 | refrain entirely from distribution of the Program.
209 |
210 | If any portion of this section is held invalid or unenforceable under
211 | any particular circumstance, the balance of the section is intended to
212 | apply and the section as a whole is intended to apply in other
213 | circumstances.
214 |
215 | It is not the purpose of this section to induce you to infringe any
216 | patents or other property right claims or to contest validity of any
217 | such claims; this section has the sole purpose of protecting the
218 | integrity of the free software distribution system, which is
219 | implemented by public license practices. Many people have made
220 | generous contributions to the wide range of software distributed
221 | through that system in reliance on consistent application of that
222 | system; it is up to the author/donor to decide if he or she is willing
223 | to distribute software through any other system and a licensee cannot
224 | impose that choice.
225 |
226 | This section is intended to make thoroughly clear what is believed to
227 | be a consequence of the rest of this License.
228 |
229 | 8. If the distribution and/or use of the Program is restricted in
230 | certain countries either by patents or by copyrighted interfaces, the
231 | original copyright holder who places the Program under this License
232 | may add an explicit geographical distribution limitation excluding
233 | those countries, so that distribution is permitted only in or among
234 | countries not thus excluded. In such case, this License incorporates
235 | the limitation as if written in the body of this License.
236 |
237 | 9. The Free Software Foundation may publish revised and/or new versions
238 | of the General Public License from time to time. Such new versions will
239 | be similar in spirit to the present version, but may differ in detail to
240 | address new problems or concerns.
241 |
242 | Each version is given a distinguishing version number. If the Program
243 | specifies a version number of this License which applies to it and "any
244 | later version", you have the option of following the terms and conditions
245 | either of that version or of any later version published by the Free
246 | Software Foundation. If the Program does not specify a version number of
247 | this License, you may choose any version ever published by the Free Software
248 | Foundation.
249 |
250 | 10. If you wish to incorporate parts of the Program into other free
251 | programs whose distribution conditions are different, write to the author
252 | to ask for permission. For software which is copyrighted by the Free
253 | Software Foundation, write to the Free Software Foundation; we sometimes
254 | make exceptions for this. Our decision will be guided by the two goals
255 | of preserving the free status of all derivatives of our free software and
256 | of promoting the sharing and reuse of software generally.
257 |
258 | NO WARRANTY
259 |
260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
268 | REPAIR OR CORRECTION.
269 |
270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
278 | POSSIBILITY OF SUCH DAMAGES.
279 |
280 | END OF TERMS AND CONDITIONS
281 |
282 | How to Apply These Terms to Your New Programs
283 |
284 | If you develop a new program, and you want it to be of the greatest
285 | possible use to the public, the best way to achieve this is to make it
286 | free software which everyone can redistribute and change under these terms.
287 |
288 | To do so, attach the following notices to the program. It is safest
289 | to attach them to the start of each source file to most effectively
290 | convey the exclusion of warranty; and each file should have at least
291 | the "copyright" line and a pointer to where the full notice is found.
292 |
293 |
294 | Copyright (C)
295 |
296 | This program is free software; you can redistribute it and/or modify
297 | it under the terms of the GNU General Public License as published by
298 | the Free Software Foundation; either version 2 of the License, or
299 | (at your option) any later version.
300 |
301 | This program is distributed in the hope that it will be useful,
302 | but WITHOUT ANY WARRANTY; without even the implied warranty of
303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
304 | GNU General Public License for more details.
305 |
306 | You should have received a copy of the GNU General Public License along
307 | with this program; if not, write to the Free Software Foundation, Inc.,
308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
309 |
310 | Also add information on how to contact you by electronic and paper mail.
311 |
312 | If the program is interactive, make it output a short notice like this
313 | when it starts in an interactive mode:
314 |
315 | Gnomovision version 69, Copyright (C) year name of author
316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
317 | This is free software, and you are welcome to redistribute it
318 | under certain conditions; type `show c' for details.
319 |
320 | The hypothetical commands `show w' and `show c' should show the appropriate
321 | parts of the General Public License. Of course, the commands you use may
322 | be called something other than `show w' and `show c'; they could even be
323 | mouse-clicks or menu items--whatever suits your program.
324 |
325 | You should also get your employer (if you work as a programmer) or your
326 | school, if any, to sign a "copyright disclaimer" for the program, if
327 | necessary. Here is a sample; alter the names:
328 |
329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program
330 | `Gnomovision' (which makes passes at compilers) written by James Hacker.
331 |
332 | , 1 April 1989
333 | Ty Coon, President of Vice
334 |
335 | This General Public License does not permit incorporating your program into
336 | proprietary programs. If your program is a subroutine library, you may
337 | consider it more useful to permit linking proprietary applications with the
338 | library. If this is what you want to do, use the GNU Lesser General
339 | Public License instead of this License.
340 |
--------------------------------------------------------------------------------
/licenses/dependencies/flake8-license.txt:
--------------------------------------------------------------------------------
1 | == Flake8 License (MIT) ==
2 |
3 | Copyright (C) 2011-2013 Tarek Ziade
4 | Copyright (C) 2012-2016 Ian Cordasco
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy of
7 | this software and associated documentation files (the "Software"), to deal in
8 | the Software without restriction, including without limitation the rights to
9 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
10 | of the Software, and to permit persons to whom the Software is furnished to do
11 | so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/licenses/dependencies/isort-license.txt:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2013 Timothy Edmund Crosley
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/licenses/dependencies/numpy_license.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2005-2021, NumPy Developers.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are
6 | met:
7 |
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above
12 | copyright notice, this list of conditions and the following
13 | disclaimer in the documentation and/or other materials provided
14 | with the distribution.
15 |
16 | * Neither the name of the NumPy Developers nor the names of any
17 | contributors may be used to endorse or promote products derived
18 | from this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/licenses/dependencies/onnx-license.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/licenses/dependencies/pre-commit-hooks-license.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2014 pre-commit dev team: Anthony Sottile, Ken Struys
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/licenses/dependencies/pre-commit-license.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2014 pre-commit dev team: Anthony Sottile, Ken Struys
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/licenses/dependencies/pyright-license.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Robert Craigie
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
24 | ===============================================================================
25 |
26 | MIT License
27 |
28 | Pyright - A static type checker for the Python language
29 | Copyright (c) Microsoft Corporation. All rights reserved.
30 |
31 | Permission is hereby granted, free of charge, to any person obtaining a copy
32 | of this software and associated documentation files (the "Software"), to deal
33 | in the Software without restriction, including without limitation the rights
34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
35 | copies of the Software, and to permit persons to whom the Software is
36 | furnished to do so, subject to the following conditions:
37 |
38 | The above copyright notice and this permission notice shall be included in all
39 | copies or substantial portions of the Software.
40 |
41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
47 | SOFTWARE
48 |
--------------------------------------------------------------------------------
/licenses/dependencies/pyupgrade-license.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Anthony Sottile
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/licenses/dependencies/torch_license.txt:
--------------------------------------------------------------------------------
1 | From PyTorch:
2 |
3 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
4 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
5 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
6 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
7 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
8 | Copyright (c) 2011-2013 NYU (Clement Farabet)
9 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
10 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
11 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
12 |
13 | From Caffe2:
14 |
15 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
16 |
17 | All contributions by Facebook:
18 | Copyright (c) 2016 Facebook Inc.
19 |
20 | All contributions by Google:
21 | Copyright (c) 2015 Google Inc.
22 | All rights reserved.
23 |
24 | All contributions by Yangqing Jia:
25 | Copyright (c) 2015 Yangqing Jia
26 | All rights reserved.
27 |
28 | All contributions by Kakao Brain:
29 | Copyright 2019-2020 Kakao Brain
30 |
31 | All contributions from Caffe:
32 | Copyright(c) 2013, 2014, 2015, the respective contributors
33 | All rights reserved.
34 |
35 | All other contributions:
36 | Copyright(c) 2015, 2016 the respective contributors
37 | All rights reserved.
38 |
39 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
40 | copyright over their contributions to Caffe2. The project versioning records
41 | all such contribution and copyright details. If a contributor wants to further
42 | mark their specific copyright on a particular contribution, they should
43 | indicate their copyright solely in the commit message of the change when it is
44 | committed.
45 |
46 | All rights reserved.
47 |
48 | Redistribution and use in source and binary forms, with or without
49 | modification, are permitted provided that the following conditions are met:
50 |
51 | 1. Redistributions of source code must retain the above copyright
52 | notice, this list of conditions and the following disclaimer.
53 |
54 | 2. Redistributions in binary form must reproduce the above copyright
55 | notice, this list of conditions and the following disclaimer in the
56 | documentation and/or other materials provided with the distribution.
57 |
58 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
59 | and IDIAP Research Institute nor the names of its contributors may be
60 | used to endorse or promote products derived from this software without
61 | specific prior written permission.
62 |
63 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
64 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
65 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
66 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
67 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
68 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
69 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
70 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
71 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
72 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
73 | POSSIBILITY OF SUCH DAMAGE.
74 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "rsl-rl-lib"
7 | version = "2.3.3"
8 | keywords = ["reinforcement-learning", "isaac", "leggedrobotics", "rl-pytorch"]
9 | maintainers = [
10 | { name="Clemens Schwarke", email="cschwarke@ethz.ch" },
11 | { name="Mayank Mittal", email="mittalma@ethz.ch" },
12 | ]
13 | authors = [
14 | { name="Clemens Schwarke", email="cschwarke@ethz.ch" },
15 | { name="Mayank Mittal", email="mittalma@ethz.ch" },
16 | { name="Nikita Rudin", email="rudinn@ethz.ch" },
17 | { name="David Hoeller", email="holler.david78@gmail.com" },
18 | ]
19 | description = "Fast and simple RL algorithms implemented in PyTorch"
20 | readme = { file = "README.md", content-type = "text/markdown"}
21 | license = { text = "BSD-3-Clause" }
22 |
23 | requires-python = ">=3.8"
24 | classifiers = [
25 | "Programming Language :: Python :: 3",
26 | "Operating System :: OS Independent",
27 | ]
28 | dependencies = [
29 | "torch>=1.10.0",
30 | "torchvision>=0.5.0",
31 | "numpy>=1.16.4",
32 | "GitPython",
33 | "onnx",
34 | ]
35 |
36 | [project.urls]
37 | Homepage = "https://github.com/leggedrobotics/rsl_rl"
38 | Issues = "https://github.com/leggedrobotics/rsl_rl/issues"
39 |
40 | [tool.setuptools.packages.find]
41 | where = ["."]
42 | include = ["rsl_rl*"]
43 |
44 | [tool.setuptools.package-data]
45 | "rsl_rl" = ["config/*", "licenses/*"]
46 |
47 | [tool.isort]
48 |
49 | py_version = 37
50 | line_length = 120
51 | group_by_package = true
52 |
53 | # Files to skip
54 | skip_glob = [".vscode/*"]
55 |
56 | # Order of imports
57 | sections = [
58 | "FUTURE",
59 | "STDLIB",
60 | "THIRDPARTY",
61 | "FIRSTPARTY",
62 | "LOCALFOLDER",
63 | ]
64 |
65 | # Extra standard libraries considered as part of python (permissive licenses)
66 | extra_standard_library = [
67 | "numpy",
68 | "torch",
69 | "tensordict",
70 | "warp",
71 | "typing_extensions",
72 | "git",
73 | ]
74 | # Imports from this repository
75 | known_first_party = "rsl_rl"
76 |
77 | [tool.pyright]
78 |
79 | include = ["rsl_rl"]
80 |
81 | typeCheckingMode = "basic"
82 | pythonVersion = "3.7"
83 | pythonPlatform = "Linux"
84 | enableTypeIgnoreComments = true
85 |
86 | # This is required as the CI pre-commit does not download the module (i.e. numpy, torch, prettytable)
87 | # Therefore, we have to ignore missing imports
88 | reportMissingImports = "none"
89 | # This is required to ignore for type checks of modules with stubs missing.
90 | reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers
91 |
92 | reportGeneralTypeIssues = "none" # -> raises 218 errors (usage of literal MISSING in dataclasses)
93 | reportOptionalMemberAccess = "warning" # -> raises 8 errors
94 | reportPrivateUsage = "warning"
95 |
--------------------------------------------------------------------------------
/rsl_rl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | """Main module for the rsl_rl package."""
7 |
--------------------------------------------------------------------------------
/rsl_rl/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | """Implementation of different RL agents."""
7 |
8 | from .distillation import Distillation
9 | from .ppo import PPO
10 |
11 | __all__ = ["PPO", "Distillation"]
12 |
--------------------------------------------------------------------------------
/rsl_rl/algorithms/distillation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | # torch
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 |
11 | # rsl-rl
12 | from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
13 | from rsl_rl.storage import RolloutStorage
14 |
15 |
16 | class Distillation:
17 | """Distillation algorithm for training a student model to mimic a teacher model."""
18 |
19 | policy: StudentTeacher | StudentTeacherRecurrent
20 | """The student teacher model."""
21 |
22 | def __init__(
23 | self,
24 | policy,
25 | num_learning_epochs=1,
26 | gradient_length=15,
27 | learning_rate=1e-3,
28 | max_grad_norm=None,
29 | loss_type="mse",
30 | device="cpu",
31 | # Distributed training parameters
32 | multi_gpu_cfg: dict | None = None,
33 | ):
34 | # device-related parameters
35 | self.device = device
36 | self.is_multi_gpu = multi_gpu_cfg is not None
37 | # Multi-GPU parameters
38 | if multi_gpu_cfg is not None:
39 | self.gpu_global_rank = multi_gpu_cfg["global_rank"]
40 | self.gpu_world_size = multi_gpu_cfg["world_size"]
41 | else:
42 | self.gpu_global_rank = 0
43 | self.gpu_world_size = 1
44 |
45 | self.rnd = None # TODO: remove when runner has a proper base class
46 |
47 | # distillation components
48 | self.policy = policy
49 | self.policy.to(self.device)
50 | self.storage = None # initialized later
51 | self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
52 | self.transition = RolloutStorage.Transition()
53 | self.last_hidden_states = None
54 |
55 | # distillation parameters
56 | self.num_learning_epochs = num_learning_epochs
57 | self.gradient_length = gradient_length
58 | self.learning_rate = learning_rate
59 | self.max_grad_norm = max_grad_norm
60 |
61 | # initialize the loss function
62 | if loss_type == "mse":
63 | self.loss_fn = nn.functional.mse_loss
64 | elif loss_type == "huber":
65 | self.loss_fn = nn.functional.huber_loss
66 | else:
67 | raise ValueError(f"Unknown loss type: {loss_type}. Supported types are: mse, huber")
68 |
69 | self.num_updates = 0
70 |
71 | def init_storage(
72 | self, training_type, num_envs, num_transitions_per_env, student_obs_shape, teacher_obs_shape, actions_shape
73 | ):
74 | # create rollout storage
75 | self.storage = RolloutStorage(
76 | training_type,
77 | num_envs,
78 | num_transitions_per_env,
79 | student_obs_shape,
80 | teacher_obs_shape,
81 | actions_shape,
82 | None,
83 | self.device,
84 | )
85 |
86 | def act(self, obs, teacher_obs):
87 | # compute the actions
88 | self.transition.actions = self.policy.act(obs).detach()
89 | self.transition.privileged_actions = self.policy.evaluate(teacher_obs).detach()
90 | # record the observations
91 | self.transition.observations = obs
92 | self.transition.privileged_observations = teacher_obs
93 | return self.transition.actions
94 |
95 | def process_env_step(self, rewards, dones, infos):
96 | # record the rewards and dones
97 | self.transition.rewards = rewards
98 | self.transition.dones = dones
99 | # record the transition
100 | self.storage.add_transitions(self.transition)
101 | self.transition.clear()
102 | self.policy.reset(dones)
103 |
104 | def update(self):
105 | self.num_updates += 1
106 | mean_behavior_loss = 0
107 | loss = 0
108 | cnt = 0
109 |
110 | for epoch in range(self.num_learning_epochs):
111 | self.policy.reset(hidden_states=self.last_hidden_states)
112 | self.policy.detach_hidden_states()
113 | for obs, _, _, privileged_actions, dones in self.storage.generator():
114 |
115 | # inference the student for gradient computation
116 | actions = self.policy.act_inference(obs)
117 |
118 | # behavior cloning loss
119 | behavior_loss = self.loss_fn(actions, privileged_actions)
120 |
121 | # total loss
122 | loss = loss + behavior_loss
123 | mean_behavior_loss += behavior_loss.item()
124 | cnt += 1
125 |
126 | # gradient step
127 | if cnt % self.gradient_length == 0:
128 | self.optimizer.zero_grad()
129 | loss.backward()
130 | if self.is_multi_gpu:
131 | self.reduce_parameters()
132 | if self.max_grad_norm:
133 | nn.utils.clip_grad_norm_(self.policy.student.parameters(), self.max_grad_norm)
134 | self.optimizer.step()
135 | self.policy.detach_hidden_states()
136 | loss = 0
137 |
138 | # reset dones
139 | self.policy.reset(dones.view(-1))
140 | self.policy.detach_hidden_states(dones.view(-1))
141 |
142 | mean_behavior_loss /= cnt
143 | self.storage.clear()
144 | self.last_hidden_states = self.policy.get_hidden_states()
145 | self.policy.detach_hidden_states()
146 |
147 | # construct the loss dictionary
148 | loss_dict = {"behavior": mean_behavior_loss}
149 |
150 | return loss_dict
151 |
152 | """
153 | Helper functions
154 | """
155 |
156 | def broadcast_parameters(self):
157 | """Broadcast model parameters to all GPUs."""
158 | # obtain the model parameters on current GPU
159 | model_params = [self.policy.state_dict()]
160 | # broadcast the model parameters
161 | torch.distributed.broadcast_object_list(model_params, src=0)
162 | # load the model parameters on all GPUs from source GPU
163 | self.policy.load_state_dict(model_params[0])
164 |
165 | def reduce_parameters(self):
166 | """Collect gradients from all GPUs and average them.
167 |
168 | This function is called after the backward pass to synchronize the gradients across all GPUs.
169 | """
170 | # Create a tensor to store the gradients
171 | grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None]
172 | all_grads = torch.cat(grads)
173 | # Average the gradients across all GPUs
174 | torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
175 | all_grads /= self.gpu_world_size
176 | # Update the gradients for all parameters with the reduced gradients
177 | offset = 0
178 | for param in self.policy.parameters():
179 | if param.grad is not None:
180 | numel = param.numel()
181 | # copy data back from shared buffer
182 | param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data))
183 | # update the offset for the next parameter
184 | offset += numel
185 |
--------------------------------------------------------------------------------
/rsl_rl/algorithms/ppo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | from itertools import chain
12 |
13 | from rsl_rl.modules import ActorCritic
14 | from rsl_rl.modules.rnd import RandomNetworkDistillation
15 | from rsl_rl.storage import RolloutStorage
16 | from rsl_rl.utils import string_to_callable
17 |
18 |
19 | class PPO:
20 | """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""
21 |
22 | policy: ActorCritic
23 | """The actor critic module."""
24 |
25 | def __init__(
26 | self,
27 | policy,
28 | num_learning_epochs=1,
29 | num_mini_batches=1,
30 | clip_param=0.2,
31 | gamma=0.998,
32 | lam=0.95,
33 | value_loss_coef=1.0,
34 | entropy_coef=0.0,
35 | learning_rate=1e-3,
36 | max_grad_norm=1.0,
37 | use_clipped_value_loss=True,
38 | schedule="fixed",
39 | desired_kl=0.01,
40 | device="cpu",
41 | normalize_advantage_per_mini_batch=False,
42 | # RND parameters
43 | rnd_cfg: dict | None = None,
44 | # Symmetry parameters
45 | symmetry_cfg: dict | None = None,
46 | # Distributed training parameters
47 | multi_gpu_cfg: dict | None = None,
48 | ):
49 | # device-related parameters
50 | self.device = device
51 | self.is_multi_gpu = multi_gpu_cfg is not None
52 | # Multi-GPU parameters
53 | if multi_gpu_cfg is not None:
54 | self.gpu_global_rank = multi_gpu_cfg["global_rank"]
55 | self.gpu_world_size = multi_gpu_cfg["world_size"]
56 | else:
57 | self.gpu_global_rank = 0
58 | self.gpu_world_size = 1
59 |
60 | # RND components
61 | if rnd_cfg is not None:
62 | # Extract learning rate and remove it from the original dict
63 | learning_rate = rnd_cfg.pop("learning_rate", 1e-3)
64 | # Create RND module
65 | self.rnd = RandomNetworkDistillation(device=self.device, **rnd_cfg)
66 | # Create RND optimizer
67 | params = self.rnd.predictor.parameters()
68 | self.rnd_optimizer = optim.Adam(params, lr=learning_rate)
69 | else:
70 | self.rnd = None
71 | self.rnd_optimizer = None
72 |
73 | # Symmetry components
74 | if symmetry_cfg is not None:
75 | # Check if symmetry is enabled
76 | use_symmetry = symmetry_cfg["use_data_augmentation"] or symmetry_cfg["use_mirror_loss"]
77 | # Print that we are not using symmetry
78 | if not use_symmetry:
79 | print("Symmetry not used for learning. We will use it for logging instead.")
80 | # If function is a string then resolve it to a function
81 | if isinstance(symmetry_cfg["data_augmentation_func"], str):
82 | symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"])
83 | # Check valid configuration
84 | if symmetry_cfg["use_data_augmentation"] and not callable(symmetry_cfg["data_augmentation_func"]):
85 | raise ValueError(
86 | "Data augmentation enabled but the function is not callable:"
87 | f" {symmetry_cfg['data_augmentation_func']}"
88 | )
89 | # Store symmetry configuration
90 | self.symmetry = symmetry_cfg
91 | else:
92 | self.symmetry = None
93 |
94 | # PPO components
95 | self.policy = policy
96 | self.policy.to(self.device)
97 | # Create optimizer
98 | self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
99 | # Create rollout storage
100 | self.storage: RolloutStorage = None # type: ignore
101 | self.transition = RolloutStorage.Transition()
102 |
103 | # PPO parameters
104 | self.clip_param = clip_param
105 | self.num_learning_epochs = num_learning_epochs
106 | self.num_mini_batches = num_mini_batches
107 | self.value_loss_coef = value_loss_coef
108 | self.entropy_coef = entropy_coef
109 | self.gamma = gamma
110 | self.lam = lam
111 | self.max_grad_norm = max_grad_norm
112 | self.use_clipped_value_loss = use_clipped_value_loss
113 | self.desired_kl = desired_kl
114 | self.schedule = schedule
115 | self.learning_rate = learning_rate
116 | self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
117 |
118 | def init_storage(
119 | self, training_type, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, actions_shape
120 | ):
121 | # create memory for RND as well :)
122 | if self.rnd:
123 | rnd_state_shape = [self.rnd.num_states]
124 | else:
125 | rnd_state_shape = None
126 | # create rollout storage
127 | self.storage = RolloutStorage(
128 | training_type,
129 | num_envs,
130 | num_transitions_per_env,
131 | actor_obs_shape,
132 | critic_obs_shape,
133 | actions_shape,
134 | rnd_state_shape,
135 | self.device,
136 | )
137 |
138 | def act(self, obs, critic_obs):
139 | if self.policy.is_recurrent:
140 | self.transition.hidden_states = self.policy.get_hidden_states()
141 | # compute the actions and values
142 | self.transition.actions = self.policy.act(obs).detach()
143 | self.transition.values = self.policy.evaluate(critic_obs).detach()
144 | self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach()
145 | self.transition.action_mean = self.policy.action_mean.detach()
146 | self.transition.action_sigma = self.policy.action_std.detach()
147 | # need to record obs and critic_obs before env.step()
148 | self.transition.observations = obs
149 | self.transition.privileged_observations = critic_obs
150 | return self.transition.actions
151 |
152 | def process_env_step(self, rewards, dones, infos):
153 | # Record the rewards and dones
154 | # Note: we clone here because later on we bootstrap the rewards based on timeouts
155 | self.transition.rewards = rewards.clone()
156 | self.transition.dones = dones
157 |
158 | # Compute the intrinsic rewards and add to extrinsic rewards
159 | if self.rnd:
160 | # Obtain curiosity gates / observations from infos
161 | rnd_state = infos["observations"]["rnd_state"]
162 | # Compute the intrinsic rewards
163 | # note: rnd_state is the gated_state after normalization if normalization is used
164 | self.intrinsic_rewards, rnd_state = self.rnd.get_intrinsic_reward(rnd_state)
165 | # Add intrinsic rewards to extrinsic rewards
166 | self.transition.rewards += self.intrinsic_rewards
167 | # Record the curiosity gates
168 | self.transition.rnd_state = rnd_state.clone()
169 |
170 | # Bootstrapping on time outs
171 | if "time_outs" in infos:
172 | self.transition.rewards += self.gamma * torch.squeeze(
173 | self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1
174 | )
175 |
176 | # record the transition
177 | self.storage.add_transitions(self.transition)
178 | self.transition.clear()
179 | self.policy.reset(dones)
180 |
181 | def compute_returns(self, last_critic_obs):
182 | # compute value for the last step
183 | last_values = self.policy.evaluate(last_critic_obs).detach()
184 | self.storage.compute_returns(
185 | last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
186 | )
187 |
188 | def update(self): # noqa: C901
189 | mean_value_loss = 0
190 | mean_surrogate_loss = 0
191 | mean_entropy = 0
192 | # -- RND loss
193 | if self.rnd:
194 | mean_rnd_loss = 0
195 | else:
196 | mean_rnd_loss = None
197 | # -- Symmetry loss
198 | if self.symmetry:
199 | mean_symmetry_loss = 0
200 | else:
201 | mean_symmetry_loss = None
202 |
203 | # generator for mini batches
204 | if self.policy.is_recurrent:
205 | generator = self.storage.recurrent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
206 | else:
207 | generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
208 |
209 | # iterate over batches
210 | for (
211 | obs_batch,
212 | critic_obs_batch,
213 | actions_batch,
214 | target_values_batch,
215 | advantages_batch,
216 | returns_batch,
217 | old_actions_log_prob_batch,
218 | old_mu_batch,
219 | old_sigma_batch,
220 | hid_states_batch,
221 | masks_batch,
222 | rnd_state_batch,
223 | ) in generator:
224 |
225 | # number of augmentations per sample
226 | # we start with 1 and increase it if we use symmetry augmentation
227 | num_aug = 1
228 | # original batch size
229 | original_batch_size = obs_batch.shape[0]
230 |
231 | # check if we should normalize advantages per mini batch
232 | if self.normalize_advantage_per_mini_batch:
233 | with torch.no_grad():
234 | advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)
235 |
236 | # Perform symmetric augmentation
237 | if self.symmetry and self.symmetry["use_data_augmentation"]:
238 | # augmentation using symmetry
239 | data_augmentation_func = self.symmetry["data_augmentation_func"]
240 | # returned shape: [batch_size * num_aug, ...]
241 | obs_batch, actions_batch = data_augmentation_func(
242 | obs=obs_batch, actions=actions_batch, env=self.symmetry["_env"], obs_type="policy"
243 | )
244 | critic_obs_batch, _ = data_augmentation_func(
245 | obs=critic_obs_batch, actions=None, env=self.symmetry["_env"], obs_type="critic"
246 | )
247 | # compute number of augmentations per sample
248 | num_aug = int(obs_batch.shape[0] / original_batch_size)
249 | # repeat the rest of the batch
250 | # -- actor
251 | old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1)
252 | # -- critic
253 | target_values_batch = target_values_batch.repeat(num_aug, 1)
254 | advantages_batch = advantages_batch.repeat(num_aug, 1)
255 | returns_batch = returns_batch.repeat(num_aug, 1)
256 |
257 | # Recompute actions log prob and entropy for current batch of transitions
258 | # Note: we need to do this because we updated the policy with the new parameters
259 | # -- actor
260 | self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
261 | actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch)
262 | # -- critic
263 | value_batch = self.policy.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
264 | # -- entropy
265 | # we only keep the entropy of the first augmentation (the original one)
266 | mu_batch = self.policy.action_mean[:original_batch_size]
267 | sigma_batch = self.policy.action_std[:original_batch_size]
268 | entropy_batch = self.policy.entropy[:original_batch_size]
269 |
270 | # KL
271 | if self.desired_kl is not None and self.schedule == "adaptive":
272 | with torch.inference_mode():
273 | kl = torch.sum(
274 | torch.log(sigma_batch / old_sigma_batch + 1.0e-5)
275 | + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
276 | / (2.0 * torch.square(sigma_batch))
277 | - 0.5,
278 | axis=-1,
279 | )
280 | kl_mean = torch.mean(kl)
281 |
282 | # Reduce the KL divergence across all GPUs
283 | if self.is_multi_gpu:
284 | torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM)
285 | kl_mean /= self.gpu_world_size
286 |
287 | # Update the learning rate
288 | # Perform this adaptation only on the main process
289 | # TODO: Is this needed? If KL-divergence is the "same" across all GPUs,
290 | # then the learning rate should be the same across all GPUs.
291 | if self.gpu_global_rank == 0:
292 | if kl_mean > self.desired_kl * 2.0:
293 | self.learning_rate = max(1e-5, self.learning_rate / 1.5)
294 | elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
295 | self.learning_rate = min(1e-2, self.learning_rate * 1.5)
296 |
297 | # Update the learning rate for all GPUs
298 | if self.is_multi_gpu:
299 | lr_tensor = torch.tensor(self.learning_rate, device=self.device)
300 | torch.distributed.broadcast(lr_tensor, src=0)
301 | self.learning_rate = lr_tensor.item()
302 |
303 | # Update the learning rate for all parameter groups
304 | for param_group in self.optimizer.param_groups:
305 | param_group["lr"] = self.learning_rate
306 |
307 | # Surrogate loss
308 | ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
309 | surrogate = -torch.squeeze(advantages_batch) * ratio
310 | surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
311 | ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
312 | )
313 | surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
314 |
315 | # Value function loss
316 | if self.use_clipped_value_loss:
317 | value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
318 | -self.clip_param, self.clip_param
319 | )
320 | value_losses = (value_batch - returns_batch).pow(2)
321 | value_losses_clipped = (value_clipped - returns_batch).pow(2)
322 | value_loss = torch.max(value_losses, value_losses_clipped).mean()
323 | else:
324 | value_loss = (returns_batch - value_batch).pow(2).mean()
325 |
326 | loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
327 |
328 | # Symmetry loss
329 | if self.symmetry:
330 | # obtain the symmetric actions
331 | # if we did augmentation before then we don't need to augment again
332 | if not self.symmetry["use_data_augmentation"]:
333 | data_augmentation_func = self.symmetry["data_augmentation_func"]
334 | obs_batch, _ = data_augmentation_func(
335 | obs=obs_batch, actions=None, env=self.symmetry["_env"], obs_type="policy"
336 | )
337 | # compute number of augmentations per sample
338 | num_aug = int(obs_batch.shape[0] / original_batch_size)
339 |
340 | # actions predicted by the actor for symmetrically-augmented observations
341 | mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone())
342 |
343 | # compute the symmetrically augmented actions
344 | # note: we are assuming the first augmentation is the original one.
345 | # We do not use the action_batch from earlier since that action was sampled from the distribution.
346 | # However, the symmetry loss is computed using the mean of the distribution.
347 | action_mean_orig = mean_actions_batch[:original_batch_size]
348 | _, actions_mean_symm_batch = data_augmentation_func(
349 | obs=None, actions=action_mean_orig, env=self.symmetry["_env"], obs_type="policy"
350 | )
351 |
352 | # compute the loss (we skip the first augmentation as it is the original one)
353 | mse_loss = torch.nn.MSELoss()
354 | symmetry_loss = mse_loss(
355 | mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:]
356 | )
357 | # add the loss to the total loss
358 | if self.symmetry["use_mirror_loss"]:
359 | loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss
360 | else:
361 | symmetry_loss = symmetry_loss.detach()
362 |
363 | # Random Network Distillation loss
364 | if self.rnd:
365 | # predict the embedding and the target
366 | predicted_embedding = self.rnd.predictor(rnd_state_batch)
367 | target_embedding = self.rnd.target(rnd_state_batch).detach()
368 | # compute the loss as the mean squared error
369 | mseloss = torch.nn.MSELoss()
370 | rnd_loss = mseloss(predicted_embedding, target_embedding)
371 |
372 | # Compute the gradients
373 | # -- For PPO
374 | self.optimizer.zero_grad()
375 | loss.backward()
376 | # -- For RND
377 | if self.rnd:
378 | self.rnd_optimizer.zero_grad() # type: ignore
379 | rnd_loss.backward()
380 |
381 | # Collect gradients from all GPUs
382 | if self.is_multi_gpu:
383 | self.reduce_parameters()
384 |
385 | # Apply the gradients
386 | # -- For PPO
387 | nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
388 | self.optimizer.step()
389 | # -- For RND
390 | if self.rnd_optimizer:
391 | self.rnd_optimizer.step()
392 |
393 | # Store the losses
394 | mean_value_loss += value_loss.item()
395 | mean_surrogate_loss += surrogate_loss.item()
396 | mean_entropy += entropy_batch.mean().item()
397 | # -- RND loss
398 | if mean_rnd_loss is not None:
399 | mean_rnd_loss += rnd_loss.item()
400 | # -- Symmetry loss
401 | if mean_symmetry_loss is not None:
402 | mean_symmetry_loss += symmetry_loss.item()
403 |
404 | # -- For PPO
405 | num_updates = self.num_learning_epochs * self.num_mini_batches
406 | mean_value_loss /= num_updates
407 | mean_surrogate_loss /= num_updates
408 | mean_entropy /= num_updates
409 | # -- For RND
410 | if mean_rnd_loss is not None:
411 | mean_rnd_loss /= num_updates
412 | # -- For Symmetry
413 | if mean_symmetry_loss is not None:
414 | mean_symmetry_loss /= num_updates
415 | # -- Clear the storage
416 | self.storage.clear()
417 |
418 | # construct the loss dictionary
419 | loss_dict = {
420 | "value_function": mean_value_loss,
421 | "surrogate": mean_surrogate_loss,
422 | "entropy": mean_entropy,
423 | }
424 | if self.rnd:
425 | loss_dict["rnd"] = mean_rnd_loss
426 | if self.symmetry:
427 | loss_dict["symmetry"] = mean_symmetry_loss
428 |
429 | return loss_dict
430 |
431 | """
432 | Helper functions
433 | """
434 |
435 | def broadcast_parameters(self):
436 | """Broadcast model parameters to all GPUs."""
437 | # obtain the model parameters on current GPU
438 | model_params = [self.policy.state_dict()]
439 | if self.rnd:
440 | model_params.append(self.rnd.predictor.state_dict())
441 | # broadcast the model parameters
442 | torch.distributed.broadcast_object_list(model_params, src=0)
443 | # load the model parameters on all GPUs from source GPU
444 | self.policy.load_state_dict(model_params[0])
445 | if self.rnd:
446 | self.rnd.predictor.load_state_dict(model_params[1])
447 |
448 | def reduce_parameters(self):
449 | """Collect gradients from all GPUs and average them.
450 |
451 | This function is called after the backward pass to synchronize the gradients across all GPUs.
452 | """
453 | # Create a tensor to store the gradients
454 | grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None]
455 | if self.rnd:
456 | grads += [param.grad.view(-1) for param in self.rnd.parameters() if param.grad is not None]
457 | all_grads = torch.cat(grads)
458 |
459 | # Average the gradients across all GPUs
460 | torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
461 | all_grads /= self.gpu_world_size
462 |
463 | # Get all parameters
464 | all_params = self.policy.parameters()
465 | if self.rnd:
466 | all_params = chain(all_params, self.rnd.parameters())
467 |
468 | # Update the gradients for all parameters with the reduced gradients
469 | offset = 0
470 | for param in all_params:
471 | if param.grad is not None:
472 | numel = param.numel()
473 | # copy data back from shared buffer
474 | param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data))
475 | # update the offset for the next parameter
476 | offset += numel
477 |
--------------------------------------------------------------------------------
/rsl_rl/env/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | """Submodule defining the environment definitions."""
7 |
8 | from .vec_env import VecEnv
9 |
10 | __all__ = ["VecEnv"]
11 |
--------------------------------------------------------------------------------
/rsl_rl/env/vec_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import torch
9 | from abc import ABC, abstractmethod
10 |
11 |
12 | class VecEnv(ABC):
13 | """Abstract class for vectorized environment.
14 |
15 | The vectorized environment is a collection of environments that are synchronized. This means that
16 | the same action is applied to all environments and the same observation is returned from all environments.
17 |
18 | All extra observations must be provided as a dictionary to "extras" in the step() method. Based on the
19 | configuration, the extra observations are used for different purposes. The following keys are used by the
20 | environment:
21 |
22 | - "observations" (dict[str, dict[str, torch.Tensor]]):
23 | Additional observations that are not used by the actor networks. The keys are the names of the observations
24 | and the values are the observations themselves. The following are reserved keys for the observations:
25 |
26 | - "critic": The observation is used as input to the critic network. Useful for asymmetric observation spaces.
27 | - "rnd_state": The observation is used as input to the RND network. Useful for random network distillation.
28 |
29 | - "time_outs" (torch.Tensor): Timeouts for the environments. These correspond to terminations that happen due to time limits and
30 | not due to the environment reaching a terminal state. This is useful for environments that have a fixed
31 | episode length.
32 |
33 | - "log" (dict[str, float | torch.Tensor]): Additional information for logging and debugging purposes.
34 | The key should be a string and start with "/" for namespacing. The value can be a scalar or a tensor.
35 | If it is a tensor, the mean of the tensor is used for logging.
36 |
37 | .. deprecated:: 2.0.0
38 |
39 | Use "log" in the extra information dictionary instead of the "episode" key.
40 |
41 | """
42 |
43 | num_envs: int
44 | """Number of environments."""
45 |
46 | num_actions: int
47 | """Number of actions."""
48 |
49 | max_episode_length: int | torch.Tensor
50 | """Maximum episode length.
51 |
52 | The maximum episode length can be a scalar or a tensor. If it is a scalar, it is the same for all environments.
53 | If it is a tensor, it is the maximum episode length for each environment. This is useful for dynamic episode
54 | lengths.
55 | """
56 |
57 | episode_length_buf: torch.Tensor
58 | """Buffer for current episode lengths."""
59 |
60 | device: torch.device
61 | """Device to use."""
62 |
63 | cfg: dict | object
64 | """Configuration object."""
65 |
66 | """
67 | Operations.
68 | """
69 |
70 | @abstractmethod
71 | def get_observations(self) -> tuple[torch.Tensor, dict]:
72 | """Return the current observations.
73 |
74 | Returns:
75 | Tuple containing the observations and extras.
76 | """
77 | raise NotImplementedError
78 |
79 | @abstractmethod
80 | def reset(self) -> tuple[torch.Tensor, dict]:
81 | """Reset all environment instances.
82 |
83 | Returns:
84 | Tuple containing the observations and extras.
85 | """
86 | raise NotImplementedError
87 |
88 | @abstractmethod
89 | def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
90 | """Apply input action on the environment.
91 |
92 | The extra information is a dictionary. It includes metrics such as the episode reward, episode length,
93 | etc. Additional information can be stored in the dictionary such as observations for the critic network, etc.
94 |
95 | Args:
96 | actions: Input actions to apply. Shape: (num_envs, num_actions)
97 |
98 | Returns:
99 | A tuple containing the observations, rewards, dones and extra information (metrics).
100 | """
101 | raise NotImplementedError
102 |
--------------------------------------------------------------------------------
/rsl_rl/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | """Definitions for neural-network components for RL-agents."""
7 |
8 | from .actor_critic import ActorCritic
9 | from .actor_critic_recurrent import ActorCriticRecurrent
10 | from .normalizer import EmpiricalNormalization
11 | from .rnd import RandomNetworkDistillation
12 | from .student_teacher import StudentTeacher
13 | from .student_teacher_recurrent import StudentTeacherRecurrent
14 |
15 | __all__ = [
16 | "ActorCritic",
17 | "ActorCriticRecurrent",
18 | "EmpiricalNormalization",
19 | "RandomNetworkDistillation",
20 | "StudentTeacher",
21 | "StudentTeacherRecurrent",
22 | ]
23 |
--------------------------------------------------------------------------------
/rsl_rl/modules/actor_critic.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.distributions import Normal
11 |
12 | from rsl_rl.utils import resolve_nn_activation
13 |
14 |
15 | class ActorCritic(nn.Module):
16 | is_recurrent = False
17 |
18 | def __init__(
19 | self,
20 | num_actor_obs,
21 | num_critic_obs,
22 | num_actions,
23 | actor_hidden_dims=[256, 256, 256],
24 | critic_hidden_dims=[256, 256, 256],
25 | activation="elu",
26 | init_noise_std=1.0,
27 | noise_std_type: str = "scalar",
28 | **kwargs,
29 | ):
30 | if kwargs:
31 | print(
32 | "ActorCritic.__init__ got unexpected arguments, which will be ignored: "
33 | + str([key for key in kwargs.keys()])
34 | )
35 | super().__init__()
36 | activation = resolve_nn_activation(activation)
37 |
38 | mlp_input_dim_a = num_actor_obs
39 | mlp_input_dim_c = num_critic_obs
40 | # Policy
41 | actor_layers = []
42 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
43 | actor_layers.append(activation)
44 | for layer_index in range(len(actor_hidden_dims)):
45 | if layer_index == len(actor_hidden_dims) - 1:
46 | actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions))
47 | else:
48 | actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
49 | actor_layers.append(activation)
50 | self.actor = nn.Sequential(*actor_layers)
51 |
52 | # Value function
53 | critic_layers = []
54 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
55 | critic_layers.append(activation)
56 | for layer_index in range(len(critic_hidden_dims)):
57 | if layer_index == len(critic_hidden_dims) - 1:
58 | critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1))
59 | else:
60 | critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1]))
61 | critic_layers.append(activation)
62 | self.critic = nn.Sequential(*critic_layers)
63 |
64 | print(f"Actor MLP: {self.actor}")
65 | print(f"Critic MLP: {self.critic}")
66 |
67 | # Action noise
68 | self.noise_std_type = noise_std_type
69 | if self.noise_std_type == "scalar":
70 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
71 | elif self.noise_std_type == "log":
72 | self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
73 | else:
74 | raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
75 |
76 | # Action distribution (populated in update_distribution)
77 | self.distribution = None
78 | # disable args validation for speedup
79 | Normal.set_default_validate_args(False)
80 |
81 | @staticmethod
82 | # not used at the moment
83 | def init_weights(sequential, scales):
84 | [
85 | torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
86 | for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
87 | ]
88 |
89 | def reset(self, dones=None):
90 | pass
91 |
92 | def forward(self):
93 | raise NotImplementedError
94 |
95 | @property
96 | def action_mean(self):
97 | return self.distribution.mean
98 |
99 | @property
100 | def action_std(self):
101 | return self.distribution.stddev
102 |
103 | @property
104 | def entropy(self):
105 | return self.distribution.entropy().sum(dim=-1)
106 |
107 | def update_distribution(self, observations):
108 | # compute mean
109 | mean = self.actor(observations)
110 | # compute standard deviation
111 | if self.noise_std_type == "scalar":
112 | std = self.std.expand_as(mean)
113 | elif self.noise_std_type == "log":
114 | std = torch.exp(self.log_std).expand_as(mean)
115 | else:
116 | raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
117 | # create distribution
118 | self.distribution = Normal(mean, std)
119 |
120 | def act(self, observations, **kwargs):
121 | self.update_distribution(observations)
122 | return self.distribution.sample()
123 |
124 | def get_actions_log_prob(self, actions):
125 | return self.distribution.log_prob(actions).sum(dim=-1)
126 |
127 | def act_inference(self, observations):
128 | actions_mean = self.actor(observations)
129 | return actions_mean
130 |
131 | def evaluate(self, critic_observations, **kwargs):
132 | value = self.critic(critic_observations)
133 | return value
134 |
135 | def load_state_dict(self, state_dict, strict=True):
136 | """Load the parameters of the actor-critic model.
137 |
138 | Args:
139 | state_dict (dict): State dictionary of the model.
140 | strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
141 | module's state_dict() function.
142 |
143 | Returns:
144 | bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
145 | `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
146 | """
147 |
148 | super().load_state_dict(state_dict, strict=strict)
149 | return True
150 |
--------------------------------------------------------------------------------
/rsl_rl/modules/actor_critic_recurrent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import warnings
9 |
10 | from rsl_rl.modules import ActorCritic
11 | from rsl_rl.networks import Memory
12 | from rsl_rl.utils import resolve_nn_activation
13 |
14 |
15 | class ActorCriticRecurrent(ActorCritic):
16 | is_recurrent = True
17 |
18 | def __init__(
19 | self,
20 | num_actor_obs,
21 | num_critic_obs,
22 | num_actions,
23 | actor_hidden_dims=[256, 256, 256],
24 | critic_hidden_dims=[256, 256, 256],
25 | activation="elu",
26 | rnn_type="lstm",
27 | rnn_hidden_dim=256,
28 | rnn_num_layers=1,
29 | init_noise_std=1.0,
30 | **kwargs,
31 | ):
32 | if "rnn_hidden_size" in kwargs:
33 | warnings.warn(
34 | "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
35 | "Please use `rnn_hidden_dim` instead.",
36 | DeprecationWarning,
37 | )
38 | if rnn_hidden_dim == 256: # Only override if the new argument is at its default
39 | rnn_hidden_dim = kwargs.pop("rnn_hidden_size")
40 | if kwargs:
41 | print(
42 | "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),
43 | )
44 |
45 | super().__init__(
46 | num_actor_obs=rnn_hidden_dim,
47 | num_critic_obs=rnn_hidden_dim,
48 | num_actions=num_actions,
49 | actor_hidden_dims=actor_hidden_dims,
50 | critic_hidden_dims=critic_hidden_dims,
51 | activation=activation,
52 | init_noise_std=init_noise_std,
53 | )
54 |
55 | activation = resolve_nn_activation(activation)
56 |
57 | self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
58 | self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
59 |
60 | print(f"Actor RNN: {self.memory_a}")
61 | print(f"Critic RNN: {self.memory_c}")
62 |
63 | def reset(self, dones=None):
64 | self.memory_a.reset(dones)
65 | self.memory_c.reset(dones)
66 |
67 | def act(self, observations, masks=None, hidden_states=None):
68 | input_a = self.memory_a(observations, masks, hidden_states)
69 | return super().act(input_a.squeeze(0))
70 |
71 | def act_inference(self, observations):
72 | input_a = self.memory_a(observations)
73 | return super().act_inference(input_a.squeeze(0))
74 |
75 | def evaluate(self, critic_observations, masks=None, hidden_states=None):
76 | input_c = self.memory_c(critic_observations, masks, hidden_states)
77 | return super().evaluate(input_c.squeeze(0))
78 |
79 | def get_hidden_states(self):
80 | return self.memory_a.hidden_states, self.memory_c.hidden_states
81 |
--------------------------------------------------------------------------------
/rsl_rl/modules/normalizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | # Copyright (c) 2020 Preferred Networks, Inc.
7 |
8 | from __future__ import annotations
9 |
10 | import torch
11 | from torch import nn
12 |
13 |
14 | class EmpiricalNormalization(nn.Module):
15 | """Normalize mean and variance of values based on empirical values."""
16 |
17 | def __init__(self, shape, eps=1e-2, until=None):
18 | """Initialize EmpiricalNormalization module.
19 |
20 | Args:
21 | shape (int or tuple of int): Shape of input values except batch axis.
22 | eps (float): Small value for stability.
23 | until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
24 | exceeds it.
25 | """
26 | super().__init__()
27 | self.eps = eps
28 | self.until = until
29 | self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0))
30 | self.register_buffer("_var", torch.ones(shape).unsqueeze(0))
31 | self.register_buffer("_std", torch.ones(shape).unsqueeze(0))
32 | self.register_buffer("count", torch.tensor(0, dtype=torch.long))
33 |
34 | @property
35 | def mean(self):
36 | return self._mean.squeeze(0).clone()
37 |
38 | @property
39 | def std(self):
40 | return self._std.squeeze(0).clone()
41 |
42 | def forward(self, x):
43 | """Normalize mean and variance of values based on empirical values.
44 |
45 | Args:
46 | x (ndarray or Variable): Input values
47 |
48 | Returns:
49 | ndarray or Variable: Normalized output values
50 | """
51 |
52 | if self.training:
53 | self.update(x)
54 | return (x - self._mean) / (self._std + self.eps)
55 |
56 | @torch.jit.unused
57 | def update(self, x):
58 | """Learn input values without computing the output values of them"""
59 |
60 | if self.until is not None and self.count >= self.until:
61 | return
62 |
63 | count_x = x.shape[0]
64 | self.count += count_x
65 | rate = count_x / self.count
66 |
67 | var_x = torch.var(x, dim=0, unbiased=False, keepdim=True)
68 | mean_x = torch.mean(x, dim=0, keepdim=True)
69 | delta_mean = mean_x - self._mean
70 | self._mean += rate * delta_mean
71 | self._var += rate * (var_x - self._var + delta_mean * (mean_x - self._mean))
72 | self._std = torch.sqrt(self._var)
73 |
74 | @torch.jit.unused
75 | def inverse(self, y):
76 | return y * (self._std + self.eps) + self._mean
77 |
78 |
79 | class EmpiricalDiscountedVariationNormalization(nn.Module):
80 | """Reward normalization from Pathak's large scale study on PPO.
81 |
82 | Reward normalization. Since the reward function is non-stationary, it is useful to normalize
83 | the scale of the rewards so that the value function can learn quickly. We did this by dividing
84 | the rewards by a running estimate of the standard deviation of the sum of discounted rewards.
85 | """
86 |
87 | def __init__(self, shape, eps=1e-2, gamma=0.99, until=None):
88 | super().__init__()
89 |
90 | self.emp_norm = EmpiricalNormalization(shape, eps, until)
91 | self.disc_avg = DiscountedAverage(gamma)
92 |
93 | def forward(self, rew):
94 | if self.training:
95 | # update discounected rewards
96 | avg = self.disc_avg.update(rew)
97 |
98 | # update moments from discounted rewards
99 | self.emp_norm.update(avg)
100 |
101 | if self.emp_norm._std > 0:
102 | return rew / self.emp_norm._std
103 | else:
104 | return rew
105 |
106 |
107 | class DiscountedAverage:
108 | r"""Discounted average of rewards.
109 |
110 | The discounted average is defined as:
111 |
112 | .. math::
113 |
114 | \bar{R}_t = \gamma \bar{R}_{t-1} + r_t
115 |
116 | Args:
117 | gamma (float): Discount factor.
118 | """
119 |
120 | def __init__(self, gamma):
121 | self.avg = None
122 | self.gamma = gamma
123 |
124 | def update(self, rew: torch.Tensor) -> torch.Tensor:
125 | if self.avg is None:
126 | self.avg = rew
127 | else:
128 | self.avg = self.avg * self.gamma + rew
129 | return self.avg
130 |
--------------------------------------------------------------------------------
/rsl_rl/modules/rnd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from rsl_rl.modules.normalizer import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
12 | from rsl_rl.utils import resolve_nn_activation
13 |
14 |
15 | class RandomNetworkDistillation(nn.Module):
16 | """Implementation of Random Network Distillation (RND) [1]
17 |
18 | References:
19 | .. [1] Burda, Yuri, et al. "Exploration by random network distillation." arXiv preprint arXiv:1810.12894 (2018).
20 | """
21 |
22 | def __init__(
23 | self,
24 | num_states: int,
25 | num_outputs: int,
26 | predictor_hidden_dims: list[int],
27 | target_hidden_dims: list[int],
28 | activation: str = "elu",
29 | weight: float = 0.0,
30 | state_normalization: bool = False,
31 | reward_normalization: bool = False,
32 | device: str = "cpu",
33 | weight_schedule: dict | None = None,
34 | ):
35 | """Initialize the RND module.
36 |
37 | - If :attr:`state_normalization` is True, then the input state is normalized using an Empirical Normalization layer.
38 | - If :attr:`reward_normalization` is True, then the intrinsic reward is normalized using an Empirical Discounted
39 | Variation Normalization layer.
40 |
41 | .. note::
42 | If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states
43 | is used as the hidden dimension.
44 |
45 | Args:
46 | num_states: Number of states/inputs to the predictor and target networks.
47 | num_outputs: Number of outputs (embedding size) of the predictor and target networks.
48 | predictor_hidden_dims: List of hidden dimensions of the predictor network.
49 | target_hidden_dims: List of hidden dimensions of the target network.
50 | activation: Activation function. Defaults to "elu".
51 | weight: Scaling factor of the intrinsic reward. Defaults to 0.0.
52 | state_normalization: Whether to normalize the input state. Defaults to False.
53 | reward_normalization: Whether to normalize the intrinsic reward. Defaults to False.
54 | device: Device to use. Defaults to "cpu".
55 | weight_schedule: The type of schedule to use for the RND weight parameter.
56 | Defaults to None, in which case the weight parameter is constant.
57 | It is a dictionary with the following keys:
58 |
59 | - "mode": The type of schedule to use for the RND weight parameter.
60 | - "constant": Constant weight schedule.
61 | - "step": Step weight schedule.
62 | - "linear": Linear weight schedule.
63 |
64 | For the "step" weight schedule, the following parameters are required:
65 |
66 | - "final_step": The step at which the weight parameter is set to the final value.
67 | - "final_value": The final value of the weight parameter.
68 |
69 | For the "linear" weight schedule, the following parameters are required:
70 | - "initial_step": The step at which the weight parameter is set to the initial value.
71 | - "final_step": The step at which the weight parameter is set to the final value.
72 | - "final_value": The final value of the weight parameter.
73 | """
74 | # initialize parent class
75 | super().__init__()
76 |
77 | # Store parameters
78 | self.num_states = num_states
79 | self.num_outputs = num_outputs
80 | self.initial_weight = weight
81 | self.device = device
82 | self.state_normalization = state_normalization
83 | self.reward_normalization = reward_normalization
84 |
85 | # Normalization of input gates
86 | if state_normalization:
87 | self.state_normalizer = EmpiricalNormalization(shape=[self.num_states], until=1.0e8).to(self.device)
88 | else:
89 | self.state_normalizer = torch.nn.Identity()
90 | # Normalization of intrinsic reward
91 | if reward_normalization:
92 | self.reward_normalizer = EmpiricalDiscountedVariationNormalization(shape=[], until=1.0e8).to(self.device)
93 | else:
94 | self.reward_normalizer = torch.nn.Identity()
95 |
96 | # counter for the number of updates
97 | self.update_counter = 0
98 |
99 | # resolve weight schedule
100 | if weight_schedule is not None:
101 | self.weight_scheduler_params = weight_schedule
102 | self.weight_scheduler = getattr(self, f"_{weight_schedule['mode']}_weight_schedule")
103 | else:
104 | self.weight_scheduler = None
105 | # Create network architecture
106 | self.predictor = self._build_mlp(num_states, predictor_hidden_dims, num_outputs, activation).to(self.device)
107 | self.target = self._build_mlp(num_states, target_hidden_dims, num_outputs, activation).to(self.device)
108 |
109 | # make target network not trainable
110 | self.target.eval()
111 |
112 | def get_intrinsic_reward(self, rnd_state) -> tuple[torch.Tensor, torch.Tensor]:
113 | # note: the counter is updated number of env steps per learning iteration
114 | self.update_counter += 1
115 | # Normalize rnd state
116 | rnd_state = self.state_normalizer(rnd_state)
117 | # Obtain the embedding of the rnd state from the target and predictor networks
118 | target_embedding = self.target(rnd_state).detach()
119 | predictor_embedding = self.predictor(rnd_state).detach()
120 | # Compute the intrinsic reward as the distance between the embeddings
121 | intrinsic_reward = torch.linalg.norm(target_embedding - predictor_embedding, dim=1)
122 | # Normalize intrinsic reward
123 | intrinsic_reward = self.reward_normalizer(intrinsic_reward)
124 |
125 | # Check the weight schedule
126 | if self.weight_scheduler is not None:
127 | self.weight = self.weight_scheduler(step=self.update_counter, **self.weight_scheduler_params)
128 | else:
129 | self.weight = self.initial_weight
130 | # Scale intrinsic reward
131 | intrinsic_reward *= self.weight
132 |
133 | return intrinsic_reward, rnd_state
134 |
135 | def forward(self, *args, **kwargs):
136 | raise RuntimeError("Forward method is not implemented. Use get_intrinsic_reward instead.")
137 |
138 | def train(self, mode: bool = True):
139 | # sets module into training mode
140 | self.predictor.train(mode)
141 | if self.state_normalization:
142 | self.state_normalizer.train(mode)
143 | if self.reward_normalization:
144 | self.reward_normalizer.train(mode)
145 | return self
146 |
147 | def eval(self):
148 | return self.train(False)
149 |
150 | """
151 | Private Methods
152 | """
153 |
154 | @staticmethod
155 | def _build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"):
156 | """Builds target and predictor networks"""
157 |
158 | network_layers = []
159 | # resolve hidden dimensions
160 | # if dims is -1 then we use the number of observations
161 | hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
162 | # resolve activation function
163 | activation = resolve_nn_activation(activation_name)
164 | # first layer
165 | network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
166 | network_layers.append(activation)
167 | # subsequent layers
168 | for layer_index in range(len(hidden_dims)):
169 | if layer_index == len(hidden_dims) - 1:
170 | # last layer
171 | network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
172 | else:
173 | # hidden layers
174 | network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
175 | network_layers.append(activation)
176 | return nn.Sequential(*network_layers)
177 |
178 | """
179 | Different weight schedules.
180 | """
181 |
182 | def _constant_weight_schedule(self, step: int, **kwargs):
183 | return self.initial_weight
184 |
185 | def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs):
186 | return self.initial_weight if step < final_step else final_value
187 |
188 | def _linear_weight_schedule(self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs):
189 | if step < initial_step:
190 | return self.initial_weight
191 | elif step > final_step:
192 | return final_value
193 | else:
194 | return self.initial_weight + (final_value - self.initial_weight) * (step - initial_step) / (
195 | final_step - initial_step
196 | )
197 |
--------------------------------------------------------------------------------
/rsl_rl/modules/student_teacher.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.distributions import Normal
11 |
12 | from rsl_rl.utils import resolve_nn_activation
13 |
14 |
15 | class StudentTeacher(nn.Module):
16 | is_recurrent = False
17 |
18 | def __init__(
19 | self,
20 | num_student_obs,
21 | num_teacher_obs,
22 | num_actions,
23 | student_hidden_dims=[256, 256, 256],
24 | teacher_hidden_dims=[256, 256, 256],
25 | activation="elu",
26 | init_noise_std=0.1,
27 | **kwargs,
28 | ):
29 | if kwargs:
30 | print(
31 | "StudentTeacher.__init__ got unexpected arguments, which will be ignored: "
32 | + str([key for key in kwargs.keys()])
33 | )
34 | super().__init__()
35 | activation = resolve_nn_activation(activation)
36 | self.loaded_teacher = False # indicates if teacher has been loaded
37 |
38 | mlp_input_dim_s = num_student_obs
39 | mlp_input_dim_t = num_teacher_obs
40 |
41 | # student
42 | student_layers = []
43 | student_layers.append(nn.Linear(mlp_input_dim_s, student_hidden_dims[0]))
44 | student_layers.append(activation)
45 | for layer_index in range(len(student_hidden_dims)):
46 | if layer_index == len(student_hidden_dims) - 1:
47 | student_layers.append(nn.Linear(student_hidden_dims[layer_index], num_actions))
48 | else:
49 | student_layers.append(nn.Linear(student_hidden_dims[layer_index], student_hidden_dims[layer_index + 1]))
50 | student_layers.append(activation)
51 | self.student = nn.Sequential(*student_layers)
52 |
53 | # teacher
54 | teacher_layers = []
55 | teacher_layers.append(nn.Linear(mlp_input_dim_t, teacher_hidden_dims[0]))
56 | teacher_layers.append(activation)
57 | for layer_index in range(len(teacher_hidden_dims)):
58 | if layer_index == len(teacher_hidden_dims) - 1:
59 | teacher_layers.append(nn.Linear(teacher_hidden_dims[layer_index], num_actions))
60 | else:
61 | teacher_layers.append(nn.Linear(teacher_hidden_dims[layer_index], teacher_hidden_dims[layer_index + 1]))
62 | teacher_layers.append(activation)
63 | self.teacher = nn.Sequential(*teacher_layers)
64 | self.teacher.eval()
65 |
66 | print(f"Student MLP: {self.student}")
67 | print(f"Teacher MLP: {self.teacher}")
68 |
69 | # action noise
70 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
71 | self.distribution = None
72 | # disable args validation for speedup
73 | Normal.set_default_validate_args = False
74 |
75 | def reset(self, dones=None, hidden_states=None):
76 | pass
77 |
78 | def forward(self):
79 | raise NotImplementedError
80 |
81 | @property
82 | def action_mean(self):
83 | return self.distribution.mean
84 |
85 | @property
86 | def action_std(self):
87 | return self.distribution.stddev
88 |
89 | @property
90 | def entropy(self):
91 | return self.distribution.entropy().sum(dim=-1)
92 |
93 | def update_distribution(self, observations):
94 | mean = self.student(observations)
95 | std = self.std.expand_as(mean)
96 | self.distribution = Normal(mean, std)
97 |
98 | def act(self, observations):
99 | self.update_distribution(observations)
100 | return self.distribution.sample()
101 |
102 | def act_inference(self, observations):
103 | actions_mean = self.student(observations)
104 | return actions_mean
105 |
106 | def evaluate(self, teacher_observations):
107 | with torch.no_grad():
108 | actions = self.teacher(teacher_observations)
109 | return actions
110 |
111 | def load_state_dict(self, state_dict, strict=True):
112 | """Load the parameters of the student and teacher networks.
113 |
114 | Args:
115 | state_dict (dict): State dictionary of the model.
116 | strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
117 | module's state_dict() function.
118 |
119 | Returns:
120 | bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
121 | `OnPolicyRunner` to determine how to load further parameters.
122 | """
123 |
124 | # check if state_dict contains teacher and student or just teacher parameters
125 | if any("actor" in key for key in state_dict.keys()): # loading parameters from rl training
126 | # rename keys to match teacher and remove critic parameters
127 | teacher_state_dict = {}
128 | for key, value in state_dict.items():
129 | if "actor." in key:
130 | teacher_state_dict[key.replace("actor.", "")] = value
131 | self.teacher.load_state_dict(teacher_state_dict, strict=strict)
132 | # also load recurrent memory if teacher is recurrent
133 | if self.is_recurrent and self.teacher_recurrent:
134 | raise NotImplementedError("Loading recurrent memory for the teacher is not implemented yet") # TODO
135 | # set flag for successfully loading the parameters
136 | self.loaded_teacher = True
137 | self.teacher.eval()
138 | return False
139 | elif any("student" in key for key in state_dict.keys()): # loading parameters from distillation training
140 | super().load_state_dict(state_dict, strict=strict)
141 | # set flag for successfully loading the parameters
142 | self.loaded_teacher = True
143 | self.teacher.eval()
144 | return True
145 | else:
146 | raise ValueError("state_dict does not contain student or teacher parameters")
147 |
148 | def get_hidden_states(self):
149 | return None
150 |
151 | def detach_hidden_states(self, dones=None):
152 | pass
153 |
--------------------------------------------------------------------------------
/rsl_rl/modules/student_teacher_recurrent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import warnings
9 |
10 | from rsl_rl.modules import StudentTeacher
11 | from rsl_rl.networks import Memory
12 | from rsl_rl.utils import resolve_nn_activation
13 |
14 |
15 | class StudentTeacherRecurrent(StudentTeacher):
16 | is_recurrent = True
17 |
18 | def __init__(
19 | self,
20 | num_student_obs,
21 | num_teacher_obs,
22 | num_actions,
23 | student_hidden_dims=[256, 256, 256],
24 | teacher_hidden_dims=[256, 256, 256],
25 | activation="elu",
26 | rnn_type="lstm",
27 | rnn_hidden_dim=256,
28 | rnn_num_layers=1,
29 | init_noise_std=0.1,
30 | teacher_recurrent=False,
31 | **kwargs,
32 | ):
33 | if "rnn_hidden_size" in kwargs:
34 | warnings.warn(
35 | "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
36 | "Please use `rnn_hidden_dim` instead.",
37 | DeprecationWarning,
38 | )
39 | if rnn_hidden_dim == 256: # Only override if the new argument is at its default
40 | rnn_hidden_dim = kwargs.pop("rnn_hidden_size")
41 | if kwargs:
42 | print(
43 | "StudentTeacherRecurrent.__init__ got unexpected arguments, which will be ignored: "
44 | + str(kwargs.keys()),
45 | )
46 |
47 | self.teacher_recurrent = teacher_recurrent
48 |
49 | super().__init__(
50 | num_student_obs=rnn_hidden_dim,
51 | num_teacher_obs=rnn_hidden_dim if teacher_recurrent else num_teacher_obs,
52 | num_actions=num_actions,
53 | student_hidden_dims=student_hidden_dims,
54 | teacher_hidden_dims=teacher_hidden_dims,
55 | activation=activation,
56 | init_noise_std=init_noise_std,
57 | )
58 |
59 | activation = resolve_nn_activation(activation)
60 |
61 | self.memory_s = Memory(num_student_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
62 | if self.teacher_recurrent:
63 | self.memory_t = Memory(
64 | num_teacher_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim
65 | )
66 |
67 | print(f"Student RNN: {self.memory_s}")
68 | if self.teacher_recurrent:
69 | print(f"Teacher RNN: {self.memory_t}")
70 |
71 | def reset(self, dones=None, hidden_states=None):
72 | if hidden_states is None:
73 | hidden_states = (None, None)
74 | self.memory_s.reset(dones, hidden_states[0])
75 | if self.teacher_recurrent:
76 | self.memory_t.reset(dones, hidden_states[1])
77 |
78 | def act(self, observations):
79 | input_s = self.memory_s(observations)
80 | return super().act(input_s.squeeze(0))
81 |
82 | def act_inference(self, observations):
83 | input_s = self.memory_s(observations)
84 | return super().act_inference(input_s.squeeze(0))
85 |
86 | def evaluate(self, teacher_observations):
87 | if self.teacher_recurrent:
88 | teacher_observations = self.memory_t(teacher_observations)
89 | return super().evaluate(teacher_observations.squeeze(0))
90 |
91 | def get_hidden_states(self):
92 | if self.teacher_recurrent:
93 | return self.memory_s.hidden_states, self.memory_t.hidden_states
94 | else:
95 | return self.memory_s.hidden_states, None
96 |
97 | def detach_hidden_states(self, dones=None):
98 | self.memory_s.detach_hidden_states(dones)
99 | if self.teacher_recurrent:
100 | self.memory_t.detach_hidden_states(dones)
101 |
--------------------------------------------------------------------------------
/rsl_rl/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | """Definitions for neural networks."""
7 |
8 | from .memory import Memory
9 |
10 | __all__ = ["Memory"]
11 |
--------------------------------------------------------------------------------
/rsl_rl/networks/memory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from rsl_rl.utils import unpad_trajectories
12 |
13 |
14 | class Memory(torch.nn.Module):
15 | def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256):
16 | super().__init__()
17 | # RNN
18 | rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
19 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
20 | self.hidden_states = None
21 |
22 | def forward(self, input, masks=None, hidden_states=None):
23 | batch_mode = masks is not None
24 | if batch_mode:
25 | # batch mode: needs saved hidden states
26 | if hidden_states is None:
27 | raise ValueError("Hidden states not passed to memory module during policy update")
28 | out, _ = self.rnn(input, hidden_states)
29 | out = unpad_trajectories(out, masks)
30 | else:
31 | # inference/distillation mode: uses hidden states of last step
32 | out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
33 | return out
34 |
35 | def reset(self, dones=None, hidden_states=None):
36 | if dones is None: # reset all hidden states
37 | if hidden_states is None:
38 | self.hidden_states = None
39 | else:
40 | self.hidden_states = hidden_states
41 | elif self.hidden_states is not None: # reset hidden states of done environments
42 | if hidden_states is None:
43 | if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
44 | for hidden_state in self.hidden_states:
45 | hidden_state[..., dones == 1, :] = 0.0
46 | else:
47 | self.hidden_states[..., dones == 1, :] = 0.0
48 | else:
49 | NotImplementedError(
50 | "Resetting hidden states of done environments with custom hidden states is not implemented"
51 | )
52 |
53 | def detach_hidden_states(self, dones=None):
54 | if self.hidden_states is not None:
55 | if dones is None: # detach all hidden states
56 | if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
57 | self.hidden_states = tuple(hidden_state.detach() for hidden_state in self.hidden_states)
58 | else:
59 | self.hidden_states = self.hidden_states.detach()
60 | else: # detach hidden states of done environments
61 | if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
62 | for hidden_state in self.hidden_states:
63 | hidden_state[..., dones == 1, :] = hidden_state[..., dones == 1, :].detach()
64 | else:
65 | self.hidden_states[..., dones == 1, :] = self.hidden_states[..., dones == 1, :].detach()
66 |
--------------------------------------------------------------------------------
/rsl_rl/runners/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | """Implementation of runners for environment-agent interaction."""
7 |
8 | from .on_policy_runner import OnPolicyRunner
9 |
10 | __all__ = ["OnPolicyRunner"]
11 |
--------------------------------------------------------------------------------
/rsl_rl/runners/on_policy_runner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import os
9 | import statistics
10 | import time
11 | import torch
12 | from collections import deque
13 |
14 | import rsl_rl
15 | from rsl_rl.algorithms import PPO, Distillation
16 | from rsl_rl.env import VecEnv
17 | from rsl_rl.modules import (
18 | ActorCritic,
19 | ActorCriticRecurrent,
20 | EmpiricalNormalization,
21 | StudentTeacher,
22 | StudentTeacherRecurrent,
23 | )
24 | from rsl_rl.utils import store_code_state
25 |
26 |
27 | class OnPolicyRunner:
28 | """On-policy runner for training and evaluation."""
29 |
30 | def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device="cpu"):
31 | self.cfg = train_cfg
32 | self.alg_cfg = train_cfg["algorithm"]
33 | self.policy_cfg = train_cfg["policy"]
34 | self.device = device
35 | self.env = env
36 |
37 | # check if multi-gpu is enabled
38 | self._configure_multi_gpu()
39 |
40 | # resolve training type depending on the algorithm
41 | if self.alg_cfg["class_name"] == "PPO":
42 | self.training_type = "rl"
43 | elif self.alg_cfg["class_name"] == "Distillation":
44 | self.training_type = "distillation"
45 | else:
46 | raise ValueError(f"Training type not found for algorithm {self.alg_cfg['class_name']}.")
47 |
48 | # resolve dimensions of observations
49 | obs, extras = self.env.get_observations()
50 | num_obs = obs.shape[1]
51 |
52 | # resolve type of privileged observations
53 | if self.training_type == "rl":
54 | if "critic" in extras["observations"]:
55 | self.privileged_obs_type = "critic" # actor-critic reinforcement learnig, e.g., PPO
56 | else:
57 | self.privileged_obs_type = None
58 | if self.training_type == "distillation":
59 | if "teacher" in extras["observations"]:
60 | self.privileged_obs_type = "teacher" # policy distillation
61 | else:
62 | self.privileged_obs_type = None
63 |
64 | # resolve dimensions of privileged observations
65 | if self.privileged_obs_type is not None:
66 | num_privileged_obs = extras["observations"][self.privileged_obs_type].shape[1]
67 | else:
68 | num_privileged_obs = num_obs
69 |
70 | # evaluate the policy class
71 | policy_class = eval(self.policy_cfg.pop("class_name"))
72 | policy: ActorCritic | ActorCriticRecurrent | StudentTeacher | StudentTeacherRecurrent = policy_class(
73 | num_obs, num_privileged_obs, self.env.num_actions, **self.policy_cfg
74 | ).to(self.device)
75 |
76 | # resolve dimension of rnd gated state
77 | if "rnd_cfg" in self.alg_cfg and self.alg_cfg["rnd_cfg"] is not None:
78 | # check if rnd gated state is present
79 | rnd_state = extras["observations"].get("rnd_state")
80 | if rnd_state is None:
81 | raise ValueError("Observations for the key 'rnd_state' not found in infos['observations'].")
82 | # get dimension of rnd gated state
83 | num_rnd_state = rnd_state.shape[1]
84 | # add rnd gated state to config
85 | self.alg_cfg["rnd_cfg"]["num_states"] = num_rnd_state
86 | # scale down the rnd weight with timestep (similar to how rewards are scaled down in legged_gym envs)
87 | self.alg_cfg["rnd_cfg"]["weight"] *= env.unwrapped.step_dt
88 |
89 | # if using symmetry then pass the environment config object
90 | if "symmetry_cfg" in self.alg_cfg and self.alg_cfg["symmetry_cfg"] is not None:
91 | # this is used by the symmetry function for handling different observation terms
92 | self.alg_cfg["symmetry_cfg"]["_env"] = env
93 |
94 | # initialize algorithm
95 | alg_class = eval(self.alg_cfg.pop("class_name"))
96 | self.alg: PPO | Distillation = alg_class(
97 | policy, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
98 | )
99 |
100 | # store training configuration
101 | self.num_steps_per_env = self.cfg["num_steps_per_env"]
102 | self.save_interval = self.cfg["save_interval"]
103 | self.empirical_normalization = self.cfg["empirical_normalization"]
104 | if self.empirical_normalization:
105 | self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
106 | self.privileged_obs_normalizer = EmpiricalNormalization(shape=[num_privileged_obs], until=1.0e8).to(
107 | self.device
108 | )
109 | else:
110 | self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
111 | self.privileged_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
112 |
113 | # init storage and model
114 | self.alg.init_storage(
115 | self.training_type,
116 | self.env.num_envs,
117 | self.num_steps_per_env,
118 | [num_obs],
119 | [num_privileged_obs],
120 | [self.env.num_actions],
121 | )
122 |
123 | # Decide whether to disable logging
124 | # We only log from the process with rank 0 (main process)
125 | self.disable_logs = self.is_distributed and self.gpu_global_rank != 0
126 | # Logging
127 | self.log_dir = log_dir
128 | self.writer = None
129 | self.tot_timesteps = 0
130 | self.tot_time = 0
131 | self.current_learning_iteration = 0
132 | self.git_status_repos = [rsl_rl.__file__]
133 |
134 | def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): # noqa: C901
135 | # initialize writer
136 | if self.log_dir is not None and self.writer is None and not self.disable_logs:
137 | # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard.
138 | self.logger_type = self.cfg.get("logger", "tensorboard")
139 | self.logger_type = self.logger_type.lower()
140 |
141 | if self.logger_type == "neptune":
142 | from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter
143 |
144 | self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
145 | self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
146 | elif self.logger_type == "wandb":
147 | from rsl_rl.utils.wandb_utils import WandbSummaryWriter
148 |
149 | self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
150 | self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
151 | elif self.logger_type == "tensorboard":
152 | from torch.utils.tensorboard import SummaryWriter
153 |
154 | self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
155 | else:
156 | raise ValueError("Logger type not found. Please choose 'neptune', 'wandb' or 'tensorboard'.")
157 |
158 | # check if teacher is loaded
159 | if self.training_type == "distillation" and not self.alg.policy.loaded_teacher:
160 | raise ValueError("Teacher model parameters not loaded. Please load a teacher model to distill.")
161 |
162 | # randomize initial episode lengths (for exploration)
163 | if init_at_random_ep_len:
164 | self.env.episode_length_buf = torch.randint_like(
165 | self.env.episode_length_buf, high=int(self.env.max_episode_length)
166 | )
167 |
168 | # start learning
169 | obs, extras = self.env.get_observations()
170 | privileged_obs = extras["observations"].get(self.privileged_obs_type, obs)
171 | obs, privileged_obs = obs.to(self.device), privileged_obs.to(self.device)
172 | self.train_mode() # switch to train mode (for dropout for example)
173 |
174 | # Book keeping
175 | ep_infos = []
176 | rewbuffer = deque(maxlen=100)
177 | lenbuffer = deque(maxlen=100)
178 | cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
179 | cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
180 |
181 | # create buffers for logging extrinsic and intrinsic rewards
182 | if self.alg.rnd:
183 | erewbuffer = deque(maxlen=100)
184 | irewbuffer = deque(maxlen=100)
185 | cur_ereward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
186 | cur_ireward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
187 |
188 | # Ensure all parameters are in-synced
189 | if self.is_distributed:
190 | print(f"Synchronizing parameters for rank {self.gpu_global_rank}...")
191 | self.alg.broadcast_parameters()
192 | # TODO: Do we need to synchronize empirical normalizers?
193 | # Right now: No, because they all should converge to the same values "asymptotically".
194 |
195 | # Start training
196 | start_iter = self.current_learning_iteration
197 | tot_iter = start_iter + num_learning_iterations
198 | for it in range(start_iter, tot_iter):
199 | start = time.time()
200 | # Rollout
201 | with torch.inference_mode():
202 | for _ in range(self.num_steps_per_env):
203 | # Sample actions
204 | actions = self.alg.act(obs, privileged_obs)
205 | # Step the environment
206 | obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
207 | # Move to device
208 | obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device))
209 | # perform normalization
210 | obs = self.obs_normalizer(obs)
211 | if self.privileged_obs_type is not None:
212 | privileged_obs = self.privileged_obs_normalizer(
213 | infos["observations"][self.privileged_obs_type].to(self.device)
214 | )
215 | else:
216 | privileged_obs = obs
217 |
218 | # process the step
219 | self.alg.process_env_step(rewards, dones, infos)
220 |
221 | # Extract intrinsic rewards (only for logging)
222 | intrinsic_rewards = self.alg.intrinsic_rewards if self.alg.rnd else None
223 |
224 | # book keeping
225 | if self.log_dir is not None:
226 | if "episode" in infos:
227 | ep_infos.append(infos["episode"])
228 | elif "log" in infos:
229 | ep_infos.append(infos["log"])
230 | # Update rewards
231 | if self.alg.rnd:
232 | cur_ereward_sum += rewards
233 | cur_ireward_sum += intrinsic_rewards # type: ignore
234 | cur_reward_sum += rewards + intrinsic_rewards
235 | else:
236 | cur_reward_sum += rewards
237 | # Update episode length
238 | cur_episode_length += 1
239 | # Clear data for completed episodes
240 | # -- common
241 | new_ids = (dones > 0).nonzero(as_tuple=False)
242 | rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
243 | lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
244 | cur_reward_sum[new_ids] = 0
245 | cur_episode_length[new_ids] = 0
246 | # -- intrinsic and extrinsic rewards
247 | if self.alg.rnd:
248 | erewbuffer.extend(cur_ereward_sum[new_ids][:, 0].cpu().numpy().tolist())
249 | irewbuffer.extend(cur_ireward_sum[new_ids][:, 0].cpu().numpy().tolist())
250 | cur_ereward_sum[new_ids] = 0
251 | cur_ireward_sum[new_ids] = 0
252 |
253 | stop = time.time()
254 | collection_time = stop - start
255 | start = stop
256 |
257 | # compute returns
258 | if self.training_type == "rl":
259 | self.alg.compute_returns(privileged_obs)
260 |
261 | # update policy
262 | loss_dict = self.alg.update()
263 |
264 | stop = time.time()
265 | learn_time = stop - start
266 | self.current_learning_iteration = it
267 | # log info
268 | if self.log_dir is not None and not self.disable_logs:
269 | # Log information
270 | self.log(locals())
271 | # Save model
272 | if it % self.save_interval == 0:
273 | self.save(os.path.join(self.log_dir, f"model_{it}.pt"))
274 |
275 | # Clear episode infos
276 | ep_infos.clear()
277 | # Save code state
278 | if it == start_iter and not self.disable_logs:
279 | # obtain all the diff files
280 | git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
281 | # if possible store them to wandb
282 | if self.logger_type in ["wandb", "neptune"] and git_file_paths:
283 | for path in git_file_paths:
284 | self.writer.save_file(path)
285 |
286 | # Save the final model after training
287 | if self.log_dir is not None and not self.disable_logs:
288 | self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))
289 |
290 | def log(self, locs: dict, width: int = 80, pad: int = 35):
291 | # Compute the collection size
292 | collection_size = self.num_steps_per_env * self.env.num_envs * self.gpu_world_size
293 | # Update total time-steps and time
294 | self.tot_timesteps += collection_size
295 | self.tot_time += locs["collection_time"] + locs["learn_time"]
296 | iteration_time = locs["collection_time"] + locs["learn_time"]
297 |
298 | # -- Episode info
299 | ep_string = ""
300 | if locs["ep_infos"]:
301 | for key in locs["ep_infos"][0]:
302 | infotensor = torch.tensor([], device=self.device)
303 | for ep_info in locs["ep_infos"]:
304 | # handle scalar and zero dimensional tensor infos
305 | if key not in ep_info:
306 | continue
307 | if not isinstance(ep_info[key], torch.Tensor):
308 | ep_info[key] = torch.Tensor([ep_info[key]])
309 | if len(ep_info[key].shape) == 0:
310 | ep_info[key] = ep_info[key].unsqueeze(0)
311 | infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
312 | value = torch.mean(infotensor)
313 | # log to logger and terminal
314 | if "/" in key:
315 | self.writer.add_scalar(key, value, locs["it"])
316 | ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""
317 | else:
318 | self.writer.add_scalar("Episode/" + key, value, locs["it"])
319 | ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
320 |
321 | mean_std = self.alg.policy.action_std.mean()
322 | fps = int(collection_size / (locs["collection_time"] + locs["learn_time"]))
323 |
324 | # -- Losses
325 | for key, value in locs["loss_dict"].items():
326 | self.writer.add_scalar(f"Loss/{key}", value, locs["it"])
327 | self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"])
328 |
329 | # -- Policy
330 | self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"])
331 |
332 | # -- Performance
333 | self.writer.add_scalar("Perf/total_fps", fps, locs["it"])
334 | self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"])
335 | self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"])
336 |
337 | # -- Training
338 | if len(locs["rewbuffer"]) > 0:
339 | # separate logging for intrinsic and extrinsic rewards
340 | if self.alg.rnd:
341 | self.writer.add_scalar("Rnd/mean_extrinsic_reward", statistics.mean(locs["erewbuffer"]), locs["it"])
342 | self.writer.add_scalar("Rnd/mean_intrinsic_reward", statistics.mean(locs["irewbuffer"]), locs["it"])
343 | self.writer.add_scalar("Rnd/weight", self.alg.rnd.weight, locs["it"])
344 | # everything else
345 | self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"])
346 | self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"])
347 | if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging
348 | self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time)
349 | self.writer.add_scalar(
350 | "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
351 | )
352 |
353 | str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
354 |
355 | if len(locs["rewbuffer"]) > 0:
356 | log_string = (
357 | f"""{'#' * width}\n"""
358 | f"""{str.center(width, ' ')}\n\n"""
359 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
360 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
361 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
362 | )
363 | # -- Losses
364 | for key, value in locs["loss_dict"].items():
365 | log_string += f"""{f'Mean {key} loss:':>{pad}} {value:.4f}\n"""
366 | # -- Rewards
367 | if self.alg.rnd:
368 | log_string += (
369 | f"""{'Mean extrinsic reward:':>{pad}} {statistics.mean(locs['erewbuffer']):.2f}\n"""
370 | f"""{'Mean intrinsic reward:':>{pad}} {statistics.mean(locs['irewbuffer']):.2f}\n"""
371 | )
372 | log_string += f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
373 | # -- episode info
374 | log_string += f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
375 | else:
376 | log_string = (
377 | f"""{'#' * width}\n"""
378 | f"""{str.center(width, ' ')}\n\n"""
379 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
380 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
381 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
382 | )
383 | for key, value in locs["loss_dict"].items():
384 | log_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""
385 |
386 | log_string += ep_string
387 | log_string += (
388 | f"""{'-' * width}\n"""
389 | f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
390 | f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
391 | f"""{'Time elapsed:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n"""
392 | f"""{'ETA:':>{pad}} {time.strftime(
393 | "%H:%M:%S",
394 | time.gmtime(
395 | self.tot_time / (locs['it'] - locs['start_iter'] + 1)
396 | * (locs['start_iter'] + locs['num_learning_iterations'] - locs['it'])
397 | )
398 | )}\n"""
399 | )
400 | print(log_string)
401 |
402 | def save(self, path: str, infos=None):
403 | # -- Save model
404 | saved_dict = {
405 | "model_state_dict": self.alg.policy.state_dict(),
406 | "optimizer_state_dict": self.alg.optimizer.state_dict(),
407 | "iter": self.current_learning_iteration,
408 | "infos": infos,
409 | }
410 | # -- Save RND model if used
411 | if self.alg.rnd:
412 | saved_dict["rnd_state_dict"] = self.alg.rnd.state_dict()
413 | saved_dict["rnd_optimizer_state_dict"] = self.alg.rnd_optimizer.state_dict()
414 | # -- Save observation normalizer if used
415 | if self.empirical_normalization:
416 | saved_dict["obs_norm_state_dict"] = self.obs_normalizer.state_dict()
417 | saved_dict["privileged_obs_norm_state_dict"] = self.privileged_obs_normalizer.state_dict()
418 |
419 | # save model
420 | torch.save(saved_dict, path)
421 |
422 | # upload model to external logging service
423 | if self.logger_type in ["neptune", "wandb"] and not self.disable_logs:
424 | self.writer.save_model(path, self.current_learning_iteration)
425 |
426 | def load(self, path: str, load_optimizer: bool = True):
427 | loaded_dict = torch.load(path, weights_only=False)
428 | # -- Load model
429 | resumed_training = self.alg.policy.load_state_dict(loaded_dict["model_state_dict"])
430 | # -- Load RND model if used
431 | if self.alg.rnd:
432 | self.alg.rnd.load_state_dict(loaded_dict["rnd_state_dict"])
433 | # -- Load observation normalizer if used
434 | if self.empirical_normalization:
435 | if resumed_training:
436 | # if a previous training is resumed, the actor/student normalizer is loaded for the actor/student
437 | # and the critic/teacher normalizer is loaded for the critic/teacher
438 | self.obs_normalizer.load_state_dict(loaded_dict["obs_norm_state_dict"])
439 | self.privileged_obs_normalizer.load_state_dict(loaded_dict["privileged_obs_norm_state_dict"])
440 | else:
441 | # if the training is not resumed but a model is loaded, this run must be distillation training following
442 | # an rl training. Thus the actor normalizer is loaded for the teacher model. The student's normalizer
443 | # is not loaded, as the observation space could differ from the previous rl training.
444 | self.privileged_obs_normalizer.load_state_dict(loaded_dict["obs_norm_state_dict"])
445 | # -- load optimizer if used
446 | if load_optimizer and resumed_training:
447 | # -- algorithm optimizer
448 | self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"])
449 | # -- RND optimizer if used
450 | if self.alg.rnd:
451 | self.alg.rnd_optimizer.load_state_dict(loaded_dict["rnd_optimizer_state_dict"])
452 | # -- load current learning iteration
453 | if resumed_training:
454 | self.current_learning_iteration = loaded_dict["iter"]
455 | return loaded_dict["infos"]
456 |
457 | def get_inference_policy(self, device=None):
458 | self.eval_mode() # switch to evaluation mode (dropout for example)
459 | if device is not None:
460 | self.alg.policy.to(device)
461 | policy = self.alg.policy.act_inference
462 | if self.cfg["empirical_normalization"]:
463 | if device is not None:
464 | self.obs_normalizer.to(device)
465 | policy = lambda x: self.alg.policy.act_inference(self.obs_normalizer(x)) # noqa: E731
466 | return policy
467 |
468 | def train_mode(self):
469 | # -- PPO
470 | self.alg.policy.train()
471 | # -- RND
472 | if self.alg.rnd:
473 | self.alg.rnd.train()
474 | # -- Normalization
475 | if self.empirical_normalization:
476 | self.obs_normalizer.train()
477 | self.privileged_obs_normalizer.train()
478 |
479 | def eval_mode(self):
480 | # -- PPO
481 | self.alg.policy.eval()
482 | # -- RND
483 | if self.alg.rnd:
484 | self.alg.rnd.eval()
485 | # -- Normalization
486 | if self.empirical_normalization:
487 | self.obs_normalizer.eval()
488 | self.privileged_obs_normalizer.eval()
489 |
490 | def add_git_repo_to_log(self, repo_file_path):
491 | self.git_status_repos.append(repo_file_path)
492 |
493 | """
494 | Helper functions.
495 | """
496 |
497 | def _configure_multi_gpu(self):
498 | """Configure multi-gpu training."""
499 | # check if distributed training is enabled
500 | self.gpu_world_size = int(os.getenv("WORLD_SIZE", "1"))
501 | self.is_distributed = self.gpu_world_size > 1
502 |
503 | # if not distributed training, set local and global rank to 0 and return
504 | if not self.is_distributed:
505 | self.gpu_local_rank = 0
506 | self.gpu_global_rank = 0
507 | self.multi_gpu_cfg = None
508 | return
509 |
510 | # get rank and world size
511 | self.gpu_local_rank = int(os.getenv("LOCAL_RANK", "0"))
512 | self.gpu_global_rank = int(os.getenv("RANK", "0"))
513 |
514 | # make a configuration dictionary
515 | self.multi_gpu_cfg = {
516 | "global_rank": self.gpu_global_rank, # rank of the main process
517 | "local_rank": self.gpu_local_rank, # rank of the current process
518 | "world_size": self.gpu_world_size, # total number of processes
519 | }
520 |
521 | # check if user has device specified for local rank
522 | if self.device != f"cuda:{self.gpu_local_rank}":
523 | raise ValueError(
524 | f"Device '{self.device}' does not match expected device for local rank '{self.gpu_local_rank}'."
525 | )
526 | # validate multi-gpu configuration
527 | if self.gpu_local_rank >= self.gpu_world_size:
528 | raise ValueError(
529 | f"Local rank '{self.gpu_local_rank}' is greater than or equal to world size '{self.gpu_world_size}'."
530 | )
531 | if self.gpu_global_rank >= self.gpu_world_size:
532 | raise ValueError(
533 | f"Global rank '{self.gpu_global_rank}' is greater than or equal to world size '{self.gpu_world_size}'."
534 | )
535 |
536 | # initialize torch distributed
537 | torch.distributed.init_process_group(backend="nccl", rank=self.gpu_global_rank, world_size=self.gpu_world_size)
538 | # set device to the local rank
539 | torch.cuda.set_device(self.gpu_local_rank)
540 |
--------------------------------------------------------------------------------
/rsl_rl/storage/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | """Implementation of transitions storage for RL-agent."""
7 |
8 | from .rollout_storage import RolloutStorage
9 |
10 | __all__ = ["RolloutStorage"]
11 |
--------------------------------------------------------------------------------
/rsl_rl/storage/rollout_storage.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import torch
9 |
10 | from rsl_rl.utils import split_and_pad_trajectories
11 |
12 |
13 | class RolloutStorage:
14 | class Transition:
15 | def __init__(self):
16 | self.observations = None
17 | self.privileged_observations = None
18 | self.actions = None
19 | self.privileged_actions = None
20 | self.rewards = None
21 | self.dones = None
22 | self.values = None
23 | self.actions_log_prob = None
24 | self.action_mean = None
25 | self.action_sigma = None
26 | self.hidden_states = None
27 | self.rnd_state = None
28 |
29 | def clear(self):
30 | self.__init__()
31 |
32 | def __init__(
33 | self,
34 | training_type,
35 | num_envs,
36 | num_transitions_per_env,
37 | obs_shape,
38 | privileged_obs_shape,
39 | actions_shape,
40 | rnd_state_shape=None,
41 | device="cpu",
42 | ):
43 | # store inputs
44 | self.training_type = training_type
45 | self.device = device
46 | self.num_transitions_per_env = num_transitions_per_env
47 | self.num_envs = num_envs
48 | self.obs_shape = obs_shape
49 | self.privileged_obs_shape = privileged_obs_shape
50 | self.rnd_state_shape = rnd_state_shape
51 | self.actions_shape = actions_shape
52 |
53 | # Core
54 | self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
55 | if privileged_obs_shape is not None:
56 | self.privileged_observations = torch.zeros(
57 | num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device
58 | )
59 | else:
60 | self.privileged_observations = None
61 | self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
62 | self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
63 | self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()
64 |
65 | # for distillation
66 | if training_type == "distillation":
67 | self.privileged_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
68 |
69 | # for reinforcement learning
70 | if training_type == "rl":
71 | self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
72 | self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
73 | self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
74 | self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
75 | self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
76 | self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
77 |
78 | # For RND
79 | if rnd_state_shape is not None:
80 | self.rnd_state = torch.zeros(num_transitions_per_env, num_envs, *rnd_state_shape, device=self.device)
81 |
82 | # For RNN networks
83 | self.saved_hidden_states_a = None
84 | self.saved_hidden_states_c = None
85 |
86 | # counter for the number of transitions stored
87 | self.step = 0
88 |
89 | def add_transitions(self, transition: Transition):
90 | # check if the transition is valid
91 | if self.step >= self.num_transitions_per_env:
92 | raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.")
93 |
94 | # Core
95 | self.observations[self.step].copy_(transition.observations)
96 | if self.privileged_observations is not None:
97 | self.privileged_observations[self.step].copy_(transition.privileged_observations)
98 | self.actions[self.step].copy_(transition.actions)
99 | self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
100 | self.dones[self.step].copy_(transition.dones.view(-1, 1))
101 |
102 | # for distillation
103 | if self.training_type == "distillation":
104 | self.privileged_actions[self.step].copy_(transition.privileged_actions)
105 |
106 | # for reinforcement learning
107 | if self.training_type == "rl":
108 | self.values[self.step].copy_(transition.values)
109 | self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
110 | self.mu[self.step].copy_(transition.action_mean)
111 | self.sigma[self.step].copy_(transition.action_sigma)
112 |
113 | # For RND
114 | if self.rnd_state_shape is not None:
115 | self.rnd_state[self.step].copy_(transition.rnd_state)
116 |
117 | # For RNN networks
118 | self._save_hidden_states(transition.hidden_states)
119 |
120 | # increment the counter
121 | self.step += 1
122 |
123 | def _save_hidden_states(self, hidden_states):
124 | if hidden_states is None or hidden_states == (None, None):
125 | return
126 | # make a tuple out of GRU hidden state sto match the LSTM format
127 | hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
128 | hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
129 | # initialize if needed
130 | if self.saved_hidden_states_a is None:
131 | self.saved_hidden_states_a = [
132 | torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))
133 | ]
134 | self.saved_hidden_states_c = [
135 | torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))
136 | ]
137 | # copy the states
138 | for i in range(len(hid_a)):
139 | self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
140 | self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])
141 |
142 | def clear(self):
143 | self.step = 0
144 |
145 | def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True):
146 | advantage = 0
147 | for step in reversed(range(self.num_transitions_per_env)):
148 | # if we are at the last step, bootstrap the return value
149 | if step == self.num_transitions_per_env - 1:
150 | next_values = last_values
151 | else:
152 | next_values = self.values[step + 1]
153 | # 1 if we are not in a terminal state, 0 otherwise
154 | next_is_not_terminal = 1.0 - self.dones[step].float()
155 | # TD error: r_t + gamma * V(s_{t+1}) - V(s_t)
156 | delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
157 | # Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1})
158 | advantage = delta + next_is_not_terminal * gamma * lam * advantage
159 | # Return: R_t = A(s_t, a_t) + V(s_t)
160 | self.returns[step] = advantage + self.values[step]
161 |
162 | # Compute the advantages
163 | self.advantages = self.returns - self.values
164 | # Normalize the advantages if flag is set
165 | # This is to prevent double normalization (i.e. if per minibatch normalization is used)
166 | if normalize_advantage:
167 | self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
168 |
169 | # for distillation
170 | def generator(self):
171 | if self.training_type != "distillation":
172 | raise ValueError("This function is only available for distillation training.")
173 |
174 | for i in range(self.num_transitions_per_env):
175 | if self.privileged_observations is not None:
176 | privileged_observations = self.privileged_observations[i]
177 | else:
178 | privileged_observations = self.observations[i]
179 | yield self.observations[i], privileged_observations, self.actions[i], self.privileged_actions[
180 | i
181 | ], self.dones[i]
182 |
183 | # for reinforcement learning with feedforward networks
184 | def mini_batch_generator(self, num_mini_batches, num_epochs=8):
185 | if self.training_type != "rl":
186 | raise ValueError("This function is only available for reinforcement learning training.")
187 | batch_size = self.num_envs * self.num_transitions_per_env
188 | mini_batch_size = batch_size // num_mini_batches
189 | indices = torch.randperm(num_mini_batches * mini_batch_size, requires_grad=False, device=self.device)
190 |
191 | # Core
192 | observations = self.observations.flatten(0, 1)
193 | if self.privileged_observations is not None:
194 | privileged_observations = self.privileged_observations.flatten(0, 1)
195 | else:
196 | privileged_observations = observations
197 |
198 | actions = self.actions.flatten(0, 1)
199 | values = self.values.flatten(0, 1)
200 | returns = self.returns.flatten(0, 1)
201 |
202 | # For PPO
203 | old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
204 | advantages = self.advantages.flatten(0, 1)
205 | old_mu = self.mu.flatten(0, 1)
206 | old_sigma = self.sigma.flatten(0, 1)
207 |
208 | # For RND
209 | if self.rnd_state_shape is not None:
210 | rnd_state = self.rnd_state.flatten(0, 1)
211 |
212 | for epoch in range(num_epochs):
213 | for i in range(num_mini_batches):
214 | # Select the indices for the mini-batch
215 | start = i * mini_batch_size
216 | end = (i + 1) * mini_batch_size
217 | batch_idx = indices[start:end]
218 |
219 | # Create the mini-batch
220 | # -- Core
221 | obs_batch = observations[batch_idx]
222 | privileged_observations_batch = privileged_observations[batch_idx]
223 | actions_batch = actions[batch_idx]
224 |
225 | # -- For PPO
226 | target_values_batch = values[batch_idx]
227 | returns_batch = returns[batch_idx]
228 | old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
229 | advantages_batch = advantages[batch_idx]
230 | old_mu_batch = old_mu[batch_idx]
231 | old_sigma_batch = old_sigma[batch_idx]
232 |
233 | # -- For RND
234 | if self.rnd_state_shape is not None:
235 | rnd_state_batch = rnd_state[batch_idx]
236 | else:
237 | rnd_state_batch = None
238 |
239 | # yield the mini-batch
240 | yield obs_batch, privileged_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
241 | None,
242 | None,
243 | ), None, rnd_state_batch
244 |
245 | # for reinfrocement learning with recurrent networks
246 | def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8):
247 | if self.training_type != "rl":
248 | raise ValueError("This function is only available for reinforcement learning training.")
249 | padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones)
250 | if self.privileged_observations is not None:
251 | padded_privileged_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones)
252 | else:
253 | padded_privileged_obs_trajectories = padded_obs_trajectories
254 |
255 | if self.rnd_state_shape is not None:
256 | padded_rnd_state_trajectories, _ = split_and_pad_trajectories(self.rnd_state, self.dones)
257 | else:
258 | padded_rnd_state_trajectories = None
259 |
260 | mini_batch_size = self.num_envs // num_mini_batches
261 | for ep in range(num_epochs):
262 | first_traj = 0
263 | for i in range(num_mini_batches):
264 | start = i * mini_batch_size
265 | stop = (i + 1) * mini_batch_size
266 |
267 | dones = self.dones.squeeze(-1)
268 | last_was_done = torch.zeros_like(dones, dtype=torch.bool)
269 | last_was_done[1:] = dones[:-1]
270 | last_was_done[0] = True
271 | trajectories_batch_size = torch.sum(last_was_done[:, start:stop])
272 | last_traj = first_traj + trajectories_batch_size
273 |
274 | masks_batch = trajectory_masks[:, first_traj:last_traj]
275 | obs_batch = padded_obs_trajectories[:, first_traj:last_traj]
276 | privileged_obs_batch = padded_privileged_obs_trajectories[:, first_traj:last_traj]
277 |
278 | if padded_rnd_state_trajectories is not None:
279 | rnd_state_batch = padded_rnd_state_trajectories[:, first_traj:last_traj]
280 | else:
281 | rnd_state_batch = None
282 |
283 | actions_batch = self.actions[:, start:stop]
284 | old_mu_batch = self.mu[:, start:stop]
285 | old_sigma_batch = self.sigma[:, start:stop]
286 | returns_batch = self.returns[:, start:stop]
287 | advantages_batch = self.advantages[:, start:stop]
288 | values_batch = self.values[:, start:stop]
289 | old_actions_log_prob_batch = self.actions_log_prob[:, start:stop]
290 |
291 | # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim])
292 | # then take only time steps after dones (flattens num envs and time dimensions),
293 | # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
294 | last_was_done = last_was_done.permute(1, 0)
295 | hid_a_batch = [
296 | saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
297 | .transpose(1, 0)
298 | .contiguous()
299 | for saved_hidden_states in self.saved_hidden_states_a
300 | ]
301 | hid_c_batch = [
302 | saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
303 | .transpose(1, 0)
304 | .contiguous()
305 | for saved_hidden_states in self.saved_hidden_states_c
306 | ]
307 | # remove the tuple for GRU
308 | hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch
309 | hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch
310 |
311 | yield obs_batch, privileged_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
312 | hid_a_batch,
313 | hid_c_batch,
314 | ), masks_batch, rnd_state_batch
315 |
316 | first_traj = last_traj
317 |
--------------------------------------------------------------------------------
/rsl_rl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | """Helper functions."""
7 |
8 | from .utils import (
9 | resolve_nn_activation,
10 | split_and_pad_trajectories,
11 | store_code_state,
12 | string_to_callable,
13 | unpad_trajectories,
14 | )
15 |
--------------------------------------------------------------------------------
/rsl_rl/utils/neptune_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import os
9 | from dataclasses import asdict
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | try:
13 | import neptune
14 | except ModuleNotFoundError:
15 | raise ModuleNotFoundError("neptune-client is required to log to Neptune.")
16 |
17 |
18 | class NeptuneLogger:
19 | def __init__(self, project, token):
20 | self.run = neptune.init_run(project=project, api_token=token)
21 |
22 | def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
23 | self.run["runner_cfg"] = runner_cfg
24 | self.run["policy_cfg"] = policy_cfg
25 | self.run["alg_cfg"] = alg_cfg
26 | self.run["env_cfg"] = asdict(env_cfg)
27 |
28 |
29 | class NeptuneSummaryWriter(SummaryWriter):
30 | """Summary writer for Neptune."""
31 |
32 | def __init__(self, log_dir: str, flush_secs: int, cfg):
33 | super().__init__(log_dir, flush_secs)
34 |
35 | try:
36 | project = cfg["neptune_project"]
37 | except KeyError:
38 | raise KeyError("Please specify neptune_project in the runner config, e.g. legged_gym.")
39 |
40 | try:
41 | token = os.environ["NEPTUNE_API_TOKEN"]
42 | except KeyError:
43 | raise KeyError(
44 | "Neptune api token not found. Please run or add to ~/.bashrc: export NEPTUNE_API_TOKEN=YOUR_API_TOKEN"
45 | )
46 |
47 | try:
48 | entity = os.environ["NEPTUNE_USERNAME"]
49 | except KeyError:
50 | raise KeyError(
51 | "Neptune username not found. Please run or add to ~/.bashrc: export NEPTUNE_USERNAME=YOUR_USERNAME"
52 | )
53 |
54 | neptune_project = entity + "/" + project
55 |
56 | self.neptune_logger = NeptuneLogger(neptune_project, token)
57 |
58 | self.name_map = {
59 | "Train/mean_reward/time": "Train/mean_reward_time",
60 | "Train/mean_episode_length/time": "Train/mean_episode_length_time",
61 | }
62 |
63 | run_name = os.path.split(log_dir)[-1]
64 |
65 | self.neptune_logger.run["log_dir"].log(run_name)
66 |
67 | def _map_path(self, path):
68 | if path in self.name_map:
69 | return self.name_map[path]
70 | else:
71 | return path
72 |
73 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False):
74 | super().add_scalar(
75 | tag,
76 | scalar_value,
77 | global_step=global_step,
78 | walltime=walltime,
79 | new_style=new_style,
80 | )
81 | self.neptune_logger.run[self._map_path(tag)].log(scalar_value, step=global_step)
82 |
83 | def stop(self):
84 | self.neptune_logger.run.stop()
85 |
86 | def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
87 | self.neptune_logger.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
88 |
89 | def save_model(self, model_path, iter):
90 | self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path)
91 |
92 | def save_file(self, path, iter=None):
93 | name = path.rsplit("/", 1)[-1].split(".")[0]
94 | self.neptune_logger.run["git_diff/" + name].upload(path)
95 |
--------------------------------------------------------------------------------
/rsl_rl/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import git
9 | import importlib
10 | import os
11 | import pathlib
12 | import torch
13 | from typing import Callable
14 |
15 |
16 | def resolve_nn_activation(act_name: str) -> torch.nn.Module:
17 | if act_name == "elu":
18 | return torch.nn.ELU()
19 | elif act_name == "selu":
20 | return torch.nn.SELU()
21 | elif act_name == "relu":
22 | return torch.nn.ReLU()
23 | elif act_name == "crelu":
24 | return torch.nn.CELU()
25 | elif act_name == "lrelu":
26 | return torch.nn.LeakyReLU()
27 | elif act_name == "tanh":
28 | return torch.nn.Tanh()
29 | elif act_name == "sigmoid":
30 | return torch.nn.Sigmoid()
31 | elif act_name == "identity":
32 | return torch.nn.Identity()
33 | else:
34 | raise ValueError(f"Invalid activation function '{act_name}'.")
35 |
36 |
37 | def split_and_pad_trajectories(tensor, dones):
38 | """Splits trajectories at done indices. Then concatenates them and pads with zeros up to the length og the longest trajectory.
39 | Returns masks corresponding to valid parts of the trajectories
40 | Example:
41 | Input: [ [a1, a2, a3, a4 | a5, a6],
42 | [b1, b2 | b3, b4, b5 | b6]
43 | ]
44 |
45 | Output:[ [a1, a2, a3, a4], | [ [True, True, True, True],
46 | [a5, a6, 0, 0], | [True, True, False, False],
47 | [b1, b2, 0, 0], | [True, True, False, False],
48 | [b3, b4, b5, 0], | [True, True, True, False],
49 | [b6, 0, 0, 0] | [True, False, False, False],
50 | ] | ]
51 |
52 | Assumes that the inputy has the following dimension order: [time, number of envs, additional dimensions]
53 | """
54 | dones = dones.clone()
55 | dones[-1] = 1
56 | # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping
57 | flat_dones = dones.transpose(1, 0).reshape(-1, 1)
58 |
59 | # Get length of trajectory by counting the number of successive not done elements
60 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0]))
61 | trajectory_lengths = done_indices[1:] - done_indices[:-1]
62 | trajectory_lengths_list = trajectory_lengths.tolist()
63 | # Extract the individual trajectories
64 | trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list)
65 | # add at least one full length trajectory
66 | trajectories = trajectories + (torch.zeros(tensor.shape[0], *tensor.shape[2:], device=tensor.device),)
67 | # pad the trajectories to the length of the longest trajectory
68 | padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
69 | # remove the added tensor
70 | padded_trajectories = padded_trajectories[:, :-1]
71 |
72 | trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
73 | return padded_trajectories, trajectory_masks
74 |
75 |
76 | def unpad_trajectories(trajectories, masks):
77 | """Does the inverse operation of split_and_pad_trajectories()"""
78 | # Need to transpose before and after the masking to have proper reshaping
79 | return (
80 | trajectories.transpose(1, 0)[masks.transpose(1, 0)]
81 | .view(-1, trajectories.shape[0], trajectories.shape[-1])
82 | .transpose(1, 0)
83 | )
84 |
85 |
86 | def store_code_state(logdir, repositories) -> list:
87 | git_log_dir = os.path.join(logdir, "git")
88 | os.makedirs(git_log_dir, exist_ok=True)
89 | file_paths = []
90 | for repository_file_path in repositories:
91 | try:
92 | repo = git.Repo(repository_file_path, search_parent_directories=True)
93 | t = repo.head.commit.tree
94 | except Exception:
95 | print(f"Could not find git repository in {repository_file_path}. Skipping.")
96 | # skip if not a git repository
97 | continue
98 | # get the name of the repository
99 | repo_name = pathlib.Path(repo.working_dir).name
100 | diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff")
101 | # check if the diff file already exists
102 | if os.path.isfile(diff_file_name):
103 | continue
104 | # write the diff file
105 | print(f"Storing git diff for '{repo_name}' in: {diff_file_name}")
106 | with open(diff_file_name, "x", encoding="utf-8") as f:
107 | content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
108 | f.write(content)
109 | # add the file path to the list of files to be uploaded
110 | file_paths.append(diff_file_name)
111 | return file_paths
112 |
113 |
114 | def string_to_callable(name: str) -> Callable:
115 | """Resolves the module and function names to return the function.
116 |
117 | Args:
118 | name (str): The function name. The format should be 'module:attribute_name'.
119 |
120 | Raises:
121 | ValueError: When the resolved attribute is not a function.
122 | ValueError: When unable to resolve the attribute.
123 |
124 | Returns:
125 | Callable: The function loaded from the module.
126 | """
127 | try:
128 | mod_name, attr_name = name.split(":")
129 | mod = importlib.import_module(mod_name)
130 | callable_object = getattr(mod, attr_name)
131 | # check if attribute is callable
132 | if callable(callable_object):
133 | return callable_object
134 | else:
135 | raise ValueError(f"The imported object is not callable: '{name}'")
136 | except AttributeError as e:
137 | msg = (
138 | "We could not interpret the entry as a callable object. The format of input should be"
139 | f" 'module:attribute_name'\nWhile processing input '{name}', received the error:\n {e}."
140 | )
141 | raise ValueError(msg)
142 |
--------------------------------------------------------------------------------
/rsl_rl/utils/wandb_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from __future__ import annotations
7 |
8 | import os
9 | from dataclasses import asdict
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | try:
13 | import wandb
14 | except ModuleNotFoundError:
15 | raise ModuleNotFoundError("Wandb is required to log to Weights and Biases.")
16 |
17 |
18 | class WandbSummaryWriter(SummaryWriter):
19 | """Summary writer for Weights and Biases."""
20 |
21 | def __init__(self, log_dir: str, flush_secs: int, cfg):
22 | super().__init__(log_dir, flush_secs)
23 |
24 | # Get the run name
25 | run_name = os.path.split(log_dir)[-1]
26 |
27 | try:
28 | project = cfg["wandb_project"]
29 | except KeyError:
30 | raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.")
31 |
32 | try:
33 | entity = os.environ["WANDB_USERNAME"]
34 | except KeyError:
35 | entity = None
36 |
37 | # Initialize wandb
38 | wandb.init(project=project, entity=entity, name=run_name)
39 |
40 | # Add log directory to wandb
41 | wandb.config.update({"log_dir": log_dir})
42 |
43 | self.name_map = {
44 | "Train/mean_reward/time": "Train/mean_reward_time",
45 | "Train/mean_episode_length/time": "Train/mean_episode_length_time",
46 | }
47 |
48 | def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
49 | wandb.config.update({"runner_cfg": runner_cfg})
50 | wandb.config.update({"policy_cfg": policy_cfg})
51 | wandb.config.update({"alg_cfg": alg_cfg})
52 | try:
53 | wandb.config.update({"env_cfg": env_cfg.to_dict()})
54 | except Exception:
55 | wandb.config.update({"env_cfg": asdict(env_cfg)})
56 |
57 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False):
58 | super().add_scalar(
59 | tag,
60 | scalar_value,
61 | global_step=global_step,
62 | walltime=walltime,
63 | new_style=new_style,
64 | )
65 | wandb.log({self._map_path(tag): scalar_value}, step=global_step)
66 |
67 | def stop(self):
68 | wandb.finish()
69 |
70 | def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
71 | self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
72 |
73 | def save_model(self, model_path, iter):
74 | wandb.save(model_path, base_path=os.path.dirname(model_path))
75 |
76 | def save_file(self, path, iter=None):
77 | wandb.save(path, base_path=os.path.dirname(path))
78 |
79 | """
80 | Private methods.
81 | """
82 |
83 | def _map_path(self, path):
84 | if path in self.name_map:
85 | return self.name_map[path]
86 | else:
87 | return path
88 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2 | # All rights reserved.
3 | #
4 | # SPDX-License-Identifier: BSD-3-Clause
5 |
6 | from setuptools import setup
7 |
8 | setup()
9 |
--------------------------------------------------------------------------------