├── LICENSE.txt
├── README.md
├── datasets
└── mocap_motions
│ ├── hop1.txt
│ ├── hop2.txt
│ ├── trot1.txt
│ └── trot2.txt
├── dreamer
├── __init__.py
├── configs.yaml
├── models.py
├── networks.py
└── tools.py
├── legged_gym
├── __init__.py
├── envs
│ ├── __init__.py
│ ├── a1
│ │ ├── a1_amp_config.py
│ │ └── a1_config.py
│ └── base
│ │ ├── base_config.py
│ │ ├── base_task.py
│ │ ├── legged_robot.py
│ │ ├── legged_robot_config.py
│ │ └── observation_buffer.py
├── scripts
│ ├── play.py
│ └── train.py
├── tests
│ └── test_env.py
└── utils
│ ├── __init__.py
│ ├── helpers.py
│ ├── logger.py
│ ├── math.py
│ ├── task_registry.py
│ ├── terrain.py
│ └── trimesh.py
├── licenses
├── assets
│ └── a1_license.txt
├── dependencies
│ └── matplotlib_license.txt
└── subcomponents
│ ├── dreamerv3-torch_license.txt
│ ├── leggedgym_license.txt
│ └── parkour_license.txt
├── requirements.txt
├── resources
└── robots
│ └── a1
│ ├── a1_license.txt
│ ├── meshes
│ ├── calf.dae
│ ├── hip.dae
│ ├── thigh.dae
│ ├── thigh_mirror.dae
│ ├── trunk.dae
│ └── trunk_A1.png
│ └── urdf
│ └── a1.urdf
└── rsl_rl
├── __init__.py
├── algorithms
├── __init__.py
├── amp_discriminator.py
├── amp_ppo.py
└── ppo.py
├── datasets
├── __init__.py
├── motion_loader.py
├── motion_util.py
└── pose3d.py
├── env
├── __init__.py
└── vec_env.py
├── modules
├── __init__.py
├── actor_critic.py
├── actor_critic_recurrent.py
├── actor_critic_wmp.py
└── depth_predictor.py
├── runners
├── __init__.py
├── on_policy_runner.py
└── wmp_runner.py
├── storage
├── __init__.py
├── replay_buffer.py
└── rollout_storage.py
└── utils
├── __init__.py
└── utils.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
203 |
204 |
205 | ========================================================================
206 | Other licenses
207 | ========================================================================
208 | The following components are provided under other License. See project link for details.
209 |
210 | leggedgym files from leggedgym: https://github.com/jadenvc/leggedgym BSD-3-Clause
211 | dreamer files from dreamerv3-torch: https://github.com/NM512/dreamerv3-torch MIT
212 | amp files from AMP_for_hardware: https://github.com/Alescontrela/AMP_for_hardware BSD-3-Clause
213 | trimesh files from parkour: https://github.com/ZiwenZhuang/parkour MIT
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
WMP
2 |
3 | Code for the paper:
4 | ### World Model-based Perception for Visual Legged Locomotion
5 | [Hang Lai](https://apex.sjtu.edu.cn/members/laihang@apexlab.org), [Jiahang Cao](https://apex.sjtu.edu.cn/members/jhcao@apexlab.org), [JiaFeng Xu](https://scholar.google.com/citations?user=GPmUxtIAAAAJ&hl=zh-CN&oi=ao), [Hongtao Wu](https://scholar.google.com/citations?user=7u0TYgIAAAAJ&hl=zh-CN&oi=ao), [Yunfeng Lin](https://apex.sjtu.edu.cn/members/yflin@apexlab.org), [Tao Kong](https://www.taokong.org/), [Yong Yu](https://scholar.google.com.hk/citations?user=-84M1m0AAAAJ&hl=zh-CN&oi=ao), [Weinan Zhang](https://wnzhang.net/)
6 |
7 | ### [🌐 Project Website](https://wmp-loco.github.io/) | [📄 Paper](https://arxiv.org/abs/2409.16784)
8 |
9 | ## Requirements
10 | 1. Create a new python virtual env with python 3.6, 3.7 or 3.8 (3.8 recommended)
11 | 2. Install pytorch:
12 | - `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117`
13 | 3. Install Isaac Gym
14 | - Download and install Isaac Gym Preview 3 (Preview 2 will not work!) from https://developer.nvidia.com/isaac-gym
15 | - `cd isaacgym/python && pip install -e .`
16 | 4. Install other packages:
17 | - `sudo apt-get install build-essential --fix-missing`
18 | - `sudo apt-get install ninja-build`
19 | - `pip install setuptools==59.5.0`
20 | - `pip install ruamel_yaml==0.17.4`
21 | - `sudo apt install libgl1-mesa-glx -y`
22 | - `pip install opencv-contrib-python`
23 | - `pip install -r requirements.txt`
24 |
25 | ## Training
26 | ```
27 | python legged_gym/scripts/train.py --task=a1_amp --headless --sim_device=cuda:0
28 | ```
29 | Training takes about 23G GPU memory, and at least 10k iterations recommended.
30 |
31 | ## Visualization
32 | **Please make sure you have trained the WMP before**
33 | ```
34 | python legged_gym/scripts/play.py --task=a1_amp --sim_device=cuda:0 --terrain=climb
35 | ```
36 |
37 |
38 | ## Acknowledgments
39 |
40 | We thank the authors of the following projects for making their code open source:
41 |
42 | - [leggedgym](https://github.com/leggedrobotics/legged_gym)
43 | - [dreamerv3-torch](https://github.com/NM512/dreamerv3-torch)
44 | - [AMP_for_hardware](https://github.com/Alescontrela/AMP_for_hardware)
45 | - [parkour](https://github.com/ZiwenZhuang/parkour/tree/main)
46 | - [extreme-parkour](https://github.com/chengxuxin/extreme-parkour)
47 |
48 |
49 |
50 | ## Citation
51 |
52 | If you find this project helpful, please consider citing our paper:
53 | ```
54 | @article{lai2024world,
55 | title={World Model-based Perception for Visual Legged Locomotion},
56 | author={Lai, Hang and Cao, Jiahang and Xu, Jiafeng and Wu, Hongtao and Lin, Yunfeng and Kong, Tao and Yu, Yong and Zhang, Weinan},
57 | journal={arXiv preprint arXiv:2409.16784},
58 | year={2024}
59 | }
60 | ```
--------------------------------------------------------------------------------
/dreamer/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2023 NM512
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
22 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
23 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
24 |
25 | # from .dreamer import *
26 | from .networks import *
27 | from .tools import *
28 | from .models import *
29 |
--------------------------------------------------------------------------------
/dreamer/configs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 |
3 | device: 'cuda:0'
4 | compile: True
5 | precision: 32
6 | debug: False
7 | video_pred_log: false
8 |
9 | # Environment
10 | task: 'a1'
11 | num_actions: 12
12 |
13 | # Model
14 | dyn_hidden: 512
15 | dyn_deter: 512
16 | dyn_stoch: 32
17 | dyn_discrete: 32
18 | dyn_rec_depth: 1
19 | dyn_mean_act: 'none'
20 | dyn_std_act: 'sigmoid2'
21 | dyn_min_std: 0.1
22 | grad_heads: ['decoder', 'reward']
23 | units: 512 # = other heads except decoder, e.g., reward
24 | act: 'SiLU'
25 | norm: True
26 | encoder:
27 | {mlp_keys: '.*', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 5, mlp_units: 1024, symlog_inputs: True}
28 | decoder:
29 | {mlp_keys: '.*', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 5, mlp_units: 1024, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse, outscale: 1.0}
30 | actor:
31 | {layers: 2, dist: 'normal', entropy: 3e-4, unimix_ratio: 0.01, std: 'learned', min_std: 0.1, max_std: 1.0, temp: 0.1, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 1.0}
32 | critic:
33 | {layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
34 | reward_head:
35 | {layers: 2, dist: 'symlog_disc', loss_scale: 0.0, outscale: 0.0}
36 | cont_head:
37 | {layers: 2, loss_scale: 1.0, outscale: 1.0}
38 | dyn_scale: 0.5
39 | rep_scale: 0.1
40 | kl_free: 1.0
41 | weight_decay: 0.0
42 | unimix_ratio: 0.01
43 | initial: 'learned'
44 |
45 | # Training
46 | train_steps_per_iter: 10
47 | train_start_steps: 10000
48 | batch_size: 16
49 | batch_length: 64
50 | train_ratio: 512
51 | pretrain: 100
52 | model_lr: 1e-4
53 | opt_eps: 1e-8
54 | grad_clip: 1000
55 | opt: 'adam'
56 |
57 |
--------------------------------------------------------------------------------
/dreamer/models.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2023 NM512
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
22 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
23 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
24 |
25 | import copy
26 | import torch
27 | from torch import nn
28 |
29 | from . import tools
30 | from . import networks
31 |
32 | to_np = lambda x: x.detach().cpu().numpy()
33 |
34 | class WorldModel(nn.Module):
35 | def __init__(self, config, obs_shape, use_camera):
36 | super(WorldModel, self).__init__()
37 | # self._step = step
38 | self._use_amp = True if config.precision == 16 else False
39 | self._config = config
40 | self.device = self._config.device
41 |
42 | self.encoder = networks.MultiEncoder(obs_shape, **config.encoder, use_camera=use_camera)
43 | self.embed_size = self.encoder.outdim
44 | self.dynamics = networks.RSSM(
45 | config.dyn_stoch,
46 | config.dyn_deter,
47 | config.dyn_hidden,
48 | config.dyn_rec_depth,
49 | config.dyn_discrete,
50 | config.act,
51 | config.norm,
52 | config.dyn_mean_act,
53 | config.dyn_std_act,
54 | config.dyn_min_std,
55 | config.unimix_ratio,
56 | config.initial,
57 | config.num_actions,
58 | self.embed_size,
59 | config.device,
60 | )
61 | self.heads = nn.ModuleDict()
62 | if config.dyn_discrete:
63 | feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
64 | else:
65 | feat_size = config.dyn_stoch + config.dyn_deter
66 | self.heads["decoder"] = networks.MultiDecoder(
67 | feat_size, obs_shape, **config.decoder, use_camera=use_camera
68 | )
69 | self.heads["reward"] = networks.MLP(
70 | feat_size,
71 | (255,) if config.reward_head["dist"] == "symlog_disc" else (),
72 | config.reward_head["layers"],
73 | config.units,
74 | config.act,
75 | config.norm,
76 | dist=config.reward_head["dist"],
77 | outscale=config.reward_head["outscale"],
78 | device=config.device,
79 | name="Reward",
80 | )
81 | # self.heads["cont"] = networks.MLP(
82 | # feat_size,
83 | # (),
84 | # config.cont_head["layers"],
85 | # config.units,
86 | # config.act,
87 | # config.norm,
88 | # dist="binary",
89 | # outscale=config.cont_head["outscale"],
90 | # device=config.device,
91 | # name="Cont",
92 | # )
93 | for name in config.grad_heads:
94 | assert name in self.heads, name
95 | self._model_opt = tools.Optimizer(
96 | "model",
97 | self.parameters(),
98 | config.model_lr,
99 | config.opt_eps,
100 | config.grad_clip,
101 | config.weight_decay,
102 | opt=config.opt,
103 | use_amp=self._use_amp,
104 | )
105 | print(
106 | f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
107 | )
108 | # other losses are scaled by 1.0.
109 | # can set different scale for terms in decoder here
110 | self._scales = dict(
111 | reward=config.reward_head["loss_scale"],
112 | image = 1.0,
113 | # clean_prop = 0,
114 | # cont=config.cont_head["loss_scale"],
115 | )
116 |
117 | def _train(self, data):
118 | # action (batch_size, batch_length, act_dim)
119 | # image (batch_size, batch_length, h, w, ch)
120 | # reward (batch_size, batch_length)
121 | # discount (batch_size, batch_length)
122 | data = self.preprocess(data)
123 |
124 | with tools.RequiresGrad(self):
125 | with torch.cuda.amp.autocast(self._use_amp):
126 | embed = self.encoder(data)
127 | post, prior = self.dynamics.observe(
128 | embed, data["action"], data["is_first"]
129 | )
130 | kl_free = self._config.kl_free
131 | dyn_scale = self._config.dyn_scale
132 | rep_scale = self._config.rep_scale
133 | kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
134 | post, prior, kl_free, dyn_scale, rep_scale
135 | )
136 | assert kl_loss.shape == embed.shape[:2], kl_loss.shape
137 | preds = {}
138 | for name, head in self.heads.items():
139 | grad_head = name in self._config.grad_heads
140 | feat = self.dynamics.get_feat(post)
141 | feat = feat if grad_head else feat.detach()
142 | pred = head(feat)
143 | if type(pred) is dict:
144 | preds.update(pred)
145 | else:
146 | preds[name] = pred
147 | losses = {}
148 | for name, pred in preds.items():
149 | loss = -pred.log_prob(data[name])
150 | assert loss.shape == embed.shape[:2], (name, loss.shape)
151 | losses[name] = loss
152 | scaled = {
153 | key: value * self._scales.get(key, 1.0)
154 | for key, value in losses.items()
155 | }
156 | model_loss = sum(scaled.values()) + kl_loss
157 | metrics = self._model_opt(torch.mean(model_loss), self.parameters())
158 |
159 | metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
160 | metrics["kl_free"] = kl_free
161 | metrics["dyn_scale"] = dyn_scale
162 | metrics["rep_scale"] = rep_scale
163 | metrics["dyn_loss"] = to_np(dyn_loss)
164 | metrics["rep_loss"] = to_np(rep_loss)
165 | metrics["kl"] = to_np(torch.mean(kl_value))
166 | with torch.cuda.amp.autocast(self._use_amp):
167 | metrics["prior_ent"] = to_np(
168 | torch.mean(self.dynamics.get_dist(prior).entropy())
169 | )
170 | metrics["post_ent"] = to_np(
171 | torch.mean(self.dynamics.get_dist(post).entropy())
172 | )
173 | context = dict(
174 | embed=embed,
175 | feat=self.dynamics.get_feat(post),
176 | kl=kl_value,
177 | postent=self.dynamics.get_dist(post).entropy(),
178 | )
179 | post = {k: v.detach() for k, v in post.items()}
180 | return post, context, metrics
181 |
182 | # this function is called during both rollout and training
183 | def preprocess(self, obs):
184 | # obs = obs.copy()
185 | # obs["image"] = torch.Tensor(obs["image"]) / 255.0
186 |
187 | # discount in obs seems useless
188 | # if "discount" in obs:
189 | # obs["discount"] *= self._config.discount
190 | # (batch_size, batch_length) -> (batch_size, batch_length, 1)
191 | # obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
192 | # 'is_first' is necesarry to initialize hidden state at training
193 | assert "is_first" in obs
194 | # 'is_terminal' is necesarry to train cont_head
195 | # assert "is_terminal" in obs
196 | # obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
197 | obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
198 | return obs
199 |
200 | def video_pred(self, data):
201 | data = self.preprocess(data)
202 | embed = self.encoder(data)
203 |
204 | states, _ = self.dynamics.observe(
205 | embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
206 | )
207 | recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[
208 | :6
209 | ]
210 | reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
211 | init = {k: v[:, -1] for k, v in states.items()}
212 | prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init)
213 | openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
214 | reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
215 | # observed image is given until 5 steps
216 | model = torch.cat([recon[:, :5], openl], 1)
217 | truth = data["image"][:6]
218 | model = model
219 | error = (model - truth + 1.0) / 2.0
220 |
221 | return torch.cat([truth, model], 2)
222 |
--------------------------------------------------------------------------------
/legged_gym/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import os
32 |
33 | LEGGED_GYM_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
34 | LEGGED_GYM_ENVS_DIR = os.path.join(LEGGED_GYM_ROOT_DIR, 'legged_gym', 'envs')
--------------------------------------------------------------------------------
/legged_gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR
32 | from legged_gym.envs.a1.a1_config import A1RoughCfg, A1RoughCfgPPO
33 | from .base.legged_robot import LeggedRobot
34 | from .a1.a1_config import A1RoughCfg, A1RoughCfgPPO
35 | from .a1.a1_amp_config import A1AMPCfg, A1AMPCfgPPO
36 |
37 |
38 | import os
39 |
40 | from legged_gym.utils.task_registry import task_registry
41 |
42 | task_registry.register( "a1", LeggedRobot, A1RoughCfg(), A1RoughCfgPPO() )
43 | task_registry.register( "a1_amp", LeggedRobot, A1AMPCfg(), A1AMPCfgPPO() )
44 |
--------------------------------------------------------------------------------
/legged_gym/envs/a1/a1_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | from legged_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO
35 |
36 | class A1RoughCfg( LeggedRobotCfg ):
37 |
38 | class env( LeggedRobotCfg.env ):
39 | num_envs = 5480
40 | include_history_steps = None # Number of steps of history to include.
41 | num_observations = 235
42 | num_privileged_obs = 235
43 | reference_state_initialization = False
44 | # reference_state_initialization_prob = 0.85
45 | # amp_motion_files = MOTION_FILES
46 |
47 | class init_state( LeggedRobotCfg.init_state ):
48 | pos = [0.0, 0.0, 0.42] # x,y,z [m]
49 | default_joint_angles = { # = target angles [rad] when action = 0.0
50 | 'FL_hip_joint': 0.1, # [rad]
51 | 'RL_hip_joint': 0.1, # [rad]
52 | 'FR_hip_joint': -0.1 , # [rad]
53 | 'RR_hip_joint': -0.1, # [rad]
54 |
55 | 'FL_thigh_joint': 0.8, # [rad]
56 | 'RL_thigh_joint': 1., # [rad]
57 | 'FR_thigh_joint': 0.8, # [rad]
58 | 'RR_thigh_joint': 1., # [rad]
59 |
60 | 'FL_calf_joint': -1.5, # [rad]
61 | 'RL_calf_joint': -1.5, # [rad]
62 | 'FR_calf_joint': -1.5, # [rad]
63 | 'RR_calf_joint': -1.5, # [rad]
64 | }
65 |
66 | class control( LeggedRobotCfg.control ):
67 | # PD Drive parameters:
68 | control_type = 'P'
69 | stiffness = {'joint': 20.} # [N*m/rad]
70 | damping = {'joint': 0.5} # [N*m*s/rad]
71 | # action scale: target angle = actionScale * action + defaultAngle
72 | action_scale = 0.25
73 | # decimation: Number of control action updates @ sim DT per policy DT
74 | decimation = 4
75 |
76 | class asset( LeggedRobotCfg.asset ):
77 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/a1/urdf/a1.urdf'
78 | foot_name = "foot"
79 | penalize_contacts_on = ["thigh", "calf"]
80 | terminate_after_contacts_on = ["base"]
81 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter
82 |
83 | class rewards( LeggedRobotCfg.rewards ):
84 | soft_dof_pos_limit = 0.9
85 | base_height_target = 0.25
86 | class scales( LeggedRobotCfg.rewards.scales ):
87 | torques = -0.0002
88 | dof_pos_limits = -10.0
89 |
90 | class A1RoughCfgPPO( LeggedRobotCfgPPO ):
91 | class algorithm( LeggedRobotCfgPPO.algorithm ):
92 | entropy_coef = 0.01
93 | class runner( LeggedRobotCfgPPO.runner ):
94 | run_name = ''
95 | experiment_name = 'rough_a1'
96 |
97 |
--------------------------------------------------------------------------------
/legged_gym/envs/base/base_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import inspect
32 |
33 | class BaseConfig:
34 | def __init__(self) -> None:
35 | """ Initializes all member classes recursively. Ignores all namse starting with '__' (buit-in methods)."""
36 | self.init_member_classes(self)
37 |
38 | @staticmethod
39 | def init_member_classes(obj):
40 | # iterate over all attributes names
41 | for key in dir(obj):
42 | # disregard builtin attributes
43 | # if key.startswith("__"):
44 | if key=="__class__":
45 | continue
46 | # get the corresponding attribute object
47 | var = getattr(obj, key)
48 | # check if it the attribute is a class
49 | if inspect.isclass(var):
50 | # instantate the class
51 | i_var = var()
52 | # set the attribute to the instance instead of the type
53 | setattr(obj, key, i_var)
54 | # recursively init members of the attribute
55 | BaseConfig.init_member_classes(i_var)
--------------------------------------------------------------------------------
/legged_gym/envs/base/base_task.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import sys
32 | from isaacgym import gymapi
33 | from isaacgym import gymutil
34 | import numpy as np
35 | import torch
36 |
37 | from legged_gym.envs.base import observation_buffer
38 |
39 |
40 | # Base class for RL tasks
41 | class BaseTask():
42 |
43 | def __init__(self, cfg, sim_params, physics_engine, sim_device, headless):
44 | self.gym = gymapi.acquire_gym()
45 |
46 | self.sim_params = sim_params
47 | self.physics_engine = physics_engine
48 | self.sim_device = sim_device
49 | sim_device_type, self.sim_device_id = gymutil.parse_device_str(self.sim_device)
50 | self.headless = headless
51 |
52 | # env device is GPU only if sim is on GPU and use_gpu_pipeline=True, otherwise returned tensors are copied to CPU by physX.
53 | if sim_device_type=='cuda' and sim_params.use_gpu_pipeline:
54 | self.device = self.sim_device
55 | else:
56 | self.device = 'cpu'
57 |
58 | # graphics device for rendering, -1 for no rendering
59 | self.graphics_device_id = self.sim_device_id
60 | if self.headless == True:
61 | self.graphics_device_id = -1
62 |
63 | self.num_envs = cfg.env.num_envs
64 | self.num_obs = cfg.env.num_observations
65 | self.num_privileged_obs = cfg.env.num_privileged_obs
66 | self.num_actions = cfg.env.num_actions
67 | self.include_history_steps = cfg.env.include_history_steps
68 |
69 | self.height_dim = cfg.env.height_dim
70 | self.privileged_dim = cfg.env.privileged_dim
71 |
72 | # optimization flags for pytorch JIT
73 | torch._C._jit_set_profiling_mode(False)
74 | torch._C._jit_set_profiling_executor(False)
75 |
76 | # allocate buffers
77 | if cfg.env.include_history_steps is not None:
78 | self.obs_buf_history = observation_buffer.ObservationBuffer(
79 | self.num_envs, self.num_obs,
80 | self.include_history_steps, self.device)
81 | self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float)
82 | self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
83 | self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long)
84 | self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
85 | self.time_out_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
86 | if self.num_privileged_obs is not None:
87 | self.privileged_obs_buf = torch.zeros(self.num_envs, self.num_privileged_obs, device=self.device, dtype=torch.float)
88 | else:
89 | self.privileged_obs_buf = None
90 | # self.num_privileged_obs = self.num_obs
91 |
92 | self.extras = {}
93 |
94 | # create envs, sim and viewer
95 | self.create_sim()
96 | self.gym.prepare_sim(self.sim)
97 |
98 | # todo: read from config
99 | self.enable_viewer_sync = True
100 | self.viewer = None
101 |
102 | # if running with a viewer, set up keyboard shortcuts and camera
103 | if self.headless == False:
104 | # subscribe to keyboard shortcuts
105 | self.viewer = self.gym.create_viewer(
106 | self.sim, gymapi.CameraProperties())
107 | self.gym.subscribe_viewer_keyboard_event(
108 | self.viewer, gymapi.KEY_ESCAPE, "QUIT")
109 | self.gym.subscribe_viewer_keyboard_event(
110 | self.viewer, gymapi.KEY_V, "toggle_viewer_sync")
111 |
112 | def get_observations(self):
113 | return self.obs_buf
114 |
115 | def get_privileged_observations(self):
116 | return self.privileged_obs_buf
117 |
118 | def reset_idx(self, env_ids):
119 | """Reset selected robots"""
120 | raise NotImplementedError
121 |
122 | def reset(self):
123 | """ Reset all robots"""
124 | self.reset_idx(torch.arange(self.num_envs, device=self.device))
125 | obs, privileged_obs, _, _, _ = self.step(torch.zeros(self.num_envs, self.num_actions, device=self.device, requires_grad=False))
126 | return obs, privileged_obs
127 |
128 | def step(self, actions):
129 | raise NotImplementedError
130 |
131 | def render(self, sync_frame_time=True):
132 | if self.viewer:
133 | # check for window closed
134 | if self.gym.query_viewer_has_closed(self.viewer):
135 | sys.exit()
136 |
137 | # check for keyboard events
138 | for evt in self.gym.query_viewer_action_events(self.viewer):
139 | if evt.action == "QUIT" and evt.value > 0:
140 | sys.exit()
141 | elif evt.action == "toggle_viewer_sync" and evt.value > 0:
142 | self.enable_viewer_sync = not self.enable_viewer_sync
143 |
144 | # fetch results
145 | if self.device != 'cpu':
146 | self.gym.fetch_results(self.sim, True)
147 |
148 | # step graphics
149 | if self.enable_viewer_sync:
150 | self.gym.step_graphics(self.sim)
151 | self.gym.draw_viewer(self.viewer, self.sim, True)
152 | if sync_frame_time:
153 | self.gym.sync_frame_time(self.sim)
154 | else:
155 | self.gym.poll_viewer_events(self.viewer)
--------------------------------------------------------------------------------
/legged_gym/envs/base/legged_robot_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | from .base_config import BaseConfig
35 |
36 | class LeggedRobotCfg(BaseConfig):
37 | class env:
38 | num_envs = 4096
39 | num_observations = 235
40 | privileged_obs = True # if True, add the privileged information in the obs
41 | privileged_dim = 24 + 3 # privileged_obs[:,:privileged_dim] is the privileged information in privileged_obs, include 3-dim base linear vel
42 | height_dim = 187 # privileged_obs[:,-height_dim:] is the heightmap in privileged_obs
43 | num_privileged_obs = None # if not None a priviledge_obs_buf will be returned by step() (critic obs for assymetric training). None is returned otherwise
44 | num_actions = 12
45 | env_spacing = 3. # not used with heightfields/trimeshes
46 | send_timeouts = True # send time out information to the algorithm
47 | episode_length_s = 20 # episode length in seconds
48 | reference_state_initialization = False # initialize state from reference data
49 |
50 | class terrain:
51 | mesh_type = 'trimesh' # "heightfield" # none, plane, heightfield or trimesh
52 | horizontal_scale = 0.1 # [m]
53 | vertical_scale = 0.005 # [m]
54 | border_size = 50 # [m] change 25 to 50
55 | curriculum = True
56 | static_friction = 1.0
57 | dynamic_friction = 1.0
58 | restitution = 0.
59 | # rough terrain only:
60 | measure_heights = True
61 | measured_points_x = [-0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] # 1mx1.6m rectangle (without center line)
62 | measured_points_y = [-0.5, -0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5]
63 | selected = False # select a unique terrain type and pass all arguments
64 | terrain_kwargs = None # Dict of arguments for selected terrain
65 | max_init_terrain_level = 0 # starting curriculum state
66 | terrain_length = 8.
67 | terrain_width = 8.
68 | num_rows= 10 # number of terrain rows (levels)
69 | num_cols = 20 # number of terrain cols (types)
70 | # terrain types: [wave, rough slope, stairs up, stairs down, discrete, rough_flat]
71 | terrain_proportions = [0.1, 0.1, 0.30, 0.25, 0.15, 0.1]
72 | # terrain_proportions = [0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
73 | # trimesh only:
74 | slope_treshold = 0.75 # slopes above this threshold will be corrected to vertical surfaces
75 |
76 | class commands:
77 | curriculum = False
78 | max_curriculum = 1.
79 | num_commands = 4 # default: lin_vel_x, lin_vel_y, ang_vel_yaw, heading (in heading mode ang_vel_yaw is recomputed from heading error)
80 | resampling_time = 10. # time before command are changed[s]
81 | heading_command = True # if true: compute ang vel command from heading error
82 | class ranges:
83 | lin_vel_x = [-1.0, 1.0] # min max [m/s]
84 | lin_vel_y = [-1.0, 1.0] # min max [m/s]
85 | ang_vel_yaw = [-1, 1] # min max [rad/s]
86 | heading = [-3.14, 3.14]
87 |
88 | class init_state:
89 | pos = [0.0, 0.0, 1.] # x,y,z [m]
90 | rot = [0.0, 0.0, 0.0, 1.0] # x,y,z,w [quat]
91 | lin_vel = [0.0, 0.0, 0.0] # x,y,z [m/s]
92 | ang_vel = [0.0, 0.0, 0.0] # x,y,z [rad/s]
93 | default_joint_angles = { # target angles when action = 0.0
94 | "joint_a": 0.,
95 | "joint_b": 0.}
96 |
97 | class control:
98 | control_type = 'P' # P: position, V: velocity, T: torques
99 | # PD Drive parameters:
100 | stiffness = {'joint_a': 10.0, 'joint_b': 15.} # [N*m/rad]
101 | damping = {'joint_a': 1.0, 'joint_b': 1.5} # [N*m*s/rad]
102 | # action scale: target angle = actionScale * action + defaultAngle
103 | action_scale = 0.5
104 | # decimation: Number of control action updates @ sim DT per policy DT
105 | decimation = 4
106 |
107 | class depth:
108 | use_camera = False
109 | camera_num_envs = 192
110 | camera_terrain_num_rows = 10
111 | camera_terrain_num_cols = 20
112 |
113 | position = [0.27, 0, 0.03] # front camera
114 | angle = [-5, 5] # positive pitch down
115 |
116 | update_interval = 5 # 5 works without retraining, 8 worse
117 |
118 | original = (106, 60)
119 | resized = (87, 58)
120 | horizontal_fov = 87
121 | buffer_len = 2
122 |
123 | near_clip = 0
124 | far_clip = 2
125 | dis_noise = 0.0
126 |
127 | scale = 1
128 | invert = True
129 |
130 | class asset:
131 | file = ""
132 | foot_name = "None" # name of the feet bodies, used to index body state and contact force tensors
133 | penalize_contacts_on = []
134 | terminate_after_contacts_on = []
135 | disable_gravity = False
136 | collapse_fixed_joints = True # merge bodies connected by fixed joints. Specific fixed joints can be kept by adding " <... dont_collapse="true">
137 | fix_base_link = False # fixe the base of the robot
138 | default_dof_drive_mode = 3 # see GymDofDriveModeFlags (0 is none, 1 is pos tgt, 2 is vel tgt, 3 effort)
139 | self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter
140 | replace_cylinder_with_capsule = True # replace collision cylinders with capsules, leads to faster/more stable simulation
141 | flip_visual_attachments = True # Some .obj meshes must be flipped from y-up to z-up
142 |
143 | density = 0.001
144 | angular_damping = 0.
145 | linear_damping = 0.
146 | max_angular_velocity = 1000.
147 | max_linear_velocity = 1000.
148 | armature = 0.
149 | thickness = 0.01
150 |
151 | class domain_rand:
152 | randomize_friction = True
153 | friction_range = [0.25, 1.75]
154 | randomize_restitution = True
155 | restitution_range = [0, 1]
156 |
157 | randomize_base_mass = True
158 | added_mass_range = [-1., 1.] # kg
159 | randomize_link_mass = True
160 | link_mass_range = [0.8, 1.2]
161 | randomize_com_pos = True
162 | com_pos_range = [-0.05, 0.05]
163 |
164 | push_robots = True
165 | push_interval_s = 15
166 | max_push_vel_xy = 1.0
167 |
168 | randomize_gains = True
169 | stiffness_multiplier_range = [0.9, 1.1]
170 | damping_multiplier_range = [0.9, 1.1]
171 | randomize_motor_strength = True
172 |
173 | motor_strength_range = [0.9, 1.1]
174 | randomize_action_latency = True
175 | latency_range = [0.00, 0.02]
176 |
177 | class rewards:
178 | reward_curriculum = True
179 | reward_curriculum_term = ["lin_vel_z"]
180 | reward_curriculum_schedule = [0, 1000, 1, 0] #from iter 0 to iter 1000, decrease from 1 to 0
181 | class scales:
182 | termination = -0.0
183 | tracking_lin_vel = 1.0
184 | tracking_ang_vel = 0.5
185 | lin_vel_z = -2.0
186 | ang_vel_xy = 0#-0.05
187 | orientation = -0.
188 | torques = -0.00001
189 | dof_vel = -0.
190 | dof_acc = -2.5e-7
191 | base_height = -0.
192 | feet_air_time = 1.0
193 | collision = -1.
194 | feet_stumble = -0.0
195 | action_rate = -0.01
196 | stand_still = -0.
197 |
198 | only_positive_rewards = True # if true negative total rewards are clipped at zero (avoids early termination problems)
199 | tracking_sigma = 0.15 # tracking reward = exp(-error^2/sigma)
200 | soft_dof_pos_limit = 1. # percentage of urdf limits, values above this limit are penalized
201 | soft_dof_vel_limit = 1.
202 | soft_torque_limit = 1.
203 | base_height_target = 1.
204 | max_contact_force = 100. # forces above this value are penalized
205 |
206 | class normalization:
207 | class obs_scales:
208 | lin_vel = 1.0
209 | ang_vel = 0.25
210 | dof_pos = 1.0
211 | dof_vel = 0.05
212 | # privileged
213 | height_measurements = 5.0
214 | contact_force = 0.005
215 | com_pos = 20
216 | pd_gains = 5
217 | clip_observations = 100.
218 | clip_actions = 6.0
219 |
220 | base_height = 0.5
221 |
222 | class noise:
223 | add_noise = True
224 | noise_level = 1.0 # scales other values
225 | class noise_scales:
226 | dof_pos = 0.01
227 | dof_vel = 1.5
228 | lin_vel = 0.1
229 | ang_vel = 0.2
230 | gravity = 0.05
231 | height_measurements = 0.1
232 |
233 | # viewer camera:
234 | class viewer:
235 | ref_env = 0
236 | pos = [10, 0, 6] # [m]
237 | lookat = [11., 5, 3.] # [m]
238 |
239 | class sim:
240 | dt = 0.002
241 | substeps = 1
242 | gravity = [0., 0. ,-9.81] # [m/s^2]
243 | up_axis = 1 # 0 is y, 1 is z
244 |
245 | class physx:
246 | num_threads = 10
247 | solver_type = 1 # 0: pgs, 1: tgs
248 | num_position_iterations = 4
249 | num_velocity_iterations = 0
250 | contact_offset = 0.01 # [m]
251 | rest_offset = 0.0 # [m]
252 | bounce_threshold_velocity = 0.5 #0.5 [m/s]
253 | max_depenetration_velocity = 1.0
254 | max_gpu_contact_pairs = 2**23 #2**24 -> needed for 8000 envs and more
255 | default_buffer_size_multiplier = 5
256 | contact_collection = 2 # 0: never, 1: last sub-step, 2: all sub-steps (default=2)
257 |
258 | class LeggedRobotCfgPPO(BaseConfig):
259 | seed = 1
260 | runner_class_name = 'OnPolicyRunner'
261 | class policy:
262 | init_noise_std = 1.0
263 | actor_hidden_dims = [512, 256, 128]
264 | critic_hidden_dims = [512, 256, 128]
265 | latent_dim = 32
266 | # height_latent_dim = 16 # the encoder in teacher policy encodes the heightmap into a height_latent_dim vector
267 | # privileged_latent_dim = 8 # the encoder in teacher policy encodes the privileged infomation into a privileged_latent_dim vector
268 | activation = 'elu' # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid
269 | # only for 'ActorCriticRecurrent':
270 | # rnn_type = 'lstm'
271 | # rnn_hidden_size = 512
272 | # rnn_num_layers = 1
273 |
274 | class algorithm:
275 | # training params
276 | value_loss_coef = 1.0
277 | use_clipped_value_loss = True
278 | clip_param = 0.2
279 | entropy_coef = 0.01
280 | num_learning_epochs = 5
281 | num_mini_batches = 4 # mini batch size = num_envs*nsteps / nminibatches
282 | learning_rate = 1.e-3 #5.e-4
283 | schedule = 'adaptive' # could be adaptive, fixed
284 | gamma = 0.99
285 | lam = 0.95
286 | desired_kl = 0.01
287 | max_grad_norm = 1.
288 |
289 | class runner:
290 | policy_class_name = 'ActorCritic'
291 | algorithm_class_name = 'PPO'
292 | num_steps_per_env = 24 # per iteration
293 | max_iterations = 1500 # number of policy updates
294 |
295 | # logging
296 | save_interval = 200 # check for potential saves every this many iterations
297 | experiment_name = 'test'
298 | run_name = 'trot'
299 | # load and resume
300 | resume = False
301 | load_run = -1 # -1 = last run
302 | checkpoint = -1 # -1 = last saved model
303 | resume_path = None # updated from load_run and chkpt
--------------------------------------------------------------------------------
/legged_gym/envs/base/observation_buffer.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import torch
32 |
33 | class ObservationBuffer:
34 | def __init__(self, num_envs, num_obs, include_history_steps, device):
35 |
36 | self.num_envs = num_envs
37 | self.num_obs = num_obs
38 | self.include_history_steps = include_history_steps
39 | self.device = device
40 |
41 | self.num_obs_total = num_obs * include_history_steps
42 |
43 | self.obs_buf = torch.zeros(self.num_envs, self.num_obs_total, device=self.device, dtype=torch.float)
44 |
45 | def reset(self, reset_idxs, new_obs):
46 | self.obs_buf[reset_idxs] = new_obs.repeat(1, self.include_history_steps)
47 |
48 | def insert(self, new_obs):
49 | # Shift observations back.
50 | self.obs_buf[:, : self.num_obs * (self.include_history_steps - 1)] = self.obs_buf[:,self.num_obs : self.num_obs * self.include_history_steps]
51 |
52 | # Add new observation.
53 | self.obs_buf[:, -self.num_obs:] = new_obs
54 |
55 | def get_obs_vec(self, obs_ids):
56 | """Gets history of observations indexed by obs_ids.
57 |
58 | Arguments:
59 | obs_ids: An array of integers with which to index the desired
60 | observations, where 0 is the latest observation and
61 | include_history_steps - 1 is the oldest observation.
62 | """
63 |
64 | obs = []
65 | for obs_id in reversed(sorted(obs_ids)):
66 | slice_idx = self.include_history_steps - obs_id - 1
67 | obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs])
68 | return torch.cat(obs, dim=-1)
69 |
70 |
--------------------------------------------------------------------------------
/legged_gym/scripts/play.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | import os
35 | import inspect
36 | import time
37 |
38 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
39 | parentdir = os.path.dirname(os.path.dirname(currentdir))
40 | os.sys.path.insert(0, parentdir)
41 | from legged_gym import LEGGED_GYM_ROOT_DIR
42 |
43 | import isaacgym
44 | from legged_gym.envs import *
45 | from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger
46 |
47 | import numpy as np
48 | import torch
49 |
50 |
51 | def play(args):
52 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
53 | # override some parameters for testing
54 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 10)
55 | env_cfg.terrain.num_cols = 1
56 | env_cfg.terrain.curriculum = False
57 | env_cfg.noise.add_noise = False
58 | # env_cfg.domain_rand.randomize_friction = False
59 | # env_cfg.domain_rand.randomize_restitution = False
60 | # env_cfg.commands.heading_command = True
61 |
62 | env_cfg.domain_rand.friction_range = [1.0, 1.0]
63 | env_cfg.domain_rand.restitution_range = [0.0, 0.0]
64 | env_cfg.domain_rand.added_mass_range = [0., 0.] # kg
65 | env_cfg.domain_rand.com_x_pos_range = [-0.0, 0.0]
66 | env_cfg.domain_rand.com_y_pos_range = [-0.0, 0.0]
67 | env_cfg.domain_rand.com_z_pos_range = [-0.0, 0.0]
68 |
69 | env_cfg.domain_rand.randomize_action_latency = False
70 | env_cfg.domain_rand.push_robots = False
71 | env_cfg.domain_rand.randomize_gains = True
72 | # env_cfg.domain_rand.randomize_base_mass = False
73 | env_cfg.domain_rand.randomize_link_mass = False
74 | # env_cfg.domain_rand.randomize_com_pos = False
75 | env_cfg.domain_rand.randomize_motor_strength = False
76 |
77 | train_cfg.runner.amp_num_preload_transitions = 1
78 |
79 | env_cfg.domain_rand.stiffness_multiplier_range = [1.0, 1.0]
80 | env_cfg.domain_rand.damping_multiplier_range = [1.0, 1.0]
81 |
82 |
83 | # env_cfg.terrain.mesh_type = 'plane'
84 | if(env_cfg.terrain.mesh_type == 'plane'):
85 | env_cfg.rewards.scales.feet_edge = 0
86 | env_cfg.rewards.scales.feet_stumble = 0
87 |
88 |
89 | if(args.terrain not in ['slope', 'stair', 'gap', 'climb', 'crawl', 'tilt']):
90 | print('terrain should be one of slope, stair, gap, climb, crawl, and tilt, set to climb as default')
91 | args.terrain = 'climb'
92 | env_cfg.terrain.terrain_proportions = {
93 | 'slope': [0, 1.0, 0.0, 0, 0, 0, 0, 0, 0],
94 | 'stair': [0, 0, 1.0, 0, 0, 0, 0, 0, 0],
95 | 'gap': [0, 0, 0, 0, 0, 1.0, 0, 0, 0, 0],
96 | 'climb': [0, 0, 0, 0, 0, 0, 1.0, 0, 0, 0],
97 | 'tilt': [0, 0, 0, 0, 0, 0, 0, 1.0, 0, 0],
98 | 'crawl': [0, 0, 0, 0, 0, 0, 0, 0, 1.0, 0],
99 | }[args.terrain]
100 |
101 | env_cfg.commands.ranges.lin_vel_x = [0.6, 0.6]
102 | env_cfg.commands.ranges.lin_vel_y = [-0.0, -0.0]
103 | env_cfg.commands.ranges.ang_vel_yaw = [0.0, 0.0]
104 | env_cfg.commands.ranges.heading = [0, 0]
105 |
106 | env_cfg.commands.ranges.flat_lin_vel_x = [0.6, 0.6]
107 | env_cfg.commands.ranges.flat_lin_vel_y = [-0.0, -0.0]
108 | env_cfg.commands.ranges.flat_ang_vel_yaw = [0.0, 0.0]
109 |
110 | env_cfg.depth.use_camera = True
111 |
112 | # prepare environment
113 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
114 | _, _ = env.reset()
115 | obs = env.get_observations()
116 | # load policy
117 | train_cfg.runner.resume = True
118 | train_cfg.runner.load_run = 'WMP'
119 |
120 |
121 | train_cfg.runner.checkpoint = -1
122 | ppo_runner, train_cfg = task_registry.make_wmp_runner(env=env, name=args.task, args=args, train_cfg=train_cfg)
123 | policy = ppo_runner.get_inference_policy(device=env.device)
124 |
125 | # export policy as a jit module (used to run it from C++)
126 | if EXPORT_POLICY:
127 | path = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'policies')
128 | export_policy_as_jit(ppo_runner.alg.actor_critic, path)
129 | print('Exported policy as jit script to: ', path)
130 |
131 | logger = Logger(env.dt)
132 | robot_index = 0 # which robot is used for logging
133 | joint_index = 1 # which joint is used for logging
134 | stop_state_log = 100 # number of steps before plotting states
135 | stop_rew_log = env.max_episode_length + 1 # number of steps before print average episode rewards
136 | camera_position = np.array(env_cfg.viewer.pos, dtype=np.float64)
137 | camera_vel = np.array([1., 1., 0.])
138 | camera_direction = np.array(env_cfg.viewer.lookat) - np.array(env_cfg.viewer.pos)
139 | img_idx = 0
140 |
141 | history_length = 5
142 | trajectory_history = torch.zeros(size=(env.num_envs, history_length, env.num_obs -
143 | env.privileged_dim - env.height_dim - 3), device = env.device)
144 | obs_without_command = torch.concat((obs[:, env.privileged_dim:env.privileged_dim + 6],
145 | obs[:, env.privileged_dim + 9:-env.height_dim]), dim=1)
146 | trajectory_history = torch.concat((trajectory_history[:, 1:], obs_without_command.unsqueeze(1)), dim=1)
147 |
148 | world_model = ppo_runner._world_model.to(env.device)
149 | wm_latent = wm_action = None
150 | wm_is_first = torch.ones(env.num_envs, device=env.device)
151 | wm_update_interval = env.cfg.depth.update_interval
152 | wm_action_history = torch.zeros(size=(env.num_envs, wm_update_interval, env.num_actions),
153 | device=env.device)
154 | wm_obs = {
155 | "prop": obs[:, env.privileged_dim: env.privileged_dim + env.cfg.env.prop_dim],
156 | "is_first": wm_is_first,
157 | }
158 |
159 | if (env.cfg.depth.use_camera):
160 | wm_obs["image"] = torch.zeros(((env.num_envs,) + env.cfg.depth.resized + (1,)),
161 | device=world_model.device)
162 |
163 | wm_feature = torch.zeros((env.num_envs, ppo_runner.wm_feature_dim), device=env.device)
164 |
165 | total_reward = 0
166 | not_dones = torch.ones((env.num_envs,), device=env.device)
167 | for i in range(1*int(env.max_episode_length) + 3):
168 | if (env.global_counter % wm_update_interval == 0):
169 | if (env.cfg.depth.use_camera):
170 | wm_obs["image"][env.depth_index] = infos["depth"].unsqueeze(-1).to(world_model.device)
171 |
172 | wm_embed = world_model.encoder(wm_obs)
173 | wm_latent, _ = world_model.dynamics.obs_step(wm_latent, wm_action, wm_embed, wm_obs["is_first"], sample=True)
174 | wm_feature = world_model.dynamics.get_deter_feat(wm_latent)
175 | wm_is_first[:] = 0
176 |
177 | history = trajectory_history.flatten(1).to(env.device)
178 | actions = policy(obs.detach(), history.detach(), wm_feature.detach())
179 |
180 |
181 | obs, _, rews, dones, infos, reset_env_ids, _ = env.step(actions.detach())
182 |
183 | not_dones *= (~dones)
184 | total_reward += torch.mean(rews * not_dones)
185 |
186 | # update world model input
187 | wm_action_history = torch.concat(
188 | (wm_action_history[:, 1:], actions.unsqueeze(1)), dim=1)
189 | wm_obs = {
190 | "prop": obs[:, env.privileged_dim: env.privileged_dim + env.cfg.env.prop_dim],
191 | "is_first": wm_is_first,
192 | }
193 | if (env.cfg.depth.use_camera):
194 | wm_obs["image"] = torch.zeros(((env.num_envs,) + env.cfg.depth.resized + (1,)),
195 | device=world_model.device)
196 |
197 | reset_env_ids = reset_env_ids.cpu().numpy()
198 | if (len(reset_env_ids) > 0):
199 | wm_action_history[reset_env_ids, :] = 0
200 | wm_is_first[reset_env_ids] = 1
201 |
202 | wm_action = wm_action_history.flatten(1)
203 |
204 |
205 | # process trajectory history
206 | env_ids = dones.nonzero(as_tuple=False).flatten()
207 | trajectory_history[env_ids] = 0
208 | obs_without_command = torch.concat((obs[:, env.privileged_dim:env.privileged_dim + 6],
209 | obs[:, env.privileged_dim + 9:-env.height_dim]),
210 | dim=1)
211 | trajectory_history = torch.concat(
212 | (trajectory_history[:, 1:], obs_without_command.unsqueeze(1)), dim=1)
213 |
214 | if RECORD_FRAMES:
215 | if i % 2:
216 | filename = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'frames', f"{img_idx}.png")
217 | env.gym.write_viewer_image_to_file(env.viewer, filename)
218 | img_idx += 1
219 | if MOVE_CAMERA:
220 | lootat = env.root_states[8, :3]
221 | camara_position = lootat.detach().cpu().numpy() + [0, 1, 0]
222 | env.set_camera(camara_position, lootat)
223 |
224 | if i < stop_state_log:
225 | logger.log_states(
226 | {
227 | 'dof_pos_target': actions[robot_index, joint_index].item() * env.cfg.control.action_scale,
228 | 'dof_pos': env.dof_pos[robot_index, joint_index].item(),
229 | 'dof_vel': env.dof_vel[robot_index, joint_index].item(),
230 | 'dof_torque': env.torques[robot_index, joint_index].item(),
231 | 'command_x': env.commands[robot_index, 0].item(),
232 | 'command_y': env.commands[robot_index, 1].item(),
233 | 'command_yaw': env.commands[robot_index, 2].item(),
234 | 'base_vel_x': env.base_lin_vel[robot_index, 0].item(),
235 | 'base_vel_y': env.base_lin_vel[robot_index, 1].item(),
236 | 'base_vel_z': env.base_lin_vel[robot_index, 2].item(),
237 | 'base_vel_yaw': env.base_ang_vel[robot_index, 2].item(),
238 | 'contact_forces_z': env.contact_forces[robot_index, env.feet_indices, 2].cpu().numpy()
239 | }
240 | )
241 | if 0 < i < stop_rew_log:
242 | if infos["episode"]:
243 | num_episodes = torch.sum(env.reset_buf).item()
244 | if num_episodes>0:
245 | logger.log_rewards(infos["episode"], num_episodes)
246 | elif i==stop_rew_log:
247 | logger.print_rewards()
248 |
249 | print('total reward:', total_reward)
250 |
251 | if __name__ == '__main__':
252 | EXPORT_POLICY = True
253 | RECORD_FRAMES = False
254 | MOVE_CAMERA = True
255 | args = get_args()
256 | args.rl_device = args.sim_device
257 | play(args)
258 |
--------------------------------------------------------------------------------
/legged_gym/scripts/train.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | import numpy as np
35 | import os
36 | from datetime import datetime
37 |
38 | import inspect
39 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
40 | parentdir = os.path.dirname(os.path.dirname(currentdir))
41 | os.sys.path.insert(0, parentdir)
42 |
43 | import isaacgym
44 | from legged_gym.envs import *
45 | from legged_gym.utils import get_args, task_registry
46 | import torch
47 |
48 | def train(args):
49 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
50 |
51 | train_cfg.runner.run_name = 'WMP'
52 |
53 | train_cfg.runner.max_iterations = 100000
54 | train_cfg.runner.save_interval = 1000
55 |
56 | env, env_cfg = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
57 | ppo_runner, train_cfg = task_registry.make_wmp_runner(env=env, name=args.task, args=args, train_cfg=train_cfg)
58 | ppo_runner.learn(num_learning_iterations=train_cfg.runner.max_iterations, init_at_random_ep_len=True)
59 |
60 |
61 | if __name__ == '__main__':
62 | args = get_args()
63 | args.rl_device = args.sim_device
64 | train(args)
65 |
--------------------------------------------------------------------------------
/legged_gym/tests/test_env.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import numpy as np
32 | import os
33 | from datetime import datetime
34 |
35 | import isaacgym
36 | from legged_gym.envs import *
37 | from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger
38 |
39 | import torch
40 |
41 |
42 | def test_env(args):
43 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
44 | # override some parameters for testing
45 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 10)
46 |
47 | # prepare environment
48 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
49 | for i in range(int(10*env.max_episode_length)):
50 | actions = 0.*torch.ones(env.num_envs, env.num_actions, device=env.device)
51 | obs, _, rew, done, info = env.step(actions)
52 | print("Done")
53 |
54 | if __name__ == '__main__':
55 | args = get_args()
56 | test_env(args)
57 |
--------------------------------------------------------------------------------
/legged_gym/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .helpers import class_to_dict, get_load_path, get_args, export_policy_as_jit, set_seed, update_class_from_dict
32 | from .task_registry import task_registry
33 | from .logger import Logger
34 | from .math import *
35 | from .trimesh import *
36 | from .terrain import Terrain
37 |
--------------------------------------------------------------------------------
/legged_gym/utils/helpers.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | import os
35 | import copy
36 | import torch
37 | import numpy as np
38 | import random
39 | from isaacgym import gymapi
40 | from isaacgym import gymutil
41 |
42 | from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR
43 |
44 | def class_to_dict(obj) -> dict:
45 | if not hasattr(obj,"__dict__"):
46 | return obj
47 | result = {}
48 | for key in dir(obj):
49 | if key.startswith("_"):
50 | continue
51 | element = []
52 | val = getattr(obj, key)
53 | if isinstance(val, list):
54 | for item in val:
55 | element.append(class_to_dict(item))
56 | else:
57 | element = class_to_dict(val)
58 | result[key] = element
59 | return result
60 |
61 | def update_class_from_dict(obj, dict):
62 | for key, val in dict.items():
63 | attr = getattr(obj, key, None)
64 | if isinstance(attr, type):
65 | update_class_from_dict(attr, val)
66 | else:
67 | setattr(obj, key, val)
68 | return
69 |
70 | def set_seed(seed):
71 | if seed == -1:
72 | seed = np.random.randint(0, 10000)
73 | print("Setting seed: {}".format(seed))
74 |
75 | random.seed(seed)
76 | np.random.seed(seed)
77 | torch.manual_seed(seed)
78 | os.environ['PYTHONHASHSEED'] = str(seed)
79 | torch.cuda.manual_seed(seed)
80 | torch.cuda.manual_seed_all(seed)
81 |
82 | def parse_sim_params(args, cfg):
83 | # code from Isaac Gym Preview 2
84 | # initialize sim params
85 | sim_params = gymapi.SimParams()
86 |
87 | # set some values from args
88 | if args.physics_engine == gymapi.SIM_FLEX:
89 | if args.device != "cpu":
90 | print("WARNING: Using Flex with GPU instead of PHYSX!")
91 | elif args.physics_engine == gymapi.SIM_PHYSX:
92 | sim_params.physx.use_gpu = args.use_gpu
93 | sim_params.physx.num_subscenes = args.subscenes
94 | sim_params.use_gpu_pipeline = args.use_gpu_pipeline
95 |
96 | # if sim options are provided in cfg, parse them and update/override above:
97 | if "sim" in cfg:
98 | gymutil.parse_sim_config(cfg["sim"], sim_params)
99 |
100 | # Override num_threads if passed on the command line
101 | if args.physics_engine == gymapi.SIM_PHYSX and args.num_threads > 0:
102 | sim_params.physx.num_threads = args.num_threads
103 |
104 | return sim_params
105 |
106 | def get_load_path(root, load_run=-1, checkpoint=-1):
107 | try:
108 | runs = os.listdir(root)
109 | #TODO sort by date to handle change of month
110 | runs.sort()
111 | if 'exported' in runs: runs.remove('exported')
112 | last_run = os.path.join(root, runs[-1])
113 | except:
114 | raise ValueError("No runs in this directory: " + root)
115 | if load_run==-1:
116 | load_run = last_run
117 | else:
118 | load_run = os.path.join(root, load_run)
119 |
120 | if checkpoint==-1:
121 | models = [file for file in os.listdir(load_run) if 'model' in file]
122 | models.sort(key=lambda m: '{0:0>15}'.format(m))
123 | model = models[-1]
124 | else:
125 | model = "model_{}.pt".format(checkpoint)
126 |
127 | load_path = os.path.join(load_run, model)
128 | return load_path
129 |
130 | def update_cfg_from_args(env_cfg, cfg_train, args):
131 | # seed
132 | if env_cfg is not None:
133 | # num envs
134 | if args.num_envs is not None:
135 | env_cfg.env.num_envs = args.num_envs
136 | if cfg_train is not None:
137 | if args.seed is not None:
138 | cfg_train.seed = args.seed
139 | # alg runner parameters
140 | if args.max_iterations is not None:
141 | cfg_train.runner.max_iterations = args.max_iterations
142 | if args.resume:
143 | cfg_train.runner.resume = args.resume
144 | if args.experiment_name is not None:
145 | cfg_train.runner.experiment_name = args.experiment_name
146 | if args.run_name is not None:
147 | cfg_train.runner.run_name = args.run_name
148 | if args.load_run is not None:
149 | cfg_train.runner.load_run = args.load_run
150 | if args.checkpoint is not None:
151 | cfg_train.runner.checkpoint = args.checkpoint
152 |
153 | return env_cfg, cfg_train
154 |
155 | def get_args():
156 | custom_parameters = [
157 | {"name": "--task", "type": str, "default": "anymal_c_flat", "help": "Resume training or start testing from a checkpoint. Overrides config file if provided."},
158 | {"name": "--resume", "action": "store_true", "default": False, "help": "Resume training from a checkpoint"},
159 | {"name": "--experiment_name", "type": str, "help": "Name of the experiment to run or load. Overrides config file if provided."},
160 | {"name": "--run_name", "type": str, "help": "Name of the run. Overrides config file if provided."},
161 | {"name": "--load_run", "type": str, "help": "Name of the run to load when resume=True. If -1: will load the last run. Overrides config file if provided."},
162 | {"name": "--checkpoint", "type": int, "help": "Saved model checkpoint number. If -1: will load the last checkpoint. Overrides config file if provided."},
163 |
164 | {"name": "--headless", "action": "store_true", "default": False, "help": "Force display off at all times"},
165 | {"name": "--horovod", "action": "store_true", "default": False, "help": "Use horovod for multi-gpu training"},
166 | {"name": "--rl_device", "type": str, "default": "cuda:0", "help": 'Device used by the RL algorithm, (cpu, gpu, cuda:0, cuda:1 etc..)'},
167 | {"name": "--num_envs", "type": int, "help": "Number of environments to create. Overrides config file if provided."},
168 | {"name": "--seed", "type": int, "help": "Random seed. Overrides config file if provided."},
169 | {"name": "--max_iterations", "type": int, "help": "Maximum number of training iterations. Overrides config file if provided."},
170 | {"name": "--terrain", "type": str, "default": "climb",
171 | "help": 'Only for play'},
172 | {"name": "--wm_device", "type": str, "default": "None", "help": 'World model device. Overrides config file in dreamer/config.yaml if provided'},
173 |
174 | ]
175 | # parse arguments
176 | args = gymutil.parse_arguments(
177 | description="RL Policy",
178 | custom_parameters=custom_parameters)
179 |
180 | # name allignment
181 | args.sim_device_id = args.compute_device_id
182 | args.sim_device = args.sim_device_type
183 | if args.sim_device=='cuda':
184 | args.sim_device += f":{args.sim_device_id}"
185 | return args
186 |
187 | def export_policy_as_jit(actor_critic, path):
188 | if hasattr(actor_critic, 'memory_a'):
189 | # assumes LSTM: TODO add GRU
190 | exporter = PolicyExporterLSTM(actor_critic)
191 | exporter.export(path)
192 | else:
193 | os.makedirs(path, exist_ok=True)
194 | path = os.path.join(path, 'policy_1.pt')
195 | model = copy.deepcopy(actor_critic.actor).to('cpu')
196 | traced_script_module = torch.jit.script(model)
197 | traced_script_module.save(path)
198 |
199 | class PolicyExporterLSTM(torch.nn.Module):
200 | def __init__(self, actor_critic):
201 | super().__init__()
202 | self.actor = copy.deepcopy(actor_critic.actor)
203 | self.is_recurrent = actor_critic.is_recurrent
204 | self.memory = copy.deepcopy(actor_critic.memory_a.rnn)
205 | self.memory.cpu()
206 | self.register_buffer(f'hidden_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size))
207 | self.register_buffer(f'cell_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size))
208 |
209 | def forward(self, x):
210 | out, (h, c) = self.memory(x.unsqueeze(0), (self.hidden_state, self.cell_state))
211 | self.hidden_state[:] = h
212 | self.cell_state[:] = c
213 | return self.actor(out.squeeze(0))
214 |
215 | @torch.jit.export
216 | def reset_memory(self):
217 | self.hidden_state[:] = 0.
218 | self.cell_state[:] = 0.
219 |
220 | def export(self, path):
221 | os.makedirs(path, exist_ok=True)
222 | path = os.path.join(path, 'policy_lstm_1.pt')
223 | self.to('cpu')
224 | traced_script_module = torch.jit.script(self)
225 | traced_script_module.save(path)
226 |
227 |
228 |
--------------------------------------------------------------------------------
/legged_gym/utils/logger.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import matplotlib.pyplot as plt
32 | import numpy as np
33 | from collections import defaultdict
34 | from multiprocessing import Process, Value
35 |
36 | class Logger:
37 | def __init__(self, dt):
38 | self.state_log = defaultdict(list)
39 | self.rew_log = defaultdict(list)
40 | self.dt = dt
41 | self.num_episodes = 0
42 | self.plot_process = None
43 |
44 | def log_state(self, key, value):
45 | self.state_log[key].append(value)
46 |
47 | def log_states(self, dict):
48 | for key, value in dict.items():
49 | self.log_state(key, value)
50 |
51 | def log_rewards(self, dict, num_episodes):
52 | for key, value in dict.items():
53 | if 'rew' in key:
54 | self.rew_log[key].append(value.item() * num_episodes)
55 | self.num_episodes += num_episodes
56 |
57 | def reset(self):
58 | self.state_log.clear()
59 | self.rew_log.clear()
60 |
61 | def plot_states(self):
62 | self.plot_process = Process(target=self._plot)
63 | self.plot_process.start()
64 |
65 | def _plot(self):
66 | nb_rows = 3
67 | nb_cols = 3
68 | fig, axs = plt.subplots(nb_rows, nb_cols)
69 | for key, value in self.state_log.items():
70 | time = np.linspace(0, len(value)*self.dt, len(value))
71 | break
72 | log= self.state_log
73 | # plot joint targets and measured positions
74 | a = axs[1, 0]
75 | if log["dof_pos"]: a.plot(time, log["dof_pos"], label='measured')
76 | if log["dof_pos_target"]: a.plot(time, log["dof_pos_target"], label='target')
77 | a.set(xlabel='time [s]', ylabel='Position [rad]', title='DOF Position')
78 | a.legend()
79 | # plot joint velocity
80 | a = axs[1, 1]
81 | if log["dof_vel"]: a.plot(time, log["dof_vel"], label='measured')
82 | if log["dof_vel_target"]: a.plot(time, log["dof_vel_target"], label='target')
83 | a.set(xlabel='time [s]', ylabel='Velocity [rad/s]', title='Joint Velocity')
84 | a.legend()
85 | # plot base vel x
86 | a = axs[0, 0]
87 | if log["base_vel_x"]: a.plot(time, log["base_vel_x"], label='measured')
88 | if log["command_x"]: a.plot(time, log["command_x"], label='commanded')
89 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity x')
90 | a.legend()
91 | # plot base vel y
92 | a = axs[0, 1]
93 | if log["base_vel_y"]: a.plot(time, log["base_vel_y"], label='measured')
94 | if log["command_y"]: a.plot(time, log["command_y"], label='commanded')
95 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity y')
96 | a.legend()
97 | # plot base vel yaw
98 | a = axs[0, 2]
99 | if log["base_vel_yaw"]: a.plot(time, log["base_vel_yaw"], label='measured')
100 | if log["command_yaw"]: a.plot(time, log["command_yaw"], label='commanded')
101 | a.set(xlabel='time [s]', ylabel='base ang vel [rad/s]', title='Base velocity yaw')
102 | a.legend()
103 | # plot base vel z
104 | a = axs[1, 2]
105 | if log["base_vel_z"]: a.plot(time, log["base_vel_z"], label='measured')
106 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity z')
107 | a.legend()
108 | # plot contact forces
109 | a = axs[2, 0]
110 | if log["contact_forces_z"]:
111 | forces = np.array(log["contact_forces_z"])
112 | for i in range(forces.shape[1]):
113 | a.plot(time, forces[:, i], label=f'force {i}')
114 | a.set(xlabel='time [s]', ylabel='Forces z [N]', title='Vertical Contact forces')
115 | a.legend()
116 | # plot torque/vel curves
117 | a = axs[2, 1]
118 | if log["dof_vel"]!=[] and log["dof_torque"]!=[]: a.plot(log["dof_vel"], log["dof_torque"], 'x', label='measured')
119 | a.set(xlabel='Joint vel [rad/s]', ylabel='Joint Torque [Nm]', title='Torque/velocity curves')
120 | a.legend()
121 | # plot torques
122 | a = axs[2, 2]
123 | if log["dof_torque"]!=[]: a.plot(time, log["dof_torque"], label='measured')
124 | a.set(xlabel='time [s]', ylabel='Joint Torque [Nm]', title='Torque')
125 | a.legend()
126 | plt.show()
127 |
128 | def print_rewards(self):
129 | print("Average rewards per second:")
130 | for key, values in self.rew_log.items():
131 | mean = np.sum(np.array(values)) / self.num_episodes
132 | print(f" - {key}: {mean}")
133 | print(f"Total number of episodes: {self.num_episodes}")
134 |
135 | def __del__(self):
136 | if self.plot_process is not None:
137 | self.plot_process.kill()
--------------------------------------------------------------------------------
/legged_gym/utils/math.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import torch
32 | from torch import Tensor
33 | import numpy as np
34 | from isaacgym.torch_utils import quat_apply, normalize
35 | from typing import Tuple
36 |
37 | # @ torch.jit.script
38 | def quat_apply_yaw(quat, vec):
39 | quat_yaw = quat.clone().view(-1, 4)
40 | quat_yaw[:, :2] = 0.
41 | quat_yaw = normalize(quat_yaw)
42 | return quat_apply(quat_yaw, vec)
43 |
44 | # @ torch.jit.script
45 | def wrap_to_pi(angles):
46 | angles %= 2*np.pi
47 | angles -= 2*np.pi * (angles > np.pi)
48 | return angles
49 |
50 | # @ torch.jit.script
51 | def torch_rand_sqrt_float(lower, upper, shape, device):
52 | # type: (float, float, Tuple[int, int], str) -> Tensor
53 | r = 2*torch.rand(*shape, device=device) - 1
54 | r = torch.where(r<0., -torch.sqrt(-r), torch.sqrt(r))
55 | r = (r + 1.) / 2.
56 | return (upper - lower) * r + lower
--------------------------------------------------------------------------------
/legged_gym/utils/task_registry.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | import os
35 | from datetime import datetime
36 | from typing import Tuple
37 | import torch
38 | import numpy as np
39 |
40 | from rsl_rl.env import VecEnv
41 | from rsl_rl.runners import OnPolicyRunner, WMPRunner
42 |
43 | from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR
44 | from .helpers import get_args, update_cfg_from_args, class_to_dict, get_load_path, set_seed, parse_sim_params
45 | from legged_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO
46 |
47 | class TaskRegistry():
48 | def __init__(self):
49 | self.task_classes = {}
50 | self.env_cfgs = {}
51 | self.train_cfgs = {}
52 |
53 | def register(self, name: str, task_class: VecEnv, env_cfg: LeggedRobotCfg, train_cfg: LeggedRobotCfgPPO):
54 | self.task_classes[name] = task_class
55 | self.env_cfgs[name] = env_cfg
56 | self.train_cfgs[name] = train_cfg
57 |
58 | def get_task_class(self, name: str) -> VecEnv:
59 | return self.task_classes[name]
60 |
61 | def get_cfgs(self, name) -> Tuple[LeggedRobotCfg, LeggedRobotCfgPPO]:
62 | train_cfg = self.train_cfgs[name]
63 | env_cfg = self.env_cfgs[name]
64 | # copy seed
65 | env_cfg.seed = train_cfg.seed
66 | return env_cfg, train_cfg
67 |
68 | def make_env(self, name, args=None, env_cfg=None) -> Tuple[VecEnv, LeggedRobotCfg]:
69 | """ Creates an environment either from a registered namme or from the provided config file.
70 |
71 | Args:
72 | name (string): Name of a registered env.
73 | args (Args, optional): Isaac Gym comand line arguments. If None get_args() will be called. Defaults to None.
74 | env_cfg (Dict, optional): Environment config file used to override the registered config. Defaults to None.
75 |
76 | Raises:
77 | ValueError: Error if no registered env corresponds to 'name'
78 |
79 | Returns:
80 | isaacgym.VecTaskPython: The created environment
81 | Dict: the corresponding config file
82 | """
83 | # if no args passed get command line arguments
84 | if args is None:
85 | args = get_args()
86 | # check if there is a registered env with that name
87 | if name in self.task_classes:
88 | task_class = self.get_task_class(name)
89 | else:
90 | raise ValueError(f"Task with name: {name} was not registered")
91 | if env_cfg is None:
92 | # load config files
93 | env_cfg, _ = self.get_cfgs(name)
94 | # override cfg from args (if specified)
95 | env_cfg, _ = update_cfg_from_args(env_cfg, None, args)
96 | set_seed(env_cfg.seed)
97 | # parse sim params (convert to dict first)
98 | sim_params = {"sim": class_to_dict(env_cfg.sim)}
99 | sim_params = parse_sim_params(args, sim_params)
100 | env = task_class( cfg=env_cfg,
101 | sim_params=sim_params,
102 | physics_engine=args.physics_engine,
103 | sim_device=args.sim_device,
104 | headless=args.headless)
105 | return env, env_cfg
106 |
107 | def make_alg_runner(self, env, name=None, args=None, train_cfg=None, log_root="default") -> Tuple[OnPolicyRunner, LeggedRobotCfgPPO]:
108 | """ Creates the training algorithm either from a registered namme or from the provided config file.
109 |
110 | Args:
111 | env (isaacgym.VecTaskPython): The environment to train (TODO: remove from within the algorithm)
112 | name (string, optional): Name of a registered env. If None, the config file will be used instead. Defaults to None.
113 | args (Args, optional): Isaac Gym comand line arguments. If None get_args() will be called. Defaults to None.
114 | train_cfg (Dict, optional): Training config file. If None 'name' will be used to get the config file. Defaults to None.
115 | log_root (str, optional): Logging directory for Tensorboard. Set to 'None' to avoid logging (at test time for example).
116 | Logs will be saved in /_. Defaults to "default"=/logs/.
117 |
118 | Raises:
119 | ValueError: Error if neither 'name' or 'train_cfg' are provided
120 | Warning: If both 'name' or 'train_cfg' are provided 'name' is ignored
121 |
122 | Returns:
123 | PPO: The created algorithm
124 | Dict: the corresponding config file
125 | """
126 | # if no args passed get command line arguments
127 | if args is None:
128 | args = get_args()
129 | # if config files are passed use them, otherwise load from the name
130 | if train_cfg is None:
131 | if name is None:
132 | raise ValueError("Either 'name' or 'train_cfg' must be not None")
133 | # load config files
134 | _, train_cfg = self.get_cfgs(name)
135 | else:
136 | if name is not None:
137 | print(f"'train_cfg' provided -> Ignoring 'name={name}'")
138 | # override cfg from args (if specified)
139 | _, train_cfg = update_cfg_from_args(None, train_cfg, args)
140 |
141 | if log_root=="default":
142 | log_root = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name)
143 | log_dir = os.path.join(log_root, train_cfg.runner.run_name)
144 | # log_dir = os.path.join(log_root, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + train_cfg.runner.run_name)
145 | elif log_root is None:
146 | log_dir = None
147 | else:
148 | # log_dir = os.path.join(log_root, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + train_cfg.runner.run_name)
149 | log_dir = os.path.join(log_root, train_cfg.runner.run_name)
150 | runner_class = eval(train_cfg.runner_class_name)
151 | train_cfg_dict = class_to_dict(train_cfg)
152 | runner = runner_class(env, train_cfg_dict, log_dir, device=args.rl_device)
153 | #save resume path before creating a new log_dir
154 | resume = train_cfg.runner.resume
155 | if resume:
156 | # load previously trained model
157 | resume_path = get_load_path(log_root, load_run=train_cfg.runner.load_run, checkpoint=train_cfg.runner.checkpoint)
158 | print(f"Loading model from: {resume_path}")
159 | runner.load(resume_path)
160 | return runner, train_cfg
161 |
162 | def make_wmp_runner(self, env, name=None, args=None, train_cfg=None, log_root="default") -> Tuple[WMPRunner, LeggedRobotCfgPPO]:
163 | """ Creates the training algorithm either from a registered namme or from the provided config file.
164 |
165 | Args:
166 | env (isaacgym.VecTaskPython): The environment to train (TODO: remove from within the algorithm)
167 | name (string, optional): Name of a registered env. If None, the config file will be used instead. Defaults to None.
168 | args (Args, optional): Isaac Gym comand line arguments. If None get_args() will be called. Defaults to None.
169 | train_cfg (Dict, optional): Training config file. If None 'name' will be used to get the config file. Defaults to None.
170 | log_root (str, optional): Logging directory for Tensorboard. Set to 'None' to avoid logging (at test time for example).
171 | Logs will be saved in /_. Defaults to "default"=/logs/.
172 |
173 | Raises:
174 | ValueError: Error if neither 'name' or 'train_cfg' are provided
175 | Warning: If both 'name' or 'train_cfg' are provided 'name' is ignored
176 |
177 | Returns:
178 | PPO: The created algorithm
179 | Dict: the corresponding config file
180 | """
181 | # if no args passed get command line arguments
182 | if args is None:
183 | args = get_args()
184 | # if config files are passed use them, otherwise load from the name
185 | if train_cfg is None:
186 | if name is None:
187 | raise ValueError("Either 'name' or 'train_cfg' must be not None")
188 | # load config files
189 | _, train_cfg = self.get_cfgs(name)
190 | else:
191 | if name is not None:
192 | print(f"'train_cfg' provided -> Ignoring 'name={name}'")
193 | # override cfg from args (if specified)
194 | _, train_cfg = update_cfg_from_args(None, train_cfg, args)
195 |
196 | if log_root=="default":
197 | log_root = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name)
198 | log_dir = os.path.join(log_root, train_cfg.runner.run_name)
199 | # log_dir = os.path.join(log_root, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + train_cfg.runner.run_name)
200 | elif log_root is None:
201 | log_dir = None
202 | else:
203 | # log_dir = os.path.join(log_root, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + train_cfg.runner.run_name)
204 | log_dir = os.path.join(log_root, train_cfg.runner.run_name)
205 | # print(train_cfg.runner_class_name)
206 | train_cfg_dict = class_to_dict(train_cfg)
207 | runner = WMPRunner(env, train_cfg_dict, log_dir, device=args.rl_device)
208 | #save resume path before creating a new log_dir
209 | resume = train_cfg.runner.resume
210 | if resume:
211 | # load previously trained model
212 | resume_path = get_load_path(log_root, load_run=train_cfg.runner.load_run, checkpoint=train_cfg.runner.checkpoint)
213 | print(f"Loading model from: {resume_path}")
214 | runner.load(resume_path)
215 | return runner, train_cfg
216 |
217 |
218 | # make global task registry
219 | task_registry = TaskRegistry()
--------------------------------------------------------------------------------
/legged_gym/utils/trimesh.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2023 Ziwen Zhuang
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | """ This file defines a mesh as a tuple of (vertices, triangles)
24 | All operations are based on numpy ndarray
25 | - vertices: np ndarray of shape (n, 3) np.float32
26 | - triangles: np ndarray of shape (n_, 3) np.uint32
27 | """
28 | import numpy as np
29 |
30 | def box_trimesh(
31 | size, # float [3] for x, y, z axis length (in meter) under box frame
32 | center_position, # float [3] position (in meter) in world frame
33 | rpy= np.zeros(3), # euler angle (in rad) not implemented yet.
34 | ):
35 | if not (rpy == 0).all():
36 | raise NotImplementedError("Only axis-aligned box triangle mesh is implemented")
37 |
38 | vertices = np.empty((8, 3), dtype= np.float32)
39 | vertices[:] = center_position
40 | vertices[[0, 4, 2, 6], 0] -= size[0] / 2
41 | vertices[[1, 5, 3, 7], 0] += size[0] / 2
42 | vertices[[0, 1, 2, 3], 1] -= size[1] / 2
43 | vertices[[4, 5, 6, 7], 1] += size[1] / 2
44 | vertices[[2, 3, 6, 7], 2] -= size[2] / 2
45 | vertices[[0, 1, 4, 5], 2] += size[2] / 2
46 |
47 | triangles = -np.ones((12, 3), dtype= np.uint32)
48 | triangles[0] = [0, 2, 1] #
49 | triangles[1] = [1, 2, 3]
50 | triangles[2] = [0, 4, 2] #
51 | triangles[3] = [2, 4, 6]
52 | triangles[4] = [4, 5, 6] #
53 | triangles[5] = [5, 7, 6]
54 | triangles[6] = [1, 3, 5] #
55 | triangles[7] = [3, 7, 5]
56 | triangles[8] = [0, 1, 4] #
57 | triangles[9] = [1, 5, 4]
58 | triangles[10]= [2, 6, 3] #
59 | triangles[11]= [3, 6, 7]
60 |
61 | return vertices, triangles
62 |
63 | def combine_trimeshes(*trimeshes):
64 | if len(trimeshes) > 2:
65 | return combine_trimeshes(
66 | trimeshes[0],
67 | combine_trimeshes(trimeshes[1:])
68 | )
69 |
70 | # only two trimesh to combine
71 | trimesh_0, trimesh_1 = trimeshes
72 | if trimesh_0[1].shape[0] < trimesh_1[1].shape[0]:
73 | trimesh_0, trimesh_1 = trimesh_1, trimesh_0
74 |
75 | trimesh_1 = (trimesh_1[0], trimesh_1[1] + trimesh_0[0].shape[0])
76 | vertices = np.concatenate((trimesh_0[0], trimesh_1[0]), axis= 0)
77 | triangles = np.concatenate((trimesh_0[1], trimesh_1[1]), axis= 0)
78 |
79 | return vertices, triangles
80 |
81 | def move_trimesh(trimesh, move: np.ndarray):
82 | """ inplace operation """
83 | trimesh[0] += move
84 |
--------------------------------------------------------------------------------
/licenses/dependencies/matplotlib_license.txt:
--------------------------------------------------------------------------------
1 | 1. This LICENSE AGREEMENT is between the Matplotlib Development Team ("MDT"), and the Individual or Organization ("Licensee") accessing and otherwise using matplotlib software in source or binary form and its associated documentation.
2 |
3 | 2. Subject to the terms and conditions of this License Agreement, MDT hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use matplotlib 3.4.3 alone or in any derivative version, provided, however, that MDT's License Agreement and MDT's notice of copyright, i.e., "Copyright (c) 2012-2013 Matplotlib Development Team; All Rights Reserved" are retained in matplotlib 3.4.3 alone or in any derivative version prepared by Licensee.
4 |
5 | 3. In the event Licensee prepares a derivative work that is based on or incorporates matplotlib 3.4.3 or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to matplotlib 3.4.3.
6 |
7 | 4. MDT is making matplotlib 3.4.3 available to Licensee on an "AS IS" basis. MDT MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, MDT MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB 3.4.3 WILL NOT INFRINGE ANY THIRD PARTY RIGHTS.
8 |
9 | 5. MDT SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB 3.4.3 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING MATPLOTLIB 3.4.3, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
10 |
11 | 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions.
12 |
13 | 7. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between MDT and Licensee. This License Agreement does not grant permission to use MDT trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party.
14 |
15 | 8. By copying, installing or otherwise using matplotlib 3.4.3, Licensee agrees to be bound by the terms and conditions of this License Agreement.
--------------------------------------------------------------------------------
/licenses/subcomponents/dreamerv3-torch_license.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 NM512
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/licenses/subcomponents/leggedgym_license.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, ETH Zurich, Nikita Rudin
2 | Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification,
6 | are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice,
9 | this list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its contributors
16 | may be used to endorse or promote products derived from this software without
17 | specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
26 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
30 | See licenses/assets for license information for assets included in this repository.
31 | See licenses/dependencies for license information of dependencies of this package.
32 |
--------------------------------------------------------------------------------
/licenses/subcomponents/parkour_license.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ziwen Zhuang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | cachetools==4.2.4
3 | certifi==2021.10.8
4 | charset-normalizer==2.0.12
5 | cycler==0.11.0
6 | fastrlock==0.8
7 | fonttools==4.32.0
8 | google-auth==1.35.0
9 | google-auth-oauthlib==0.4.6
10 | grpcio==1.44.0
11 | idna==3.3
12 | imageio==2.17.0
13 | importlib-metadata==4.11.3
14 | joblib==1.1.0
15 | kiwisolver==1.4.2
16 | Markdown==3.3.6
17 | matplotlib==3.5.1
18 | ninja==1.11.1
19 | numpy==1.22.3
20 | oauthlib==3.2.0
21 | opencv-python==4.6.0.66
22 | packaging==21.3
23 | Pillow==9.1.0
24 | protobuf==3.20.0
25 | pyasn1==0.4.8
26 | pyasn1-modules==0.2.8
27 | pybullet==3.2.5
28 | pyparsing==3.0.8
29 | python-dateutil==2.8.2
30 | PyYAML==6.0
31 | requests==2.27.1
32 | requests-oauthlib==1.3.1
33 | rsa==4.8
34 | scikit-learn==1.1.2
35 | scipy==1.8.0
36 | six==1.16.0
37 | sklearn==0.0
38 | tensorboard==2.8.0
39 | tensorboard-data-server==0.6.1
40 | tensorboard-plugin-wit==1.8.1
41 | threadpoolctl==3.1.0
42 | trimesh==3.20.1
43 | typing_extensions==4.2.0
44 | urllib3==1.26.9
45 | Werkzeug==2.1.1
46 | zipp==3.8.0
47 |
--------------------------------------------------------------------------------
/resources/robots/a1/meshes/trunk_A1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/WMP/c232c115ada4517453ebded5019078ba055456de/resources/robots/a1/meshes/trunk_A1.png
--------------------------------------------------------------------------------
/rsl_rl/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
--------------------------------------------------------------------------------
/rsl_rl/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .ppo import PPO
32 | from .amp_ppo import AMPPPO
--------------------------------------------------------------------------------
/rsl_rl/algorithms/amp_discriminator.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 |
32 | import torch
33 | import torch.nn as nn
34 | import torch.utils.data
35 | from torch import autograd
36 |
37 | from rsl_rl.utils import utils
38 |
39 |
40 | class AMPDiscriminator(nn.Module):
41 | def __init__(
42 | self, input_dim, amp_reward_coef, hidden_layer_sizes, device, task_reward_lerp=0.0):
43 | super(AMPDiscriminator, self).__init__()
44 |
45 | self.device = device
46 | self.input_dim = input_dim
47 |
48 | self.amp_reward_coef = amp_reward_coef
49 | amp_layers = []
50 | curr_in_dim = input_dim
51 | for hidden_dim in hidden_layer_sizes:
52 | amp_layers.append(nn.Linear(curr_in_dim, hidden_dim))
53 | amp_layers.append(nn.ReLU())
54 | curr_in_dim = hidden_dim
55 | self.trunk = nn.Sequential(*amp_layers).to(device)
56 | self.amp_linear = nn.Linear(hidden_layer_sizes[-1], 1).to(device)
57 |
58 | self.trunk.train()
59 | self.amp_linear.train()
60 |
61 | self.task_reward_lerp = task_reward_lerp
62 |
63 | def forward(self, x):
64 | h = self.trunk(x)
65 | d = self.amp_linear(h)
66 | return d
67 |
68 | def compute_grad_pen(self,
69 | expert_state,
70 | expert_next_state,
71 | lambda_=10):
72 | expert_data = torch.cat([expert_state, expert_next_state], dim=-1)
73 | expert_data.requires_grad = True
74 |
75 | disc = self.amp_linear(self.trunk(expert_data))
76 | ones = torch.ones(disc.size(), device=disc.device)
77 | grad = autograd.grad(
78 | outputs=disc, inputs=expert_data,
79 | grad_outputs=ones, create_graph=True,
80 | retain_graph=True, only_inputs=True)[0]
81 |
82 | # Enforce that the grad norm approaches 0.
83 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 0).pow(2).mean()
84 | return grad_pen
85 |
86 | def predict_amp_reward(
87 | self, state, next_state, task_reward, normalizer=None):
88 | with torch.no_grad():
89 | self.eval()
90 | if normalizer is not None:
91 | state = normalizer.normalize_torch(state, self.device)
92 | next_state = normalizer.normalize_torch(next_state, self.device)
93 |
94 | d = self.amp_linear(self.trunk(torch.cat([state, next_state], dim=-1)))
95 | reward = self.amp_reward_coef * torch.clamp(1 - (1/4) * torch.square(d - 1), min=0)
96 | if self.task_reward_lerp > 0:
97 | reward = self._lerp_reward(reward, task_reward.unsqueeze(-1))
98 | self.train()
99 | return reward.squeeze(), d
100 |
101 | def _lerp_reward(self, disc_r, task_r):
102 | # r = (1.0 - self.task_reward_lerp) * disc_r + self.task_reward_lerp * task_r
103 | r = disc_r + task_r
104 | return r
--------------------------------------------------------------------------------
/rsl_rl/algorithms/ppo.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.optim as optim
34 |
35 | from rsl_rl.modules import ActorCritic
36 | from rsl_rl.storage import RolloutStorage
37 |
38 | class PPO:
39 | actor_critic: ActorCritic
40 | def __init__(self,
41 | actor_critic,
42 | num_learning_epochs=1,
43 | num_mini_batches=1,
44 | clip_param=0.2,
45 | gamma=0.998,
46 | lam=0.95,
47 | value_loss_coef=1.0,
48 | entropy_coef=0.0,
49 | learning_rate=1e-3,
50 | max_grad_norm=1.0,
51 | use_clipped_value_loss=True,
52 | schedule="fixed",
53 | desired_kl=0.01,
54 | device='cpu',
55 | ):
56 |
57 | self.device = device
58 |
59 | self.desired_kl = desired_kl
60 | self.schedule = schedule
61 | self.learning_rate = learning_rate
62 |
63 | # PPO components
64 | self.actor_critic = actor_critic
65 | self.actor_critic.to(self.device)
66 | self.storage = None # initialized later
67 | self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
68 | self.transition = RolloutStorage.Transition()
69 |
70 | # PPO parameters
71 | self.clip_param = clip_param
72 | self.num_learning_epochs = num_learning_epochs
73 | self.num_mini_batches = num_mini_batches
74 | self.value_loss_coef = value_loss_coef
75 | self.entropy_coef = entropy_coef
76 | self.gamma = gamma
77 | self.lam = lam
78 | self.max_grad_norm = max_grad_norm
79 | self.use_clipped_value_loss = use_clipped_value_loss
80 |
81 | def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, history_dim, wm_feature_dim):
82 | self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, history_dim=history_dim,
83 | wm_feature_dim = wm_feature_dim, device = self.device)
84 |
85 | def test_mode(self):
86 | self.actor_critic.test()
87 |
88 | def train_mode(self):
89 | self.actor_critic.train()
90 |
91 | def act(self, obs, critic_obs):
92 | if self.actor_critic.is_recurrent:
93 | self.transition.hidden_states = self.actor_critic.get_hidden_states()
94 | # Compute the actions and values
95 | self.transition.actions = self.actor_critic.act(obs).detach()
96 | self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
97 | self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
98 | self.transition.action_mean = self.actor_critic.action_mean.detach()
99 | self.transition.action_sigma = self.actor_critic.action_std.detach()
100 | # need to record obs and critic_obs before env.step()
101 | self.transition.observations = obs
102 | self.transition.critic_observations = critic_obs
103 | return self.transition.actions
104 |
105 | def process_env_step(self, rewards, dones, infos):
106 | self.transition.rewards = rewards.clone()
107 | self.transition.dones = dones
108 | # Bootstrapping on time outs
109 | if 'time_outs' in infos:
110 | self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
111 |
112 | # Record the transition
113 | self.storage.add_transitions(self.transition)
114 | self.transition.clear()
115 | self.actor_critic.reset(dones)
116 |
117 | def compute_returns(self, last_critic_obs):
118 | last_values= self.actor_critic.evaluate(last_critic_obs).detach()
119 | self.storage.compute_returns(last_values, self.gamma, self.lam)
120 |
121 | def update(self):
122 | mean_value_loss = 0
123 | mean_surrogate_loss = 0
124 | if self.actor_critic.is_recurrent:
125 | generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
126 | else:
127 | generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
128 | for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
129 | old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:
130 |
131 |
132 | self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
133 | actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
134 | value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
135 | mu_batch = self.actor_critic.action_mean
136 | sigma_batch = self.actor_critic.action_std
137 | entropy_batch = self.actor_critic.entropy
138 |
139 | # KL
140 | if self.desired_kl != None and self.schedule == 'adaptive':
141 | with torch.inference_mode():
142 | kl = torch.sum(
143 | torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
144 | kl_mean = torch.mean(kl)
145 |
146 | if kl_mean > self.desired_kl * 2.0:
147 | self.learning_rate = max(1e-5, self.learning_rate / 1.5)
148 | elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
149 | self.learning_rate = min(1e-2, self.learning_rate * 1.5)
150 |
151 | for param_group in self.optimizer.param_groups:
152 | param_group['lr'] = self.learning_rate
153 |
154 |
155 | # Surrogate loss
156 | ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
157 | surrogate = -torch.squeeze(advantages_batch) * ratio
158 | surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param,
159 | 1.0 + self.clip_param)
160 | surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
161 |
162 | # Value function loss
163 | if self.use_clipped_value_loss:
164 | value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param,
165 | self.clip_param)
166 | value_losses = (value_batch - returns_batch).pow(2)
167 | value_losses_clipped = (value_clipped - returns_batch).pow(2)
168 | value_loss = torch.max(value_losses, value_losses_clipped).mean()
169 | else:
170 | value_loss = (returns_batch - value_batch).pow(2).mean()
171 |
172 | loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
173 |
174 | # Gradient step
175 | self.optimizer.zero_grad()
176 | loss.backward()
177 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
178 | self.optimizer.step()
179 |
180 | mean_value_loss += value_loss.item()
181 | mean_surrogate_loss += surrogate_loss.item()
182 |
183 | num_updates = self.num_learning_epochs * self.num_mini_batches
184 | mean_value_loss /= num_updates
185 | mean_surrogate_loss /= num_updates
186 | self.storage.clear()
187 |
188 | return mean_value_loss, mean_surrogate_loss
189 |
--------------------------------------------------------------------------------
/rsl_rl/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/WMP/c232c115ada4517453ebded5019078ba055456de/rsl_rl/datasets/__init__.py
--------------------------------------------------------------------------------
/rsl_rl/datasets/motion_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utility functions for processing motion clips."""
17 |
18 | import os
19 | import inspect
20 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
21 | parentdir = os.path.dirname(os.path.dirname(currentdir))
22 | os.sys.path.insert(0, parentdir)
23 |
24 | import numpy as np
25 |
26 | from rsl_rl.datasets import pose3d
27 | from pybullet_utils import transformations
28 |
29 |
30 | def standardize_quaternion(q):
31 | """Returns a quaternion where q.w >= 0 to remove redundancy due to q = -q.
32 |
33 | Args:
34 | q: A quaternion to be standardized.
35 |
36 | Returns:
37 | A quaternion with q.w >= 0.
38 |
39 | """
40 | if q[-1] < 0:
41 | q = -q
42 | return q
43 |
44 |
45 | def normalize_rotation_angle(theta):
46 | """Returns a rotation angle normalized between [-pi, pi].
47 |
48 | Args:
49 | theta: angle of rotation (radians).
50 |
51 | Returns:
52 | An angle of rotation normalized between [-pi, pi].
53 |
54 | """
55 | norm_theta = theta
56 | if np.abs(norm_theta) > np.pi:
57 | norm_theta = np.fmod(norm_theta, 2 * np.pi)
58 | if norm_theta >= 0:
59 | norm_theta += -2 * np.pi
60 | else:
61 | norm_theta += 2 * np.pi
62 |
63 | return norm_theta
64 |
65 |
66 | def calc_heading(q):
67 | """Returns the heading of a rotation q, specified as a quaternion.
68 |
69 | The heading represents the rotational component of q along the vertical
70 | axis (z axis).
71 |
72 | Args:
73 | q: A quaternion that the heading is to be computed from.
74 |
75 | Returns:
76 | An angle representing the rotation about the z axis.
77 |
78 | """
79 | ref_dir = np.array([1, 0, 0])
80 | rot_dir = pose3d.QuaternionRotatePoint(ref_dir, q)
81 | heading = np.arctan2(rot_dir[1], rot_dir[0])
82 | return heading
83 |
84 |
85 | def calc_heading_rot(q):
86 | """Return a quaternion representing the heading rotation of q along the vertical axis (z axis).
87 |
88 | Args:
89 | q: A quaternion that the heading is to be computed from.
90 |
91 | Returns:
92 | A quaternion representing the rotation about the z axis.
93 |
94 | """
95 | heading = calc_heading(q)
96 | q_heading = transformations.quaternion_about_axis(heading, [0, 0, 1])
97 | return q_heading
98 |
--------------------------------------------------------------------------------
/rsl_rl/datasets/pose3d.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Utilities for 3D pose conversion."""
16 | import math
17 | import numpy as np
18 |
19 | from pybullet_utils import transformations
20 |
21 | VECTOR3_0 = np.zeros(3, dtype=np.float64)
22 | VECTOR3_1 = np.ones(3, dtype=np.float64)
23 | VECTOR3_X = np.array([1, 0, 0], dtype=np.float64)
24 | VECTOR3_Y = np.array([0, 1, 0], dtype=np.float64)
25 | VECTOR3_Z = np.array([0, 0, 1], dtype=np.float64)
26 |
27 | # QUATERNION_IDENTITY is the multiplicative identity 1.0 + 0i + 0j + 0k.
28 | # When interpreted as a rotation, it is the identity rotation.
29 | QUATERNION_IDENTITY = np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float64)
30 |
31 |
32 | def Vector3RandomNormal(sigma, mu=VECTOR3_0):
33 | """Returns a random 3D vector from a normal distribution.
34 |
35 | Each component is selected independently from a normal distribution.
36 |
37 | Args:
38 | sigma: Scale (or stddev) of distribution for all variables.
39 | mu: Mean of distribution for each variable.
40 |
41 | Returns:
42 | A 3D vector in a numpy array.
43 | """
44 |
45 | random_v3 = np.random.normal(scale=sigma, size=3) + mu
46 | return random_v3
47 |
48 |
49 | def Vector3RandomUniform(low=VECTOR3_0, high=VECTOR3_1):
50 | """Returns a 3D vector selected uniformly from the input box.
51 |
52 | Args:
53 | low: The min-value corner of the box.
54 | high: The max-value corner of the box.
55 |
56 | Returns:
57 | A 3D vector in a numpy array.
58 | """
59 |
60 | random_x = np.random.uniform(low=low[0], high=high[0])
61 | random_y = np.random.uniform(low=low[1], high=high[1])
62 | random_z = np.random.uniform(low=low[2], high=high[2])
63 | return np.array([random_x, random_y, random_z])
64 |
65 |
66 | def Vector3RandomUnit():
67 | """Returns a random 3D vector with unit length.
68 |
69 | Generates a 3D vector selected uniformly from the unit sphere.
70 |
71 | Returns:
72 | A normalized 3D vector in a numpy array.
73 | """
74 | longitude = np.random.uniform(low=-math.pi, high=math.pi)
75 | sin_latitude = np.random.uniform(low=-1.0, high=1.0)
76 | cos_latitude = math.sqrt(1.0 - sin_latitude * sin_latitude)
77 | x = math.cos(longitude) * cos_latitude
78 | y = math.sin(longitude) * cos_latitude
79 | z = sin_latitude
80 | return np.array([x, y, z], dtype=np.float64)
81 |
82 |
83 | def QuaternionNormalize(q):
84 | """Normalizes the quaternion to length 1.
85 |
86 | Divides the quaternion by its magnitude. If the magnitude is too
87 | small, returns the quaternion identity value (1.0).
88 |
89 | Args:
90 | q: A quaternion to be normalized.
91 |
92 | Raises:
93 | ValueError: If input quaternion has length near zero.
94 |
95 | Returns:
96 | A quaternion with magnitude 1 in a numpy array [x, y, z, w].
97 |
98 | """
99 | q_norm = np.linalg.norm(q)
100 | if np.isclose(q_norm, 0.0):
101 | raise ValueError(
102 | 'Quaternion may not be zero in QuaternionNormalize: |q| = %f, q = %s' %
103 | (q_norm, q))
104 | return q / q_norm
105 |
106 |
107 | def QuaternionFromAxisAngle(axis, angle):
108 | """Returns a quaternion that generates the given axis-angle rotation.
109 |
110 | Returns the quaternion: sin(angle/2) * axis + cos(angle/2).
111 |
112 | Args:
113 | axis: Axis of rotation, a 3D vector in a numpy array.
114 | angle: The angle of rotation (radians).
115 |
116 | Raises:
117 | ValueError: If input axis is not a normalizable 3D vector.
118 |
119 | Returns:
120 | A unit quaternion in a numpy array.
121 |
122 | """
123 | if len(axis) != 3:
124 | raise ValueError('Axis vector should have three components: %s' % axis)
125 | axis_norm = np.linalg.norm(axis)
126 | if np.isclose(axis_norm, 0.0):
127 | raise ValueError('Axis vector may not have zero length: |v| = %f, v = %s' %
128 | (axis_norm, axis))
129 | half_angle = angle * 0.5
130 | q = np.zeros(4, dtype=np.float64)
131 | q[0:3] = axis
132 | q[0:3] *= math.sin(half_angle) / axis_norm
133 | q[3] = math.cos(half_angle)
134 | return q
135 |
136 |
137 | def QuaternionToAxisAngle(quat, default_axis=VECTOR3_Z, direction_axis=None):
138 | """Calculates axis and angle of rotation performed by a quaternion.
139 |
140 | Calculates the axis and angle of the rotation performed by the quaternion.
141 | The quaternion should have four values and be normalized.
142 |
143 | Args:
144 | quat: Unit quaternion in a numpy array.
145 | default_axis: 3D vector axis used if the rotation is near to zero. Without
146 | this default, small rotations would result in an exception. It is
147 | reasonable to use a default axis for tiny rotations, because zero angle
148 | rotations about any axis are equivalent.
149 | direction_axis: Used to disambiguate rotation directions. If the
150 | direction_axis is specified, the axis of the rotation will be chosen such
151 | that its inner product with the direction_axis is non-negative.
152 |
153 | Raises:
154 | ValueError: If quat is not a normalized quaternion.
155 |
156 | Returns:
157 | axis: Axis of rotation.
158 | angle: Angle in radians.
159 | """
160 | if len(quat) != 4:
161 | raise ValueError(
162 | 'Quaternion should have four components [x, y, z, w]: %s' % quat)
163 | if not np.isclose(1.0, np.linalg.norm(quat)):
164 | raise ValueError('Quaternion should have unit length: |q| = %f, q = %s' %
165 | (np.linalg.norm(quat), quat))
166 | axis = quat[:3].copy()
167 | axis_norm = np.linalg.norm(axis)
168 | min_axis_norm = 1e-8
169 | if axis_norm < min_axis_norm:
170 | axis = default_axis
171 | if len(default_axis) != 3:
172 | raise ValueError('Axis vector should have three components: %s' % axis)
173 | if not np.isclose(np.linalg.norm(axis), 1.0):
174 | raise ValueError('Axis vector should have unit length: |v| = %f, v = %s' %
175 | (np.linalg.norm(axis), axis))
176 | else:
177 | axis /= axis_norm
178 | sin_half_angle = axis_norm
179 | if direction_axis is not None and np.inner(axis, direction_axis) < 0:
180 | sin_half_angle = -sin_half_angle
181 | axis = -axis
182 | cos_half_angle = quat[3]
183 | half_angle = math.atan2(sin_half_angle, cos_half_angle)
184 | angle = half_angle * 2
185 | return axis, angle
186 |
187 |
188 | def QuaternionRandomRotation(max_angle=math.pi):
189 | """Creates a random small rotation around a random axis.
190 |
191 | Generates a small rotation with the axis vector selected uniformly
192 | from the unit sphere and an angle selected from a uniform
193 | distribution over [0, max_angle].
194 |
195 | If the max_angle is not specified, the rotation should be selected
196 | uniformly over all possible rotation angles.
197 |
198 | Args:
199 | max_angle: The maximum angle of rotation (radians).
200 |
201 | Returns:
202 | A unit quaternion in a numpy array.
203 |
204 | """
205 |
206 | angle = np.random.uniform(low=0, high=max_angle)
207 | axis = Vector3RandomUnit()
208 | return QuaternionFromAxisAngle(axis, angle)
209 |
210 |
211 | def QuaternionRotatePoint(point, quat):
212 | """Performs a rotation by quaternion.
213 |
214 | Rotate the point by the quaternion using quaternion multiplication,
215 | (q * p * q^-1), without constructing the rotation matrix.
216 |
217 | Args:
218 | point: The point to be rotated.
219 | quat: The rotation represented as a quaternion [x, y, z, w].
220 |
221 | Returns:
222 | A 3D vector in a numpy array.
223 | """
224 |
225 | q_point = np.array([point[0], point[1], point[2], 0.0])
226 | quat_inverse = transformations.quaternion_inverse(quat)
227 | q_point_rotated = transformations.quaternion_multiply(
228 | transformations.quaternion_multiply(quat, q_point), quat_inverse)
229 | return q_point_rotated[:3]
230 |
231 |
232 | def IsRotationMatrix(m):
233 | """Returns true if the 3x3 submatrix represents a rotation.
234 |
235 | Args:
236 | m: A transformation matrix.
237 |
238 | Raises:
239 | ValueError: If input is not a matrix of size at least 3x3.
240 |
241 | Returns:
242 | True if the 3x3 submatrix is a rotation (orthogonal).
243 | """
244 | if len(m.shape) != 2 or m.shape[0] < 3 or m.shape[1] < 3:
245 | raise ValueError('Matrix should be 3x3 or 4x4: %s\n %s' % (m.shape, m))
246 | rot = m[:3, :3]
247 | eye = np.matmul(rot, np.transpose(rot))
248 | return np.isclose(eye, np.identity(3), atol=1e-4).all()
249 |
250 | # def ZAxisAlignedRobotPoseTool(robot_pose_tool):
251 | # """Returns the current gripper pose rotated for alignment with the z-axis.
252 |
253 | # Args:
254 | # robot_pose_tool: a pose3d.Pose3d() instance.
255 |
256 | # Returns:
257 | # An instance of pose.Transform representing the current gripper pose
258 | # rotated for alignment with the z-axis.
259 | # """
260 | # # Align the current pose to the z-axis.
261 | # robot_pose_tool.quaternion = transformations.quaternion_multiply(
262 | # RotationBetween(
263 | # robot_pose_tool.matrix4x4[0:3, 0:3].dot(np.array([0, 0, 1])),
264 | # np.array([0.0, 0.0, -1.0])), robot_pose_tool.quaternion)
265 | # return robot_pose_tool
266 |
267 | # def RotationBetween(a_translation_b, a_translation_c):
268 | # """Computes the rotation from one vector to another.
269 |
270 | # The computed rotation has the property that:
271 |
272 | # a_translation_c = a_rotation_b_to_c * a_translation_b
273 |
274 | # Args:
275 | # a_translation_b: vec3, vector to rotate from
276 | # a_translation_c: vec3, vector to rotate to
277 |
278 | # Returns:
279 | # a_rotation_b_to_c: new Orientation
280 | # """
281 | # rotation = rotation3.Rotation3.rotation_between(
282 | # a_translation_b, a_translation_c, err_msg='RotationBetween')
283 | # return rotation.quaternion.xyzw
284 |
--------------------------------------------------------------------------------
/rsl_rl/env/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .vec_env import VecEnv
--------------------------------------------------------------------------------
/rsl_rl/env/vec_env.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from abc import ABC, abstractmethod
32 | import torch
33 | from typing import Tuple, Union
34 |
35 | # minimal interface of the environment
36 | class VecEnv(ABC):
37 | num_envs: int
38 | num_obs: int
39 | num_privileged_obs: int
40 | num_actions: int
41 | max_episode_length: int
42 | privileged_obs_buf: torch.Tensor
43 | obs_buf: torch.Tensor
44 | rew_buf: torch.Tensor
45 | reset_buf: torch.Tensor
46 | episode_length_buf: torch.Tensor # current episode duration
47 | extras: dict
48 | device: torch.device
49 | @abstractmethod
50 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]:
51 | pass
52 | @abstractmethod
53 | def reset(self, env_ids: Union[list, torch.Tensor]):
54 | pass
55 | @abstractmethod
56 | def get_observations(self) -> torch.Tensor:
57 | pass
58 | @abstractmethod
59 | def get_privileged_observations(self) -> Union[torch.Tensor, None]:
60 | pass
--------------------------------------------------------------------------------
/rsl_rl/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .actor_critic import ActorCritic
32 | from .actor_critic_wmp import ActorCriticWMP
33 | from .actor_critic_recurrent import ActorCriticRecurrent
34 | from .depth_predictor import DepthPredictor
--------------------------------------------------------------------------------
/rsl_rl/modules/actor_critic.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import numpy as np
32 |
33 | import torch
34 | import torch.nn as nn
35 | from torch.distributions import Normal
36 | from torch.nn.modules import rnn
37 |
38 |
39 | class ActorCritic(nn.Module):
40 | is_recurrent = False
41 |
42 | def __init__(self, num_actor_obs,
43 | num_critic_obs,
44 | num_actions,
45 | actor_hidden_dims=[256, 256, 256],
46 | critic_hidden_dims=[256, 256, 256],
47 | activation='elu',
48 | init_noise_std=1.0,
49 | fixed_std=False,
50 | **kwargs):
51 | if kwargs:
52 | print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str(
53 | [key for key in kwargs.keys()]))
54 | super(ActorCritic, self).__init__()
55 |
56 | activation = get_activation(activation)
57 |
58 | mlp_input_dim_a = num_actor_obs
59 | mlp_input_dim_c = num_critic_obs
60 |
61 | # Policy
62 | actor_layers = []
63 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
64 | actor_layers.append(activation)
65 | for l in range(len(actor_hidden_dims)):
66 | if l == len(actor_hidden_dims) - 1:
67 | actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
68 | else:
69 | actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
70 | actor_layers.append(activation)
71 | self.actor = nn.Sequential(*actor_layers)
72 |
73 | # Value function
74 | critic_layers = []
75 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
76 | critic_layers.append(activation)
77 | for l in range(len(critic_hidden_dims)):
78 | if l == len(critic_hidden_dims) - 1:
79 | critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
80 | else:
81 | critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
82 | critic_layers.append(activation)
83 | self.critic = nn.Sequential(*critic_layers)
84 |
85 | print(f"Actor MLP: {self.actor}")
86 | print(f"Critic MLP: {self.critic}")
87 |
88 | # Action noise
89 | self.fixed_std = fixed_std
90 | std = init_noise_std * torch.ones(num_actions)
91 | self.std = torch.tensor(std) if fixed_std else nn.Parameter(std)
92 | self.distribution = None
93 | # disable args validation for speedup
94 | Normal.set_default_validate_args = False
95 |
96 | # seems that we get better performance without init
97 | # self.init_memory_weights(self.memory_a, 0.001, 0.)
98 | # self.init_memory_weights(self.memory_c, 0.001, 0.)
99 |
100 | @staticmethod
101 | # not used at the moment
102 | def init_weights(sequential, scales):
103 | [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in
104 | enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))]
105 |
106 | def reset(self, dones=None):
107 | pass
108 |
109 | def forward(self):
110 | raise NotImplementedError
111 |
112 | @property
113 | def action_mean(self):
114 | return self.distribution.mean
115 |
116 | @property
117 | def action_std(self):
118 | return self.distribution.stddev
119 |
120 | @property
121 | def entropy(self):
122 | return self.distribution.entropy().sum(dim=-1)
123 |
124 | def update_distribution(self, observations):
125 | mean = self.actor(observations)
126 | std = self.std.to(mean.device)
127 | self.distribution = Normal(mean, mean * 0. + std)
128 |
129 | def act(self, observations, **kwargs):
130 | self.update_distribution(observations)
131 | return self.distribution.sample()
132 |
133 | def get_actions_log_prob(self, actions):
134 | return self.distribution.log_prob(actions).sum(dim=-1)
135 |
136 | def act_inference(self, observations):
137 | actions_mean = self.actor(observations)
138 | return actions_mean
139 |
140 | def evaluate(self, critic_observations, **kwargs):
141 | value = self.critic(critic_observations)
142 | return value
143 |
144 |
145 | def get_activation(act_name):
146 | if act_name == "elu":
147 | return nn.ELU()
148 | elif act_name == "selu":
149 | return nn.SELU()
150 | elif act_name == "relu":
151 | return nn.ReLU()
152 | elif act_name == "crelu":
153 | return nn.ReLU()
154 | elif act_name == "lrelu":
155 | return nn.LeakyReLU()
156 | elif act_name == "tanh":
157 | return nn.Tanh()
158 | elif act_name == "sigmoid":
159 | return nn.Sigmoid()
160 | else:
161 | print("invalid activation function!")
162 | return None
--------------------------------------------------------------------------------
/rsl_rl/modules/actor_critic_recurrent.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import numpy as np
32 |
33 | import torch
34 | import torch.nn as nn
35 | from torch.distributions import Normal
36 | from torch.nn.modules import rnn
37 | from .actor_critic import ActorCritic, get_activation
38 | from rsl_rl.utils import unpad_trajectories
39 |
40 | class ActorCriticRecurrent(ActorCritic):
41 | is_recurrent = True
42 | def __init__(self, num_actor_obs,
43 | num_critic_obs,
44 | num_actions,
45 | actor_hidden_dims=[256, 256, 256],
46 | critic_hidden_dims=[256, 256, 256],
47 | activation='elu',
48 | rnn_type='lstm',
49 | rnn_hidden_size=256,
50 | rnn_num_layers=1,
51 | init_noise_std=1.0,
52 | **kwargs):
53 | if kwargs:
54 | print("ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),)
55 |
56 | super().__init__(num_actor_obs=rnn_hidden_size,
57 | num_critic_obs=rnn_hidden_size,
58 | num_actions=num_actions,
59 | actor_hidden_dims=actor_hidden_dims,
60 | critic_hidden_dims=critic_hidden_dims,
61 | activation=activation,
62 | init_noise_std=init_noise_std)
63 |
64 | activation = get_activation(activation)
65 |
66 | self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
67 | self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
68 |
69 | print(f"Actor RNN: {self.memory_a}")
70 | print(f"Critic RNN: {self.memory_c}")
71 |
72 | def reset(self, dones=None):
73 | self.memory_a.reset(dones)
74 | self.memory_c.reset(dones)
75 |
76 | def act(self, observations, masks=None, hidden_states=None):
77 | input_a = self.memory_a(observations, masks, hidden_states)
78 | return super().act(input_a.squeeze(0))
79 |
80 | def act_inference(self, observations):
81 | input_a = self.memory_a(observations)
82 | return super().act_inference(input_a.squeeze(0))
83 |
84 | def evaluate(self, critic_observations, masks=None, hidden_states=None):
85 | input_c = self.memory_c(critic_observations, masks, hidden_states)
86 | return super().evaluate(input_c.squeeze(0))
87 |
88 | def get_hidden_states(self):
89 | return self.memory_a.hidden_states, self.memory_c.hidden_states
90 |
91 |
92 | class Memory(torch.nn.Module):
93 | def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
94 | super().__init__()
95 | # RNN
96 | rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
97 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
98 | self.hidden_states = None
99 |
100 | def forward(self, input, masks=None, hidden_states=None):
101 | batch_mode = masks is not None
102 | if batch_mode:
103 | # batch mode (policy update): need saved hidden states
104 | if hidden_states is None:
105 | raise ValueError("Hidden states not passed to memory module during policy update")
106 | out, _ = self.rnn(input, hidden_states)
107 | out = unpad_trajectories(out, masks)
108 | else:
109 | # inference mode (collection): use hidden states of last step
110 | out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
111 | return out
112 |
113 | def reset(self, dones=None):
114 | # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
115 | if (self.hidden_states is not None):
116 | for i in range(len(self.hidden_states)):
117 | self.hidden_states[i][..., dones, :] = 0.0
--------------------------------------------------------------------------------
/rsl_rl/modules/actor_critic_wmp.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | import numpy as np
35 |
36 | import torch
37 | import torch.nn as nn
38 | from torch.distributions import Normal
39 | from torch.nn.modules import rnn
40 |
41 |
42 | class ActorCriticWMP(nn.Module):
43 | is_recurrent = False
44 |
45 | def __init__(self, num_actor_obs,
46 | num_critic_obs,
47 | num_actions,
48 | encoder_hidden_dims=[256, 128],
49 | wm_encoder_hidden_dims = [64, 32],
50 | actor_hidden_dims=[256, 256, 256],
51 | critic_hidden_dims=[256, 256, 256],
52 | activation='elu',
53 | init_noise_std=1.0,
54 | fixed_std=False,
55 | latent_dim = 32,
56 | height_dim=187,
57 | privileged_dim=3 + 24,
58 | history_dim = 42*5,
59 | wm_feature_dim = 1536,
60 | wm_latent_dim=16,
61 | **kwargs):
62 | if kwargs:
63 | print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str(
64 | [key for key in kwargs.keys()]))
65 | super(ActorCriticWMP, self).__init__()
66 |
67 | activation = get_activation(activation)
68 |
69 | self.latent_dim = latent_dim
70 | self.height_dim = height_dim
71 | self.privileged_dim = privileged_dim
72 |
73 | mlp_input_dim_a = latent_dim + 3 + wm_latent_dim #latent vector + command + wm_latent
74 | mlp_input_dim_c = num_critic_obs + wm_latent_dim
75 |
76 | # History Encoder
77 | encoder_layers = []
78 | encoder_layers.append(nn.Linear(history_dim, encoder_hidden_dims[0]))
79 | encoder_layers.append(activation)
80 | for l in range(len(encoder_hidden_dims)):
81 | if l == len(encoder_hidden_dims) - 1:
82 | encoder_layers.append(nn.Linear(encoder_hidden_dims[l], latent_dim))
83 | else:
84 | encoder_layers.append(nn.Linear(encoder_hidden_dims[l], encoder_hidden_dims[l + 1]))
85 | encoder_layers.append(activation)
86 | self.history_encoder = nn.Sequential(*encoder_layers)
87 |
88 | # World Model Feature Encoder
89 | wm_encoder_layers = []
90 | wm_encoder_layers.append(nn.Linear(wm_feature_dim, wm_encoder_hidden_dims[0]))
91 | wm_encoder_layers.append(activation)
92 | for l in range(len(wm_encoder_hidden_dims)):
93 | if l == len(wm_encoder_hidden_dims) - 1:
94 | wm_encoder_layers.append(nn.Linear(wm_encoder_hidden_dims[l], wm_latent_dim))
95 | else:
96 | wm_encoder_layers.append(nn.Linear(wm_encoder_hidden_dims[l], wm_encoder_hidden_dims[l + 1]))
97 | wm_encoder_layers.append(activation)
98 | self.wm_feature_encoder = nn.Sequential(*wm_encoder_layers)
99 |
100 | # Critic World Model Feature Encoder
101 | critic_wm_encoder_layers = []
102 | critic_wm_encoder_layers.append(nn.Linear(wm_feature_dim, wm_encoder_hidden_dims[0]))
103 | critic_wm_encoder_layers.append(activation)
104 | for l in range(len(wm_encoder_hidden_dims)):
105 | if l == len(wm_encoder_hidden_dims) - 1:
106 | critic_wm_encoder_layers.append(nn.Linear(wm_encoder_hidden_dims[l], wm_latent_dim))
107 | else:
108 | critic_wm_encoder_layers.append(nn.Linear(wm_encoder_hidden_dims[l], wm_encoder_hidden_dims[l + 1]))
109 | critic_wm_encoder_layers.append(activation)
110 | self.critic_wm_feature_encoder = nn.Sequential(*critic_wm_encoder_layers)
111 |
112 | # Policy
113 | actor_layers = []
114 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
115 | actor_layers.append(activation)
116 | for l in range(len(actor_hidden_dims)):
117 | if l == len(actor_hidden_dims) - 1:
118 | actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
119 | # actor_layers.append(nn.Tanh())
120 | else:
121 | actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
122 | actor_layers.append(activation)
123 | self.actor = nn.Sequential(*actor_layers)
124 |
125 | # Value function
126 | critic_layers = []
127 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
128 | critic_layers.append(activation)
129 | for l in range(len(critic_hidden_dims)):
130 | if l == len(critic_hidden_dims) - 1:
131 | critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
132 | else:
133 | critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
134 | critic_layers.append(activation)
135 |
136 | self.critic = nn.Sequential(*critic_layers)
137 |
138 |
139 |
140 | print(f"Actor MLP: {self.actor}")
141 | print(f"Critic MLP: {self.critic}")
142 |
143 | # Action noise
144 | self.fixed_std = fixed_std
145 | std = init_noise_std * torch.ones(num_actions)
146 | self.std = torch.tensor(std) if fixed_std else nn.Parameter(std)
147 | self.distribution = None
148 | # disable args validation for speedup
149 | Normal.set_default_validate_args = False
150 |
151 | # seems that we get better performance without init
152 | # self.init_memory_weights(self.memory_a, 0.001, 0.)
153 | # self.init_memory_weights(self.memory_c, 0.001, 0.)
154 |
155 | @staticmethod
156 | # not used at the moment
157 | def init_weights(sequential, scales):
158 | [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in
159 | enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))]
160 |
161 | def reset(self, dones=None):
162 | pass
163 |
164 | def forward(self):
165 | raise NotImplementedError
166 |
167 | @property
168 | def action_mean(self):
169 | return self.distribution.mean
170 |
171 | @property
172 | def action_std(self):
173 | return self.distribution.stddev
174 |
175 | @property
176 | def entropy(self):
177 | return self.distribution.entropy().sum(dim=-1)
178 |
179 | def update_distribution(self, observations):
180 | mean = self.actor(observations)
181 | std = self.std.to(mean.device)
182 | self.distribution = Normal(mean, mean * 0. + std)
183 |
184 | def act(self, observations, history, wm_feature, **kwargs):
185 | latent_vector = self.history_encoder(history)
186 | command = observations[:, self.privileged_dim + 6:self.privileged_dim + 9]
187 | wm_latent_vector = self.wm_feature_encoder(wm_feature)
188 | concat_observations = torch.concat((latent_vector, command, wm_latent_vector),
189 | dim=-1)
190 | self.update_distribution(concat_observations)
191 | return self.distribution.sample()
192 |
193 | def get_latent_vector(self, observations, history, **kwargs):
194 | latent_vector = self.history_encoder(history)
195 | return latent_vector
196 |
197 | def get_linear_vel(self, observations, history, **kwargs):
198 | latent_vector = self.history_encoder(history)
199 | linear_vel = latent_vector[:,-3:]
200 | return linear_vel
201 |
202 | def get_actions_log_prob(self, actions):
203 | return self.distribution.log_prob(actions).sum(dim=-1)
204 |
205 | def act_inference(self, observations, history, wm_feature):
206 | latent_vector = self.history_encoder(history)
207 | command = observations[:, self.privileged_dim + 6:self.privileged_dim + 9]
208 | wm_latent_vector = self.wm_feature_encoder(wm_feature)
209 | concat_observations = torch.concat((latent_vector, command, wm_latent_vector),
210 | dim=-1)
211 | actions_mean = self.actor(concat_observations)
212 | return actions_mean
213 |
214 | def evaluate(self, critic_observations, wm_feature, **kwargs):
215 | wm_latent_vector = self.critic_wm_feature_encoder(wm_feature)
216 | concat_observations = torch.concat((critic_observations, wm_latent_vector),
217 | dim=-1)
218 |
219 |
220 | value = self.critic(concat_observations)
221 | return value
222 |
223 |
224 | def get_activation(act_name):
225 | if act_name == "elu":
226 | return nn.ELU()
227 | elif act_name == "selu":
228 | return nn.SELU()
229 | elif act_name == "relu":
230 | return nn.ReLU()
231 | elif act_name == "crelu":
232 | return nn.ReLU()
233 | elif act_name == "lrelu":
234 | return nn.LeakyReLU()
235 | elif act_name == "tanh":
236 | return nn.Tanh()
237 | elif act_name == "sigmoid":
238 | return nn.Sigmoid()
239 | else:
240 | print("invalid activation function!")
241 | return None
--------------------------------------------------------------------------------
/rsl_rl/modules/depth_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (2024) Bytedance Ltd. and/or its affiliates
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 |
17 | import torch
18 | import torch.nn as nn
19 | from torch.distributions import Normal
20 | from torch.nn.modules import rnn
21 | import torch.nn.functional as F
22 | from dreamer import tools
23 |
24 | class DepthPredictor(nn.Module):
25 | def __init__(self, forward_heightamp_dim = 525,
26 | prop_dim = 33,
27 | depth_image_dims = [64, 64],
28 | encoder_hidden_dims=[256, 128],
29 | depth=32,
30 | act="ELU",
31 | norm=True,
32 | kernel_size=4,
33 | minres=4,
34 | outscale=1.0,
35 | cnn_sigmoid=False,):
36 |
37 |
38 | # add this to fully recover the process of conv encoder
39 | h, w = depth_image_dims
40 | stages = int(np.log2(w) - np.log2(minres))
41 | self.h_list = []
42 | self.w_list = []
43 | for i in range(stages):
44 | h, w = (h+1) // 2, (w+1) // 2
45 | self.h_list.append(h)
46 | self.w_list.append(w)
47 | self.h_list = self.h_list[::-1]
48 | self.w_list = self.w_list[::-1]
49 | self.h_list.append(depth_image_dims[0])
50 | self.w_list.append(depth_image_dims[1])
51 |
52 | super(DepthPredictor, self).__init__()
53 | act = getattr(torch.nn, act)
54 | self._cnn_sigmoid = cnn_sigmoid
55 | layer_num = len(self.h_list) - 1
56 | # layer_num = int(np.log2(shape[2]) - np.log2(minres))
57 | # self._minres = minres
58 | # out_ch = minres**2 * depth * 2 ** (layer_num - 1)
59 | out_ch = self.h_list[0] * self.w_list[0] * depth * 2 ** (len(self.h_list) - 2)
60 | self._embed_size = out_ch
61 |
62 | in_dim = out_ch // (self.h_list[0] * self.w_list[0])
63 | out_dim = in_dim // 2
64 |
65 | # Encoder
66 | encoder_layers = []
67 | encoder_layers.append(nn.Linear(forward_heightamp_dim + prop_dim, encoder_hidden_dims[0]))
68 | encoder_layers.append(act())
69 | for l in range(len(encoder_hidden_dims)):
70 | if l == len(encoder_hidden_dims) - 1:
71 | encoder_layers.append(nn.Linear(encoder_hidden_dims[l], self._embed_size))
72 | else:
73 | encoder_layers.append(nn.Linear(encoder_hidden_dims[l], encoder_hidden_dims[l + 1]))
74 | encoder_layers.append(act())
75 | self.encoder = nn.Sequential(*encoder_layers)
76 |
77 |
78 | layers = []
79 | # h, w = minres, minres
80 | for i in range(layer_num):
81 | bias = False
82 | if i == layer_num - 1:
83 | out_dim = 1
84 | act = False
85 | bias = True
86 | norm = False
87 |
88 | if i != 0:
89 | in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
90 | if(self.h_list[i] * 2 == self.h_list[i+1]):
91 | pad_h, outpad_h = 1, 0
92 | else:
93 | pad_h, outpad_h = 2, 1
94 |
95 | if(self.w_list[i] * 2 == self.w_list[i+1]):
96 | pad_w, outpad_w = 1, 0
97 | else:
98 | pad_w, outpad_w = 2, 1
99 |
100 | layers.append(
101 | nn.ConvTranspose2d(
102 | in_dim,
103 | out_dim,
104 | kernel_size,
105 | 2,
106 | padding=(pad_h, pad_w),
107 | output_padding=(outpad_h, outpad_w),
108 | bias=bias,
109 | )
110 | )
111 | if norm:
112 | layers.append(ImgChLayerNorm(out_dim))
113 | if act:
114 | layers.append(act())
115 | in_dim = out_dim
116 | out_dim //= 2
117 | # h, w = h * 2, w * 2
118 | [m.apply(tools.weight_init) for m in layers[:-1]]
119 | layers[-1].apply(tools.uniform_weight_init(outscale))
120 | self.layers = nn.Sequential(*layers)
121 |
122 |
123 | def forward(self, forward_heightmap, prop):
124 | x = torch.concat((forward_heightmap, prop), dim=-1)
125 | x = self.encoder(x)
126 | # (batch, time, -1) -> (batch * time, h, w, ch)
127 | x = x.reshape(
128 | [-1, self.h_list[0], self.w_list[0], self._embed_size // (self.h_list[0] * self.w_list[0])]
129 | )
130 | # (batch, time, -1) -> (batch * time, ch, h, w)
131 | x = x.permute(0, 3, 1, 2)
132 | # print('init decoder shape:', x.shape)
133 | # for layer in self.layers:
134 | # x = layer(x)
135 | # print(x.shape)
136 | x = self.layers(x)
137 | mean = x.permute(0, 2, 3, 1)
138 | if self._cnn_sigmoid:
139 | mean = F.sigmoid(mean)
140 | # else:
141 | # mean += 0.5
142 | return mean
143 |
144 |
145 | class ImgChLayerNorm(nn.Module):
146 | def __init__(self, ch, eps=1e-03):
147 | super(ImgChLayerNorm, self).__init__()
148 | self.norm = torch.nn.LayerNorm(ch, eps=eps)
149 |
150 | def forward(self, x):
151 | x = x.permute(0, 2, 3, 1)
152 | x = self.norm(x)
153 | x = x.permute(0, 3, 1, 2)
154 | return x
155 |
--------------------------------------------------------------------------------
/rsl_rl/runners/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | from .on_policy_runner import OnPolicyRunner
35 | from .wmp_runner import WMPRunner
--------------------------------------------------------------------------------
/rsl_rl/runners/on_policy_runner.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
32 | # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
33 |
34 | import time
35 | import os
36 | from collections import deque
37 | import statistics
38 |
39 | from torch.utils.tensorboard import SummaryWriter
40 | import torch
41 |
42 | from rsl_rl.algorithms import PPO
43 | from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
44 | from rsl_rl.env import VecEnv
45 |
46 |
47 | class OnPolicyRunner:
48 |
49 | def __init__(self,
50 | env: VecEnv,
51 | train_cfg,
52 | log_dir=None,
53 | device='cpu'):
54 |
55 | self.cfg=train_cfg["runner"]
56 | self.alg_cfg = train_cfg["algorithm"]
57 | self.policy_cfg = train_cfg["policy"]
58 | self.device = device
59 | self.env = env
60 | if self.env.num_privileged_obs is not None:
61 | num_critic_obs = self.env.num_privileged_obs
62 | else:
63 | num_critic_obs = self.env.num_obs
64 | actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic
65 | actor_critic: ActorCritic = actor_critic_class( self.env.num_obs,
66 | num_critic_obs,
67 | self.env.num_actions,
68 | **self.policy_cfg).to(self.device)
69 | alg_class = eval(self.cfg["algorithm_class_name"]) # PPO
70 | self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
71 | self.num_steps_per_env = self.cfg["num_steps_per_env"]
72 | self.save_interval = self.cfg["save_interval"]
73 |
74 | # init storage and model
75 | self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])
76 |
77 | # Log
78 | self.log_dir = log_dir
79 | self.writer = None
80 | self.tot_timesteps = 0
81 | self.tot_time = 0
82 | self.current_learning_iteration = 0
83 |
84 | _, _ = self.env.reset()
85 |
86 | def learn(self, num_learning_iterations, init_at_random_ep_len=False):
87 | # initialize writer
88 | if self.log_dir is not None and self.writer is None:
89 | self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
90 | if init_at_random_ep_len:
91 | self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length))
92 | obs = self.env.get_observations()
93 | privileged_obs = self.env.get_privileged_observations()
94 | critic_obs = privileged_obs if privileged_obs is not None else obs
95 | obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
96 | self.alg.actor_critic.train() # switch to train mode (for dropout for example)
97 |
98 | ep_infos = []
99 | rewbuffer = deque(maxlen=100)
100 | lenbuffer = deque(maxlen=100)
101 | cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
102 | cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
103 |
104 | tot_iter = self.current_learning_iteration + num_learning_iterations
105 | for it in range(self.current_learning_iteration, tot_iter):
106 | start = time.time()
107 | # Rollout
108 | with torch.inference_mode():
109 | for i in range(self.num_steps_per_env):
110 | actions = self.alg.act(obs, critic_obs)
111 | obs, privileged_obs, rewards, dones, infos, _, _ = self.env.step(actions)
112 | critic_obs = privileged_obs if privileged_obs is not None else obs
113 | obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
114 | self.alg.process_env_step(rewards, dones, infos)
115 |
116 | if self.log_dir is not None:
117 | # Book keeping
118 | if 'episode' in infos:
119 | ep_infos.append(infos['episode'])
120 | cur_reward_sum += rewards
121 | cur_episode_length += 1
122 | new_ids = (dones > 0).nonzero(as_tuple=False)
123 | rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
124 | lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
125 | cur_reward_sum[new_ids] = 0
126 | cur_episode_length[new_ids] = 0
127 |
128 | stop = time.time()
129 | collection_time = stop - start
130 |
131 | # Learning step
132 | start = stop
133 | self.alg.compute_returns(critic_obs)
134 |
135 | mean_value_loss, mean_surrogate_loss = self.alg.update()
136 | stop = time.time()
137 | learn_time = stop - start
138 | if self.log_dir is not None:
139 | self.log(locals())
140 | if it % self.save_interval == 0:
141 | self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it)))
142 | ep_infos.clear()
143 |
144 | self.current_learning_iteration += num_learning_iterations
145 | self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
146 |
147 | def log(self, locs, width=80, pad=35):
148 | self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
149 | self.tot_time += locs['collection_time'] + locs['learn_time']
150 | iteration_time = locs['collection_time'] + locs['learn_time']
151 |
152 | ep_string = f''
153 | if locs['ep_infos']:
154 | for key in locs['ep_infos'][0]:
155 | infotensor = torch.tensor([], device=self.device)
156 | for ep_info in locs['ep_infos']:
157 | # handle scalar and zero dimensional tensor infos
158 | if not isinstance(ep_info[key], torch.Tensor):
159 | ep_info[key] = torch.Tensor([ep_info[key]])
160 | if len(ep_info[key].shape) == 0:
161 | ep_info[key] = ep_info[key].unsqueeze(0)
162 | infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
163 | value = torch.mean(infotensor)
164 | self.writer.add_scalar('Episode/' + key, value, locs['it'])
165 | ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
166 | mean_std = self.alg.actor_critic.std.mean()
167 | fps = int(self.num_steps_per_env * self.env.num_envs / (locs['collection_time'] + locs['learn_time']))
168 |
169 | self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it'])
170 | self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it'])
171 | self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it'])
172 | self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it'])
173 | self.writer.add_scalar('Perf/total_fps', fps, locs['it'])
174 | self.writer.add_scalar('Perf/collection time', locs['collection_time'], locs['it'])
175 | self.writer.add_scalar('Perf/learning_time', locs['learn_time'], locs['it'])
176 | if len(locs['rewbuffer']) > 0:
177 | self.writer.add_scalar('Train/mean_reward', statistics.mean(locs['rewbuffer']), locs['it'])
178 | self.writer.add_scalar('Train/mean_episode_length', statistics.mean(locs['lenbuffer']), locs['it'])
179 | self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time)
180 | self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time)
181 |
182 | str = f" \033[1m Learning iteration {locs['it']}/{self.current_learning_iteration + locs['num_learning_iterations']} \033[0m "
183 |
184 | if len(locs['rewbuffer']) > 0:
185 | log_string = (f"""{'#' * width}\n"""
186 | f"""{str.center(width, ' ')}\n\n"""
187 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
188 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
189 | f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
190 | f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
191 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
192 | f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
193 | f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""")
194 | # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
195 | # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
196 | else:
197 | log_string = (f"""{'#' * width}\n"""
198 | f"""{str.center(width, ' ')}\n\n"""
199 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
200 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
201 | f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
202 | f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
203 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""")
204 | # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
205 | # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
206 |
207 | log_string += ep_string
208 | log_string += (f"""{'-' * width}\n"""
209 | f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
210 | f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
211 | f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
212 | f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
213 | locs['num_learning_iterations'] - locs['it']):.1f}s\n""")
214 | print(log_string)
215 |
216 | def save(self, path, infos=None):
217 | torch.save({
218 | 'model_state_dict': self.alg.actor_critic.state_dict(),
219 | 'optimizer_state_dict': self.alg.optimizer.state_dict(),
220 | 'iter': self.current_learning_iteration,
221 | 'infos': infos,
222 | }, path)
223 |
224 | def load(self, path, load_optimizer=True):
225 | loaded_dict = torch.load(path)
226 | self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
227 | if load_optimizer:
228 | self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
229 | self.current_learning_iteration = loaded_dict['iter']
230 | return loaded_dict['infos']
231 |
232 | def get_inference_policy(self, device=None):
233 | self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example)
234 | if device is not None:
235 | self.alg.actor_critic.to(device)
236 | return self.alg.actor_critic.act_inference
237 |
--------------------------------------------------------------------------------
/rsl_rl/storage/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from .rollout_storage import RolloutStorage
--------------------------------------------------------------------------------
/rsl_rl/storage/replay_buffer.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 |
32 | import torch
33 | import numpy as np
34 |
35 |
36 | class ReplayBuffer:
37 | """Fixed-size buffer to store experience tuples."""
38 |
39 | def __init__(self, obs_dim, buffer_size, device):
40 | """Initialize a ReplayBuffer object.
41 | Arguments:
42 | buffer_size (int): maximum size of buffer
43 | """
44 | self.states = torch.zeros(buffer_size, obs_dim).to(device)
45 | self.next_states = torch.zeros(buffer_size, obs_dim).to(device)
46 | self.buffer_size = buffer_size
47 | self.device = device
48 |
49 | self.step = 0
50 | self.num_samples = 0
51 |
52 | def insert(self, states, next_states):
53 | """Add new states to memory."""
54 |
55 | num_states = states.shape[0]
56 | start_idx = self.step
57 | end_idx = self.step + num_states
58 | if end_idx > self.buffer_size:
59 | self.states[self.step:self.buffer_size] = states[:self.buffer_size - self.step]
60 | self.next_states[self.step:self.buffer_size] = next_states[:self.buffer_size - self.step]
61 | self.states[:end_idx - self.buffer_size] = states[self.buffer_size - self.step:]
62 | self.next_states[:end_idx - self.buffer_size] = next_states[self.buffer_size - self.step:]
63 | else:
64 | self.states[start_idx:end_idx] = states
65 | self.next_states[start_idx:end_idx] = next_states
66 |
67 | self.num_samples = min(self.buffer_size, max(end_idx, self.num_samples))
68 | self.step = (self.step + num_states) % self.buffer_size
69 |
70 | def feed_forward_generator(self, num_mini_batch, mini_batch_size):
71 | for _ in range(num_mini_batch):
72 | sample_idxs = np.random.choice(self.num_samples, size=mini_batch_size)
73 | yield (self.states[sample_idxs].to(self.device),
74 | self.next_states[sample_idxs].to(self.device))
75 |
--------------------------------------------------------------------------------
/rsl_rl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .utils import split_and_pad_trajectories, unpad_trajectories
--------------------------------------------------------------------------------
/rsl_rl/utils/utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 | from typing import Tuple
31 |
32 | import torch
33 | import numpy as np
34 |
35 | _EPS = np.finfo(float).eps * 4.0
36 |
37 |
38 | def split_and_pad_trajectories(tensor, dones):
39 | """ Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
40 | Returns masks corresponding to valid parts of the trajectories
41 | Example:
42 | Input: [ [a1, a2, a3, a4 | a5, a6],
43 | [b1, b2 | b3, b4, b5 | b6]
44 | ]
45 |
46 | Output:[ [a1, a2, a3, a4], | [ [True, True, True, True],
47 | [a5, a6, 0, 0], | [True, True, False, False],
48 | [b1, b2, 0, 0], | [True, True, False, False],
49 | [b3, b4, b5, 0], | [True, True, True, False],
50 | [b6, 0, 0, 0] | [True, False, False, False],
51 | ] | ]
52 |
53 | Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions]
54 | """
55 | dones = dones.clone()
56 | dones[-1] = 1
57 | # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping
58 | flat_dones = dones.transpose(1, 0).reshape(-1, 1)
59 |
60 | # Get length of trajectory by counting the number of successive not done elements
61 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0]))
62 | trajectory_lengths = done_indices[1:] - done_indices[:-1]
63 | trajectory_lengths_list = trajectory_lengths.tolist()
64 | # Extract the individual trajectories
65 | trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1),trajectory_lengths_list)
66 | padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
67 |
68 |
69 | trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
70 | return padded_trajectories, trajectory_masks
71 |
72 | def unpad_trajectories(trajectories, masks):
73 | """ Does the inverse operation of split_and_pad_trajectories()
74 | """
75 | # Need to transpose before and after the masking to have proper reshaping
76 | return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0], trajectories.shape[-1]).transpose(1, 0)
77 |
78 |
79 | class RunningMeanStd(object):
80 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
81 | """
82 | Calulates the running mean and std of a data stream
83 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
84 | :param epsilon: helps with arithmetic issues
85 | :param shape: the shape of the data stream's output
86 | """
87 | self.mean = np.zeros(shape, np.float64)
88 | self.var = np.ones(shape, np.float64)
89 | self.count = epsilon
90 |
91 | def update(self, arr: np.ndarray) -> None:
92 | batch_mean = np.mean(arr, axis=0)
93 | batch_var = np.var(arr, axis=0)
94 | batch_count = arr.shape[0]
95 | self.update_from_moments(batch_mean, batch_var, batch_count)
96 |
97 | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None:
98 | delta = batch_mean - self.mean
99 | tot_count = self.count + batch_count
100 |
101 | new_mean = self.mean + delta * batch_count / tot_count
102 | m_a = self.var * self.count
103 | m_b = batch_var * batch_count
104 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
105 | new_var = m_2 / (self.count + batch_count)
106 |
107 | new_count = batch_count + self.count
108 |
109 | self.mean = new_mean
110 | self.var = new_var
111 | self.count = new_count
112 |
113 |
114 | class Normalizer(RunningMeanStd):
115 | def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0):
116 | super().__init__(shape=input_dim)
117 | self.epsilon = epsilon
118 | self.clip_obs = clip_obs
119 |
120 | def normalize(self, input):
121 | return np.clip(
122 | (input - self.mean) / np.sqrt(self.var + self.epsilon),
123 | -self.clip_obs, self.clip_obs)
124 |
125 | def normalize_torch(self, input, device):
126 | mean_torch = torch.tensor(
127 | self.mean, device=device, dtype=torch.float32)
128 | std_torch = torch.sqrt(torch.tensor(
129 | self.var + self.epsilon, device=device, dtype=torch.float32))
130 | return torch.clamp(
131 | (input - mean_torch) / std_torch, -self.clip_obs, self.clip_obs)
132 |
133 | def update_normalizer(self, rollouts, expert_loader):
134 | policy_data_generator = rollouts.feed_forward_generator_amp(
135 | None, mini_batch_size=expert_loader.batch_size)
136 | expert_data_generator = expert_loader.dataset.feed_forward_generator_amp(
137 | expert_loader.batch_size)
138 |
139 | for expert_batch, policy_batch in zip(expert_data_generator, policy_data_generator):
140 | self.update(
141 | torch.vstack(tuple(policy_batch) + tuple(expert_batch)).cpu().numpy())
142 |
143 |
144 | class Normalize(torch.nn.Module):
145 | def __init__(self):
146 | super(Normalize, self).__init__()
147 | self.normalize = torch.nn.functional.normalize
148 |
149 | def forward(self, x):
150 | x = self.normalize(x, dim=-1)
151 | return x
152 |
153 |
154 | def quaternion_slerp(q0, q1, fraction, spin=0, shortestpath=True):
155 | """Batch quaternion spherical linear interpolation."""
156 |
157 | out = torch.zeros_like(q0)
158 |
159 | zero_mask = torch.isclose(fraction, torch.zeros_like(fraction)).squeeze()
160 | ones_mask = torch.isclose(fraction, torch.ones_like(fraction)).squeeze()
161 | out[zero_mask] = q0[zero_mask]
162 | out[ones_mask] = q1[ones_mask]
163 |
164 | d = torch.sum(q0 * q1, dim=-1, keepdim=True)
165 | dist_mask = (torch.abs(torch.abs(d) - 1.0) < _EPS).squeeze()
166 | out[dist_mask] = q0[dist_mask]
167 |
168 | if shortestpath:
169 | d_old = torch.clone(d)
170 | d = torch.where(d_old < 0, -d, d)
171 | q1 = torch.where(d_old < 0, -q1, q1)
172 |
173 | angle = torch.acos(d) + spin * torch.pi
174 | angle_mask = (torch.abs(angle) < _EPS).squeeze()
175 | out[angle_mask] = q0[angle_mask]
176 |
177 | final_mask = torch.logical_or(zero_mask, ones_mask)
178 | final_mask = torch.logical_or(final_mask, dist_mask)
179 | final_mask = torch.logical_or(final_mask, angle_mask)
180 | final_mask = torch.logical_not(final_mask)
181 |
182 | isin = 1.0 / angle
183 | q0 *= torch.sin((1.0 - fraction) * angle) * isin
184 | q1 *= torch.sin(fraction * angle) * isin
185 | q0 += q1
186 | out[final_mask] = q0[final_mask]
187 | return out
188 |
--------------------------------------------------------------------------------