├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── fld.png ├── humanoid_gym ├── __init__.py ├── envs │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── legged_robot.py │ │ └── legged_robot_config.py │ ├── base_task.py │ └── mit_humanoid │ │ ├── mit_humanoid.py │ │ └── mit_humanoid_config.py └── utils │ ├── README.md │ ├── __init__.py │ ├── base_config.py │ ├── helpers.py │ ├── keyboard_controller.py │ ├── logger.py │ ├── math.py │ ├── task_registry.py │ └── terrain.py ├── learning ├── __init__.py ├── algorithms │ ├── __init__.py │ └── ppo.py ├── datasets │ ├── __init__.py │ └── motion_loader.py ├── env │ ├── __init__.py │ └── vec_env.py ├── modules │ ├── __init__.py │ ├── actor_critic.py │ ├── actor_critic_recurrent.py │ ├── discriminator.py │ ├── fld.py │ ├── gmm.py │ ├── normalizer.py │ └── plotter.py ├── runners │ ├── __init__.py │ └── gld_on_policy_runner.py ├── samplers │ ├── __init__.py │ ├── alp_gmm.py │ ├── base.py │ ├── gmm.py │ ├── offline.py │ └── random.py ├── storage │ ├── __init__.py │ ├── distribution_buffer.py │ ├── replay_buffer.py │ └── rollout_storage.py └── utils │ ├── __init__.py │ └── utils.py ├── resources └── robots │ └── mit_humanoid │ ├── datasets │ ├── anomalous │ │ ├── motion_data_crossover.pt │ │ ├── motion_data_jump.pt │ │ ├── motion_data_kick.pt │ │ └── motion_data_spinkick.pt │ ├── decoded │ │ └── reference_state_idx_dict.json │ └── misc │ │ ├── motion_data_back.pt │ │ ├── motion_data_jog.pt │ │ ├── motion_data_jog_slow.pt │ │ ├── motion_data_run.pt │ │ ├── motion_data_side_left.pt │ │ ├── motion_data_side_right.pt │ │ ├── motion_data_step.pt │ │ ├── motion_data_step_fast.pt │ │ ├── motion_data_stride.pt │ │ └── reference_state_idx_dict.json │ ├── meshes │ ├── back.stl │ ├── left_foot.stl │ ├── left_forearm.stl │ ├── left_hip_abad.stl │ ├── left_hip_yaw.stl │ ├── left_leg_lower.stl │ ├── left_leg_upper.stl │ ├── left_shoulder1.stl │ ├── left_shoulder2.stl │ ├── left_shoulder3.stl │ └── torso.stl │ └── urdf │ ├── README.md │ ├── humanoid_F_ht.urdf │ ├── humanoid_F_ht_b.urdf │ ├── humanoid_F_sf.urdf │ ├── humanoid_R_ht.urdf │ ├── humanoid_R_sf.urdf │ └── humanoid_R_sf_b.urdf ├── scripts ├── fld │ ├── evaluate.py │ ├── experiment.py │ ├── preview.py │ └── training.py ├── play.py └── train.py └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.dae filter=lfs diff=lfs merge=lfs -text 2 | *.obj filter=lfs diff=lfs merge=lfs -text 3 | *.obj text !filter !merge !diff 4 | *.dae text !filter !merge !diff 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # These are some examples of commonly ignored file patterns. 2 | # You should customize this list as applicable to your project. 3 | # Learn more about .gitignore: 4 | # https://www.atlassian.com/git/tutorials/saving-changes/gitignore 5 | 6 | # Node artifact files 7 | node_modules/ 8 | dist/ 9 | 10 | # Compiled Java class files 11 | *.class 12 | 13 | # Compiled Python bytecode 14 | *.py[cod] 15 | 16 | # Log files 17 | *.log 18 | 19 | # Package files 20 | *.jar 21 | 22 | # Maven 23 | target/ 24 | dist/ 25 | 26 | # JetBrains IDE 27 | .idea/ 28 | 29 | # Unit test reports 30 | TEST*.xml 31 | 32 | # Generated by MacOS 33 | .DS_Store 34 | 35 | # Generated by Windows 36 | Thumbs.db 37 | 38 | # Applications 39 | *.app 40 | *.exe 41 | *.war 42 | 43 | # Large media files 44 | *.mp4 45 | *.tiff 46 | *.avi 47 | *.flv 48 | *.mov 49 | *.wmv 50 | 51 | # VS Code 52 | .vscode 53 | # logs 54 | logs 55 | runs 56 | 57 | # other 58 | *.egg-info 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, ETH Zurich, NVIDIA Corporation 2 | 3 | All rights reserved 4 | Parts of the code are released under BSD-3-Clause license. 5 | 6 | See licenses in resources/robots for license information for assets included in this repository -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLD with MIT Humanoid 2 | 3 | This repository provides the [Fourier Latent Dynamics (FLD)](https://arxiv.org/abs/2402.13820) algorithm that represents high-dimension, long-horizon, highly nonlinear, period or quasi-period data in a continuously parameterized space. This work demonstrates its representation and generation capability with a robotic motion tracking task on [MIT Humanoid](https://spectrum.ieee.org/mit-dynamic-acrobatic-humanoid-robot) using [NVIDIA Isaac Gym](https://developer.nvidia.com/isaac-gym). 4 | 5 | ![fld](fld.png) 6 | 7 | **Paper**: [FLD: Fourier Latent Dynamics for Structured Motion Representation and Learning](https://arxiv.org/abs/2402.13820) 8 | **Project website**: https://sites.google.com/view/iclr2024-fld/home 9 | 10 | **Maintainer**: [Chenhao Li](https://breadli428.github.io/) 11 | **Affiliation**: [Biomimetic Robotics Lab](https://biomimetics.mit.edu/), [Massachusetts Institute of Technology](https://www.mit.edu/) 12 | **Contact**: [chenhli@mit.edu](mailto:chenhli@mit.edu) 13 | 14 | ## Installation 15 | 16 | 1. Create a new python virtual environment with `python 3.8` 17 | 2. Install `pytorch 1.10` with `cuda-11.3` 18 | 19 | pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 20 | 21 | 3. Install Isaac Gym 22 | 23 | - Download and install [Isaac Gym Preview 4](https://developer.nvidia.com/isaac-gym) 24 | 25 | ``` 26 | cd isaacgym/python 27 | pip install -e . 28 | ``` 29 | 30 | - Try running an example 31 | 32 | ``` 33 | cd examples 34 | python 1080_balls_of_solitude.py 35 | ``` 36 | 37 | - For troubleshooting, check docs in `isaacgym/docs/index.html` 38 | 39 | 4. Install `humanoid_gym` 40 | 41 | git clone https://github.com/mit-biomimetics/fld.git 42 | cd fld 43 | pip install -e . 44 | 45 | ## Configuration 46 | - The workflow consists of two main stages: motion representation and motion learning. In the first stage, the motion data is represented in the latent space using FLD. In the second stage, the latent space is used to train a policy for the robot. 47 | - The provided code examplifies the training of FLD with human motion data retargeted to MIT Humanoid. The dataset of [9 different motions](https://youtu.be/MVkg18c5aaU) is stored under `resources/robots/mit_humanoid/datasets/misc`. 10 trajectories of 240 frames for each motion are stored in a separate `.pt` file with the format `motion_data_.pt`. The state dimension indices are specified in `reference_state_idx_dict.json` under `resources/robots/mit_humanoid/datasets/misc`. 48 | - The MIT Humanoid environment is defined by an env file `mit_humanoid.py` and a config file `mit_humanoid_config.py` under `humanoid_gym/envs/mit_humanoid/`. The config file sets both the environment parameters in class `MITHumanoidFlatCfg` and the training parameters in class `MITHumanoidFlatCfgPPO`. 49 | 50 | 51 | ## Usage 52 | 53 | ### FLD Training 54 | 55 | ``` 56 | python scripts/fld/experiment.py 57 | ``` 58 | 59 | - `history_horizon` denotes the window size of the input data. A good practice is to set it such that it contains at least one period of the motion. 60 | - `forecast_horizon` denotes the number of future steps to predict while maintaining the quasi-constant latent parameterization. For motions with high aperiodicity, this value should be set small. It falls back to PAE when `forecast_horizon` is set to 1. 61 | - The training process is visualized by inspecting the Tensorboard logs at `logs//fld/misc/`. The figures include the FLD loss, the reconstruction of sampled trajectories for each motion, the latent parameters in each latent channel along sampled trajectories for each motion with the formed latent manifold, and the latent parameter distribution. 62 | - The trained FLD model is saved in `logs//fld/misc/model_.pt`, where `` is defined in the experiment config. 63 | - The training process is logged in the same folder. Run `tensorboard --logdir logs//fld/misc/ --samples_per_plugin images=100` to visualize the training loss and plots. 64 | - A `statistics.pt` file is saved in the same folder, containing the mean and standard deviation of the input data and the statistics of the latent parameterization space. This file is used to normalize the input data and to define plotting ranges during policy training. 65 | 66 | ### FLD Evaluation 67 | 68 | ``` 69 | python scripts/fld/evaluate.py 70 | ``` 71 | 72 | - A `latent_params.pt` file is saved in the same folder, containing the latent parameters of the input data. This file is used to define the input data for policy training with the offline task sampler. 73 | - A `gmm.pt` file is saved in the same folder, containing the Gaussian Mixture Model (GMM) of the latent parameters. This file is used to define the input data distribution for policy training with the offline gmm task sampler. 74 | - A set of latent parameters is sampled and reconstructed to the original motion space. The decoded motion is saved in `resources/robots/mit_humanoid/datasets/decoded/motion_data.pt`. Figure 1 shows the latent sample and the reconstructed motion trajectory. Figure 2 shows the sampled latent parameters. Figure 3 shows the latent manifold of the sampled trajectory, along with the original ones. Figure 4 shows the GMM of the latent parameters. 75 | - Note that the motion contains only kinematic and proprioceptive information. For visualization only, the global position and orientation of the robot base are approximated by integrating the velocity information with finite difference. Depending on the finite difference method and the intial states, the global position and orientation may be inaccurate and drift over time. 76 | 77 | ### Motion Visualization 78 | 79 | ``` 80 | python scripts/fld/preview.py 81 | ``` 82 | - To visualize the original motions in the training dataset or the sampled and decoded motions in the Isaac Gym environment, set `motion_file` to the corresponding motion file. 83 | - Alternatively, the latent parameters can be interactively modified by setting `PLAY_LOADED_DATA` to `False`. The modified latent parameters are then decoded to the original motion space and visualized. 84 | 85 | ### Policy Training 86 | 87 | ``` 88 | python scripts/train.py --task mit_humanoid 89 | ``` 90 | 91 | - Configure the training parameters in `humanoid_gym/envs/mit_humanoid/mit_humanoid_config.py`. 92 | - Choose the task sampler by setting `MITHumanoidFlatCfgPPO.runner.task_sampler_class_name` to `OfflineSampler`, `GMMSampler`, `RandomSampler` or `ALPGMMSampler`. 93 | - The trained policy is saved in `logs//_/model_.pt`, where `` and `` are defined in the train config. 94 | - To disable rendering, append `--headless`. 95 | 96 | ### Policy Playing 97 | 98 | ``` 99 | python scripts/play.py --load_run "_" 100 | ``` 101 | 102 | - By default the loaded policy is the last model of the last run of the experiment folder. 103 | - Other runs/model iteration can be selected by setting `load_run` and `checkpoint` in the train config. 104 | - The target motions are randomly selected from the dataset from the path specified by `datasets_root`. These motions are first encoded to the latent space and then sent to the policy for execution. 105 | - The fallback mechanism is enabled by default with a theshold of 1.0 on `dynamics_error`. 106 | 107 | ## Troubleshooting 108 | ``` 109 | RuntimeError: nvrtc: error: invalid value for --gpu-architecture (-arch) 110 | ``` 111 | - This error occurs when the CUDA version is not compatible with the installed PyTorch version. A quick fix is to comment out decorator `@torch.jit.script` in `isaacgym/python/isaacgym/torch_utils.py`. 112 | 113 | 114 | ## Known Issues 115 | The `ALPGMMSampler` utilizes [faiss](https://github.com/facebookresearch/faiss) for efficient similarity search and clustering of dense vectors in the latent parameterization space. The installation of `faiss` requires a compatible CUDA version. The current implementation is tested with `faiss-cpu` and `faiss-gpu` with `cuda-10.2`. 116 | 117 | 118 | ## Citation 119 | ``` 120 | @article{li2024fld, 121 | title={FLD: Fourier Latent Dynamics for Structured Motion Representation and Learning}, 122 | author={Li, Chenhao and Stanger-Jones, Elijah and Heim, Steve and Kim, Sangbae}, 123 | journal={arXiv preprint arXiv:2402.13820}, 124 | year={2024} 125 | } 126 | ``` 127 | 128 | ## References 129 | 130 | The code is built upon the open-sourced [Periodic Autoencoder (PAE) Implementation](https://github.com/sebastianstarke/AI4Animation/tree/master/AI4Animation/SIGGRAPH_2022/PyTorch/PAE), [Isaac Gym Environments for Legged Robots](https://github.com/leggedrobotics/legged_gym) and the [PPO implementation](https://github.com/leggedrobotics/rsl_rl). We refer to the original repositories for more details. 131 | -------------------------------------------------------------------------------- /fld.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/fld.png -------------------------------------------------------------------------------- /humanoid_gym/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | LEGGED_GYM_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | """Absolute path to the humanoid-gym repository.""" 5 | 6 | LEGGED_GYM_ENVS_DIR = os.path.join(LEGGED_GYM_ROOT_DIR, "humanoid_gym", "envs") 7 | """Absolute path to the module `humanoid_gym.envs` in humanoid-gym repository.""" 8 | -------------------------------------------------------------------------------- /humanoid_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Locomotion environments. 3 | ## 4 | # fmt: off 5 | from .base.legged_robot import LeggedRobot 6 | from .mit_humanoid.mit_humanoid import MITHumanoid 7 | from .mit_humanoid.mit_humanoid_config import ( 8 | MITHumanoidFlatCfg, 9 | MITHumanoidFlatCfgPPO 10 | ) 11 | 12 | # fmt: on 13 | 14 | ## 15 | # Task registration 16 | ## 17 | from humanoid_gym.utils.task_registry import task_registry 18 | 19 | task_registry.register("mit_humanoid", MITHumanoid, MITHumanoidFlatCfg, MITHumanoidFlatCfgPPO) -------------------------------------------------------------------------------- /humanoid_gym/envs/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .legged_robot import LeggedRobot 2 | from .legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO 3 | 4 | __all__ = ["LeggedRobot", "LeggedRobotCfg", "LeggedRobotCfgPPO"] 5 | -------------------------------------------------------------------------------- /humanoid_gym/envs/base/legged_robot_config.py: -------------------------------------------------------------------------------- 1 | # humanoid-gym 2 | from humanoid_gym.utils.base_config import BaseConfig 3 | 4 | 5 | class LeggedRobotCfg(BaseConfig): 6 | class env: 7 | num_envs = 4096 8 | num_observations = 235 # robot state (48) + height scans (17*11=187) 9 | num_privileged_obs = None # if not None a priviledge_obs_buf will be returned by step() (critic obs for assymetric training). None is returned otherwise 10 | num_actions = 12 # joint positions, velocities or torques 11 | env_spacing = 3.0 # not used with heightfields/trimeshes 12 | send_timeouts = True # send time out information to the algorithm 13 | episode_length_s = 20 # episode length in seconds 14 | 15 | class terrain: 16 | mesh_type = "trimesh" # none, plane, heightfield or trimesh 17 | horizontal_scale = 0.1 # [m] 18 | vertical_scale = 0.005 # [m] 19 | border_size = 25 # [m] 20 | curriculum = True 21 | static_friction = 1.0 22 | dynamic_friction = 1.0 23 | restitution = 0.0 24 | # rough terrain only: 25 | measure_heights = True 26 | # 1mx1.6m rectangle (without center line) 27 | # fmt: off 28 | measured_points_x = [-0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] 29 | measured_points_y = [-0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5] 30 | # fmt: on 31 | selected = False # select a unique terrain type and pass all arguments 32 | terrain_kwargs = None # Dict of arguments for selected terrain 33 | max_init_terrain_level = 5 # starting curriculum state 34 | terrain_length = 8.0 35 | terrain_width = 8.0 36 | num_rows = 10 # number of terrain rows (levels) 37 | num_cols = 20 # number of terrain cols (types) 38 | # terrain types: [smooth slope, rough slope, stairs up, stairs down, discrete] 39 | terrain_proportions = [0.1, 0.1, 0.35, 0.25, 0.2] 40 | # trimesh only: 41 | # slopes above this threshold will be corrected to vertical surfaces 42 | slope_treshold = 0.75 43 | 44 | class commands: 45 | curriculum = False 46 | max_curriculum = 1.0 47 | num_commands = 4 # default: lin_vel_x, lin_vel_y, ang_vel_yaw, heading (in heading mode ang_vel_yaw is recomputed from heading error) 48 | resampling_time = 10.0 # time before commands are changed [s] 49 | heading_command = True # if true: compute ang vel command from heading error 50 | 51 | class ranges: 52 | lin_vel_x = [-1.0, 1.0] # min max [m/s] 53 | lin_vel_y = [-1.0, 1.0] # min max [m/s] 54 | ang_vel_yaw = [-1, 1] # min max [rad/s] 55 | heading = [-3.14, 3.14] # [rad] 56 | 57 | class init_state: 58 | pos = [0.0, 0.0, 1.0] # x,y,z [m] 59 | rot = [0.0, 0.0, 0.0, 1.0] # x,y,z,w [quat] 60 | lin_vel = [0.0, 0.0, 0.0] # x,y,z [m/s] 61 | ang_vel = [0.0, 0.0, 0.0] # x,y,z [rad/s] 62 | default_joint_angles = { # target angles when action = 0.0 63 | "joint_a": 0.0, 64 | "joint_b": 0.0, 65 | } 66 | 67 | class control: 68 | control_type = "P" # P: position, V: velocity, T: torques 69 | # PD Drive parameters: 70 | stiffness = {"joint_a": 10.0, "joint_b": 15.0} # [N*m/rad] 71 | damping = {"joint_a": 1.0, "joint_b": 1.5} # [N*m*s/rad] 72 | # action scale: target angle = actionScale * action + defaultAngle 73 | action_scale = 0.5 74 | # decimation: Number of control action updates @ sim DT per policy DT 75 | decimation = 4 76 | 77 | class asset: 78 | file = "" 79 | foot_name = "None" # name of the feet bodies, used to index body state and contact force tensors 80 | penalize_contacts_on = [] 81 | terminate_after_contacts_on = [] 82 | disable_gravity = False 83 | collapse_fixed_joints = True # merge bodies connected by fixed joints. Specific fixed joints can be kept by adding " <... dont_collapse="true"> 84 | fix_base_link = False # fixe the base of the robot 85 | default_dof_drive_mode = 3 # see GymDofDriveModeFlags (0 is none, 1 is pos tgt, 2 is vel tgt, 3 effort) 86 | self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter 87 | # replace collision cylinders with capsules, leads to faster/more stable simulation 88 | replace_cylinder_with_capsule = True 89 | flip_visual_attachments = True # Some .obj meshes must be flipped from y-up to z-up 90 | enable_joint_force_sensors = False # Check out isaacgym_lib/docs/programming/forcesensors.html 91 | 92 | density = 0.001 93 | angular_damping = 0.0 94 | linear_damping = 0.0 95 | max_angular_velocity = 1000.0 96 | max_linear_velocity = 1000.0 97 | armature = 0.0 98 | thickness = 0.01 99 | 100 | class domain_rand: 101 | randomize_friction = True 102 | friction_range = [0.5, 1.25] 103 | randomize_base_mass = False 104 | added_mass_range = [-1.0, 1.0] 105 | push_robots = True 106 | push_interval_s = 15 # push applied each time interval [s] 107 | max_push_vel_xy = 1.0 # velocity offset added by push [m/s] 108 | 109 | class rewards: 110 | class scales: 111 | termination = -0.0 112 | tracking_lin_vel = 1.0 113 | tracking_ang_vel = 0.5 114 | lin_vel_z = -2.0 115 | ang_vel_xy = -0.05 116 | orientation = -0.0 117 | torques = -0.00001 118 | dof_vel = -0.0 119 | dof_acc = -2.5e-7 120 | base_height = -0.0 121 | feet_air_time = 1.0 122 | collision = -1.0 123 | feet_stumble = -0.0 124 | action_rate = -0.01 125 | stand_still = -0.0 126 | 127 | # if true negative total rewards are clipped at zero (avoids early termination problems) 128 | only_positive_rewards = True 129 | tracking_sigma = 0.25 # tracking reward = exp(-error^2/sigma) 130 | # percentage of urdf limits, values above this limit are penalized 131 | soft_dof_pos_limit = 1.0 132 | soft_dof_vel_limit = 1.0 133 | soft_torque_limit = 1.0 134 | base_height_target = 1.0 135 | max_contact_force = 100.0 # forces above this value are penalized 136 | 137 | class normalization: 138 | class obs_scales: 139 | lin_vel = 2.0 140 | ang_vel = 0.25 141 | dof_pos = 1.0 142 | dof_vel = 0.05 143 | height_measurements = 5.0 144 | 145 | clip_observations = 100.0 146 | clip_actions = 100.0 147 | 148 | class noise: 149 | add_noise = True 150 | noise_level = 1.0 # scales other values 151 | 152 | class noise_scales: 153 | dof_pos = 0.01 154 | dof_vel = 1.5 155 | lin_vel = 0.1 156 | ang_vel = 0.2 157 | gravity = 0.05 158 | height_measurements = 0.1 159 | 160 | # viewer camera: 161 | class viewer: 162 | ref_env = 0 163 | pos = [10, 0, 6] # [m] 164 | lookat = [11.0, 5, 3.0] # [m] 165 | 166 | class sim: 167 | dt = 0.005 168 | substeps = 1 169 | gravity = [0.0, 0.0, -9.81] # [m/s^2] 170 | up_axis = 1 # 0 is y, 1 is z 171 | 172 | class physx: 173 | num_threads = 10 174 | solver_type = 1 # 0: pgs, 1: tgs 175 | num_position_iterations = 4 176 | num_velocity_iterations = 0 177 | contact_offset = 0.01 # [m] 178 | rest_offset = 0.0 # [m] 179 | bounce_threshold_velocity = 0.5 # [m/s] 180 | max_depenetration_velocity = 1.0 181 | max_gpu_contact_pairs = 2 ** 23 # 2**24 -> needed for 8000 envs and more 182 | default_buffer_size_multiplier = 5 183 | # 0: never, 1: last sub-step, 2: all sub-steps (default=2) 184 | contact_collection = 2 185 | 186 | 187 | class LeggedRobotCfgPPO(BaseConfig): 188 | seed = 1 189 | runner_class_name = "OnPolicyRunner" 190 | 191 | class policy: 192 | init_noise_std = 1.0 193 | actor_hidden_dims = [512, 256, 128] 194 | critic_hidden_dims = [512, 256, 128] 195 | activation = "elu" # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid 196 | # only for 'ActorCriticRecurrent': 197 | # rnn_type = 'lstm' 198 | # rnn_hidden_size = 512 199 | # rnn_num_layers = 1 200 | 201 | class algorithm: 202 | # training params 203 | value_loss_coef = 1.0 204 | use_clipped_value_loss = True 205 | clip_param = 0.2 206 | entropy_coef = 0.01 207 | num_learning_epochs = 5 208 | num_mini_batches = 4 # mini batch size = num_envs * nsteps / nminibatches 209 | learning_rate = 1.0e-3 # 5.e-4 210 | schedule = "adaptive" # adaptive, fixed 211 | gamma = 0.99 212 | lam = 0.95 213 | desired_kl = 0.01 214 | max_grad_norm = 1.0 215 | 216 | class runner: 217 | policy_class_name = "ActorCritic" 218 | algorithm_class_name = "PPO" 219 | num_steps_per_env = 24 # per iteration 220 | max_iterations = 1500 # number of policy updates 221 | 222 | # logging 223 | save_interval = 50 # check for potential saves every this many iterations 224 | experiment_name = "test" 225 | run_name = "" 226 | # load and resume 227 | resume = False 228 | load_run = -1 # -1 = last run 229 | checkpoint = -1 # -1 = last saved model 230 | resume_path = None # updated from load_run and chkpt 231 | -------------------------------------------------------------------------------- /humanoid_gym/envs/base_task.py: -------------------------------------------------------------------------------- 1 | # isaacgym 2 | from isaacgym import gymapi 3 | from isaacgym import gymutil 4 | 5 | # python 6 | import sys 7 | import torch 8 | import abc 9 | from typing import Tuple, Union 10 | 11 | # humanoid-gym 12 | from humanoid_gym.utils.base_config import BaseConfig 13 | 14 | 15 | class BaseTask: 16 | """Base class for RL tasks.""" 17 | 18 | def __init__( 19 | self, 20 | cfg: BaseConfig, 21 | sim_params: gymapi.SimParams, 22 | physics_engine: gymapi.SimType, 23 | sim_device: str, 24 | headless: bool, 25 | ): 26 | """Initialize the base class for RL. 27 | 28 | The class initializes the simulation. It also allocates buffers for observations, 29 | actions, rewards, reset, episode length, episode timetout and privileged observations. 30 | 31 | The :obj:`cfg` must contain the following: 32 | 33 | - num_envs (int): Number of environment instances. 34 | - num_observations (int): Number of observations. 35 | - num_privileged_obs (int): Number of privileged observations. 36 | - num_actions (int): Number of actions. 37 | 38 | Note: 39 | If :obj:`cfg.num_privileged_obs` is not :obj:`None`, a buffer for privileged 40 | observations is returned. This is useful for critic observations in asymmetric 41 | actor-critic. 42 | 43 | Args: 44 | cfg (BaseConfig): Configuration for the environment. 45 | sim_params (gymapi.SimParams): The simulation parameters. 46 | physics_engine (gymapi.SimType): Simulation type (must be gymapi.SIM_PHYSX). 47 | sim_device (str): The simulation device (ex: `cuda:0` or `cpu`). 48 | headless (bool): If true, run without rendering. 49 | """ 50 | # copy input arguments into class members 51 | self.sim_params = sim_params 52 | self.physics_engine = physics_engine 53 | self.sim_device = sim_device 54 | self.headless = headless 55 | sim_device_type, self.sim_device_id = gymutil.parse_device_str(self.sim_device) 56 | # env device is GPU only if sim is on GPU and use_gpu_pipeline is True. 57 | # otherwise returned tensors are copied to CPU by PhysX. 58 | if sim_device_type == "cuda" and sim_params.use_gpu_pipeline: 59 | self.device = self.sim_device 60 | else: 61 | self.device = "cpu" 62 | # graphics device for rendering, -1 for no rendering 63 | self.graphics_device_id = self.sim_device_id 64 | if self.headless is True: 65 | self.graphics_device_id = -1 66 | 67 | # store the environment information 68 | self.num_envs = cfg.env.num_envs 69 | self.num_obs = cfg.env.num_observations 70 | self.num_privileged_obs = cfg.env.num_privileged_obs 71 | self.num_actions = cfg.env.num_actions 72 | 73 | # optimization flags for pytorch JIT 74 | torch._C._jit_set_profiling_mode(False) 75 | torch._C._jit_set_profiling_executor(False) 76 | 77 | # allocate buffers 78 | self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float) 79 | self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float) 80 | self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long) 81 | self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long) 82 | self.time_out_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) 83 | if self.num_privileged_obs is not None: 84 | self.privileged_obs_buf = torch.zeros( 85 | self.num_envs, 86 | self.num_privileged_obs, 87 | device=self.device, 88 | dtype=torch.float, 89 | ) 90 | else: 91 | self.privileged_obs_buf = None 92 | # allocate dictionary to store metrics 93 | self.extras = {} 94 | 95 | # create envs, sim 96 | self.gym = gymapi.acquire_gym() 97 | self.create_sim() 98 | self.gym.prepare_sim(self.sim) 99 | 100 | # create viewer 101 | # Todo: read from config 102 | self.enable_viewer_sync = True 103 | self.viewer = None 104 | # if running with a viewer, set up keyboard shortcuts and camera 105 | if self.headless is False: 106 | # subscribe to keyboard shortcuts 107 | self.viewer = self.gym.create_viewer(self.sim, gymapi.CameraProperties()) 108 | self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_ESCAPE, "QUIT") 109 | self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_V, "toggle_viewer_sync") 110 | 111 | def __del__(self): 112 | """Cleanup in the end.""" 113 | try: 114 | if self.sim is not None: 115 | self.gym.destroy_sim(self.sim) 116 | if self.viewer is not None: 117 | self.gym.destroy_viewer(self.viewer) 118 | except: 119 | pass 120 | 121 | """ 122 | Properties. 123 | """ 124 | 125 | def get_observations(self) -> torch.Tensor: 126 | return self.obs_buf 127 | 128 | def get_privileged_observations(self) -> Union[torch.Tensor, None]: 129 | return self.privileged_obs_buf 130 | 131 | """ 132 | Operations. 133 | """ 134 | 135 | def set_camera_view(self, position: Tuple[float, float, float], lookat: Tuple[float, float, float]) -> None: 136 | """Set camera position and direction.""" 137 | cam_pos = gymapi.Vec3(position[0], position[1], position[2]) 138 | cam_target = gymapi.Vec3(lookat[0], lookat[1], lookat[2]) 139 | self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target) 140 | 141 | def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 142 | """Reset all environment instances. 143 | 144 | Returns: 145 | Tuple[torch.Tensor, torch.Tensor | None]: Tuple containing the observations and privileged observations. 146 | """ 147 | # reset environments 148 | self.reset_idx(torch.arange(self.num_envs, device=self.device)) 149 | # perform single-step to get observations 150 | zero_actions = torch.zeros(self.num_envs, self.num_actions, device=self.device, requires_grad=False) 151 | obs, privileged_obs, _, _, _ = self.step(zero_actions) 152 | # return obs 153 | return obs, privileged_obs 154 | 155 | @abc.abstractmethod 156 | def step( 157 | self, actions: torch.Tensor 158 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]: 159 | """Apply input action on the environment. 160 | 161 | Args: 162 | actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions) 163 | 164 | Returns: 165 | Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, dict]: 166 | A tuple containing the observations, privileged observations, rewards, dones and 167 | extra information (metrics). 168 | """ 169 | raise NotImplementedError 170 | 171 | def render(self, sync_frame_time=True): 172 | """Render the viewer.""" 173 | if self.viewer: 174 | # check for window closed 175 | if self.gym.query_viewer_has_closed(self.viewer): 176 | sys.exit() 177 | # check for keyboard events 178 | for evt in self.gym.query_viewer_action_events(self.viewer): 179 | if evt.action == "QUIT" and evt.value > 0: 180 | sys.exit() 181 | elif evt.action == "toggle_viewer_sync" and evt.value > 0: 182 | self.enable_viewer_sync = not self.enable_viewer_sync 183 | # fetch results 184 | if self.device != "cpu": 185 | self.gym.fetch_results(self.sim, True) 186 | # step graphics 187 | if self.enable_viewer_sync: 188 | self.gym.step_graphics(self.sim) 189 | self.gym.draw_viewer(self.viewer, self.sim, True) 190 | if sync_frame_time: 191 | self.gym.sync_frame_time(self.sim) 192 | else: 193 | self.gym.poll_viewer_events(self.viewer) 194 | 195 | """ 196 | Protected Methods. 197 | """ 198 | 199 | @abc.abstractmethod 200 | def create_sim(self): 201 | """Creates simulation, terrain and environments""" 202 | raise NotImplementedError 203 | 204 | @abc.abstractmethod 205 | def reset_idx(self, env_ids: torch.Tensor) -> None: 206 | """Resets the MDP for given environment instances. 207 | 208 | Args: 209 | env_ids (torch.Tensor): A tensor containing indices of environment instances to reset. 210 | """ 211 | raise NotImplementedError 212 | -------------------------------------------------------------------------------- /humanoid_gym/envs/mit_humanoid/mit_humanoid_config.py: -------------------------------------------------------------------------------- 1 | from humanoid_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO 2 | from humanoid_gym import LEGGED_GYM_ROOT_DIR 3 | 4 | 5 | class MITHumanoidFlatCfg(LeggedRobotCfg): 6 | class env(LeggedRobotCfg.env): 7 | num_observations = 106 8 | num_actions = 18 9 | 10 | class terrain(LeggedRobotCfg.terrain): 11 | mesh_type = "plane" 12 | curriculum = False 13 | measure_heights = False 14 | terrain_proportions = [0.0, 1.0] 15 | num_rows = 5 16 | max_init_terrain_level = 4 17 | 18 | class init_state(LeggedRobotCfg.init_state): 19 | pos = [0.0, 0.0, 0.7] # x,y,z [m] 20 | default_joint_angles = { # = target angles [rad] when action = 0.0 21 | "left_hip_yaw": 0.0, 22 | "left_hip_abad": 0.0, 23 | "left_hip_pitch": -0.4, 24 | "left_knee": 0.9, 25 | "left_ankle": -0.45, 26 | "left_shoulder_pitch": 0.0, 27 | "left_shoulder_abad": 0.0, 28 | "left_shoulder_yaw": 0.0, 29 | "left_elbow": 0.0, 30 | 31 | "right_hip_yaw": 0.0, 32 | "right_hip_abad": 0.0, 33 | "right_hip_pitch": -0.4, 34 | "right_knee": 0.9, 35 | "right_ankle": -0.45, 36 | "right_shoulder_pitch": 0.0, 37 | "right_shoulder_abad": 0.0, 38 | "right_shoulder_yaw": 0.0, 39 | "right_elbow": 0.0, 40 | } 41 | 42 | class control(LeggedRobotCfg.control): 43 | stiffness = { 44 | "hip_yaw": 30.0, 45 | "hip_abad": 30.0, 46 | "hip_pitch": 30.0, 47 | "knee": 30.0, 48 | "ankle": 30.0, 49 | "shoulder_pitch": 40.0, 50 | "shoulder_abad": 40.0, 51 | "shoulder_yaw": 40.0, 52 | "elbow": 40.0, 53 | } # [N*m/rad] 54 | damping = { 55 | "hip_yaw": 5.0, 56 | "hip_abad": 5.0, 57 | "hip_pitch": 5.0, 58 | "knee": 5.0, 59 | "ankle": 5.0, 60 | "shoulder_pitch": 5.0, 61 | "shoulder_abad": 5.0, 62 | "shoulder_yaw": 5.0, 63 | "elbow": 5.0, 64 | } # [N*m*s/rad] 65 | 66 | class asset(LeggedRobotCfg.asset): 67 | file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/mit_humanoid/urdf/humanoid_R_sf.urdf' 68 | foot_name = "foot" 69 | penalize_contacts_on = ["arm"] 70 | terminate_after_contacts_on = ["base"] 71 | self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter 72 | flip_visual_attachments = False 73 | 74 | class rewards(LeggedRobotCfg.rewards): 75 | soft_dof_pos_limit = 0.85 76 | soft_dof_vel_limit = 0.9 77 | soft_torque_limit = 0.9 78 | base_height_target = 0.66 79 | max_contact_force = 350.0 80 | tracking_reconstructed_lin_vel_scale = 0.2 81 | tracking_reconstructed_ang_vel_scale = 0.2 82 | tracking_reconstructed_projected_gravity_scale = 1.0 83 | tracking_reconstructed_dof_pos_leg_l_scale = 1.0 84 | tracking_reconstructed_dof_pos_arm_l_scale = 1.0 85 | tracking_reconstructed_dof_pos_leg_r_scale = 1.0 86 | tracking_reconstructed_dof_pos_arm_r_scale = 1.0 87 | class scales(LeggedRobotCfg.rewards.scales): 88 | orientation = -0.0 89 | feet_air_time = 0.0 90 | lin_vel_z = -0.0 91 | ang_vel_xy = -0.0 92 | stand_still = -0.0 93 | base_height = -0.0 94 | tracking_lin_vel = 0.0 95 | tracking_ang_vel = 0.0 96 | tracking_reconstructed_lin_vel = 1.0 97 | tracking_reconstructed_ang_vel = 1.0 98 | tracking_reconstructed_projected_gravity = 1.0 99 | tracking_reconstructed_dof_pos_leg_l = 1.0 100 | tracking_reconstructed_dof_pos_arm_l = 1.0 101 | tracking_reconstructed_dof_pos_leg_r = 1.0 102 | tracking_reconstructed_dof_pos_arm_r = 1.0 103 | arm_near_home = -0.0 104 | leg_near_home = -0.0 105 | class commands(LeggedRobotCfg.commands): 106 | curriculum = False 107 | resampling_time = 5.0 108 | heading_command = False 109 | class ranges: 110 | lin_vel_x = [-0.0, 0.0] # min max [m/s] 111 | lin_vel_y = [-0.0, 0.0] # min max [m/s] 112 | ang_vel_yaw = [-0.0, 0.0] # min max [rad/s] 113 | 114 | class domain_rand(LeggedRobotCfg.domain_rand): 115 | push_robots = True 116 | max_push_vel_xy = 0.5 117 | randomize_base_mass = True 118 | added_mass_range = [-1.0, 1.0] 119 | latent_encoding_update_noise_level = 0.0 120 | 121 | class fld: 122 | latent_channel = 8 123 | history_horizon = 51 124 | encoder_shape = [64, 64] 125 | decoder_shape = [64, 64] 126 | state_idx_dict = { 127 | "base_pos": [0, 1, 2], 128 | "base_quat": [3, 4, 5, 6], 129 | "base_lin_vel": [7, 8, 9], 130 | "base_ang_vel": [10, 11, 12], 131 | "projected_gravity": [13, 14, 15], 132 | "dof_pos_leg_l": [16, 17, 18, 19, 20], 133 | "dof_pos_arm_l": [21, 22, 23, 24], 134 | "dof_pos_leg_r": [25, 26, 27, 28, 29], 135 | "dof_pos_arm_r": [30, 31, 32, 33], 136 | } 137 | load_root = LEGGED_GYM_ROOT_DIR + "/logs/flat_mit_humanoid/fld/misc" 138 | load_model = "model_5000.pt" 139 | 140 | 141 | class task_sampler: 142 | collect_samples = True 143 | collect_sample_step_interval = 5 144 | collect_elite_performance_threshold = 1.0 145 | library_size = 5000 146 | update_interval = 50 147 | elite_buffer_size = 1 148 | max_num_updates = 1000 149 | curriculum = True 150 | curriculum_scale = 1.25 151 | curriculum_performance_threshold = 0.8 152 | max_num_curriculum_updates = 5 153 | check_update_interval = 100 154 | class offline: 155 | pass 156 | class random: 157 | pass 158 | class gmm: 159 | num_components = 8 160 | class alp_gmm: 161 | init_num_components = 8 162 | min_num_components = 2 163 | max_num_components = 10 164 | random_type = "uniform" # "uniform" or "gmm" 165 | class classifier: 166 | enabled = False 167 | num_classes = 9 168 | 169 | 170 | class MITHumanoidFlatCfgPPO(LeggedRobotCfgPPO): 171 | runner_class_name = "FLDOnPolicyRunner" 172 | class policy(LeggedRobotCfgPPO.policy): 173 | actor_hidden_dims = [128, 128, 128] 174 | critic_hidden_dims = [128, 128, 128] 175 | init_noise_std = 1.0 176 | class runner(LeggedRobotCfgPPO.runner): 177 | run_name = 'min_forces' 178 | experiment_name = 'flat_mit_humanoid' 179 | algorithm_class_name = "PPO" 180 | policy_class_name = "ActorCritic" 181 | task_sampler_class_name = "OfflineSampler" # "OfflineSampler", "GMMSampler", "RandomSampler", "ALPGMMSampler" 182 | load_run = -1 183 | max_iterations = 3000 -------------------------------------------------------------------------------- /humanoid_gym/utils/README.md: -------------------------------------------------------------------------------- 1 | # Legged Gym Utilities 2 | 3 | ## Keyboard Controller 4 | 5 | By overwriting the `_get_keyboard_events()` method, a custom keyboard controller can be added to the environment. The keyboard controller subscribes to IsaacGym's keyboard-system, therefore the events are only caught, if the IsaacGym window is focused. 6 | 7 | ### Example 8 | 9 | ```python 10 | from humanoid_gym.utils.keyboard_controller import KeyboardAction, Button, Delta, Switch 11 | 12 | def _get_keyboard_events(self) -> Dict[str, KeyboardAction]: 13 | # Simple keyboard controller for linear and angular velocity 14 | 15 | def print_command(): 16 | print(f"New command: {self.commands[0]}") 17 | 18 | key_board_events = { 19 | 'u' : Delta("lin_vel_x", amount = 0.1, variable_reference = self.commands[:, 0], callback = print_command), 20 | 'j' : Delta("lin_vel_x", amount = -0.1, variable_reference = self.commands[:, 0], callback = print_command), 21 | 'h' : Delta("lin_vel_y", amount = 0.1, variable_reference = self.commands[:, 1], callback = print_command), 22 | 'k' : Delta("lin_vel_y", amount = -0.1, variable_reference = self.commands[:, 1], callback = print_command), 23 | 'y' : Delta("ang_vel_z", amount = 0.1, variable_reference = self.commands[:, 2], callback = print_command), 24 | 'i' : Delta("ang_vel_z",amount = -0.1, variable_reference = self.commands[:, 2], callback = print_command), 25 | 'm' : Button("some_var", 0, 1, self.commands[:, someIndex], print_command) 26 | 'n' : Switch("some_other_var", 0, 1, self.commands[:, someIndex], print_command) 27 | } 28 | return key_board_events 29 | ``` 30 | 31 | A parent keyboard can also be extended by calling the `super()` method: 32 | 33 | ```python 34 | def _get_keyboard_events(self) -> Dict[str, KeyboardAction]: 35 | basic_keyboard = super()._get_keyboard_events() 36 | basic_keyboard['x'] = Button("new_var", 0, 1, self.commands[:, someIndex], None) 37 | return basic_keyboard 38 | ``` 39 | 40 | The following keyboard events are available: 41 | 42 | |**Classname** | **Parameters** | **Description** | 43 | |--------------|----------------|-----------------| 44 | | Delta | `amount`, `variable_reference`, `change_callback` (optional) | Increments the `reference_variable` by its amount and calls the `change_callback` if it was passed | 45 | | Button | `start_state`, `toggle_state`, `variable_reference`, `callback` (optional) | Sets `variable_reference[:] = toggle_state` for the duration the button is held down. Resets to `start_state` afterwards. Calls the `callback` if it was passed. | 46 | | Switch | `start_state`, `toggle_state`, `index`, `variable_reference`, `callback` (optional) | Toggles `variable_reference[:]` between the `toggle_state` and `start_state` every time the button is pressed and released. Calls the `callback` if it was passed. | 47 | | DelegateHandle | `delegate`, `edge_detection`, `callback` | Exectues the function handle `delegate` when the key was pressed. If `edge_detection` is true, it only executes in on rising edges. Executes the `callback` whenever the function handle was called. | 48 | 49 | With the `DelegateHandle` keyboard-event basically every desired action can be implemented. `Delta`, `Button` and `Switch` are only commonly used helpers. 50 | 51 | ### **Available keys** 52 | 53 | The list of keys you can use (e.g. `basic_keyboard['KEY_NAME']`) be found below. The Controller takes the key in the dictionary (e.g. `x`), transforms it to capital letters and prepends `KEY_` to it. So to get an event on `KEY_RIGHT_ALT`, you have to add `basic_keyboard['right_alt'] = [...]`. 54 | 55 |
56 | Click here to see all options 57 | KEY_SPACE, 58 | KEY_APOSTROPHE, 59 | KEY_COMMA, 60 | KEY_MINUS, 61 | KEY_PERIOD, 62 | KEY_SLASH, 63 | KEY_0, 64 | KEY_1, 65 | KEY_2, 66 | KEY_3, 67 | KEY_4, 68 | KEY_5, 69 | KEY_6, 70 | KEY_7, 71 | KEY_8, 72 | KEY_9, 73 | KEY_SEMICOLON, 74 | KEY_EQUAL, 75 | KEY_A, 76 | KEY_B, 77 | KEY_C, 78 | KEY_D, 79 | KEY_E, 80 | KEY_F, 81 | KEY_G, 82 | KEY_H, 83 | KEY_I, 84 | KEY_J, 85 | KEY_K, 86 | KEY_L, 87 | KEY_M, 88 | KEY_N, 89 | KEY_O, 90 | KEY_P, 91 | KEY_Q, 92 | KEY_R, 93 | KEY_S, 94 | KEY_T, 95 | KEY_U, 96 | KEY_V, 97 | KEY_W, 98 | KEY_X, 99 | KEY_Y, 100 | KEY_Z, 101 | KEY_LEFT_BRACKET, 102 | KEY_BACKSLASH, 103 | KEY_RIGHT_BRACKET, 104 | KEY_GRAVE_ACCENT, 105 | KEY_ESCAPE, 106 | KEY_TAB, 107 | KEY_ENTER, 108 | KEY_BACKSPACE, 109 | KEY_INSERT, 110 | KEY_DEL, 111 | KEY_RIGHT, 112 | KEY_LEFT, 113 | KEY_DOWN, 114 | KEY_UP, 115 | KEY_PAGE_UP, 116 | KEY_PAGE_DOWN, 117 | KEY_HOME, 118 | KEY_END, 119 | KEY_CAPS_LOCK, 120 | KEY_SCROLL_LOCK, 121 | KEY_NUM_LOCK, 122 | KEY_PRINT_SCREEN, 123 | KEY_PAUSE, 124 | KEY_F1, 125 | KEY_F2, 126 | KEY_F3, 127 | KEY_F4, 128 | KEY_F5, 129 | KEY_F6, 130 | KEY_F7, 131 | KEY_F8, 132 | KEY_F9, 133 | KEY_F10, 134 | KEY_F11, 135 | KEY_F12, 136 | KEY_NUMPAD_0, 137 | KEY_NUMPAD_1, 138 | KEY_NUMPAD_2, 139 | KEY_NUMPAD_3, 140 | KEY_NUMPAD_4, 141 | KEY_NUMPAD_5, 142 | KEY_NUMPAD_6, 143 | KEY_NUMPAD_7, 144 | KEY_NUMPAD_8, 145 | KEY_NUMPAD_9, 146 | KEY_NUMPAD_DEL, 147 | KEY_NUMPAD_DIVIDE, 148 | KEY_NUMPAD_MULTIPLY, 149 | KEY_NUMPAD_SUBTRACT, 150 | KEY_NUMPAD_ADD, 151 | KEY_NUMPAD_ENTER, 152 | KEY_NUMPAD_EQUAL, 153 | KEY_LEFT_SHIFT, 154 | KEY_LEFT_CONTROL, 155 | KEY_LEFT_ALT, 156 | KEY_LEFT_SUPER, 157 | KEY_RIGHT_SHIFT, 158 | KEY_RIGHT_CONTROL, 159 | KEY_RIGHT_ALT, 160 | KEY_RIGHT_SUPER, 161 | KEY_MENU 162 |
163 | 164 | An the exact list depends on your isaacgym version and can be found in the docs folder of your local isaacgym_lib copy: `isaacgym_lib/docs/api/python/enum_py.html#isaacgym.gymapi.KeyboardInput`. 165 | -------------------------------------------------------------------------------- /humanoid_gym/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .helpers import ( 2 | class_to_dict, 3 | get_load_path, 4 | get_args, 5 | export_policy_as_jit, 6 | export_policy_as_onnx, 7 | set_seed, 8 | update_class_from_dict, 9 | ) 10 | from .task_registry import task_registry 11 | from .logger import Logger 12 | from .math import * 13 | from .terrain import Terrain 14 | -------------------------------------------------------------------------------- /humanoid_gym/utils/base_config.py: -------------------------------------------------------------------------------- 1 | # python 2 | import inspect 3 | 4 | 5 | class BaseConfig: 6 | """Data structure for handling python-classes configurations.""" 7 | 8 | def __init__(self) -> None: 9 | """Initializes all member classes recursively.""" 10 | self.init_member_classes(self) 11 | 12 | @staticmethod 13 | def init_member_classes(obj) -> None: 14 | """Initializes all member classes recursively. 15 | 16 | Note: 17 | Ignores all names starting with "__" (i.e. built-in methods). 18 | """ 19 | # iterate over all attributes names 20 | for key in dir(obj): 21 | # disregard builtin attributes 22 | # if key.startswith("__"): 23 | if key == "__class__": 24 | continue 25 | # get the corresponding attribute object 26 | var = getattr(obj, key) 27 | # check if the attribute is a class 28 | if inspect.isclass(var): 29 | # instantiate the class 30 | i_var = var() 31 | # set the attribute to the instance instead of the type 32 | setattr(obj, key, i_var) 33 | # recursively init members of the attribute 34 | BaseConfig.init_member_classes(i_var) 35 | -------------------------------------------------------------------------------- /humanoid_gym/utils/keyboard_controller.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # isaacgym 4 | import isaacgym 5 | 6 | # python 7 | from abc import abstractmethod, ABC 8 | from typing import Callable, Any, Dict 9 | import torch 10 | 11 | 12 | # Callback function for when event is triggered. 13 | CallbackFn = Callable[[], None] 14 | # Action function called when event is trigged. 15 | # Takes the environment instance and event value. 16 | ActionFn = Callable[[Any, int], None] 17 | 18 | 19 | class KeyboardAction(ABC): 20 | """Base class for keyboard event.""" 21 | 22 | def __init__( 23 | self, 24 | name: str, 25 | variable_reference: torch.Tensor = None, 26 | member_name: str = None, 27 | ): 28 | """Initializes the keyboard action event with variable attributes. 29 | 30 | Note: 31 | The keyboard action can be applied to a variable (passed via reference) or 32 | on a member in the environment class instance. 33 | 34 | Args: 35 | name (str): Name of the affected value. 36 | variable_reference (torch.Tensor, optional): Reference variable to alter value. Defaults to None. 37 | member_name (str, optional): Name of the variable in the environment. Defaults to None. 38 | 39 | Raises: 40 | ValueError -- If both reference variable and environment's member name are None or not None. 41 | """ 42 | # check input 43 | if (variable_reference is None and member_name is None) or ( 44 | variable_reference is not None and member_name is not None 45 | ): 46 | msg = "Invalid arguments: Action can only be applied on either reference variable or environment's member variable." 47 | raise ValueError(msg) 48 | # store input arguments 49 | self.name = name 50 | self.variable_reference = variable_reference 51 | self.member_name = member_name 52 | # disambiguate the type of mode 53 | if variable_reference is not None and member_name is None: 54 | self._ref_mode = True 55 | elif variable_reference is None and member_name is not None: 56 | self._ref_mode = False 57 | 58 | def __str__(self) -> str: 59 | """Helper string to explain keyboard action.""" 60 | return f"Keyboard action on {self.name}." 61 | 62 | def get_reference(self, env) -> torch.Tensor: 63 | """Retrieve the variable on which event action is applied. 64 | 65 | Args: 66 | env (BaseTask): The environment/task instance. 67 | 68 | Returns: 69 | torch.Tensor: The passed variable reference or environment instance's member. 70 | """ 71 | if self._ref_mode: 72 | return self.variable_reference 73 | else: 74 | return getattr(env, self.member_name) 75 | 76 | @abstractmethod 77 | def do(self, env, value: int): 78 | """Action applied by the keyboard event. 79 | 80 | Args: 81 | env (BaseTask): The environment/task instance. 82 | value (int): The event triggered when keyboard button pressed. 83 | """ 84 | raise NotImplementedError 85 | 86 | 87 | class DelegateHandle(KeyboardAction): 88 | """Pre-defined delegate that executes an event handler. 89 | 90 | This class exectues the function handle `delegate` when the key is pressed. If `edge_detection` is 91 | true, then the function executes only on rising edges (i.e. release of the key). 92 | 93 | The `callback` function is executed whenever the function handle is called. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | name: str, 99 | delegate: ActionFn, 100 | edge_detection: bool = True, 101 | callback: CallbackFn = None, 102 | variable_reference: torch.Tensor = None, 103 | member_name: str = None, 104 | ): 105 | """Initializes the class. 106 | 107 | Args: 108 | name (str): Name of the affected value. 109 | delegate (ActionFn): The function called when keyboard is pressed/released. 110 | edge_detection (bool, optional): Decides whether to change value on press/release. Defaults to True. 111 | callback (CallbackFn, optional): Function called whenever key triggered. Defaults to None. 112 | variable_reference (torch.Tensor, optional): Reference variable to alter value. Defaults to None. 113 | member_name (str, optional): Name of the variable in the environment. Defaults to None. 114 | """ 115 | super().__init__(name, variable_reference, member_name) 116 | # store inputs 117 | self._delegate = delegate 118 | self._edge_detection = edge_detection 119 | self._callback = callback 120 | 121 | def do(self, env, value): 122 | """Action applied by the keyboard event. 123 | 124 | Args: 125 | env (BaseTask): The environment/task instance. 126 | value (int): The event triggered when keyboard button pressed. 127 | """ 128 | # if no event triggered return. 129 | if self._edge_detection and value == 0: 130 | return 131 | # resolve action based on press/release 132 | self._delegate(env, value) 133 | # trigger callback function 134 | if self._callback is not None: 135 | self._callback() 136 | 137 | 138 | class Delta(DelegateHandle): 139 | """Keyboard action that increments the value of reference variable by scalar amount.""" 140 | 141 | def __init__( 142 | self, 143 | name: str, 144 | amount: float, 145 | variable_reference: torch.Tensor, 146 | callback: CallbackFn = None, 147 | ): 148 | """Initializes the class. 149 | 150 | Args: 151 | name (str): Name of the affected value. 152 | amount (float): The amount by which to increment. 153 | variable_reference (torch.Tensor): Reference variable to alter value. 154 | callback (CallbackFn, optional): Function called whenever key triggered. Defaults to None. 155 | """ 156 | self.amount = amount 157 | 158 | # delegate function 159 | def addDelta(env, value): 160 | self.variable_reference += self.amount 161 | 162 | # initialize parent 163 | super().__init__(name, addDelta, True, callback, variable_reference, None) 164 | 165 | def __str__(self) -> str: 166 | if self.amount >= 0: 167 | return f"Increments the variable {self.name} by {self.amount}" 168 | else: 169 | return f"Decrements the variable {self.name} by {-self.amount}" 170 | 171 | 172 | class Switch(DelegateHandle): 173 | """Keyboard action that toggles between values of reference variable.""" 174 | 175 | def __init__( 176 | self, 177 | name: str, 178 | start_state: torch.Tensor, 179 | toggle_state: torch.Tensor, 180 | variable_reference: torch.Tensor, 181 | callback: CallbackFn = None, 182 | ): 183 | """Initializes the class. 184 | 185 | Args: 186 | name (str): Name of the affected value. 187 | start_state (torch.Tensor): Initial value of reference variable. 188 | toggle_state (torch.Tensor): Toggled value of reference variable. 189 | variable_reference (torch.Tensor): Reference variable to alter value. 190 | callback (CallbackFn, optional): Function called whenever key triggered. Defaults to None. 191 | """ 192 | # copy inputs to class 193 | self.start_state = start_state 194 | self.toggle_state = toggle_state 195 | self.variable_reference = variable_reference 196 | # initial state of toggle switch 197 | self.switch_value = True 198 | 199 | # delegate function 200 | def switchState(env, value): 201 | # switch between state depending on switch's value 202 | if self.switch_value: 203 | new_state = self.toggle_state 204 | else: 205 | new_state = self.start_state 206 | # store value into reference variable 207 | self.variable_reference[:] = new_state 208 | # toggle switch to other state 209 | self.switch_value = not self.switch_value 210 | 211 | # initialize parent 212 | super().__init__(name, switchState, True, callback, variable_reference, None) 213 | 214 | def __str__(self) -> str: 215 | return f"Toggles the variable {self.name} between {self.toggle_state} and {self.start_state}." 216 | 217 | 218 | class Button(Switch): 219 | """Sets the variable to value only while keyboard button is pressed.""" 220 | 221 | def __init__( 222 | self, 223 | name: str, 224 | start_state: torch.Tensor, 225 | toggle_state: torch.Tensor, 226 | variable_reference: torch.Tensor, 227 | callback: CallbackFn = None, 228 | ): 229 | """Initializes the class. 230 | 231 | Args: 232 | name (str): Name of the affected value. 233 | start_state (torch.Tensor): Initial value of reference variable. 234 | toggle_state (torch.Tensor): Toggled value of reference variable. 235 | variable_reference (torch.Tensor): Reference variable to alter value. 236 | callback (CallbackFn, optional): Function called whenever key triggered. Defaults to None. 237 | """ 238 | # initialize toggle switch 239 | super().__init__(name, start_state, toggle_state, variable_reference, callback) 240 | # trigger event only when key is pressed 241 | self._edge_detection = False 242 | 243 | def __str__(self) -> str: 244 | return f"Sets the variable {self.name} to {self.toggle_state} only while key is pressed." 245 | 246 | 247 | class KeyBoardController: 248 | """Wrapper around IsaacGym viewer to handle different keyboard actions.""" 249 | 250 | def __init__(self, env, key_actions: Dict[str, KeyboardAction]): 251 | """Initializes the class. 252 | 253 | Args: 254 | env (BaseTask): The environment/task instance. 255 | key_actions (Dict[str, KeyboardAction]): The pairs of key buttons and their actions. 256 | """ 257 | # store inputs 258 | self._env = env 259 | self._key_actions = key_actions 260 | # setup the keyboard event subscriber 261 | for key_name in self._key_actions.keys(): 262 | key_enum = getattr(isaacgym.gymapi.KeyboardInput, f"KEY_{key_name.capitalize()}") 263 | env.gym.subscribe_viewer_keyboard_event(env.viewer, key_enum, key_name) 264 | 265 | def update(self, env): 266 | """Update the reference variables by querying viewer events.""" 267 | # gather all events on viewer 268 | events = env.gym.query_viewer_action_events(env.viewer) 269 | # iterate over events 270 | for event in events: 271 | key_pressed = event.action 272 | if key_pressed in self._key_actions: 273 | cfg = self._key_actions[key_pressed] 274 | cfg.do(env, event.value) 275 | 276 | def print_options(self): 277 | print("[KeyboardController] Key-action pairs:") 278 | for key_name, action in self._key_actions.items(): 279 | print(f"\t{key_name}: {action}") 280 | -------------------------------------------------------------------------------- /humanoid_gym/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # python 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from collections import defaultdict 8 | from multiprocessing import Process 9 | 10 | 11 | class Logger: 12 | def __init__(self, dt): 13 | self.state_log = defaultdict(list) 14 | self.rew_log = defaultdict(list) 15 | self.dt = dt 16 | self.num_episodes = 0 17 | self.plot_process = None 18 | 19 | def log_state(self, key, value): 20 | self.state_log[key].append(value) 21 | 22 | def log_states(self, dict): 23 | for key, value in dict.items(): 24 | self.log_state(key, value) 25 | 26 | def log_rewards(self, dict, num_episodes): 27 | for key, value in dict.items(): 28 | if "rew" in key: 29 | self.rew_log[key].append(value.item() * num_episodes) 30 | self.num_episodes += num_episodes 31 | 32 | def reset(self): 33 | self.state_log.clear() 34 | self.rew_log.clear() 35 | 36 | def plot_states(self): 37 | self.plot_process = Process(target=self._plot) 38 | self.plot_process.start() 39 | 40 | def _plot(self): 41 | nb_rows = 3 42 | nb_cols = 3 43 | fig, axs = plt.subplots(nb_rows, nb_cols) 44 | for key, value in self.state_log.items(): 45 | time = np.linspace(0, len(value) * self.dt, len(value)) 46 | break 47 | log = self.state_log 48 | # plot joint targets and measured positions 49 | a = axs[1, 0] 50 | if log["dof_pos"]: 51 | a.plot(time, log["dof_pos"], label="measured") 52 | if log["dof_pos_target"]: 53 | a.plot(time, log["dof_pos_target"], label="target") 54 | a.set(xlabel="time [s]", ylabel="Position [rad]", title="DOF Position") 55 | a.legend() 56 | # plot joint velocity 57 | a = axs[1, 1] 58 | if log["dof_vel"]: 59 | a.plot(time, log["dof_vel"], label="measured") 60 | if log["dof_vel_target"]: 61 | a.plot(time, log["dof_vel_target"], label="target") 62 | a.set(xlabel="time [s]", ylabel="Velocity [rad/s]", title="Joint Velocity") 63 | a.legend() 64 | # plot base vel x 65 | a = axs[0, 0] 66 | if log["base_vel_x"]: 67 | a.plot(time, log["base_vel_x"], label="measured") 68 | if log["command_x"]: 69 | a.plot(time, log["command_x"], label="commanded") 70 | a.set(xlabel="time [s]", ylabel="base lin vel [m/s]", title="Base velocity x") 71 | a.legend() 72 | # plot base vel y 73 | a = axs[0, 1] 74 | if log["base_vel_y"]: 75 | a.plot(time, log["base_vel_y"], label="measured") 76 | if log["command_y"]: 77 | a.plot(time, log["command_y"], label="commanded") 78 | a.set(xlabel="time [s]", ylabel="base lin vel [m/s]", title="Base velocity y") 79 | a.legend() 80 | # plot base vel yaw 81 | a = axs[0, 2] 82 | if log["base_vel_yaw"]: 83 | a.plot(time, log["base_vel_yaw"], label="measured") 84 | if log["command_yaw"]: 85 | a.plot(time, log["command_yaw"], label="commanded") 86 | a.set(xlabel="time [s]", ylabel="base ang vel [rad/s]", title="Base velocity yaw") 87 | a.legend() 88 | # plot base vel z 89 | a = axs[1, 2] 90 | if log["base_vel_z"]: 91 | a.plot(time, log["base_vel_z"], label="measured") 92 | a.set(xlabel="time [s]", ylabel="base lin vel [m/s]", title="Base velocity z") 93 | a.legend() 94 | # plot contact forces 95 | a = axs[2, 0] 96 | if log["contact_forces_z"]: 97 | forces = np.array(log["contact_forces_z"]) 98 | for i in range(forces.shape[1]): 99 | a.plot(time, forces[:, i], label=f"force {i}") 100 | a.set(xlabel="time [s]", ylabel="Forces z [N]", title="Vertical Contact forces") 101 | a.legend() 102 | # plot torque/vel curves 103 | a = axs[2, 1] 104 | if log["dof_vel"] != [] and log["dof_torque"] != []: 105 | a.plot(log["dof_vel"], log["dof_torque"], "x", label="measured") 106 | a.set( 107 | xlabel="Joint vel [rad/s]", 108 | ylabel="Joint Torque [Nm]", 109 | title="Torque/velocity curves", 110 | ) 111 | a.legend() 112 | # plot torques 113 | a = axs[2, 2] 114 | if log["dof_torque"] != []: 115 | a.plot(time, log["dof_torque"], label="measured") 116 | a.set(xlabel="time [s]", ylabel="Joint Torque [Nm]", title="Torque") 117 | a.legend() 118 | plt.show() 119 | 120 | def print_rewards(self): 121 | print("Average rewards per second:") 122 | for key, values in self.rew_log.items(): 123 | mean = np.sum(np.array(values)) / self.num_episodes 124 | print(f" - {key}: {mean}") 125 | print(f"Total number of episodes: {self.num_episodes}") 126 | 127 | def __del__(self): 128 | if self.plot_process is not None: 129 | self.plot_process.kill() 130 | -------------------------------------------------------------------------------- /humanoid_gym/utils/math.py: -------------------------------------------------------------------------------- 1 | # isaac-gym 2 | from isaacgym.torch_utils import quat_apply, normalize, quat_mul, quat_conjugate 3 | 4 | # python 5 | import torch 6 | import numpy as np 7 | from typing import Tuple 8 | 9 | 10 | # @ torch.jit.script 11 | def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: 12 | """Rotate a vector only around the yaw-direction. 13 | 14 | Args: 15 | quat (torch.Tensor): Input orientation to extract yaw from. 16 | vec (torch.Tensor): Input vector. 17 | 18 | Returns: 19 | torch.Tensor: Rotated vector. 20 | """ 21 | quat_yaw = quat.clone().view(-1, 4) 22 | quat_yaw[:, :2] = 0.0 23 | quat_yaw = normalize(quat_yaw) 24 | return quat_apply(quat_yaw, vec) 25 | 26 | 27 | # @ torch.jit.script 28 | def box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: 29 | """Implements box-minur operator (quaternion difference) 30 | https://docs.leggedrobotics.com/kindr/cheatsheet_latest.pdf 31 | 32 | Args: 33 | q1 (torch.Tensor): quaternion 34 | q2 (torch.Tensor): quaternion 35 | 36 | Returns: 37 | torch.Tensor: q1 box-minus q2 38 | """ 39 | quat_diff = quat_mul(q1, quat_conjugate(q2)) # q1 * q2^-1 40 | re = quat_diff[:, -1] # real part, q = [x, y, z, w] = [re, im] 41 | im = quat_diff[:, 0:3] # imaginary part 42 | norm_im = torch.norm(im, dim=1) 43 | scale = 2.0 * torch.where(norm_im > 1.0e-7, torch.atan(norm_im / re) / norm_im, torch.sign(re)) 44 | return scale.unsqueeze(-1) * im 45 | 46 | 47 | # @ torch.jit.script 48 | def wrap_to_pi(angles: torch.Tensor) -> torch.Tensor: 49 | """Wraps input angles (in radians) to the range [-pi, pi]. 50 | 51 | Args: 52 | angles (torch.Tensor): Input angles. 53 | 54 | Returns: 55 | torch.Tensor: Angles in the range [-pi, pi]. 56 | """ 57 | angles %= 2 * np.pi 58 | angles -= 2 * np.pi * (angles > np.pi) 59 | return angles 60 | 61 | 62 | # @ torch.jit.script 63 | def torch_rand_sqrt_float(lower: float, upper: float, size: Tuple[int, int], device: str) -> torch.Tensor: 64 | """Randomly samples tensor from a triangular distribution. 65 | 66 | Args: 67 | lower (float): The lower range of the sampled tensor. 68 | upper (float): The upper range of the sampled tensor. 69 | size (Tuple[int, int]): The shape of the tensor. 70 | device (str): Device to create tensor on. 71 | 72 | Returns: 73 | torch.Tensor: Sampled tensor of shape :obj:`size`. 74 | """ 75 | # create random tensor in the range [-1, 1] 76 | r = 2 * torch.rand(*size, device=device) - 1 77 | # convert to triangular distribution 78 | r = torch.where(r < 0.0, -torch.sqrt(-r), torch.sqrt(r)) 79 | # rescale back to [0, 1] 80 | r = (r + 1.0) / 2.0 81 | # rescale to range [lower, upper] 82 | return (upper - lower) * r + lower 83 | -------------------------------------------------------------------------------- /humanoid_gym/utils/task_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # python 5 | import argparse 6 | import os 7 | from datetime import datetime 8 | from typing import Tuple, Type, List 9 | 10 | # humanoid-gym 11 | from humanoid_gym import LEGGED_GYM_ROOT_DIR 12 | from humanoid_gym.utils.base_config import BaseConfig 13 | from humanoid_gym.utils.helpers import ( 14 | get_args, 15 | update_cfg_from_args, 16 | class_to_dict, 17 | get_load_path, 18 | set_seed, 19 | parse_sim_params, 20 | ) 21 | 22 | # learning 23 | from learning.env import VecEnv 24 | from learning.runners import FLDOnPolicyRunner 25 | 26 | 27 | class TaskRegistry: 28 | """This class simplifies creation of environments and agents.""" 29 | 30 | def __init__(self): 31 | self.task_classes = dict() 32 | self.env_cfgs = dict() 33 | self.train_cfgs = dict() 34 | 35 | """ 36 | Properties 37 | """ 38 | 39 | def get_task_names(self) -> List[str]: 40 | """Returns a list of registered task names. 41 | 42 | Returns: 43 | List[str]: List of registered task names. 44 | """ 45 | return list(self.task_classes.keys()) 46 | 47 | def get_task_class(self, name: str) -> Type[VecEnv]: 48 | """Retrieve the class object corresponding to input name. 49 | 50 | Args: 51 | name (str): name of the registered environment. 52 | 53 | Raises: 54 | ValueError: When there is no registered environment with input `name`. 55 | 56 | Returns: 57 | Type[VecEnv]: The environment class object. 58 | """ 59 | # check if there is a registered env with that name 60 | if self._check_valid_task(name): 61 | return self.task_classes[name] 62 | 63 | def get_cfgs(self, name: str) -> Tuple[BaseConfig, BaseConfig]: 64 | """Retrieve the default environment and training configurations. 65 | 66 | Args: 67 | name (str): Name of the environment. 68 | 69 | Raises: 70 | ValueError: When there is no registered environment with input `name`. 71 | 72 | Returns: 73 | Tuple[BaseConfig, BaseConfig]: The default environment and training configurations. 74 | """ 75 | # check if there is a registered env with that name 76 | if self._check_valid_task(name): 77 | # retrieve configurations 78 | train_cfg = self.train_cfgs[name] 79 | env_cfg = self.env_cfgs[name] 80 | # copy seed between environment and agent 81 | env_cfg.seed = train_cfg.seed 82 | return env_cfg, train_cfg 83 | 84 | """ 85 | Operations 86 | """ 87 | 88 | def register( 89 | self, 90 | name: str, 91 | task_class: Type[VecEnv], 92 | env_cfg: Type[BaseConfig], 93 | train_cfg: Type[BaseConfig], 94 | ): 95 | """Add a particular environment to the task registry. 96 | 97 | Args: 98 | name (str): Name of the environment. 99 | task_class (Type[VecEnv]): The corresponding task class. 100 | env_cfg (Type[BaseConfig]): The corresponding environment configuration file. 101 | train_cfg (Type[BaseConfig]): The corresponding agent configuration file. 102 | """ 103 | self.task_classes[name] = task_class 104 | self.env_cfgs[name] = env_cfg() 105 | self.train_cfgs[name] = train_cfg() 106 | 107 | def make_env( 108 | self, name: str, args: argparse.Namespace = None, env_cfg: BaseConfig = None 109 | ) -> Tuple[VecEnv, BaseConfig]: 110 | """Creates an environment from the registry. 111 | 112 | Args: 113 | name (str): Name of a registered env. 114 | args (argparse.Namespace, optional): Parsed CLI arguments. If :obj:`None`, then 115 | `get_args()` is called to obtain arguments. Defaults to None. 116 | env_cfg (BaseConfig, optional): Environment configuration class instance used to 117 | overwrite the default registered configuration. Defaults to None. 118 | 119 | Raises: 120 | ValueError: When there is no registered environment with input `name`. 121 | 122 | Returns: 123 | Tuple[VecEnv, BaseConfig]: Tuple containing the created class instance and corresponding 124 | configuration class instance. 125 | """ 126 | # check if there is a registered env with that name 127 | task_class = self.get_task_class(name) 128 | # if no args passed, get command line arguments 129 | if args is None: 130 | args = get_args() 131 | # if no config passed, use default env config 132 | if env_cfg is None: 133 | # load config files 134 | env_cfg, _ = self.get_cfgs(name) 135 | # override cfg from args (if specified) 136 | env_cfg, _ = update_cfg_from_args(env_cfg, None, args) 137 | # set seed 138 | set_seed(env_cfg.seed) 139 | # parse sim params (convert to dict first) 140 | sim_params = {"sim": class_to_dict(env_cfg.sim)} 141 | sim_params = parse_sim_params(args, sim_params) 142 | # create environment instance 143 | env = task_class( 144 | cfg=env_cfg, 145 | sim_params=sim_params, 146 | physics_engine=args.physics_engine, 147 | sim_device=args.sim_device, 148 | headless=args.headless, 149 | ) 150 | return env, env_cfg 151 | 152 | def make_alg_runner( 153 | self, 154 | env: VecEnv, 155 | name: str = None, 156 | args: argparse.Namespace = None, 157 | train_cfg: BaseConfig = None, 158 | log_root: str = "default", 159 | ) -> Tuple[FLDOnPolicyRunner, BaseConfig]: 160 | """Creates the training algorithm either from a registered name or from the provided 161 | config file. 162 | 163 | TODO (@nrudin): Remove environment from within the algorithm. 164 | 165 | Note: 166 | The training/agent configuration is loaded from either "name" or "train_cfg". If both are 167 | passed then the default configuration via "name" is ignored. 168 | 169 | Args: 170 | env (VecEnv): The environment to train on. 171 | name (str, optional): The environment name to retrieve corresponding training configuration. 172 | Defaults to None. 173 | args (argparse.Namespace, optional): Parsed CLI arguments. If :obj:`None`, then 174 | `get_args()` is called to obtain arguments. Defaults to None. 175 | train_cfg (BaseConfig, optional): Instance of training configuration class. If 176 | :obj:`None`, then `name` is used to retrieve default training configuration. 177 | Defaults to None. 178 | log_root (str, optional): Logging directory for TensorBoard. Set to obj:`None` to avoid 179 | logging (such as during test-time). Logs are saved in the `/_` 180 | directory. If "default", then `log_root` is set to 181 | "{LEGGED_GYM_ROOT_DIR}/logs/{train_cfg.runner.experiment_name}". Defaults to "default". 182 | 183 | Raises: 184 | ValueError: If neither "name" or "train_cfg" are provided for loading training configuration. 185 | 186 | Returns: 187 | Tuple[WASABIOnPolicyRunner, BaseConfig]: Tuple containing the training runner and configuration instances. 188 | """ 189 | # if config files are passed use them, otherwise load default from the name 190 | if train_cfg is None: 191 | if name is None: 192 | raise ValueError("No training configuration provided. Either 'name' or 'train_cfg' must not be None.") 193 | else: 194 | # load config files 195 | _, train_cfg = self.get_cfgs(name) 196 | else: 197 | if name is not None: 198 | print(f"Training configuration instance provided. Ignoring default configuration for 'name={name}'.") 199 | # if no args passed get command line arguments 200 | if args is None: 201 | args = get_args() 202 | # override cfg from args (if specified) 203 | _, train_cfg = update_cfg_from_args(None, train_cfg, args) 204 | # resolve logging 205 | if log_root is None: 206 | log_dir_path = None 207 | else: 208 | # default location for logs 209 | if log_root == "default": 210 | log_root = os.path.join(LEGGED_GYM_ROOT_DIR, "logs", train_cfg.runner.experiment_name) 211 | # log directory 212 | log_dir_path = os.path.join( 213 | log_root, 214 | datetime.now().strftime("%b%d_%H-%M-%S") + "_" + train_cfg.runner.run_name, 215 | ) 216 | # create training runner 217 | runner_class = eval(train_cfg.runner_class_name) 218 | train_cfg_dict = class_to_dict(train_cfg) 219 | runner = runner_class(env, train_cfg_dict, log_dir_path, device=args.rl_device) 220 | # save resume path before creating a new log_dir 221 | runner.add_git_repo_to_log(__file__) 222 | # save resume path before creating a new log_dir 223 | resume = train_cfg.runner.resume 224 | if resume: 225 | # load previously trained model 226 | resume_path = get_load_path( 227 | log_root, 228 | load_run=train_cfg.runner.load_run, 229 | checkpoint=train_cfg.runner.checkpoint, 230 | ) 231 | print(f"Loading model from: {resume_path}") 232 | runner.load(resume_path) 233 | 234 | return runner, train_cfg 235 | 236 | """ 237 | Private helpers. 238 | """ 239 | 240 | def _check_valid_task(self, name: str) -> bool: 241 | """Checks if input task name is valid. 242 | 243 | Args: 244 | name (str): Name of the registered task. 245 | 246 | Raises: 247 | ValueError: When there is no registered environment with input `name`. 248 | 249 | Returns: 250 | bool: True if the task exists. 251 | """ 252 | registered_tasks = self.get_task_names() 253 | if name not in registered_tasks: 254 | print(f"The task '{name}' is not registered. Please use one of the following: ") 255 | for name in registered_tasks: 256 | print(f"\t - {name}") 257 | raise ValueError(f"[TaskRegistry]: Task with name: {name} is not registered.") 258 | else: 259 | return True 260 | 261 | 262 | # make global task registry 263 | task_registry = TaskRegistry() 264 | -------------------------------------------------------------------------------- /humanoid_gym/utils/terrain.py: -------------------------------------------------------------------------------- 1 | # isaacgym 2 | from isaacgym import terrain_utils 3 | 4 | # python 5 | import numpy as np 6 | 7 | # humanoid-gym 8 | from humanoid_gym.utils.base_config import BaseConfig 9 | 10 | 11 | class Terrain: 12 | """Wrapper around terrain-utils to generate terrains.""" 13 | 14 | def __init__(self, cfg: BaseConfig, num_robots: int) -> None: 15 | 16 | self.cfg = cfg 17 | self.num_robots = num_robots 18 | self.type = cfg.mesh_type 19 | if self.type in ["none", "plane"]: 20 | return 21 | self.env_length = cfg.terrain_length 22 | self.env_width = cfg.terrain_width 23 | self.proportions = [np.sum(cfg.terrain_proportions[: i + 1]) for i in range(len(cfg.terrain_proportions))] 24 | 25 | self.cfg.num_sub_terrains = cfg.num_rows * cfg.num_cols 26 | self.env_origins = np.zeros((cfg.num_rows, cfg.num_cols, 3)) 27 | 28 | self.width_per_env_pixels = int(self.env_width / cfg.horizontal_scale) 29 | self.length_per_env_pixels = int(self.env_length / cfg.horizontal_scale) 30 | 31 | self.border = int(cfg.border_size / self.cfg.horizontal_scale) 32 | self.tot_cols = int(cfg.num_cols * self.width_per_env_pixels) + 2 * self.border 33 | self.tot_rows = int(cfg.num_rows * self.length_per_env_pixels) + 2 * self.border 34 | 35 | self.height_field_raw = np.zeros((self.tot_rows, self.tot_cols), dtype=np.int16) 36 | if cfg.curriculum: 37 | self.curriculum() 38 | elif cfg.selected: 39 | self.selected_terrain() 40 | else: 41 | self.randomized_terrain() 42 | 43 | self.heightsamples = self.height_field_raw 44 | if self.type == "trimesh": 45 | (self.vertices, self.triangles,) = terrain_utils.convert_heightfield_to_trimesh( 46 | self.height_field_raw, 47 | self.cfg.horizontal_scale, 48 | self.cfg.vertical_scale, 49 | self.cfg.slope_treshold, 50 | ) 51 | 52 | def randomized_terrain(self): 53 | for k in range(self.cfg.num_sub_terrains): 54 | # Env coordinates in the world 55 | (i, j) = np.unravel_index(k, (self.cfg.num_rows, self.cfg.num_cols)) 56 | 57 | choice = np.random.uniform(0, 1) 58 | difficulty = np.random.choice([0.5, 0.75, 0.9]) 59 | terrain = self.make_terrain(choice, difficulty) 60 | self.add_terrain_to_map(terrain, i, j) 61 | 62 | def curriculum(self): 63 | for j in range(self.cfg.num_cols): 64 | for i in range(self.cfg.num_rows): 65 | difficulty = i / self.cfg.num_rows 66 | choice = j / self.cfg.num_cols + 0.001 67 | 68 | terrain = self.make_terrain(choice, difficulty) 69 | self.add_terrain_to_map(terrain, i, j) 70 | 71 | def selected_terrain(self): 72 | terrain_type = self.cfg.terrain_kwargs.pop("type") 73 | for k in range(self.cfg.num_sub_terrains): 74 | # Env coordinates in the world 75 | (i, j) = np.unravel_index(k, (self.cfg.num_rows, self.cfg.num_cols)) 76 | 77 | terrain = terrain_utils.SubTerrain( 78 | "terrain", 79 | width=self.width_per_env_pixels, 80 | length=self.width_per_env_pixels, 81 | vertical_scale=self.vertical_scale, 82 | horizontal_scale=self.horizontal_scale, 83 | ) 84 | 85 | eval(terrain_type)(terrain, **self.cfg.terrain_kwargs.terrain_kwargs) 86 | self.add_terrain_to_map(terrain, i, j) 87 | 88 | def make_terrain(self, choice, difficulty): 89 | terrain = terrain_utils.SubTerrain( 90 | "terrain", 91 | width=self.width_per_env_pixels, 92 | length=self.width_per_env_pixels, 93 | vertical_scale=self.cfg.vertical_scale, 94 | horizontal_scale=self.cfg.horizontal_scale, 95 | ) 96 | slope = difficulty * 0.4 97 | step_height = 0.05 + 0.18 * difficulty 98 | discrete_obstacles_height = 0.05 + difficulty * 0.2 99 | stepping_stones_size = 1.5 * (1.05 - difficulty) 100 | stone_distance = 0.05 if difficulty == 0 else 0.1 101 | gap_size = 1.0 * difficulty 102 | pit_depth = 1.0 * difficulty 103 | if choice < self.proportions[0]: 104 | if choice < self.proportions[0] / 2: 105 | slope *= -1 106 | terrain_utils.pyramid_sloped_terrain(terrain, slope=slope, platform_size=3.0) 107 | elif choice < self.proportions[1]: 108 | terrain_utils.pyramid_sloped_terrain(terrain, slope=slope, platform_size=3.0) 109 | terrain_utils.random_uniform_terrain( 110 | terrain, 111 | min_height=-0.05, 112 | max_height=0.05, 113 | step=0.005, 114 | downsampled_scale=0.2, 115 | ) 116 | elif choice < self.proportions[3]: 117 | if choice < self.proportions[2]: 118 | step_height *= -1 119 | terrain_utils.pyramid_stairs_terrain(terrain, step_width=0.31, step_height=step_height, platform_size=3.0) 120 | elif choice < self.proportions[4]: 121 | num_rectangles = 20 122 | rectangle_min_size = 1.0 123 | rectangle_max_size = 2.0 124 | terrain_utils.discrete_obstacles_terrain( 125 | terrain, 126 | discrete_obstacles_height, 127 | rectangle_min_size, 128 | rectangle_max_size, 129 | num_rectangles, 130 | platform_size=3.0, 131 | ) 132 | elif choice < self.proportions[5]: 133 | terrain_utils.stepping_stones_terrain( 134 | terrain, 135 | stone_size=stepping_stones_size, 136 | stone_distance=stone_distance, 137 | max_height=0.0, 138 | platform_size=4.0, 139 | ) 140 | elif choice < self.proportions[6]: 141 | gap_terrain(terrain, gap_size=gap_size, platform_size=3.0) 142 | else: 143 | pit_terrain(terrain, depth=pit_depth, platform_size=4.0) 144 | 145 | return terrain 146 | 147 | def add_terrain_to_map(self, terrain, row, col): 148 | i = row 149 | j = col 150 | # map coordinate system 151 | start_x = self.border + i * self.length_per_env_pixels 152 | end_x = self.border + (i + 1) * self.length_per_env_pixels 153 | start_y = self.border + j * self.width_per_env_pixels 154 | end_y = self.border + (j + 1) * self.width_per_env_pixels 155 | self.height_field_raw[start_x:end_x, start_y:end_y] = terrain.height_field_raw 156 | 157 | env_origin_x = (i + 0.5) * self.env_length 158 | env_origin_y = (j + 0.5) * self.env_width 159 | x1 = int((self.env_length / 2.0 - 1) / terrain.horizontal_scale) 160 | x2 = int((self.env_length / 2.0 + 1) / terrain.horizontal_scale) 161 | y1 = int((self.env_width / 2.0 - 1) / terrain.horizontal_scale) 162 | y2 = int((self.env_width / 2.0 + 1) / terrain.horizontal_scale) 163 | env_origin_z = np.max(terrain.height_field_raw[x1:x2, y1:y2]) * terrain.vertical_scale 164 | self.env_origins[i, j] = [env_origin_x, env_origin_y, env_origin_z] 165 | 166 | 167 | def gap_terrain(terrain: BaseConfig, gap_size: float, platform_size: float = 1.0): 168 | gap_size = int(gap_size / terrain.horizontal_scale) 169 | platform_size = int(platform_size / terrain.horizontal_scale) 170 | 171 | center_x = terrain.length // 2 172 | center_y = terrain.width // 2 173 | x1 = (terrain.length - platform_size) // 2 174 | x2 = x1 + gap_size 175 | y1 = (terrain.width - platform_size) // 2 176 | y2 = y1 + gap_size 177 | 178 | terrain.height_field_raw[center_x - x2 : center_x + x2, center_y - y2 : center_y + y2] = -1000 179 | terrain.height_field_raw[center_x - x1 : center_x + x1, center_y - y1 : center_y + y1] = 0 180 | 181 | 182 | def pit_terrain(terrain: BaseConfig, depth, platform_size: float = 1.0): 183 | depth = int(depth / terrain.vertical_scale) 184 | platform_size = int(platform_size / terrain.horizontal_scale / 2) 185 | x1 = terrain.length // 2 - platform_size 186 | x2 = terrain.length // 2 + platform_size 187 | y1 = terrain.width // 2 - platform_size 188 | y2 = terrain.width // 2 + platform_size 189 | terrain.height_field_raw[x1:x2, y1:y2] = -depth 190 | -------------------------------------------------------------------------------- /learning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | -------------------------------------------------------------------------------- /learning/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Implementation of different RL agents.""" 5 | 6 | from .ppo import PPO 7 | 8 | __all__ = ["PPO"] 9 | -------------------------------------------------------------------------------- /learning/algorithms/ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # torch 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | # learning 10 | from learning.modules import ActorCritic 11 | from learning.storage import RolloutStorage 12 | 13 | 14 | class PPO: 15 | actor_critic: ActorCritic 16 | 17 | def __init__( 18 | self, 19 | actor_critic, 20 | num_learning_epochs=1, 21 | num_mini_batches=1, 22 | clip_param=0.2, 23 | gamma=0.998, 24 | lam=0.95, 25 | value_loss_coef=1.0, 26 | entropy_coef=0.0, 27 | learning_rate=1e-3, 28 | max_grad_norm=1.0, 29 | use_clipped_value_loss=True, 30 | schedule="fixed", 31 | desired_kl=0.01, 32 | device="cpu", 33 | ): 34 | 35 | self.device = device 36 | 37 | self.desired_kl = desired_kl 38 | self.schedule = schedule 39 | self.learning_rate = learning_rate 40 | 41 | # PPO components 42 | self.actor_critic = actor_critic 43 | self.actor_critic.to(self.device) 44 | self.storage = None # initialized later 45 | self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate) 46 | self.transition = RolloutStorage.Transition() 47 | 48 | # PPO parameters 49 | self.clip_param = clip_param 50 | self.num_learning_epochs = num_learning_epochs 51 | self.num_mini_batches = num_mini_batches 52 | self.value_loss_coef = value_loss_coef 53 | self.entropy_coef = entropy_coef 54 | self.gamma = gamma 55 | self.lam = lam 56 | self.max_grad_norm = max_grad_norm 57 | self.use_clipped_value_loss = use_clipped_value_loss 58 | 59 | def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape): 60 | self.storage = RolloutStorage( 61 | num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device 62 | ) 63 | 64 | def test_mode(self): 65 | self.actor_critic.test() 66 | 67 | def train_mode(self): 68 | self.actor_critic.train() 69 | 70 | def act(self, obs, critic_obs): 71 | if self.actor_critic.is_recurrent: 72 | self.transition.hidden_states = self.actor_critic.get_hidden_states() 73 | # Compute the actions and values 74 | self.transition.actions = self.actor_critic.act(obs).detach() 75 | self.transition.values = self.actor_critic.evaluate(critic_obs).detach() 76 | self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach() 77 | self.transition.action_mean = self.actor_critic.action_mean.detach() 78 | self.transition.action_sigma = self.actor_critic.action_std.detach() 79 | # need to record obs and critic_obs before env.step() 80 | self.transition.observations = obs 81 | self.transition.critic_observations = critic_obs 82 | return self.transition.actions 83 | 84 | def process_env_step(self, rewards, dones, infos): 85 | self.transition.rewards = rewards.clone() 86 | self.transition.dones = dones 87 | # Bootstrapping on time outs 88 | if "time_outs" in infos: 89 | self.transition.rewards += self.gamma * torch.squeeze( 90 | self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1 91 | ) 92 | 93 | # Record the transition 94 | self.storage.add_transitions(self.transition) 95 | self.transition.clear() 96 | self.actor_critic.reset(dones) 97 | 98 | def compute_returns(self, last_critic_obs): 99 | last_values = self.actor_critic.evaluate(last_critic_obs).detach() 100 | self.storage.compute_returns(last_values, self.gamma, self.lam) 101 | 102 | def update(self): 103 | mean_value_loss = 0 104 | mean_surrogate_loss = 0 105 | if self.actor_critic.is_recurrent: 106 | generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) 107 | else: 108 | generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) 109 | for ( 110 | obs_batch, 111 | critic_obs_batch, 112 | actions_batch, 113 | target_values_batch, 114 | advantages_batch, 115 | returns_batch, 116 | old_actions_log_prob_batch, 117 | old_mu_batch, 118 | old_sigma_batch, 119 | hid_states_batch, 120 | masks_batch, 121 | ) in generator: 122 | 123 | self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) 124 | actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch) 125 | value_batch = self.actor_critic.evaluate( 126 | critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1] 127 | ) 128 | mu_batch = self.actor_critic.action_mean 129 | sigma_batch = self.actor_critic.action_std 130 | entropy_batch = self.actor_critic.entropy 131 | 132 | # KL 133 | if self.desired_kl is not None and self.schedule == "adaptive": 134 | with torch.inference_mode(): 135 | kl = torch.sum( 136 | torch.log(sigma_batch / old_sigma_batch + 1.0e-5) 137 | + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) 138 | / (2.0 * torch.square(sigma_batch)) 139 | - 0.5, 140 | axis=-1, 141 | ) 142 | kl_mean = torch.mean(kl) 143 | 144 | if kl_mean > self.desired_kl * 2.0: 145 | self.learning_rate = max(1e-5, self.learning_rate / 1.5) 146 | elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: 147 | self.learning_rate = min(1e-2, self.learning_rate * 1.5) 148 | 149 | for param_group in self.optimizer.param_groups: 150 | param_group["lr"] = self.learning_rate 151 | 152 | # Surrogate loss 153 | ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) 154 | surrogate = -torch.squeeze(advantages_batch) * ratio 155 | surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp( 156 | ratio, 1.0 - self.clip_param, 1.0 + self.clip_param 157 | ) 158 | surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() 159 | 160 | # Value function loss 161 | if self.use_clipped_value_loss: 162 | value_clipped = target_values_batch + (value_batch - target_values_batch).clamp( 163 | -self.clip_param, self.clip_param 164 | ) 165 | value_losses = (value_batch - returns_batch).pow(2) 166 | value_losses_clipped = (value_clipped - returns_batch).pow(2) 167 | value_loss = torch.max(value_losses, value_losses_clipped).mean() 168 | else: 169 | value_loss = (returns_batch - value_batch).pow(2).mean() 170 | 171 | loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean() 172 | 173 | # Gradient step 174 | self.optimizer.zero_grad() 175 | loss.backward() 176 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm) 177 | self.optimizer.step() 178 | 179 | mean_value_loss += value_loss.item() 180 | mean_surrogate_loss += surrogate_loss.item() 181 | 182 | num_updates = self.num_learning_epochs * self.num_mini_batches 183 | mean_value_loss /= num_updates 184 | mean_surrogate_loss /= num_updates 185 | self.storage.clear() 186 | 187 | return mean_value_loss, mean_surrogate_loss 188 | -------------------------------------------------------------------------------- /learning/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/learning/datasets/__init__.py -------------------------------------------------------------------------------- /learning/datasets/motion_loader.py: -------------------------------------------------------------------------------- 1 | from humanoid_gym import LEGGED_GYM_ROOT_DIR 2 | from isaacgym.torch_utils import ( 3 | quat_mul, 4 | quat_conjugate, 5 | normalize, 6 | quat_from_angle_axis, 7 | ) 8 | import os 9 | import json 10 | import torch 11 | 12 | class MotionLoader: 13 | 14 | def __init__(self, device, motion_file=None, corruption_level=0.0, reference_history_horizon=2, test_mode=False, test_observation_dim=None): 15 | self.device = device 16 | self.reference_history_horizon = reference_history_horizon 17 | if motion_file is None: 18 | motion_file = LEGGED_GYM_ROOT_DIR + "/resources/robots/anymal_c/datasets/motion_data.pt" 19 | self.reference_state_idx_dict_file = os.path.join(os.path.dirname(motion_file), "reference_state_idx_dict.json") 20 | with open(self.reference_state_idx_dict_file, 'r') as f: 21 | self.state_idx_dict = json.load(f) 22 | self.observation_dim = sum([ids[1] - ids[0] for state, ids in self.state_idx_dict.items() if ((state != "base_pos") and (state != "base_quat"))]) 23 | self.observation_start_dim = self.state_idx_dict["base_lin_vel"][0] 24 | loaded_data = torch.load(motion_file, map_location=self.device) 25 | 26 | # Normalize and standardize quaternions 27 | base_quat = normalize(loaded_data[:, :, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]]) 28 | base_quat[base_quat[:, :, -1] < 0] = -base_quat[base_quat[:, :, -1] < 0] 29 | loaded_data[:, :, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]] = base_quat 30 | 31 | # Load data for DTW 32 | motion_file_dtw = os.path.join(os.path.dirname(motion_file), "motion_data_original.pt") 33 | try: 34 | self.dtw_reference = torch.load(motion_file_dtw, map_location=self.device)[:, :, self.observation_start_dim:] 35 | print(f"[MotionLoader] Loaded DTW reference motion clips.") 36 | except: 37 | self.dtw_reference = None 38 | print(f"[MotionLoader] No DTW reference motion clips provided.") 39 | 40 | self.data = self._data_corruption(loaded_data, level=corruption_level) 41 | self.num_motion_clips, self.num_steps, self.reference_full_dim = self.data.size() 42 | print(f"[MotionLoader] Loaded {self.num_motion_clips} motion clips from {motion_file}. Each records {self.num_steps} steps and {self.reference_full_dim} states.") 43 | 44 | # Preload transitions 45 | self.num_preload_transitions = 500000 46 | motion_clip_sample_ids = torch.randint(0, self.num_motion_clips, (self.num_preload_transitions,), device=self.device) 47 | step_sample = torch.rand(self.num_preload_transitions, device=self.device) * (self.num_steps - self.reference_history_horizon) 48 | self.preloaded_states = torch.zeros( 49 | self.num_preload_transitions, 50 | self.reference_history_horizon, 51 | self.reference_full_dim, 52 | dtype=torch.float, 53 | device=self.device, 54 | requires_grad=False 55 | ) 56 | for i in range(self.reference_history_horizon): 57 | self.preloaded_states[:, i] = self._get_frame_at_step(motion_clip_sample_ids, step_sample + i) 58 | 59 | if test_mode: 60 | self.observation_dim = test_observation_dim 61 | 62 | def _data_corruption(self, loaded_data, level=0): 63 | if level == 0: 64 | print(f"[MotionLoader] Proceeded without processing the loaded data.") 65 | else: 66 | loaded_data = self._rand_dropout(loaded_data, level) 67 | loaded_data = self._rand_noise(loaded_data, level) 68 | loaded_data = self._rand_interpolation(loaded_data, level) 69 | loaded_data = self._rand_duplication(loaded_data, level) 70 | return loaded_data 71 | 72 | def _rand_dropout(self, data, level=0): 73 | num_motion_clips, num_steps, reference_full_dim = data.size() 74 | num_dropouts = round(num_steps * level) 75 | if num_dropouts == 0: 76 | return data 77 | dropped_data = torch.zeros(num_motion_clips, num_steps - num_dropouts, reference_full_dim, dtype=torch.float, device=self.device, requires_grad=False) 78 | for i in range(num_motion_clips): 79 | step_ids = torch.randperm(num_steps)[:-num_dropouts].sort()[0] 80 | dropped_data[i] = data[i, step_ids] 81 | return dropped_data 82 | 83 | def _rand_interpolation(self, data, level=0): 84 | num_motion_clips, num_steps, reference_full_dim = data.size() 85 | num_interpolations = round((num_steps - 2) * level) 86 | if num_interpolations == 0: 87 | return data 88 | interpolated_data = data 89 | for i in range(num_motion_clips): 90 | step_ids = torch.randperm(num_steps) 91 | step_ids = step_ids[(step_ids != 0) * (step_ids != num_steps - 1)] 92 | step_ids = step_ids[:num_interpolations].sort()[0] 93 | interpolated_data[i, step_ids] = self.slerp(data[i, step_ids - 1], data[i, step_ids + 1], 0.5) 94 | interpolated_data[i, step_ids, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]] = self.quaternion_slerp( 95 | data[i, step_ids - 1, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]], 96 | data[i, step_ids + 1, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]], 97 | 0.5 98 | ) 99 | return interpolated_data 100 | 101 | def _rand_duplication(self, data, level=0): 102 | num_motion_clips, num_steps, reference_full_dim = data.size() 103 | num_duplications = round(num_steps * level) * 10 104 | if num_duplications == 0: 105 | return data 106 | duplicated_data = torch.zeros(num_motion_clips, num_steps + num_duplications, reference_full_dim, dtype=torch.float, device=self.device, requires_grad=False) 107 | step_ids = torch.randint(0, num_steps, (num_motion_clips, num_duplications), device=self.device) 108 | for i in range(num_motion_clips): 109 | duplicated_step_ids = torch.cat((torch.arange(num_steps, device=self.device), step_ids[i])).sort()[0] 110 | duplicated_data[i] = data[i, duplicated_step_ids] 111 | return duplicated_data 112 | 113 | def _rand_noise(self, data, level=0): 114 | noise_scales_dict = { 115 | "base_pos": 0.1, 116 | "base_quat": 0.01, 117 | "base_lin_vel": 0.1, 118 | "base_ang_vel": 0.2, 119 | "projected_gravity": 0.05, 120 | "base_height": 0.1, 121 | "dof_pos": 0.01, 122 | "dof_vel": 1.5 123 | } 124 | noise_scale_vec = torch.zeros_like(data[0, 0], device=self.device, dtype=torch.float, requires_grad=False) 125 | for key, value in self.state_idx_dict.items(): 126 | noise_scale_vec[value[0]:value[1]] = noise_scales_dict[key] * level 127 | data += (2 * torch.randn_like(data) - 1) * noise_scale_vec 128 | return data 129 | 130 | def _get_frame_at_step(self, motion_clip_sample_ids, step_sample): 131 | step_low, step_high = step_sample.floor().long(), step_sample.ceil().long() 132 | blend = (step_sample - step_low).unsqueeze(-1) 133 | frame = self.slerp(self.data[motion_clip_sample_ids, step_low], self.data[motion_clip_sample_ids, step_high], blend) 134 | frame[:, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]] = self.quaternion_slerp( 135 | self.data[motion_clip_sample_ids, step_low, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]], 136 | self.data[motion_clip_sample_ids, step_high, self.state_idx_dict["base_quat"][0]:self.state_idx_dict["base_quat"][1]], 137 | blend 138 | ) 139 | return frame 140 | 141 | def get_frames(self, num_frames): 142 | ids = torch.randint(0, self.num_preload_transitions, (num_frames,), device=self.device) 143 | return self.preloaded_states[ids, 0] 144 | 145 | def get_transitions(self, num_transitions): 146 | ids = torch.randint(0, self.num_preload_transitions, (num_transitions,), device=self.device) 147 | return self.preloaded_states[ids, :] 148 | 149 | def slerp(self, value_low, value_high, blend): 150 | return (1.0 - blend) * value_low + blend * value_high 151 | 152 | def quaternion_slerp(self, quat_low, quat_high, blend): 153 | relative_quat = normalize(quat_mul(quat_high, quat_conjugate(quat_low))) 154 | angle = 2 * torch.acos(relative_quat[:, -1]).unsqueeze(-1) 155 | axis = normalize(relative_quat[:, :3]) 156 | angle_slerp = self.slerp(torch.zeros_like(angle), angle, blend).squeeze(-1) 157 | relative_quat_slerp = quat_from_angle_axis(angle_slerp, axis) 158 | return normalize(quat_mul(relative_quat_slerp, quat_low)) 159 | 160 | def feed_forward_generator(self, num_mini_batch, mini_batch_size): 161 | for _ in range(num_mini_batch): 162 | ids = torch.randint(0, self.num_preload_transitions, (mini_batch_size,), device=self.device) 163 | states = self.preloaded_states[ids, :, self.observation_start_dim:] 164 | yield states 165 | 166 | @staticmethod 167 | def get_base_pos(state_idx_dict, frames): 168 | if "base_pos" in state_idx_dict: 169 | return frames[:, state_idx_dict["base_pos"][0]:state_idx_dict["base_pos"][1]] 170 | else: 171 | raise Exception("[MotionLoader] base_pos not specified in the state_idx_dict") 172 | 173 | @staticmethod 174 | def get_base_quat(state_idx_dict, frames): 175 | if "base_quat" in state_idx_dict: 176 | return frames[:, state_idx_dict["base_quat"][0]:state_idx_dict["base_quat"][1]] 177 | else: 178 | raise Exception("[MotionLoader] base_quat not specified in the state_idx_dict") 179 | 180 | @staticmethod 181 | def get_base_lin_vel(state_idx_dict, frames): 182 | if "base_lin_vel" in state_idx_dict: 183 | return frames[:, state_idx_dict["base_lin_vel"][0]:state_idx_dict["base_lin_vel"][1]] 184 | else: 185 | raise Exception("[MotionLoader] base_lin_vel not specified in the state_idx_dict") 186 | 187 | @staticmethod 188 | def get_base_ang_vel(state_idx_dict, frames): 189 | if "base_ang_vel" in state_idx_dict: 190 | return frames[:, state_idx_dict["base_ang_vel"][0]:state_idx_dict["base_ang_vel"][1]] 191 | else: 192 | raise Exception("[MotionLoader] base_ang_vel not specified in the state_idx_dict") 193 | 194 | @staticmethod 195 | def get_projected_gravity(state_idx_dict, frames): 196 | if "projected_gravity" in state_idx_dict: 197 | return frames[:, state_idx_dict["projected_gravity"][0]:state_idx_dict["projected_gravity"][1]] 198 | else: 199 | raise Exception("[MotionLoader] projected_gravity not specified in the state_idx_dict") 200 | 201 | @staticmethod 202 | def get_dof_pos(state_idx_dict, frames): 203 | if "dof_pos" in state_idx_dict: 204 | return frames[:, state_idx_dict["dof_pos"][0]:state_idx_dict["dof_pos"][1]] 205 | else: 206 | raise Exception("[MotionLoader] dof_pos not specified in the state_idx_dict") 207 | 208 | @staticmethod 209 | def get_dof_vel(state_idx_dict, frames): 210 | if "dof_vel" in state_idx_dict: 211 | return frames[:, state_idx_dict["dof_vel"][0]:state_idx_dict["dof_vel"][1]] 212 | else: 213 | raise Exception("[MotionLoader] dof_vel not specified in the state_idx_dict") 214 | 215 | @staticmethod 216 | def get_feet_pos(state_idx_dict, frames): 217 | if "feet_pos" in state_idx_dict: 218 | return frames[:, state_idx_dict["feet_pos"][0]:state_idx_dict["feet_pos"][1]] 219 | else: 220 | raise Exception("[MotionLoader] feet_pos not specified in the state_idx_dict") 221 | -------------------------------------------------------------------------------- /learning/env/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | """Submodule defining the environment definitions.""" 4 | 5 | from .vec_env import VecEnv 6 | 7 | __all__ = ["VecEnv"] 8 | -------------------------------------------------------------------------------- /learning/env/vec_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # python 5 | from abc import ABC, abstractmethod 6 | from typing import Tuple, Union 7 | 8 | # torch 9 | import torch 10 | 11 | 12 | # minimal interface of the environment 13 | class VecEnv(ABC): 14 | """Abstract class for vectorized environment.""" 15 | 16 | num_envs: int 17 | num_obs: int 18 | num_privileged_obs: int 19 | num_actions: int 20 | max_episode_length: int 21 | privileged_obs_buf: torch.Tensor 22 | obs_buf: torch.Tensor 23 | rew_buf: torch.Tensor 24 | reset_buf: torch.Tensor 25 | episode_length_buf: torch.Tensor # current episode duration 26 | extras: dict 27 | device: torch.device 28 | 29 | """ 30 | Properties 31 | """ 32 | 33 | @abstractmethod 34 | def get_observations(self) -> torch.Tensor: 35 | pass 36 | 37 | @abstractmethod 38 | def get_privileged_observations(self) -> Union[torch.Tensor, None]: 39 | pass 40 | 41 | """ 42 | Operations. 43 | """ 44 | 45 | @abstractmethod 46 | def step( 47 | self, actions: torch.Tensor 48 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]: 49 | """Apply input action on the environment. 50 | 51 | Args: 52 | actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions) 53 | 54 | Returns: 55 | Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, dict]: 56 | A tuple containing the observations, privileged observations, rewards, dones and 57 | extra information (metrics). 58 | """ 59 | raise NotImplementedError 60 | 61 | @abstractmethod 62 | def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 63 | """Reset all environment instances. 64 | 65 | Returns: 66 | Tuple[torch.Tensor, torch.Tensor | None]: Tuple containing the observations and privileged observations. 67 | """ 68 | raise NotImplementedError 69 | -------------------------------------------------------------------------------- /learning/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Definitions for neural-network components for RL-agents.""" 5 | 6 | from .actor_critic import ActorCritic 7 | from .actor_critic_recurrent import ActorCriticRecurrent 8 | from .normalizer import Normalizer 9 | 10 | __all__ = ["ActorCritic", "ActorCriticRecurrent", "Normalizer"] 11 | -------------------------------------------------------------------------------- /learning/modules/actor_critic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # torch 5 | import torch 6 | import torch.nn as nn 7 | from torch.distributions import Normal 8 | from learning.modules.normalizer import EmpiricalNormalization 9 | 10 | 11 | class ActorCritic(nn.Module): 12 | is_recurrent = False 13 | 14 | def __init__( 15 | self, 16 | num_actor_obs, 17 | num_critic_obs, 18 | num_actions, 19 | actor_hidden_dims=[256, 256, 256], 20 | critic_hidden_dims=[256, 256, 256], 21 | activation="elu", 22 | init_noise_std=1.0, 23 | update_obs_norm=True, 24 | **kwargs, 25 | ): 26 | if kwargs: 27 | print( 28 | "ActorCritic.__init__ got unexpected arguments, which will be ignored: " 29 | + str([key for key in kwargs.keys()]) 30 | ) 31 | super(ActorCritic, self).__init__() 32 | activation = get_activation(activation) 33 | 34 | mlp_input_dim_a = num_actor_obs 35 | mlp_input_dim_c = num_critic_obs 36 | 37 | # Policy 38 | actor_layers = [] 39 | actor_layers.append( 40 | EmpiricalNormalization(shape=[mlp_input_dim_a], update_obs_norm=update_obs_norm, until=1.0e8) 41 | ) 42 | actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) 43 | actor_layers.append(activation) 44 | for layer_index in range(len(actor_hidden_dims)): 45 | if layer_index == len(actor_hidden_dims) - 1: 46 | actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions)) 47 | else: 48 | actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1])) 49 | actor_layers.append(activation) 50 | self.actor = nn.Sequential(*actor_layers) 51 | 52 | # Value function 53 | critic_layers = [] 54 | critic_layers.append( 55 | EmpiricalNormalization(shape=[mlp_input_dim_c], update_obs_norm=update_obs_norm, until=1.0e8) 56 | ) 57 | critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) 58 | critic_layers.append(activation) 59 | for layer_index in range(len(critic_hidden_dims)): 60 | if layer_index == len(critic_hidden_dims) - 1: 61 | critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1)) 62 | else: 63 | critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1])) 64 | critic_layers.append(activation) 65 | self.critic = nn.Sequential(*critic_layers) 66 | 67 | print(f"Actor MLP: {self.actor}") 68 | print(f"Critic MLP: {self.critic}") 69 | 70 | # Action noise 71 | self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) 72 | self.distribution = None 73 | # disable args validation for speedup 74 | Normal.set_default_validate_args = False 75 | 76 | # seems that we get better performance without init 77 | # self.init_memory_weights(self.memory_a, 0.001, 0.) 78 | # self.init_memory_weights(self.memory_c, 0.001, 0.) 79 | 80 | @staticmethod 81 | # not used at the moment 82 | def init_weights(sequential, scales): 83 | [ 84 | torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) 85 | for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear)) 86 | ] 87 | 88 | def reset(self, dones=None): 89 | pass 90 | 91 | def forward(self): 92 | raise NotImplementedError 93 | 94 | @property 95 | def action_mean(self): 96 | return self.distribution.mean 97 | 98 | @property 99 | def action_std(self): 100 | return self.distribution.stddev 101 | 102 | @property 103 | def entropy(self): 104 | return self.distribution.entropy().sum(dim=-1) 105 | 106 | def update_distribution(self, observations): 107 | mean = self.actor(observations) 108 | self.distribution = Normal(mean, mean * 0.0 + self.std) 109 | 110 | def act(self, observations, **kwargs): 111 | self.update_distribution(observations) 112 | return self.distribution.sample() 113 | 114 | def get_actions_log_prob(self, actions): 115 | return self.distribution.log_prob(actions).sum(dim=-1) 116 | 117 | def act_inference(self, observations): 118 | actions_mean = self.actor(observations) 119 | return actions_mean 120 | 121 | def evaluate(self, critic_observations, **kwargs): 122 | value = self.critic(critic_observations) 123 | return value 124 | 125 | 126 | def get_activation(act_name): 127 | if act_name == "elu": 128 | return nn.ELU() 129 | elif act_name == "selu": 130 | return nn.SELU() 131 | elif act_name == "relu": 132 | return nn.ReLU() 133 | elif act_name == "crelu": 134 | return nn.ReLU() 135 | elif act_name == "lrelu": 136 | return nn.LeakyReLU() 137 | elif act_name == "tanh": 138 | return nn.Tanh() 139 | elif act_name == "sigmoid": 140 | return nn.Sigmoid() 141 | else: 142 | print("invalid activation function!") 143 | return None 144 | -------------------------------------------------------------------------------- /learning/modules/actor_critic_recurrent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # torch 5 | import torch 6 | import torch.nn as nn 7 | 8 | # learning 9 | from learning.modules.actor_critic import ActorCritic, get_activation 10 | from learning.utils import unpad_trajectories 11 | 12 | 13 | class ActorCriticRecurrent(ActorCritic): 14 | is_recurrent = True 15 | 16 | def __init__( 17 | self, 18 | num_actor_obs, 19 | num_critic_obs, 20 | num_actions, 21 | actor_hidden_dims=[256, 256, 256], 22 | critic_hidden_dims=[256, 256, 256], 23 | activation="elu", 24 | rnn_type="lstm", 25 | rnn_hidden_size=256, 26 | rnn_num_layers=1, 27 | init_noise_std=1.0, 28 | **kwargs, 29 | ): 30 | if kwargs: 31 | print( 32 | "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()), 33 | ) 34 | 35 | super().__init__( 36 | num_actor_obs=rnn_hidden_size, 37 | num_critic_obs=rnn_hidden_size, 38 | num_actions=num_actions, 39 | actor_hidden_dims=actor_hidden_dims, 40 | critic_hidden_dims=critic_hidden_dims, 41 | activation=activation, 42 | init_noise_std=init_noise_std, 43 | ) 44 | 45 | activation = get_activation(activation) 46 | 47 | self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 48 | self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) 49 | 50 | print(f"Actor RNN: {self.memory_a}") 51 | print(f"Critic RNN: {self.memory_c}") 52 | 53 | def reset(self, dones=None): 54 | self.memory_a.reset(dones) 55 | self.memory_c.reset(dones) 56 | 57 | def act(self, observations, masks=None, hidden_states=None): 58 | input_a = self.memory_a(observations, masks, hidden_states) 59 | return super().act(input_a.squeeze(0)) 60 | 61 | def act_inference(self, observations): 62 | input_a = self.memory_a(observations) 63 | return super().act_inference(input_a.squeeze(0)) 64 | 65 | def evaluate(self, critic_observations, masks=None, hidden_states=None): 66 | input_c = self.memory_c(critic_observations, masks, hidden_states) 67 | return super().evaluate(input_c.squeeze(0)) 68 | 69 | def get_hidden_states(self): 70 | return self.memory_a.hidden_states, self.memory_c.hidden_states 71 | 72 | 73 | class Memory(torch.nn.Module): 74 | def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256): 75 | super().__init__() 76 | # RNN 77 | rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM 78 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 79 | self.hidden_states = None 80 | 81 | def forward(self, input, masks=None, hidden_states=None): 82 | batch_mode = masks is not None 83 | if batch_mode: 84 | # batch mode (policy update): need saved hidden states 85 | if hidden_states is None: 86 | raise ValueError("Hidden states not passed to memory module during policy update") 87 | out, _ = self.rnn(input, hidden_states) 88 | out = unpad_trajectories(out, masks) 89 | else: 90 | # inference mode (collection): use hidden states of last step 91 | out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) 92 | return out 93 | 94 | def reset(self, dones=None): 95 | # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state 96 | for hidden_state in self.hidden_states: 97 | hidden_state[..., dones, :] = 0.0 98 | -------------------------------------------------------------------------------- /learning/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Discriminator(nn.Module): 5 | def __init__(self, 6 | latent_channel, 7 | num_classes, 8 | device, 9 | shape=[1024, 512], 10 | ): 11 | super(Discriminator, self).__init__() 12 | self.input_dim = latent_channel 13 | self.num_classes = num_classes 14 | self.device = device 15 | self.shape = shape 16 | 17 | discriminator_layers = [] 18 | curr_in_dim = self.input_dim 19 | for hidden_dim in self.shape: 20 | discriminator_layers.append(nn.Linear(curr_in_dim, hidden_dim)) 21 | discriminator_layers.append(nn.ReLU()) 22 | curr_in_dim = hidden_dim 23 | discriminator_layers.append(nn.Linear(self.shape[-1], self.num_classes)) 24 | self.architecture = nn.Sequential(*discriminator_layers).to(self.device) 25 | self.architecture.train() 26 | 27 | def forward(self, x): 28 | return self.architecture(x) 29 | -------------------------------------------------------------------------------- /learning/modules/fld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FLD(nn.Module): 5 | def __init__(self, 6 | observation_dim, 7 | history_horizon, 8 | latent_channel, 9 | device, 10 | dt=0.02, 11 | encoder_shape=None, 12 | decoder_shape=None, 13 | **kwargs, 14 | ): 15 | if kwargs: 16 | print("FLD.__init__ got unexpected arguments, which will be ignored: " 17 | + str([key for key in kwargs.keys()])) 18 | super(FLD, self).__init__() 19 | self.input_channel = observation_dim 20 | self.history_horizon = history_horizon 21 | self.latent_channel = latent_channel 22 | self.device = device 23 | self.dt = dt 24 | 25 | self.args = torch.linspace(-(history_horizon - 1) * self.dt / 2, (history_horizon - 1) * self.dt / 2, self.history_horizon, dtype=torch.float, device=self.device) 26 | self.freqs = torch.fft.rfftfreq(history_horizon, device=self.device)[1:] * history_horizon 27 | self.encoder_shape = encoder_shape if encoder_shape is not None else [int(self.input_channel / 3)] 28 | self.decoder_shape = decoder_shape if decoder_shape is not None else [int(self.input_channel / 3)] 29 | 30 | encoder_layers = [] 31 | curr_in_channel = self.input_channel 32 | for hidden_channel in self.encoder_shape: 33 | encoder_layers.append( 34 | nn.Conv1d( 35 | curr_in_channel, 36 | hidden_channel, 37 | history_horizon, 38 | stride=1, 39 | padding=int((history_horizon - 1) / 2), 40 | dilation=1, 41 | groups=1, 42 | bias=True, 43 | padding_mode='zeros') 44 | ) 45 | encoder_layers.append(nn.BatchNorm1d(num_features=hidden_channel)) 46 | encoder_layers.append(nn.ELU()) 47 | curr_in_channel = hidden_channel 48 | encoder_layers.append( 49 | nn.Conv1d( 50 | self.encoder_shape[-1], 51 | latent_channel, 52 | history_horizon, 53 | stride=1, 54 | padding=int((history_horizon - 1) / 2), 55 | dilation=1, 56 | groups=1, 57 | bias=True, 58 | padding_mode='zeros') 59 | ) 60 | encoder_layers.append(nn.BatchNorm1d(num_features=latent_channel)) 61 | encoder_layers.append(nn.ELU()) 62 | 63 | self.encoder = nn.Sequential(*encoder_layers).to(self.device) 64 | self.encoder.train() 65 | 66 | self.phase_encoder = nn.ModuleList() 67 | for _ in range(latent_channel): 68 | phase_encoder_layers = [] 69 | phase_encoder_layers.append(nn.Linear(history_horizon, 2)) 70 | phase_encoder_layers.append(nn.BatchNorm1d(num_features=2)) 71 | phase_encoder = nn.Sequential(*phase_encoder_layers).to(self.device) 72 | self.phase_encoder.append(phase_encoder) 73 | self.phase_encoder.train() 74 | 75 | decoder_layers = [] 76 | curr_in_channel = latent_channel 77 | for hidden_channel in self.decoder_shape: 78 | decoder_layers.append( 79 | nn.Conv1d( 80 | curr_in_channel, 81 | hidden_channel, 82 | history_horizon, 83 | stride=1, 84 | padding=int((history_horizon - 1) / 2), 85 | dilation=1, 86 | groups=1, 87 | bias=True, 88 | padding_mode='zeros') 89 | ) 90 | decoder_layers.append(nn.BatchNorm1d(num_features=hidden_channel)) 91 | decoder_layers.append(nn.ELU()) 92 | curr_in_channel = hidden_channel 93 | decoder_layers.append( 94 | nn.Conv1d( 95 | self.decoder_shape[-1], 96 | self.input_channel, 97 | history_horizon, 98 | stride=1, 99 | padding=int((history_horizon - 1) / 2), 100 | dilation=1, 101 | groups=1, 102 | bias=True, 103 | padding_mode='zeros') 104 | ) 105 | self.decoder = nn.Sequential(*decoder_layers).to(self.device) 106 | self.decoder.train() 107 | 108 | def forward(self, x, k=1): 109 | x = self.encoder(x) 110 | latent = x 111 | frequency, amplitude, offset = self.fft(x) 112 | phase = torch.zeros((x.size(0), self.latent_channel), device=self.device, dtype=torch.float) 113 | for i in range(self.latent_channel): 114 | phase_shift = self.phase_encoder[i](x[:, i, :]) 115 | phase[:, i] = torch.atan2(phase_shift[:, 1], phase_shift[:, 0]) / (2 * torch.pi) 116 | 117 | params = [phase, frequency, amplitude, offset] # (batch_size, latent_channel) 118 | 119 | phase_dynamics = phase.unsqueeze(0) + frequency.unsqueeze(0) * self.dt * torch.arange(0, k, device=self.device, dtype=torch.float, requires_grad=False).view(-1, 1, 1) # (k, batch_size, latent_channel) 120 | z = amplitude.unsqueeze(-1).unsqueeze(0) * torch.sin(2 * torch.pi * ((frequency.unsqueeze(-1) * self.args).unsqueeze(0) + phase_dynamics.unsqueeze(-1))) + offset.unsqueeze(-1).unsqueeze(0) # (k, batch_size, latent_channel, history_horizon) 121 | signal = z[0] 122 | pred_dynamics = self.decoder(z.flatten(0, 1)).view(k, -1, self.input_channel, self.history_horizon) # (k, batch_size, input_channel, history_horizon) 123 | 124 | return pred_dynamics, latent, signal, params 125 | 126 | def fft(self, x): 127 | rfft = torch.fft.rfft(x, dim=2) 128 | magnitude = rfft.abs() 129 | spectrum = magnitude[:, :, 1:] 130 | power = torch.square(spectrum) 131 | frequency = torch.sum(self.freqs * power, dim=2) / torch.sum(power, dim=2) 132 | amplitude = 2 * torch.sqrt(torch.sum(power, dim=2)) / self.history_horizon 133 | offset = rfft.real[:, :, 0] / self.history_horizon 134 | return frequency, amplitude, offset 135 | 136 | def get_dynamics_error(self, state_transitions, k): 137 | self.eval() 138 | state_transitions_sequence = torch.zeros( 139 | state_transitions.size(0), 140 | state_transitions.size(1) - self.history_horizon + 1, 141 | self.history_horizon, 142 | state_transitions.size(2), 143 | dtype=torch.float, 144 | device=self.device, 145 | requires_grad=False 146 | ) 147 | for step in range(state_transitions.size(1) - self.history_horizon + 1): 148 | state_transitions_sequence[:, step] = state_transitions[:, step:step + self.history_horizon, :] 149 | with torch.no_grad(): 150 | pred_dynamics, _, _, _ = self.forward(state_transitions_sequence.flatten(0, 1).swapaxes(1, 2), k) 151 | pred_dynamics = pred_dynamics.swapaxes(2, 3).view(k, -1, state_transitions.size(1) - self.history_horizon + 1, self.history_horizon, state_transitions.size(2)) 152 | error = torch.zeros(state_transitions.size(0), device=self.device, dtype=torch.float, requires_grad=False) 153 | for i in range(k): 154 | error[:] += torch.square((pred_dynamics[i, :, :state_transitions_sequence.size(1) - i] - state_transitions_sequence[:, i:])).mean(dim=(1, 2, 3)) 155 | return error 156 | -------------------------------------------------------------------------------- /learning/modules/gmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from sklearn.mixture import GaussianMixture as skGaussianMixture 4 | 5 | class GaussianMixtures(nn.Module): 6 | 7 | def __init__(self, min_n_components, max_n_components, n_features, device, covariance_type="full"): 8 | super(GaussianMixtures, self).__init__() 9 | self.candidates = nn.ModuleList( 10 | [ 11 | GaussianMixture( 12 | n_components, 13 | n_features, 14 | device, 15 | covariance_type=covariance_type, 16 | ) 17 | for n_components in range(min_n_components, max_n_components + 1) 18 | ] 19 | ) 20 | self.num_gmms = len(self.candidates) 21 | self.n_features = n_features 22 | self.device = device 23 | 24 | 25 | def aic(self, x): 26 | aics = torch.zeros(self.num_gmms, device=self.device, dtype=torch.float, requires_grad=False) 27 | for i in range(self.num_gmms): 28 | aics[i] = self.candidates[i].aic(x) 29 | return aics 30 | 31 | 32 | def bic(self, x): 33 | bics = torch.zeros(self.num_gmms, device=self.device, dtype=torch.float, requires_grad=False) 34 | for i in range(self.num_gmms): 35 | bics[i] = self.candidates[i].bic(x) 36 | return bics 37 | 38 | 39 | def fit(self, x): 40 | for i in range(self.num_gmms): 41 | self.candidates[i].fit(x) 42 | 43 | 44 | def get_best_gmm_idx(self, x, criterion="bic"): 45 | if criterion == "bic": 46 | bics = self.bic(x) 47 | idx = torch.argmin(bics) 48 | elif criterion == "aic": 49 | aics = self.aic(x) 50 | idx = torch.argmin(aics) 51 | else: 52 | raise ValueError("[GMMs] Invalid criterion.") 53 | return idx 54 | 55 | 56 | def sample(self, n, idx): 57 | alp_means = self.candidates[idx].mu[:, -1] 58 | k = torch.multinomial(alp_means, n, replacement=True) 59 | samples = self.candidates[idx].sample_class(n, k) 60 | return samples 61 | 62 | 63 | class GaussianMixture(nn.Module): 64 | def __init__(self, n_components, n_features, device, covariance_type="full"): 65 | super().__init__() 66 | 67 | self.n_components = n_components 68 | self.n_features = n_features 69 | self.device = device 70 | self.covariance_type = covariance_type 71 | 72 | self.gmm = skGaussianMixture(n_components=n_components, covariance_type=covariance_type, n_init=10) 73 | self._init_params() 74 | 75 | 76 | def _init_params(self): 77 | self.mu = nn.Parameter(torch.randn(self.n_components, self.n_features, device=self.device), requires_grad=False) 78 | if self.covariance_type == "diag": 79 | self.var = nn.Parameter(torch.ones(self.n_components, self.n_features, device=self.device), requires_grad=False) 80 | self.var_chol = nn.Parameter(torch.ones(self.n_components, self.n_features, device=self.device), requires_grad=False) 81 | elif self.covariance_type == "full": 82 | self.var = nn.Parameter(torch.eye(self.n_features, device=self.device).repeat(self.n_components, 1, 1), requires_grad=False) 83 | self.var_chol = nn.Parameter(torch.eye(self.n_features, device=self.device).repeat(self.n_components, 1, 1), requires_grad=False) 84 | else: 85 | raise Exception("[GaussianMixture] __init__ got invalid covariance_type: {}".format(self.covariance_type)) 86 | self.pi = nn.Parameter(torch.ones(self.n_components, device=self.device) / self.n_components, requires_grad=False) 87 | 88 | 89 | def load_state_dict(self, state_dict): 90 | super().load_state_dict(state_dict) 91 | self.gmm.means_ = self.mu.detach().cpu().numpy() 92 | self.gmm.covariances_ = self.var.detach().cpu().numpy() 93 | self.gmm.weights_ = self.pi.detach().cpu().numpy() 94 | if self.covariance_type == "diag": 95 | self.gmm.precisions_cholesky_ = torch.linalg.inv(self.var).detach().cpu().numpy() 96 | elif self.covariance_type == "full": 97 | self.gmm.precisions_cholesky_ = torch.linalg.cholesky(torch.linalg.inv(self.var)).detach().cpu().numpy() 98 | else: 99 | raise Exception("[GaussianMixture] __init__ got invalid covariance_type: {}".format(self.covariance_type)) 100 | 101 | 102 | def set_variances(self, var): 103 | self.var[:] = var 104 | if self.covariance_type == "diag": 105 | self.var_chol[:] = var.sqrt() 106 | elif self.covariance_type == "full": 107 | self.var_chol[:] = torch.linalg.cholesky(var) 108 | else: 109 | raise Exception("[GaussianMixture] __init__ got invalid covariance_type: {}".format(self.covariance_type)) 110 | self.gmm.covariances_ = var.detach().cpu().numpy() 111 | self.gmm.precisions_cholesky_ = torch.linalg.cholesky(torch.linalg.inv(var)).detach().cpu().numpy() 112 | 113 | 114 | def aic(self, x): 115 | x = x.cpu().numpy() 116 | aic = self.gmm.aic(x) 117 | return aic 118 | 119 | 120 | def bic(self, x): 121 | x = x.cpu().numpy() 122 | bic = self.gmm.bic(x) 123 | return bic 124 | 125 | 126 | def fit(self, x): 127 | x = x.cpu().numpy() 128 | self.gmm.fit(x) 129 | self.mu[:] = torch.tensor(self.gmm.means_, device=self.device, dtype=torch.float, requires_grad=False) 130 | self.var[:] = torch.tensor(self.gmm.covariances_, device=self.device, dtype=torch.float, requires_grad=False) 131 | self.pi[:] = torch.tensor(self.gmm.weights_, device=self.device, dtype=torch.float, requires_grad=False) 132 | if self.covariance_type == "diag": 133 | self.var_chol[:] = self.var.sqrt() 134 | elif self.covariance_type == "full": 135 | self.var_chol[:] = torch.linalg.cholesky(self.var) 136 | else: 137 | raise Exception("[GaussianMixture] __init__ got invalid covariance_type: {}".format(self.covariance_type)) 138 | 139 | 140 | def predict(self, x): 141 | x = x.cpu().numpy() 142 | y = self.gmm.predict(x) 143 | return torch.tensor(y, device=self.device, dtype=torch.long, requires_grad=False) 144 | 145 | 146 | def predict_proba(self, x): 147 | x = x.cpu().numpy() 148 | resp = self.gmm.predict_proba(x) 149 | return torch.tensor(resp, device=self.device, dtype=torch.float, requires_grad=False) 150 | 151 | 152 | def sample(self, n): 153 | x, y = self.gmm.sample(n) 154 | return torch.tensor(x, device=self.device, dtype=torch.float, requires_grad=False), torch.tensor(y, device=self.device, dtype=torch.long, requires_grad=False) 155 | 156 | 157 | def sample_class(self, n, k): 158 | mu_k = self.mu[k, :] 159 | var_chol_k = self.var_chol[k, :] 160 | if self.covariance_type == "diag": 161 | return torch.randn(n, self.n_features, device=self.device, dtype=torch.float, requires_grad=False) * var_chol_k + mu_k 162 | elif self.covariance_type == "full": 163 | return (var_chol_k @ torch.randn(n, self.n_features, 1, device=self.device, dtype=torch.float, requires_grad=False)).squeeze(-1) + mu_k 164 | else: 165 | raise Exception("[GaussianMixture] __init__ got invalid covariance_type: {}".format(self.covariance_type)) 166 | 167 | 168 | def score(self, x): 169 | x = x.cpu().numpy() 170 | score = self.gmm.score(x) 171 | return torch.tensor(score, device=self.device, dtype=torch.float, requires_grad=False) 172 | 173 | 174 | def score_samples(self, x): 175 | x = x.cpu().numpy() 176 | score_samples = self.gmm.score_samples(x) 177 | return torch.tensor(score_samples, device=self.device, dtype=torch.float, requires_grad=False) 178 | 179 | 180 | def get_block_parameters(self, latent_dim): 181 | mu = [] 182 | var = [] 183 | num_slices = int(self.n_features / latent_dim) 184 | for i in range(num_slices): 185 | mu.append(self.mu[:, i * latent_dim:(i + 1) * latent_dim]) 186 | if self.covariance_type == "full": 187 | var.append(self.var[:, i * latent_dim:(i + 1) * latent_dim, i * latent_dim:(i + 1) * latent_dim]) 188 | else: 189 | var.append(self.var[:, i * latent_dim:(i + 1) * latent_dim]) 190 | return mu, var -------------------------------------------------------------------------------- /learning/modules/normalizer.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Preferred Networks, Inc. 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | import numpy as np 23 | 24 | import torch 25 | from torch import nn 26 | from typing import Tuple 27 | 28 | 29 | class EmpiricalNormalization(nn.Module): 30 | """Normalize mean and variance of values based on empirical values. 31 | Args: 32 | shape (int or tuple of int): Shape of input values except batch axis. 33 | batch_axis (int): Batch axis. 34 | eps (float): Small value for stability. 35 | dtype (dtype): Dtype of input values. 36 | until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes 37 | exceeds it. 38 | update_obs_norm (bool): If true, learns updates mean and variance 39 | """ 40 | 41 | def __init__( 42 | self, 43 | shape, 44 | batch_axis=0, 45 | eps=1e-2, 46 | dtype=np.float32, 47 | until=None, 48 | clip_threshold=None, 49 | update_obs_norm=True, 50 | ): 51 | super(EmpiricalNormalization, self).__init__() 52 | dtype = np.dtype(dtype) 53 | self.batch_axis = batch_axis 54 | self.eps = eps 55 | self.until = until 56 | self.clip_threshold = clip_threshold 57 | self.register_buffer( 58 | "_mean", 59 | torch.tensor(np.expand_dims(np.zeros(shape, dtype=dtype), batch_axis)), 60 | ) 61 | self.register_buffer( 62 | "_var", 63 | torch.tensor(np.expand_dims(np.ones(shape, dtype=dtype), batch_axis)), 64 | ) 65 | self.register_buffer("count", torch.tensor(0)) 66 | self.in_features = shape[0] 67 | 68 | # cache 69 | self._cached_std_inverse = torch.tensor(np.expand_dims(np.ones(shape, dtype=dtype), batch_axis)) 70 | self._is_std_cached = False 71 | self._is_training = update_obs_norm 72 | 73 | @property 74 | def mean(self): 75 | return torch.squeeze(self._mean, self.batch_axis).clone() 76 | 77 | @property 78 | def std(self): 79 | return torch.sqrt(torch.squeeze(self._var, self.batch_axis)).clone() 80 | 81 | @property 82 | def _std_inverse(self): 83 | if self._is_std_cached is False: 84 | self._cached_std_inverse = (self._var + self.eps) ** -0.5 85 | 86 | return self._cached_std_inverse 87 | 88 | @torch.jit.unused 89 | @torch.no_grad() 90 | def experience(self, x): 91 | """Learn input values without computing the output values of them""" 92 | 93 | if self.until is not None: 94 | if self.count >= self.until: 95 | return 96 | 97 | count_x = x.shape[self.batch_axis] 98 | if count_x == 0: 99 | return 100 | 101 | self.count += count_x 102 | rate = count_x / self.count.float() 103 | assert rate > 0 104 | assert rate <= 1 105 | 106 | var_x = torch.var(x, dim=self.batch_axis, unbiased=False, keepdim=True) 107 | mean_x = torch.mean(x, dim=self.batch_axis, keepdim=True) 108 | delta_mean = mean_x - self._mean 109 | self._mean += rate * delta_mean 110 | self._var += rate * (var_x - self._var + delta_mean * (mean_x - self._mean)) 111 | 112 | # clear cache 113 | self._is_std_cached = False 114 | 115 | def forward(self, x): 116 | """Normalize mean and variance of values based on emprical values. 117 | Args: 118 | x (ndarray or Variable): Input values 119 | Returns: 120 | ndarray or Variable: Normalized output values 121 | """ 122 | 123 | if self._is_training: 124 | self.experience(x) 125 | 126 | if not x.is_cuda: 127 | self._is_std_cached = False 128 | normalized = (x - self._mean) * self._std_inverse 129 | if self.clip_threshold is not None: 130 | normalized = torch.clamp(normalized, -self.clip_threshold, self.clip_threshold) 131 | if not x.is_cuda: 132 | self._is_std_cached = False 133 | return normalized 134 | 135 | @torch.jit.unused 136 | def inverse(self, y): 137 | std = torch.sqrt(self._var + self.eps) 138 | return y * std + self._mean 139 | 140 | def load_numpy(self, mean, var, count, device="cpu"): 141 | self._mean = torch.from_numpy(np.expand_dims(mean, self.batch_axis)).to(device) 142 | self._var = torch.from_numpy(np.expand_dims(var, self.batch_axis)).to(device) 143 | self.count = torch.tensor(count).to(device) 144 | 145 | class Normalizer: 146 | def __init__(self, input_dim, device, epsilon=1e-2, clip=10.0): 147 | self.device = device 148 | self.mean = torch.zeros(input_dim, device=self.device) 149 | self.var = torch.ones(input_dim, device=self.device) 150 | self.count = epsilon 151 | self.epsilon = epsilon 152 | self.clip = clip 153 | 154 | def normalize(self, data): 155 | mean_ = self.mean 156 | std_ = torch.sqrt(self.var + self.epsilon) 157 | return torch.clamp((data - mean_) / std_, -self.clip, self.clip) 158 | 159 | def update(self, data): 160 | batch_mean = torch.mean(data, dim=0) 161 | batch_var = torch.var(data, dim=0) 162 | batch_count = data.shape[0] 163 | self.update_from_moments(batch_mean, batch_var, batch_count) 164 | 165 | def update_from_moments(self, batch_mean, batch_var, batch_count): 166 | delta = batch_mean - self.mean 167 | tot_count = self.count + batch_count 168 | 169 | new_mean = self.mean + delta * batch_count / tot_count 170 | new_var = (self.var * self.count + 171 | batch_var * batch_count + 172 | torch.square(delta) * self.count * batch_count / tot_count) / tot_count 173 | self.mean = new_mean 174 | self.var = new_var 175 | self.count = tot_count 176 | -------------------------------------------------------------------------------- /learning/modules/plotter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.decomposition import PCA 3 | from matplotlib.patches import Ellipse 4 | import matplotlib.colors as mcolors 5 | from matplotlib.colors import to_rgba 6 | 7 | class Plotter: 8 | def __init__(self) -> None: 9 | self.pca = PCA(n_components=2) 10 | 11 | def plot_pca(self, ax, manifold_collection, title=None, point_color=None, draw_line=True, draw_arrow=True): 12 | ax.cla() 13 | 14 | point_alpha = 0.3 15 | line_alpha = 0.2 16 | arrow_alpha = 1.0 17 | arrow_step = 50 18 | arrow_size = 0.015 19 | arrow_power = 1.0 20 | arrow_color = (0.25, 0.25, 0.5) 21 | 22 | num_steps = [len(manifold) for manifold in manifold_collection] 23 | manifolds = torch.cat(manifold_collection, dim=0).cpu() 24 | 25 | manifolds_pca = torch.tensor(self.pca.fit_transform(manifolds)).split(num_steps, dim=0) 26 | 27 | for i, manifold in enumerate(manifolds_pca): 28 | if draw_line: 29 | ax.plot(manifold[:, 0], manifold[:, 1], color=(0.0, 0.0, 0.0), alpha=line_alpha) 30 | if point_color is None: 31 | ax.scatter(manifold[:, 0], manifold[:, 1], alpha=point_alpha, label=i) 32 | else: 33 | ax.scatter(manifold[:, 0], manifold[:, 1], color=point_color[i], alpha=point_alpha, label=i) 34 | if draw_arrow: 35 | for j in range(0, len(manifold) - arrow_step, arrow_step): 36 | d = torch.norm(manifold[j, :], dim=-1) 37 | d = torch.pow(d, arrow_power) 38 | ax.arrow( 39 | manifold[j, 0], 40 | manifold[j, 1], 41 | manifold[j + 1, 0] - manifold[j, 0], 42 | manifold[j + 1, 1] - manifold[j, 1], 43 | alpha=arrow_alpha, width=d * arrow_size, color=arrow_color 44 | ) 45 | ax.legend() 46 | ax.set_axis_off() 47 | if title != None: 48 | ax.set_title(title) 49 | 50 | def append_pca_gmm(self, ax, mean, variance, color=None, alphas=None): 51 | mean = mean.cpu() 52 | variance = variance.cpu() 53 | n_components = mean.size(0) 54 | if color is None: 55 | color = 'red' 56 | if alphas is None: 57 | alphas = [0.8] * n_components 58 | else: 59 | alphas = alphas.cpu().tolist() 60 | for i in range(n_components): 61 | mu = mean[i] 62 | var = variance[i] 63 | mu_transformed = torch.tensor(self.pca.transform(mu.unsqueeze(0)), dtype=torch.float).squeeze(0) 64 | eigval, eigvec = torch.linalg.eigh(var) 65 | pca_components = torch.tensor(self.pca.components_, dtype=torch.float) 66 | eigvec_transformed = torch.matmul(pca_components, eigvec) 67 | var_projected = torch.matmul(eigvec_transformed, torch.matmul(torch.diag(eigval), eigvec_transformed.T)) 68 | eigval_transformed, eigvec_transformed = torch.linalg.eigh(var_projected) 69 | width = eigval_transformed[0].sqrt() * 2 70 | height = eigval_transformed[1].sqrt() * 2 71 | std_scale = 3.0 72 | angle = torch.atan2(eigvec_transformed[1, 0], eigvec_transformed[0, 0]) * 180 / torch.pi 73 | ell = Ellipse(mu_transformed, width * std_scale, height * std_scale, angle=angle, fc=to_rgba(color, 0.2 * alphas[i]), ec=to_rgba(color, alphas[i]), lw=3) 74 | ax.add_artist(ell) 75 | 76 | def plot_pca_intensity(self, ax, manifold_collection, values, cmap='YlOrRd', vmin=0.0, vmax=1.0, xmin=None, xmax=None, ymin=None, ymax=None, title=None): 77 | ax.cla() 78 | 79 | point_alpha = 0.5 80 | 81 | num_steps = [len(manifold) for manifold in manifold_collection] 82 | manifolds = torch.cat(manifold_collection, dim=0).cpu() 83 | 84 | manifolds_pca = torch.tensor(self.pca.transform(manifolds)).split(num_steps, dim=0) 85 | 86 | for i, manifold in enumerate(manifolds_pca): 87 | ax.scatter(manifold[:, 0], manifold[:, 1], alpha=point_alpha, c=values[i].cpu(), cmap=cmap, vmin=vmin, vmax=vmax) 88 | 89 | ax.set_axis_off() 90 | if title != None: 91 | ax.set_title(title) 92 | if xmin != None and xmax != None: 93 | ax.set_xlim(xmin, xmax) 94 | if ymin != None and ymax != None: 95 | ax.set_ylim(ymin, ymax) 96 | 97 | def plot_distribution(self, ax, values, title): 98 | ax.cla() 99 | values = values.cpu() 100 | means = torch.mean(values, dim=0) 101 | std = torch.std(values, dim=0) 102 | args = torch.arange(values.size(1)) 103 | ax.bar(args, means, yerr=std, align='center', alpha=0.5, ecolor='black', capsize=10) 104 | ax.set_xlabel('Channel') 105 | ax.set_xticks(args) 106 | ax.set_title(title) 107 | ax.yaxis.grid(True) 108 | 109 | def plot_histogram(self, ax, values, title): 110 | ax.cla() 111 | values = values.cpu() 112 | ax.hist(values, bins=50, density=True) 113 | ax.set_title(title) 114 | ax.yaxis.grid(True) 115 | 116 | def plot_gmm(self, ax, data, pred_mean, pred_var, color=None, ymin=None, ymax=None, title=None): 117 | # ax.cla() 118 | data = data.cpu() 119 | if pred_mean is not None and pred_var is not None: 120 | pred_mean = pred_mean.cpu() 121 | pred_var = torch.linalg.cholesky(pred_var).diagonal(dim1=-2, dim2=-1) if pred_var.dim() == 3 else pred_var 122 | pred_std = torch.sqrt(pred_var.cpu()) 123 | if data.dim() == 2: 124 | if color == None: 125 | color = "lightgrey" 126 | args = torch.arange(data.size(1)) 127 | theta_args = args * torch.pi / 4 128 | ax.plot(theta_args.repeat(data.size(0), 1).cpu(), data, "o", color="lightgrey", alpha=0.01, markersize=10, mew=0.0) 129 | elif data.dim() == 3: 130 | if color == None: 131 | color = list(mcolors.TABLEAU_COLORS.keys()) 132 | args = torch.arange(data.size(2)) 133 | theta_args = args * torch.pi / 4 134 | for i in range(data.size(0)): 135 | ax.plot(theta_args.repeat(data.size(1), 1).cpu(), data[i], "o", color=color[i], alpha=0.01, markersize=10, mew=0.0) 136 | if pred_mean is not None and pred_var is not None: 137 | theta_args_mean = torch.cat([theta_args, theta_args[0].unsqueeze(0)], dim=0) 138 | pred_mean = torch.cat([pred_mean, pred_mean[:, 0].unsqueeze(1)], dim=1) 139 | pred_std = torch.cat([pred_std, pred_std[:, 0].unsqueeze(1)], dim=1) 140 | for i in range(pred_mean.size(0)): 141 | ax.errorbar(theta_args_mean, pred_mean[i], yerr=pred_std[i], fmt="o-", markersize=4, capsize=5, alpha=0.5, linewidth=3) 142 | ax.set_xlabel('Channel') 143 | ax.set_xticks(theta_args) 144 | ax.set_xticklabels(args.tolist()) 145 | if title != None: 146 | ax.set_title(title) 147 | if ymin != None and ymax != None: 148 | ymin = ymin.cpu() 149 | ymax = ymax.cpu() 150 | ax.set_ylim(ymin, ymax) 151 | ax.yaxis.grid(True) 152 | 153 | def plot_correlation(self, ax, performance, score, title=None): 154 | ax.cla() 155 | performance = performance.cpu() 156 | score = score.cpu() 157 | ax.plot(performance, score, "o", color="tab:grey", alpha=0.002, markersize=10) 158 | ax.set_xlabel('Performance') 159 | ax.set_ylabel('Score') 160 | ax.set_xlim(0.0, 1.0) 161 | if title != None: 162 | ax.set_title(title) 163 | 164 | def plot_circles(self, ax, phase, amplitude, title=None, show_axes=True): 165 | ax.cla() 166 | phase = phase.cpu() 167 | amplitude = amplitude.cpu() 168 | 169 | aspect = 0.5 170 | ax.set_aspect(aspect) 171 | channel = phase.shape[0] 172 | ax.set_xlim(0.0, channel + 1.0) 173 | ax.set_ylim(-1.0, 1.0) 174 | theta = torch.linspace(0.0, 2 * torch.pi, 100) 175 | 176 | for i in range(channel): 177 | p = phase[i] 178 | a = amplitude[i] 179 | x1 = aspect * a * torch.cos(theta) + i + 1 180 | x2 = a * torch.sin(theta) 181 | ax.plot(x1, x2) 182 | line_x1 = [i + 1, i + 1 + aspect * a * torch.cos(2 * torch.pi * p)] 183 | line_x2 = [0.0, a * torch.sin(2 * torch.pi * p)] 184 | ax.plot(line_x1, line_x2, color=(0, 0, 0)) 185 | 186 | if title != None: 187 | ax.set_title(title) 188 | if show_axes == False: 189 | ax.axes.xaxis.set_visible(False) 190 | ax.axes.yaxis.set_visible(False) 191 | 192 | def plot_curves(self, ax, values, xmin, xmax, ymin, ymax, title=None, show_axes=True): 193 | ax.cla() 194 | values = values.cpu() 195 | args = torch.linspace(xmin, xmax, values.size(1)).repeat(values.size(0), 1) 196 | ax.plot(args.swapaxes(0, 1), values.swapaxes(0, 1)) 197 | ax.set_ylim(ymin, ymax) 198 | if title != None: 199 | ax.set_title(title) 200 | if show_axes == False: 201 | ax.axes.xaxis.set_visible(False) 202 | ax.axes.yaxis.set_visible(False) 203 | 204 | def plot_phase_1d(self, ax, phase, amplitude, title=None, show_axes=True): 205 | ax.cla() 206 | phase = phase.cpu() 207 | amplitude = amplitude.cpu() 208 | 209 | phase = torch.where(phase < 0, phase, phase + 1) 210 | phase = phase % 1.0 211 | args = torch.arange(phase.size(0)) 212 | amplitude = torch.clip(amplitude, 0.0, 1.0) 213 | for i in range(1, phase.size(0)): 214 | ax.plot( 215 | [args[i - 1].item(), args[i].item()], 216 | [phase[i - 1].item(), phase[i].item()], 217 | color=(0.0, 0.0, 0.0), 218 | alpha=amplitude[i].item() 219 | ) 220 | ax.set_ylim(0.0, 1.0) 221 | 222 | if title != None: 223 | ax.set_title(title) 224 | if show_axes == False: 225 | ax.axes.xaxis.set_visible(False) 226 | ax.axes.yaxis.set_visible(False) 227 | 228 | def plot_phase_2d(self, ax, phase, amplitude, title=None, show_axes=True): 229 | ax.cla() 230 | phase = phase.cpu() 231 | amplitude = amplitude.cpu() 232 | 233 | args = torch.arange(phase.size(0)) 234 | 235 | phase_x1 = amplitude * torch.sin(2 * torch.pi * phase) 236 | phase_x2 = amplitude * torch.cos(2 * torch.pi * phase) 237 | 238 | ax.plot(args, phase_x1) 239 | ax.plot(args, phase_x2) 240 | ax.set_ylim(-1.0, 1.0) 241 | 242 | if title != None: 243 | ax.set_title(title) 244 | if show_axes == False: 245 | ax.axes.xaxis.set_visible(False) 246 | ax.axes.yaxis.set_visible(False) -------------------------------------------------------------------------------- /learning/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Implementation of runners for environment-agent interaction.""" 5 | 6 | from .gld_on_policy_runner import FLDOnPolicyRunner 7 | 8 | __all__ = ["FLDOnPolicyRunner"] 9 | -------------------------------------------------------------------------------- /learning/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of task samplers.""" 2 | 3 | from .offline import OfflineSampler 4 | from .random import RandomSampler 5 | from .gmm import GMMSampler 6 | from .alp_gmm import ALPGMMSampler 7 | 8 | __all__ = ["OfflineSampler", "RandomSampler", "GMMSampler", "ALPGMMSampler"] 9 | -------------------------------------------------------------------------------- /learning/samplers/alp_gmm.py: -------------------------------------------------------------------------------- 1 | from learning.samplers.base import BaseSampler 2 | from learning.modules.gmm import GaussianMixture, GaussianMixtures 3 | import torch 4 | import torch.nn as nn 5 | import faiss 6 | 7 | 8 | class ALPGMMSampler(BaseSampler): 9 | def __init__(self, init_n_components, min_n_components, max_n_components, n_features, min, max, device, covariance_type="full", curriculum_scale=1.0, random_type="uniform"): 10 | super().__init__() 11 | self.min_n_components = min_n_components 12 | self.max_n_components = max_n_components 13 | self.n_features = n_features 14 | self.min = nn.Parameter(min.clone().detach(), requires_grad=False) 15 | self.max = nn.Parameter(max.clone().detach(), requires_grad=False) 16 | self.device = device 17 | self.curriculum_scale = curriculum_scale 18 | self.random_type = random_type 19 | self.gmms = GaussianMixtures(min_n_components, max_n_components, n_features + 1, device=device, covariance_type=covariance_type).eval() 20 | self.gmm_idx = nn.Parameter(torch.tensor(-1, device=device, dtype=torch.long), requires_grad=False) 21 | self.init_random_sampling = nn.Parameter(torch.tensor(True, device=device, dtype=torch.bool), requires_grad=False) 22 | self.task_performance_buffer = KNNBuffer(n_features, 1, device=device) 23 | if random_type == "gmm": 24 | self.gmm = GaussianMixture(init_n_components, n_features, device=device, covariance_type=covariance_type).eval() 25 | 26 | def load_gmm(self, load_path): 27 | self.gmm.load_state_dict(torch.load(load_path)["gmm_state_dict"]) 28 | 29 | def compute_alp(self, task, performance): 30 | old_performance = self.task_performance_buffer.get_k_nearest_neighbors(task, performance) 31 | alp = (performance - old_performance).abs() 32 | return alp 33 | 34 | 35 | def update(self, task, performance): 36 | alp = self.compute_alp(task, performance) 37 | self.task_performance_buffer.insert(task, performance) 38 | x = torch.cat((task, alp), dim=-1) 39 | self.gmms.fit(x) 40 | self.gmm_idx.data = self.gmms.get_best_gmm_idx(x) 41 | self.init_random_sampling.data = torch.tensor(False, device=self.device, dtype=torch.bool) 42 | 43 | 44 | def update_curriculum(self): 45 | if self.init_random_sampling: 46 | pass 47 | else: 48 | self.gmms.candidates[self.gmm_idx].set_variances(self.gmms.candidates[self.gmm_idx].var * self.curriculum_scale ** 2) 49 | 50 | 51 | def sample(self, n_samples, random_ratio=0.2): 52 | if self.init_random_sampling: 53 | return (self.max - self.min) * torch.rand(n_samples, self.n_features, device=self.device, dtype=torch.float, requires_grad=False) + self.min 54 | random = torch.rand(1, device=self.device, dtype=torch.float, requires_grad=False) < random_ratio 55 | if random: 56 | if self.random_type == "uniform": 57 | return (self.max - self.min) * torch.rand(n_samples, self.n_features, device=self.device, dtype=torch.float, requires_grad=False) + self.min 58 | elif self.random_type == "gmm": 59 | return self.gmm.sample(n_samples)[0] 60 | else: 61 | return self.gmms.sample(n_samples, self.gmm_idx)[:, :-1] 62 | 63 | 64 | def get_knn_buffer_size(self): 65 | return self.task_performance_buffer.key_index.ntotal 66 | 67 | 68 | def get_knn_buffer(self): 69 | keys, values = self.task_performance_buffer.get_samples(10000) 70 | return keys, values 71 | 72 | 73 | class KNNBuffer: 74 | def __init__(self, key_dim, value_dim, device): 75 | res = faiss.StandardGpuResources() 76 | key_index = faiss.IndexFlatL2(key_dim) 77 | self.key_index = faiss.index_cpu_to_gpu(res, 0, key_index) 78 | self.values = torch.zeros(0, value_dim, device=device, dtype=torch.float, requires_grad=False) 79 | self.key_dim = key_dim 80 | self.value_dim = value_dim 81 | self.device = device 82 | 83 | 84 | def insert(self, keys, values): 85 | self.key_index.add(keys.cpu().numpy()) 86 | self.values = torch.cat((self.values, values), dim=0) 87 | 88 | 89 | def get_k_nearest_neighbors(self, keys, values, k=1, average=True): 90 | if self.key_index.ntotal == 0: 91 | return torch.zeros((keys.size(0), self.value_dim), device=self.device, dtype=torch.float) 92 | _, k_nearest_neighbors_ids = self.key_index.search(keys.cpu().numpy(), k) 93 | k_nearest_neighbors_values = self.values[k_nearest_neighbors_ids, :].clone() 94 | # update the values of the k nearest neighbors 95 | self.values[k_nearest_neighbors_ids, :] = values.unsqueeze(1).repeat(1, k, 1) 96 | if average: 97 | return k_nearest_neighbors_values.mean(dim=1) 98 | else: 99 | return k_nearest_neighbors_values 100 | 101 | 102 | def get_samples(self, n_samples): 103 | if self.key_index.ntotal <= n_samples: 104 | return torch.tensor(self.key_index.reconstruct_n(0, self.key_index.ntotal), device=self.device, dtype=torch.float, requires_grad=False), self.values 105 | else: 106 | sample_ids = torch.randperm(self.key_index.ntotal, device=self.device, dtype=torch.long, requires_grad=False)[:n_samples] 107 | samples = torch.tensor(self.key_index.reconstruct_batch(sample_ids.cpu().numpy()), device=self.device, dtype=torch.float, requires_grad=False) 108 | return samples, self.values[sample_ids, :] 109 | -------------------------------------------------------------------------------- /learning/samplers/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class BaseSampler(nn.Module): 4 | 5 | def __init__(self): 6 | super().__init__() 7 | 8 | 9 | def update(self): 10 | raise NotImplementedError 11 | 12 | 13 | def update_curriculum(self): 14 | raise NotImplementedError 15 | 16 | 17 | def sample(self): 18 | raise NotImplementedError -------------------------------------------------------------------------------- /learning/samplers/gmm.py: -------------------------------------------------------------------------------- 1 | from learning.samplers.base import BaseSampler 2 | from learning.modules.gmm import GaussianMixture 3 | import torch 4 | 5 | 6 | class GMMSampler(BaseSampler): 7 | def __init__(self, n_components, n_features, device, covariance_type="full", curriculum_scale=1.0): 8 | super().__init__() 9 | self.n_components = n_components 10 | self.n_features = n_features 11 | self.device = device 12 | self.curriculum_scale = curriculum_scale 13 | self.gmm = GaussianMixture(n_components, n_features, device=device, covariance_type=covariance_type).eval() 14 | 15 | 16 | def load_gmm(self, load_path): 17 | self.gmm.load_state_dict(torch.load(load_path)["gmm_state_dict"]) 18 | 19 | 20 | def update(self, x): 21 | self.gmm.fit(x) 22 | 23 | 24 | def update_curriculum(self): 25 | self.gmm.set_variances(self.gmm.var * self.curriculum_scale ** 2) 26 | 27 | 28 | def sample(self, n_samples): 29 | return self.gmm.sample(n_samples)[0] 30 | -------------------------------------------------------------------------------- /learning/samplers/offline.py: -------------------------------------------------------------------------------- 1 | from learning.samplers.base import BaseSampler 2 | import torch 3 | 4 | 5 | class OfflineSampler(BaseSampler): 6 | def __init__(self, device): 7 | super().__init__() 8 | self.device = device 9 | 10 | 11 | def load_data(self, load_path): 12 | self.data = torch.load(load_path) 13 | 14 | 15 | def update(self, x): 16 | pass 17 | 18 | 19 | def update_curriculum(self): 20 | pass 21 | 22 | 23 | def sample(self, n_samples): 24 | sample_ids = torch.randint(0, self.data.size(0), (n_samples,), device=self.device, dtype=torch.long, requires_grad=False) 25 | return self.data[sample_ids] 26 | -------------------------------------------------------------------------------- /learning/samplers/random.py: -------------------------------------------------------------------------------- 1 | from learning.samplers.base import BaseSampler 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class RandomSampler(BaseSampler): 7 | def __init__(self, n_features, min, max, device, curriculum_scale=1.0): 8 | super().__init__() 9 | self.n_features = n_features 10 | self.min = nn.Parameter(min.clone().detach(), requires_grad=False) 11 | self.max = nn.Parameter(max.clone().detach(), requires_grad=False) 12 | self.device = device 13 | self.curriculum_scale = curriculum_scale 14 | 15 | 16 | def update(self, x): 17 | pass 18 | 19 | 20 | def update_curriculum(self): 21 | mean = (self.min + self.max) / 2 22 | std = (self.max - self.min) / 2 23 | std *= self.curriculum_scale 24 | self.max.data = mean + std 25 | self.min.data = mean - std 26 | 27 | 28 | def sample(self, n_samples): 29 | return (self.max - self.min) * torch.rand(n_samples, self.n_features, device=self.device, dtype=torch.float, requires_grad=False) + self.min 30 | -------------------------------------------------------------------------------- /learning/storage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Implementation of transitions storage for RL-agent.""" 5 | 6 | from .rollout_storage import RolloutStorage 7 | 8 | __all__ = ["RolloutStorage"] 9 | -------------------------------------------------------------------------------- /learning/storage/distribution_buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DistributionBuffer: 4 | def __init__(self, buffer_dim, buffer_size, device) -> None: 5 | self.distribution_buffer = torch.zeros(buffer_size, buffer_dim, dtype=torch.float, requires_grad=False).to(device) 6 | self.buffer_size = buffer_size 7 | self.device = device 8 | self.step = 0 9 | self.num_samples = 0 10 | 11 | def insert(self, data): 12 | num_data = data.shape[0] 13 | start_idx = self.step 14 | end_idx = self.step + num_data 15 | if end_idx > self.buffer_size: 16 | self.distribution_buffer[self.step:self.buffer_size] = data[:self.buffer_size - self.step] 17 | self.distribution_buffer[:end_idx - self.buffer_size] = data[self.buffer_size - self.step:] 18 | else: 19 | self.distribution_buffer[start_idx:end_idx] = data 20 | 21 | self.num_samples = min(self.buffer_size, max(end_idx, self.num_samples)) 22 | self.step = (self.step + num_data) % self.buffer_size 23 | 24 | def get_distribution(self): 25 | return self.distribution_buffer 26 | -------------------------------------------------------------------------------- /learning/storage/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ReplayBuffer: 6 | """Fixed-size buffer to store experience tuples.""" 7 | 8 | def __init__(self, obs_dim, obs_horizon, buffer_size, device): 9 | """Initialize a ReplayBuffer object. 10 | Arguments: 11 | buffer_size (int): maximum size of buffer 12 | """ 13 | self.state_buf = torch.zeros(buffer_size, obs_horizon, obs_dim).to(device) 14 | self.buffer_size = buffer_size 15 | self.device = device 16 | 17 | self.step = 0 18 | self.num_samples = 0 19 | 20 | def insert(self, state_buf): 21 | """Add new states to memory.""" 22 | 23 | num_states = state_buf.shape[0] 24 | start_idx = self.step 25 | end_idx = self.step + num_states 26 | if end_idx > self.buffer_size: 27 | self.state_buf[self.step:self.buffer_size] = state_buf[:self.buffer_size - self.step] 28 | self.state_buf[:end_idx - self.buffer_size] = state_buf[self.buffer_size - self.step:] 29 | else: 30 | self.state_buf[start_idx:end_idx] = state_buf 31 | 32 | self.num_samples = min(self.buffer_size, max(end_idx, self.num_samples)) 33 | self.step = (self.step + num_states) % self.buffer_size 34 | 35 | def feed_forward_generator(self, num_mini_batch, mini_batch_size): 36 | for _ in range(num_mini_batch): 37 | sample_idxs = np.random.choice(self.num_samples, size=mini_batch_size) 38 | yield self.state_buf[sample_idxs, :].to(self.device) 39 | -------------------------------------------------------------------------------- /learning/storage/rollout_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ETH Zurich, NVIDIA CORPORATION 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | # torch 5 | import torch 6 | 7 | # learning 8 | from learning.utils import split_and_pad_trajectories 9 | 10 | 11 | class RolloutStorage: 12 | class Transition: 13 | def __init__(self): 14 | self.observations = None 15 | self.critic_observations = None 16 | self.actions = None 17 | self.rewards = None 18 | self.dones = None 19 | self.values = None 20 | self.actions_log_prob = None 21 | self.action_mean = None 22 | self.action_sigma = None 23 | self.hidden_states = None 24 | 25 | def clear(self): 26 | self.__init__() 27 | 28 | def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device="cpu"): 29 | 30 | self.device = device 31 | 32 | self.obs_shape = obs_shape 33 | self.privileged_obs_shape = privileged_obs_shape 34 | self.actions_shape = actions_shape 35 | 36 | # Core 37 | self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device) 38 | if privileged_obs_shape[0] is not None: 39 | self.privileged_observations = torch.zeros( 40 | num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device 41 | ) 42 | else: 43 | self.privileged_observations = None 44 | self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 45 | self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 46 | self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() 47 | 48 | # For PPO 49 | self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 50 | self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 51 | self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 52 | self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 53 | self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 54 | self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 55 | 56 | self.num_transitions_per_env = num_transitions_per_env 57 | self.num_envs = num_envs 58 | 59 | # rnn 60 | self.saved_hidden_states_a = None 61 | self.saved_hidden_states_c = None 62 | 63 | self.step = 0 64 | 65 | def add_transitions(self, transition: Transition): 66 | if self.step >= self.num_transitions_per_env: 67 | raise AssertionError("Rollout buffer overflow") 68 | self.observations[self.step].copy_(transition.observations) 69 | if self.privileged_observations is not None: 70 | self.privileged_observations[self.step].copy_(transition.critic_observations) 71 | self.actions[self.step].copy_(transition.actions) 72 | self.rewards[self.step].copy_(transition.rewards.view(-1, 1)) 73 | self.dones[self.step].copy_(transition.dones.view(-1, 1)) 74 | self.values[self.step].copy_(transition.values) 75 | self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) 76 | self.mu[self.step].copy_(transition.action_mean) 77 | self.sigma[self.step].copy_(transition.action_sigma) 78 | self._save_hidden_states(transition.hidden_states) 79 | self.step += 1 80 | 81 | def _save_hidden_states(self, hidden_states): 82 | if hidden_states is None or hidden_states == (None, None): 83 | return 84 | # make a tuple out of GRU hidden state sto match the LSTM format 85 | hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) 86 | hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) 87 | 88 | # initialize if needed 89 | if self.saved_hidden_states_a is None: 90 | self.saved_hidden_states_a = [ 91 | torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a)) 92 | ] 93 | self.saved_hidden_states_c = [ 94 | torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c)) 95 | ] 96 | # copy the states 97 | for i in range(len(hid_a)): 98 | self.saved_hidden_states_a[i][self.step].copy_(hid_a[i]) 99 | self.saved_hidden_states_c[i][self.step].copy_(hid_c[i]) 100 | 101 | def clear(self): 102 | self.step = 0 103 | 104 | def compute_returns(self, last_values, gamma, lam): 105 | advantage = 0 106 | for step in reversed(range(self.num_transitions_per_env)): 107 | if step == self.num_transitions_per_env - 1: 108 | next_values = last_values 109 | else: 110 | next_values = self.values[step + 1] 111 | next_is_not_terminal = 1.0 - self.dones[step].float() 112 | delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] 113 | advantage = delta + next_is_not_terminal * gamma * lam * advantage 114 | self.returns[step] = advantage + self.values[step] 115 | 116 | # Compute and normalize the advantages 117 | self.advantages = self.returns - self.values 118 | self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) 119 | 120 | def get_statistics(self): 121 | done = self.dones 122 | done[-1] = 1 123 | flat_dones = done.permute(1, 0, 2).reshape(-1, 1) 124 | done_indices = torch.cat( 125 | (flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]) 126 | ) 127 | trajectory_lengths = done_indices[1:] - done_indices[:-1] 128 | return trajectory_lengths.float().mean(), self.rewards.mean() 129 | 130 | def mini_batch_generator(self, num_mini_batches, num_epochs=8): 131 | batch_size = self.num_envs * self.num_transitions_per_env 132 | mini_batch_size = batch_size // num_mini_batches 133 | indices = torch.randperm(num_mini_batches * mini_batch_size, requires_grad=False, device=self.device) 134 | 135 | observations = self.observations.flatten(0, 1) 136 | if self.privileged_observations is not None: 137 | critic_observations = self.privileged_observations.flatten(0, 1) 138 | else: 139 | critic_observations = observations 140 | 141 | actions = self.actions.flatten(0, 1) 142 | values = self.values.flatten(0, 1) 143 | returns = self.returns.flatten(0, 1) 144 | old_actions_log_prob = self.actions_log_prob.flatten(0, 1) 145 | advantages = self.advantages.flatten(0, 1) 146 | old_mu = self.mu.flatten(0, 1) 147 | old_sigma = self.sigma.flatten(0, 1) 148 | 149 | for epoch in range(num_epochs): 150 | for i in range(num_mini_batches): 151 | 152 | start = i * mini_batch_size 153 | end = (i + 1) * mini_batch_size 154 | batch_idx = indices[start:end] 155 | 156 | obs_batch = observations[batch_idx] 157 | critic_observations_batch = critic_observations[batch_idx] 158 | actions_batch = actions[batch_idx] 159 | target_values_batch = values[batch_idx] 160 | returns_batch = returns[batch_idx] 161 | old_actions_log_prob_batch = old_actions_log_prob[batch_idx] 162 | advantages_batch = advantages[batch_idx] 163 | old_mu_batch = old_mu[batch_idx] 164 | old_sigma_batch = old_sigma[batch_idx] 165 | yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( 166 | None, 167 | None, 168 | ), None 169 | 170 | # for RNNs only 171 | def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8): 172 | 173 | padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) 174 | if self.privileged_observations is not None: 175 | padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones) 176 | else: 177 | padded_critic_obs_trajectories = padded_obs_trajectories 178 | 179 | mini_batch_size = self.num_envs // num_mini_batches 180 | for ep in range(num_epochs): 181 | first_traj = 0 182 | for i in range(num_mini_batches): 183 | start = i * mini_batch_size 184 | stop = (i + 1) * mini_batch_size 185 | 186 | dones = self.dones.squeeze(-1) 187 | last_was_done = torch.zeros_like(dones, dtype=torch.bool) 188 | last_was_done[1:] = dones[:-1] 189 | last_was_done[0] = True 190 | trajectories_batch_size = torch.sum(last_was_done[:, start:stop]) 191 | last_traj = first_traj + trajectories_batch_size 192 | 193 | masks_batch = trajectory_masks[:, first_traj:last_traj] 194 | obs_batch = padded_obs_trajectories[:, first_traj:last_traj] 195 | critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj] 196 | 197 | actions_batch = self.actions[:, start:stop] 198 | old_mu_batch = self.mu[:, start:stop] 199 | old_sigma_batch = self.sigma[:, start:stop] 200 | returns_batch = self.returns[:, start:stop] 201 | advantages_batch = self.advantages[:, start:stop] 202 | values_batch = self.values[:, start:stop] 203 | old_actions_log_prob_batch = self.actions_log_prob[:, start:stop] 204 | 205 | # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim]) 206 | # then take only time steps after dones (flattens num envs and time dimensions), 207 | # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim] 208 | last_was_done = last_was_done.permute(1, 0) 209 | hid_a_batch = [ 210 | saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] 211 | .transpose(1, 0) 212 | .contiguous() 213 | for saved_hidden_states in self.saved_hidden_states_a 214 | ] 215 | hid_c_batch = [ 216 | saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] 217 | .transpose(1, 0) 218 | .contiguous() 219 | for saved_hidden_states in self.saved_hidden_states_c 220 | ] 221 | # remove the tuple for GRU 222 | hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch 223 | hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_a_batch 224 | 225 | yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( 226 | hid_a_batch, 227 | hid_c_batch, 228 | ), masks_batch 229 | 230 | first_traj = last_traj 231 | -------------------------------------------------------------------------------- /learning/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Helper functions.""" 2 | 3 | from .utils import split_and_pad_trajectories, unpad_trajectories, store_code_state, get_skew_matrix, get_base_ang_vel_from_base_quat, get_base_quat_from_base_ang_vel 4 | -------------------------------------------------------------------------------- /learning/utils/utils.py: -------------------------------------------------------------------------------- 1 | # python 2 | import os 3 | import git 4 | import pathlib 5 | 6 | # torch 7 | import torch 8 | 9 | 10 | def split_and_pad_trajectories(tensor, dones): 11 | """Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory. 12 | Returns masks corresponding to valid parts of the trajectories 13 | Example: 14 | Input: [ [a1, a2, a3, a4 | a5, a6], 15 | [b1, b2 | b3, b4, b5 | b6] 16 | ] 17 | 18 | Output:[ [a1, a2, a3, a4], | [ [True, True, True, True], 19 | [a5, a6, 0, 0], | [True, True, False, False], 20 | [b1, b2, 0, 0], | [True, True, False, False], 21 | [b3, b4, b5, 0], | [True, True, True, False], 22 | [b6, 0, 0, 0] | [True, False, False, False], 23 | ] | ] 24 | 25 | Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions] 26 | """ 27 | dones = dones.clone() 28 | dones[-1] = 1 29 | # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping 30 | flat_dones = dones.transpose(1, 0).reshape(-1, 1) 31 | 32 | # Get length of trajectory by counting the number of successive not done elements 33 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0])) 34 | trajectory_lengths = done_indices[1:] - done_indices[:-1] 35 | trajectory_lengths_list = trajectory_lengths.tolist() 36 | # Extract the individual trajectories 37 | trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list) 38 | padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) 39 | 40 | trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) 41 | return padded_trajectories, trajectory_masks 42 | 43 | 44 | def unpad_trajectories(trajectories, masks): 45 | """Does the inverse operation of split_and_pad_trajectories()""" 46 | # Need to transpose before and after the masking to have proper reshaping 47 | return ( 48 | trajectories.transpose(1, 0)[masks.transpose(1, 0)] 49 | .view(-1, trajectories.shape[0], trajectories.shape[-1]) 50 | .transpose(1, 0) 51 | ) 52 | 53 | 54 | def store_code_state(logdir, repositories): 55 | for repository_file_path in repositories: 56 | repo = git.Repo(repository_file_path, search_parent_directories=True) 57 | repo_name = pathlib.Path(repo.working_dir).name 58 | t = repo.head.commit.tree 59 | content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" 60 | with open(os.path.join(logdir, f"{repo_name}_git.diff"), "x") as f: 61 | f.write(content) 62 | 63 | 64 | def get_skew_matrix(vec): 65 | """Get the skew matrix of a vector.""" 66 | matrix = torch.zeros(vec.shape[0], 3, 3, dtype=torch.float, requires_grad=False) 67 | matrix[:, 0, 1] = -vec[:, 2] 68 | matrix[:, 0, 2] = vec[:, 1] 69 | matrix[:, 1, 0] = vec[:, 2] 70 | matrix[:, 1, 2] = -vec[:, 0] 71 | matrix[:, 2, 0] = -vec[:, 1] 72 | matrix[:, 2, 1] = vec[:, 0] 73 | return matrix 74 | 75 | 76 | def get_base_ang_vel_from_base_quat(base_quat, dt, target_frame="local"): 77 | """ 78 | Get the base angular velocity from the base quaternion. 79 | args: 80 | base_quat: torch.Tensor (num_trajs, num_steps, 4) 81 | dt: float 82 | returns: 83 | base_ang_vel: torch.Tensor (num_trajs, num_steps, 3) expressed in the target frame 84 | """ 85 | num_trajs, num_steps, _ = base_quat.size() 86 | device = base_quat.device 87 | mapping = torch.zeros(num_trajs, num_steps, 3, 4, device=device, dtype=torch.float, requires_grad=False) 88 | mapping[:, :, :, -1] = -base_quat[:, :, :-1] 89 | if target_frame == "local": 90 | mapping[:, :, :, :-1] = get_skew_matrix(-base_quat[:, :, :-1].flatten(0, 1)).view(num_trajs, num_steps, 3, 3) 91 | elif target_frame == "global": 92 | mapping[:, :, :, :-1] = get_skew_matrix(base_quat[:, :, :-1].flatten(0, 1)).view(num_trajs, num_steps, 3, 3) 93 | else: 94 | raise ValueError(f"Unknown target frame {target_frame}") 95 | mapping[:, :, :, :-1] += torch.eye(3, device=device, dtype=torch.float, requires_grad=False).repeat(num_trajs, num_steps, 1, 1) * base_quat[:, :, -1].unsqueeze(-1).unsqueeze(-1) 96 | base_ang_vel = 2 * mapping[:, :-1, :, :] @ ((base_quat[:, 1:, :] - base_quat[:, :-1, :]) / dt).unsqueeze(-1) 97 | base_ang_vel = torch.cat((base_ang_vel[:, [0]], base_ang_vel), dim=1).squeeze(-1) 98 | return base_ang_vel 99 | 100 | 101 | def get_base_quat_from_base_ang_vel(base_ang_vel, dt, source_frame="local", init_base_quat=None): 102 | """ 103 | Get the base quaternion from the base angular velocity. 104 | args: 105 | base_ang_vel: torch.Tensor (num_trajs, num_steps, 3) expressed in the source frame 106 | dt: float 107 | returns: 108 | base_quat: torch.Tensor (num_trajs, num_steps, 4) 109 | """ 110 | num_trajs, num_steps, _ = base_ang_vel.size() 111 | device = base_ang_vel.device 112 | if init_base_quat is None: 113 | init_base_quat = torch.tensor([0.0, 0.0, 0.0, 1.0], device=device, dtype=torch.float, requires_grad=False).repeat(num_trajs, 1) 114 | base_quat = torch.zeros(num_trajs, num_steps, 4, device=device, dtype=torch.float, requires_grad=False) 115 | base_quat[:, 0, :] = init_base_quat 116 | for step in range(num_steps - 1): 117 | base_quat_step = base_quat[:, step, :] 118 | mapping = torch.zeros(num_trajs, 3, 4, device=device, dtype=torch.float, requires_grad=False) 119 | mapping[:, :, -1] = -base_quat_step[:, :-1] 120 | if source_frame == "local": 121 | mapping[:, :, :-1] = get_skew_matrix(-base_quat_step[:, :-1]) 122 | elif source_frame == "global": 123 | mapping[:, :, :-1] = get_skew_matrix(base_quat_step[:, :-1]) 124 | else: 125 | raise ValueError(f"Unknown source frame {source_frame}") 126 | mapping[:, :, :-1] += torch.eye(3, device=device, dtype=torch.float, requires_grad=False).repeat(num_trajs, 1, 1) * base_quat_step[:, -1].unsqueeze(-1).unsqueeze(-1) 127 | base_ang_vel_step = base_ang_vel[:, step, :].unsqueeze(-1) 128 | quat_change_est = (0.5 * dt * mapping.transpose(-2, -1) @ base_ang_vel_step).squeeze(-1) 129 | base_quat[:, step + 1, :] = quat_change_est + base_quat_step 130 | return base_quat 131 | -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/anomalous/motion_data_crossover.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/anomalous/motion_data_crossover.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/anomalous/motion_data_jump.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/anomalous/motion_data_jump.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/anomalous/motion_data_kick.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/anomalous/motion_data_kick.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/anomalous/motion_data_spinkick.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/anomalous/motion_data_spinkick.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/decoded/reference_state_idx_dict.json: -------------------------------------------------------------------------------- 1 | {"base_pos": [0, 3], "base_quat": [3, 7], "base_lin_vel": [7, 10], "base_ang_vel": [10, 13], "projected_gravity": [13, 16], "dof_pos": [16, 34], "dof_vel": [34, 52]} -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_back.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_back.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_jog.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_jog.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_jog_slow.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_jog_slow.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_run.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_run.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_side_left.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_side_left.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_side_right.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_side_right.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_step.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_step.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_step_fast.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_step_fast.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/motion_data_stride.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/datasets/misc/motion_data_stride.pt -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/datasets/misc/reference_state_idx_dict.json: -------------------------------------------------------------------------------- 1 | {"base_pos": [0, 3], "base_quat": [3, 7], "base_lin_vel": [7, 10], "base_ang_vel": [10, 13], "projected_gravity": [13, 16], "dof_pos": [16, 34], "dof_vel": [34, 52]} -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/back.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/back.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_foot.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_foot.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_forearm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_forearm.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_hip_abad.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_hip_abad.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_hip_yaw.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_hip_yaw.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_leg_lower.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_leg_lower.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_leg_upper.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_leg_upper.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_shoulder1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_shoulder1.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_shoulder2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_shoulder2.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/left_shoulder3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/left_shoulder3.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/meshes/torso.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-biomimetics/fld/a2b04c51492042bf478a99226942d741af5b7e17/resources/robots/mit_humanoid/meshes/torso.stl -------------------------------------------------------------------------------- /resources/robots/mit_humanoid/urdf/README.md: -------------------------------------------------------------------------------- 1 | # MIT Humanoid URDFs 2 | 3 | ### Naming convention: 4 | F: Full collision boxes (each link gets a collision box) 5 | R: Reduced collision boxes (only feet, body, and hands get a collision box) 6 | sf: single-foot (one collision box for the entire foot) 7 | ht: heel-toe (two collision boxes for each foot) 8 | 9 | ### Notes 10 | - colors don't seem to get rendered 11 | - visualization is _a lot_ faster with just the collision boxes. -------------------------------------------------------------------------------- /scripts/fld/experiment.py: -------------------------------------------------------------------------------- 1 | from humanoid_gym import LEGGED_GYM_ROOT_DIR 2 | from scripts.fld.training import FLDTraining 3 | import os 4 | import torch 5 | 6 | 7 | class FLDExperiment: 8 | """ 9 | Represents an experiment for FLD (Fourier Latent Dynamics). 10 | 11 | Args: 12 | state_idx_dict (dict): A dictionary mapping state names to their corresponding indices. 13 | history_horizon (int): The length of the input observation window. 14 | forecast_horizon (int): The number of consecutive future steps to predict while maintaining the quasi-constant latent parameterization. 15 | device (str): The device to use for computation. 16 | 17 | """ 18 | 19 | def __init__(self, state_idx_dict, history_horizon, forecast_horizon, device): 20 | self.state_idx_dict = state_idx_dict 21 | self.history_horizon = history_horizon 22 | self.forecast_horizon = forecast_horizon 23 | self.dim_of_interest = torch.cat( 24 | [ 25 | torch.tensor(ids, device=device, dtype=torch.long, requires_grad=False) 26 | for state, ids in state_idx_dict.items() 27 | if ((state != "base_pos") and (state != "base_quat")) 28 | ] 29 | ) 30 | self.device = device 31 | 32 | def prepare_data(self): 33 | """ 34 | Loads and prepares the motion data. 35 | 36 | This method loads the motion data from the specified directory, normalizes it, 37 | and calculates the mean and standard deviation of the state transitions data. 38 | 39 | """ 40 | datasets_root = os.path.join(LEGGED_GYM_ROOT_DIR + "/resources/robots/mit_humanoid/datasets/misc") 41 | motion_data = os.listdir(datasets_root) 42 | motion_name_set = [data.replace('motion_data_', '').replace('.pt', '') for data in motion_data if "combined" not in data and ".pt" in data] 43 | motion_data_collection = [] 44 | 45 | for i, motion_name in enumerate(motion_name_set): 46 | motion_path = os.path.join(datasets_root, "motion_data_" + motion_name + ".pt") 47 | motion_data = torch.load(motion_path, map_location=self.device)[:, :, self.dim_of_interest] # (num_trajs, traj_len, obs_dim) 48 | loaded_num_trajs, loaded_num_steps, loaded_obs_dim = motion_data.size() 49 | print(f"[Motion Loader] Loaded motion {motion_name} with {loaded_num_trajs} trajectories, {loaded_num_steps} steps with {loaded_obs_dim} dimensions.") 50 | motion_data_collection.append(motion_data.unsqueeze(0)) 51 | 52 | motion_data_collection = torch.cat(motion_data_collection, dim=0) # (num_motions, num_trajs, traj_len, obs_dim) 53 | self.state_transitions_mean = motion_data_collection.flatten(0, 2).mean(dim=0) 54 | self.state_transitions_std = motion_data_collection.flatten(0, 2).std(dim=0) + 1e-6 55 | 56 | # Unfold the data to prepare for training 57 | # num_steps denotes the trajectory length induced by bootstrapping the window of history_horizon forward with forecast_horizon steps 58 | # num_groups denotes the number of such num_steps 59 | motion_data_collection = motion_data_collection.unfold(2, self.history_horizon + self.forecast_horizon - 1, 1).swapaxes(-2, -1) # (num_motions, num_trajs, num_groups, num_steps, obs_dim) 60 | self.state_transitions_data = (motion_data_collection - self.state_transitions_mean) / self.state_transitions_std # (num_motions, num_trajs, num_groups, num_steps, obs_dim) 61 | 62 | def train(self, log_dir, latent_dim): 63 | """ 64 | Trains the FLD model. 65 | 66 | Args: 67 | log_dir (str): The directory to save the training logs. 68 | latent_dim (int): The dimensionality of the latent space. 69 | 70 | """ 71 | fld_training = FLDTraining( 72 | log_dir, 73 | latent_dim, 74 | self.history_horizon, 75 | self.forecast_horizon, 76 | self.state_idx_dict, 77 | self.state_transitions_data, 78 | self.state_transitions_mean, 79 | self.state_transitions_std, 80 | fld_encoder_shape=[64, 64], 81 | fld_decoder_shape=[64, 64], 82 | fld_learning_rate=0.0001, 83 | fld_weight_decay=0.0005, 84 | fld_num_mini_batches=10, 85 | device="cuda", 86 | loss_function="geometric", 87 | noise_level=0.1, 88 | ) 89 | fld_training.train(max_iterations=5000) 90 | fld_training.fit_gmm(covariance_type="full") 91 | 92 | 93 | if __name__ == "__main__": 94 | state_idx_dict = { 95 | "base_pos": [0, 1, 2], 96 | "base_quat": [3, 4, 5, 6], 97 | "base_lin_vel": [7, 8, 9], 98 | "base_ang_vel": [10, 11, 12], 99 | "projected_gravity": [13, 14, 15], 100 | "dof_pos_leg_l": [16, 17, 18, 19, 20], 101 | "dof_pos_arm_l": [21, 22, 23, 24], 102 | "dof_pos_leg_r": [25, 26, 27, 28, 29], 103 | "dof_pos_arm_r": [30, 31, 32, 33], 104 | } 105 | history_horizon = 51 # the window size of the input state transitions 106 | latent_dim = 8 107 | forecast_horizon = 50 # the autoregressive prediction steps while obeying the quasi-constant latent parameterization 108 | device = "cuda" 109 | log_dir_root = LEGGED_GYM_ROOT_DIR + "/logs/flat_mit_humanoid/fld/" 110 | log_dir = log_dir_root + "misc" 111 | fld_experiment = FLDExperiment(state_idx_dict, history_horizon, forecast_horizon, device) 112 | fld_experiment.prepare_data() 113 | fld_experiment.train(log_dir, latent_dim) 114 | -------------------------------------------------------------------------------- /scripts/fld/preview.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plays a trained policy and logs statistics. 3 | """ 4 | 5 | # humanoid-gym 6 | from humanoid_gym import LEGGED_GYM_ROOT_DIR 7 | from humanoid_gym.envs import task_registry 8 | from humanoid_gym.utils import get_args, export_policy_as_jit, export_policy_as_onnx, Logger 9 | from humanoid_gym.utils.keyboard_controller import KeyBoardController, KeyboardAction, Delta, Switch 10 | 11 | # isaac-gym 12 | from isaacgym import gymtorch 13 | from isaacgym.torch_utils import ( 14 | quat_rotate, 15 | ) 16 | 17 | # learning 18 | from learning.datasets.motion_loader import MotionLoader 19 | 20 | # python 21 | import argparse 22 | import os 23 | import numpy as np 24 | import json 25 | import torch 26 | 27 | # global settings 28 | EXPORT_POLICY = True 29 | RECORD_FRAMES = False 30 | MOVE_CAMERA = False 31 | PLAY_LOADED_DATA = True # True: play loaded data, False: play with FLD latent parameters 32 | PLOT = True 33 | 34 | from learning.modules.fld import FLD 35 | from learning.modules.plotter import Plotter 36 | import matplotlib.pyplot as plt 37 | 38 | 39 | def preview(args: argparse.Namespace): 40 | args.task = "mit_humanoid" 41 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) 42 | # override some parameters for testing 43 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 2) 44 | env_cfg.terrain.num_rows = 5 45 | env_cfg.terrain.num_cols = 5 46 | env_cfg.terrain.curriculum = False 47 | env_cfg.noise.add_noise = False 48 | env_cfg.domain_rand.randomize_friction = False 49 | env_cfg.domain_rand.push_robots = False 50 | 51 | env_cfg.commands.resampling_time = 1000 52 | env_cfg.commands.ranges.lin_vel_x = [0.0, 0.0] 53 | env_cfg.commands.ranges.lin_vel_y = [0.0, 0.0] 54 | env_cfg.commands.ranges.ang_vel_yaw = [0.0, 0.0] 55 | env_cfg.domain_rand.added_mass_range = [0.0, 0.0] 56 | 57 | # prepare environment 58 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) 59 | obs = env.get_observations() 60 | 61 | def _zero_torques(self, actions): 62 | return torch.zeros_like(actions) 63 | 64 | env._compute_torques = type(env._compute_torques)(_zero_torques, env) 65 | 66 | if PLAY_LOADED_DATA: 67 | motion_file = LEGGED_GYM_ROOT_DIR + "/resources/robots/mit_humanoid/datasets/misc/motion_data_run.pt" 68 | motion_loader = MotionLoader( 69 | device=env.device, 70 | motion_file=motion_file, 71 | reference_history_horizon=env.fld_history_horizon, 72 | ) 73 | else: 74 | state_idx_dict = { 75 | "base_pos": [0, 1, 2], 76 | "base_quat": [3, 4, 5, 6], 77 | "base_lin_vel": [7, 8, 9], 78 | "base_ang_vel": [10, 11, 12], 79 | "projected_gravity": [13, 14, 15], 80 | "dof_pos_leg_l": [16, 17, 18, 19, 20], 81 | "dof_pos_arm_l": [21, 22, 23, 24], 82 | "dof_pos_leg_r": [25, 26, 27, 28, 29], 83 | "dof_pos_arm_r": [30, 31, 32, 33], 84 | } 85 | 86 | dim_of_interest = torch.cat([torch.tensor(ids, device=env.device, dtype=torch.long, requires_grad=False) for state, ids in state_idx_dict.items() if ((state != "base_pos") and (state != "base_quat"))]) 87 | observation_dim = dim_of_interest.size(0) 88 | history_horizon = 51 89 | log_dir_root = LEGGED_GYM_ROOT_DIR + "/logs/flat_mit_humanoid/fld/" 90 | latent_dim = 8 91 | 92 | fld = FLD(observation_dim, history_horizon, latent_dim, env.device, encoder_shape=env_cfg.fld.encoder_shape, decoder_shape=env_cfg.fld.decoder_shape) 93 | 94 | runs = os.listdir(log_dir_root) 95 | runs.sort() 96 | last_run = os.path.join(log_dir_root, runs[-1]) 97 | load_run = last_run 98 | models = [file for file in os.listdir(load_run) if "model" in file] 99 | models.sort(key=lambda m: "{0:0>15}".format(m)) 100 | model = models[-1] 101 | 102 | loaded_dict = torch.load(os.path.join(load_run, model)) 103 | fld.load_state_dict(loaded_dict["fld_state_dict"]) 104 | 105 | datasets_root = os.path.join(LEGGED_GYM_ROOT_DIR + "/resources/robots/mit_humanoid/datasets/misc") 106 | motion_data = os.listdir(datasets_root) 107 | motion_name_set = [data.replace('motion_data_', '').replace('.pt', '') for data in motion_data if "combined" not in data and ".pt" in data] 108 | aggregated_data_collection = [] 109 | 110 | if PLOT: 111 | plotter = Plotter() 112 | plt.ion() 113 | fig1, ax1 = plt.subplots(4, 1) 114 | 115 | for i, motion_name in enumerate(motion_name_set): 116 | motion_path = os.path.join(datasets_root, "motion_data_" + motion_name + ".pt") 117 | motion_data = torch.load(motion_path, map_location=env.device)[:, :, dim_of_interest] 118 | loaded_num_trajs, loaded_num_steps, loaded_obs_dim = motion_data.size() 119 | print(f"[Motion Loader] Loaded motion {motion_name} with {loaded_num_trajs} trajectories, {loaded_num_steps} steps with {loaded_obs_dim} dimensions.") 120 | aggregated_data = torch.zeros(loaded_num_trajs, 121 | loaded_num_steps - history_horizon + 1, 122 | history_horizon, 123 | loaded_obs_dim, 124 | dtype=torch.float, 125 | device=env.device, 126 | requires_grad=False 127 | ) 128 | for step in range(loaded_num_steps - history_horizon + 1): 129 | aggregated_data[:, step] = motion_data[:, step:step+history_horizon, :] 130 | aggregated_data_collection.append(aggregated_data.unsqueeze(0)) 131 | 132 | aggregated_data_collection = torch.cat(aggregated_data_collection, dim=0) 133 | 134 | state_transitions_mean = aggregated_data_collection.flatten(0, 3).mean(dim=0) 135 | state_transitions_std = aggregated_data_collection.flatten(0, 3).std(dim=0) 136 | state_transitions_data = (aggregated_data_collection - state_transitions_mean) / state_transitions_std 137 | 138 | fld.eval() 139 | 140 | eval_traj = state_transitions_data[0, 0].swapaxes(1, 2) 141 | with torch.no_grad(): 142 | _, _, _, params = fld(eval_traj) 143 | env.latent_sample_phase = params[0][0, :] 144 | env.latent_sample_frequency = params[1][0, :] 145 | env.latent_sample_amplitude = params[2][0, :] 146 | env.latent_sample_offset = params[3][0, :] 147 | 148 | 149 | latent_variable_list = ['phase', 'frequency', 'amplitude', 'offset'] 150 | env.latent_variable_selector = torch.tensor([-1], device=env.device, dtype=torch.long, requires_grad=False) 151 | env.latent_channel_selector = torch.tensor([-1], device=env.device, dtype=torch.long, requires_grad=False) 152 | env.latent_value_modifier = torch.tensor([0.0], device=env.device, dtype=torch.float, requires_grad=False) 153 | 154 | def print_selector(): 155 | print(f"latent_variable_selector: {env.latent_variable_selector}") 156 | print(f"latent_channel_selector: {env.latent_channel_selector}") 157 | 158 | def print_command(): 159 | latent_variable = latent_variable_list[env.latent_variable_selector] 160 | latent_variable_name = "latent_sample_" + str(latent_variable) 161 | print(f"{latent_variable}: {getattr(env, latent_variable_name)}") 162 | 163 | key_board_events = { 164 | "p": Switch("latent_variable", start_state=0, toggle_state=0, variable_reference=env.latent_variable_selector, callback=print_selector), 165 | "f": Switch("latent_variable", start_state=1, toggle_state=1, variable_reference=env.latent_variable_selector, callback=print_selector), 166 | "a": Switch("latent_variable", start_state=2, toggle_state=2, variable_reference=env.latent_variable_selector, callback=print_selector), 167 | "o": Switch("latent_variable", start_state=3, toggle_state=3, variable_reference=env.latent_variable_selector, callback=print_selector), 168 | "u": Delta("latent_value_modifier", amount=0.1, variable_reference=env.latent_value_modifier, callback=print_command), 169 | "j": Delta("latent_value_modifier", amount=-0.1, variable_reference=env.latent_value_modifier, callback=print_command), 170 | } 171 | for i in range(latent_dim): 172 | key_board_events[f"{i}"] = Switch("latent_channel", start_state=i, toggle_state=i, variable_reference=env.latent_channel_selector, callback=print_selector) 173 | 174 | env.keyboard_controller = KeyBoardController(env, key_board_events) 175 | env.keyboard_controller.print_options() 176 | 177 | 178 | for i in range(10 * int(env.max_episode_length)): 179 | env.update_keyboard_events() 180 | # actions = policy(obs.detach()) 181 | actions = torch.zeros_like(env.actions) 182 | # obs, _, rews, dones, infos = env.step(actions.detach()) 183 | env.render() 184 | env.gym.simulate(env.sim) 185 | 186 | if PLAY_LOADED_DATA: 187 | frames = motion_loader.data[0, i % motion_loader.num_steps].repeat(env.num_envs, 1) 188 | state_idx_dict = motion_loader.state_idx_dict 189 | else: 190 | with open(datasets_root + "/reference_state_idx_dict.json", 'r') as f: 191 | state_idx_dict = json.load(f) 192 | getattr(env, f"latent_sample_{latent_variable_list[env.latent_variable_selector]}")[env.latent_channel_selector] += env.latent_value_modifier 193 | env.latent_value_modifier[:] = 0.0 194 | env.latent_sample_phase += env.latent_sample_frequency * env.dt 195 | latent_sample_z = env.latent_sample_amplitude.unsqueeze(-1) * torch.sin(2 * torch.pi * (env.latent_sample_frequency.unsqueeze(-1) * fld.args + env.latent_sample_phase.unsqueeze(-1))) + env.latent_sample_offset.unsqueeze(-1) 196 | with torch.no_grad(): 197 | decoded_traj_pred = fld.decoder(latent_sample_z.unsqueeze(0)) 198 | decoded_traj_raw = decoded_traj_pred.swapaxes(1, 2) 199 | decoded_traj = decoded_traj_raw * state_transitions_std + state_transitions_mean 200 | 201 | decoded_traj = decoded_traj[0, 0, :] 202 | 203 | decoded_traj_buf = torch.zeros(52, device=env.device, dtype=torch.float, requires_grad=False) 204 | decoded_traj_buf[2] = 0.66 205 | decoded_traj_buf[6] = 1.0 206 | 207 | decoded_traj_buf[dim_of_interest] = decoded_traj 208 | 209 | frames = decoded_traj_buf.repeat(env.num_envs, 1) 210 | 211 | if PLOT: 212 | plotter.plot_circles(ax1[0], env.latent_sample_phase, env.latent_sample_amplitude, title="Learned Phase Timing" + " " + str(latent_dim) + "x" + str(2), show_axes=False) 213 | plotter.plot_curves(ax1[1], latent_sample_z, -1.0, 1.0, -2.0, 2.0, title="Latent Parametrized Signal" + " " + str(latent_dim) + "x" + str(history_horizon), show_axes=False) 214 | plotter.plot_curves(ax1[2], decoded_traj_pred.squeeze(0), -1.0, 1.0, -5.0, 5.0, title="Curve Reconstruction" + " " + str(fld.input_channel) + "x" + str(history_horizon), show_axes=False) 215 | plotter.plot_curves(ax1[3], decoded_traj_pred.flatten(1, 2), -1.0, 1.0, -5.0, 5.0, title="Curve Reconstruction (Flattened)" + " " + str(1) + "x" + str(fld.input_channel*history_horizon), show_axes=False) 216 | fig1.canvas.draw() 217 | fig1.canvas.flush_events() 218 | 219 | 220 | env.dof_pos[:] = MotionLoader.get_dof_pos(state_idx_dict, frames) 221 | env.dof_vel[:] = MotionLoader.get_dof_vel(state_idx_dict, frames) 222 | root_pos = MotionLoader.get_base_pos(state_idx_dict, frames) 223 | root_pos[:, :2] = root_pos[:, :2] + env.env_origins[:, :2] 224 | env.root_states[:, :3] = root_pos 225 | root_ori = MotionLoader.get_base_quat(state_idx_dict, frames) 226 | env.root_states[:, 3:7] = root_ori 227 | env.root_states[:, 7:10] = quat_rotate(root_ori, MotionLoader.get_base_lin_vel(state_idx_dict, frames)) 228 | env.root_states[:, 10:13] = quat_rotate(root_ori, MotionLoader.get_base_ang_vel(state_idx_dict, frames)) 229 | 230 | env_ids_int32 = torch.arange(env.num_envs, device=env.device).to(dtype=torch.int32) 231 | env.gym.set_dof_state_tensor_indexed(env.sim, 232 | gymtorch.unwrap_tensor(env.dof_state), 233 | gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) 234 | 235 | env.gym.set_actor_root_state_tensor_indexed(env.sim, 236 | gymtorch.unwrap_tensor(env.root_states), 237 | gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) 238 | 239 | 240 | if __name__ == "__main__": 241 | args = get_args() 242 | preview(args) 243 | -------------------------------------------------------------------------------- /scripts/play.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plays a trained policy and logs statistics. 3 | """ 4 | 5 | # humanoid-gym 6 | from humanoid_gym import LEGGED_GYM_ROOT_DIR 7 | from humanoid_gym.envs import task_registry 8 | from humanoid_gym.utils import get_args, export_policy_as_jit, export_policy_as_onnx, Logger 9 | 10 | # python 11 | import argparse 12 | import os 13 | import numpy as np 14 | import torch 15 | 16 | # global settings 17 | EXPORT_POLICY = True 18 | MOVE_CAMERA = True 19 | 20 | def play(args: argparse.Namespace): 21 | args.task = "mit_humanoid" 22 | env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) 23 | # override some parameters for testing 24 | env_cfg.env.num_envs = min(env_cfg.env.num_envs, 2) 25 | env_cfg.terrain.num_rows = 5 26 | env_cfg.terrain.num_cols = 5 27 | env_cfg.terrain.curriculum = False 28 | env_cfg.noise.add_noise = False 29 | env_cfg.domain_rand.randomize_friction = False 30 | env_cfg.domain_rand.push_robots = False 31 | 32 | env_cfg.domain_rand.added_mass_range = [0.0, 0.0] 33 | env_cfg.commands.resampling_time = 1000.0 34 | env_cfg.commands.ranges.lin_vel_x = [0.3, 0.3] 35 | env_cfg.env.episode_length_s = 1000.0 36 | env_cfg.env.env_spacing = 100.0 37 | env_cfg.domain_rand.latent_encoding_update_noise_level = 0.0 38 | 39 | # camera 40 | env_cfg.viewer.pos = [0.0, -2.13, 1.22] 41 | dir = [0.0, 1.0, -0.4] 42 | env_cfg.viewer.lookat = [a + b for a, b in zip(env_cfg.viewer.pos, dir)] 43 | 44 | # prepare environment 45 | env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) 46 | obs = env.get_observations() 47 | # load policy 48 | train_cfg.runner.resume = True 49 | ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args, train_cfg=train_cfg) 50 | policy = ppo_runner.get_inference_policy(device=env.device) 51 | 52 | motion_data_collection = [] 53 | datasets_root = os.path.join(LEGGED_GYM_ROOT_DIR + "/resources/robots/mit_humanoid/datasets/misc") 54 | motion_data = os.listdir(datasets_root) 55 | motion_name_set = [data.replace('motion_data_', '').replace('.pt', '') for data in motion_data if "combined" not in data and ".pt" in data] 56 | for i, motion_name in enumerate(motion_name_set): 57 | motion_path = os.path.join(datasets_root, "motion_data_" + motion_name + ".pt") 58 | motion_data = torch.load(motion_path, map_location=env.device).to(env.device) 59 | motion_data_collection.append(motion_data.unsqueeze(0)) 60 | motion_data_collection = torch.cat(motion_data_collection, dim=0) 61 | motion_data_num_motions, motion_data_num_trajs, motion_data_num_steps, motion_data_observation_dim = motion_data_collection.size() 62 | motion_idx = 0 63 | traj_idx = 0 64 | t = 0 65 | ood_motion = torch.tensor([], device=env.device, dtype=torch.float, requires_grad=False) 66 | last_input = "" 67 | 68 | # export policy as a jit module and as onnx model (used to run it from C++) 69 | if EXPORT_POLICY: 70 | path = os.path.join( 71 | LEGGED_GYM_ROOT_DIR, 72 | "logs", 73 | train_cfg.runner.experiment_name, 74 | "exported", 75 | "policies", 76 | ) 77 | name = "policy" 78 | export_policy_as_jit(ppo_runner.alg.actor_critic, path, filename=f"{name}.pt") 79 | export_policy_as_onnx(ppo_runner.alg.actor_critic, path, filename=f"{name}.onnx") 80 | print("Exported policy to: ", path) 81 | 82 | logger = Logger(env.dt) 83 | robot_index = 1 # which robot is used for logging 84 | joint_index = 3 # which joint is used for logging 85 | stop_state_log = 100 # number of steps before plotting states 86 | stop_rew_log = env.max_episode_length + 1 # number of steps before print average episode rewards 87 | camera_position = np.array(env_cfg.viewer.pos, dtype=np.float64) 88 | camera_direction = np.array(env_cfg.viewer.lookat) - np.array(env_cfg.viewer.pos) 89 | 90 | for i in range(10 * int(env.max_episode_length)): 91 | actions = policy(obs.detach()) 92 | obs, _, rews, dones, infos = env.step(actions.detach()) 93 | 94 | user_input = motion_data_collection[motion_idx, traj_idx, t:t+env.fld_history_horizon, :] 95 | if t + env.fld_history_horizon >= motion_data_num_steps: 96 | motion_idx = torch.randint(0, motion_data_num_motions, (1,)).item() 97 | traj_idx = torch.randint(0, motion_data_num_trajs, (1,)).item() 98 | print(f"[Motion] Motion name {motion_name_set[motion_idx]}") 99 | print(f"[Motion] Trajectory index {traj_idx}") 100 | t = 0 101 | else: 102 | t += 1 103 | 104 | # compute dynamics error 105 | dynamics_error = env.get_latent_dynamics_error(user_input.unsqueeze(0), k=1) 106 | # fallback mechanism 107 | if dynamics_error < 1.0: 108 | if last_input != "user": 109 | print("[Input] User") 110 | latent_encoding = env.latent_encoding[:].clone() 111 | latent_encoding[:, :, 1:] = env.get_latent_encoding_from_transitions(user_input.repeat(env.num_envs, 1, 1))[:, :, 1:] 112 | env.latent_encoding[:] = latent_encoding 113 | last_input = "user" 114 | else: 115 | if last_input != "default": 116 | print("[Input] Default") 117 | ood_motion = torch.cat((ood_motion, user_input[[-1]]), dim=0) 118 | last_input = "default" 119 | 120 | # compute tracking error 121 | tracking_reconstructed_terms = [name for name in env.reward_scales.keys() if "tracking_reconstructed" in name] 122 | tracking_error = torch.mean(torch.hstack([env.rew[name].unsqueeze(-1) / env.reward_scales[name] * env.dt for name in tracking_reconstructed_terms]), dim=-1) 123 | print(f"[FLD] Tracking error: {tracking_error}") 124 | 125 | if MOVE_CAMERA: 126 | camera_position = env.root_states[0, :3].cpu().numpy() 127 | camera_position[1] -= 2.13 128 | camera_position[2] = 1.22 129 | env.set_camera(camera_position, camera_position + camera_direction) 130 | 131 | if i < stop_state_log: 132 | logger.log_states( 133 | { 134 | "dof_pos_target": (actions * env.cfg.control.action_scale + env.default_dof_pos)[robot_index, joint_index].item(), 135 | "dof_pos": env.dof_pos[robot_index, joint_index].item(), 136 | "dof_vel": env.dof_vel[robot_index, joint_index].item(), 137 | "dof_torque": env.torques[robot_index, joint_index].item(), 138 | "command_x": env.commands[robot_index, 0].item(), 139 | "command_x": env.commands[robot_index, 0].item(), 140 | "command_y": env.commands[robot_index, 1].item(), 141 | "command_yaw": env.commands[robot_index, 2].item(), 142 | "base_vel_x": env.base_lin_vel[robot_index, 0].item(), 143 | "base_vel_y": env.base_lin_vel[robot_index, 1].item(), 144 | "base_vel_z": env.base_lin_vel[robot_index, 2].item(), 145 | "base_vel_yaw": env.base_ang_vel[robot_index, 2].item(), 146 | "contact_forces_z": env.contact_forces[robot_index, env.feet_indices, 2].cpu().numpy(), 147 | "base_lin_vel": env.fld_state[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["base_lin_vel"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 148 | "base_lin_vel_ref": env.decoded_obs[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["base_lin_vel"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 149 | "base_ang_vel": env.fld_state[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["base_ang_vel"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 150 | "base_ang_vel_ref": env.decoded_obs[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["base_ang_vel"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 151 | "projected_gravity": env.fld_state[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["projected_gravity"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 152 | "projected_gravity_ref": env.decoded_obs[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["projected_gravity"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 153 | "dof_pos_leg_l": env.fld_state[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["dof_pos_leg_l"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 154 | "dof_pos_leg_l_ref": env.decoded_obs[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["dof_pos_leg_l"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 155 | "dof_pos_arm_l": env.fld_state[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["dof_pos_arm_l"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 156 | "dof_pos_arm_l_ref": env.decoded_obs[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["dof_pos_arm_l"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 157 | "dof_pos_leg_r": env.fld_state[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["dof_pos_leg_r"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 158 | "dof_pos_leg_r_ref": env.decoded_obs[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["dof_pos_leg_r"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 159 | "dof_pos_arm_r": env.fld_state[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["dof_pos_arm_r"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 160 | "dof_pos_arm_r_ref": env.decoded_obs[robot_index, torch.tensor(env.decoded_obs_state_idx_dict["dof_pos_arm_r"], device=env.device, dtype=torch.long, requires_grad=False)].tolist(), 161 | } 162 | ) 163 | elif i == stop_state_log: 164 | logger.plot_states() 165 | if 0 < i < stop_rew_log: 166 | if infos["episode"]: 167 | num_episodes = torch.sum(env.reset_buf).item() 168 | if num_episodes > 0: 169 | logger.log_rewards(infos["episode"], num_episodes) 170 | elif i == stop_rew_log: 171 | logger.print_rewards() 172 | 173 | 174 | if __name__ == "__main__": 175 | args = get_args() 176 | play(args) 177 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main script for launching a training session. 3 | """ 4 | # humanoid-gym 5 | from humanoid_gym.envs import task_registry 6 | from humanoid_gym.utils import get_args 7 | 8 | 9 | def train(args): 10 | env, env_cfg = task_registry.make_env(name=args.task, args=args) 11 | ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args) 12 | ppo_runner.learn( 13 | num_learning_iterations=train_cfg.runner.max_iterations, 14 | init_at_random_ep_len=True, 15 | ) 16 | 17 | 18 | if __name__ == "__main__": 19 | args = get_args() 20 | train(args) 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Installation script for the 'humanoid_gym' python package.""" 2 | 3 | from setuptools import setup, find_packages 4 | 5 | # Minimum dependencies required prior to installation 6 | INSTALL_REQUIRES = [ 7 | "isaacgym", 8 | "matplotlib", 9 | "tensorboard<=2.11.0", 10 | "torch>=1.4.0", 11 | "torchvision>=0.5.0", 12 | "onnx", 13 | "numpy>=1.16.4,<=1.22.4", 14 | "setuptools==59.5.0", 15 | "gym>=0.17.1", 16 | "GitPython", 17 | "scikit-learn>=1.2.1", 18 | "faiss-gpu>=1.7.2", 19 | "Pillow==9.5.0", 20 | ] 21 | 22 | # Installation operation 23 | setup( 24 | name="humanoid_gym", 25 | version="1.0.0", 26 | author="Chenhao Li", 27 | packages=find_packages(), 28 | author_email="chenhli@ethz.ch", 29 | description="Isaac Gym environments for MIT Humanoid", 30 | install_requires=INSTALL_REQUIRES, 31 | ) 32 | --------------------------------------------------------------------------------