├── .gitignore ├── .gitmodules ├── LICENSE.md ├── README.md ├── envs ├── common │ ├── config_builder.py │ ├── mujoco_env.py │ └── robot_interface.py ├── h1 │ ├── __init__.py │ ├── configs │ │ └── base.yaml │ ├── gen_xml.py │ └── h1_env.py └── jvrc │ ├── __init__.py │ ├── configs │ └── base.yaml │ ├── gen_xml.py │ ├── jvrc_step.py │ └── jvrc_walk.py ├── models └── __init__.py ├── rl ├── __init__.py ├── algos │ ├── __init__.py │ └── ppo.py ├── distributions │ ├── __init__.py │ ├── beta.py │ └── gaussian.py ├── envs │ ├── __init__.py │ ├── normalize.py │ └── wrappers.py ├── policies │ ├── __init__.py │ ├── actor.py │ ├── base.py │ └── critic.py ├── storage │ └── rollout_storage.py └── utils │ └── eval.py ├── robots └── robot_base.py ├── run_experiment.py ├── scripts └── debug_stepper.py ├── tasks ├── __init__.py ├── rewards.py ├── stepping_task.py └── walking_task.py └── utils └── footstep_plans.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | __pycache__* 3 | *.pyc 4 | *.egg-info 5 | *.#~ 6 | logs* 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "models/jvrc_mj_description"] 2 | path = models/jvrc_mj_description 3 | url = git@github.com:rohanpsingh/jvrc_mj_description.git 4 | [submodule "models/cassie_mj_description"] 5 | path = models/cassie_mj_description 6 | url = git@github.com:rohanpsingh/cassie_mj_description.git 7 | [submodule "models/mujoco_menagerie"] 8 | path = models/mujoco_menagerie 9 | url = git@github.com:google-deepmind/mujoco_menagerie 10 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022, Rohan P. Singh 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LearningHumanoidWalking 2 | 3 |

4 | humanoid-walk 5 |

