├── .gitignore ├── LICENSE ├── README.md ├── assets ├── main_result.png ├── overview.png ├── overview_gleambench.png └── statistic.png ├── data_gleam └── README.md ├── gleam ├── __init__.py ├── callback.py ├── env │ ├── config_gleam.py │ ├── config_gleam_eval.py │ ├── config_legged_visual.py │ ├── env_base.py │ ├── env_gleam_base.py │ ├── env_gleam_eval.py │ ├── env_gleam_stage1.py │ └── env_gleam_stage2.py ├── network │ ├── base.py │ ├── encoder.py │ ├── init.py │ └── locotransformer.py ├── test │ └── test_gleam_gleambench.py ├── train │ ├── __init__.py │ ├── train_gleam_stage1.py │ └── train_gleam_stage2.py ├── utils │ ├── bfs_cuda_2D.h │ ├── bfs_cuda_kernel_2D.cu │ └── utils.py └── wrapper │ └── env_wrapper_gleam.py ├── legged_gym ├── __init__.py ├── env │ ├── __init__.py │ ├── a1 │ │ └── a1_config.py │ ├── anymal_b │ │ └── anymal_b_config.py │ ├── anymal_c │ │ ├── anymal.py │ │ ├── flat │ │ │ └── anymal_c_flat_config.py │ │ └── mixed_terrains │ │ │ └── anymal_c_rough_config.py │ └── base │ │ ├── base_config.py │ │ ├── base_task.py │ │ ├── drone_robot.py │ │ ├── drone_robot_visual.py │ │ ├── legged_robot.py │ │ └── legged_robot_config.py ├── scripts │ ├── play.py │ └── train.py ├── tests │ └── test_env.py └── utils │ ├── __init__.py │ ├── helpers.py │ ├── logger.py │ ├── math.py │ ├── task_registry.py │ └── terrain.py ├── licenses ├── assets │ ├── ANYmal_b_license.txt │ ├── ANYmal_c_license.txt │ ├── a1_license.txt │ └── cassie_license.txt └── dependencies │ └── matplotlib_license.txt ├── requirements.txt ├── resources └── robots │ └── drone │ └── cf2.dae ├── rsl_rl ├── __init__.py ├── algorithms │ ├── __init__.py │ └── ppo.py ├── env │ ├── __init__.py │ └── vec_env.py ├── modules │ ├── __init__.py │ ├── actor_critic.py │ └── actor_critic_recurrent.py ├── runners │ ├── __init__.py │ └── on_policy_runner.py ├── storage │ ├── __init__.py │ └── rollout_storage.py └── utils │ ├── __init__.py │ └── utils.py ├── setup.py ├── stable_baselines3 ├── __init__.py ├── a2c │ ├── __init__.py │ ├── a2c.py │ └── policies.py ├── common │ ├── __init__.py │ ├── atari_wrappers.py │ ├── base_class.py │ ├── base_class_grid_obs.py │ ├── buffers.py │ ├── callbacks.py │ ├── callbacks_gleam.py │ ├── distributions.py │ ├── env_checker.py │ ├── env_util.py │ ├── envs │ │ ├── __init__.py │ │ ├── bit_flipping_env.py │ │ ├── identity_env.py │ │ └── multi_input_envs.py │ ├── evaluation.py │ ├── evaluation_gleam.py │ ├── logger.py │ ├── monitor.py │ ├── noise.py │ ├── off_policy_algorithm.py │ ├── on_policy_algorithm.py │ ├── on_policy_algorithm_grid_obs.py │ ├── on_policy_algorithm_hybrid_act.py │ ├── policies.py │ ├── preprocessing.py │ ├── results_plotter.py │ ├── running_mean_std.py │ ├── save_util.py │ ├── sb2_compat │ │ ├── __init__.py │ │ └── rmsprop_tf_like.py │ ├── torch_layers.py │ ├── type_aliases.py │ ├── utils.py │ └── vec_env │ │ ├── __init__.py │ │ ├── base_vec_env.py │ │ ├── dummy_vec_env.py │ │ ├── stacked_observations.py │ │ ├── subproc_vec_env.py │ │ ├── util.py │ │ ├── vec_check_nan.py │ │ ├── vec_extract_dict_obs.py │ │ ├── vec_frame_stack.py │ │ ├── vec_monitor.py │ │ ├── vec_normalize.py │ │ ├── vec_transpose.py │ │ └── vec_video_recorder.py ├── ddpg │ ├── __init__.py │ ├── ddpg.py │ └── policies.py ├── dqn │ ├── __init__.py │ ├── dqn.py │ └── policies.py ├── her │ ├── __init__.py │ ├── goal_selection_strategy.py │ └── her_replay_buffer.py ├── ppo │ ├── __init__.py │ ├── policies.py │ ├── ppo.py │ ├── ppo_grid_hybrid_action.py │ ├── ppo_grid_obs.py │ ├── ppo_occant.py │ └── ppo_test.py ├── sac │ ├── __init__.py │ ├── policies.py │ └── sac.py ├── td3 │ ├── __init__.py │ ├── policies.py │ └── td3.py ├── utils.py └── version.txt └── wandb_utils ├── __init__.py ├── wandb_api_key_file.txt └── wandb_callback.py /.gitignore: -------------------------------------------------------------------------------- 1 | # These are some examples of commonly ignored file patterns. 2 | # You should customize this list as applicable to your project. 3 | # Learn more about .gitignore: 4 | # https://www.atlassian.com/git/tutorials/saving-changes/gitignore 5 | 6 | # Node artifact files 7 | node_modules/ 8 | 9 | # Compiled Java class files 10 | *.class 11 | 12 | # Compiled Python bytecode 13 | *.py[cod] 14 | 15 | # Log files 16 | *.log 17 | log.txt 18 | 19 | # Package files 20 | *.jar 21 | 22 | # Maven 23 | target/ 24 | dist/ 25 | 26 | # JetBrains IDE 27 | .idea/ 28 | 29 | # Unit test reports 30 | TEST*.xml 31 | 32 | # Generated by MacOS 33 | .DS_Store 34 | 35 | # Generated by Windows 36 | Thumbs.db 37 | 38 | # Applications 39 | *.app 40 | *.exe 41 | *.war 42 | 43 | # Large media files 44 | *.gif 45 | *.mp4 46 | *.tiff 47 | *.avi 48 | *.flv 49 | *.mov 50 | *.wmv 51 | 52 | # VS Code 53 | .vscode 54 | *.vsix 55 | 56 | # logs 57 | logs 58 | runs 59 | 60 | # other 61 | *.egg-info 62 | __pycache__ 63 | venv 64 | *.cpp 65 | 66 | # model 67 | *.pkl 68 | *.zip 69 | *.pt 70 | *.pth 71 | 72 | # log 73 | *.log 74 | *.csv 75 | 76 | # pyc 77 | *.pyc 78 | 79 | # json 80 | *.json 81 | 82 | # build 83 | /build/ 84 | **.so 85 | 86 | # test scripts 87 | **/wandb/ 88 | *.idea 89 | 90 | 91 | # Image 92 | *.jpg 93 | *.jpeg 94 | *.png 95 | 96 | 97 | *.npy 98 | *.tar.gz 99 | 100 | events.out* 101 | *.th 102 | log 103 | *.mtl 104 | *.obj 105 | *.ply 106 | *.glb 107 | *.fbx 108 | *.pcd 109 | *.xyz 110 | *.mlp 111 | *.sh 112 | *.lib 113 | *.whl 114 | *.hdr 115 | *.bin 116 | *.navmesh 117 | *.sur 118 | *.h5 119 | *.urdf 120 | 121 | data_gleam/* 122 | venv/* 123 | 124 | # module 125 | ckpt 126 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, ETH Zurich, Nikita Rudin 2 | Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its contributors 16 | may be used to endorse or promote products derived from this software without 17 | specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 26 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | See licenses/assets for license information for assets included in this repository. 31 | See licenses/dependencies for license information of dependencies of this package. 32 | -------------------------------------------------------------------------------- /assets/main_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/assets/main_result.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/assets/overview.png -------------------------------------------------------------------------------- /assets/overview_gleambench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/assets/overview_gleambench.png -------------------------------------------------------------------------------- /assets/statistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/assets/statistic.png -------------------------------------------------------------------------------- /data_gleam/README.md: -------------------------------------------------------------------------------- 1 | # GLEAM-Bench 2 | 3 | 6 | 7 | 8 | 9 |

10 | 11 |

12 |

13 | 14 |

