├── .gitignore
├── LICENSE
├── README.md
├── assets
├── main_result.png
├── overview.png
├── overview_gleambench.png
└── statistic.png
├── data_gleam
└── README.md
├── gleam
├── __init__.py
├── callback.py
├── env
│ ├── config_gleam.py
│ ├── config_gleam_eval.py
│ ├── config_legged_visual.py
│ ├── env_base.py
│ ├── env_gleam_base.py
│ ├── env_gleam_eval.py
│ ├── env_gleam_stage1.py
│ └── env_gleam_stage2.py
├── network
│ ├── base.py
│ ├── encoder.py
│ ├── init.py
│ └── locotransformer.py
├── test
│ └── test_gleam_gleambench.py
├── train
│ ├── __init__.py
│ ├── train_gleam_stage1.py
│ └── train_gleam_stage2.py
├── utils
│ ├── bfs_cuda_2D.h
│ ├── bfs_cuda_kernel_2D.cu
│ └── utils.py
└── wrapper
│ └── env_wrapper_gleam.py
├── legged_gym
├── __init__.py
├── env
│ ├── __init__.py
│ ├── a1
│ │ └── a1_config.py
│ ├── anymal_b
│ │ └── anymal_b_config.py
│ ├── anymal_c
│ │ ├── anymal.py
│ │ ├── flat
│ │ │ └── anymal_c_flat_config.py
│ │ └── mixed_terrains
│ │ │ └── anymal_c_rough_config.py
│ └── base
│ │ ├── base_config.py
│ │ ├── base_task.py
│ │ ├── drone_robot.py
│ │ ├── drone_robot_visual.py
│ │ ├── legged_robot.py
│ │ └── legged_robot_config.py
├── scripts
│ ├── play.py
│ └── train.py
├── tests
│ └── test_env.py
└── utils
│ ├── __init__.py
│ ├── helpers.py
│ ├── logger.py
│ ├── math.py
│ ├── task_registry.py
│ └── terrain.py
├── licenses
├── assets
│ ├── ANYmal_b_license.txt
│ ├── ANYmal_c_license.txt
│ ├── a1_license.txt
│ └── cassie_license.txt
└── dependencies
│ └── matplotlib_license.txt
├── requirements.txt
├── resources
└── robots
│ └── drone
│ └── cf2.dae
├── rsl_rl
├── __init__.py
├── algorithms
│ ├── __init__.py
│ └── ppo.py
├── env
│ ├── __init__.py
│ └── vec_env.py
├── modules
│ ├── __init__.py
│ ├── actor_critic.py
│ └── actor_critic_recurrent.py
├── runners
│ ├── __init__.py
│ └── on_policy_runner.py
├── storage
│ ├── __init__.py
│ └── rollout_storage.py
└── utils
│ ├── __init__.py
│ └── utils.py
├── setup.py
├── stable_baselines3
├── __init__.py
├── a2c
│ ├── __init__.py
│ ├── a2c.py
│ └── policies.py
├── common
│ ├── __init__.py
│ ├── atari_wrappers.py
│ ├── base_class.py
│ ├── base_class_grid_obs.py
│ ├── buffers.py
│ ├── callbacks.py
│ ├── callbacks_gleam.py
│ ├── distributions.py
│ ├── env_checker.py
│ ├── env_util.py
│ ├── envs
│ │ ├── __init__.py
│ │ ├── bit_flipping_env.py
│ │ ├── identity_env.py
│ │ └── multi_input_envs.py
│ ├── evaluation.py
│ ├── evaluation_gleam.py
│ ├── logger.py
│ ├── monitor.py
│ ├── noise.py
│ ├── off_policy_algorithm.py
│ ├── on_policy_algorithm.py
│ ├── on_policy_algorithm_grid_obs.py
│ ├── on_policy_algorithm_hybrid_act.py
│ ├── policies.py
│ ├── preprocessing.py
│ ├── results_plotter.py
│ ├── running_mean_std.py
│ ├── save_util.py
│ ├── sb2_compat
│ │ ├── __init__.py
│ │ └── rmsprop_tf_like.py
│ ├── torch_layers.py
│ ├── type_aliases.py
│ ├── utils.py
│ └── vec_env
│ │ ├── __init__.py
│ │ ├── base_vec_env.py
│ │ ├── dummy_vec_env.py
│ │ ├── stacked_observations.py
│ │ ├── subproc_vec_env.py
│ │ ├── util.py
│ │ ├── vec_check_nan.py
│ │ ├── vec_extract_dict_obs.py
│ │ ├── vec_frame_stack.py
│ │ ├── vec_monitor.py
│ │ ├── vec_normalize.py
│ │ ├── vec_transpose.py
│ │ └── vec_video_recorder.py
├── ddpg
│ ├── __init__.py
│ ├── ddpg.py
│ └── policies.py
├── dqn
│ ├── __init__.py
│ ├── dqn.py
│ └── policies.py
├── her
│ ├── __init__.py
│ ├── goal_selection_strategy.py
│ └── her_replay_buffer.py
├── ppo
│ ├── __init__.py
│ ├── policies.py
│ ├── ppo.py
│ ├── ppo_grid_hybrid_action.py
│ ├── ppo_grid_obs.py
│ ├── ppo_occant.py
│ └── ppo_test.py
├── sac
│ ├── __init__.py
│ ├── policies.py
│ └── sac.py
├── td3
│ ├── __init__.py
│ ├── policies.py
│ └── td3.py
├── utils.py
└── version.txt
└── wandb_utils
├── __init__.py
├── wandb_api_key_file.txt
└── wandb_callback.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # These are some examples of commonly ignored file patterns.
2 | # You should customize this list as applicable to your project.
3 | # Learn more about .gitignore:
4 | # https://www.atlassian.com/git/tutorials/saving-changes/gitignore
5 |
6 | # Node artifact files
7 | node_modules/
8 |
9 | # Compiled Java class files
10 | *.class
11 |
12 | # Compiled Python bytecode
13 | *.py[cod]
14 |
15 | # Log files
16 | *.log
17 | log.txt
18 |
19 | # Package files
20 | *.jar
21 |
22 | # Maven
23 | target/
24 | dist/
25 |
26 | # JetBrains IDE
27 | .idea/
28 |
29 | # Unit test reports
30 | TEST*.xml
31 |
32 | # Generated by MacOS
33 | .DS_Store
34 |
35 | # Generated by Windows
36 | Thumbs.db
37 |
38 | # Applications
39 | *.app
40 | *.exe
41 | *.war
42 |
43 | # Large media files
44 | *.gif
45 | *.mp4
46 | *.tiff
47 | *.avi
48 | *.flv
49 | *.mov
50 | *.wmv
51 |
52 | # VS Code
53 | .vscode
54 | *.vsix
55 |
56 | # logs
57 | logs
58 | runs
59 |
60 | # other
61 | *.egg-info
62 | __pycache__
63 | venv
64 | *.cpp
65 |
66 | # model
67 | *.pkl
68 | *.zip
69 | *.pt
70 | *.pth
71 |
72 | # log
73 | *.log
74 | *.csv
75 |
76 | # pyc
77 | *.pyc
78 |
79 | # json
80 | *.json
81 |
82 | # build
83 | /build/
84 | **.so
85 |
86 | # test scripts
87 | **/wandb/
88 | *.idea
89 |
90 |
91 | # Image
92 | *.jpg
93 | *.jpeg
94 | *.png
95 |
96 |
97 | *.npy
98 | *.tar.gz
99 |
100 | events.out*
101 | *.th
102 | log
103 | *.mtl
104 | *.obj
105 | *.ply
106 | *.glb
107 | *.fbx
108 | *.pcd
109 | *.xyz
110 | *.mlp
111 | *.sh
112 | *.lib
113 | *.whl
114 | *.hdr
115 | *.bin
116 | *.navmesh
117 | *.sur
118 | *.h5
119 | *.urdf
120 |
121 | data_gleam/*
122 | venv/*
123 |
124 | # module
125 | ckpt
126 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, ETH Zurich, Nikita Rudin
2 | Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification,
6 | are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice,
9 | this list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its contributors
16 | may be used to endorse or promote products derived from this software without
17 | specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
26 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
30 | See licenses/assets for license information for assets included in this repository.
31 | See licenses/dependencies for license information of dependencies of this package.
32 |
--------------------------------------------------------------------------------
/assets/main_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/assets/main_result.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/assets/overview.png
--------------------------------------------------------------------------------
/assets/overview_gleambench.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/assets/overview_gleambench.png
--------------------------------------------------------------------------------
/assets/statistic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/assets/statistic.png
--------------------------------------------------------------------------------
/data_gleam/README.md:
--------------------------------------------------------------------------------
1 | # GLEAM-Bench
2 |
3 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | We introduce GLEAM-Bench, a benchmark for generalizable exploration for active mapping in complex 3D indoor scenes.
18 | These scene meshes are characterized by watertight geometry, diverse floorplans (≥10 types), and complex interconnectivity. We unify and refine multi-source datasets through manual filtering, geometric repair, and task-oriented preprocessing.
19 | To simulate the exploration process, we connect our dataset with NVIDIA Isaac Gym, enabling parallel sensory data simulation and online policy training.
20 |
21 |
22 | ## Download
23 | Please fill in the [form](https://docs.google.com/forms/d/e/1FAIpQLSdq9aX1dwoyBb31nm8L_Mx5FeaVsr5AY538UiwKqg8LPKX9vg/viewform?usp=sharing) to access the download links.
24 |
25 | **Consolidated Version (GLEAM)**
26 |
27 | We provide all the preprocessed data used in our work, , including mesh files (in `obj` folder), ground-truth surface points (in `gt` folder) and asset indexing files (in `urdf` folder). The directory structure should be as follows.
28 |
29 | ```
30 | data_gleam
31 | ├── train_stage1_512
32 | │ ├── gt
33 | │ ├── obj
34 | │ ├── urdf
35 | ├── train_stage2_512
36 | │ ├── gt
37 | │ ├── obj
38 | │ ├── urdf
39 | ├── eval_128
40 | │ ├── gt
41 | │ ├── obj
42 | │ ├── urdf
43 | ```
44 |
45 | The standard training process of GLEAM is divided into two stages (i.e. stage1 and stage2), each involving different 512 training indoor scenes. The evaluation involves 128 unseen testing scenes from ProcTHOR, HSSD, Gibson and Matterport3D (cross-dataset generalization).
46 |
47 |
48 |
49 |
50 |
51 |
52 | **Original Separate Version**
53 |
54 | The separate meshes are also provided in the download links:
55 | ```
56 | data_gleam_raw
57 | ├── procthor_12-room_64
58 | │ ├── gt
59 | │ ├── obj
60 | ├── gibson_96
61 | │ ├── gt
62 | │ ├── obj
63 | ├── hssd_32
64 | │ ├── gt
65 | │ ├── obj
66 | ...
67 | ```
68 |
69 | **Raw Version (With Texture and Part-Level Mesh Layers)**
70 |
71 | Due to storage limitations, we will not be releasing scenes with textures or part-level mesh layers. If you want to get the raw version of GLEAM-Bench, we recommend you:
72 | - **ProcTHOR**: export the meshes using our provided C# script;
73 | - **HSSD**, **Gibson**, **Matterport3D**: download the corresponding raw meshes from the original source.
74 |
75 |
76 | ## Export Meshes from Generated Scenes by ProcTHOR
77 | We also provide the C# script [[HERE](https://github.com/zjwzcx/Batch-Export-ProcTHOR-Meshes)] to export mesh files (.fbx) from generated scenes by ProcTHOR. Note that these generated mesh files have textures and interactive object-level layers.
78 |
79 |
80 | ## Citation
81 |
82 | The GLEAM-Bench dataset comes from the GLEAM paper:
83 |
84 | - **arXiv**: https://arxiv.org/abs/2505.20294
85 |
86 | - **Code**: https://github.com/zjwzcx/GLEAM
87 |
88 | - **BibTex**:
89 | ```bibtex
90 | @misc{chen2025gleam,
91 | title={GLEAM: Learning Generalizable Exploration Policy for Active Mapping in Complex 3D Indoor Scenes},
92 | author={Xiao Chen and Tai Wang and Quanyi Li and Tao Huang and Jiangmiao Pang and Tianfan Xue},
93 | year={2025},
94 | eprint={2505.20294},
95 | archivePrefix={arXiv},
96 | primaryClass={cs.CV},
97 | url={https://arxiv.org/abs/2505.20294},
98 | }
99 | ```
100 |
101 |
102 | If you use our dataset and benchmark, please kindly cite the original datasets involved in our work. BibTex entries are provided below.
103 |
104 | Dataset BibTex
105 |
106 | ```bibtex
107 | @inproceedings{procthor,
108 | author={Matt Deitke and Eli VanderBilt and Alvaro Herrasti and
109 | Luca Weihs and Jordi Salvador and Kiana Ehsani and
110 | Winson Han and Eric Kolve and Ali Farhadi and
111 | Aniruddha Kembhavi and Roozbeh Mottaghi},
112 | title={{ProcTHOR: Large-Scale Embodied AI Using Procedural Generation}},
113 | booktitle={NeurIPS},
114 | year={2022},
115 | note={Outstanding Paper Award}
116 | }
117 | ```
118 | ```bibtex
119 | @inproceedings{xiazamirhe2018gibsonenv,
120 | title={Gibson Env: real-world perception for embodied agents},
121 | author={Xia, Fei and R. Zamir, Amir and He, Zhi-Yang and Sax, Alexander and Malik, Jitendra and Savarese, Silvio},
122 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2018 IEEE Conference on},
123 | year={2018},
124 | organization={IEEE}
125 | }
126 | ```
127 | ```bibtex
128 | @article{khanna2023hssd,
129 | author={Khanna*, Mukul and Mao*, Yongsen and Jiang, Hanxiao and Haresh, Sanjay and Shacklett, Brennan and Batra, Dhruv and Clegg, Alexander and Undersander, Eric and Chang, Angel X. and Savva, Manolis},
130 | title={{Habitat Synthetic Scenes Dataset (HSSD-200): An Analysis of 3D Scene Scale and Realism Tradeoffs for ObjectGoal Navigation}},
131 | journal={arXiv preprint},
132 | year={2023},
133 | eprint={2306.11290},
134 | archivePrefix={arXiv},
135 | primaryClass={cs.CV}
136 | }
137 | ```
138 | ```bibtex
139 | @article{Matterport3D,
140 | title={Matterport3D: Learning from RGB-D Data in Indoor Environments},
141 | author={Chang, Angel and Dai, Angela and Funkhouser, Thomas and Halber, Maciej and Niessner, Matthias and Savva, Manolis and Song, Shuran and Zeng, Andy and Zhang, Yinda},
142 | journal={International Conference on 3D Vision (3DV)},
143 | year={2017}
144 | }
145 | ```
146 |
147 |
148 |
--------------------------------------------------------------------------------
/gleam/__init__.py:
--------------------------------------------------------------------------------
1 | from legged_gym.utils.task_registry import task_registry
2 |
3 | from gleam.env.env_gleam_stage1 import Env_GLEAM_Stage1
4 | from gleam.env.env_gleam_stage2 import Env_GLEAM_Stage2
5 | from gleam.env.env_gleam_eval import Env_GLEAM_Eval
6 | from gleam.env.config_gleam import Config_GLEAM, DroneCfgPPO
7 | from gleam.env.config_gleam_eval import Config_GLEAM_Eval
8 | task_registry.register("train_gleam_stage1", Env_GLEAM_Stage1, Config_GLEAM, DroneCfgPPO)
9 | task_registry.register("train_gleam_stage2", Env_GLEAM_Stage2, Config_GLEAM, DroneCfgPPO)
10 | task_registry.register("eval_gleam_gleambench", Env_GLEAM_Eval, Config_GLEAM_Eval, DroneCfgPPO)
11 |
--------------------------------------------------------------------------------
/gleam/callback.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch as th
3 | from stable_baselines3.common.callbacks import CheckpointCallback
4 |
5 |
6 | class BestCKPTCallback(CheckpointCallback):
7 | def __init__(
8 | self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0, key_list: list = None
9 | ):
10 | key_list = key_list or []
11 | self.rollout_count = 0
12 | super(BestCKPTCallback, self).__init__(save_freq, save_path, name_prefix, verbose)
13 | self.key_highest_value = {k: 0. for k in key_list}
14 |
15 | @property
16 | def best_save_path(self):
17 | return self.locals["self"].logger.dir
18 |
19 | @property
20 | def device(self):
21 | return self.locals["self"].device
22 |
23 | def _on_rollout_end(self) -> None:
24 | if self.n_calls % self.save_freq == 0:
25 | path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
26 | self.model.save(path)
27 | if self.verbose > 1:
28 | print(f"Saving model checkpoint to {path}")
29 |
30 | for key in self.key_highest_value.keys():
31 | value = self.calculate_value(key)
32 | if value > self.key_highest_value[key]:
33 | self.key_highest_value[key] = value
34 | path = os.path.join(self.best_save_path, "{}_best_{}".format(self.name_prefix, key))
35 | self.model.save(path)
36 | # if self.verbose > 1:
37 | print("Saving Best {} checkpoint to {}: {}".format(key, path, value))
38 |
39 | def calculate_value(self, key):
40 | ep_info_buffer = self.locals["self"].ep_info_buffer
41 | infotensor = th.tensor([], device=self.device)
42 | assert key in ep_info_buffer[0], "no key named {}, can not save checkpoint".format(key)
43 | for ep_info in ep_info_buffer:
44 | # handle scalar and zero dimensional tensor infos
45 | if not isinstance(ep_info[key], th.Tensor):
46 | ep_info[key] = th.Tensor([ep_info[key]])
47 | if len(ep_info[key].shape) == 0:
48 | ep_info[key] = ep_info[key].unsqueeze(0)
49 | infotensor = th.cat((infotensor, ep_info[key].to(self.device)))
50 | value = th.mean(infotensor).detach().cpu().numpy()
51 | return value
52 |
53 | class ReconstructionCallBack(BestCKPTCallback):
54 | pass
55 |
--------------------------------------------------------------------------------
/gleam/env/config_gleam.py:
--------------------------------------------------------------------------------
1 | from gleam.env.config_legged_visual import LeggedVisualInputConfig, LeggedVIsualInputCfgPPO
2 | from isaacgym import gymapi
3 |
4 |
5 | class Config_GLEAM(LeggedVisualInputConfig):
6 | position_use_polar_coordinates = False # the position will be represented by (r, \theta, \phi) instead of (x, y, z)
7 | direction_use_vector = False # (r, p, y)
8 | debug_save_image_tensor = False
9 | debug_save_path = None
10 | max_episode_length = 500 # max_steps_per_episode
11 |
12 | num_sampled_point = 5000
13 |
14 | class rewards:
15 | class scales: # * self.dt (0.019999) / episode_length_s (20). For example, when reward_scales=1000, rew_xxx = reward * 1000 * 0.019999 / 20 = reward
16 | surface_coverage_2d = 1000 # original scale (coverage ratio: [0, 1])
17 |
18 | termination = 50 # Terminal reward / penalty
19 | collision = -100 # Penalize collisions on selected bodies
20 |
21 | only_positive_rewards = False # if true negative total rewards are clipped at zero (avoids early termination problems)
22 | max_contact_force = 100. # forces above this value are penalized
23 |
24 | class normalization:
25 | class obs_scales:
26 | lin_vel = 2.0
27 | ang_vel = 0.25
28 | dof_pos = 1.0
29 | dof_vel = 0.05
30 | height_measurements = 5.0
31 |
32 | pi = 3.14159265359
33 | clip_observations = 255.
34 |
35 | # discrete action space
36 | init_action = [64, 64, 0, 0, 0, 0] # 64: keep still
37 | clip_actions_up = [128, 128, 0, 0, 0, 0]
38 | clip_actions_low = [0, 0, 0, 0, 0, 0]
39 |
40 | class visual_input(LeggedVisualInputConfig.visual_input):
41 | camera_width = 256
42 | camera_height = 32
43 | horizontal_fov = 90.0 # Horizontal field of view in degrees. Vertical field of view is calculated from height to width ratio
44 |
45 | stack = 100
46 | supersampling_horizontal = 1
47 | supersampling_vertical = 1
48 | normalization = True # normalize pixels to [0, 1]
49 | far_plane = 2000000.0 # distance in world coordinates to far-clipping plane
50 | near_plane = 0.0010000000474974513 # distance in world coordinate units to near-clipping plane
51 |
52 | type = gymapi.IMAGE_DEPTH
53 |
54 | class asset(LeggedVisualInputConfig.asset):
55 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/drone/cf2x.urdf'
56 | name = "cf2x"
57 | prop_name = "prop"
58 | penalize_contacts_on = ["prop", "base"]
59 | terminate_after_contacts_on = ["prop", "base"]
60 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter
61 |
62 | class env(LeggedVisualInputConfig.env):
63 | num_observations = 6
64 | episode_length_s = 20 # in second !!!!!
65 | num_actions = 6
66 | env_spacing = 20
67 |
68 | class termination:
69 | collision = True
70 | max_step_done = True
71 |
72 |
73 | class DroneCfgPPO(LeggedVIsualInputCfgPPO):
74 | """
75 | This config is only a placeholder for using task register and thus unless, since we use sb3 instead of RSL_RL
76 | """
77 | seed = 1
78 | runner_class_name = 'OnPolicyRunner'
79 |
80 | class policy(LeggedVIsualInputCfgPPO.policy):
81 | init_noise_std = 1.0
82 | actor_hidden_dims = [512, 256, 128]
83 | critic_hidden_dims = [512, 256, 128]
84 | activation = 'elu' # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid
85 | # only for 'ActorCriticRecurrent':
86 | # rnn_type = 'lstm'
87 | # rnn_hidden_size = 512
88 | # rnn_num_layers = 1
89 |
90 | class algorithm(LeggedVIsualInputCfgPPO.algorithm):
91 | # training params
92 | value_loss_coef = 1.0
93 | use_clipped_value_loss = True
94 | clip_param = 0.2
95 | entropy_coef = 0.01
96 | num_learning_epochs = 5
97 | num_mini_batches = 4 # mini batch size = num_envs*nsteps / nminibatches
98 | learning_rate = 1.e-3 # 5.e-4
99 | schedule = 'adaptive' # could be adaptive, fixed
100 | gamma = 0.99
101 | lam = 0.95
102 | desired_kl = 0.01
103 | max_grad_norm = 1.
104 |
105 | class runner(LeggedVIsualInputCfgPPO.runner):
106 | policy_class_name = 'ActorCritic'
107 | algorithm_class_name = 'PPO'
108 | num_steps_per_env = 24 # per iteration
109 | max_iterations = 1500 # number of policy updates
110 |
111 | # logging
112 | save_interval = 50 # check for potential saves every this many iterations
113 | experiment_name = 'test'
114 | run_name = ''
115 | # load and resume
116 | resume = False
117 | load_run = -1 # -1 = last run
118 | checkpoint = -1 # -1 = last saved model
119 | resume_path = None # updated from load_run and chkpt
120 |
121 | class visual_input(LeggedVisualInputConfig.visual_input):
122 | camera_width = 320
123 | camera_height = 240
124 | type = gymapi.IMAGE_COLOR
125 | # stack = 5 # consecutive frames to stack
126 | normalization = True
127 | cam_pos = (0.0, 0, 0.0)
128 |
--------------------------------------------------------------------------------
/gleam/env/config_gleam_eval.py:
--------------------------------------------------------------------------------
1 | from gleam.env.config_legged_visual import LeggedVisualInputConfig
2 | from isaacgym import gymapi
3 |
4 |
5 | class Config_GLEAM_Eval(LeggedVisualInputConfig):
6 | position_use_polar_coordinates = False # the position will be represented by (r, \theta, \phi) instead of (x, y, z)
7 | direction_use_vector = False # (r, p, y)
8 | debug_save_image_tensor = False
9 | debug_save_path = None
10 | max_episode_length = 50 # max_steps_per_episode
11 |
12 | num_sampled_point = 5000
13 |
14 | class rewards:
15 | class scales: # * self.dt (0.019999) / episode_length_s (20). For example, when reward_scales=1000, rew_xxx = reward * 1000 * 0.019999 / 20 = reward
16 | surface_coverage_2d = 1000 # original scale (coverage ratio: [0, 1])
17 |
18 | only_positive_rewards = False # if true negative total rewards are clipped at zero (avoids early termination problems)
19 | max_contact_force = 100. # forces above this value are penalized
20 |
21 | class normalization:
22 | class obs_scales:
23 | lin_vel = 2.0
24 | ang_vel = 0.25
25 | dof_pos = 1.0
26 | dof_vel = 0.05
27 | height_measurements = 5.0
28 |
29 | pi = 3.14159265359
30 | clip_observations = 255.
31 |
32 | # discrete action space
33 | init_action = [64, 64, 0, 0, 0, 0] # 64: keep still
34 | clip_actions_up = [128, 128, 0, 0, 0, 0]
35 | clip_actions_low = [0, 0, 0, 0, 0, 0]
36 |
37 | class visual_input(LeggedVisualInputConfig.visual_input):
38 | camera_width = 256
39 | camera_height = 32
40 | horizontal_fov = 90.0 # Horizontal field of view in degrees. Vertical field of view is calculated from height to width ratio
41 |
42 | stack = 100
43 | supersampling_horizontal = 1
44 | supersampling_vertical = 1
45 | normalization = True # normalize pixels to [0, 1]
46 | far_plane = 2000000.0 # distance in world coordinates to far-clipping plane
47 | near_plane = 0.0010000000474974513 # distance in world coordinate units to near-clipping plane
48 |
49 | type = gymapi.IMAGE_DEPTH
50 |
51 | class asset(LeggedVisualInputConfig.asset):
52 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/drone/cf2x.urdf'
53 | name = "cf2x"
54 | prop_name = "prop"
55 | penalize_contacts_on = ["prop", "base"]
56 | terminate_after_contacts_on = ["prop", "base"]
57 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter
58 |
59 | class env(LeggedVisualInputConfig.env):
60 | num_observations = 6
61 | episode_length_s = 20 # in second !!!!!
62 | num_actions = 6
63 | env_spacing = 20
64 |
65 | class termination:
66 | collision = True
67 | max_step_done = True
--------------------------------------------------------------------------------
/gleam/env/env_gleam_stage2.py:
--------------------------------------------------------------------------------
1 | from isaacgym import gymapi
2 | from isaacgym.torch_utils import *
3 | from legged_gym import OPEN_ROBOT_ROOT_DIR
4 | import os
5 | import torch
6 | import numpy as np
7 | from gleam.env.env_gleam_base import Env_GLEAM_Base
8 | from gleam.env.env_gleam_stage1 import Env_GLEAM_Stage1
9 |
10 |
11 | class Env_GLEAM_Stage2(Env_GLEAM_Stage1):
12 | def __init__(self, *args, **kwargs):
13 | """
14 | Training set of stage 2 includes 512 indoor scenes from:
15 | procthor_2-bed-2-bath_256,
16 | procthor_7-room-3-bed_96,
17 | procthor_12-room_64,
18 | gibson_96.
19 | """
20 | self.visualize_flag = False # training
21 | # self.visualize_flag = True # visualization
22 |
23 | # self.num_scene = 4 # debug
24 | # print("*"*50, "num_scene: ", self.num_scene, "*"*50)
25 |
26 | self.num_scene = 512
27 |
28 | super(Env_GLEAM_Base, self).__init__(*args, **kwargs)
29 |
30 | def _additional_create(self, env_handle, env_index):
31 | """
32 | If there are N training scenes and M environments, each environment will load N//M scenes.
33 | Only the first scene (idx == 0) is active, and the others are inactive.
34 | """
35 | assert self.cfg.return_visual_observation, "visual observation should be returned!"
36 | assert self.num_scene >= self.num_envs, "num_scene should be larger than num_envs"
37 |
38 | # urdf load, create actor
39 | dataset_name = "stage2_512"
40 | urdf_path = f"data_gleam/train_{dataset_name}/urdf"
41 |
42 | scene_per_env = self.num_scene // self.num_envs
43 |
44 | asset_options = gymapi.AssetOptions()
45 | asset_options.flip_visual_attachments = self.cfg.asset.flip_visual_attachments
46 | asset_options.fix_base_link = True
47 | asset_options.disable_gravity = True
48 |
49 | inactive_x = self.env_origins[:, 0].max() + 15.
50 | inactive_y = self.env_origins[:, 1].max() + 15.
51 | inactive_z = self.env_origins[:, 2].max()
52 |
53 | for idx in range(scene_per_env):
54 | scene_idx = env_index * scene_per_env + idx
55 | urdf_name = f"scene_{scene_idx}.urdf"
56 | asset = self.gym.load_asset(self.sim, urdf_path, urdf_name, asset_options)
57 |
58 | pose = gymapi.Transform()
59 | if idx == 0: # Only the first scene (idx == 0) is active
60 | pose.p = gymapi.Vec3(self.env_origins[env_index][0],
61 | self.env_origins[env_index][1],
62 | self.env_origins[env_index][2])
63 | else:
64 | pose.p = gymapi.Vec3(inactive_x, inactive_y, inactive_z)
65 | pose.r = gymapi.Quat(-np.pi/2, 0.0, 0.0, np.pi/2)
66 | self.gym.create_actor(env_handle, asset, pose, None, env_index, 0)
67 |
68 | self.additional_actors[env_index] = [i+1 for i in range(scene_per_env)]
69 |
70 | def _init_load_all(self):
71 | """
72 | Load all ground truth data.
73 | """
74 | self.grid_size = 128
75 | self.motion_height = 1.5 # 1.5m
76 | dataset_name = "stage2_512"
77 | gt_path = os.path.join(OPEN_ROBOT_ROOT_DIR, f"data_gleam/train_{dataset_name}/gt/")
78 |
79 | # [num_scene, 3]
80 | self.voxel_size_gt = torch.load(gt_path+f"{dataset_name}_{self.grid_size}_voxel_size_gt.pt",
81 | map_location=self.device)[:self.num_scene]
82 |
83 | # [num_scene, 6], (x_max, x_min, y_max, y_min, z_max, z_min)
84 | self.range_gt = torch.load(gt_path+f"{dataset_name}_{self.grid_size}_range_gt.pt",
85 | map_location=self.device)[:self.num_scene]
86 |
87 | # [num_scene, grid_size, grid_size], layout map at the height of {self.motion height}
88 | self.layout_maps_height = torch.load(gt_path+f"{dataset_name}_{self.grid_size}_occ_map_height_1d5_gt.pt",
89 | map_location=self.device)[:self.num_scene].to(torch.float16)
90 | self.layout_maps_height /= 255.
91 |
92 | # [num_scene]
93 | self.num_valid_pixel_gt = self.layout_maps_height.sum(dim=(1, 2))
94 |
95 | # [num_scene, 128, 128]
96 | init_maps = torch.load(gt_path+f"{dataset_name}_{self.grid_size}_init_map_1d5.pt",
97 | map_location=self.device)[:self.num_scene]
98 | init_maps /= 255.
99 |
100 | # len() == self.num_scene, the shape of element is [num_non_zero_pixel, 2]
101 | self.init_maps_list = [(torch.nonzero(init_maps[idx]) / (self.grid_size - 1) * 2 - 1)\
102 | * self.range_gt[idx, :4:2]
103 | for idx in range(self.num_scene)]
104 |
105 | print("Loaded all ground truth data.")
106 |
107 | def _init_buffers(self):
108 | super()._init_buffers()
109 |
110 | # visualization
111 | if self.visualize_flag:
112 | # visualize scenes
113 | self.vis_obj_idx = 0
114 | print("Visualization object index: ", self.vis_obj_idx)
115 |
116 | self.reset_once_flag = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
117 | self.reset_once_cr = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
118 | self.local_paths = [[] for _ in range(self.num_envs)]
119 |
120 | self.save_path = f'./gleam/output/train_stage2_512'
121 | os.makedirs(self.save_path, exist_ok=True)
122 |
--------------------------------------------------------------------------------
/gleam/network/encoder.py:
--------------------------------------------------------------------------------
1 | # from typing import Any, Dict, List, Optional, Type, Union, Tuple
2 | from typing import List, Tuple
3 |
4 | import gym
5 | import torch as th
6 | from gleam.network.base import LocoTransformerEncoder_Map
7 | from gleam.network.locotransformer import LocoTransformer_GLEAM
8 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
9 |
10 |
11 | class Encoder_GLEAM(BaseFeaturesExtractor):
12 | """
13 | We adapt the LocoTransformer to encode raw observations, including:
14 | - state input: historical poses within this episode.
15 | - visual input: current egocentric map, with the shape of (128, 128).
16 | """
17 |
18 | def __init__(
19 | self,
20 | observation_space: gym.spaces.Space,
21 | encoder_param=None,
22 | net_param=None,
23 | visual_input_shape=None,
24 | state_input_shape=None,
25 | ):
26 | assert encoder_param is not None, "Need parameters !"
27 | assert net_param is not None, "Need parameters !"
28 | assert isinstance(visual_input_shape, List) or isinstance(visual_input_shape, Tuple), "Use tuple or list"
29 | assert isinstance(state_input_shape, List) or isinstance(state_input_shape, Tuple), "Use tuple or list"
30 | self.map_channel = visual_input_shape[0]
31 | self.map_shape = visual_input_shape[1:]
32 | self.state_input_shape = state_input_shape
33 | feature_dim = net_param["append_hidden_shapes"][-1]
34 | net_param["append_hidden_shapes"].pop()
35 | super(Encoder_GLEAM, self).__init__(observation_space, feature_dim)
36 |
37 | # create encoder, share encoder
38 | self.encoder = LocoTransformerEncoder_Map(
39 | in_channels=self.map_channel,
40 | state_input_dim=self.state_input_shape[0],
41 | **encoder_param
42 | )
43 |
44 | self.locotransformer = LocoTransformer_GLEAM(
45 | encoder=self.encoder,
46 | state_input_shape=self.state_input_shape[0],
47 | visual_input_shape=(self.map_channel, *self.map_shape),
48 | output_shape=feature_dim,
49 | **net_param
50 | )
51 |
52 | def forward(self, observations: th.Tensor) -> th.Tensor:
53 | """
54 | observations: [num_env, concat(state_input, visual_input)]
55 | """
56 | return self.locotransformer.forward(observations)
--------------------------------------------------------------------------------
/gleam/network/init.py:
--------------------------------------------------------------------------------
1 | """
2 | Credit to legged_gym and TorchRL.
3 | """
4 |
5 | import numpy as np
6 | import torch.nn as nn
7 |
8 |
9 | def _fanin_init(tensor, alpha=0):
10 | size = tensor.size()
11 | if len(size) == 2:
12 | fan_in = size[0]
13 | elif len(size) > 2:
14 | fan_in = np.prod(size[1:])
15 | else:
16 | raise Exception("Shape must be have dimension at least 2.")
17 | # bound = 1. / np.sqrt(fan_in)
18 | bound = np.sqrt(1. / ((1 + alpha * alpha) * fan_in))
19 | return tensor.data.uniform_(-bound, bound)
20 |
21 |
22 | def _uniform_init(tensor, param=3e-3):
23 | return tensor.data.uniform_(-param, param)
24 |
25 |
26 | def _constant_bias_init(tensor, constant=0.1):
27 | tensor.data.fill_(constant)
28 |
29 |
30 | def layer_init(layer, weight_init=_fanin_init, bias_init=_constant_bias_init):
31 | weight_init(layer.weight)
32 | bias_init(layer.bias)
33 |
34 |
35 | def basic_init(layer):
36 | layer_init(layer, weight_init=_fanin_init, bias_init=_constant_bias_init)
37 |
38 |
39 | def uniform_init(layer):
40 | layer_init(layer, weight_init=_uniform_init, bias_init=_uniform_init)
41 |
42 |
43 | def _orthogonal_init(tensor, gain=np.sqrt(2)):
44 | nn.init.orthogonal_(tensor, gain=gain)
45 |
46 |
47 | def orthogonal_init(layer, scale=np.sqrt(2), constant=0):
48 | layer_init(
49 | layer, weight_init=lambda x: _orthogonal_init(x, gain=scale), bias_init=lambda x: _constant_bias_init(x, 0)
50 | )
51 |
--------------------------------------------------------------------------------
/gleam/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/gleam/train/__init__.py
--------------------------------------------------------------------------------
/gleam/utils/bfs_cuda_2D.h:
--------------------------------------------------------------------------------
1 | #ifndef BFS_CUDA_H
2 | #define BFS_CUDA_H
3 |
4 | void BFS_kernel_launcher_2D(
5 | const float* occupancy_maps,
6 | const int* starts,
7 | const int* goals,
8 | float* path_lengths,
9 | int num_env,
10 | int H,
11 | int W
12 | );
13 |
14 | #endif
--------------------------------------------------------------------------------
/gleam/utils/bfs_cuda_kernel_2D.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include "bfs_cuda_2D.h"
5 |
6 | #define MAX_PATH_LENGTH 16384 // 128 * 128
7 |
8 | __global__ void BFS_kernel_2D(
9 | const float* __restrict__ occupancy_maps, // Grid map (1 = walkable, 0 = obstacle)
10 | const int* __restrict__ starts,
11 | const int* __restrict__ goals,
12 | float* __restrict__ path_lengths,
13 | int num_env,
14 | int H,
15 | int W)
16 | {
17 | int env_id = blockIdx.x * blockDim.x + threadIdx.x;
18 | if (env_id >= num_env) return;
19 |
20 | // Compute offsets for this environment
21 | const int occupancy_offset = env_id * H * W;
22 |
23 | // Extract start and goal positions
24 | int start_x = starts[env_id * 2];
25 | int start_y = starts[env_id * 2 + 1];
26 | int goal_x = goals[env_id * 2];
27 | int goal_y = goals[env_id * 2 + 1];
28 |
29 | // Use global memory for visited and queue
30 | bool* visited = new bool[H * W];
31 | int* distances = new int[H * W];
32 | int* queue = new int[MAX_PATH_LENGTH];
33 |
34 | // Initialize arrays
35 | for (int i = 0; i < H * W; i++) {
36 | visited[i] = false;
37 | distances[i] = 0;
38 | }
39 |
40 | // Initialize BFS
41 | int front = 0;
42 | int back = 1;
43 | int start_pos = start_x * W + start_y;
44 | queue[0] = start_pos;
45 | visited[start_pos] = true;
46 | distances[start_pos] = 1;
47 | bool found = false;
48 | float path_length = -1.0f;
49 |
50 | // BFS loop - stop if path length exceeds MAX_PATH_LENGTH
51 | while (front < back && back < MAX_PATH_LENGTH && !found) {
52 | int current_pos = queue[front];
53 | int current_x = current_pos / W;
54 | int current_y = current_pos % W;
55 |
56 | // If current distance is already MAX_PATH_LENGTH, break
57 | if (distances[current_pos] >= MAX_PATH_LENGTH) {
58 | break;
59 | }
60 |
61 | if (current_x == goal_x && current_y == goal_y) {
62 | found = true;
63 | path_length = (float)distances[current_pos];
64 | break;
65 | }
66 |
67 | // // Check 8-connected neighbors (diagonal movement allowed)
68 | // const int dx[8] = {-1, 0, 1, 0, 1, 1, -1, -1};
69 | // const int dy[8] = {0, 1, 0, -1, 1, -1, 1, -1};
70 |
71 | // Check 4-connected neighbors (no diagonal movement)
72 | const int dx[4] = {-1, 0, 1, 0};
73 | const int dy[4] = {0, 1, 0, -1};
74 |
75 | for (int i = 0; i < 4; i++) {
76 | int nx = current_x + dx[i];
77 | int ny = current_y + dy[i];
78 |
79 | if (nx >= 0 && nx < H && ny >= 0 && ny < W) {
80 | int neighbor_pos = nx * W + ny;
81 | if (occupancy_maps[occupancy_offset + neighbor_pos] == 1 && !visited[neighbor_pos]) {
82 | queue[back] = neighbor_pos;
83 | visited[neighbor_pos] = true;
84 | distances[neighbor_pos] = distances[current_pos] + 1;
85 | back++;
86 | if (back >= MAX_PATH_LENGTH) break;
87 | }
88 | }
89 | }
90 | front++;
91 | }
92 |
93 | path_lengths[env_id] = path_length;
94 |
95 | // Clean up
96 | delete[] queue;
97 | delete[] visited;
98 | delete[] distances;
99 | }
100 |
101 | void BFS_kernel_launcher_2D(
102 | const float* occupancy_maps,
103 | const int* starts,
104 | const int* goals,
105 | float* path_lengths,
106 | int num_env,
107 | int H,
108 | int W)
109 | {
110 | int threads = 256;
111 | int blocks = (num_env + threads - 1) / threads;
112 | BFS_kernel_2D<<>>(
113 | occupancy_maps,
114 | starts,
115 | goals,
116 | path_lengths,
117 | num_env,
118 | H,
119 | W
120 | );
121 | }
--------------------------------------------------------------------------------
/gleam/wrapper/env_wrapper_gleam.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """An env wrapper that flattens the observation dictionary to an array."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import gym
21 | import collections
22 |
23 | import torch
24 | from gym import spaces
25 | import numpy as np
26 |
27 | def flatten_observations(observation_dict, key_sequence):
28 | """Flattens the observation dictionary to an array.
29 |
30 | If observation_excluded is passed in, it will still return a dictionary,
31 | which includes all the (key, observation_dict[key]) in observation_excluded,
32 | and ('other': the flattened array).
33 |
34 | Args:
35 | observation_dict: A dictionary of all the observations.
36 | key_sequence: A list/tuple of all the keys of the observations to be
37 | added during flattening.
38 |
39 | Returns:
40 | An array or a dictionary of observations based on whether
41 | observation_excluded is empty.
42 | """
43 | observations = []
44 | for key in key_sequence:
45 | value = observation_dict[key]
46 | observations.append(value)
47 |
48 | flat_observations = torch.concat(observations, dim=-1) # grid observation
49 | return flat_observations
50 |
51 |
52 | def flatten_observation_spaces(observation_spaces, key_sequence):
53 | """Flattens the dictionary observation spaces to gym.spaces.Box.
54 |
55 | If observation_excluded is passed in, it will still return a dictionary,
56 | which includes all the (key, observation_spaces[key]) in observation_excluded,
57 | and ('other': the flattened Box space).
58 |
59 | Args:
60 | observation_spaces: A dictionary of all the observation spaces.
61 | key_sequence: A list/tuple of all the keys of the observations to be
62 | added during flattening.
63 |
64 | Returns:
65 | A box space or a dictionary of observation spaces based on whether
66 | observation_excluded is empty.
67 | """
68 | assert isinstance(key_sequence, list)
69 | lower_bound = []
70 | upper_bound = []
71 | for key in key_sequence:
72 | value = observation_spaces.spaces[key]
73 | if isinstance(value, spaces.Box):
74 | lower_bound.append(np.asarray(value.low).flatten())
75 | upper_bound.append(np.asarray(value.high).flatten())
76 | else:
77 | continue
78 |
79 | lower_bound = np.concatenate(lower_bound)
80 | upper_bound = np.concatenate(upper_bound)
81 | observation_space = spaces.Box(np.array(lower_bound, dtype=np.float32), np.array(upper_bound, dtype=np.float32), dtype=np.float32)
82 | return observation_space
83 |
84 |
85 | class EnvWrapperGLEAM(gym.Env):
86 | """An env wrapper that flattens the observation dictionary to an array."""
87 | def __init__(self, gym_env, observation_excluded=()):
88 | """Initializes the wrapper."""
89 | self.observation_excluded = observation_excluded
90 | self._gym_env = gym_env
91 | self.observation_space = self._flatten_observation_spaces(self._gym_env.observation_space)
92 | self.action_space = self._gym_env.action_space
93 |
94 | def __getattr__(self, attr):
95 | return getattr(self._gym_env, attr)
96 |
97 | def _flatten_observation_spaces(self, observation_spaces):
98 | flat_observation_space = flatten_observation_spaces(
99 | observation_spaces=observation_spaces, key_sequence=["state", "ego_map_2D"]
100 | )
101 | return flat_observation_space
102 |
103 | def _flatten_observation(self, input_observation):
104 | """Flatten the dictionary to an array."""
105 | return flatten_observations(observation_dict=input_observation, key_sequence=["state", "ego_map_2D"])
106 |
107 | def reset(self):
108 | observation = self._gym_env.reset()
109 | return self._flatten_observation(observation)
110 |
111 | def step(self, action):
112 | """Steps the wrapped environment.
113 |
114 | Args:
115 | action: Numpy array. The input action from an NN agent.
116 |
117 | Returns:
118 | The tuple containing the flattened observation, the reward, the epsiode
119 | end indicator.
120 | """
121 | observation_dict, reward, done, _ = self._gym_env.step(action)
122 | return self._flatten_observation(observation_dict), reward, done, _
123 |
124 | def render(self, mode='human'):
125 | return self._gym_env.render(mode)
126 |
127 | def close(self):
128 | self._gym_env.close()
129 |
--------------------------------------------------------------------------------
/legged_gym/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import os
32 |
33 | OPEN_ROBOT_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
34 | LEGGED_GYM_ENVS_DIR = os.path.join(OPEN_ROBOT_ROOT_DIR, 'legged_gym', 'envs')
35 |
--------------------------------------------------------------------------------
/legged_gym/env/__init__.py:
--------------------------------------------------------------------------------
1 | # # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # # SPDX-License-Identifier: BSD-3-Clause
3 | # #
4 | # # Redistribution and use in source and binary forms, with or without
5 | # # modification, are permitted provided that the following conditions are met:
6 | # #
7 | # # 1. Redistributions of source code must retain the above copyright notice, this
8 | # # list of conditions and the following disclaimer.
9 | # #
10 | # # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # # this list of conditions and the following disclaimer in the documentation
12 | # # and/or other materials provided with the distribution.
13 | # #
14 | # # 3. Neither the name of the copyright holder nor the names of its
15 | # # contributors may be used to endorse or promote products derived from
16 | # # this software without specific prior written permission.
17 | # #
18 | # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | # #
29 | # # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 |
32 | # from legged_gym import OPEN_ROBOT_ROOT_DIR, LEGGED_GYM_ENVS_DIR
33 | # from legged_gym.env.a1.a1_config import A1RoughCfg, A1RoughCfgPPO
34 | # from .base.legged_robot import LeggedRobot
35 | # from .anymal_c.anymal import Anymal
36 | # from .anymal_c.mixed_terrains.anymal_c_rough_config import AnymalCRoughCfg, AnymalCRoughCfgPPO
37 | # from .anymal_c.flat.anymal_c_flat_config import AnymalCFlatCfg, AnymalCFlatCfgPPO
38 | # from .anymal_b.anymal_b_config import AnymalBRoughCfg, AnymalBRoughCfgPPO
39 | # from .cassie.cassie import Cassie
40 | # from .cassie.cassie_config import CassieRoughCfg, CassieRoughCfgPPO
41 | # from .a1.a1_config import A1RoughCfg, A1RoughCfgPPO
42 |
43 | # from legged_gym.utils.task_registry import task_registry
44 |
45 | # task_registry.register("anymal_c_rough", Anymal, AnymalCRoughCfg(), AnymalCRoughCfgPPO())
46 | # task_registry.register("anymal_c_flat", Anymal, AnymalCFlatCfg(), AnymalCFlatCfgPPO())
47 | # task_registry.register("anymal_b", Anymal, AnymalBRoughCfg(), AnymalBRoughCfgPPO())
48 | # task_registry.register("a1", LeggedRobot, A1RoughCfg(), A1RoughCfgPPO())
49 | # task_registry.register("cassie", Cassie, CassieRoughCfg(), CassieRoughCfgPPO())
50 |
--------------------------------------------------------------------------------
/legged_gym/env/a1/a1_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from legged_gym.env.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO
32 |
33 |
34 | class A1RoughCfg(LeggedRobotCfg):
35 | class init_state(LeggedRobotCfg.init_state):
36 | pos = [0.0, 0.0, 0.42] # x,y,z [m]
37 | default_joint_angles = { # = target angles [rad] when action = 0.0
38 | 'FL_hip_joint': 0.1, # [rad]
39 | 'RL_hip_joint': 0.1, # [rad]
40 | 'FR_hip_joint': -0.1 , # [rad]
41 | 'RR_hip_joint': -0.1, # [rad]
42 |
43 | 'FL_thigh_joint': 0.8, # [rad]
44 | 'RL_thigh_joint': 1., # [rad]
45 | 'FR_thigh_joint': 0.8, # [rad]
46 | 'RR_thigh_joint': 1., # [rad]
47 |
48 | 'FL_calf_joint': -1.5, # [rad]
49 | 'RL_calf_joint': -1.5, # [rad]
50 | 'FR_calf_joint': -1.5, # [rad]
51 | 'RR_calf_joint': -1.5, # [rad]
52 | }
53 |
54 | class control(LeggedRobotCfg.control):
55 | # PD Drive parameters:
56 | control_type = 'P'
57 | stiffness = {'joint': 20.} # [N*m/rad]
58 | damping = {'joint': 0.5} # [N*m*s/rad]
59 | # action scale: target angle = actionScale * action + defaultAngle
60 | action_scale = 0.25
61 | # decimation: Number of control action updates @ sim DT per policy DT
62 | decimation = 4
63 |
64 | class asset(LeggedRobotCfg.asset):
65 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/a1/urdf/a1.urdf'
66 | name = "a1"
67 | foot_name = "foot"
68 | penalize_contacts_on = ["thigh", "calf"]
69 | terminate_after_contacts_on = ["base"]
70 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter
71 |
72 | class rewards(LeggedRobotCfg.rewards):
73 | soft_dof_pos_limit = 0.9
74 | base_height_target = 0.25
75 |
76 | class scales(LeggedRobotCfg.rewards.scales):
77 | torques = -0.0002
78 | dof_pos_limits = -10.0
79 |
80 |
81 | class A1RoughCfgPPO(LeggedRobotCfgPPO):
82 | class algorithm(LeggedRobotCfgPPO.algorithm):
83 | entropy_coef = 0.01
84 |
85 | class runner(LeggedRobotCfgPPO.runner):
86 | run_name = ''
87 | experiment_name = 'rough_a1'
88 |
--------------------------------------------------------------------------------
/legged_gym/env/anymal_b/anymal_b_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from legged_gym.env import AnymalCRoughCfg, AnymalCRoughCfgPPO
32 |
33 |
34 | class AnymalBRoughCfg(AnymalCRoughCfg):
35 | class asset(AnymalCRoughCfg.asset):
36 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/anymal_b/urdf/anymal_b.urdf'
37 | name = "anymal_b"
38 | foot_name = 'FOOT'
39 |
40 | class rewards(AnymalCRoughCfg.rewards):
41 | class scales(AnymalCRoughCfg.rewards.scales):
42 | pass
43 |
44 |
45 | class AnymalBRoughCfgPPO(AnymalCRoughCfgPPO):
46 | class runner(AnymalCRoughCfgPPO.runner):
47 | run_name = ''
48 | experiment_name = 'rough_anymal_b'
49 | load_run = -1
50 |
--------------------------------------------------------------------------------
/legged_gym/env/anymal_c/anymal.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from time import time
32 | import numpy as np
33 | import os
34 |
35 | from isaacgym.torch_utils import *
36 | from isaacgym import gymtorch, gymapi, gymutil
37 |
38 | import torch
39 | # from torch.tensor import Tensor
40 | from typing import Tuple, Dict
41 |
42 | from legged_gym.env import LeggedRobot
43 | from legged_gym import OPEN_ROBOT_ROOT_DIR
44 | from .mixed_terrains.anymal_c_rough_config import AnymalCRoughCfg
45 |
46 |
47 | class Anymal(LeggedRobot):
48 | cfg: AnymalCRoughCfg
49 |
50 | def __init__(self, cfg, sim_params, physics_engine, sim_device, headless):
51 | super().__init__(cfg, sim_params, physics_engine, sim_device, headless)
52 |
53 | # load actuator network
54 | if self.cfg.control.use_actuator_network:
55 | actuator_network_path = self.cfg.control.actuator_net_file.format(LEGGED_GYM_ROOT_DIR=OPEN_ROBOT_ROOT_DIR)
56 | self.actuator_network = torch.jit.load(actuator_network_path).to(self.device)
57 |
58 | def reset_idx(self, env_ids):
59 | super().reset_idx(env_ids)
60 | # Additionaly empty actuator network hidden states
61 | self.sea_hidden_state_per_env[:, env_ids] = 0.
62 | self.sea_cell_state_per_env[:, env_ids] = 0.
63 |
64 | def _init_buffers(self):
65 | super()._init_buffers()
66 | # Additionally initialize actuator network hidden state tensors
67 | self.sea_input = torch.zeros(self.num_envs * self.num_actions, 1, 2, device=self.device, requires_grad=False)
68 | self.sea_hidden_state = torch.zeros(
69 | 2, self.num_envs * self.num_actions, 8, device=self.device, requires_grad=False
70 | )
71 | self.sea_cell_state = torch.zeros(
72 | 2, self.num_envs * self.num_actions, 8, device=self.device, requires_grad=False
73 | )
74 | self.sea_hidden_state_per_env = self.sea_hidden_state.view(2, self.num_envs, self.num_actions, 8)
75 | self.sea_cell_state_per_env = self.sea_cell_state.view(2, self.num_envs, self.num_actions, 8)
76 |
77 | def _compute_torques(self, actions):
78 | # Choose between pd controller and actuator network
79 | if self.cfg.control.use_actuator_network:
80 | with torch.inference_mode():
81 | self.sea_input[:, 0,
82 | 0] = (actions * self.cfg.control.action_scale + self.default_dof_pos -
83 | self.dof_pos).flatten()
84 | self.sea_input[:, 0, 1] = self.dof_vel.flatten()
85 | torques, (self.sea_hidden_state[:], self.sea_cell_state[:]
86 | ) = self.actuator_network(self.sea_input, (self.sea_hidden_state, self.sea_cell_state))
87 | return torques
88 | else:
89 | # pd controller
90 | return super()._compute_torques(actions)
91 |
--------------------------------------------------------------------------------
/legged_gym/env/anymal_c/flat/anymal_c_flat_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from legged_gym.env import AnymalCRoughCfg, AnymalCRoughCfgPPO
32 |
33 |
34 | class AnymalCFlatCfg(AnymalCRoughCfg):
35 | class env(AnymalCRoughCfg.env):
36 | num_observations = 48
37 |
38 | class terrain(AnymalCRoughCfg.terrain):
39 | mesh_type = 'plane'
40 | measure_heights = False
41 |
42 | class asset(AnymalCRoughCfg.asset):
43 | self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter
44 |
45 | class rewards(AnymalCRoughCfg.rewards):
46 | max_contact_force = 350.
47 |
48 | class scales(AnymalCRoughCfg.rewards.scales):
49 | orientation = -5.0
50 | torques = -0.000025
51 | feet_air_time = 2.
52 | # feet_contact_forces = -0.01
53 |
54 | class commands(AnymalCRoughCfg.commands):
55 | heading_command = False
56 | resampling_time = 4.
57 |
58 | class ranges(AnymalCRoughCfg.commands.ranges):
59 | ang_vel_yaw = [-1.5, 1.5]
60 |
61 | class domain_rand(AnymalCRoughCfg.domain_rand):
62 | friction_range = [
63 | 0., 1.5
64 | ] # on ground planes the friction combination mode is averaging, i.e total friction = (foot_friction + 1.)/2.
65 |
66 |
67 | class AnymalCFlatCfgPPO(AnymalCRoughCfgPPO):
68 | class policy(AnymalCRoughCfgPPO.policy):
69 | actor_hidden_dims = [128, 64, 32]
70 | critic_hidden_dims = [128, 64, 32]
71 | activation = 'elu' # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid
72 |
73 | class algorithm(AnymalCRoughCfgPPO.algorithm):
74 | entropy_coef = 0.01
75 |
76 | class runner(AnymalCRoughCfgPPO.runner):
77 | run_name = ''
78 | experiment_name = 'flat_anymal_c'
79 | load_run = -1
80 | max_iterations = 300
81 |
--------------------------------------------------------------------------------
/legged_gym/env/anymal_c/mixed_terrains/anymal_c_rough_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from legged_gym.env.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO
32 |
33 |
34 | class AnymalCRoughCfg(LeggedRobotCfg):
35 | class env(LeggedRobotCfg.env):
36 | num_envs = 4096
37 | num_actions = 12
38 |
39 | class terrain(LeggedRobotCfg.terrain):
40 | mesh_type = 'trimesh'
41 |
42 | class init_state(LeggedRobotCfg.init_state):
43 | pos = [0.0, 0.0, 0.6] # x,y,z [m]
44 | default_joint_angles = { # = target angles [rad] when action = 0.0
45 | "LF_HAA": 0.0,
46 | "LH_HAA": 0.0,
47 | "RF_HAA": -0.0,
48 | "RH_HAA": -0.0,
49 |
50 | "LF_HFE": 0.4,
51 | "LH_HFE": -0.4,
52 | "RF_HFE": 0.4,
53 | "RH_HFE": -0.4,
54 |
55 | "LF_KFE": -0.8,
56 | "LH_KFE": 0.8,
57 | "RF_KFE": -0.8,
58 | "RH_KFE": 0.8,
59 | }
60 |
61 | class control(LeggedRobotCfg.control):
62 | # PD Drive parameters:
63 | stiffness = {'HAA': 80., 'HFE': 80., 'KFE': 80.} # [N*m/rad]
64 | damping = {'HAA': 2., 'HFE': 2., 'KFE': 2.} # [N*m*s/rad]
65 | # action scale: target angle = actionScale * action + defaultAngle
66 | action_scale = 0.5
67 | # decimation: Number of control action updates @ sim DT per policy DT
68 | decimation = 4
69 | use_actuator_network = True
70 | actuator_net_file = "{LEGGED_GYM_ROOT_DIR}/resources/actuator_nets/anydrive_v3_lstm.pt"
71 |
72 | class asset(LeggedRobotCfg.asset):
73 | file = "{LEGGED_GYM_ROOT_DIR}/resources/robots/anymal_c/urdf/anymal_c.urdf"
74 | name = "anymal_c"
75 | foot_name = "FOOT"
76 | penalize_contacts_on = ["SHANK", "THIGH"]
77 | terminate_after_contacts_on = ["base"]
78 | self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter
79 |
80 | class domain_rand(LeggedRobotCfg.domain_rand):
81 | randomize_base_mass = True
82 | added_mass_range = [-5., 5.]
83 |
84 | class rewards(LeggedRobotCfg.rewards):
85 | base_height_target = 0.5
86 | max_contact_force = 500.
87 | only_positive_rewards = True
88 |
89 | class scales(LeggedRobotCfg.rewards.scales):
90 | pass
91 |
92 |
93 | class AnymalCRoughCfgPPO(LeggedRobotCfgPPO):
94 | class runner(LeggedRobotCfgPPO.runner):
95 | run_name = ''
96 | experiment_name = 'rough_anymal_c'
97 | load_run = -1
98 |
--------------------------------------------------------------------------------
/legged_gym/env/base/base_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import inspect
32 |
33 |
34 | class BaseConfig:
35 | def __init__(self) -> None:
36 | """ Initializes all member classes recursively. Ignores all namse starting with '__' (buit-in methods)."""
37 | self.init_member_classes(self)
38 |
39 | @staticmethod
40 | def init_member_classes(obj):
41 | # iterate over all attributes names
42 | for key in dir(obj):
43 | # disregard builtin attributes
44 | # if key.startswith("__"):
45 | if key == "__class__":
46 | continue
47 | # get the corresponding attribute object
48 | var = getattr(obj, key)
49 | # check if it the attribute is a class
50 | if inspect.isclass(var):
51 | # instantate the class
52 | i_var = var()
53 | # set the attribute to the instance instead of the type
54 | setattr(obj, key, i_var)
55 | # recursively init members of the attribute
56 | BaseConfig.init_member_classes(i_var)
57 |
--------------------------------------------------------------------------------
/legged_gym/scripts/play.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from legged_gym import OPEN_ROBOT_ROOT_DIR
32 | import os
33 |
34 | import isaacgym
35 | from legged_gym.env import *
36 | from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger
37 |
38 | import numpy as np
39 | import torch
40 |
41 |
42 | def play(args):
43 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
44 | env_cfg.return_visual_observation = False
45 | # override some parameters for testing
46 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1)
47 | env_cfg.terrain.num_rows = 1
48 | env_cfg.terrain.num_cols = 1
49 | env_cfg.terrain.curriculum = False
50 | env_cfg.terrain.mesh_type = "plane"
51 | env_cfg.noise.add_noise = False
52 | env_cfg.domain_rand.randomize_friction = False
53 | env_cfg.domain_rand.push_robots = False
54 |
55 | # prepare environment
56 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
57 | obs = env.get_observations()
58 | # load policy
59 | train_cfg.runner.resume = True
60 | train_cfg.runner.num_steps_per_env = 1
61 | ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args, train_cfg=train_cfg)
62 | policy = ppo_runner.get_inference_policy(device=env.device)
63 |
64 | # export policy as a jit module (used to run it from C++)
65 | if EXPORT_POLICY:
66 | path = os.path.join(OPEN_ROBOT_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'policies')
67 | export_policy_as_jit(ppo_runner.alg.actor_critic, path)
68 | print('Exported policy as jit script to: ', path)
69 |
70 | logger = Logger(env.dt)
71 | robot_index = 0 # which robot is used for logging
72 | joint_index = 1 # which joint is used for logging
73 | stop_state_log = 100 # number of steps before plotting states
74 | stop_rew_log = env.max_episode_length + 1 # number of steps before print average episode rewards
75 | camera_position = np.array(env_cfg.viewer.pos, dtype=np.float64)
76 | camera_vel = np.array([1., 1., 0.])
77 | camera_direction = np.array(env_cfg.viewer.lookat) - np.array(env_cfg.viewer.pos)
78 | img_idx = 0
79 |
80 | for i in range(10 * int(env.max_episode_length)):
81 | actions = policy(obs.detach())
82 | obs, _, rews, dones, infos = env.step(actions.detach())
83 | if RECORD_FRAMES:
84 | if i % 2:
85 | filename = os.path.join(
86 | OPEN_ROBOT_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'frames',
87 | f"{img_idx}.png"
88 | )
89 | env.gym.write_viewer_image_to_file(env.viewer, filename)
90 | img_idx += 1
91 | if MOVE_CAMERA:
92 | camera_position += camera_vel * env.dt
93 | env.set_camera(camera_position, camera_position + camera_direction)
94 |
95 | if i < stop_state_log:
96 | logger.log_states(
97 | {
98 | 'dof_pos_target': actions[robot_index, joint_index].item() * env.cfg.control.action_scale,
99 | 'dof_pos': env.dof_pos[robot_index, joint_index].item(),
100 | 'dof_vel': env.dof_vel[robot_index, joint_index].item(),
101 | 'dof_torque': env.torques[robot_index, joint_index].item(),
102 | 'command_x': env.commands[robot_index, 0].item(),
103 | 'command_y': env.commands[robot_index, 1].item(),
104 | 'command_yaw': env.commands[robot_index, 2].item(),
105 | 'base_vel_x': env.base_lin_vel[robot_index, 0].item(),
106 | 'base_vel_y': env.base_lin_vel[robot_index, 1].item(),
107 | 'base_vel_z': env.base_lin_vel[robot_index, 2].item(),
108 | 'base_vel_yaw': env.base_ang_vel[robot_index, 2].item(),
109 | 'contact_forces_z': env.contact_forces[robot_index, env.feet_indices, 2].cpu().numpy()
110 | }
111 | )
112 | elif i == stop_state_log:
113 | logger.plot_states()
114 | if 0 < i < stop_rew_log:
115 | if infos["episode"]:
116 | num_episodes = torch.sum(env.reset_buf).item()
117 | if num_episodes > 0:
118 | logger.log_rewards(infos["episode"], num_episodes)
119 | elif i == stop_rew_log:
120 | logger.print_rewards()
121 |
122 |
123 | if __name__ == '__main__':
124 | import legged_complex_env
125 | EXPORT_POLICY = True
126 | RECORD_FRAMES = False
127 | MOVE_CAMERA = False
128 | args = get_args()
129 | args.task = "a1"
130 | play(args)
131 |
--------------------------------------------------------------------------------
/legged_gym/scripts/train.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import numpy as np
32 | import os
33 | from datetime import datetime
34 |
35 | import isaacgym
36 | from legged_gym.env import *
37 | from legged_gym.utils import get_args, task_registry
38 | import torch
39 |
40 |
41 | def train(args):
42 | env, env_cfg = task_registry.make_env(name=args.task, args=args)
43 | ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args)
44 | ppo_runner.learn(num_learning_iterations=train_cfg.runner.max_iterations, init_at_random_ep_len=True)
45 |
46 |
47 | if __name__ == '__main__':
48 | args = get_args()
49 | train(args)
50 |
--------------------------------------------------------------------------------
/legged_gym/tests/test_env.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import numpy as np
32 | import os
33 | from datetime import datetime
34 |
35 | import isaacgym
36 | from legged_gym.env import *
37 | from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger
38 |
39 | import torch
40 |
41 |
42 | def test_env(args):
43 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
44 | # override some parameters for testing
45 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 2)
46 |
47 | # prepare environment
48 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
49 | for i in range(int(10 * env.max_episode_length)):
50 | actions = 0. * torch.ones(env.num_envs, env.num_actions, device=env.device)
51 | obs, _, rew, done, info = env.step(actions)
52 | print("Done")
53 |
54 |
55 | if __name__ == '__main__':
56 | args = get_args()
57 | test_env(args)
58 |
--------------------------------------------------------------------------------
/legged_gym/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .helpers import class_to_dict, get_load_path, get_args, export_policy_as_jit, set_seed, update_class_from_dict
32 | from .task_registry import task_registry
33 | from .logger import Logger
34 | from .math import *
35 | from .terrain import Terrain
36 |
--------------------------------------------------------------------------------
/legged_gym/utils/logger.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import matplotlib.pyplot as plt
32 | import numpy as np
33 | from collections import defaultdict
34 | from multiprocessing import Process, Value
35 |
36 |
37 | class Logger:
38 | def __init__(self, dt):
39 | self.state_log = defaultdict(list)
40 | self.rew_log = defaultdict(list)
41 | self.dt = dt
42 | self.num_episodes = 0
43 | self.plot_process = None
44 |
45 | def log_state(self, key, value):
46 | self.state_log[key].append(value)
47 |
48 | def log_states(self, dict):
49 | for key, value in dict.items():
50 | self.log_state(key, value)
51 |
52 | def log_rewards(self, dict, num_episodes):
53 | for key, value in dict.items():
54 | if 'rew' in key:
55 | self.rew_log[key].append(value.item() * num_episodes)
56 | self.num_episodes += num_episodes
57 |
58 | def reset(self):
59 | self.state_log.clear()
60 | self.rew_log.clear()
61 |
62 | def plot_states(self):
63 | self.plot_process = Process(target=self._plot)
64 | self.plot_process.start()
65 |
66 | def _plot(self):
67 | nb_rows = 3
68 | nb_cols = 3
69 | fig, axs = plt.subplots(nb_rows, nb_cols)
70 | for key, value in self.state_log.items():
71 | time = np.linspace(0, len(value) * self.dt, len(value))
72 | break
73 | log = self.state_log
74 | # plot joint targets and measured positions
75 | a = axs[1, 0]
76 | if log["dof_pos"]:
77 | a.plot(time, log["dof_pos"], label='measured')
78 | if log["dof_pos_target"]:
79 | a.plot(time, log["dof_pos_target"], label='target')
80 | a.set(xlabel='time [s]', ylabel='Position [rad]', title='DOF Position')
81 | a.legend()
82 | # plot joint velocity
83 | a = axs[1, 1]
84 | if log["dof_vel"]:
85 | a.plot(time, log["dof_vel"], label='measured')
86 | if log["dof_vel_target"]:
87 | a.plot(time, log["dof_vel_target"], label='target')
88 | a.set(xlabel='time [s]', ylabel='Velocity [rad/s]', title='Joint Velocity')
89 | a.legend()
90 | # plot base vel x
91 | a = axs[0, 0]
92 | if log["base_vel_x"]:
93 | a.plot(time, log["base_vel_x"], label='measured')
94 | if log["command_x"]:
95 | a.plot(time, log["command_x"], label='commanded')
96 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity x')
97 | a.legend()
98 | # plot base vel y
99 | a = axs[0, 1]
100 | if log["base_vel_y"]:
101 | a.plot(time, log["base_vel_y"], label='measured')
102 | if log["command_y"]:
103 | a.plot(time, log["command_y"], label='commanded')
104 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity y')
105 | a.legend()
106 | # plot base vel yaw
107 | a = axs[0, 2]
108 | if log["base_vel_yaw"]:
109 | a.plot(time, log["base_vel_yaw"], label='measured')
110 | if log["command_yaw"]:
111 | a.plot(time, log["command_yaw"], label='commanded')
112 | a.set(xlabel='time [s]', ylabel='base ang vel [rad/s]', title='Base velocity yaw')
113 | a.legend()
114 | # plot base vel z
115 | a = axs[1, 2]
116 | if log["base_vel_z"]:
117 | a.plot(time, log["base_vel_z"], label='measured')
118 | a.set(xlabel='time [s]', ylabel='base lin vel [m/s]', title='Base velocity z')
119 | a.legend()
120 | # plot contact forces
121 | a = axs[2, 0]
122 | if log["contact_forces_z"]:
123 | forces = np.array(log["contact_forces_z"])
124 | for i in range(forces.shape[1]):
125 | a.plot(time, forces[:, i], label=f'force {i}')
126 | a.set(xlabel='time [s]', ylabel='Forces z [N]', title='Vertical Contact forces')
127 | a.legend()
128 | # plot torque/vel curves
129 | a = axs[2, 1]
130 | if log["dof_vel"] != [] and log["dof_torque"] != []:
131 | a.plot(log["dof_vel"], log["dof_torque"], 'x', label='measured')
132 | a.set(xlabel='Joint vel [rad/s]', ylabel='Joint Torque [Nm]', title='Torque/velocity curves')
133 | a.legend()
134 | # plot torques
135 | a = axs[2, 2]
136 | if log["dof_torque"] != []:
137 | a.plot(time, log["dof_torque"], label='measured')
138 | a.set(xlabel='time [s]', ylabel='Joint Torque [Nm]', title='Torque')
139 | a.legend()
140 | plt.show()
141 |
142 | def print_rewards(self):
143 | print("Average rewards per second:")
144 | for key, values in self.rew_log.items():
145 | mean = np.sum(np.array(values)) / self.num_episodes
146 | print(f" - {key}: {mean}")
147 | print(f"Total number of episodes: {self.num_episodes}")
148 |
149 | def __del__(self):
150 | if self.plot_process is not None:
151 | self.plot_process.kill()
152 |
--------------------------------------------------------------------------------
/legged_gym/utils/math.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import torch
32 | from torch import Tensor
33 | import numpy as np
34 | from isaacgym.torch_utils import quat_apply, normalize
35 | from typing import Tuple
36 |
37 |
38 | # @ torch.jit.script
39 | def quat_apply_yaw(quat, vec):
40 | quat_yaw = quat.clone().view(-1, 4)
41 | quat_yaw[:, :2] = 0.
42 | quat_yaw = normalize(quat_yaw)
43 | return quat_apply(quat_yaw, vec)
44 |
45 |
46 | # @ torch.jit.script
47 | def wrap_to_pi(angles):
48 | angles %= 2 * np.pi
49 | angles -= 2 * np.pi * (angles > np.pi)
50 | return angles
51 |
52 |
53 | # @ torch.jit.script
54 | def torch_rand_sqrt_float(lower, upper, shape, device):
55 | # type: (float, float, Tuple[int, int], str) -> Tensor
56 | r = 2 * torch.rand(*shape, device=device) - 1
57 | r = torch.where(r < 0., -torch.sqrt(-r), torch.sqrt(r))
58 | r = (r + 1.) / 2.
59 | return (upper - lower) * r + lower
60 |
--------------------------------------------------------------------------------
/licenses/assets/ANYmal_b_license.txt:
--------------------------------------------------------------------------------
1 | Copyright 2019 ANYbotics, https://www.anybotics.com
2 |
3 | Redistribution and use in source and binary forms, with or without modification,
4 | are permitted provided that the following conditions are met:
5 |
6 | 1. Redistributions of source code must retain the above copyright notice, this
7 | list of conditions and the following disclaimer.
8 |
9 | 2. Redistributions in binary form must reproduce the above copyright notice, this
10 | list of conditions and the following disclaimer in the documentation and/or
11 | other materials provided with the distribution.
12 |
13 | 3. The name of ANYbotics and ANYmal may not be used to endorse or promote products
14 | derived from this software without specific prior written permission.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
20 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 | WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
24 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
25 | POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/licenses/assets/ANYmal_c_license.txt:
--------------------------------------------------------------------------------
1 | Copyright 2020, ANYbotics AG.
2 |
3 | Redistribution and use in source and binary forms, with or without
4 | modification, are permitted provided that the following conditions
5 | are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright
8 | notice, this list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright
11 | notice, this list of conditions and the following disclaimer in
12 | the documentation and/or other materials provided with the
13 | distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived
17 | from this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/licenses/assets/cassie_license.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jenna Reher, jreher@caltech.edu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/licenses/dependencies/matplotlib_license.txt:
--------------------------------------------------------------------------------
1 | 1. This LICENSE AGREEMENT is between the Matplotlib Development Team ("MDT"), and the Individual or Organization ("Licensee") accessing and otherwise using matplotlib software in source or binary form and its associated documentation.
2 |
3 | 2. Subject to the terms and conditions of this License Agreement, MDT hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use matplotlib 3.4.3 alone or in any derivative version, provided, however, that MDT's License Agreement and MDT's notice of copyright, i.e., "Copyright (c) 2012-2013 Matplotlib Development Team; All Rights Reserved" are retained in matplotlib 3.4.3 alone or in any derivative version prepared by Licensee.
4 |
5 | 3. In the event Licensee prepares a derivative work that is based on or incorporates matplotlib 3.4.3 or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to matplotlib 3.4.3.
6 |
7 | 4. MDT is making matplotlib 3.4.3 available to Licensee on an "AS IS" basis. MDT MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, MDT MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB 3.4.3 WILL NOT INFRINGE ANY THIRD PARTY RIGHTS.
8 |
9 | 5. MDT SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB 3.4.3 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING MATPLOTLIB 3.4.3, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
10 |
11 | 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions.
12 |
13 | 7. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between MDT and Licensee. This License Agreement does not grant permission to use MDT trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party.
14 |
15 | 8. By copying, installing or otherwise using matplotlib 3.4.3, Licensee agrees to be bound by the terms and conditions of this License Agreement.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.3.0
2 | addict==2.4.0
3 | aiohttp==3.8.3
4 | aiosignal==1.2.0
5 | appdirs==1.4.4
6 | asttokens==2.4.1
7 | async-timeout==4.0.2
8 | asyncio==3.4.3
9 | asynctest==0.13.0
10 | attrs==22.1.0
11 | backcall==0.2.0
12 | cachetools==5.2.0
13 | ccimport==0.4.2
14 | certifi==2024.2.2
15 | charset-normalizer==2.1.1
16 | click==8.1.3
17 | cloudpickle==2.2.0
18 | comm==0.2.2
19 | commonmark==0.9.1
20 | ConfigArgParse==1.5.3
21 | cumm==0.4.11
22 | cumm-cu113==0.4.11
23 | cycler==0.11.0
24 | Cython==0.29.33
25 | dash==2.17.0
26 | dash-core-components==2.0.0
27 | dash-html-components==2.0.0
28 | dash-table==5.0.0
29 | decorator==5.1.1
30 | disutils==1.4.32.post2
31 | docker-pycreds==0.4.0
32 | einops==0.8.1
33 | -e git+https://github.com/OpenRobotLab/EmbodiedScan.git@23b97029a9954d81575f28f79921c59a14495211#egg=embodiedscan
34 | executing==2.0.1
35 | fastjsonschema==2.19.1
36 | filelock==3.14.0
37 | fire==0.6.0
38 | Flask==2.2.2
39 | fonttools==4.38.0
40 | frozenlist==1.3.1
41 | fsspec==2024.5.0
42 | fvcore==0.1.5.post20221221
43 | gitdb==4.0.9
44 | GitPython==3.1.43
45 | google-auth==2.14.0
46 | google-auth-oauthlib==0.4.6
47 | grpcio==1.50.0
48 | gym==0.21.0
49 | gym-notices==0.0.8
50 | h5py==3.11.0
51 | huggingface-hub==0.23.0
52 | idna==3.4
53 | imageio==2.22.3
54 | imageio-ffmpeg==0.4.7
55 | importlib-metadata==4.13.0
56 | importlib_resources==6.4.0
57 | iopath==0.1.10
58 | ipython==8.12.3
59 | ipywidgets==8.1.2
60 | itsdangerous==2.1.2
61 | jedi==0.19.1
62 | Jinja2==3.1.2
63 | joblib==1.4.2
64 | jsonschema==4.17.3
65 | jupyter_core==5.7.2
66 | jupyterlab_widgets==3.0.10
67 | kiwisolver==1.4.4
68 | kornia==0.6.8
69 | lark==1.1.9
70 | lpips==0.1.4
71 | Mako==1.3.5
72 | Markdown==3.4.1
73 | MarkupSafe==2.1.1
74 | matplotlib==3.5.3
75 | matplotlib-inline==0.1.7
76 | MinkowskiEngine==0.5.4
77 | mmcv==2.0.0rc4
78 | mmdet==3.3.0
79 | mmengine==0.10.4
80 | mpmath==1.3.0
81 | multidict==6.0.2
82 | nbformat==5.5.0
83 | nest-asyncio==1.6.0
84 | networkx==2.6.3
85 | ninja==1.11.1
86 | numpy==1.23.5
87 | nvidia-cublas-cu12==12.1.3.1
88 | nvidia-cuda-cupti-cu12==12.1.105
89 | nvidia-cuda-nvrtc-cu12==12.1.105
90 | nvidia-cuda-runtime-cu12==12.1.105
91 | nvidia-cudnn-cu12==8.9.2.26
92 | nvidia-cufft-cu12==11.0.2.54
93 | nvidia-curand-cu12==10.3.2.106
94 | nvidia-cusolver-cu12==11.4.5.107
95 | nvidia-cusparse-cu12==12.1.0.106
96 | nvidia-ml-py==12.535.161
97 | nvidia-nccl-cu12==2.20.5
98 | nvidia-nvjitlink-cu12==12.4.127
99 | nvidia-nvtx-cu12==12.1.105
100 | nvitop==1.3.2
101 | oauthlib==3.2.2
102 | open3d==0.16.0
103 | opencv-python==4.9.0.80
104 | packaging==21.3
105 | pandas==1.3.5
106 | parso==0.8.4
107 | pathfinding==1.0.14
108 | pathtools==0.1.2
109 | pccm==0.4.11
110 | pexpect==4.9.0
111 | pickleshare==0.7.5
112 | Pillow==9.3.0
113 | pkgutil_resolve_name==1.3.10
114 | platformdirs==4.2.2
115 | plotly==5.22.0
116 | plyfile==0.7.4
117 | portalocker==2.8.2
118 | promise==2.3
119 | prompt-toolkit==3.0.43
120 | protobuf==3.19.6
121 | psutil==5.9.4
122 | ptyprocess==0.7.0
123 | pure-eval==0.2.2
124 | pyasn1==0.4.8
125 | pyasn1-modules==0.2.8
126 | pybind11==2.10.1
127 | pycocotools==2.0.7
128 | # pycuda==2024.1
129 | pycuda
130 | Pygments==2.13.0
131 | pyparsing==3.0.9
132 | pyquaternion==0.9.9
133 | pyrsistent==0.20.0
134 | python-dateutil==2.8.2
135 | pytools==2024.1.2
136 | -e git+https://github.com/facebookresearch/pytorch3d.git@17117106e4dd8269c02271462539c9c4a5d0d5ec#egg=pytorch3d
137 | pytz==2022.6
138 | PyWavelets==1.3.0
139 | PyYAML==6.0
140 | regex==2024.5.15
141 | requests==2.28.1
142 | requests-oauthlib==1.3.1
143 | retrying==1.3.4
144 | rich==12.6.0
145 | rsa==4.9
146 | safetensors==0.4.3
147 | scikit-image==0.19.3
148 | scikit-learn==1.3.2
149 | scipy==1.8.0
150 | sentry-sdk==2.19.0
151 | setproctitle==1.3.2
152 | shapely==2.0.4
153 | shortuuid==1.0.11
154 | six==1.16.0
155 | smmap==5.0.0
156 | spconv==2.3.6
157 | spconv-cu113==2.3.6
158 | stack-data==0.6.3
159 | sympy==1.12
160 | tabulate==0.9.0
161 | tenacity==8.3.0
162 | tensorboard==2.10.1
163 | tensorboard-data-server==0.6.1
164 | tensorboard-plugin-wit==1.8.1
165 | termcolor==2.4.0
166 | terminaltables==3.1.10
167 | threadpoolctl==3.5.0
168 | tifffile==2021.11.2
169 | tokenizers==0.19.1
170 | torch==1.11.0+cu113
171 | torchaudio==0.11.0+cu113
172 | torchvision==0.12.0+cu113
173 | tqdm==4.64.1
174 | traitlets==5.14.3
175 | transformers==4.40.2
176 | triton==2.3.0
177 | typing_extensions==4.4.0
178 | urllib3==1.26.12
179 | wandb==0.18.7
180 | wcwidth==0.2.13
181 | Werkzeug==2.2.2
182 | widgetsnbextension==4.0.10
183 | yacs==0.1.8
184 | yapf==0.30.0
185 | yarl==1.8.1
186 | zipp==3.10.0
187 |
--------------------------------------------------------------------------------
/rsl_rl/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
--------------------------------------------------------------------------------
/rsl_rl/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .ppo import PPO
32 |
--------------------------------------------------------------------------------
/rsl_rl/env/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .vec_env import VecEnv
32 |
--------------------------------------------------------------------------------
/rsl_rl/env/vec_env.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from abc import ABC, abstractmethod
32 | import torch
33 | from typing import Tuple, Union
34 | from gym.spaces import Box
35 |
36 |
37 | # minimal interface of the environment
38 | class VecEnv(ABC):
39 | num_envs: int
40 | num_obs: int
41 | num_privileged_obs: int
42 | num_actions: int
43 | max_episode_length: int
44 | privileged_obs_buf: torch.Tensor
45 | obs_buf: torch.Tensor
46 | rew_buf: torch.Tensor
47 | reset_buf: torch.Tensor
48 | episode_length_buf: torch.Tensor # current episode duration
49 | extras: dict
50 | device: torch.device
51 | observation_space: Box
52 | action_space: Box
53 |
54 | @abstractmethod
55 | def step(self,
56 | actions: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]:
57 | pass
58 |
59 | @abstractmethod
60 | def reset(self, env_ids: Union[list, torch.Tensor]):
61 | pass
62 |
63 | @abstractmethod
64 | def get_observations(self) -> torch.Tensor:
65 | pass
66 |
67 | @abstractmethod
68 | def get_privileged_observations(self) -> Union[torch.Tensor, None]:
69 | pass
70 |
--------------------------------------------------------------------------------
/rsl_rl/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .actor_critic import ActorCritic
32 | from .actor_critic_recurrent import ActorCriticRecurrent
33 |
--------------------------------------------------------------------------------
/rsl_rl/modules/actor_critic.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import numpy as np
32 |
33 | import torch
34 | import torch.nn as nn
35 | from torch.distributions import Normal
36 | from torch.nn.modules import rnn
37 |
38 |
39 | class ActorCritic(nn.Module):
40 | is_recurrent = False
41 |
42 | def __init__(
43 | self,
44 | num_actor_obs,
45 | num_critic_obs,
46 | num_actions,
47 | actor_hidden_dims=[256, 256, 256],
48 | critic_hidden_dims=[256, 256, 256],
49 | activation='elu',
50 | init_noise_std=1.0,
51 | **kwargs
52 | ):
53 | if kwargs:
54 | print(
55 | "ActorCritic.__init__ got unexpected arguments, which will be ignored: " +
56 | str([key for key in kwargs.keys()])
57 | )
58 | super(ActorCritic, self).__init__()
59 |
60 | activation = get_activation(activation)
61 |
62 | mlp_input_dim_a = num_actor_obs
63 | mlp_input_dim_c = num_critic_obs
64 |
65 | # Policy
66 | actor_layers = []
67 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
68 | actor_layers.append(activation)
69 | for l in range(len(actor_hidden_dims)):
70 | if l == len(actor_hidden_dims) - 1:
71 | actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
72 | else:
73 | actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
74 | actor_layers.append(activation)
75 | self.actor = nn.Sequential(*actor_layers)
76 |
77 | # Value function
78 | critic_layers = []
79 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
80 | critic_layers.append(activation)
81 | for l in range(len(critic_hidden_dims)):
82 | if l == len(critic_hidden_dims) - 1:
83 | critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
84 | else:
85 | critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
86 | critic_layers.append(activation)
87 | self.critic = nn.Sequential(*critic_layers)
88 |
89 | print(f"Actor MLP: {self.actor}")
90 | print(f"Critic MLP: {self.critic}")
91 |
92 | # Action noise
93 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
94 | self.distribution = None
95 | # disable args validation for speedup
96 | Normal.set_default_validate_args = False
97 |
98 | # seems that we get better performance without init
99 | # self.init_memory_weights(self.memory_a, 0.001, 0.)
100 | # self.init_memory_weights(self.memory_c, 0.001, 0.)
101 |
102 | @staticmethod
103 | # not used at the moment
104 | def init_weights(sequential, scales):
105 | [
106 | torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
107 | for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
108 | ]
109 |
110 | def reset(self, dones=None):
111 | pass
112 |
113 | def forward(self):
114 | raise NotImplementedError
115 |
116 | @property
117 | def action_mean(self):
118 | return self.distribution.mean
119 |
120 | @property
121 | def action_std(self):
122 | return self.distribution.stddev
123 |
124 | @property
125 | def entropy(self):
126 | return self.distribution.entropy().sum(dim=-1)
127 |
128 | def update_distribution(self, observations):
129 | mean = self.actor(observations)
130 | self.distribution = Normal(mean, mean * 0. + self.std)
131 |
132 | def act(self, observations, **kwargs):
133 | self.update_distribution(observations)
134 | return self.distribution.sample()
135 |
136 | def get_actions_log_prob(self, actions):
137 | return self.distribution.log_prob(actions).sum(dim=-1)
138 |
139 | def act_inference(self, observations):
140 | actions_mean = self.actor(observations)
141 | return actions_mean
142 |
143 | def evaluate(self, critic_observations, **kwargs):
144 | value = self.critic(critic_observations)
145 | return value
146 |
147 |
148 | def get_activation(act_name):
149 | if act_name == "elu":
150 | return nn.ELU()
151 | elif act_name == "selu":
152 | return nn.SELU()
153 | elif act_name == "relu":
154 | return nn.ReLU()
155 | elif act_name == "crelu":
156 | return nn.ReLU()
157 | elif act_name == "lrelu":
158 | return nn.LeakyReLU()
159 | elif act_name == "tanh":
160 | return nn.Tanh()
161 | elif act_name == "sigmoid":
162 | return nn.Sigmoid()
163 | else:
164 | print("invalid activation function!")
165 | return None
166 |
--------------------------------------------------------------------------------
/rsl_rl/modules/actor_critic_recurrent.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import numpy as np
32 |
33 | import torch
34 | import torch.nn as nn
35 | from torch.distributions import Normal
36 | from torch.nn.modules import rnn
37 | from .actor_critic import ActorCritic, get_activation
38 | from rsl_rl.utils import unpad_trajectories
39 |
40 |
41 | class ActorCriticRecurrent(ActorCritic):
42 | is_recurrent = True
43 |
44 | def __init__(
45 | self,
46 | num_actor_obs,
47 | num_critic_obs,
48 | num_actions,
49 | actor_hidden_dims=[256, 256, 256],
50 | critic_hidden_dims=[256, 256, 256],
51 | activation='elu',
52 | rnn_type='lstm',
53 | rnn_hidden_size=256,
54 | rnn_num_layers=1,
55 | init_noise_std=1.0,
56 | **kwargs
57 | ):
58 | if kwargs:
59 | print(
60 | "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),
61 | )
62 |
63 | super().__init__(
64 | num_actor_obs=rnn_hidden_size,
65 | num_critic_obs=rnn_hidden_size,
66 | num_actions=num_actions,
67 | actor_hidden_dims=actor_hidden_dims,
68 | critic_hidden_dims=critic_hidden_dims,
69 | activation=activation,
70 | init_noise_std=init_noise_std
71 | )
72 |
73 | activation = get_activation(activation)
74 |
75 | self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
76 | self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
77 |
78 | print(f"Actor RNN: {self.memory_a}")
79 | print(f"Critic RNN: {self.memory_c}")
80 |
81 | def reset(self, dones=None):
82 | self.memory_a.reset(dones)
83 | self.memory_c.reset(dones)
84 |
85 | def act(self, observations, masks=None, hidden_states=None):
86 | input_a = self.memory_a(observations, masks, hidden_states)
87 | return super().act(input_a.squeeze(0))
88 |
89 | def act_inference(self, observations):
90 | input_a = self.memory_a(observations)
91 | return super().act_inference(input_a.squeeze(0))
92 |
93 | def evaluate(self, critic_observations, masks=None, hidden_states=None):
94 | input_c = self.memory_c(critic_observations, masks, hidden_states)
95 | return super().evaluate(input_c.squeeze(0))
96 |
97 | def get_hidden_states(self):
98 | return self.memory_a.hidden_states, self.memory_c.hidden_states
99 |
100 |
101 | class Memory(torch.nn.Module):
102 | def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
103 | super().__init__()
104 | # RNN
105 | rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
106 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
107 | self.hidden_states = None
108 |
109 | def forward(self, input, masks=None, hidden_states=None):
110 | batch_mode = masks is not None
111 | if batch_mode:
112 | # batch mode (policy update): need saved hidden states
113 | if hidden_states is None:
114 | raise ValueError("Hidden states not passed to memory module during policy update")
115 | out, _ = self.rnn(input, hidden_states)
116 | out = unpad_trajectories(out, masks)
117 | else:
118 | # inference mode (collection): use hidden states of last step
119 | out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
120 | return out
121 |
122 | def reset(self, dones=None):
123 | # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
124 | for hidden_state in self.hidden_states:
125 | hidden_state[..., dones, :] = 0.0
126 |
--------------------------------------------------------------------------------
/rsl_rl/runners/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .on_policy_runner import OnPolicyRunner
32 |
--------------------------------------------------------------------------------
/rsl_rl/storage/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from .rollout_storage import RolloutStorage
--------------------------------------------------------------------------------
/rsl_rl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | from .utils import split_and_pad_trajectories, unpad_trajectories
32 |
--------------------------------------------------------------------------------
/rsl_rl/utils/utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | #
29 | # Copyright (c) 2021 ETH Zurich, Nikita Rudin
30 |
31 | import torch
32 |
33 |
34 | def split_and_pad_trajectories(tensor, dones):
35 | """ Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
36 | Returns masks corresponding to valid parts of the trajectories
37 | Example:
38 | Input: [ [a1, a2, a3, a4 | a5, a6],
39 | [b1, b2 | b3, b4, b5 | b6]
40 | ]
41 |
42 | Output:[ [a1, a2, a3, a4], | [ [True, True, True, True],
43 | [a5, a6, 0, 0], | [True, True, False, False],
44 | [b1, b2, 0, 0], | [True, True, False, False],
45 | [b3, b4, b5, 0], | [True, True, True, False],
46 | [b6, 0, 0, 0] | [True, False, False, False],
47 | ] | ]
48 |
49 | Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions]
50 | """
51 | dones = dones.clone()
52 | dones[-1] = 1
53 | # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping
54 | flat_dones = dones.transpose(1, 0).reshape(-1, 1)
55 |
56 | # Get length of trajectory by counting the number of successive not done elements
57 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0]))
58 | trajectory_lengths = done_indices[1:] - done_indices[:-1]
59 | trajectory_lengths_list = trajectory_lengths.tolist()
60 | # Extract the individual trajectories
61 | trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list)
62 | padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
63 |
64 | trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
65 | return padded_trajectories, trajectory_masks
66 |
67 |
68 | def unpad_trajectories(trajectories, masks):
69 | """ Does the inverse operation of split_and_pad_trajectories()
70 | """
71 | # Need to transpose before and after the masking to have proper reshaping
72 | return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0],
73 | trajectories.shape[-1]).transpose(1, 0)
74 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # setup.py
2 | from setuptools import find_packages, setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 |
6 | pathfinding_2D_cuda = CUDAExtension(
7 | name='bfs_cuda_2D',
8 | sources=[
9 | 'gleam/utils/bfs_cuda_2D.cpp',
10 | 'gleam/utils/bfs_cuda_kernel_2D.cu',
11 | ],
12 | )
13 |
14 | # Define the setup configuration
15 | # Original author: Nikita Rudin (legged_gym)
16 | setup(
17 | name='legged_gym',
18 | version='1.0.0',
19 | author='Xiao Chen',
20 | license="BSD-3-Clause",
21 | packages=find_packages(),
22 | author_email='cx123@ie.cuhk.edu.hk',
23 | description='GLEAM based on Isaac Gym environments',
24 | install_requires=[
25 | 'isaacgym',
26 | 'gym',
27 | 'matplotlib',
28 | "tensorboard",
29 | "cloudpickle",
30 | "pandas",
31 | "yapf~=0.30.0",
32 | "wandb",
33 | "opencv-python>=3.0.0"
34 | ],
35 | ext_modules=[
36 | pathfinding_2D_cuda
37 | ],
38 | cmdclass={
39 | 'build_ext': BuildExtension
40 | },
41 | )
42 |
--------------------------------------------------------------------------------
/stable_baselines3/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from stable_baselines3.a2c import A2C
4 | from stable_baselines3.common.utils import get_system_info
5 | from stable_baselines3.ddpg import DDPG
6 | from stable_baselines3.dqn import DQN
7 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
8 | from stable_baselines3.ppo import PPO
9 | from stable_baselines3.sac import SAC
10 | from stable_baselines3.td3 import TD3
11 |
12 | # Read version from file
13 | version_file = os.path.join(os.path.dirname(__file__), "version.txt")
14 | with open(version_file) as file_handler:
15 | __version__ = file_handler.read().strip()
16 |
17 |
18 | def HER(*args, **kwargs):
19 | raise ImportError(
20 | "Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n "
21 | "Please check the documentation for more information: https://stable-baselines3.readthedocs.io/"
22 | )
23 |
--------------------------------------------------------------------------------
/stable_baselines3/a2c/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.a2c.a2c import A2C
2 | from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
3 |
--------------------------------------------------------------------------------
/stable_baselines3/a2c/policies.py:
--------------------------------------------------------------------------------
1 | # This file is here just to define MlpPolicy/CnnPolicy
2 | # that work for A2C
3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
4 |
5 | MlpPolicy = ActorCriticPolicy
6 | CnnPolicy = ActorCriticCnnPolicy
7 | MultiInputPolicy = MultiInputActorCriticPolicy
8 |
--------------------------------------------------------------------------------
/stable_baselines3/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/stable_baselines3/common/__init__.py
--------------------------------------------------------------------------------
/stable_baselines3/common/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv
2 | from stable_baselines3.common.envs.identity_env import (
3 | FakeImageEnv,
4 | IdentityEnv,
5 | IdentityEnvBox,
6 | IdentityEnvMultiBinary,
7 | IdentityEnvMultiDiscrete,
8 | )
9 | from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv
10 |
--------------------------------------------------------------------------------
/stable_baselines3/common/envs/identity_env.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import numpy as np
4 | from gym import Env, Space
5 | from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
6 |
7 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
8 |
9 |
10 | class IdentityEnv(Env):
11 | def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100):
12 | """
13 | Identity environment for testing purposes
14 |
15 | :param dim: the size of the action and observation dimension you want
16 | to learn. Provide at most one of ``dim`` and ``space``. If both are
17 | None, then initialization proceeds with ``dim=1`` and ``space=None``.
18 | :param space: the action and observation space. Provide at most one of
19 | ``dim`` and ``space``.
20 | :param ep_length: the length of each episode in timesteps
21 | """
22 | if space is None:
23 | if dim is None:
24 | dim = 1
25 | space = Discrete(dim)
26 | else:
27 | assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
28 |
29 | self.action_space = self.observation_space = space
30 | self.ep_length = ep_length
31 | self.current_step = 0
32 | self.num_resets = -1 # Becomes 0 after __init__ exits.
33 | self.reset()
34 |
35 | def reset(self) -> GymObs:
36 | self.current_step = 0
37 | self.num_resets += 1
38 | self._choose_next_state()
39 | return self.state
40 |
41 | def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
42 | reward = self._get_reward(action)
43 | self._choose_next_state()
44 | self.current_step += 1
45 | done = self.current_step >= self.ep_length
46 | return self.state, reward, done, {}
47 |
48 | def _choose_next_state(self) -> None:
49 | self.state = self.action_space.sample()
50 |
51 | def _get_reward(self, action: Union[int, np.ndarray]) -> float:
52 | return 1.0 if np.all(self.state == action) else 0.0
53 |
54 | def render(self, mode: str = "human") -> None:
55 | pass
56 |
57 |
58 | class IdentityEnvBox(IdentityEnv):
59 | def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
60 | """
61 | Identity environment for testing purposes
62 |
63 | :param low: the lower bound of the box dim
64 | :param high: the upper bound of the box dim
65 | :param eps: the epsilon bound for correct value
66 | :param ep_length: the length of each episode in timesteps
67 | """
68 | space = Box(low=low, high=high, shape=(1, ), dtype=np.float32)
69 | super().__init__(ep_length=ep_length, space=space)
70 | self.eps = eps
71 |
72 | def step(self, action: np.ndarray) -> GymStepReturn:
73 | reward = self._get_reward(action)
74 | self._choose_next_state()
75 | self.current_step += 1
76 | done = self.current_step >= self.ep_length
77 | return self.state, reward, done, {}
78 |
79 | def _get_reward(self, action: np.ndarray) -> float:
80 | return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
81 |
82 |
83 | class IdentityEnvMultiDiscrete(IdentityEnv):
84 | def __init__(self, dim: int = 1, ep_length: int = 100):
85 | """
86 | Identity environment for testing purposes
87 |
88 | :param dim: the size of the dimensions you want to learn
89 | :param ep_length: the length of each episode in timesteps
90 | """
91 | space = MultiDiscrete([dim, dim])
92 | super().__init__(ep_length=ep_length, space=space)
93 |
94 |
95 | class IdentityEnvMultiBinary(IdentityEnv):
96 | def __init__(self, dim: int = 1, ep_length: int = 100):
97 | """
98 | Identity environment for testing purposes
99 |
100 | :param dim: the size of the dimensions you want to learn
101 | :param ep_length: the length of each episode in timesteps
102 | """
103 | space = MultiBinary(dim)
104 | super().__init__(ep_length=ep_length, space=space)
105 |
106 |
107 | class FakeImageEnv(Env):
108 | """
109 | Fake image environment for testing purposes, it mimics Atari games.
110 |
111 | :param action_dim: Number of discrete actions
112 | :param screen_height: Height of the image
113 | :param screen_width: Width of the image
114 | :param n_channels: Number of color channels
115 | :param discrete: Create discrete action space instead of continuous
116 | :param channel_first: Put channels on first axis instead of last
117 | """
118 | def __init__(
119 | self,
120 | action_dim: int = 6,
121 | screen_height: int = 84,
122 | screen_width: int = 84,
123 | n_channels: int = 1,
124 | discrete: bool = True,
125 | channel_first: bool = False,
126 | ):
127 | self.observation_shape = (screen_height, screen_width, n_channels)
128 | if channel_first:
129 | self.observation_shape = (n_channels, screen_height, screen_width)
130 | self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
131 | if discrete:
132 | self.action_space = Discrete(action_dim)
133 | else:
134 | self.action_space = Box(low=-1, high=1, shape=(5, ), dtype=np.float32)
135 | self.ep_length = 10
136 | self.current_step = 0
137 |
138 | def reset(self) -> np.ndarray:
139 | self.current_step = 0
140 | return self.observation_space.sample()
141 |
142 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
143 | reward = 0.0
144 | self.current_step += 1
145 | done = self.current_step >= self.ep_length
146 | return self.observation_space.sample(), reward, done, {}
147 |
148 | def render(self, mode: str = "human") -> None:
149 | pass
150 |
--------------------------------------------------------------------------------
/stable_baselines3/common/envs/multi_input_envs.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 |
3 | import gym
4 | import numpy as np
5 |
6 | from stable_baselines3.common.type_aliases import GymStepReturn
7 |
8 |
9 | class SimpleMultiObsEnv(gym.Env):
10 | """
11 | Base class for GridWorld-based MultiObs Environments 4x4 grid world.
12 |
13 | .. code-block:: text
14 |
15 | ____________
16 | | 0 1 2 3|
17 | | 4|¯5¯¯6¯| 7|
18 | | 8|_9_10_|11|
19 | |12 13 14 15|
20 | ¯¯¯¯¯¯¯¯¯¯¯¯¯¯
21 |
22 | start is 0
23 | states 5, 6, 9, and 10 are blocked
24 | goal is 15
25 | actions are = [left, down, right, up]
26 |
27 | simple linear state env of 15 states but encoded with a vector and an image observation:
28 | each column is represented by a random vector and each row is
29 | represented by a random image, both sampled once at creation time.
30 |
31 | :param num_col: Number of columns in the grid
32 | :param num_row: Number of rows in the grid
33 | :param random_start: If true, agent starts in random position
34 | :param channel_last: If true, the image will be channel last, else it will be channel first
35 | """
36 | def __init__(
37 | self,
38 | num_col: int = 4,
39 | num_row: int = 4,
40 | random_start: bool = True,
41 | discrete_actions: bool = True,
42 | channel_last: bool = True,
43 | ):
44 | super().__init__()
45 |
46 | self.vector_size = 5
47 | if channel_last:
48 | self.img_size = [64, 64, 1]
49 | else:
50 | self.img_size = [1, 64, 64]
51 |
52 | self.random_start = random_start
53 | self.discrete_actions = discrete_actions
54 | if discrete_actions:
55 | self.action_space = gym.spaces.Discrete(4)
56 | else:
57 | self.action_space = gym.spaces.Box(0, 1, (4, ))
58 |
59 | self.observation_space = gym.spaces.Dict(
60 | spaces={
61 | "vec": gym.spaces.Box(0, 1, (self.vector_size, ), dtype=np.float64),
62 | "img": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8),
63 | }
64 | )
65 | self.count = 0
66 | # Timeout
67 | self.max_count = 100
68 | self.log = ""
69 | self.state = 0
70 | self.action2str = ["left", "down", "right", "up"]
71 | self.init_possible_transitions()
72 |
73 | self.num_col = num_col
74 | self.state_mapping = []
75 | self.init_state_mapping(num_col, num_row)
76 |
77 | self.max_state = len(self.state_mapping) - 1
78 |
79 | def init_state_mapping(self, num_col: int, num_row: int) -> None:
80 | """
81 | Initializes the state_mapping array which holds the observation values for each state
82 |
83 | :param num_col: Number of columns.
84 | :param num_row: Number of rows.
85 | """
86 | # Each column is represented by a random vector
87 | col_vecs = np.random.random((num_col, self.vector_size))
88 | # Each row is represented by a random image
89 | row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8)
90 |
91 | for i in range(num_col):
92 | for j in range(num_row):
93 | self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)})
94 |
95 | def get_state_mapping(self) -> Dict[str, np.ndarray]:
96 | """
97 | Uses the state to get the observation mapping.
98 |
99 | :return: observation dict {'vec': ..., 'img': ...}
100 | """
101 | return self.state_mapping[self.state]
102 |
103 | def init_possible_transitions(self) -> None:
104 | """
105 | Initializes the transitions of the environment
106 | The environment exploits the cardinal directions of the grid by noting that
107 | they correspond to simple addition and subtraction from the cell id within the grid
108 |
109 | - up => means moving up a row => means subtracting the length of a column
110 | - down => means moving down a row => means adding the length of a column
111 | - left => means moving left by one => means subtracting 1
112 | - right => means moving right by one => means adding 1
113 |
114 | Thus one only needs to specify in which states each action is possible
115 | in order to define the transitions of the environment
116 | """
117 | self.left_possible = [1, 2, 3, 13, 14, 15]
118 | self.down_possible = [0, 4, 8, 3, 7, 11]
119 | self.right_possible = [0, 1, 2, 12, 13, 14]
120 | self.up_possible = [4, 8, 12, 7, 11, 15]
121 |
122 | def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn:
123 | """
124 | Run one timestep of the environment's dynamics. When end of
125 | episode is reached, you are responsible for calling `reset()`
126 | to reset this environment's state.
127 | Accepts an action and returns a tuple (observation, reward, done, info).
128 |
129 | :param action:
130 | :return: tuple (observation, reward, done, info).
131 | """
132 | if not self.discrete_actions:
133 | action = np.argmax(action)
134 | else:
135 | action = int(action)
136 |
137 | self.count += 1
138 |
139 | prev_state = self.state
140 |
141 | reward = -0.1
142 | # define state transition
143 | if self.state in self.left_possible and action == 0: # left
144 | self.state -= 1
145 | elif self.state in self.down_possible and action == 1: # down
146 | self.state += self.num_col
147 | elif self.state in self.right_possible and action == 2: # right
148 | self.state += 1
149 | elif self.state in self.up_possible and action == 3: # up
150 | self.state -= self.num_col
151 |
152 | got_to_end = self.state == self.max_state
153 | reward = 1 if got_to_end else reward
154 | done = self.count > self.max_count or got_to_end
155 |
156 | self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}"
157 |
158 | return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end}
159 |
160 | def render(self, mode: str = "human") -> None:
161 | """
162 | Prints the log of the environment.
163 |
164 | :param mode:
165 | """
166 | print(self.log)
167 |
168 | def reset(self) -> Dict[str, np.ndarray]:
169 | """
170 | Resets the environment state and step count and returns reset observation.
171 |
172 | :return: observation dict {'vec': ..., 'img': ...}
173 | """
174 | self.count = 0
175 | if not self.random_start:
176 | self.state = 0
177 | else:
178 | self.state = np.random.randint(0, self.max_state)
179 | return self.state_mapping[self.state]
180 |
--------------------------------------------------------------------------------
/stable_baselines3/common/evaluation_gleam.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import torch
3 | from typing import List, Tuple, Union
4 | from stable_baselines3.common import base_class
5 | from stable_baselines3.common.vec_env import VecEnv
6 |
7 |
8 | def evaluate_policy_grid_obs(
9 | model: "base_class.BaseAlgorithm",
10 | env: Union[gym.Env, VecEnv],
11 | deterministic: bool = True,
12 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
13 | """
14 | Runs policy for ``n_eval_episodes`` episodes and returns average reward.
15 | If a vector env is passed in, this divides the episodes to evaluate onto the
16 | different elements of the vector env. This static division of work is done to
17 | remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
18 | details and discussion.
19 |
20 | .. note::
21 | If environment has not been wrapped with ``Monitor`` wrapper, reward and
22 | episode lengths are counted as it appears with ``env.step`` calls. If
23 | the environment contains wrappers that modify rewards or episode lengths
24 | (e.g. reward scaling, early episode reset), these will affect the evaluation
25 | results as well. You can avoid this by wrapping environment with ``Monitor``
26 | wrapper before anything else.
27 |
28 | :param model: The RL agent you want to evaluate.
29 | :param env: The gym environment or ``VecEnv`` environment.
30 | :param deterministic: Whether to use deterministic or stochastic actions
31 | :return: Mean reward per episode, std of reward per episode.
32 | Returns ([float], [int]) when ``return_episode_rewards`` is True, first
33 | list containing per-episode rewards and second containing per-episode lengths
34 | (in number of steps).
35 | """
36 |
37 | with torch.no_grad():
38 | observations = env.reset()
39 | # set termination condition in eval_env
40 | while(1):
41 | actions, _ = model.predict(observations, state=None, deterministic=deterministic)
42 | observations, rewards, dones, infos = env.step(actions)
43 |
44 |
--------------------------------------------------------------------------------
/stable_baselines3/common/noise.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from abc import ABC, abstractmethod
3 | from typing import Iterable, List, Optional
4 |
5 | import numpy as np
6 |
7 |
8 | class ActionNoise(ABC):
9 | """
10 | The action noise base class
11 | """
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def reset(self) -> None:
16 | """
17 | call end of episode reset for the noise
18 | """
19 | pass
20 |
21 | @abstractmethod
22 | def __call__(self) -> np.ndarray:
23 | raise NotImplementedError()
24 |
25 |
26 | class NormalActionNoise(ActionNoise):
27 | """
28 | A Gaussian action noise
29 |
30 | :param mean: the mean value of the noise
31 | :param sigma: the scale of the noise (std here)
32 | """
33 | def __init__(self, mean: np.ndarray, sigma: np.ndarray):
34 | self._mu = mean
35 | self._sigma = sigma
36 | super().__init__()
37 |
38 | def __call__(self) -> np.ndarray:
39 | return np.random.normal(self._mu, self._sigma)
40 |
41 | def __repr__(self) -> str:
42 | return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
43 |
44 |
45 | class OrnsteinUhlenbeckActionNoise(ActionNoise):
46 | """
47 | An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction.
48 |
49 | Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
50 |
51 | :param mean: the mean of the noise
52 | :param sigma: the scale of the noise
53 | :param theta: the rate of mean reversion
54 | :param dt: the timestep for the noise
55 | :param initial_noise: the initial value for the noise output, (if None: 0)
56 | """
57 | def __init__(
58 | self,
59 | mean: np.ndarray,
60 | sigma: np.ndarray,
61 | theta: float = 0.15,
62 | dt: float = 1e-2,
63 | initial_noise: Optional[np.ndarray] = None,
64 | ):
65 | self._theta = theta
66 | self._mu = mean
67 | self._sigma = sigma
68 | self._dt = dt
69 | self.initial_noise = initial_noise
70 | self.noise_prev = np.zeros_like(self._mu)
71 | self.reset()
72 | super().__init__()
73 |
74 | def __call__(self) -> np.ndarray:
75 | noise = (
76 | self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt +
77 | self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
78 | )
79 | self.noise_prev = noise
80 | return noise
81 |
82 | def reset(self) -> None:
83 | """
84 | reset the Ornstein Uhlenbeck noise, to the initial position
85 | """
86 | self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
87 |
88 | def __repr__(self) -> str:
89 | return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})"
90 |
91 |
92 | class VectorizedActionNoise(ActionNoise):
93 | """
94 | A Vectorized action noise for parallel environments.
95 |
96 | :param base_noise: ActionNoise The noise generator to use
97 | :param n_envs: The number of parallel environments
98 | """
99 | def __init__(self, base_noise: ActionNoise, n_envs: int):
100 | try:
101 | self.n_envs = int(n_envs)
102 | assert self.n_envs > 0
103 | except (TypeError, AssertionError) as e:
104 | raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e
105 |
106 | self.base_noise = base_noise
107 | self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
108 |
109 | def reset(self, indices: Optional[Iterable[int]] = None) -> None:
110 | """
111 | Reset all the noise processes, or those listed in indices
112 |
113 | :param indices: Optional[Iterable[int]] The indices to reset. Default: None.
114 | If the parameter is None, then all processes are reset to their initial position.
115 | """
116 | if indices is None:
117 | indices = range(len(self.noises))
118 |
119 | for index in indices:
120 | self.noises[index].reset()
121 |
122 | def __repr__(self) -> str:
123 | return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})"
124 |
125 | def __call__(self) -> np.ndarray:
126 | """
127 | Generate and stack the action noise from each noise object
128 | """
129 | noise = np.stack([noise() for noise in self.noises])
130 | return noise
131 |
132 | @property
133 | def base_noise(self) -> ActionNoise:
134 | return self._base_noise
135 |
136 | @base_noise.setter
137 | def base_noise(self, base_noise: ActionNoise) -> None:
138 | if base_noise is None:
139 | raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
140 | if not isinstance(base_noise, ActionNoise):
141 | raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
142 | self._base_noise = base_noise
143 |
144 | @property
145 | def noises(self) -> List[ActionNoise]:
146 | return self._noises
147 |
148 | @noises.setter
149 | def noises(self, noises: List[ActionNoise]) -> None:
150 | noises = list(noises) # raises TypeError if not iterable
151 | assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
152 |
153 | different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))]
154 |
155 | if len(different_types):
156 | raise ValueError(
157 | f"Noise instances at indices {different_types} don't match the type of base_noise",
158 | type(self.base_noise)
159 | )
160 |
161 | self._noises = noises
162 | for noise in noises:
163 | noise.reset()
164 |
--------------------------------------------------------------------------------
/stable_baselines3/common/results_plotter.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | # import matplotlib
7 | # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
8 | from matplotlib import pyplot as plt
9 |
10 | from stable_baselines3.common.monitor import load_results
11 |
12 | X_TIMESTEPS = "timesteps"
13 | X_EPISODES = "episodes"
14 | X_WALLTIME = "walltime_hrs"
15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
16 | EPISODES_WINDOW = 100
17 |
18 |
19 | def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
20 | """
21 | Apply a rolling window to a np.ndarray
22 |
23 | :param array: the input Array
24 | :param window: length of the rolling window
25 | :return: rolling window on the input array
26 | """
27 | shape = array.shape[:-1] + (array.shape[-1] - window + 1, window)
28 | strides = array.strides + (array.strides[-1], )
29 | return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
30 |
31 |
32 | def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]:
33 | """
34 | Apply a function to the rolling window of 2 arrays
35 |
36 | :param var_1: variable 1
37 | :param var_2: variable 2
38 | :param window: length of the rolling window
39 | :param func: function to apply on the rolling window on variable 2 (such as np.mean)
40 | :return: the rolling output with applied function
41 | """
42 | var_2_window = rolling_window(var_2, window)
43 | function_on_var2 = func(var_2_window, axis=-1)
44 | return var_1[window - 1:], function_on_var2
45 |
46 |
47 | def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
48 | """
49 | Decompose a data frame variable to x ans ys
50 |
51 | :param data_frame: the input data
52 | :param x_axis: the axis for the x and y output
53 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
54 | :return: the x and y output
55 | """
56 | if x_axis == X_TIMESTEPS:
57 | x_var = np.cumsum(data_frame.l.values)
58 | y_var = data_frame.r.values
59 | elif x_axis == X_EPISODES:
60 | x_var = np.arange(len(data_frame))
61 | y_var = data_frame.r.values
62 | elif x_axis == X_WALLTIME:
63 | # Convert to hours
64 | x_var = data_frame.t.values / 3600.0
65 | y_var = data_frame.r.values
66 | else:
67 | raise NotImplementedError
68 | return x_var, y_var
69 |
70 |
71 | def plot_curves(
72 | xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)
73 | ) -> None:
74 | """
75 | plot the curves
76 |
77 | :param xy_list: the x and y coordinates to plot
78 | :param x_axis: the axis for the x and y output
79 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
80 | :param title: the title of the plot
81 | :param figsize: Size of the figure (width, height)
82 | """
83 |
84 | plt.figure(title, figsize=figsize)
85 | max_x = max(xy[0][-1] for xy in xy_list)
86 | min_x = 0
87 | for (_, (x, y)) in enumerate(xy_list):
88 | plt.scatter(x, y, s=2)
89 | # Do not plot the smoothed curve at all if the timeseries is shorter than window size.
90 | if x.shape[0] >= EPISODES_WINDOW:
91 | # Compute and plot rolling mean with window of size EPISODE_WINDOW
92 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
93 | plt.plot(x, y_mean)
94 | plt.xlim(min_x, max_x)
95 | plt.title(title)
96 | plt.xlabel(x_axis)
97 | plt.ylabel("Episode Rewards")
98 | plt.tight_layout()
99 |
100 |
101 | def plot_results(
102 | dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)
103 | ) -> None:
104 | """
105 | Plot the results using csv files from ``Monitor`` wrapper.
106 |
107 | :param dirs: the save location of the results to plot
108 | :param num_timesteps: only plot the points below this value
109 | :param x_axis: the axis for the x and y output
110 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
111 | :param task_name: the title of the task to plot
112 | :param figsize: Size of the figure (width, height)
113 | """
114 |
115 | data_frames = []
116 | for folder in dirs:
117 | data_frame = load_results(folder)
118 | if num_timesteps is not None:
119 | data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps]
120 | data_frames.append(data_frame)
121 | xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames]
122 | plot_curves(xy_list, x_axis, task_name, figsize)
123 |
--------------------------------------------------------------------------------
/stable_baselines3/common/running_mean_std.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import numpy as np
4 |
5 |
6 | class RunningMeanStd:
7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
8 | """
9 | Calulates the running mean and std of a data stream
10 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
11 |
12 | :param epsilon: helps with arithmetic issues
13 | :param shape: the shape of the data stream's output
14 | """
15 | self.mean = np.zeros(shape, np.float64)
16 | self.var = np.ones(shape, np.float64)
17 | self.count = epsilon
18 |
19 | def copy(self) -> "RunningMeanStd":
20 | """
21 | :return: Return a copy of the current object.
22 | """
23 | new_object = RunningMeanStd(shape=self.mean.shape)
24 | new_object.mean = self.mean.copy()
25 | new_object.var = self.var.copy()
26 | new_object.count = float(self.count)
27 | return new_object
28 |
29 | def combine(self, other: "RunningMeanStd") -> None:
30 | """
31 | Combine stats from another ``RunningMeanStd`` object.
32 |
33 | :param other: The other object to combine with.
34 | """
35 | self.update_from_moments(other.mean, other.var, other.count)
36 |
37 | def update(self, arr: np.ndarray) -> None:
38 | batch_mean = np.mean(arr, axis=0)
39 | batch_var = np.var(arr, axis=0)
40 | batch_count = arr.shape[0]
41 | self.update_from_moments(batch_mean, batch_var, batch_count)
42 |
43 | def update_from_moments(
44 | self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: Union[int, float]
45 | ) -> None:
46 | delta = batch_mean - self.mean
47 | tot_count = self.count + batch_count
48 |
49 | new_mean = self.mean + delta * batch_count / tot_count
50 | m_a = self.var * self.count
51 | m_b = batch_var * batch_count
52 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
53 | new_var = m_2 / (self.count + batch_count)
54 |
55 | new_count = batch_count + self.count
56 |
57 | self.mean = new_mean
58 | self.var = new_var
59 | self.count = new_count
60 |
--------------------------------------------------------------------------------
/stable_baselines3/common/sb2_compat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/stable_baselines3/common/sb2_compat/__init__.py
--------------------------------------------------------------------------------
/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Iterable, Optional
2 |
3 | import torch
4 | from torch.optim import Optimizer
5 |
6 |
7 | class RMSpropTFLike(Optimizer):
8 | r"""Implements RMSprop algorithm with closer match to Tensorflow version.
9 |
10 | For reproducibility with original stable-baselines. Use this
11 | version with e.g. A2C for stabler learning than with the PyTorch
12 | RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop.
13 |
14 | See a more throughout conversion in pytorch-image-models repository:
15 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py
16 |
17 | Changes to the original RMSprop:
18 | - Move epsilon inside square root
19 | - Initialize squared gradient to ones rather than zeros
20 |
21 | Proposed by G. Hinton in his
22 | `course `_.
23 |
24 | The centered version first appears in `Generating Sequences
25 | With Recurrent Neural Networks `_.
26 |
27 | The implementation here takes the square root of the gradient average before
28 | adding epsilon (note that TensorFlow interchanges these two operations). The effective
29 | learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha`
30 | is the scheduled learning rate and :math:`v` is the weighted moving average
31 | of the squared gradient.
32 |
33 | :params: iterable of parameters to optimize or dicts defining
34 | parameter groups
35 | :param lr: learning rate (default: 1e-2)
36 | :param momentum: momentum factor (default: 0)
37 | :param alpha: smoothing constant (default: 0.99)
38 | :param eps: term added to the denominator to improve
39 | numerical stability (default: 1e-8)
40 | :param centered: if ``True``, compute the centered RMSProp,
41 | the gradient is normalized by an estimation of its variance
42 | :param weight_decay: weight decay (L2 penalty) (default: 0)
43 |
44 | """
45 | def __init__(
46 | self,
47 | params: Iterable[torch.nn.Parameter],
48 | lr: float = 1e-2,
49 | alpha: float = 0.99,
50 | eps: float = 1e-8,
51 | weight_decay: float = 0,
52 | momentum: float = 0,
53 | centered: bool = False,
54 | ):
55 | if not 0.0 <= lr:
56 | raise ValueError(f"Invalid learning rate: {lr}")
57 | if not 0.0 <= eps:
58 | raise ValueError(f"Invalid epsilon value: {eps}")
59 | if not 0.0 <= momentum:
60 | raise ValueError(f"Invalid momentum value: {momentum}")
61 | if not 0.0 <= weight_decay:
62 | raise ValueError(f"Invalid weight_decay value: {weight_decay}")
63 | if not 0.0 <= alpha:
64 | raise ValueError(f"Invalid alpha value: {alpha}")
65 |
66 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
67 | super().__init__(params, defaults)
68 |
69 | def __setstate__(self, state: Dict[str, Any]) -> None:
70 | super().__setstate__(state)
71 | for group in self.param_groups:
72 | group.setdefault("momentum", 0)
73 | group.setdefault("centered", False)
74 |
75 | @torch.no_grad()
76 | def step(self, closure: Optional[Callable[[], None]] = None) -> Optional[torch.Tensor]:
77 | """Performs a single optimization step.
78 |
79 | :param closure: A closure that reevaluates the model
80 | and returns the loss.
81 | :return: loss
82 | """
83 | loss = None
84 | if closure is not None:
85 | with torch.enable_grad():
86 | loss = closure()
87 |
88 | for group in self.param_groups:
89 | for p in group["params"]:
90 | if p.grad is None:
91 | continue
92 | grad = p.grad
93 | if grad.is_sparse:
94 | raise RuntimeError("RMSpropTF does not support sparse gradients")
95 | state = self.state[p]
96 |
97 | # State initialization
98 | if len(state) == 0:
99 | state["step"] = 0
100 | # PyTorch initialized to zeros here
101 | state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format)
102 | if group["momentum"] > 0:
103 | state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
104 | if group["centered"]:
105 | state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
106 |
107 | square_avg = state["square_avg"]
108 | alpha = group["alpha"]
109 |
110 | state["step"] += 1
111 |
112 | if group["weight_decay"] != 0:
113 | grad = grad.add(p, alpha=group["weight_decay"])
114 |
115 | square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
116 |
117 | if group["centered"]:
118 | grad_avg = state["grad_avg"]
119 | grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
120 | # PyTorch added epsilon after square root
121 | # avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps'])
122 | avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_()
123 | else:
124 | # PyTorch added epsilon after square root
125 | # avg = square_avg.sqrt().add_(group['eps'])
126 | avg = square_avg.add(group["eps"]).sqrt_()
127 |
128 | if group["momentum"] > 0:
129 | buf = state["momentum_buffer"]
130 | buf.mul_(group["momentum"]).addcdiv_(grad, avg)
131 | p.add_(buf, alpha=-group["lr"])
132 | else:
133 | p.addcdiv_(grad, avg, value=-group["lr"])
134 |
135 | return loss
136 |
--------------------------------------------------------------------------------
/stable_baselines3/common/type_aliases.py:
--------------------------------------------------------------------------------
1 | """Common aliases for type hints"""
2 |
3 | from enum import Enum
4 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
5 |
6 | import gym
7 | import numpy as np
8 | import torch as th
9 |
10 | from stable_baselines3.common import callbacks, vec_env
11 |
12 | GymEnv = Union[gym.Env, vec_env.VecEnv]
13 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
14 | GymStepReturn = Tuple[GymObs, float, bool, Dict]
15 | TensorDict = Dict[Union[str, int], th.Tensor]
16 | OptimizerStateDict = Dict[str, Any]
17 | MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
18 |
19 | # A schedule takes the remaining progress as input
20 | # and ouputs a scalar (e.g. learning rate, clip range, ...)
21 | Schedule = Callable[[float], float]
22 |
23 |
24 | class RolloutBufferSamples(NamedTuple):
25 | observations: th.Tensor
26 | actions: th.Tensor
27 | old_values: th.Tensor
28 | old_log_prob: th.Tensor
29 | advantages: th.Tensor
30 | returns: th.Tensor
31 |
32 |
33 | class DictRolloutBufferSamples(RolloutBufferSamples):
34 | observations: TensorDict
35 | actions: th.Tensor
36 | old_values: th.Tensor
37 | old_log_prob: th.Tensor
38 | advantages: th.Tensor
39 | returns: th.Tensor
40 |
41 |
42 | class ReplayBufferSamples(NamedTuple):
43 | observations: th.Tensor
44 | actions: th.Tensor
45 | next_observations: th.Tensor
46 | dones: th.Tensor
47 | rewards: th.Tensor
48 |
49 |
50 | class DictReplayBufferSamples(ReplayBufferSamples):
51 | observations: TensorDict
52 | actions: th.Tensor
53 | next_observations: th.Tensor
54 | dones: th.Tensor
55 | rewards: th.Tensor
56 |
57 |
58 | class RolloutReturn(NamedTuple):
59 | episode_timesteps: int
60 | n_episodes: int
61 | continue_training: bool
62 |
63 |
64 | class TrainFrequencyUnit(Enum):
65 | STEP = "step"
66 | EPISODE = "episode"
67 |
68 |
69 | class TrainFreq(NamedTuple):
70 | frequency: int
71 | unit: TrainFrequencyUnit # either "step" or "episode"
72 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa F401
2 | import typing
3 | from copy import deepcopy
4 | from typing import Optional, Type, Union
5 |
6 | from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
8 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
11 | from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
12 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
13 | from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
14 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
15 | from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
16 | from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
17 |
18 | # Avoid circular import
19 | if typing.TYPE_CHECKING:
20 | from stable_baselines3.common.type_aliases import GymEnv
21 |
22 |
23 | def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
24 | """
25 | Retrieve a ``VecEnvWrapper`` object by recursively searching.
26 |
27 | :param env:
28 | :param vec_wrapper_class:
29 | :return:
30 | """
31 | env_tmp = env
32 | while isinstance(env_tmp, VecEnvWrapper):
33 | if isinstance(env_tmp, vec_wrapper_class):
34 | return env_tmp
35 | env_tmp = env_tmp.venv
36 | return None
37 |
38 |
39 | def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
40 | """
41 | :param env:
42 | :return:
43 | """
44 | return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
45 |
46 |
47 | def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
48 | """
49 | Check if an environment is already wrapped by a given ``VecEnvWrapper``.
50 |
51 | :param env:
52 | :param vec_wrapper_class:
53 | :return:
54 | """
55 | return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
56 |
57 |
58 | # Define here to avoid circular import
59 | def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
60 | """
61 | Sync eval env and train env when using VecNormalize
62 |
63 | :param env:
64 | :param eval_env:
65 | """
66 | env_tmp, eval_env_tmp = env, eval_env
67 | while isinstance(env_tmp, VecEnvWrapper):
68 | if isinstance(env_tmp, VecNormalize):
69 | # Only synchronize if observation normalization exists
70 | if hasattr(env_tmp, "obs_rms"):
71 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
72 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
73 | env_tmp = env_tmp.venv
74 | eval_env_tmp = eval_env_tmp.venv
75 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/dummy_vec_env.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from copy import deepcopy
3 | from typing import Any, Callable, List, Optional, Sequence, Type, Union
4 |
5 | import gym
6 | import numpy as np
7 | # from torch import is_tensor
8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
9 | from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
10 |
11 |
12 | class DummyVecEnv(VecEnv):
13 | """
14 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
15 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
16 | as the overhead of multiprocess or multithread outweighs the environment computation time.
17 | This can also be used for RL methods that
18 | require a vectorized environment, but that you want a single environments to train with.
19 |
20 | :param env_fns: a list of functions
21 | that return environments to vectorize
22 | """
23 | def __init__(self, env_fns: List[Callable[[], gym.Env]]):
24 | self.envs = [fn() for fn in env_fns]
25 | env = self.envs[0]
26 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
27 | obs_space = env.observation_space
28 | self.keys, shapes, dtypes = obs_space_info(obs_space)
29 |
30 | self.buf_obs = OrderedDict(
31 | [(k, np.zeros((self.num_envs, ) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys]
32 | )
33 | self.buf_dones = np.zeros((self.num_envs, ), dtype=bool)
34 | self.buf_rews = np.zeros((self.num_envs, ), dtype=np.float32)
35 | self.buf_infos = [{} for _ in range(self.num_envs)]
36 | self.actions = None
37 | self.metadata = env.metadata
38 |
39 | def step_async(self, actions: np.ndarray) -> None:
40 | self.actions = actions
41 |
42 | def step_wait(self) -> VecEnvStepReturn:
43 | for env_idx in range(self.num_envs):
44 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
45 | self.actions[env_idx]
46 | )
47 | if self.buf_dones[env_idx]:
48 | # save final observation where user can get it, then reset
49 | self.buf_infos[env_idx]["terminal_observation"] = obs
50 | obs = self.envs[env_idx].reset()
51 | self._save_obs(env_idx, obs)
52 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
53 |
54 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
55 | if seed is None:
56 | seed = np.random.randint(0, 2**32 - 1)
57 | seeds = []
58 | for idx, env in enumerate(self.envs):
59 | seeds.append(env.seed(seed + idx))
60 | return seeds
61 |
62 | def reset(self) -> VecEnvObs:
63 | for env_idx in range(self.num_envs):
64 | obs = self.envs[env_idx].reset()
65 | self._save_obs(env_idx, obs)
66 | return self._obs_from_buf()
67 |
68 | def close(self) -> None:
69 | for env in self.envs:
70 | env.close()
71 |
72 | def get_images(self) -> Sequence[np.ndarray]:
73 | return [env.render(mode="rgb_array") for env in self.envs]
74 |
75 | def render(self, mode: str = "human") -> Optional[np.ndarray]:
76 | """
77 | Gym environment rendering. If there are multiple environments then
78 | they are tiled together in one image via ``BaseVecEnv.render()``.
79 | Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the
80 | underlying environment.
81 |
82 | Therefore, some arguments such as ``mode`` will have values that are valid
83 | only when ``num_envs == 1``.
84 |
85 | :param mode: The rendering type.
86 | """
87 | if self.num_envs == 1:
88 | return self.envs[0].render(mode=mode)
89 | else:
90 | return super().render(mode=mode)
91 |
92 | def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
93 | for key in self.keys:
94 | if key is None:
95 | self.buf_obs[key][env_idx] = obs
96 | else:
97 | # if is_tensor(obs[key]): # TODO
98 | # obs[key] = obs[key].cpu()
99 | self.buf_obs[key][env_idx] = obs[key]
100 |
101 | def _obs_from_buf(self) -> VecEnvObs:
102 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
103 |
104 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
105 | """Return attribute from vectorized environment (see base class)."""
106 | target_envs = self._get_target_envs(indices)
107 | return [getattr(env_i, attr_name) for env_i in target_envs]
108 |
109 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
110 | """Set attribute inside vectorized environments (see base class)."""
111 | target_envs = self._get_target_envs(indices)
112 | for env_i in target_envs:
113 | setattr(env_i, attr_name, value)
114 |
115 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
116 | """Call instance methods of vectorized environments."""
117 | target_envs = self._get_target_envs(indices)
118 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
119 |
120 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
121 | """Check if worker environments are wrapped with a given wrapper"""
122 | target_envs = self._get_target_envs(indices)
123 | # Import here to avoid a circular import
124 | from stable_baselines3.common import env_util
125 |
126 | return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
127 |
128 | def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
129 | indices = self._get_indices(indices)
130 | return [self.envs[i] for i in indices]
131 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for dealing with vectorized environments.
3 | """
4 | from collections import OrderedDict
5 | from typing import Any, Dict, List, Tuple
6 |
7 | import gym
8 | import numpy as np
9 |
10 | from stable_baselines3.common.preprocessing import check_for_nested_spaces
11 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
12 |
13 |
14 | def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
15 | """
16 | Deep-copy a dict of numpy arrays.
17 |
18 | :param obs: a dict of numpy arrays.
19 | :return: a dict of copied numpy arrays.
20 | """
21 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
22 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
23 |
24 |
25 | def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
26 | """
27 | Convert an internal representation raw_obs into the appropriate type
28 | specified by space.
29 |
30 | :param obs_space: an observation space.
31 | :param obs_dict: a dict of numpy arrays.
32 | :return: returns an observation of the same type as space.
33 | If space is Dict, function is identity; if space is Tuple, converts dict to Tuple;
34 | otherwise, space is unstructured and returns the value raw_obs[None].
35 | """
36 | if isinstance(obs_space, gym.spaces.Dict):
37 | return obs_dict
38 | elif isinstance(obs_space, gym.spaces.Tuple):
39 | assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
40 | return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
41 | else:
42 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
43 | return obs_dict[None]
44 |
45 |
46 | def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
47 | """
48 | Get dict-structured information about a gym.Space.
49 |
50 | Dict spaces are represented directly by their dict of subspaces.
51 | Tuple spaces are converted into a dict with keys indexing into the tuple.
52 | Unstructured spaces are represented by {None: obs_space}.
53 |
54 | :param obs_space: an observation space
55 | :return: A tuple (keys, shapes, dtypes):
56 | keys: a list of dict keys.
57 | shapes: a dict mapping keys to shapes.
58 | dtypes: a dict mapping keys to dtypes.
59 | """
60 | check_for_nested_spaces(obs_space)
61 | if isinstance(obs_space, gym.spaces.Dict):
62 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
63 | subspaces = obs_space.spaces
64 | elif isinstance(obs_space, gym.spaces.Tuple):
65 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
66 | else:
67 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
68 | subspaces = {None: obs_space}
69 | keys = []
70 | shapes = {}
71 | dtypes = {}
72 | for key, box in subspaces.items():
73 | keys.append(key)
74 | shapes[key] = box.shape
75 | dtypes[key] = box.dtype
76 | return keys, shapes, dtypes
77 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/vec_check_nan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 |
5 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
6 |
7 |
8 | class VecCheckNan(VecEnvWrapper):
9 | """
10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
11 | allowing you to know from what the NaN of inf originated from.
12 |
13 | :param venv: the vectorized environment to wrap
14 | :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning
15 | :param warn_once: Whether or not to only warn once.
16 | :param check_inf: Whether or not to check for +inf or -inf as well
17 | """
18 | def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True):
19 | VecEnvWrapper.__init__(self, venv)
20 | self.raise_exception = raise_exception
21 | self.warn_once = warn_once
22 | self.check_inf = check_inf
23 | self._actions = None
24 | self._observations = None
25 | self._user_warned = False
26 |
27 | def step_async(self, actions: np.ndarray) -> None:
28 | self._check_val(async_step=True, actions=actions)
29 |
30 | self._actions = actions
31 | self.venv.step_async(actions)
32 |
33 | def step_wait(self) -> VecEnvStepReturn:
34 | observations, rewards, news, infos = self.venv.step_wait()
35 |
36 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news)
37 |
38 | self._observations = observations
39 | return observations, rewards, news, infos
40 |
41 | def reset(self) -> VecEnvObs:
42 | observations = self.venv.reset()
43 | self._actions = None
44 |
45 | self._check_val(async_step=False, observations=observations)
46 |
47 | self._observations = observations
48 | return observations
49 |
50 | def _check_val(self, *, async_step: bool, **kwargs) -> None:
51 | # if warn and warn once and have warned once: then stop checking
52 | if not self.raise_exception and self.warn_once and self._user_warned:
53 | return
54 |
55 | found = []
56 | for name, val in kwargs.items():
57 | has_nan = np.any(np.isnan(val))
58 | has_inf = self.check_inf and np.any(np.isinf(val))
59 | if has_inf:
60 | found.append((name, "inf"))
61 | if has_nan:
62 | found.append((name, "nan"))
63 |
64 | if found:
65 | self._user_warned = True
66 | msg = ""
67 | for i, (name, type_val) in enumerate(found):
68 | msg += f"found {type_val} in {name}"
69 | if i != len(found) - 1:
70 | msg += ", "
71 |
72 | msg += ".\r\nOriginated from the "
73 |
74 | if not async_step:
75 | if self._actions is None:
76 | msg += "environment observation (at reset)"
77 | else:
78 | msg += f"environment, Last given value was: \r\n\taction={self._actions}"
79 | else:
80 | msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
81 |
82 | if self.raise_exception:
83 | raise ValueError(msg)
84 | else:
85 | warnings.warn(msg, UserWarning)
86 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/vec_extract_dict_obs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
4 |
5 |
6 | class VecExtractDictObs(VecEnvWrapper):
7 | """
8 | A vectorized wrapper for extracting dictionary observations.
9 |
10 | :param venv: The vectorized environment
11 | :param key: The key of the dictionary observation
12 | """
13 | def __init__(self, venv: VecEnv, key: str):
14 | self.key = key
15 | super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
16 |
17 | def reset(self) -> np.ndarray:
18 | obs = self.venv.reset()
19 | return obs[self.key]
20 |
21 | def step_wait(self) -> VecEnvStepReturn:
22 | obs, reward, done, info = self.venv.step_wait()
23 | return obs[self.key], reward, done, info
24 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/vec_frame_stack.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple, Union
2 |
3 | import numpy as np
4 | from gym import spaces
5 |
6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
7 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
8 |
9 |
10 | class VecFrameStack(VecEnvWrapper):
11 | """
12 | Frame stacking wrapper for vectorized environment. Designed for image observations.
13 |
14 | Uses the StackedObservations class, or StackedDictObservations depending on the observations space
15 |
16 | :param venv: the vectorized environment to wrap
17 | :param n_stack: Number of frames to stack
18 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
19 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
20 | Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
21 | """
22 | def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None):
23 | self.venv = venv
24 | self.n_stack = n_stack
25 |
26 | wrapped_obs_space = venv.observation_space
27 |
28 | if isinstance(wrapped_obs_space, spaces.Box):
29 | assert not isinstance(
30 | channels_order, dict
31 | ), f"Expected None or string for channels_order but received {channels_order}"
32 | self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
33 |
34 | elif isinstance(wrapped_obs_space, spaces.Dict):
35 | self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
36 |
37 | else:
38 | raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces")
39 |
40 | observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space)
41 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
42 |
43 | def step_wait(
44 | self,
45 | ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]], ]:
46 |
47 | observations, rewards, dones, infos = self.venv.step_wait()
48 |
49 | observations, infos = self.stackedobs.update(observations, dones, infos)
50 |
51 | return observations, rewards, dones, infos
52 |
53 | def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
54 | """
55 | Reset all environments
56 | """
57 | observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
58 |
59 | observation = self.stackedobs.reset(observation)
60 | return observation
61 |
62 | def close(self) -> None:
63 | self.venv.close()
64 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/vec_monitor.py:
--------------------------------------------------------------------------------
1 | import time
2 | import warnings
3 | from typing import Optional, Tuple
4 |
5 | import numpy as np
6 |
7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
8 |
9 |
10 | class VecMonitor(VecEnvWrapper):
11 | """
12 | A vectorized monitor wrapper for *vectorized* Gym environments,
13 | it is used to record the episode reward, length, time and other data.
14 |
15 | Some environments like `openai/procgen `_
16 | or `gym3 `_ directly initialize the
17 | vectorized environments, without giving us a chance to use the ``Monitor``
18 | wrapper. So this class simply does the job of the ``Monitor`` wrapper on
19 | a vectorized level.
20 |
21 | :param venv: The vectorized environment
22 | :param filename: the location to save a log file, can be None for no log
23 | :param info_keywords: extra information to log, from the information return of env.step()
24 | """
25 | def __init__(
26 | self,
27 | venv: VecEnv,
28 | filename: Optional[str] = None,
29 | info_keywords: Tuple[str, ...] = (),
30 | ):
31 | # Avoid circular import
32 | from stable_baselines3.common.monitor import Monitor, ResultsWriter
33 |
34 | # This check is not valid for special `VecEnv`
35 | # like the ones created by Procgen, that does follow completely
36 | # the `VecEnv` interface
37 | try:
38 | is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
39 | except AttributeError:
40 | is_wrapped_with_monitor = False
41 |
42 | if is_wrapped_with_monitor:
43 | warnings.warn(
44 | "The environment is already wrapped with a `Monitor` wrapper"
45 | "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
46 | "overwritten by the `VecMonitor` ones.",
47 | UserWarning,
48 | )
49 |
50 | VecEnvWrapper.__init__(self, venv)
51 | self.episode_returns = None
52 | self.episode_lengths = None
53 | self.episode_count = 0
54 | self.t_start = time.time()
55 |
56 | env_id = None
57 | if hasattr(venv, "spec") and venv.spec is not None:
58 | env_id = venv.spec.id
59 |
60 | if filename:
61 | self.results_writer = ResultsWriter(
62 | filename, header={
63 | "t_start": self.t_start,
64 | "env_id": env_id
65 | }, extra_keys=info_keywords
66 | )
67 | else:
68 | self.results_writer = None
69 | self.info_keywords = info_keywords
70 |
71 | def reset(self) -> VecEnvObs:
72 | obs = self.venv.reset()
73 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
74 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
75 | return obs
76 |
77 | def step_wait(self) -> VecEnvStepReturn:
78 | obs, rewards, dones, infos = self.venv.step_wait()
79 | self.episode_returns += rewards
80 | self.episode_lengths += 1
81 | new_infos = list(infos[:])
82 | for i in range(len(dones)):
83 | if dones[i]:
84 | info = infos[i].copy()
85 | episode_return = self.episode_returns[i]
86 | episode_length = self.episode_lengths[i]
87 | episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
88 | for key in self.info_keywords:
89 | episode_info[key] = info[key]
90 | info["episode"] = episode_info
91 | self.episode_count += 1
92 | self.episode_returns[i] = 0
93 | self.episode_lengths[i] = 0
94 | if self.results_writer:
95 | self.results_writer.write_row(episode_info)
96 | new_infos[i] = info
97 | return obs, rewards, dones, new_infos
98 |
99 | def close(self) -> None:
100 | if self.results_writer:
101 | self.results_writer.close()
102 | return self.venv.close()
103 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/vec_transpose.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Dict, Union
3 |
4 | import numpy as np
5 | from gym import spaces
6 |
7 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
9 |
10 |
11 | class VecTransposeImage(VecEnvWrapper):
12 | """
13 | Re-order channels, from HxWxC to CxHxW.
14 | It is required for PyTorch convolution layers.
15 |
16 | :param venv:
17 | :param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not,
18 | which may result in unwanted behavior, see GH issue #671.
19 | """
20 | def __init__(self, venv: VecEnv, skip: bool = False):
21 | assert is_image_space(venv.observation_space) or isinstance(
22 | venv.observation_space, spaces.dict.Dict
23 | ), "The observation space must be an image or dictionary observation space"
24 |
25 | self.skip = skip
26 | # Do nothing
27 | if skip:
28 | super().__init__(venv)
29 | return
30 |
31 | if isinstance(venv.observation_space, spaces.dict.Dict):
32 | self.image_space_keys = []
33 | observation_space = deepcopy(venv.observation_space)
34 | for key, space in observation_space.spaces.items():
35 | if is_image_space(space):
36 | # Keep track of which keys should be transposed later
37 | self.image_space_keys.append(key)
38 | observation_space.spaces[key] = self.transpose_space(space, key)
39 | else:
40 | observation_space = self.transpose_space(venv.observation_space)
41 | super().__init__(venv, observation_space=observation_space)
42 |
43 | @staticmethod
44 | def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
45 | """
46 | Transpose an observation space (re-order channels).
47 |
48 | :param observation_space:
49 | :param key: In case of dictionary space, the key of the observation space.
50 | :return:
51 | """
52 | # Sanity checks
53 | assert is_image_space(observation_space), "The observation space must be an image"
54 | assert not is_image_space_channels_first(
55 | observation_space
56 | ), f"The observation space {key} must follow the channel last convention"
57 | height, width, channels = observation_space.shape
58 | new_shape = (channels, height, width)
59 | return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype)
60 |
61 | @staticmethod
62 | def transpose_image(image: np.ndarray) -> np.ndarray:
63 | """
64 | Transpose an image or batch of images (re-order channels).
65 |
66 | :param image:
67 | :return:
68 | """
69 | if len(image.shape) == 3:
70 | return np.transpose(image, (2, 0, 1))
71 | return np.transpose(image, (0, 3, 1, 2))
72 |
73 | def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
74 | """
75 | Transpose (if needed) and return new observations.
76 |
77 | :param observations:
78 | :return: Transposed observations
79 | """
80 | # Do nothing
81 | if self.skip:
82 | return observations
83 |
84 | if isinstance(observations, dict):
85 | # Avoid modifying the original object in place
86 | observations = deepcopy(observations)
87 | for k in self.image_space_keys:
88 | observations[k] = self.transpose_image(observations[k])
89 | else:
90 | observations = self.transpose_image(observations)
91 | return observations
92 |
93 | def step_wait(self) -> VecEnvStepReturn:
94 | observations, rewards, dones, infos = self.venv.step_wait()
95 |
96 | # Transpose the terminal observations
97 | for idx, done in enumerate(dones):
98 | if not done:
99 | continue
100 | if "terminal_observation" in infos[idx]:
101 | infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
102 |
103 | return self.transpose_observations(observations), rewards, dones, infos
104 |
105 | def reset(self) -> Union[np.ndarray, Dict]:
106 | """
107 | Reset all environments
108 | """
109 | return self.transpose_observations(self.venv.reset())
110 |
111 | def close(self) -> None:
112 | self.venv.close()
113 |
--------------------------------------------------------------------------------
/stable_baselines3/common/vec_env/vec_video_recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Callable
3 |
4 | from gym.wrappers.monitoring import video_recorder
5 |
6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
8 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
9 |
10 |
11 | class VecVideoRecorder(VecEnvWrapper):
12 | """
13 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
14 | It requires ffmpeg or avconv to be installed on the machine.
15 |
16 | :param venv:
17 | :param video_folder: Where to save videos
18 | :param record_video_trigger: Function that defines when to start recording.
19 | The function takes the current number of step,
20 | and returns whether we should start recording or not.
21 | :param video_length: Length of recorded videos
22 | :param name_prefix: Prefix to the video name
23 | """
24 | def __init__(
25 | self,
26 | venv: VecEnv,
27 | video_folder: str,
28 | record_video_trigger: Callable[[int], bool],
29 | video_length: int = 200,
30 | name_prefix: str = "rl-video",
31 | ):
32 |
33 | VecEnvWrapper.__init__(self, venv)
34 |
35 | self.env = venv
36 | # Temp variable to retrieve metadata
37 | temp_env = venv
38 |
39 | # Unwrap to retrieve metadata dict
40 | # that will be used by gym recorder
41 | while isinstance(temp_env, VecEnvWrapper):
42 | temp_env = temp_env.venv
43 |
44 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
45 | metadata = temp_env.get_attr("metadata")[0]
46 | else:
47 | metadata = temp_env.metadata
48 |
49 | self.env.metadata = metadata
50 |
51 | self.record_video_trigger = record_video_trigger
52 | self.video_recorder = None
53 |
54 | self.video_folder = os.path.abspath(video_folder)
55 | # Create output folder if needed
56 | os.makedirs(self.video_folder, exist_ok=True)
57 |
58 | self.name_prefix = name_prefix
59 | self.step_id = 0
60 | self.video_length = video_length
61 |
62 | self.recording = False
63 | self.recorded_frames = 0
64 |
65 | def reset(self) -> VecEnvObs:
66 | obs = self.venv.reset()
67 | self.start_video_recorder()
68 | return obs
69 |
70 | def start_video_recorder(self) -> None:
71 | self.close_video_recorder()
72 |
73 | video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
74 | base_path = os.path.join(self.video_folder, video_name)
75 | self.video_recorder = video_recorder.VideoRecorder(
76 | env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
77 | )
78 |
79 | self.video_recorder.capture_frame()
80 | self.recorded_frames = 1
81 | self.recording = True
82 |
83 | def _video_enabled(self) -> bool:
84 | return self.record_video_trigger(self.step_id)
85 |
86 | def step_wait(self) -> VecEnvStepReturn:
87 | obs, rews, dones, infos = self.venv.step_wait()
88 |
89 | self.step_id += 1
90 | if self.recording:
91 | self.video_recorder.capture_frame()
92 | self.recorded_frames += 1
93 | if self.recorded_frames > self.video_length:
94 | print(f"Saving video to {self.video_recorder.path}")
95 | self.close_video_recorder()
96 | elif self._video_enabled():
97 | self.start_video_recorder()
98 |
99 | return obs, rews, dones, infos
100 |
101 | def close_video_recorder(self) -> None:
102 | if self.recording:
103 | self.video_recorder.close()
104 | self.recording = False
105 | self.recorded_frames = 1
106 |
107 | def close(self) -> None:
108 | VecEnvWrapper.close(self)
109 | self.close_video_recorder()
110 |
111 | def __del__(self):
112 | self.close()
113 |
--------------------------------------------------------------------------------
/stable_baselines3/ddpg/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.ddpg.ddpg import DDPG
2 | from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
3 |
--------------------------------------------------------------------------------
/stable_baselines3/ddpg/ddpg.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Type, Union
2 |
3 | import torch as th
4 |
5 | from stable_baselines3.common.buffers import ReplayBuffer
6 | from stable_baselines3.common.noise import ActionNoise
7 | from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
8 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
9 | from stable_baselines3.td3.policies import TD3Policy
10 | from stable_baselines3.td3.td3 import TD3
11 |
12 |
13 | class DDPG(TD3):
14 | """
15 | Deep Deterministic Policy Gradient (DDPG).
16 |
17 | Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf
18 | DDPG Paper: https://arxiv.org/abs/1509.02971
19 | Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html
20 |
21 | Note: we treat DDPG as a special case of its successor TD3.
22 |
23 | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
24 | :param env: The environment to learn from (if registered in Gym, can be str)
25 | :param learning_rate: learning rate for adam optimizer,
26 | the same learning rate will be used for all networks (Q-Values, Actor and Value function)
27 | it can be a function of the current progress remaining (from 1 to 0)
28 | :param buffer_size: size of the replay buffer
29 | :param learning_starts: how many steps of the model to collect transitions for before learning starts
30 | :param batch_size: Minibatch size for each gradient update
31 | :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
32 | :param gamma: the discount factor
33 | :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
34 | like ``(5, "step")`` or ``(2, "episode")``.
35 | :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
36 | Set to ``-1`` means to do as many gradient steps as steps done in the environment
37 | during the rollout.
38 | :param action_noise: the action noise type (None by default), this can help
39 | for hard exploration problem. Cf common.noise for the different action noise type.
40 | :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
41 | If ``None``, it will be automatically selected.
42 | :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
43 | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
44 | at a cost of more complexity.
45 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
46 | :param create_eval_env: Whether to create a second environment that will be
47 | used for evaluating the agent periodically. (Only available when passing string for the environment)
48 | :param policy_kwargs: additional arguments to be passed to the policy on creation
49 | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
50 | :param seed: Seed for the pseudo random generators
51 | :param device: Device (cpu, cuda, ...) on which the code should be run.
52 | Setting it to auto, the code will be run on the GPU if possible.
53 | :param _init_setup_model: Whether or not to build the network at the creation of the instance
54 | """
55 | def __init__(
56 | self,
57 | policy: Union[str, Type[TD3Policy]],
58 | env: Union[GymEnv, str],
59 | learning_rate: Union[float, Schedule] = 1e-3,
60 | buffer_size: int = 1_000_000, # 1e6
61 | learning_starts: int = 100,
62 | batch_size: int = 100,
63 | tau: float = 0.005,
64 | gamma: float = 0.99,
65 | train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
66 | gradient_steps: int = -1,
67 | action_noise: Optional[ActionNoise] = None,
68 | replay_buffer_class: Optional[ReplayBuffer] = None,
69 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
70 | optimize_memory_usage: bool = False,
71 | tensorboard_log: Optional[str] = None,
72 | create_eval_env: bool = False,
73 | policy_kwargs: Optional[Dict[str, Any]] = None,
74 | verbose: int = 0,
75 | seed: Optional[int] = None,
76 | device: Union[th.device, str] = "auto",
77 | _init_setup_model: bool = True,
78 | ):
79 |
80 | super().__init__(
81 | policy=policy,
82 | env=env,
83 | learning_rate=learning_rate,
84 | buffer_size=buffer_size,
85 | learning_starts=learning_starts,
86 | batch_size=batch_size,
87 | tau=tau,
88 | gamma=gamma,
89 | train_freq=train_freq,
90 | gradient_steps=gradient_steps,
91 | action_noise=action_noise,
92 | replay_buffer_class=replay_buffer_class,
93 | replay_buffer_kwargs=replay_buffer_kwargs,
94 | policy_kwargs=policy_kwargs,
95 | tensorboard_log=tensorboard_log,
96 | verbose=verbose,
97 | device=device,
98 | create_eval_env=create_eval_env,
99 | seed=seed,
100 | optimize_memory_usage=optimize_memory_usage,
101 | # Remove all tricks from TD3 to obtain DDPG:
102 | # we still need to specify target_policy_noise > 0 to avoid errors
103 | policy_delay=1,
104 | target_noise_clip=0.0,
105 | target_policy_noise=0.1,
106 | _init_setup_model=False,
107 | )
108 |
109 | # Use only one critic
110 | if "n_critics" not in self.policy_kwargs:
111 | self.policy_kwargs["n_critics"] = 1
112 |
113 | if _init_setup_model:
114 | self._setup_model()
115 |
116 | def learn(
117 | self,
118 | total_timesteps: int,
119 | callback: MaybeCallback = None,
120 | log_interval: int = 4,
121 | eval_env: Optional[GymEnv] = None,
122 | eval_freq: int = -1,
123 | n_eval_episodes: int = 5,
124 | tb_log_name: str = "DDPG",
125 | eval_log_path: Optional[str] = None,
126 | reset_num_timesteps: bool = True,
127 | ) -> OffPolicyAlgorithm:
128 |
129 | return super().learn(
130 | total_timesteps=total_timesteps,
131 | callback=callback,
132 | log_interval=log_interval,
133 | eval_env=eval_env,
134 | eval_freq=eval_freq,
135 | n_eval_episodes=n_eval_episodes,
136 | tb_log_name=tb_log_name,
137 | eval_log_path=eval_log_path,
138 | reset_num_timesteps=reset_num_timesteps,
139 | )
140 |
--------------------------------------------------------------------------------
/stable_baselines3/ddpg/policies.py:
--------------------------------------------------------------------------------
1 | # DDPG can be view as a special case of TD3
2 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401
3 |
--------------------------------------------------------------------------------
/stable_baselines3/dqn/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.dqn.dqn import DQN
2 | from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
3 |
--------------------------------------------------------------------------------
/stable_baselines3/her/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
2 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
3 |
--------------------------------------------------------------------------------
/stable_baselines3/her/goal_selection_strategy.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class GoalSelectionStrategy(Enum):
5 | """
6 | The strategies for selecting new goals when
7 | creating artificial transitions.
8 | """
9 |
10 | # Select a goal that was achieved
11 | # after the current step, in the same episode
12 | FUTURE = 0
13 | # Select the goal that was achieved
14 | # at the end of the episode
15 | FINAL = 1
16 | # Select a goal that was achieved in the episode
17 | EPISODE = 2
18 |
19 |
20 | # For convenience
21 | # that way, we can use string to select a strategy
22 | KEY_TO_GOAL_STRATEGY = {
23 | "future": GoalSelectionStrategy.FUTURE,
24 | "final": GoalSelectionStrategy.FINAL,
25 | "episode": GoalSelectionStrategy.EPISODE,
26 | }
27 |
--------------------------------------------------------------------------------
/stable_baselines3/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from stable_baselines3.ppo.ppo import PPO
3 | from stable_baselines3.ppo.ppo_test import PPO_Test
4 |
--------------------------------------------------------------------------------
/stable_baselines3/ppo/policies.py:
--------------------------------------------------------------------------------
1 | # This file is here just to define MlpPolicy/CnnPolicy
2 | # that work for PPO
3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
4 |
5 | MlpPolicy = ActorCriticPolicy
6 | CnnPolicy = ActorCriticCnnPolicy
7 | MultiInputPolicy = MultiInputActorCriticPolicy
8 |
--------------------------------------------------------------------------------
/stable_baselines3/sac/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from stable_baselines3.sac.sac import SAC
3 |
--------------------------------------------------------------------------------
/stable_baselines3/td3/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from stable_baselines3.td3.td3 import TD3
3 |
--------------------------------------------------------------------------------
/stable_baselines3/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | from legged_gym.env.base.base_task import BaseTask
4 | from legged_gym import OPEN_ROBOT_ROOT_DIR
5 | from gleam.wrapper.env_wrapper_gleam import EnvWrapperGLEAM
6 |
7 |
8 | root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
9 |
10 |
11 | def get_api_key_file(wandb_key_file=None):
12 | wandb_key_file = wandb_key_file or "wandb_api_key_file.txt"
13 | root = OPEN_ROBOT_ROOT_DIR
14 | path = os.path.join(root, "wandb_utils", wandb_key_file)
15 | print("We are using this wandb key file: ", path)
16 | return path
17 |
18 |
19 | def get_time_str():
20 | return datetime.datetime.now().strftime("%m%d_%H%M_%S")
21 |
22 |
23 | def is_isaac_gym_env(env):
24 | Isacc_Gym_Env = [
25 | BaseTask,
26 | EnvWrapperGLEAM,
27 | ]
28 |
29 | for target_env in Isacc_Gym_Env:
30 | if isinstance(env, target_env):
31 | return True
32 | return False
33 |
--------------------------------------------------------------------------------
/stable_baselines3/version.txt:
--------------------------------------------------------------------------------
1 | 1.6.0
2 |
--------------------------------------------------------------------------------
/wandb_utils/__init__.py:
--------------------------------------------------------------------------------
1 | project_name = "active_reconstruction"
2 | team_name = "openrobot"
3 |
--------------------------------------------------------------------------------
/wandb_utils/wandb_api_key_file.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjwzcx/GLEAM/187e496b1c97db15e51dc023612547ac05006455/wandb_utils/wandb_api_key_file.txt
--------------------------------------------------------------------------------
/wandb_utils/wandb_callback.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import wandb
5 | from wandb.sdk.lib import telemetry as wb_telemetry
6 | from stable_baselines3.utils import get_api_key_file
7 |
8 | from stable_baselines3.common.callbacks import BaseCallback
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class WandbCallback(BaseCallback):
14 | """ Log SB3 experiments to Weights and Biases
15 | - Added model tracking and uploading
16 | - Added complete hyperparameters recording
17 | - Added gradient logging
18 | - Note that `wandb.init(...)` must be called before the WandbCallback can be used
19 | Args:
20 | verbose: The verbosity of sb3 output
21 | model_save_path: Path to the folder where the model will be saved, The default value is `None` so the model is not logged
22 | model_save_freq: Frequency to save the model
23 | gradient_save_freq: Frequency to log gradient. The default value is 0 so the gradients are not logged
24 | """
25 | def __init__(
26 | self,
27 | trial_name,
28 | exp_name,
29 | project_name,
30 | config=None,
31 | verbose: int = 0,
32 | model_save_path: str = None,
33 | model_save_freq: int = 0,
34 | gradient_save_freq: int = 0,
35 | ):
36 | team_name = config["team_name"]
37 | # PZH: Setup our key
38 | WANDB_ENV_VAR = "WANDB_API_KEY"
39 | key_file_path = get_api_key_file(
40 | "wandb_lqy_key_file.text" if team_name == "lqy0057" else None
41 | ) # Search ~/wandb_api_key_file.txt first, then use PZH's
42 | with open(key_file_path, "r") as f:
43 | key = f.readline()
44 | key = key.replace("\n", "")
45 | key = key.replace(" ", "")
46 | os.environ[WANDB_ENV_VAR] = key
47 |
48 | self.run = wandb.init(
49 | id=trial_name,
50 | # id=exp_name,
51 | # name=run_name,
52 | config=config or {},
53 | resume=True,
54 | reinit=True,
55 | # allow_val_change=True,
56 | group=exp_name,
57 | project=project_name,
58 | entity=team_name,
59 | sync_tensorboard=True, # Open this and setup tb in sb3 so that we can get log!
60 | save_code=False
61 | )
62 |
63 | super(WandbCallback, self).__init__(verbose)
64 | if wandb.run is None:
65 | raise wandb.Error("You must call wandb.init() before WandbCallback()")
66 | with wb_telemetry.context() as tel:
67 | tel.feature.sb3 = True
68 | self.model_save_freq = model_save_freq
69 | self.model_save_path = model_save_path
70 | self.gradient_save_freq = gradient_save_freq
71 |
72 | # Create folder if needed
73 | if self.model_save_path is not None:
74 | os.makedirs(self.model_save_path, exist_ok=True)
75 | self.path = os.path.join(self.model_save_path, "model.zip")
76 | else:
77 | assert (
78 | self.model_save_freq == 0
79 | ), "to use the `model_save_freq` you have to set the `model_save_path` parameter"
80 |
81 | def _init_callback(self) -> None:
82 | d = {}
83 | if "algo" not in d:
84 | d["algo"] = type(self.model).__name__
85 | for key in self.model.__dict__:
86 | if key in wandb.config:
87 | continue
88 | if type(self.model.__dict__[key]) in [float, int, str]:
89 | d[key] = self.model.__dict__[key]
90 | else:
91 | d[key] = str(self.model.__dict__[key])
92 | if self.gradient_save_freq > 0:
93 | wandb.watch(self.model.policy, log_freq=self.gradient_save_freq, log="all")
94 | wandb.config.setdefaults(d)
95 |
96 | def _on_step(self) -> bool:
97 | if self.model_save_freq > 0:
98 | if self.model_save_path is not None:
99 | if self.n_calls % self.model_save_freq == 0:
100 | self.save_model()
101 | return True
102 |
103 | def _on_training_end(self) -> None:
104 | if self.model_save_path is not None:
105 | self.save_model()
106 |
107 | def save_model(self) -> None:
108 | self.model.save(self.path)
109 | wandb.save(self.path, base_path=self.model_save_path)
110 | if self.verbose > 1:
111 | logger.info("Saving model checkpoint to " + self.path)
112 |
--------------------------------------------------------------------------------