├── LICENSE.txt ├── README.md ├── datasets └── mocap_motions │ ├── hop1.txt │ ├── hop2.txt │ ├── trot1.txt │ └── trot2.txt ├── dreamer ├── __init__.py ├── configs.yaml ├── models.py ├── networks.py └── tools.py ├── legged_gym ├── __init__.py ├── envs │ ├── __init__.py │ ├── a1 │ │ ├── a1_amp_config.py │ │ └── a1_config.py │ └── base │ │ ├── base_config.py │ │ ├── base_task.py │ │ ├── legged_robot.py │ │ ├── legged_robot_config.py │ │ └── observation_buffer.py ├── scripts │ ├── play.py │ └── train.py ├── tests │ └── test_env.py └── utils │ ├── __init__.py │ ├── helpers.py │ ├── logger.py │ ├── math.py │ ├── task_registry.py │ ├── terrain.py │ └── trimesh.py ├── licenses ├── assets │ └── a1_license.txt ├── dependencies │ └── matplotlib_license.txt └── subcomponents │ ├── dreamerv3-torch_license.txt │ ├── leggedgym_license.txt │ └── parkour_license.txt ├── requirements.txt ├── resources └── robots │ └── a1 │ ├── a1_license.txt │ ├── meshes │ ├── calf.dae │ ├── hip.dae │ ├── thigh.dae │ ├── thigh_mirror.dae │ ├── trunk.dae │ └── trunk_A1.png │ └── urdf │ └── a1.urdf └── rsl_rl ├── __init__.py ├── algorithms ├── __init__.py ├── amp_discriminator.py ├── amp_ppo.py └── ppo.py ├── datasets ├── __init__.py ├── motion_loader.py ├── motion_util.py └── pose3d.py ├── env ├── __init__.py └── vec_env.py ├── modules ├── __init__.py ├── actor_critic.py ├── actor_critic_recurrent.py ├── actor_critic_wmp.py └── depth_predictor.py ├── runners ├── __init__.py ├── on_policy_runner.py └── wmp_runner.py ├── storage ├── __init__.py ├── replay_buffer.py └── rollout_storage.py └── utils ├── __init__.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | 204 | 205 | ======================================================================== 206 | Other licenses 207 | ======================================================================== 208 | The following components are provided under other License. See project link for details. 209 | 210 | leggedgym files from leggedgym: https://github.com/jadenvc/leggedgym BSD-3-Clause 211 | dreamer files from dreamerv3-torch: https://github.com/NM512/dreamerv3-torch MIT 212 | amp files from AMP_for_hardware: https://github.com/Alescontrela/AMP_for_hardware BSD-3-Clause 213 | trimesh files from parkour: https://github.com/ZiwenZhuang/parkour MIT -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

WMP