6 | 7 | Code for the papers: 8 | - [**Robust Humanoid Walking on Compliant and Uneven Terrain with Deep Reinforcement Learning**](https://ieeexplore.ieee.org/abstract/document/10769793) 9 | [Rohan P. Singh](https://rohanpsingh.github.io), [Mitsuharu Morisawa](https://unit.aist.go.jp/jrl-22022/en/members/member-morisawa.html), [Mehdi Benallegue](https://unit.aist.go.jp/jrl-22022/en/members/member-benalleguem.html), [Zhaoming Xie](https://zhaomingxie.github.io/), [Fumio Kanehiro](https://unit.aist.go.jp/jrl-22022/en/members/member-kanehiro.html) 10 | 11 | - [**Learning Bipedal Walking for Humanoids with Current Feedback**](https://arxiv.org/pdf/2303.03724.pdf) 12 | [Rohan P. Singh](https://rohanpsingh.github.io), [Zhaoming Xie](https://zhaomingxie.github.io/), [Pierre Gergondet](https://unit.aist.go.jp/jrl-22022/en/members/member-gergondet.html), [Fumio Kanehiro](https://unit.aist.go.jp/jrl-22022/en/members/member-kanehiro.html) 13 | 14 | - [**Learning Bipedal Walking On Planned Footsteps For Humanoid Robots**](https://arxiv.org/pdf/2207.12644.pdf) 15 | [Rohan P. Singh](https://rohanpsingh.github.io), [Mehdi Benallegue](https://unit.aist.go.jp/jrl-22022/en/members/member-benalleguem.html), [Mitsuharu Morisawa](https://unit.aist.go.jp/jrl-22022/en/members/member-morisawa.html), [Rafael Cisneros](https://unit.aist.go.jp/jrl-22022/en/members/member-cisneros.html), [Fumio Kanehiro](https://unit.aist.go.jp/jrl-22022/en/members/member-kanehiro.html) 16 | 17 | 18 | ## Code structure: 19 | A rough outline for the repository that might be useful for adding your own robot: 20 | ``` 21 | LearningHumanoidWalking/ 22 | ├── envs/ <-- Actions and observation space, PD gains, simulation step, control decimation, init, ... 23 | ├── tasks/ <-- Reward function, termination conditions, and more... 24 | ├── rl/ <-- Code for PPO, actor/critic networks, observation normalization process... 25 | ├── models/ <-- MuJoCo model files: XMLs/meshes/textures 26 | └── scripts/ <-- Utility scripts, etc. 27 | ``` 28 | 29 | ## Requirements: 30 | - Python version: 3.12.4 31 | - pip install: 32 | - mujoco==3.2.2 33 | - ray==2.40.0 34 | - pytorch=2.5.1 35 | - intel-openmp 36 | - [mujoco-python-viewer](https://github.com/rohanpsingh/mujoco-python-viewer) 37 | - transforms3d 38 | - scipy 39 | 40 | ## Usage: 41 | 42 | Environment names supported: 43 | 44 | | Task Description | Environment name | 45 | | ----------- | ----------- | 46 | | Basic Standing Task | 'h1' | 47 | | Basic Walking Task | 'jvrc_walk' | 48 | | Stepping Task (using footsteps) | 'jvrc_step' | 49 | 50 | 51 | #### **To train:** 52 | 53 | ``` 54 | $ python run_experiment.py train --logdir --num_procs --env 55 | ``` 56 | 57 | 58 | #### **To play:** 59 | 60 | ``` 61 | $ python run_experiment.py eval --logdir 62 | ``` 63 | 64 | Or, we could write a rollout script specific to each environment. 65 | For example, `debug_stepper.py` can be used with the `jvrc_step` environment. 66 | ``` 67 | $ PYTHONPATH=.:$PYTHONPATH python scripts/debug_stepper.py --path 68 | ``` 69 | 70 | #### **What you should see:** 71 | 72 | *Ascending stairs:* 73 | ![climb_up](https://user-images.githubusercontent.com/16384313/180697513-25796b1a-87e0-4ab2-9e5f-d86c58ebea36.gif) 74 | 75 | *Descending stairs:* 76 | ![climb_down](https://user-images.githubusercontent.com/16384313/180697788-d1a2eec0-0d3d-451a-95e0-9f0e60191c34.gif) 77 | 78 | *Walking on curves:* 79 | ![curve](https://user-images.githubusercontent.com/16384313/180697266-7b44beb3-38bf-4494-b568-963919dc1106.gif) 80 | 81 | 82 | ## Citation 83 | If you find this work useful in your own research, please cite the following works: 84 | 85 | For omnidirectional walking: 86 | ``` 87 | @inproceedings{singh2024robust, 88 | title={Robust Humanoid Walking on Compliant and Uneven Terrain with Deep Reinforcement Learning}, 89 | author={Singh, Rohan P and Morisawa, Mitsuharu and Benallegue, Mehdi and Xie, Zhaoming and Kanehiro, Fumio}, 90 | booktitle={2024 IEEE-RAS 23rd International Conference on Humanoid Robots (Humanoids)}, 91 | pages={497--504}, 92 | year={2024}, 93 | organization={IEEE} 94 | } 95 | ``` 96 | 97 | For simulating "back-emf" effect and other randomizations: 98 | ``` 99 | @article{xie2023learning, 100 | title={Learning bipedal walking for humanoids with current feedback}, 101 | author={Singh, Rohan Pratap and Xie, Zhaoming and Gergondet, Pierre and Kanehiro, Fumio}, 102 | journal={IEEE Access}, 103 | volume={11}, 104 | pages={82013--82023}, 105 | year={2023}, 106 | publisher={IEEE} 107 | } 108 | ``` 109 | 110 | For walking on footsteps: 111 | 112 | ``` 113 | @inproceedings{singh2022learning, 114 | title={Learning Bipedal Walking On Planned Footsteps For Humanoid Robots}, 115 | author={Singh, Rohan P and Benallegue, Mehdi and Morisawa, Mitsuharu and Cisneros, Rafael and Kanehiro, Fumio}, 116 | booktitle={2022 IEEE-RAS 21st International Conference on Humanoid Robots (Humanoids)}, 117 | pages={686--693}, 118 | year={2022}, 119 | organization={IEEE} 120 | } 121 | ``` 122 | 123 | ### Credits 124 | The code in this repository was heavily inspired from [apex](https://github.com/osudrl/apex). Clock-based reward terms and some other ideas were originally proposed by the team from OSU DRL for the Cassie robot, so please also consider citing the works of Jonah Siekmann, Helei Duan, Jeremy Dao, and others. 125 | 126 | -------------------------------------------------------------------------------- /envs/common/config_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from typing import Any, Dict, Optional, Union, List 4 | 5 | class Configuration: 6 | """A class to handle configuration data with attribute-style access.""" 7 | 8 | def __init__(self, **kwargs: Any) -> None: 9 | """ 10 | Initialize a Configuration object with nested attribute access. 11 | 12 | Args: 13 | **kwargs: Key-value pairs to be set as attributes. 14 | """ 15 | for key, value in kwargs.items(): 16 | if isinstance(value, dict): 17 | setattr(self, key, Configuration(**value)) 18 | elif isinstance(value, list) and all(isinstance(item, dict) for item in value): 19 | setattr(self, key, [Configuration(**item) for item in value]) 20 | else: 21 | setattr(self, key, value) 22 | 23 | def __repr__(self) -> str: 24 | """Return a string representation of the configuration.""" 25 | return str(self.__dict__) 26 | 27 | def __getattr__(self, name: str) -> None: 28 | """Return None for non-existent attributes instead of raising an error.""" 29 | return None 30 | 31 | def to_dict(self) -> Dict[str, Any]: 32 | """Convert the Configuration object back to a dictionary.""" 33 | result = {} 34 | for key, value in self.__dict__.items(): 35 | if isinstance(value, Configuration): 36 | result[key] = value.to_dict() 37 | elif isinstance(value, list) and value and isinstance(value[0], Configuration): 38 | result[key] = [item.to_dict() if isinstance(item, Configuration) else item for item in value] 39 | else: 40 | result[key] = value 41 | return result 42 | 43 | def load_yaml(file_path: str) -> Configuration: 44 | """ 45 | Load configuration from a YAML file. 46 | 47 | Args: 48 | file_path: Path to the YAML file. 49 | 50 | Returns: 51 | Configuration object with data from the YAML file. 52 | 53 | Raises: 54 | FileNotFoundError: If the file doesn't exist. 55 | yaml.YAMLError: If there's an error parsing the YAML. 56 | """ 57 | if not os.path.exists(file_path): 58 | raise FileNotFoundError(f"Configuration file not found: {file_path}") 59 | 60 | with open(file_path, 'r') as file: 61 | try: 62 | config_data = yaml.safe_load(file) 63 | except yaml.YAMLError as e: 64 | raise yaml.YAMLError(f"Error parsing YAML file {file_path}: {e}") 65 | return Configuration(**config_data) 66 | -------------------------------------------------------------------------------- /envs/common/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | import numpy as np 4 | import mujoco 5 | import mujoco_viewer 6 | 7 | DEFAULT_SIZE = 500 8 | 9 | class MujocoEnv(): 10 | """Superclass for all MuJoCo environments. 11 | """ 12 | 13 | def __init__(self, model_path, sim_dt, control_dt): 14 | if model_path.startswith("/"): 15 | fullpath = model_path 16 | else: 17 | raise Exception("Provide full path to robot description package.") 18 | if not os.path.exists(fullpath): 19 | raise IOError("File %s does not exist" % fullpath) 20 | 21 | self.spec = mujoco.MjSpec() 22 | self.spec.from_file(fullpath) 23 | self.model = self.spec.compile() 24 | self.data = mujoco.MjData(self.model) 25 | self.viewer = None 26 | 27 | # set frame skip and sim dt 28 | self.frame_skip = (control_dt/sim_dt) 29 | self.model.opt.timestep = sim_dt 30 | 31 | self.init_qpos = self.data.qpos.ravel().copy() 32 | self.init_qvel = self.data.qvel.ravel().copy() 33 | 34 | # methods to override: 35 | # ---------------------------- 36 | 37 | def reset_model(self): 38 | """ 39 | Reset the robot degrees of freedom (qpos and qvel). 40 | Implement this in each subclass. 41 | """ 42 | raise NotImplementedError 43 | 44 | def viewer_setup(self): 45 | """ 46 | This method is called when the viewer is initialized. 47 | Optionally implement this method, if you need to tinker with camera position 48 | and so forth. 49 | """ 50 | self.viewer.cam.trackbodyid = 1 51 | self.viewer.cam.distance = self.model.stat.extent * 1.5 52 | self.viewer.cam.lookat[2] = 1.5 53 | self.viewer.cam.lookat[0] = 2.0 54 | self.viewer.cam.elevation = -20 55 | self.viewer.vopt.geomgroup[2] = 0 56 | self.viewer._render_every_frame = True 57 | 58 | def viewer_is_paused(self): 59 | return self.viewer._paused 60 | 61 | # ----------------------------- 62 | # (some methods are taken directly from dm_control) 63 | 64 | @contextlib.contextmanager 65 | def disable(self, *flags): 66 | """Context manager for temporarily disabling MuJoCo flags. 67 | 68 | Args: 69 | *flags: Positional arguments specifying flags to disable. Can be either 70 | lowercase strings (e.g. 'gravity', 'contact') or `mjtDisableBit` enum 71 | values. 72 | 73 | Yields: 74 | None 75 | 76 | Raises: 77 | ValueError: If any item in `flags` is neither a valid name nor a value 78 | from `mujoco.mjtDisableBit`. 79 | """ 80 | old_bitmask = self.model.opt.disableflags 81 | new_bitmask = old_bitmask 82 | for flag in flags: 83 | if isinstance(flag, str): 84 | try: 85 | field_name = "mjDSBL_" + flag.upper() 86 | flag = getattr(mujoco.mjtDisableBit, field_name) 87 | except AttributeError: 88 | valid_names = [ 89 | field_name.split("_")[1].lower() 90 | for field_name in list(mujoco.mjtDisableBit.__members__)[:-1] 91 | ] 92 | raise ValueError("'{}' is not a valid flag name. Valid names: {}" 93 | .format(flag, ", ".join(valid_names))) from None 94 | elif isinstance(flag, int): 95 | flag = mujoco.mjtDisableBit(flag) 96 | new_bitmask |= flag.value 97 | self.model.opt.disableflags = new_bitmask 98 | try: 99 | yield 100 | finally: 101 | self.model.opt.disableflags = old_bitmask 102 | 103 | def reset(self): 104 | mujoco.mj_resetData(self.model, self.data) 105 | ob = self.reset_model() 106 | return ob 107 | 108 | def set_state(self, qpos, qvel): 109 | assert qpos.shape == (self.model.nq,), \ 110 | f"qpos shape {qpos.shape} is expected to be {(self.model.nq,)}" 111 | assert qvel.shape == (self.model.nv,), \ 112 | f"qvel shape {qvel.shape} is expected to be {(self.model.nv,)}" 113 | self.data.qpos[:] = qpos 114 | self.data.qvel[:] = qvel 115 | self.data.act = [] 116 | self.data.plugin_state = [] 117 | # Disable actuation since we don't yet have meaningful control inputs. 118 | with self.disable('actuation'): 119 | mujoco.mj_forward(self.model, self.data) 120 | 121 | @property 122 | def dt(self): 123 | return self.model.opt.timestep * self.frame_skip 124 | 125 | def render(self): 126 | if self.viewer is None: 127 | self.viewer = mujoco_viewer.MujocoViewer(self.model, self.data) 128 | self.viewer_setup() 129 | self.viewer.render() 130 | 131 | def uploadGPU(self, hfieldid=None, meshid=None, texid=None): 132 | # hfield 133 | if hfieldid is not None: 134 | mujoco.mjr_uploadHField(self.model, self.viewer.ctx, hfieldid) 135 | # mesh 136 | if meshid is not None: 137 | mujoco.mjr_uploadMesh(self.model, self.viewer.ctx, meshid) 138 | # texture 139 | if texid is not None: 140 | mujoco.mjr_uploadTexture(self.model, self.viewer.ctx, texid) 141 | 142 | def close(self): 143 | if self.viewer is not None: 144 | self.viewer.close() 145 | self.viewer = None 146 | -------------------------------------------------------------------------------- /envs/common/robot_interface.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import transforms3d as tf3 4 | import mujoco 5 | import torch 6 | import collections 7 | 8 | class RobotInterface(object): 9 | def __init__(self, model, data, rfoot_body_name=None, lfoot_body_name=None, 10 | path_to_nets=None): 11 | self.model = model 12 | self.data = data 13 | 14 | self.rfoot_body_name = rfoot_body_name 15 | self.lfoot_body_name = lfoot_body_name 16 | self.floor_body_name = model.body(0).name 17 | self.robot_root_name = model.body(1).name 18 | 19 | self.stepCounter = 0 20 | 21 | if path_to_nets: 22 | self.load_motor_nets(path_to_nets) 23 | 24 | def load_motor_nets(self, path_to_nets): 25 | self.motor_dyn_nets = {} 26 | for jnt in os.listdir(path_to_nets): 27 | if not os.path.isdir(os.path.join(path_to_nets, jnt)): 28 | continue 29 | net_path = os.path.join(path_to_nets, jnt, "trained_jit.pth") 30 | net = torch.jit.load(net_path) 31 | net.eval() 32 | self.motor_dyn_nets[jnt] = net 33 | self.ctau_buffer = collections.deque(maxlen=25) 34 | self.qdot_buffer = collections.deque(maxlen=25) 35 | return 36 | 37 | def motor_nets_forward(self, cmdTau): 38 | if len(self.ctau_buffer)0) 451 | 452 | def check_lfoot_floor_collision(self): 453 | """ 454 | Returns True if there is a collision between left foot and floor. 455 | """ 456 | return (len(self.get_lfoot_floor_contacts())>0) 457 | 458 | def check_bad_collisions(self, body_names=[]): 459 | """ 460 | Returns True if there are collisions other than specifiedbody-floor, 461 | or feet-floor if body_names is not provided. 462 | """ 463 | num_cons = 0 464 | if not isinstance(body_names, list): 465 | raise TypeError(f"expected list of body names, got '{type(body_names).__name__}'") 466 | if not len(body_names): 467 | num_rcons = len(self.get_rfoot_floor_contacts()) 468 | num_lcons = len(self.get_lfoot_floor_contacts()) 469 | num_cons = num_rcons + num_lcons 470 | for bn in body_names: 471 | num_cons += len(self.get_body_floor_contacts(bn)) 472 | return num_cons != self.data.ncon 473 | 474 | def check_self_collisions(self): 475 | """ 476 | Returns True if there are collisions other than any-geom-floor. 477 | """ 478 | contacts = [self.data.contact[i] for i in range(self.data.ncon)] 479 | for i,c in enumerate(contacts): 480 | geom1_body = self.model.body(self.model.geom_bodyid[c.geom1]) 481 | geom2_body = self.model.body(self.model.geom_bodyid[c.geom2]) 482 | geom1_is_robot = self.model.body(geom1_body.rootid).name==self.robot_root_name 483 | geom2_is_robot = self.model.body(geom2_body.rootid).name==self.robot_root_name 484 | if geom1_is_robot and geom2_is_robot: 485 | return True 486 | return False 487 | 488 | def set_pd_gains(self, kp, kv): 489 | assert kp.size==self.model.nu 490 | assert kv.size==self.model.nu 491 | self.kp = kp.copy() 492 | self.kv = kv.copy() 493 | return 494 | 495 | def step_pd(self, p, v): 496 | target_angles = p 497 | target_speeds = v 498 | 499 | assert type(target_angles)==np.ndarray 500 | assert type(target_speeds)==np.ndarray 501 | 502 | curr_angles = self.get_act_joint_positions() 503 | curr_speeds = self.get_act_joint_velocities() 504 | 505 | perror = target_angles - curr_angles 506 | verror = target_speeds - curr_speeds 507 | 508 | assert self.kp.size==perror.size 509 | assert self.kv.size==verror.size 510 | return self.kp * perror + self.kv * verror 511 | 512 | def set_motor_torque(self, torque, motor_dyn_fwd = False): 513 | """ 514 | Apply torques to motors. 515 | """ 516 | if isinstance(torque, np.ndarray): 517 | assert torque.shape==(self.nu(), ) 518 | ctrl = torque 519 | elif isinstance(torque, list): 520 | assert len(torque)==self.nu() 521 | ctrl = np.copy(torque) 522 | else: 523 | raise Exception("motor torque should be list of ndarray.") 524 | try: 525 | if motor_dyn_fwd: 526 | if not hasattr(self, 'motor_dyn_nets'): 527 | raise Exception("motor dynamics network are not defined.") 528 | gear = self.get_gear_ratios() 529 | ctrl = self.motor_nets_forward(ctrl*gear) 530 | ctrl /= gear 531 | np.copyto(self.data.ctrl, ctrl) 532 | except Exception as e: 533 | print("Could not apply motor torque.") 534 | print(e) 535 | return 536 | 537 | def step(self, mj_step=True, nstep=1): 538 | """ 539 | (Adapted from dm_control/mujoco/engine.py) 540 | 541 | Advances physics with up-to-date position and velocity dependent fields. 542 | Args: 543 | nstep: Optional integer, number of steps to take. 544 | """ 545 | if mj_step: 546 | mujoco.mj_step(self.model, self.data, nstep) 547 | self.stepCounter += nstep 548 | return 549 | 550 | # In the case of Euler integration we assume mj_step1 has already been 551 | # called for this state, finish the step with mj_step2 and then update all 552 | # position and velocity related fields with mj_step1. This ensures that 553 | # (most of) mjData is in sync with qpos and qvel. In the case of non-Euler 554 | # integrators (e.g. RK4) an additional mj_step1 must be called after the 555 | # last mj_step to ensure mjData syncing. 556 | if self.model.opt.integrator != mujoco.mjtIntegrator.mjINT_RK4.value: 557 | mujoco.mj_step2(self.model, self.data) 558 | if nstep > 1: 559 | mujoco.mj_step(self.model, self.data, nstep-1) 560 | else: 561 | mujoco.mj_step(self.model, self.data, nstep) 562 | 563 | mujoco.mj_step1(self.model, self.data) 564 | 565 | self.stepCounter += nstep 566 | -------------------------------------------------------------------------------- /envs/h1/__init__.py: -------------------------------------------------------------------------------- 1 | from envs.h1.h1_env import H1Env 2 | -------------------------------------------------------------------------------- /envs/h1/configs/base.yaml: -------------------------------------------------------------------------------- 1 | sim_dt: 0.001 2 | control_dt: 0.025 3 | obs_history_len: 1 4 | action_smoothing: 0.5 5 | 6 | ctrllimited: false 7 | jointlimited: false 8 | reduced_xml: true 9 | 10 | init_noise: 3 11 | 12 | pdgains: 13 | "left_hip_yaw_joint": [100, 10] 14 | "left_hip_roll_joint": [100, 10] 15 | "left_hip_pitch_joint": [100, 10] 16 | "left_knee_joint": [100, 10] 17 | "left_ankle_joint": [20, 4] 18 | "right_hip_yaw_joint": [100, 10] 19 | "right_hip_roll_joint": [100, 10] 20 | "right_hip_pitch_joint": [100, 10] 21 | "right_knee_joint": [100, 10] 22 | "right_ankle_joint": [20, 4] 23 | "torso_joint": [40, 4] 24 | "left_shoulder_pitch_joint": [20, 2] 25 | "left_shoulder_roll_joint": [20, 2] 26 | "left_shoulder_yaw_joint": [20, 2] 27 | "left_elbow_joint": [20, 2] 28 | "right_shoulder_pitch_joint": [20, 2] 29 | "right_shoulder_roll_joint": [20, 2] 30 | "right_shoulder_yaw_joint": [20, 2] 31 | "right_elbow_joint": [20, 2] 32 | 33 | observation_noise: 34 | enabled: true 35 | multiplier: 1.0 36 | type: "uniform" # or "gaussian" 37 | scales: 38 | root_orient: 0.05 39 | root_ang_vel: 0.05 40 | motor_pos: 0.02 41 | motor_vel: 0.05 42 | motor_tau: 5.0 43 | 44 | perturbation: 45 | enable: true 46 | bodies: ["pelvis", "torso_link"] 47 | force_magnitude: 10 48 | torque_magnitude: 2 49 | interval: 5 50 | 51 | dynamics_randomization: 52 | enable: true 53 | interval: 0.5 54 | 55 | -------------------------------------------------------------------------------- /envs/h1/gen_xml.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import models 4 | from dm_control import mjcf 5 | import transforms3d as tf3 6 | 7 | H1_DESCRIPTION_PATH=os.path.join(os.path.dirname(models.__file__), "mujoco_menagerie/unitree_h1/scene.xml") 8 | 9 | LEG_JOINTS = ["left_hip_yaw", "left_hip_roll", "left_hip_pitch", "left_knee", "left_ankle", 10 | "right_hip_yaw", "right_hip_roll", "right_hip_pitch", "right_knee", "right_ankle"] 11 | WAIST_JOINTS = ["torso"] 12 | ARM_JOINTS = ["left_shoulder_pitch", "left_shoulder_roll", "left_shoulder_yaw", "left_elbow", 13 | "right_shoulder_pitch", "right_shoulder_roll", "right_shoulder_yaw", "right_elbow"] 14 | 15 | def create_rangefinder_array(mjcf_model, num_rows=4, num_cols=4, spacing=0.4): 16 | for i in range(num_rows*num_cols): 17 | name = 'rf' + repr(i) 18 | u = (i % num_cols) 19 | v = (i // num_rows) 20 | x = (v - (num_cols - 1) / 2) * spacing 21 | y = ((num_rows - 1) / 2 - u) * (-spacing) 22 | # add sites 23 | mjcf_model.find('body', 'pelvis').add('site', 24 | name=name, 25 | pos=[x, y, 0], 26 | quat='0 1 0 0') 27 | 28 | # add sensors 29 | mjcf_model.sensor.add('rangefinder', 30 | name=name, 31 | site=name) 32 | 33 | return mjcf_model 34 | 35 | def remove_joints_and_actuators(mjcf_model, config): 36 | # remove joints 37 | for limb in config['unused_joints']: 38 | for joint in limb: 39 | mjcf_model.find('joint', joint).remove() 40 | 41 | # remove all actuators with no corresponding joint 42 | for mot in mjcf_model.actuator.motor: 43 | mot.user = None 44 | if mot.joint==None: 45 | mot.remove() 46 | return mjcf_model 47 | 48 | def builder(export_path, config): 49 | print("Modifying XML model...") 50 | mjcf_model = mjcf.from_path(H1_DESCRIPTION_PATH) 51 | 52 | mjcf_model.model = 'h1' 53 | 54 | # modify model 55 | mjcf_model = remove_joints_and_actuators(mjcf_model, config) 56 | mjcf_model.find('default', 'visual').geom.group = 1 57 | mjcf_model.find('default', 'collision').geom.group = 2 58 | if 'ctrllimited' in config: 59 | mjcf_model.find('default', 'h1').motor.ctrllimited = config['ctrllimited'] 60 | if 'jointlimited' in config: 61 | mjcf_model.find('default', 'h1').joint.limited = config['jointlimited'] 62 | 63 | # rename all motors to fit assumed convention 64 | for mot in mjcf_model.actuator.motor: 65 | mot.name = mot.name + "_motor" 66 | 67 | # remove keyframe for now 68 | mjcf_model.keyframe.remove() 69 | 70 | # remove vis geoms and assets, if needed 71 | if 'minimal' in config: 72 | if config['minimal']: 73 | mjcf_model.find('default', 'collision').geom.group = 1 74 | meshes = mjcf_model.asset.mesh 75 | for mesh in meshes: 76 | mesh.remove() 77 | for geom in mjcf_model.find_all('geom'): 78 | if geom.dclass: 79 | if geom.dclass.dclass=="visual": 80 | geom.remove() 81 | 82 | # set name of freejoint 83 | mjcf_model.find('body', 'pelvis').freejoint.name = 'root' 84 | 85 | # add rangefinder 86 | if 'rangefinder' in config: 87 | if config['rangefinder']: 88 | mjcf_model = create_rangefinder_array(mjcf_model) 89 | 90 | # create a raised platform 91 | if 'raisedplatform' in config: 92 | if config['raisedplatform']: 93 | block_pos = [2.5, 0, 0] 94 | block_size = [3, 3, 1] 95 | name = 'raised-platform' 96 | mjcf_model.worldbody.add('body', name=name, pos=block_pos) 97 | mjcf_model.find('body', name).add('geom', name=name, group='3', 98 | condim='3', friction='.8 .1 .1', size=block_size, 99 | type='box', material='') 100 | 101 | # set some size options 102 | mjcf_model.size.njmax = "-1" 103 | mjcf_model.size.nconmax = "-1" 104 | mjcf_model.size.nuser_actuator = "-1" 105 | 106 | # export model 107 | mjcf.export_with_assets(mjcf_model, out_dir=export_path, precision=5) 108 | path_to_xml = os.path.join(export_path, mjcf_model.model + '.xml') 109 | print("Exporting XML model to ", path_to_xml) 110 | return 111 | -------------------------------------------------------------------------------- /envs/h1/h1_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | import transforms3d as tf3 5 | import collections 6 | 7 | from robots.robot_base import RobotBase 8 | from envs.common import mujoco_env 9 | from envs.common import robot_interface 10 | from envs.common import config_builder 11 | 12 | from .gen_xml import * 13 | 14 | class Task: 15 | def __init__(self, client, neutral_pose): 16 | self._client = client 17 | self.neutral_pose = neutral_pose 18 | 19 | def calc_reward(self, prev_torque, prev_action, action): 20 | root_pose = self._client.get_object_affine_by_name("pelvis", 'OBJ_BODY') 21 | 22 | # height reward 23 | target_root_h = 0.98 24 | root_h = root_pose[2, 3] 25 | height_error = np.linalg.norm(root_h - target_root_h) 26 | 27 | # upperbody reward 28 | head_pose_offset = np.zeros(2) 29 | head_pose = self._client.get_object_affine_by_name("head_link", 'OBJ_BODY') 30 | head_pos_in_robot_base = np.linalg.inv(root_pose).dot(head_pose)[:2, 3] - head_pose_offset 31 | upperbody_error = np.linalg.norm(head_pos_in_robot_base) 32 | 33 | # posture reward 34 | current_pose = np.array(self._client.get_act_joint_positions())[:10] 35 | posture_error = np.linalg.norm(current_pose - self.neutral_pose) 36 | 37 | # torque reward 38 | tau_error = np.linalg.norm(self._client.get_act_joint_torques()) 39 | 40 | # velocity reward 41 | root_vel = self._client.get_body_vel("pelvis", frame=1)[0][:2] 42 | fwd_vel_error = np.linalg.norm(root_vel) 43 | yaw_vel = self._client.get_qvel()[5] 44 | yaw_vel_error = np.linalg.norm(yaw_vel) 45 | 46 | reward = { 47 | "com_vel_error": 0.3 * np.exp(-4 * np.square(fwd_vel_error)), 48 | "yaw_vel_error": 0.3 * np.exp(-4 * np.square(yaw_vel_error)), 49 | "height": 0.1 * np.exp(-0.5 * np.square(height_error)), 50 | "upperbody": 0.1 * np.exp(-40*np.square(upperbody_error)), 51 | "joint_torque_reward": 0.1 * np.exp(-5e-5*np.square(tau_error)), 52 | "posture": 0.1 * np.exp(-1*np.square(posture_error)), 53 | } 54 | return reward 55 | 56 | def step(self): 57 | pass 58 | 59 | def substep(self): 60 | pass 61 | 62 | def done(self): 63 | root_jnt_adr = self._client.model.body("pelvis").jntadr[0] 64 | root_qpos_adr = self._client.model.joint(root_jnt_adr).qposadr[0] 65 | qpos = self._client.get_qpos()[root_qpos_adr:root_qpos_adr+7] 66 | contact_flag = self._client.check_self_collisions() 67 | terminate_conditions = {"qpos[2]_ll":(qpos[2] < 0.9), 68 | "qpos[2]_ul":(qpos[2] > 1.4), 69 | "contact_flag":contact_flag, 70 | } 71 | done = True in terminate_conditions.values() 72 | return done 73 | 74 | def reset(self): 75 | pass 76 | 77 | class H1Env(mujoco_env.MujocoEnv): 78 | def __init__(self, path_to_yaml = None): 79 | 80 | ## Load CONFIG from yaml ## 81 | if path_to_yaml is None: 82 | path_to_yaml = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/base.yaml') 83 | 84 | self.cfg = config_builder.load_yaml(path_to_yaml) 85 | 86 | sim_dt = self.cfg.sim_dt 87 | control_dt = self.cfg.control_dt 88 | frame_skip = (control_dt/sim_dt) 89 | 90 | self.dynrand_interval = self.cfg.dynamics_randomization.interval/control_dt 91 | self.perturb_interval = self.cfg.perturbation.interval/control_dt 92 | self.history_len = self.cfg.obs_history_len 93 | 94 | path_to_xml = '/tmp/mjcf-export/h1/h1.xml' 95 | if not os.path.exists(path_to_xml): 96 | export_dir = os.path.dirname(path_to_xml) 97 | builder(export_dir, config={ 98 | 'unused_joints': [WAIST_JOINTS, ARM_JOINTS], 99 | 'rangefinder': False, 100 | 'raisedplatform': False, 101 | 'ctrllimited': self.cfg.ctrllimited, 102 | 'jointlimited': self.cfg.jointlimited, 103 | 'minimal': self.cfg.reduced_xml, 104 | }) 105 | 106 | mujoco_env.MujocoEnv.__init__(self, path_to_xml, sim_dt, control_dt) 107 | 108 | # the actual weight of the robot is about 7kg higher 109 | self.model.body("pelvis").mass = 8.89 110 | self.model.body("torso_link").mass = 21.289 111 | 112 | # list of desired actuators 113 | self.leg_names = LEG_JOINTS 114 | gains_dict = self.cfg.pdgains.to_dict() 115 | kp, kd = zip(*[gains_dict[jn] for jn in self.leg_names]) 116 | pdgains = np.array([kp, kd]) 117 | 118 | # define nominal pose 119 | base_position = [0, 0, 0.98] 120 | base_orientation = [1, 0, 0, 0] 121 | half_sitting_pose = [ 122 | 0, 0, -0.2, 0.6, -0.4, 123 | 0, 0, -0.2, 0.6, -0.4, 124 | ] 125 | 126 | # set up interface 127 | self.interface = robot_interface.RobotInterface(self.model, self.data, 'right_ankle_link', 'left_ankle_link', None) 128 | self.nominal_pose = base_position + base_orientation + half_sitting_pose 129 | 130 | # set up task 131 | self.task = Task(self.interface, half_sitting_pose) 132 | 133 | self.robot = RobotBase(pdgains, control_dt, self.interface, self.task) 134 | 135 | # set action space 136 | action_space_size = len(self.leg_names) 137 | action = np.zeros(action_space_size) 138 | self.action_space = np.zeros(action_space_size) 139 | self.prev_prediction = np.zeros(action_space_size) 140 | 141 | # set observation space 142 | self.base_obs_len = 35 143 | self.observation_history = collections.deque(maxlen=self.history_len) 144 | self.observation_space = np.zeros(self.base_obs_len*self.history_len) 145 | 146 | # manually define observation mean and std 147 | self.obs_mean = np.concatenate(( 148 | np.zeros(5), 149 | half_sitting_pose, np.zeros(10), np.zeros(10), 150 | )) 151 | 152 | self.obs_std = np.concatenate(( 153 | [0.2, 0.2, 1, 1, 1], 154 | 0.5*np.ones(10), 4*np.ones(10), 100*np.ones(10), 155 | )) 156 | 157 | self.obs_mean = np.tile(self.obs_mean, self.history_len) 158 | self.obs_std = np.tile(self.obs_std, self.history_len) 159 | 160 | # copy the original model 161 | self.default_model = copy.deepcopy(self.model) 162 | 163 | def get_obs(self): 164 | # internal state 165 | qpos = np.copy(self.interface.get_qpos()) 166 | qvel = np.copy(self.interface.get_qvel()) 167 | root_r, root_p = tf3.euler.quat2euler(qpos[3:7])[0:2] 168 | root_r = np.array([root_r]) 169 | root_p = np.array([root_p]) 170 | root_ang_vel = qvel[3:6] 171 | motor_pos = self.interface.get_act_joint_positions() 172 | motor_vel = self.interface.get_act_joint_velocities() 173 | motor_tau = self.interface.get_act_joint_torques() 174 | 175 | # add some Gaussian noise to observations 176 | if self.cfg.observation_noise.enabled: 177 | scales = self.cfg.observation_noise.scales 178 | level = self.cfg.observation_noise.multiplier 179 | 180 | # add some noise to observations 181 | if self.cfg.observation_noise.enabled: 182 | noise_type = self.cfg.observation_noise.type 183 | scales = self.cfg.observation_noise.scales 184 | level = self.cfg.observation_noise.multiplier 185 | if noise_type=="uniform": 186 | noise = lambda x, n : np.random.uniform(-x, x, n) 187 | elif noise_type=="gaussian": 188 | noise = lambda x, n : np.random.randn(n) * x 189 | else: 190 | raise Exception("Observation noise type can only be \"uniform\" or \"gaussian\"") 191 | root_r += noise(scales.root_orient * level, 1) 192 | root_p += noise(scales.root_orient * level, 1) 193 | root_ang_vel += noise(scales.root_ang_vel * level, len(root_ang_vel)) 194 | motor_pos += noise(scales.motor_pos * level, len(motor_pos)) 195 | motor_vel += noise(scales.motor_vel * level, len(motor_vel)) 196 | motor_tau += noise(scales.motor_tau * level, len(motor_tau)) 197 | 198 | robot_state = np.concatenate([ 199 | root_r, root_p, root_ang_vel, motor_pos, motor_vel, motor_tau, 200 | ]) 201 | 202 | state = robot_state.copy() 203 | assert state.shape==(self.base_obs_len,), \ 204 | "State vector length expected to be: {} but is {}".format(self.base_obs_len, len(state)) 205 | 206 | if len(self.observation_history)==0: 207 | for _ in range(self.history_len): 208 | self.observation_history.appendleft(np.zeros_like(state)) 209 | self.observation_history.appendleft(state) 210 | else: 211 | self.observation_history.appendleft(state) 212 | return np.array(self.observation_history).flatten() 213 | 214 | def step(self, action): 215 | # Compute the applied action to actuators 216 | # (targets <- Policy predictions) 217 | # (offsets <- Half-sitting pose) 218 | 219 | # action vector assumed to be in the following order: 220 | # [leg_0, leg_1, ..., leg_n, waist_0, ..., waist_n, arm_0, arm_1, ..., arm_n] 221 | targets = self.cfg.action_smoothing * action + \ 222 | (1 - self.cfg.action_smoothing) * self.prev_prediction 223 | offsets = [ 224 | self.nominal_pose[self.interface.get_jnt_qposadr_by_name(jnt)[0]] 225 | for jnt in self.leg_names 226 | ] 227 | 228 | rewards, done = self.robot.step(targets, np.asarray(offsets)) 229 | obs = self.get_obs() 230 | 231 | if self.cfg.dynamics_randomization.enable and np.random.randint(self.dynrand_interval)==0: 232 | self.randomize_dyn() 233 | 234 | if self.cfg.perturbation.enable and np.random.randint(self.perturb_interval)==0: 235 | self.randomize_perturb() 236 | 237 | self.prev_prediction = action 238 | 239 | return obs, sum(rewards.values()), done, rewards 240 | 241 | def reset_model(self): 242 | if self.cfg.dynamics_randomization.enable: 243 | self.randomize_dyn() 244 | 245 | init_qpos, init_qvel = self.nominal_pose.copy(), [0] * self.interface.nv() 246 | 247 | # add some initialization noise to root orientation (roll, pitch) 248 | c = self.cfg.init_noise * np.pi/180 249 | root_adr = self.interface.get_jnt_qposadr_by_name('root')[0] 250 | init_qpos[root_adr+2] = np.random.uniform(1.0, 1.02) 251 | init_qpos[root_adr+3:root_adr+7] = tf3.euler.euler2quat(np.random.uniform(-c, c), np.random.uniform(-c, c), 0) 252 | init_qpos[root_adr+7:] += np.random.uniform(-c, c, len(self.leg_names)) 253 | 254 | # set up init state 255 | self.set_state( 256 | np.asarray(init_qpos), 257 | np.asarray(init_qvel) 258 | ) 259 | # do a few simulation steps to avoid big contact forces in the start 260 | for _ in range(3): 261 | self.interface.step() 262 | 263 | self.task.reset() 264 | 265 | self.prev_prediction = np.zeros_like(self.prev_prediction) 266 | self.observation_history = collections.deque(maxlen=self.history_len) 267 | obs = self.get_obs() 268 | return obs 269 | 270 | #### randomizations and other utility functions ########### 271 | def randomize_perturb(self): 272 | frc_mag = self.cfg.perturbation.force_magnitude 273 | tau_mag = self.cfg.perturbation.force_magnitude 274 | for body in self.cfg.perturbation.bodies: 275 | self.data.body(body).xfrc_applied[:3] = np.random.uniform(-frc_mag, frc_mag, 3) 276 | self.data.body(body).xfrc_applied[3:] = np.random.uniform(-tau_mag, tau_mag, 3) 277 | if np.random.randint(2)==0: 278 | self.data.xfrc_applied = np.zeros_like(self.data.xfrc_applied) 279 | 280 | def randomize_dyn(self): 281 | # dynamics randomization 282 | dofadr = [self.interface.get_jnt_qveladr_by_name(jn) 283 | for jn in self.leg_names] 284 | for jnt in dofadr: 285 | self.model.dof_frictionloss[jnt] = np.random.uniform(0, 2) # actuated joint frictionloss 286 | self.model.dof_damping[jnt] = np.random.uniform(0.02, 2) # actuated joint damping 287 | 288 | # randomize com 289 | bodies = ["pelvis"] 290 | for legjoint in self.leg_names: 291 | bodyid = self.model.joint(legjoint).bodyid 292 | bodyname = self.model.body(bodyid).name 293 | bodies.append(bodyname) 294 | 295 | for body in bodies: 296 | default_mass = self.default_model.body(body).mass[0] 297 | default_ipos = self.default_model.body(body).ipos 298 | self.model.body(body).mass[0] = default_mass*np.random.uniform(0.95, 1.05) 299 | self.model.body(body).ipos = default_ipos + np.random.uniform(-0.01, 0.01, 3) 300 | 301 | def viewer_setup(self): 302 | super().viewer_setup() 303 | self.viewer.cam.distance = 5 304 | self.viewer.cam.lookat[2] = 1.5 305 | self.viewer.cam.lookat[0] = 1.0 306 | -------------------------------------------------------------------------------- /envs/jvrc/__init__.py: -------------------------------------------------------------------------------- 1 | from envs.jvrc.jvrc_walk import JvrcWalkEnv 2 | from envs.jvrc.jvrc_step import JvrcStepEnv 3 | -------------------------------------------------------------------------------- /envs/jvrc/configs/base.yaml: -------------------------------------------------------------------------------- 1 | sim_dt: 0.001 2 | control_dt: 0.025 3 | obs_history_len: 1 4 | action_smoothing: 0.5 5 | 6 | kp: [200, 200, 200, 250, 80, 80, 7 | 200, 200, 200, 250, 80, 80] 8 | kd: [20, 20, 20, 25, 8, 8, 9 | 20, 20, 20, 25, 8, 8,] 10 | -------------------------------------------------------------------------------- /envs/jvrc/gen_xml.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import models 4 | from dm_control import mjcf 5 | 6 | JVRC_DESCRIPTION_PATH=os.path.join(os.path.dirname(models.__file__), "jvrc_mj_description/xml/scene.xml") 7 | 8 | WAIST_JOINTS = ['WAIST_Y', 'WAIST_P', 'WAIST_R'] 9 | HEAD_JOINTS = ['NECK_Y', 'NECK_R', 'NECK_P'] 10 | HAND_JOINTS = ['R_UTHUMB', 'R_LTHUMB', 'R_UINDEX', 'R_LINDEX', 'R_ULITTLE', 'R_LLITTLE', 11 | 'L_UTHUMB', 'L_LTHUMB', 'L_UINDEX', 'L_LINDEX', 'L_ULITTLE', 'L_LLITTLE'] 12 | ARM_JOINTS = ['R_SHOULDER_P', 'R_SHOULDER_R', 'R_SHOULDER_Y', 'R_ELBOW_P', 'R_ELBOW_Y', 'R_WRIST_R', 'R_WRIST_Y', 13 | 'L_SHOULDER_P', 'L_SHOULDER_R', 'L_SHOULDER_Y', 'L_ELBOW_P', 'L_ELBOW_Y', 'L_WRIST_R', 'L_WRIST_Y'] 14 | LEG_JOINTS = ['R_HIP_P', 'R_HIP_R', 'R_HIP_Y', 'R_KNEE', 'R_ANKLE_R', 'R_ANKLE_P', 15 | 'L_HIP_P', 'L_HIP_R', 'L_HIP_Y', 'L_KNEE', 'L_ANKLE_R', 'L_ANKLE_P'] 16 | 17 | 18 | def builder(export_path, config): 19 | print("Modifying XML model...") 20 | mjcf_model = mjcf.from_path(JVRC_DESCRIPTION_PATH) 21 | 22 | mjcf_model.model = 'jvrc' 23 | 24 | # set njmax and nconmax 25 | mjcf_model.size.njmax = -1 26 | mjcf_model.size.nconmax = -1 27 | mjcf_model.statistic.meansize = 0.1 28 | mjcf_model.statistic.meanmass = 2 29 | 30 | # modify skybox 31 | for tx in mjcf_model.asset.texture: 32 | if tx.type=="skybox": 33 | tx.rgb1 = '1 1 1' 34 | tx.rgb2 = '1 1 1' 35 | 36 | # remove all collisions 37 | mjcf_model.contact.remove() 38 | 39 | # remove actuators except for leg joints 40 | for mot in mjcf_model.actuator.motor: 41 | if mot.joint.name not in LEG_JOINTS: 42 | mot.remove() 43 | 44 | # remove unused joints 45 | for joint in WAIST_JOINTS + HEAD_JOINTS + HAND_JOINTS + ARM_JOINTS: 46 | mjcf_model.find('joint', joint).remove() 47 | 48 | # remove existing equality 49 | mjcf_model.equality.remove() 50 | 51 | # set arm joints to fixed configuration 52 | arm_bodies = { 53 | "R_SHOULDER_P_S":[0, -0.052, 0], "R_SHOULDER_R_S":[-0.17, 0, 0], "R_ELBOW_P_S":[0, -0.524, 0], 54 | "L_SHOULDER_P_S":[0, -0.052, 0], "L_SHOULDER_R_S":[ 0.17, 0, 0], "L_ELBOW_P_S":[0, -0.524, 0], 55 | } 56 | for bname, euler in arm_bodies.items(): 57 | mjcf_model.find('body', bname).euler = euler 58 | 59 | # collision geoms 60 | collision_geoms = [ 61 | 'R_HIP_R_S', 'R_HIP_Y_S', 'R_KNEE_S', 62 | 'L_HIP_R_S', 'L_HIP_Y_S', 'L_KNEE_S', 63 | ] 64 | 65 | # remove unused collision geoms 66 | for body in mjcf_model.worldbody.find_all('body'): 67 | for idx, geom in enumerate(body.geom): 68 | geom.name = body.name + '-geom-' + repr(idx) 69 | if (geom.dclass.dclass=="collision"): 70 | if body.name not in collision_geoms: 71 | geom.remove() 72 | 73 | # move collision geoms to different group 74 | mjcf_model.default.default['collision'].geom.group = 3 75 | 76 | # manually create collision geom for feet 77 | mjcf_model.worldbody.find('body', 'R_ANKLE_P_S').add('geom', dclass='collision', size='0.1 0.05 0.01', pos='0.029 0 -0.09778', type='box') 78 | mjcf_model.worldbody.find('body', 'L_ANKLE_P_S').add('geom', dclass='collision', size='0.1 0.05 0.01', pos='0.029 0 -0.09778', type='box') 79 | 80 | # ignore collision 81 | mjcf_model.contact.add('exclude', body1='R_KNEE_S', body2='R_ANKLE_P_S') 82 | mjcf_model.contact.add('exclude', body1='L_KNEE_S', body2='L_ANKLE_P_S') 83 | 84 | # remove unused meshes 85 | meshes = [g.mesh.name for g in mjcf_model.find_all('geom') if g.type=='mesh' or g.type==None] 86 | for mesh in mjcf_model.find_all('mesh'): 87 | if mesh.name not in meshes: 88 | mesh.remove() 89 | 90 | # fix site pos 91 | mjcf_model.worldbody.find('site', 'rf_force').pos = '0.03 0.0 -0.1' 92 | mjcf_model.worldbody.find('site', 'lf_force').pos = '0.03 0.0 -0.1' 93 | 94 | # add box geoms 95 | if 'boxes' in config and config['boxes']==True: 96 | for idx in range(20): 97 | name = 'box' + repr(idx+1).zfill(2) 98 | mjcf_model.worldbody.add('body', name=name, pos=[0, 0, -0.2]) 99 | mjcf_model.find('body', name).add('geom', 100 | name=name, 101 | dclass='collision', 102 | group='0', 103 | size='1 1 0.1', 104 | type='box', 105 | material='') 106 | 107 | # wrap floor geom in a body 108 | mjcf_model.find('geom', 'floor').remove() 109 | mjcf_model.worldbody.add('body', name='floor') 110 | mjcf_model.find('body', 'floor').add('geom', name='floor', type="plane", size="0 0 0.25", material="groundplane") 111 | 112 | # export model 113 | mjcf.export_with_assets(mjcf_model, out_dir=export_path, precision=5) 114 | path_to_xml = os.path.join(export_path, mjcf_model.model + '.xml') 115 | print("Exporting XML model to ", path_to_xml) 116 | return 117 | 118 | if __name__=='__main__': 119 | builder(sys.argv[1], config={}) 120 | 121 | -------------------------------------------------------------------------------- /envs/jvrc/jvrc_step.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import transforms3d as tf3 4 | import collections 5 | 6 | from tasks import stepping_task 7 | from robots.robot_base import RobotBase 8 | from envs.common import mujoco_env 9 | from envs.common import robot_interface 10 | from envs.common import config_builder 11 | from envs.jvrc.jvrc_walk import JvrcWalkEnv 12 | 13 | from .gen_xml import * 14 | 15 | class JvrcStepEnv(JvrcWalkEnv): 16 | def __init__(self, path_to_yaml = None): 17 | 18 | ## Load CONFIG from yaml ## 19 | if path_to_yaml is None: 20 | path_to_yaml = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/base.yaml') 21 | 22 | self.cfg = config_builder.load_yaml(path_to_yaml) 23 | 24 | sim_dt = self.cfg.sim_dt 25 | control_dt = self.cfg.control_dt 26 | frame_skip = (control_dt/sim_dt) 27 | 28 | self.history_len = self.cfg.obs_history_len 29 | 30 | path_to_xml = '/tmp/mjcf-export/jvrc_step/jvrc.xml' 31 | if not os.path.exists(path_to_xml): 32 | export_dir = os.path.dirname(path_to_xml) 33 | builder(export_dir, config={ 34 | "boxes": True, 35 | }) 36 | 37 | mujoco_env.MujocoEnv.__init__(self, path_to_xml, sim_dt, control_dt) 38 | 39 | pdgains = np.zeros((2, 12)) 40 | pdgains[0] = self.cfg.kp 41 | pdgains[1] = self.cfg.kd 42 | 43 | # list of desired actuators 44 | # RHIP_P, RHIP_R, RHIP_Y, RKNEE, RANKLE_R, RANKLE_P 45 | # LHIP_P, LHIP_R, LHIP_Y, LKNEE, LANKLE_R, LANKLE_P 46 | self.actuators = LEG_JOINTS 47 | 48 | # define nominal pose 49 | base_position = [0, 0, 0.81] 50 | base_orientation = [1, 0, 0, 0] 51 | half_sitting_pose = [-30, 0, 0, 50, 0, -24, 52 | -30, 0, 0, 50, 0, -24, 53 | ] # degrees 54 | self.nominal_pose = base_position + base_orientation + np.deg2rad(half_sitting_pose).tolist() 55 | 56 | # set up interface 57 | self.interface = robot_interface.RobotInterface(self.model, self.data, 'R_ANKLE_P_S', 'L_ANKLE_P_S', None) 58 | 59 | # set up task 60 | self.task = stepping_task.SteppingTask(client=self.interface, 61 | dt=control_dt, 62 | neutral_foot_orient=np.array([1, 0, 0, 0]), 63 | root_body='PELVIS_S', 64 | lfoot_body='L_ANKLE_P_S', 65 | rfoot_body='R_ANKLE_P_S', 66 | head_body='NECK_P_S', 67 | ) 68 | # set goal height 69 | self.task._goal_height_ref = 0.80 70 | self.task._total_duration = 1.1 71 | self.task._swing_duration = 0.75 72 | self.task._stance_duration = 0.35 73 | 74 | # set up robot 75 | self.robot = RobotBase(pdgains, control_dt, self.interface, self.task) 76 | 77 | # define indices for action and obs mirror fns 78 | base_mir_obs = [-0.1, 1, # root orient 79 | -2, 3, -4, # root ang vel 80 | 11, -12, -13, 14, -15, 16, # motor pos [1] 81 | 5, -6, -7, 8, -9, 10, # motor pos [2] 82 | 23, -24, -25, 26, -27, 28, # motor vel [1] 83 | 17, -18, -19, 20, -21, 22, # motor vel [2] 84 | ] 85 | append_obs = [(len(base_mir_obs)+i) for i in range(10)] 86 | self.robot.clock_inds = append_obs[0:2] 87 | self.robot.mirrored_obs = np.array(base_mir_obs + append_obs, copy=True).tolist() 88 | self.robot.mirrored_acts = [6, -7, -8, 9, -10, 11, 89 | 0.1, -1, -2, 3, -4, 5,] 90 | 91 | # set action space 92 | action_space_size = len(self.actuators) 93 | action = np.zeros(action_space_size) 94 | self.action_space = np.zeros(action_space_size) 95 | self.prev_prediction = np.zeros(action_space_size) 96 | 97 | # set observation space 98 | self.base_obs_len = 39 99 | self.observation_history = collections.deque(maxlen=self.history_len) 100 | self.observation_space = np.zeros(self.base_obs_len*self.history_len) 101 | 102 | def get_obs(self): 103 | # external state 104 | clock = [np.sin(2 * np.pi * self.task._phase / self.task._period), 105 | np.cos(2 * np.pi * self.task._phase / self.task._period)] 106 | ext_state = np.concatenate((clock, 107 | np.asarray(self.task._goal_steps_x).flatten(), 108 | np.asarray(self.task._goal_steps_y).flatten(), 109 | np.asarray(self.task._goal_steps_z).flatten(), 110 | np.asarray(self.task._goal_steps_theta).flatten())) 111 | 112 | # internal state 113 | qpos = np.copy(self.interface.get_qpos()) 114 | qvel = np.copy(self.interface.get_qvel()) 115 | root_r, root_p = tf3.euler.quat2euler(qpos[3:7])[0:2] 116 | root_r = np.array([root_r]) 117 | root_p = np.array([root_p]) 118 | root_ang_vel = qvel[3:6] 119 | motor_pos = self.interface.get_act_joint_positions() 120 | motor_vel = self.interface.get_act_joint_velocities() 121 | 122 | robot_state = np.concatenate([ 123 | root_r, root_p, root_ang_vel, motor_pos, motor_vel, 124 | ]) 125 | 126 | state = np.concatenate([robot_state, ext_state]) 127 | assert state.shape==(self.base_obs_len,), \ 128 | "State vector length expected to be: {} but is {}".format(self.base_obs_len, len(state)) 129 | 130 | if len(self.observation_history)==0: 131 | for _ in range(self.history_len): 132 | self.observation_history.appendleft(np.zeros_like(state)) 133 | self.observation_history.appendleft(state) 134 | else: 135 | self.observation_history.appendleft(state) 136 | return np.array(self.observation_history).flatten() 137 | -------------------------------------------------------------------------------- /envs/jvrc/jvrc_walk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import transforms3d as tf3 4 | import collections 5 | 6 | from tasks import walking_task 7 | from robots.robot_base import RobotBase 8 | from envs.common import mujoco_env 9 | from envs.common import robot_interface 10 | from envs.common import config_builder 11 | 12 | from .gen_xml import * 13 | 14 | class JvrcWalkEnv(mujoco_env.MujocoEnv): 15 | def __init__(self, path_to_yaml = None): 16 | 17 | ## Load CONFIG from yaml ## 18 | if path_to_yaml is None: 19 | path_to_yaml = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/base.yaml') 20 | 21 | self.cfg = config_builder.load_yaml(path_to_yaml) 22 | 23 | sim_dt = self.cfg.sim_dt 24 | control_dt = self.cfg.control_dt 25 | frame_skip = (control_dt/sim_dt) 26 | 27 | self.history_len = self.cfg.obs_history_len 28 | 29 | path_to_xml = '/tmp/mjcf-export/jvrc_walk/jvrc.xml' 30 | if not os.path.exists(path_to_xml): 31 | export_dir = os.path.dirname(path_to_xml) 32 | builder(export_dir, config={ 33 | }) 34 | 35 | mujoco_env.MujocoEnv.__init__(self, path_to_xml, sim_dt, control_dt) 36 | 37 | pdgains = np.zeros((2, 12)) 38 | pdgains[0] = self.cfg.kp 39 | pdgains[1] = self.cfg.kd 40 | 41 | # list of desired actuators 42 | # RHIP_P, RHIP_R, RHIP_Y, RKNEE, RANKLE_R, RANKLE_P 43 | # LHIP_P, LHIP_R, LHIP_Y, LKNEE, LANKLE_R, LANKLE_P 44 | self.actuators = LEG_JOINTS 45 | 46 | # define nominal pose 47 | base_position = [0, 0, 0.81] 48 | base_orientation = [1, 0, 0, 0] 49 | half_sitting_pose = [-30, 0, 0, 50, 0, -24, 50 | -30, 0, 0, 50, 0, -24, 51 | ] # degrees 52 | self.nominal_pose = base_position + base_orientation + np.deg2rad(half_sitting_pose).tolist() 53 | 54 | # set up interface 55 | self.interface = robot_interface.RobotInterface(self.model, self.data, 'R_ANKLE_P_S', 'L_ANKLE_P_S', None) 56 | 57 | # set up task 58 | self.task = walking_task.WalkingTask(client=self.interface, 59 | dt=control_dt, 60 | neutral_foot_orient=np.array([1, 0, 0, 0]), 61 | root_body='PELVIS_S', 62 | lfoot_body='L_ANKLE_P_S', 63 | rfoot_body='R_ANKLE_P_S', 64 | ) 65 | # set goal height 66 | self.task._goal_height_ref = 0.80 67 | self.task._total_duration = 1.1 68 | self.task._swing_duration = 0.75 69 | self.task._stance_duration = 0.35 70 | 71 | # set up robot 72 | self.robot = RobotBase(pdgains, control_dt, self.interface, self.task) 73 | 74 | # define indices for action and obs mirror fns 75 | base_mir_obs = [-0.1, 1, # root orient 76 | -2, 3, -4, # root ang vel 77 | 11, -12, -13, 14, -15, 16, # motor pos [1] 78 | 5, -6, -7, 8, -9, 10, # motor pos [2] 79 | 23, -24, -25, 26, -27, 28, # motor vel [1] 80 | 17, -18, -19, 20, -21, 22, # motor vel [2] 81 | ] 82 | append_obs = [(len(base_mir_obs)+i) for i in range(3)] 83 | self.robot.clock_inds = append_obs[0:2] 84 | self.robot.mirrored_obs = np.array(base_mir_obs + append_obs, copy=True).tolist() 85 | self.robot.mirrored_acts = [6, -7, -8, 9, -10, 11, 86 | 0.1, -1, -2, 3, -4, 5,] 87 | 88 | # set action space 89 | action_space_size = len(self.actuators) 90 | action = np.zeros(action_space_size) 91 | self.action_space = np.zeros(action_space_size) 92 | self.prev_prediction = np.zeros(action_space_size) 93 | 94 | # set observation space 95 | self.base_obs_len = 32 96 | self.observation_history = collections.deque(maxlen=self.history_len) 97 | self.observation_space = np.zeros(self.base_obs_len*self.history_len) 98 | 99 | # manually define observation mean and std 100 | self.obs_mean = np.concatenate(( 101 | np.zeros(5), 102 | np.deg2rad(half_sitting_pose), np.zeros(12), 103 | [0.5, 0.5, 0.5] 104 | )) 105 | 106 | self.obs_std = np.concatenate(( 107 | [0.2, 0.2, 1, 1, 1], 108 | 0.5*np.ones(12), 4*np.ones(12), 109 | [1, 1, 1,] 110 | )) 111 | 112 | self.obs_mean = np.tile(self.obs_mean, self.history_len) 113 | self.obs_std = np.tile(self.obs_std, self.history_len) 114 | 115 | def get_obs(self): 116 | # external state 117 | clock = [np.sin(2 * np.pi * self.task._phase / self.task._period), 118 | np.cos(2 * np.pi * self.task._phase / self.task._period)] 119 | ext_state = np.concatenate((clock, [self.task._goal_speed_ref])) 120 | 121 | # internal state 122 | qpos = np.copy(self.interface.get_qpos()) 123 | qvel = np.copy(self.interface.get_qvel()) 124 | root_r, root_p = tf3.euler.quat2euler(qpos[3:7])[0:2] 125 | root_r = np.array([root_r]) 126 | root_p = np.array([root_p]) 127 | root_ang_vel = qvel[3:6] 128 | motor_pos = self.interface.get_act_joint_positions() 129 | motor_vel = self.interface.get_act_joint_velocities() 130 | 131 | robot_state = np.concatenate([ 132 | root_r, root_p, root_ang_vel, motor_pos, motor_vel, 133 | ]) 134 | 135 | state = np.concatenate([robot_state, ext_state]) 136 | assert state.shape==(self.base_obs_len,), \ 137 | "State vector length expected to be: {} but is {}".format(self.base_obs_len, len(state)) 138 | 139 | if len(self.observation_history)==0: 140 | for _ in range(self.history_len): 141 | self.observation_history.appendleft(np.zeros_like(state)) 142 | self.observation_history.appendleft(state) 143 | else: 144 | self.observation_history.appendleft(state) 145 | return np.array(self.observation_history).flatten() 146 | 147 | def step(self, action): 148 | # Compute the applied action to actuators 149 | # (targets <- Policy predictions) 150 | targets = self.cfg.action_smoothing * action + \ 151 | (1 - self.cfg.action_smoothing) * self.prev_prediction 152 | # (offsets <- Half-sitting pose) 153 | offsets = [ 154 | self.nominal_pose[self.interface.get_jnt_qposadr_by_name(jnt)[0]] 155 | for jnt in self.actuators 156 | ] 157 | 158 | rewards, done = self.robot.step(targets, np.asarray(offsets)) 159 | obs = self.get_obs() 160 | 161 | self.prev_prediction = action 162 | 163 | return obs, sum(rewards.values()), done, rewards 164 | 165 | def reset_model(self): 166 | init_qpos, init_qvel = self.nominal_pose.copy(), [0] * self.interface.nv() 167 | 168 | # set up init state 169 | self.set_state( 170 | np.asarray(init_qpos), 171 | np.asarray(init_qvel) 172 | ) 173 | 174 | self.task.reset(iter_count=self.robot.iteration_count) 175 | 176 | self.prev_prediction = np.zeros_like(self.prev_prediction) 177 | self.observation_history = collections.deque(maxlen=self.history_len) 178 | obs = self.get_obs() 179 | return obs 180 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohanpsingh/LearningHumanoidWalking/e517f531433777cc59bd33e85a0d0f4884f5b39b/models/__init__.py -------------------------------------------------------------------------------- /rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohanpsingh/LearningHumanoidWalking/e517f531433777cc59bd33e85a0d0f4884f5b39b/rl/__init__.py -------------------------------------------------------------------------------- /rl/algos/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rl/algos/ppo.py: -------------------------------------------------------------------------------- 1 | """Proximal Policy Optimization (clip objective).""" 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.optim as optim 6 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 7 | from torch.nn.utils.rnn import pad_sequence 8 | from torch.nn import functional as F 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from pathlib import Path 12 | import sys 13 | import time 14 | import numpy as np 15 | import datetime 16 | 17 | import ray 18 | 19 | from rl.storage.rollout_storage import PPOBuffer 20 | from rl.policies.actor import Gaussian_FF_Actor, Gaussian_LSTM_Actor 21 | from rl.policies.critic import FF_V, LSTM_V 22 | from rl.envs.normalize import get_normalization_params 23 | 24 | class PPO: 25 | def __init__(self, env_fn, args): 26 | self.gamma = args.gamma 27 | self.lam = args.lam 28 | self.lr = args.lr 29 | self.eps = args.eps 30 | self.ent_coeff = args.entropy_coeff 31 | self.clip = args.clip 32 | self.minibatch_size = args.minibatch_size 33 | self.epochs = args.epochs 34 | self.max_traj_len = args.max_traj_len 35 | self.use_gae = args.use_gae 36 | self.n_proc = args.num_procs 37 | self.grad_clip = args.max_grad_norm 38 | self.mirror_coeff = args.mirror_coeff 39 | self.eval_freq = args.eval_freq 40 | self.recurrent = args.recurrent 41 | self.imitate_coeff = args.imitate_coeff 42 | 43 | # batch_size depends on number of parallel envs 44 | self.batch_size = self.n_proc * self.max_traj_len 45 | 46 | self.total_steps = 0 47 | self.highest_reward = -np.inf 48 | 49 | # counter for training iterations 50 | self.iteration_count = 0 51 | 52 | # directory logging and saving weights 53 | self.save_path = Path(args.logdir) 54 | Path.mkdir(self.save_path, parents=True, exist_ok=True) 55 | 56 | # create the summarywriter 57 | self.writer = SummaryWriter(log_dir=self.save_path, flush_secs=10) 58 | 59 | # create networks or load up pretrained 60 | obs_dim = env_fn().observation_space.shape[0] 61 | action_dim = env_fn().action_space.shape[0] 62 | if args.continued: 63 | path_to_actor = args.continued 64 | path_to_critic = Path(args.continued.parent, "critic" + str(args.continued).split('actor')[1]) 65 | policy = torch.load(path_to_actor, weights_only=False) 66 | critic = torch.load(path_to_critic, weights_only=False) 67 | # policy action noise parameters are initialized from scratch and not loaded 68 | if args.learn_std: 69 | policy.stds = torch.nn.Parameter(args.std_dev * torch.ones(action_dim)) 70 | else: 71 | policy.stds = args.std_dev * torch.ones(action_dim) 72 | print("Loaded (pre-trained) actor from: ", path_to_actor) 73 | print("Loaded (pre-trained) critic from: ", path_to_critic) 74 | else: 75 | if args.recurrent: 76 | policy = Gaussian_LSTM_Actor(obs_dim, action_dim, init_std=args.std_dev, 77 | learn_std=args.learn_std) 78 | critic = LSTM_V(obs_dim) 79 | else: 80 | policy = Gaussian_FF_Actor(obs_dim, action_dim, init_std=args.std_dev, 81 | learn_std=args.learn_std, bounded=False) 82 | critic = FF_V(obs_dim) 83 | 84 | if hasattr(env_fn(), 'obs_mean') and hasattr(env_fn(), 'obs_std'): 85 | obs_mean, obs_std = env_fn().obs_mean, env_fn().obs_std 86 | else: 87 | obs_mean, obs_std = get_normalization_params(iter=args.input_norm_steps, 88 | noise_std=1, 89 | policy=policy, 90 | env_fn=env_fn, 91 | procs=args.num_procs) 92 | with torch.no_grad(): 93 | policy.obs_mean, policy.obs_std = map(torch.Tensor, (obs_mean, obs_std)) 94 | critic.obs_mean = policy.obs_mean 95 | critic.obs_std = policy.obs_std 96 | 97 | base_policy = None 98 | if args.imitate: 99 | base_policy = torch.load(args.imitate, weights_only=False) 100 | 101 | self.old_policy = deepcopy(policy) 102 | self.policy = policy 103 | self.critic = critic 104 | self.base_policy = base_policy 105 | 106 | @staticmethod 107 | def save(nets, save_path, suffix=""): 108 | filetype = ".pt" 109 | for name, net in nets.items(): 110 | path = Path(save_path, name + suffix + filetype) 111 | torch.save(net, path) 112 | print("Saved {} at {}".format(name, path)) 113 | return 114 | 115 | @ray.remote 116 | @torch.no_grad() 117 | @staticmethod 118 | def sample(env_fn, policy, critic, gamma, lam, iteration_count, max_steps, max_traj_len, deterministic): 119 | """ 120 | Sample max_steps number of total timesteps, truncating 121 | trajectories if they exceed max_traj_len number of timesteps. 122 | """ 123 | env = env_fn() 124 | env.robot.iteration_count = iteration_count 125 | 126 | memory = PPOBuffer(policy.state_dim, policy.action_dim, gamma, lam, size=max_traj_len*2) 127 | memory_full = False 128 | 129 | while not memory_full: 130 | state = torch.tensor(env.reset(), dtype=torch.float) 131 | done = False 132 | traj_len = 0 133 | 134 | if hasattr(policy, 'init_hidden_state'): 135 | policy.init_hidden_state() 136 | 137 | if hasattr(critic, 'init_hidden_state'): 138 | critic.init_hidden_state() 139 | 140 | while not done and traj_len < max_traj_len: 141 | action = policy(state, deterministic=deterministic) 142 | value = critic(state) 143 | 144 | next_state, reward, done, _ = env.step(action.numpy().copy()) 145 | 146 | reward = torch.tensor(reward, dtype=torch.float) 147 | memory.store(state, action, reward, value, done) 148 | memory_full = (len(memory) >= max_steps) 149 | 150 | state = torch.tensor(next_state, dtype=torch.float) 151 | traj_len += 1 152 | 153 | #if memory_full: 154 | # break 155 | 156 | value = critic(state) 157 | memory.finish_path(last_val=(not done) * value) 158 | 159 | return memory.get_data() 160 | 161 | def sample_parallel(self, *args, deterministic=False): 162 | 163 | max_steps = (self.batch_size // self.n_proc) 164 | worker_args = (self.gamma, self.lam, self.iteration_count, max_steps, self.max_traj_len, deterministic) 165 | args = args + worker_args 166 | 167 | # Create pool of workers, each getting data for min_steps 168 | worker = self.sample 169 | workers = [worker.remote(*args) for _ in range(self.n_proc)] 170 | result = ray.get(workers) 171 | 172 | # Aggregate results 173 | keys = result[0].keys() 174 | aggregated_data = { 175 | k: torch.cat([r[k] for r in result]) for k in keys 176 | } 177 | 178 | class Data: 179 | def __init__(self, data): 180 | for key, value in data.items(): 181 | setattr(self, key, value) 182 | data = Data(aggregated_data) 183 | return data 184 | 185 | def update_actor_critic(self, obs_batch, action_batch, return_batch, advantage_batch, mask, mirror_observation=None, mirror_action=None): 186 | 187 | pdf = self.policy.distribution(obs_batch) 188 | log_probs = pdf.log_prob(action_batch).sum(-1, keepdim=True) 189 | 190 | old_pdf = self.old_policy.distribution(obs_batch) 191 | old_log_probs = old_pdf.log_prob(action_batch).sum(-1, keepdim=True) 192 | 193 | # ratio between old and new policy, should be one at the first iteration 194 | ratio = (log_probs - old_log_probs).exp() 195 | 196 | # clipped surrogate loss 197 | cpi_loss = ratio * advantage_batch * mask 198 | clip_loss = ratio.clamp(1.0 - self.clip, 1.0 + self.clip) * advantage_batch * mask 199 | actor_loss = -torch.min(cpi_loss, clip_loss).mean() 200 | 201 | # only used for logging 202 | clip_fraction = torch.mean((torch.abs(ratio - 1) > self.clip).float()).item() 203 | 204 | # Value loss using the TD(gae_lambda) target 205 | values = self.critic(obs_batch) 206 | critic_loss = F.mse_loss(return_batch, values) 207 | 208 | # Entropy loss favor exploration 209 | entropy_penalty = -(pdf.entropy() * mask).mean() 210 | 211 | # Mirror Symmetry Loss 212 | deterministic_actions = self.policy(obs_batch) 213 | if mirror_observation is not None and mirror_action is not None: 214 | if self.recurrent: 215 | mir_obs = torch.stack([mirror_observation(obs_batch[i,:,:]) for i in range(obs_batch.shape[0])]) 216 | mirror_actions = self.policy(mir_obs) 217 | else: 218 | mir_obs = mirror_observation(obs_batch) 219 | mirror_actions = self.policy(mir_obs) 220 | mirror_actions = mirror_action(mirror_actions) 221 | mirror_loss = (deterministic_actions - mirror_actions).pow(2).mean() 222 | else: 223 | mirror_loss = torch.zeros_like(actor_loss) 224 | 225 | # imitation loss 226 | if self.base_policy is not None: 227 | imitation_loss = (self.base_policy(obs_batch) - deterministic_actions).pow(2).mean() 228 | else: 229 | imitation_loss = torch.zeros_like(actor_loss) 230 | 231 | # Calculate approximate form of reverse KL Divergence for early stopping 232 | # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 233 | # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 234 | # and Schulman blog: http://joschu.net/blog/kl-approx.html 235 | with torch.no_grad(): 236 | log_ratio = log_probs - old_log_probs 237 | approx_kl_div = torch.mean((ratio - 1) - log_ratio) 238 | 239 | self.actor_optimizer.zero_grad() 240 | (actor_loss + self.mirror_coeff*mirror_loss + self.imitate_coeff*imitation_loss + self.ent_coeff*entropy_penalty).backward(retain_graph=True) 241 | 242 | # Clip the gradient norm to prevent "unlucky" minibatches from 243 | # causing pathological updates 244 | torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_clip) 245 | self.actor_optimizer.step() 246 | 247 | self.critic_optimizer.zero_grad() 248 | critic_loss.backward(retain_graph=True) 249 | 250 | # Clip the gradient norm to prevent "unlucky" minibatches from 251 | # causing pathological updates 252 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.grad_clip) 253 | self.critic_optimizer.step() 254 | 255 | return ( 256 | actor_loss, 257 | entropy_penalty, 258 | critic_loss, 259 | approx_kl_div, 260 | mirror_loss, 261 | imitation_loss, 262 | clip_fraction, 263 | ) 264 | 265 | def evaluate(self, env_fn, nets, itr, num_batches=5): 266 | # set all nets to .eval() mode 267 | for net in nets.values(): 268 | net.eval() 269 | 270 | # collect some batches of data 271 | eval_batches = [] 272 | for _ in range(num_batches): 273 | batch = self.sample_parallel(env_fn, *nets.values(), deterministic=True) 274 | eval_batches.append(batch) 275 | 276 | # save all the networks 277 | self.save(nets, self.save_path, "_" + repr(itr)) 278 | 279 | # save as actor.pt, if it is best 280 | eval_ep_rewards = [float(i) for i in batch.ep_rewards for batch in eval_batches] 281 | avg_eval_ep_rewards = np.mean(eval_ep_rewards) 282 | if self.highest_reward < avg_eval_ep_rewards: 283 | self.highest_reward = avg_eval_ep_rewards 284 | self.save(nets, self.save_path) 285 | 286 | return eval_batches 287 | 288 | def train(self, env_fn, n_itr): 289 | 290 | self.actor_optimizer = optim.Adam(self.policy.parameters(), lr=self.lr, eps=self.eps) 291 | self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.lr, eps=self.eps) 292 | 293 | train_start_time = time.time() 294 | 295 | obs_mirr, act_mirr = None, None 296 | if hasattr(env_fn(), 'mirror_observation'): 297 | obs_mirr = env_fn().mirror_clock_observation 298 | 299 | if hasattr(env_fn(), 'mirror_action'): 300 | act_mirr = env_fn().mirror_action 301 | 302 | for itr in range(n_itr): 303 | print("********** Iteration {} ************".format(itr)) 304 | 305 | self.policy.train() 306 | self.critic.train() 307 | 308 | # set iteration count (could be used for curriculum training) 309 | self.iteration_count = itr 310 | 311 | sample_start_time = time.time() 312 | policy_ref = ray.put(self.policy) 313 | critic_ref = ray.put(self.critic) 314 | batch = self.sample_parallel(env_fn, policy_ref, critic_ref) 315 | observations = batch.states.float() 316 | actions = batch.actions.float() 317 | returns = batch.returns.float() 318 | values = batch.values.float() 319 | 320 | num_samples = len(observations) 321 | elapsed = time.time() - sample_start_time 322 | print("Sampling took {:.2f}s for {} steps.".format(elapsed, num_samples)) 323 | 324 | # Normalize advantage 325 | advantages = returns - values 326 | advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps) 327 | 328 | minibatch_size = self.minibatch_size or num_samples 329 | self.total_steps += num_samples 330 | 331 | self.old_policy.load_state_dict(self.policy.state_dict()) 332 | 333 | optimizer_start_time = time.time() 334 | 335 | actor_losses = [] 336 | entropies = [] 337 | critic_losses = [] 338 | kls = [] 339 | mirror_losses = [] 340 | imitation_losses = [] 341 | clip_fractions = [] 342 | for epoch in range(self.epochs): 343 | if self.recurrent: 344 | random_indices = SubsetRandomSampler(range(len(batch.traj_idx)-1)) 345 | sampler = BatchSampler(random_indices, minibatch_size, drop_last=False) 346 | else: 347 | random_indices = SubsetRandomSampler(range(num_samples)) 348 | sampler = BatchSampler(random_indices, minibatch_size, drop_last=True) 349 | 350 | for indices in sampler: 351 | if self.recurrent: 352 | obs_batch = [observations[int(batch.traj_idx[i]):int(batch.traj_idx[i+1])] for i in indices] 353 | action_batch = [actions[int(batch.traj_idx[i]):int(batch.traj_idx[i+1])] for i in indices] 354 | return_batch = [returns[int(batch.traj_idx[i]):int(batch.traj_idx[i+1])] for i in indices] 355 | advantage_batch = [advantages[int(batch.traj_idx[i]):int(batch.traj_idx[i+1])] for i in indices] 356 | mask = [torch.ones_like(r) for r in return_batch] 357 | 358 | obs_batch = pad_sequence(obs_batch, batch_first=False) 359 | action_batch = pad_sequence(action_batch, batch_first=False) 360 | return_batch = pad_sequence(return_batch, batch_first=False) 361 | advantage_batch = pad_sequence(advantage_batch, batch_first=False) 362 | mask = pad_sequence(mask, batch_first=False) 363 | else: 364 | obs_batch = observations[indices] 365 | action_batch = actions[indices] 366 | return_batch = returns[indices] 367 | advantage_batch = advantages[indices] 368 | mask = 1 369 | 370 | scalars = self.update_actor_critic(obs_batch, action_batch, return_batch, advantage_batch, mask, mirror_observation=obs_mirr, mirror_action=act_mirr) 371 | actor_loss, entropy_penalty, critic_loss, approx_kl_div, mirror_loss, imitation_loss, clip_fraction = scalars 372 | 373 | actor_losses.append(actor_loss.item()) 374 | entropies.append(entropy_penalty.item()) 375 | critic_losses.append(critic_loss.item()) 376 | kls.append(approx_kl_div.item()) 377 | mirror_losses.append(mirror_loss.item()) 378 | imitation_losses.append(imitation_loss.item()) 379 | clip_fractions.append(clip_fraction) 380 | 381 | elapsed = time.time() - optimizer_start_time 382 | print("Optimizer took: {:.2f}s".format(elapsed)) 383 | 384 | action_noise = self.policy.stds.data.tolist() 385 | 386 | sys.stdout.write("-" * 37 + "\n") 387 | sys.stdout.write("| %15s | %15s |" % ('Mean Eprew', "%8.5g" % torch.mean(batch.ep_rewards)) + "\n") 388 | sys.stdout.write("| %15s | %15s |" % ('Mean Eplen', "%8.5g" % torch.mean(batch.ep_lens)) + "\n") 389 | sys.stdout.write("| %15s | %15s |" % ('Actor loss', "%8.3g" % np.mean(actor_losses)) + "\n") 390 | sys.stdout.write("| %15s | %15s |" % ('Critic loss', "%8.3g" % np.mean(critic_losses)) + "\n") 391 | sys.stdout.write("| %15s | %15s |" % ('Mirror loss', "%8.3g" % np.mean(mirror_losses)) + "\n") 392 | sys.stdout.write("| %15s | %15s |" % ('Imitation loss', "%8.3g" % np.mean(imitation_losses)) + "\n") 393 | sys.stdout.write("| %15s | %15s |" % ('Mean KL Div', "%8.3g" % np.mean(kls)) + "\n") 394 | sys.stdout.write("| %15s | %15s |" % ('Mean Entropy', "%8.3g" % np.mean(entropies)) + "\n") 395 | sys.stdout.write("| %15s | %15s |" % ('Clip Fraction', "%8.3g" % np.mean(clip_fractions)) + "\n") 396 | sys.stdout.write("| %15s | %15s |" % ('Mean noise std', "%8.3g" % np.mean(action_noise)) + "\n") 397 | sys.stdout.write("-" * 37 + "\n") 398 | sys.stdout.flush() 399 | 400 | elapsed = time.time() - train_start_time 401 | iter_avg = elapsed/(itr+1) 402 | ETA = round((n_itr - itr)*iter_avg) 403 | print("Total time elapsed: {:.2f}s. Total steps: {} (fps={:.2f}. iter-avg={:.2f}s. ETA={})".format( 404 | elapsed, self.total_steps, self.total_steps/elapsed, iter_avg, datetime.timedelta(seconds=ETA))) 405 | 406 | # To save time, perform evaluation only after 100 iters 407 | if itr==0 or (itr+1)%self.eval_freq==0: 408 | nets = {"actor": self.policy, "critic": self.critic} 409 | 410 | evaluate_start = time.time() 411 | eval_batches = self.evaluate(env_fn, nets, itr) 412 | eval_time = time.time() - evaluate_start 413 | 414 | eval_ep_lens = [float(i) for b in eval_batches for i in b.ep_lens] 415 | eval_ep_rewards = [float(i) for b in eval_batches for i in b.ep_rewards] 416 | avg_eval_ep_lens = np.mean(eval_ep_lens) 417 | avg_eval_ep_rewards = np.mean(eval_ep_rewards) 418 | print("====EVALUATE EPISODE====") 419 | print("(Episode length:{:.3f}. Reward:{:.3f}. Time taken:{:.2f}s)".format( 420 | avg_eval_ep_lens, avg_eval_ep_rewards, eval_time)) 421 | 422 | # tensorboard logging 423 | self.writer.add_scalar("Eval/mean_reward", avg_eval_ep_rewards, itr) 424 | self.writer.add_scalar("Eval/mean_episode_length", avg_eval_ep_lens, itr) 425 | 426 | # tensorboard logging 427 | self.writer.add_scalar("Loss/actor", np.mean(actor_losses), itr) 428 | self.writer.add_scalar("Loss/critic", np.mean(critic_losses), itr) 429 | self.writer.add_scalar("Loss/mirror", np.mean(mirror_losses), itr) 430 | self.writer.add_scalar("Loss/imitation", np.mean(imitation_losses), itr) 431 | self.writer.add_scalar("Train/mean_reward", torch.mean(batch.ep_rewards), itr) 432 | self.writer.add_scalar("Train/mean_episode_length", torch.mean(batch.ep_lens), itr) 433 | self.writer.add_scalar("Train/mean_noise_std", np.mean(action_noise), itr) 434 | -------------------------------------------------------------------------------- /rl/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian import DiagonalGaussian 2 | from .beta import Beta, Beta2 -------------------------------------------------------------------------------- /rl/distributions/beta.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # TODO: extend these for arbitrary bounds 8 | 9 | """A beta distribution, but where the pdf is scaled to (-1, 1)""" 10 | class BoundedBeta(torch.distributions.Beta): 11 | def log_prob(self, x): 12 | return super().log_prob((x + 1) / 2) 13 | 14 | class Beta(nn.Module): 15 | def __init__(self, action_dim): 16 | super(Beta, self).__init__() 17 | 18 | self.action_dim = action_dim 19 | 20 | def forward(self, alpha_beta): 21 | alpha = 1 + F.softplus(alpha_beta[:, :self.action_dim]) 22 | beta = 1 + F.softplus(alpha_beta[:, self.action_dim:]) 23 | return alpha, beta 24 | 25 | def sample(self, x, deterministic): 26 | if deterministic is False: 27 | action = self.evaluate(x).sample() 28 | else: 29 | # E = alpha / (alpha + beta) 30 | return self.evaluate(x).mean 31 | 32 | return 2 * action - 1 33 | 34 | def evaluate(self, x): 35 | alpha, beta = self(x) 36 | return BoundedBeta(alpha, beta) 37 | 38 | 39 | # TODO: think of a better name for this 40 | """Beta distribution parameterized by mean and variance.""" 41 | class Beta2(nn.Module): 42 | def __init__(self, action_dim, init_std=0.25, learn_std=False): 43 | super(Beta2, self).__init__() 44 | 45 | assert init_std < 0.5, "Beta distribution has a max std dev of 0.5" 46 | 47 | self.action_dim = action_dim 48 | 49 | self.logstd = nn.Parameter( 50 | torch.ones(1, action_dim) * np.log(init_std), 51 | requires_grad=learn_std 52 | ) 53 | 54 | self.learn_std = learn_std 55 | 56 | 57 | def forward(self, x): 58 | mean = torch.sigmoid(x) 59 | 60 | var = self.logstd.exp().pow(2) 61 | 62 | """ 63 | alpha = ((1 - mu) / sigma^2 - 1 / mu) * mu^2 64 | beta = alpha * (1 / mu - 1) 65 | 66 | Implemented slightly differently for numerical stability. 67 | """ 68 | alpha = ((1 - mean) / var) * mean.pow(2) - mean 69 | beta = ((1 - mean) / var) * mean - 1 - alpha 70 | 71 | # PROBLEM: if alpha or beta < 1 thats not good 72 | 73 | #assert np.allclose(alpha, ((1 - mean) / var - 1 / mean) * mean.pow(2)) 74 | #assert np.allclose(beta, ((1 - mean) / var - 1 / mean) * mean.pow(2) * (1 / mean - 1)) 75 | 76 | #alpha = 1 + F.softplus(alpha) 77 | #beta = 1 + F.softplus(beta) 78 | 79 | # print("alpha",alpha) 80 | # print("beta",beta) 81 | 82 | # #print(alpha / (alpha + beta)) 83 | # print("mu",mean) 84 | 85 | # #print(torch.sqrt(alpha * beta / ((alpha+beta)**2 * (alpha + beta + 1)))) 86 | # print("var", var) 87 | 88 | # import pdb 89 | # pdb.set_trace() 90 | 91 | return alpha, beta 92 | 93 | def sample(self, x, deterministic): 94 | if deterministic is False: 95 | action = self.evaluate(x).sample() 96 | else: 97 | # E = alpha / (alpha + beta) 98 | return self.evaluate(x).mean 99 | 100 | # 2 * a - 1 puts a in (-1, 1) 101 | return 2 * action - 1 102 | 103 | def evaluate(self, x): 104 | alpha, beta = self(x) 105 | return BoundedBeta(alpha, beta) -------------------------------------------------------------------------------- /rl/distributions/gaussian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | # TODO: look at change of variables function for enforcing 8 | # action bounds correctly 9 | class DiagonalGaussian(nn.Module): 10 | def __init__(self, num_outputs, init_std=1, learn_std=True): 11 | super(DiagonalGaussian, self).__init__() 12 | 13 | self.logstd = nn.Parameter( 14 | torch.ones(1, num_outputs) * np.log(init_std), 15 | requires_grad=learn_std 16 | ) 17 | 18 | self.learn_std = learn_std 19 | 20 | def forward(self, x): 21 | mean = x 22 | 23 | std = self.logstd.exp() 24 | 25 | return mean, std 26 | 27 | def sample(self, x, deterministic): 28 | if deterministic is False: 29 | action = self.evaluate(x).sample() 30 | else: 31 | action, _ = self(x) 32 | 33 | return action 34 | 35 | def evaluate(self, x): 36 | mean, std = self(x) 37 | return torch.distributions.Normal(mean, std) 38 | -------------------------------------------------------------------------------- /rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .normalize import * 2 | from .wrappers import * 3 | -------------------------------------------------------------------------------- /rl/envs/normalize.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py 2 | # Thanks to the authors + OpenAI for the code 3 | 4 | import numpy as np 5 | import functools 6 | import torch 7 | import ray 8 | 9 | from .wrappers import WrapEnv 10 | 11 | @ray.remote 12 | def _run_random_actions(iter, policy, env_fn, noise_std): 13 | 14 | env = WrapEnv(env_fn) 15 | states = np.zeros((iter, env.observation_space.shape[0])) 16 | 17 | state = env.reset() 18 | for t in range(iter): 19 | states[t, :] = state 20 | 21 | state = torch.Tensor(state) 22 | 23 | action = policy(state) 24 | 25 | # add gaussian noise to deterministic action 26 | action = action + torch.randn(action.size()) * noise_std 27 | 28 | state, _, done, _ = env.step(action.data.numpy()) 29 | 30 | if done: 31 | state = env.reset() 32 | 33 | return states 34 | 35 | def get_normalization_params(iter, policy, env_fn, noise_std, procs=4): 36 | print("Gathering input normalization data using {0} steps, noise = {1}...".format(iter, noise_std)) 37 | 38 | states_ids = [_run_random_actions.remote(iter // procs, policy, env_fn, noise_std) for _ in range(procs)] 39 | 40 | states = [] 41 | for _ in range(procs): 42 | ready_ids, _ = ray.wait(states_ids, num_returns=1) 43 | states.extend(ray.get(ready_ids[0])) 44 | states_ids.remove(ready_ids[0]) 45 | 46 | print("Done gathering input normalization data.") 47 | 48 | return np.mean(states, axis=0), np.sqrt(np.var(states, axis=0) + 1e-8) 49 | 50 | 51 | # returns a function that creates a normalized environment, then pre-normalizes it 52 | # using states sampled from a deterministic policy with some added noise 53 | def PreNormalizer(iter, noise_std, policy, *args, **kwargs): 54 | 55 | # noise is gaussian noise 56 | @torch.no_grad() 57 | def pre_normalize(env, policy, num_iter, noise_std): 58 | # save whether or not the environment is configured to do online normalization 59 | online_val = env.online 60 | env.online = True 61 | 62 | state = env.reset() 63 | 64 | for t in range(num_iter): 65 | state = torch.Tensor(state) 66 | 67 | _, action = policy(state) 68 | 69 | # add gaussian noise to deterministic action 70 | action = action + torch.randn(action.size()) * noise_std 71 | 72 | state, _, done, _ = env.step(action.data.numpy()) 73 | 74 | if done: 75 | state = env.reset() 76 | 77 | env.online = online_val 78 | 79 | def _Normalizer(venv): 80 | venv = Normalize(venv, *args, **kwargs) 81 | 82 | print("Gathering input normalization data using {0} steps, noise = {1}...".format(iter, noise_std)) 83 | pre_normalize(venv, policy, iter, noise_std) 84 | print("Done gathering input normalization data.") 85 | 86 | return venv 87 | 88 | return _Normalizer 89 | 90 | # returns a function that creates a normalized environment 91 | def Normalizer(*args, **kwargs): 92 | def _Normalizer(venv): 93 | return Normalize(venv, *args, **kwargs) 94 | 95 | return _Normalizer 96 | 97 | class Normalize: 98 | """ 99 | Vectorized environment base class 100 | """ 101 | def __init__(self, 102 | venv, 103 | ob_rms=None, 104 | ob=True, 105 | ret=False, 106 | clipob=10., 107 | cliprew=10., 108 | online=True, 109 | gamma=1.0, 110 | epsilon=1e-8): 111 | 112 | self.venv = venv 113 | self._observation_space = venv.observation_space 114 | self._action_space = venv.action_space 115 | 116 | if ob_rms is not None: 117 | self.ob_rms = ob_rms 118 | else: 119 | self.ob_rms = RunningMeanStd(shape=self._observation_space.shape) if ob else None 120 | 121 | self.ret_rms = RunningMeanStd(shape=()) if ret else None 122 | self.clipob = clipob 123 | self.cliprew = cliprew 124 | self.ret = np.zeros(self.num_envs) 125 | self.gamma = gamma 126 | self.epsilon = epsilon 127 | 128 | self.online = online 129 | 130 | def step(self, vac): 131 | obs, rews, news, infos = self.venv.step(vac) 132 | 133 | #self.ret = self.ret * self.gamma + rews 134 | obs = self._obfilt(obs) 135 | 136 | # NOTE: shifting mean of reward seems bad; qualitatively changes MDP 137 | if self.ret_rms: 138 | if self.online: 139 | self.ret_rms.update(self.ret) 140 | 141 | rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew) 142 | 143 | return obs, rews, news, infos 144 | 145 | def _obfilt(self, obs): 146 | if self.ob_rms: 147 | if self.online: 148 | self.ob_rms.update(obs) 149 | 150 | obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) 151 | return obs 152 | else: 153 | return obs 154 | 155 | def reset(self): 156 | """ 157 | Reset all environments 158 | """ 159 | obs = self.venv.reset() 160 | return self._obfilt(obs) 161 | 162 | @property 163 | def action_space(self): 164 | return self._action_space 165 | 166 | @property 167 | def observation_space(self): 168 | return self._observation_space 169 | 170 | def close(self): 171 | self.venv.close() 172 | 173 | def render(self): 174 | self.venv.render() 175 | 176 | @property 177 | def num_envs(self): 178 | return self.venv.num_envs 179 | 180 | 181 | 182 | class RunningMeanStd(object): 183 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 184 | def __init__(self, epsilon=1e-4, shape=()): 185 | self.mean = np.zeros(shape, 'float64') 186 | self.var = np.zeros(shape, 'float64') 187 | self.count = epsilon 188 | 189 | 190 | def update(self, x): 191 | batch_mean = np.mean(x, axis=0) 192 | batch_var = np.var(x, axis=0) 193 | batch_count = x.shape[0] 194 | 195 | delta = batch_mean - self.mean 196 | tot_count = self.count + batch_count 197 | 198 | new_mean = self.mean + delta * batch_count / tot_count 199 | m_a = self.var * (self.count) 200 | m_b = batch_var * (batch_count) 201 | M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 202 | new_var = M2 / (self.count + batch_count) 203 | 204 | new_count = batch_count + self.count 205 | 206 | self.mean = new_mean 207 | self.var = new_var 208 | self.count = new_count 209 | 210 | def test_runningmeanstd(): 211 | for (x1, x2, x3) in [ 212 | (np.random.randn(3), np.random.randn(4), np.random.randn(5)), 213 | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)), 214 | ]: 215 | 216 | rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:]) 217 | 218 | x = np.concatenate([x1, x2, x3], axis=0) 219 | ms1 = [x.mean(axis=0), x.var(axis=0)] 220 | rms.update(x1) 221 | rms.update(x2) 222 | rms.update(x3) 223 | ms2 = [rms.mean, rms.var] 224 | 225 | assert np.allclose(ms1, ms2) 226 | -------------------------------------------------------------------------------- /rl/envs/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # Gives a vectorized interface to a single environment 5 | class WrapEnv: 6 | def __init__(self, env_fn): 7 | self.env = env_fn() 8 | 9 | def __getattr__(self, attr): 10 | return getattr(self.env, attr) 11 | 12 | def step(self, action): 13 | state, reward, done, info = self.env.step(action[0]) 14 | return np.array([state]), np.array([reward]), np.array([done]), np.array([info]) 15 | 16 | def render(self): 17 | self.env.render() 18 | 19 | def reset(self): 20 | return np.array([self.env.reset()]) 21 | 22 | # TODO: this is probably a better case for inheritance than for a wrapper 23 | # Gives an interface to exploit mirror symmetry 24 | class SymmetricEnv: 25 | def __init__(self, env_fn, mirrored_obs=None, mirrored_act=None, clock_inds=None, obs_fn=None, act_fn=None): 26 | 27 | assert (bool(mirrored_act) ^ bool(act_fn)) and (bool(mirrored_obs) ^ bool(obs_fn)), \ 28 | "You must provide either mirror indices or a mirror function, but not both, for \ 29 | observation and action." 30 | 31 | if mirrored_act: 32 | self.act_mirror_matrix = torch.Tensor(_get_symmetry_matrix(mirrored_act)) 33 | 34 | elif act_fn: 35 | assert callable(act_fn), "Action mirror function must be callable" 36 | self.mirror_action = act_fn 37 | 38 | if mirrored_obs: 39 | self.obs_mirror_matrix = torch.Tensor(_get_symmetry_matrix(mirrored_obs)) 40 | 41 | elif obs_fn: 42 | assert callable(obs_fn), "Observation mirror function must be callable" 43 | self.mirror_observation = obs_fn 44 | 45 | self.clock_inds = clock_inds 46 | self.env = env_fn() 47 | 48 | def __getattr__(self, attr): 49 | return getattr(self.env, attr) 50 | 51 | def mirror_action(self, action): 52 | return action @ self.act_mirror_matrix 53 | 54 | def mirror_observation(self, obs): 55 | return obs @ self.obs_mirror_matrix 56 | 57 | # To be used when there is a clock in the observation. In this case, the mirrored_obs vector inputted 58 | # when the SymmeticEnv is created should not move the clock input order. The indices of the obs vector 59 | # where the clocks are located need to be inputted. 60 | def mirror_clock_observation(self, obs): 61 | # print("obs.shape = ", obs.shape) 62 | # print("obs_mirror_matrix.shape = ", self.obs_mirror_matrix.shape) 63 | mirror_obs_batch = torch.zeros_like(obs) 64 | history_len = 1 # FIX HISTORY-OF-STATES LENGTH TO 1 FOR NOW 65 | for block in range(history_len): 66 | obs_ = obs[:, self.base_obs_len*block : self.base_obs_len*(block+1)] 67 | mirror_obs = obs_ @ self.obs_mirror_matrix 68 | clock = mirror_obs[:, self.clock_inds] 69 | for i in range(np.shape(clock)[1]): 70 | mirror_obs[:, self.clock_inds[i]] = np.sin(np.arcsin(clock[:, i]) + np.pi) 71 | mirror_obs_batch[:, self.base_obs_len*block : self.base_obs_len*(block+1)] = mirror_obs 72 | return mirror_obs_batch 73 | 74 | 75 | def _get_symmetry_matrix(mirrored): 76 | numel = len(mirrored) 77 | mat = np.zeros((numel, numel)) 78 | 79 | for (i, j) in zip(np.arange(numel), np.abs(np.array(mirrored).astype(int))): 80 | mat[i, j] = np.sign(mirrored[i]) 81 | 82 | return mat 83 | -------------------------------------------------------------------------------- /rl/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import Gaussian_FF_Actor 2 | -------------------------------------------------------------------------------- /rl/policies/actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch import sqrt 6 | 7 | from rl.policies.base import Net 8 | 9 | class Actor(Net): 10 | def __init__(self): 11 | super(Actor, self).__init__() 12 | 13 | def forward(self): 14 | raise NotImplementedError 15 | 16 | class Linear_Actor(Actor): 17 | def __init__(self, state_dim, action_dim, hidden_size=32): 18 | super(Linear_Actor, self).__init__() 19 | 20 | self.l1 = nn.Linear(state_dim, hidden_size) 21 | self.l2 = nn.Linear(hidden_size, action_dim) 22 | 23 | self.action_dim = action_dim 24 | 25 | for p in self.parameters(): 26 | p.data = torch.zeros(p.shape) 27 | 28 | def forward(self, state): 29 | a = self.l1(state) 30 | a = self.l2(a) 31 | return a 32 | 33 | class FF_Actor(Actor): 34 | def __init__(self, state_dim, action_dim, layers=(256, 256), nonlinearity=F.relu): 35 | super(FF_Actor, self).__init__() 36 | 37 | self.actor_layers = nn.ModuleList() 38 | self.actor_layers += [nn.Linear(state_dim, layers[0])] 39 | for i in range(len(layers)-1): 40 | self.actor_layers += [nn.Linear(layers[i], layers[i+1])] 41 | self.network_out = nn.Linear(layers[-1], action_dim) 42 | 43 | self.action_dim = action_dim 44 | self.nonlinearity = nonlinearity 45 | 46 | self.initialize_parameters() 47 | 48 | def forward(self, state, deterministic=True): 49 | x = state 50 | for idx, layer in enumerate(self.actor_layers): 51 | x = self.nonlinearity(layer(x)) 52 | 53 | action = torch.tanh(self.network_out(x)) 54 | return action 55 | 56 | 57 | class LSTM_Actor(Actor): 58 | def __init__(self, state_dim, action_dim, layers=(128, 128), nonlinearity=torch.tanh): 59 | super(LSTM_Actor, self).__init__() 60 | 61 | self.actor_layers = nn.ModuleList() 62 | self.actor_layers += [nn.LSTMCell(state_dim, layers[0])] 63 | for i in range(len(layers)-1): 64 | self.actor_layers += [nn.LSTMCell(layers[i], layers[i+1])] 65 | self.network_out = nn.Linear(layers[i-1], action_dim) 66 | 67 | self.action_dim = action_dim 68 | self.init_hidden_state() 69 | self.nonlinearity = nonlinearity 70 | 71 | def get_hidden_state(self): 72 | return self.hidden, self.cells 73 | 74 | def set_hidden_state(self, data): 75 | if len(data) != 2: 76 | print("Got invalid hidden state data.") 77 | exit(1) 78 | 79 | self.hidden, self.cells = data 80 | 81 | def init_hidden_state(self, batch_size=1): 82 | self.hidden = [torch.zeros(batch_size, l.hidden_size) for l in self.actor_layers] 83 | self.cells = [torch.zeros(batch_size, l.hidden_size) for l in self.actor_layers] 84 | 85 | def forward(self, x, deterministic=True): 86 | dims = len(x.size()) 87 | 88 | if dims == 3: # if we get a batch of trajectories 89 | self.init_hidden_state(batch_size=x.size(1)) 90 | y = [] 91 | for t, x_t in enumerate(x): 92 | for idx, layer in enumerate(self.actor_layers): 93 | c, h = self.cells[idx], self.hidden[idx] 94 | self.hidden[idx], self.cells[idx] = layer(x_t, (h, c)) 95 | x_t = self.hidden[idx] 96 | y.append(x_t) 97 | x = torch.stack([x_t for x_t in y]) 98 | 99 | else: 100 | if dims == 1: # if we get a single timestep (if not, assume we got a batch of single timesteps) 101 | x = x.view(1, -1) 102 | 103 | for idx, layer in enumerate(self.actor_layers): 104 | h, c = self.hidden[idx], self.cells[idx] 105 | self.hidden[idx], self.cells[idx] = layer(x, (h, c)) 106 | x = self.hidden[idx] 107 | x = self.nonlinearity(self.network_out(x)) 108 | 109 | if dims == 1: 110 | x = x.view(-1) 111 | 112 | action = self.network_out(x) 113 | return action 114 | 115 | 116 | class Gaussian_FF_Actor(Actor): # more consistent with other actor naming conventions 117 | def __init__(self, state_dim, action_dim, layers=(256, 256), nonlinearity=torch.nn.functional.relu, 118 | init_std=0.2, learn_std=False, bounded=False, normc_init=True): 119 | super(Gaussian_FF_Actor, self).__init__() 120 | 121 | self.actor_layers = nn.ModuleList() 122 | self.actor_layers += [nn.Linear(state_dim, layers[0])] 123 | for i in range(len(layers)-1): 124 | self.actor_layers += [nn.Linear(layers[i], layers[i+1])] 125 | self.means = nn.Linear(layers[-1], action_dim) 126 | 127 | self.learn_std = learn_std 128 | if self.learn_std: 129 | self.stds = nn.Parameter(init_std * torch.ones(action_dim)) 130 | else: 131 | self.stds = init_std * torch.ones(action_dim) 132 | 133 | self.action_dim = action_dim 134 | self.state_dim = state_dim 135 | self.nonlinearity = nonlinearity 136 | 137 | # Initialized to no input normalization, can be modified later 138 | self.obs_std = 1.0 139 | self.obs_mean = 0.0 140 | 141 | # weight initialization scheme used in PPO paper experiments 142 | self.normc_init = normc_init 143 | 144 | self.bounded = bounded 145 | 146 | self.init_parameters() 147 | 148 | def init_parameters(self): 149 | if self.normc_init: 150 | self.apply(normc_fn) 151 | self.means.weight.data.mul_(0.01) 152 | 153 | def _get_dist_params(self, state): 154 | state = (state - self.obs_mean) / self.obs_std 155 | 156 | x = state 157 | for l in self.actor_layers: 158 | x = self.nonlinearity(l(x)) 159 | mean = self.means(x) 160 | 161 | if self.bounded: 162 | mean = torch.tanh(mean) 163 | 164 | sd = torch.zeros_like(mean) 165 | if hasattr(self, 'stds'): 166 | sd = self.stds 167 | return mean, sd 168 | 169 | def forward(self, state, deterministic=True): 170 | mu, sd = self._get_dist_params(state) 171 | 172 | if not deterministic: 173 | action = torch.distributions.Normal(mu, sd).sample() 174 | else: 175 | action = mu 176 | 177 | return action 178 | 179 | def distribution(self, inputs): 180 | mu, sd = self._get_dist_params(inputs) 181 | return torch.distributions.Normal(mu, sd) 182 | 183 | 184 | class Gaussian_LSTM_Actor(Actor): 185 | def __init__(self, state_dim, action_dim, layers=(128, 128), nonlinearity=F.tanh, normc_init=False, 186 | init_std=0.2, learn_std=False): 187 | super(Gaussian_LSTM_Actor, self).__init__() 188 | 189 | self.actor_layers = nn.ModuleList() 190 | self.actor_layers += [nn.LSTMCell(state_dim, layers[0])] 191 | for i in range(len(layers)-1): 192 | self.actor_layers += [nn.LSTMCell(layers[i], layers[i+1])] 193 | self.network_out = nn.Linear(layers[i-1], action_dim) 194 | 195 | self.action_dim = action_dim 196 | self.state_dim = state_dim 197 | self.init_hidden_state() 198 | self.nonlinearity = nonlinearity 199 | 200 | # Initialized to no input normalization, can be modified later 201 | self.obs_std = 1.0 202 | self.obs_mean = 0.0 203 | 204 | self.learn_std = learn_std 205 | if self.learn_std: 206 | self.stds = nn.Parameter(init_std * torch.ones(action_dim)) 207 | else: 208 | self.stds = init_std * torch.ones(action_dim) 209 | 210 | if normc_init: 211 | self.initialize_parameters() 212 | 213 | self.act = self.forward 214 | 215 | def _get_dist_params(self, state): 216 | state = (state - self.obs_mean) / self.obs_std 217 | 218 | dims = len(state.size()) 219 | 220 | x = state 221 | if dims == 3: # if we get a batch of trajectories 222 | self.init_hidden_state(batch_size=x.size(1)) 223 | action = [] 224 | y = [] 225 | for t, x_t in enumerate(x): 226 | for idx, layer in enumerate(self.actor_layers): 227 | c, h = self.cells[idx], self.hidden[idx] 228 | self.hidden[idx], self.cells[idx] = layer(x_t, (h, c)) 229 | x_t = self.hidden[idx] 230 | y.append(x_t) 231 | x = torch.stack([x_t for x_t in y]) 232 | 233 | else: 234 | if dims == 1: # if we get a single timestep (if not, assume we got a batch of single timesteps) 235 | x = x.view(1, -1) 236 | 237 | for idx, layer in enumerate(self.actor_layers): 238 | h, c = self.hidden[idx], self.cells[idx] 239 | self.hidden[idx], self.cells[idx] = layer(x, (h, c)) 240 | x = self.hidden[idx] 241 | 242 | if dims == 1: 243 | x = x.view(-1) 244 | 245 | mu = self.network_out(x) 246 | sd = self.stds 247 | return mu, sd 248 | 249 | def init_hidden_state(self, batch_size=1): 250 | self.hidden = [torch.zeros(batch_size, l.hidden_size) for l in self.actor_layers] 251 | self.cells = [torch.zeros(batch_size, l.hidden_size) for l in self.actor_layers] 252 | 253 | def forward(self, state, deterministic=True): 254 | mu, sd = self._get_dist_params(state) 255 | 256 | if not deterministic: 257 | action = torch.distributions.Normal(mu, sd).sample() 258 | else: 259 | action = mu 260 | 261 | return action 262 | 263 | def distribution(self, inputs): 264 | mu, sd = self._get_dist_params(inputs) 265 | return torch.distributions.Normal(mu, sd) 266 | 267 | # Initialization scheme for gaussian mlp (from ppo paper) 268 | # NOTE: the fact that this has the same name as a parameter caused a NASTY bug 269 | # apparently "if " evaluates to True in python... 270 | def normc_fn(m): 271 | classname = m.__class__.__name__ 272 | if classname.find('Linear') != -1: 273 | m.weight.data.normal_(0, 1) 274 | m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True)) 275 | if m.bias is not None: 276 | m.bias.data.fill_(0) 277 | -------------------------------------------------------------------------------- /rl/policies/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch import sqrt 6 | 7 | def normc_fn(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Linear') != -1: 10 | m.weight.data.normal_(0, 1) 11 | m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True)) 12 | if m.bias is not None: 13 | m.bias.data.fill_(0) 14 | 15 | # The base class for an actor. Includes functions for normalizing state (optional) 16 | class Net(nn.Module): 17 | def __init__(self): 18 | super(Net, self).__init__() 19 | 20 | self.welford_state_mean = torch.zeros(1) 21 | self.welford_state_mean_diff = torch.ones(1) 22 | self.welford_state_n = 1 23 | 24 | def forward(self): 25 | raise NotImplementedError 26 | 27 | def normalize_state(self, state, update=True): 28 | state = torch.Tensor(state) 29 | 30 | if self.welford_state_n == 1: 31 | self.welford_state_mean = torch.zeros(state.size(-1)) 32 | self.welford_state_mean_diff = torch.ones(state.size(-1)) 33 | 34 | if update: 35 | if len(state.size()) == 1: # If we get a single state vector 36 | state_old = self.welford_state_mean 37 | self.welford_state_mean += (state - state_old) / self.welford_state_n 38 | self.welford_state_mean_diff += (state - state_old) * (state - state_old) 39 | self.welford_state_n += 1 40 | elif len(state.size()) == 2: # If we get a batch 41 | print("NORMALIZING 2D TENSOR (this should not be happening)") 42 | for r_n in r: 43 | state_old = self.welford_state_mean 44 | self.welford_state_mean += (state_n - state_old) / self.welford_state_n 45 | self.welford_state_mean_diff += (state_n - state_old) * (state_n - state_old) 46 | self.welford_state_n += 1 47 | elif len(state.size()) == 3: # If we get a batch of sequences 48 | print("NORMALIZING 3D TENSOR (this really should not be happening)") 49 | for r_t in r: 50 | for r_n in r_t: 51 | state_old = self.welford_state_mean 52 | self.welford_state_mean += (state_n - state_old) / self.welford_state_n 53 | self.welford_state_mean_diff += (state_n - state_old) * (state_n - state_old) 54 | self.welford_state_n += 1 55 | return (state - self.welford_state_mean) / sqrt(self.welford_state_mean_diff / self.welford_state_n) 56 | 57 | def copy_normalizer_stats(self, net): 58 | self.welford_state_mean = net.self_state_mean 59 | self.welford_state_mean_diff = net.welford_state_mean_diff 60 | self.welford_state_n = net.welford_state_n 61 | 62 | def initialize_parameters(self): 63 | self.apply(normc_fn) 64 | -------------------------------------------------------------------------------- /rl/policies/critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from rl.policies.base import Net, normc_fn 6 | 7 | # The base class for a critic. Includes functions for normalizing reward and state (optional) 8 | class Critic(Net): 9 | def __init__(self): 10 | super(Critic, self).__init__() 11 | 12 | self.welford_reward_mean = 0.0 13 | self.welford_reward_mean_diff = 1.0 14 | self.welford_reward_n = 1 15 | 16 | def forward(self): 17 | raise NotImplementedError 18 | 19 | def normalize_reward(self, r, update=True): 20 | if update: 21 | if len(r.size()) == 1: 22 | r_old = self.welford_reward_mean 23 | self.welford_reward_mean += (r - r_old) / self.welford_reward_n 24 | self.welford_reward_mean_diff += (r - r_old) * (r - r_old) 25 | self.welford_reward_n += 1 26 | elif len(r.size()) == 2: 27 | for r_n in r: 28 | r_old = self.welford_reward_mean 29 | self.welford_reward_mean += (r_n - r_old) / self.welford_reward_n 30 | self.welford_reward_mean_diff += (r_n - r_old) * (r_n - r_old) 31 | self.welford_reward_n += 1 32 | else: 33 | raise NotImplementedError 34 | 35 | return (r - self.welford_reward_mean) / torch.sqrt(self.welford_reward_mean_diff / self.welford_reward_n) 36 | 37 | class FF_V(Critic): 38 | def __init__(self, state_dim, layers=(256, 256), nonlinearity=torch.nn.functional.relu, normc_init=True, obs_std=None, obs_mean=None): 39 | super(FF_V, self).__init__() 40 | 41 | self.critic_layers = nn.ModuleList() 42 | self.critic_layers += [nn.Linear(state_dim, layers[0])] 43 | for i in range(len(layers)-1): 44 | self.critic_layers += [nn.Linear(layers[i], layers[i+1])] 45 | self.network_out = nn.Linear(layers[-1], 1) 46 | 47 | self.nonlinearity = nonlinearity 48 | 49 | self.obs_std = obs_std 50 | self.obs_mean = obs_mean 51 | 52 | # weight initialization scheme used in PPO paper experiments 53 | self.normc_init = normc_init 54 | 55 | self.init_parameters() 56 | 57 | def init_parameters(self): 58 | if self.normc_init: 59 | print("Doing norm column initialization.") 60 | self.apply(normc_fn) 61 | 62 | def forward(self, inputs): 63 | inputs = (inputs - self.obs_mean) / self.obs_std 64 | 65 | x = inputs 66 | for l in self.critic_layers: 67 | x = self.nonlinearity(l(x)) 68 | value = self.network_out(x) 69 | 70 | return value 71 | 72 | class LSTM_V(Critic): 73 | def __init__(self, input_dim, layers=(128, 128), normc_init=True): 74 | super(LSTM_V, self).__init__() 75 | 76 | self.critic_layers = nn.ModuleList() 77 | self.critic_layers += [nn.LSTMCell(input_dim, layers[0])] 78 | for i in range(len(layers)-1): 79 | self.critic_layers += [nn.LSTMCell(layers[i], layers[i+1])] 80 | self.network_out = nn.Linear(layers[-1], 1) 81 | 82 | self.init_hidden_state() 83 | 84 | if normc_init: 85 | self.initialize_parameters() 86 | 87 | def get_hidden_state(self): 88 | return self.hidden, self.cells 89 | 90 | def init_hidden_state(self, batch_size=1): 91 | self.hidden = [torch.zeros(batch_size, l.hidden_size) for l in self.critic_layers] 92 | self.cells = [torch.zeros(batch_size, l.hidden_size) for l in self.critic_layers] 93 | 94 | def forward(self, state): 95 | state = (state - self.obs_mean) / self.obs_std 96 | dims = len(state.size()) 97 | 98 | if dims == 3: # if we get a batch of trajectories 99 | self.init_hidden_state(batch_size=state.size(1)) 100 | value = [] 101 | for t, state_batch_t in enumerate(state): 102 | x_t = state_batch_t 103 | for idx, layer in enumerate(self.critic_layers): 104 | c, h = self.cells[idx], self.hidden[idx] 105 | self.hidden[idx], self.cells[idx] = layer(x_t, (h, c)) 106 | x_t = self.hidden[idx] 107 | x_t = self.network_out(x_t) 108 | value.append(x_t) 109 | 110 | x = torch.stack([a.float() for a in value]) 111 | 112 | else: 113 | x = state 114 | if dims == 1: 115 | x = x.view(1, -1) 116 | 117 | for idx, layer in enumerate(self.critic_layers): 118 | c, h = self.cells[idx], self.hidden[idx] 119 | self.hidden[idx], self.cells[idx] = layer(x, (h, c)) 120 | x = self.hidden[idx] 121 | x = self.network_out(x) 122 | 123 | if dims == 1: 124 | x = x.view(-1) 125 | 126 | return x 127 | 128 | 129 | GaussianMLP_Critic = FF_V 130 | -------------------------------------------------------------------------------- /rl/storage/rollout_storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class PPOBuffer: 4 | def __init__(self, obs_len=1, act_len=1, gamma=0.99, lam=0.95, use_gae=False, size=1): 5 | self.states = torch.zeros(size, obs_len, dtype=float) 6 | self.actions = torch.zeros(size, act_len, dtype=float) 7 | self.rewards = torch.zeros(size, 1, dtype=float) 8 | self.values = torch.zeros(size, 1, dtype=float) 9 | self.returns = torch.zeros(size, 1, dtype=float) 10 | self.dones = torch.zeros(size, 1, dtype=float) 11 | 12 | self.gamma, self.lam = gamma, lam 13 | self.ptr = 0 14 | self.traj_idx = [0] 15 | 16 | def __len__(self): 17 | return self.ptr 18 | 19 | def store(self, state, action, reward, value, done): 20 | """ 21 | Append one timestep of agent-environment interaction to the buffer. 22 | """ 23 | self.states[self.ptr]= state 24 | self.actions[self.ptr] = action 25 | self.rewards[self.ptr] = reward 26 | self.values[self.ptr] = value 27 | self.dones[self.ptr] = done 28 | self.ptr += 1 29 | 30 | def finish_path(self, last_val=None): 31 | self.traj_idx += [self.ptr] 32 | rewards = self.rewards[self.traj_idx[-2]:self.traj_idx[-1], 0] 33 | R = last_val.squeeze(0) 34 | returns = torch.zeros_like(rewards) 35 | for i in range(len(rewards) - 1, -1, -1): 36 | R = self.gamma * R + rewards[i] 37 | returns[i] = R 38 | self.returns[self.traj_idx[-2]:self.traj_idx[-1], 0] = returns 39 | self.dones[-1] = True 40 | 41 | def get_data(self): 42 | """ 43 | Return collected data and reset buffer. 44 | 45 | Returns: 46 | dict: Collected trajectory data 47 | """ 48 | ep_lens = [j - i for i, j in zip(self.traj_idx, self.traj_idx[1:])] 49 | ep_rewards = [ 50 | float(sum(self.rewards[int(i):int(j)])) for i, j in zip(self.traj_idx, self.traj_idx[1:]) 51 | ] 52 | data = { 53 | 'states': self.states[:self.ptr], 54 | 'actions': self.actions[:self.ptr], 55 | 'rewards': self.rewards[:self.ptr], 56 | 'values': self.values[:self.ptr], 57 | 'returns': self.returns[:self.ptr], 58 | 'dones': self.dones[:self.ptr], 59 | 'traj_idx': torch.Tensor(self.traj_idx), 60 | 'ep_lens': torch.Tensor(ep_lens), 61 | 'ep_rewards': torch.Tensor(ep_rewards), 62 | } 63 | return data 64 | -------------------------------------------------------------------------------- /rl/utils/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from pathlib import Path 4 | 5 | import mujoco 6 | import mujoco.viewer 7 | 8 | import imageio 9 | from datetime import datetime 10 | 11 | class EvaluateEnv: 12 | def __init__(self, env, policy, args): 13 | self.env = env 14 | self.policy = policy 15 | self.ep_len = args.ep_len 16 | 17 | if args.out_dir is None: 18 | args.out_dir = Path(args.path.parent, "videos") 19 | 20 | video_outdir = Path(args.out_dir) 21 | try: 22 | Path.mkdir(video_outdir, exist_ok=True) 23 | now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 24 | video_fn = Path(video_outdir, args.path.stem + "-" + now + ".mp4") 25 | self.writer = imageio.get_writer(video_fn, fps=60) 26 | except Exception as e: 27 | print("Could not create video writer:", e) 28 | exit(-1) 29 | 30 | @torch.no_grad() 31 | def run(self): 32 | 33 | height = 480 34 | width = 640 35 | renderer = mujoco.Renderer(self.env.model, height, width) 36 | viewer = mujoco.viewer.launch_passive(self.env.model, self.env.data) 37 | frames = [] 38 | 39 | # Make a camera. 40 | cam = viewer.cam 41 | mujoco.mjv_defaultCamera(cam) 42 | cam.elevation = -20 43 | cam.distance = 4 44 | 45 | reset_counter = 0 46 | observation = self.env.reset() 47 | while self.env.data.time < self.ep_len: 48 | 49 | step_start = time.time() 50 | 51 | # forward pass and step 52 | raw = self.policy.forward(torch.tensor(observation, dtype=torch.float32), deterministic=True).detach().numpy() 53 | observation, reward, done, _ = self.env.step(raw.copy()) 54 | 55 | # render scene 56 | cam.lookat = self.env.data.body(1).xpos.copy() 57 | renderer.update_scene(self.env.data, cam) 58 | pixels = renderer.render() 59 | frames.append(pixels) 60 | 61 | viewer.sync() 62 | 63 | if done and reset_counter < 3: 64 | observation = self.env.reset() 65 | reset_counter += 1 66 | 67 | time_until_next_step = max( 68 | 0, self.env.frame_skip*self.env.model.opt.timestep - (time.time() - step_start)) 69 | time.sleep(time_until_next_step) 70 | 71 | for frame in frames: 72 | self.writer.append_data(frame) 73 | self.writer.close() 74 | self.env.close() 75 | viewer.close() 76 | -------------------------------------------------------------------------------- /robots/robot_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class RobotBase(object): 4 | def __init__(self, pdgains, dt, client, task, pdrand_k = 0, sim_bemf = False, sim_motor_dyn = False): 5 | 6 | self.client = client 7 | self.task = task 8 | self.control_dt = dt 9 | self.pdrand_k = pdrand_k 10 | self.sim_bemf = sim_bemf 11 | self.sim_motor_dyn = sim_motor_dyn 12 | 13 | assert (self.sim_bemf & self.sim_motor_dyn)==False, \ 14 | "You cannot simulate back-EMF and motor dynamics simultaneously!" 15 | 16 | # set PD gains 17 | self.kp = pdgains[0] 18 | self.kd = pdgains[1] 19 | assert self.kp.shape==self.kd.shape==(self.client.nu(),), \ 20 | f"kp shape {self.kp.shape} and kd shape {self.kd.shape} must be {(self.client.nu(),)}" 21 | 22 | # torque damping param 23 | self.tau_d = np.zeros(self.client.nu()) 24 | 25 | self.client.set_pd_gains(self.kp, self.kd) 26 | tau = self.client.step_pd(np.zeros(self.client.nu()), np.zeros(self.client.nu())) 27 | w = self.client.get_act_joint_velocities() 28 | assert len(w)==len(tau) 29 | 30 | self.prev_action = None 31 | self.prev_torque = None 32 | self.iteration_count = np.inf 33 | 34 | # frame skip parameter 35 | if (np.around(self.control_dt%self.client.sim_dt(), 6)): 36 | raise Exception("Control dt should be an integer multiple of Simulation dt.") 37 | self.frame_skip = int(self.control_dt/self.client.sim_dt()) 38 | 39 | def _do_simulation(self, target, n_frames): 40 | # randomize PD gains 41 | if self.pdrand_k: 42 | k = self.pdrand_k 43 | kp = np.random.uniform((1-k)*self.kp, (1+k)*self.kp) 44 | kd = np.random.uniform((1-k)*self.kd, (1+k)*self.kd) 45 | self.client.set_pd_gains(kp, kd) 46 | 47 | assert target.shape == (self.client.nu(),), \ 48 | f"Target shape must be {(self.client.nu(),)}" 49 | 50 | ratio = self.client.get_gear_ratios() 51 | 52 | if self.sim_bemf and np.random.randint(10)==0: 53 | self.tau_d = np.random.uniform(5, 40, self.client.nu()) 54 | 55 | for _ in range(n_frames): 56 | w = self.client.get_act_joint_velocities() 57 | tau = self.client.step_pd(target, np.zeros(self.client.nu())) 58 | tau = tau - self.tau_d*w 59 | tau /= ratio 60 | self.client.set_motor_torque(tau, self.sim_motor_dyn) 61 | self.client.step() 62 | 63 | def step(self, action, offset=None): 64 | 65 | if not isinstance(action, np.ndarray): 66 | raise TypeError("Expected action to be a numpy array") 67 | 68 | action = np.copy(action) 69 | 70 | assert action.shape == (self.client.nu(),), \ 71 | f"Action vector length expected to be: {self.client.nu()} but is {action.shape}" 72 | 73 | # If offset is provided, add to action vector 74 | if offset is not None: 75 | if not isinstance(offset, np.ndarray): 76 | raise TypeError("Expected offset to be a numpy array") 77 | assert offset.shape == action.shape, \ 78 | f"Offset shape {offset} must match action shape {action.shape}" 79 | offset = np.copy(offset) 80 | action += offset 81 | 82 | if self.prev_action is None: 83 | self.prev_action = action 84 | if self.prev_torque is None: 85 | self.prev_torque = np.asarray(self.client.get_act_joint_torques()) 86 | 87 | # Perform the simulation 88 | self._do_simulation(action, self.frame_skip) 89 | 90 | # Task-related operations 91 | self.task.step() 92 | rewards = self.task.calc_reward(self.prev_torque, self.prev_action, action) 93 | done = self.task.done() 94 | 95 | self.prev_action = action 96 | self.prev_torque = np.asarray(self.client.get_act_joint_torques()) 97 | 98 | return rewards, done 99 | -------------------------------------------------------------------------------- /run_experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import argparse 4 | import ray 5 | from functools import partial 6 | 7 | import numpy as np 8 | import torch 9 | import pickle 10 | import shutil 11 | 12 | from rl.algos.ppo import PPO 13 | from rl.envs.wrappers import SymmetricEnv 14 | from rl.utils.eval import EvaluateEnv 15 | 16 | def import_env(env_name_str): 17 | if env_name_str=='jvrc_walk': 18 | from envs.jvrc import JvrcWalkEnv as Env 19 | elif env_name_str=='jvrc_step': 20 | from envs.jvrc import JvrcStepEnv as Env 21 | elif env_name_str=='h1': 22 | from envs.h1 import H1Env as Env 23 | else: 24 | raise Exception("Check env name!") 25 | return Env 26 | 27 | def run_experiment(args): 28 | # import the correct environment 29 | Env = import_env(args.env) 30 | 31 | # wrapper function for creating parallelized envs 32 | env_fn = partial(Env, path_to_yaml=args.yaml) 33 | _env = env_fn() 34 | if not args.no_mirror: 35 | try: 36 | print("Wrapping in SymmetricEnv.") 37 | env_fn = partial(SymmetricEnv, env_fn, 38 | mirrored_obs=_env.robot.mirrored_obs, 39 | mirrored_act=_env.robot.mirrored_acts, 40 | clock_inds=_env.robot.clock_inds) 41 | except AttributeError as e: 42 | print("Warning! Cannot use SymmetricEnv.", e) 43 | 44 | # Set up Parallelism 45 | #os.environ['OMP_NUM_THREADS'] = '1' # [TODO: Is this needed?] 46 | if not ray.is_initialized(): 47 | ray.init(num_cpus=args.num_procs) 48 | 49 | # dump hyperparameters 50 | Path.mkdir(args.logdir, parents=True, exist_ok=True) 51 | pkl_path = Path(args.logdir, "experiment.pkl") 52 | with open(pkl_path, 'wb') as f: 53 | pickle.dump(args, f) 54 | 55 | # copy config file 56 | if args.yaml: 57 | config_out_path = Path(args.logdir, "config.yaml") 58 | shutil.copyfile(args.yaml, config_out_path) 59 | 60 | algo = PPO(env_fn, args) 61 | algo.train(env_fn, args.n_itr) 62 | 63 | if __name__ == "__main__": 64 | 65 | parser = argparse.ArgumentParser() 66 | 67 | if sys.argv[1] == 'train': 68 | sys.argv.remove(sys.argv[1]) 69 | 70 | parser.add_argument("--env", required=True, type=str) 71 | parser.add_argument("--logdir", default=Path("/tmp/logs"), type=Path, help="Path to save weights and logs") 72 | parser.add_argument("--input-norm-steps", type=int, default=100000) 73 | parser.add_argument("--n-itr", type=int, default=20000, help="Number of iterations of the learning algorithm") 74 | parser.add_argument("--lr", type=float, default=1e-4, help="Adam learning rate") # Xie 75 | parser.add_argument("--eps", type=float, default=1e-5, help="Adam epsilon (for numerical stability)") 76 | parser.add_argument("--lam", type=float, default=0.95, help="Generalized advantage estimate discount") 77 | parser.add_argument("--gamma", type=float, default=0.99, help="MDP discount") 78 | parser.add_argument("--std-dev", type=float, default=0.223, help="Action noise for exploration") 79 | parser.add_argument("--learn-std", action="store_true", help="Exploration noise will be learned") 80 | parser.add_argument("--entropy-coeff", type=float, default=0.0, help="Coefficient for entropy regularization") 81 | parser.add_argument("--clip", type=float, default=0.2, help="Clipping parameter for PPO surrogate loss") 82 | parser.add_argument("--minibatch-size", type=int, default=64, help="Batch size for PPO updates") 83 | parser.add_argument("--epochs", type=int, default=3, help="Number of optimization epochs per PPO update") #Xie 84 | parser.add_argument("--use-gae", type=bool, default=True,help="Whether or not to calculate returns using Generalized Advantage Estimation") 85 | parser.add_argument("--num-procs", type=int, default=12, help="Number of threads to train on") 86 | parser.add_argument("--max-grad-norm", type=float, default=0.05, help="Value to clip gradients at") 87 | parser.add_argument("--max-traj-len", type=int, default=400, help="Max episode horizon") 88 | parser.add_argument("--no-mirror", required=False, action="store_true", help="to use SymmetricEnv") 89 | parser.add_argument("--mirror-coeff", required=False, default=0.4, type=float, help="weight for mirror loss") 90 | parser.add_argument("--eval-freq", required=False, default=100, type=int, help="Frequency of performing evaluation") 91 | parser.add_argument("--continued", required=False, type=Path, help="path to pretrained weights") 92 | parser.add_argument("--recurrent", required=False, action="store_true", help="use LSTM instead of FF") 93 | parser.add_argument("--imitate", required=False, type=str, default=None, help="Policy to imitate") 94 | parser.add_argument("--imitate-coeff", required=False, type=float, default=0.3, help="Coefficient for imitation loss") 95 | parser.add_argument("--yaml", required=False, type=str, default=None, help="Path to config file passed to Env class") 96 | args = parser.parse_args() 97 | 98 | run_experiment(args) 99 | 100 | elif sys.argv[1] == 'eval': 101 | sys.argv.remove(sys.argv[1]) 102 | 103 | parser.add_argument("--path", required=False, type=Path, default=Path("/tmp/logs"), 104 | help="Path to trained model dir") 105 | parser.add_argument("--out-dir", required=False, type=Path, default=None, 106 | help="Path to directory to save videos") 107 | parser.add_argument("--ep-len", required=False, type=int, default=10, 108 | help="Episode length to play (in seconds)") 109 | args = parser.parse_args() 110 | 111 | path_to_actor = "" 112 | if args.path.is_file() and args.path.suffix==".pt": 113 | path_to_actor = args.path 114 | elif args.path.is_dir(): 115 | path_to_actor = Path(args.path, "actor.pt") 116 | else: 117 | raise Exception("Invalid path to actor module: ", args.path) 118 | 119 | path_to_critic = Path(path_to_actor.parent, "critic" + str(path_to_actor).split('actor')[1]) 120 | path_to_pkl = Path(path_to_actor.parent, "experiment.pkl") 121 | 122 | # load experiment args 123 | run_args = pickle.load(open(path_to_pkl, "rb")) 124 | # load trained policy 125 | policy = torch.load(path_to_actor, weights_only=False) 126 | critic = torch.load(path_to_critic, weights_only=False) 127 | policy.eval() 128 | critic.eval() 129 | 130 | # load experiment args 131 | run_args = pickle.load(open(path_to_pkl, "rb")) 132 | 133 | # import the correct environment 134 | Env = import_env(run_args.env) 135 | if "yaml" in run_args and run_args.yaml is not None: 136 | yaml_path = Path(run_args.yaml) 137 | else: 138 | yaml_path = None 139 | env = partial(Env, yaml_path)() 140 | 141 | # run 142 | e = EvaluateEnv(env, policy, args) 143 | e.run() 144 | -------------------------------------------------------------------------------- /scripts/debug_stepper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | import pickle 6 | import mujoco 7 | import numpy as np 8 | import transforms3d as tf3 9 | from run_experiment import import_env 10 | 11 | def print_reward(ep_rewards): 12 | mean_rewards = {k:[] for k in ep_rewards[-1].keys()} 13 | print('*********************************') 14 | for key in mean_rewards.keys(): 15 | l = [step[key] for step in ep_rewards] 16 | mean_rewards[key] = sum(l)/len(l) 17 | print(key, ': ', mean_rewards[key]) 18 | #total_rewards = [r for step in ep_rewards for r in step.values()] 19 | print('*********************************') 20 | print("mean per step reward: ", sum(mean_rewards.values())) 21 | 22 | def draw_targets(task, viewer): 23 | # draw step sequence 24 | arrow_size = [0.02, 0.02, 0.5] 25 | sphere = mujoco.mjtGeom.mjGEOM_SPHERE 26 | arrow = mujoco.mjtGeom.mjGEOM_ARROW 27 | if hasattr(task, 'sequence'): 28 | for idx, step in enumerate(task.sequence): 29 | step_pos = [step[0], step[1], step[2]] 30 | step_theta = step[3] 31 | if step_pos not in [task.sequence[task.t1][0:3].tolist(), task.sequence[task.t2][0:3].tolist()]: 32 | viewer.add_marker(pos=step_pos, size=np.ones(3)*0.05, rgba=np.array([0, 1, 1, 1]), type=sphere, label="") 33 | viewer.add_marker(pos=step_pos, mat=tf3.euler.euler2mat(0, np.pi/2, step_theta), size=arrow_size, rgba=np.array([0, 1, 1, 1]), type=arrow, label="") 34 | 35 | target_radius = task.target_radius 36 | step_pos = task.sequence[task.t1][0:3].tolist() 37 | step_theta = task.sequence[task.t1][3] 38 | viewer.add_marker(pos=step_pos, size=np.ones(3)*0.05, rgba=np.array([1, 0, 0, 1]), type=sphere, label="t1") 39 | viewer.add_marker(pos=step_pos, mat=tf3.euler.euler2mat(0, np.pi/2, step_theta), size=arrow_size, rgba=np.array([1, 0, 0, 1]), type=arrow, label="") 40 | viewer.add_marker(pos=step_pos, size=np.ones(3)*target_radius, rgba=np.array([1, 0, 0, 0.1]), type=sphere, label="") 41 | step_pos = task.sequence[task.t2][0:3].tolist() 42 | step_theta = task.sequence[task.t2][3] 43 | viewer.add_marker(pos=step_pos, size=np.ones(3)*0.05, rgba=np.array([0, 0, 1, 1]), type=sphere, label="t2") 44 | viewer.add_marker(pos=step_pos, mat=tf3.euler.euler2mat(0, np.pi/2, step_theta), size=arrow_size, rgba=np.array([0, 0, 1, 1]), type=arrow, label="") 45 | viewer.add_marker(pos=step_pos, size=np.ones(3)*target_radius, rgba=np.array([0, 0, 1, 0.1]), type=sphere, label="") 46 | return 47 | 48 | def draw_stuff(task, viewer): 49 | arrow_size = [0.02, 0.02, 0.5] 50 | sphere = mujoco.mjtGeom.mjGEOM_SPHERE 51 | arrow = mujoco.mjtGeom.mjGEOM_ARROW 52 | 53 | # draw observed targets 54 | goalx = task._goal_steps_x 55 | goaly = task._goal_steps_y 56 | goaltheta = task._goal_steps_theta 57 | viewer.add_marker(pos=[goalx[0], goaly[0], 0], size=np.ones(3)*0.05, rgba=np.array([0, 1, 1, 1]), type=sphere, label="G1") 58 | viewer.add_marker(pos=[goalx[0], goaly[0], 0], mat=tf3.euler.euler2mat(0, np.pi/2, goaltheta[0]), size=arrow_size, rgba=np.array([0, 1, 1, 1]), type=arrow, label="") 59 | viewer.add_marker(pos=[goalx[1], goaly[1], 0], size=np.ones(3)*0.05, rgba=np.array([0, 1, 1, 1]), type=sphere, label="G2") 60 | viewer.add_marker(pos=[goalx[1], goaly[1], 0], mat=tf3.euler.euler2mat(0, np.pi/2, goaltheta[1]), size=arrow_size, rgba=np.array([0, 1, 1, 1]), type=arrow, label="") 61 | 62 | # draw feet pose 63 | lfoot_orient = (tf3.quaternions.quat2mat(task.l_foot_quat)).dot(tf3.euler.euler2mat(0, np.pi/2, 0)) 64 | rfoot_orient = (tf3.quaternions.quat2mat(task.r_foot_quat)).dot(tf3.euler.euler2mat(0, np.pi/2, 0)) 65 | viewer.add_marker(pos=task.l_foot_pos, size=np.ones(3)*0.05, rgba=[0.5, 0.5, 0.5, 1], type=sphere, label="") 66 | viewer.add_marker(pos=task.l_foot_pos, mat=lfoot_orient, size=arrow_size, rgba=[0.5, 0.5, 0.5, 1], type=arrow, label="") 67 | viewer.add_marker(pos=task.r_foot_pos, size=np.ones(3)*0.05, rgba=[0.5, 0.5, 0.5, 1], type=sphere, label="") 68 | viewer.add_marker(pos=task.r_foot_pos, mat=rfoot_orient, size=arrow_size, rgba=[0.5, 0.5, 0.5, 1], type=arrow, label="") 69 | 70 | # draw origin 71 | viewer.add_marker(pos=[0, 0, 0], size=np.ones(3)*0.05, rgba=np.array([1, 1, 1, 1]), type=sphere, label="") 72 | viewer.add_marker(pos=[0, 0, 0], mat=tf3.euler.euler2mat(0, 0, 0), size=[0.01, 0.01, 2], rgba=np.array([0, 0, 1, 0.2]), type=arrow, label="") 73 | viewer.add_marker(pos=[0, 0, 0], mat=tf3.euler.euler2mat(0, np.pi/2, 0), size=[0.01, 0.01, 2], rgba=np.array([1, 0, 0, 0.2]), type=arrow, label="") 74 | viewer.add_marker(pos=[0, 0, 0], mat=tf3.euler.euler2mat(-np.pi/2, np.pi/2, 0), size=[0.01, 0.01, 2], rgba=np.array([0, 1, 0, 0.2]), type=arrow, label="") 75 | return 76 | 77 | def run(env, policy, args): 78 | observation = env.reset() 79 | 80 | env.render() 81 | viewer = env.viewer 82 | viewer._paused = True 83 | done = False 84 | ts, end_ts = 0, 2000 85 | ep_rewards = [] 86 | 87 | while (ts < end_ts): 88 | if hasattr(env, 'frame_skip'): 89 | start = time.time() 90 | 91 | with torch.no_grad(): 92 | action = policy.forward(torch.Tensor(observation), deterministic=True).detach().numpy() 93 | 94 | observation, _, done, info = env.step(action.copy()) 95 | ep_rewards.append(info) 96 | 97 | if env.__class__.__name__ == 'JvrcStepEnv': 98 | draw_targets(env.task, viewer) 99 | draw_stuff(env.task, viewer) 100 | env.render() 101 | 102 | if args.sync and hasattr(env, 'frame_skip'): 103 | end = time.time() 104 | sim_dt = env.robot.client.sim_dt() 105 | delaytime = max(0, env.frame_skip / (1/sim_dt) - (end-start)) 106 | time.sleep(delaytime) 107 | 108 | if args.quit_on_done and done: 109 | break 110 | 111 | ts+=1 112 | 113 | print("Episode finished after {} timesteps".format(ts)) 114 | print_reward(ep_rewards) 115 | env.close() 116 | 117 | def main(): 118 | # get command line arguments 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument("--path", 121 | required=True, 122 | type=str, 123 | help="path to trained model dir", 124 | ) 125 | parser.add_argument("--sync", 126 | required=False, 127 | action="store_true", 128 | help="sync the simulation speed with real time", 129 | ) 130 | parser.add_argument("--quit-on-done", 131 | required=False, 132 | action="store_true", 133 | help="Exit when done condition is reached", 134 | ) 135 | args = parser.parse_args() 136 | 137 | path_to_actor = "" 138 | path_to_pkl = "" 139 | if os.path.isfile(args.path) and args.path.endswith(".pt"): 140 | path_to_actor = args.path 141 | path_to_pkl = os.path.join(os.path.dirname(args.path), "experiment.pkl") 142 | if os.path.isdir(args.path): 143 | path_to_actor = os.path.join(args.path, "actor.pt") 144 | path_to_pkl = os.path.join(args.path, "experiment.pkl") 145 | 146 | # load experiment args 147 | run_args = pickle.load(open(path_to_pkl, "rb")) 148 | # load trained policy 149 | policy = torch.load(path_to_actor) 150 | policy.eval() 151 | # import the correct environment 152 | env = import_env(run_args.env)() 153 | 154 | run(env, policy, args) 155 | print("-----------------------------------------") 156 | 157 | if __name__=='__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohanpsingh/LearningHumanoidWalking/e517f531433777cc59bd33e85a0d0f4884f5b39b/tasks/__init__.py -------------------------------------------------------------------------------- /tasks/rewards.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ############################## 4 | ############################## 5 | # Define reward functions here 6 | ############################## 7 | ############################## 8 | 9 | def _calc_fwd_vel_reward(self): 10 | # forward vel reward 11 | root_vel = self._client.get_qvel()[0] 12 | error = np.linalg.norm(root_vel - self._goal_speed_ref) 13 | return np.exp(-error) 14 | 15 | def _calc_action_reward(self, action, prev_action): 16 | # action reward 17 | penalty = 5 * sum(np.abs(prev_action - action)) / len(action) 18 | return np.exp(-penalty) 19 | 20 | def _calc_torque_reward(self, prev_torque): 21 | # actuator torque reward 22 | torque = np.asarray(self._client.get_act_joint_torques()) 23 | penalty = 0.25 * (sum(np.abs(prev_torque - torque)) / len(torque)) 24 | return np.exp(-penalty) 25 | 26 | def _calc_height_reward(self): 27 | # height reward 28 | if self._client.check_rfoot_floor_collision() or self._client.check_lfoot_floor_collision(): 29 | contact_point = min([c.pos[2] for _,c in (self._client.get_rfoot_floor_contacts() + 30 | self._client.get_lfoot_floor_contacts())]) 31 | else: 32 | contact_point = 0 33 | current_height = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY')[2] 34 | relative_height = current_height - contact_point 35 | error = np.abs(relative_height - self._goal_height_ref) 36 | deadzone_size = 0.01 + 0.05 * self._goal_speed_ref 37 | if error < deadzone_size: 38 | error = 0 39 | return np.exp(-40*np.square(error)) 40 | 41 | def _calc_heading_reward(self): 42 | # heading reward 43 | cur_heading = self._client.get_qvel()[:3] 44 | cur_heading /= np.linalg.norm(cur_heading) 45 | error = np.linalg.norm(cur_heading - np.array([1, 0, 0])) 46 | return np.exp(-error) 47 | 48 | def _calc_root_accel_reward(self): 49 | qvel = self._client.get_qvel() 50 | qacc = self._client.get_qacc() 51 | error = 0.25 * (np.abs(qvel[3:6]).sum() + np.abs(qacc[0:3]).sum()) 52 | return np.exp(-error) 53 | 54 | def _calc_feet_separation_reward(self): 55 | # feet y-separation cost 56 | rfoot_pos = self._client.get_rfoot_body_pos()[1] 57 | lfoot_pos = self._client.get_lfoot_body_pos()[1] 58 | foot_dist = np.abs(rfoot_pos-lfoot_pos) 59 | error = 5*np.square(foot_dist-0.35) 60 | if foot_dist < 0.40 and foot_dist > 0.30: 61 | error = 0 62 | return np.exp(-error) 63 | 64 | def _calc_foot_frc_clock_reward(self, left_frc_fn, right_frc_fn): 65 | # constraints of foot forces based on clock 66 | desired_max_foot_frc = self._client.get_robot_mass()*9.8*0.5 67 | #desired_max_foot_frc = self._client.get_robot_mass()*10*1.2 68 | normed_left_frc = min(self.l_foot_frc, desired_max_foot_frc) / desired_max_foot_frc 69 | normed_right_frc = min(self.r_foot_frc, desired_max_foot_frc) / desired_max_foot_frc 70 | normed_left_frc*=2 71 | normed_left_frc-=1 72 | normed_right_frc*=2 73 | normed_right_frc-=1 74 | 75 | left_frc_clock = left_frc_fn(self._phase) 76 | right_frc_clock = right_frc_fn(self._phase) 77 | 78 | left_frc_score = np.tan(np.pi/4 * left_frc_clock * normed_left_frc) 79 | right_frc_score = np.tan(np.pi/4 * right_frc_clock * normed_right_frc) 80 | 81 | foot_frc_score = (left_frc_score + right_frc_score)/2 82 | return foot_frc_score 83 | 84 | def _calc_foot_vel_clock_reward(self, left_vel_fn, right_vel_fn): 85 | # constraints of foot velocities based on clock 86 | desired_max_foot_vel = 0.2 87 | normed_left_vel = min(np.linalg.norm(self.l_foot_vel), desired_max_foot_vel) / desired_max_foot_vel 88 | normed_right_vel = min(np.linalg.norm(self.r_foot_vel), desired_max_foot_vel) / desired_max_foot_vel 89 | normed_left_vel*=2 90 | normed_left_vel-=1 91 | normed_right_vel*=2 92 | normed_right_vel-=1 93 | 94 | left_vel_clock = left_vel_fn(self._phase) 95 | right_vel_clock = right_vel_fn(self._phase) 96 | 97 | left_vel_score = np.tan(np.pi/4 * left_vel_clock * normed_left_vel) 98 | right_vel_score = np.tan(np.pi/4 * right_vel_clock * normed_right_vel) 99 | 100 | foot_vel_score = (left_vel_score + right_vel_score)/2 101 | return foot_vel_score 102 | 103 | def _calc_foot_pos_clock_reward(self): 104 | # constraints of foot height based on clock 105 | desired_max_foot_height = 0.05 106 | l_foot_pos = self._client.get_object_xpos_by_name('lf_force', 'OBJ_SITE')[2] 107 | r_foot_pos = self._client.get_object_xpos_by_name('rf_force', 'OBJ_SITE')[2] 108 | normed_left_pos = min(np.linalg.norm(l_foot_pos), desired_max_foot_height) / desired_max_foot_height 109 | normed_right_pos = min(np.linalg.norm(r_foot_pos), desired_max_foot_height) / desired_max_foot_height 110 | 111 | left_pos_clock = self.left_clock[1](self._phase) 112 | right_pos_clock = self.right_clock[1](self._phase) 113 | 114 | left_pos_score = np.tan(np.pi/4 * left_pos_clock * normed_left_pos) 115 | right_pos_score = np.tan(np.pi/4 * right_pos_clock * normed_right_pos) 116 | 117 | foot_pos_score = left_pos_score + right_pos_score 118 | return foot_pos_score 119 | 120 | def _calc_body_orient_reward(self, body_name, quat_ref=[1, 0, 0, 0]): 121 | # body orientation reward 122 | body_quat = self._client.get_object_xquat_by_name(body_name, "OBJ_BODY") 123 | target_quat = np.array(quat_ref) 124 | error = 10 * (1 - np.inner(target_quat, body_quat) ** 2) 125 | return np.exp(-error) 126 | 127 | def _calc_joint_vel_reward(self, enabled, cutoff=0.5): 128 | # joint velocity reward 129 | motor_speeds = self._client.get_motor_velocities() 130 | motor_limits = self._client.get_motor_speed_limits() 131 | motor_speeds = [motor_speeds[i] for i in enabled] 132 | motor_limits = [motor_limits[i] for i in enabled] 133 | error = 5e-6*sum([np.square(q) 134 | for q, qmax in zip(motor_speeds, motor_limits) 135 | if np.abs(q)>np.abs(cutoff*qmax)]) 136 | return np.exp(-error) 137 | 138 | 139 | def _calc_joint_acc_reward(self): 140 | # joint accelaration reward 141 | joint_acc_cost = np.sum(np.square(self._client.get_qacc()[-self._num_joints:])) 142 | return self.wp.joint_acc_weight*joint_acc_cost 143 | 144 | def _calc_ang_vel_reward(self): 145 | # angular vel reward 146 | ang_vel = self._client.get_qvel()[3:6] 147 | ang_vel_cost = np.square(np.linalg.norm(ang_vel)) 148 | return self.wp.ang_vel_weight*ang_vel_cost 149 | 150 | def _calc_impact_reward(self): 151 | # contact reward 152 | ncon = len(self._client.get_rfoot_floor_contacts()) + \ 153 | len(self._client.get_lfoot_floor_contactts()) 154 | if ncon==0: 155 | return 0 156 | quad_impact_cost = np.sum(np.square(self._client.get_body_ext_force()))/ncon 157 | return self.wp.impact_weight*quad_impact_cost 158 | 159 | def _calc_zmp_reward(self): 160 | # zmp reward 161 | self.current_zmp = estimate_zmp(self) 162 | if np.linalg.norm(self.current_zmp - self._prev_zmp) > 1: 163 | self.current_zmp = self._prev_zmp 164 | zmp_cost = np.square(np.linalg.norm(self.current_zmp - self.desired_zmp)) 165 | self._prev_zmp = self.current_zmp 166 | return self.wp.zmp_weight*zmp_cost 167 | 168 | def _calc_foot_contact_reward(self): 169 | right_contacts = self._client.get_rfoot_floor_collisions() 170 | left_contacts = self._client.get_lfoot_floor_collisions() 171 | 172 | radius_thresh = 0.3 173 | f_base = self._client.get_qpos()[0:2] 174 | c_dist_r = [(np.linalg.norm(c.pos[0:2] - f_base)) for _, c in right_contacts] 175 | c_dist_l = [(np.linalg.norm(c.pos[0:2] - f_base)) for _, c in left_contacts] 176 | d = sum([r for r in c_dist_r if r > radius_thresh] + 177 | [r for r in c_dist_l if r > radius_thresh]) 178 | return self.wp.foot_contact_weight*d 179 | 180 | def _calc_gait_reward(self): 181 | if self._period<=0: 182 | raise Exception("Cycle period should be greater than zero.") 183 | 184 | # get foot-ground contact force 185 | rfoot_grf = self._client.get_rfoot_grf() 186 | lfoot_grf = self._client.get_lfoot_grf() 187 | 188 | # get foot speed 189 | rfoot_speed = self._client.get_rfoot_body_speed() 190 | lfoot_speed = self._client.get_lfoot_body_speed() 191 | 192 | # get foot position 193 | rfoot_pos = self._client.get_rfoot_body_pos() 194 | lfoot_pos = self._client.get_lfoot_body_pos() 195 | swing_height = 0.3 196 | stance_height = 0.1 197 | 198 | r = 0.5 199 | if self._phase < r: 200 | # right foot is in contact 201 | # left foot is swinging 202 | cost = (0.01*lfoot_grf)# \ 203 | #+ np.square(lfoot_pos[2]-swing_height) 204 | #+ (10*np.square(rfoot_pos[2]-stance_height)) 205 | else: 206 | # left foot is in contact 207 | # right foot is swinging 208 | cost = (0.01*rfoot_grf) 209 | #+ np.square(rfoot_pos[2]-swing_height) 210 | #+ (10*np.square(lfoot_pos[2]-stance_height)) 211 | return self.wp.gait_weight*cost 212 | 213 | def _calc_reference(self): 214 | if self.ref_poses is None: 215 | raise Exception("Reference trajectory not provided.") 216 | 217 | # get reference pose 218 | phase = self._phase 219 | traj_length = self.traj_len 220 | indx = int(phase*(traj_length-1)) 221 | reference_pose = self.ref_poses[indx,:] 222 | 223 | # get current pose 224 | current_pose = np.array(self._client.get_act_joint_positions()) 225 | 226 | cost = np.square(np.linalg.norm(reference_pose-current_pose)) 227 | return self.wp.ref_traj_weight*cost 228 | 229 | ############################## 230 | ############################## 231 | # Define utility functions 232 | ############################## 233 | ############################## 234 | 235 | def estimate_zmp(self): 236 | Gv = 9.80665 237 | Mg = self._mass * Gv 238 | 239 | com_pos = self._sim.data.subtree_com[1].copy() 240 | lin_mom = self._sim.data.subtree_linvel[1].copy()*self._mass 241 | ang_mom = self._sim.data.subtree_angmom[1].copy() + np.cross(com_pos, lin_mom) 242 | 243 | d_lin_mom = (lin_mom - self._prev_lin_mom)/self._control_dt 244 | d_ang_mom = (ang_mom - self._prev_ang_mom)/self._control_dt 245 | 246 | Fgz = d_lin_mom[2] + Mg 247 | 248 | # check contact with floor 249 | contacts = [self._sim.data.contact[i] for i in range(self._sim.data.ncon)] 250 | contact_flag = [(c.geom1==0 or c.geom2==0) for c in contacts] 251 | 252 | if (True in contact_flag) and Fgz > 20: 253 | zmp_x = (Mg*com_pos[0] - d_ang_mom[1])/Fgz 254 | zmp_y = (Mg*com_pos[1] + d_ang_mom[0])/Fgz 255 | else: 256 | zmp_x = com_pos[0] 257 | zmp_y = com_pos[1] 258 | 259 | self._prev_lin_mom = lin_mom 260 | self._prev_ang_mom = ang_mom 261 | return np.array([zmp_x, zmp_y]) 262 | 263 | ############################## 264 | ############################## 265 | # Based on apex 266 | ############################## 267 | ############################## 268 | 269 | def create_phase_reward(swing_duration, stance_duration, strict_relaxer, stance_mode, FREQ=40): 270 | 271 | from scipy.interpolate import PchipInterpolator 272 | 273 | # NOTE: these times are being converted from time in seconds to phaselength 274 | right_swing = np.array([0.0, swing_duration]) * FREQ 275 | first_dblstance = np.array([swing_duration, swing_duration + stance_duration]) * FREQ 276 | left_swing = np.array([swing_duration + stance_duration, 2 * swing_duration + stance_duration]) * FREQ 277 | second_dblstance = np.array([2 * swing_duration + stance_duration, 2 * (swing_duration + stance_duration)]) * FREQ 278 | 279 | r_frc_phase_points = np.zeros((2, 8)) 280 | r_vel_phase_points = np.zeros((2, 8)) 281 | l_frc_phase_points = np.zeros((2, 8)) 282 | l_vel_phase_points = np.zeros((2, 8)) 283 | 284 | right_swing_relax_offset = (right_swing[1] - right_swing[0]) * strict_relaxer 285 | l_frc_phase_points[0,0] = r_frc_phase_points[0,0] = right_swing[0] + right_swing_relax_offset 286 | l_frc_phase_points[0,1] = r_frc_phase_points[0,1] = right_swing[1] - right_swing_relax_offset 287 | l_vel_phase_points[0,0] = r_vel_phase_points[0,0] = right_swing[0] + right_swing_relax_offset 288 | l_vel_phase_points[0,1] = r_vel_phase_points[0,1] = right_swing[1] - right_swing_relax_offset 289 | # During right swing we want foot velocities and don't want foot forces 290 | l_vel_phase_points[1,:2] = r_frc_phase_points[1,:2] = np.negative(np.ones(2)) # penalize l vel and r force 291 | l_frc_phase_points[1,:2] = r_vel_phase_points[1,:2] = np.ones(2) # incentivize l force and r vel 292 | 293 | dbl_stance_relax_offset = (first_dblstance[1] - first_dblstance[0]) * strict_relaxer 294 | l_frc_phase_points[0,2] = r_frc_phase_points[0,2] = first_dblstance[0] + dbl_stance_relax_offset 295 | l_frc_phase_points[0,3] = r_frc_phase_points[0,3] = first_dblstance[1] - dbl_stance_relax_offset 296 | l_vel_phase_points[0,2] = r_vel_phase_points[0,2] = first_dblstance[0] + dbl_stance_relax_offset 297 | l_vel_phase_points[0,3] = r_vel_phase_points[0,3] = first_dblstance[1] - dbl_stance_relax_offset 298 | if stance_mode == "aerial": 299 | # During aerial we want foot velocities and don't want foot forces 300 | # During grounded walking we want foot forces and don't want velocities 301 | l_frc_phase_points[1,2:4] = r_frc_phase_points[1,2:4] = np.negative(np.ones(2)) # penalize l and r foot force 302 | l_vel_phase_points[1,2:4] = r_vel_phase_points[1,2:4] = np.ones(2) # incentivize l and r foot velocity 303 | elif stance_mode == "zero": 304 | l_frc_phase_points[1,2:4] = r_frc_phase_points[1,2:4] = np.zeros(2) 305 | l_vel_phase_points[1,2:4] = r_vel_phase_points[1,2:4] = np.zeros(2) 306 | else: 307 | # During grounded walking we want foot forces and don't want velocities 308 | l_frc_phase_points[1,2:4] = r_frc_phase_points[1,2:4] = np.ones(2) # incentivize l and r foot force 309 | l_vel_phase_points[1,2:4] = r_vel_phase_points[1,2:4] = np.negative(np.ones(2)) # penalize l and r foot velocity 310 | 311 | left_swing_relax_offset = (left_swing[1] - left_swing[0]) * strict_relaxer 312 | l_frc_phase_points[0,4] = r_frc_phase_points[0,4] = left_swing[0] + left_swing_relax_offset 313 | l_frc_phase_points[0,5] = r_frc_phase_points[0,5] = left_swing[1] - left_swing_relax_offset 314 | l_vel_phase_points[0,4] = r_vel_phase_points[0,4] = left_swing[0] + left_swing_relax_offset 315 | l_vel_phase_points[0,5] = r_vel_phase_points[0,5] = left_swing[1] - left_swing_relax_offset 316 | # During left swing we want foot forces and don't want foot velocities (from perspective of right foot) 317 | l_vel_phase_points[1,4:6] = r_frc_phase_points[1,4:6] = np.ones(2) # incentivize l vel and r force 318 | l_frc_phase_points[1,4:6] = r_vel_phase_points[1,4:6] = np.negative(np.ones(2)) # penalize l force and r vel 319 | 320 | dbl_stance_relax_offset = (second_dblstance[1] - second_dblstance[0]) * strict_relaxer 321 | l_frc_phase_points[0,6] = r_frc_phase_points[0,6] = second_dblstance[0] + dbl_stance_relax_offset 322 | l_frc_phase_points[0,7] = r_frc_phase_points[0,7] = second_dblstance[1] - dbl_stance_relax_offset 323 | l_vel_phase_points[0,6] = r_vel_phase_points[0,6] = second_dblstance[0] + dbl_stance_relax_offset 324 | l_vel_phase_points[0,7] = r_vel_phase_points[0,7] = second_dblstance[1] - dbl_stance_relax_offset 325 | if stance_mode == "aerial": 326 | # During aerial we want foot velocities and don't want foot forces 327 | # During grounded walking we want foot forces and don't want velocities 328 | l_frc_phase_points[1,6:] = r_frc_phase_points[1,6:] = np.negative(np.ones(2)) # penalize l and r foot force 329 | l_vel_phase_points[1,6:] = r_vel_phase_points[1,6:] = np.ones(2) # incentivize l and r foot velocity 330 | elif stance_mode == "zero": 331 | l_frc_phase_points[1,6:] = r_frc_phase_points[1,6:] = np.zeros(2) 332 | l_vel_phase_points[1,6:] = r_vel_phase_points[1,6:] = np.zeros(2) 333 | else: 334 | # During grounded walking we want foot forces and don't want velocities 335 | l_frc_phase_points[1,6:] = r_frc_phase_points[1,6:] = np.ones(2) # incentivize l and r foot force 336 | l_vel_phase_points[1,6:] = r_vel_phase_points[1,6:] = np.negative(np.ones(2)) # penalize l and r foot velocity 337 | 338 | ## extend the data to three cycles : one before and one after : this ensures continuity 339 | 340 | r_frc_prev_cycle = np.copy(r_frc_phase_points) 341 | r_vel_prev_cycle = np.copy(r_vel_phase_points) 342 | l_frc_prev_cycle = np.copy(l_frc_phase_points) 343 | l_vel_prev_cycle = np.copy(l_vel_phase_points) 344 | l_frc_prev_cycle[0] = r_frc_prev_cycle[0] = r_frc_phase_points[0] - r_frc_phase_points[0, -1] - dbl_stance_relax_offset 345 | l_vel_prev_cycle[0] = r_vel_prev_cycle[0] = r_vel_phase_points[0] - r_vel_phase_points[0, -1] - dbl_stance_relax_offset 346 | 347 | r_frc_second_cycle = np.copy(r_frc_phase_points) 348 | r_vel_second_cycle = np.copy(r_vel_phase_points) 349 | l_frc_second_cycle = np.copy(l_frc_phase_points) 350 | l_vel_second_cycle = np.copy(l_vel_phase_points) 351 | l_frc_second_cycle[0] = r_frc_second_cycle[0] = r_frc_phase_points[0] + r_frc_phase_points[0, -1] + dbl_stance_relax_offset 352 | l_vel_second_cycle[0] = r_vel_second_cycle[0] = r_vel_phase_points[0] + r_vel_phase_points[0, -1] + dbl_stance_relax_offset 353 | 354 | r_frc_phase_points_repeated = np.hstack((r_frc_prev_cycle, r_frc_phase_points, r_frc_second_cycle)) 355 | r_vel_phase_points_repeated = np.hstack((r_vel_prev_cycle, r_vel_phase_points, r_vel_second_cycle)) 356 | l_frc_phase_points_repeated = np.hstack((l_frc_prev_cycle, l_frc_phase_points, l_frc_second_cycle)) 357 | l_vel_phase_points_repeated = np.hstack((l_vel_prev_cycle, l_vel_phase_points, l_vel_second_cycle)) 358 | 359 | ## Create the smoothing function with cubic spline and cutoff at limits -1 and 1 360 | r_frc_phase_spline = PchipInterpolator(r_frc_phase_points_repeated[0], r_frc_phase_points_repeated[1]) 361 | r_vel_phase_spline = PchipInterpolator(r_vel_phase_points_repeated[0], r_vel_phase_points_repeated[1]) 362 | l_frc_phase_spline = PchipInterpolator(l_frc_phase_points_repeated[0], l_frc_phase_points_repeated[1]) 363 | l_vel_phase_spline = PchipInterpolator(l_vel_phase_points_repeated[0], l_vel_phase_points_repeated[1]) 364 | 365 | return [r_frc_phase_spline, r_vel_phase_spline], [l_frc_phase_spline, l_vel_phase_spline] 366 | -------------------------------------------------------------------------------- /tasks/stepping_task.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import transforms3d as tf3 4 | from tasks import rewards 5 | from enum import Enum, auto 6 | 7 | class WalkModes(Enum): 8 | STANDING = auto() 9 | CURVED = auto() 10 | FORWARD = auto() 11 | BACKWARD = auto() 12 | INPLACE = auto() 13 | LATERAL = auto() 14 | 15 | class SteppingTask(object): 16 | """Bipedal locomotion by stepping on targets.""" 17 | 18 | def __init__(self, 19 | client=None, 20 | dt=0.025, 21 | neutral_foot_orient=[], 22 | root_body='pelvis', 23 | lfoot_body='lfoot', 24 | rfoot_body='rfoot', 25 | head_body='head', 26 | ): 27 | 28 | self._client = client 29 | self._control_dt = dt 30 | 31 | self._mass = self._client.get_robot_mass() 32 | 33 | self._goal_speed_ref = 0 34 | self._goal_height_ref = [] 35 | self._swing_duration = [] 36 | self._stance_duration = [] 37 | self._total_duration = [] 38 | 39 | self._head_body_name = head_body 40 | self._root_body_name = root_body 41 | self._lfoot_body_name = lfoot_body 42 | self._rfoot_body_name = rfoot_body 43 | 44 | # read previously generated footstep plans 45 | with open('utils/footstep_plans.txt', 'r') as fn: 46 | lines = [l.strip() for l in fn.readlines()] 47 | self.plans = [] 48 | sequence = [] 49 | for line in lines: 50 | if line=='---': 51 | if len(sequence): 52 | self.plans.append(sequence) 53 | sequence=[] 54 | continue 55 | else: 56 | sequence.append(np.array([float(l) for l in line.split(',')])) 57 | 58 | def step_reward(self): 59 | target_pos = self.sequence[self.t1][0:3] 60 | foot_dist_to_target = min([np.linalg.norm(ft-target_pos) for ft in [self.l_foot_pos, 61 | self.r_foot_pos]]) 62 | hit_reward = 0 63 | if self.target_reached: 64 | hit_reward = np.exp(-foot_dist_to_target/0.25) 65 | 66 | target_mp = (self.sequence[self.t1][0:2] + self.sequence[self.t2][0:2])/2 67 | root_xy_pos = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY')[0:2] 68 | root_dist_to_target = np.linalg.norm(root_xy_pos-target_mp) 69 | progress_reward = np.exp(-root_dist_to_target/2) 70 | return (0.8*hit_reward + 0.2*progress_reward) 71 | 72 | def calc_reward(self, prev_torque, prev_action, action): 73 | orient = tf3.euler.euler2quat(0, 0, self.sequence[self.t1][3]) 74 | r_frc = self.right_clock[0] 75 | l_frc = self.left_clock[0] 76 | r_vel = self.right_clock[1] 77 | l_vel = self.left_clock[1] 78 | if self.mode == WalkModes.STANDING: 79 | r_frc = (lambda _:1) 80 | l_frc = (lambda _:1) 81 | r_vel = (lambda _:-1) 82 | l_vel = (lambda _:-1) 83 | head_pos = self._client.get_object_xpos_by_name(self._head_body_name, 'OBJ_BODY')[0:2] 84 | root_pos = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY')[0:2] 85 | reward = dict(foot_frc_score=0.150 * rewards._calc_foot_frc_clock_reward(self, l_frc, r_frc), 86 | foot_vel_score=0.150 * rewards._calc_foot_vel_clock_reward(self, l_vel, r_vel), 87 | orient_cost=0.050 * rewards._calc_body_orient_reward(self, 88 | self._root_body_name, 89 | quat_ref=orient), 90 | height_error=0.050 * rewards._calc_height_reward(self), 91 | #torque_penalty=0.050 * rewards._calc_torque_reward(self, prev_torque), 92 | #action_penalty=0.050 * rewards._calc_action_reward(self, prev_action), 93 | step_reward=0.450 * self.step_reward(), 94 | upper_body_reward=0.050 * np.exp(-10*np.square(np.linalg.norm(head_pos-root_pos))) 95 | ) 96 | return reward 97 | 98 | def transform_sequence(self, sequence): 99 | lfoot_pos = self._client.get_lfoot_body_pos() 100 | rfoot_pos = self._client.get_rfoot_body_pos() 101 | root_yaw = tf3.euler.quat2euler(self._client.get_object_xquat_by_name(self._root_body_name, 'OBJ_BODY'))[2] 102 | mid_pt = (lfoot_pos + rfoot_pos)/2 103 | sequence_rel = [] 104 | for x, y, z, theta in sequence: 105 | x_ = mid_pt[0] + x*np.cos(root_yaw) - y*np.sin(root_yaw) 106 | y_ = mid_pt[1] + x*np.sin(root_yaw) + y*np.cos(root_yaw) 107 | theta_ = root_yaw + theta 108 | step = np.array([x_, y_, z, theta_]) 109 | sequence_rel.append(step) 110 | return sequence_rel 111 | 112 | def generate_step_sequence(self, **kwargs): 113 | step_size, step_gap, step_height, num_steps, curved, lateral = kwargs.values() 114 | if curved: 115 | # set 0 height for curved sequences 116 | plan = random.choice(self.plans) 117 | sequence = [[s[0], s[1], 0, s[2]] for s in plan] 118 | return np.array(sequence) 119 | 120 | if lateral: 121 | sequence = [] 122 | y = 0 123 | c = np.random.choice([-1, 1]) 124 | for i in range(1, num_steps): 125 | if i%2: 126 | y += step_size 127 | else: 128 | y -= (2/3)*step_size 129 | step = np.array([0, c*y, 0, 0]) 130 | sequence.append(step) 131 | return sequence 132 | 133 | sequence = [] 134 | if self._phase==(0.5*self._period): 135 | first_step = np.array([0, -1*np.random.uniform(0.095, 0.105), 0, 0]) 136 | y = -step_gap 137 | else: 138 | first_step = np.array([0, 1*np.random.uniform(0.095, 0.105), 0, 0]) 139 | y = step_gap 140 | sequence.append(first_step) 141 | x, z = 0, 0 142 | c = np.random.randint(2, 4) 143 | for i in range(1, num_steps-1): 144 | x += step_size 145 | y *= -1 146 | if i > c: # let height of first few steps equal to 0 147 | z += step_height 148 | step = np.array([x, y, z, 0]) 149 | sequence.append(step) 150 | final_step = np.array([x+step_size, -y, z, 0]) 151 | sequence.append(final_step) 152 | return sequence 153 | 154 | def update_goal_steps(self): 155 | self._goal_steps_x[:] = np.zeros(2) 156 | self._goal_steps_y[:] = np.zeros(2) 157 | self._goal_steps_z[:] = np.zeros(2) 158 | self._goal_steps_theta[:] = np.zeros(2) 159 | root_pos = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY') 160 | root_quat = self._client.get_object_xquat_by_name(self._root_body_name, 'OBJ_BODY') 161 | for idx, t in enumerate([self.t1, self.t2]): 162 | ref_frame = tf3.affines.compose(root_pos, tf3.quaternions.quat2mat(root_quat), np.ones(3)) 163 | abs_goal_pos = self.sequence[t][0:3] 164 | abs_goal_rot = tf3.euler.euler2mat(0, 0, self.sequence[t][3]) 165 | absolute_target = tf3.affines.compose(abs_goal_pos, abs_goal_rot, np.ones(3)) 166 | relative_target = np.linalg.inv(ref_frame).dot(absolute_target) 167 | if self.mode != WalkModes.STANDING: 168 | self._goal_steps_x[idx] = relative_target[0, 3] 169 | self._goal_steps_y[idx] = relative_target[1, 3] 170 | self._goal_steps_z[idx] = relative_target[2, 3] 171 | self._goal_steps_theta[idx] = tf3.euler.mat2euler(relative_target[:3, :3])[2] 172 | return 173 | 174 | def update_target_steps(self): 175 | assert len(self.sequence)>0 176 | self.t1 = self.t2 177 | self.t2+=1 178 | if self.t2==len(self.sequence): 179 | self.t2 = len(self.sequence)-1 180 | return 181 | 182 | def step(self): 183 | # increment phase 184 | self._phase+=1 185 | if self._phase>=self._period: 186 | self._phase=0 187 | 188 | self.l_foot_quat = self._client.get_object_xquat_by_name('lf_force', 'OBJ_SITE') 189 | self.r_foot_quat = self._client.get_object_xquat_by_name('rf_force', 'OBJ_SITE') 190 | self.l_foot_pos = self._client.get_object_xpos_by_name('lf_force', 'OBJ_SITE') 191 | self.r_foot_pos = self._client.get_object_xpos_by_name('rf_force', 'OBJ_SITE') 192 | self.l_foot_vel = self._client.get_lfoot_body_vel()[0] 193 | self.r_foot_vel = self._client.get_rfoot_body_vel()[0] 194 | self.l_foot_frc = self._client.get_lfoot_grf() 195 | self.r_foot_frc = self._client.get_rfoot_grf() 196 | 197 | # check if target reached 198 | target_pos = self.sequence[self.t1][0:3] 199 | foot_dist_to_target = min([np.linalg.norm(ft-target_pos) for ft in [self.l_foot_pos, 200 | self.r_foot_pos]]) 201 | 202 | 203 | lfoot_in_target = (np.linalg.norm(self.l_foot_pos-target_pos) < self.target_radius) 204 | rfoot_in_target = (np.linalg.norm(self.r_foot_pos-target_pos) < self.target_radius) 205 | if lfoot_in_target or rfoot_in_target: 206 | self.target_reached = True 207 | self.target_reached_frames+=1 208 | else: 209 | self.target_reached = False 210 | self.target_reached_frames=0 211 | 212 | # update target steps if needed 213 | if self.target_reached and (self.target_reached_frames>=self.delay_frames): 214 | self.update_target_steps() 215 | self.target_reached = False 216 | self.target_reached_frames = 0 217 | 218 | # update goal 219 | self.update_goal_steps() 220 | return 221 | 222 | def substep(self): 223 | pass 224 | 225 | def done(self): 226 | contact_flag = self._client.check_self_collisions() 227 | 228 | qpos = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY') 229 | foot_pos = min([c[2] for c in (self.l_foot_pos, self.r_foot_pos)]) 230 | root_rel_height = qpos[2] - foot_pos 231 | terminate_conditions = {"qpos[2]_ll":(root_rel_height < 0.6), 232 | "contact_flag":contact_flag, 233 | } 234 | 235 | done = True in terminate_conditions.values() 236 | return done 237 | 238 | def reset(self, iter_count=0): 239 | # training iteration 240 | self.iteration_count = iter_count 241 | 242 | # for steps 243 | self._goal_steps_x = [0, 0] 244 | self._goal_steps_y = [0, 0] 245 | self._goal_steps_z = [0, 0] 246 | self._goal_steps_theta = [0, 0] 247 | 248 | self.target_radius = 0.20 249 | self.delay_frames = int(np.floor(self._swing_duration/self._control_dt)) 250 | self.target_reached = False 251 | self.target_reached_frames = 0 252 | self.t1 = 0 253 | self.t2 = 0 254 | 255 | self.right_clock, self.left_clock = rewards.create_phase_reward(self._swing_duration, 256 | self._stance_duration, 257 | 0.1, 258 | "grounded", 259 | 1/self._control_dt) 260 | 261 | # number of control steps in one full cycle 262 | # (one full cycle includes left swing + right swing) 263 | self._period = np.floor(2*self._total_duration*(1/self._control_dt)) 264 | # randomize phase during initialization 265 | self._phase = int(np.random.choice([0, self._period/2])) 266 | 267 | ## GENERATE STEP SEQUENCE 268 | # select a walking 'mode' 269 | self.mode = np.random.choice( 270 | [WalkModes.CURVED, WalkModes.STANDING, WalkModes.BACKWARD, WalkModes.LATERAL, WalkModes.FORWARD], 271 | p=[0.15, 0.05, 0.2, 0.3, 0.3]) 272 | 273 | d = {'step_size':0.3, 'step_gap':0.15, 'step_height':0, 'num_steps':20, 'curved':False, 'lateral':False} 274 | # generate sequence according to mode 275 | if self.mode == WalkModes.CURVED: 276 | d['curved'] = True 277 | elif self.mode == WalkModes.STANDING: 278 | d['num_steps'] = 1 279 | elif self.mode == WalkModes.BACKWARD: 280 | d['step_size'] = -0.1 281 | elif self.mode == WalkModes.INPLACE: 282 | ss = np.random.uniform(-0.05, 0.05) 283 | d['step_size']=ss 284 | elif self.mode == WalkModes.LATERAL: 285 | d['step_size'] = 0.4 286 | d['lateral'] = True 287 | elif self.mode == WalkModes.FORWARD: 288 | h = np.clip((self.iteration_count-3000)/8000, 0, 1)*0.1 289 | d['step_height']=np.random.choice([-h, h]) 290 | else: 291 | raise Exception("Invalid WalkModes") 292 | sequence = self.generate_step_sequence(**d) 293 | self.sequence = self.transform_sequence(sequence) 294 | self.update_target_steps() 295 | 296 | ## CREATE TERRAIN USING GEOMS 297 | nboxes = 20 298 | boxes = ["box"+repr(i+1).zfill(2) for i in range(nboxes)] 299 | sequence = [np.array([0, 0, -1, 0]) for i in range(nboxes)] 300 | sequence[:len(self.sequence)] = self.sequence 301 | for box, step in zip(boxes, sequence): 302 | box_h = self._client.model.geom(box).size[2] 303 | self._client.model.body(box).pos[:] = step[0:3] - np.array([0, 0, box_h]) 304 | self._client.model.body(box).quat[:] = tf3.euler.euler2quat(0, 0, step[3]) 305 | self._client.model.geom(box).size[:] = np.array([0.15, 1, box_h]) 306 | self._client.model.geom(box).rgba[:] = np.array([0.8, 0.8, 0.8, 1]) 307 | 308 | self._client.model.body('floor').pos[:] = np.array([0, 0, 0]) 309 | if self.mode == WalkModes.FORWARD: 310 | self._client.model.body('floor').pos[:] = np.array([0, 0, -2]) 311 | -------------------------------------------------------------------------------- /tasks/walking_task.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import transforms3d as tf3 3 | from tasks import rewards 4 | 5 | class WalkingTask(object): 6 | """Dynamically stable walking on biped.""" 7 | 8 | def __init__(self, 9 | client=None, 10 | dt=0.025, 11 | neutral_foot_orient=[], 12 | root_body='pelvis', 13 | lfoot_body='lfoot', 14 | rfoot_body='rfoot', 15 | waist_r_joint='waist_r', 16 | waist_p_joint='waist_p', 17 | ): 18 | 19 | self._client = client 20 | self._control_dt = dt 21 | self._neutral_foot_orient=neutral_foot_orient 22 | 23 | self._mass = self._client.get_robot_mass() 24 | 25 | # These depend on the robot, hardcoded for now 26 | # Ideally, they should be arguments to __init__ 27 | self._goal_speed_ref = [] 28 | self._goal_height_ref = [] 29 | self._swing_duration = [] 30 | self._stance_duration = [] 31 | self._total_duration = [] 32 | 33 | self._root_body_name = root_body 34 | self._lfoot_body_name = lfoot_body 35 | self._rfoot_body_name = rfoot_body 36 | 37 | def calc_reward(self, prev_torque, prev_action, action): 38 | self.l_foot_vel = self._client.get_lfoot_body_vel()[0] 39 | self.r_foot_vel = self._client.get_rfoot_body_vel()[0] 40 | self.l_foot_frc = self._client.get_lfoot_grf() 41 | self.r_foot_frc = self._client.get_rfoot_grf() 42 | r_frc = self.right_clock[0] 43 | l_frc = self.left_clock[0] 44 | r_vel = self.right_clock[1] 45 | l_vel = self.left_clock[1] 46 | reward = dict(foot_frc_score=0.150 * rewards._calc_foot_frc_clock_reward(self, l_frc, r_frc), 47 | foot_vel_score=0.150 * rewards._calc_foot_vel_clock_reward(self, l_vel, r_vel), 48 | orient_cost=0.050 * (rewards._calc_body_orient_reward(self, self._lfoot_body_name) + 49 | rewards._calc_body_orient_reward(self, self._rfoot_body_name) + 50 | rewards._calc_body_orient_reward(self, self._root_body_name))/3, 51 | root_accel=0.050 * rewards._calc_root_accel_reward(self), 52 | height_error=0.050 * rewards._calc_height_reward(self), 53 | com_vel_error=0.200 * rewards._calc_fwd_vel_reward(self), 54 | torque_penalty=0.050 * rewards._calc_torque_reward(self, prev_torque), 55 | action_penalty=0.050 * rewards._calc_action_reward(self, action, prev_action), 56 | ) 57 | return reward 58 | 59 | def step(self): 60 | if self._phase>self._period: 61 | self._phase=0 62 | self._phase+=1 63 | return 64 | 65 | def done(self): 66 | contact_flag = self._client.check_self_collisions() 67 | qpos = self._client.get_qpos() 68 | terminate_conditions = {"qpos[2]_ll":(qpos[2] < 0.6), 69 | "qpos[2]_ul":(qpos[2] > 1.4), 70 | "contact_flag":contact_flag, 71 | } 72 | 73 | done = True in terminate_conditions.values() 74 | return done 75 | 76 | def reset(self, iter_count=0): 77 | self._goal_speed_ref = np.random.choice([0, np.random.uniform(0.3, 0.4)]) 78 | self.right_clock, self.left_clock = rewards.create_phase_reward(self._swing_duration, 79 | self._stance_duration, 80 | 0.1, 81 | "grounded", 82 | 1/self._control_dt) 83 | 84 | # number of control steps in one full cycle 85 | # (one full cycle includes left swing + right swing) 86 | self._period = np.floor(2*self._total_duration*(1/self._control_dt)) 87 | # randomize phase during initialization 88 | self._phase = np.random.randint(0, self._period) 89 | --------------------------------------------------------------------------------