├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── abstract.png ├── learning ├── __init__.py ├── algorithms │ ├── __init__.py │ └── cassi.py ├── datasets │ ├── __init__.py │ └── motion_loader.py ├── env │ ├── __init__.py │ └── vec_env.py ├── modules │ ├── __init__.py │ ├── actor_critic.py │ ├── actor_critic_recurrent.py │ ├── discriminator.py │ ├── discriminator_ensemble.py │ └── normalizer.py ├── runners │ ├── __init__.py │ └── cassi_on_policy_runner.py ├── storage │ ├── __init__.py │ ├── replay_buffer.py │ └── rollout_storage.py └── utils │ ├── __init__.py │ └── utils.py ├── resources └── robots │ └── solo8 │ ├── datasets │ ├── motion_data.pt │ └── reference_state_idx_dict.json │ ├── meshes │ ├── solo_body.stl │ ├── solo_foot.stl │ ├── solo_lower_leg_left_side.stl │ ├── solo_lower_leg_right_side.stl │ ├── solo_upper_leg_left_side.stl │ └── solo_upper_leg_right_side.stl │ └── urdf │ └── solo8.urdf ├── scripts ├── play.py └── train.py ├── setup.py └── solo_gym ├── __init__.py ├── envs ├── __init__.py ├── base │ ├── __init__.py │ ├── legged_robot.py │ └── legged_robot_config.py ├── base_task.py └── solo8 │ ├── solo8.py │ └── solo8_config.py └── utils ├── README.md ├── __init__.py ├── base_config.py ├── helpers.py ├── keyboard_controller.py ├── logger.py ├── math.py ├── task_registry.py └── terrain.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.dae filter=lfs diff=lfs merge=lfs -text 2 | *.obj filter=lfs diff=lfs merge=lfs -text 3 | *.obj text !filter !merge !diff 4 | *.dae text !filter !merge !diff 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # These are some examples of commonly ignored file patterns. 2 | # You should customize this list as applicable to your project. 3 | # Learn more about .gitignore: 4 | # https://www.atlassian.com/git/tutorials/saving-changes/gitignore 5 | 6 | # Node artifact files 7 | node_modules/ 8 | dist/ 9 | 10 | # Compiled Java class files 11 | *.class 12 | 13 | # Compiled Python bytecode 14 | *.py[cod] 15 | 16 | # Log files 17 | *.log 18 | 19 | # Package files 20 | *.jar 21 | 22 | # Maven 23 | target/ 24 | dist/ 25 | 26 | # JetBrains IDE 27 | .idea/ 28 | 29 | # Unit test reports 30 | TEST*.xml 31 | 32 | # Generated by MacOS 33 | .DS_Store 34 | 35 | # Generated by Windows 36 | Thumbs.db 37 | 38 | # Applications 39 | *.app 40 | *.exe 41 | *.war 42 | 43 | # Large media files 44 | *.mp4 45 | *.tiff 46 | *.avi 47 | *.flv 48 | *.mov 49 | *.wmv 50 | 51 | # VS Code 52 | .vscode 53 | # logs 54 | logs 55 | runs 56 | 57 | # other 58 | *.egg-info 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, ETH Zurich, NVIDIA Corporation 2 | 3 | All rights reserved 4 | Parts of the code are released under BSD-3-Clause license. 5 | 6 | See licenses in resources/robots for license information for assets included in this repository -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CASSI with Solo 2 | 3 | This repository provides the [Cooperative Adversarial Self-supervised Skill Imitation (CASSI)](https://arxiv.org/abs/2209.07899) algorithm that enables [Solo](https://open-dynamic-robot-initiative.github.io/) to extract diverse skills through adversarial imitation from unlabeled, mixed motions using [NVIDIA Isaac Gym](https://developer.nvidia.com/isaac-gym). 4 | 5 | ![abstract](abstract.png) 6 | 7 | **Paper**: [Versatile Skill Control via Self-supervised Adversarial Imitation of Unlabeled Mixed Motions](https://arxiv.org/abs/2209.07899) 8 | **Project website**: https://sites.google.com/view/icra2023-cassi/home 9 | 10 | **Maintainer**: [Chenhao Li](https://breadli428.github.io/) 11 | **Affiliation**: [Autonomous Learning Group](https://al.is.mpg.de/), [Max Planck Institute for Intelligent Systems](https://is.mpg.de/), and [Robotic Systems Lab](https://rsl.ethz.ch/), [ETH Zurich](https://ethz.ch/en.html) 12 | **Contact**: [chenhli@ethz.ch](mailto:chenhli@ethz.ch) 13 | 14 | ## Installation 15 | 16 | 1. Create a new python virtual environment with `python 3.8` 17 | 2. Install `pytorch 1.10` with `cuda-11.3` 18 | 19 | pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 20 | 21 | 3. Install Isaac Gym 22 | 23 | - Download and install [Isaac Gym Preview 4](https://developer.nvidia.com/isaac-gym) 24 | 25 | ``` 26 | cd isaacgym/python 27 | pip install -e . 28 | ``` 29 | 30 | - Try running an example 31 | 32 | ``` 33 | cd examples 34 | python 1080_balls_of_solitude.py 35 | ``` 36 | 37 | - For troubleshooting, check docs in `isaacgym/docs/index.html` 38 | 39 | 4. Install `solo_gym` 40 | 41 | git clone https://github.com/martius-lab/cassi.git 42 | cd solo_gym 43 | pip install -e . 44 | 45 | ## Configuration 46 | - The Solo environment is defined by an env file `solo8.py` and a config file `solo8_config.py` under `solo_gym/envs/solo8/`. The config file sets both the environment parameters in class `Solo8FlatCfg` and the training parameters in class `Solo8FlatCfgPPO`. 47 | - The provided code examplifies the training of Solo 8 with [unlabeled mixed motions](https://youtu.be/SUQ_FoaJgnA?feature=shared). Demonstrations induced by 6 locomotion gaits are randomly mixed and augmented with perturbations to 6000 trajectoires with 120 frames and stored in `resources/robots/solo8/datasets/motion_data.pt`. The state dimension indices are specified in `reference_state_idx_dict.json`. To train with other demonstrations, replace `motion_data.pt` and adapt reward functions defined in `solo_gym/envs/solo8/solo8.py` accordingly. 48 | 49 | 50 | ## Usage 51 | 52 | ### Train 53 | 54 | ``` 55 | python scripts/train.py --task solo8 56 | ``` 57 | 58 | - The trained policy is saved in `logs//_/model_.pt`, where `` and `` are defined in the train config. 59 | - To disable rendering, append `--headless`. 60 | 61 | ### Play a trained policy 62 | 63 | ``` 64 | python scripts/play.py 65 | ``` 66 | 67 | - By default the loaded policy is the last model of the last run of the experiment folder. 68 | - Other runs/model iteration can be selected by setting `load_run` and `checkpoint` in the train config. 69 | - Use `u` and `j` to command the forward velocity, `h` and `k` to switch between the extracted skills. 70 | 71 | ## Citation 72 | ``` 73 | @inproceedings{li2023versatile, 74 | title={Versatile skill control via self-supervised adversarial imitation of unlabeled mixed motions}, 75 | author={Li, Chenhao and Blaes, Sebastian and Kolev, Pavel and Vlastelica, Marin and Frey, Jonas and Martius, Georg}, 76 | booktitle={2023 IEEE international conference on robotics and automation (ICRA)}, 77 | pages={2944--2950}, 78 | year={2023}, 79 | organization={IEEE} 80 | } 81 | ``` 82 | 83 | ## References 84 | 85 | The code is built upon the open-sourced [Isaac Gym Environments for Legged Robots](https://github.com/leggedrobotics/legged_gym) and the [PPO implementation](https://github.com/leggedrobotics/rsl_rl). We refer to the original repositories for more details. 86 | -------------------------------------------------------------------------------- /abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/abstract.png -------------------------------------------------------------------------------- /learning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | -------------------------------------------------------------------------------- /learning/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Implementation of different RL agents.""" 5 | 6 | from .cassi import CASSI 7 | 8 | __all__ = ["CASSI"] 9 | -------------------------------------------------------------------------------- /learning/algorithms/cassi.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # torch 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | # learning 10 | from learning.modules import ActorCritic 11 | from learning.modules.discriminator_ensemble import DiscriminatorEnsemble 12 | from learning.storage import RolloutStorage 13 | from learning.storage.replay_buffer import ReplayBuffer 14 | 15 | 16 | class CASSI: 17 | actor_critic: ActorCritic 18 | discriminator_ensemble: DiscriminatorEnsemble 19 | 20 | def __init__( 21 | self, 22 | actor_critic, 23 | discriminator, 24 | cassi_expert_data, 25 | discriminator_ensemble, 26 | num_learning_epochs=1, 27 | num_mini_batches=1, 28 | clip_param=0.2, 29 | gamma=0.998, 30 | lam=0.95, 31 | value_loss_coef=1.0, 32 | entropy_coef=0.0, 33 | policy_learning_rate=1e-3, 34 | max_grad_norm=1.0, 35 | use_clipped_value_loss=True, 36 | schedule="fixed", 37 | desired_kl=0.01, 38 | device="cpu", 39 | discriminator_learning_rate=0.000025, 40 | discriminator_momentum=0.9, 41 | discriminator_weight_decay=0.0005, 42 | discriminator_gradient_penalty_coef=5, 43 | discriminator_loss_function="MSELoss", 44 | discriminator_num_mini_batches=10, 45 | cassi_replay_buffer_size=100000, 46 | discriminator_ensemble_learning_rate=0.001, 47 | discriminator_ensemble_weight_decay=0.0005, 48 | discriminator_ensemble_num_mini_batches=10, 49 | discriminator_ensemble_replay_buffer_size=100000, 50 | **kwargs, 51 | ): 52 | if kwargs: 53 | print("CASSI.__init__ got unexpected arguments, which will be ignored: " 54 | + str([key for key in kwargs.keys()])) 55 | 56 | self.device = device 57 | 58 | self.desired_kl = desired_kl 59 | self.schedule = schedule 60 | 61 | # PPO components 62 | self.actor_critic = actor_critic 63 | self.actor_critic.to(self.device) 64 | self.storage = None # initialized later 65 | self.transition = RolloutStorage.Transition() # actor_critic transition 66 | 67 | # PPO parameters 68 | self.clip_param = clip_param 69 | self.num_learning_epochs = num_learning_epochs 70 | self.num_mini_batches = num_mini_batches 71 | self.value_loss_coef = value_loss_coef 72 | self.entropy_coef = entropy_coef 73 | self.gamma = gamma 74 | self.lam = lam 75 | self.max_grad_norm = max_grad_norm 76 | self.use_clipped_value_loss = use_clipped_value_loss 77 | self.policy_learning_rate = policy_learning_rate 78 | 79 | self.policy_optimizer = optim.Adam(self.actor_critic.parameters(), lr=self.policy_learning_rate) 80 | 81 | # Discriminator components 82 | self.discriminator = discriminator 83 | self.discriminator.to(self.device) 84 | self.cassi_policy_data = ReplayBuffer(discriminator.observation_dim, discriminator.observation_horizon, cassi_replay_buffer_size, device) 85 | self.cassi_expert_data = cassi_expert_data 86 | self.cassi_state_normalizer = discriminator.state_normalizer 87 | self.cassi_style_reward_normalizer = discriminator.reward_normalizer 88 | 89 | # Discriminator parameters 90 | self.discriminator_learning_rate = discriminator_learning_rate 91 | self.discriminator_momentum = discriminator_momentum 92 | self.discriminator_weight_decay = discriminator_weight_decay 93 | self.discriminator_gradient_penalty_coef = discriminator_gradient_penalty_coef 94 | self.discriminator_loss_function = discriminator_loss_function 95 | self.discriminator_num_mini_batches = discriminator_num_mini_batches 96 | 97 | if self.discriminator_loss_function == "WassersteinLoss": 98 | discriminator_optimizer = optim.RMSprop 99 | else: 100 | discriminator_optimizer = optim.SGD 101 | self.discriminator_optimizer = discriminator_optimizer( 102 | self.discriminator.parameters(), 103 | lr=self.discriminator_learning_rate, 104 | momentum=self.discriminator_momentum, 105 | weight_decay=self.discriminator_weight_decay, 106 | ) 107 | 108 | # Discriminator Ensemble components 109 | self.discriminator_ensemble = discriminator_ensemble 110 | self.discriminator_ensemble_num_classes = discriminator_ensemble.num_classes 111 | self.discriminator_ensemble_ensemble_size = discriminator_ensemble.ensemble_size 112 | self.discriminator_ensemble_incremental_input = discriminator_ensemble.incremental_input 113 | self.discriminator_ensemble.to(self.device) 114 | self.discriminator_ensemble_state_transitions = [] 115 | for _ in range(self.discriminator_ensemble_num_classes): 116 | self.discriminator_ensemble_state_transitions.append(ReplayBuffer(discriminator_ensemble.observation_dim, discriminator_ensemble.observation_horizon, discriminator_ensemble_replay_buffer_size, device)) 117 | 118 | # Discriminator Ensemble parameters 119 | self.discriminator_ensemble_learning_rate = discriminator_ensemble_learning_rate 120 | self.discriminator_ensemble_weight_decay = discriminator_ensemble_weight_decay 121 | self.discriminator_ensemble_num_mini_batches = discriminator_ensemble_num_mini_batches 122 | 123 | discriminator_ensemble_optimizer = optim.Adam 124 | self.discriminator_ensemble_optimizer = [discriminator_ensemble_optimizer( 125 | d.parameters(), 126 | lr=self.discriminator_ensemble_learning_rate, 127 | weight_decay=self.discriminator_ensemble_weight_decay, 128 | ) for d in self.discriminator_ensemble.ensemble] 129 | 130 | def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape): 131 | self.storage = RolloutStorage( 132 | num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device 133 | ) 134 | 135 | def test_mode(self): 136 | self.actor_critic.test() 137 | 138 | def train_mode(self): 139 | self.actor_critic.train() 140 | 141 | def act(self, obs, critic_obs, cassi_observation_buf, discriminator_ensemble_observation_buf): 142 | if self.actor_critic.is_recurrent: 143 | self.transition.hidden_states = self.actor_critic.get_hidden_states() 144 | # Compute the actions and values 145 | self.transition.actions = self.actor_critic.act(obs).detach() 146 | self.transition.values = self.actor_critic.evaluate(critic_obs).detach() 147 | self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach() 148 | self.transition.action_mean = self.actor_critic.action_mean.detach() 149 | self.transition.action_sigma = self.actor_critic.action_std.detach() 150 | # need to record obs and critic_obs before env.step() 151 | self.transition.observations = obs 152 | self.transition.critic_observations = critic_obs 153 | self.cassi_observation_buf = cassi_observation_buf.clone() 154 | self.discriminator_ensemble_observation_buf = discriminator_ensemble_observation_buf.clone() 155 | return self.transition.actions 156 | 157 | def process_env_step(self, rewards, dones, infos, cassi_obs, dis_obs, style_selector): 158 | self.transition.rewards = rewards.clone() 159 | self.transition.dones = dones 160 | # Bootstrapping on time outs 161 | if "time_outs" in infos: 162 | self.transition.rewards += self.gamma * torch.squeeze( 163 | self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1 164 | ) 165 | 166 | # Record the transition 167 | self.storage.add_transitions(self.transition) 168 | cassi_observation_buf = torch.cat((self.cassi_observation_buf[:, 1:], cassi_obs.unsqueeze(1)), dim=1) 169 | self.cassi_policy_data.insert(cassi_observation_buf) 170 | discriminator_ensemble_observation_buf = torch.cat((self.discriminator_ensemble_observation_buf[:, 1:], dis_obs.unsqueeze(1)), dim=1) 171 | for i in range(self.discriminator_ensemble_num_classes): 172 | self.discriminator_ensemble_state_transitions[i].insert(discriminator_ensemble_observation_buf[style_selector==i]) 173 | self.transition.clear() 174 | self.actor_critic.reset(dones) 175 | 176 | def compute_returns(self, last_critic_obs): 177 | last_values = self.actor_critic.evaluate(last_critic_obs).detach() 178 | self.storage.compute_returns(last_values, self.gamma, self.lam) 179 | 180 | def update(self): 181 | mean_value_loss = 0 182 | mean_surrogate_loss = 0 183 | mean_cassi_loss = 0 184 | mean_grad_pen_loss = 0 185 | mean_policy_pred = 0 186 | mean_expert_pred = 0 187 | mean_discriminator_ensemble_loss = torch.zeros(self.discriminator_ensemble_ensemble_size, dtype=torch.float, device=self.device) 188 | total = torch.zeros(self.discriminator_ensemble_ensemble_size, dtype=torch.int, device=self.device) 189 | correct = torch.zeros(self.discriminator_ensemble_ensemble_size, dtype=torch.int, device=self.device) 190 | 191 | # Policy update 192 | if self.actor_critic.is_recurrent: 193 | generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) 194 | else: 195 | generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) 196 | for ( 197 | obs_batch, 198 | critic_obs_batch, 199 | actions_batch, 200 | target_values_batch, 201 | advantages_batch, 202 | returns_batch, 203 | old_actions_log_prob_batch, 204 | old_mu_batch, 205 | old_sigma_batch, 206 | hid_states_batch, 207 | masks_batch, 208 | ) in generator: 209 | 210 | self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) 211 | actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch) 212 | value_batch = self.actor_critic.evaluate( 213 | critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1] 214 | ) 215 | mu_batch = self.actor_critic.action_mean 216 | sigma_batch = self.actor_critic.action_std 217 | entropy_batch = self.actor_critic.entropy 218 | 219 | # KL 220 | if self.desired_kl is not None and self.schedule == "adaptive": 221 | with torch.inference_mode(): 222 | kl = torch.sum( 223 | torch.log(sigma_batch / old_sigma_batch + 1.0e-5) 224 | + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) 225 | / (2.0 * torch.square(sigma_batch)) 226 | - 0.5, 227 | axis=-1, 228 | ) 229 | kl_mean = torch.mean(kl) 230 | 231 | if kl_mean > self.desired_kl * 2.0: 232 | self.policy_learning_rate = max(1e-5, self.policy_learning_rate / 1.5) 233 | elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: 234 | self.policy_learning_rate = min(1e-2, self.policy_learning_rate * 1.5) 235 | 236 | for param_group in self.policy_optimizer.param_groups: 237 | param_group["lr"] = self.policy_learning_rate 238 | 239 | # Surrogate loss 240 | ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) 241 | surrogate = -torch.squeeze(advantages_batch) * ratio 242 | surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp( 243 | ratio, 1.0 - self.clip_param, 1.0 + self.clip_param 244 | ) 245 | surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() 246 | 247 | # Value function loss 248 | if self.use_clipped_value_loss: 249 | value_clipped = target_values_batch + (value_batch - target_values_batch).clamp( 250 | -self.clip_param, self.clip_param 251 | ) 252 | value_losses = (value_batch - returns_batch).pow(2) 253 | value_losses_clipped = (value_clipped - returns_batch).pow(2) 254 | value_loss = torch.max(value_losses, value_losses_clipped).mean() 255 | else: 256 | value_loss = (returns_batch - value_batch).pow(2).mean() 257 | 258 | ppo_loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean() 259 | 260 | # Gradient step 261 | self.policy_optimizer.zero_grad() 262 | ppo_loss.backward() 263 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm) 264 | self.policy_optimizer.step() 265 | 266 | mean_value_loss += value_loss.item() 267 | mean_surrogate_loss += surrogate_loss.item() 268 | 269 | # Discriminator update 270 | cassi_policy_generator = self.cassi_policy_data.feed_forward_generator( 271 | self.discriminator_num_mini_batches, 272 | self.storage.num_envs * self.storage.num_transitions_per_env // self.discriminator_num_mini_batches) 273 | cassi_expert_generator = self.cassi_expert_data.feed_forward_generator( 274 | self.discriminator_num_mini_batches, 275 | self.storage.num_envs * self.storage.num_transitions_per_env // self.discriminator_num_mini_batches) 276 | 277 | for sample_cassi_policy, sample_cassi_expert in zip(cassi_policy_generator, cassi_expert_generator): 278 | 279 | # Discriminator loss 280 | policy_state_buf = torch.zeros_like(sample_cassi_policy) 281 | expert_state_buf = torch.zeros_like(sample_cassi_expert) 282 | if self.cassi_state_normalizer is not None: 283 | for i in range(self.discriminator.observation_horizon): 284 | with torch.no_grad(): 285 | policy_state_buf[:, i] = self.cassi_state_normalizer.normalize(sample_cassi_policy[:, i]) 286 | expert_state_buf[:, i] = self.cassi_state_normalizer.normalize(sample_cassi_expert[:, i]) 287 | policy_d = self.discriminator(policy_state_buf.flatten(1, 2)) 288 | expert_d = self.discriminator(expert_state_buf.flatten(1, 2)) 289 | if self.discriminator_loss_function == "BCEWithLogitsLoss": 290 | expert_loss = torch.nn.BCEWithLogitsLoss()(expert_d, torch.ones_like(expert_d)) 291 | policy_loss = torch.nn.BCEWithLogitsLoss()(policy_d, torch.zeros_like(policy_d)) 292 | elif self.discriminator_loss_function == "MSELoss": 293 | expert_loss = torch.nn.MSELoss()(expert_d, torch.ones(expert_d.size(), device=self.device)) 294 | policy_loss = torch.nn.MSELoss()(policy_d, -1 * torch.ones(policy_d.size(), device=self.device)) 295 | elif self.discriminator_loss_function == "WassersteinLoss": 296 | expert_loss = -expert_d.mean() 297 | policy_loss = policy_d.mean() 298 | else: 299 | raise ValueError("Unexpected loss function specified") 300 | cassi_loss = 0.5 * (expert_loss + policy_loss) 301 | grad_pen_loss = self.discriminator.compute_grad_pen(sample_cassi_expert, 302 | lambda_=self.discriminator_gradient_penalty_coef) 303 | 304 | # Gradient step 305 | discriminator_loss = cassi_loss + grad_pen_loss 306 | self.discriminator_optimizer.zero_grad() 307 | discriminator_loss.backward() 308 | self.discriminator_optimizer.step() 309 | 310 | if self.cassi_state_normalizer is not None: 311 | self.cassi_state_normalizer.update(sample_cassi_policy[:, 0]) 312 | self.cassi_state_normalizer.update(sample_cassi_expert[:, 0]) 313 | 314 | mean_cassi_loss += cassi_loss.item() 315 | mean_grad_pen_loss += grad_pen_loss.item() 316 | mean_policy_pred += policy_d.mean().item() 317 | mean_expert_pred += expert_d.mean().item() 318 | 319 | # Discriminator Ensemble update 320 | discriminator_ensemble_data_generator = [] 321 | for i in range(self.discriminator_ensemble_num_classes): 322 | discriminator_ensemble_data_generator.append( 323 | self.discriminator_ensemble_state_transitions[i].feed_forward_generator( 324 | self.discriminator_ensemble_num_mini_batches, 325 | self.storage.num_envs * self.storage.num_transitions_per_env // self.discriminator_ensemble_num_mini_batches 326 | ) 327 | ) 328 | for batch in zip(*discriminator_ensemble_data_generator): 329 | batch_sample_collection = [] 330 | batch_label_collection = [] 331 | correct_batch = torch.zeros(self.discriminator_ensemble_ensemble_size, dtype=torch.int, device=self.device) 332 | total_batch = torch.zeros(self.discriminator_ensemble_ensemble_size, dtype=torch.int, device=self.device) 333 | 334 | for i in range(self.discriminator_ensemble_num_classes): 335 | class_sample = batch[i] 336 | class_label = torch.ones(class_sample.size(0), dtype=torch.long, device=self.device) * i 337 | batch_sample_collection.append(class_sample) 338 | batch_label_collection.append(class_label) 339 | 340 | batch_sample = torch.cat(batch_sample_collection, dim=0) 341 | batch_label = torch.cat(batch_label_collection, dim=0) 342 | 343 | if self.discriminator_ensemble_incremental_input: 344 | ensemble_out, ensemble_indices = self.discriminator_ensemble((batch_sample - batch_sample[:, :1, :]).flatten(1, 2)) 345 | else: 346 | ensemble_out, ensemble_indices = self.discriminator_ensemble((batch_sample).flatten(1, 2)) 347 | for i in range(self.discriminator_ensemble_ensemble_size): 348 | out = ensemble_out[i] 349 | idx = ensemble_indices[i] 350 | label = batch_label[idx] 351 | _, pred = torch.max(out.data, dim=1) 352 | correct_batch[i] += (pred == label).sum().item() 353 | total_batch[i] += label.size(0) 354 | discriminator_ensemble_loss = torch.nn.CrossEntropyLoss()( 355 | out, 356 | label 357 | ) 358 | # Gradient step 359 | self.discriminator_ensemble_optimizer[i].zero_grad() 360 | discriminator_ensemble_loss.backward() 361 | self.discriminator_ensemble_optimizer[i].step() 362 | mean_discriminator_ensemble_loss[i] += discriminator_ensemble_loss.item() 363 | total[i] += total_batch[i] 364 | correct[i] += correct_batch[i] 365 | 366 | policy_num_updates = self.num_learning_epochs * self.num_mini_batches 367 | mean_value_loss /= policy_num_updates 368 | mean_surrogate_loss /= policy_num_updates 369 | 370 | discriminator_num_updates = self.discriminator_num_mini_batches 371 | mean_cassi_loss /= discriminator_num_updates 372 | mean_grad_pen_loss /= discriminator_num_updates 373 | mean_policy_pred /= discriminator_num_updates 374 | mean_expert_pred /= discriminator_num_updates 375 | 376 | discriminator_ensemble_num_updates = self.discriminator_ensemble_num_mini_batches 377 | mean_discriminator_ensemble_loss /= discriminator_ensemble_num_updates 378 | discriminator_ensemble_accuracy = correct / total 379 | 380 | self.storage.clear() 381 | 382 | return mean_value_loss, mean_surrogate_loss, mean_cassi_loss, mean_grad_pen_loss, mean_policy_pred, mean_expert_pred, mean_discriminator_ensemble_loss, discriminator_ensemble_accuracy 383 | -------------------------------------------------------------------------------- /learning/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/learning/datasets/__init__.py -------------------------------------------------------------------------------- /learning/datasets/motion_loader.py: -------------------------------------------------------------------------------- 1 | from solo_gym import LEGGED_GYM_ROOT_DIR 2 | from isaacgym.torch_utils import ( 3 | quat_mul, 4 | quat_conjugate, 5 | normalize, 6 | quat_from_angle_axis, 7 | ) 8 | import os 9 | import json 10 | import torch 11 | 12 | class MotionLoader: 13 | 14 | def __init__(self, device, motion_file=None, corruption_level=0.0, reference_observation_horizon=2, test_mode=False, test_observation_dim=None): 15 | self.device = device 16 | self.reference_observation_horizon = reference_observation_horizon 17 | if motion_file is None: 18 | motion_file = LEGGED_GYM_ROOT_DIR + "/resources/robots/anymal_c/datasets/motion_data.pt" 19 | self.reference_state_idx_dict_file = os.path.join(os.path.dirname(motion_file), "reference_state_idx_dict.json") 20 | with open(self.reference_state_idx_dict_file, 'r') as f: 21 | self.state_idx_dict = json.load(f) 22 | self.observation_dim = sum([ids[1] - ids[0] for state, ids in self.state_idx_dict.items() if ((state != "base_pos") and (state != "base_quat"))]) 23 | self.observation_start_dim = self.state_idx_dict["base_lin_vel"][0] 24 | loaded_data = torch.load(motion_file, map_location=self.device) 25 | 26 | # Normalize and standardize quaternions 27 | base_quat = normalize(loaded_data[:, :, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]]) 28 | base_quat[base_quat[:, :, -1] < 0] = -base_quat[base_quat[:, :, -1] < 0] 29 | loaded_data[:, :, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]] = base_quat 30 | 31 | # Load data for DTW 32 | motion_file_dtw = os.path.join(os.path.dirname(motion_file), "motion_data_original.pt") 33 | try: 34 | self.dtw_reference = torch.load(motion_file_dtw, map_location=self.device)[:, :, self.observation_start_dim:] 35 | print(f"[MotionLoader] Loaded DTW reference motion clips.") 36 | except: 37 | self.dtw_reference = None 38 | print(f"[MotionLoader] No DTW reference motion clips provided.") 39 | 40 | self.data = self._data_corruption(loaded_data, level=corruption_level) 41 | self.num_motion_clips, self.num_steps, self.reference_full_dim = self.data.size() 42 | print(f"[MotionLoader] Loaded {self.num_motion_clips} motion clips from {motion_file}. Each records {self.num_steps} steps and {self.reference_full_dim} states.") 43 | 44 | # Preload transitions 45 | self.num_preload_transitions = 500000 46 | motion_clip_sample_ids = torch.randint(0, self.num_motion_clips, (self.num_preload_transitions,), device=self.device) 47 | step_sample = torch.rand(self.num_preload_transitions, device=self.device) * (self.num_steps - self.reference_observation_horizon) 48 | self.preloaded_states = torch.zeros( 49 | self.num_preload_transitions, 50 | self.reference_observation_horizon, 51 | self.reference_full_dim, 52 | dtype=torch.float, 53 | device=self.device, 54 | requires_grad=False 55 | ) 56 | for i in range(self.reference_observation_horizon): 57 | self.preloaded_states[:, i] = self._get_frame_at_step(motion_clip_sample_ids, step_sample + i) 58 | 59 | if test_mode: 60 | self.observation_dim = test_observation_dim 61 | 62 | def _data_corruption(self, loaded_data, level=0): 63 | if level == 0: 64 | print(f"[MotionLoader] Proceeded without processing the loaded data.") 65 | else: 66 | loaded_data = self._rand_dropout(loaded_data, level) 67 | loaded_data = self._rand_noise(loaded_data, level) 68 | loaded_data = self._rand_interpolation(loaded_data, level) 69 | loaded_data = self._rand_duplication(loaded_data, level) 70 | return loaded_data 71 | 72 | def _rand_dropout(self, data, level=0): 73 | num_motion_clips, num_steps, reference_full_dim = data.size() 74 | num_dropouts = round(num_steps * level) 75 | if num_dropouts == 0: 76 | return data 77 | dropped_data = torch.zeros(num_motion_clips, num_steps - num_dropouts, reference_full_dim, dtype=torch.float, device=self.device, requires_grad=False) 78 | for i in range(num_motion_clips): 79 | step_ids = torch.randperm(num_steps)[:-num_dropouts].sort()[0] 80 | dropped_data[i] = data[i, step_ids] 81 | return dropped_data 82 | 83 | def _rand_interpolation(self, data, level=0): 84 | num_motion_clips, num_steps, reference_full_dim = data.size() 85 | num_interpolations = round((num_steps - 2) * level) 86 | if num_interpolations == 0: 87 | return data 88 | interpolated_data = data 89 | for i in range(num_motion_clips): 90 | step_ids = torch.randperm(num_steps) 91 | step_ids = step_ids[(step_ids != 0) * (step_ids != num_steps - 1)] 92 | step_ids = step_ids[:num_interpolations].sort()[0] 93 | interpolated_data[i, step_ids] = self.slerp(data[i, step_ids - 1], data[i, step_ids + 1], 0.5) 94 | interpolated_data[i, step_ids, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]] = self.quaternion_slerp( 95 | data[i, step_ids - 1, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]], 96 | data[i, step_ids + 1, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]], 97 | 0.5 98 | ) 99 | return interpolated_data 100 | 101 | def _rand_duplication(self, data, level=0): 102 | num_motion_clips, num_steps, reference_full_dim = data.size() 103 | num_duplications = round(num_steps * level) * 10 104 | if num_duplications == 0: 105 | return data 106 | duplicated_data = torch.zeros(num_motion_clips, num_steps + num_duplications, reference_full_dim, dtype=torch.float, device=self.device, requires_grad=False) 107 | step_ids = torch.randint(0, num_steps, (num_motion_clips, num_duplications), device=self.device) 108 | for i in range(num_motion_clips): 109 | duplicated_step_ids = torch.cat((torch.arange(num_steps, device=self.device), step_ids[i])).sort()[0] 110 | duplicated_data[i] = data[i, duplicated_step_ids] 111 | return duplicated_data 112 | 113 | def _rand_noise(self, data, level=0): 114 | noise_scales_dict = { 115 | "base_pos": 0.1, 116 | "base_quat": 0.01, 117 | "base_lin_vel": 0.1, 118 | "base_ang_vel": 0.2, 119 | "projected_gravity": 0.05, 120 | "base_height": 0.1, 121 | "dof_pos": 0.01, 122 | "dof_vel": 1.5 123 | } 124 | noise_scale_vec = torch.zeros_like(data[0, 0], device=self.device, dtype=torch.float, requires_grad=False) 125 | for key, value in self.state_idx_dict.items(): 126 | noise_scale_vec[value[0]:value[1]] = noise_scales_dict[key] * level 127 | data += (2 * torch.randn_like(data) - 1) * noise_scale_vec 128 | return data 129 | 130 | def _get_frame_at_step(self, motion_clip_sample_ids, step_sample): 131 | step_low, step_high = step_sample.floor().long(), step_sample.ceil().long() 132 | blend = (step_sample - step_low).unsqueeze(-1) 133 | frame = self.slerp(self.data[motion_clip_sample_ids, step_low], self.data[motion_clip_sample_ids, step_high], blend) 134 | frame[:, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]] = self.quaternion_slerp( 135 | self.data[motion_clip_sample_ids, step_low, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]], 136 | self.data[motion_clip_sample_ids, step_high, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]], 137 | blend 138 | ) 139 | return frame 140 | 141 | def get_frames(self, num_frames): 142 | ids = torch.randint(0, self.num_preload_transitions, (num_frames,), device=self.device) 143 | return self.preloaded_states[ids, 0] 144 | 145 | def slerp(self, value_low, value_high, blend): 146 | return (1.0 - blend) * value_low + blend * value_high 147 | 148 | def quaternion_slerp(self, quat_low, quat_high, blend): 149 | relative_quat = normalize(quat_mul(quat_high, quat_conjugate(quat_low))) 150 | angle = 2 * torch.acos(relative_quat[:, -1]).unsqueeze(-1) 151 | axis = normalize(relative_quat[:, :3]) 152 | angle_slerp = self.slerp(torch.zeros_like(angle), angle, blend).squeeze(-1) 153 | relative_quat_slerp = quat_from_angle_axis(angle_slerp, axis) 154 | return normalize(quat_mul(relative_quat_slerp, quat_low)) 155 | 156 | def feed_forward_generator(self, num_mini_batch, mini_batch_size): 157 | for _ in range(num_mini_batch): 158 | ids = torch.randint(0, self.num_preload_transitions, (mini_batch_size,), device=self.device) 159 | states = self.preloaded_states[ids, :, self.observation_start_dim:] 160 | yield states 161 | 162 | def get_base_pos(self, frames): 163 | if "base_pos" in self.state_idx_dict: 164 | return frames[:, self.state_idx_dict["base_pos"][0]:self.state_idx_dict["base_pos"][1]] 165 | else: 166 | raise Exception("[MotionLoader] base_pos not specified in the state_idx_dict") 167 | 168 | def get_base_quat(self, frames): 169 | if "base_quat" in self.state_idx_dict: 170 | return frames[:, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]] 171 | else: 172 | raise Exception("[MotionLoader] base_quat not specified in the state_idx_dict") 173 | 174 | def get_base_lin_vel(self, frames): 175 | if "base_lin_vel" in self.state_idx_dict: 176 | return frames[:, self.state_idx_dict["base_lin_vel"][0]:self.state_idx_dict["base_lin_vel"][1]] 177 | else: 178 | raise Exception("[MotionLoader] base_lin_vel not specified in the state_idx_dict") 179 | 180 | def get_base_ang_vel(self, frames): 181 | if "base_ang_vel" in self.state_idx_dict: 182 | return frames[:, self.state_idx_dict["base_ang_vel"][0]:self.state_idx_dict["base_ang_vel"][1]] 183 | else: 184 | raise Exception("[MotionLoader] base_ang_vel not specified in the state_idx_dict") 185 | 186 | def get_projected_gravity(self, frames): 187 | if "projected_gravity" in self.state_idx_dict: 188 | return frames[:, self.state_idx_dict["projected_gravity"][0]:self.state_idx_dict["projected_gravity"][1]] 189 | else: 190 | raise Exception("[MotionLoader] projected_gravity not specified in the state_idx_dict") 191 | 192 | def get_dof_pos(self, frames): 193 | if "dof_pos" in self.state_idx_dict: 194 | return frames[:, self.state_idx_dict["dof_pos"][0]:self.state_idx_dict["dof_pos"][1]] 195 | else: 196 | raise Exception("[MotionLoader] dof_pos not specified in the state_idx_dict") 197 | 198 | def get_dof_vel(self, frames): 199 | if "dof_vel" in self.state_idx_dict: 200 | return frames[:, self.state_idx_dict["dof_vel"][0]:self.state_idx_dict["dof_vel"][1]] 201 | else: 202 | raise Exception("[MotionLoader] dof_vel not specified in the state_idx_dict") 203 | 204 | def get_feet_pos(self, frames): 205 | if "feet_pos" in self.state_idx_dict: 206 | return frames[:, self.state_idx_dict["feet_pos"][0]:self.state_idx_dict["feet_pos"][1]] 207 | else: 208 | raise Exception("[MotionLoader] feet_pos not specified in the state_idx_dict") 209 | -------------------------------------------------------------------------------- /learning/env/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | """Submodule defining the environment definitions.""" 4 | 5 | from .vec_env import VecEnv 6 | 7 | __all__ = ["VecEnv"] 8 | -------------------------------------------------------------------------------- /learning/env/vec_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # python 5 | from abc import ABC, abstractmethod 6 | from typing import Tuple, Union 7 | 8 | # torch 9 | import torch 10 | 11 | 12 | # minimal interface of the environment 13 | class VecEnv(ABC): 14 | """Abstract class for vectorized environment.""" 15 | 16 | num_envs: int 17 | num_obs: int 18 | num_privileged_obs: int 19 | num_actions: int 20 | max_episode_length: int 21 | privileged_obs_buf: torch.Tensor 22 | obs_buf: torch.Tensor 23 | rew_buf: torch.Tensor 24 | reset_buf: torch.Tensor 25 | episode_length_buf: torch.Tensor # current episode duration 26 | extras: dict 27 | device: torch.device 28 | 29 | """ 30 | Properties 31 | """ 32 | 33 | @abstractmethod 34 | def get_observations(self) -> torch.Tensor: 35 | pass 36 | 37 | @abstractmethod 38 | def get_privileged_observations(self) -> Union[torch.Tensor, None]: 39 | pass 40 | 41 | """ 42 | Operations. 43 | """ 44 | 45 | @abstractmethod 46 | def step( 47 | self, actions: torch.Tensor 48 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]: 49 | """Apply input action on the environment. 50 | 51 | Args: 52 | actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions) 53 | 54 | Returns: 55 | Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, dict]: 56 | A tuple containing the observations, privileged observations, rewards, dones and 57 | extra information (metrics). 58 | """ 59 | raise NotImplementedError 60 | 61 | @abstractmethod 62 | def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 63 | """Reset all environment instances. 64 | 65 | Returns: 66 | Tuple[torch.Tensor, torch.Tensor | None]: Tuple containing the observations and privileged observations. 67 | """ 68 | raise NotImplementedError 69 | -------------------------------------------------------------------------------- /learning/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Definitions for neural-network components for RL-agents.""" 5 | 6 | from .actor_critic import ActorCritic 7 | from .actor_critic_recurrent import ActorCriticRecurrent 8 | from .normalizer import Normalizer 9 | 10 | __all__ = ["ActorCritic", "ActorCriticRecurrent", "Normalizer"] 11 | -------------------------------------------------------------------------------- /learning/modules/actor_critic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # torch 5 | import torch 6 | import torch.nn as nn 7 | from torch.distributions import Normal 8 | from learning.modules.normalizer import EmpiricalNormalization 9 | 10 | 11 | class ActorCritic(nn.Module): 12 | is_recurrent = False 13 | 14 | def __init__( 15 | self, 16 | num_actor_obs, 17 | num_critic_obs, 18 | num_actions, 19 | actor_hidden_dims=[256, 256, 256], 20 | critic_hidden_dims=[256, 256, 256], 21 | activation="elu", 22 | init_noise_std=1.0, 23 | update_obs_norm=True, 24 | **kwargs, 25 | ): 26 | if kwargs: 27 | print( 28 | "ActorCritic.__init__ got unexpected arguments, which will be ignored: " 29 | + str([key for key in kwargs.keys()]) 30 | ) 31 | super(ActorCritic, self).__init__() 32 | activation = get_activation(activation) 33 | 34 | mlp_input_dim_a = num_actor_obs 35 | mlp_input_dim_c = num_critic_obs 36 | 37 | # Policy 38 | actor_layers = [] 39 | actor_layers.append( 40 | EmpiricalNormalization(shape=[mlp_input_dim_a], update_obs_norm=update_obs_norm, until=1.0e8) 41 | ) 42 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 43 | actor_layers.append(activation) 44 | for layer_index in range(len(actor_hidden_dims)): 45 | if layer_index == len(actor_hidden_dims) - 1: 46 | actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions)) 47 | else: 48 | actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1])) 49 | actor_layers.append(activation) 50 | self.actor = nn.Sequential(*actor_layers) 51 | 52 | # Value function 53 | critic_layers = [] 54 | critic_layers.append( 55 | EmpiricalNormalization(shape=[mlp_input_dim_c], update_obs_norm=update_obs_norm, until=1.0e8) 56 | ) 57 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 58 | critic_layers.append(activation) 59 | for layer_index in range(len(critic_hidden_dims)): 60 | if layer_index == len(critic_hidden_dims) - 1: 61 | critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1)) 62 | else: 63 | critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1])) 64 | critic_layers.append(activation) 65 | self.critic = nn.Sequential(*critic_layers) 66 | 67 | print(f"Actor MLP: {self.actor}") 68 | print(f"Critic MLP: {self.critic}") 69 | 70 | # Action noise 71 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 72 | self.distribution = None 73 | # disable args validation for speedup 74 | Normal.set_default_validate_args = False 75 | 76 | # seems that we get better performance without init 77 | # self.init_memory_weights(self.memory_a, 0.001, 0.) 78 | # self.init_memory_weights(self.memory_c, 0.001, 0.) 79 | 80 | @staticmethod 81 | # not used at the moment 82 | def init_weights(sequential, scales): 83 | [ 84 | torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) 85 | for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear)) 86 | ] 87 | 88 | def reset(self, dones=None): 89 | pass 90 | 91 | def forward(self): 92 | raise NotImplementedError 93 | 94 | @property 95 | def action_mean(self): 96 | return self.distribution.mean 97 | 98 | @property 99 | def action_std(self): 100 | return self.distribution.stddev 101 | 102 | @property 103 | def entropy(self): 104 | return self.distribution.entropy().sum(dim=-1) 105 | 106 | def update_distribution(self, observations): 107 | mean = self.actor(observations) 108 | self.distribution = Normal(mean, mean * 0.0 + self.std) 109 | 110 | def act(self, observations, **kwargs): 111 | self.update_distribution(observations) 112 | return self.distribution.sample() 113 | 114 | def get_actions_log_prob(self, actions): 115 | return self.distribution.log_prob(actions).sum(dim=-1) 116 | 117 | def act_inference(self, observations): 118 | actions_mean = self.actor(observations) 119 | return actions_mean 120 | 121 | def evaluate(self, critic_observations, **kwargs): 122 | value = self.critic(critic_observations) 123 | return value 124 | 125 | 126 | def get_activation(act_name): 127 | if act_name == "elu": 128 | return nn.ELU() 129 | elif act_name == "selu": 130 | return nn.SELU() 131 | elif act_name == "relu": 132 | return nn.ReLU() 133 | elif act_name == "crelu": 134 | return nn.ReLU() 135 | elif act_name == "lrelu": 136 | return nn.LeakyReLU() 137 | elif act_name == "tanh": 138 | return nn.Tanh() 139 | elif act_name == "sigmoid": 140 | return nn.Sigmoid() 141 | else: 142 | print("invalid activation function!") 143 | return None 144 | -------------------------------------------------------------------------------- /learning/modules/actor_critic_recurrent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # torch 5 | import torch 6 | import torch.nn as nn 7 | 8 | # learning 9 | from learning.modules.actor_critic import ActorCritic, get_activation 10 | from learning.utils import unpad_trajectories 11 | 12 | 13 | class ActorCriticRecurrent(ActorCritic): 14 | is_recurrent = True 15 | 16 | def __init__( 17 | self, 18 | num_actor_obs, 19 | num_critic_obs, 20 | num_actions, 21 | actor_hidden_dims=[256, 256, 256], 22 | critic_hidden_dims=[256, 256, 256], 23 | activation="elu", 24 | rnn_type="lstm", 25 | rnn_hidden_size=256, 26 | rnn_num_layers=1, 27 | init_noise_std=1.0, 28 | **kwargs, 29 | ): 30 | if kwargs: 31 | print( 32 | "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()), 33 | ) 34 | 35 | super().__init__( 36 | num_actor_obs=rnn_hidden_size, 37 | num_critic_obs=rnn_hidden_size, 38 | num_actions=num_actions, 39 | actor_hidden_dims=actor_hidden_dims, 40 | critic_hidden_dims=critic_hidden_dims, 41 | activation=activation, 42 | init_noise_std=init_noise_std, 43 | ) 44 | 45 | activation = get_activation(activation) 46 | 47 | self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 48 | self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 49 | 50 | print(f"Actor RNN: {self.memory_a}") 51 | print(f"Critic RNN: {self.memory_c}") 52 | 53 | def reset(self, dones=None): 54 | self.memory_a.reset(dones) 55 | self.memory_c.reset(dones) 56 | 57 | def act(self, observations, masks=None, hidden_states=None): 58 | input_a = self.memory_a(observations, masks, hidden_states) 59 | return super().act(input_a.squeeze(0)) 60 | 61 | def act_inference(self, observations): 62 | input_a = self.memory_a(observations) 63 | return super().act_inference(input_a.squeeze(0)) 64 | 65 | def evaluate(self, critic_observations, masks=None, hidden_states=None): 66 | input_c = self.memory_c(critic_observations, masks, hidden_states) 67 | return super().evaluate(input_c.squeeze(0)) 68 | 69 | def get_hidden_states(self): 70 | return self.memory_a.hidden_states, self.memory_c.hidden_states 71 | 72 | 73 | class Memory(torch.nn.Module): 74 | def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256): 75 | super().__init__() 76 | # RNN 77 | rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM 78 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 79 | self.hidden_states = None 80 | 81 | def forward(self, input, masks=None, hidden_states=None): 82 | batch_mode = masks is not None 83 | if batch_mode: 84 | # batch mode (policy update): need saved hidden states 85 | if hidden_states is None: 86 | raise ValueError("Hidden states not passed to memory module during policy update") 87 | out, _ = self.rnn(input, hidden_states) 88 | out = unpad_trajectories(out, masks) 89 | else: 90 | # inference mode (collection): use hidden states of last step 91 | out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) 92 | return out 93 | 94 | def reset(self, dones=None): 95 | # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state 96 | for hidden_state in self.hidden_states: 97 | hidden_state[..., dones, :] = 0.0 98 | -------------------------------------------------------------------------------- /learning/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data 4 | from torch import autograd 5 | 6 | 7 | class Discriminator(nn.Module): 8 | def __init__(self, 9 | observation_dim, 10 | observation_horizon, 11 | device, 12 | state_normalizer=None, 13 | reward_normalizer=None, 14 | shape=[1024, 512], 15 | style_reward_function="quad_mapping", 16 | **kwargs, 17 | ): 18 | if kwargs: 19 | print("Discriminator.__init__ got unexpected arguments, which will be ignored: " 20 | + str([key for key in kwargs.keys()])) 21 | super(Discriminator, self).__init__() 22 | self.observation_dim = observation_dim 23 | self.observation_horizon = observation_horizon 24 | self.input_dim = observation_dim * observation_horizon 25 | self.device = device 26 | self.state_normalizer = state_normalizer 27 | self.reward_normalizer = reward_normalizer 28 | self.style_reward_function = style_reward_function 29 | self.shape = shape 30 | 31 | discriminator_layers = [] 32 | curr_in_dim = self.input_dim 33 | for hidden_dim in self.shape: 34 | discriminator_layers.append(nn.Linear(curr_in_dim, hidden_dim)) 35 | discriminator_layers.append(nn.ReLU()) 36 | curr_in_dim = hidden_dim 37 | discriminator_layers.append(nn.Linear(self.shape[-1], 1)) 38 | self.architecture = nn.Sequential(*discriminator_layers).to(self.device) 39 | self.architecture.train() 40 | 41 | def forward(self, x): 42 | return self.architecture(x) 43 | 44 | def compute_grad_pen(self, expert_state_buf, lambda_=10): 45 | expert_data = expert_state_buf.flatten(1, 2) 46 | expert_data.requires_grad = True 47 | 48 | disc = self.architecture(expert_data) 49 | ones = torch.ones(disc.size(), device=disc.device) 50 | grad = autograd.grad( 51 | outputs=disc, inputs=expert_data, 52 | grad_outputs=ones, create_graph=True, 53 | retain_graph=True, only_inputs=True)[0] 54 | 55 | # Enforce that the grad norm approaches 0. 56 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 0).pow(2).mean() 57 | return grad_pen 58 | 59 | def predict_cassi_reward(self, state_buf): 60 | with torch.no_grad(): 61 | self.eval() 62 | if self.state_normalizer is not None: 63 | for i in range(self.observation_horizon): 64 | state_buf[:, i] = self.state_normalizer.normalize(state_buf[:, i].clone()) 65 | d = self.architecture(state_buf.flatten(1, 2)) 66 | if self.style_reward_function == "quad_mapping": 67 | style_reward = torch.clamp(1 - (1/4) * torch.square(d - 1), min=0) 68 | elif self.style_reward_function == "log_mapping": 69 | style_reward = -torch.log(torch.maximum(1 - 1 / (1 + torch.exp(-d)), torch.tensor(0.0001, device=self.device))) 70 | elif self.style_reward_function == "wasserstein_mapping": 71 | if self.reward_normalizer is not None: 72 | style_reward = self.reward_normalizer.normalize(d.clone()) 73 | self.reward_normalizer.update(d) 74 | else: 75 | style_reward = d 76 | else: 77 | raise ValueError("Unexpected style reward mapping specified") 78 | self.train() 79 | return style_reward.squeeze() 80 | -------------------------------------------------------------------------------- /learning/modules/discriminator_ensemble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DiscriminatorEnsemble(nn.Module): 6 | def __init__(self, 7 | observation_dim, 8 | observation_horizon, 9 | num_classes, 10 | device, 11 | shape=[1024, 512], 12 | ensemble_size=5, 13 | eps=1e-7, 14 | incremental_input=False, 15 | ): 16 | super(DiscriminatorEnsemble, self).__init__() 17 | self.observation_dim = observation_dim 18 | self.observation_horizon = observation_horizon 19 | self.input_dim = observation_dim * observation_horizon 20 | self.device = device 21 | self.num_classes = num_classes 22 | self.shape = shape 23 | self.ensemble_size = ensemble_size 24 | self.eps = eps 25 | self.incremental_input = incremental_input 26 | 27 | self.ensemble = nn.ModuleList( 28 | [ 29 | Discriminator( 30 | observation_dim, 31 | observation_horizon, 32 | num_classes, 33 | device, 34 | shape, 35 | eps, 36 | incremental_input, 37 | ) 38 | for _ in range(ensemble_size) 39 | ] 40 | ) 41 | 42 | def forward(self, x): 43 | # bootstrapping 44 | indices = [] 45 | out = [] 46 | for discriminator in self.ensemble: 47 | idx = torch.randint(0, x.size(0), (x.size(0),), device=self.device) 48 | indices.append(idx) 49 | out.append(discriminator.architecture(x[idx])) 50 | return out, indices 51 | 52 | def compute_dis_skill_reward(self, observation_buf, style_selector): 53 | with torch.no_grad(): 54 | self.eval() 55 | logp_ensemble = [] 56 | for discriminator in self.ensemble: 57 | discriminator.eval() 58 | logp = discriminator.predict_logp(observation_buf, style_selector) 59 | logp_ensemble.append(logp) 60 | discriminator.train() 61 | logp_avg = torch.log(torch.exp(torch.cat(logp_ensemble, dim=1)).mean(dim=1)) 62 | skill_reward = logp_avg - torch.log(torch.tensor(1 / self.num_classes, device=self.device)) 63 | self.train() 64 | return torch.clip(skill_reward, min=0.0) 65 | 66 | def compute_dis_disdain_reward(self, observation_buf): 67 | with torch.no_grad(): 68 | self.eval() 69 | entropy_ensemble = [] 70 | probs_ensemble = [] 71 | for discriminator in self.ensemble: 72 | discriminator.eval() 73 | entropy_ensemble.append(discriminator.predict_entropy(observation_buf).unsqueeze(1)) 74 | logits = discriminator.predict_logits(observation_buf).unsqueeze(1) 75 | probs = nn.functional.softmax(logits, dim=2) 76 | probs_ensemble.append(probs) 77 | discriminator.train() 78 | # mean of the entropy 79 | entropy_avg = torch.cat(entropy_ensemble, dim=1).mean(dim=1) 80 | # entropy of the mean 81 | probs_avg = torch.cat(probs_ensemble, dim=1).mean(dim=1) 82 | entropy = (-probs_avg * torch.log(probs_avg + self.eps)).sum(dim=1) 83 | # DISDAIN reward 84 | disdain_reward = entropy - entropy_avg 85 | self.train() 86 | return disdain_reward 87 | 88 | class Discriminator(nn.Module): 89 | def __init__(self, 90 | observation_dim, 91 | observation_horizon, 92 | num_classes, 93 | device, 94 | shape=[1024, 512], 95 | eps=1e-7, 96 | incremental_input=False, 97 | ): 98 | super(Discriminator, self).__init__() 99 | self.observation_dim = observation_dim 100 | self.observation_horizon = observation_horizon 101 | self.input_dim = observation_dim * observation_horizon 102 | self.device = device 103 | self.num_classes = num_classes 104 | self.shape = shape 105 | self.eps = eps 106 | self.incremental_input = incremental_input 107 | 108 | discriminator_layers = [] 109 | curr_in_dim = self.input_dim 110 | discriminator_layers.append(nn.BatchNorm1d(curr_in_dim)) 111 | for hidden_dim in self.shape: 112 | discriminator_layers.append(nn.Linear(curr_in_dim, hidden_dim)) 113 | discriminator_layers.append(nn.ReLU()) 114 | curr_in_dim = hidden_dim 115 | discriminator_layers.append(nn.Linear(self.shape[-1], self.num_classes)) 116 | self.architecture = nn.Sequential(*discriminator_layers).to(self.device) 117 | self.architecture.train() 118 | 119 | def forward(self, x): 120 | return self.architecture(x) 121 | 122 | def predict_logits(self, observation_buf): 123 | with torch.no_grad(): 124 | self.eval() 125 | if self.incremental_input: 126 | logits = self.architecture((observation_buf - observation_buf[:, :1, :]).flatten(1, 2)) 127 | else: 128 | logits = self.architecture((observation_buf).flatten(1, 2)) 129 | self.train() 130 | return logits 131 | 132 | def predict_logp(self, observation_buf, style_selector): 133 | with torch.no_grad(): 134 | self.eval() 135 | logits = self.predict_logits(observation_buf) 136 | logp = torch.gather(nn.functional.log_softmax(logits, dim=1), dim=1, index=style_selector.unsqueeze(-1)) 137 | self.train() 138 | return logp 139 | 140 | def predict_entropy(self, observation_buf): 141 | with torch.no_grad(): 142 | self.eval() 143 | logits = self.predict_logits(observation_buf) 144 | probs = nn.functional.softmax(logits, dim=1) 145 | entropy = (-probs * torch.log(probs + self.eps)).sum(dim=1) 146 | self.train() 147 | return entropy 148 | -------------------------------------------------------------------------------- /learning/modules/normalizer.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Preferred Networks, Inc. 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | import numpy as np 23 | 24 | import torch 25 | from torch import nn 26 | from typing import Tuple 27 | 28 | 29 | class EmpiricalNormalization(nn.Module): 30 | """Normalize mean and variance of values based on empirical values. 31 | Args: 32 | shape (int or tuple of int): Shape of input values except batch axis. 33 | batch_axis (int): Batch axis. 34 | eps (float): Small value for stability. 35 | dtype (dtype): Dtype of input values. 36 | until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes 37 | exceeds it. 38 | update_obs_norm (bool): If true, learns updates mean and variance 39 | """ 40 | 41 | def __init__( 42 | self, 43 | shape, 44 | batch_axis=0, 45 | eps=1e-2, 46 | dtype=np.float32, 47 | until=None, 48 | clip_threshold=None, 49 | update_obs_norm=True, 50 | ): 51 | super(EmpiricalNormalization, self).__init__() 52 | dtype = np.dtype(dtype) 53 | self.batch_axis = batch_axis 54 | self.eps = eps 55 | self.until = until 56 | self.clip_threshold = clip_threshold 57 | self.register_buffer( 58 | "_mean", 59 | torch.tensor(np.expand_dims(np.zeros(shape, dtype=dtype), batch_axis)), 60 | ) 61 | self.register_buffer( 62 | "_var", 63 | torch.tensor(np.expand_dims(np.ones(shape, dtype=dtype), batch_axis)), 64 | ) 65 | self.register_buffer("count", torch.tensor(0)) 66 | self.in_features = shape[0] 67 | 68 | # cache 69 | self._cached_std_inverse = torch.tensor(np.expand_dims(np.ones(shape, dtype=dtype), batch_axis)) 70 | self._is_std_cached = False 71 | self._is_training = update_obs_norm 72 | 73 | @property 74 | def mean(self): 75 | return torch.squeeze(self._mean, self.batch_axis).clone() 76 | 77 | @property 78 | def std(self): 79 | return torch.sqrt(torch.squeeze(self._var, self.batch_axis)).clone() 80 | 81 | @property 82 | def _std_inverse(self): 83 | if self._is_std_cached is False: 84 | self._cached_std_inverse = (self._var + self.eps) ** -0.5 85 | 86 | return self._cached_std_inverse 87 | 88 | @torch.jit.unused 89 | @torch.no_grad() 90 | def experience(self, x): 91 | """Learn input values without computing the output values of them""" 92 | 93 | if self.until is not None: 94 | if self.count >= self.until: 95 | return 96 | 97 | count_x = x.shape[self.batch_axis] 98 | if count_x == 0: 99 | return 100 | 101 | self.count += count_x 102 | rate = count_x / self.count.float() 103 | assert rate > 0 104 | assert rate <= 1 105 | 106 | var_x = torch.var(x, dim=self.batch_axis, unbiased=False, keepdim=True) 107 | mean_x = torch.mean(x, dim=self.batch_axis, keepdim=True) 108 | delta_mean = mean_x - self._mean 109 | self._mean += rate * delta_mean 110 | self._var += rate * (var_x - self._var + delta_mean * (mean_x - self._mean)) 111 | 112 | # clear cache 113 | self._is_std_cached = False 114 | 115 | def forward(self, x): 116 | """Normalize mean and variance of values based on emprical values. 117 | Args: 118 | x (ndarray or Variable): Input values 119 | Returns: 120 | ndarray or Variable: Normalized output values 121 | """ 122 | 123 | if self._is_training: 124 | self.experience(x) 125 | 126 | if not x.is_cuda: 127 | self._is_std_cached = False 128 | normalized = (x - self._mean) * self._std_inverse 129 | if self.clip_threshold is not None: 130 | normalized = torch.clamp(normalized, -self.clip_threshold, self.clip_threshold) 131 | if not x.is_cuda: 132 | self._is_std_cached = False 133 | return normalized 134 | 135 | @torch.jit.unused 136 | def inverse(self, y): 137 | std = torch.sqrt(self._var + self.eps) 138 | return y * std + self._mean 139 | 140 | def load_numpy(self, mean, var, count, device="cpu"): 141 | self._mean = torch.from_numpy(np.expand_dims(mean, self.batch_axis)).to(device) 142 | self._var = torch.from_numpy(np.expand_dims(var, self.batch_axis)).to(device) 143 | self.count = torch.tensor(count).to(device) 144 | 145 | class Normalizer: 146 | def __init__(self, input_dim, device, epsilon=1e-2, clip=10.0): 147 | self.device = device 148 | self.mean = torch.zeros(input_dim, device=self.device) 149 | self.var = torch.ones(input_dim, device=self.device) 150 | self.count = epsilon 151 | self.epsilon = epsilon 152 | self.clip = clip 153 | 154 | def normalize(self, data): 155 | mean_ = self.mean 156 | std_ = torch.sqrt(self.var + self.epsilon) 157 | return torch.clamp((data - mean_) / std_, -self.clip, self.clip) 158 | 159 | def update(self, data): 160 | batch_mean = torch.mean(data, dim=0) 161 | batch_var = torch.var(data, dim=0) 162 | batch_count = data.shape[0] 163 | self.update_from_moments(batch_mean, batch_var, batch_count) 164 | 165 | def update_from_moments(self, batch_mean, batch_var, batch_count): 166 | delta = batch_mean - self.mean 167 | tot_count = self.count + batch_count 168 | 169 | new_mean = self.mean + delta * batch_count / tot_count 170 | new_var = (self.var * self.count + 171 | batch_var * batch_count + 172 | torch.square(delta) * self.count * batch_count / tot_count) / tot_count 173 | self.mean = new_mean 174 | self.var = new_var 175 | self.count = tot_count 176 | -------------------------------------------------------------------------------- /learning/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Implementation of runners for environment-agent interaction.""" 5 | 6 | from .cassi_on_policy_runner import CASSIOnPolicyRunner 7 | 8 | __all__ = ["CASSIOnPolicyRunner"] 9 | -------------------------------------------------------------------------------- /learning/runners/cassi_on_policy_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # python 5 | import os 6 | import time 7 | import statistics 8 | from collections import deque 9 | from typing import Union 10 | 11 | # torch 12 | import torch 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | # learning 16 | import learning 17 | from learning.algorithms import CASSI 18 | from learning.modules import ActorCritic, ActorCriticRecurrent 19 | from learning.modules.discriminator import Discriminator 20 | from learning.modules.normalizer import Normalizer 21 | from learning.modules.discriminator_ensemble import DiscriminatorEnsemble 22 | from learning.env import VecEnv 23 | 24 | 25 | class CASSIOnPolicyRunner: 26 | def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): 27 | 28 | self.cfg = train_cfg["runner"] 29 | self.alg_cfg = train_cfg["algorithm"] 30 | self.policy_cfg = train_cfg["policy"] 31 | self.discriminator_cfg = train_cfg["discriminator"] 32 | self.discriminator_ensemble_cfg = train_cfg["discriminator_ensemble"] 33 | 34 | self.device = device 35 | self.env = env 36 | if self.env.num_privileged_obs is not None: 37 | num_critic_obs = self.env.num_privileged_obs 38 | else: 39 | num_critic_obs = self.env.num_obs 40 | actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic 41 | actor_critic: Union[ActorCritic, ActorCriticRecurrent] = actor_critic_class( 42 | self.env.num_obs, num_critic_obs, self.env.num_actions, **self.policy_cfg 43 | ).to(self.device) 44 | alg_class = eval(self.cfg["algorithm_class_name"]) # CASSI 45 | cassi_expert_data = self.env.motion_loader 46 | cassi_state_normalizer = Normalizer(cassi_expert_data.observation_dim, self.device) 47 | if self.cfg["normalize_style_reward"]: 48 | cassi_style_reward_normalizer = Normalizer(1, self.device) 49 | else: 50 | cassi_style_reward_normalizer = None 51 | discriminator = Discriminator( 52 | observation_dim=cassi_expert_data.observation_dim, 53 | observation_horizon=self.env.reference_observation_horizon, 54 | device=self.device, 55 | state_normalizer=cassi_state_normalizer, 56 | reward_normalizer=cassi_style_reward_normalizer, 57 | **self.discriminator_cfg).to(self.device) 58 | discriminator_ensemble = DiscriminatorEnsemble( 59 | observation_dim=self.env.dis_observation_dim, 60 | observation_horizon=self.env.dis_observation_horizon, 61 | num_classes=self.env.dis_num_classes, 62 | device=self.device, 63 | **self.discriminator_ensemble_cfg).to(self.device) 64 | 65 | self.alg: CASSI = alg_class(actor_critic, discriminator, cassi_expert_data, discriminator_ensemble, device=self.device, **self.alg_cfg) 66 | self.env.discriminator = discriminator 67 | self.env.discriminator_ensemble = discriminator_ensemble 68 | self.num_steps_per_env = self.cfg["num_steps_per_env"] 69 | self.save_interval = self.cfg["save_interval"] 70 | 71 | # init storage and model 72 | self.alg.init_storage( 73 | self.env.num_envs, 74 | self.num_steps_per_env, 75 | [self.env.num_obs], 76 | [self.env.num_privileged_obs], 77 | [self.env.num_actions], 78 | ) 79 | 80 | # Log 81 | self.log_dir = log_dir 82 | self.writer = None 83 | self.tot_timesteps = 0 84 | self.tot_time = 0 85 | self.current_learning_iteration = 0 86 | self.git_status_repos = [learning.__file__] 87 | 88 | _, _ = self.env.reset() 89 | 90 | def learn(self, num_learning_iterations, init_at_random_ep_len=False): 91 | # initialize writer 92 | if self.log_dir is not None and self.writer is None: 93 | self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) 94 | if init_at_random_ep_len: 95 | self.env.episode_length_buf = torch.randint_like( 96 | self.env.episode_length_buf, high=int(self.env.max_episode_length) 97 | ) 98 | obs = self.env.get_observations() 99 | privileged_obs = self.env.get_privileged_observations() 100 | cassi_observation_buf = self.env.get_cassi_observation_buf() 101 | dis_observation_buf = self.env.get_dis_observation_buf() 102 | critic_obs = privileged_obs if privileged_obs is not None else obs 103 | obs, critic_obs, cassi_observation_buf, dis_observation_buf = obs.to(self.device), critic_obs.to(self.device), cassi_observation_buf.to(self.device), dis_observation_buf.to(self.device) 104 | self.alg.actor_critic.train() # switch to train mode (for dropout for example) 105 | self.alg.discriminator.train() 106 | self.alg.discriminator_ensemble.train() 107 | 108 | ep_infos = [] 109 | rewbuffer = deque(maxlen=100) 110 | lenbuffer = deque(maxlen=100) 111 | cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) 112 | cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) 113 | 114 | tot_iter = self.current_learning_iteration + num_learning_iterations 115 | for it in range(self.current_learning_iteration, tot_iter): 116 | start = time.time() 117 | # Rollout 118 | with torch.inference_mode(): 119 | for i in range(self.num_steps_per_env): 120 | actions = self.alg.act(obs, critic_obs, cassi_observation_buf, dis_observation_buf) 121 | obs, privileged_obs, rewards, dones, infos = self.env.step(actions) 122 | next_cassi_obs = self.env.get_cassi_observations() 123 | next_dis_obs = self.env.get_dis_observations() 124 | critic_obs = privileged_obs if privileged_obs is not None else obs 125 | obs, critic_obs, next_cassi_obs, next_dis_obs, rewards, dones = ( 126 | obs.to(self.device), 127 | critic_obs.to(self.device), 128 | next_cassi_obs.to(self.device), 129 | next_dis_obs.to(self.device), 130 | rewards.to(self.device), 131 | dones.to(self.device), 132 | ) 133 | cassi_observation_buf[:, :-1] = cassi_observation_buf[:, 1:].clone() 134 | cassi_observation_buf[:, -1] = next_cassi_obs.clone() 135 | dis_observation_buf[:, :-1] = dis_observation_buf[:, 1:].clone() 136 | dis_observation_buf[:, -1] = next_dis_obs.clone() 137 | style_selector = self.env.get_style_selector() 138 | self.alg.process_env_step(rewards, dones, infos, next_cassi_obs, next_dis_obs, style_selector) 139 | 140 | if self.log_dir is not None: 141 | # Book keeping 142 | if "episode" in infos: 143 | ep_infos.append(infos["episode"]) 144 | cur_reward_sum += rewards 145 | cur_episode_length += 1 146 | new_ids = (dones > 0).nonzero(as_tuple=False) 147 | rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) 148 | lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) 149 | cur_reward_sum[new_ids] = 0 150 | cur_episode_length[new_ids] = 0 151 | 152 | stop = time.time() 153 | collection_time = stop - start 154 | 155 | # Learning step 156 | start = stop 157 | self.alg.compute_returns(critic_obs) 158 | 159 | mean_value_loss, mean_surrogate_loss, mean_cassi_loss, mean_grad_pen_loss, mean_policy_pred, mean_expert_pred, mean_discriminator_ensemble_loss, discriminator_ensemble_accuracy = self.alg.update() 160 | 161 | stop = time.time() 162 | learn_time = stop - start 163 | if self.log_dir is not None: 164 | self.log(locals()) 165 | if it % self.save_interval == 0: 166 | self.save(os.path.join(self.log_dir, "model_{}.pt".format(it))) 167 | ep_infos.clear() 168 | 169 | self.current_learning_iteration += num_learning_iterations 170 | self.save(os.path.join(self.log_dir, "model_{}.pt".format(self.current_learning_iteration))) 171 | 172 | def log(self, locs, width=80, pad=35): 173 | self.tot_timesteps += self.num_steps_per_env * self.env.num_envs 174 | self.tot_time += locs["collection_time"] + locs["learn_time"] 175 | iteration_time = locs["collection_time"] + locs["learn_time"] 176 | 177 | ep_string = "" 178 | if locs["ep_infos"]: 179 | for key in locs["ep_infos"][0]: 180 | infotensor = torch.tensor([], device=self.device) 181 | for ep_info in locs["ep_infos"]: 182 | # handle scalar and zero dimensional tensor infos 183 | if not isinstance(ep_info[key], torch.Tensor): 184 | ep_info[key] = torch.Tensor([ep_info[key]]) 185 | if len(ep_info[key].shape) == 0: 186 | ep_info[key] = ep_info[key].unsqueeze(0) 187 | infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) 188 | value = torch.mean(infotensor) 189 | self.writer.add_scalar("Episode/" + key, value, locs["it"]) 190 | ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" 191 | mean_std = self.alg.actor_critic.std.mean() 192 | fps = int(self.num_steps_per_env * self.env.num_envs / (locs["collection_time"] + locs["learn_time"])) 193 | 194 | self.writer.add_scalar("Loss/value_function", locs["mean_value_loss"], locs["it"]) 195 | self.writer.add_scalar("Loss/surrogate", locs["mean_surrogate_loss"], locs["it"]) 196 | self.writer.add_scalar("Loss/learning_rate", self.alg.policy_learning_rate, locs["it"]) 197 | self.writer.add_scalar("Loss/CASSI", locs["mean_cassi_loss"], locs["it"]) 198 | self.writer.add_scalar("Loss/cassi_grad", locs["mean_grad_pen_loss"], locs["it"]) 199 | self.writer.add_scalar("Discriminator/policy_pred", locs["mean_policy_pred"], locs["it"]) 200 | self.writer.add_scalar("Discriminator/expert_pred", locs["mean_expert_pred"], locs["it"]) 201 | for i in range(self.alg.discriminator_ensemble_ensemble_size): 202 | self.writer.add_scalar(f"Discriminator_ensemble/loss/component_{i}", locs["mean_discriminator_ensemble_loss"][i], locs["it"]) 203 | self.writer.add_scalar(f"Discriminator_ensemble/accuracy/component_{i}", locs["discriminator_ensemble_accuracy"][i], locs["it"]) 204 | self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"]) 205 | self.writer.add_scalar("Perf/total_fps", fps, locs["it"]) 206 | self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"]) 207 | self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"]) 208 | if len(locs["rewbuffer"]) > 0: 209 | self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"]) 210 | self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"]) 211 | self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time) 212 | self.writer.add_scalar("Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time) 213 | 214 | str = f" \033[1m Learning iteration {locs['it']}/{self.current_learning_iteration + locs['num_learning_iterations']} \033[0m " 215 | 216 | if len(locs["rewbuffer"]) > 0: 217 | log_string = ( 218 | f"""{'#' * width}\n""" 219 | f"""{str.center(width, ' ')}\n\n""" 220 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ 221 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 222 | f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" 223 | f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" 224 | f"""{'CASSI loss:':>{pad}} {locs['mean_cassi_loss']:.4f}\n""" 225 | f"""{'CASSI grad pen loss:':>{pad}} {locs['mean_grad_pen_loss']:.4f}\n""" 226 | f"""{'CASSI mean policy pred:':>{pad}} {locs['mean_policy_pred']:.4f}\n""" 227 | f"""{'CASSI mean expert pred:':>{pad}} {locs['mean_expert_pred']:.4f}\n""" 228 | f"""{'Discriminator ensemble loss:':>{pad}} {locs['mean_discriminator_ensemble_loss'].mean():.4f}\n""" 229 | f"""{'Discriminator ensemble accuracy:':>{pad}} {locs['discriminator_ensemble_accuracy'].mean():.4f}\n""" 230 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" 231 | f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" 232 | f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""" 233 | ) 234 | # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" 235 | # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") 236 | else: 237 | log_string = ( 238 | f"""{'#' * width}\n""" 239 | f"""{str.center(width, ' ')}\n\n""" 240 | f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ 241 | 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 242 | f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" 243 | f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" 244 | f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" 245 | ) 246 | # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" 247 | # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") 248 | 249 | log_string += ep_string 250 | log_string += ( 251 | f"""{'-' * width}\n""" 252 | f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" 253 | f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" 254 | f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n""" 255 | f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * ( 256 | locs['num_learning_iterations'] - locs['it']):.1f}s\n""" 257 | ) 258 | print(log_string) 259 | 260 | def save(self, path, infos=None): 261 | torch.save( 262 | { 263 | "model_state_dict": self.alg.actor_critic.state_dict(), 264 | "discriminator_state_dict": self.alg.discriminator.state_dict(), 265 | "discriminator_ensemble_state_dict": self.alg.discriminator_ensemble.state_dict(), 266 | "policy_optimizer_state_dict": self.alg.policy_optimizer.state_dict(), 267 | "discriminator_optimizer_state_dict": self.alg.discriminator_optimizer.state_dict(), 268 | "cassi_state_normalizer": self.alg.cassi_state_normalizer, 269 | "cassi_style_reward_normalizer": self.alg.cassi_style_reward_normalizer, 270 | "discriminator_ensemble_optimizer_state_dict": [discriminator_ensemble_optimizer.state_dict() for discriminator_ensemble_optimizer in self.alg.discriminator_ensemble_optimizer], 271 | "iter": self.current_learning_iteration, 272 | "infos": infos, 273 | }, 274 | path, 275 | ) 276 | 277 | def load(self, path, load_optimizer=True): 278 | loaded_dict = torch.load(path) 279 | self.alg.actor_critic.load_state_dict(loaded_dict["model_state_dict"]) 280 | self.alg.discriminator.load_state_dict(loaded_dict["discriminator_state_dict"]) 281 | self.alg.cassi_state_normalizer = loaded_dict["cassi_state_normalizer"] 282 | self.alg.cassi_style_reward_normalizer = loaded_dict["cassi_style_reward_normalizer"] 283 | self.alg.discriminator_ensemble.load_state_dict(loaded_dict["discriminator_ensemble_state_dict"]) 284 | if load_optimizer: 285 | self.alg.policy_optimizer.load_state_dict(loaded_dict["policy_optimizer_state_dict"]) 286 | self.alg.discriminator_optimizer.load_state_dict(loaded_dict["discriminator_optimizer_state_dict"]) 287 | for i in range(self.alg.discriminator_ensemble_ensemble_size): 288 | self.alg.discriminator_ensemble_optimizer[i].load_state_dict(loaded_dict["discriminator_ensemble_optimizer_state_dict"][i]) 289 | self.current_learning_iteration = loaded_dict["iter"] 290 | return loaded_dict["infos"] 291 | 292 | def get_inference_policy(self, device=None): 293 | self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example) 294 | if device is not None: 295 | self.alg.actor_critic.to(device) 296 | return self.alg.actor_critic.act_inference 297 | 298 | def add_git_repo_to_log(self, repo_file_path): 299 | self.git_status_repos.append(repo_file_path) 300 | -------------------------------------------------------------------------------- /learning/storage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Implementation of transitions storage for RL-agent.""" 5 | 6 | from .rollout_storage import RolloutStorage 7 | 8 | __all__ = ["RolloutStorage"] 9 | -------------------------------------------------------------------------------- /learning/storage/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ReplayBuffer: 6 | """Fixed-size buffer to store experience tuples.""" 7 | 8 | def __init__(self, obs_dim, obs_horizon, buffer_size, device): 9 | """Initialize a ReplayBuffer object. 10 | Arguments: 11 | buffer_size (int): maximum size of buffer 12 | """ 13 | self.state_buf = torch.zeros(buffer_size, obs_horizon, obs_dim).to(device) 14 | self.buffer_size = buffer_size 15 | self.device = device 16 | 17 | self.step = 0 18 | self.num_samples = 0 19 | 20 | def insert(self, state_buf): 21 | """Add new states to memory.""" 22 | 23 | num_states = state_buf.shape[0] 24 | start_idx = self.step 25 | end_idx = self.step + num_states 26 | if end_idx > self.buffer_size: 27 | self.state_buf[self.step:self.buffer_size] = state_buf[:self.buffer_size - self.step] 28 | self.state_buf[:end_idx - self.buffer_size] = state_buf[self.buffer_size - self.step:] 29 | else: 30 | self.state_buf[start_idx:end_idx] = state_buf 31 | 32 | self.num_samples = min(self.buffer_size, max(end_idx, self.num_samples)) 33 | self.step = (self.step + num_states) % self.buffer_size 34 | 35 | def feed_forward_generator(self, num_mini_batch, mini_batch_size): 36 | for _ in range(num_mini_batch): 37 | sample_idxs = np.random.choice(self.num_samples, size=mini_batch_size) 38 | yield self.state_buf[sample_idxs, :].to(self.device) 39 | -------------------------------------------------------------------------------- /learning/storage/rollout_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # torch 5 | import torch 6 | 7 | # learning 8 | from learning.utils import split_and_pad_trajectories 9 | 10 | 11 | class RolloutStorage: 12 | class Transition: 13 | def __init__(self): 14 | self.observations = None 15 | self.critic_observations = None 16 | self.actions = None 17 | self.rewards = None 18 | self.dones = None 19 | self.values = None 20 | self.actions_log_prob = None 21 | self.action_mean = None 22 | self.action_sigma = None 23 | self.hidden_states = None 24 | 25 | def clear(self): 26 | self.__init__() 27 | 28 | def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device="cpu"): 29 | 30 | self.device = device 31 | 32 | self.obs_shape = obs_shape 33 | self.privileged_obs_shape = privileged_obs_shape 34 | self.actions_shape = actions_shape 35 | 36 | # Core 37 | self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device) 38 | if privileged_obs_shape[0] is not None: 39 | self.privileged_observations = torch.zeros( 40 | num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device 41 | ) 42 | else: 43 | self.privileged_observations = None 44 | self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 45 | self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 46 | self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() 47 | 48 | # For PPO 49 | self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 50 | self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 51 | self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 52 | self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 53 | self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 54 | self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 55 | 56 | self.num_transitions_per_env = num_transitions_per_env 57 | self.num_envs = num_envs 58 | 59 | # rnn 60 | self.saved_hidden_states_a = None 61 | self.saved_hidden_states_c = None 62 | 63 | self.step = 0 64 | 65 | def add_transitions(self, transition: Transition): 66 | if self.step >= self.num_transitions_per_env: 67 | raise AssertionError("Rollout buffer overflow") 68 | self.observations[self.step].copy_(transition.observations) 69 | if self.privileged_observations is not None: 70 | self.privileged_observations[self.step].copy_(transition.critic_observations) 71 | self.actions[self.step].copy_(transition.actions) 72 | self.rewards[self.step].copy_(transition.rewards.view(-1, 1)) 73 | self.dones[self.step].copy_(transition.dones.view(-1, 1)) 74 | self.values[self.step].copy_(transition.values) 75 | self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) 76 | self.mu[self.step].copy_(transition.action_mean) 77 | self.sigma[self.step].copy_(transition.action_sigma) 78 | self._save_hidden_states(transition.hidden_states) 79 | self.step += 1 80 | 81 | def _save_hidden_states(self, hidden_states): 82 | if hidden_states is None or hidden_states == (None, None): 83 | return 84 | # make a tuple out of GRU hidden state sto match the LSTM format 85 | hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) 86 | hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) 87 | 88 | # initialize if needed 89 | if self.saved_hidden_states_a is None: 90 | self.saved_hidden_states_a = [ 91 | torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a)) 92 | ] 93 | self.saved_hidden_states_c = [ 94 | torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c)) 95 | ] 96 | # copy the states 97 | for i in range(len(hid_a)): 98 | self.saved_hidden_states_a[i][self.step].copy_(hid_a[i]) 99 | self.saved_hidden_states_c[i][self.step].copy_(hid_c[i]) 100 | 101 | def clear(self): 102 | self.step = 0 103 | 104 | def compute_returns(self, last_values, gamma, lam): 105 | advantage = 0 106 | for step in reversed(range(self.num_transitions_per_env)): 107 | if step == self.num_transitions_per_env - 1: 108 | next_values = last_values 109 | else: 110 | next_values = self.values[step + 1] 111 | next_is_not_terminal = 1.0 - self.dones[step].float() 112 | delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] 113 | advantage = delta + next_is_not_terminal * gamma * lam * advantage 114 | self.returns[step] = advantage + self.values[step] 115 | 116 | # Compute and normalize the advantages 117 | self.advantages = self.returns - self.values 118 | self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) 119 | 120 | def get_statistics(self): 121 | done = self.dones 122 | done[-1] = 1 123 | flat_dones = done.permute(1, 0, 2).reshape(-1, 1) 124 | done_indices = torch.cat( 125 | (flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]) 126 | ) 127 | trajectory_lengths = done_indices[1:] - done_indices[:-1] 128 | return trajectory_lengths.float().mean(), self.rewards.mean() 129 | 130 | def mini_batch_generator(self, num_mini_batches, num_epochs=8): 131 | batch_size = self.num_envs * self.num_transitions_per_env 132 | mini_batch_size = batch_size // num_mini_batches 133 | indices = torch.randperm(num_mini_batches * mini_batch_size, requires_grad=False, device=self.device) 134 | 135 | observations = self.observations.flatten(0, 1) 136 | if self.privileged_observations is not None: 137 | critic_observations = self.privileged_observations.flatten(0, 1) 138 | else: 139 | critic_observations = observations 140 | 141 | actions = self.actions.flatten(0, 1) 142 | values = self.values.flatten(0, 1) 143 | returns = self.returns.flatten(0, 1) 144 | old_actions_log_prob = self.actions_log_prob.flatten(0, 1) 145 | advantages = self.advantages.flatten(0, 1) 146 | old_mu = self.mu.flatten(0, 1) 147 | old_sigma = self.sigma.flatten(0, 1) 148 | 149 | for epoch in range(num_epochs): 150 | for i in range(num_mini_batches): 151 | 152 | start = i * mini_batch_size 153 | end = (i + 1) * mini_batch_size 154 | batch_idx = indices[start:end] 155 | 156 | obs_batch = observations[batch_idx] 157 | critic_observations_batch = critic_observations[batch_idx] 158 | actions_batch = actions[batch_idx] 159 | target_values_batch = values[batch_idx] 160 | returns_batch = returns[batch_idx] 161 | old_actions_log_prob_batch = old_actions_log_prob[batch_idx] 162 | advantages_batch = advantages[batch_idx] 163 | old_mu_batch = old_mu[batch_idx] 164 | old_sigma_batch = old_sigma[batch_idx] 165 | yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( 166 | None, 167 | None, 168 | ), None 169 | 170 | # for RNNs only 171 | def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8): 172 | 173 | padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) 174 | if self.privileged_observations is not None: 175 | padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones) 176 | else: 177 | padded_critic_obs_trajectories = padded_obs_trajectories 178 | 179 | mini_batch_size = self.num_envs // num_mini_batches 180 | for ep in range(num_epochs): 181 | first_traj = 0 182 | for i in range(num_mini_batches): 183 | start = i * mini_batch_size 184 | stop = (i + 1) * mini_batch_size 185 | 186 | dones = self.dones.squeeze(-1) 187 | last_was_done = torch.zeros_like(dones, dtype=torch.bool) 188 | last_was_done[1:] = dones[:-1] 189 | last_was_done[0] = True 190 | trajectories_batch_size = torch.sum(last_was_done[:, start:stop]) 191 | last_traj = first_traj + trajectories_batch_size 192 | 193 | masks_batch = trajectory_masks[:, first_traj:last_traj] 194 | obs_batch = padded_obs_trajectories[:, first_traj:last_traj] 195 | critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj] 196 | 197 | actions_batch = self.actions[:, start:stop] 198 | old_mu_batch = self.mu[:, start:stop] 199 | old_sigma_batch = self.sigma[:, start:stop] 200 | returns_batch = self.returns[:, start:stop] 201 | advantages_batch = self.advantages[:, start:stop] 202 | values_batch = self.values[:, start:stop] 203 | old_actions_log_prob_batch = self.actions_log_prob[:, start:stop] 204 | 205 | # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim]) 206 | # then take only time steps after dones (flattens num envs and time dimensions), 207 | # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim] 208 | last_was_done = last_was_done.permute(1, 0) 209 | hid_a_batch = [ 210 | saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] 211 | .transpose(1, 0) 212 | .contiguous() 213 | for saved_hidden_states in self.saved_hidden_states_a 214 | ] 215 | hid_c_batch = [ 216 | saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] 217 | .transpose(1, 0) 218 | .contiguous() 219 | for saved_hidden_states in self.saved_hidden_states_c 220 | ] 221 | # remove the tuple for GRU 222 | hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch 223 | hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_a_batch 224 | 225 | yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( 226 | hid_a_batch, 227 | hid_c_batch, 228 | ), masks_batch 229 | 230 | first_traj = last_traj 231 | -------------------------------------------------------------------------------- /learning/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Helper functions.""" 2 | 3 | from .utils import split_and_pad_trajectories, unpad_trajectories, store_code_state 4 | -------------------------------------------------------------------------------- /learning/utils/utils.py: -------------------------------------------------------------------------------- 1 | # python 2 | import os 3 | import git 4 | import pathlib 5 | 6 | # torch 7 | import torch 8 | 9 | 10 | def split_and_pad_trajectories(tensor, dones): 11 | """Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory. 12 | Returns masks corresponding to valid parts of the trajectories 13 | Example: 14 | Input: [ [a1, a2, a3, a4 | a5, a6], 15 | [b1, b2 | b3, b4, b5 | b6] 16 | ] 17 | 18 | Output:[ [a1, a2, a3, a4], | [ [True, True, True, True], 19 | [a5, a6, 0, 0], | [True, True, False, False], 20 | [b1, b2, 0, 0], | [True, True, False, False], 21 | [b3, b4, b5, 0], | [True, True, True, False], 22 | [b6, 0, 0, 0] | [True, False, False, False], 23 | ] | ] 24 | 25 | Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions] 26 | """ 27 | dones = dones.clone() 28 | dones[-1] = 1 29 | # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping 30 | flat_dones = dones.transpose(1, 0).reshape(-1, 1) 31 | 32 | # Get length of trajectory by counting the number of successive not done elements 33 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0])) 34 | trajectory_lengths = done_indices[1:] - done_indices[:-1] 35 | trajectory_lengths_list = trajectory_lengths.tolist() 36 | # Extract the individual trajectories 37 | trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list) 38 | padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) 39 | 40 | trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) 41 | return padded_trajectories, trajectory_masks 42 | 43 | 44 | def unpad_trajectories(trajectories, masks): 45 | """Does the inverse operation of split_and_pad_trajectories()""" 46 | # Need to transpose before and after the masking to have proper reshaping 47 | return ( 48 | trajectories.transpose(1, 0)[masks.transpose(1, 0)] 49 | .view(-1, trajectories.shape[0], trajectories.shape[-1]) 50 | .transpose(1, 0) 51 | ) 52 | 53 | 54 | def store_code_state(logdir, repositories): 55 | for repository_file_path in repositories: 56 | repo = git.Repo(repository_file_path, search_parent_directories=True) 57 | repo_name = pathlib.Path(repo.working_dir).name 58 | t = repo.head.commit.tree 59 | content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" 60 | with open(os.path.join(logdir, f"{repo_name}_git.diff"), "x") as f: 61 | f.write(content) 62 | -------------------------------------------------------------------------------- /resources/robots/solo8/datasets/motion_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/resources/robots/solo8/datasets/motion_data.pt -------------------------------------------------------------------------------- /resources/robots/solo8/datasets/reference_state_idx_dict.json: -------------------------------------------------------------------------------- 1 | {"base_pos": [0, 3], "base_quat": [3, 7], "base_lin_vel": [7, 10], "base_ang_vel": [10, 13], "projected_gravity": [13, 16], "base_height": [16, 17], "dof_pos": [17, 25], "dof_vel": [25, 33]} -------------------------------------------------------------------------------- /resources/robots/solo8/meshes/solo_body.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/resources/robots/solo8/meshes/solo_body.stl -------------------------------------------------------------------------------- /resources/robots/solo8/meshes/solo_foot.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/resources/robots/solo8/meshes/solo_foot.stl -------------------------------------------------------------------------------- /resources/robots/solo8/meshes/solo_lower_leg_left_side.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/resources/robots/solo8/meshes/solo_lower_leg_left_side.stl -------------------------------------------------------------------------------- /resources/robots/solo8/meshes/solo_lower_leg_right_side.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/resources/robots/solo8/meshes/solo_lower_leg_right_side.stl -------------------------------------------------------------------------------- /resources/robots/solo8/meshes/solo_upper_leg_left_side.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/resources/robots/solo8/meshes/solo_upper_leg_left_side.stl -------------------------------------------------------------------------------- /resources/robots/solo8/meshes/solo_upper_leg_right_side.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/cassi/a762cf516594593519dad9d9eb8a8471c4dc861e/resources/robots/solo8/meshes/solo_upper_leg_right_side.stl -------------------------------------------------------------------------------- /scripts/play.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plays a trained policy and logs statistics. 3 | """ 4 | 5 | # solo-gym 6 | from solo_gym import LEGGED_GYM_ROOT_DIR 7 | from solo_gym.envs import task_registry 8 | from solo_gym.utils import get_args, export_policy_as_jit, export_policy_as_onnx, Logger 9 | 10 | # python 11 | import argparse 12 | import os 13 | import numpy as np 14 | import torch 15 | 16 | # global settings 17 | EXPORT_POLICY = True 18 | MOVE_CAMERA = True 19 | 20 | def play(args: argparse.Namespace): 21 | args.task = "solo8" 22 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) 23 | # override some parameters for testing 24 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 2) 25 | env_cfg.terrain.num_rows = 5 26 | env_cfg.terrain.num_cols = 5 27 | env_cfg.terrain.curriculum = False 28 | env_cfg.noise.add_noise = False 29 | env_cfg.domain_rand.randomize_friction = False 30 | env_cfg.domain_rand.push_robots = False 31 | 32 | env_cfg.domain_rand.added_mass_range = [0.0, 0.0] 33 | env_cfg.commands.resampling_time = 1000.0 34 | env_cfg.commands.ranges.lin_vel_x = [0.5, 0.5] 35 | env_cfg.env.episode_length_s = 1000.0 36 | env_cfg.env.env_spacing = 100.0 37 | 38 | # camera 39 | env_cfg.viewer.pos = [0.0, -2.13, 1.22] 40 | dir = [0.0, 1.0, -0.4] 41 | env_cfg.viewer.lookat = [a + b for a, b in zip(env_cfg.viewer.pos, dir)] 42 | 43 | # prepare environment 44 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) 45 | obs = env.get_observations() 46 | # load policy 47 | train_cfg.runner.resume = True 48 | ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args, train_cfg=train_cfg) 49 | policy = ppo_runner.get_inference_policy(device=env.device) 50 | 51 | # export policy as a jit module and as onnx model (used to run it from C++) 52 | if EXPORT_POLICY: 53 | path = os.path.join( 54 | LEGGED_GYM_ROOT_DIR, 55 | "logs", 56 | train_cfg.runner.experiment_name, 57 | "exported", 58 | "policies", 59 | ) 60 | name = "policy" 61 | export_policy_as_jit(ppo_runner.alg.actor_critic, path, filename=f"{name}.pt") 62 | export_policy_as_onnx(ppo_runner.alg.actor_critic, path, filename=f"{name}.onnx") 63 | print("Exported policy to: ", path) 64 | 65 | logger = Logger(env.dt) 66 | robot_index = 1 # which robot is used for logging 67 | joint_index = 3 # which joint is used for logging 68 | stop_state_log = 100 # number of steps before plotting states 69 | stop_rew_log = env.max_episode_length + 1 # number of steps before print average episode rewards 70 | camera_position = np.array(env_cfg.viewer.pos, dtype=np.float64) 71 | camera_direction = np.array(env_cfg.viewer.lookat) - np.array(env_cfg.viewer.pos) 72 | 73 | env.keyboard_controller.print_options() 74 | env.style_selector[:] = 0 75 | 76 | for i in range(10 * int(env.max_episode_length)): 77 | env.update_keyboard_events() 78 | actions = policy(obs.detach()) 79 | obs, _, rews, dones, infos = env.step(actions.detach()) 80 | 81 | if MOVE_CAMERA: 82 | camera_position = env.root_states[0, :3].cpu().numpy() 83 | camera_position[1] -= 2.0 84 | camera_position[2] = 1.0 85 | env.set_camera(camera_position, camera_position + camera_direction) 86 | 87 | if i < stop_state_log: 88 | logger.log_states( 89 | { 90 | "dof_pos_target": (actions * env.cfg.control.action_scale + env.default_dof_pos)[robot_index, joint_index].item(), 91 | "dof_pos": env.dof_pos[robot_index, joint_index].item(), 92 | "dof_vel": env.dof_vel[robot_index, joint_index].item(), 93 | "dof_torque": env.torques[robot_index, joint_index].item(), 94 | "command_x": env.commands[robot_index, 0].item(), 95 | "command_y": 0.0, 96 | "command_yaw": 0.0, 97 | "base_vel_x": env.base_lin_vel[robot_index, 0].item(), 98 | "base_vel_y": env.base_lin_vel[robot_index, 1].item(), 99 | "base_vel_z": env.base_lin_vel[robot_index, 2].item(), 100 | "base_vel_yaw": env.base_ang_vel[robot_index, 2].item(), 101 | "contact_forces_z": env.contact_forces[robot_index, env.feet_indices, 2].cpu().numpy(), 102 | } 103 | ) 104 | elif i == stop_state_log: 105 | logger.plot_states() 106 | if 0 < i < stop_rew_log: 107 | if infos["episode"]: 108 | num_episodes = torch.sum(env.reset_buf).item() 109 | if num_episodes > 0: 110 | logger.log_rewards(infos["episode"], num_episodes) 111 | elif i == stop_rew_log: 112 | logger.print_rewards() 113 | 114 | 115 | if __name__ == "__main__": 116 | args = get_args() 117 | play(args) 118 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main script for launching a training session. 3 | """ 4 | # solo-gym 5 | from solo_gym.envs import task_registry 6 | from solo_gym.utils import get_args 7 | 8 | 9 | def train(args): 10 | env, env_cfg = task_registry.make_env(name=args.task, args=args) 11 | ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args) 12 | ppo_runner.learn( 13 | num_learning_iterations=train_cfg.runner.max_iterations, 14 | init_at_random_ep_len=True, 15 | ) 16 | 17 | 18 | if __name__ == "__main__": 19 | args = get_args() 20 | train(args) 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Installation script for the 'solo_gym' python package.""" 2 | 3 | from setuptools import setup, find_packages 4 | 5 | # Minimum dependencies required prior to installation 6 | INSTALL_REQUIRES = [ 7 | "isaacgym", 8 | "matplotlib", 9 | "tensorboard", 10 | "torch>=1.4.0", 11 | "torchvision>=0.5.0", 12 | "numpy>=1.16.4,<=1.22.4", 13 | "setuptools==59.5.0", 14 | "gym>=0.17.1", 15 | "GitPython", 16 | ] 17 | 18 | # Installation operation 19 | setup( 20 | name="solo_gym", 21 | version="1.0.0", 22 | author="Chenhao Li", 23 | packages=find_packages(), 24 | author_email="chenhli@ethz.ch", 25 | description="Isaac Gym environments for Solo", 26 | install_requires=INSTALL_REQUIRES, 27 | ) 28 | -------------------------------------------------------------------------------- /solo_gym/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | LEGGED_GYM_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | """Absolute path to the solo-gym repository.""" 5 | 6 | LEGGED_GYM_ENVS_DIR = os.path.join(LEGGED_GYM_ROOT_DIR, "solo_gym", "envs") 7 | """Absolute path to the module `solo_gym.envs` in solo-gym repository.""" 8 | -------------------------------------------------------------------------------- /solo_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Locomotion environments. 3 | ## 4 | # fmt: off 5 | from .base.legged_robot import LeggedRobot 6 | from .solo8.solo8 import Solo8 7 | from .solo8.solo8_config import ( 8 | Solo8FlatCfg, 9 | Solo8FlatCfgPPO 10 | ) 11 | 12 | # fmt: on 13 | 14 | ## 15 | # Task registration 16 | ## 17 | from solo_gym.utils.task_registry import task_registry 18 | 19 | task_registry.register("solo8", Solo8, Solo8FlatCfg, Solo8FlatCfgPPO) -------------------------------------------------------------------------------- /solo_gym/envs/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .legged_robot import LeggedRobot 2 | from .legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO 3 | 4 | __all__ = ["LeggedRobot", "LeggedRobotCfg", "LeggedRobotCfgPPO"] 5 | -------------------------------------------------------------------------------- /solo_gym/envs/base/legged_robot_config.py: -------------------------------------------------------------------------------- 1 | # solo-gym 2 | from solo_gym.utils.base_config import BaseConfig 3 | 4 | 5 | class LeggedRobotCfg(BaseConfig): 6 | class env: 7 | num_envs = 4096 8 | num_observations = 235 # robot state (48) + height scans (17*11=187) 9 | num_privileged_obs = None # if not None a priviledge_obs_buf will be returned by step() (critic obs for assymetric training). None is returned otherwise 10 | num_actions = 12 # joint positions, velocities or torques 11 | env_spacing = 3.0 # not used with heightfields/trimeshes 12 | send_timeouts = True # send time out information to the algorithm 13 | episode_length_s = 20 # episode length in seconds 14 | 15 | class terrain: 16 | mesh_type = "trimesh" # none, plane, heightfield or trimesh 17 | horizontal_scale = 0.1 # [m] 18 | vertical_scale = 0.005 # [m] 19 | border_size = 25 # [m] 20 | curriculum = True 21 | static_friction = 1.0 22 | dynamic_friction = 1.0 23 | restitution = 0.0 24 | # rough terrain only: 25 | measure_heights = True 26 | # 1mx1.6m rectangle (without center line) 27 | # fmt: off 28 | measured_points_x = [-0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] 29 | measured_points_y = [-0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5] 30 | # fmt: on 31 | selected = False # select a unique terrain type and pass all arguments 32 | terrain_kwargs = None # Dict of arguments for selected terrain 33 | max_init_terrain_level = 5 # starting curriculum state 34 | terrain_length = 8.0 35 | terrain_width = 8.0 36 | num_rows = 10 # number of terrain rows (levels) 37 | num_cols = 20 # number of terrain cols (types) 38 | # terrain types: [smooth slope, rough slope, stairs up, stairs down, discrete] 39 | terrain_proportions = [0.1, 0.1, 0.35, 0.25, 0.2] 40 | # trimesh only: 41 | # slopes above this threshold will be corrected to vertical surfaces 42 | slope_treshold = 0.75 43 | 44 | class commands: 45 | curriculum = False 46 | max_curriculum = 1.0 47 | num_commands = 4 # default: lin_vel_x, lin_vel_y, ang_vel_yaw, heading (in heading mode ang_vel_yaw is recomputed from heading error) 48 | resampling_time = 10.0 # time before commands are changed [s] 49 | heading_command = True # if true: compute ang vel command from heading error 50 | 51 | class ranges: 52 | lin_vel_x = [-1.0, 1.0] # min max [m/s] 53 | lin_vel_y = [-1.0, 1.0] # min max [m/s] 54 | ang_vel_yaw = [-1, 1] # min max [rad/s] 55 | heading = [-3.14, 3.14] # [rad] 56 | 57 | class init_state: 58 | pos = [0.0, 0.0, 1.0] # x,y,z [m] 59 | rot = [0.0, 0.0, 0.0, 1.0] # x,y,z,w [quat] 60 | lin_vel = [0.0, 0.0, 0.0] # x,y,z [m/s] 61 | ang_vel = [0.0, 0.0, 0.0] # x,y,z [rad/s] 62 | default_joint_angles = { # target angles when action = 0.0 63 | "joint_a": 0.0, 64 | "joint_b": 0.0, 65 | } 66 | 67 | class control: 68 | control_type = "P" # P: position, V: velocity, T: torques 69 | # PD Drive parameters: 70 | stiffness = {"joint_a": 10.0, "joint_b": 15.0} # [N*m/rad] 71 | damping = {"joint_a": 1.0, "joint_b": 1.5} # [N*m*s/rad] 72 | # action scale: target angle = actionScale * action + defaultAngle 73 | action_scale = 0.5 74 | # decimation: Number of control action updates @ sim DT per policy DT 75 | decimation = 4 76 | 77 | class asset: 78 | file = "" 79 | foot_name = "None" # name of the feet bodies, used to index body state and contact force tensors 80 | penalize_contacts_on = [] 81 | terminate_after_contacts_on = [] 82 | disable_gravity = False 83 | collapse_fixed_joints = True # merge bodies connected by fixed joints. Specific fixed joints can be kept by adding " <... dont_collapse="true"> 84 | fix_base_link = False # fixe the base of the robot 85 | default_dof_drive_mode = 3 # see GymDofDriveModeFlags (0 is none, 1 is pos tgt, 2 is vel tgt, 3 effort) 86 | self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter 87 | # replace collision cylinders with capsules, leads to faster/more stable simulation 88 | replace_cylinder_with_capsule = True 89 | flip_visual_attachments = True # Some .obj meshes must be flipped from y-up to z-up 90 | enable_joint_force_sensors = False # Check out isaacgym_lib/docs/programming/forcesensors.html 91 | 92 | density = 0.001 93 | angular_damping = 0.0 94 | linear_damping = 0.0 95 | max_angular_velocity = 1000.0 96 | max_linear_velocity = 1000.0 97 | armature = 0.0 98 | thickness = 0.01 99 | 100 | class domain_rand: 101 | randomize_friction = True 102 | friction_range = [0.5, 1.25] 103 | randomize_base_mass = False 104 | added_mass_range = [-1.0, 1.0] 105 | push_robots = True 106 | push_interval_s = 15 # push applied each time interval [s] 107 | max_push_vel_xy = 1.0 # velocity offset added by push [m/s] 108 | 109 | class rewards: 110 | class scales: 111 | termination = -0.0 112 | tracking_lin_vel = 1.0 113 | tracking_ang_vel = 0.5 114 | lin_vel_z = -2.0 115 | ang_vel_xy = -0.05 116 | orientation = -0.0 117 | torques = -0.00001 118 | dof_vel = -0.0 119 | dof_acc = -2.5e-7 120 | base_height = -0.0 121 | feet_air_time = 1.0 122 | collision = -1.0 123 | feet_stumble = -0.0 124 | action_rate = -0.01 125 | stand_still = -0.0 126 | 127 | # if true negative total rewards are clipped at zero (avoids early termination problems) 128 | only_positive_rewards = True 129 | tracking_sigma = 0.25 # tracking reward = exp(-error^2/sigma) 130 | # percentage of urdf limits, values above this limit are penalized 131 | soft_dof_pos_limit = 1.0 132 | soft_dof_vel_limit = 1.0 133 | soft_torque_limit = 1.0 134 | base_height_target = 1.0 135 | max_contact_force = 100.0 # forces above this value are penalized 136 | 137 | class normalization: 138 | class obs_scales: 139 | lin_vel = 2.0 140 | ang_vel = 0.25 141 | dof_pos = 1.0 142 | dof_vel = 0.05 143 | height_measurements = 5.0 144 | 145 | clip_observations = 100.0 146 | clip_actions = 100.0 147 | 148 | class noise: 149 | add_noise = True 150 | noise_level = 1.0 # scales other values 151 | 152 | class noise_scales: 153 | dof_pos = 0.01 154 | dof_vel = 1.5 155 | lin_vel = 0.1 156 | ang_vel = 0.2 157 | gravity = 0.05 158 | height_measurements = 0.1 159 | 160 | # viewer camera: 161 | class viewer: 162 | ref_env = 0 163 | pos = [10, 0, 6] # [m] 164 | lookat = [11.0, 5, 3.0] # [m] 165 | 166 | class sim: 167 | dt = 0.005 168 | substeps = 1 169 | gravity = [0.0, 0.0, -9.81] # [m/s^2] 170 | up_axis = 1 # 0 is y, 1 is z 171 | 172 | class physx: 173 | num_threads = 10 174 | solver_type = 1 # 0: pgs, 1: tgs 175 | num_position_iterations = 4 176 | num_velocity_iterations = 0 177 | contact_offset = 0.01 # [m] 178 | rest_offset = 0.0 # [m] 179 | bounce_threshold_velocity = 0.5 # [m/s] 180 | max_depenetration_velocity = 1.0 181 | max_gpu_contact_pairs = 2 ** 23 # 2**24 -> needed for 8000 envs and more 182 | default_buffer_size_multiplier = 5 183 | # 0: never, 1: last sub-step, 2: all sub-steps (default=2) 184 | contact_collection = 2 185 | 186 | 187 | class LeggedRobotCfgPPO(BaseConfig): 188 | seed = 1 189 | runner_class_name = "OnPolicyRunner" 190 | 191 | class policy: 192 | init_noise_std = 1.0 193 | actor_hidden_dims = [512, 256, 128] 194 | critic_hidden_dims = [512, 256, 128] 195 | activation = "elu" # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid 196 | # only for 'ActorCriticRecurrent': 197 | # rnn_type = 'lstm' 198 | # rnn_hidden_size = 512 199 | # rnn_num_layers = 1 200 | 201 | class algorithm: 202 | # training params 203 | value_loss_coef = 1.0 204 | use_clipped_value_loss = True 205 | clip_param = 0.2 206 | entropy_coef = 0.01 207 | num_learning_epochs = 5 208 | num_mini_batches = 4 # mini batch size = num_envs * nsteps / nminibatches 209 | learning_rate = 1.0e-3 # 5.e-4 210 | schedule = "adaptive" # adaptive, fixed 211 | gamma = 0.99 212 | lam = 0.95 213 | desired_kl = 0.01 214 | max_grad_norm = 1.0 215 | 216 | class runner: 217 | policy_class_name = "ActorCritic" 218 | algorithm_class_name = "PPO" 219 | num_steps_per_env = 24 # per iteration 220 | max_iterations = 1500 # number of policy updates 221 | 222 | # logging 223 | save_interval = 50 # check for potential saves every this many iterations 224 | experiment_name = "test" 225 | run_name = "" 226 | # load and resume 227 | resume = False 228 | load_run = -1 # -1 = last run 229 | checkpoint = -1 # -1 = last saved model 230 | resume_path = None # updated from load_run and chkpt 231 | -------------------------------------------------------------------------------- /solo_gym/envs/base_task.py: -------------------------------------------------------------------------------- 1 | # isaacgym 2 | from isaacgym import gymapi 3 | from isaacgym import gymutil 4 | 5 | # python 6 | import sys 7 | import torch 8 | import abc 9 | from typing import Tuple, Union 10 | 11 | # solo-gym 12 | from solo_gym.utils.base_config import BaseConfig 13 | 14 | 15 | class BaseTask: 16 | """Base class for RL tasks.""" 17 | 18 | def __init__( 19 | self, 20 | cfg: BaseConfig, 21 | sim_params: gymapi.SimParams, 22 | physics_engine: gymapi.SimType, 23 | sim_device: str, 24 | headless: bool, 25 | ): 26 | """Initialize the base class for RL. 27 | 28 | The class initializes the simulation. It also allocates buffers for observations, 29 | actions, rewards, reset, episode length, episode timetout and privileged observations. 30 | 31 | The :obj:`cfg` must contain the following: 32 | 33 | - num_envs (int): Number of environment instances. 34 | - num_observations (int): Number of observations. 35 | - num_privileged_obs (int): Number of privileged observations. 36 | - num_actions (int): Number of actions. 37 | 38 | Note: 39 | If :obj:`cfg.num_privileged_obs` is not :obj:`None`, a buffer for privileged 40 | observations is returned. This is useful for critic observations in asymmetric 41 | actor-critic. 42 | 43 | Args: 44 | cfg (BaseConfig): Configuration for the environment. 45 | sim_params (gymapi.SimParams): The simulation parameters. 46 | physics_engine (gymapi.SimType): Simulation type (must be gymapi.SIM_PHYSX). 47 | sim_device (str): The simulation device (ex: `cuda:0` or `cpu`). 48 | headless (bool): If true, run without rendering. 49 | """ 50 | # copy input arguments into class members 51 | self.sim_params = sim_params 52 | self.physics_engine = physics_engine 53 | self.sim_device = sim_device 54 | self.headless = headless 55 | sim_device_type, self.sim_device_id = gymutil.parse_device_str(self.sim_device) 56 | # env device is GPU only if sim is on GPU and use_gpu_pipeline is True. 57 | # otherwise returned tensors are copied to CPU by PhysX. 58 | if sim_device_type == "cuda" and sim_params.use_gpu_pipeline: 59 | self.device = self.sim_device 60 | else: 61 | self.device = "cpu" 62 | # graphics device for rendering, -1 for no rendering 63 | self.graphics_device_id = self.sim_device_id 64 | if self.headless is True: 65 | self.graphics_device_id = -1 66 | 67 | # store the environment information 68 | self.num_envs = cfg.env.num_envs 69 | self.num_obs = cfg.env.num_observations 70 | self.num_privileged_obs = cfg.env.num_privileged_obs 71 | self.num_actions = cfg.env.num_actions 72 | 73 | # optimization flags for pytorch JIT 74 | torch._C._jit_set_profiling_mode(False) 75 | torch._C._jit_set_profiling_executor(False) 76 | 77 | # allocate buffers 78 | self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float) 79 | self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float) 80 | self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long) 81 | self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long) 82 | self.time_out_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) 83 | if self.num_privileged_obs is not None: 84 | self.privileged_obs_buf = torch.zeros( 85 | self.num_envs, 86 | self.num_privileged_obs, 87 | device=self.device, 88 | dtype=torch.float, 89 | ) 90 | else: 91 | self.privileged_obs_buf = None 92 | # allocate dictionary to store metrics 93 | self.extras = {} 94 | 95 | # create envs, sim 96 | self.gym = gymapi.acquire_gym() 97 | self.create_sim() 98 | self.gym.prepare_sim(self.sim) 99 | 100 | # create viewer 101 | # Todo: read from config 102 | self.enable_viewer_sync = True 103 | self.viewer = None 104 | # if running with a viewer, set up keyboard shortcuts and camera 105 | if self.headless is False: 106 | # subscribe to keyboard shortcuts 107 | self.viewer = self.gym.create_viewer(self.sim, gymapi.CameraProperties()) 108 | self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_ESCAPE, "QUIT") 109 | self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_V, "toggle_viewer_sync") 110 | 111 | def __del__(self): 112 | """Cleanup in the end.""" 113 | try: 114 | if self.sim is not None: 115 | self.gym.destroy_sim(self.sim) 116 | if self.viewer is not None: 117 | self.gym.destroy_viewer(self.viewer) 118 | except: 119 | pass 120 | 121 | """ 122 | Properties. 123 | """ 124 | 125 | def get_observations(self) -> torch.Tensor: 126 | return self.obs_buf 127 | 128 | def get_privileged_observations(self) -> Union[torch.Tensor, None]: 129 | return self.privileged_obs_buf 130 | 131 | """ 132 | Operations. 133 | """ 134 | 135 | def set_camera_view(self, position: Tuple[float, float, float], lookat: Tuple[float, float, float]) -> None: 136 | """Set camera position and direction.""" 137 | cam_pos = gymapi.Vec3(position[0], position[1], position[2]) 138 | cam_target = gymapi.Vec3(lookat[0], lookat[1], lookat[2]) 139 | self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target) 140 | 141 | def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 142 | """Reset all environment instances. 143 | 144 | Returns: 145 | Tuple[torch.Tensor, torch.Tensor | None]: Tuple containing the observations and privileged observations. 146 | """ 147 | # reset environments 148 | self.reset_idx(torch.arange(self.num_envs, device=self.device)) 149 | # perform single-step to get observations 150 | zero_actions = torch.zeros(self.num_envs, self.num_actions, device=self.device, requires_grad=False) 151 | obs, privileged_obs, _, _, _ = self.step(zero_actions) 152 | # return obs 153 | return obs, privileged_obs 154 | 155 | @abc.abstractmethod 156 | def step( 157 | self, actions: torch.Tensor 158 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]: 159 | """Apply input action on the environment. 160 | 161 | Args: 162 | actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions) 163 | 164 | Returns: 165 | Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, dict]: 166 | A tuple containing the observations, privileged observations, rewards, dones and 167 | extra information (metrics). 168 | """ 169 | raise NotImplementedError 170 | 171 | def render(self, sync_frame_time=True): 172 | """Render the viewer.""" 173 | if self.viewer: 174 | # check for window closed 175 | if self.gym.query_viewer_has_closed(self.viewer): 176 | sys.exit() 177 | # check for keyboard events 178 | for evt in self.gym.query_viewer_action_events(self.viewer): 179 | if evt.action == "QUIT" and evt.value > 0: 180 | sys.exit() 181 | elif evt.action == "toggle_viewer_sync" and evt.value > 0: 182 | self.enable_viewer_sync = not self.enable_viewer_sync 183 | # fetch results 184 | if self.device != "cpu": 185 | self.gym.fetch_results(self.sim, True) 186 | # step graphics 187 | if self.enable_viewer_sync: 188 | self.gym.step_graphics(self.sim) 189 | self.gym.draw_viewer(self.viewer, self.sim, True) 190 | if sync_frame_time: 191 | self.gym.sync_frame_time(self.sim) 192 | else: 193 | self.gym.poll_viewer_events(self.viewer) 194 | 195 | """ 196 | Protected Methods. 197 | """ 198 | 199 | @abc.abstractmethod 200 | def create_sim(self): 201 | """Creates simulation, terrain and environments""" 202 | raise NotImplementedError 203 | 204 | @abc.abstractmethod 205 | def reset_idx(self, env_ids: torch.Tensor) -> None: 206 | """Resets the MDP for given environment instances. 207 | 208 | Args: 209 | env_ids (torch.Tensor): A tensor containing indices of environment instances to reset. 210 | """ 211 | raise NotImplementedError 212 | -------------------------------------------------------------------------------- /solo_gym/envs/solo8/solo8.py: -------------------------------------------------------------------------------- 1 | # python 2 | import torch 3 | 4 | # solo-gym 5 | from solo_gym.envs import LeggedRobot 6 | from .solo8_config import Solo8FlatCfg 7 | from isaacgym import gymtorch, gymapi 8 | from isaacgym.torch_utils import ( 9 | torch_rand_float, 10 | quat_rotate, 11 | quat_rotate_inverse, 12 | ) 13 | from learning.datasets.motion_loader import MotionLoader 14 | from typing import Dict 15 | from solo_gym.utils.keyboard_controller import KeyboardAction, Delta 16 | 17 | class Solo8(LeggedRobot): 18 | cfg: Solo8FlatCfg 19 | 20 | def __init__(self, cfg, sim_params, physics_engine, sim_device, headless): 21 | super().__init__(cfg, sim_params, physics_engine, sim_device, headless) 22 | # load AMP components 23 | self.reference_motion_file = self.cfg.motion_loader.reference_motion_file 24 | self.test_mode = self.cfg.motion_loader.test_mode 25 | self.test_observation_dim = self.cfg.motion_loader.test_observation_dim 26 | self.reference_observation_horizon = self.cfg.motion_loader.reference_observation_horizon 27 | self.motion_loader = MotionLoader( 28 | device=self.device, 29 | motion_file=self.reference_motion_file, 30 | corruption_level=self.cfg.motion_loader.corruption_level, 31 | reference_observation_horizon=self.reference_observation_horizon, 32 | test_mode=self.test_mode, 33 | test_observation_dim=self.test_observation_dim 34 | ) 35 | self.reference_state_idx_dict = self.motion_loader.state_idx_dict 36 | self.reference_full_dim = sum([ids[1] - ids[0] for ids in self.reference_state_idx_dict.values()]) 37 | self.reference_observation_dim = sum([ids[1] - ids[0] for state, ids in self.reference_state_idx_dict.items() if ((state != "base_pos") and (state != "base_quat"))]) 38 | self.cassi_states = torch.zeros( 39 | self.num_envs, self.reference_full_dim, dtype=torch.float, device=self.device, requires_grad=False 40 | ) 41 | self.discriminator = None # assigned in runner 42 | self.cassi_observation_buf = torch.zeros( 43 | self.num_envs, self.reference_observation_horizon, self.reference_observation_dim, dtype=torch.float, device=self.device, requires_grad=False 44 | ) 45 | self.cassi_observation_buf[:, -1] = self.get_cassi_observations() 46 | 47 | # load DISDAIN components 48 | self.dis_observation_horizon = self.cfg.discriminator_ensemble.observation_horizon 49 | self.dis_state_idx_dict = self.cfg.discriminator_ensemble.state_idx_dict 50 | self.dis_full_dim = sum([ids[1] - ids[0] for ids in self.dis_state_idx_dict.values()]) 51 | self.dis_observation_dim = sum([ids[1] - ids[0] for state, ids in self.dis_state_idx_dict.items() if ((state != "base_pos") and (state != "base_quat"))]) 52 | self.dis_observation_start_dim = self.cfg.discriminator_ensemble.observation_start_dim 53 | self.dis_num_classes = self.cfg.discriminator_ensemble.num_classes 54 | self.dis_states = torch.zeros( 55 | self.num_envs, self.dis_full_dim, dtype=torch.float, device=self.device, requires_grad=False 56 | ) 57 | self.dis_observation_buf = torch.zeros( 58 | self.num_envs, self.dis_observation_horizon, self.dis_observation_dim, dtype=torch.float, device=self.device, requires_grad=False 59 | ) 60 | self.dis_observation_buf[:, -1] = self.get_dis_observations() 61 | self.discriminator_ensemble = None 62 | 63 | def post_physics_step(self): 64 | """check terminations, compute observations and rewards 65 | calls self._post_physics_step_callback() for common computations 66 | calls self._draw_debug_vis() if needed 67 | """ 68 | self.gym.refresh_actor_root_state_tensor(self.sim) 69 | self.gym.refresh_net_contact_force_tensor(self.sim) 70 | self.gym.refresh_rigid_body_state_tensor(self.sim) 71 | if self.cfg.asset.enable_joint_force_sensors: 72 | self.gym.refresh_dof_force_tensor(self.sim) 73 | 74 | self.episode_length_buf += 1 75 | self.common_step_counter += 1 76 | 77 | # prepare quantities 78 | self.base_quat[:] = self.root_states[:, 3:7] 79 | self.base_lin_vel[:] = quat_rotate_inverse(self.base_quat, self.root_states[:, 7:10]) 80 | self.base_ang_vel[:] = quat_rotate_inverse(self.base_quat, self.root_states[:, 10:13]) 81 | self.projected_gravity[:] = quat_rotate_inverse(self.base_quat, self.gravity_vec) 82 | self.base_height[:] = torch.mean(self.root_states[:, 2].unsqueeze(1) - self.measured_heights, dim=1, keepdim=True) 83 | self.base_lin_vel_x[:] = self.base_lin_vel[:, :1] 84 | self.base_pos_x[:] = self.root_states[:, :1] - self.env_origins[:, :1] 85 | 86 | self._post_physics_step_callback() 87 | 88 | # compute observations, rewards, resets, ... 89 | self.check_termination() 90 | self.cassi_record_states() 91 | self.dis_record_states() 92 | self.next_cassi_observations = self.get_cassi_observations() 93 | self.next_dis_observations = self.get_dis_observations() 94 | self.compute_reward() 95 | env_ids = self.reset_buf.nonzero(as_tuple=False).flatten() 96 | self.reset_idx(env_ids) 97 | self.compute_observations() # in some cases a simulation step might be required to refresh some obs (for example body positions) 98 | 99 | self.last_actions[:] = self.actions[:] 100 | self.last_dof_vel[:] = self.dof_vel[:] 101 | self.last_root_vel[:] = self.root_states[:, 7:13] 102 | self.update_cassi_observation_buf() 103 | self.update_dis_observation_buf() 104 | 105 | if self.viewer and self.enable_viewer_sync and self.debug_viz: 106 | self._draw_debug_vis() 107 | 108 | def update_cassi_observation_buf(self): 109 | self.cassi_observation_buf[:, :-1] = self.cassi_observation_buf[:, 1:].clone() 110 | self.cassi_observation_buf[:, -1] = self.next_cassi_observations.clone() 111 | 112 | def get_cassi_observation_buf(self): 113 | return self.cassi_observation_buf.clone() 114 | 115 | def update_dis_observation_buf(self): 116 | self.dis_observation_buf[:, :-1] = self.dis_observation_buf[:, 1:].clone() 117 | self.dis_observation_buf[:, -1] = self.next_dis_observations.clone() 118 | 119 | def get_dis_observation_buf(self): 120 | return self.dis_observation_buf.clone() 121 | 122 | def compute_observations(self): 123 | """Computes observations""" 124 | self.obs_buf = torch.cat( 125 | ( 126 | self.base_lin_vel * self.obs_scales.lin_vel, 127 | self.base_ang_vel * self.obs_scales.ang_vel, 128 | self.projected_gravity, 129 | self.commands * self.commands_scale, 130 | (self.dof_pos - self.default_dof_pos) * self.obs_scales.dof_pos, 131 | self.dof_vel * self.obs_scales.dof_vel, 132 | self.actions, 133 | ), 134 | dim=-1, 135 | ) 136 | style_selector_one_hot = torch.nn.functional.one_hot(self.style_selector, self.dis_num_classes) 137 | self.obs_buf = torch.cat((self.obs_buf, style_selector_one_hot), dim=-1) 138 | # add noise if needed 139 | if self.add_noise: 140 | self.obs_buf += (2 * torch.rand_like(self.obs_buf) - 1) * self.noise_scale_vec 141 | 142 | def _init_buffers(self): 143 | super()._init_buffers() 144 | self.commands_scale = torch.tensor([self.obs_scales.lin_vel], device=self.device, requires_grad=False) 145 | self.hip_indices = torch.tensor([i for i in range(self.num_dof) if "KFE" not in self.dof_names[i]], device=self.device, requires_grad=False) 146 | self.knee_indices = torch.tensor([i for i in range(self.num_dof) if "KFE" in self.dof_names[i]], device=self.device, requires_grad=False) 147 | self.desired_torques = torch.zeros_like(self.torques) 148 | self.max_torque = torch.zeros_like(self.torques) 149 | self.min_torque = torch.zeros_like(self.torques) 150 | self.dof_vel_limits = torch.zeros_like(self.dof_vel) 151 | self.max_torque[:] = self.cfg.control.torque_limit 152 | self.min_torque[:] = -self.cfg.control.torque_limit 153 | self.dof_vel_limits[:, self.hip_indices] = 14.0 154 | self.dof_vel_limits[:, self.knee_indices] = 5.0 155 | self.base_height = torch.zeros(self.num_envs, 1, dtype=torch.float, device=self.device, requires_grad=False) 156 | self.style_selector = torch.zeros(self.num_envs, dtype=torch.long, device=self.device, requires_grad=False) 157 | self.base_lin_vel_x = torch.zeros(self.num_envs, 1, dtype=torch.float, device=self.device, requires_grad=False) 158 | self.base_pos_x = torch.zeros(self.num_envs, 1, dtype=torch.float, device=self.device, requires_grad=False) 159 | 160 | def _resample_commands(self, env_ids): 161 | self.commands[env_ids, 0] = torch_rand_float( 162 | self.command_ranges["lin_vel_x"][0], 163 | self.command_ranges["lin_vel_x"][1], 164 | (len(env_ids), 1), 165 | device=self.device, 166 | ).squeeze(1) 167 | 168 | # set small commands to zero 169 | self.commands[env_ids, :1] *= (torch.norm(self.commands[env_ids, :1], dim=1) > 0.2).unsqueeze(1) 170 | 171 | def _resample_style_selector(self, env_ids): 172 | self.style_selector[env_ids] = torch.randint(self.dis_num_classes, (len(env_ids),), device=self.device, requires_grad=False) 173 | 174 | def _get_keyboard_events(self) -> Dict[str, KeyboardAction]: 175 | """Simple keyboard controller for linear and angular velocity.""" 176 | 177 | def print_command(): 178 | print("[LeggedRobot]: Environment 0 command: ", self.commands[0]) 179 | print("[LeggedRobot]: Environment 0 style selector: ", self.style_selector[0]) 180 | 181 | key_board_events = { 182 | "u": Delta("lin_vel_x", amount=0.1, variable_reference=self.commands[:, 0], callback=print_command), 183 | "j": Delta("lin_vel_x", amount=-0.1, variable_reference=self.commands[:, 0], callback=print_command), 184 | "h": Delta("style_selector", amount=1, variable_reference=self.style_selector, callback=print_command), 185 | "k": Delta("style_selector", amount=-1, variable_reference=self.style_selector, callback=print_command), 186 | } 187 | return key_board_events 188 | 189 | def _get_noise_scale_vec(self, cfg): 190 | noise_vec = torch.zeros_like(self.obs_buf[0]) 191 | self.add_noise = self.cfg.noise.add_noise 192 | noise_scales = self.cfg.noise.noise_scales 193 | noise_level = self.cfg.noise.noise_level 194 | noise_vec[:3] = noise_scales.lin_vel * noise_level * self.obs_scales.lin_vel 195 | noise_vec[3:6] = noise_scales.ang_vel * noise_level * self.obs_scales.ang_vel 196 | noise_vec[6:9] = noise_scales.gravity * noise_level 197 | noise_vec[9:10] = 0.0 # commands 198 | noise_vec[10:18] = noise_scales.dof_pos * noise_level * self.obs_scales.dof_pos 199 | noise_vec[18:26] = noise_scales.dof_vel * noise_level * self.obs_scales.dof_vel 200 | noise_vec[26:34] = 0.0 # previous actions 201 | return noise_vec 202 | 203 | def reset_idx(self, env_ids): 204 | """Reset some environments. 205 | Calls self._reset_dofs(env_ids), self._reset_root_states(env_ids), and self._resample_commands(env_ids) 206 | [Optional] calls self._update_terrain_curriculum(env_ids), self.update_command_curriculum(env_ids) and 207 | Logs episode info 208 | Resets some buffers 209 | 210 | Args: 211 | env_ids (list[int]): List of environment ids which must be reset 212 | """ 213 | if len(env_ids) == 0: 214 | return 215 | # update curriculum 216 | if self.cfg.terrain.curriculum: 217 | self._update_terrain_curriculum(env_ids) 218 | # avoid updating command curriculum at each step since the maximum command is common to all envs 219 | if self.cfg.commands.curriculum and (self.common_step_counter % self.max_episode_length == 0): 220 | self.update_command_curriculum(env_ids) 221 | 222 | # reset robot states 223 | if self.cfg.domain_rand.reference_state_initialization: 224 | frames = self.motion_loader.get_frames(len(env_ids)) 225 | env_ids_mask = torch.rand(len(env_ids), device=self.device, requires_grad=False) <= self.cfg.domain_rand.reference_state_initialization_prob 226 | else: 227 | frames = None 228 | env_ids_mask = None 229 | self._reset_dofs(env_ids, frames, env_ids_mask) 230 | self._reset_root_states(env_ids, frames, env_ids_mask) 231 | 232 | self._resample_commands(env_ids) 233 | 234 | # reset buffers 235 | self.last_actions[env_ids] = 0.0 236 | self.last_dof_vel[env_ids] = 0.0 237 | self.feet_air_time[env_ids] = 0.0 238 | self.episode_length_buf[env_ids] = 0 239 | self.reset_buf[env_ids] = 1 240 | # fill extras 241 | self.extras["episode"] = {} 242 | for key in self.episode_sums.keys(): 243 | self.extras["episode"]["rew_" + key] = ( 244 | torch.mean(self.episode_sums[key][env_ids]) / self.max_episode_length_s 245 | ) 246 | self.episode_sums[key][env_ids] = 0.0 247 | # log additional curriculum info 248 | if self.cfg.terrain.curriculum: 249 | self.extras["episode"]["terrain_level"] = torch.mean(self.terrain_levels.float()) 250 | if self.cfg.commands.curriculum: 251 | self.extras["episode"]["max_command_x"] = self.command_ranges["lin_vel_x"][1] 252 | # send timeout info to the algorithm 253 | if self.cfg.env.send_timeouts: 254 | self.extras["time_outs"] = self.time_out_buf 255 | self._resample_style_selector(env_ids) 256 | 257 | def get_cassi_observations(self): 258 | if self.test_mode: 259 | cassi_obs = torch.zeros(self.num_envs, self.test_observation_dim, device=self.device, requires_grad=False) 260 | else: 261 | cassi_obs = self.cassi_states[:, self.motion_loader.observation_start_dim:].clone() 262 | return cassi_obs 263 | 264 | def cassi_record_states(self): 265 | for key, value in self.reference_state_idx_dict.items(): 266 | if key == "base_pos": 267 | self.cassi_states[:, value[0]: value[1]] = self._get_base_pos() 268 | elif key == "feet_pos": 269 | self.cassi_states[:, value[0]: value[1]] = self._get_feet_pos() 270 | else: 271 | self.cassi_states[:, value[0]: value[1]] = getattr(self, key) 272 | 273 | def get_dis_observations(self): 274 | dis_obs = self.dis_states[:, self.dis_observation_start_dim:].clone() 275 | return dis_obs 276 | 277 | def get_style_selector(self): 278 | style_selector = self.style_selector.clone() 279 | return style_selector 280 | 281 | def dis_record_states(self): 282 | for key, value in self.dis_state_idx_dict.items(): 283 | if key == "base_pos": 284 | self.dis_states[:, value[0]: value[1]] = self._get_base_pos() 285 | elif key == "feet_pos": 286 | self.dis_states[:, value[0]: value[1]] = self._get_feet_pos() 287 | else: 288 | self.dis_states[:, value[0]: value[1]] = getattr(self, key) 289 | 290 | def _get_base_pos(self): 291 | return self.root_states[:, :3] - self.env_origins[:, :3] 292 | 293 | def _get_feet_pos(self): 294 | feet_pos_global = self.rigid_body_pos[:, self.feet_indices, :3] 295 | feet_pos_local = torch.zeros_like(feet_pos_global) 296 | for i in range(len(self.feet_indices)): 297 | feet_pos_local[:, i] = quat_rotate_inverse( 298 | self.base_quat, 299 | feet_pos_global[:, i] 300 | ) 301 | return feet_pos_local.flatten(1, 2) 302 | 303 | def set_camera(self, position, lookat): 304 | """ Set camera position and direction 305 | """ 306 | cam_pos = gymapi.Vec3(position[0], position[1], position[2]) 307 | cam_target = gymapi.Vec3(lookat[0], lookat[1], lookat[2]) 308 | self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target) 309 | 310 | def _reset_dofs(self, env_ids, frames, env_ids_mask): 311 | if frames is not None: 312 | self.dof_pos[env_ids[env_ids_mask]] = self.motion_loader.get_dof_pos(frames[env_ids_mask]) 313 | self.dof_vel[env_ids[env_ids_mask]] = self.motion_loader.get_dof_vel(frames[env_ids_mask]) 314 | self.dof_pos[env_ids[~env_ids_mask]] = self.default_dof_pos * torch_rand_float( 315 | 0.5, 1.5, (len(env_ids[~env_ids_mask]), self.num_dof), device=self.device 316 | ) 317 | self.dof_vel[env_ids[~env_ids_mask]] = 0.0 318 | 319 | env_ids_int32 = env_ids.to(dtype=torch.int32) 320 | self.gym.set_dof_state_tensor_indexed(self.sim, 321 | gymtorch.unwrap_tensor(self.dof_state), 322 | gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) 323 | else: 324 | super()._reset_dofs(env_ids) 325 | 326 | def _reset_root_states(self, env_ids, frames, env_ids_mask): 327 | if frames is not None: 328 | root_pos = self.motion_loader.get_base_pos(frames[env_ids_mask]) 329 | root_pos[:, :2] = root_pos[:, :2] + self.env_origins[env_ids[env_ids_mask], :2] 330 | self.root_states[env_ids[env_ids_mask], :3] = root_pos 331 | root_ori = self.motion_loader.get_base_quat(frames[env_ids_mask]) 332 | self.root_states[env_ids[env_ids_mask], 3:7] = root_ori 333 | self.root_states[env_ids[env_ids_mask], 7:10] = quat_rotate(root_ori, self.motion_loader.get_base_lin_vel(frames[env_ids_mask])) 334 | self.root_states[env_ids[env_ids_mask], 10:13] = quat_rotate(root_ori, self.motion_loader.get_base_ang_vel(frames[env_ids_mask])) 335 | 336 | if self.custom_origins: 337 | self.root_states[env_ids[~env_ids_mask]] = self.base_init_state 338 | self.root_states[env_ids[~env_ids_mask], :3] += self.env_origins[env_ids[~env_ids_mask]] 339 | self.root_states[env_ids[~env_ids_mask], :2] += torch_rand_float( 340 | -1.0, 1.0, (len(env_ids[~env_ids_mask]), 2), device=self.device 341 | ) 342 | else: 343 | self.root_states[env_ids[~env_ids_mask]] = self.base_init_state 344 | self.root_states[env_ids[~env_ids_mask], :3] += self.env_origins[env_ids[~env_ids_mask]] 345 | self.root_states[env_ids[~env_ids_mask], 7:13] = torch_rand_float( 346 | -0.5, 0.5, (len(env_ids[~env_ids_mask]), 6), device=self.device 347 | ) 348 | 349 | env_ids_int32 = env_ids.to(dtype=torch.int32) 350 | self.gym.set_actor_root_state_tensor_indexed(self.sim, 351 | gymtorch.unwrap_tensor(self.root_states), 352 | gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) 353 | else: 354 | super()._reset_root_states(env_ids) 355 | 356 | def _compute_torques(self, actions): 357 | # save desired torques before clipping 358 | self.desired_torques = super()._compute_torques(actions) 359 | return torch.clip(self.desired_torques, min=self.min_torque, max=self.max_torque) 360 | 361 | def _reward_tracking_lin_vel(self): 362 | # Tracking of linear velocity commands (only x axis) 363 | lin_vel_error = torch.sum(torch.square(self.commands[:, :1] - self.base_lin_vel[:, :1]), dim=1) 364 | return torch.exp(-lin_vel_error / self.cfg.rewards.tracking_sigma) 365 | 366 | def _reward_impacts(self): 367 | body_indices = [1, 2, 4, 5, 7, 8, 10, 11] # knees and feet 368 | # body_indices = self.feet_indices 369 | # acc = torch.norm(self.last_bodies_vel[:, body_indices] - self.rigid_body_states[:, body_indices, 7:10], dim=2) / self.dt 370 | acc = torch.abs((self.last_bodies_vel[:, body_indices, 2] - self.rigid_body_states[:, body_indices, 9]) / self.dt) 371 | acc = acc.clip(min=20.) - 20. 372 | # self.last_bodies_vel[:] = self.rigid_body_states[:, :, 7:10] 373 | # print(acc[0]) 374 | # acc = (acc / (torch.norm(self.rigid_body_states[:, body_indices, 7:10], dim=2) + 1.e-6)).clip(min=50.) - 50. 375 | self.impact_reward = acc[0, 0] #torch.sum(torch.square(acc), dim=1) * (self.episode_length_buf != 1) 376 | return torch.sum(torch.square(acc), dim=1) * (self.episode_length_buf != 1) 377 | 378 | def _reward_cassi_style(self): 379 | cassi_observation_buf = torch.cat((self.cassi_observation_buf[:, 1:], self.next_cassi_observations.unsqueeze(1)), dim=1) 380 | cassi_style_reward = self.discriminator.predict_cassi_reward(cassi_observation_buf) 381 | return cassi_style_reward 382 | 383 | def _reward_dis_skill(self): 384 | dis_observation_buf = torch.cat((self.dis_observation_buf[:, 1:], self.next_dis_observations.unsqueeze(1)), dim=1) 385 | style_selector = self.get_style_selector() 386 | dis_skill_reward = self.discriminator_ensemble.compute_dis_skill_reward(dis_observation_buf, style_selector) 387 | return dis_skill_reward 388 | 389 | def _reward_dis_disdain(self): 390 | dis_observation_buf = torch.cat((self.dis_observation_buf[:, 1:], self.next_dis_observations.unsqueeze(1)), dim=1) 391 | dis_disdain_reward = self.discriminator_ensemble.compute_dis_disdain_reward(dis_observation_buf) 392 | return dis_disdain_reward 393 | 394 | def _reward_ang_vel_x(self): 395 | return torch.abs(self.base_ang_vel[:, 0]) 396 | 397 | def _reward_lin_vel_y(self): 398 | return torch.abs(self.base_lin_vel[:, 1]) 399 | 400 | def _reward_ang_vel_z(self): 401 | return torch.abs(self.base_ang_vel[:, 2]) 402 | -------------------------------------------------------------------------------- /solo_gym/envs/solo8/solo8_config.py: -------------------------------------------------------------------------------- 1 | from solo_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO 2 | from solo_gym import LEGGED_GYM_ROOT_DIR 3 | 4 | 5 | class Solo8FlatCfg(LeggedRobotCfg): 6 | class env(LeggedRobotCfg.env): 7 | num_observations = 40 # 34 + cla_num_classes 8 | num_actions = 8 9 | 10 | class terrain(LeggedRobotCfg.terrain): 11 | mesh_type = "plane" 12 | curriculum = False 13 | measure_heights = False 14 | terrain_proportions = [0.0, 1.0] 15 | num_rows = 5 16 | max_init_terrain_level = 4 17 | 18 | class init_state(LeggedRobotCfg.init_state): 19 | pos = [0.0, 0.0, 0.35] # x,y,z [m] 20 | default_joint_angles = { # = target angles [rad] when action = 0.0 21 | "FL_HFE": 1.0, 22 | "HL_HFE": -1.0, 23 | "FR_HFE": 1.0, 24 | "HR_HFE": -1.0, 25 | 26 | "FL_KFE": -2.0, 27 | "HL_KFE": 2.0, 28 | "FR_KFE": -2.0, 29 | "HR_KFE": 2.0, 30 | } 31 | 32 | class control(LeggedRobotCfg.control): 33 | # PD Drive parameters: 34 | stiffness = {'HFE': 5.0, 'KFE': 5.0} # [N*m/rad] 35 | damping = {'HFE': 0.1, 'KFE': 0.1} # [N*m*s/rad] 36 | torque_limit = 2.5 37 | 38 | class asset(LeggedRobotCfg.asset): 39 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/solo8/urdf/solo8.urdf' 40 | foot_name = "FOOT" 41 | terminate_after_contacts_on = ["base", "UPPER"] 42 | self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter 43 | 44 | class rewards(LeggedRobotCfg.rewards): 45 | soft_dof_pos_limit = 0.85 46 | soft_dof_vel_limit = 0.9 47 | soft_torque_limit = 0.9 48 | base_height_target = 0.24 49 | max_contact_force = 350.0 50 | only_positive_rewards = True 51 | class scales(LeggedRobotCfg.rewards.scales): 52 | orientation = -0.0 53 | torques = -0.000025 54 | feet_air_time = 0.5 55 | collision = -0.0 56 | lin_vel_z = -0.0 57 | ang_vel_xy = -0.0 58 | stand_still = -0.02 59 | base_height = -0.0 60 | tracking_lin_vel = 1.0 61 | tracking_ang_vel = 0.0 62 | ang_vel_x = -0.02 63 | lin_vel_y = -0.02 64 | ang_vel_z = -0.02 65 | cassi_style = 1.0 66 | dis_skill = 1.0 67 | dis_disdain = 10.0 68 | 69 | class commands(LeggedRobotCfg.commands): 70 | num_commands = 1 71 | curriculum = False 72 | max_curriculum = 5.0 73 | resampling_time = 5.0 74 | heading_command = False 75 | class ranges(LeggedRobotCfg.commands.ranges): 76 | lin_vel_x = [0.0, 1.0] 77 | 78 | class domain_rand(LeggedRobotCfg.domain_rand): 79 | push_robots = True 80 | max_push_vel_xy = 0.5 81 | randomize_base_mass = True 82 | added_mass_range = [-0.5, 1.0] 83 | reference_state_initialization = True 84 | reference_state_initialization_prob = 0.85 85 | 86 | class motion_loader: 87 | reference_motion_file = LEGGED_GYM_ROOT_DIR + "/resources/robots/solo8/datasets/motion_data.pt" 88 | corruption_level = 0.0 89 | reference_observation_horizon = 2 90 | test_mode = False 91 | test_observation_dim = None # observation_dim of reference motion 92 | 93 | class discriminator_ensemble: 94 | state_idx_dict = { 95 | "base_pos": [0, 3], 96 | "base_quat": [3, 7], 97 | "base_lin_vel": [7, 10], 98 | "base_ang_vel": [10, 13], 99 | "projected_gravity": [13, 16], 100 | "base_height": [16, 17], 101 | "dof_pos": [17, 25], 102 | "dof_vel": [25, 33], 103 | } 104 | num_classes = 6 105 | observation_start_dim = 7 106 | observation_horizon = 8 107 | 108 | class Solo8FlatCfgPPO(LeggedRobotCfgPPO): 109 | runner_class_name = "CASSIOnPolicyRunner" 110 | class policy(LeggedRobotCfgPPO.policy): 111 | actor_hidden_dims = [128, 128, 128] 112 | critic_hidden_dims = [128, 128, 128] 113 | init_noise_std = 1.0 114 | 115 | class discriminator: 116 | style_reward_function = "quad_mapping" # log_mapping, quad_mapping, wasserstein_mapping 117 | shape = [512, 256] 118 | 119 | class discriminator_ensemble: 120 | shape = [256, 256] 121 | ensemble_size = 5 122 | incremental_input = False 123 | 124 | class algorithm(LeggedRobotCfgPPO.algorithm): 125 | cassi_replay_buffer_size = 100000 126 | discriminator_ensemble_replay_buffer_size = 100000 127 | policy_learning_rate = 1e-3 128 | discriminator_learning_rate = 1e-4 129 | discriminator_momentum = 0.5 130 | discriminator_weight_decay = 1e-4 131 | discriminator_gradient_penalty_coef = 5 132 | discriminator_loss_function = "MSELoss" # MSELoss, BCEWithLogitsLoss, WassersteinLoss 133 | discriminator_num_mini_batches = 80 134 | discriminator_ensemble_learning_rate = 1e-4 135 | discriminator_ensemble_weight_decay = 0.0005 136 | discriminator_ensemble_num_mini_batches = 80 137 | 138 | class runner(LeggedRobotCfgPPO.runner): 139 | run_name = 'cassi' 140 | experiment_name = 'flat_solo8' 141 | algorithm_class_name = "CASSI" 142 | policy_class_name = "ActorCritic" 143 | load_run = -1 144 | max_iterations = 5000 145 | normalize_style_reward = True 146 | master_classifier_file = LEGGED_GYM_ROOT_DIR + "/resources/robots/solo8/master_classifier/model.pt" 147 | -------------------------------------------------------------------------------- /solo_gym/utils/README.md: -------------------------------------------------------------------------------- 1 | # Legged Gym Utilities 2 | 3 | ## Keyboard Controller 4 | 5 | By overwriting the `_get_keyboard_events()` method, a custom keyboard controller can be added to the environment. The keyboard controller subscribes to IsaacGym's keyboard-system, therefore the events are only caught, if the IsaacGym window is focused. 6 | 7 | ### Example 8 | 9 | ```python 10 | from solo_gym.utils.keyboard_controller import KeyboardAction, Button, Delta, Switch 11 | 12 | def _get_keyboard_events(self) -> Dict[str, KeyboardAction]: 13 | # Simple keyboard controller for linear and angular velocity 14 | 15 | def print_command(): 16 | print(f"New command: {self.commands[0]}") 17 | 18 | key_board_events = { 19 | 'u' : Delta("lin_vel_x", amount = 0.1, variable_reference = self.commands[:, 0], callback = print_command), 20 | 'j' : Delta("lin_vel_x", amount = -0.1, variable_reference = self.commands[:, 0], callback = print_command), 21 | 'h' : Delta("lin_vel_y", amount = 0.1, variable_reference = self.commands[:, 1], callback = print_command), 22 | 'k' : Delta("lin_vel_y", amount = -0.1, variable_reference = self.commands[:, 1], callback = print_command), 23 | 'y' : Delta("ang_vel_z", amount = 0.1, variable_reference = self.commands[:, 2], callback = print_command), 24 | 'i' : Delta("ang_vel_z",amount = -0.1, variable_reference = self.commands[:, 2], callback = print_command), 25 | 'm' : Button("some_var", 0, 1, self.commands[:, someIndex], print_command) 26 | 'n' : Switch("some_other_var", 0, 1, self.commands[:, someIndex], print_command) 27 | } 28 | return key_board_events 29 | ``` 30 | 31 | A parent keyboard can also be extended by calling the `super()` method: 32 | 33 | ```python 34 | def _get_keyboard_events(self) -> Dict[str, KeyboardAction]: 35 | basic_keyboard = super()._get_keyboard_events() 36 | basic_keyboard['x'] = Button("new_var", 0, 1, self.commands[:, someIndex], None) 37 | return basic_keyboard 38 | ``` 39 | 40 | The following keyboard events are available: 41 | 42 | |**Classname** | **Parameters** | **Description** | 43 | |--------------|----------------|-----------------| 44 | | Delta | `amount`, `variable_reference`, `change_callback` (optional) | Increments the `reference_variable` by its amount and calls the `change_callback` if it was passed | 45 | | Button | `start_state`, `toggle_state`, `variable_reference`, `callback` (optional) | Sets `variable_reference[:] = toggle_state` for the duration the button is held down. Resets to `start_state` afterwards. Calls the `callback` if it was passed. | 46 | | Switch | `start_state`, `toggle_state`, `index`, `variable_reference`, `callback` (optional) | Toggles `variable_reference[:]` between the `toggle_state` and `start_state` every time the button is pressed and released. Calls the `callback` if it was passed. | 47 | | DelegateHandle | `delegate`, `edge_detection`, `callback` | Exectues the function handle `delegate` when the key was pressed. If `edge_detection` is true, it only executes in on rising edges. Executes the `callback` whenever the function handle was called. | 48 | 49 | With the `DelegateHandle` keyboard-event basically every desired action can be implemented. `Delta`, `Button` and `Switch` are only commonly used helpers. 50 | 51 | ### **Available keys** 52 | 53 | The list of keys you can use (e.g. `basic_keyboard['KEY_NAME']`) be found below. The Controller takes the key in the dictionary (e.g. `x`), transforms it to capital letters and prepends `KEY_` to it. So to get an event on `KEY_RIGHT_ALT`, you have to add `basic_keyboard['right_alt'] = [...]`. 54 | 55 |
56 | Click here to see all options 57 | KEY_SPACE, 58 | KEY_APOSTROPHE, 59 | KEY_COMMA, 60 | KEY_MINUS, 61 | KEY_PERIOD, 62 | KEY_SLASH, 63 | KEY_0, 64 | KEY_1, 65 | KEY_2, 66 | KEY_3, 67 | KEY_4, 68 | KEY_5, 69 | KEY_6, 70 | KEY_7, 71 | KEY_8, 72 | KEY_9, 73 | KEY_SEMICOLON, 74 | KEY_EQUAL, 75 | KEY_A, 76 | KEY_B, 77 | KEY_C, 78 | KEY_D, 79 | KEY_E, 80 | KEY_F, 81 | KEY_G, 82 | KEY_H, 83 | KEY_I, 84 | KEY_J, 85 | KEY_K, 86 | KEY_L, 87 | KEY_M, 88 | KEY_N, 89 | KEY_O, 90 | KEY_P, 91 | KEY_Q, 92 | KEY_R, 93 | KEY_S, 94 | KEY_T, 95 | KEY_U, 96 | KEY_V, 97 | KEY_W, 98 | KEY_X, 99 | KEY_Y, 100 | KEY_Z, 101 | KEY_LEFT_BRACKET, 102 | KEY_BACKSLASH, 103 | KEY_RIGHT_BRACKET, 104 | KEY_GRAVE_ACCENT, 105 | KEY_ESCAPE, 106 | KEY_TAB, 107 | KEY_ENTER, 108 | KEY_BACKSPACE, 109 | KEY_INSERT, 110 | KEY_DEL, 111 | KEY_RIGHT, 112 | KEY_LEFT, 113 | KEY_DOWN, 114 | KEY_UP, 115 | KEY_PAGE_UP, 116 | KEY_PAGE_DOWN, 117 | KEY_HOME, 118 | KEY_END, 119 | KEY_CAPS_LOCK, 120 | KEY_SCROLL_LOCK, 121 | KEY_NUM_LOCK, 122 | KEY_PRINT_SCREEN, 123 | KEY_PAUSE, 124 | KEY_F1, 125 | KEY_F2, 126 | KEY_F3, 127 | KEY_F4, 128 | KEY_F5, 129 | KEY_F6, 130 | KEY_F7, 131 | KEY_F8, 132 | KEY_F9, 133 | KEY_F10, 134 | KEY_F11, 135 | KEY_F12, 136 | KEY_NUMPAD_0, 137 | KEY_NUMPAD_1, 138 | KEY_NUMPAD_2, 139 | KEY_NUMPAD_3, 140 | KEY_NUMPAD_4, 141 | KEY_NUMPAD_5, 142 | KEY_NUMPAD_6, 143 | KEY_NUMPAD_7, 144 | KEY_NUMPAD_8, 145 | KEY_NUMPAD_9, 146 | KEY_NUMPAD_DEL, 147 | KEY_NUMPAD_DIVIDE, 148 | KEY_NUMPAD_MULTIPLY, 149 | KEY_NUMPAD_SUBTRACT, 150 | KEY_NUMPAD_ADD, 151 | KEY_NUMPAD_ENTER, 152 | KEY_NUMPAD_EQUAL, 153 | KEY_LEFT_SHIFT, 154 | KEY_LEFT_CONTROL, 155 | KEY_LEFT_ALT, 156 | KEY_LEFT_SUPER, 157 | KEY_RIGHT_SHIFT, 158 | KEY_RIGHT_CONTROL, 159 | KEY_RIGHT_ALT, 160 | KEY_RIGHT_SUPER, 161 | KEY_MENU 162 |
163 | 164 | An the exact list depends on your isaacgym version and can be found in the docs folder of your local isaacgym_lib copy: `isaacgym_lib/docs/api/python/enum_py.html#isaacgym.gymapi.KeyboardInput`. 165 | -------------------------------------------------------------------------------- /solo_gym/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .helpers import ( 2 | class_to_dict, 3 | get_load_path, 4 | get_args, 5 | export_policy_as_jit, 6 | export_policy_as_onnx, 7 | set_seed, 8 | update_class_from_dict, 9 | ) 10 | from .task_registry import task_registry 11 | from .logger import Logger 12 | from .math import * 13 | from .terrain import Terrain 14 | -------------------------------------------------------------------------------- /solo_gym/utils/base_config.py: -------------------------------------------------------------------------------- 1 | # python 2 | import inspect 3 | 4 | 5 | class BaseConfig: 6 | """Data structure for handling python-classes configurations.""" 7 | 8 | def __init__(self) -> None: 9 | """Initializes all member classes recursively.""" 10 | self.init_member_classes(self) 11 | 12 | @staticmethod 13 | def init_member_classes(obj) -> None: 14 | """Initializes all member classes recursively. 15 | 16 | Note: 17 | Ignores all names starting with "__" (i.e. built-in methods). 18 | """ 19 | # iterate over all attributes names 20 | for key in dir(obj): 21 | # disregard builtin attributes 22 | # if key.startswith("__"): 23 | if key == "__class__": 24 | continue 25 | # get the corresponding attribute object 26 | var = getattr(obj, key) 27 | # check if the attribute is a class 28 | if inspect.isclass(var): 29 | # instantiate the class 30 | i_var = var() 31 | # set the attribute to the instance instead of the type 32 | setattr(obj, key, i_var) 33 | # recursively init members of the attribute 34 | BaseConfig.init_member_classes(i_var) 35 | -------------------------------------------------------------------------------- /solo_gym/utils/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # isaacgym 5 | from isaacgym import gymapi 6 | from isaacgym import gymutil 7 | 8 | # python 9 | import os 10 | import copy 11 | import torch 12 | import numpy as np 13 | import random 14 | import argparse 15 | 16 | 17 | """ 18 | Dictionary <-> Class operations. 19 | """ 20 | 21 | 22 | def class_to_dict(obj) -> dict: 23 | """Convert a class to dictionary data. 24 | 25 | Args: 26 | obj ([type]): Input class instance. 27 | 28 | Returns: 29 | dict: Dictionary containing values from the class. 30 | """ 31 | if not hasattr(obj, "__dict__"): 32 | return obj 33 | result = {} 34 | for key in dir(obj): 35 | if key.startswith("_"): 36 | continue 37 | element = [] 38 | val = getattr(obj, key) 39 | if isinstance(val, list): 40 | for item in val: 41 | element.append(class_to_dict(item)) 42 | else: 43 | element = class_to_dict(val) 44 | result[key] = element 45 | return result 46 | 47 | 48 | def update_class_from_dict(obj, d: dict) -> None: 49 | """Updates a class with values from dictionary. 50 | 51 | Args: 52 | obj ([type]): Input class instance. 53 | d (dict): Dictionary to update values from. 54 | """ 55 | for key, val in d.items(): 56 | attr = getattr(obj, key, None) 57 | if isinstance(attr, type): 58 | update_class_from_dict(attr, val) 59 | else: 60 | setattr(obj, key, val) 61 | return 62 | 63 | 64 | """ 65 | Parsing configurations. 66 | """ 67 | 68 | 69 | def parse_sim_params(args: argparse.Namespace, cfg: dict) -> gymapi.SimParams: 70 | """Parses CLI args and dictionary to generate physics stepping settings. 71 | 72 | Args: 73 | args (argparse.Namespace): [description] 74 | cfg (dict): [description] 75 | 76 | Returns: 77 | gymapi.SimParams: IsaacGym SimParams object with updated settings. 78 | """ 79 | # Note: Some of the code from Isaac Gym Preview 2 80 | # initialize sim params 81 | sim_params = gymapi.SimParams() 82 | 83 | # set some values from args 84 | if args.physics_engine == gymapi.SIM_FLEX: 85 | if args.device != "cpu": 86 | print("WARNING: Using Flex with GPU instead of PHYSX!") 87 | elif args.physics_engine == gymapi.SIM_PHYSX: 88 | sim_params.physx.use_gpu = args.use_gpu 89 | sim_params.physx.num_subscenes = args.subscenes 90 | sim_params.use_gpu_pipeline = args.use_gpu_pipeline 91 | 92 | # if sim options are provided in cfg, parse them and update/override above: 93 | if "sim" in cfg: 94 | gymutil.parse_sim_config(cfg["sim"], sim_params) 95 | 96 | # Override num_threads if passed on the command line 97 | if args.physics_engine == gymapi.SIM_PHYSX and args.num_threads > 0: 98 | sim_params.physx.num_threads = args.num_threads 99 | 100 | return sim_params 101 | 102 | 103 | def update_cfg_from_args(env_cfg, cfg_train, args: argparse.Namespace) -> tuple: 104 | """Updates environment and training configrations from CLI args. 105 | 106 | Args: 107 | env_cfg ([type]): Environment configuration instance. 108 | cfg_train ([type]): Training configuration instance. 109 | args (argparse.Namespace): Parsed CLI arguments. 110 | 111 | Returns: 112 | tuple: Tuple containing the environment and training configurations. 113 | """ 114 | # environment configuration 115 | if env_cfg is not None: 116 | # num envs 117 | if args.num_envs is not None: 118 | env_cfg.env.num_envs = args.num_envs 119 | # training configuration 120 | if cfg_train is not None: 121 | # seed 122 | if args.seed is not None: 123 | cfg_train.seed = args.seed 124 | # alg runner parameters 125 | if args.max_iterations is not None: 126 | cfg_train.runner.max_iterations = args.max_iterations 127 | if args.resume: 128 | cfg_train.runner.resume = args.resume 129 | if args.experiment_name is not None: 130 | cfg_train.runner.experiment_name = args.experiment_name 131 | if args.run_name is not None: 132 | cfg_train.runner.run_name = args.run_name 133 | if args.load_run is not None: 134 | cfg_train.runner.load_run = args.load_run 135 | if args.checkpoint is not None: 136 | cfg_train.runner.checkpoint = args.checkpoint 137 | 138 | return env_cfg, cfg_train 139 | 140 | 141 | def get_args() -> argparse.Namespace: 142 | """Defines custom command-line arguments and parses them. 143 | 144 | Returns: 145 | argparse.Namespace: Parsed CLI arguments. 146 | """ 147 | custom_parameters = [ 148 | { 149 | "name": "--task", 150 | "type": str, 151 | "default": "anymal_c_flat", 152 | "help": "Resume training or start testing from a checkpoint. Overrides config file if provided.", 153 | }, 154 | { 155 | "name": "--resume", 156 | "action": "store_true", 157 | "default": False, 158 | "help": "Resume training from a checkpoint", 159 | }, 160 | { 161 | "name": "--experiment_name", 162 | "type": str, 163 | "help": "Name of the experiment to run or load. Overrides config file if provided.", 164 | }, 165 | { 166 | "name": "--run_name", 167 | "type": str, 168 | "help": "Name of the run. Overrides config file if provided.", 169 | }, 170 | { 171 | "name": "--load_run", 172 | "type": str, 173 | "help": "Name of the run to load when resume=True. If -1: will load the last run. Overrides config file if provided.", 174 | }, 175 | { 176 | "name": "--checkpoint", 177 | "type": int, 178 | "help": "Saved model checkpoint number. If -1: will load the last checkpoint. Overrides config file if provided.", 179 | }, 180 | { 181 | "name": "--headless", 182 | "action": "store_true", 183 | "default": False, 184 | "help": "Force display off at all times", 185 | }, 186 | { 187 | "name": "--horovod", 188 | "action": "store_true", 189 | "default": False, 190 | "help": "Use horovod for multi-gpu training", 191 | }, 192 | { 193 | "name": "--rl_device", 194 | "type": str, 195 | "default": "cuda:0", 196 | "help": "Device used by the RL algorithm, (cpu, gpu, cuda:0, cuda:1 etc..)", 197 | }, 198 | { 199 | "name": "--num_envs", 200 | "type": int, 201 | "help": "Number of environments to create. Overrides config file if provided.", 202 | }, 203 | { 204 | "name": "--seed", 205 | "type": int, 206 | "help": "Random seed. Overrides config file if provided.", 207 | }, 208 | { 209 | "name": "--max_iterations", 210 | "type": int, 211 | "help": "Maximum number of training iterations. Overrides config file if provided.", 212 | }, 213 | ] 214 | # parse arguments 215 | args = gymutil.parse_arguments(description="RL Policy using IsaacGym", custom_parameters=custom_parameters) 216 | # name alignment 217 | args.sim_device_id = args.compute_device_id 218 | args.sim_device = args.sim_device_type 219 | if args.sim_device == "cuda": 220 | args.sim_device += f":{args.sim_device_id}" 221 | return args 222 | 223 | 224 | """ 225 | MDP-related operations. 226 | """ 227 | 228 | 229 | def set_seed(seed: int): 230 | """Set the seeding of the experiment. 231 | 232 | Note: 233 | If input is -1, then a random integer between (0, 10000) is sampled. 234 | 235 | Args: 236 | seed (int): The seed value to set. 237 | """ 238 | # default argument 239 | if seed == -1: 240 | seed = np.random.randint(0, 10000) 241 | print("[Utils] Setting seed: {}".format(seed)) 242 | # set seed 243 | random.seed(seed) 244 | np.random.seed(seed) 245 | torch.manual_seed(seed) 246 | os.environ["PYTHONHASHSEED"] = str(seed) 247 | torch.cuda.manual_seed(seed) 248 | torch.cuda.manual_seed_all(seed) 249 | 250 | 251 | """ 252 | Model loading and saving. 253 | """ 254 | 255 | 256 | def get_load_path(root: str, load_run=-1, checkpoint: int = -1) -> str: 257 | # check if runs present in directory 258 | try: 259 | runs = os.listdir(root) 260 | # TODO: sort by date to handle change of month 261 | runs.sort() 262 | if "exported" in runs: 263 | runs.remove("exported") 264 | last_run = os.path.join(root, runs[-1]) 265 | except IndexError: 266 | raise ValueError(f"No runs present in directory: {root}") 267 | # path to the directory containing the run 268 | if load_run == -1: 269 | load_run = last_run 270 | else: 271 | load_run = os.path.join(root, load_run) 272 | # name of model checkpoint 273 | if checkpoint == -1: 274 | models = [file for file in os.listdir(load_run) if "model" in file] 275 | models.sort(key=lambda m: "{0:0>15}".format(m)) 276 | model = models[-1] 277 | else: 278 | model = "model_{}.pt".format(checkpoint) 279 | 280 | return os.path.join(load_run, model) 281 | 282 | 283 | def export_policy_as_jit(actor_critic, path, filename="policy.pt"): 284 | policy_exporter = TorchPolicyExporter(actor_critic) 285 | policy_exporter.export(path, filename) 286 | 287 | 288 | def export_policy_as_onnx(actor_critic, path, filename="policy.onnx"): 289 | policy_exporter = OnnxPolicyExporter(actor_critic) 290 | policy_exporter.export(path, filename) 291 | 292 | 293 | class TorchPolicyExporter(torch.nn.Module): 294 | def __init__(self, actor_critic): 295 | super().__init__() 296 | self.actor = copy.deepcopy(actor_critic.actor) 297 | self.is_recurrent = actor_critic.is_recurrent 298 | if self.is_recurrent: 299 | self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) 300 | self.rnn.cpu() 301 | self.register_buffer( 302 | "hidden_state", 303 | torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size), 304 | ) 305 | self.register_buffer( 306 | "cell_state", 307 | torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size), 308 | ) 309 | self.forward = self.forward_lstm 310 | self.reset = self.reset_memory 311 | 312 | def forward_lstm(self, x): 313 | x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state)) 314 | self.hidden_state[:] = h 315 | self.cell_state[:] = c 316 | x = x.squeeze(0) 317 | return self.actor(x) 318 | 319 | def forward(self, x): 320 | return self.actor(x) 321 | 322 | @torch.jit.export 323 | def reset(self): 324 | pass 325 | 326 | def reset_memory(self): 327 | self.hidden_state[:] = 0.0 328 | self.cell_state[:] = 0.0 329 | 330 | def export(self, path, filename): 331 | os.makedirs(path, exist_ok=True) 332 | path = os.path.join(path, filename) 333 | self.to("cpu") 334 | traced_script_module = torch.jit.script(self) 335 | traced_script_module.save(path) 336 | 337 | 338 | class OnnxPolicyExporter(torch.nn.Module): 339 | def __init__(self, actor_critic): 340 | super().__init__() 341 | self.actor = copy.deepcopy(actor_critic.actor) 342 | self.is_recurrent = actor_critic.is_recurrent 343 | if self.is_recurrent: 344 | self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) 345 | self.rnn.cpu() 346 | self.forward = self.forward_lstm 347 | 348 | def forward_lstm(self, x_in, h_in, c_in): 349 | x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in)) 350 | x = x.squeeze(0) 351 | return self.actor(x), h, c 352 | 353 | def forward(self, x): 354 | return self.actor(x) 355 | 356 | def export(self, path, filename): 357 | self.to("cpu") 358 | if self.is_recurrent: 359 | obs = torch.zeros(1, self.rnn.input_size) 360 | h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) 361 | c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) 362 | actions, h_out, c_out = self(obs, h_in, c_in) 363 | torch.onnx.export( 364 | self, 365 | (obs, h_in, c_in), 366 | os.path.join(path, filename), 367 | export_params=True, 368 | opset_version=11, 369 | verbose=True, 370 | input_names=["obs", "h_in", "c_in"], 371 | output_names=["actions", "h_out", "c_out"], 372 | dynamic_axes={}, 373 | ) 374 | else: 375 | obs = torch.zeros(1, self.actor[0].in_features) 376 | torch.onnx.export( 377 | self, 378 | obs, 379 | os.path.join(path, filename), 380 | export_params=True, 381 | opset_version=11, 382 | verbose=True, 383 | input_names=["obs"], 384 | output_names=["actions"], 385 | dynamic_axes={}, 386 | ) 387 | -------------------------------------------------------------------------------- /solo_gym/utils/keyboard_controller.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # isaacgym 4 | import isaacgym 5 | 6 | # python 7 | from abc import abstractmethod, ABC 8 | from typing import Callable, Any, Dict 9 | import torch 10 | 11 | 12 | # Callback function for when event is triggered. 13 | CallbackFn = Callable[[], None] 14 | # Action function called when event is trigged. 15 | # Takes the environment instance and event value. 16 | ActionFn = Callable[[Any, int], None] 17 | 18 | 19 | class KeyboardAction(ABC): 20 | """Base class for keyboard event.""" 21 | 22 | def __init__( 23 | self, 24 | name: str, 25 | variable_reference: torch.Tensor = None, 26 | member_name: str = None, 27 | ): 28 | """Initializes the keyboard action event with variable attributes. 29 | 30 | Note: 31 | The keyboard action can be applied to a variable (passed via reference) or 32 | on a member in the environment class instance. 33 | 34 | Args: 35 | name (str): Name of the affected value. 36 | variable_reference (torch.Tensor, optional): Reference variable to alter value. Defaults to None. 37 | member_name (str, optional): Name of the variable in the environment. Defaults to None. 38 | 39 | Raises: 40 | ValueError -- If both reference variable and environment's member name are None or not None. 41 | """ 42 | # check input 43 | if (variable_reference is None and member_name is None) or ( 44 | variable_reference is not None and member_name is not None 45 | ): 46 | msg = "Invalid arguments: Action can only be applied on either reference variable or environment's member variable." 47 | raise ValueError(msg) 48 | # store input arguments 49 | self.name = name 50 | self.variable_reference = variable_reference 51 | self.member_name = member_name 52 | # disambiguate the type of mode 53 | if variable_reference is not None and member_name is None: 54 | self._ref_mode = True 55 | elif variable_reference is None and member_name is not None: 56 | self._ref_mode = False 57 | 58 | def __str__(self) -> str: 59 | """Helper string to explain keyboard action.""" 60 | return f"Keyboard action on {self.name}." 61 | 62 | def get_reference(self, env) -> torch.Tensor: 63 | """Retrieve the variable on which event action is applied. 64 | 65 | Args: 66 | env (BaseTask): The environment/task instance. 67 | 68 | Returns: 69 | torch.Tensor: The passed variable reference or environment instance's member. 70 | """ 71 | if self._ref_mode: 72 | return self.variable_reference 73 | else: 74 | return getattr(env, self.member_name) 75 | 76 | @abstractmethod 77 | def do(self, env, value: int): 78 | """Action applied by the keyboard event. 79 | 80 | Args: 81 | env (BaseTask): The environment/task instance. 82 | value (int): The event triggered when keyboard button pressed. 83 | """ 84 | raise NotImplementedError 85 | 86 | 87 | class DelegateHandle(KeyboardAction): 88 | """Pre-defined delegate that executes an event handler. 89 | 90 | This class exectues the function handle `delegate` when the key is pressed. If `edge_detection` is 91 | true, then the function executes only on rising edges (i.e. release of the key). 92 | 93 | The `callback` function is executed whenever the function handle is called. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | name: str, 99 | delegate: ActionFn, 100 | edge_detection: bool = True, 101 | callback: CallbackFn = None, 102 | variable_reference: torch.Tensor = None, 103 | member_name: str = None, 104 | ): 105 | """Initializes the class. 106 | 107 | Args: 108 | name (str): Name of the affected value. 109 | delegate (ActionFn): The function called when keyboard is pressed/released. 110 | edge_detection (bool, optional): Decides whether to change value on press/release. Defaults to True. 111 | callback (CallbackFn, optional): Function called whenever key triggered. Defaults to None. 112 | variable_reference (torch.Tensor, optional): Reference variable to alter value. Defaults to None. 113 | member_name (str, optional): Name of the variable in the environment. Defaults to None. 114 | """ 115 | super().__init__(name, variable_reference, member_name) 116 | # store inputs 117 | self._delegate = delegate 118 | self._edge_detection = edge_detection 119 | self._callback = callback 120 | 121 | def do(self, env, value): 122 | """Action applied by the keyboard event. 123 | 124 | Args: 125 | env (BaseTask): The environment/task instance. 126 | value (int): The event triggered when keyboard button pressed. 127 | """ 128 | # if no event triggered return. 129 | if self._edge_detection and value == 0: 130 | return 131 | # resolve action based on press/release 132 | self._delegate(env, value) 133 | # trigger callback function 134 | if self._callback is not None: 135 | self._callback() 136 | 137 | 138 | class Delta(DelegateHandle): 139 | """Keyboard action that increments the value of reference variable by scalar amount.""" 140 | 141 | def __init__( 142 | self, 143 | name: str, 144 | amount: float, 145 | variable_reference: torch.Tensor, 146 | callback: CallbackFn = None, 147 | ): 148 | """Initializes the class. 149 | 150 | Args: 151 | name (str): Name of the affected value. 152 | amount (float): The amount by which to increment. 153 | variable_reference (torch.Tensor): Reference variable to alter value. 154 | callback (CallbackFn, optional): Function called whenever key triggered. Defaults to None. 155 | """ 156 | self.amount = amount 157 | 158 | # delegate function 159 | def addDelta(env, value): 160 | self.variable_reference += self.amount 161 | 162 | # initialize parent 163 | super().__init__(name, addDelta, True, callback, variable_reference, None) 164 | 165 | def __str__(self) -> str: 166 | if self.amount >= 0: 167 | return f"Increments the variable {self.name} by {self.amount}" 168 | else: 169 | return f"Decrements the variable {self.name} by {-self.amount}" 170 | 171 | 172 | class Switch(DelegateHandle): 173 | """Keyboard action that toggles between values of reference variable.""" 174 | 175 | def __init__( 176 | self, 177 | name: str, 178 | start_state: torch.Tensor, 179 | toggle_state: torch.Tensor, 180 | variable_reference: torch.Tensor, 181 | callback: CallbackFn = None, 182 | ): 183 | """Initializes the class. 184 | 185 | Args: 186 | name (str): Name of the affected value. 187 | start_state (torch.Tensor): Initial value of reference variable. 188 | toggle_state (torch.Tensor): Toggled value of reference variable. 189 | variable_reference (torch.Tensor): Reference variable to alter value. 190 | callback (CallbackFn, optional): Function called whenever key triggered. Defaults to None. 191 | """ 192 | # copy inputs to class 193 | self.start_state = start_state 194 | self.toggle_state = toggle_state 195 | self.variable_reference = variable_reference 196 | # initial state of toggle switch 197 | self.switch_value = True 198 | 199 | # delegate function 200 | def switchState(env, value): 201 | # switch between state depending on switch's value 202 | if self.switch_value: 203 | new_state = self.toggle_state 204 | else: 205 | new_state = self.start_state 206 | # store value into reference variable 207 | self.variable_reference[:] = new_state 208 | # toggle switch to other state 209 | self.switch_value = not self.switch_value 210 | 211 | # initialize parent 212 | super().__init__(name, switchState, True, callback, variable_reference, None) 213 | 214 | def __str__(self) -> str: 215 | return f"Toggles the variable {self.name} between {self.toggle_state} and {self.start_state}." 216 | 217 | 218 | class Button(Switch): 219 | """Sets the variable to value only while keyboard button is pressed.""" 220 | 221 | def __init__( 222 | self, 223 | name: str, 224 | start_state: torch.Tensor, 225 | toggle_state: torch.Tensor, 226 | variable_reference: torch.Tensor, 227 | callback: CallbackFn = None, 228 | ): 229 | """Initializes the class. 230 | 231 | Args: 232 | name (str): Name of the affected value. 233 | start_state (torch.Tensor): Initial value of reference variable. 234 | toggle_state (torch.Tensor): Toggled value of reference variable. 235 | variable_reference (torch.Tensor): Reference variable to alter value. 236 | callback (CallbackFn, optional): Function called whenever key triggered. Defaults to None. 237 | """ 238 | # initialize toggle switch 239 | super().__init__(name, start_state, toggle_state, variable_reference, callback) 240 | # trigger event only when key is pressed 241 | self._edge_detection = False 242 | 243 | def __str__(self) -> str: 244 | return f"Sets the variable {self.name} to {self.toggle_state} only while key is pressed." 245 | 246 | 247 | class KeyBoardController: 248 | """Wrapper around IsaacGym viewer to handle different keyboard actions.""" 249 | 250 | def __init__(self, env, key_actions: Dict[str, KeyboardAction]): 251 | """Initializes the class. 252 | 253 | Args: 254 | env (BaseTask): The environment/task instance. 255 | key_actions (Dict[str, KeyboardAction]): The pairs of key buttons and their actions. 256 | """ 257 | # store inputs 258 | self._env = env 259 | self._key_actions = key_actions 260 | # setup the keyboard event subscriber 261 | for key_name in self._key_actions.keys(): 262 | key_enum = getattr(isaacgym.gymapi.KeyboardInput, f"KEY_{key_name.capitalize()}") 263 | env.gym.subscribe_viewer_keyboard_event(env.viewer, key_enum, key_name) 264 | 265 | def update(self, env): 266 | """Update the reference variables by querying viewer events.""" 267 | # gather all events on viewer 268 | events = env.gym.query_viewer_action_events(env.viewer) 269 | # iterate over events 270 | for event in events: 271 | key_pressed = event.action 272 | if key_pressed in self._key_actions: 273 | cfg = self._key_actions[key_pressed] 274 | cfg.do(env, event.value) 275 | 276 | def print_options(self): 277 | print("[KeyboardController] Key-action pairs:") 278 | for key_name, action in self._key_actions.items(): 279 | print(f"\t{key_name}: {action}") 280 | -------------------------------------------------------------------------------- /solo_gym/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # python 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from collections import defaultdict 8 | from multiprocessing import Process 9 | 10 | 11 | class Logger: 12 | def __init__(self, dt): 13 | self.state_log = defaultdict(list) 14 | self.rew_log = defaultdict(list) 15 | self.dt = dt 16 | self.num_episodes = 0 17 | self.plot_process = None 18 | 19 | def log_state(self, key, value): 20 | self.state_log[key].append(value) 21 | 22 | def log_states(self, dict): 23 | for key, value in dict.items(): 24 | self.log_state(key, value) 25 | 26 | def log_rewards(self, dict, num_episodes): 27 | for key, value in dict.items(): 28 | if "rew" in key: 29 | self.rew_log[key].append(value.item() * num_episodes) 30 | self.num_episodes += num_episodes 31 | 32 | def reset(self): 33 | self.state_log.clear() 34 | self.rew_log.clear() 35 | 36 | def plot_states(self): 37 | self.plot_process = Process(target=self._plot) 38 | self.plot_process.start() 39 | 40 | def _plot(self): 41 | nb_rows = 3 42 | nb_cols = 3 43 | fig, axs = plt.subplots(nb_rows, nb_cols) 44 | for key, value in self.state_log.items(): 45 | time = np.linspace(0, len(value) * self.dt, len(value)) 46 | break 47 | log = self.state_log 48 | # plot joint targets and measured positions 49 | a = axs[1, 0] 50 | if log["dof_pos"]: 51 | a.plot(time, log["dof_pos"], label="measured") 52 | if log["dof_pos_target"]: 53 | a.plot(time, log["dof_pos_target"], label="target") 54 | a.set(xlabel="time [s]", ylabel="Position [rad]", title="DOF Position") 55 | a.legend() 56 | # plot joint velocity 57 | a = axs[1, 1] 58 | if log["dof_vel"]: 59 | a.plot(time, log["dof_vel"], label="measured") 60 | if log["dof_vel_target"]: 61 | a.plot(time, log["dof_vel_target"], label="target") 62 | a.set(xlabel="time [s]", ylabel="Velocity [rad/s]", title="Joint Velocity") 63 | a.legend() 64 | # plot base vel x 65 | a = axs[0, 0] 66 | if log["base_vel_x"]: 67 | a.plot(time, log["base_vel_x"], label="measured") 68 | if log["command_x"]: 69 | a.plot(time, log["command_x"], label="commanded") 70 | a.set(xlabel="time [s]", ylabel="base lin vel [m/s]", title="Base velocity x") 71 | a.legend() 72 | # plot base vel y 73 | a = axs[0, 1] 74 | if log["base_vel_y"]: 75 | a.plot(time, log["base_vel_y"], label="measured") 76 | if log["command_y"]: 77 | a.plot(time, log["command_y"], label="commanded") 78 | a.set(xlabel="time [s]", ylabel="base lin vel [m/s]", title="Base velocity y") 79 | a.legend() 80 | # plot base vel yaw 81 | a = axs[0, 2] 82 | if log["base_vel_yaw"]: 83 | a.plot(time, log["base_vel_yaw"], label="measured") 84 | if log["command_yaw"]: 85 | a.plot(time, log["command_yaw"], label="commanded") 86 | a.set(xlabel="time [s]", ylabel="base ang vel [rad/s]", title="Base velocity yaw") 87 | a.legend() 88 | # plot base vel z 89 | a = axs[1, 2] 90 | if log["base_vel_z"]: 91 | a.plot(time, log["base_vel_z"], label="measured") 92 | a.set(xlabel="time [s]", ylabel="base lin vel [m/s]", title="Base velocity z") 93 | a.legend() 94 | # plot contact forces 95 | a = axs[2, 0] 96 | if log["contact_forces_z"]: 97 | forces = np.array(log["contact_forces_z"]) 98 | for i in range(forces.shape[1]): 99 | a.plot(time, forces[:, i], label=f"force {i}") 100 | a.set(xlabel="time [s]", ylabel="Forces z [N]", title="Vertical Contact forces") 101 | a.legend() 102 | # plot torque/vel curves 103 | a = axs[2, 1] 104 | if log["dof_vel"] != [] and log["dof_torque"] != []: 105 | a.plot(log["dof_vel"], log["dof_torque"], "x", label="measured") 106 | a.set( 107 | xlabel="Joint vel [rad/s]", 108 | ylabel="Joint Torque [Nm]", 109 | title="Torque/velocity curves", 110 | ) 111 | a.legend() 112 | # plot torques 113 | a = axs[2, 2] 114 | if log["dof_torque"] != []: 115 | a.plot(time, log["dof_torque"], label="measured") 116 | a.set(xlabel="time [s]", ylabel="Joint Torque [Nm]", title="Torque") 117 | a.legend() 118 | plt.show() 119 | 120 | def print_rewards(self): 121 | print("Average rewards per second:") 122 | for key, values in self.rew_log.items(): 123 | mean = np.sum(np.array(values)) / self.num_episodes 124 | print(f" - {key}: {mean}") 125 | print(f"Total number of episodes: {self.num_episodes}") 126 | 127 | def __del__(self): 128 | if self.plot_process is not None: 129 | self.plot_process.kill() 130 | -------------------------------------------------------------------------------- /solo_gym/utils/math.py: -------------------------------------------------------------------------------- 1 | # isaac-gym 2 | from isaacgym.torch_utils import quat_apply, normalize, quat_mul, quat_conjugate 3 | 4 | # python 5 | import torch 6 | import numpy as np 7 | from typing import Tuple 8 | 9 | 10 | # @ torch.jit.script 11 | def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: 12 | """Rotate a vector only around the yaw-direction. 13 | 14 | Args: 15 | quat (torch.Tensor): Input orientation to extract yaw from. 16 | vec (torch.Tensor): Input vector. 17 | 18 | Returns: 19 | torch.Tensor: Rotated vector. 20 | """ 21 | quat_yaw = quat.clone().view(-1, 4) 22 | quat_yaw[:, :2] = 0.0 23 | quat_yaw = normalize(quat_yaw) 24 | return quat_apply(quat_yaw, vec) 25 | 26 | 27 | # @ torch.jit.script 28 | def box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: 29 | """Implements box-minur operator (quaternion difference) 30 | https://docs.leggedrobotics.com/kindr/cheatsheet_latest.pdf 31 | 32 | Args: 33 | q1 (torch.Tensor): quaternion 34 | q2 (torch.Tensor): quaternion 35 | 36 | Returns: 37 | torch.Tensor: q1 box-minus q2 38 | """ 39 | quat_diff = quat_mul(q1, quat_conjugate(q2)) # q1 * q2^-1 40 | re = quat_diff[:, -1] # real part, q = [x, y, z, w] = [re, im] 41 | im = quat_diff[:, 0:3] # imaginary part 42 | norm_im = torch.norm(im, dim=1) 43 | scale = 2.0 * torch.where(norm_im > 1.0e-7, torch.atan(norm_im / re) / norm_im, torch.sign(re)) 44 | return scale.unsqueeze(-1) * im 45 | 46 | 47 | # @ torch.jit.script 48 | def wrap_to_pi(angles: torch.Tensor) -> torch.Tensor: 49 | """Wraps input angles (in radians) to the range [-pi, pi]. 50 | 51 | Args: 52 | angles (torch.Tensor): Input angles. 53 | 54 | Returns: 55 | torch.Tensor: Angles in the range [-pi, pi]. 56 | """ 57 | angles %= 2 * np.pi 58 | angles -= 2 * np.pi * (angles > np.pi) 59 | return angles 60 | 61 | 62 | # @ torch.jit.script 63 | def torch_rand_sqrt_float(lower: float, upper: float, size: Tuple[int, int], device: str) -> torch.Tensor: 64 | """Randomly samples tensor from a triangular distribution. 65 | 66 | Args: 67 | lower (float): The lower range of the sampled tensor. 68 | upper (float): The upper range of the sampled tensor. 69 | size (Tuple[int, int]): The shape of the tensor. 70 | device (str): Device to create tensor on. 71 | 72 | Returns: 73 | torch.Tensor: Sampled tensor of shape :obj:`size`. 74 | """ 75 | # create random tensor in the range [-1, 1] 76 | r = 2 * torch.rand(*size, device=device) - 1 77 | # convert to triangular distribution 78 | r = torch.where(r < 0.0, -torch.sqrt(-r), torch.sqrt(r)) 79 | # rescale back to [0, 1] 80 | r = (r + 1.0) / 2.0 81 | # rescale to range [lower, upper] 82 | return (upper - lower) * r + lower 83 | -------------------------------------------------------------------------------- /solo_gym/utils/task_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # python 5 | import argparse 6 | import os 7 | from datetime import datetime 8 | from typing import Tuple, Type, List 9 | 10 | # solo-gym 11 | from solo_gym import LEGGED_GYM_ROOT_DIR 12 | from solo_gym.utils.base_config import BaseConfig 13 | from solo_gym.utils.helpers import ( 14 | get_args, 15 | update_cfg_from_args, 16 | class_to_dict, 17 | get_load_path, 18 | set_seed, 19 | parse_sim_params, 20 | ) 21 | 22 | # learning 23 | from learning.env import VecEnv 24 | from learning.runners import CASSIOnPolicyRunner 25 | 26 | 27 | class TaskRegistry: 28 | """This class simplifies creation of environments and agents.""" 29 | 30 | def __init__(self): 31 | self.task_classes = dict() 32 | self.env_cfgs = dict() 33 | self.train_cfgs = dict() 34 | 35 | """ 36 | Properties 37 | """ 38 | 39 | def get_task_names(self) -> List[str]: 40 | """Returns a list of registered task names. 41 | 42 | Returns: 43 | List[str]: List of registered task names. 44 | """ 45 | return list(self.task_classes.keys()) 46 | 47 | def get_task_class(self, name: str) -> Type[VecEnv]: 48 | """Retrieve the class object corresponding to input name. 49 | 50 | Args: 51 | name (str): name of the registered environment. 52 | 53 | Raises: 54 | ValueError: When there is no registered environment with input `name`. 55 | 56 | Returns: 57 | Type[VecEnv]: The environment class object. 58 | """ 59 | # check if there is a registered env with that name 60 | if self._check_valid_task(name): 61 | return self.task_classes[name] 62 | 63 | def get_cfgs(self, name: str) -> Tuple[BaseConfig, BaseConfig]: 64 | """Retrieve the default environment and training configurations. 65 | 66 | Args: 67 | name (str): Name of the environment. 68 | 69 | Raises: 70 | ValueError: When there is no registered environment with input `name`. 71 | 72 | Returns: 73 | Tuple[BaseConfig, BaseConfig]: The default environment and training configurations. 74 | """ 75 | # check if there is a registered env with that name 76 | if self._check_valid_task(name): 77 | # retrieve configurations 78 | train_cfg = self.train_cfgs[name] 79 | env_cfg = self.env_cfgs[name] 80 | # copy seed between environment and agent 81 | env_cfg.seed = train_cfg.seed 82 | return env_cfg, train_cfg 83 | 84 | """ 85 | Operations 86 | """ 87 | 88 | def register( 89 | self, 90 | name: str, 91 | task_class: Type[VecEnv], 92 | env_cfg: Type[BaseConfig], 93 | train_cfg: Type[BaseConfig], 94 | ): 95 | """Add a particular environment to the task registry. 96 | 97 | Args: 98 | name (str): Name of the environment. 99 | task_class (Type[VecEnv]): The corresponding task class. 100 | env_cfg (Type[BaseConfig]): The corresponding environment configuration file. 101 | train_cfg (Type[BaseConfig]): The corresponding agent configuration file. 102 | """ 103 | self.task_classes[name] = task_class 104 | self.env_cfgs[name] = env_cfg() 105 | self.train_cfgs[name] = train_cfg() 106 | 107 | def make_env( 108 | self, name: str, args: argparse.Namespace = None, env_cfg: BaseConfig = None 109 | ) -> Tuple[VecEnv, BaseConfig]: 110 | """Creates an environment from the registry. 111 | 112 | Args: 113 | name (str): Name of a registered env. 114 | args (argparse.Namespace, optional): Parsed CLI arguments. If :obj:`None`, then 115 | `get_args()` is called to obtain arguments. Defaults to None. 116 | env_cfg (BaseConfig, optional): Environment configuration class instance used to 117 | overwrite the default registered configuration. Defaults to None. 118 | 119 | Raises: 120 | ValueError: When there is no registered environment with input `name`. 121 | 122 | Returns: 123 | Tuple[VecEnv, BaseConfig]: Tuple containing the created class instance and corresponding 124 | configuration class instance. 125 | """ 126 | # check if there is a registered env with that name 127 | task_class = self.get_task_class(name) 128 | # if no args passed, get command line arguments 129 | if args is None: 130 | args = get_args() 131 | # if no config passed, use default env config 132 | if env_cfg is None: 133 | # load config files 134 | env_cfg, _ = self.get_cfgs(name) 135 | # override cfg from args (if specified) 136 | env_cfg, _ = update_cfg_from_args(env_cfg, None, args) 137 | # set seed 138 | set_seed(env_cfg.seed) 139 | # parse sim params (convert to dict first) 140 | sim_params = {"sim": class_to_dict(env_cfg.sim)} 141 | sim_params = parse_sim_params(args, sim_params) 142 | # create environment instance 143 | env = task_class( 144 | cfg=env_cfg, 145 | sim_params=sim_params, 146 | physics_engine=args.physics_engine, 147 | sim_device=args.sim_device, 148 | headless=args.headless, 149 | ) 150 | return env, env_cfg 151 | 152 | def make_alg_runner( 153 | self, 154 | env: VecEnv, 155 | name: str = None, 156 | args: argparse.Namespace = None, 157 | train_cfg: BaseConfig = None, 158 | log_root: str = "default", 159 | ) -> Tuple[CASSIOnPolicyRunner, BaseConfig]: 160 | """Creates the training algorithm either from a registered name or from the provided 161 | config file. 162 | 163 | TODO (@nrudin): Remove environment from within the algorithm. 164 | 165 | Note: 166 | The training/agent configuration is loaded from either "name" or "train_cfg". If both are 167 | passed then the default configuration via "name" is ignored. 168 | 169 | Args: 170 | env (VecEnv): The environment to train on. 171 | name (str, optional): The environment name to retrieve corresponding training configuration. 172 | Defaults to None. 173 | args (argparse.Namespace, optional): Parsed CLI arguments. If :obj:`None`, then 174 | `get_args()` is called to obtain arguments. Defaults to None. 175 | train_cfg (BaseConfig, optional): Instance of training configuration class. If 176 | :obj:`None`, then `name` is used to retrieve default training configuration. 177 | Defaults to None. 178 | log_root (str, optional): Logging directory for TensorBoard. Set to obj:`None` to avoid 179 | logging (such as during test-time). Logs are saved in the `/_` 180 | directory. If "default", then `log_root` is set to 181 | "{LEGGED_GYM_ROOT_DIR}/logs/{train_cfg.runner.experiment_name}". Defaults to "default". 182 | 183 | Raises: 184 | ValueError: If neither "name" or "train_cfg" are provided for loading training configuration. 185 | 186 | Returns: 187 | Tuple[CASSIOnPolicyRunner, BaseConfig]: Tuple containing the training runner and configuration instances. 188 | """ 189 | # if config files are passed use them, otherwise load default from the name 190 | if train_cfg is None: 191 | if name is None: 192 | raise ValueError("No training configuration provided. Either 'name' or 'train_cfg' must not be None.") 193 | else: 194 | # load config files 195 | _, train_cfg = self.get_cfgs(name) 196 | else: 197 | if name is not None: 198 | print(f"Training configuration instance provided. Ignoring default configuration for 'name={name}'.") 199 | # if no args passed get command line arguments 200 | if args is None: 201 | args = get_args() 202 | # override cfg from args (if specified) 203 | _, train_cfg = update_cfg_from_args(None, train_cfg, args) 204 | # resolve logging 205 | if log_root is None: 206 | log_dir_path = None 207 | else: 208 | # default location for logs 209 | if log_root == "default": 210 | log_root = os.path.join(LEGGED_GYM_ROOT_DIR, "logs", train_cfg.runner.experiment_name) 211 | # log directory 212 | log_dir_path = os.path.join( 213 | log_root, 214 | datetime.now().strftime("%b%d_%H-%M-%S") + "_" + train_cfg.runner.run_name, 215 | ) 216 | # create training runner 217 | runner_class = eval(train_cfg.runner_class_name) 218 | train_cfg_dict = class_to_dict(train_cfg) 219 | runner = runner_class(env, train_cfg_dict, log_dir_path, device=args.rl_device) 220 | # save resume path before creating a new log_dir 221 | runner.add_git_repo_to_log(__file__) 222 | # save resume path before creating a new log_dir 223 | resume = train_cfg.runner.resume 224 | if resume: 225 | # load previously trained model 226 | resume_path = get_load_path( 227 | log_root, 228 | load_run=train_cfg.runner.load_run, 229 | checkpoint=train_cfg.runner.checkpoint, 230 | ) 231 | print(f"Loading model from: {resume_path}") 232 | runner.load(resume_path) 233 | 234 | return runner, train_cfg 235 | 236 | """ 237 | Private helpers. 238 | """ 239 | 240 | def _check_valid_task(self, name: str) -> bool: 241 | """Checks if input task name is valid. 242 | 243 | Args: 244 | name (str): Name of the registered task. 245 | 246 | Raises: 247 | ValueError: When there is no registered environment with input `name`. 248 | 249 | Returns: 250 | bool: True if the task exists. 251 | """ 252 | registered_tasks = self.get_task_names() 253 | if name not in registered_tasks: 254 | print(f"The task '{name}' is not registered. Please use one of the following: ") 255 | for name in registered_tasks: 256 | print(f"\t - {name}") 257 | raise ValueError(f"[TaskRegistry]: Task with name: {name} is not registered.") 258 | else: 259 | return True 260 | 261 | 262 | # make global task registry 263 | task_registry = TaskRegistry() 264 | -------------------------------------------------------------------------------- /solo_gym/utils/terrain.py: -------------------------------------------------------------------------------- 1 | # isaacgym 2 | from isaacgym import terrain_utils 3 | 4 | # python 5 | import numpy as np 6 | 7 | # solo-gym 8 | from solo_gym.utils.base_config import BaseConfig 9 | 10 | 11 | class Terrain: 12 | """Wrapper around terrain-utils to generate terrains.""" 13 | 14 | def __init__(self, cfg: BaseConfig, num_robots: int) -> None: 15 | 16 | self.cfg = cfg 17 | self.num_robots = num_robots 18 | self.type = cfg.mesh_type 19 | if self.type in ["none", "plane"]: 20 | return 21 | self.env_length = cfg.terrain_length 22 | self.env_width = cfg.terrain_width 23 | self.proportions = [np.sum(cfg.terrain_proportions[: i + 1]) for i in range(len(cfg.terrain_proportions))] 24 | 25 | self.cfg.num_sub_terrains = cfg.num_rows * cfg.num_cols 26 | self.env_origins = np.zeros((cfg.num_rows, cfg.num_cols, 3)) 27 | 28 | self.width_per_env_pixels = int(self.env_width / cfg.horizontal_scale) 29 | self.length_per_env_pixels = int(self.env_length / cfg.horizontal_scale) 30 | 31 | self.border = int(cfg.border_size / self.cfg.horizontal_scale) 32 | self.tot_cols = int(cfg.num_cols * self.width_per_env_pixels) + 2 * self.border 33 | self.tot_rows = int(cfg.num_rows * self.length_per_env_pixels) + 2 * self.border 34 | 35 | self.height_field_raw = np.zeros((self.tot_rows, self.tot_cols), dtype=np.int16) 36 | if cfg.curriculum: 37 | self.curriculum() 38 | elif cfg.selected: 39 | self.selected_terrain() 40 | else: 41 | self.randomized_terrain() 42 | 43 | self.heightsamples = self.height_field_raw 44 | if self.type == "trimesh": 45 | (self.vertices, self.triangles,) = terrain_utils.convert_heightfield_to_trimesh( 46 | self.height_field_raw, 47 | self.cfg.horizontal_scale, 48 | self.cfg.vertical_scale, 49 | self.cfg.slope_treshold, 50 | ) 51 | 52 | def randomized_terrain(self): 53 | for k in range(self.cfg.num_sub_terrains): 54 | # Env coordinates in the world 55 | (i, j) = np.unravel_index(k, (self.cfg.num_rows, self.cfg.num_cols)) 56 | 57 | choice = np.random.uniform(0, 1) 58 | difficulty = np.random.choice([0.5, 0.75, 0.9]) 59 | terrain = self.make_terrain(choice, difficulty) 60 | self.add_terrain_to_map(terrain, i, j) 61 | 62 | def curriculum(self): 63 | for j in range(self.cfg.num_cols): 64 | for i in range(self.cfg.num_rows): 65 | difficulty = i / self.cfg.num_rows 66 | choice = j / self.cfg.num_cols + 0.001 67 | 68 | terrain = self.make_terrain(choice, difficulty) 69 | self.add_terrain_to_map(terrain, i, j) 70 | 71 | def selected_terrain(self): 72 | terrain_type = self.cfg.terrain_kwargs.pop("type") 73 | for k in range(self.cfg.num_sub_terrains): 74 | # Env coordinates in the world 75 | (i, j) = np.unravel_index(k, (self.cfg.num_rows, self.cfg.num_cols)) 76 | 77 | terrain = terrain_utils.SubTerrain( 78 | "terrain", 79 | width=self.width_per_env_pixels, 80 | length=self.width_per_env_pixels, 81 | vertical_scale=self.vertical_scale, 82 | horizontal_scale=self.horizontal_scale, 83 | ) 84 | 85 | eval(terrain_type)(terrain, **self.cfg.terrain_kwargs.terrain_kwargs) 86 | self.add_terrain_to_map(terrain, i, j) 87 | 88 | def make_terrain(self, choice, difficulty): 89 | terrain = terrain_utils.SubTerrain( 90 | "terrain", 91 | width=self.width_per_env_pixels, 92 | length=self.width_per_env_pixels, 93 | vertical_scale=self.cfg.vertical_scale, 94 | horizontal_scale=self.cfg.horizontal_scale, 95 | ) 96 | slope = difficulty * 0.4 97 | step_height = 0.05 + 0.18 * difficulty 98 | discrete_obstacles_height = 0.05 + difficulty * 0.2 99 | stepping_stones_size = 1.5 * (1.05 - difficulty) 100 | stone_distance = 0.05 if difficulty == 0 else 0.1 101 | gap_size = 1.0 * difficulty 102 | pit_depth = 1.0 * difficulty 103 | if choice < self.proportions[0]: 104 | if choice < self.proportions[0] / 2: 105 | slope *= -1 106 | terrain_utils.pyramid_sloped_terrain(terrain, slope=slope, platform_size=3.0) 107 | elif choice < self.proportions[1]: 108 | terrain_utils.pyramid_sloped_terrain(terrain, slope=slope, platform_size=3.0) 109 | terrain_utils.random_uniform_terrain( 110 | terrain, 111 | min_height=-0.05, 112 | max_height=0.05, 113 | step=0.005, 114 | downsampled_scale=0.2, 115 | ) 116 | elif choice < self.proportions[3]: 117 | if choice < self.proportions[2]: 118 | step_height *= -1 119 | terrain_utils.pyramid_stairs_terrain(terrain, step_width=0.31, step_height=step_height, platform_size=3.0) 120 | elif choice < self.proportions[4]: 121 | num_rectangles = 20 122 | rectangle_min_size = 1.0 123 | rectangle_max_size = 2.0 124 | terrain_utils.discrete_obstacles_terrain( 125 | terrain, 126 | discrete_obstacles_height, 127 | rectangle_min_size, 128 | rectangle_max_size, 129 | num_rectangles, 130 | platform_size=3.0, 131 | ) 132 | elif choice < self.proportions[5]: 133 | terrain_utils.stepping_stones_terrain( 134 | terrain, 135 | stone_size=stepping_stones_size, 136 | stone_distance=stone_distance, 137 | max_height=0.0, 138 | platform_size=4.0, 139 | ) 140 | elif choice < self.proportions[6]: 141 | gap_terrain(terrain, gap_size=gap_size, platform_size=3.0) 142 | else: 143 | pit_terrain(terrain, depth=pit_depth, platform_size=4.0) 144 | 145 | return terrain 146 | 147 | def add_terrain_to_map(self, terrain, row, col): 148 | i = row 149 | j = col 150 | # map coordinate system 151 | start_x = self.border + i * self.length_per_env_pixels 152 | end_x = self.border + (i + 1) * self.length_per_env_pixels 153 | start_y = self.border + j * self.width_per_env_pixels 154 | end_y = self.border + (j + 1) * self.width_per_env_pixels 155 | self.height_field_raw[start_x:end_x, start_y:end_y] = terrain.height_field_raw 156 | 157 | env_origin_x = (i + 0.5) * self.env_length 158 | env_origin_y = (j + 0.5) * self.env_width 159 | x1 = int((self.env_length / 2.0 - 1) / terrain.horizontal_scale) 160 | x2 = int((self.env_length / 2.0 + 1) / terrain.horizontal_scale) 161 | y1 = int((self.env_width / 2.0 - 1) / terrain.horizontal_scale) 162 | y2 = int((self.env_width / 2.0 + 1) / terrain.horizontal_scale) 163 | env_origin_z = np.max(terrain.height_field_raw[x1:x2, y1:y2]) * terrain.vertical_scale 164 | self.env_origins[i, j] = [env_origin_x, env_origin_y, env_origin_z] 165 | 166 | 167 | def gap_terrain(terrain: BaseConfig, gap_size: float, platform_size: float = 1.0): 168 | gap_size = int(gap_size / terrain.horizontal_scale) 169 | platform_size = int(platform_size / terrain.horizontal_scale) 170 | 171 | center_x = terrain.length // 2 172 | center_y = terrain.width // 2 173 | x1 = (terrain.length - platform_size) // 2 174 | x2 = x1 + gap_size 175 | y1 = (terrain.width - platform_size) // 2 176 | y2 = y1 + gap_size 177 | 178 | terrain.height_field_raw[center_x - x2 : center_x + x2, center_y - y2 : center_y + y2] = -1000 179 | terrain.height_field_raw[center_x - x1 : center_x + x1, center_y - y1 : center_y + y1] = 0 180 | 181 | 182 | def pit_terrain(terrain: BaseConfig, depth, platform_size: float = 1.0): 183 | depth = int(depth / terrain.vertical_scale) 184 | platform_size = int(platform_size / terrain.horizontal_scale / 2) 185 | x1 = terrain.length // 2 - platform_size 186 | x2 = terrain.length // 2 + platform_size 187 | y1 = terrain.width // 2 - platform_size 188 | y2 = terrain.width // 2 + platform_size 189 | terrain.height_field_raw[x1:x2, y1:y2] = -depth 190 | --------------------------------------------------------------------------------