2 | 3 | Code for the paper: 4 | ### World Model-based Perception for Visual Legged Locomotion 5 | [Hang Lai](https://apex.sjtu.edu.cn/members/laihang@apexlab.org), [Jiahang Cao](https://apex.sjtu.edu.cn/members/jhcao@apexlab.org), [JiaFeng Xu](https://scholar.google.com/citations?user=GPmUxtIAAAAJ&hl=zh-CN&oi=ao), [Hongtao Wu](https://scholar.google.com/citations?user=7u0TYgIAAAAJ&hl=zh-CN&oi=ao), [Yunfeng Lin](https://apex.sjtu.edu.cn/members/yflin@apexlab.org), [Tao Kong](https://www.taokong.org/), [Yong Yu](https://scholar.google.com.hk/citations?user=-84M1m0AAAAJ&hl=zh-CN&oi=ao), [Weinan Zhang](https://wnzhang.net/) 6 | 7 | ### [🌐 Project Website](https://wmp-loco.github.io/) | [📄 Paper](https://arxiv.org/abs/2409.16784) 8 | 9 | ## Requirements 10 | 1. Create a new python virtual env with python 3.6, 3.7 or 3.8 (3.8 recommended) 11 | 2. Install pytorch: 12 | - `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117` 13 | 3. Install Isaac Gym 14 | - Download and install Isaac Gym Preview 3 (Preview 2 will not work!) from https://developer.nvidia.com/isaac-gym 15 | - `cd isaacgym/python && pip install -e .` 16 | 4. Install other packages: 17 | - `sudo apt-get install build-essential --fix-missing` 18 | - `sudo apt-get install ninja-build` 19 | - `pip install setuptools==59.5.0` 20 | - `pip install ruamel_yaml==0.17.4` 21 | - `sudo apt install libgl1-mesa-glx -y` 22 | - `pip install opencv-contrib-python` 23 | - `pip install -r requirements.txt` 24 | 25 | ## Training 26 | ``` 27 | python legged_gym/scripts/train.py --task=a1_amp --headless --sim_device=cuda:0 28 | ``` 29 | Training takes about 23G GPU memory, and at least 10k iterations recommended. 30 | 31 | ## Visualization 32 | **Please make sure you have trained the WMP before** 33 | ``` 34 | python legged_gym/scripts/play.py --task=a1_amp --sim_device=cuda:0 --terrain=climb 35 | ``` 36 | 37 | 38 | ## Acknowledgments 39 | 40 | We thank the authors of the following projects for making their code open source: 41 | 42 | - [leggedgym](https://github.com/leggedrobotics/legged_gym) 43 | - [dreamerv3-torch](https://github.com/NM512/dreamerv3-torch) 44 | - [AMP_for_hardware](https://github.com/Alescontrela/AMP_for_hardware) 45 | - [parkour](https://github.com/ZiwenZhuang/parkour/tree/main) 46 | - [extreme-parkour](https://github.com/chengxuxin/extreme-parkour) 47 | 48 | 49 | 50 | ## Citation 51 | 52 | If you find this project helpful, please consider citing our paper: 53 | ``` 54 | @article{lai2024world, 55 | title={World Model-based Perception for Visual Legged Locomotion}, 56 | author={Lai, Hang and Cao, Jiahang and Xu, Jiafeng and Wu, Hongtao and Lin, Yunfeng and Kong, Tao and Yu, Yong and Zhang, Weinan}, 57 | journal={arXiv preprint arXiv:2409.16784}, 58 | year={2024} 59 | } 60 | ``` -------------------------------------------------------------------------------- /dreamer/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2023 NM512 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | 22 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 23 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 24 | 25 | # from .dreamer import * 26 | from .networks import * 27 | from .tools import * 28 | from .models import * 29 | -------------------------------------------------------------------------------- /dreamer/configs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | device: 'cuda:0' 4 | compile: True 5 | precision: 32 6 | debug: False 7 | video_pred_log: false 8 | 9 | # Environment 10 | task: 'a1' 11 | num_actions: 12 12 | 13 | # Model 14 | dyn_hidden: 512 15 | dyn_deter: 512 16 | dyn_stoch: 32 17 | dyn_discrete: 32 18 | dyn_rec_depth: 1 19 | dyn_mean_act: 'none' 20 | dyn_std_act: 'sigmoid2' 21 | dyn_min_std: 0.1 22 | grad_heads: ['decoder', 'reward'] 23 | units: 512 # = other heads except decoder, e.g., reward 24 | act: 'SiLU' 25 | norm: True 26 | encoder: 27 | {mlp_keys: '.*', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 5, mlp_units: 1024, symlog_inputs: True} 28 | decoder: 29 | {mlp_keys: '.*', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 5, mlp_units: 1024, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse, outscale: 1.0} 30 | actor: 31 | {layers: 2, dist: 'normal', entropy: 3e-4, unimix_ratio: 0.01, std: 'learned', min_std: 0.1, max_std: 1.0, temp: 0.1, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 1.0} 32 | critic: 33 | {layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0} 34 | reward_head: 35 | {layers: 2, dist: 'symlog_disc', loss_scale: 0.0, outscale: 0.0} 36 | cont_head: 37 | {layers: 2, loss_scale: 1.0, outscale: 1.0} 38 | dyn_scale: 0.5 39 | rep_scale: 0.1 40 | kl_free: 1.0 41 | weight_decay: 0.0 42 | unimix_ratio: 0.01 43 | initial: 'learned' 44 | 45 | # Training 46 | train_steps_per_iter: 10 47 | train_start_steps: 10000 48 | batch_size: 16 49 | batch_length: 64 50 | train_ratio: 512 51 | pretrain: 100 52 | model_lr: 1e-4 53 | opt_eps: 1e-8 54 | grad_clip: 1000 55 | opt: 'adam' 56 | 57 | -------------------------------------------------------------------------------- /dreamer/models.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2023 NM512 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | 22 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 23 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 24 | 25 | import copy 26 | import torch 27 | from torch import nn 28 | 29 | from . import tools 30 | from . import networks 31 | 32 | to_np = lambda x: x.detach().cpu().numpy() 33 | 34 | class WorldModel(nn.Module): 35 | def __init__(self, config, obs_shape, use_camera): 36 | super(WorldModel, self).__init__() 37 | # self._step = step 38 | self._use_amp = True if config.precision == 16 else False 39 | self._config = config 40 | self.device = self._config.device 41 | 42 | self.encoder = networks.MultiEncoder(obs_shape, **config.encoder, use_camera=use_camera) 43 | self.embed_size = self.encoder.outdim 44 | self.dynamics = networks.RSSM( 45 | config.dyn_stoch, 46 | config.dyn_deter, 47 | config.dyn_hidden, 48 | config.dyn_rec_depth, 49 | config.dyn_discrete, 50 | config.act, 51 | config.norm, 52 | config.dyn_mean_act, 53 | config.dyn_std_act, 54 | config.dyn_min_std, 55 | config.unimix_ratio, 56 | config.initial, 57 | config.num_actions, 58 | self.embed_size, 59 | config.device, 60 | ) 61 | self.heads = nn.ModuleDict() 62 | if config.dyn_discrete: 63 | feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter 64 | else: 65 | feat_size = config.dyn_stoch + config.dyn_deter 66 | self.heads["decoder"] = networks.MultiDecoder( 67 | feat_size, obs_shape, **config.decoder, use_camera=use_camera 68 | ) 69 | self.heads["reward"] = networks.MLP( 70 | feat_size, 71 | (255,) if config.reward_head["dist"] == "symlog_disc" else (), 72 | config.reward_head["layers"], 73 | config.units, 74 | config.act, 75 | config.norm, 76 | dist=config.reward_head["dist"], 77 | outscale=config.reward_head["outscale"], 78 | device=config.device, 79 | name="Reward", 80 | ) 81 | # self.heads["cont"] = networks.MLP( 82 | # feat_size, 83 | # (), 84 | # config.cont_head["layers"], 85 | # config.units, 86 | # config.act, 87 | # config.norm, 88 | # dist="binary", 89 | # outscale=config.cont_head["outscale"], 90 | # device=config.device, 91 | # name="Cont", 92 | # ) 93 | for name in config.grad_heads: 94 | assert name in self.heads, name 95 | self._model_opt = tools.Optimizer( 96 | "model", 97 | self.parameters(), 98 | config.model_lr, 99 | config.opt_eps, 100 | config.grad_clip, 101 | config.weight_decay, 102 | opt=config.opt, 103 | use_amp=self._use_amp, 104 | ) 105 | print( 106 | f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables." 107 | ) 108 | # other losses are scaled by 1.0. 109 | # can set different scale for terms in decoder here 110 | self._scales = dict( 111 | reward=config.reward_head["loss_scale"], 112 | image = 1.0, 113 | # clean_prop = 0, 114 | # cont=config.cont_head["loss_scale"], 115 | ) 116 | 117 | def _train(self, data): 118 | # action (batch_size, batch_length, act_dim) 119 | # image (batch_size, batch_length, h, w, ch) 120 | # reward (batch_size, batch_length) 121 | # discount (batch_size, batch_length) 122 | data = self.preprocess(data) 123 | 124 | with tools.RequiresGrad(self): 125 | with torch.cuda.amp.autocast(self._use_amp): 126 | embed = self.encoder(data) 127 | post, prior = self.dynamics.observe( 128 | embed, data["action"], data["is_first"] 129 | ) 130 | kl_free = self._config.kl_free 131 | dyn_scale = self._config.dyn_scale 132 | rep_scale = self._config.rep_scale 133 | kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss( 134 | post, prior, kl_free, dyn_scale, rep_scale 135 | ) 136 | assert kl_loss.shape == embed.shape[:2], kl_loss.shape 137 | preds = {} 138 | for name, head in self.heads.items(): 139 | grad_head = name in self._config.grad_heads 140 | feat = self.dynamics.get_feat(post) 141 | feat = feat if grad_head else feat.detach() 142 | pred = head(feat) 143 | if type(pred) is dict: 144 | preds.update(pred) 145 | else: 146 | preds[name] = pred 147 | losses = {} 148 | for name, pred in preds.items(): 149 | loss = -pred.log_prob(data[name]) 150 | assert loss.shape == embed.shape[:2], (name, loss.shape) 151 | losses[name] = loss 152 | scaled = { 153 | key: value * self._scales.get(key, 1.0) 154 | for key, value in losses.items() 155 | } 156 | model_loss = sum(scaled.values()) + kl_loss 157 | metrics = self._model_opt(torch.mean(model_loss), self.parameters()) 158 | 159 | metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()}) 160 | metrics["kl_free"] = kl_free 161 | metrics["dyn_scale"] = dyn_scale 162 | metrics["rep_scale"] = rep_scale 163 | metrics["dyn_loss"] = to_np(dyn_loss) 164 | metrics["rep_loss"] = to_np(rep_loss) 165 | metrics["kl"] = to_np(torch.mean(kl_value)) 166 | with torch.cuda.amp.autocast(self._use_amp): 167 | metrics["prior_ent"] = to_np( 168 | torch.mean(self.dynamics.get_dist(prior).entropy()) 169 | ) 170 | metrics["post_ent"] = to_np( 171 | torch.mean(self.dynamics.get_dist(post).entropy()) 172 | ) 173 | context = dict( 174 | embed=embed, 175 | feat=self.dynamics.get_feat(post), 176 | kl=kl_value, 177 | postent=self.dynamics.get_dist(post).entropy(), 178 | ) 179 | post = {k: v.detach() for k, v in post.items()} 180 | return post, context, metrics 181 | 182 | # this function is called during both rollout and training 183 | def preprocess(self, obs): 184 | # obs = obs.copy() 185 | # obs["image"] = torch.Tensor(obs["image"]) / 255.0 186 | 187 | # discount in obs seems useless 188 | # if "discount" in obs: 189 | # obs["discount"] *= self._config.discount 190 | # (batch_size, batch_length) -> (batch_size, batch_length, 1) 191 | # obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1) 192 | # 'is_first' is necesarry to initialize hidden state at training 193 | assert "is_first" in obs 194 | # 'is_terminal' is necesarry to train cont_head 195 | # assert "is_terminal" in obs 196 | # obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1) 197 | obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()} 198 | return obs 199 | 200 | def video_pred(self, data): 201 | data = self.preprocess(data) 202 | embed = self.encoder(data) 203 | 204 | states, _ = self.dynamics.observe( 205 | embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5] 206 | ) 207 | recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[ 208 | :6 209 | ] 210 | reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6] 211 | init = {k: v[:, -1] for k, v in states.items()} 212 | prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init) 213 | openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode() 214 | reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode() 215 | # observed image is given until 5 steps 216 | model = torch.cat([recon[:, :5], openl], 1) 217 | truth = data["image"][:6] 218 | model = model 219 | error = (model - truth + 1.0) / 2.0 220 | 221 | return torch.cat([truth, model], 2) 222 | -------------------------------------------------------------------------------- /legged_gym/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import os 32 | 33 | LEGGED_GYM_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 34 | LEGGED_GYM_ENVS_DIR = os.path.join(LEGGED_GYM_ROOT_DIR, 'legged_gym', 'envs') -------------------------------------------------------------------------------- /legged_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR 32 | from legged_gym.envs.a1.a1_config import A1RoughCfg, A1RoughCfgPPO 33 | from .base.legged_robot import LeggedRobot 34 | from .a1.a1_config import A1RoughCfg, A1RoughCfgPPO 35 | from .a1.a1_amp_config import A1AMPCfg, A1AMPCfgPPO 36 | 37 | 38 | import os 39 | 40 | from legged_gym.utils.task_registry import task_registry 41 | 42 | task_registry.register( "a1", LeggedRobot, A1RoughCfg(), A1RoughCfgPPO() ) 43 | task_registry.register( "a1_amp", LeggedRobot, A1AMPCfg(), A1AMPCfgPPO() ) 44 | -------------------------------------------------------------------------------- /legged_gym/envs/a1/a1_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | from legged_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO 35 | 36 | class A1RoughCfg( LeggedRobotCfg ): 37 | 38 | class env( LeggedRobotCfg.env ): 39 | num_envs = 5480 40 | include_history_steps = None # Number of steps of history to include. 41 | num_observations = 235 42 | num_privileged_obs = 235 43 | reference_state_initialization = False 44 | # reference_state_initialization_prob = 0.85 45 | # amp_motion_files = MOTION_FILES 46 | 47 | class init_state( LeggedRobotCfg.init_state ): 48 | pos = [0.0, 0.0, 0.42] # x,y,z [m] 49 | default_joint_angles = { # = target angles [rad] when action = 0.0 50 | 'FL_hip_joint': 0.1, # [rad] 51 | 'RL_hip_joint': 0.1, # [rad] 52 | 'FR_hip_joint': -0.1 , # [rad] 53 | 'RR_hip_joint': -0.1, # [rad] 54 | 55 | 'FL_thigh_joint': 0.8, # [rad] 56 | 'RL_thigh_joint': 1., # [rad] 57 | 'FR_thigh_joint': 0.8, # [rad] 58 | 'RR_thigh_joint': 1., # [rad] 59 | 60 | 'FL_calf_joint': -1.5, # [rad] 61 | 'RL_calf_joint': -1.5, # [rad] 62 | 'FR_calf_joint': -1.5, # [rad] 63 | 'RR_calf_joint': -1.5, # [rad] 64 | } 65 | 66 | class control( LeggedRobotCfg.control ): 67 | # PD Drive parameters: 68 | control_type = 'P' 69 | stiffness = {'joint': 20.} # [N*m/rad] 70 | damping = {'joint': 0.5} # [N*m*s/rad] 71 | # action scale: target angle = actionScale * action + defaultAngle 72 | action_scale = 0.25 73 | # decimation: Number of control action updates @ sim DT per policy DT 74 | decimation = 4 75 | 76 | class asset( LeggedRobotCfg.asset ): 77 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/a1/urdf/a1.urdf' 78 | foot_name = "foot" 79 | penalize_contacts_on = ["thigh", "calf"] 80 | terminate_after_contacts_on = ["base"] 81 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter 82 | 83 | class rewards( LeggedRobotCfg.rewards ): 84 | soft_dof_pos_limit = 0.9 85 | base_height_target = 0.25 86 | class scales( LeggedRobotCfg.rewards.scales ): 87 | torques = -0.0002 88 | dof_pos_limits = -10.0 89 | 90 | class A1RoughCfgPPO( LeggedRobotCfgPPO ): 91 | class algorithm( LeggedRobotCfgPPO.algorithm ): 92 | entropy_coef = 0.01 93 | class runner( LeggedRobotCfgPPO.runner ): 94 | run_name = '' 95 | experiment_name = 'rough_a1' 96 | 97 | -------------------------------------------------------------------------------- /legged_gym/envs/base/base_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import inspect 32 | 33 | class BaseConfig: 34 | def __init__(self) -> None: 35 | """ Initializes all member classes recursively. Ignores all namse starting with '__' (buit-in methods).""" 36 | self.init_member_classes(self) 37 | 38 | @staticmethod 39 | def init_member_classes(obj): 40 | # iterate over all attributes names 41 | for key in dir(obj): 42 | # disregard builtin attributes 43 | # if key.startswith("__"): 44 | if key=="__class__": 45 | continue 46 | # get the corresponding attribute object 47 | var = getattr(obj, key) 48 | # check if it the attribute is a class 49 | if inspect.isclass(var): 50 | # instantate the class 51 | i_var = var() 52 | # set the attribute to the instance instead of the type 53 | setattr(obj, key, i_var) 54 | # recursively init members of the attribute 55 | BaseConfig.init_member_classes(i_var) -------------------------------------------------------------------------------- /legged_gym/envs/base/base_task.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import sys 32 | from isaacgym import gymapi 33 | from isaacgym import gymutil 34 | import numpy as np 35 | import torch 36 | 37 | from legged_gym.envs.base import observation_buffer 38 | 39 | 40 | # Base class for RL tasks 41 | class BaseTask(): 42 | 43 | def __init__(self, cfg, sim_params, physics_engine, sim_device, headless): 44 | self.gym = gymapi.acquire_gym() 45 | 46 | self.sim_params = sim_params 47 | self.physics_engine = physics_engine 48 | self.sim_device = sim_device 49 | sim_device_type, self.sim_device_id = gymutil.parse_device_str(self.sim_device) 50 | self.headless = headless 51 | 52 | # env device is GPU only if sim is on GPU and use_gpu_pipeline=True, otherwise returned tensors are copied to CPU by physX. 53 | if sim_device_type=='cuda' and sim_params.use_gpu_pipeline: 54 | self.device = self.sim_device 55 | else: 56 | self.device = 'cpu' 57 | 58 | # graphics device for rendering, -1 for no rendering 59 | self.graphics_device_id = self.sim_device_id 60 | if self.headless == True: 61 | self.graphics_device_id = -1 62 | 63 | self.num_envs = cfg.env.num_envs 64 | self.num_obs = cfg.env.num_observations 65 | self.num_privileged_obs = cfg.env.num_privileged_obs 66 | self.num_actions = cfg.env.num_actions 67 | self.include_history_steps = cfg.env.include_history_steps 68 | 69 | self.height_dim = cfg.env.height_dim 70 | self.privileged_dim = cfg.env.privileged_dim 71 | 72 | # optimization flags for pytorch JIT 73 | torch._C._jit_set_profiling_mode(False) 74 | torch._C._jit_set_profiling_executor(False) 75 | 76 | # allocate buffers 77 | if cfg.env.include_history_steps is not None: 78 | self.obs_buf_history = observation_buffer.ObservationBuffer( 79 | self.num_envs, self.num_obs, 80 | self.include_history_steps, self.device) 81 | self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float) 82 | self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float) 83 | self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long) 84 | self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long) 85 | self.time_out_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) 86 | if self.num_privileged_obs is not None: 87 | self.privileged_obs_buf = torch.zeros(self.num_envs, self.num_privileged_obs, device=self.device, dtype=torch.float) 88 | else: 89 | self.privileged_obs_buf = None 90 | # self.num_privileged_obs = self.num_obs 91 | 92 | self.extras = {} 93 | 94 | # create envs, sim and viewer 95 | self.create_sim() 96 | self.gym.prepare_sim(self.sim) 97 | 98 | # todo: read from config 99 | self.enable_viewer_sync = True 100 | self.viewer = None 101 | 102 | # if running with a viewer, set up keyboard shortcuts and camera 103 | if self.headless == False: 104 | # subscribe to keyboard shortcuts 105 | self.viewer = self.gym.create_viewer( 106 | self.sim, gymapi.CameraProperties()) 107 | self.gym.subscribe_viewer_keyboard_event( 108 | self.viewer, gymapi.KEY_ESCAPE, "QUIT") 109 | self.gym.subscribe_viewer_keyboard_event( 110 | self.viewer, gymapi.KEY_V, "toggle_viewer_sync") 111 | 112 | def get_observations(self): 113 | return self.obs_buf 114 | 115 | def get_privileged_observations(self): 116 | return self.privileged_obs_buf 117 | 118 | def reset_idx(self, env_ids): 119 | """Reset selected robots""" 120 | raise NotImplementedError 121 | 122 | def reset(self): 123 | """ Reset all robots""" 124 | self.reset_idx(torch.arange(self.num_envs, device=self.device)) 125 | obs, privileged_obs, _, _, _ = self.step(torch.zeros(self.num_envs, self.num_actions, device=self.device, requires_grad=False)) 126 | return obs, privileged_obs 127 | 128 | def step(self, actions): 129 | raise NotImplementedError 130 | 131 | def render(self, sync_frame_time=True): 132 | if self.viewer: 133 | # check for window closed 134 | if self.gym.query_viewer_has_closed(self.viewer): 135 | sys.exit() 136 | 137 | # check for keyboard events 138 | for evt in self.gym.query_viewer_action_events(self.viewer): 139 | if evt.action == "QUIT" and evt.value > 0: 140 | sys.exit() 141 | elif evt.action == "toggle_viewer_sync" and evt.value > 0: 142 | self.enable_viewer_sync = not self.enable_viewer_sync 143 | 144 | # fetch results 145 | if self.device != 'cpu': 146 | self.gym.fetch_results(self.sim, True) 147 | 148 | # step graphics 149 | if self.enable_viewer_sync: 150 | self.gym.step_graphics(self.sim) 151 | self.gym.draw_viewer(self.viewer, self.sim, True) 152 | if sync_frame_time: 153 | self.gym.sync_frame_time(self.sim) 154 | else: 155 | self.gym.poll_viewer_events(self.viewer) -------------------------------------------------------------------------------- /legged_gym/envs/base/legged_robot_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | from .base_config import BaseConfig 35 | 36 | class LeggedRobotCfg(BaseConfig): 37 | class env: 38 | num_envs = 4096 39 | num_observations = 235 40 | privileged_obs = True # if True, add the privileged information in the obs 41 | privileged_dim = 24 + 3 # privileged_obs[:,:privileged_dim] is the privileged information in privileged_obs, include 3-dim base linear vel 42 | height_dim = 187 # privileged_obs[:,-height_dim:] is the heightmap in privileged_obs 43 | num_privileged_obs = None # if not None a priviledge_obs_buf will be returned by step() (critic obs for assymetric training). None is returned otherwise 44 | num_actions = 12 45 | env_spacing = 3. # not used with heightfields/trimeshes 46 | send_timeouts = True # send time out information to the algorithm 47 | episode_length_s = 20 # episode length in seconds 48 | reference_state_initialization = False # initialize state from reference data 49 | 50 | class terrain: 51 | mesh_type = 'trimesh' # "heightfield" # none, plane, heightfield or trimesh 52 | horizontal_scale = 0.1 # [m] 53 | vertical_scale = 0.005 # [m] 54 | border_size = 50 # [m] change 25 to 50 55 | curriculum = True 56 | static_friction = 1.0 57 | dynamic_friction = 1.0 58 | restitution = 0. 59 | # rough terrain only: 60 | measure_heights = True 61 | measured_points_x = [-0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] # 1mx1.6m rectangle (without center line) 62 | measured_points_y = [-0.5, -0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5] 63 | selected = False # select a unique terrain type and pass all arguments 64 | terrain_kwargs = None # Dict of arguments for selected terrain 65 | max_init_terrain_level = 0 # starting curriculum state 66 | terrain_length = 8. 67 | terrain_width = 8. 68 | num_rows= 10 # number of terrain rows (levels) 69 | num_cols = 20 # number of terrain cols (types) 70 | # terrain types: [wave, rough slope, stairs up, stairs down, discrete, rough_flat] 71 | terrain_proportions = [0.1, 0.1, 0.30, 0.25, 0.15, 0.1] 72 | # terrain_proportions = [0.0, 0.0, 0.0, 0.0, 0.0, 1.0] 73 | # trimesh only: 74 | slope_treshold = 0.75 # slopes above this threshold will be corrected to vertical surfaces 75 | 76 | class commands: 77 | curriculum = False 78 | max_curriculum = 1. 79 | num_commands = 4 # default: lin_vel_x, lin_vel_y, ang_vel_yaw, heading (in heading mode ang_vel_yaw is recomputed from heading error) 80 | resampling_time = 10. # time before command are changed[s] 81 | heading_command = True # if true: compute ang vel command from heading error 82 | class ranges: 83 | lin_vel_x = [-1.0, 1.0] # min max [m/s] 84 | lin_vel_y = [-1.0, 1.0] # min max [m/s] 85 | ang_vel_yaw = [-1, 1] # min max [rad/s] 86 | heading = [-3.14, 3.14] 87 | 88 | class init_state: 89 | pos = [0.0, 0.0, 1.] # x,y,z [m] 90 | rot = [0.0, 0.0, 0.0, 1.0] # x,y,z,w [quat] 91 | lin_vel = [0.0, 0.0, 0.0] # x,y,z [m/s] 92 | ang_vel = [0.0, 0.0, 0.0] # x,y,z [rad/s] 93 | default_joint_angles = { # target angles when action = 0.0 94 | "joint_a": 0., 95 | "joint_b": 0.} 96 | 97 | class control: 98 | control_type = 'P' # P: position, V: velocity, T: torques 99 | # PD Drive parameters: 100 | stiffness = {'joint_a': 10.0, 'joint_b': 15.} # [N*m/rad] 101 | damping = {'joint_a': 1.0, 'joint_b': 1.5} # [N*m*s/rad] 102 | # action scale: target angle = actionScale * action + defaultAngle 103 | action_scale = 0.5 104 | # decimation: Number of control action updates @ sim DT per policy DT 105 | decimation = 4 106 | 107 | class depth: 108 | use_camera = False 109 | camera_num_envs = 192 110 | camera_terrain_num_rows = 10 111 | camera_terrain_num_cols = 20 112 | 113 | position = [0.27, 0, 0.03] # front camera 114 | angle = [-5, 5] # positive pitch down 115 | 116 | update_interval = 5 # 5 works without retraining, 8 worse 117 | 118 | original = (106, 60) 119 | resized = (87, 58) 120 | horizontal_fov = 87 121 | buffer_len = 2 122 | 123 | near_clip = 0 124 | far_clip = 2 125 | dis_noise = 0.0 126 | 127 | scale = 1 128 | invert = True 129 | 130 | class asset: 131 | file = "" 132 | foot_name = "None" # name of the feet bodies, used to index body state and contact force tensors 133 | penalize_contacts_on = [] 134 | terminate_after_contacts_on = [] 135 | disable_gravity = False 136 | collapse_fixed_joints = True # merge bodies connected by fixed joints. Specific fixed joints can be kept by adding " <... dont_collapse="true"> 137 | fix_base_link = False # fixe the base of the robot 138 | default_dof_drive_mode = 3 # see GymDofDriveModeFlags (0 is none, 1 is pos tgt, 2 is vel tgt, 3 effort) 139 | self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter 140 | replace_cylinder_with_capsule = True # replace collision cylinders with capsules, leads to faster/more stable simulation 141 | flip_visual_attachments = True # Some .obj meshes must be flipped from y-up to z-up 142 | 143 | density = 0.001 144 | angular_damping = 0. 145 | linear_damping = 0. 146 | max_angular_velocity = 1000. 147 | max_linear_velocity = 1000. 148 | armature = 0. 149 | thickness = 0.01 150 | 151 | class domain_rand: 152 | randomize_friction = True 153 | friction_range = [0.25, 1.75] 154 | randomize_restitution = True 155 | restitution_range = [0, 1] 156 | 157 | randomize_base_mass = True 158 | added_mass_range = [-1., 1.] # kg 159 | randomize_link_mass = True 160 | link_mass_range = [0.8, 1.2] 161 | randomize_com_pos = True 162 | com_pos_range = [-0.05, 0.05] 163 | 164 | push_robots = True 165 | push_interval_s = 15 166 | max_push_vel_xy = 1.0 167 | 168 | randomize_gains = True 169 | stiffness_multiplier_range = [0.9, 1.1] 170 | damping_multiplier_range = [0.9, 1.1] 171 | randomize_motor_strength = True 172 | 173 | motor_strength_range = [0.9, 1.1] 174 | randomize_action_latency = True 175 | latency_range = [0.00, 0.02] 176 | 177 | class rewards: 178 | reward_curriculum = True 179 | reward_curriculum_term = ["lin_vel_z"] 180 | reward_curriculum_schedule = [0, 1000, 1, 0] #from iter 0 to iter 1000, decrease from 1 to 0 181 | class scales: 182 | termination = -0.0 183 | tracking_lin_vel = 1.0 184 | tracking_ang_vel = 0.5 185 | lin_vel_z = -2.0 186 | ang_vel_xy = 0#-0.05 187 | orientation = -0. 188 | torques = -0.00001 189 | dof_vel = -0. 190 | dof_acc = -2.5e-7 191 | base_height = -0. 192 | feet_air_time = 1.0 193 | collision = -1. 194 | feet_stumble = -0.0 195 | action_rate = -0.01 196 | stand_still = -0. 197 | 198 | only_positive_rewards = True # if true negative total rewards are clipped at zero (avoids early termination problems) 199 | tracking_sigma = 0.15 # tracking reward = exp(-error^2/sigma) 200 | soft_dof_pos_limit = 1. # percentage of urdf limits, values above this limit are penalized 201 | soft_dof_vel_limit = 1. 202 | soft_torque_limit = 1. 203 | base_height_target = 1. 204 | max_contact_force = 100. # forces above this value are penalized 205 | 206 | class normalization: 207 | class obs_scales: 208 | lin_vel = 1.0 209 | ang_vel = 0.25 210 | dof_pos = 1.0 211 | dof_vel = 0.05 212 | # privileged 213 | height_measurements = 5.0 214 | contact_force = 0.005 215 | com_pos = 20 216 | pd_gains = 5 217 | clip_observations = 100. 218 | clip_actions = 6.0 219 | 220 | base_height = 0.5 221 | 222 | class noise: 223 | add_noise = True 224 | noise_level = 1.0 # scales other values 225 | class noise_scales: 226 | dof_pos = 0.01 227 | dof_vel = 1.5 228 | lin_vel = 0.1 229 | ang_vel = 0.2 230 | gravity = 0.05 231 | height_measurements = 0.1 232 | 233 | # viewer camera: 234 | class viewer: 235 | ref_env = 0 236 | pos = [10, 0, 6] # [m] 237 | lookat = [11., 5, 3.] # [m] 238 | 239 | class sim: 240 | dt = 0.002 241 | substeps = 1 242 | gravity = [0., 0. ,-9.81] # [m/s^2] 243 | up_axis = 1 # 0 is y, 1 is z 244 | 245 | class physx: 246 | num_threads = 10 247 | solver_type = 1 # 0: pgs, 1: tgs 248 | num_position_iterations = 4 249 | num_velocity_iterations = 0 250 | contact_offset = 0.01 # [m] 251 | rest_offset = 0.0 # [m] 252 | bounce_threshold_velocity = 0.5 #0.5 [m/s] 253 | max_depenetration_velocity = 1.0 254 | max_gpu_contact_pairs = 2**23 #2**24 -> needed for 8000 envs and more 255 | default_buffer_size_multiplier = 5 256 | contact_collection = 2 # 0: never, 1: last sub-step, 2: all sub-steps (default=2) 257 | 258 | class LeggedRobotCfgPPO(BaseConfig): 259 | seed = 1 260 | runner_class_name = 'OnPolicyRunner' 261 | class policy: 262 | init_noise_std = 1.0 263 | actor_hidden_dims = [512, 256, 128] 264 | critic_hidden_dims = [512, 256, 128] 265 | latent_dim = 32 266 | # height_latent_dim = 16 # the encoder in teacher policy encodes the heightmap into a height_latent_dim vector 267 | # privileged_latent_dim = 8 # the encoder in teacher policy encodes the privileged infomation into a privileged_latent_dim vector 268 | activation = 'elu' # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid 269 | # only for 'ActorCriticRecurrent': 270 | # rnn_type = 'lstm' 271 | # rnn_hidden_size = 512 272 | # rnn_num_layers = 1 273 | 274 | class algorithm: 275 | # training params 276 | value_loss_coef = 1.0 277 | use_clipped_value_loss = True 278 | clip_param = 0.2 279 | entropy_coef = 0.01 280 | num_learning_epochs = 5 281 | num_mini_batches = 4 # mini batch size = num_envs*nsteps / nminibatches 282 | learning_rate = 1.e-3 #5.e-4 283 | schedule = 'adaptive' # could be adaptive, fixed 284 | gamma = 0.99 285 | lam = 0.95 286 | desired_kl = 0.01 287 | max_grad_norm = 1. 288 | 289 | class runner: 290 | policy_class_name = 'ActorCritic' 291 | algorithm_class_name = 'PPO' 292 | num_steps_per_env = 24 # per iteration 293 | max_iterations = 1500 # number of policy updates 294 | 295 | # logging 296 | save_interval = 200 # check for potential saves every this many iterations 297 | experiment_name = 'test' 298 | run_name = 'trot' 299 | # load and resume 300 | resume = False 301 | load_run = -1 # -1 = last run 302 | checkpoint = -1 # -1 = last saved model 303 | resume_path = None # updated from load_run and chkpt -------------------------------------------------------------------------------- /legged_gym/envs/base/observation_buffer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import torch 32 | 33 | class ObservationBuffer: 34 | def __init__(self, num_envs, num_obs, include_history_steps, device): 35 | 36 | self.num_envs = num_envs 37 | self.num_obs = num_obs 38 | self.include_history_steps = include_history_steps 39 | self.device = device 40 | 41 | self.num_obs_total = num_obs * include_history_steps 42 | 43 | self.obs_buf = torch.zeros(self.num_envs, self.num_obs_total, device=self.device, dtype=torch.float) 44 | 45 | def reset(self, reset_idxs, new_obs): 46 | self.obs_buf[reset_idxs] = new_obs.repeat(1, self.include_history_steps) 47 | 48 | def insert(self, new_obs): 49 | # Shift observations back. 50 | self.obs_buf[:, : self.num_obs * (self.include_history_steps - 1)] = self.obs_buf[:,self.num_obs : self.num_obs * self.include_history_steps] 51 | 52 | # Add new observation. 53 | self.obs_buf[:, -self.num_obs:] = new_obs 54 | 55 | def get_obs_vec(self, obs_ids): 56 | """Gets history of observations indexed by obs_ids. 57 | 58 | Arguments: 59 | obs_ids: An array of integers with which to index the desired 60 | observations, where 0 is the latest observation and 61 | include_history_steps - 1 is the oldest observation. 62 | """ 63 | 64 | obs = [] 65 | for obs_id in reversed(sorted(obs_ids)): 66 | slice_idx = self.include_history_steps - obs_id - 1 67 | obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs]) 68 | return torch.cat(obs, dim=-1) 69 | 70 | -------------------------------------------------------------------------------- /legged_gym/scripts/play.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | import os 35 | import inspect 36 | import time 37 | 38 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 39 | parentdir = os.path.dirname(os.path.dirname(currentdir)) 40 | os.sys.path.insert(0, parentdir) 41 | from legged_gym import LEGGED_GYM_ROOT_DIR 42 | 43 | import isaacgym 44 | from legged_gym.envs import * 45 | from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger 46 | 47 | import numpy as np 48 | import torch 49 | 50 | 51 | def play(args): 52 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) 53 | # override some parameters for testing 54 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 10) 55 | env_cfg.terrain.num_cols = 1 56 | env_cfg.terrain.curriculum = False 57 | env_cfg.noise.add_noise = False 58 | # env_cfg.domain_rand.randomize_friction = False 59 | # env_cfg.domain_rand.randomize_restitution = False 60 | # env_cfg.commands.heading_command = True 61 | 62 | env_cfg.domain_rand.friction_range = [1.0, 1.0] 63 | env_cfg.domain_rand.restitution_range = [0.0, 0.0] 64 | env_cfg.domain_rand.added_mass_range = [0., 0.] # kg 65 | env_cfg.domain_rand.com_x_pos_range = [-0.0, 0.0] 66 | env_cfg.domain_rand.com_y_pos_range = [-0.0, 0.0] 67 | env_cfg.domain_rand.com_z_pos_range = [-0.0, 0.0] 68 | 69 | env_cfg.domain_rand.randomize_action_latency = False 70 | env_cfg.domain_rand.push_robots = False 71 | env_cfg.domain_rand.randomize_gains = True 72 | # env_cfg.domain_rand.randomize_base_mass = False 73 | env_cfg.domain_rand.randomize_link_mass = False 74 | # env_cfg.domain_rand.randomize_com_pos = False 75 | env_cfg.domain_rand.randomize_motor_strength = False 76 | 77 | train_cfg.runner.amp_num_preload_transitions = 1 78 | 79 | env_cfg.domain_rand.stiffness_multiplier_range = [1.0, 1.0] 80 | env_cfg.domain_rand.damping_multiplier_range = [1.0, 1.0] 81 | 82 | 83 | # env_cfg.terrain.mesh_type = 'plane' 84 | if(env_cfg.terrain.mesh_type == 'plane'): 85 | env_cfg.rewards.scales.feet_edge = 0 86 | env_cfg.rewards.scales.feet_stumble = 0 87 | 88 | 89 | if(args.terrain not in ['slope', 'stair', 'gap', 'climb', 'crawl', 'tilt']): 90 | print('terrain should be one of slope, stair, gap, climb, crawl, and tilt, set to climb as default') 91 | args.terrain = 'climb' 92 | env_cfg.terrain.terrain_proportions = { 93 | 'slope': [0, 1.0, 0.0, 0, 0, 0, 0, 0, 0], 94 | 'stair': [0, 0, 1.0, 0, 0, 0, 0, 0, 0], 95 | 'gap': [0, 0, 0, 0, 0, 1.0, 0, 0, 0, 0], 96 | 'climb': [0, 0, 0, 0, 0, 0, 1.0, 0, 0, 0], 97 | 'tilt': [0, 0, 0, 0, 0, 0, 0, 1.0, 0, 0], 98 | 'crawl': [0, 0, 0, 0, 0, 0, 0, 0, 1.0, 0], 99 | }[args.terrain] 100 | 101 | env_cfg.commands.ranges.lin_vel_x = [0.6, 0.6] 102 | env_cfg.commands.ranges.lin_vel_y = [-0.0, -0.0] 103 | env_cfg.commands.ranges.ang_vel_yaw = [0.0, 0.0] 104 | env_cfg.commands.ranges.heading = [0, 0] 105 | 106 | env_cfg.commands.ranges.flat_lin_vel_x = [0.6, 0.6] 107 | env_cfg.commands.ranges.flat_lin_vel_y = [-0.0, -0.0] 108 | env_cfg.commands.ranges.flat_ang_vel_yaw = [0.0, 0.0] 109 | 110 | env_cfg.depth.use_camera = True 111 | 112 | # prepare environment 113 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) 114 | _, _ = env.reset() 115 | obs = env.get_observations() 116 | # load policy 117 | train_cfg.runner.resume = True 118 | train_cfg.runner.load_run = 'WMP' 119 | 120 | 121 | train_cfg.runner.checkpoint = -1 122 | ppo_runner, train_cfg = task_registry.make_wmp_runner(env=env, name=args.task, args=args, train_cfg=train_cfg) 123 | policy = ppo_runner.get_inference_policy(device=env.device) 124 | 125 | # export policy as a jit module (used to run it from C++) 126 | if EXPORT_POLICY: 127 | path = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'policies') 128 | export_policy_as_jit(ppo_runner.alg.actor_critic, path) 129 | print('Exported policy as jit script to: ', path) 130 | 131 | logger = Logger(env.dt) 132 | robot_index = 0 # which robot is used for logging 133 | joint_index = 1 # which joint is used for logging 134 | stop_state_log = 100 # number of steps before plotting states 135 | stop_rew_log = env.max_episode_length + 1 # number of steps before print average episode rewards 136 | camera_position = np.array(env_cfg.viewer.pos, dtype=np.float64) 137 | camera_vel = np.array([1., 1., 0.]) 138 | camera_direction = np.array(env_cfg.viewer.lookat) - np.array(env_cfg.viewer.pos) 139 | img_idx = 0 140 | 141 | history_length = 5 142 | trajectory_history = torch.zeros(size=(env.num_envs, history_length, env.num_obs - 143 | env.privileged_dim - env.height_dim - 3), device = env.device) 144 | obs_without_command = torch.concat((obs[:, env.privileged_dim:env.privileged_dim + 6], 145 | obs[:, env.privileged_dim + 9:-env.height_dim]), dim=1) 146 | trajectory_history = torch.concat((trajectory_history[:, 1:], obs_without_command.unsqueeze(1)), dim=1) 147 | 148 | world_model = ppo_runner._world_model.to(env.device) 149 | wm_latent = wm_action = None 150 | wm_is_first = torch.ones(env.num_envs, device=env.device) 151 | wm_update_interval = env.cfg.depth.update_interval 152 | wm_action_history = torch.zeros(size=(env.num_envs, wm_update_interval, env.num_actions), 153 | device=env.device) 154 | wm_obs = { 155 | "prop": obs[:, env.privileged_dim: env.privileged_dim + env.cfg.env.prop_dim], 156 | "is_first": wm_is_first, 157 | } 158 | 159 | if (env.cfg.depth.use_camera): 160 | wm_obs["image"] = torch.zeros(((env.num_envs,) + env.cfg.depth.resized + (1,)), 161 | device=world_model.device) 162 | 163 | wm_feature = torch.zeros((env.num_envs, ppo_runner.wm_feature_dim), device=env.device) 164 | 165 | total_reward = 0 166 | not_dones = torch.ones((env.num_envs,), device=env.device) 167 | for i in range(1*int(env.max_episode_length) + 3): 168 | if (env.global_counter % wm_update_interval == 0): 169 | if (env.cfg.depth.use_camera): 170 | wm_obs["image"][env.depth_index] = infos["depth"].unsqueeze(-1).to(world_model.device) 171 | 172 | wm_embed = world_model.encoder(wm_obs) 173 | wm_latent, _ = world_model.dynamics.obs_step(wm_latent, wm_action, wm_embed, wm_obs["is_first"], sample=True) 174 | wm_feature = world_model.dynamics.get_deter_feat(wm_latent) 175 | wm_is_first[:] = 0 176 | 177 | history = trajectory_history.flatten(1).to(env.device) 178 | actions = policy(obs.detach(), history.detach(), wm_feature.detach()) 179 | 180 | 181 | obs, _, rews, dones, infos, reset_env_ids, _ = env.step(actions.detach()) 182 | 183 | not_dones *= (~dones) 184 | total_reward += torch.mean(rews * not_dones) 185 | 186 | # update world model input 187 | wm_action_history = torch.concat( 188 | (wm_action_history[:, 1:], actions.unsqueeze(1)), dim=1) 189 | wm_obs = { 190 | "prop": obs[:, env.privileged_dim: env.privileged_dim + env.cfg.env.prop_dim], 191 | "is_first": wm_is_first, 192 | } 193 | if (env.cfg.depth.use_camera): 194 | wm_obs["image"] = torch.zeros(((env.num_envs,) + env.cfg.depth.resized + (1,)), 195 | device=world_model.device) 196 | 197 | reset_env_ids = reset_env_ids.cpu().numpy() 198 | if (len(reset_env_ids) > 0): 199 | wm_action_history[reset_env_ids, :] = 0 200 | wm_is_first[reset_env_ids] = 1 201 | 202 | wm_action = wm_action_history.flatten(1) 203 | 204 | 205 | # process trajectory history 206 | env_ids = dones.nonzero(as_tuple=False).flatten() 207 | trajectory_history[env_ids] = 0 208 | obs_without_command = torch.concat((obs[:, env.privileged_dim:env.privileged_dim + 6], 209 | obs[:, env.privileged_dim + 9:-env.height_dim]), 210 | dim=1) 211 | trajectory_history = torch.concat( 212 | (trajectory_history[:, 1:], obs_without_command.unsqueeze(1)), dim=1) 213 | 214 | if RECORD_FRAMES: 215 | if i % 2: 216 | filename = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'frames', f"{img_idx}.png") 217 | env.gym.write_viewer_image_to_file(env.viewer, filename) 218 | img_idx += 1 219 | if MOVE_CAMERA: 220 | lootat = env.root_states[8, :3] 221 | camara_position = lootat.detach().cpu().numpy() + [0, 1, 0] 222 | env.set_camera(camara_position, lootat) 223 | 224 | if i < stop_state_log: 225 | logger.log_states( 226 | { 227 | 'dof_pos_target': actions[robot_index, joint_index].item() * env.cfg.control.action_scale, 228 | 'dof_pos': env.dof_pos[robot_index, joint_index].item(), 229 | 'dof_vel': env.dof_vel[robot_index, joint_index].item(), 230 | 'dof_torque': env.torques[robot_index, joint_index].item(), 231 | 'command_x': env.commands[robot_index, 0].item(), 232 | 'command_y': env.commands[robot_index, 1].item(), 233 | 'command_yaw': env.commands[robot_index, 2].item(), 234 | 'base_vel_x': env.base_lin_vel[robot_index, 0].item(), 235 | 'base_vel_y': env.base_lin_vel[robot_index, 1].item(), 236 | 'base_vel_z': env.base_lin_vel[robot_index, 2].item(), 237 | 'base_vel_yaw': env.base_ang_vel[robot_index, 2].item(), 238 | 'contact_forces_z': env.contact_forces[robot_index, env.feet_indices, 2].cpu().numpy() 239 | } 240 | ) 241 | if 0 < i < stop_rew_log: 242 | if infos["episode"]: 243 | num_episodes = torch.sum(env.reset_buf).item() 244 | if num_episodes>0: 245 | logger.log_rewards(infos["episode"], num_episodes) 246 | elif i==stop_rew_log: 247 | logger.print_rewards() 248 | 249 | print('total reward:', total_reward) 250 | 251 | if __name__ == '__main__': 252 | EXPORT_POLICY = True 253 | RECORD_FRAMES = False 254 | MOVE_CAMERA = True 255 | args = get_args() 256 | args.rl_device = args.sim_device 257 | play(args) 258 | -------------------------------------------------------------------------------- /legged_gym/scripts/train.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | import numpy as np 35 | import os 36 | from datetime import datetime 37 | 38 | import inspect 39 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 40 | parentdir = os.path.dirname(os.path.dirname(currentdir)) 41 | os.sys.path.insert(0, parentdir) 42 | 43 | import isaacgym 44 | from legged_gym.envs import * 45 | from legged_gym.utils import get_args, task_registry 46 | import torch 47 | 48 | def train(args): 49 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) 50 | 51 | train_cfg.runner.run_name = 'WMP' 52 | 53 | train_cfg.runner.max_iterations = 100000 54 | train_cfg.runner.save_interval = 1000 55 | 56 | env, env_cfg = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) 57 | ppo_runner, train_cfg = task_registry.make_wmp_runner(env=env, name=args.task, args=args, train_cfg=train_cfg) 58 | ppo_runner.learn(num_learning_iterations=train_cfg.runner.max_iterations, init_at_random_ep_len=True) 59 | 60 | 61 | if __name__ == '__main__': 62 | args = get_args() 63 | args.rl_device = args.sim_device 64 | train(args) 65 | -------------------------------------------------------------------------------- /legged_gym/tests/test_env.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import numpy as np 32 | import os 33 | from datetime import datetime 34 | 35 | import isaacgym 36 | from legged_gym.envs import * 37 | from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger 38 | 39 | import torch 40 | 41 | 42 | def test_env(args): 43 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) 44 | # override some parameters for testing 45 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 10) 46 | 47 | # prepare environment 48 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) 49 | for i in range(int(10*env.max_episode_length)): 50 | actions = 0.*torch.ones(env.num_envs, env.num_actions, device=env.device) 51 | obs, _, rew, done, info = env.step(actions) 52 | print("Done") 53 | 54 | if __name__ == '__main__': 55 | args = get_args() 56 | test_env(args) 57 | -------------------------------------------------------------------------------- /legged_gym/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .helpers import class_to_dict, get_load_path, get_args, export_policy_as_jit, set_seed, update_class_from_dict 32 | from .task_registry import task_registry 33 | from .logger import Logger 34 | from .math import * 35 | from .trimesh import * 36 | from .terrain import Terrain 37 | -------------------------------------------------------------------------------- /legged_gym/utils/helpers.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | import os 35 | import copy 36 | import torch 37 | import numpy as np 38 | import random 39 | from isaacgym import gymapi 40 | from isaacgym import gymutil 41 | 42 | from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR 43 | 44 | def class_to_dict(obj) -> dict: 45 | if not hasattr(obj,"__dict__"): 46 | return obj 47 | result = {} 48 | for key in dir(obj): 49 | if key.startswith("_"): 50 | continue 51 | element = [] 52 | val = getattr(obj, key) 53 | if isinstance(val, list): 54 | for item in val: 55 | element.append(class_to_dict(item)) 56 | else: 57 | element = class_to_dict(val) 58 | result[key] = element 59 | return result 60 | 61 | def update_class_from_dict(obj, dict): 62 | for key, val in dict.items(): 63 | attr = getattr(obj, key, None) 64 | if isinstance(attr, type): 65 | update_class_from_dict(attr, val) 66 | else: 67 | setattr(obj, key, val) 68 | return 69 | 70 | def set_seed(seed): 71 | if seed == -1: 72 | seed = np.random.randint(0, 10000) 73 | print("Setting seed: {}".format(seed)) 74 | 75 | random.seed(seed) 76 | np.random.seed(seed) 77 | torch.manual_seed(seed) 78 | os.environ['PYTHONHASHSEED'] = str(seed) 79 | torch.cuda.manual_seed(seed) 80 | torch.cuda.manual_seed_all(seed) 81 | 82 | def parse_sim_params(args, cfg): 83 | # code from Isaac Gym Preview 2 84 | # initialize sim params 85 | sim_params = gymapi.SimParams() 86 | 87 | # set some values from args 88 | if args.physics_engine == gymapi.SIM_FLEX: 89 | if args.device != "cpu": 90 | print("WARNING: Using Flex with GPU instead of PHYSX!") 91 | elif args.physics_engine == gymapi.SIM_PHYSX: 92 | sim_params.physx.use_gpu = args.use_gpu 93 | sim_params.physx.num_subscenes = args.subscenes 94 | sim_params.use_gpu_pipeline = args.use_gpu_pipeline 95 | 96 | # if sim options are provided in cfg, parse them and update/override above: 97 | if "sim" in cfg: 98 | gymutil.parse_sim_config(cfg["sim"], sim_params) 99 | 100 | # Override num_threads if passed on the command line 101 | if args.physics_engine == gymapi.SIM_PHYSX and args.num_threads > 0: 102 | sim_params.physx.num_threads = args.num_threads 103 | 104 | return sim_params 105 | 106 | def get_load_path(root, load_run=-1, checkpoint=-1): 107 | try: 108 | runs = os.listdir(root) 109 | #TODO sort by date to handle change of month 110 | runs.sort() 111 | if 'exported' in runs: runs.remove('exported') 112 | last_run = os.path.join(root, runs[-1]) 113 | except: 114 | raise ValueError("No runs in this directory: " + root) 115 | if load_run==-1: 116 | load_run = last_run 117 | else: 118 | load_run = os.path.join(root, load_run) 119 | 120 | if checkpoint==-1: 121 | models = [file for file in os.listdir(load_run) if 'model' in file] 122 | models.sort(key=lambda m: '{0:0>15}'.format(m)) 123 | model = models[-1] 124 | else: 125 | model = "model_{}.pt".format(checkpoint) 126 | 127 | load_path = os.path.join(load_run, model) 128 | return load_path 129 | 130 | def update_cfg_from_args(env_cfg, cfg_train, args): 131 | # seed 132 | if env_cfg is not None: 133 | # num envs 134 | if args.num_envs is not None: 135 | env_cfg.env.num_envs = args.num_envs 136 | if cfg_train is not None: 137 | if args.seed is not None: 138 | cfg_train.seed = args.seed 139 | # alg runner parameters 140 | if args.max_iterations is not None: 141 | cfg_train.runner.max_iterations = args.max_iterations 142 | if args.resume: 143 | cfg_train.runner.resume = args.resume 144 | if args.experiment_name is not None: 145 | cfg_train.runner.experiment_name = args.experiment_name 146 | if args.run_name is not None: 147 | cfg_train.runner.run_name = args.run_name 148 | if args.load_run is not None: 149 | cfg_train.runner.load_run = args.load_run 150 | if args.checkpoint is not None: 151 | cfg_train.runner.checkpoint = args.checkpoint 152 | 153 | return env_cfg, cfg_train 154 | 155 | def get_args(): 156 | custom_parameters = [ 157 | {"name": "--task", "type": str, "default": "anymal_c_flat", "help": "Resume training or start testing from a checkpoint. Overrides config file if provided."}, 158 | {"name": "--resume", "action": "store_true", "default": False, "help": "Resume training from a checkpoint"}, 159 | {"name": "--experiment_name", "type": str, "help": "Name of the experiment to run or load. Overrides config file if provided."}, 160 | {"name": "--run_name", "type": str, "help": "Name of the run. Overrides config file if provided."}, 161 | {"name": "--load_run", "type": str, "help": "Name of the run to load when resume=True. If -1: will load the last run. Overrides config file if provided."}, 162 | {"name": "--checkpoint", "type": int, "help": "Saved model checkpoint number. If -1: will load the last checkpoint. Overrides config file if provided."}, 163 | 164 | {"name": "--headless", "action": "store_true", "default": False, "help": "Force display off at all times"}, 165 | {"name": "--horovod", "action": "store_true", "default": False, "help": "Use horovod for multi-gpu training"}, 166 | {"name": "--rl_device", "type": str, "default": "cuda:0", "help": 'Device used by the RL algorithm, (cpu, gpu, cuda:0, cuda:1 etc..)'}, 167 | {"name": "--num_envs", "type": int, "help": "Number of environments to create. Overrides config file if provided."}, 168 | {"name": "--seed", "type": int, "help": "Random seed. Overrides config file if provided."}, 169 | {"name": "--max_iterations", "type": int, "help": "Maximum number of training iterations. Overrides config file if provided."}, 170 | {"name": "--terrain", "type": str, "default": "climb", 171 | "help": 'Only for play'}, 172 | {"name": "--wm_device", "type": str, "default": "None", "help": 'World model device. Overrides config file in dreamer/config.yaml if provided'}, 173 | 174 | ] 175 | # parse arguments 176 | args = gymutil.parse_arguments( 177 | description="RL Policy", 178 | custom_parameters=custom_parameters) 179 | 180 | # name allignment 181 | args.sim_device_id = args.compute_device_id 182 | args.sim_device = args.sim_device_type 183 | if args.sim_device=='cuda': 184 | args.sim_device += f":{args.sim_device_id}" 185 | return args 186 | 187 | def export_policy_as_jit(actor_critic, path): 188 | if hasattr(actor_critic, 'memory_a'): 189 | # assumes LSTM: TODO add GRU 190 | exporter = PolicyExporterLSTM(actor_critic) 191 | exporter.export(path) 192 | else: 193 | os.makedirs(path, exist_ok=True) 194 | path = os.path.join(path, 'policy_1.pt') 195 | model = copy.deepcopy(actor_critic.actor).to('cpu') 196 | traced_script_module = torch.jit.script(model) 197 | traced_script_module.save(path) 198 | 199 | class PolicyExporterLSTM(torch.nn.Module): 200 | def __init__(self, actor_critic): 201 | super().__init__() 202 | self.actor = copy.deepcopy(actor_critic.actor) 203 | self.is_recurrent = actor_critic.is_recurrent 204 | self.memory = copy.deepcopy(actor_critic.memory_a.rnn) 205 | self.memory.cpu() 206 | self.register_buffer(f'hidden_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size)) 207 | self.register_buffer(f'cell_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size)) 208 | 209 | def forward(self, x): 210 | out, (h, c) = self.memory(x.unsqueeze(0), (self.hidden_state, self.cell_state)) 211 | self.hidden_state[:] = h 212 | self.cell_state[:] = c 213 | return self.actor(out.squeeze(0)) 214 | 215 | @torch.jit.export 216 | def reset_memory(self): 217 | self.hidden_state[:] = 0. 218 | self.cell_state[:] = 0. 219 | 220 | def export(self, path): 221 | os.makedirs(path, exist_ok=True) 222 | path = os.path.join(path, 'policy_lstm_1.pt') 223 | self.to('cpu') 224 | traced_script_module = torch.jit.script(self) 225 | traced_script_module.save(path) 226 | 227 | 228 | -------------------------------------------------------------------------------- /legged_gym/utils/logger.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import matplotlib.pyplot as plt 32 | import numpy as np 33 | from collections import defaultdict 34 | from multiprocessing import Process, Value 35 | 36 | class Logger: 37 | def __init__(self, dt): 38 | self.state_log = defaultdict(list) 39 | self.rew_log = defaultdict(list) 40 | self.dt = dt 41 | self.num_episodes = 0 42 | self.plot_process = None 43 | 44 | def log_state(self, key, value): 45 | self.state_log[key].append(value) 46 | 47 | def log_states(self, dict): 48 | for key, value in dict.items(): 49 | self.log_state(key, value) 50 | 51 | def log_rewards(self, dict, num_episodes): 52 | for key, value in dict.items(): 53 | if 'rew' in key: 54 | self.rew_log[key].append(value.item() * num_episodes) 55 | self.num_episodes += num_episodes 56 | 57 | def reset(self): 58 | self.state_log.clear() 59 | self.rew_log.clear() 60 | 61 | def plot_states(self): 62 | self.plot_process = Process(target=self._plot) 63 | self.plot_process.start() 64 | 65 | def _plot(self): 66 | nb_rows = 3 67 | nb_cols = 3 68 | fig, axs = plt.subplots(nb_rows, nb_cols) 69 | for key, value in self.state_log.items(): 70 | time = np.linspace(0, len(value)*self.dt, len(value)) 71 | break 72 | log= self.state_log 73 | # plot joint targets and measured positions 74 | a = axs[1, 0] 75 | if log["dof_pos"]: a.plot(time, log["dof_pos"], label='measured') 76 | if log["dof_pos_target"]: a.plot(time, log["dof_pos_target"], label='target') 77 | a.set(xlabel='time [s]', ylabel='Position [rad]', title='DOF Position') 78 | a.legend() 79 | # plot joint velocity 80 | a = axs[1, 1] 81 | if log["dof_vel"]: a.plot(time, log["dof_vel"], label='measured') 82 | if log["dof_vel_target"]: a.plot(time, log["dof_vel_target"], label='target') 83 | a.set(xlabel='time [s]', ylabel='Velocity [rad/s]', title='Joint Velocity') 84 | a.legend() 85 | # plot base vel x 86 | a = axs[0, 0] 87 | if log["base_vel_x"]: a.plot(time, log["base_vel_x"], label='measured') 88 | if log["command_x"]: a.plot(time, log["command_x"], label='commanded') 89 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity x') 90 | a.legend() 91 | # plot base vel y 92 | a = axs[0, 1] 93 | if log["base_vel_y"]: a.plot(time, log["base_vel_y"], label='measured') 94 | if log["command_y"]: a.plot(time, log["command_y"], label='commanded') 95 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity y') 96 | a.legend() 97 | # plot base vel yaw 98 | a = axs[0, 2] 99 | if log["base_vel_yaw"]: a.plot(time, log["base_vel_yaw"], label='measured') 100 | if log["command_yaw"]: a.plot(time, log["command_yaw"], label='commanded') 101 | a.set(xlabel='time [s]', ylabel='base ang vel [rad/s]', title='Base velocity yaw') 102 | a.legend() 103 | # plot base vel z 104 | a = axs[1, 2] 105 | if log["base_vel_z"]: a.plot(time, log["base_vel_z"], label='measured') 106 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity z') 107 | a.legend() 108 | # plot contact forces 109 | a = axs[2, 0] 110 | if log["contact_forces_z"]: 111 | forces = np.array(log["contact_forces_z"]) 112 | for i in range(forces.shape[1]): 113 | a.plot(time, forces[:, i], label=f'force {i}') 114 | a.set(xlabel='time [s]', ylabel='Forces z [N]', title='Vertical Contact forces') 115 | a.legend() 116 | # plot torque/vel curves 117 | a = axs[2, 1] 118 | if log["dof_vel"]!=[] and log["dof_torque"]!=[]: a.plot(log["dof_vel"], log["dof_torque"], 'x', label='measured') 119 | a.set(xlabel='Joint vel [rad/s]', ylabel='Joint Torque [Nm]', title='Torque/velocity curves') 120 | a.legend() 121 | # plot torques 122 | a = axs[2, 2] 123 | if log["dof_torque"]!=[]: a.plot(time, log["dof_torque"], label='measured') 124 | a.set(xlabel='time [s]', ylabel='Joint Torque [Nm]', title='Torque') 125 | a.legend() 126 | plt.show() 127 | 128 | def print_rewards(self): 129 | print("Average rewards per second:") 130 | for key, values in self.rew_log.items(): 131 | mean = np.sum(np.array(values)) / self.num_episodes 132 | print(f" - {key}: {mean}") 133 | print(f"Total number of episodes: {self.num_episodes}") 134 | 135 | def __del__(self): 136 | if self.plot_process is not None: 137 | self.plot_process.kill() -------------------------------------------------------------------------------- /legged_gym/utils/math.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import torch 32 | from torch import Tensor 33 | import numpy as np 34 | from isaacgym.torch_utils import quat_apply, normalize 35 | from typing import Tuple 36 | 37 | # @ torch.jit.script 38 | def quat_apply_yaw(quat, vec): 39 | quat_yaw = quat.clone().view(-1, 4) 40 | quat_yaw[:, :2] = 0. 41 | quat_yaw = normalize(quat_yaw) 42 | return quat_apply(quat_yaw, vec) 43 | 44 | # @ torch.jit.script 45 | def wrap_to_pi(angles): 46 | angles %= 2*np.pi 47 | angles -= 2*np.pi * (angles > np.pi) 48 | return angles 49 | 50 | # @ torch.jit.script 51 | def torch_rand_sqrt_float(lower, upper, shape, device): 52 | # type: (float, float, Tuple[int, int], str) -> Tensor 53 | r = 2*torch.rand(*shape, device=device) - 1 54 | r = torch.where(r<0., -torch.sqrt(-r), torch.sqrt(r)) 55 | r = (r + 1.) / 2. 56 | return (upper - lower) * r + lower -------------------------------------------------------------------------------- /legged_gym/utils/task_registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | import os 35 | from datetime import datetime 36 | from typing import Tuple 37 | import torch 38 | import numpy as np 39 | 40 | from rsl_rl.env import VecEnv 41 | from rsl_rl.runners import OnPolicyRunner, WMPRunner 42 | 43 | from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR 44 | from .helpers import get_args, update_cfg_from_args, class_to_dict, get_load_path, set_seed, parse_sim_params 45 | from legged_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO 46 | 47 | class TaskRegistry(): 48 | def __init__(self): 49 | self.task_classes = {} 50 | self.env_cfgs = {} 51 | self.train_cfgs = {} 52 | 53 | def register(self, name: str, task_class: VecEnv, env_cfg: LeggedRobotCfg, train_cfg: LeggedRobotCfgPPO): 54 | self.task_classes[name] = task_class 55 | self.env_cfgs[name] = env_cfg 56 | self.train_cfgs[name] = train_cfg 57 | 58 | def get_task_class(self, name: str) -> VecEnv: 59 | return self.task_classes[name] 60 | 61 | def get_cfgs(self, name) -> Tuple[LeggedRobotCfg, LeggedRobotCfgPPO]: 62 | train_cfg = self.train_cfgs[name] 63 | env_cfg = self.env_cfgs[name] 64 | # copy seed 65 | env_cfg.seed = train_cfg.seed 66 | return env_cfg, train_cfg 67 | 68 | def make_env(self, name, args=None, env_cfg=None) -> Tuple[VecEnv, LeggedRobotCfg]: 69 | """ Creates an environment either from a registered namme or from the provided config file. 70 | 71 | Args: 72 | name (string): Name of a registered env. 73 | args (Args, optional): Isaac Gym comand line arguments. If None get_args() will be called. Defaults to None. 74 | env_cfg (Dict, optional): Environment config file used to override the registered config. Defaults to None. 75 | 76 | Raises: 77 | ValueError: Error if no registered env corresponds to 'name' 78 | 79 | Returns: 80 | isaacgym.VecTaskPython: The created environment 81 | Dict: the corresponding config file 82 | """ 83 | # if no args passed get command line arguments 84 | if args is None: 85 | args = get_args() 86 | # check if there is a registered env with that name 87 | if name in self.task_classes: 88 | task_class = self.get_task_class(name) 89 | else: 90 | raise ValueError(f"Task with name: {name} was not registered") 91 | if env_cfg is None: 92 | # load config files 93 | env_cfg, _ = self.get_cfgs(name) 94 | # override cfg from args (if specified) 95 | env_cfg, _ = update_cfg_from_args(env_cfg, None, args) 96 | set_seed(env_cfg.seed) 97 | # parse sim params (convert to dict first) 98 | sim_params = {"sim": class_to_dict(env_cfg.sim)} 99 | sim_params = parse_sim_params(args, sim_params) 100 | env = task_class( cfg=env_cfg, 101 | sim_params=sim_params, 102 | physics_engine=args.physics_engine, 103 | sim_device=args.sim_device, 104 | headless=args.headless) 105 | return env, env_cfg 106 | 107 | def make_alg_runner(self, env, name=None, args=None, train_cfg=None, log_root="default") -> Tuple[OnPolicyRunner, LeggedRobotCfgPPO]: 108 | """ Creates the training algorithm either from a registered namme or from the provided config file. 109 | 110 | Args: 111 | env (isaacgym.VecTaskPython): The environment to train (TODO: remove from within the algorithm) 112 | name (string, optional): Name of a registered env. If None, the config file will be used instead. Defaults to None. 113 | args (Args, optional): Isaac Gym comand line arguments. If None get_args() will be called. Defaults to None. 114 | train_cfg (Dict, optional): Training config file. If None 'name' will be used to get the config file. Defaults to None. 115 | log_root (str, optional): Logging directory for Tensorboard. Set to 'None' to avoid logging (at test time for example). 116 | Logs will be saved in /_. Defaults to "default"=/logs/. 117 | 118 | Raises: 119 | ValueError: Error if neither 'name' or 'train_cfg' are provided 120 | Warning: If both 'name' or 'train_cfg' are provided 'name' is ignored 121 | 122 | Returns: 123 | PPO: The created algorithm 124 | Dict: the corresponding config file 125 | """ 126 | # if no args passed get command line arguments 127 | if args is None: 128 | args = get_args() 129 | # if config files are passed use them, otherwise load from the name 130 | if train_cfg is None: 131 | if name is None: 132 | raise ValueError("Either 'name' or 'train_cfg' must be not None") 133 | # load config files 134 | _, train_cfg = self.get_cfgs(name) 135 | else: 136 | if name is not None: 137 | print(f"'train_cfg' provided -> Ignoring 'name={name}'") 138 | # override cfg from args (if specified) 139 | _, train_cfg = update_cfg_from_args(None, train_cfg, args) 140 | 141 | if log_root=="default": 142 | log_root = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name) 143 | log_dir = os.path.join(log_root, train_cfg.runner.run_name) 144 | # log_dir = os.path.join(log_root, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + train_cfg.runner.run_name) 145 | elif log_root is None: 146 | log_dir = None 147 | else: 148 | # log_dir = os.path.join(log_root, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + train_cfg.runner.run_name) 149 | log_dir = os.path.join(log_root, train_cfg.runner.run_name) 150 | runner_class = eval(train_cfg.runner_class_name) 151 | train_cfg_dict = class_to_dict(train_cfg) 152 | runner = runner_class(env, train_cfg_dict, log_dir, device=args.rl_device) 153 | #save resume path before creating a new log_dir 154 | resume = train_cfg.runner.resume 155 | if resume: 156 | # load previously trained model 157 | resume_path = get_load_path(log_root, load_run=train_cfg.runner.load_run, checkpoint=train_cfg.runner.checkpoint) 158 | print(f"Loading model from: {resume_path}") 159 | runner.load(resume_path) 160 | return runner, train_cfg 161 | 162 | def make_wmp_runner(self, env, name=None, args=None, train_cfg=None, log_root="default") -> Tuple[WMPRunner, LeggedRobotCfgPPO]: 163 | """ Creates the training algorithm either from a registered namme or from the provided config file. 164 | 165 | Args: 166 | env (isaacgym.VecTaskPython): The environment to train (TODO: remove from within the algorithm) 167 | name (string, optional): Name of a registered env. If None, the config file will be used instead. Defaults to None. 168 | args (Args, optional): Isaac Gym comand line arguments. If None get_args() will be called. Defaults to None. 169 | train_cfg (Dict, optional): Training config file. If None 'name' will be used to get the config file. Defaults to None. 170 | log_root (str, optional): Logging directory for Tensorboard. Set to 'None' to avoid logging (at test time for example). 171 | Logs will be saved in /_. Defaults to "default"=/logs/. 172 | 173 | Raises: 174 | ValueError: Error if neither 'name' or 'train_cfg' are provided 175 | Warning: If both 'name' or 'train_cfg' are provided 'name' is ignored 176 | 177 | Returns: 178 | PPO: The created algorithm 179 | Dict: the corresponding config file 180 | """ 181 | # if no args passed get command line arguments 182 | if args is None: 183 | args = get_args() 184 | # if config files are passed use them, otherwise load from the name 185 | if train_cfg is None: 186 | if name is None: 187 | raise ValueError("Either 'name' or 'train_cfg' must be not None") 188 | # load config files 189 | _, train_cfg = self.get_cfgs(name) 190 | else: 191 | if name is not None: 192 | print(f"'train_cfg' provided -> Ignoring 'name={name}'") 193 | # override cfg from args (if specified) 194 | _, train_cfg = update_cfg_from_args(None, train_cfg, args) 195 | 196 | if log_root=="default": 197 | log_root = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name) 198 | log_dir = os.path.join(log_root, train_cfg.runner.run_name) 199 | # log_dir = os.path.join(log_root, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + train_cfg.runner.run_name) 200 | elif log_root is None: 201 | log_dir = None 202 | else: 203 | # log_dir = os.path.join(log_root, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + train_cfg.runner.run_name) 204 | log_dir = os.path.join(log_root, train_cfg.runner.run_name) 205 | # print(train_cfg.runner_class_name) 206 | train_cfg_dict = class_to_dict(train_cfg) 207 | runner = WMPRunner(env, train_cfg_dict, log_dir, device=args.rl_device) 208 | #save resume path before creating a new log_dir 209 | resume = train_cfg.runner.resume 210 | if resume: 211 | # load previously trained model 212 | resume_path = get_load_path(log_root, load_run=train_cfg.runner.load_run, checkpoint=train_cfg.runner.checkpoint) 213 | print(f"Loading model from: {resume_path}") 214 | runner.load(resume_path) 215 | return runner, train_cfg 216 | 217 | 218 | # make global task registry 219 | task_registry = TaskRegistry() -------------------------------------------------------------------------------- /legged_gym/utils/trimesh.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2023 Ziwen Zhuang 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | """ This file defines a mesh as a tuple of (vertices, triangles) 24 | All operations are based on numpy ndarray 25 | - vertices: np ndarray of shape (n, 3) np.float32 26 | - triangles: np ndarray of shape (n_, 3) np.uint32 27 | """ 28 | import numpy as np 29 | 30 | def box_trimesh( 31 | size, # float [3] for x, y, z axis length (in meter) under box frame 32 | center_position, # float [3] position (in meter) in world frame 33 | rpy= np.zeros(3), # euler angle (in rad) not implemented yet. 34 | ): 35 | if not (rpy == 0).all(): 36 | raise NotImplementedError("Only axis-aligned box triangle mesh is implemented") 37 | 38 | vertices = np.empty((8, 3), dtype= np.float32) 39 | vertices[:] = center_position 40 | vertices[[0, 4, 2, 6], 0] -= size[0] / 2 41 | vertices[[1, 5, 3, 7], 0] += size[0] / 2 42 | vertices[[0, 1, 2, 3], 1] -= size[1] / 2 43 | vertices[[4, 5, 6, 7], 1] += size[1] / 2 44 | vertices[[2, 3, 6, 7], 2] -= size[2] / 2 45 | vertices[[0, 1, 4, 5], 2] += size[2] / 2 46 | 47 | triangles = -np.ones((12, 3), dtype= np.uint32) 48 | triangles[0] = [0, 2, 1] # 49 | triangles[1] = [1, 2, 3] 50 | triangles[2] = [0, 4, 2] # 51 | triangles[3] = [2, 4, 6] 52 | triangles[4] = [4, 5, 6] # 53 | triangles[5] = [5, 7, 6] 54 | triangles[6] = [1, 3, 5] # 55 | triangles[7] = [3, 7, 5] 56 | triangles[8] = [0, 1, 4] # 57 | triangles[9] = [1, 5, 4] 58 | triangles[10]= [2, 6, 3] # 59 | triangles[11]= [3, 6, 7] 60 | 61 | return vertices, triangles 62 | 63 | def combine_trimeshes(*trimeshes): 64 | if len(trimeshes) > 2: 65 | return combine_trimeshes( 66 | trimeshes[0], 67 | combine_trimeshes(trimeshes[1:]) 68 | ) 69 | 70 | # only two trimesh to combine 71 | trimesh_0, trimesh_1 = trimeshes 72 | if trimesh_0[1].shape[0] < trimesh_1[1].shape[0]: 73 | trimesh_0, trimesh_1 = trimesh_1, trimesh_0 74 | 75 | trimesh_1 = (trimesh_1[0], trimesh_1[1] + trimesh_0[0].shape[0]) 76 | vertices = np.concatenate((trimesh_0[0], trimesh_1[0]), axis= 0) 77 | triangles = np.concatenate((trimesh_0[1], trimesh_1[1]), axis= 0) 78 | 79 | return vertices, triangles 80 | 81 | def move_trimesh(trimesh, move: np.ndarray): 82 | """ inplace operation """ 83 | trimesh[0] += move 84 | -------------------------------------------------------------------------------- /licenses/dependencies/matplotlib_license.txt: -------------------------------------------------------------------------------- 1 | 1. This LICENSE AGREEMENT is between the Matplotlib Development Team ("MDT"), and the Individual or Organization ("Licensee") accessing and otherwise using matplotlib software in source or binary form and its associated documentation. 2 | 3 | 2. Subject to the terms and conditions of this License Agreement, MDT hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use matplotlib 3.4.3 alone or in any derivative version, provided, however, that MDT's License Agreement and MDT's notice of copyright, i.e., "Copyright (c) 2012-2013 Matplotlib Development Team; All Rights Reserved" are retained in matplotlib 3.4.3 alone or in any derivative version prepared by Licensee. 4 | 5 | 3. In the event Licensee prepares a derivative work that is based on or incorporates matplotlib 3.4.3 or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to matplotlib 3.4.3. 6 | 7 | 4. MDT is making matplotlib 3.4.3 available to Licensee on an "AS IS" basis. MDT MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, MDT MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB 3.4.3 WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 8 | 9 | 5. MDT SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB 3.4.3 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING MATPLOTLIB 3.4.3, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 10 | 11 | 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 12 | 13 | 7. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between MDT and Licensee. This License Agreement does not grant permission to use MDT trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 14 | 15 | 8. By copying, installing or otherwise using matplotlib 3.4.3, Licensee agrees to be bound by the terms and conditions of this License Agreement. -------------------------------------------------------------------------------- /licenses/subcomponents/dreamerv3-torch_license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 NM512 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/subcomponents/leggedgym_license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, ETH Zurich, Nikita Rudin 2 | Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its contributors 16 | may be used to endorse or promote products derived from this software without 17 | specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 26 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | See licenses/assets for license information for assets included in this repository. 31 | See licenses/dependencies for license information of dependencies of this package. 32 | -------------------------------------------------------------------------------- /licenses/subcomponents/parkour_license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ziwen Zhuang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | cachetools==4.2.4 3 | certifi==2021.10.8 4 | charset-normalizer==2.0.12 5 | cycler==0.11.0 6 | fastrlock==0.8 7 | fonttools==4.32.0 8 | google-auth==1.35.0 9 | google-auth-oauthlib==0.4.6 10 | grpcio==1.44.0 11 | idna==3.3 12 | imageio==2.17.0 13 | importlib-metadata==4.11.3 14 | joblib==1.1.0 15 | kiwisolver==1.4.2 16 | Markdown==3.3.6 17 | matplotlib==3.5.1 18 | ninja==1.11.1 19 | numpy==1.22.3 20 | oauthlib==3.2.0 21 | opencv-python==4.6.0.66 22 | packaging==21.3 23 | Pillow==9.1.0 24 | protobuf==3.20.0 25 | pyasn1==0.4.8 26 | pyasn1-modules==0.2.8 27 | pybullet==3.2.5 28 | pyparsing==3.0.8 29 | python-dateutil==2.8.2 30 | PyYAML==6.0 31 | requests==2.27.1 32 | requests-oauthlib==1.3.1 33 | rsa==4.8 34 | scikit-learn==1.1.2 35 | scipy==1.8.0 36 | six==1.16.0 37 | sklearn==0.0 38 | tensorboard==2.8.0 39 | tensorboard-data-server==0.6.1 40 | tensorboard-plugin-wit==1.8.1 41 | threadpoolctl==3.1.0 42 | trimesh==3.20.1 43 | typing_extensions==4.2.0 44 | urllib3==1.26.9 45 | Werkzeug==2.1.1 46 | zipp==3.8.0 47 | -------------------------------------------------------------------------------- /resources/robots/a1/meshes/trunk_A1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/WMP/c232c115ada4517453ebded5019078ba055456de/resources/robots/a1/meshes/trunk_A1.png -------------------------------------------------------------------------------- /rsl_rl/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin -------------------------------------------------------------------------------- /rsl_rl/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .ppo import PPO 32 | from .amp_ppo import AMPPPO -------------------------------------------------------------------------------- /rsl_rl/algorithms/amp_discriminator.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | 32 | import torch 33 | import torch.nn as nn 34 | import torch.utils.data 35 | from torch import autograd 36 | 37 | from rsl_rl.utils import utils 38 | 39 | 40 | class AMPDiscriminator(nn.Module): 41 | def __init__( 42 | self, input_dim, amp_reward_coef, hidden_layer_sizes, device, task_reward_lerp=0.0): 43 | super(AMPDiscriminator, self).__init__() 44 | 45 | self.device = device 46 | self.input_dim = input_dim 47 | 48 | self.amp_reward_coef = amp_reward_coef 49 | amp_layers = [] 50 | curr_in_dim = input_dim 51 | for hidden_dim in hidden_layer_sizes: 52 | amp_layers.append(nn.Linear(curr_in_dim, hidden_dim)) 53 | amp_layers.append(nn.ReLU()) 54 | curr_in_dim = hidden_dim 55 | self.trunk = nn.Sequential(*amp_layers).to(device) 56 | self.amp_linear = nn.Linear(hidden_layer_sizes[-1], 1).to(device) 57 | 58 | self.trunk.train() 59 | self.amp_linear.train() 60 | 61 | self.task_reward_lerp = task_reward_lerp 62 | 63 | def forward(self, x): 64 | h = self.trunk(x) 65 | d = self.amp_linear(h) 66 | return d 67 | 68 | def compute_grad_pen(self, 69 | expert_state, 70 | expert_next_state, 71 | lambda_=10): 72 | expert_data = torch.cat([expert_state, expert_next_state], dim=-1) 73 | expert_data.requires_grad = True 74 | 75 | disc = self.amp_linear(self.trunk(expert_data)) 76 | ones = torch.ones(disc.size(), device=disc.device) 77 | grad = autograd.grad( 78 | outputs=disc, inputs=expert_data, 79 | grad_outputs=ones, create_graph=True, 80 | retain_graph=True, only_inputs=True)[0] 81 | 82 | # Enforce that the grad norm approaches 0. 83 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 0).pow(2).mean() 84 | return grad_pen 85 | 86 | def predict_amp_reward( 87 | self, state, next_state, task_reward, normalizer=None): 88 | with torch.no_grad(): 89 | self.eval() 90 | if normalizer is not None: 91 | state = normalizer.normalize_torch(state, self.device) 92 | next_state = normalizer.normalize_torch(next_state, self.device) 93 | 94 | d = self.amp_linear(self.trunk(torch.cat([state, next_state], dim=-1))) 95 | reward = self.amp_reward_coef * torch.clamp(1 - (1/4) * torch.square(d - 1), min=0) 96 | if self.task_reward_lerp > 0: 97 | reward = self._lerp_reward(reward, task_reward.unsqueeze(-1)) 98 | self.train() 99 | return reward.squeeze(), d 100 | 101 | def _lerp_reward(self, disc_r, task_r): 102 | # r = (1.0 - self.task_reward_lerp) * disc_r + self.task_reward_lerp * task_r 103 | r = disc_r + task_r 104 | return r -------------------------------------------------------------------------------- /rsl_rl/algorithms/ppo.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.optim as optim 34 | 35 | from rsl_rl.modules import ActorCritic 36 | from rsl_rl.storage import RolloutStorage 37 | 38 | class PPO: 39 | actor_critic: ActorCritic 40 | def __init__(self, 41 | actor_critic, 42 | num_learning_epochs=1, 43 | num_mini_batches=1, 44 | clip_param=0.2, 45 | gamma=0.998, 46 | lam=0.95, 47 | value_loss_coef=1.0, 48 | entropy_coef=0.0, 49 | learning_rate=1e-3, 50 | max_grad_norm=1.0, 51 | use_clipped_value_loss=True, 52 | schedule="fixed", 53 | desired_kl=0.01, 54 | device='cpu', 55 | ): 56 | 57 | self.device = device 58 | 59 | self.desired_kl = desired_kl 60 | self.schedule = schedule 61 | self.learning_rate = learning_rate 62 | 63 | # PPO components 64 | self.actor_critic = actor_critic 65 | self.actor_critic.to(self.device) 66 | self.storage = None # initialized later 67 | self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate) 68 | self.transition = RolloutStorage.Transition() 69 | 70 | # PPO parameters 71 | self.clip_param = clip_param 72 | self.num_learning_epochs = num_learning_epochs 73 | self.num_mini_batches = num_mini_batches 74 | self.value_loss_coef = value_loss_coef 75 | self.entropy_coef = entropy_coef 76 | self.gamma = gamma 77 | self.lam = lam 78 | self.max_grad_norm = max_grad_norm 79 | self.use_clipped_value_loss = use_clipped_value_loss 80 | 81 | def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, history_dim, wm_feature_dim): 82 | self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, history_dim=history_dim, 83 | wm_feature_dim = wm_feature_dim, device = self.device) 84 | 85 | def test_mode(self): 86 | self.actor_critic.test() 87 | 88 | def train_mode(self): 89 | self.actor_critic.train() 90 | 91 | def act(self, obs, critic_obs): 92 | if self.actor_critic.is_recurrent: 93 | self.transition.hidden_states = self.actor_critic.get_hidden_states() 94 | # Compute the actions and values 95 | self.transition.actions = self.actor_critic.act(obs).detach() 96 | self.transition.values = self.actor_critic.evaluate(critic_obs).detach() 97 | self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach() 98 | self.transition.action_mean = self.actor_critic.action_mean.detach() 99 | self.transition.action_sigma = self.actor_critic.action_std.detach() 100 | # need to record obs and critic_obs before env.step() 101 | self.transition.observations = obs 102 | self.transition.critic_observations = critic_obs 103 | return self.transition.actions 104 | 105 | def process_env_step(self, rewards, dones, infos): 106 | self.transition.rewards = rewards.clone() 107 | self.transition.dones = dones 108 | # Bootstrapping on time outs 109 | if 'time_outs' in infos: 110 | self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1) 111 | 112 | # Record the transition 113 | self.storage.add_transitions(self.transition) 114 | self.transition.clear() 115 | self.actor_critic.reset(dones) 116 | 117 | def compute_returns(self, last_critic_obs): 118 | last_values= self.actor_critic.evaluate(last_critic_obs).detach() 119 | self.storage.compute_returns(last_values, self.gamma, self.lam) 120 | 121 | def update(self): 122 | mean_value_loss = 0 123 | mean_surrogate_loss = 0 124 | if self.actor_critic.is_recurrent: 125 | generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) 126 | else: 127 | generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) 128 | for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \ 129 | old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator: 130 | 131 | 132 | self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) 133 | actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch) 134 | value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]) 135 | mu_batch = self.actor_critic.action_mean 136 | sigma_batch = self.actor_critic.action_std 137 | entropy_batch = self.actor_critic.entropy 138 | 139 | # KL 140 | if self.desired_kl != None and self.schedule == 'adaptive': 141 | with torch.inference_mode(): 142 | kl = torch.sum( 143 | torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1) 144 | kl_mean = torch.mean(kl) 145 | 146 | if kl_mean > self.desired_kl * 2.0: 147 | self.learning_rate = max(1e-5, self.learning_rate / 1.5) 148 | elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: 149 | self.learning_rate = min(1e-2, self.learning_rate * 1.5) 150 | 151 | for param_group in self.optimizer.param_groups: 152 | param_group['lr'] = self.learning_rate 153 | 154 | 155 | # Surrogate loss 156 | ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) 157 | surrogate = -torch.squeeze(advantages_batch) * ratio 158 | surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param, 159 | 1.0 + self.clip_param) 160 | surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() 161 | 162 | # Value function loss 163 | if self.use_clipped_value_loss: 164 | value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param, 165 | self.clip_param) 166 | value_losses = (value_batch - returns_batch).pow(2) 167 | value_losses_clipped = (value_clipped - returns_batch).pow(2) 168 | value_loss = torch.max(value_losses, value_losses_clipped).mean() 169 | else: 170 | value_loss = (returns_batch - value_batch).pow(2).mean() 171 | 172 | loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean() 173 | 174 | # Gradient step 175 | self.optimizer.zero_grad() 176 | loss.backward() 177 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm) 178 | self.optimizer.step() 179 | 180 | mean_value_loss += value_loss.item() 181 | mean_surrogate_loss += surrogate_loss.item() 182 | 183 | num_updates = self.num_learning_epochs * self.num_mini_batches 184 | mean_value_loss /= num_updates 185 | mean_surrogate_loss /= num_updates 186 | self.storage.clear() 187 | 188 | return mean_value_loss, mean_surrogate_loss 189 | -------------------------------------------------------------------------------- /rsl_rl/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/WMP/c232c115ada4517453ebded5019078ba055456de/rsl_rl/datasets/__init__.py -------------------------------------------------------------------------------- /rsl_rl/datasets/motion_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for processing motion clips.""" 17 | 18 | import os 19 | import inspect 20 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 21 | parentdir = os.path.dirname(os.path.dirname(currentdir)) 22 | os.sys.path.insert(0, parentdir) 23 | 24 | import numpy as np 25 | 26 | from rsl_rl.datasets import pose3d 27 | from pybullet_utils import transformations 28 | 29 | 30 | def standardize_quaternion(q): 31 | """Returns a quaternion where q.w >= 0 to remove redundancy due to q = -q. 32 | 33 | Args: 34 | q: A quaternion to be standardized. 35 | 36 | Returns: 37 | A quaternion with q.w >= 0. 38 | 39 | """ 40 | if q[-1] < 0: 41 | q = -q 42 | return q 43 | 44 | 45 | def normalize_rotation_angle(theta): 46 | """Returns a rotation angle normalized between [-pi, pi]. 47 | 48 | Args: 49 | theta: angle of rotation (radians). 50 | 51 | Returns: 52 | An angle of rotation normalized between [-pi, pi]. 53 | 54 | """ 55 | norm_theta = theta 56 | if np.abs(norm_theta) > np.pi: 57 | norm_theta = np.fmod(norm_theta, 2 * np.pi) 58 | if norm_theta >= 0: 59 | norm_theta += -2 * np.pi 60 | else: 61 | norm_theta += 2 * np.pi 62 | 63 | return norm_theta 64 | 65 | 66 | def calc_heading(q): 67 | """Returns the heading of a rotation q, specified as a quaternion. 68 | 69 | The heading represents the rotational component of q along the vertical 70 | axis (z axis). 71 | 72 | Args: 73 | q: A quaternion that the heading is to be computed from. 74 | 75 | Returns: 76 | An angle representing the rotation about the z axis. 77 | 78 | """ 79 | ref_dir = np.array([1, 0, 0]) 80 | rot_dir = pose3d.QuaternionRotatePoint(ref_dir, q) 81 | heading = np.arctan2(rot_dir[1], rot_dir[0]) 82 | return heading 83 | 84 | 85 | def calc_heading_rot(q): 86 | """Return a quaternion representing the heading rotation of q along the vertical axis (z axis). 87 | 88 | Args: 89 | q: A quaternion that the heading is to be computed from. 90 | 91 | Returns: 92 | A quaternion representing the rotation about the z axis. 93 | 94 | """ 95 | heading = calc_heading(q) 96 | q_heading = transformations.quaternion_about_axis(heading, [0, 0, 1]) 97 | return q_heading 98 | -------------------------------------------------------------------------------- /rsl_rl/datasets/pose3d.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utilities for 3D pose conversion.""" 16 | import math 17 | import numpy as np 18 | 19 | from pybullet_utils import transformations 20 | 21 | VECTOR3_0 = np.zeros(3, dtype=np.float64) 22 | VECTOR3_1 = np.ones(3, dtype=np.float64) 23 | VECTOR3_X = np.array([1, 0, 0], dtype=np.float64) 24 | VECTOR3_Y = np.array([0, 1, 0], dtype=np.float64) 25 | VECTOR3_Z = np.array([0, 0, 1], dtype=np.float64) 26 | 27 | # QUATERNION_IDENTITY is the multiplicative identity 1.0 + 0i + 0j + 0k. 28 | # When interpreted as a rotation, it is the identity rotation. 29 | QUATERNION_IDENTITY = np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float64) 30 | 31 | 32 | def Vector3RandomNormal(sigma, mu=VECTOR3_0): 33 | """Returns a random 3D vector from a normal distribution. 34 | 35 | Each component is selected independently from a normal distribution. 36 | 37 | Args: 38 | sigma: Scale (or stddev) of distribution for all variables. 39 | mu: Mean of distribution for each variable. 40 | 41 | Returns: 42 | A 3D vector in a numpy array. 43 | """ 44 | 45 | random_v3 = np.random.normal(scale=sigma, size=3) + mu 46 | return random_v3 47 | 48 | 49 | def Vector3RandomUniform(low=VECTOR3_0, high=VECTOR3_1): 50 | """Returns a 3D vector selected uniformly from the input box. 51 | 52 | Args: 53 | low: The min-value corner of the box. 54 | high: The max-value corner of the box. 55 | 56 | Returns: 57 | A 3D vector in a numpy array. 58 | """ 59 | 60 | random_x = np.random.uniform(low=low[0], high=high[0]) 61 | random_y = np.random.uniform(low=low[1], high=high[1]) 62 | random_z = np.random.uniform(low=low[2], high=high[2]) 63 | return np.array([random_x, random_y, random_z]) 64 | 65 | 66 | def Vector3RandomUnit(): 67 | """Returns a random 3D vector with unit length. 68 | 69 | Generates a 3D vector selected uniformly from the unit sphere. 70 | 71 | Returns: 72 | A normalized 3D vector in a numpy array. 73 | """ 74 | longitude = np.random.uniform(low=-math.pi, high=math.pi) 75 | sin_latitude = np.random.uniform(low=-1.0, high=1.0) 76 | cos_latitude = math.sqrt(1.0 - sin_latitude * sin_latitude) 77 | x = math.cos(longitude) * cos_latitude 78 | y = math.sin(longitude) * cos_latitude 79 | z = sin_latitude 80 | return np.array([x, y, z], dtype=np.float64) 81 | 82 | 83 | def QuaternionNormalize(q): 84 | """Normalizes the quaternion to length 1. 85 | 86 | Divides the quaternion by its magnitude. If the magnitude is too 87 | small, returns the quaternion identity value (1.0). 88 | 89 | Args: 90 | q: A quaternion to be normalized. 91 | 92 | Raises: 93 | ValueError: If input quaternion has length near zero. 94 | 95 | Returns: 96 | A quaternion with magnitude 1 in a numpy array [x, y, z, w]. 97 | 98 | """ 99 | q_norm = np.linalg.norm(q) 100 | if np.isclose(q_norm, 0.0): 101 | raise ValueError( 102 | 'Quaternion may not be zero in QuaternionNormalize: |q| = %f, q = %s' % 103 | (q_norm, q)) 104 | return q / q_norm 105 | 106 | 107 | def QuaternionFromAxisAngle(axis, angle): 108 | """Returns a quaternion that generates the given axis-angle rotation. 109 | 110 | Returns the quaternion: sin(angle/2) * axis + cos(angle/2). 111 | 112 | Args: 113 | axis: Axis of rotation, a 3D vector in a numpy array. 114 | angle: The angle of rotation (radians). 115 | 116 | Raises: 117 | ValueError: If input axis is not a normalizable 3D vector. 118 | 119 | Returns: 120 | A unit quaternion in a numpy array. 121 | 122 | """ 123 | if len(axis) != 3: 124 | raise ValueError('Axis vector should have three components: %s' % axis) 125 | axis_norm = np.linalg.norm(axis) 126 | if np.isclose(axis_norm, 0.0): 127 | raise ValueError('Axis vector may not have zero length: |v| = %f, v = %s' % 128 | (axis_norm, axis)) 129 | half_angle = angle * 0.5 130 | q = np.zeros(4, dtype=np.float64) 131 | q[0:3] = axis 132 | q[0:3] *= math.sin(half_angle) / axis_norm 133 | q[3] = math.cos(half_angle) 134 | return q 135 | 136 | 137 | def QuaternionToAxisAngle(quat, default_axis=VECTOR3_Z, direction_axis=None): 138 | """Calculates axis and angle of rotation performed by a quaternion. 139 | 140 | Calculates the axis and angle of the rotation performed by the quaternion. 141 | The quaternion should have four values and be normalized. 142 | 143 | Args: 144 | quat: Unit quaternion in a numpy array. 145 | default_axis: 3D vector axis used if the rotation is near to zero. Without 146 | this default, small rotations would result in an exception. It is 147 | reasonable to use a default axis for tiny rotations, because zero angle 148 | rotations about any axis are equivalent. 149 | direction_axis: Used to disambiguate rotation directions. If the 150 | direction_axis is specified, the axis of the rotation will be chosen such 151 | that its inner product with the direction_axis is non-negative. 152 | 153 | Raises: 154 | ValueError: If quat is not a normalized quaternion. 155 | 156 | Returns: 157 | axis: Axis of rotation. 158 | angle: Angle in radians. 159 | """ 160 | if len(quat) != 4: 161 | raise ValueError( 162 | 'Quaternion should have four components [x, y, z, w]: %s' % quat) 163 | if not np.isclose(1.0, np.linalg.norm(quat)): 164 | raise ValueError('Quaternion should have unit length: |q| = %f, q = %s' % 165 | (np.linalg.norm(quat), quat)) 166 | axis = quat[:3].copy() 167 | axis_norm = np.linalg.norm(axis) 168 | min_axis_norm = 1e-8 169 | if axis_norm < min_axis_norm: 170 | axis = default_axis 171 | if len(default_axis) != 3: 172 | raise ValueError('Axis vector should have three components: %s' % axis) 173 | if not np.isclose(np.linalg.norm(axis), 1.0): 174 | raise ValueError('Axis vector should have unit length: |v| = %f, v = %s' % 175 | (np.linalg.norm(axis), axis)) 176 | else: 177 | axis /= axis_norm 178 | sin_half_angle = axis_norm 179 | if direction_axis is not None and np.inner(axis, direction_axis) < 0: 180 | sin_half_angle = -sin_half_angle 181 | axis = -axis 182 | cos_half_angle = quat[3] 183 | half_angle = math.atan2(sin_half_angle, cos_half_angle) 184 | angle = half_angle * 2 185 | return axis, angle 186 | 187 | 188 | def QuaternionRandomRotation(max_angle=math.pi): 189 | """Creates a random small rotation around a random axis. 190 | 191 | Generates a small rotation with the axis vector selected uniformly 192 | from the unit sphere and an angle selected from a uniform 193 | distribution over [0, max_angle]. 194 | 195 | If the max_angle is not specified, the rotation should be selected 196 | uniformly over all possible rotation angles. 197 | 198 | Args: 199 | max_angle: The maximum angle of rotation (radians). 200 | 201 | Returns: 202 | A unit quaternion in a numpy array. 203 | 204 | """ 205 | 206 | angle = np.random.uniform(low=0, high=max_angle) 207 | axis = Vector3RandomUnit() 208 | return QuaternionFromAxisAngle(axis, angle) 209 | 210 | 211 | def QuaternionRotatePoint(point, quat): 212 | """Performs a rotation by quaternion. 213 | 214 | Rotate the point by the quaternion using quaternion multiplication, 215 | (q * p * q^-1), without constructing the rotation matrix. 216 | 217 | Args: 218 | point: The point to be rotated. 219 | quat: The rotation represented as a quaternion [x, y, z, w]. 220 | 221 | Returns: 222 | A 3D vector in a numpy array. 223 | """ 224 | 225 | q_point = np.array([point[0], point[1], point[2], 0.0]) 226 | quat_inverse = transformations.quaternion_inverse(quat) 227 | q_point_rotated = transformations.quaternion_multiply( 228 | transformations.quaternion_multiply(quat, q_point), quat_inverse) 229 | return q_point_rotated[:3] 230 | 231 | 232 | def IsRotationMatrix(m): 233 | """Returns true if the 3x3 submatrix represents a rotation. 234 | 235 | Args: 236 | m: A transformation matrix. 237 | 238 | Raises: 239 | ValueError: If input is not a matrix of size at least 3x3. 240 | 241 | Returns: 242 | True if the 3x3 submatrix is a rotation (orthogonal). 243 | """ 244 | if len(m.shape) != 2 or m.shape[0] < 3 or m.shape[1] < 3: 245 | raise ValueError('Matrix should be 3x3 or 4x4: %s\n %s' % (m.shape, m)) 246 | rot = m[:3, :3] 247 | eye = np.matmul(rot, np.transpose(rot)) 248 | return np.isclose(eye, np.identity(3), atol=1e-4).all() 249 | 250 | # def ZAxisAlignedRobotPoseTool(robot_pose_tool): 251 | # """Returns the current gripper pose rotated for alignment with the z-axis. 252 | 253 | # Args: 254 | # robot_pose_tool: a pose3d.Pose3d() instance. 255 | 256 | # Returns: 257 | # An instance of pose.Transform representing the current gripper pose 258 | # rotated for alignment with the z-axis. 259 | # """ 260 | # # Align the current pose to the z-axis. 261 | # robot_pose_tool.quaternion = transformations.quaternion_multiply( 262 | # RotationBetween( 263 | # robot_pose_tool.matrix4x4[0:3, 0:3].dot(np.array([0, 0, 1])), 264 | # np.array([0.0, 0.0, -1.0])), robot_pose_tool.quaternion) 265 | # return robot_pose_tool 266 | 267 | # def RotationBetween(a_translation_b, a_translation_c): 268 | # """Computes the rotation from one vector to another. 269 | 270 | # The computed rotation has the property that: 271 | 272 | # a_translation_c = a_rotation_b_to_c * a_translation_b 273 | 274 | # Args: 275 | # a_translation_b: vec3, vector to rotate from 276 | # a_translation_c: vec3, vector to rotate to 277 | 278 | # Returns: 279 | # a_rotation_b_to_c: new Orientation 280 | # """ 281 | # rotation = rotation3.Rotation3.rotation_between( 282 | # a_translation_b, a_translation_c, err_msg='RotationBetween') 283 | # return rotation.quaternion.xyzw 284 | -------------------------------------------------------------------------------- /rsl_rl/env/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .vec_env import VecEnv -------------------------------------------------------------------------------- /rsl_rl/env/vec_env.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from abc import ABC, abstractmethod 32 | import torch 33 | from typing import Tuple, Union 34 | 35 | # minimal interface of the environment 36 | class VecEnv(ABC): 37 | num_envs: int 38 | num_obs: int 39 | num_privileged_obs: int 40 | num_actions: int 41 | max_episode_length: int 42 | privileged_obs_buf: torch.Tensor 43 | obs_buf: torch.Tensor 44 | rew_buf: torch.Tensor 45 | reset_buf: torch.Tensor 46 | episode_length_buf: torch.Tensor # current episode duration 47 | extras: dict 48 | device: torch.device 49 | @abstractmethod 50 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]: 51 | pass 52 | @abstractmethod 53 | def reset(self, env_ids: Union[list, torch.Tensor]): 54 | pass 55 | @abstractmethod 56 | def get_observations(self) -> torch.Tensor: 57 | pass 58 | @abstractmethod 59 | def get_privileged_observations(self) -> Union[torch.Tensor, None]: 60 | pass -------------------------------------------------------------------------------- /rsl_rl/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .actor_critic import ActorCritic 32 | from .actor_critic_wmp import ActorCriticWMP 33 | from .actor_critic_recurrent import ActorCriticRecurrent 34 | from .depth_predictor import DepthPredictor -------------------------------------------------------------------------------- /rsl_rl/modules/actor_critic.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import numpy as np 32 | 33 | import torch 34 | import torch.nn as nn 35 | from torch.distributions import Normal 36 | from torch.nn.modules import rnn 37 | 38 | 39 | class ActorCritic(nn.Module): 40 | is_recurrent = False 41 | 42 | def __init__(self, num_actor_obs, 43 | num_critic_obs, 44 | num_actions, 45 | actor_hidden_dims=[256, 256, 256], 46 | critic_hidden_dims=[256, 256, 256], 47 | activation='elu', 48 | init_noise_std=1.0, 49 | fixed_std=False, 50 | **kwargs): 51 | if kwargs: 52 | print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str( 53 | [key for key in kwargs.keys()])) 54 | super(ActorCritic, self).__init__() 55 | 56 | activation = get_activation(activation) 57 | 58 | mlp_input_dim_a = num_actor_obs 59 | mlp_input_dim_c = num_critic_obs 60 | 61 | # Policy 62 | actor_layers = [] 63 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 64 | actor_layers.append(activation) 65 | for l in range(len(actor_hidden_dims)): 66 | if l == len(actor_hidden_dims) - 1: 67 | actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions)) 68 | else: 69 | actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1])) 70 | actor_layers.append(activation) 71 | self.actor = nn.Sequential(*actor_layers) 72 | 73 | # Value function 74 | critic_layers = [] 75 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 76 | critic_layers.append(activation) 77 | for l in range(len(critic_hidden_dims)): 78 | if l == len(critic_hidden_dims) - 1: 79 | critic_layers.append(nn.Linear(critic_hidden_dims[l], 1)) 80 | else: 81 | critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1])) 82 | critic_layers.append(activation) 83 | self.critic = nn.Sequential(*critic_layers) 84 | 85 | print(f"Actor MLP: {self.actor}") 86 | print(f"Critic MLP: {self.critic}") 87 | 88 | # Action noise 89 | self.fixed_std = fixed_std 90 | std = init_noise_std * torch.ones(num_actions) 91 | self.std = torch.tensor(std) if fixed_std else nn.Parameter(std) 92 | self.distribution = None 93 | # disable args validation for speedup 94 | Normal.set_default_validate_args = False 95 | 96 | # seems that we get better performance without init 97 | # self.init_memory_weights(self.memory_a, 0.001, 0.) 98 | # self.init_memory_weights(self.memory_c, 0.001, 0.) 99 | 100 | @staticmethod 101 | # not used at the moment 102 | def init_weights(sequential, scales): 103 | [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in 104 | enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))] 105 | 106 | def reset(self, dones=None): 107 | pass 108 | 109 | def forward(self): 110 | raise NotImplementedError 111 | 112 | @property 113 | def action_mean(self): 114 | return self.distribution.mean 115 | 116 | @property 117 | def action_std(self): 118 | return self.distribution.stddev 119 | 120 | @property 121 | def entropy(self): 122 | return self.distribution.entropy().sum(dim=-1) 123 | 124 | def update_distribution(self, observations): 125 | mean = self.actor(observations) 126 | std = self.std.to(mean.device) 127 | self.distribution = Normal(mean, mean * 0. + std) 128 | 129 | def act(self, observations, **kwargs): 130 | self.update_distribution(observations) 131 | return self.distribution.sample() 132 | 133 | def get_actions_log_prob(self, actions): 134 | return self.distribution.log_prob(actions).sum(dim=-1) 135 | 136 | def act_inference(self, observations): 137 | actions_mean = self.actor(observations) 138 | return actions_mean 139 | 140 | def evaluate(self, critic_observations, **kwargs): 141 | value = self.critic(critic_observations) 142 | return value 143 | 144 | 145 | def get_activation(act_name): 146 | if act_name == "elu": 147 | return nn.ELU() 148 | elif act_name == "selu": 149 | return nn.SELU() 150 | elif act_name == "relu": 151 | return nn.ReLU() 152 | elif act_name == "crelu": 153 | return nn.ReLU() 154 | elif act_name == "lrelu": 155 | return nn.LeakyReLU() 156 | elif act_name == "tanh": 157 | return nn.Tanh() 158 | elif act_name == "sigmoid": 159 | return nn.Sigmoid() 160 | else: 161 | print("invalid activation function!") 162 | return None -------------------------------------------------------------------------------- /rsl_rl/modules/actor_critic_recurrent.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import numpy as np 32 | 33 | import torch 34 | import torch.nn as nn 35 | from torch.distributions import Normal 36 | from torch.nn.modules import rnn 37 | from .actor_critic import ActorCritic, get_activation 38 | from rsl_rl.utils import unpad_trajectories 39 | 40 | class ActorCriticRecurrent(ActorCritic): 41 | is_recurrent = True 42 | def __init__(self, num_actor_obs, 43 | num_critic_obs, 44 | num_actions, 45 | actor_hidden_dims=[256, 256, 256], 46 | critic_hidden_dims=[256, 256, 256], 47 | activation='elu', 48 | rnn_type='lstm', 49 | rnn_hidden_size=256, 50 | rnn_num_layers=1, 51 | init_noise_std=1.0, 52 | **kwargs): 53 | if kwargs: 54 | print("ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),) 55 | 56 | super().__init__(num_actor_obs=rnn_hidden_size, 57 | num_critic_obs=rnn_hidden_size, 58 | num_actions=num_actions, 59 | actor_hidden_dims=actor_hidden_dims, 60 | critic_hidden_dims=critic_hidden_dims, 61 | activation=activation, 62 | init_noise_std=init_noise_std) 63 | 64 | activation = get_activation(activation) 65 | 66 | self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 67 | self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 68 | 69 | print(f"Actor RNN: {self.memory_a}") 70 | print(f"Critic RNN: {self.memory_c}") 71 | 72 | def reset(self, dones=None): 73 | self.memory_a.reset(dones) 74 | self.memory_c.reset(dones) 75 | 76 | def act(self, observations, masks=None, hidden_states=None): 77 | input_a = self.memory_a(observations, masks, hidden_states) 78 | return super().act(input_a.squeeze(0)) 79 | 80 | def act_inference(self, observations): 81 | input_a = self.memory_a(observations) 82 | return super().act_inference(input_a.squeeze(0)) 83 | 84 | def evaluate(self, critic_observations, masks=None, hidden_states=None): 85 | input_c = self.memory_c(critic_observations, masks, hidden_states) 86 | return super().evaluate(input_c.squeeze(0)) 87 | 88 | def get_hidden_states(self): 89 | return self.memory_a.hidden_states, self.memory_c.hidden_states 90 | 91 | 92 | class Memory(torch.nn.Module): 93 | def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256): 94 | super().__init__() 95 | # RNN 96 | rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM 97 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 98 | self.hidden_states = None 99 | 100 | def forward(self, input, masks=None, hidden_states=None): 101 | batch_mode = masks is not None 102 | if batch_mode: 103 | # batch mode (policy update): need saved hidden states 104 | if hidden_states is None: 105 | raise ValueError("Hidden states not passed to memory module during policy update") 106 | out, _ = self.rnn(input, hidden_states) 107 | out = unpad_trajectories(out, masks) 108 | else: 109 | # inference mode (collection): use hidden states of last step 110 | out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) 111 | return out 112 | 113 | def reset(self, dones=None): 114 | # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state 115 | if (self.hidden_states is not None): 116 | for i in range(len(self.hidden_states)): 117 | self.hidden_states[i][..., dones, :] = 0.0 -------------------------------------------------------------------------------- /rsl_rl/modules/actor_critic_wmp.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | import numpy as np 35 | 36 | import torch 37 | import torch.nn as nn 38 | from torch.distributions import Normal 39 | from torch.nn.modules import rnn 40 | 41 | 42 | class ActorCriticWMP(nn.Module): 43 | is_recurrent = False 44 | 45 | def __init__(self, num_actor_obs, 46 | num_critic_obs, 47 | num_actions, 48 | encoder_hidden_dims=[256, 128], 49 | wm_encoder_hidden_dims = [64, 32], 50 | actor_hidden_dims=[256, 256, 256], 51 | critic_hidden_dims=[256, 256, 256], 52 | activation='elu', 53 | init_noise_std=1.0, 54 | fixed_std=False, 55 | latent_dim = 32, 56 | height_dim=187, 57 | privileged_dim=3 + 24, 58 | history_dim = 42*5, 59 | wm_feature_dim = 1536, 60 | wm_latent_dim=16, 61 | **kwargs): 62 | if kwargs: 63 | print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str( 64 | [key for key in kwargs.keys()])) 65 | super(ActorCriticWMP, self).__init__() 66 | 67 | activation = get_activation(activation) 68 | 69 | self.latent_dim = latent_dim 70 | self.height_dim = height_dim 71 | self.privileged_dim = privileged_dim 72 | 73 | mlp_input_dim_a = latent_dim + 3 + wm_latent_dim #latent vector + command + wm_latent 74 | mlp_input_dim_c = num_critic_obs + wm_latent_dim 75 | 76 | # History Encoder 77 | encoder_layers = [] 78 | encoder_layers.append(nn.Linear(history_dim, encoder_hidden_dims[0])) 79 | encoder_layers.append(activation) 80 | for l in range(len(encoder_hidden_dims)): 81 | if l == len(encoder_hidden_dims) - 1: 82 | encoder_layers.append(nn.Linear(encoder_hidden_dims[l], latent_dim)) 83 | else: 84 | encoder_layers.append(nn.Linear(encoder_hidden_dims[l], encoder_hidden_dims[l + 1])) 85 | encoder_layers.append(activation) 86 | self.history_encoder = nn.Sequential(*encoder_layers) 87 | 88 | # World Model Feature Encoder 89 | wm_encoder_layers = [] 90 | wm_encoder_layers.append(nn.Linear(wm_feature_dim, wm_encoder_hidden_dims[0])) 91 | wm_encoder_layers.append(activation) 92 | for l in range(len(wm_encoder_hidden_dims)): 93 | if l == len(wm_encoder_hidden_dims) - 1: 94 | wm_encoder_layers.append(nn.Linear(wm_encoder_hidden_dims[l], wm_latent_dim)) 95 | else: 96 | wm_encoder_layers.append(nn.Linear(wm_encoder_hidden_dims[l], wm_encoder_hidden_dims[l + 1])) 97 | wm_encoder_layers.append(activation) 98 | self.wm_feature_encoder = nn.Sequential(*wm_encoder_layers) 99 | 100 | # Critic World Model Feature Encoder 101 | critic_wm_encoder_layers = [] 102 | critic_wm_encoder_layers.append(nn.Linear(wm_feature_dim, wm_encoder_hidden_dims[0])) 103 | critic_wm_encoder_layers.append(activation) 104 | for l in range(len(wm_encoder_hidden_dims)): 105 | if l == len(wm_encoder_hidden_dims) - 1: 106 | critic_wm_encoder_layers.append(nn.Linear(wm_encoder_hidden_dims[l], wm_latent_dim)) 107 | else: 108 | critic_wm_encoder_layers.append(nn.Linear(wm_encoder_hidden_dims[l], wm_encoder_hidden_dims[l + 1])) 109 | critic_wm_encoder_layers.append(activation) 110 | self.critic_wm_feature_encoder = nn.Sequential(*critic_wm_encoder_layers) 111 | 112 | # Policy 113 | actor_layers = [] 114 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 115 | actor_layers.append(activation) 116 | for l in range(len(actor_hidden_dims)): 117 | if l == len(actor_hidden_dims) - 1: 118 | actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions)) 119 | # actor_layers.append(nn.Tanh()) 120 | else: 121 | actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1])) 122 | actor_layers.append(activation) 123 | self.actor = nn.Sequential(*actor_layers) 124 | 125 | # Value function 126 | critic_layers = [] 127 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 128 | critic_layers.append(activation) 129 | for l in range(len(critic_hidden_dims)): 130 | if l == len(critic_hidden_dims) - 1: 131 | critic_layers.append(nn.Linear(critic_hidden_dims[l], 1)) 132 | else: 133 | critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1])) 134 | critic_layers.append(activation) 135 | 136 | self.critic = nn.Sequential(*critic_layers) 137 | 138 | 139 | 140 | print(f"Actor MLP: {self.actor}") 141 | print(f"Critic MLP: {self.critic}") 142 | 143 | # Action noise 144 | self.fixed_std = fixed_std 145 | std = init_noise_std * torch.ones(num_actions) 146 | self.std = torch.tensor(std) if fixed_std else nn.Parameter(std) 147 | self.distribution = None 148 | # disable args validation for speedup 149 | Normal.set_default_validate_args = False 150 | 151 | # seems that we get better performance without init 152 | # self.init_memory_weights(self.memory_a, 0.001, 0.) 153 | # self.init_memory_weights(self.memory_c, 0.001, 0.) 154 | 155 | @staticmethod 156 | # not used at the moment 157 | def init_weights(sequential, scales): 158 | [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in 159 | enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))] 160 | 161 | def reset(self, dones=None): 162 | pass 163 | 164 | def forward(self): 165 | raise NotImplementedError 166 | 167 | @property 168 | def action_mean(self): 169 | return self.distribution.mean 170 | 171 | @property 172 | def action_std(self): 173 | return self.distribution.stddev 174 | 175 | @property 176 | def entropy(self): 177 | return self.distribution.entropy().sum(dim=-1) 178 | 179 | def update_distribution(self, observations): 180 | mean = self.actor(observations) 181 | std = self.std.to(mean.device) 182 | self.distribution = Normal(mean, mean * 0. + std) 183 | 184 | def act(self, observations, history, wm_feature, **kwargs): 185 | latent_vector = self.history_encoder(history) 186 | command = observations[:, self.privileged_dim + 6:self.privileged_dim + 9] 187 | wm_latent_vector = self.wm_feature_encoder(wm_feature) 188 | concat_observations = torch.concat((latent_vector, command, wm_latent_vector), 189 | dim=-1) 190 | self.update_distribution(concat_observations) 191 | return self.distribution.sample() 192 | 193 | def get_latent_vector(self, observations, history, **kwargs): 194 | latent_vector = self.history_encoder(history) 195 | return latent_vector 196 | 197 | def get_linear_vel(self, observations, history, **kwargs): 198 | latent_vector = self.history_encoder(history) 199 | linear_vel = latent_vector[:,-3:] 200 | return linear_vel 201 | 202 | def get_actions_log_prob(self, actions): 203 | return self.distribution.log_prob(actions).sum(dim=-1) 204 | 205 | def act_inference(self, observations, history, wm_feature): 206 | latent_vector = self.history_encoder(history) 207 | command = observations[:, self.privileged_dim + 6:self.privileged_dim + 9] 208 | wm_latent_vector = self.wm_feature_encoder(wm_feature) 209 | concat_observations = torch.concat((latent_vector, command, wm_latent_vector), 210 | dim=-1) 211 | actions_mean = self.actor(concat_observations) 212 | return actions_mean 213 | 214 | def evaluate(self, critic_observations, wm_feature, **kwargs): 215 | wm_latent_vector = self.critic_wm_feature_encoder(wm_feature) 216 | concat_observations = torch.concat((critic_observations, wm_latent_vector), 217 | dim=-1) 218 | 219 | 220 | value = self.critic(concat_observations) 221 | return value 222 | 223 | 224 | def get_activation(act_name): 225 | if act_name == "elu": 226 | return nn.ELU() 227 | elif act_name == "selu": 228 | return nn.SELU() 229 | elif act_name == "relu": 230 | return nn.ReLU() 231 | elif act_name == "crelu": 232 | return nn.ReLU() 233 | elif act_name == "lrelu": 234 | return nn.LeakyReLU() 235 | elif act_name == "tanh": 236 | return nn.Tanh() 237 | elif act_name == "sigmoid": 238 | return nn.Sigmoid() 239 | else: 240 | print("invalid activation function!") 241 | return None -------------------------------------------------------------------------------- /rsl_rl/modules/depth_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (2024) Bytedance Ltd. and/or its affiliates 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.distributions import Normal 20 | from torch.nn.modules import rnn 21 | import torch.nn.functional as F 22 | from dreamer import tools 23 | 24 | class DepthPredictor(nn.Module): 25 | def __init__(self, forward_heightamp_dim = 525, 26 | prop_dim = 33, 27 | depth_image_dims = [64, 64], 28 | encoder_hidden_dims=[256, 128], 29 | depth=32, 30 | act="ELU", 31 | norm=True, 32 | kernel_size=4, 33 | minres=4, 34 | outscale=1.0, 35 | cnn_sigmoid=False,): 36 | 37 | 38 | # add this to fully recover the process of conv encoder 39 | h, w = depth_image_dims 40 | stages = int(np.log2(w) - np.log2(minres)) 41 | self.h_list = [] 42 | self.w_list = [] 43 | for i in range(stages): 44 | h, w = (h+1) // 2, (w+1) // 2 45 | self.h_list.append(h) 46 | self.w_list.append(w) 47 | self.h_list = self.h_list[::-1] 48 | self.w_list = self.w_list[::-1] 49 | self.h_list.append(depth_image_dims[0]) 50 | self.w_list.append(depth_image_dims[1]) 51 | 52 | super(DepthPredictor, self).__init__() 53 | act = getattr(torch.nn, act) 54 | self._cnn_sigmoid = cnn_sigmoid 55 | layer_num = len(self.h_list) - 1 56 | # layer_num = int(np.log2(shape[2]) - np.log2(minres)) 57 | # self._minres = minres 58 | # out_ch = minres**2 * depth * 2 ** (layer_num - 1) 59 | out_ch = self.h_list[0] * self.w_list[0] * depth * 2 ** (len(self.h_list) - 2) 60 | self._embed_size = out_ch 61 | 62 | in_dim = out_ch // (self.h_list[0] * self.w_list[0]) 63 | out_dim = in_dim // 2 64 | 65 | # Encoder 66 | encoder_layers = [] 67 | encoder_layers.append(nn.Linear(forward_heightamp_dim + prop_dim, encoder_hidden_dims[0])) 68 | encoder_layers.append(act()) 69 | for l in range(len(encoder_hidden_dims)): 70 | if l == len(encoder_hidden_dims) - 1: 71 | encoder_layers.append(nn.Linear(encoder_hidden_dims[l], self._embed_size)) 72 | else: 73 | encoder_layers.append(nn.Linear(encoder_hidden_dims[l], encoder_hidden_dims[l + 1])) 74 | encoder_layers.append(act()) 75 | self.encoder = nn.Sequential(*encoder_layers) 76 | 77 | 78 | layers = [] 79 | # h, w = minres, minres 80 | for i in range(layer_num): 81 | bias = False 82 | if i == layer_num - 1: 83 | out_dim = 1 84 | act = False 85 | bias = True 86 | norm = False 87 | 88 | if i != 0: 89 | in_dim = 2 ** (layer_num - (i - 1) - 2) * depth 90 | if(self.h_list[i] * 2 == self.h_list[i+1]): 91 | pad_h, outpad_h = 1, 0 92 | else: 93 | pad_h, outpad_h = 2, 1 94 | 95 | if(self.w_list[i] * 2 == self.w_list[i+1]): 96 | pad_w, outpad_w = 1, 0 97 | else: 98 | pad_w, outpad_w = 2, 1 99 | 100 | layers.append( 101 | nn.ConvTranspose2d( 102 | in_dim, 103 | out_dim, 104 | kernel_size, 105 | 2, 106 | padding=(pad_h, pad_w), 107 | output_padding=(outpad_h, outpad_w), 108 | bias=bias, 109 | ) 110 | ) 111 | if norm: 112 | layers.append(ImgChLayerNorm(out_dim)) 113 | if act: 114 | layers.append(act()) 115 | in_dim = out_dim 116 | out_dim //= 2 117 | # h, w = h * 2, w * 2 118 | [m.apply(tools.weight_init) for m in layers[:-1]] 119 | layers[-1].apply(tools.uniform_weight_init(outscale)) 120 | self.layers = nn.Sequential(*layers) 121 | 122 | 123 | def forward(self, forward_heightmap, prop): 124 | x = torch.concat((forward_heightmap, prop), dim=-1) 125 | x = self.encoder(x) 126 | # (batch, time, -1) -> (batch * time, h, w, ch) 127 | x = x.reshape( 128 | [-1, self.h_list[0], self.w_list[0], self._embed_size // (self.h_list[0] * self.w_list[0])] 129 | ) 130 | # (batch, time, -1) -> (batch * time, ch, h, w) 131 | x = x.permute(0, 3, 1, 2) 132 | # print('init decoder shape:', x.shape) 133 | # for layer in self.layers: 134 | # x = layer(x) 135 | # print(x.shape) 136 | x = self.layers(x) 137 | mean = x.permute(0, 2, 3, 1) 138 | if self._cnn_sigmoid: 139 | mean = F.sigmoid(mean) 140 | # else: 141 | # mean += 0.5 142 | return mean 143 | 144 | 145 | class ImgChLayerNorm(nn.Module): 146 | def __init__(self, ch, eps=1e-03): 147 | super(ImgChLayerNorm, self).__init__() 148 | self.norm = torch.nn.LayerNorm(ch, eps=eps) 149 | 150 | def forward(self, x): 151 | x = x.permute(0, 2, 3, 1) 152 | x = self.norm(x) 153 | x = x.permute(0, 3, 1, 2) 154 | return x 155 | -------------------------------------------------------------------------------- /rsl_rl/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | from .on_policy_runner import OnPolicyRunner 35 | from .wmp_runner import WMPRunner -------------------------------------------------------------------------------- /rsl_rl/runners/on_policy_runner.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 33 | 34 | import time 35 | import os 36 | from collections import deque 37 | import statistics 38 | 39 | from torch.utils.tensorboard import SummaryWriter 40 | import torch 41 | 42 | from rsl_rl.algorithms import PPO 43 | from rsl_rl.modules import ActorCritic, ActorCriticRecurrent 44 | from rsl_rl.env import VecEnv 45 | 46 | 47 | class OnPolicyRunner: 48 | 49 | def __init__(self, 50 | env: VecEnv, 51 | train_cfg, 52 | log_dir=None, 53 | device='cpu'): 54 | 55 | self.cfg=train_cfg["runner"] 56 | self.alg_cfg = train_cfg["algorithm"] 57 | self.policy_cfg = train_cfg["policy"] 58 | self.device = device 59 | self.env = env 60 | if self.env.num_privileged_obs is not None: 61 | num_critic_obs = self.env.num_privileged_obs 62 | else: 63 | num_critic_obs = self.env.num_obs 64 | actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic 65 | actor_critic: ActorCritic = actor_critic_class( self.env.num_obs, 66 | num_critic_obs, 67 | self.env.num_actions, 68 | **self.policy_cfg).to(self.device) 69 | alg_class = eval(self.cfg["algorithm_class_name"]) # PPO 70 | self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg) 71 | self.num_steps_per_env = self.cfg["num_steps_per_env"] 72 | self.save_interval = self.cfg["save_interval"] 73 | 74 | # init storage and model 75 | self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions]) 76 | 77 | # Log 78 | self.log_dir = log_dir 79 | self.writer = None 80 | self.tot_timesteps = 0 81 | self.tot_time = 0 82 | self.current_learning_iteration = 0 83 | 84 | _, _ = self.env.reset() 85 | 86 | def learn(self, num_learning_iterations, init_at_random_ep_len=False): 87 | # initialize writer 88 | if self.log_dir is not None and self.writer is None: 89 | self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) 90 | if init_at_random_ep_len: 91 | self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length)) 92 | obs = self.env.get_observations() 93 | privileged_obs = self.env.get_privileged_observations() 94 | critic_obs = privileged_obs if privileged_obs is not None else obs 95 | obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) 96 | self.alg.actor_critic.train() # switch to train mode (for dropout for example) 97 | 98 | ep_infos = [] 99 | rewbuffer = deque(maxlen=100) 100 | lenbuffer = deque(maxlen=100) 101 | cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) 102 | cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) 103 | 104 | tot_iter = self.current_learning_iteration + num_learning_iterations 105 | for it in range(self.current_learning_iteration, tot_iter): 106 | start = time.time() 107 | # Rollout 108 | with torch.inference_mode(): 109 | for i in range(self.num_steps_per_env): 110 | actions = self.alg.act(obs, critic_obs) 111 | obs, privileged_obs, rewards, dones, infos, _, _ = self.env.step(actions) 112 | critic_obs = privileged_obs if privileged_obs is not None else obs 113 | obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device) 114 | self.alg.process_env_step(rewards, dones, infos) 115 | 116 | if self.log_dir is not None: 117 | # Book keeping 118 | if 'episode' in infos: 119 | ep_infos.append(infos['episode']) 120 | cur_reward_sum += rewards 121 | cur_episode_length += 1 122 | new_ids = (dones > 0).nonzero(as_tuple=False) 123 | rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) 124 | lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) 125 | cur_reward_sum[new_ids] = 0 126 | cur_episode_length[new_ids] = 0 127 | 128 | stop = time.time() 129 | collection_time = stop - start 130 | 131 | # Learning step 132 | start = stop 133 | self.alg.compute_returns(critic_obs) 134 | 135 | mean_value_loss, mean_surrogate_loss = self.alg.update() 136 | stop = time.time() 137 | learn_time = stop - start 138 | if self.log_dir is not None: 139 | self.log(locals()) 140 | if it % self.save_interval == 0: 141 | self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it))) 142 | ep_infos.clear() 143 | 144 | self.current_learning_iteration += num_learning_iterations 145 | self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration))) 146 | 147 | def log(self, locs, width=80, pad=35): 148 | self.tot_timesteps += self.num_steps_per_env * self.env.num_envs 149 | self.tot_time += locs['collection_time'] + locs['learn_time'] 150 | iteration_time = locs['collection_time'] + locs['learn_time'] 151 | 152 | ep_string = f'' 153 | if locs['ep_infos']: 154 | for key in locs['ep_infos'][0]: 155 | infotensor = torch.tensor([], device=self.device) 156 | for ep_info in locs['ep_infos']: 157 | # handle scalar and zero dimensional tensor infos 158 | if not isinstance(ep_info[key], torch.Tensor): 159 | ep_info[key] = torch.Tensor([ep_info[key]]) 160 | if len(ep_info[key].shape) == 0: 161 | ep_info[key] = ep_info[key].unsqueeze(0) 162 | infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) 163 | value = torch.mean(infotensor) 164 | self.writer.add_scalar('Episode/' + key, value, locs['it']) 165 | ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" 166 | mean_std = self.alg.actor_critic.std.mean() 167 | fps = int(self.num_steps_per_env * self.env.num_envs / (locs['collection_time'] + locs['learn_time'])) 168 | 169 | self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it']) 170 | self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it']) 171 | self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it']) 172 | self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it']) 173 | self.writer.add_scalar('Perf/total_fps', fps, locs['it']) 174 | self.writer.add_scalar('Perf/collection time', locs['collection_time'], locs['it']) 175 | self.writer.add_scalar('Perf/learning_time', locs['learn_time'], locs['it']) 176 | if len(locs['rewbuffer']) > 0: 177 | self.writer.add_scalar('Train/mean_reward', statistics.mean(locs['rewbuffer']), locs['it']) 178 | self.writer.add_scalar('Train/mean_episode_length', statistics.mean(locs['lenbuffer']), locs['it']) 179 | self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time) 180 | self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time) 181 | 182 | str = f" \033[1m Learning iteration {locs['it']}/{self.current_learning_iteration + locs['num_learning_iterations']} \033[0m " 183 | 184 | if len(locs['rewbuffer']) > 0: 185 | log_string = (f"""{'#' * width}\n""" 186 | f"""{str.center(width, ' ')}\n\n""" 187 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ 188 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 189 | f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" 190 | f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" 191 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" 192 | f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" 193 | f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""") 194 | # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" 195 | # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") 196 | else: 197 | log_string = (f"""{'#' * width}\n""" 198 | f"""{str.center(width, ' ')}\n\n""" 199 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ 200 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 201 | f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" 202 | f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" 203 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""") 204 | # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" 205 | # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") 206 | 207 | log_string += ep_string 208 | log_string += (f"""{'-' * width}\n""" 209 | f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" 210 | f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" 211 | f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n""" 212 | f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * ( 213 | locs['num_learning_iterations'] - locs['it']):.1f}s\n""") 214 | print(log_string) 215 | 216 | def save(self, path, infos=None): 217 | torch.save({ 218 | 'model_state_dict': self.alg.actor_critic.state_dict(), 219 | 'optimizer_state_dict': self.alg.optimizer.state_dict(), 220 | 'iter': self.current_learning_iteration, 221 | 'infos': infos, 222 | }, path) 223 | 224 | def load(self, path, load_optimizer=True): 225 | loaded_dict = torch.load(path) 226 | self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict']) 227 | if load_optimizer: 228 | self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict']) 229 | self.current_learning_iteration = loaded_dict['iter'] 230 | return loaded_dict['infos'] 231 | 232 | def get_inference_policy(self, device=None): 233 | self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example) 234 | if device is not None: 235 | self.alg.actor_critic.to(device) 236 | return self.alg.actor_critic.act_inference 237 | -------------------------------------------------------------------------------- /rsl_rl/storage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from .rollout_storage import RolloutStorage -------------------------------------------------------------------------------- /rsl_rl/storage/replay_buffer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | 32 | import torch 33 | import numpy as np 34 | 35 | 36 | class ReplayBuffer: 37 | """Fixed-size buffer to store experience tuples.""" 38 | 39 | def __init__(self, obs_dim, buffer_size, device): 40 | """Initialize a ReplayBuffer object. 41 | Arguments: 42 | buffer_size (int): maximum size of buffer 43 | """ 44 | self.states = torch.zeros(buffer_size, obs_dim).to(device) 45 | self.next_states = torch.zeros(buffer_size, obs_dim).to(device) 46 | self.buffer_size = buffer_size 47 | self.device = device 48 | 49 | self.step = 0 50 | self.num_samples = 0 51 | 52 | def insert(self, states, next_states): 53 | """Add new states to memory.""" 54 | 55 | num_states = states.shape[0] 56 | start_idx = self.step 57 | end_idx = self.step + num_states 58 | if end_idx > self.buffer_size: 59 | self.states[self.step:self.buffer_size] = states[:self.buffer_size - self.step] 60 | self.next_states[self.step:self.buffer_size] = next_states[:self.buffer_size - self.step] 61 | self.states[:end_idx - self.buffer_size] = states[self.buffer_size - self.step:] 62 | self.next_states[:end_idx - self.buffer_size] = next_states[self.buffer_size - self.step:] 63 | else: 64 | self.states[start_idx:end_idx] = states 65 | self.next_states[start_idx:end_idx] = next_states 66 | 67 | self.num_samples = min(self.buffer_size, max(end_idx, self.num_samples)) 68 | self.step = (self.step + num_states) % self.buffer_size 69 | 70 | def feed_forward_generator(self, num_mini_batch, mini_batch_size): 71 | for _ in range(num_mini_batch): 72 | sample_idxs = np.random.choice(self.num_samples, size=mini_batch_size) 73 | yield (self.states[sample_idxs].to(self.device), 74 | self.next_states[sample_idxs].to(self.device)) 75 | -------------------------------------------------------------------------------- /rsl_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .utils import split_and_pad_trajectories, unpad_trajectories -------------------------------------------------------------------------------- /rsl_rl/utils/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | from typing import Tuple 31 | 32 | import torch 33 | import numpy as np 34 | 35 | _EPS = np.finfo(float).eps * 4.0 36 | 37 | 38 | def split_and_pad_trajectories(tensor, dones): 39 | """ Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory. 40 | Returns masks corresponding to valid parts of the trajectories 41 | Example: 42 | Input: [ [a1, a2, a3, a4 | a5, a6], 43 | [b1, b2 | b3, b4, b5 | b6] 44 | ] 45 | 46 | Output:[ [a1, a2, a3, a4], | [ [True, True, True, True], 47 | [a5, a6, 0, 0], | [True, True, False, False], 48 | [b1, b2, 0, 0], | [True, True, False, False], 49 | [b3, b4, b5, 0], | [True, True, True, False], 50 | [b6, 0, 0, 0] | [True, False, False, False], 51 | ] | ] 52 | 53 | Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions] 54 | """ 55 | dones = dones.clone() 56 | dones[-1] = 1 57 | # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping 58 | flat_dones = dones.transpose(1, 0).reshape(-1, 1) 59 | 60 | # Get length of trajectory by counting the number of successive not done elements 61 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0])) 62 | trajectory_lengths = done_indices[1:] - done_indices[:-1] 63 | trajectory_lengths_list = trajectory_lengths.tolist() 64 | # Extract the individual trajectories 65 | trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1),trajectory_lengths_list) 66 | padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) 67 | 68 | 69 | trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) 70 | return padded_trajectories, trajectory_masks 71 | 72 | def unpad_trajectories(trajectories, masks): 73 | """ Does the inverse operation of split_and_pad_trajectories() 74 | """ 75 | # Need to transpose before and after the masking to have proper reshaping 76 | return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0], trajectories.shape[-1]).transpose(1, 0) 77 | 78 | 79 | class RunningMeanStd(object): 80 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): 81 | """ 82 | Calulates the running mean and std of a data stream 83 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 84 | :param epsilon: helps with arithmetic issues 85 | :param shape: the shape of the data stream's output 86 | """ 87 | self.mean = np.zeros(shape, np.float64) 88 | self.var = np.ones(shape, np.float64) 89 | self.count = epsilon 90 | 91 | def update(self, arr: np.ndarray) -> None: 92 | batch_mean = np.mean(arr, axis=0) 93 | batch_var = np.var(arr, axis=0) 94 | batch_count = arr.shape[0] 95 | self.update_from_moments(batch_mean, batch_var, batch_count) 96 | 97 | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None: 98 | delta = batch_mean - self.mean 99 | tot_count = self.count + batch_count 100 | 101 | new_mean = self.mean + delta * batch_count / tot_count 102 | m_a = self.var * self.count 103 | m_b = batch_var * batch_count 104 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 105 | new_var = m_2 / (self.count + batch_count) 106 | 107 | new_count = batch_count + self.count 108 | 109 | self.mean = new_mean 110 | self.var = new_var 111 | self.count = new_count 112 | 113 | 114 | class Normalizer(RunningMeanStd): 115 | def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0): 116 | super().__init__(shape=input_dim) 117 | self.epsilon = epsilon 118 | self.clip_obs = clip_obs 119 | 120 | def normalize(self, input): 121 | return np.clip( 122 | (input - self.mean) / np.sqrt(self.var + self.epsilon), 123 | -self.clip_obs, self.clip_obs) 124 | 125 | def normalize_torch(self, input, device): 126 | mean_torch = torch.tensor( 127 | self.mean, device=device, dtype=torch.float32) 128 | std_torch = torch.sqrt(torch.tensor( 129 | self.var + self.epsilon, device=device, dtype=torch.float32)) 130 | return torch.clamp( 131 | (input - mean_torch) / std_torch, -self.clip_obs, self.clip_obs) 132 | 133 | def update_normalizer(self, rollouts, expert_loader): 134 | policy_data_generator = rollouts.feed_forward_generator_amp( 135 | None, mini_batch_size=expert_loader.batch_size) 136 | expert_data_generator = expert_loader.dataset.feed_forward_generator_amp( 137 | expert_loader.batch_size) 138 | 139 | for expert_batch, policy_batch in zip(expert_data_generator, policy_data_generator): 140 | self.update( 141 | torch.vstack(tuple(policy_batch) + tuple(expert_batch)).cpu().numpy()) 142 | 143 | 144 | class Normalize(torch.nn.Module): 145 | def __init__(self): 146 | super(Normalize, self).__init__() 147 | self.normalize = torch.nn.functional.normalize 148 | 149 | def forward(self, x): 150 | x = self.normalize(x, dim=-1) 151 | return x 152 | 153 | 154 | def quaternion_slerp(q0, q1, fraction, spin=0, shortestpath=True): 155 | """Batch quaternion spherical linear interpolation.""" 156 | 157 | out = torch.zeros_like(q0) 158 | 159 | zero_mask = torch.isclose(fraction, torch.zeros_like(fraction)).squeeze() 160 | ones_mask = torch.isclose(fraction, torch.ones_like(fraction)).squeeze() 161 | out[zero_mask] = q0[zero_mask] 162 | out[ones_mask] = q1[ones_mask] 163 | 164 | d = torch.sum(q0 * q1, dim=-1, keepdim=True) 165 | dist_mask = (torch.abs(torch.abs(d) - 1.0) < _EPS).squeeze() 166 | out[dist_mask] = q0[dist_mask] 167 | 168 | if shortestpath: 169 | d_old = torch.clone(d) 170 | d = torch.where(d_old < 0, -d, d) 171 | q1 = torch.where(d_old < 0, -q1, q1) 172 | 173 | angle = torch.acos(d) + spin * torch.pi 174 | angle_mask = (torch.abs(angle) < _EPS).squeeze() 175 | out[angle_mask] = q0[angle_mask] 176 | 177 | final_mask = torch.logical_or(zero_mask, ones_mask) 178 | final_mask = torch.logical_or(final_mask, dist_mask) 179 | final_mask = torch.logical_or(final_mask, angle_mask) 180 | final_mask = torch.logical_not(final_mask) 181 | 182 | isin = 1.0 / angle 183 | q0 *= torch.sin((1.0 - fraction) * angle) * isin 184 | q1 *= torch.sin(fraction * angle) * isin 185 | q0 += q1 186 | out[final_mask] = q0[final_mask] 187 | return out 188 | --------------------------------------------------------------------------------