15 | 16 | 17 | We introduce GLEAM-Bench, a benchmark for generalizable exploration for active mapping in complex 3D indoor scenes. 18 | These scene meshes are characterized by watertight geometry, diverse floorplans (≥10 types), and complex interconnectivity. We unify and refine multi-source datasets through manual filtering, geometric repair, and task-oriented preprocessing. 19 | To simulate the exploration process, we connect our dataset with NVIDIA Isaac Gym, enabling parallel sensory data simulation and online policy training. 20 | 21 | 22 | ## Download 23 | Please fill in the [form](https://docs.google.com/forms/d/e/1FAIpQLSdq9aX1dwoyBb31nm8L_Mx5FeaVsr5AY538UiwKqg8LPKX9vg/viewform?usp=sharing) to access the download links. 24 | 25 | **Consolidated Version (GLEAM)** 26 | 27 | We provide all the preprocessed data used in our work, , including mesh files (in `obj` folder), ground-truth surface points (in `gt` folder) and asset indexing files (in `urdf` folder). The directory structure should be as follows. 28 | 29 | ``` 30 | data_gleam 31 | ├── train_stage1_512 32 | │ ├── gt 33 | │ ├── obj 34 | │ ├── urdf 35 | ├── train_stage2_512 36 | │ ├── gt 37 | │ ├── obj 38 | │ ├── urdf 39 | ├── eval_128 40 | │ ├── gt 41 | │ ├── obj 42 | │ ├── urdf 43 | ``` 44 | 45 | The standard training process of GLEAM is divided into two stages (i.e. stage1 and stage2), each involving different 512 training indoor scenes. The evaluation involves 128 unseen testing scenes from ProcTHOR, HSSD, Gibson and Matterport3D (cross-dataset generalization). 46 | 47 | 48 | 49 | 50 | 51 | 52 | **Original Separate Version** 53 | 54 | The separate meshes are also provided in the download links: 55 | ``` 56 | data_gleam_raw 57 | ├── procthor_12-room_64 58 | │ ├── gt 59 | │ ├── obj 60 | ├── gibson_96 61 | │ ├── gt 62 | │ ├── obj 63 | ├── hssd_32 64 | │ ├── gt 65 | │ ├── obj 66 | ... 67 | ``` 68 | 69 | **Raw Version (With Texture and Part-Level Mesh Layers)** 70 | 71 | Due to storage limitations, we will not be releasing scenes with textures or part-level mesh layers. If you want to get the raw version of GLEAM-Bench, we recommend you: 72 | - **ProcTHOR**: export the meshes using our provided C# script; 73 | - **HSSD**, **Gibson**, **Matterport3D**: download the corresponding raw meshes from the original source. 74 | 75 | 76 | ## Export Meshes from Generated Scenes by ProcTHOR 77 | We also provide the C# script [[HERE](https://github.com/zjwzcx/Batch-Export-ProcTHOR-Meshes)] to export mesh files (.fbx) from generated scenes by ProcTHOR. Note that these generated mesh files have textures and interactive object-level layers. 78 | 79 | 80 | ## Citation 81 | 82 | The GLEAM-Bench dataset comes from the GLEAM paper: 83 | 84 | - **arXiv**: https://arxiv.org/abs/2505.20294 85 | 86 | - **Code**: https://github.com/zjwzcx/GLEAM 87 | 88 | - **BibTex**: 89 | ```bibtex 90 | @misc{chen2025gleam, 91 | title={GLEAM: Learning Generalizable Exploration Policy for Active Mapping in Complex 3D Indoor Scenes}, 92 | author={Xiao Chen and Tai Wang and Quanyi Li and Tao Huang and Jiangmiao Pang and Tianfan Xue}, 93 | year={2025}, 94 | eprint={2505.20294}, 95 | archivePrefix={arXiv}, 96 | primaryClass={cs.CV}, 97 | url={https://arxiv.org/abs/2505.20294}, 98 | } 99 | ``` 100 | 101 | 102 | If you use our dataset and benchmark, please kindly cite the original datasets involved in our work. BibTex entries are provided below. 103 | 104 |
Dataset BibTex 105 | 106 | ```bibtex 107 | @inproceedings{procthor, 108 | author={Matt Deitke and Eli VanderBilt and Alvaro Herrasti and 109 | Luca Weihs and Jordi Salvador and Kiana Ehsani and 110 | Winson Han and Eric Kolve and Ali Farhadi and 111 | Aniruddha Kembhavi and Roozbeh Mottaghi}, 112 | title={{ProcTHOR: Large-Scale Embodied AI Using Procedural Generation}}, 113 | booktitle={NeurIPS}, 114 | year={2022}, 115 | note={Outstanding Paper Award} 116 | } 117 | ``` 118 | ```bibtex 119 | @inproceedings{xiazamirhe2018gibsonenv, 120 | title={Gibson Env: real-world perception for embodied agents}, 121 | author={Xia, Fei and R. Zamir, Amir and He, Zhi-Yang and Sax, Alexander and Malik, Jitendra and Savarese, Silvio}, 122 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2018 IEEE Conference on}, 123 | year={2018}, 124 | organization={IEEE} 125 | } 126 | ``` 127 | ```bibtex 128 | @article{khanna2023hssd, 129 | author={Khanna*, Mukul and Mao*, Yongsen and Jiang, Hanxiao and Haresh, Sanjay and Shacklett, Brennan and Batra, Dhruv and Clegg, Alexander and Undersander, Eric and Chang, Angel X. and Savva, Manolis}, 130 | title={{Habitat Synthetic Scenes Dataset (HSSD-200): An Analysis of 3D Scene Scale and Realism Tradeoffs for ObjectGoal Navigation}}, 131 | journal={arXiv preprint}, 132 | year={2023}, 133 | eprint={2306.11290}, 134 | archivePrefix={arXiv}, 135 | primaryClass={cs.CV} 136 | } 137 | ``` 138 | ```bibtex 139 | @article{Matterport3D, 140 | title={Matterport3D: Learning from RGB-D Data in Indoor Environments}, 141 | author={Chang, Angel and Dai, Angela and Funkhouser, Thomas and Halber, Maciej and Niessner, Matthias and Savva, Manolis and Song, Shuran and Zeng, Andy and Zhang, Yinda}, 142 | journal={International Conference on 3D Vision (3DV)}, 143 | year={2017} 144 | } 145 | ``` 146 | 147 |
148 | -------------------------------------------------------------------------------- /gleam/__init__.py: -------------------------------------------------------------------------------- 1 | from legged_gym.utils.task_registry import task_registry 2 | 3 | from gleam.env.env_gleam_stage1 import Env_GLEAM_Stage1 4 | from gleam.env.env_gleam_stage2 import Env_GLEAM_Stage2 5 | from gleam.env.env_gleam_eval import Env_GLEAM_Eval 6 | from gleam.env.config_gleam import Config_GLEAM, DroneCfgPPO 7 | from gleam.env.config_gleam_eval import Config_GLEAM_Eval 8 | task_registry.register("train_gleam_stage1", Env_GLEAM_Stage1, Config_GLEAM, DroneCfgPPO) 9 | task_registry.register("train_gleam_stage2", Env_GLEAM_Stage2, Config_GLEAM, DroneCfgPPO) 10 | task_registry.register("eval_gleam_gleambench", Env_GLEAM_Eval, Config_GLEAM_Eval, DroneCfgPPO) 11 | -------------------------------------------------------------------------------- /gleam/callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as th 3 | from stable_baselines3.common.callbacks import CheckpointCallback 4 | 5 | 6 | class BestCKPTCallback(CheckpointCallback): 7 | def __init__( 8 | self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0, key_list: list = None 9 | ): 10 | key_list = key_list or [] 11 | self.rollout_count = 0 12 | super(BestCKPTCallback, self).__init__(save_freq, save_path, name_prefix, verbose) 13 | self.key_highest_value = {k: 0. for k in key_list} 14 | 15 | @property 16 | def best_save_path(self): 17 | return self.locals["self"].logger.dir 18 | 19 | @property 20 | def device(self): 21 | return self.locals["self"].device 22 | 23 | def _on_rollout_end(self) -> None: 24 | if self.n_calls % self.save_freq == 0: 25 | path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps") 26 | self.model.save(path) 27 | if self.verbose > 1: 28 | print(f"Saving model checkpoint to {path}") 29 | 30 | for key in self.key_highest_value.keys(): 31 | value = self.calculate_value(key) 32 | if value > self.key_highest_value[key]: 33 | self.key_highest_value[key] = value 34 | path = os.path.join(self.best_save_path, "{}_best_{}".format(self.name_prefix, key)) 35 | self.model.save(path) 36 | # if self.verbose > 1: 37 | print("Saving Best {} checkpoint to {}: {}".format(key, path, value)) 38 | 39 | def calculate_value(self, key): 40 | ep_info_buffer = self.locals["self"].ep_info_buffer 41 | infotensor = th.tensor([], device=self.device) 42 | assert key in ep_info_buffer[0], "no key named {}, can not save checkpoint".format(key) 43 | for ep_info in ep_info_buffer: 44 | # handle scalar and zero dimensional tensor infos 45 | if not isinstance(ep_info[key], th.Tensor): 46 | ep_info[key] = th.Tensor([ep_info[key]]) 47 | if len(ep_info[key].shape) == 0: 48 | ep_info[key] = ep_info[key].unsqueeze(0) 49 | infotensor = th.cat((infotensor, ep_info[key].to(self.device))) 50 | value = th.mean(infotensor).detach().cpu().numpy() 51 | return value 52 | 53 | class ReconstructionCallBack(BestCKPTCallback): 54 | pass 55 | -------------------------------------------------------------------------------- /gleam/env/config_gleam.py: -------------------------------------------------------------------------------- 1 | from gleam.env.config_legged_visual import LeggedVisualInputConfig, LeggedVIsualInputCfgPPO 2 | from isaacgym import gymapi 3 | 4 | 5 | class Config_GLEAM(LeggedVisualInputConfig): 6 | position_use_polar_coordinates = False # the position will be represented by (r, \theta, \phi) instead of (x, y, z) 7 | direction_use_vector = False # (r, p, y) 8 | debug_save_image_tensor = False 9 | debug_save_path = None 10 | max_episode_length = 500 # max_steps_per_episode 11 | 12 | num_sampled_point = 5000 13 | 14 | class rewards: 15 | class scales: # * self.dt (0.019999) / episode_length_s (20). For example, when reward_scales=1000, rew_xxx = reward * 1000 * 0.019999 / 20 = reward 16 | surface_coverage_2d = 1000 # original scale (coverage ratio: [0, 1]) 17 | 18 | termination = 50 # Terminal reward / penalty 19 | collision = -100 # Penalize collisions on selected bodies 20 | 21 | only_positive_rewards = False # if true negative total rewards are clipped at zero (avoids early termination problems) 22 | max_contact_force = 100. # forces above this value are penalized 23 | 24 | class normalization: 25 | class obs_scales: 26 | lin_vel = 2.0 27 | ang_vel = 0.25 28 | dof_pos = 1.0 29 | dof_vel = 0.05 30 | height_measurements = 5.0 31 | 32 | pi = 3.14159265359 33 | clip_observations = 255. 34 | 35 | # discrete action space 36 | init_action = [64, 64, 0, 0, 0, 0] # 64: keep still 37 | clip_actions_up = [128, 128, 0, 0, 0, 0] 38 | clip_actions_low = [0, 0, 0, 0, 0, 0] 39 | 40 | class visual_input(LeggedVisualInputConfig.visual_input): 41 | camera_width = 256 42 | camera_height = 32 43 | horizontal_fov = 90.0 # Horizontal field of view in degrees. Vertical field of view is calculated from height to width ratio 44 | 45 | stack = 100 46 | supersampling_horizontal = 1 47 | supersampling_vertical = 1 48 | normalization = True # normalize pixels to [0, 1] 49 | far_plane = 2000000.0 # distance in world coordinates to far-clipping plane 50 | near_plane = 0.0010000000474974513 # distance in world coordinate units to near-clipping plane 51 | 52 | type = gymapi.IMAGE_DEPTH 53 | 54 | class asset(LeggedVisualInputConfig.asset): 55 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/drone/cf2x.urdf' 56 | name = "cf2x" 57 | prop_name = "prop" 58 | penalize_contacts_on = ["prop", "base"] 59 | terminate_after_contacts_on = ["prop", "base"] 60 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter 61 | 62 | class env(LeggedVisualInputConfig.env): 63 | num_observations = 6 64 | episode_length_s = 20 # in second !!!!! 65 | num_actions = 6 66 | env_spacing = 20 67 | 68 | class termination: 69 | collision = True 70 | max_step_done = True 71 | 72 | 73 | class DroneCfgPPO(LeggedVIsualInputCfgPPO): 74 | """ 75 | This config is only a placeholder for using task register and thus unless, since we use sb3 instead of RSL_RL 76 | """ 77 | seed = 1 78 | runner_class_name = 'OnPolicyRunner' 79 | 80 | class policy(LeggedVIsualInputCfgPPO.policy): 81 | init_noise_std = 1.0 82 | actor_hidden_dims = [512, 256, 128] 83 | critic_hidden_dims = [512, 256, 128] 84 | activation = 'elu' # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid 85 | # only for 'ActorCriticRecurrent': 86 | # rnn_type = 'lstm' 87 | # rnn_hidden_size = 512 88 | # rnn_num_layers = 1 89 | 90 | class algorithm(LeggedVIsualInputCfgPPO.algorithm): 91 | # training params 92 | value_loss_coef = 1.0 93 | use_clipped_value_loss = True 94 | clip_param = 0.2 95 | entropy_coef = 0.01 96 | num_learning_epochs = 5 97 | num_mini_batches = 4 # mini batch size = num_envs*nsteps / nminibatches 98 | learning_rate = 1.e-3 # 5.e-4 99 | schedule = 'adaptive' # could be adaptive, fixed 100 | gamma = 0.99 101 | lam = 0.95 102 | desired_kl = 0.01 103 | max_grad_norm = 1. 104 | 105 | class runner(LeggedVIsualInputCfgPPO.runner): 106 | policy_class_name = 'ActorCritic' 107 | algorithm_class_name = 'PPO' 108 | num_steps_per_env = 24 # per iteration 109 | max_iterations = 1500 # number of policy updates 110 | 111 | # logging 112 | save_interval = 50 # check for potential saves every this many iterations 113 | experiment_name = 'test' 114 | run_name = '' 115 | # load and resume 116 | resume = False 117 | load_run = -1 # -1 = last run 118 | checkpoint = -1 # -1 = last saved model 119 | resume_path = None # updated from load_run and chkpt 120 | 121 | class visual_input(LeggedVisualInputConfig.visual_input): 122 | camera_width = 320 123 | camera_height = 240 124 | type = gymapi.IMAGE_COLOR 125 | # stack = 5 # consecutive frames to stack 126 | normalization = True 127 | cam_pos = (0.0, 0, 0.0) 128 | -------------------------------------------------------------------------------- /gleam/env/config_gleam_eval.py: -------------------------------------------------------------------------------- 1 | from gleam.env.config_legged_visual import LeggedVisualInputConfig 2 | from isaacgym import gymapi 3 | 4 | 5 | class Config_GLEAM_Eval(LeggedVisualInputConfig): 6 | position_use_polar_coordinates = False # the position will be represented by (r, \theta, \phi) instead of (x, y, z) 7 | direction_use_vector = False # (r, p, y) 8 | debug_save_image_tensor = False 9 | debug_save_path = None 10 | max_episode_length = 50 # max_steps_per_episode 11 | 12 | num_sampled_point = 5000 13 | 14 | class rewards: 15 | class scales: # * self.dt (0.019999) / episode_length_s (20). For example, when reward_scales=1000, rew_xxx = reward * 1000 * 0.019999 / 20 = reward 16 | surface_coverage_2d = 1000 # original scale (coverage ratio: [0, 1]) 17 | 18 | only_positive_rewards = False # if true negative total rewards are clipped at zero (avoids early termination problems) 19 | max_contact_force = 100. # forces above this value are penalized 20 | 21 | class normalization: 22 | class obs_scales: 23 | lin_vel = 2.0 24 | ang_vel = 0.25 25 | dof_pos = 1.0 26 | dof_vel = 0.05 27 | height_measurements = 5.0 28 | 29 | pi = 3.14159265359 30 | clip_observations = 255. 31 | 32 | # discrete action space 33 | init_action = [64, 64, 0, 0, 0, 0] # 64: keep still 34 | clip_actions_up = [128, 128, 0, 0, 0, 0] 35 | clip_actions_low = [0, 0, 0, 0, 0, 0] 36 | 37 | class visual_input(LeggedVisualInputConfig.visual_input): 38 | camera_width = 256 39 | camera_height = 32 40 | horizontal_fov = 90.0 # Horizontal field of view in degrees. Vertical field of view is calculated from height to width ratio 41 | 42 | stack = 100 43 | supersampling_horizontal = 1 44 | supersampling_vertical = 1 45 | normalization = True # normalize pixels to [0, 1] 46 | far_plane = 2000000.0 # distance in world coordinates to far-clipping plane 47 | near_plane = 0.0010000000474974513 # distance in world coordinate units to near-clipping plane 48 | 49 | type = gymapi.IMAGE_DEPTH 50 | 51 | class asset(LeggedVisualInputConfig.asset): 52 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/drone/cf2x.urdf' 53 | name = "cf2x" 54 | prop_name = "prop" 55 | penalize_contacts_on = ["prop", "base"] 56 | terminate_after_contacts_on = ["prop", "base"] 57 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter 58 | 59 | class env(LeggedVisualInputConfig.env): 60 | num_observations = 6 61 | episode_length_s = 20 # in second !!!!! 62 | num_actions = 6 63 | env_spacing = 20 64 | 65 | class termination: 66 | collision = True 67 | max_step_done = True -------------------------------------------------------------------------------- /gleam/env/env_gleam_stage2.py: -------------------------------------------------------------------------------- 1 | from isaacgym import gymapi 2 | from isaacgym.torch_utils import * 3 | from legged_gym import OPEN_ROBOT_ROOT_DIR 4 | import os 5 | import torch 6 | import numpy as np 7 | from gleam.env.env_gleam_base import Env_GLEAM_Base 8 | from gleam.env.env_gleam_stage1 import Env_GLEAM_Stage1 9 | 10 | 11 | class Env_GLEAM_Stage2(Env_GLEAM_Stage1): 12 | def __init__(self, *args, **kwargs): 13 | """ 14 | Training set of stage 2 includes 512 indoor scenes from: 15 | procthor_2-bed-2-bath_256, 16 | procthor_7-room-3-bed_96, 17 | procthor_12-room_64, 18 | gibson_96. 19 | """ 20 | self.visualize_flag = False # training 21 | # self.visualize_flag = True # visualization 22 | 23 | # self.num_scene = 4 # debug 24 | # print("*"*50, "num_scene: ", self.num_scene, "*"*50) 25 | 26 | self.num_scene = 512 27 | 28 | super(Env_GLEAM_Base, self).__init__(*args, **kwargs) 29 | 30 | def _additional_create(self, env_handle, env_index): 31 | """ 32 | If there are N training scenes and M environments, each environment will load N//M scenes. 33 | Only the first scene (idx == 0) is active, and the others are inactive. 34 | """ 35 | assert self.cfg.return_visual_observation, "visual observation should be returned!" 36 | assert self.num_scene >= self.num_envs, "num_scene should be larger than num_envs" 37 | 38 | # urdf load, create actor 39 | dataset_name = "stage2_512" 40 | urdf_path = f"data_gleam/train_{dataset_name}/urdf" 41 | 42 | scene_per_env = self.num_scene // self.num_envs 43 | 44 | asset_options = gymapi.AssetOptions() 45 | asset_options.flip_visual_attachments = self.cfg.asset.flip_visual_attachments 46 | asset_options.fix_base_link = True 47 | asset_options.disable_gravity = True 48 | 49 | inactive_x = self.env_origins[:, 0].max() + 15. 50 | inactive_y = self.env_origins[:, 1].max() + 15. 51 | inactive_z = self.env_origins[:, 2].max() 52 | 53 | for idx in range(scene_per_env): 54 | scene_idx = env_index * scene_per_env + idx 55 | urdf_name = f"scene_{scene_idx}.urdf" 56 | asset = self.gym.load_asset(self.sim, urdf_path, urdf_name, asset_options) 57 | 58 | pose = gymapi.Transform() 59 | if idx == 0: # Only the first scene (idx == 0) is active 60 | pose.p = gymapi.Vec3(self.env_origins[env_index][0], 61 | self.env_origins[env_index][1], 62 | self.env_origins[env_index][2]) 63 | else: 64 | pose.p = gymapi.Vec3(inactive_x, inactive_y, inactive_z) 65 | pose.r = gymapi.Quat(-np.pi/2, 0.0, 0.0, np.pi/2) 66 | self.gym.create_actor(env_handle, asset, pose, None, env_index, 0) 67 | 68 | self.additional_actors[env_index] = [i+1 for i in range(scene_per_env)] 69 | 70 | def _init_load_all(self): 71 | """ 72 | Load all ground truth data. 73 | """ 74 | self.grid_size = 128 75 | self.motion_height = 1.5 # 1.5m 76 | dataset_name = "stage2_512" 77 | gt_path = os.path.join(OPEN_ROBOT_ROOT_DIR, f"data_gleam/train_{dataset_name}/gt/") 78 | 79 | # [num_scene, 3] 80 | self.voxel_size_gt = torch.load(gt_path+f"{dataset_name}_{self.grid_size}_voxel_size_gt.pt", 81 | map_location=self.device)[:self.num_scene] 82 | 83 | # [num_scene, 6], (x_max, x_min, y_max, y_min, z_max, z_min) 84 | self.range_gt = torch.load(gt_path+f"{dataset_name}_{self.grid_size}_range_gt.pt", 85 | map_location=self.device)[:self.num_scene] 86 | 87 | # [num_scene, grid_size, grid_size], layout map at the height of {self.motion height} 88 | self.layout_maps_height = torch.load(gt_path+f"{dataset_name}_{self.grid_size}_occ_map_height_1d5_gt.pt", 89 | map_location=self.device)[:self.num_scene].to(torch.float16) 90 | self.layout_maps_height /= 255. 91 | 92 | # [num_scene] 93 | self.num_valid_pixel_gt = self.layout_maps_height.sum(dim=(1, 2)) 94 | 95 | # [num_scene, 128, 128] 96 | init_maps = torch.load(gt_path+f"{dataset_name}_{self.grid_size}_init_map_1d5.pt", 97 | map_location=self.device)[:self.num_scene] 98 | init_maps /= 255. 99 | 100 | # len() == self.num_scene, the shape of element is [num_non_zero_pixel, 2] 101 | self.init_maps_list = [(torch.nonzero(init_maps[idx]) / (self.grid_size - 1) * 2 - 1)\ 102 | * self.range_gt[idx, :4:2] 103 | for idx in range(self.num_scene)] 104 | 105 | print("Loaded all ground truth data.") 106 | 107 | def _init_buffers(self): 108 | super()._init_buffers() 109 | 110 | # visualization 111 | if self.visualize_flag: 112 | # visualize scenes 113 | self.vis_obj_idx = 0 114 | print("Visualization object index: ", self.vis_obj_idx) 115 | 116 | self.reset_once_flag = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) 117 | self.reset_once_cr = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) 118 | self.local_paths = [[] for _ in range(self.num_envs)] 119 | 120 | self.save_path = f'./gleam/output/train_stage2_512' 121 | os.makedirs(self.save_path, exist_ok=True) 122 | -------------------------------------------------------------------------------- /gleam/network/encoder.py: -------------------------------------------------------------------------------- 1 | # from typing import Any, Dict, List, Optional, Type, Union, Tuple 2 | from typing import List, Tuple 3 | 4 | import gym 5 | import torch as th 6 | from gleam.network.base import LocoTransformerEncoder_Map 7 | from gleam.network.locotransformer import LocoTransformer_GLEAM 8 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 9 | 10 | 11 | class Encoder_GLEAM(BaseFeaturesExtractor): 12 | """ 13 | We adapt the LocoTransformer to encode raw observations, including: 14 | - state input: historical poses within this episode. 15 | - visual input: current egocentric map, with the shape of (128, 128). 16 | """ 17 | 18 | def __init__( 19 | self, 20 | observation_space: gym.spaces.Space, 21 | encoder_param=None, 22 | net_param=None, 23 | visual_input_shape=None, 24 | state_input_shape=None, 25 | ): 26 | assert encoder_param is not None, "Need parameters !" 27 | assert net_param is not None, "Need parameters !" 28 | assert isinstance(visual_input_shape, List) or isinstance(visual_input_shape, Tuple), "Use tuple or list" 29 | assert isinstance(state_input_shape, List) or isinstance(state_input_shape, Tuple), "Use tuple or list" 30 | self.map_channel = visual_input_shape[0] 31 | self.map_shape = visual_input_shape[1:] 32 | self.state_input_shape = state_input_shape 33 | feature_dim = net_param["append_hidden_shapes"][-1] 34 | net_param["append_hidden_shapes"].pop() 35 | super(Encoder_GLEAM, self).__init__(observation_space, feature_dim) 36 | 37 | # create encoder, share encoder 38 | self.encoder = LocoTransformerEncoder_Map( 39 | in_channels=self.map_channel, 40 | state_input_dim=self.state_input_shape[0], 41 | **encoder_param 42 | ) 43 | 44 | self.locotransformer = LocoTransformer_GLEAM( 45 | encoder=self.encoder, 46 | state_input_shape=self.state_input_shape[0], 47 | visual_input_shape=(self.map_channel, *self.map_shape), 48 | output_shape=feature_dim, 49 | **net_param 50 | ) 51 | 52 | def forward(self, observations: th.Tensor) -> th.Tensor: 53 | """ 54 | observations: [num_env, concat(state_input, visual_input)] 55 | """ 56 | return self.locotransformer.forward(observations) -------------------------------------------------------------------------------- /gleam/network/init.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credit to legged_gym and TorchRL. 3 | """ 4 | 5 | import numpy as np 6 | import torch.nn as nn 7 | 8 | 9 | def _fanin_init(tensor, alpha=0): 10 | size = tensor.size() 11 | if len(size) == 2: 12 | fan_in = size[0] 13 | elif len(size) > 2: 14 | fan_in = np.prod(size[1:]) 15 | else: 16 | raise Exception("Shape must be have dimension at least 2.") 17 | # bound = 1. / np.sqrt(fan_in) 18 | bound = np.sqrt(1. / ((1 + alpha * alpha) * fan_in)) 19 | return tensor.data.uniform_(-bound, bound) 20 | 21 | 22 | def _uniform_init(tensor, param=3e-3): 23 | return tensor.data.uniform_(-param, param) 24 | 25 | 26 | def _constant_bias_init(tensor, constant=0.1): 27 | tensor.data.fill_(constant) 28 | 29 | 30 | def layer_init(layer, weight_init=_fanin_init, bias_init=_constant_bias_init): 31 | weight_init(layer.weight) 32 | bias_init(layer.bias) 33 | 34 | 35 | def basic_init(layer): 36 | layer_init(layer, weight_init=_fanin_init, bias_init=_constant_bias_init) 37 | 38 | 39 | def uniform_init(layer): 40 | layer_init(layer, weight_init=_uniform_init, bias_init=_uniform_init) 41 | 42 | 43 | def _orthogonal_init(tensor, gain=np.sqrt(2)): 44 | nn.init.orthogonal_(tensor, gain=gain) 45 | 46 | 47 | def orthogonal_init(layer, scale=np.sqrt(2), constant=0): 48 | layer_init( 49 | layer, weight_init=lambda x: _orthogonal_init(x, gain=scale), bias_init=lambda x: _constant_bias_init(x, 0) 50 | ) 51 | -------------------------------------------------------------------------------- /gleam/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/gleam/train/__init__.py -------------------------------------------------------------------------------- /gleam/utils/bfs_cuda_2D.h: -------------------------------------------------------------------------------- 1 | #ifndef BFS_CUDA_H 2 | #define BFS_CUDA_H 3 | 4 | void BFS_kernel_launcher_2D( 5 | const float* occupancy_maps, 6 | const int* starts, 7 | const int* goals, 8 | float* path_lengths, 9 | int num_env, 10 | int H, 11 | int W 12 | ); 13 | 14 | #endif -------------------------------------------------------------------------------- /gleam/utils/bfs_cuda_kernel_2D.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "bfs_cuda_2D.h" 5 | 6 | #define MAX_PATH_LENGTH 16384 // 128 * 128 7 | 8 | __global__ void BFS_kernel_2D( 9 | const float* __restrict__ occupancy_maps, // Grid map (1 = walkable, 0 = obstacle) 10 | const int* __restrict__ starts, 11 | const int* __restrict__ goals, 12 | float* __restrict__ path_lengths, 13 | int num_env, 14 | int H, 15 | int W) 16 | { 17 | int env_id = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (env_id >= num_env) return; 19 | 20 | // Compute offsets for this environment 21 | const int occupancy_offset = env_id * H * W; 22 | 23 | // Extract start and goal positions 24 | int start_x = starts[env_id * 2]; 25 | int start_y = starts[env_id * 2 + 1]; 26 | int goal_x = goals[env_id * 2]; 27 | int goal_y = goals[env_id * 2 + 1]; 28 | 29 | // Use global memory for visited and queue 30 | bool* visited = new bool[H * W]; 31 | int* distances = new int[H * W]; 32 | int* queue = new int[MAX_PATH_LENGTH]; 33 | 34 | // Initialize arrays 35 | for (int i = 0; i < H * W; i++) { 36 | visited[i] = false; 37 | distances[i] = 0; 38 | } 39 | 40 | // Initialize BFS 41 | int front = 0; 42 | int back = 1; 43 | int start_pos = start_x * W + start_y; 44 | queue[0] = start_pos; 45 | visited[start_pos] = true; 46 | distances[start_pos] = 1; 47 | bool found = false; 48 | float path_length = -1.0f; 49 | 50 | // BFS loop - stop if path length exceeds MAX_PATH_LENGTH 51 | while (front < back && back < MAX_PATH_LENGTH && !found) { 52 | int current_pos = queue[front]; 53 | int current_x = current_pos / W; 54 | int current_y = current_pos % W; 55 | 56 | // If current distance is already MAX_PATH_LENGTH, break 57 | if (distances[current_pos] >= MAX_PATH_LENGTH) { 58 | break; 59 | } 60 | 61 | if (current_x == goal_x && current_y == goal_y) { 62 | found = true; 63 | path_length = (float)distances[current_pos]; 64 | break; 65 | } 66 | 67 | // // Check 8-connected neighbors (diagonal movement allowed) 68 | // const int dx[8] = {-1, 0, 1, 0, 1, 1, -1, -1}; 69 | // const int dy[8] = {0, 1, 0, -1, 1, -1, 1, -1}; 70 | 71 | // Check 4-connected neighbors (no diagonal movement) 72 | const int dx[4] = {-1, 0, 1, 0}; 73 | const int dy[4] = {0, 1, 0, -1}; 74 | 75 | for (int i = 0; i < 4; i++) { 76 | int nx = current_x + dx[i]; 77 | int ny = current_y + dy[i]; 78 | 79 | if (nx >= 0 && nx < H && ny >= 0 && ny < W) { 80 | int neighbor_pos = nx * W + ny; 81 | if (occupancy_maps[occupancy_offset + neighbor_pos] == 1 && !visited[neighbor_pos]) { 82 | queue[back] = neighbor_pos; 83 | visited[neighbor_pos] = true; 84 | distances[neighbor_pos] = distances[current_pos] + 1; 85 | back++; 86 | if (back >= MAX_PATH_LENGTH) break; 87 | } 88 | } 89 | } 90 | front++; 91 | } 92 | 93 | path_lengths[env_id] = path_length; 94 | 95 | // Clean up 96 | delete[] queue; 97 | delete[] visited; 98 | delete[] distances; 99 | } 100 | 101 | void BFS_kernel_launcher_2D( 102 | const float* occupancy_maps, 103 | const int* starts, 104 | const int* goals, 105 | float* path_lengths, 106 | int num_env, 107 | int H, 108 | int W) 109 | { 110 | int threads = 256; 111 | int blocks = (num_env + threads - 1) / threads; 112 | BFS_kernel_2D<<>>( 113 | occupancy_maps, 114 | starts, 115 | goals, 116 | path_lengths, 117 | num_env, 118 | H, 119 | W 120 | ); 121 | } -------------------------------------------------------------------------------- /gleam/wrapper/env_wrapper_gleam.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """An env wrapper that flattens the observation dictionary to an array.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import gym 21 | import collections 22 | 23 | import torch 24 | from gym import spaces 25 | import numpy as np 26 | 27 | def flatten_observations(observation_dict, key_sequence): 28 | """Flattens the observation dictionary to an array. 29 | 30 | If observation_excluded is passed in, it will still return a dictionary, 31 | which includes all the (key, observation_dict[key]) in observation_excluded, 32 | and ('other': the flattened array). 33 | 34 | Args: 35 | observation_dict: A dictionary of all the observations. 36 | key_sequence: A list/tuple of all the keys of the observations to be 37 | added during flattening. 38 | 39 | Returns: 40 | An array or a dictionary of observations based on whether 41 | observation_excluded is empty. 42 | """ 43 | observations = [] 44 | for key in key_sequence: 45 | value = observation_dict[key] 46 | observations.append(value) 47 | 48 | flat_observations = torch.concat(observations, dim=-1) # grid observation 49 | return flat_observations 50 | 51 | 52 | def flatten_observation_spaces(observation_spaces, key_sequence): 53 | """Flattens the dictionary observation spaces to gym.spaces.Box. 54 | 55 | If observation_excluded is passed in, it will still return a dictionary, 56 | which includes all the (key, observation_spaces[key]) in observation_excluded, 57 | and ('other': the flattened Box space). 58 | 59 | Args: 60 | observation_spaces: A dictionary of all the observation spaces. 61 | key_sequence: A list/tuple of all the keys of the observations to be 62 | added during flattening. 63 | 64 | Returns: 65 | A box space or a dictionary of observation spaces based on whether 66 | observation_excluded is empty. 67 | """ 68 | assert isinstance(key_sequence, list) 69 | lower_bound = [] 70 | upper_bound = [] 71 | for key in key_sequence: 72 | value = observation_spaces.spaces[key] 73 | if isinstance(value, spaces.Box): 74 | lower_bound.append(np.asarray(value.low).flatten()) 75 | upper_bound.append(np.asarray(value.high).flatten()) 76 | else: 77 | continue 78 | 79 | lower_bound = np.concatenate(lower_bound) 80 | upper_bound = np.concatenate(upper_bound) 81 | observation_space = spaces.Box(np.array(lower_bound, dtype=np.float32), np.array(upper_bound, dtype=np.float32), dtype=np.float32) 82 | return observation_space 83 | 84 | 85 | class EnvWrapperGLEAM(gym.Env): 86 | """An env wrapper that flattens the observation dictionary to an array.""" 87 | def __init__(self, gym_env, observation_excluded=()): 88 | """Initializes the wrapper.""" 89 | self.observation_excluded = observation_excluded 90 | self._gym_env = gym_env 91 | self.observation_space = self._flatten_observation_spaces(self._gym_env.observation_space) 92 | self.action_space = self._gym_env.action_space 93 | 94 | def __getattr__(self, attr): 95 | return getattr(self._gym_env, attr) 96 | 97 | def _flatten_observation_spaces(self, observation_spaces): 98 | flat_observation_space = flatten_observation_spaces( 99 | observation_spaces=observation_spaces, key_sequence=["state", "ego_map_2D"] 100 | ) 101 | return flat_observation_space 102 | 103 | def _flatten_observation(self, input_observation): 104 | """Flatten the dictionary to an array.""" 105 | return flatten_observations(observation_dict=input_observation, key_sequence=["state", "ego_map_2D"]) 106 | 107 | def reset(self): 108 | observation = self._gym_env.reset() 109 | return self._flatten_observation(observation) 110 | 111 | def step(self, action): 112 | """Steps the wrapped environment. 113 | 114 | Args: 115 | action: Numpy array. The input action from an NN agent. 116 | 117 | Returns: 118 | The tuple containing the flattened observation, the reward, the epsiode 119 | end indicator. 120 | """ 121 | observation_dict, reward, done, _ = self._gym_env.step(action) 122 | return self._flatten_observation(observation_dict), reward, done, _ 123 | 124 | def render(self, mode='human'): 125 | return self._gym_env.render(mode) 126 | 127 | def close(self): 128 | self._gym_env.close() 129 | -------------------------------------------------------------------------------- /legged_gym/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import os 32 | 33 | OPEN_ROBOT_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 34 | LEGGED_GYM_ENVS_DIR = os.path.join(OPEN_ROBOT_ROOT_DIR, 'legged_gym', 'envs') 35 | -------------------------------------------------------------------------------- /legged_gym/env/__init__.py: -------------------------------------------------------------------------------- 1 | # # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # # SPDX-License-Identifier: BSD-3-Clause 3 | # # 4 | # # Redistribution and use in source and binary forms, with or without 5 | # # modification, are permitted provided that the following conditions are met: 6 | # # 7 | # # 1. Redistributions of source code must retain the above copyright notice, this 8 | # # list of conditions and the following disclaimer. 9 | # # 10 | # # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # # this list of conditions and the following disclaimer in the documentation 12 | # # and/or other materials provided with the distribution. 13 | # # 14 | # # 3. Neither the name of the copyright holder nor the names of its 15 | # # contributors may be used to endorse or promote products derived from 16 | # # this software without specific prior written permission. 17 | # # 18 | # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # # 29 | # # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | 32 | # from legged_gym import OPEN_ROBOT_ROOT_DIR, LEGGED_GYM_ENVS_DIR 33 | # from legged_gym.env.a1.a1_config import A1RoughCfg, A1RoughCfgPPO 34 | # from .base.legged_robot import LeggedRobot 35 | # from .anymal_c.anymal import Anymal 36 | # from .anymal_c.mixed_terrains.anymal_c_rough_config import AnymalCRoughCfg, AnymalCRoughCfgPPO 37 | # from .anymal_c.flat.anymal_c_flat_config import AnymalCFlatCfg, AnymalCFlatCfgPPO 38 | # from .anymal_b.anymal_b_config import AnymalBRoughCfg, AnymalBRoughCfgPPO 39 | # from .cassie.cassie import Cassie 40 | # from .cassie.cassie_config import CassieRoughCfg, CassieRoughCfgPPO 41 | # from .a1.a1_config import A1RoughCfg, A1RoughCfgPPO 42 | 43 | # from legged_gym.utils.task_registry import task_registry 44 | 45 | # task_registry.register("anymal_c_rough", Anymal, AnymalCRoughCfg(), AnymalCRoughCfgPPO()) 46 | # task_registry.register("anymal_c_flat", Anymal, AnymalCFlatCfg(), AnymalCFlatCfgPPO()) 47 | # task_registry.register("anymal_b", Anymal, AnymalBRoughCfg(), AnymalBRoughCfgPPO()) 48 | # task_registry.register("a1", LeggedRobot, A1RoughCfg(), A1RoughCfgPPO()) 49 | # task_registry.register("cassie", Cassie, CassieRoughCfg(), CassieRoughCfgPPO()) 50 | -------------------------------------------------------------------------------- /legged_gym/env/a1/a1_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from legged_gym.env.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO 32 | 33 | 34 | class A1RoughCfg(LeggedRobotCfg): 35 | class init_state(LeggedRobotCfg.init_state): 36 | pos = [0.0, 0.0, 0.42] # x,y,z [m] 37 | default_joint_angles = { # = target angles [rad] when action = 0.0 38 | 'FL_hip_joint': 0.1, # [rad] 39 | 'RL_hip_joint': 0.1, # [rad] 40 | 'FR_hip_joint': -0.1 , # [rad] 41 | 'RR_hip_joint': -0.1, # [rad] 42 | 43 | 'FL_thigh_joint': 0.8, # [rad] 44 | 'RL_thigh_joint': 1., # [rad] 45 | 'FR_thigh_joint': 0.8, # [rad] 46 | 'RR_thigh_joint': 1., # [rad] 47 | 48 | 'FL_calf_joint': -1.5, # [rad] 49 | 'RL_calf_joint': -1.5, # [rad] 50 | 'FR_calf_joint': -1.5, # [rad] 51 | 'RR_calf_joint': -1.5, # [rad] 52 | } 53 | 54 | class control(LeggedRobotCfg.control): 55 | # PD Drive parameters: 56 | control_type = 'P' 57 | stiffness = {'joint': 20.} # [N*m/rad] 58 | damping = {'joint': 0.5} # [N*m*s/rad] 59 | # action scale: target angle = actionScale * action + defaultAngle 60 | action_scale = 0.25 61 | # decimation: Number of control action updates @ sim DT per policy DT 62 | decimation = 4 63 | 64 | class asset(LeggedRobotCfg.asset): 65 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/a1/urdf/a1.urdf' 66 | name = "a1" 67 | foot_name = "foot" 68 | penalize_contacts_on = ["thigh", "calf"] 69 | terminate_after_contacts_on = ["base"] 70 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter 71 | 72 | class rewards(LeggedRobotCfg.rewards): 73 | soft_dof_pos_limit = 0.9 74 | base_height_target = 0.25 75 | 76 | class scales(LeggedRobotCfg.rewards.scales): 77 | torques = -0.0002 78 | dof_pos_limits = -10.0 79 | 80 | 81 | class A1RoughCfgPPO(LeggedRobotCfgPPO): 82 | class algorithm(LeggedRobotCfgPPO.algorithm): 83 | entropy_coef = 0.01 84 | 85 | class runner(LeggedRobotCfgPPO.runner): 86 | run_name = '' 87 | experiment_name = 'rough_a1' 88 | -------------------------------------------------------------------------------- /legged_gym/env/anymal_b/anymal_b_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from legged_gym.env import AnymalCRoughCfg, AnymalCRoughCfgPPO 32 | 33 | 34 | class AnymalBRoughCfg(AnymalCRoughCfg): 35 | class asset(AnymalCRoughCfg.asset): 36 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/anymal_b/urdf/anymal_b.urdf' 37 | name = "anymal_b" 38 | foot_name = 'FOOT' 39 | 40 | class rewards(AnymalCRoughCfg.rewards): 41 | class scales(AnymalCRoughCfg.rewards.scales): 42 | pass 43 | 44 | 45 | class AnymalBRoughCfgPPO(AnymalCRoughCfgPPO): 46 | class runner(AnymalCRoughCfgPPO.runner): 47 | run_name = '' 48 | experiment_name = 'rough_anymal_b' 49 | load_run = -1 50 | -------------------------------------------------------------------------------- /legged_gym/env/anymal_c/anymal.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from time import time 32 | import numpy as np 33 | import os 34 | 35 | from isaacgym.torch_utils import * 36 | from isaacgym import gymtorch, gymapi, gymutil 37 | 38 | import torch 39 | # from torch.tensor import Tensor 40 | from typing import Tuple, Dict 41 | 42 | from legged_gym.env import LeggedRobot 43 | from legged_gym import OPEN_ROBOT_ROOT_DIR 44 | from .mixed_terrains.anymal_c_rough_config import AnymalCRoughCfg 45 | 46 | 47 | class Anymal(LeggedRobot): 48 | cfg: AnymalCRoughCfg 49 | 50 | def __init__(self, cfg, sim_params, physics_engine, sim_device, headless): 51 | super().__init__(cfg, sim_params, physics_engine, sim_device, headless) 52 | 53 | # load actuator network 54 | if self.cfg.control.use_actuator_network: 55 | actuator_network_path = self.cfg.control.actuator_net_file.format(LEGGED_GYM_ROOT_DIR=OPEN_ROBOT_ROOT_DIR) 56 | self.actuator_network = torch.jit.load(actuator_network_path).to(self.device) 57 | 58 | def reset_idx(self, env_ids): 59 | super().reset_idx(env_ids) 60 | # Additionaly empty actuator network hidden states 61 | self.sea_hidden_state_per_env[:, env_ids] = 0. 62 | self.sea_cell_state_per_env[:, env_ids] = 0. 63 | 64 | def _init_buffers(self): 65 | super()._init_buffers() 66 | # Additionally initialize actuator network hidden state tensors 67 | self.sea_input = torch.zeros(self.num_envs * self.num_actions, 1, 2, device=self.device, requires_grad=False) 68 | self.sea_hidden_state = torch.zeros( 69 | 2, self.num_envs * self.num_actions, 8, device=self.device, requires_grad=False 70 | ) 71 | self.sea_cell_state = torch.zeros( 72 | 2, self.num_envs * self.num_actions, 8, device=self.device, requires_grad=False 73 | ) 74 | self.sea_hidden_state_per_env = self.sea_hidden_state.view(2, self.num_envs, self.num_actions, 8) 75 | self.sea_cell_state_per_env = self.sea_cell_state.view(2, self.num_envs, self.num_actions, 8) 76 | 77 | def _compute_torques(self, actions): 78 | # Choose between pd controller and actuator network 79 | if self.cfg.control.use_actuator_network: 80 | with torch.inference_mode(): 81 | self.sea_input[:, 0, 82 | 0] = (actions * self.cfg.control.action_scale + self.default_dof_pos - 83 | self.dof_pos).flatten() 84 | self.sea_input[:, 0, 1] = self.dof_vel.flatten() 85 | torques, (self.sea_hidden_state[:], self.sea_cell_state[:] 86 | ) = self.actuator_network(self.sea_input, (self.sea_hidden_state, self.sea_cell_state)) 87 | return torques 88 | else: 89 | # pd controller 90 | return super()._compute_torques(actions) 91 | -------------------------------------------------------------------------------- /legged_gym/env/anymal_c/flat/anymal_c_flat_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from legged_gym.env import AnymalCRoughCfg, AnymalCRoughCfgPPO 32 | 33 | 34 | class AnymalCFlatCfg(AnymalCRoughCfg): 35 | class env(AnymalCRoughCfg.env): 36 | num_observations = 48 37 | 38 | class terrain(AnymalCRoughCfg.terrain): 39 | mesh_type = 'plane' 40 | measure_heights = False 41 | 42 | class asset(AnymalCRoughCfg.asset): 43 | self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter 44 | 45 | class rewards(AnymalCRoughCfg.rewards): 46 | max_contact_force = 350. 47 | 48 | class scales(AnymalCRoughCfg.rewards.scales): 49 | orientation = -5.0 50 | torques = -0.000025 51 | feet_air_time = 2. 52 | # feet_contact_forces = -0.01 53 | 54 | class commands(AnymalCRoughCfg.commands): 55 | heading_command = False 56 | resampling_time = 4. 57 | 58 | class ranges(AnymalCRoughCfg.commands.ranges): 59 | ang_vel_yaw = [-1.5, 1.5] 60 | 61 | class domain_rand(AnymalCRoughCfg.domain_rand): 62 | friction_range = [ 63 | 0., 1.5 64 | ] # on ground planes the friction combination mode is averaging, i.e total friction = (foot_friction + 1.)/2. 65 | 66 | 67 | class AnymalCFlatCfgPPO(AnymalCRoughCfgPPO): 68 | class policy(AnymalCRoughCfgPPO.policy): 69 | actor_hidden_dims = [128, 64, 32] 70 | critic_hidden_dims = [128, 64, 32] 71 | activation = 'elu' # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid 72 | 73 | class algorithm(AnymalCRoughCfgPPO.algorithm): 74 | entropy_coef = 0.01 75 | 76 | class runner(AnymalCRoughCfgPPO.runner): 77 | run_name = '' 78 | experiment_name = 'flat_anymal_c' 79 | load_run = -1 80 | max_iterations = 300 81 | -------------------------------------------------------------------------------- /legged_gym/env/anymal_c/mixed_terrains/anymal_c_rough_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from legged_gym.env.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO 32 | 33 | 34 | class AnymalCRoughCfg(LeggedRobotCfg): 35 | class env(LeggedRobotCfg.env): 36 | num_envs = 4096 37 | num_actions = 12 38 | 39 | class terrain(LeggedRobotCfg.terrain): 40 | mesh_type = 'trimesh' 41 | 42 | class init_state(LeggedRobotCfg.init_state): 43 | pos = [0.0, 0.0, 0.6] # x,y,z [m] 44 | default_joint_angles = { # = target angles [rad] when action = 0.0 45 | "LF_HAA": 0.0, 46 | "LH_HAA": 0.0, 47 | "RF_HAA": -0.0, 48 | "RH_HAA": -0.0, 49 | 50 | "LF_HFE": 0.4, 51 | "LH_HFE": -0.4, 52 | "RF_HFE": 0.4, 53 | "RH_HFE": -0.4, 54 | 55 | "LF_KFE": -0.8, 56 | "LH_KFE": 0.8, 57 | "RF_KFE": -0.8, 58 | "RH_KFE": 0.8, 59 | } 60 | 61 | class control(LeggedRobotCfg.control): 62 | # PD Drive parameters: 63 | stiffness = {'HAA': 80., 'HFE': 80., 'KFE': 80.} # [N*m/rad] 64 | damping = {'HAA': 2., 'HFE': 2., 'KFE': 2.} # [N*m*s/rad] 65 | # action scale: target angle = actionScale * action + defaultAngle 66 | action_scale = 0.5 67 | # decimation: Number of control action updates @ sim DT per policy DT 68 | decimation = 4 69 | use_actuator_network = True 70 | actuator_net_file = "{LEGGED_GYM_ROOT_DIR}/resources/actuator_nets/anydrive_v3_lstm.pt" 71 | 72 | class asset(LeggedRobotCfg.asset): 73 | file = "{LEGGED_GYM_ROOT_DIR}/resources/robots/anymal_c/urdf/anymal_c.urdf" 74 | name = "anymal_c" 75 | foot_name = "FOOT" 76 | penalize_contacts_on = ["SHANK", "THIGH"] 77 | terminate_after_contacts_on = ["base"] 78 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter 79 | 80 | class domain_rand(LeggedRobotCfg.domain_rand): 81 | randomize_base_mass = True 82 | added_mass_range = [-5., 5.] 83 | 84 | class rewards(LeggedRobotCfg.rewards): 85 | base_height_target = 0.5 86 | max_contact_force = 500. 87 | only_positive_rewards = True 88 | 89 | class scales(LeggedRobotCfg.rewards.scales): 90 | pass 91 | 92 | 93 | class AnymalCRoughCfgPPO(LeggedRobotCfgPPO): 94 | class runner(LeggedRobotCfgPPO.runner): 95 | run_name = '' 96 | experiment_name = 'rough_anymal_c' 97 | load_run = -1 98 | -------------------------------------------------------------------------------- /legged_gym/env/base/base_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import inspect 32 | 33 | 34 | class BaseConfig: 35 | def __init__(self) -> None: 36 | """ Initializes all member classes recursively. Ignores all namse starting with '__' (buit-in methods).""" 37 | self.init_member_classes(self) 38 | 39 | @staticmethod 40 | def init_member_classes(obj): 41 | # iterate over all attributes names 42 | for key in dir(obj): 43 | # disregard builtin attributes 44 | # if key.startswith("__"): 45 | if key == "__class__": 46 | continue 47 | # get the corresponding attribute object 48 | var = getattr(obj, key) 49 | # check if it the attribute is a class 50 | if inspect.isclass(var): 51 | # instantate the class 52 | i_var = var() 53 | # set the attribute to the instance instead of the type 54 | setattr(obj, key, i_var) 55 | # recursively init members of the attribute 56 | BaseConfig.init_member_classes(i_var) 57 | -------------------------------------------------------------------------------- /legged_gym/scripts/play.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from legged_gym import OPEN_ROBOT_ROOT_DIR 32 | import os 33 | 34 | import isaacgym 35 | from legged_gym.env import * 36 | from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger 37 | 38 | import numpy as np 39 | import torch 40 | 41 | 42 | def play(args): 43 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) 44 | env_cfg.return_visual_observation = False 45 | # override some parameters for testing 46 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1) 47 | env_cfg.terrain.num_rows = 1 48 | env_cfg.terrain.num_cols = 1 49 | env_cfg.terrain.curriculum = False 50 | env_cfg.terrain.mesh_type = "plane" 51 | env_cfg.noise.add_noise = False 52 | env_cfg.domain_rand.randomize_friction = False 53 | env_cfg.domain_rand.push_robots = False 54 | 55 | # prepare environment 56 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) 57 | obs = env.get_observations() 58 | # load policy 59 | train_cfg.runner.resume = True 60 | train_cfg.runner.num_steps_per_env = 1 61 | ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args, train_cfg=train_cfg) 62 | policy = ppo_runner.get_inference_policy(device=env.device) 63 | 64 | # export policy as a jit module (used to run it from C++) 65 | if EXPORT_POLICY: 66 | path = os.path.join(OPEN_ROBOT_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'policies') 67 | export_policy_as_jit(ppo_runner.alg.actor_critic, path) 68 | print('Exported policy as jit script to: ', path) 69 | 70 | logger = Logger(env.dt) 71 | robot_index = 0 # which robot is used for logging 72 | joint_index = 1 # which joint is used for logging 73 | stop_state_log = 100 # number of steps before plotting states 74 | stop_rew_log = env.max_episode_length + 1 # number of steps before print average episode rewards 75 | camera_position = np.array(env_cfg.viewer.pos, dtype=np.float64) 76 | camera_vel = np.array([1., 1., 0.]) 77 | camera_direction = np.array(env_cfg.viewer.lookat) - np.array(env_cfg.viewer.pos) 78 | img_idx = 0 79 | 80 | for i in range(10 * int(env.max_episode_length)): 81 | actions = policy(obs.detach()) 82 | obs, _, rews, dones, infos = env.step(actions.detach()) 83 | if RECORD_FRAMES: 84 | if i % 2: 85 | filename = os.path.join( 86 | OPEN_ROBOT_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'frames', 87 | f"{img_idx}.png" 88 | ) 89 | env.gym.write_viewer_image_to_file(env.viewer, filename) 90 | img_idx += 1 91 | if MOVE_CAMERA: 92 | camera_position += camera_vel * env.dt 93 | env.set_camera(camera_position, camera_position + camera_direction) 94 | 95 | if i < stop_state_log: 96 | logger.log_states( 97 | { 98 | 'dof_pos_target': actions[robot_index, joint_index].item() * env.cfg.control.action_scale, 99 | 'dof_pos': env.dof_pos[robot_index, joint_index].item(), 100 | 'dof_vel': env.dof_vel[robot_index, joint_index].item(), 101 | 'dof_torque': env.torques[robot_index, joint_index].item(), 102 | 'command_x': env.commands[robot_index, 0].item(), 103 | 'command_y': env.commands[robot_index, 1].item(), 104 | 'command_yaw': env.commands[robot_index, 2].item(), 105 | 'base_vel_x': env.base_lin_vel[robot_index, 0].item(), 106 | 'base_vel_y': env.base_lin_vel[robot_index, 1].item(), 107 | 'base_vel_z': env.base_lin_vel[robot_index, 2].item(), 108 | 'base_vel_yaw': env.base_ang_vel[robot_index, 2].item(), 109 | 'contact_forces_z': env.contact_forces[robot_index, env.feet_indices, 2].cpu().numpy() 110 | } 111 | ) 112 | elif i == stop_state_log: 113 | logger.plot_states() 114 | if 0 < i < stop_rew_log: 115 | if infos["episode"]: 116 | num_episodes = torch.sum(env.reset_buf).item() 117 | if num_episodes > 0: 118 | logger.log_rewards(infos["episode"], num_episodes) 119 | elif i == stop_rew_log: 120 | logger.print_rewards() 121 | 122 | 123 | if __name__ == '__main__': 124 | import legged_complex_env 125 | EXPORT_POLICY = True 126 | RECORD_FRAMES = False 127 | MOVE_CAMERA = False 128 | args = get_args() 129 | args.task = "a1" 130 | play(args) 131 | -------------------------------------------------------------------------------- /legged_gym/scripts/train.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import numpy as np 32 | import os 33 | from datetime import datetime 34 | 35 | import isaacgym 36 | from legged_gym.env import * 37 | from legged_gym.utils import get_args, task_registry 38 | import torch 39 | 40 | 41 | def train(args): 42 | env, env_cfg = task_registry.make_env(name=args.task, args=args) 43 | ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args) 44 | ppo_runner.learn(num_learning_iterations=train_cfg.runner.max_iterations, init_at_random_ep_len=True) 45 | 46 | 47 | if __name__ == '__main__': 48 | args = get_args() 49 | train(args) 50 | -------------------------------------------------------------------------------- /legged_gym/tests/test_env.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import numpy as np 32 | import os 33 | from datetime import datetime 34 | 35 | import isaacgym 36 | from legged_gym.env import * 37 | from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger 38 | 39 | import torch 40 | 41 | 42 | def test_env(args): 43 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) 44 | # override some parameters for testing 45 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 2) 46 | 47 | # prepare environment 48 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) 49 | for i in range(int(10 * env.max_episode_length)): 50 | actions = 0. * torch.ones(env.num_envs, env.num_actions, device=env.device) 51 | obs, _, rew, done, info = env.step(actions) 52 | print("Done") 53 | 54 | 55 | if __name__ == '__main__': 56 | args = get_args() 57 | test_env(args) 58 | -------------------------------------------------------------------------------- /legged_gym/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .helpers import class_to_dict, get_load_path, get_args, export_policy_as_jit, set_seed, update_class_from_dict 32 | from .task_registry import task_registry 33 | from .logger import Logger 34 | from .math import * 35 | from .terrain import Terrain 36 | -------------------------------------------------------------------------------- /legged_gym/utils/logger.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import matplotlib.pyplot as plt 32 | import numpy as np 33 | from collections import defaultdict 34 | from multiprocessing import Process, Value 35 | 36 | 37 | class Logger: 38 | def __init__(self, dt): 39 | self.state_log = defaultdict(list) 40 | self.rew_log = defaultdict(list) 41 | self.dt = dt 42 | self.num_episodes = 0 43 | self.plot_process = None 44 | 45 | def log_state(self, key, value): 46 | self.state_log[key].append(value) 47 | 48 | def log_states(self, dict): 49 | for key, value in dict.items(): 50 | self.log_state(key, value) 51 | 52 | def log_rewards(self, dict, num_episodes): 53 | for key, value in dict.items(): 54 | if 'rew' in key: 55 | self.rew_log[key].append(value.item() * num_episodes) 56 | self.num_episodes += num_episodes 57 | 58 | def reset(self): 59 | self.state_log.clear() 60 | self.rew_log.clear() 61 | 62 | def plot_states(self): 63 | self.plot_process = Process(target=self._plot) 64 | self.plot_process.start() 65 | 66 | def _plot(self): 67 | nb_rows = 3 68 | nb_cols = 3 69 | fig, axs = plt.subplots(nb_rows, nb_cols) 70 | for key, value in self.state_log.items(): 71 | time = np.linspace(0, len(value) * self.dt, len(value)) 72 | break 73 | log = self.state_log 74 | # plot joint targets and measured positions 75 | a = axs[1, 0] 76 | if log["dof_pos"]: 77 | a.plot(time, log["dof_pos"], label='measured') 78 | if log["dof_pos_target"]: 79 | a.plot(time, log["dof_pos_target"], label='target') 80 | a.set(xlabel='time [s]', ylabel='Position [rad]', title='DOF Position') 81 | a.legend() 82 | # plot joint velocity 83 | a = axs[1, 1] 84 | if log["dof_vel"]: 85 | a.plot(time, log["dof_vel"], label='measured') 86 | if log["dof_vel_target"]: 87 | a.plot(time, log["dof_vel_target"], label='target') 88 | a.set(xlabel='time [s]', ylabel='Velocity [rad/s]', title='Joint Velocity') 89 | a.legend() 90 | # plot base vel x 91 | a = axs[0, 0] 92 | if log["base_vel_x"]: 93 | a.plot(time, log["base_vel_x"], label='measured') 94 | if log["command_x"]: 95 | a.plot(time, log["command_x"], label='commanded') 96 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity x') 97 | a.legend() 98 | # plot base vel y 99 | a = axs[0, 1] 100 | if log["base_vel_y"]: 101 | a.plot(time, log["base_vel_y"], label='measured') 102 | if log["command_y"]: 103 | a.plot(time, log["command_y"], label='commanded') 104 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity y') 105 | a.legend() 106 | # plot base vel yaw 107 | a = axs[0, 2] 108 | if log["base_vel_yaw"]: 109 | a.plot(time, log["base_vel_yaw"], label='measured') 110 | if log["command_yaw"]: 111 | a.plot(time, log["command_yaw"], label='commanded') 112 | a.set(xlabel='time [s]', ylabel='base ang vel [rad/s]', title='Base velocity yaw') 113 | a.legend() 114 | # plot base vel z 115 | a = axs[1, 2] 116 | if log["base_vel_z"]: 117 | a.plot(time, log["base_vel_z"], label='measured') 118 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity z') 119 | a.legend() 120 | # plot contact forces 121 | a = axs[2, 0] 122 | if log["contact_forces_z"]: 123 | forces = np.array(log["contact_forces_z"]) 124 | for i in range(forces.shape[1]): 125 | a.plot(time, forces[:, i], label=f'force {i}') 126 | a.set(xlabel='time [s]', ylabel='Forces z [N]', title='Vertical Contact forces') 127 | a.legend() 128 | # plot torque/vel curves 129 | a = axs[2, 1] 130 | if log["dof_vel"] != [] and log["dof_torque"] != []: 131 | a.plot(log["dof_vel"], log["dof_torque"], 'x', label='measured') 132 | a.set(xlabel='Joint vel [rad/s]', ylabel='Joint Torque [Nm]', title='Torque/velocity curves') 133 | a.legend() 134 | # plot torques 135 | a = axs[2, 2] 136 | if log["dof_torque"] != []: 137 | a.plot(time, log["dof_torque"], label='measured') 138 | a.set(xlabel='time [s]', ylabel='Joint Torque [Nm]', title='Torque') 139 | a.legend() 140 | plt.show() 141 | 142 | def print_rewards(self): 143 | print("Average rewards per second:") 144 | for key, values in self.rew_log.items(): 145 | mean = np.sum(np.array(values)) / self.num_episodes 146 | print(f" - {key}: {mean}") 147 | print(f"Total number of episodes: {self.num_episodes}") 148 | 149 | def __del__(self): 150 | if self.plot_process is not None: 151 | self.plot_process.kill() 152 | -------------------------------------------------------------------------------- /legged_gym/utils/math.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import torch 32 | from torch import Tensor 33 | import numpy as np 34 | from isaacgym.torch_utils import quat_apply, normalize 35 | from typing import Tuple 36 | 37 | 38 | # @ torch.jit.script 39 | def quat_apply_yaw(quat, vec): 40 | quat_yaw = quat.clone().view(-1, 4) 41 | quat_yaw[:, :2] = 0. 42 | quat_yaw = normalize(quat_yaw) 43 | return quat_apply(quat_yaw, vec) 44 | 45 | 46 | # @ torch.jit.script 47 | def wrap_to_pi(angles): 48 | angles %= 2 * np.pi 49 | angles -= 2 * np.pi * (angles > np.pi) 50 | return angles 51 | 52 | 53 | # @ torch.jit.script 54 | def torch_rand_sqrt_float(lower, upper, shape, device): 55 | # type: (float, float, Tuple[int, int], str) -> Tensor 56 | r = 2 * torch.rand(*shape, device=device) - 1 57 | r = torch.where(r < 0., -torch.sqrt(-r), torch.sqrt(r)) 58 | r = (r + 1.) / 2. 59 | return (upper - lower) * r + lower 60 | -------------------------------------------------------------------------------- /licenses/assets/ANYmal_b_license.txt: -------------------------------------------------------------------------------- 1 | Copyright 2019 ANYbotics, https://www.anybotics.com 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, this 10 | list of conditions and the following disclaimer in the documentation and/or 11 | other materials provided with the distribution. 12 | 13 | 3. The name of ANYbotics and ANYmal may not be used to endorse or promote products 14 | derived from this software without specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 19 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 20 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 21 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 23 | WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /licenses/assets/ANYmal_c_license.txt: -------------------------------------------------------------------------------- 1 | Copyright 2020, ANYbotics AG. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in 12 | the documentation and/or other materials provided with the 13 | distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived 17 | from this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /licenses/assets/cassie_license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jenna Reher, jreher@caltech.edu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/dependencies/matplotlib_license.txt: -------------------------------------------------------------------------------- 1 | 1. This LICENSE AGREEMENT is between the Matplotlib Development Team ("MDT"), and the Individual or Organization ("Licensee") accessing and otherwise using matplotlib software in source or binary form and its associated documentation. 2 | 3 | 2. Subject to the terms and conditions of this License Agreement, MDT hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use matplotlib 3.4.3 alone or in any derivative version, provided, however, that MDT's License Agreement and MDT's notice of copyright, i.e., "Copyright (c) 2012-2013 Matplotlib Development Team; All Rights Reserved" are retained in matplotlib 3.4.3 alone or in any derivative version prepared by Licensee. 4 | 5 | 3. In the event Licensee prepares a derivative work that is based on or incorporates matplotlib 3.4.3 or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to matplotlib 3.4.3. 6 | 7 | 4. MDT is making matplotlib 3.4.3 available to Licensee on an "AS IS" basis. MDT MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, MDT MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB 3.4.3 WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 8 | 9 | 5. MDT SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB 3.4.3 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING MATPLOTLIB 3.4.3, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 10 | 11 | 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 12 | 13 | 7. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between MDT and Licensee. This License Agreement does not grant permission to use MDT trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 14 | 15 | 8. By copying, installing or otherwise using matplotlib 3.4.3, Licensee agrees to be bound by the terms and conditions of this License Agreement. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | addict==2.4.0 3 | aiohttp==3.8.3 4 | aiosignal==1.2.0 5 | appdirs==1.4.4 6 | asttokens==2.4.1 7 | async-timeout==4.0.2 8 | asyncio==3.4.3 9 | asynctest==0.13.0 10 | attrs==22.1.0 11 | backcall==0.2.0 12 | cachetools==5.2.0 13 | ccimport==0.4.2 14 | certifi==2024.2.2 15 | charset-normalizer==2.1.1 16 | click==8.1.3 17 | cloudpickle==2.2.0 18 | comm==0.2.2 19 | commonmark==0.9.1 20 | ConfigArgParse==1.5.3 21 | cumm==0.4.11 22 | cumm-cu113==0.4.11 23 | cycler==0.11.0 24 | Cython==0.29.33 25 | dash==2.17.0 26 | dash-core-components==2.0.0 27 | dash-html-components==2.0.0 28 | dash-table==5.0.0 29 | decorator==5.1.1 30 | disutils==1.4.32.post2 31 | docker-pycreds==0.4.0 32 | einops==0.8.1 33 | -e git+https://github.com/OpenRobotLab/EmbodiedScan.git@23b97029a9954d81575f28f79921c59a14495211#egg=embodiedscan 34 | executing==2.0.1 35 | fastjsonschema==2.19.1 36 | filelock==3.14.0 37 | fire==0.6.0 38 | Flask==2.2.2 39 | fonttools==4.38.0 40 | frozenlist==1.3.1 41 | fsspec==2024.5.0 42 | fvcore==0.1.5.post20221221 43 | gitdb==4.0.9 44 | GitPython==3.1.43 45 | google-auth==2.14.0 46 | google-auth-oauthlib==0.4.6 47 | grpcio==1.50.0 48 | gym==0.21.0 49 | gym-notices==0.0.8 50 | h5py==3.11.0 51 | huggingface-hub==0.23.0 52 | idna==3.4 53 | imageio==2.22.3 54 | imageio-ffmpeg==0.4.7 55 | importlib-metadata==4.13.0 56 | importlib_resources==6.4.0 57 | iopath==0.1.10 58 | ipython==8.12.3 59 | ipywidgets==8.1.2 60 | itsdangerous==2.1.2 61 | jedi==0.19.1 62 | Jinja2==3.1.2 63 | joblib==1.4.2 64 | jsonschema==4.17.3 65 | jupyter_core==5.7.2 66 | jupyterlab_widgets==3.0.10 67 | kiwisolver==1.4.4 68 | kornia==0.6.8 69 | lark==1.1.9 70 | lpips==0.1.4 71 | Mako==1.3.5 72 | Markdown==3.4.1 73 | MarkupSafe==2.1.1 74 | matplotlib==3.5.3 75 | matplotlib-inline==0.1.7 76 | MinkowskiEngine==0.5.4 77 | mmcv==2.0.0rc4 78 | mmdet==3.3.0 79 | mmengine==0.10.4 80 | mpmath==1.3.0 81 | multidict==6.0.2 82 | nbformat==5.5.0 83 | nest-asyncio==1.6.0 84 | networkx==2.6.3 85 | ninja==1.11.1 86 | numpy==1.23.5 87 | nvidia-cublas-cu12==12.1.3.1 88 | nvidia-cuda-cupti-cu12==12.1.105 89 | nvidia-cuda-nvrtc-cu12==12.1.105 90 | nvidia-cuda-runtime-cu12==12.1.105 91 | nvidia-cudnn-cu12==8.9.2.26 92 | nvidia-cufft-cu12==11.0.2.54 93 | nvidia-curand-cu12==10.3.2.106 94 | nvidia-cusolver-cu12==11.4.5.107 95 | nvidia-cusparse-cu12==12.1.0.106 96 | nvidia-ml-py==12.535.161 97 | nvidia-nccl-cu12==2.20.5 98 | nvidia-nvjitlink-cu12==12.4.127 99 | nvidia-nvtx-cu12==12.1.105 100 | nvitop==1.3.2 101 | oauthlib==3.2.2 102 | open3d==0.16.0 103 | opencv-python==4.9.0.80 104 | packaging==21.3 105 | pandas==1.3.5 106 | parso==0.8.4 107 | pathfinding==1.0.14 108 | pathtools==0.1.2 109 | pccm==0.4.11 110 | pexpect==4.9.0 111 | pickleshare==0.7.5 112 | Pillow==9.3.0 113 | pkgutil_resolve_name==1.3.10 114 | platformdirs==4.2.2 115 | plotly==5.22.0 116 | plyfile==0.7.4 117 | portalocker==2.8.2 118 | promise==2.3 119 | prompt-toolkit==3.0.43 120 | protobuf==3.19.6 121 | psutil==5.9.4 122 | ptyprocess==0.7.0 123 | pure-eval==0.2.2 124 | pyasn1==0.4.8 125 | pyasn1-modules==0.2.8 126 | pybind11==2.10.1 127 | pycocotools==2.0.7 128 | # pycuda==2024.1 129 | pycuda 130 | Pygments==2.13.0 131 | pyparsing==3.0.9 132 | pyquaternion==0.9.9 133 | pyrsistent==0.20.0 134 | python-dateutil==2.8.2 135 | pytools==2024.1.2 136 | -e git+https://github.com/facebookresearch/pytorch3d.git@17117106e4dd8269c02271462539c9c4a5d0d5ec#egg=pytorch3d 137 | pytz==2022.6 138 | PyWavelets==1.3.0 139 | PyYAML==6.0 140 | regex==2024.5.15 141 | requests==2.28.1 142 | requests-oauthlib==1.3.1 143 | retrying==1.3.4 144 | rich==12.6.0 145 | rsa==4.9 146 | safetensors==0.4.3 147 | scikit-image==0.19.3 148 | scikit-learn==1.3.2 149 | scipy==1.8.0 150 | sentry-sdk==2.19.0 151 | setproctitle==1.3.2 152 | shapely==2.0.4 153 | shortuuid==1.0.11 154 | six==1.16.0 155 | smmap==5.0.0 156 | spconv==2.3.6 157 | spconv-cu113==2.3.6 158 | stack-data==0.6.3 159 | sympy==1.12 160 | tabulate==0.9.0 161 | tenacity==8.3.0 162 | tensorboard==2.10.1 163 | tensorboard-data-server==0.6.1 164 | tensorboard-plugin-wit==1.8.1 165 | termcolor==2.4.0 166 | terminaltables==3.1.10 167 | threadpoolctl==3.5.0 168 | tifffile==2021.11.2 169 | tokenizers==0.19.1 170 | torch==1.11.0+cu113 171 | torchaudio==0.11.0+cu113 172 | torchvision==0.12.0+cu113 173 | tqdm==4.64.1 174 | traitlets==5.14.3 175 | transformers==4.40.2 176 | triton==2.3.0 177 | typing_extensions==4.4.0 178 | urllib3==1.26.12 179 | wandb==0.18.7 180 | wcwidth==0.2.13 181 | Werkzeug==2.2.2 182 | widgetsnbextension==4.0.10 183 | yacs==0.1.8 184 | yapf==0.30.0 185 | yarl==1.8.1 186 | zipp==3.10.0 187 | -------------------------------------------------------------------------------- /rsl_rl/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | -------------------------------------------------------------------------------- /rsl_rl/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .ppo import PPO 32 | -------------------------------------------------------------------------------- /rsl_rl/env/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .vec_env import VecEnv 32 | -------------------------------------------------------------------------------- /rsl_rl/env/vec_env.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from abc import ABC, abstractmethod 32 | import torch 33 | from typing import Tuple, Union 34 | from gym.spaces import Box 35 | 36 | 37 | # minimal interface of the environment 38 | class VecEnv(ABC): 39 | num_envs: int 40 | num_obs: int 41 | num_privileged_obs: int 42 | num_actions: int 43 | max_episode_length: int 44 | privileged_obs_buf: torch.Tensor 45 | obs_buf: torch.Tensor 46 | rew_buf: torch.Tensor 47 | reset_buf: torch.Tensor 48 | episode_length_buf: torch.Tensor # current episode duration 49 | extras: dict 50 | device: torch.device 51 | observation_space: Box 52 | action_space: Box 53 | 54 | @abstractmethod 55 | def step(self, 56 | actions: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]: 57 | pass 58 | 59 | @abstractmethod 60 | def reset(self, env_ids: Union[list, torch.Tensor]): 61 | pass 62 | 63 | @abstractmethod 64 | def get_observations(self) -> torch.Tensor: 65 | pass 66 | 67 | @abstractmethod 68 | def get_privileged_observations(self) -> Union[torch.Tensor, None]: 69 | pass 70 | -------------------------------------------------------------------------------- /rsl_rl/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .actor_critic import ActorCritic 32 | from .actor_critic_recurrent import ActorCriticRecurrent 33 | -------------------------------------------------------------------------------- /rsl_rl/modules/actor_critic.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import numpy as np 32 | 33 | import torch 34 | import torch.nn as nn 35 | from torch.distributions import Normal 36 | from torch.nn.modules import rnn 37 | 38 | 39 | class ActorCritic(nn.Module): 40 | is_recurrent = False 41 | 42 | def __init__( 43 | self, 44 | num_actor_obs, 45 | num_critic_obs, 46 | num_actions, 47 | actor_hidden_dims=[256, 256, 256], 48 | critic_hidden_dims=[256, 256, 256], 49 | activation='elu', 50 | init_noise_std=1.0, 51 | **kwargs 52 | ): 53 | if kwargs: 54 | print( 55 | "ActorCritic.__init__ got unexpected arguments, which will be ignored: " + 56 | str([key for key in kwargs.keys()]) 57 | ) 58 | super(ActorCritic, self).__init__() 59 | 60 | activation = get_activation(activation) 61 | 62 | mlp_input_dim_a = num_actor_obs 63 | mlp_input_dim_c = num_critic_obs 64 | 65 | # Policy 66 | actor_layers = [] 67 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 68 | actor_layers.append(activation) 69 | for l in range(len(actor_hidden_dims)): 70 | if l == len(actor_hidden_dims) - 1: 71 | actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions)) 72 | else: 73 | actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1])) 74 | actor_layers.append(activation) 75 | self.actor = nn.Sequential(*actor_layers) 76 | 77 | # Value function 78 | critic_layers = [] 79 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 80 | critic_layers.append(activation) 81 | for l in range(len(critic_hidden_dims)): 82 | if l == len(critic_hidden_dims) - 1: 83 | critic_layers.append(nn.Linear(critic_hidden_dims[l], 1)) 84 | else: 85 | critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1])) 86 | critic_layers.append(activation) 87 | self.critic = nn.Sequential(*critic_layers) 88 | 89 | print(f"Actor MLP: {self.actor}") 90 | print(f"Critic MLP: {self.critic}") 91 | 92 | # Action noise 93 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 94 | self.distribution = None 95 | # disable args validation for speedup 96 | Normal.set_default_validate_args = False 97 | 98 | # seems that we get better performance without init 99 | # self.init_memory_weights(self.memory_a, 0.001, 0.) 100 | # self.init_memory_weights(self.memory_c, 0.001, 0.) 101 | 102 | @staticmethod 103 | # not used at the moment 104 | def init_weights(sequential, scales): 105 | [ 106 | torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) 107 | for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear)) 108 | ] 109 | 110 | def reset(self, dones=None): 111 | pass 112 | 113 | def forward(self): 114 | raise NotImplementedError 115 | 116 | @property 117 | def action_mean(self): 118 | return self.distribution.mean 119 | 120 | @property 121 | def action_std(self): 122 | return self.distribution.stddev 123 | 124 | @property 125 | def entropy(self): 126 | return self.distribution.entropy().sum(dim=-1) 127 | 128 | def update_distribution(self, observations): 129 | mean = self.actor(observations) 130 | self.distribution = Normal(mean, mean * 0. + self.std) 131 | 132 | def act(self, observations, **kwargs): 133 | self.update_distribution(observations) 134 | return self.distribution.sample() 135 | 136 | def get_actions_log_prob(self, actions): 137 | return self.distribution.log_prob(actions).sum(dim=-1) 138 | 139 | def act_inference(self, observations): 140 | actions_mean = self.actor(observations) 141 | return actions_mean 142 | 143 | def evaluate(self, critic_observations, **kwargs): 144 | value = self.critic(critic_observations) 145 | return value 146 | 147 | 148 | def get_activation(act_name): 149 | if act_name == "elu": 150 | return nn.ELU() 151 | elif act_name == "selu": 152 | return nn.SELU() 153 | elif act_name == "relu": 154 | return nn.ReLU() 155 | elif act_name == "crelu": 156 | return nn.ReLU() 157 | elif act_name == "lrelu": 158 | return nn.LeakyReLU() 159 | elif act_name == "tanh": 160 | return nn.Tanh() 161 | elif act_name == "sigmoid": 162 | return nn.Sigmoid() 163 | else: 164 | print("invalid activation function!") 165 | return None 166 | -------------------------------------------------------------------------------- /rsl_rl/modules/actor_critic_recurrent.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import numpy as np 32 | 33 | import torch 34 | import torch.nn as nn 35 | from torch.distributions import Normal 36 | from torch.nn.modules import rnn 37 | from .actor_critic import ActorCritic, get_activation 38 | from rsl_rl.utils import unpad_trajectories 39 | 40 | 41 | class ActorCriticRecurrent(ActorCritic): 42 | is_recurrent = True 43 | 44 | def __init__( 45 | self, 46 | num_actor_obs, 47 | num_critic_obs, 48 | num_actions, 49 | actor_hidden_dims=[256, 256, 256], 50 | critic_hidden_dims=[256, 256, 256], 51 | activation='elu', 52 | rnn_type='lstm', 53 | rnn_hidden_size=256, 54 | rnn_num_layers=1, 55 | init_noise_std=1.0, 56 | **kwargs 57 | ): 58 | if kwargs: 59 | print( 60 | "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()), 61 | ) 62 | 63 | super().__init__( 64 | num_actor_obs=rnn_hidden_size, 65 | num_critic_obs=rnn_hidden_size, 66 | num_actions=num_actions, 67 | actor_hidden_dims=actor_hidden_dims, 68 | critic_hidden_dims=critic_hidden_dims, 69 | activation=activation, 70 | init_noise_std=init_noise_std 71 | ) 72 | 73 | activation = get_activation(activation) 74 | 75 | self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 76 | self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 77 | 78 | print(f"Actor RNN: {self.memory_a}") 79 | print(f"Critic RNN: {self.memory_c}") 80 | 81 | def reset(self, dones=None): 82 | self.memory_a.reset(dones) 83 | self.memory_c.reset(dones) 84 | 85 | def act(self, observations, masks=None, hidden_states=None): 86 | input_a = self.memory_a(observations, masks, hidden_states) 87 | return super().act(input_a.squeeze(0)) 88 | 89 | def act_inference(self, observations): 90 | input_a = self.memory_a(observations) 91 | return super().act_inference(input_a.squeeze(0)) 92 | 93 | def evaluate(self, critic_observations, masks=None, hidden_states=None): 94 | input_c = self.memory_c(critic_observations, masks, hidden_states) 95 | return super().evaluate(input_c.squeeze(0)) 96 | 97 | def get_hidden_states(self): 98 | return self.memory_a.hidden_states, self.memory_c.hidden_states 99 | 100 | 101 | class Memory(torch.nn.Module): 102 | def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256): 103 | super().__init__() 104 | # RNN 105 | rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM 106 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 107 | self.hidden_states = None 108 | 109 | def forward(self, input, masks=None, hidden_states=None): 110 | batch_mode = masks is not None 111 | if batch_mode: 112 | # batch mode (policy update): need saved hidden states 113 | if hidden_states is None: 114 | raise ValueError("Hidden states not passed to memory module during policy update") 115 | out, _ = self.rnn(input, hidden_states) 116 | out = unpad_trajectories(out, masks) 117 | else: 118 | # inference mode (collection): use hidden states of last step 119 | out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) 120 | return out 121 | 122 | def reset(self, dones=None): 123 | # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state 124 | for hidden_state in self.hidden_states: 125 | hidden_state[..., dones, :] = 0.0 126 | -------------------------------------------------------------------------------- /rsl_rl/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .on_policy_runner import OnPolicyRunner 32 | -------------------------------------------------------------------------------- /rsl_rl/storage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from .rollout_storage import RolloutStorage -------------------------------------------------------------------------------- /rsl_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | from .utils import split_and_pad_trajectories, unpad_trajectories 32 | -------------------------------------------------------------------------------- /rsl_rl/utils/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin 30 | 31 | import torch 32 | 33 | 34 | def split_and_pad_trajectories(tensor, dones): 35 | """ Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory. 36 | Returns masks corresponding to valid parts of the trajectories 37 | Example: 38 | Input: [ [a1, a2, a3, a4 | a5, a6], 39 | [b1, b2 | b3, b4, b5 | b6] 40 | ] 41 | 42 | Output:[ [a1, a2, a3, a4], | [ [True, True, True, True], 43 | [a5, a6, 0, 0], | [True, True, False, False], 44 | [b1, b2, 0, 0], | [True, True, False, False], 45 | [b3, b4, b5, 0], | [True, True, True, False], 46 | [b6, 0, 0, 0] | [True, False, False, False], 47 | ] | ] 48 | 49 | Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions] 50 | """ 51 | dones = dones.clone() 52 | dones[-1] = 1 53 | # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping 54 | flat_dones = dones.transpose(1, 0).reshape(-1, 1) 55 | 56 | # Get length of trajectory by counting the number of successive not done elements 57 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0])) 58 | trajectory_lengths = done_indices[1:] - done_indices[:-1] 59 | trajectory_lengths_list = trajectory_lengths.tolist() 60 | # Extract the individual trajectories 61 | trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list) 62 | padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) 63 | 64 | trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) 65 | return padded_trajectories, trajectory_masks 66 | 67 | 68 | def unpad_trajectories(trajectories, masks): 69 | """ Does the inverse operation of split_and_pad_trajectories() 70 | """ 71 | # Need to transpose before and after the masking to have proper reshaping 72 | return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0], 73 | trajectories.shape[-1]).transpose(1, 0) 74 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | from setuptools import find_packages, setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | 6 | pathfinding_2D_cuda = CUDAExtension( 7 | name='bfs_cuda_2D', 8 | sources=[ 9 | 'gleam/utils/bfs_cuda_2D.cpp', 10 | 'gleam/utils/bfs_cuda_kernel_2D.cu', 11 | ], 12 | ) 13 | 14 | # Define the setup configuration 15 | # Original author: Nikita Rudin (legged_gym) 16 | setup( 17 | name='legged_gym', 18 | version='1.0.0', 19 | author='Xiao Chen', 20 | license="BSD-3-Clause", 21 | packages=find_packages(), 22 | author_email='cx123@ie.cuhk.edu.hk', 23 | description='GLEAM based on Isaac Gym environments', 24 | install_requires=[ 25 | 'isaacgym', 26 | 'gym', 27 | 'matplotlib', 28 | "tensorboard", 29 | "cloudpickle", 30 | "pandas", 31 | "yapf~=0.30.0", 32 | "wandb", 33 | "opencv-python>=3.0.0" 34 | ], 35 | ext_modules=[ 36 | pathfinding_2D_cuda 37 | ], 38 | cmdclass={ 39 | 'build_ext': BuildExtension 40 | }, 41 | ) 42 | -------------------------------------------------------------------------------- /stable_baselines3/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from stable_baselines3.a2c import A2C 4 | from stable_baselines3.common.utils import get_system_info 5 | from stable_baselines3.ddpg import DDPG 6 | from stable_baselines3.dqn import DQN 7 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer 8 | from stable_baselines3.ppo import PPO 9 | from stable_baselines3.sac import SAC 10 | from stable_baselines3.td3 import TD3 11 | 12 | # Read version from file 13 | version_file = os.path.join(os.path.dirname(__file__), "version.txt") 14 | with open(version_file) as file_handler: 15 | __version__ = file_handler.read().strip() 16 | 17 | 18 | def HER(*args, **kwargs): 19 | raise ImportError( 20 | "Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n " 21 | "Please check the documentation for more information: https://stable-baselines3.readthedocs.io/" 22 | ) 23 | -------------------------------------------------------------------------------- /stable_baselines3/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.a2c.a2c import A2C 2 | from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /stable_baselines3/a2c/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for A2C 3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy 4 | 5 | MlpPolicy = ActorCriticPolicy 6 | CnnPolicy = ActorCriticCnnPolicy 7 | MultiInputPolicy = MultiInputActorCriticPolicy 8 | -------------------------------------------------------------------------------- /stable_baselines3/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/stable_baselines3/common/__init__.py -------------------------------------------------------------------------------- /stable_baselines3/common/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv 2 | from stable_baselines3.common.envs.identity_env import ( 3 | FakeImageEnv, 4 | IdentityEnv, 5 | IdentityEnvBox, 6 | IdentityEnvMultiBinary, 7 | IdentityEnvMultiDiscrete, 8 | ) 9 | from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv 10 | -------------------------------------------------------------------------------- /stable_baselines3/common/envs/identity_env.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | from gym import Env, Space 5 | from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete 6 | 7 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn 8 | 9 | 10 | class IdentityEnv(Env): 11 | def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100): 12 | """ 13 | Identity environment for testing purposes 14 | 15 | :param dim: the size of the action and observation dimension you want 16 | to learn. Provide at most one of ``dim`` and ``space``. If both are 17 | None, then initialization proceeds with ``dim=1`` and ``space=None``. 18 | :param space: the action and observation space. Provide at most one of 19 | ``dim`` and ``space``. 20 | :param ep_length: the length of each episode in timesteps 21 | """ 22 | if space is None: 23 | if dim is None: 24 | dim = 1 25 | space = Discrete(dim) 26 | else: 27 | assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed" 28 | 29 | self.action_space = self.observation_space = space 30 | self.ep_length = ep_length 31 | self.current_step = 0 32 | self.num_resets = -1 # Becomes 0 after __init__ exits. 33 | self.reset() 34 | 35 | def reset(self) -> GymObs: 36 | self.current_step = 0 37 | self.num_resets += 1 38 | self._choose_next_state() 39 | return self.state 40 | 41 | def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: 42 | reward = self._get_reward(action) 43 | self._choose_next_state() 44 | self.current_step += 1 45 | done = self.current_step >= self.ep_length 46 | return self.state, reward, done, {} 47 | 48 | def _choose_next_state(self) -> None: 49 | self.state = self.action_space.sample() 50 | 51 | def _get_reward(self, action: Union[int, np.ndarray]) -> float: 52 | return 1.0 if np.all(self.state == action) else 0.0 53 | 54 | def render(self, mode: str = "human") -> None: 55 | pass 56 | 57 | 58 | class IdentityEnvBox(IdentityEnv): 59 | def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100): 60 | """ 61 | Identity environment for testing purposes 62 | 63 | :param low: the lower bound of the box dim 64 | :param high: the upper bound of the box dim 65 | :param eps: the epsilon bound for correct value 66 | :param ep_length: the length of each episode in timesteps 67 | """ 68 | space = Box(low=low, high=high, shape=(1, ), dtype=np.float32) 69 | super().__init__(ep_length=ep_length, space=space) 70 | self.eps = eps 71 | 72 | def step(self, action: np.ndarray) -> GymStepReturn: 73 | reward = self._get_reward(action) 74 | self._choose_next_state() 75 | self.current_step += 1 76 | done = self.current_step >= self.ep_length 77 | return self.state, reward, done, {} 78 | 79 | def _get_reward(self, action: np.ndarray) -> float: 80 | return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 81 | 82 | 83 | class IdentityEnvMultiDiscrete(IdentityEnv): 84 | def __init__(self, dim: int = 1, ep_length: int = 100): 85 | """ 86 | Identity environment for testing purposes 87 | 88 | :param dim: the size of the dimensions you want to learn 89 | :param ep_length: the length of each episode in timesteps 90 | """ 91 | space = MultiDiscrete([dim, dim]) 92 | super().__init__(ep_length=ep_length, space=space) 93 | 94 | 95 | class IdentityEnvMultiBinary(IdentityEnv): 96 | def __init__(self, dim: int = 1, ep_length: int = 100): 97 | """ 98 | Identity environment for testing purposes 99 | 100 | :param dim: the size of the dimensions you want to learn 101 | :param ep_length: the length of each episode in timesteps 102 | """ 103 | space = MultiBinary(dim) 104 | super().__init__(ep_length=ep_length, space=space) 105 | 106 | 107 | class FakeImageEnv(Env): 108 | """ 109 | Fake image environment for testing purposes, it mimics Atari games. 110 | 111 | :param action_dim: Number of discrete actions 112 | :param screen_height: Height of the image 113 | :param screen_width: Width of the image 114 | :param n_channels: Number of color channels 115 | :param discrete: Create discrete action space instead of continuous 116 | :param channel_first: Put channels on first axis instead of last 117 | """ 118 | def __init__( 119 | self, 120 | action_dim: int = 6, 121 | screen_height: int = 84, 122 | screen_width: int = 84, 123 | n_channels: int = 1, 124 | discrete: bool = True, 125 | channel_first: bool = False, 126 | ): 127 | self.observation_shape = (screen_height, screen_width, n_channels) 128 | if channel_first: 129 | self.observation_shape = (n_channels, screen_height, screen_width) 130 | self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8) 131 | if discrete: 132 | self.action_space = Discrete(action_dim) 133 | else: 134 | self.action_space = Box(low=-1, high=1, shape=(5, ), dtype=np.float32) 135 | self.ep_length = 10 136 | self.current_step = 0 137 | 138 | def reset(self) -> np.ndarray: 139 | self.current_step = 0 140 | return self.observation_space.sample() 141 | 142 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: 143 | reward = 0.0 144 | self.current_step += 1 145 | done = self.current_step >= self.ep_length 146 | return self.observation_space.sample(), reward, done, {} 147 | 148 | def render(self, mode: str = "human") -> None: 149 | pass 150 | -------------------------------------------------------------------------------- /stable_baselines3/common/envs/multi_input_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from stable_baselines3.common.type_aliases import GymStepReturn 7 | 8 | 9 | class SimpleMultiObsEnv(gym.Env): 10 | """ 11 | Base class for GridWorld-based MultiObs Environments 4x4 grid world. 12 | 13 | .. code-block:: text 14 | 15 | ____________ 16 | | 0 1 2 3| 17 | | 4|¯5¯¯6¯| 7| 18 | | 8|_9_10_|11| 19 | |12 13 14 15| 20 | ¯¯¯¯¯¯¯¯¯¯¯¯¯¯ 21 | 22 | start is 0 23 | states 5, 6, 9, and 10 are blocked 24 | goal is 15 25 | actions are = [left, down, right, up] 26 | 27 | simple linear state env of 15 states but encoded with a vector and an image observation: 28 | each column is represented by a random vector and each row is 29 | represented by a random image, both sampled once at creation time. 30 | 31 | :param num_col: Number of columns in the grid 32 | :param num_row: Number of rows in the grid 33 | :param random_start: If true, agent starts in random position 34 | :param channel_last: If true, the image will be channel last, else it will be channel first 35 | """ 36 | def __init__( 37 | self, 38 | num_col: int = 4, 39 | num_row: int = 4, 40 | random_start: bool = True, 41 | discrete_actions: bool = True, 42 | channel_last: bool = True, 43 | ): 44 | super().__init__() 45 | 46 | self.vector_size = 5 47 | if channel_last: 48 | self.img_size = [64, 64, 1] 49 | else: 50 | self.img_size = [1, 64, 64] 51 | 52 | self.random_start = random_start 53 | self.discrete_actions = discrete_actions 54 | if discrete_actions: 55 | self.action_space = gym.spaces.Discrete(4) 56 | else: 57 | self.action_space = gym.spaces.Box(0, 1, (4, )) 58 | 59 | self.observation_space = gym.spaces.Dict( 60 | spaces={ 61 | "vec": gym.spaces.Box(0, 1, (self.vector_size, ), dtype=np.float64), 62 | "img": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8), 63 | } 64 | ) 65 | self.count = 0 66 | # Timeout 67 | self.max_count = 100 68 | self.log = "" 69 | self.state = 0 70 | self.action2str = ["left", "down", "right", "up"] 71 | self.init_possible_transitions() 72 | 73 | self.num_col = num_col 74 | self.state_mapping = [] 75 | self.init_state_mapping(num_col, num_row) 76 | 77 | self.max_state = len(self.state_mapping) - 1 78 | 79 | def init_state_mapping(self, num_col: int, num_row: int) -> None: 80 | """ 81 | Initializes the state_mapping array which holds the observation values for each state 82 | 83 | :param num_col: Number of columns. 84 | :param num_row: Number of rows. 85 | """ 86 | # Each column is represented by a random vector 87 | col_vecs = np.random.random((num_col, self.vector_size)) 88 | # Each row is represented by a random image 89 | row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8) 90 | 91 | for i in range(num_col): 92 | for j in range(num_row): 93 | self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)}) 94 | 95 | def get_state_mapping(self) -> Dict[str, np.ndarray]: 96 | """ 97 | Uses the state to get the observation mapping. 98 | 99 | :return: observation dict {'vec': ..., 'img': ...} 100 | """ 101 | return self.state_mapping[self.state] 102 | 103 | def init_possible_transitions(self) -> None: 104 | """ 105 | Initializes the transitions of the environment 106 | The environment exploits the cardinal directions of the grid by noting that 107 | they correspond to simple addition and subtraction from the cell id within the grid 108 | 109 | - up => means moving up a row => means subtracting the length of a column 110 | - down => means moving down a row => means adding the length of a column 111 | - left => means moving left by one => means subtracting 1 112 | - right => means moving right by one => means adding 1 113 | 114 | Thus one only needs to specify in which states each action is possible 115 | in order to define the transitions of the environment 116 | """ 117 | self.left_possible = [1, 2, 3, 13, 14, 15] 118 | self.down_possible = [0, 4, 8, 3, 7, 11] 119 | self.right_possible = [0, 1, 2, 12, 13, 14] 120 | self.up_possible = [4, 8, 12, 7, 11, 15] 121 | 122 | def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn: 123 | """ 124 | Run one timestep of the environment's dynamics. When end of 125 | episode is reached, you are responsible for calling `reset()` 126 | to reset this environment's state. 127 | Accepts an action and returns a tuple (observation, reward, done, info). 128 | 129 | :param action: 130 | :return: tuple (observation, reward, done, info). 131 | """ 132 | if not self.discrete_actions: 133 | action = np.argmax(action) 134 | else: 135 | action = int(action) 136 | 137 | self.count += 1 138 | 139 | prev_state = self.state 140 | 141 | reward = -0.1 142 | # define state transition 143 | if self.state in self.left_possible and action == 0: # left 144 | self.state -= 1 145 | elif self.state in self.down_possible and action == 1: # down 146 | self.state += self.num_col 147 | elif self.state in self.right_possible and action == 2: # right 148 | self.state += 1 149 | elif self.state in self.up_possible and action == 3: # up 150 | self.state -= self.num_col 151 | 152 | got_to_end = self.state == self.max_state 153 | reward = 1 if got_to_end else reward 154 | done = self.count > self.max_count or got_to_end 155 | 156 | self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" 157 | 158 | return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end} 159 | 160 | def render(self, mode: str = "human") -> None: 161 | """ 162 | Prints the log of the environment. 163 | 164 | :param mode: 165 | """ 166 | print(self.log) 167 | 168 | def reset(self) -> Dict[str, np.ndarray]: 169 | """ 170 | Resets the environment state and step count and returns reset observation. 171 | 172 | :return: observation dict {'vec': ..., 'img': ...} 173 | """ 174 | self.count = 0 175 | if not self.random_start: 176 | self.state = 0 177 | else: 178 | self.state = np.random.randint(0, self.max_state) 179 | return self.state_mapping[self.state] 180 | -------------------------------------------------------------------------------- /stable_baselines3/common/evaluation_gleam.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | from typing import List, Tuple, Union 4 | from stable_baselines3.common import base_class 5 | from stable_baselines3.common.vec_env import VecEnv 6 | 7 | 8 | def evaluate_policy_grid_obs( 9 | model: "base_class.BaseAlgorithm", 10 | env: Union[gym.Env, VecEnv], 11 | deterministic: bool = True, 12 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: 13 | """ 14 | Runs policy for ``n_eval_episodes`` episodes and returns average reward. 15 | If a vector env is passed in, this divides the episodes to evaluate onto the 16 | different elements of the vector env. This static division of work is done to 17 | remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more 18 | details and discussion. 19 | 20 | .. note:: 21 | If environment has not been wrapped with ``Monitor`` wrapper, reward and 22 | episode lengths are counted as it appears with ``env.step`` calls. If 23 | the environment contains wrappers that modify rewards or episode lengths 24 | (e.g. reward scaling, early episode reset), these will affect the evaluation 25 | results as well. You can avoid this by wrapping environment with ``Monitor`` 26 | wrapper before anything else. 27 | 28 | :param model: The RL agent you want to evaluate. 29 | :param env: The gym environment or ``VecEnv`` environment. 30 | :param deterministic: Whether to use deterministic or stochastic actions 31 | :return: Mean reward per episode, std of reward per episode. 32 | Returns ([float], [int]) when ``return_episode_rewards`` is True, first 33 | list containing per-episode rewards and second containing per-episode lengths 34 | (in number of steps). 35 | """ 36 | 37 | with torch.no_grad(): 38 | observations = env.reset() 39 | # set termination condition in eval_env 40 | while(1): 41 | actions, _ = model.predict(observations, state=None, deterministic=deterministic) 42 | observations, rewards, dones, infos = env.step(actions) 43 | 44 | -------------------------------------------------------------------------------- /stable_baselines3/common/noise.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from abc import ABC, abstractmethod 3 | from typing import Iterable, List, Optional 4 | 5 | import numpy as np 6 | 7 | 8 | class ActionNoise(ABC): 9 | """ 10 | The action noise base class 11 | """ 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def reset(self) -> None: 16 | """ 17 | call end of episode reset for the noise 18 | """ 19 | pass 20 | 21 | @abstractmethod 22 | def __call__(self) -> np.ndarray: 23 | raise NotImplementedError() 24 | 25 | 26 | class NormalActionNoise(ActionNoise): 27 | """ 28 | A Gaussian action noise 29 | 30 | :param mean: the mean value of the noise 31 | :param sigma: the scale of the noise (std here) 32 | """ 33 | def __init__(self, mean: np.ndarray, sigma: np.ndarray): 34 | self._mu = mean 35 | self._sigma = sigma 36 | super().__init__() 37 | 38 | def __call__(self) -> np.ndarray: 39 | return np.random.normal(self._mu, self._sigma) 40 | 41 | def __repr__(self) -> str: 42 | return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})" 43 | 44 | 45 | class OrnsteinUhlenbeckActionNoise(ActionNoise): 46 | """ 47 | An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction. 48 | 49 | Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab 50 | 51 | :param mean: the mean of the noise 52 | :param sigma: the scale of the noise 53 | :param theta: the rate of mean reversion 54 | :param dt: the timestep for the noise 55 | :param initial_noise: the initial value for the noise output, (if None: 0) 56 | """ 57 | def __init__( 58 | self, 59 | mean: np.ndarray, 60 | sigma: np.ndarray, 61 | theta: float = 0.15, 62 | dt: float = 1e-2, 63 | initial_noise: Optional[np.ndarray] = None, 64 | ): 65 | self._theta = theta 66 | self._mu = mean 67 | self._sigma = sigma 68 | self._dt = dt 69 | self.initial_noise = initial_noise 70 | self.noise_prev = np.zeros_like(self._mu) 71 | self.reset() 72 | super().__init__() 73 | 74 | def __call__(self) -> np.ndarray: 75 | noise = ( 76 | self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt + 77 | self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) 78 | ) 79 | self.noise_prev = noise 80 | return noise 81 | 82 | def reset(self) -> None: 83 | """ 84 | reset the Ornstein Uhlenbeck noise, to the initial position 85 | """ 86 | self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu) 87 | 88 | def __repr__(self) -> str: 89 | return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})" 90 | 91 | 92 | class VectorizedActionNoise(ActionNoise): 93 | """ 94 | A Vectorized action noise for parallel environments. 95 | 96 | :param base_noise: ActionNoise The noise generator to use 97 | :param n_envs: The number of parallel environments 98 | """ 99 | def __init__(self, base_noise: ActionNoise, n_envs: int): 100 | try: 101 | self.n_envs = int(n_envs) 102 | assert self.n_envs > 0 103 | except (TypeError, AssertionError) as e: 104 | raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e 105 | 106 | self.base_noise = base_noise 107 | self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)] 108 | 109 | def reset(self, indices: Optional[Iterable[int]] = None) -> None: 110 | """ 111 | Reset all the noise processes, or those listed in indices 112 | 113 | :param indices: Optional[Iterable[int]] The indices to reset. Default: None. 114 | If the parameter is None, then all processes are reset to their initial position. 115 | """ 116 | if indices is None: 117 | indices = range(len(self.noises)) 118 | 119 | for index in indices: 120 | self.noises[index].reset() 121 | 122 | def __repr__(self) -> str: 123 | return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})" 124 | 125 | def __call__(self) -> np.ndarray: 126 | """ 127 | Generate and stack the action noise from each noise object 128 | """ 129 | noise = np.stack([noise() for noise in self.noises]) 130 | return noise 131 | 132 | @property 133 | def base_noise(self) -> ActionNoise: 134 | return self._base_noise 135 | 136 | @base_noise.setter 137 | def base_noise(self, base_noise: ActionNoise) -> None: 138 | if base_noise is None: 139 | raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise) 140 | if not isinstance(base_noise, ActionNoise): 141 | raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise) 142 | self._base_noise = base_noise 143 | 144 | @property 145 | def noises(self) -> List[ActionNoise]: 146 | return self._noises 147 | 148 | @noises.setter 149 | def noises(self, noises: List[ActionNoise]) -> None: 150 | noises = list(noises) # raises TypeError if not iterable 151 | assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}." 152 | 153 | different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))] 154 | 155 | if len(different_types): 156 | raise ValueError( 157 | f"Noise instances at indices {different_types} don't match the type of base_noise", 158 | type(self.base_noise) 159 | ) 160 | 161 | self._noises = noises 162 | for noise in noises: 163 | noise.reset() 164 | -------------------------------------------------------------------------------- /stable_baselines3/common/results_plotter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | # import matplotlib 7 | # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode 8 | from matplotlib import pyplot as plt 9 | 10 | from stable_baselines3.common.monitor import load_results 11 | 12 | X_TIMESTEPS = "timesteps" 13 | X_EPISODES = "episodes" 14 | X_WALLTIME = "walltime_hrs" 15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME] 16 | EPISODES_WINDOW = 100 17 | 18 | 19 | def rolling_window(array: np.ndarray, window: int) -> np.ndarray: 20 | """ 21 | Apply a rolling window to a np.ndarray 22 | 23 | :param array: the input Array 24 | :param window: length of the rolling window 25 | :return: rolling window on the input array 26 | """ 27 | shape = array.shape[:-1] + (array.shape[-1] - window + 1, window) 28 | strides = array.strides + (array.strides[-1], ) 29 | return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides) 30 | 31 | 32 | def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]: 33 | """ 34 | Apply a function to the rolling window of 2 arrays 35 | 36 | :param var_1: variable 1 37 | :param var_2: variable 2 38 | :param window: length of the rolling window 39 | :param func: function to apply on the rolling window on variable 2 (such as np.mean) 40 | :return: the rolling output with applied function 41 | """ 42 | var_2_window = rolling_window(var_2, window) 43 | function_on_var2 = func(var_2_window, axis=-1) 44 | return var_1[window - 1:], function_on_var2 45 | 46 | 47 | def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: 48 | """ 49 | Decompose a data frame variable to x ans ys 50 | 51 | :param data_frame: the input data 52 | :param x_axis: the axis for the x and y output 53 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 54 | :return: the x and y output 55 | """ 56 | if x_axis == X_TIMESTEPS: 57 | x_var = np.cumsum(data_frame.l.values) 58 | y_var = data_frame.r.values 59 | elif x_axis == X_EPISODES: 60 | x_var = np.arange(len(data_frame)) 61 | y_var = data_frame.r.values 62 | elif x_axis == X_WALLTIME: 63 | # Convert to hours 64 | x_var = data_frame.t.values / 3600.0 65 | y_var = data_frame.r.values 66 | else: 67 | raise NotImplementedError 68 | return x_var, y_var 69 | 70 | 71 | def plot_curves( 72 | xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2) 73 | ) -> None: 74 | """ 75 | plot the curves 76 | 77 | :param xy_list: the x and y coordinates to plot 78 | :param x_axis: the axis for the x and y output 79 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 80 | :param title: the title of the plot 81 | :param figsize: Size of the figure (width, height) 82 | """ 83 | 84 | plt.figure(title, figsize=figsize) 85 | max_x = max(xy[0][-1] for xy in xy_list) 86 | min_x = 0 87 | for (_, (x, y)) in enumerate(xy_list): 88 | plt.scatter(x, y, s=2) 89 | # Do not plot the smoothed curve at all if the timeseries is shorter than window size. 90 | if x.shape[0] >= EPISODES_WINDOW: 91 | # Compute and plot rolling mean with window of size EPISODE_WINDOW 92 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) 93 | plt.plot(x, y_mean) 94 | plt.xlim(min_x, max_x) 95 | plt.title(title) 96 | plt.xlabel(x_axis) 97 | plt.ylabel("Episode Rewards") 98 | plt.tight_layout() 99 | 100 | 101 | def plot_results( 102 | dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2) 103 | ) -> None: 104 | """ 105 | Plot the results using csv files from ``Monitor`` wrapper. 106 | 107 | :param dirs: the save location of the results to plot 108 | :param num_timesteps: only plot the points below this value 109 | :param x_axis: the axis for the x and y output 110 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 111 | :param task_name: the title of the task to plot 112 | :param figsize: Size of the figure (width, height) 113 | """ 114 | 115 | data_frames = [] 116 | for folder in dirs: 117 | data_frame = load_results(folder) 118 | if num_timesteps is not None: 119 | data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps] 120 | data_frames.append(data_frame) 121 | xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames] 122 | plot_curves(xy_list, x_axis, task_name, figsize) 123 | -------------------------------------------------------------------------------- /stable_baselines3/common/running_mean_std.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import numpy as np 4 | 5 | 6 | class RunningMeanStd: 7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): 8 | """ 9 | Calulates the running mean and std of a data stream 10 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 11 | 12 | :param epsilon: helps with arithmetic issues 13 | :param shape: the shape of the data stream's output 14 | """ 15 | self.mean = np.zeros(shape, np.float64) 16 | self.var = np.ones(shape, np.float64) 17 | self.count = epsilon 18 | 19 | def copy(self) -> "RunningMeanStd": 20 | """ 21 | :return: Return a copy of the current object. 22 | """ 23 | new_object = RunningMeanStd(shape=self.mean.shape) 24 | new_object.mean = self.mean.copy() 25 | new_object.var = self.var.copy() 26 | new_object.count = float(self.count) 27 | return new_object 28 | 29 | def combine(self, other: "RunningMeanStd") -> None: 30 | """ 31 | Combine stats from another ``RunningMeanStd`` object. 32 | 33 | :param other: The other object to combine with. 34 | """ 35 | self.update_from_moments(other.mean, other.var, other.count) 36 | 37 | def update(self, arr: np.ndarray) -> None: 38 | batch_mean = np.mean(arr, axis=0) 39 | batch_var = np.var(arr, axis=0) 40 | batch_count = arr.shape[0] 41 | self.update_from_moments(batch_mean, batch_var, batch_count) 42 | 43 | def update_from_moments( 44 | self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: Union[int, float] 45 | ) -> None: 46 | delta = batch_mean - self.mean 47 | tot_count = self.count + batch_count 48 | 49 | new_mean = self.mean + delta * batch_count / tot_count 50 | m_a = self.var * self.count 51 | m_b = batch_var * batch_count 52 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 53 | new_var = m_2 / (self.count + batch_count) 54 | 55 | new_count = batch_count + self.count 56 | 57 | self.mean = new_mean 58 | self.var = new_var 59 | self.count = new_count 60 | -------------------------------------------------------------------------------- /stable_baselines3/common/sb2_compat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/stable_baselines3/common/sb2_compat/__init__.py -------------------------------------------------------------------------------- /stable_baselines3/common/sb2_compat/rmsprop_tf_like.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Iterable, Optional 2 | 3 | import torch 4 | from torch.optim import Optimizer 5 | 6 | 7 | class RMSpropTFLike(Optimizer): 8 | r"""Implements RMSprop algorithm with closer match to Tensorflow version. 9 | 10 | For reproducibility with original stable-baselines. Use this 11 | version with e.g. A2C for stabler learning than with the PyTorch 12 | RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop. 13 | 14 | See a more throughout conversion in pytorch-image-models repository: 15 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py 16 | 17 | Changes to the original RMSprop: 18 | - Move epsilon inside square root 19 | - Initialize squared gradient to ones rather than zeros 20 | 21 | Proposed by G. Hinton in his 22 | `course `_. 23 | 24 | The centered version first appears in `Generating Sequences 25 | With Recurrent Neural Networks `_. 26 | 27 | The implementation here takes the square root of the gradient average before 28 | adding epsilon (note that TensorFlow interchanges these two operations). The effective 29 | learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha` 30 | is the scheduled learning rate and :math:`v` is the weighted moving average 31 | of the squared gradient. 32 | 33 | :params: iterable of parameters to optimize or dicts defining 34 | parameter groups 35 | :param lr: learning rate (default: 1e-2) 36 | :param momentum: momentum factor (default: 0) 37 | :param alpha: smoothing constant (default: 0.99) 38 | :param eps: term added to the denominator to improve 39 | numerical stability (default: 1e-8) 40 | :param centered: if ``True``, compute the centered RMSProp, 41 | the gradient is normalized by an estimation of its variance 42 | :param weight_decay: weight decay (L2 penalty) (default: 0) 43 | 44 | """ 45 | def __init__( 46 | self, 47 | params: Iterable[torch.nn.Parameter], 48 | lr: float = 1e-2, 49 | alpha: float = 0.99, 50 | eps: float = 1e-8, 51 | weight_decay: float = 0, 52 | momentum: float = 0, 53 | centered: bool = False, 54 | ): 55 | if not 0.0 <= lr: 56 | raise ValueError(f"Invalid learning rate: {lr}") 57 | if not 0.0 <= eps: 58 | raise ValueError(f"Invalid epsilon value: {eps}") 59 | if not 0.0 <= momentum: 60 | raise ValueError(f"Invalid momentum value: {momentum}") 61 | if not 0.0 <= weight_decay: 62 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 63 | if not 0.0 <= alpha: 64 | raise ValueError(f"Invalid alpha value: {alpha}") 65 | 66 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) 67 | super().__init__(params, defaults) 68 | 69 | def __setstate__(self, state: Dict[str, Any]) -> None: 70 | super().__setstate__(state) 71 | for group in self.param_groups: 72 | group.setdefault("momentum", 0) 73 | group.setdefault("centered", False) 74 | 75 | @torch.no_grad() 76 | def step(self, closure: Optional[Callable[[], None]] = None) -> Optional[torch.Tensor]: 77 | """Performs a single optimization step. 78 | 79 | :param closure: A closure that reevaluates the model 80 | and returns the loss. 81 | :return: loss 82 | """ 83 | loss = None 84 | if closure is not None: 85 | with torch.enable_grad(): 86 | loss = closure() 87 | 88 | for group in self.param_groups: 89 | for p in group["params"]: 90 | if p.grad is None: 91 | continue 92 | grad = p.grad 93 | if grad.is_sparse: 94 | raise RuntimeError("RMSpropTF does not support sparse gradients") 95 | state = self.state[p] 96 | 97 | # State initialization 98 | if len(state) == 0: 99 | state["step"] = 0 100 | # PyTorch initialized to zeros here 101 | state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format) 102 | if group["momentum"] > 0: 103 | state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) 104 | if group["centered"]: 105 | state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) 106 | 107 | square_avg = state["square_avg"] 108 | alpha = group["alpha"] 109 | 110 | state["step"] += 1 111 | 112 | if group["weight_decay"] != 0: 113 | grad = grad.add(p, alpha=group["weight_decay"]) 114 | 115 | square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) 116 | 117 | if group["centered"]: 118 | grad_avg = state["grad_avg"] 119 | grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) 120 | # PyTorch added epsilon after square root 121 | # avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps']) 122 | avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_() 123 | else: 124 | # PyTorch added epsilon after square root 125 | # avg = square_avg.sqrt().add_(group['eps']) 126 | avg = square_avg.add(group["eps"]).sqrt_() 127 | 128 | if group["momentum"] > 0: 129 | buf = state["momentum_buffer"] 130 | buf.mul_(group["momentum"]).addcdiv_(grad, avg) 131 | p.add_(buf, alpha=-group["lr"]) 132 | else: 133 | p.addcdiv_(grad, avg, value=-group["lr"]) 134 | 135 | return loss 136 | -------------------------------------------------------------------------------- /stable_baselines3/common/type_aliases.py: -------------------------------------------------------------------------------- 1 | """Common aliases for type hints""" 2 | 3 | from enum import Enum 4 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union 5 | 6 | import gym 7 | import numpy as np 8 | import torch as th 9 | 10 | from stable_baselines3.common import callbacks, vec_env 11 | 12 | GymEnv = Union[gym.Env, vec_env.VecEnv] 13 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] 14 | GymStepReturn = Tuple[GymObs, float, bool, Dict] 15 | TensorDict = Dict[Union[str, int], th.Tensor] 16 | OptimizerStateDict = Dict[str, Any] 17 | MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] 18 | 19 | # A schedule takes the remaining progress as input 20 | # and ouputs a scalar (e.g. learning rate, clip range, ...) 21 | Schedule = Callable[[float], float] 22 | 23 | 24 | class RolloutBufferSamples(NamedTuple): 25 | observations: th.Tensor 26 | actions: th.Tensor 27 | old_values: th.Tensor 28 | old_log_prob: th.Tensor 29 | advantages: th.Tensor 30 | returns: th.Tensor 31 | 32 | 33 | class DictRolloutBufferSamples(RolloutBufferSamples): 34 | observations: TensorDict 35 | actions: th.Tensor 36 | old_values: th.Tensor 37 | old_log_prob: th.Tensor 38 | advantages: th.Tensor 39 | returns: th.Tensor 40 | 41 | 42 | class ReplayBufferSamples(NamedTuple): 43 | observations: th.Tensor 44 | actions: th.Tensor 45 | next_observations: th.Tensor 46 | dones: th.Tensor 47 | rewards: th.Tensor 48 | 49 | 50 | class DictReplayBufferSamples(ReplayBufferSamples): 51 | observations: TensorDict 52 | actions: th.Tensor 53 | next_observations: th.Tensor 54 | dones: th.Tensor 55 | rewards: th.Tensor 56 | 57 | 58 | class RolloutReturn(NamedTuple): 59 | episode_timesteps: int 60 | n_episodes: int 61 | continue_training: bool 62 | 63 | 64 | class TrainFrequencyUnit(Enum): 65 | STEP = "step" 66 | EPISODE = "episode" 67 | 68 | 69 | class TrainFreq(NamedTuple): 70 | frequency: int 71 | unit: TrainFrequencyUnit # either "step" or "episode" 72 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F401 2 | import typing 3 | from copy import deepcopy 4 | from typing import Optional, Type, Union 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan 11 | from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs 12 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack 13 | from stable_baselines3.common.vec_env.vec_monitor import VecMonitor 14 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize 15 | from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage 16 | from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder 17 | 18 | # Avoid circular import 19 | if typing.TYPE_CHECKING: 20 | from stable_baselines3.common.type_aliases import GymEnv 21 | 22 | 23 | def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: 24 | """ 25 | Retrieve a ``VecEnvWrapper`` object by recursively searching. 26 | 27 | :param env: 28 | :param vec_wrapper_class: 29 | :return: 30 | """ 31 | env_tmp = env 32 | while isinstance(env_tmp, VecEnvWrapper): 33 | if isinstance(env_tmp, vec_wrapper_class): 34 | return env_tmp 35 | env_tmp = env_tmp.venv 36 | return None 37 | 38 | 39 | def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: 40 | """ 41 | :param env: 42 | :return: 43 | """ 44 | return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type 45 | 46 | 47 | def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: 48 | """ 49 | Check if an environment is already wrapped by a given ``VecEnvWrapper``. 50 | 51 | :param env: 52 | :param vec_wrapper_class: 53 | :return: 54 | """ 55 | return unwrap_vec_wrapper(env, vec_wrapper_class) is not None 56 | 57 | 58 | # Define here to avoid circular import 59 | def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: 60 | """ 61 | Sync eval env and train env when using VecNormalize 62 | 63 | :param env: 64 | :param eval_env: 65 | """ 66 | env_tmp, eval_env_tmp = env, eval_env 67 | while isinstance(env_tmp, VecEnvWrapper): 68 | if isinstance(env_tmp, VecNormalize): 69 | # Only synchronize if observation normalization exists 70 | if hasattr(env_tmp, "obs_rms"): 71 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) 72 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) 73 | env_tmp = env_tmp.venv 74 | eval_env_tmp = eval_env_tmp.venv 75 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import deepcopy 3 | from typing import Any, Callable, List, Optional, Sequence, Type, Union 4 | 5 | import gym 6 | import numpy as np 7 | # from torch import is_tensor 8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn 9 | from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info 10 | 11 | 12 | class DummyVecEnv(VecEnv): 13 | """ 14 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current 15 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``, 16 | as the overhead of multiprocess or multithread outweighs the environment computation time. 17 | This can also be used for RL methods that 18 | require a vectorized environment, but that you want a single environments to train with. 19 | 20 | :param env_fns: a list of functions 21 | that return environments to vectorize 22 | """ 23 | def __init__(self, env_fns: List[Callable[[], gym.Env]]): 24 | self.envs = [fn() for fn in env_fns] 25 | env = self.envs[0] 26 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 27 | obs_space = env.observation_space 28 | self.keys, shapes, dtypes = obs_space_info(obs_space) 29 | 30 | self.buf_obs = OrderedDict( 31 | [(k, np.zeros((self.num_envs, ) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys] 32 | ) 33 | self.buf_dones = np.zeros((self.num_envs, ), dtype=bool) 34 | self.buf_rews = np.zeros((self.num_envs, ), dtype=np.float32) 35 | self.buf_infos = [{} for _ in range(self.num_envs)] 36 | self.actions = None 37 | self.metadata = env.metadata 38 | 39 | def step_async(self, actions: np.ndarray) -> None: 40 | self.actions = actions 41 | 42 | def step_wait(self) -> VecEnvStepReturn: 43 | for env_idx in range(self.num_envs): 44 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( 45 | self.actions[env_idx] 46 | ) 47 | if self.buf_dones[env_idx]: 48 | # save final observation where user can get it, then reset 49 | self.buf_infos[env_idx]["terminal_observation"] = obs 50 | obs = self.envs[env_idx].reset() 51 | self._save_obs(env_idx, obs) 52 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) 53 | 54 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: 55 | if seed is None: 56 | seed = np.random.randint(0, 2**32 - 1) 57 | seeds = [] 58 | for idx, env in enumerate(self.envs): 59 | seeds.append(env.seed(seed + idx)) 60 | return seeds 61 | 62 | def reset(self) -> VecEnvObs: 63 | for env_idx in range(self.num_envs): 64 | obs = self.envs[env_idx].reset() 65 | self._save_obs(env_idx, obs) 66 | return self._obs_from_buf() 67 | 68 | def close(self) -> None: 69 | for env in self.envs: 70 | env.close() 71 | 72 | def get_images(self) -> Sequence[np.ndarray]: 73 | return [env.render(mode="rgb_array") for env in self.envs] 74 | 75 | def render(self, mode: str = "human") -> Optional[np.ndarray]: 76 | """ 77 | Gym environment rendering. If there are multiple environments then 78 | they are tiled together in one image via ``BaseVecEnv.render()``. 79 | Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the 80 | underlying environment. 81 | 82 | Therefore, some arguments such as ``mode`` will have values that are valid 83 | only when ``num_envs == 1``. 84 | 85 | :param mode: The rendering type. 86 | """ 87 | if self.num_envs == 1: 88 | return self.envs[0].render(mode=mode) 89 | else: 90 | return super().render(mode=mode) 91 | 92 | def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: 93 | for key in self.keys: 94 | if key is None: 95 | self.buf_obs[key][env_idx] = obs 96 | else: 97 | # if is_tensor(obs[key]): # TODO 98 | # obs[key] = obs[key].cpu() 99 | self.buf_obs[key][env_idx] = obs[key] 100 | 101 | def _obs_from_buf(self) -> VecEnvObs: 102 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) 103 | 104 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 105 | """Return attribute from vectorized environment (see base class).""" 106 | target_envs = self._get_target_envs(indices) 107 | return [getattr(env_i, attr_name) for env_i in target_envs] 108 | 109 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 110 | """Set attribute inside vectorized environments (see base class).""" 111 | target_envs = self._get_target_envs(indices) 112 | for env_i in target_envs: 113 | setattr(env_i, attr_name, value) 114 | 115 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 116 | """Call instance methods of vectorized environments.""" 117 | target_envs = self._get_target_envs(indices) 118 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] 119 | 120 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: 121 | """Check if worker environments are wrapped with a given wrapper""" 122 | target_envs = self._get_target_envs(indices) 123 | # Import here to avoid a circular import 124 | from stable_baselines3.common import env_util 125 | 126 | return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] 127 | 128 | def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: 129 | indices = self._get_indices(indices) 130 | return [self.envs[i] for i in indices] 131 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | from collections import OrderedDict 5 | from typing import Any, Dict, List, Tuple 6 | 7 | import gym 8 | import numpy as np 9 | 10 | from stable_baselines3.common.preprocessing import check_for_nested_spaces 11 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs 12 | 13 | 14 | def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 15 | """ 16 | Deep-copy a dict of numpy arrays. 17 | 18 | :param obs: a dict of numpy arrays. 19 | :return: a dict of copied numpy arrays. 20 | """ 21 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" 22 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 23 | 24 | 25 | def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: 26 | """ 27 | Convert an internal representation raw_obs into the appropriate type 28 | specified by space. 29 | 30 | :param obs_space: an observation space. 31 | :param obs_dict: a dict of numpy arrays. 32 | :return: returns an observation of the same type as space. 33 | If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; 34 | otherwise, space is unstructured and returns the value raw_obs[None]. 35 | """ 36 | if isinstance(obs_space, gym.spaces.Dict): 37 | return obs_dict 38 | elif isinstance(obs_space, gym.spaces.Tuple): 39 | assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" 40 | return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) 41 | else: 42 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 43 | return obs_dict[None] 44 | 45 | 46 | def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: 47 | """ 48 | Get dict-structured information about a gym.Space. 49 | 50 | Dict spaces are represented directly by their dict of subspaces. 51 | Tuple spaces are converted into a dict with keys indexing into the tuple. 52 | Unstructured spaces are represented by {None: obs_space}. 53 | 54 | :param obs_space: an observation space 55 | :return: A tuple (keys, shapes, dtypes): 56 | keys: a list of dict keys. 57 | shapes: a dict mapping keys to shapes. 58 | dtypes: a dict mapping keys to dtypes. 59 | """ 60 | check_for_nested_spaces(obs_space) 61 | if isinstance(obs_space, gym.spaces.Dict): 62 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 63 | subspaces = obs_space.spaces 64 | elif isinstance(obs_space, gym.spaces.Tuple): 65 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} 66 | else: 67 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" 68 | subspaces = {None: obs_space} 69 | keys = [] 70 | shapes = {} 71 | dtypes = {} 72 | for key, box in subspaces.items(): 73 | keys.append(key) 74 | shapes[key] = box.shape 75 | dtypes[key] = box.dtype 76 | return keys, shapes, dtypes 77 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 6 | 7 | 8 | class VecCheckNan(VecEnvWrapper): 9 | """ 10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default, 11 | allowing you to know from what the NaN of inf originated from. 12 | 13 | :param venv: the vectorized environment to wrap 14 | :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning 15 | :param warn_once: Whether or not to only warn once. 16 | :param check_inf: Whether or not to check for +inf or -inf as well 17 | """ 18 | def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True): 19 | VecEnvWrapper.__init__(self, venv) 20 | self.raise_exception = raise_exception 21 | self.warn_once = warn_once 22 | self.check_inf = check_inf 23 | self._actions = None 24 | self._observations = None 25 | self._user_warned = False 26 | 27 | def step_async(self, actions: np.ndarray) -> None: 28 | self._check_val(async_step=True, actions=actions) 29 | 30 | self._actions = actions 31 | self.venv.step_async(actions) 32 | 33 | def step_wait(self) -> VecEnvStepReturn: 34 | observations, rewards, news, infos = self.venv.step_wait() 35 | 36 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) 37 | 38 | self._observations = observations 39 | return observations, rewards, news, infos 40 | 41 | def reset(self) -> VecEnvObs: 42 | observations = self.venv.reset() 43 | self._actions = None 44 | 45 | self._check_val(async_step=False, observations=observations) 46 | 47 | self._observations = observations 48 | return observations 49 | 50 | def _check_val(self, *, async_step: bool, **kwargs) -> None: 51 | # if warn and warn once and have warned once: then stop checking 52 | if not self.raise_exception and self.warn_once and self._user_warned: 53 | return 54 | 55 | found = [] 56 | for name, val in kwargs.items(): 57 | has_nan = np.any(np.isnan(val)) 58 | has_inf = self.check_inf and np.any(np.isinf(val)) 59 | if has_inf: 60 | found.append((name, "inf")) 61 | if has_nan: 62 | found.append((name, "nan")) 63 | 64 | if found: 65 | self._user_warned = True 66 | msg = "" 67 | for i, (name, type_val) in enumerate(found): 68 | msg += f"found {type_val} in {name}" 69 | if i != len(found) - 1: 70 | msg += ", " 71 | 72 | msg += ".\r\nOriginated from the " 73 | 74 | if not async_step: 75 | if self._actions is None: 76 | msg += "environment observation (at reset)" 77 | else: 78 | msg += f"environment, Last given value was: \r\n\taction={self._actions}" 79 | else: 80 | msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}" 81 | 82 | if self.raise_exception: 83 | raise ValueError(msg) 84 | else: 85 | warnings.warn(msg, UserWarning) 86 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/vec_extract_dict_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 4 | 5 | 6 | class VecExtractDictObs(VecEnvWrapper): 7 | """ 8 | A vectorized wrapper for extracting dictionary observations. 9 | 10 | :param venv: The vectorized environment 11 | :param key: The key of the dictionary observation 12 | """ 13 | def __init__(self, venv: VecEnv, key: str): 14 | self.key = key 15 | super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) 16 | 17 | def reset(self) -> np.ndarray: 18 | obs = self.venv.reset() 19 | return obs[self.key] 20 | 21 | def step_wait(self) -> VecEnvStepReturn: 22 | obs, reward, done, info = self.venv.step_wait() 23 | return obs[self.key], reward, done, info 24 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations 8 | 9 | 10 | class VecFrameStack(VecEnvWrapper): 11 | """ 12 | Frame stacking wrapper for vectorized environment. Designed for image observations. 13 | 14 | Uses the StackedObservations class, or StackedDictObservations depending on the observations space 15 | 16 | :param venv: the vectorized environment to wrap 17 | :param n_stack: Number of frames to stack 18 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. 19 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default). 20 | Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces 21 | """ 22 | def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): 23 | self.venv = venv 24 | self.n_stack = n_stack 25 | 26 | wrapped_obs_space = venv.observation_space 27 | 28 | if isinstance(wrapped_obs_space, spaces.Box): 29 | assert not isinstance( 30 | channels_order, dict 31 | ), f"Expected None or string for channels_order but received {channels_order}" 32 | self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) 33 | 34 | elif isinstance(wrapped_obs_space, spaces.Dict): 35 | self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) 36 | 37 | else: 38 | raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") 39 | 40 | observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) 41 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 42 | 43 | def step_wait( 44 | self, 45 | ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]], ]: 46 | 47 | observations, rewards, dones, infos = self.venv.step_wait() 48 | 49 | observations, infos = self.stackedobs.update(observations, dones, infos) 50 | 51 | return observations, rewards, dones, infos 52 | 53 | def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: 54 | """ 55 | Reset all environments 56 | """ 57 | observation = self.venv.reset() # pytype:disable=annotation-type-mismatch 58 | 59 | observation = self.stackedobs.reset(observation) 60 | return observation 61 | 62 | def close(self) -> None: 63 | self.venv.close() 64 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/vec_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | from typing import Optional, Tuple 4 | 5 | import numpy as np 6 | 7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 8 | 9 | 10 | class VecMonitor(VecEnvWrapper): 11 | """ 12 | A vectorized monitor wrapper for *vectorized* Gym environments, 13 | it is used to record the episode reward, length, time and other data. 14 | 15 | Some environments like `openai/procgen `_ 16 | or `gym3 `_ directly initialize the 17 | vectorized environments, without giving us a chance to use the ``Monitor`` 18 | wrapper. So this class simply does the job of the ``Monitor`` wrapper on 19 | a vectorized level. 20 | 21 | :param venv: The vectorized environment 22 | :param filename: the location to save a log file, can be None for no log 23 | :param info_keywords: extra information to log, from the information return of env.step() 24 | """ 25 | def __init__( 26 | self, 27 | venv: VecEnv, 28 | filename: Optional[str] = None, 29 | info_keywords: Tuple[str, ...] = (), 30 | ): 31 | # Avoid circular import 32 | from stable_baselines3.common.monitor import Monitor, ResultsWriter 33 | 34 | # This check is not valid for special `VecEnv` 35 | # like the ones created by Procgen, that does follow completely 36 | # the `VecEnv` interface 37 | try: 38 | is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] 39 | except AttributeError: 40 | is_wrapped_with_monitor = False 41 | 42 | if is_wrapped_with_monitor: 43 | warnings.warn( 44 | "The environment is already wrapped with a `Monitor` wrapper" 45 | "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" 46 | "overwritten by the `VecMonitor` ones.", 47 | UserWarning, 48 | ) 49 | 50 | VecEnvWrapper.__init__(self, venv) 51 | self.episode_returns = None 52 | self.episode_lengths = None 53 | self.episode_count = 0 54 | self.t_start = time.time() 55 | 56 | env_id = None 57 | if hasattr(venv, "spec") and venv.spec is not None: 58 | env_id = venv.spec.id 59 | 60 | if filename: 61 | self.results_writer = ResultsWriter( 62 | filename, header={ 63 | "t_start": self.t_start, 64 | "env_id": env_id 65 | }, extra_keys=info_keywords 66 | ) 67 | else: 68 | self.results_writer = None 69 | self.info_keywords = info_keywords 70 | 71 | def reset(self) -> VecEnvObs: 72 | obs = self.venv.reset() 73 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 74 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 75 | return obs 76 | 77 | def step_wait(self) -> VecEnvStepReturn: 78 | obs, rewards, dones, infos = self.venv.step_wait() 79 | self.episode_returns += rewards 80 | self.episode_lengths += 1 81 | new_infos = list(infos[:]) 82 | for i in range(len(dones)): 83 | if dones[i]: 84 | info = infos[i].copy() 85 | episode_return = self.episode_returns[i] 86 | episode_length = self.episode_lengths[i] 87 | episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} 88 | for key in self.info_keywords: 89 | episode_info[key] = info[key] 90 | info["episode"] = episode_info 91 | self.episode_count += 1 92 | self.episode_returns[i] = 0 93 | self.episode_lengths[i] = 0 94 | if self.results_writer: 95 | self.results_writer.write_row(episode_info) 96 | new_infos[i] = info 97 | return obs, rewards, dones, new_infos 98 | 99 | def close(self) -> None: 100 | if self.results_writer: 101 | self.results_writer.close() 102 | return self.venv.close() 103 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/vec_transpose.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | from gym import spaces 6 | 7 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first 8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 9 | 10 | 11 | class VecTransposeImage(VecEnvWrapper): 12 | """ 13 | Re-order channels, from HxWxC to CxHxW. 14 | It is required for PyTorch convolution layers. 15 | 16 | :param venv: 17 | :param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not, 18 | which may result in unwanted behavior, see GH issue #671. 19 | """ 20 | def __init__(self, venv: VecEnv, skip: bool = False): 21 | assert is_image_space(venv.observation_space) or isinstance( 22 | venv.observation_space, spaces.dict.Dict 23 | ), "The observation space must be an image or dictionary observation space" 24 | 25 | self.skip = skip 26 | # Do nothing 27 | if skip: 28 | super().__init__(venv) 29 | return 30 | 31 | if isinstance(venv.observation_space, spaces.dict.Dict): 32 | self.image_space_keys = [] 33 | observation_space = deepcopy(venv.observation_space) 34 | for key, space in observation_space.spaces.items(): 35 | if is_image_space(space): 36 | # Keep track of which keys should be transposed later 37 | self.image_space_keys.append(key) 38 | observation_space.spaces[key] = self.transpose_space(space, key) 39 | else: 40 | observation_space = self.transpose_space(venv.observation_space) 41 | super().__init__(venv, observation_space=observation_space) 42 | 43 | @staticmethod 44 | def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: 45 | """ 46 | Transpose an observation space (re-order channels). 47 | 48 | :param observation_space: 49 | :param key: In case of dictionary space, the key of the observation space. 50 | :return: 51 | """ 52 | # Sanity checks 53 | assert is_image_space(observation_space), "The observation space must be an image" 54 | assert not is_image_space_channels_first( 55 | observation_space 56 | ), f"The observation space {key} must follow the channel last convention" 57 | height, width, channels = observation_space.shape 58 | new_shape = (channels, height, width) 59 | return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) 60 | 61 | @staticmethod 62 | def transpose_image(image: np.ndarray) -> np.ndarray: 63 | """ 64 | Transpose an image or batch of images (re-order channels). 65 | 66 | :param image: 67 | :return: 68 | """ 69 | if len(image.shape) == 3: 70 | return np.transpose(image, (2, 0, 1)) 71 | return np.transpose(image, (0, 3, 1, 2)) 72 | 73 | def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: 74 | """ 75 | Transpose (if needed) and return new observations. 76 | 77 | :param observations: 78 | :return: Transposed observations 79 | """ 80 | # Do nothing 81 | if self.skip: 82 | return observations 83 | 84 | if isinstance(observations, dict): 85 | # Avoid modifying the original object in place 86 | observations = deepcopy(observations) 87 | for k in self.image_space_keys: 88 | observations[k] = self.transpose_image(observations[k]) 89 | else: 90 | observations = self.transpose_image(observations) 91 | return observations 92 | 93 | def step_wait(self) -> VecEnvStepReturn: 94 | observations, rewards, dones, infos = self.venv.step_wait() 95 | 96 | # Transpose the terminal observations 97 | for idx, done in enumerate(dones): 98 | if not done: 99 | continue 100 | if "terminal_observation" in infos[idx]: 101 | infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) 102 | 103 | return self.transpose_observations(observations), rewards, dones, infos 104 | 105 | def reset(self) -> Union[np.ndarray, Dict]: 106 | """ 107 | Reset all environments 108 | """ 109 | return self.transpose_observations(self.venv.reset()) 110 | 111 | def close(self) -> None: 112 | self.venv.close() 113 | -------------------------------------------------------------------------------- /stable_baselines3/common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | 4 | from gym.wrappers.monitoring import video_recorder 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 9 | 10 | 11 | class VecVideoRecorder(VecEnvWrapper): 12 | """ 13 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. 14 | It requires ffmpeg or avconv to be installed on the machine. 15 | 16 | :param venv: 17 | :param video_folder: Where to save videos 18 | :param record_video_trigger: Function that defines when to start recording. 19 | The function takes the current number of step, 20 | and returns whether we should start recording or not. 21 | :param video_length: Length of recorded videos 22 | :param name_prefix: Prefix to the video name 23 | """ 24 | def __init__( 25 | self, 26 | venv: VecEnv, 27 | video_folder: str, 28 | record_video_trigger: Callable[[int], bool], 29 | video_length: int = 200, 30 | name_prefix: str = "rl-video", 31 | ): 32 | 33 | VecEnvWrapper.__init__(self, venv) 34 | 35 | self.env = venv 36 | # Temp variable to retrieve metadata 37 | temp_env = venv 38 | 39 | # Unwrap to retrieve metadata dict 40 | # that will be used by gym recorder 41 | while isinstance(temp_env, VecEnvWrapper): 42 | temp_env = temp_env.venv 43 | 44 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): 45 | metadata = temp_env.get_attr("metadata")[0] 46 | else: 47 | metadata = temp_env.metadata 48 | 49 | self.env.metadata = metadata 50 | 51 | self.record_video_trigger = record_video_trigger 52 | self.video_recorder = None 53 | 54 | self.video_folder = os.path.abspath(video_folder) 55 | # Create output folder if needed 56 | os.makedirs(self.video_folder, exist_ok=True) 57 | 58 | self.name_prefix = name_prefix 59 | self.step_id = 0 60 | self.video_length = video_length 61 | 62 | self.recording = False 63 | self.recorded_frames = 0 64 | 65 | def reset(self) -> VecEnvObs: 66 | obs = self.venv.reset() 67 | self.start_video_recorder() 68 | return obs 69 | 70 | def start_video_recorder(self) -> None: 71 | self.close_video_recorder() 72 | 73 | video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" 74 | base_path = os.path.join(self.video_folder, video_name) 75 | self.video_recorder = video_recorder.VideoRecorder( 76 | env=self.env, base_path=base_path, metadata={"step_id": self.step_id} 77 | ) 78 | 79 | self.video_recorder.capture_frame() 80 | self.recorded_frames = 1 81 | self.recording = True 82 | 83 | def _video_enabled(self) -> bool: 84 | return self.record_video_trigger(self.step_id) 85 | 86 | def step_wait(self) -> VecEnvStepReturn: 87 | obs, rews, dones, infos = self.venv.step_wait() 88 | 89 | self.step_id += 1 90 | if self.recording: 91 | self.video_recorder.capture_frame() 92 | self.recorded_frames += 1 93 | if self.recorded_frames > self.video_length: 94 | print(f"Saving video to {self.video_recorder.path}") 95 | self.close_video_recorder() 96 | elif self._video_enabled(): 97 | self.start_video_recorder() 98 | 99 | return obs, rews, dones, infos 100 | 101 | def close_video_recorder(self) -> None: 102 | if self.recording: 103 | self.video_recorder.close() 104 | self.recording = False 105 | self.recorded_frames = 1 106 | 107 | def close(self) -> None: 108 | VecEnvWrapper.close(self) 109 | self.close_video_recorder() 110 | 111 | def __del__(self): 112 | self.close() 113 | -------------------------------------------------------------------------------- /stable_baselines3/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.ddpg.ddpg import DDPG 2 | from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /stable_baselines3/ddpg/ddpg.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Type, Union 2 | 3 | import torch as th 4 | 5 | from stable_baselines3.common.buffers import ReplayBuffer 6 | from stable_baselines3.common.noise import ActionNoise 7 | from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm 8 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule 9 | from stable_baselines3.td3.policies import TD3Policy 10 | from stable_baselines3.td3.td3 import TD3 11 | 12 | 13 | class DDPG(TD3): 14 | """ 15 | Deep Deterministic Policy Gradient (DDPG). 16 | 17 | Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf 18 | DDPG Paper: https://arxiv.org/abs/1509.02971 19 | Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html 20 | 21 | Note: we treat DDPG as a special case of its successor TD3. 22 | 23 | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) 24 | :param env: The environment to learn from (if registered in Gym, can be str) 25 | :param learning_rate: learning rate for adam optimizer, 26 | the same learning rate will be used for all networks (Q-Values, Actor and Value function) 27 | it can be a function of the current progress remaining (from 1 to 0) 28 | :param buffer_size: size of the replay buffer 29 | :param learning_starts: how many steps of the model to collect transitions for before learning starts 30 | :param batch_size: Minibatch size for each gradient update 31 | :param tau: the soft update coefficient ("Polyak update", between 0 and 1) 32 | :param gamma: the discount factor 33 | :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit 34 | like ``(5, "step")`` or ``(2, "episode")``. 35 | :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) 36 | Set to ``-1`` means to do as many gradient steps as steps done in the environment 37 | during the rollout. 38 | :param action_noise: the action noise type (None by default), this can help 39 | for hard exploration problem. Cf common.noise for the different action noise type. 40 | :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). 41 | If ``None``, it will be automatically selected. 42 | :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. 43 | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer 44 | at a cost of more complexity. 45 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 46 | :param create_eval_env: Whether to create a second environment that will be 47 | used for evaluating the agent periodically. (Only available when passing string for the environment) 48 | :param policy_kwargs: additional arguments to be passed to the policy on creation 49 | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug 50 | :param seed: Seed for the pseudo random generators 51 | :param device: Device (cpu, cuda, ...) on which the code should be run. 52 | Setting it to auto, the code will be run on the GPU if possible. 53 | :param _init_setup_model: Whether or not to build the network at the creation of the instance 54 | """ 55 | def __init__( 56 | self, 57 | policy: Union[str, Type[TD3Policy]], 58 | env: Union[GymEnv, str], 59 | learning_rate: Union[float, Schedule] = 1e-3, 60 | buffer_size: int = 1_000_000, # 1e6 61 | learning_starts: int = 100, 62 | batch_size: int = 100, 63 | tau: float = 0.005, 64 | gamma: float = 0.99, 65 | train_freq: Union[int, Tuple[int, str]] = (1, "episode"), 66 | gradient_steps: int = -1, 67 | action_noise: Optional[ActionNoise] = None, 68 | replay_buffer_class: Optional[ReplayBuffer] = None, 69 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None, 70 | optimize_memory_usage: bool = False, 71 | tensorboard_log: Optional[str] = None, 72 | create_eval_env: bool = False, 73 | policy_kwargs: Optional[Dict[str, Any]] = None, 74 | verbose: int = 0, 75 | seed: Optional[int] = None, 76 | device: Union[th.device, str] = "auto", 77 | _init_setup_model: bool = True, 78 | ): 79 | 80 | super().__init__( 81 | policy=policy, 82 | env=env, 83 | learning_rate=learning_rate, 84 | buffer_size=buffer_size, 85 | learning_starts=learning_starts, 86 | batch_size=batch_size, 87 | tau=tau, 88 | gamma=gamma, 89 | train_freq=train_freq, 90 | gradient_steps=gradient_steps, 91 | action_noise=action_noise, 92 | replay_buffer_class=replay_buffer_class, 93 | replay_buffer_kwargs=replay_buffer_kwargs, 94 | policy_kwargs=policy_kwargs, 95 | tensorboard_log=tensorboard_log, 96 | verbose=verbose, 97 | device=device, 98 | create_eval_env=create_eval_env, 99 | seed=seed, 100 | optimize_memory_usage=optimize_memory_usage, 101 | # Remove all tricks from TD3 to obtain DDPG: 102 | # we still need to specify target_policy_noise > 0 to avoid errors 103 | policy_delay=1, 104 | target_noise_clip=0.0, 105 | target_policy_noise=0.1, 106 | _init_setup_model=False, 107 | ) 108 | 109 | # Use only one critic 110 | if "n_critics" not in self.policy_kwargs: 111 | self.policy_kwargs["n_critics"] = 1 112 | 113 | if _init_setup_model: 114 | self._setup_model() 115 | 116 | def learn( 117 | self, 118 | total_timesteps: int, 119 | callback: MaybeCallback = None, 120 | log_interval: int = 4, 121 | eval_env: Optional[GymEnv] = None, 122 | eval_freq: int = -1, 123 | n_eval_episodes: int = 5, 124 | tb_log_name: str = "DDPG", 125 | eval_log_path: Optional[str] = None, 126 | reset_num_timesteps: bool = True, 127 | ) -> OffPolicyAlgorithm: 128 | 129 | return super().learn( 130 | total_timesteps=total_timesteps, 131 | callback=callback, 132 | log_interval=log_interval, 133 | eval_env=eval_env, 134 | eval_freq=eval_freq, 135 | n_eval_episodes=n_eval_episodes, 136 | tb_log_name=tb_log_name, 137 | eval_log_path=eval_log_path, 138 | reset_num_timesteps=reset_num_timesteps, 139 | ) 140 | -------------------------------------------------------------------------------- /stable_baselines3/ddpg/policies.py: -------------------------------------------------------------------------------- 1 | # DDPG can be view as a special case of TD3 2 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401 3 | -------------------------------------------------------------------------------- /stable_baselines3/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.dqn.dqn import DQN 2 | from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /stable_baselines3/her/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy 2 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer 3 | -------------------------------------------------------------------------------- /stable_baselines3/her/goal_selection_strategy.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GoalSelectionStrategy(Enum): 5 | """ 6 | The strategies for selecting new goals when 7 | creating artificial transitions. 8 | """ 9 | 10 | # Select a goal that was achieved 11 | # after the current step, in the same episode 12 | FUTURE = 0 13 | # Select the goal that was achieved 14 | # at the end of the episode 15 | FINAL = 1 16 | # Select a goal that was achieved in the episode 17 | EPISODE = 2 18 | 19 | 20 | # For convenience 21 | # that way, we can use string to select a strategy 22 | KEY_TO_GOAL_STRATEGY = { 23 | "future": GoalSelectionStrategy.FUTURE, 24 | "final": GoalSelectionStrategy.FINAL, 25 | "episode": GoalSelectionStrategy.EPISODE, 26 | } 27 | -------------------------------------------------------------------------------- /stable_baselines3/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.ppo.ppo import PPO 3 | from stable_baselines3.ppo.ppo_test import PPO_Test 4 | -------------------------------------------------------------------------------- /stable_baselines3/ppo/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for PPO 3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy 4 | 5 | MlpPolicy = ActorCriticPolicy 6 | CnnPolicy = ActorCriticCnnPolicy 7 | MultiInputPolicy = MultiInputActorCriticPolicy 8 | -------------------------------------------------------------------------------- /stable_baselines3/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.sac.sac import SAC 3 | -------------------------------------------------------------------------------- /stable_baselines3/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.td3.td3 import TD3 3 | -------------------------------------------------------------------------------- /stable_baselines3/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | from legged_gym.env.base.base_task import BaseTask 4 | from legged_gym import OPEN_ROBOT_ROOT_DIR 5 | from gleam.wrapper.env_wrapper_gleam import EnvWrapperGLEAM 6 | 7 | 8 | root = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 9 | 10 | 11 | def get_api_key_file(wandb_key_file=None): 12 | wandb_key_file = wandb_key_file or "wandb_api_key_file.txt" 13 | root = OPEN_ROBOT_ROOT_DIR 14 | path = os.path.join(root, "wandb_utils", wandb_key_file) 15 | print("We are using this wandb key file: ", path) 16 | return path 17 | 18 | 19 | def get_time_str(): 20 | return datetime.datetime.now().strftime("%m%d_%H%M_%S") 21 | 22 | 23 | def is_isaac_gym_env(env): 24 | Isacc_Gym_Env = [ 25 | BaseTask, 26 | EnvWrapperGLEAM, 27 | ] 28 | 29 | for target_env in Isacc_Gym_Env: 30 | if isinstance(env, target_env): 31 | return True 32 | return False 33 | -------------------------------------------------------------------------------- /stable_baselines3/version.txt: -------------------------------------------------------------------------------- 1 | 1.6.0 2 | -------------------------------------------------------------------------------- /wandb_utils/__init__.py: -------------------------------------------------------------------------------- 1 | project_name = "active_reconstruction" 2 | team_name = "openrobot" 3 | -------------------------------------------------------------------------------- /wandb_utils/wandb_api_key_file.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/wandb_utils/wandb_api_key_file.txt -------------------------------------------------------------------------------- /wandb_utils/wandb_callback.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import wandb 5 | from wandb.sdk.lib import telemetry as wb_telemetry 6 | from stable_baselines3.utils import get_api_key_file 7 | 8 | from stable_baselines3.common.callbacks import BaseCallback 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class WandbCallback(BaseCallback): 14 | """ Log SB3 experiments to Weights and Biases 15 | - Added model tracking and uploading 16 | - Added complete hyperparameters recording 17 | - Added gradient logging 18 | - Note that `wandb.init(...)` must be called before the WandbCallback can be used 19 | Args: 20 | verbose: The verbosity of sb3 output 21 | model_save_path: Path to the folder where the model will be saved, The default value is `None` so the model is not logged 22 | model_save_freq: Frequency to save the model 23 | gradient_save_freq: Frequency to log gradient. The default value is 0 so the gradients are not logged 24 | """ 25 | def __init__( 26 | self, 27 | trial_name, 28 | exp_name, 29 | project_name, 30 | config=None, 31 | verbose: int = 0, 32 | model_save_path: str = None, 33 | model_save_freq: int = 0, 34 | gradient_save_freq: int = 0, 35 | ): 36 | team_name = config["team_name"] 37 | # PZH: Setup our key 38 | WANDB_ENV_VAR = "WANDB_API_KEY" 39 | key_file_path = get_api_key_file( 40 | "wandb_lqy_key_file.text" if team_name == "lqy0057" else None 41 | ) # Search ~/wandb_api_key_file.txt first, then use PZH's 42 | with open(key_file_path, "r") as f: 43 | key = f.readline() 44 | key = key.replace("\n", "") 45 | key = key.replace(" ", "") 46 | os.environ[WANDB_ENV_VAR] = key 47 | 48 | self.run = wandb.init( 49 | id=trial_name, 50 | # id=exp_name, 51 | # name=run_name, 52 | config=config or {}, 53 | resume=True, 54 | reinit=True, 55 | # allow_val_change=True, 56 | group=exp_name, 57 | project=project_name, 58 | entity=team_name, 59 | sync_tensorboard=True, # Open this and setup tb in sb3 so that we can get log! 60 | save_code=False 61 | ) 62 | 63 | super(WandbCallback, self).__init__(verbose) 64 | if wandb.run is None: 65 | raise wandb.Error("You must call wandb.init() before WandbCallback()") 66 | with wb_telemetry.context() as tel: 67 | tel.feature.sb3 = True 68 | self.model_save_freq = model_save_freq 69 | self.model_save_path = model_save_path 70 | self.gradient_save_freq = gradient_save_freq 71 | 72 | # Create folder if needed 73 | if self.model_save_path is not None: 74 | os.makedirs(self.model_save_path, exist_ok=True) 75 | self.path = os.path.join(self.model_save_path, "model.zip") 76 | else: 77 | assert ( 78 | self.model_save_freq == 0 79 | ), "to use the `model_save_freq` you have to set the `model_save_path` parameter" 80 | 81 | def _init_callback(self) -> None: 82 | d = {} 83 | if "algo" not in d: 84 | d["algo"] = type(self.model).__name__ 85 | for key in self.model.__dict__: 86 | if key in wandb.config: 87 | continue 88 | if type(self.model.__dict__[key]) in [float, int, str]: 89 | d[key] = self.model.__dict__[key] 90 | else: 91 | d[key] = str(self.model.__dict__[key]) 92 | if self.gradient_save_freq > 0: 93 | wandb.watch(self.model.policy, log_freq=self.gradient_save_freq, log="all") 94 | wandb.config.setdefaults(d) 95 | 96 | def _on_step(self) -> bool: 97 | if self.model_save_freq > 0: 98 | if self.model_save_path is not None: 99 | if self.n_calls % self.model_save_freq == 0: 100 | self.save_model() 101 | return True 102 | 103 | def _on_training_end(self) -> None: 104 | if self.model_save_path is not None: 105 | self.save_model() 106 | 107 | def save_model(self) -> None: 108 | self.model.save(self.path) 109 | wandb.save(self.path, base_path=self.model_save_path) 110 | if self.verbose > 1: 111 | logger.info("Saving model checkpoint to " + self.path) 112 | --------------------------------------------------------------------------------