├── .gitignore
├── .gitmodules
├── LICENSE.md
├── README.md
├── envs
├── common
│ ├── config_builder.py
│ ├── mujoco_env.py
│ └── robot_interface.py
├── h1
│ ├── __init__.py
│ ├── configs
│ │ └── base.yaml
│ ├── gen_xml.py
│ └── h1_env.py
└── jvrc
│ ├── __init__.py
│ ├── configs
│ └── base.yaml
│ ├── gen_xml.py
│ ├── jvrc_step.py
│ └── jvrc_walk.py
├── models
└── __init__.py
├── rl
├── __init__.py
├── algos
│ ├── __init__.py
│ └── ppo.py
├── distributions
│ ├── __init__.py
│ ├── beta.py
│ └── gaussian.py
├── envs
│ ├── __init__.py
│ ├── normalize.py
│ └── wrappers.py
├── policies
│ ├── __init__.py
│ ├── actor.py
│ ├── base.py
│ └── critic.py
├── storage
│ └── rollout_storage.py
└── utils
│ └── eval.py
├── robots
└── robot_base.py
├── run_experiment.py
├── scripts
└── debug_stepper.py
├── tasks
├── __init__.py
├── rewards.py
├── stepping_task.py
└── walking_task.py
└── utils
└── footstep_plans.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | __pycache__*
3 | *.pyc
4 | *.egg-info
5 | *.#~
6 | logs*
7 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "models/jvrc_mj_description"]
2 | path = models/jvrc_mj_description
3 | url = git@github.com:rohanpsingh/jvrc_mj_description.git
4 | [submodule "models/cassie_mj_description"]
5 | path = models/cassie_mj_description
6 | url = git@github.com:rohanpsingh/cassie_mj_description.git
7 | [submodule "models/mujoco_menagerie"]
8 | path = models/mujoco_menagerie
9 | url = git@github.com:google-deepmind/mujoco_menagerie
10 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2022, Rohan P. Singh
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LearningHumanoidWalking
2 |
3 |
4 |
5 |
6 |
7 | Code for the papers:
8 | - [**Robust Humanoid Walking on Compliant and Uneven Terrain with Deep Reinforcement Learning**](https://ieeexplore.ieee.org/abstract/document/10769793)
9 | [Rohan P. Singh](https://rohanpsingh.github.io), [Mitsuharu Morisawa](https://unit.aist.go.jp/jrl-22022/en/members/member-morisawa.html), [Mehdi Benallegue](https://unit.aist.go.jp/jrl-22022/en/members/member-benalleguem.html), [Zhaoming Xie](https://zhaomingxie.github.io/), [Fumio Kanehiro](https://unit.aist.go.jp/jrl-22022/en/members/member-kanehiro.html)
10 |
11 | - [**Learning Bipedal Walking for Humanoids with Current Feedback**](https://arxiv.org/pdf/2303.03724.pdf)
12 | [Rohan P. Singh](https://rohanpsingh.github.io), [Zhaoming Xie](https://zhaomingxie.github.io/), [Pierre Gergondet](https://unit.aist.go.jp/jrl-22022/en/members/member-gergondet.html), [Fumio Kanehiro](https://unit.aist.go.jp/jrl-22022/en/members/member-kanehiro.html)
13 |
14 | - [**Learning Bipedal Walking On Planned Footsteps For Humanoid Robots**](https://arxiv.org/pdf/2207.12644.pdf)
15 | [Rohan P. Singh](https://rohanpsingh.github.io), [Mehdi Benallegue](https://unit.aist.go.jp/jrl-22022/en/members/member-benalleguem.html), [Mitsuharu Morisawa](https://unit.aist.go.jp/jrl-22022/en/members/member-morisawa.html), [Rafael Cisneros](https://unit.aist.go.jp/jrl-22022/en/members/member-cisneros.html), [Fumio Kanehiro](https://unit.aist.go.jp/jrl-22022/en/members/member-kanehiro.html)
16 |
17 |
18 | ## Code structure:
19 | A rough outline for the repository that might be useful for adding your own robot:
20 | ```
21 | LearningHumanoidWalking/
22 | ├── envs/ <-- Actions and observation space, PD gains, simulation step, control decimation, init, ...
23 | ├── tasks/ <-- Reward function, termination conditions, and more...
24 | ├── rl/ <-- Code for PPO, actor/critic networks, observation normalization process...
25 | ├── models/ <-- MuJoCo model files: XMLs/meshes/textures
26 | └── scripts/ <-- Utility scripts, etc.
27 | ```
28 |
29 | ## Requirements:
30 | - Python version: 3.12.4
31 | - pip install:
32 | - mujoco==3.2.2
33 | - ray==2.40.0
34 | - pytorch=2.5.1
35 | - intel-openmp
36 | - [mujoco-python-viewer](https://github.com/rohanpsingh/mujoco-python-viewer)
37 | - transforms3d
38 | - scipy
39 |
40 | ## Usage:
41 |
42 | Environment names supported:
43 |
44 | | Task Description | Environment name |
45 | | ----------- | ----------- |
46 | | Basic Standing Task | 'h1' |
47 | | Basic Walking Task | 'jvrc_walk' |
48 | | Stepping Task (using footsteps) | 'jvrc_step' |
49 |
50 |
51 | #### **To train:**
52 |
53 | ```
54 | $ python run_experiment.py train --logdir --num_procs --env
55 | ```
56 |
57 |
58 | #### **To play:**
59 |
60 | ```
61 | $ python run_experiment.py eval --logdir
62 | ```
63 |
64 | Or, we could write a rollout script specific to each environment.
65 | For example, `debug_stepper.py` can be used with the `jvrc_step` environment.
66 | ```
67 | $ PYTHONPATH=.:$PYTHONPATH python scripts/debug_stepper.py --path
68 | ```
69 |
70 | #### **What you should see:**
71 |
72 | *Ascending stairs:*
73 | 
74 |
75 | *Descending stairs:*
76 | 
77 |
78 | *Walking on curves:*
79 | 
80 |
81 |
82 | ## Citation
83 | If you find this work useful in your own research, please cite the following works:
84 |
85 | For omnidirectional walking:
86 | ```
87 | @inproceedings{singh2024robust,
88 | title={Robust Humanoid Walking on Compliant and Uneven Terrain with Deep Reinforcement Learning},
89 | author={Singh, Rohan P and Morisawa, Mitsuharu and Benallegue, Mehdi and Xie, Zhaoming and Kanehiro, Fumio},
90 | booktitle={2024 IEEE-RAS 23rd International Conference on Humanoid Robots (Humanoids)},
91 | pages={497--504},
92 | year={2024},
93 | organization={IEEE}
94 | }
95 | ```
96 |
97 | For simulating "back-emf" effect and other randomizations:
98 | ```
99 | @article{xie2023learning,
100 | title={Learning bipedal walking for humanoids with current feedback},
101 | author={Singh, Rohan Pratap and Xie, Zhaoming and Gergondet, Pierre and Kanehiro, Fumio},
102 | journal={IEEE Access},
103 | volume={11},
104 | pages={82013--82023},
105 | year={2023},
106 | publisher={IEEE}
107 | }
108 | ```
109 |
110 | For walking on footsteps:
111 |
112 | ```
113 | @inproceedings{singh2022learning,
114 | title={Learning Bipedal Walking On Planned Footsteps For Humanoid Robots},
115 | author={Singh, Rohan P and Benallegue, Mehdi and Morisawa, Mitsuharu and Cisneros, Rafael and Kanehiro, Fumio},
116 | booktitle={2022 IEEE-RAS 21st International Conference on Humanoid Robots (Humanoids)},
117 | pages={686--693},
118 | year={2022},
119 | organization={IEEE}
120 | }
121 | ```
122 |
123 | ### Credits
124 | The code in this repository was heavily inspired from [apex](https://github.com/osudrl/apex). Clock-based reward terms and some other ideas were originally proposed by the team from OSU DRL for the Cassie robot, so please also consider citing the works of Jonah Siekmann, Helei Duan, Jeremy Dao, and others.
125 |
126 |
--------------------------------------------------------------------------------
/envs/common/config_builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | from typing import Any, Dict, Optional, Union, List
4 |
5 | class Configuration:
6 | """A class to handle configuration data with attribute-style access."""
7 |
8 | def __init__(self, **kwargs: Any) -> None:
9 | """
10 | Initialize a Configuration object with nested attribute access.
11 |
12 | Args:
13 | **kwargs: Key-value pairs to be set as attributes.
14 | """
15 | for key, value in kwargs.items():
16 | if isinstance(value, dict):
17 | setattr(self, key, Configuration(**value))
18 | elif isinstance(value, list) and all(isinstance(item, dict) for item in value):
19 | setattr(self, key, [Configuration(**item) for item in value])
20 | else:
21 | setattr(self, key, value)
22 |
23 | def __repr__(self) -> str:
24 | """Return a string representation of the configuration."""
25 | return str(self.__dict__)
26 |
27 | def __getattr__(self, name: str) -> None:
28 | """Return None for non-existent attributes instead of raising an error."""
29 | return None
30 |
31 | def to_dict(self) -> Dict[str, Any]:
32 | """Convert the Configuration object back to a dictionary."""
33 | result = {}
34 | for key, value in self.__dict__.items():
35 | if isinstance(value, Configuration):
36 | result[key] = value.to_dict()
37 | elif isinstance(value, list) and value and isinstance(value[0], Configuration):
38 | result[key] = [item.to_dict() if isinstance(item, Configuration) else item for item in value]
39 | else:
40 | result[key] = value
41 | return result
42 |
43 | def load_yaml(file_path: str) -> Configuration:
44 | """
45 | Load configuration from a YAML file.
46 |
47 | Args:
48 | file_path: Path to the YAML file.
49 |
50 | Returns:
51 | Configuration object with data from the YAML file.
52 |
53 | Raises:
54 | FileNotFoundError: If the file doesn't exist.
55 | yaml.YAMLError: If there's an error parsing the YAML.
56 | """
57 | if not os.path.exists(file_path):
58 | raise FileNotFoundError(f"Configuration file not found: {file_path}")
59 |
60 | with open(file_path, 'r') as file:
61 | try:
62 | config_data = yaml.safe_load(file)
63 | except yaml.YAMLError as e:
64 | raise yaml.YAMLError(f"Error parsing YAML file {file_path}: {e}")
65 | return Configuration(**config_data)
66 |
--------------------------------------------------------------------------------
/envs/common/mujoco_env.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import os
3 | import numpy as np
4 | import mujoco
5 | import mujoco_viewer
6 |
7 | DEFAULT_SIZE = 500
8 |
9 | class MujocoEnv():
10 | """Superclass for all MuJoCo environments.
11 | """
12 |
13 | def __init__(self, model_path, sim_dt, control_dt):
14 | if model_path.startswith("/"):
15 | fullpath = model_path
16 | else:
17 | raise Exception("Provide full path to robot description package.")
18 | if not os.path.exists(fullpath):
19 | raise IOError("File %s does not exist" % fullpath)
20 |
21 | self.spec = mujoco.MjSpec()
22 | self.spec.from_file(fullpath)
23 | self.model = self.spec.compile()
24 | self.data = mujoco.MjData(self.model)
25 | self.viewer = None
26 |
27 | # set frame skip and sim dt
28 | self.frame_skip = (control_dt/sim_dt)
29 | self.model.opt.timestep = sim_dt
30 |
31 | self.init_qpos = self.data.qpos.ravel().copy()
32 | self.init_qvel = self.data.qvel.ravel().copy()
33 |
34 | # methods to override:
35 | # ----------------------------
36 |
37 | def reset_model(self):
38 | """
39 | Reset the robot degrees of freedom (qpos and qvel).
40 | Implement this in each subclass.
41 | """
42 | raise NotImplementedError
43 |
44 | def viewer_setup(self):
45 | """
46 | This method is called when the viewer is initialized.
47 | Optionally implement this method, if you need to tinker with camera position
48 | and so forth.
49 | """
50 | self.viewer.cam.trackbodyid = 1
51 | self.viewer.cam.distance = self.model.stat.extent * 1.5
52 | self.viewer.cam.lookat[2] = 1.5
53 | self.viewer.cam.lookat[0] = 2.0
54 | self.viewer.cam.elevation = -20
55 | self.viewer.vopt.geomgroup[2] = 0
56 | self.viewer._render_every_frame = True
57 |
58 | def viewer_is_paused(self):
59 | return self.viewer._paused
60 |
61 | # -----------------------------
62 | # (some methods are taken directly from dm_control)
63 |
64 | @contextlib.contextmanager
65 | def disable(self, *flags):
66 | """Context manager for temporarily disabling MuJoCo flags.
67 |
68 | Args:
69 | *flags: Positional arguments specifying flags to disable. Can be either
70 | lowercase strings (e.g. 'gravity', 'contact') or `mjtDisableBit` enum
71 | values.
72 |
73 | Yields:
74 | None
75 |
76 | Raises:
77 | ValueError: If any item in `flags` is neither a valid name nor a value
78 | from `mujoco.mjtDisableBit`.
79 | """
80 | old_bitmask = self.model.opt.disableflags
81 | new_bitmask = old_bitmask
82 | for flag in flags:
83 | if isinstance(flag, str):
84 | try:
85 | field_name = "mjDSBL_" + flag.upper()
86 | flag = getattr(mujoco.mjtDisableBit, field_name)
87 | except AttributeError:
88 | valid_names = [
89 | field_name.split("_")[1].lower()
90 | for field_name in list(mujoco.mjtDisableBit.__members__)[:-1]
91 | ]
92 | raise ValueError("'{}' is not a valid flag name. Valid names: {}"
93 | .format(flag, ", ".join(valid_names))) from None
94 | elif isinstance(flag, int):
95 | flag = mujoco.mjtDisableBit(flag)
96 | new_bitmask |= flag.value
97 | self.model.opt.disableflags = new_bitmask
98 | try:
99 | yield
100 | finally:
101 | self.model.opt.disableflags = old_bitmask
102 |
103 | def reset(self):
104 | mujoco.mj_resetData(self.model, self.data)
105 | ob = self.reset_model()
106 | return ob
107 |
108 | def set_state(self, qpos, qvel):
109 | assert qpos.shape == (self.model.nq,), \
110 | f"qpos shape {qpos.shape} is expected to be {(self.model.nq,)}"
111 | assert qvel.shape == (self.model.nv,), \
112 | f"qvel shape {qvel.shape} is expected to be {(self.model.nv,)}"
113 | self.data.qpos[:] = qpos
114 | self.data.qvel[:] = qvel
115 | self.data.act = []
116 | self.data.plugin_state = []
117 | # Disable actuation since we don't yet have meaningful control inputs.
118 | with self.disable('actuation'):
119 | mujoco.mj_forward(self.model, self.data)
120 |
121 | @property
122 | def dt(self):
123 | return self.model.opt.timestep * self.frame_skip
124 |
125 | def render(self):
126 | if self.viewer is None:
127 | self.viewer = mujoco_viewer.MujocoViewer(self.model, self.data)
128 | self.viewer_setup()
129 | self.viewer.render()
130 |
131 | def uploadGPU(self, hfieldid=None, meshid=None, texid=None):
132 | # hfield
133 | if hfieldid is not None:
134 | mujoco.mjr_uploadHField(self.model, self.viewer.ctx, hfieldid)
135 | # mesh
136 | if meshid is not None:
137 | mujoco.mjr_uploadMesh(self.model, self.viewer.ctx, meshid)
138 | # texture
139 | if texid is not None:
140 | mujoco.mjr_uploadTexture(self.model, self.viewer.ctx, texid)
141 |
142 | def close(self):
143 | if self.viewer is not None:
144 | self.viewer.close()
145 | self.viewer = None
146 |
--------------------------------------------------------------------------------
/envs/common/robot_interface.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import transforms3d as tf3
4 | import mujoco
5 | import torch
6 | import collections
7 |
8 | class RobotInterface(object):
9 | def __init__(self, model, data, rfoot_body_name=None, lfoot_body_name=None,
10 | path_to_nets=None):
11 | self.model = model
12 | self.data = data
13 |
14 | self.rfoot_body_name = rfoot_body_name
15 | self.lfoot_body_name = lfoot_body_name
16 | self.floor_body_name = model.body(0).name
17 | self.robot_root_name = model.body(1).name
18 |
19 | self.stepCounter = 0
20 |
21 | if path_to_nets:
22 | self.load_motor_nets(path_to_nets)
23 |
24 | def load_motor_nets(self, path_to_nets):
25 | self.motor_dyn_nets = {}
26 | for jnt in os.listdir(path_to_nets):
27 | if not os.path.isdir(os.path.join(path_to_nets, jnt)):
28 | continue
29 | net_path = os.path.join(path_to_nets, jnt, "trained_jit.pth")
30 | net = torch.jit.load(net_path)
31 | net.eval()
32 | self.motor_dyn_nets[jnt] = net
33 | self.ctau_buffer = collections.deque(maxlen=25)
34 | self.qdot_buffer = collections.deque(maxlen=25)
35 | return
36 |
37 | def motor_nets_forward(self, cmdTau):
38 | if len(self.ctau_buffer)0)
451 |
452 | def check_lfoot_floor_collision(self):
453 | """
454 | Returns True if there is a collision between left foot and floor.
455 | """
456 | return (len(self.get_lfoot_floor_contacts())>0)
457 |
458 | def check_bad_collisions(self, body_names=[]):
459 | """
460 | Returns True if there are collisions other than specifiedbody-floor,
461 | or feet-floor if body_names is not provided.
462 | """
463 | num_cons = 0
464 | if not isinstance(body_names, list):
465 | raise TypeError(f"expected list of body names, got '{type(body_names).__name__}'")
466 | if not len(body_names):
467 | num_rcons = len(self.get_rfoot_floor_contacts())
468 | num_lcons = len(self.get_lfoot_floor_contacts())
469 | num_cons = num_rcons + num_lcons
470 | for bn in body_names:
471 | num_cons += len(self.get_body_floor_contacts(bn))
472 | return num_cons != self.data.ncon
473 |
474 | def check_self_collisions(self):
475 | """
476 | Returns True if there are collisions other than any-geom-floor.
477 | """
478 | contacts = [self.data.contact[i] for i in range(self.data.ncon)]
479 | for i,c in enumerate(contacts):
480 | geom1_body = self.model.body(self.model.geom_bodyid[c.geom1])
481 | geom2_body = self.model.body(self.model.geom_bodyid[c.geom2])
482 | geom1_is_robot = self.model.body(geom1_body.rootid).name==self.robot_root_name
483 | geom2_is_robot = self.model.body(geom2_body.rootid).name==self.robot_root_name
484 | if geom1_is_robot and geom2_is_robot:
485 | return True
486 | return False
487 |
488 | def set_pd_gains(self, kp, kv):
489 | assert kp.size==self.model.nu
490 | assert kv.size==self.model.nu
491 | self.kp = kp.copy()
492 | self.kv = kv.copy()
493 | return
494 |
495 | def step_pd(self, p, v):
496 | target_angles = p
497 | target_speeds = v
498 |
499 | assert type(target_angles)==np.ndarray
500 | assert type(target_speeds)==np.ndarray
501 |
502 | curr_angles = self.get_act_joint_positions()
503 | curr_speeds = self.get_act_joint_velocities()
504 |
505 | perror = target_angles - curr_angles
506 | verror = target_speeds - curr_speeds
507 |
508 | assert self.kp.size==perror.size
509 | assert self.kv.size==verror.size
510 | return self.kp * perror + self.kv * verror
511 |
512 | def set_motor_torque(self, torque, motor_dyn_fwd = False):
513 | """
514 | Apply torques to motors.
515 | """
516 | if isinstance(torque, np.ndarray):
517 | assert torque.shape==(self.nu(), )
518 | ctrl = torque
519 | elif isinstance(torque, list):
520 | assert len(torque)==self.nu()
521 | ctrl = np.copy(torque)
522 | else:
523 | raise Exception("motor torque should be list of ndarray.")
524 | try:
525 | if motor_dyn_fwd:
526 | if not hasattr(self, 'motor_dyn_nets'):
527 | raise Exception("motor dynamics network are not defined.")
528 | gear = self.get_gear_ratios()
529 | ctrl = self.motor_nets_forward(ctrl*gear)
530 | ctrl /= gear
531 | np.copyto(self.data.ctrl, ctrl)
532 | except Exception as e:
533 | print("Could not apply motor torque.")
534 | print(e)
535 | return
536 |
537 | def step(self, mj_step=True, nstep=1):
538 | """
539 | (Adapted from dm_control/mujoco/engine.py)
540 |
541 | Advances physics with up-to-date position and velocity dependent fields.
542 | Args:
543 | nstep: Optional integer, number of steps to take.
544 | """
545 | if mj_step:
546 | mujoco.mj_step(self.model, self.data, nstep)
547 | self.stepCounter += nstep
548 | return
549 |
550 | # In the case of Euler integration we assume mj_step1 has already been
551 | # called for this state, finish the step with mj_step2 and then update all
552 | # position and velocity related fields with mj_step1. This ensures that
553 | # (most of) mjData is in sync with qpos and qvel. In the case of non-Euler
554 | # integrators (e.g. RK4) an additional mj_step1 must be called after the
555 | # last mj_step to ensure mjData syncing.
556 | if self.model.opt.integrator != mujoco.mjtIntegrator.mjINT_RK4.value:
557 | mujoco.mj_step2(self.model, self.data)
558 | if nstep > 1:
559 | mujoco.mj_step(self.model, self.data, nstep-1)
560 | else:
561 | mujoco.mj_step(self.model, self.data, nstep)
562 |
563 | mujoco.mj_step1(self.model, self.data)
564 |
565 | self.stepCounter += nstep
566 |
--------------------------------------------------------------------------------
/envs/h1/__init__.py:
--------------------------------------------------------------------------------
1 | from envs.h1.h1_env import H1Env
2 |
--------------------------------------------------------------------------------
/envs/h1/configs/base.yaml:
--------------------------------------------------------------------------------
1 | sim_dt: 0.001
2 | control_dt: 0.025
3 | obs_history_len: 1
4 | action_smoothing: 0.5
5 |
6 | ctrllimited: false
7 | jointlimited: false
8 | reduced_xml: true
9 |
10 | init_noise: 3
11 |
12 | pdgains:
13 | "left_hip_yaw_joint": [100, 10]
14 | "left_hip_roll_joint": [100, 10]
15 | "left_hip_pitch_joint": [100, 10]
16 | "left_knee_joint": [100, 10]
17 | "left_ankle_joint": [20, 4]
18 | "right_hip_yaw_joint": [100, 10]
19 | "right_hip_roll_joint": [100, 10]
20 | "right_hip_pitch_joint": [100, 10]
21 | "right_knee_joint": [100, 10]
22 | "right_ankle_joint": [20, 4]
23 | "torso_joint": [40, 4]
24 | "left_shoulder_pitch_joint": [20, 2]
25 | "left_shoulder_roll_joint": [20, 2]
26 | "left_shoulder_yaw_joint": [20, 2]
27 | "left_elbow_joint": [20, 2]
28 | "right_shoulder_pitch_joint": [20, 2]
29 | "right_shoulder_roll_joint": [20, 2]
30 | "right_shoulder_yaw_joint": [20, 2]
31 | "right_elbow_joint": [20, 2]
32 |
33 | observation_noise:
34 | enabled: true
35 | multiplier: 1.0
36 | type: "uniform" # or "gaussian"
37 | scales:
38 | root_orient: 0.05
39 | root_ang_vel: 0.05
40 | motor_pos: 0.02
41 | motor_vel: 0.05
42 | motor_tau: 5.0
43 |
44 | perturbation:
45 | enable: true
46 | bodies: ["pelvis", "torso_link"]
47 | force_magnitude: 10
48 | torque_magnitude: 2
49 | interval: 5
50 |
51 | dynamics_randomization:
52 | enable: true
53 | interval: 0.5
54 |
55 |
--------------------------------------------------------------------------------
/envs/h1/gen_xml.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import models
4 | from dm_control import mjcf
5 | import transforms3d as tf3
6 |
7 | H1_DESCRIPTION_PATH=os.path.join(os.path.dirname(models.__file__), "mujoco_menagerie/unitree_h1/scene.xml")
8 |
9 | LEG_JOINTS = ["left_hip_yaw", "left_hip_roll", "left_hip_pitch", "left_knee", "left_ankle",
10 | "right_hip_yaw", "right_hip_roll", "right_hip_pitch", "right_knee", "right_ankle"]
11 | WAIST_JOINTS = ["torso"]
12 | ARM_JOINTS = ["left_shoulder_pitch", "left_shoulder_roll", "left_shoulder_yaw", "left_elbow",
13 | "right_shoulder_pitch", "right_shoulder_roll", "right_shoulder_yaw", "right_elbow"]
14 |
15 | def create_rangefinder_array(mjcf_model, num_rows=4, num_cols=4, spacing=0.4):
16 | for i in range(num_rows*num_cols):
17 | name = 'rf' + repr(i)
18 | u = (i % num_cols)
19 | v = (i // num_rows)
20 | x = (v - (num_cols - 1) / 2) * spacing
21 | y = ((num_rows - 1) / 2 - u) * (-spacing)
22 | # add sites
23 | mjcf_model.find('body', 'pelvis').add('site',
24 | name=name,
25 | pos=[x, y, 0],
26 | quat='0 1 0 0')
27 |
28 | # add sensors
29 | mjcf_model.sensor.add('rangefinder',
30 | name=name,
31 | site=name)
32 |
33 | return mjcf_model
34 |
35 | def remove_joints_and_actuators(mjcf_model, config):
36 | # remove joints
37 | for limb in config['unused_joints']:
38 | for joint in limb:
39 | mjcf_model.find('joint', joint).remove()
40 |
41 | # remove all actuators with no corresponding joint
42 | for mot in mjcf_model.actuator.motor:
43 | mot.user = None
44 | if mot.joint==None:
45 | mot.remove()
46 | return mjcf_model
47 |
48 | def builder(export_path, config):
49 | print("Modifying XML model...")
50 | mjcf_model = mjcf.from_path(H1_DESCRIPTION_PATH)
51 |
52 | mjcf_model.model = 'h1'
53 |
54 | # modify model
55 | mjcf_model = remove_joints_and_actuators(mjcf_model, config)
56 | mjcf_model.find('default', 'visual').geom.group = 1
57 | mjcf_model.find('default', 'collision').geom.group = 2
58 | if 'ctrllimited' in config:
59 | mjcf_model.find('default', 'h1').motor.ctrllimited = config['ctrllimited']
60 | if 'jointlimited' in config:
61 | mjcf_model.find('default', 'h1').joint.limited = config['jointlimited']
62 |
63 | # rename all motors to fit assumed convention
64 | for mot in mjcf_model.actuator.motor:
65 | mot.name = mot.name + "_motor"
66 |
67 | # remove keyframe for now
68 | mjcf_model.keyframe.remove()
69 |
70 | # remove vis geoms and assets, if needed
71 | if 'minimal' in config:
72 | if config['minimal']:
73 | mjcf_model.find('default', 'collision').geom.group = 1
74 | meshes = mjcf_model.asset.mesh
75 | for mesh in meshes:
76 | mesh.remove()
77 | for geom in mjcf_model.find_all('geom'):
78 | if geom.dclass:
79 | if geom.dclass.dclass=="visual":
80 | geom.remove()
81 |
82 | # set name of freejoint
83 | mjcf_model.find('body', 'pelvis').freejoint.name = 'root'
84 |
85 | # add rangefinder
86 | if 'rangefinder' in config:
87 | if config['rangefinder']:
88 | mjcf_model = create_rangefinder_array(mjcf_model)
89 |
90 | # create a raised platform
91 | if 'raisedplatform' in config:
92 | if config['raisedplatform']:
93 | block_pos = [2.5, 0, 0]
94 | block_size = [3, 3, 1]
95 | name = 'raised-platform'
96 | mjcf_model.worldbody.add('body', name=name, pos=block_pos)
97 | mjcf_model.find('body', name).add('geom', name=name, group='3',
98 | condim='3', friction='.8 .1 .1', size=block_size,
99 | type='box', material='')
100 |
101 | # set some size options
102 | mjcf_model.size.njmax = "-1"
103 | mjcf_model.size.nconmax = "-1"
104 | mjcf_model.size.nuser_actuator = "-1"
105 |
106 | # export model
107 | mjcf.export_with_assets(mjcf_model, out_dir=export_path, precision=5)
108 | path_to_xml = os.path.join(export_path, mjcf_model.model + '.xml')
109 | print("Exporting XML model to ", path_to_xml)
110 | return
111 |
--------------------------------------------------------------------------------
/envs/h1/h1_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import numpy as np
4 | import transforms3d as tf3
5 | import collections
6 |
7 | from robots.robot_base import RobotBase
8 | from envs.common import mujoco_env
9 | from envs.common import robot_interface
10 | from envs.common import config_builder
11 |
12 | from .gen_xml import *
13 |
14 | class Task:
15 | def __init__(self, client, neutral_pose):
16 | self._client = client
17 | self.neutral_pose = neutral_pose
18 |
19 | def calc_reward(self, prev_torque, prev_action, action):
20 | root_pose = self._client.get_object_affine_by_name("pelvis", 'OBJ_BODY')
21 |
22 | # height reward
23 | target_root_h = 0.98
24 | root_h = root_pose[2, 3]
25 | height_error = np.linalg.norm(root_h - target_root_h)
26 |
27 | # upperbody reward
28 | head_pose_offset = np.zeros(2)
29 | head_pose = self._client.get_object_affine_by_name("head_link", 'OBJ_BODY')
30 | head_pos_in_robot_base = np.linalg.inv(root_pose).dot(head_pose)[:2, 3] - head_pose_offset
31 | upperbody_error = np.linalg.norm(head_pos_in_robot_base)
32 |
33 | # posture reward
34 | current_pose = np.array(self._client.get_act_joint_positions())[:10]
35 | posture_error = np.linalg.norm(current_pose - self.neutral_pose)
36 |
37 | # torque reward
38 | tau_error = np.linalg.norm(self._client.get_act_joint_torques())
39 |
40 | # velocity reward
41 | root_vel = self._client.get_body_vel("pelvis", frame=1)[0][:2]
42 | fwd_vel_error = np.linalg.norm(root_vel)
43 | yaw_vel = self._client.get_qvel()[5]
44 | yaw_vel_error = np.linalg.norm(yaw_vel)
45 |
46 | reward = {
47 | "com_vel_error": 0.3 * np.exp(-4 * np.square(fwd_vel_error)),
48 | "yaw_vel_error": 0.3 * np.exp(-4 * np.square(yaw_vel_error)),
49 | "height": 0.1 * np.exp(-0.5 * np.square(height_error)),
50 | "upperbody": 0.1 * np.exp(-40*np.square(upperbody_error)),
51 | "joint_torque_reward": 0.1 * np.exp(-5e-5*np.square(tau_error)),
52 | "posture": 0.1 * np.exp(-1*np.square(posture_error)),
53 | }
54 | return reward
55 |
56 | def step(self):
57 | pass
58 |
59 | def substep(self):
60 | pass
61 |
62 | def done(self):
63 | root_jnt_adr = self._client.model.body("pelvis").jntadr[0]
64 | root_qpos_adr = self._client.model.joint(root_jnt_adr).qposadr[0]
65 | qpos = self._client.get_qpos()[root_qpos_adr:root_qpos_adr+7]
66 | contact_flag = self._client.check_self_collisions()
67 | terminate_conditions = {"qpos[2]_ll":(qpos[2] < 0.9),
68 | "qpos[2]_ul":(qpos[2] > 1.4),
69 | "contact_flag":contact_flag,
70 | }
71 | done = True in terminate_conditions.values()
72 | return done
73 |
74 | def reset(self):
75 | pass
76 |
77 | class H1Env(mujoco_env.MujocoEnv):
78 | def __init__(self, path_to_yaml = None):
79 |
80 | ## Load CONFIG from yaml ##
81 | if path_to_yaml is None:
82 | path_to_yaml = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/base.yaml')
83 |
84 | self.cfg = config_builder.load_yaml(path_to_yaml)
85 |
86 | sim_dt = self.cfg.sim_dt
87 | control_dt = self.cfg.control_dt
88 | frame_skip = (control_dt/sim_dt)
89 |
90 | self.dynrand_interval = self.cfg.dynamics_randomization.interval/control_dt
91 | self.perturb_interval = self.cfg.perturbation.interval/control_dt
92 | self.history_len = self.cfg.obs_history_len
93 |
94 | path_to_xml = '/tmp/mjcf-export/h1/h1.xml'
95 | if not os.path.exists(path_to_xml):
96 | export_dir = os.path.dirname(path_to_xml)
97 | builder(export_dir, config={
98 | 'unused_joints': [WAIST_JOINTS, ARM_JOINTS],
99 | 'rangefinder': False,
100 | 'raisedplatform': False,
101 | 'ctrllimited': self.cfg.ctrllimited,
102 | 'jointlimited': self.cfg.jointlimited,
103 | 'minimal': self.cfg.reduced_xml,
104 | })
105 |
106 | mujoco_env.MujocoEnv.__init__(self, path_to_xml, sim_dt, control_dt)
107 |
108 | # the actual weight of the robot is about 7kg higher
109 | self.model.body("pelvis").mass = 8.89
110 | self.model.body("torso_link").mass = 21.289
111 |
112 | # list of desired actuators
113 | self.leg_names = LEG_JOINTS
114 | gains_dict = self.cfg.pdgains.to_dict()
115 | kp, kd = zip(*[gains_dict[jn] for jn in self.leg_names])
116 | pdgains = np.array([kp, kd])
117 |
118 | # define nominal pose
119 | base_position = [0, 0, 0.98]
120 | base_orientation = [1, 0, 0, 0]
121 | half_sitting_pose = [
122 | 0, 0, -0.2, 0.6, -0.4,
123 | 0, 0, -0.2, 0.6, -0.4,
124 | ]
125 |
126 | # set up interface
127 | self.interface = robot_interface.RobotInterface(self.model, self.data, 'right_ankle_link', 'left_ankle_link', None)
128 | self.nominal_pose = base_position + base_orientation + half_sitting_pose
129 |
130 | # set up task
131 | self.task = Task(self.interface, half_sitting_pose)
132 |
133 | self.robot = RobotBase(pdgains, control_dt, self.interface, self.task)
134 |
135 | # set action space
136 | action_space_size = len(self.leg_names)
137 | action = np.zeros(action_space_size)
138 | self.action_space = np.zeros(action_space_size)
139 | self.prev_prediction = np.zeros(action_space_size)
140 |
141 | # set observation space
142 | self.base_obs_len = 35
143 | self.observation_history = collections.deque(maxlen=self.history_len)
144 | self.observation_space = np.zeros(self.base_obs_len*self.history_len)
145 |
146 | # manually define observation mean and std
147 | self.obs_mean = np.concatenate((
148 | np.zeros(5),
149 | half_sitting_pose, np.zeros(10), np.zeros(10),
150 | ))
151 |
152 | self.obs_std = np.concatenate((
153 | [0.2, 0.2, 1, 1, 1],
154 | 0.5*np.ones(10), 4*np.ones(10), 100*np.ones(10),
155 | ))
156 |
157 | self.obs_mean = np.tile(self.obs_mean, self.history_len)
158 | self.obs_std = np.tile(self.obs_std, self.history_len)
159 |
160 | # copy the original model
161 | self.default_model = copy.deepcopy(self.model)
162 |
163 | def get_obs(self):
164 | # internal state
165 | qpos = np.copy(self.interface.get_qpos())
166 | qvel = np.copy(self.interface.get_qvel())
167 | root_r, root_p = tf3.euler.quat2euler(qpos[3:7])[0:2]
168 | root_r = np.array([root_r])
169 | root_p = np.array([root_p])
170 | root_ang_vel = qvel[3:6]
171 | motor_pos = self.interface.get_act_joint_positions()
172 | motor_vel = self.interface.get_act_joint_velocities()
173 | motor_tau = self.interface.get_act_joint_torques()
174 |
175 | # add some Gaussian noise to observations
176 | if self.cfg.observation_noise.enabled:
177 | scales = self.cfg.observation_noise.scales
178 | level = self.cfg.observation_noise.multiplier
179 |
180 | # add some noise to observations
181 | if self.cfg.observation_noise.enabled:
182 | noise_type = self.cfg.observation_noise.type
183 | scales = self.cfg.observation_noise.scales
184 | level = self.cfg.observation_noise.multiplier
185 | if noise_type=="uniform":
186 | noise = lambda x, n : np.random.uniform(-x, x, n)
187 | elif noise_type=="gaussian":
188 | noise = lambda x, n : np.random.randn(n) * x
189 | else:
190 | raise Exception("Observation noise type can only be \"uniform\" or \"gaussian\"")
191 | root_r += noise(scales.root_orient * level, 1)
192 | root_p += noise(scales.root_orient * level, 1)
193 | root_ang_vel += noise(scales.root_ang_vel * level, len(root_ang_vel))
194 | motor_pos += noise(scales.motor_pos * level, len(motor_pos))
195 | motor_vel += noise(scales.motor_vel * level, len(motor_vel))
196 | motor_tau += noise(scales.motor_tau * level, len(motor_tau))
197 |
198 | robot_state = np.concatenate([
199 | root_r, root_p, root_ang_vel, motor_pos, motor_vel, motor_tau,
200 | ])
201 |
202 | state = robot_state.copy()
203 | assert state.shape==(self.base_obs_len,), \
204 | "State vector length expected to be: {} but is {}".format(self.base_obs_len, len(state))
205 |
206 | if len(self.observation_history)==0:
207 | for _ in range(self.history_len):
208 | self.observation_history.appendleft(np.zeros_like(state))
209 | self.observation_history.appendleft(state)
210 | else:
211 | self.observation_history.appendleft(state)
212 | return np.array(self.observation_history).flatten()
213 |
214 | def step(self, action):
215 | # Compute the applied action to actuators
216 | # (targets <- Policy predictions)
217 | # (offsets <- Half-sitting pose)
218 |
219 | # action vector assumed to be in the following order:
220 | # [leg_0, leg_1, ..., leg_n, waist_0, ..., waist_n, arm_0, arm_1, ..., arm_n]
221 | targets = self.cfg.action_smoothing * action + \
222 | (1 - self.cfg.action_smoothing) * self.prev_prediction
223 | offsets = [
224 | self.nominal_pose[self.interface.get_jnt_qposadr_by_name(jnt)[0]]
225 | for jnt in self.leg_names
226 | ]
227 |
228 | rewards, done = self.robot.step(targets, np.asarray(offsets))
229 | obs = self.get_obs()
230 |
231 | if self.cfg.dynamics_randomization.enable and np.random.randint(self.dynrand_interval)==0:
232 | self.randomize_dyn()
233 |
234 | if self.cfg.perturbation.enable and np.random.randint(self.perturb_interval)==0:
235 | self.randomize_perturb()
236 |
237 | self.prev_prediction = action
238 |
239 | return obs, sum(rewards.values()), done, rewards
240 |
241 | def reset_model(self):
242 | if self.cfg.dynamics_randomization.enable:
243 | self.randomize_dyn()
244 |
245 | init_qpos, init_qvel = self.nominal_pose.copy(), [0] * self.interface.nv()
246 |
247 | # add some initialization noise to root orientation (roll, pitch)
248 | c = self.cfg.init_noise * np.pi/180
249 | root_adr = self.interface.get_jnt_qposadr_by_name('root')[0]
250 | init_qpos[root_adr+2] = np.random.uniform(1.0, 1.02)
251 | init_qpos[root_adr+3:root_adr+7] = tf3.euler.euler2quat(np.random.uniform(-c, c), np.random.uniform(-c, c), 0)
252 | init_qpos[root_adr+7:] += np.random.uniform(-c, c, len(self.leg_names))
253 |
254 | # set up init state
255 | self.set_state(
256 | np.asarray(init_qpos),
257 | np.asarray(init_qvel)
258 | )
259 | # do a few simulation steps to avoid big contact forces in the start
260 | for _ in range(3):
261 | self.interface.step()
262 |
263 | self.task.reset()
264 |
265 | self.prev_prediction = np.zeros_like(self.prev_prediction)
266 | self.observation_history = collections.deque(maxlen=self.history_len)
267 | obs = self.get_obs()
268 | return obs
269 |
270 | #### randomizations and other utility functions ###########
271 | def randomize_perturb(self):
272 | frc_mag = self.cfg.perturbation.force_magnitude
273 | tau_mag = self.cfg.perturbation.force_magnitude
274 | for body in self.cfg.perturbation.bodies:
275 | self.data.body(body).xfrc_applied[:3] = np.random.uniform(-frc_mag, frc_mag, 3)
276 | self.data.body(body).xfrc_applied[3:] = np.random.uniform(-tau_mag, tau_mag, 3)
277 | if np.random.randint(2)==0:
278 | self.data.xfrc_applied = np.zeros_like(self.data.xfrc_applied)
279 |
280 | def randomize_dyn(self):
281 | # dynamics randomization
282 | dofadr = [self.interface.get_jnt_qveladr_by_name(jn)
283 | for jn in self.leg_names]
284 | for jnt in dofadr:
285 | self.model.dof_frictionloss[jnt] = np.random.uniform(0, 2) # actuated joint frictionloss
286 | self.model.dof_damping[jnt] = np.random.uniform(0.02, 2) # actuated joint damping
287 |
288 | # randomize com
289 | bodies = ["pelvis"]
290 | for legjoint in self.leg_names:
291 | bodyid = self.model.joint(legjoint).bodyid
292 | bodyname = self.model.body(bodyid).name
293 | bodies.append(bodyname)
294 |
295 | for body in bodies:
296 | default_mass = self.default_model.body(body).mass[0]
297 | default_ipos = self.default_model.body(body).ipos
298 | self.model.body(body).mass[0] = default_mass*np.random.uniform(0.95, 1.05)
299 | self.model.body(body).ipos = default_ipos + np.random.uniform(-0.01, 0.01, 3)
300 |
301 | def viewer_setup(self):
302 | super().viewer_setup()
303 | self.viewer.cam.distance = 5
304 | self.viewer.cam.lookat[2] = 1.5
305 | self.viewer.cam.lookat[0] = 1.0
306 |
--------------------------------------------------------------------------------
/envs/jvrc/__init__.py:
--------------------------------------------------------------------------------
1 | from envs.jvrc.jvrc_walk import JvrcWalkEnv
2 | from envs.jvrc.jvrc_step import JvrcStepEnv
3 |
--------------------------------------------------------------------------------
/envs/jvrc/configs/base.yaml:
--------------------------------------------------------------------------------
1 | sim_dt: 0.001
2 | control_dt: 0.025
3 | obs_history_len: 1
4 | action_smoothing: 0.5
5 |
6 | kp: [200, 200, 200, 250, 80, 80,
7 | 200, 200, 200, 250, 80, 80]
8 | kd: [20, 20, 20, 25, 8, 8,
9 | 20, 20, 20, 25, 8, 8,]
10 |
--------------------------------------------------------------------------------
/envs/jvrc/gen_xml.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import models
4 | from dm_control import mjcf
5 |
6 | JVRC_DESCRIPTION_PATH=os.path.join(os.path.dirname(models.__file__), "jvrc_mj_description/xml/scene.xml")
7 |
8 | WAIST_JOINTS = ['WAIST_Y', 'WAIST_P', 'WAIST_R']
9 | HEAD_JOINTS = ['NECK_Y', 'NECK_R', 'NECK_P']
10 | HAND_JOINTS = ['R_UTHUMB', 'R_LTHUMB', 'R_UINDEX', 'R_LINDEX', 'R_ULITTLE', 'R_LLITTLE',
11 | 'L_UTHUMB', 'L_LTHUMB', 'L_UINDEX', 'L_LINDEX', 'L_ULITTLE', 'L_LLITTLE']
12 | ARM_JOINTS = ['R_SHOULDER_P', 'R_SHOULDER_R', 'R_SHOULDER_Y', 'R_ELBOW_P', 'R_ELBOW_Y', 'R_WRIST_R', 'R_WRIST_Y',
13 | 'L_SHOULDER_P', 'L_SHOULDER_R', 'L_SHOULDER_Y', 'L_ELBOW_P', 'L_ELBOW_Y', 'L_WRIST_R', 'L_WRIST_Y']
14 | LEG_JOINTS = ['R_HIP_P', 'R_HIP_R', 'R_HIP_Y', 'R_KNEE', 'R_ANKLE_R', 'R_ANKLE_P',
15 | 'L_HIP_P', 'L_HIP_R', 'L_HIP_Y', 'L_KNEE', 'L_ANKLE_R', 'L_ANKLE_P']
16 |
17 |
18 | def builder(export_path, config):
19 | print("Modifying XML model...")
20 | mjcf_model = mjcf.from_path(JVRC_DESCRIPTION_PATH)
21 |
22 | mjcf_model.model = 'jvrc'
23 |
24 | # set njmax and nconmax
25 | mjcf_model.size.njmax = -1
26 | mjcf_model.size.nconmax = -1
27 | mjcf_model.statistic.meansize = 0.1
28 | mjcf_model.statistic.meanmass = 2
29 |
30 | # modify skybox
31 | for tx in mjcf_model.asset.texture:
32 | if tx.type=="skybox":
33 | tx.rgb1 = '1 1 1'
34 | tx.rgb2 = '1 1 1'
35 |
36 | # remove all collisions
37 | mjcf_model.contact.remove()
38 |
39 | # remove actuators except for leg joints
40 | for mot in mjcf_model.actuator.motor:
41 | if mot.joint.name not in LEG_JOINTS:
42 | mot.remove()
43 |
44 | # remove unused joints
45 | for joint in WAIST_JOINTS + HEAD_JOINTS + HAND_JOINTS + ARM_JOINTS:
46 | mjcf_model.find('joint', joint).remove()
47 |
48 | # remove existing equality
49 | mjcf_model.equality.remove()
50 |
51 | # set arm joints to fixed configuration
52 | arm_bodies = {
53 | "R_SHOULDER_P_S":[0, -0.052, 0], "R_SHOULDER_R_S":[-0.17, 0, 0], "R_ELBOW_P_S":[0, -0.524, 0],
54 | "L_SHOULDER_P_S":[0, -0.052, 0], "L_SHOULDER_R_S":[ 0.17, 0, 0], "L_ELBOW_P_S":[0, -0.524, 0],
55 | }
56 | for bname, euler in arm_bodies.items():
57 | mjcf_model.find('body', bname).euler = euler
58 |
59 | # collision geoms
60 | collision_geoms = [
61 | 'R_HIP_R_S', 'R_HIP_Y_S', 'R_KNEE_S',
62 | 'L_HIP_R_S', 'L_HIP_Y_S', 'L_KNEE_S',
63 | ]
64 |
65 | # remove unused collision geoms
66 | for body in mjcf_model.worldbody.find_all('body'):
67 | for idx, geom in enumerate(body.geom):
68 | geom.name = body.name + '-geom-' + repr(idx)
69 | if (geom.dclass.dclass=="collision"):
70 | if body.name not in collision_geoms:
71 | geom.remove()
72 |
73 | # move collision geoms to different group
74 | mjcf_model.default.default['collision'].geom.group = 3
75 |
76 | # manually create collision geom for feet
77 | mjcf_model.worldbody.find('body', 'R_ANKLE_P_S').add('geom', dclass='collision', size='0.1 0.05 0.01', pos='0.029 0 -0.09778', type='box')
78 | mjcf_model.worldbody.find('body', 'L_ANKLE_P_S').add('geom', dclass='collision', size='0.1 0.05 0.01', pos='0.029 0 -0.09778', type='box')
79 |
80 | # ignore collision
81 | mjcf_model.contact.add('exclude', body1='R_KNEE_S', body2='R_ANKLE_P_S')
82 | mjcf_model.contact.add('exclude', body1='L_KNEE_S', body2='L_ANKLE_P_S')
83 |
84 | # remove unused meshes
85 | meshes = [g.mesh.name for g in mjcf_model.find_all('geom') if g.type=='mesh' or g.type==None]
86 | for mesh in mjcf_model.find_all('mesh'):
87 | if mesh.name not in meshes:
88 | mesh.remove()
89 |
90 | # fix site pos
91 | mjcf_model.worldbody.find('site', 'rf_force').pos = '0.03 0.0 -0.1'
92 | mjcf_model.worldbody.find('site', 'lf_force').pos = '0.03 0.0 -0.1'
93 |
94 | # add box geoms
95 | if 'boxes' in config and config['boxes']==True:
96 | for idx in range(20):
97 | name = 'box' + repr(idx+1).zfill(2)
98 | mjcf_model.worldbody.add('body', name=name, pos=[0, 0, -0.2])
99 | mjcf_model.find('body', name).add('geom',
100 | name=name,
101 | dclass='collision',
102 | group='0',
103 | size='1 1 0.1',
104 | type='box',
105 | material='')
106 |
107 | # wrap floor geom in a body
108 | mjcf_model.find('geom', 'floor').remove()
109 | mjcf_model.worldbody.add('body', name='floor')
110 | mjcf_model.find('body', 'floor').add('geom', name='floor', type="plane", size="0 0 0.25", material="groundplane")
111 |
112 | # export model
113 | mjcf.export_with_assets(mjcf_model, out_dir=export_path, precision=5)
114 | path_to_xml = os.path.join(export_path, mjcf_model.model + '.xml')
115 | print("Exporting XML model to ", path_to_xml)
116 | return
117 |
118 | if __name__=='__main__':
119 | builder(sys.argv[1], config={})
120 |
121 |
--------------------------------------------------------------------------------
/envs/jvrc/jvrc_step.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import transforms3d as tf3
4 | import collections
5 |
6 | from tasks import stepping_task
7 | from robots.robot_base import RobotBase
8 | from envs.common import mujoco_env
9 | from envs.common import robot_interface
10 | from envs.common import config_builder
11 | from envs.jvrc.jvrc_walk import JvrcWalkEnv
12 |
13 | from .gen_xml import *
14 |
15 | class JvrcStepEnv(JvrcWalkEnv):
16 | def __init__(self, path_to_yaml = None):
17 |
18 | ## Load CONFIG from yaml ##
19 | if path_to_yaml is None:
20 | path_to_yaml = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/base.yaml')
21 |
22 | self.cfg = config_builder.load_yaml(path_to_yaml)
23 |
24 | sim_dt = self.cfg.sim_dt
25 | control_dt = self.cfg.control_dt
26 | frame_skip = (control_dt/sim_dt)
27 |
28 | self.history_len = self.cfg.obs_history_len
29 |
30 | path_to_xml = '/tmp/mjcf-export/jvrc_step/jvrc.xml'
31 | if not os.path.exists(path_to_xml):
32 | export_dir = os.path.dirname(path_to_xml)
33 | builder(export_dir, config={
34 | "boxes": True,
35 | })
36 |
37 | mujoco_env.MujocoEnv.__init__(self, path_to_xml, sim_dt, control_dt)
38 |
39 | pdgains = np.zeros((2, 12))
40 | pdgains[0] = self.cfg.kp
41 | pdgains[1] = self.cfg.kd
42 |
43 | # list of desired actuators
44 | # RHIP_P, RHIP_R, RHIP_Y, RKNEE, RANKLE_R, RANKLE_P
45 | # LHIP_P, LHIP_R, LHIP_Y, LKNEE, LANKLE_R, LANKLE_P
46 | self.actuators = LEG_JOINTS
47 |
48 | # define nominal pose
49 | base_position = [0, 0, 0.81]
50 | base_orientation = [1, 0, 0, 0]
51 | half_sitting_pose = [-30, 0, 0, 50, 0, -24,
52 | -30, 0, 0, 50, 0, -24,
53 | ] # degrees
54 | self.nominal_pose = base_position + base_orientation + np.deg2rad(half_sitting_pose).tolist()
55 |
56 | # set up interface
57 | self.interface = robot_interface.RobotInterface(self.model, self.data, 'R_ANKLE_P_S', 'L_ANKLE_P_S', None)
58 |
59 | # set up task
60 | self.task = stepping_task.SteppingTask(client=self.interface,
61 | dt=control_dt,
62 | neutral_foot_orient=np.array([1, 0, 0, 0]),
63 | root_body='PELVIS_S',
64 | lfoot_body='L_ANKLE_P_S',
65 | rfoot_body='R_ANKLE_P_S',
66 | head_body='NECK_P_S',
67 | )
68 | # set goal height
69 | self.task._goal_height_ref = 0.80
70 | self.task._total_duration = 1.1
71 | self.task._swing_duration = 0.75
72 | self.task._stance_duration = 0.35
73 |
74 | # set up robot
75 | self.robot = RobotBase(pdgains, control_dt, self.interface, self.task)
76 |
77 | # define indices for action and obs mirror fns
78 | base_mir_obs = [-0.1, 1, # root orient
79 | -2, 3, -4, # root ang vel
80 | 11, -12, -13, 14, -15, 16, # motor pos [1]
81 | 5, -6, -7, 8, -9, 10, # motor pos [2]
82 | 23, -24, -25, 26, -27, 28, # motor vel [1]
83 | 17, -18, -19, 20, -21, 22, # motor vel [2]
84 | ]
85 | append_obs = [(len(base_mir_obs)+i) for i in range(10)]
86 | self.robot.clock_inds = append_obs[0:2]
87 | self.robot.mirrored_obs = np.array(base_mir_obs + append_obs, copy=True).tolist()
88 | self.robot.mirrored_acts = [6, -7, -8, 9, -10, 11,
89 | 0.1, -1, -2, 3, -4, 5,]
90 |
91 | # set action space
92 | action_space_size = len(self.actuators)
93 | action = np.zeros(action_space_size)
94 | self.action_space = np.zeros(action_space_size)
95 | self.prev_prediction = np.zeros(action_space_size)
96 |
97 | # set observation space
98 | self.base_obs_len = 39
99 | self.observation_history = collections.deque(maxlen=self.history_len)
100 | self.observation_space = np.zeros(self.base_obs_len*self.history_len)
101 |
102 | def get_obs(self):
103 | # external state
104 | clock = [np.sin(2 * np.pi * self.task._phase / self.task._period),
105 | np.cos(2 * np.pi * self.task._phase / self.task._period)]
106 | ext_state = np.concatenate((clock,
107 | np.asarray(self.task._goal_steps_x).flatten(),
108 | np.asarray(self.task._goal_steps_y).flatten(),
109 | np.asarray(self.task._goal_steps_z).flatten(),
110 | np.asarray(self.task._goal_steps_theta).flatten()))
111 |
112 | # internal state
113 | qpos = np.copy(self.interface.get_qpos())
114 | qvel = np.copy(self.interface.get_qvel())
115 | root_r, root_p = tf3.euler.quat2euler(qpos[3:7])[0:2]
116 | root_r = np.array([root_r])
117 | root_p = np.array([root_p])
118 | root_ang_vel = qvel[3:6]
119 | motor_pos = self.interface.get_act_joint_positions()
120 | motor_vel = self.interface.get_act_joint_velocities()
121 |
122 | robot_state = np.concatenate([
123 | root_r, root_p, root_ang_vel, motor_pos, motor_vel,
124 | ])
125 |
126 | state = np.concatenate([robot_state, ext_state])
127 | assert state.shape==(self.base_obs_len,), \
128 | "State vector length expected to be: {} but is {}".format(self.base_obs_len, len(state))
129 |
130 | if len(self.observation_history)==0:
131 | for _ in range(self.history_len):
132 | self.observation_history.appendleft(np.zeros_like(state))
133 | self.observation_history.appendleft(state)
134 | else:
135 | self.observation_history.appendleft(state)
136 | return np.array(self.observation_history).flatten()
137 |
--------------------------------------------------------------------------------
/envs/jvrc/jvrc_walk.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import transforms3d as tf3
4 | import collections
5 |
6 | from tasks import walking_task
7 | from robots.robot_base import RobotBase
8 | from envs.common import mujoco_env
9 | from envs.common import robot_interface
10 | from envs.common import config_builder
11 |
12 | from .gen_xml import *
13 |
14 | class JvrcWalkEnv(mujoco_env.MujocoEnv):
15 | def __init__(self, path_to_yaml = None):
16 |
17 | ## Load CONFIG from yaml ##
18 | if path_to_yaml is None:
19 | path_to_yaml = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/base.yaml')
20 |
21 | self.cfg = config_builder.load_yaml(path_to_yaml)
22 |
23 | sim_dt = self.cfg.sim_dt
24 | control_dt = self.cfg.control_dt
25 | frame_skip = (control_dt/sim_dt)
26 |
27 | self.history_len = self.cfg.obs_history_len
28 |
29 | path_to_xml = '/tmp/mjcf-export/jvrc_walk/jvrc.xml'
30 | if not os.path.exists(path_to_xml):
31 | export_dir = os.path.dirname(path_to_xml)
32 | builder(export_dir, config={
33 | })
34 |
35 | mujoco_env.MujocoEnv.__init__(self, path_to_xml, sim_dt, control_dt)
36 |
37 | pdgains = np.zeros((2, 12))
38 | pdgains[0] = self.cfg.kp
39 | pdgains[1] = self.cfg.kd
40 |
41 | # list of desired actuators
42 | # RHIP_P, RHIP_R, RHIP_Y, RKNEE, RANKLE_R, RANKLE_P
43 | # LHIP_P, LHIP_R, LHIP_Y, LKNEE, LANKLE_R, LANKLE_P
44 | self.actuators = LEG_JOINTS
45 |
46 | # define nominal pose
47 | base_position = [0, 0, 0.81]
48 | base_orientation = [1, 0, 0, 0]
49 | half_sitting_pose = [-30, 0, 0, 50, 0, -24,
50 | -30, 0, 0, 50, 0, -24,
51 | ] # degrees
52 | self.nominal_pose = base_position + base_orientation + np.deg2rad(half_sitting_pose).tolist()
53 |
54 | # set up interface
55 | self.interface = robot_interface.RobotInterface(self.model, self.data, 'R_ANKLE_P_S', 'L_ANKLE_P_S', None)
56 |
57 | # set up task
58 | self.task = walking_task.WalkingTask(client=self.interface,
59 | dt=control_dt,
60 | neutral_foot_orient=np.array([1, 0, 0, 0]),
61 | root_body='PELVIS_S',
62 | lfoot_body='L_ANKLE_P_S',
63 | rfoot_body='R_ANKLE_P_S',
64 | )
65 | # set goal height
66 | self.task._goal_height_ref = 0.80
67 | self.task._total_duration = 1.1
68 | self.task._swing_duration = 0.75
69 | self.task._stance_duration = 0.35
70 |
71 | # set up robot
72 | self.robot = RobotBase(pdgains, control_dt, self.interface, self.task)
73 |
74 | # define indices for action and obs mirror fns
75 | base_mir_obs = [-0.1, 1, # root orient
76 | -2, 3, -4, # root ang vel
77 | 11, -12, -13, 14, -15, 16, # motor pos [1]
78 | 5, -6, -7, 8, -9, 10, # motor pos [2]
79 | 23, -24, -25, 26, -27, 28, # motor vel [1]
80 | 17, -18, -19, 20, -21, 22, # motor vel [2]
81 | ]
82 | append_obs = [(len(base_mir_obs)+i) for i in range(3)]
83 | self.robot.clock_inds = append_obs[0:2]
84 | self.robot.mirrored_obs = np.array(base_mir_obs + append_obs, copy=True).tolist()
85 | self.robot.mirrored_acts = [6, -7, -8, 9, -10, 11,
86 | 0.1, -1, -2, 3, -4, 5,]
87 |
88 | # set action space
89 | action_space_size = len(self.actuators)
90 | action = np.zeros(action_space_size)
91 | self.action_space = np.zeros(action_space_size)
92 | self.prev_prediction = np.zeros(action_space_size)
93 |
94 | # set observation space
95 | self.base_obs_len = 32
96 | self.observation_history = collections.deque(maxlen=self.history_len)
97 | self.observation_space = np.zeros(self.base_obs_len*self.history_len)
98 |
99 | # manually define observation mean and std
100 | self.obs_mean = np.concatenate((
101 | np.zeros(5),
102 | np.deg2rad(half_sitting_pose), np.zeros(12),
103 | [0.5, 0.5, 0.5]
104 | ))
105 |
106 | self.obs_std = np.concatenate((
107 | [0.2, 0.2, 1, 1, 1],
108 | 0.5*np.ones(12), 4*np.ones(12),
109 | [1, 1, 1,]
110 | ))
111 |
112 | self.obs_mean = np.tile(self.obs_mean, self.history_len)
113 | self.obs_std = np.tile(self.obs_std, self.history_len)
114 |
115 | def get_obs(self):
116 | # external state
117 | clock = [np.sin(2 * np.pi * self.task._phase / self.task._period),
118 | np.cos(2 * np.pi * self.task._phase / self.task._period)]
119 | ext_state = np.concatenate((clock, [self.task._goal_speed_ref]))
120 |
121 | # internal state
122 | qpos = np.copy(self.interface.get_qpos())
123 | qvel = np.copy(self.interface.get_qvel())
124 | root_r, root_p = tf3.euler.quat2euler(qpos[3:7])[0:2]
125 | root_r = np.array([root_r])
126 | root_p = np.array([root_p])
127 | root_ang_vel = qvel[3:6]
128 | motor_pos = self.interface.get_act_joint_positions()
129 | motor_vel = self.interface.get_act_joint_velocities()
130 |
131 | robot_state = np.concatenate([
132 | root_r, root_p, root_ang_vel, motor_pos, motor_vel,
133 | ])
134 |
135 | state = np.concatenate([robot_state, ext_state])
136 | assert state.shape==(self.base_obs_len,), \
137 | "State vector length expected to be: {} but is {}".format(self.base_obs_len, len(state))
138 |
139 | if len(self.observation_history)==0:
140 | for _ in range(self.history_len):
141 | self.observation_history.appendleft(np.zeros_like(state))
142 | self.observation_history.appendleft(state)
143 | else:
144 | self.observation_history.appendleft(state)
145 | return np.array(self.observation_history).flatten()
146 |
147 | def step(self, action):
148 | # Compute the applied action to actuators
149 | # (targets <- Policy predictions)
150 | targets = self.cfg.action_smoothing * action + \
151 | (1 - self.cfg.action_smoothing) * self.prev_prediction
152 | # (offsets <- Half-sitting pose)
153 | offsets = [
154 | self.nominal_pose[self.interface.get_jnt_qposadr_by_name(jnt)[0]]
155 | for jnt in self.actuators
156 | ]
157 |
158 | rewards, done = self.robot.step(targets, np.asarray(offsets))
159 | obs = self.get_obs()
160 |
161 | self.prev_prediction = action
162 |
163 | return obs, sum(rewards.values()), done, rewards
164 |
165 | def reset_model(self):
166 | init_qpos, init_qvel = self.nominal_pose.copy(), [0] * self.interface.nv()
167 |
168 | # set up init state
169 | self.set_state(
170 | np.asarray(init_qpos),
171 | np.asarray(init_qvel)
172 | )
173 |
174 | self.task.reset(iter_count=self.robot.iteration_count)
175 |
176 | self.prev_prediction = np.zeros_like(self.prev_prediction)
177 | self.observation_history = collections.deque(maxlen=self.history_len)
178 | obs = self.get_obs()
179 | return obs
180 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohanpsingh/LearningHumanoidWalking/e517f531433777cc59bd33e85a0d0f4884f5b39b/models/__init__.py
--------------------------------------------------------------------------------
/rl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohanpsingh/LearningHumanoidWalking/e517f531433777cc59bd33e85a0d0f4884f5b39b/rl/__init__.py
--------------------------------------------------------------------------------
/rl/algos/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/rl/algos/ppo.py:
--------------------------------------------------------------------------------
1 | """Proximal Policy Optimization (clip objective)."""
2 | from copy import deepcopy
3 |
4 | import torch
5 | import torch.optim as optim
6 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
7 | from torch.nn.utils.rnn import pad_sequence
8 | from torch.nn import functional as F
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from pathlib import Path
12 | import sys
13 | import time
14 | import numpy as np
15 | import datetime
16 |
17 | import ray
18 |
19 | from rl.storage.rollout_storage import PPOBuffer
20 | from rl.policies.actor import Gaussian_FF_Actor, Gaussian_LSTM_Actor
21 | from rl.policies.critic import FF_V, LSTM_V
22 | from rl.envs.normalize import get_normalization_params
23 |
24 | class PPO:
25 | def __init__(self, env_fn, args):
26 | self.gamma = args.gamma
27 | self.lam = args.lam
28 | self.lr = args.lr
29 | self.eps = args.eps
30 | self.ent_coeff = args.entropy_coeff
31 | self.clip = args.clip
32 | self.minibatch_size = args.minibatch_size
33 | self.epochs = args.epochs
34 | self.max_traj_len = args.max_traj_len
35 | self.use_gae = args.use_gae
36 | self.n_proc = args.num_procs
37 | self.grad_clip = args.max_grad_norm
38 | self.mirror_coeff = args.mirror_coeff
39 | self.eval_freq = args.eval_freq
40 | self.recurrent = args.recurrent
41 | self.imitate_coeff = args.imitate_coeff
42 |
43 | # batch_size depends on number of parallel envs
44 | self.batch_size = self.n_proc * self.max_traj_len
45 |
46 | self.total_steps = 0
47 | self.highest_reward = -np.inf
48 |
49 | # counter for training iterations
50 | self.iteration_count = 0
51 |
52 | # directory logging and saving weights
53 | self.save_path = Path(args.logdir)
54 | Path.mkdir(self.save_path, parents=True, exist_ok=True)
55 |
56 | # create the summarywriter
57 | self.writer = SummaryWriter(log_dir=self.save_path, flush_secs=10)
58 |
59 | # create networks or load up pretrained
60 | obs_dim = env_fn().observation_space.shape[0]
61 | action_dim = env_fn().action_space.shape[0]
62 | if args.continued:
63 | path_to_actor = args.continued
64 | path_to_critic = Path(args.continued.parent, "critic" + str(args.continued).split('actor')[1])
65 | policy = torch.load(path_to_actor, weights_only=False)
66 | critic = torch.load(path_to_critic, weights_only=False)
67 | # policy action noise parameters are initialized from scratch and not loaded
68 | if args.learn_std:
69 | policy.stds = torch.nn.Parameter(args.std_dev * torch.ones(action_dim))
70 | else:
71 | policy.stds = args.std_dev * torch.ones(action_dim)
72 | print("Loaded (pre-trained) actor from: ", path_to_actor)
73 | print("Loaded (pre-trained) critic from: ", path_to_critic)
74 | else:
75 | if args.recurrent:
76 | policy = Gaussian_LSTM_Actor(obs_dim, action_dim, init_std=args.std_dev,
77 | learn_std=args.learn_std)
78 | critic = LSTM_V(obs_dim)
79 | else:
80 | policy = Gaussian_FF_Actor(obs_dim, action_dim, init_std=args.std_dev,
81 | learn_std=args.learn_std, bounded=False)
82 | critic = FF_V(obs_dim)
83 |
84 | if hasattr(env_fn(), 'obs_mean') and hasattr(env_fn(), 'obs_std'):
85 | obs_mean, obs_std = env_fn().obs_mean, env_fn().obs_std
86 | else:
87 | obs_mean, obs_std = get_normalization_params(iter=args.input_norm_steps,
88 | noise_std=1,
89 | policy=policy,
90 | env_fn=env_fn,
91 | procs=args.num_procs)
92 | with torch.no_grad():
93 | policy.obs_mean, policy.obs_std = map(torch.Tensor, (obs_mean, obs_std))
94 | critic.obs_mean = policy.obs_mean
95 | critic.obs_std = policy.obs_std
96 |
97 | base_policy = None
98 | if args.imitate:
99 | base_policy = torch.load(args.imitate, weights_only=False)
100 |
101 | self.old_policy = deepcopy(policy)
102 | self.policy = policy
103 | self.critic = critic
104 | self.base_policy = base_policy
105 |
106 | @staticmethod
107 | def save(nets, save_path, suffix=""):
108 | filetype = ".pt"
109 | for name, net in nets.items():
110 | path = Path(save_path, name + suffix + filetype)
111 | torch.save(net, path)
112 | print("Saved {} at {}".format(name, path))
113 | return
114 |
115 | @ray.remote
116 | @torch.no_grad()
117 | @staticmethod
118 | def sample(env_fn, policy, critic, gamma, lam, iteration_count, max_steps, max_traj_len, deterministic):
119 | """
120 | Sample max_steps number of total timesteps, truncating
121 | trajectories if they exceed max_traj_len number of timesteps.
122 | """
123 | env = env_fn()
124 | env.robot.iteration_count = iteration_count
125 |
126 | memory = PPOBuffer(policy.state_dim, policy.action_dim, gamma, lam, size=max_traj_len*2)
127 | memory_full = False
128 |
129 | while not memory_full:
130 | state = torch.tensor(env.reset(), dtype=torch.float)
131 | done = False
132 | traj_len = 0
133 |
134 | if hasattr(policy, 'init_hidden_state'):
135 | policy.init_hidden_state()
136 |
137 | if hasattr(critic, 'init_hidden_state'):
138 | critic.init_hidden_state()
139 |
140 | while not done and traj_len < max_traj_len:
141 | action = policy(state, deterministic=deterministic)
142 | value = critic(state)
143 |
144 | next_state, reward, done, _ = env.step(action.numpy().copy())
145 |
146 | reward = torch.tensor(reward, dtype=torch.float)
147 | memory.store(state, action, reward, value, done)
148 | memory_full = (len(memory) >= max_steps)
149 |
150 | state = torch.tensor(next_state, dtype=torch.float)
151 | traj_len += 1
152 |
153 | #if memory_full:
154 | # break
155 |
156 | value = critic(state)
157 | memory.finish_path(last_val=(not done) * value)
158 |
159 | return memory.get_data()
160 |
161 | def sample_parallel(self, *args, deterministic=False):
162 |
163 | max_steps = (self.batch_size // self.n_proc)
164 | worker_args = (self.gamma, self.lam, self.iteration_count, max_steps, self.max_traj_len, deterministic)
165 | args = args + worker_args
166 |
167 | # Create pool of workers, each getting data for min_steps
168 | worker = self.sample
169 | workers = [worker.remote(*args) for _ in range(self.n_proc)]
170 | result = ray.get(workers)
171 |
172 | # Aggregate results
173 | keys = result[0].keys()
174 | aggregated_data = {
175 | k: torch.cat([r[k] for r in result]) for k in keys
176 | }
177 |
178 | class Data:
179 | def __init__(self, data):
180 | for key, value in data.items():
181 | setattr(self, key, value)
182 | data = Data(aggregated_data)
183 | return data
184 |
185 | def update_actor_critic(self, obs_batch, action_batch, return_batch, advantage_batch, mask, mirror_observation=None, mirror_action=None):
186 |
187 | pdf = self.policy.distribution(obs_batch)
188 | log_probs = pdf.log_prob(action_batch).sum(-1, keepdim=True)
189 |
190 | old_pdf = self.old_policy.distribution(obs_batch)
191 | old_log_probs = old_pdf.log_prob(action_batch).sum(-1, keepdim=True)
192 |
193 | # ratio between old and new policy, should be one at the first iteration
194 | ratio = (log_probs - old_log_probs).exp()
195 |
196 | # clipped surrogate loss
197 | cpi_loss = ratio * advantage_batch * mask
198 | clip_loss = ratio.clamp(1.0 - self.clip, 1.0 + self.clip) * advantage_batch * mask
199 | actor_loss = -torch.min(cpi_loss, clip_loss).mean()
200 |
201 | # only used for logging
202 | clip_fraction = torch.mean((torch.abs(ratio - 1) > self.clip).float()).item()
203 |
204 | # Value loss using the TD(gae_lambda) target
205 | values = self.critic(obs_batch)
206 | critic_loss = F.mse_loss(return_batch, values)
207 |
208 | # Entropy loss favor exploration
209 | entropy_penalty = -(pdf.entropy() * mask).mean()
210 |
211 | # Mirror Symmetry Loss
212 | deterministic_actions = self.policy(obs_batch)
213 | if mirror_observation is not None and mirror_action is not None:
214 | if self.recurrent:
215 | mir_obs = torch.stack([mirror_observation(obs_batch[i,:,:]) for i in range(obs_batch.shape[0])])
216 | mirror_actions = self.policy(mir_obs)
217 | else:
218 | mir_obs = mirror_observation(obs_batch)
219 | mirror_actions = self.policy(mir_obs)
220 | mirror_actions = mirror_action(mirror_actions)
221 | mirror_loss = (deterministic_actions - mirror_actions).pow(2).mean()
222 | else:
223 | mirror_loss = torch.zeros_like(actor_loss)
224 |
225 | # imitation loss
226 | if self.base_policy is not None:
227 | imitation_loss = (self.base_policy(obs_batch) - deterministic_actions).pow(2).mean()
228 | else:
229 | imitation_loss = torch.zeros_like(actor_loss)
230 |
231 | # Calculate approximate form of reverse KL Divergence for early stopping
232 | # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
233 | # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
234 | # and Schulman blog: http://joschu.net/blog/kl-approx.html
235 | with torch.no_grad():
236 | log_ratio = log_probs - old_log_probs
237 | approx_kl_div = torch.mean((ratio - 1) - log_ratio)
238 |
239 | self.actor_optimizer.zero_grad()
240 | (actor_loss + self.mirror_coeff*mirror_loss + self.imitate_coeff*imitation_loss + self.ent_coeff*entropy_penalty).backward(retain_graph=True)
241 |
242 | # Clip the gradient norm to prevent "unlucky" minibatches from
243 | # causing pathological updates
244 | torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_clip)
245 | self.actor_optimizer.step()
246 |
247 | self.critic_optimizer.zero_grad()
248 | critic_loss.backward(retain_graph=True)
249 |
250 | # Clip the gradient norm to prevent "unlucky" minibatches from
251 | # causing pathological updates
252 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.grad_clip)
253 | self.critic_optimizer.step()
254 |
255 | return (
256 | actor_loss,
257 | entropy_penalty,
258 | critic_loss,
259 | approx_kl_div,
260 | mirror_loss,
261 | imitation_loss,
262 | clip_fraction,
263 | )
264 |
265 | def evaluate(self, env_fn, nets, itr, num_batches=5):
266 | # set all nets to .eval() mode
267 | for net in nets.values():
268 | net.eval()
269 |
270 | # collect some batches of data
271 | eval_batches = []
272 | for _ in range(num_batches):
273 | batch = self.sample_parallel(env_fn, *nets.values(), deterministic=True)
274 | eval_batches.append(batch)
275 |
276 | # save all the networks
277 | self.save(nets, self.save_path, "_" + repr(itr))
278 |
279 | # save as actor.pt, if it is best
280 | eval_ep_rewards = [float(i) for i in batch.ep_rewards for batch in eval_batches]
281 | avg_eval_ep_rewards = np.mean(eval_ep_rewards)
282 | if self.highest_reward < avg_eval_ep_rewards:
283 | self.highest_reward = avg_eval_ep_rewards
284 | self.save(nets, self.save_path)
285 |
286 | return eval_batches
287 |
288 | def train(self, env_fn, n_itr):
289 |
290 | self.actor_optimizer = optim.Adam(self.policy.parameters(), lr=self.lr, eps=self.eps)
291 | self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.lr, eps=self.eps)
292 |
293 | train_start_time = time.time()
294 |
295 | obs_mirr, act_mirr = None, None
296 | if hasattr(env_fn(), 'mirror_observation'):
297 | obs_mirr = env_fn().mirror_clock_observation
298 |
299 | if hasattr(env_fn(), 'mirror_action'):
300 | act_mirr = env_fn().mirror_action
301 |
302 | for itr in range(n_itr):
303 | print("********** Iteration {} ************".format(itr))
304 |
305 | self.policy.train()
306 | self.critic.train()
307 |
308 | # set iteration count (could be used for curriculum training)
309 | self.iteration_count = itr
310 |
311 | sample_start_time = time.time()
312 | policy_ref = ray.put(self.policy)
313 | critic_ref = ray.put(self.critic)
314 | batch = self.sample_parallel(env_fn, policy_ref, critic_ref)
315 | observations = batch.states.float()
316 | actions = batch.actions.float()
317 | returns = batch.returns.float()
318 | values = batch.values.float()
319 |
320 | num_samples = len(observations)
321 | elapsed = time.time() - sample_start_time
322 | print("Sampling took {:.2f}s for {} steps.".format(elapsed, num_samples))
323 |
324 | # Normalize advantage
325 | advantages = returns - values
326 | advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps)
327 |
328 | minibatch_size = self.minibatch_size or num_samples
329 | self.total_steps += num_samples
330 |
331 | self.old_policy.load_state_dict(self.policy.state_dict())
332 |
333 | optimizer_start_time = time.time()
334 |
335 | actor_losses = []
336 | entropies = []
337 | critic_losses = []
338 | kls = []
339 | mirror_losses = []
340 | imitation_losses = []
341 | clip_fractions = []
342 | for epoch in range(self.epochs):
343 | if self.recurrent:
344 | random_indices = SubsetRandomSampler(range(len(batch.traj_idx)-1))
345 | sampler = BatchSampler(random_indices, minibatch_size, drop_last=False)
346 | else:
347 | random_indices = SubsetRandomSampler(range(num_samples))
348 | sampler = BatchSampler(random_indices, minibatch_size, drop_last=True)
349 |
350 | for indices in sampler:
351 | if self.recurrent:
352 | obs_batch = [observations[int(batch.traj_idx[i]):int(batch.traj_idx[i+1])] for i in indices]
353 | action_batch = [actions[int(batch.traj_idx[i]):int(batch.traj_idx[i+1])] for i in indices]
354 | return_batch = [returns[int(batch.traj_idx[i]):int(batch.traj_idx[i+1])] for i in indices]
355 | advantage_batch = [advantages[int(batch.traj_idx[i]):int(batch.traj_idx[i+1])] for i in indices]
356 | mask = [torch.ones_like(r) for r in return_batch]
357 |
358 | obs_batch = pad_sequence(obs_batch, batch_first=False)
359 | action_batch = pad_sequence(action_batch, batch_first=False)
360 | return_batch = pad_sequence(return_batch, batch_first=False)
361 | advantage_batch = pad_sequence(advantage_batch, batch_first=False)
362 | mask = pad_sequence(mask, batch_first=False)
363 | else:
364 | obs_batch = observations[indices]
365 | action_batch = actions[indices]
366 | return_batch = returns[indices]
367 | advantage_batch = advantages[indices]
368 | mask = 1
369 |
370 | scalars = self.update_actor_critic(obs_batch, action_batch, return_batch, advantage_batch, mask, mirror_observation=obs_mirr, mirror_action=act_mirr)
371 | actor_loss, entropy_penalty, critic_loss, approx_kl_div, mirror_loss, imitation_loss, clip_fraction = scalars
372 |
373 | actor_losses.append(actor_loss.item())
374 | entropies.append(entropy_penalty.item())
375 | critic_losses.append(critic_loss.item())
376 | kls.append(approx_kl_div.item())
377 | mirror_losses.append(mirror_loss.item())
378 | imitation_losses.append(imitation_loss.item())
379 | clip_fractions.append(clip_fraction)
380 |
381 | elapsed = time.time() - optimizer_start_time
382 | print("Optimizer took: {:.2f}s".format(elapsed))
383 |
384 | action_noise = self.policy.stds.data.tolist()
385 |
386 | sys.stdout.write("-" * 37 + "\n")
387 | sys.stdout.write("| %15s | %15s |" % ('Mean Eprew', "%8.5g" % torch.mean(batch.ep_rewards)) + "\n")
388 | sys.stdout.write("| %15s | %15s |" % ('Mean Eplen', "%8.5g" % torch.mean(batch.ep_lens)) + "\n")
389 | sys.stdout.write("| %15s | %15s |" % ('Actor loss', "%8.3g" % np.mean(actor_losses)) + "\n")
390 | sys.stdout.write("| %15s | %15s |" % ('Critic loss', "%8.3g" % np.mean(critic_losses)) + "\n")
391 | sys.stdout.write("| %15s | %15s |" % ('Mirror loss', "%8.3g" % np.mean(mirror_losses)) + "\n")
392 | sys.stdout.write("| %15s | %15s |" % ('Imitation loss', "%8.3g" % np.mean(imitation_losses)) + "\n")
393 | sys.stdout.write("| %15s | %15s |" % ('Mean KL Div', "%8.3g" % np.mean(kls)) + "\n")
394 | sys.stdout.write("| %15s | %15s |" % ('Mean Entropy', "%8.3g" % np.mean(entropies)) + "\n")
395 | sys.stdout.write("| %15s | %15s |" % ('Clip Fraction', "%8.3g" % np.mean(clip_fractions)) + "\n")
396 | sys.stdout.write("| %15s | %15s |" % ('Mean noise std', "%8.3g" % np.mean(action_noise)) + "\n")
397 | sys.stdout.write("-" * 37 + "\n")
398 | sys.stdout.flush()
399 |
400 | elapsed = time.time() - train_start_time
401 | iter_avg = elapsed/(itr+1)
402 | ETA = round((n_itr - itr)*iter_avg)
403 | print("Total time elapsed: {:.2f}s. Total steps: {} (fps={:.2f}. iter-avg={:.2f}s. ETA={})".format(
404 | elapsed, self.total_steps, self.total_steps/elapsed, iter_avg, datetime.timedelta(seconds=ETA)))
405 |
406 | # To save time, perform evaluation only after 100 iters
407 | if itr==0 or (itr+1)%self.eval_freq==0:
408 | nets = {"actor": self.policy, "critic": self.critic}
409 |
410 | evaluate_start = time.time()
411 | eval_batches = self.evaluate(env_fn, nets, itr)
412 | eval_time = time.time() - evaluate_start
413 |
414 | eval_ep_lens = [float(i) for b in eval_batches for i in b.ep_lens]
415 | eval_ep_rewards = [float(i) for b in eval_batches for i in b.ep_rewards]
416 | avg_eval_ep_lens = np.mean(eval_ep_lens)
417 | avg_eval_ep_rewards = np.mean(eval_ep_rewards)
418 | print("====EVALUATE EPISODE====")
419 | print("(Episode length:{:.3f}. Reward:{:.3f}. Time taken:{:.2f}s)".format(
420 | avg_eval_ep_lens, avg_eval_ep_rewards, eval_time))
421 |
422 | # tensorboard logging
423 | self.writer.add_scalar("Eval/mean_reward", avg_eval_ep_rewards, itr)
424 | self.writer.add_scalar("Eval/mean_episode_length", avg_eval_ep_lens, itr)
425 |
426 | # tensorboard logging
427 | self.writer.add_scalar("Loss/actor", np.mean(actor_losses), itr)
428 | self.writer.add_scalar("Loss/critic", np.mean(critic_losses), itr)
429 | self.writer.add_scalar("Loss/mirror", np.mean(mirror_losses), itr)
430 | self.writer.add_scalar("Loss/imitation", np.mean(imitation_losses), itr)
431 | self.writer.add_scalar("Train/mean_reward", torch.mean(batch.ep_rewards), itr)
432 | self.writer.add_scalar("Train/mean_episode_length", torch.mean(batch.ep_lens), itr)
433 | self.writer.add_scalar("Train/mean_noise_std", np.mean(action_noise), itr)
434 |
--------------------------------------------------------------------------------
/rl/distributions/__init__.py:
--------------------------------------------------------------------------------
1 | from .gaussian import DiagonalGaussian
2 | from .beta import Beta, Beta2
--------------------------------------------------------------------------------
/rl/distributions/beta.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | # TODO: extend these for arbitrary bounds
8 |
9 | """A beta distribution, but where the pdf is scaled to (-1, 1)"""
10 | class BoundedBeta(torch.distributions.Beta):
11 | def log_prob(self, x):
12 | return super().log_prob((x + 1) / 2)
13 |
14 | class Beta(nn.Module):
15 | def __init__(self, action_dim):
16 | super(Beta, self).__init__()
17 |
18 | self.action_dim = action_dim
19 |
20 | def forward(self, alpha_beta):
21 | alpha = 1 + F.softplus(alpha_beta[:, :self.action_dim])
22 | beta = 1 + F.softplus(alpha_beta[:, self.action_dim:])
23 | return alpha, beta
24 |
25 | def sample(self, x, deterministic):
26 | if deterministic is False:
27 | action = self.evaluate(x).sample()
28 | else:
29 | # E = alpha / (alpha + beta)
30 | return self.evaluate(x).mean
31 |
32 | return 2 * action - 1
33 |
34 | def evaluate(self, x):
35 | alpha, beta = self(x)
36 | return BoundedBeta(alpha, beta)
37 |
38 |
39 | # TODO: think of a better name for this
40 | """Beta distribution parameterized by mean and variance."""
41 | class Beta2(nn.Module):
42 | def __init__(self, action_dim, init_std=0.25, learn_std=False):
43 | super(Beta2, self).__init__()
44 |
45 | assert init_std < 0.5, "Beta distribution has a max std dev of 0.5"
46 |
47 | self.action_dim = action_dim
48 |
49 | self.logstd = nn.Parameter(
50 | torch.ones(1, action_dim) * np.log(init_std),
51 | requires_grad=learn_std
52 | )
53 |
54 | self.learn_std = learn_std
55 |
56 |
57 | def forward(self, x):
58 | mean = torch.sigmoid(x)
59 |
60 | var = self.logstd.exp().pow(2)
61 |
62 | """
63 | alpha = ((1 - mu) / sigma^2 - 1 / mu) * mu^2
64 | beta = alpha * (1 / mu - 1)
65 |
66 | Implemented slightly differently for numerical stability.
67 | """
68 | alpha = ((1 - mean) / var) * mean.pow(2) - mean
69 | beta = ((1 - mean) / var) * mean - 1 - alpha
70 |
71 | # PROBLEM: if alpha or beta < 1 thats not good
72 |
73 | #assert np.allclose(alpha, ((1 - mean) / var - 1 / mean) * mean.pow(2))
74 | #assert np.allclose(beta, ((1 - mean) / var - 1 / mean) * mean.pow(2) * (1 / mean - 1))
75 |
76 | #alpha = 1 + F.softplus(alpha)
77 | #beta = 1 + F.softplus(beta)
78 |
79 | # print("alpha",alpha)
80 | # print("beta",beta)
81 |
82 | # #print(alpha / (alpha + beta))
83 | # print("mu",mean)
84 |
85 | # #print(torch.sqrt(alpha * beta / ((alpha+beta)**2 * (alpha + beta + 1))))
86 | # print("var", var)
87 |
88 | # import pdb
89 | # pdb.set_trace()
90 |
91 | return alpha, beta
92 |
93 | def sample(self, x, deterministic):
94 | if deterministic is False:
95 | action = self.evaluate(x).sample()
96 | else:
97 | # E = alpha / (alpha + beta)
98 | return self.evaluate(x).mean
99 |
100 | # 2 * a - 1 puts a in (-1, 1)
101 | return 2 * action - 1
102 |
103 | def evaluate(self, x):
104 | alpha, beta = self(x)
105 | return BoundedBeta(alpha, beta)
--------------------------------------------------------------------------------
/rl/distributions/gaussian.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Variable
6 |
7 | # TODO: look at change of variables function for enforcing
8 | # action bounds correctly
9 | class DiagonalGaussian(nn.Module):
10 | def __init__(self, num_outputs, init_std=1, learn_std=True):
11 | super(DiagonalGaussian, self).__init__()
12 |
13 | self.logstd = nn.Parameter(
14 | torch.ones(1, num_outputs) * np.log(init_std),
15 | requires_grad=learn_std
16 | )
17 |
18 | self.learn_std = learn_std
19 |
20 | def forward(self, x):
21 | mean = x
22 |
23 | std = self.logstd.exp()
24 |
25 | return mean, std
26 |
27 | def sample(self, x, deterministic):
28 | if deterministic is False:
29 | action = self.evaluate(x).sample()
30 | else:
31 | action, _ = self(x)
32 |
33 | return action
34 |
35 | def evaluate(self, x):
36 | mean, std = self(x)
37 | return torch.distributions.Normal(mean, std)
38 |
--------------------------------------------------------------------------------
/rl/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .normalize import *
2 | from .wrappers import *
3 |
--------------------------------------------------------------------------------
/rl/envs/normalize.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
2 | # Thanks to the authors + OpenAI for the code
3 |
4 | import numpy as np
5 | import functools
6 | import torch
7 | import ray
8 |
9 | from .wrappers import WrapEnv
10 |
11 | @ray.remote
12 | def _run_random_actions(iter, policy, env_fn, noise_std):
13 |
14 | env = WrapEnv(env_fn)
15 | states = np.zeros((iter, env.observation_space.shape[0]))
16 |
17 | state = env.reset()
18 | for t in range(iter):
19 | states[t, :] = state
20 |
21 | state = torch.Tensor(state)
22 |
23 | action = policy(state)
24 |
25 | # add gaussian noise to deterministic action
26 | action = action + torch.randn(action.size()) * noise_std
27 |
28 | state, _, done, _ = env.step(action.data.numpy())
29 |
30 | if done:
31 | state = env.reset()
32 |
33 | return states
34 |
35 | def get_normalization_params(iter, policy, env_fn, noise_std, procs=4):
36 | print("Gathering input normalization data using {0} steps, noise = {1}...".format(iter, noise_std))
37 |
38 | states_ids = [_run_random_actions.remote(iter // procs, policy, env_fn, noise_std) for _ in range(procs)]
39 |
40 | states = []
41 | for _ in range(procs):
42 | ready_ids, _ = ray.wait(states_ids, num_returns=1)
43 | states.extend(ray.get(ready_ids[0]))
44 | states_ids.remove(ready_ids[0])
45 |
46 | print("Done gathering input normalization data.")
47 |
48 | return np.mean(states, axis=0), np.sqrt(np.var(states, axis=0) + 1e-8)
49 |
50 |
51 | # returns a function that creates a normalized environment, then pre-normalizes it
52 | # using states sampled from a deterministic policy with some added noise
53 | def PreNormalizer(iter, noise_std, policy, *args, **kwargs):
54 |
55 | # noise is gaussian noise
56 | @torch.no_grad()
57 | def pre_normalize(env, policy, num_iter, noise_std):
58 | # save whether or not the environment is configured to do online normalization
59 | online_val = env.online
60 | env.online = True
61 |
62 | state = env.reset()
63 |
64 | for t in range(num_iter):
65 | state = torch.Tensor(state)
66 |
67 | _, action = policy(state)
68 |
69 | # add gaussian noise to deterministic action
70 | action = action + torch.randn(action.size()) * noise_std
71 |
72 | state, _, done, _ = env.step(action.data.numpy())
73 |
74 | if done:
75 | state = env.reset()
76 |
77 | env.online = online_val
78 |
79 | def _Normalizer(venv):
80 | venv = Normalize(venv, *args, **kwargs)
81 |
82 | print("Gathering input normalization data using {0} steps, noise = {1}...".format(iter, noise_std))
83 | pre_normalize(venv, policy, iter, noise_std)
84 | print("Done gathering input normalization data.")
85 |
86 | return venv
87 |
88 | return _Normalizer
89 |
90 | # returns a function that creates a normalized environment
91 | def Normalizer(*args, **kwargs):
92 | def _Normalizer(venv):
93 | return Normalize(venv, *args, **kwargs)
94 |
95 | return _Normalizer
96 |
97 | class Normalize:
98 | """
99 | Vectorized environment base class
100 | """
101 | def __init__(self,
102 | venv,
103 | ob_rms=None,
104 | ob=True,
105 | ret=False,
106 | clipob=10.,
107 | cliprew=10.,
108 | online=True,
109 | gamma=1.0,
110 | epsilon=1e-8):
111 |
112 | self.venv = venv
113 | self._observation_space = venv.observation_space
114 | self._action_space = venv.action_space
115 |
116 | if ob_rms is not None:
117 | self.ob_rms = ob_rms
118 | else:
119 | self.ob_rms = RunningMeanStd(shape=self._observation_space.shape) if ob else None
120 |
121 | self.ret_rms = RunningMeanStd(shape=()) if ret else None
122 | self.clipob = clipob
123 | self.cliprew = cliprew
124 | self.ret = np.zeros(self.num_envs)
125 | self.gamma = gamma
126 | self.epsilon = epsilon
127 |
128 | self.online = online
129 |
130 | def step(self, vac):
131 | obs, rews, news, infos = self.venv.step(vac)
132 |
133 | #self.ret = self.ret * self.gamma + rews
134 | obs = self._obfilt(obs)
135 |
136 | # NOTE: shifting mean of reward seems bad; qualitatively changes MDP
137 | if self.ret_rms:
138 | if self.online:
139 | self.ret_rms.update(self.ret)
140 |
141 | rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
142 |
143 | return obs, rews, news, infos
144 |
145 | def _obfilt(self, obs):
146 | if self.ob_rms:
147 | if self.online:
148 | self.ob_rms.update(obs)
149 |
150 | obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
151 | return obs
152 | else:
153 | return obs
154 |
155 | def reset(self):
156 | """
157 | Reset all environments
158 | """
159 | obs = self.venv.reset()
160 | return self._obfilt(obs)
161 |
162 | @property
163 | def action_space(self):
164 | return self._action_space
165 |
166 | @property
167 | def observation_space(self):
168 | return self._observation_space
169 |
170 | def close(self):
171 | self.venv.close()
172 |
173 | def render(self):
174 | self.venv.render()
175 |
176 | @property
177 | def num_envs(self):
178 | return self.venv.num_envs
179 |
180 |
181 |
182 | class RunningMeanStd(object):
183 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
184 | def __init__(self, epsilon=1e-4, shape=()):
185 | self.mean = np.zeros(shape, 'float64')
186 | self.var = np.zeros(shape, 'float64')
187 | self.count = epsilon
188 |
189 |
190 | def update(self, x):
191 | batch_mean = np.mean(x, axis=0)
192 | batch_var = np.var(x, axis=0)
193 | batch_count = x.shape[0]
194 |
195 | delta = batch_mean - self.mean
196 | tot_count = self.count + batch_count
197 |
198 | new_mean = self.mean + delta * batch_count / tot_count
199 | m_a = self.var * (self.count)
200 | m_b = batch_var * (batch_count)
201 | M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
202 | new_var = M2 / (self.count + batch_count)
203 |
204 | new_count = batch_count + self.count
205 |
206 | self.mean = new_mean
207 | self.var = new_var
208 | self.count = new_count
209 |
210 | def test_runningmeanstd():
211 | for (x1, x2, x3) in [
212 | (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
213 | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
214 | ]:
215 |
216 | rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
217 |
218 | x = np.concatenate([x1, x2, x3], axis=0)
219 | ms1 = [x.mean(axis=0), x.var(axis=0)]
220 | rms.update(x1)
221 | rms.update(x2)
222 | rms.update(x3)
223 | ms2 = [rms.mean, rms.var]
224 |
225 | assert np.allclose(ms1, ms2)
226 |
--------------------------------------------------------------------------------
/rl/envs/wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | # Gives a vectorized interface to a single environment
5 | class WrapEnv:
6 | def __init__(self, env_fn):
7 | self.env = env_fn()
8 |
9 | def __getattr__(self, attr):
10 | return getattr(self.env, attr)
11 |
12 | def step(self, action):
13 | state, reward, done, info = self.env.step(action[0])
14 | return np.array([state]), np.array([reward]), np.array([done]), np.array([info])
15 |
16 | def render(self):
17 | self.env.render()
18 |
19 | def reset(self):
20 | return np.array([self.env.reset()])
21 |
22 | # TODO: this is probably a better case for inheritance than for a wrapper
23 | # Gives an interface to exploit mirror symmetry
24 | class SymmetricEnv:
25 | def __init__(self, env_fn, mirrored_obs=None, mirrored_act=None, clock_inds=None, obs_fn=None, act_fn=None):
26 |
27 | assert (bool(mirrored_act) ^ bool(act_fn)) and (bool(mirrored_obs) ^ bool(obs_fn)), \
28 | "You must provide either mirror indices or a mirror function, but not both, for \
29 | observation and action."
30 |
31 | if mirrored_act:
32 | self.act_mirror_matrix = torch.Tensor(_get_symmetry_matrix(mirrored_act))
33 |
34 | elif act_fn:
35 | assert callable(act_fn), "Action mirror function must be callable"
36 | self.mirror_action = act_fn
37 |
38 | if mirrored_obs:
39 | self.obs_mirror_matrix = torch.Tensor(_get_symmetry_matrix(mirrored_obs))
40 |
41 | elif obs_fn:
42 | assert callable(obs_fn), "Observation mirror function must be callable"
43 | self.mirror_observation = obs_fn
44 |
45 | self.clock_inds = clock_inds
46 | self.env = env_fn()
47 |
48 | def __getattr__(self, attr):
49 | return getattr(self.env, attr)
50 |
51 | def mirror_action(self, action):
52 | return action @ self.act_mirror_matrix
53 |
54 | def mirror_observation(self, obs):
55 | return obs @ self.obs_mirror_matrix
56 |
57 | # To be used when there is a clock in the observation. In this case, the mirrored_obs vector inputted
58 | # when the SymmeticEnv is created should not move the clock input order. The indices of the obs vector
59 | # where the clocks are located need to be inputted.
60 | def mirror_clock_observation(self, obs):
61 | # print("obs.shape = ", obs.shape)
62 | # print("obs_mirror_matrix.shape = ", self.obs_mirror_matrix.shape)
63 | mirror_obs_batch = torch.zeros_like(obs)
64 | history_len = 1 # FIX HISTORY-OF-STATES LENGTH TO 1 FOR NOW
65 | for block in range(history_len):
66 | obs_ = obs[:, self.base_obs_len*block : self.base_obs_len*(block+1)]
67 | mirror_obs = obs_ @ self.obs_mirror_matrix
68 | clock = mirror_obs[:, self.clock_inds]
69 | for i in range(np.shape(clock)[1]):
70 | mirror_obs[:, self.clock_inds[i]] = np.sin(np.arcsin(clock[:, i]) + np.pi)
71 | mirror_obs_batch[:, self.base_obs_len*block : self.base_obs_len*(block+1)] = mirror_obs
72 | return mirror_obs_batch
73 |
74 |
75 | def _get_symmetry_matrix(mirrored):
76 | numel = len(mirrored)
77 | mat = np.zeros((numel, numel))
78 |
79 | for (i, j) in zip(np.arange(numel), np.abs(np.array(mirrored).astype(int))):
80 | mat[i, j] = np.sign(mirrored[i])
81 |
82 | return mat
83 |
--------------------------------------------------------------------------------
/rl/policies/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor import Gaussian_FF_Actor
2 |
--------------------------------------------------------------------------------
/rl/policies/actor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch import sqrt
6 |
7 | from rl.policies.base import Net
8 |
9 | class Actor(Net):
10 | def __init__(self):
11 | super(Actor, self).__init__()
12 |
13 | def forward(self):
14 | raise NotImplementedError
15 |
16 | class Linear_Actor(Actor):
17 | def __init__(self, state_dim, action_dim, hidden_size=32):
18 | super(Linear_Actor, self).__init__()
19 |
20 | self.l1 = nn.Linear(state_dim, hidden_size)
21 | self.l2 = nn.Linear(hidden_size, action_dim)
22 |
23 | self.action_dim = action_dim
24 |
25 | for p in self.parameters():
26 | p.data = torch.zeros(p.shape)
27 |
28 | def forward(self, state):
29 | a = self.l1(state)
30 | a = self.l2(a)
31 | return a
32 |
33 | class FF_Actor(Actor):
34 | def __init__(self, state_dim, action_dim, layers=(256, 256), nonlinearity=F.relu):
35 | super(FF_Actor, self).__init__()
36 |
37 | self.actor_layers = nn.ModuleList()
38 | self.actor_layers += [nn.Linear(state_dim, layers[0])]
39 | for i in range(len(layers)-1):
40 | self.actor_layers += [nn.Linear(layers[i], layers[i+1])]
41 | self.network_out = nn.Linear(layers[-1], action_dim)
42 |
43 | self.action_dim = action_dim
44 | self.nonlinearity = nonlinearity
45 |
46 | self.initialize_parameters()
47 |
48 | def forward(self, state, deterministic=True):
49 | x = state
50 | for idx, layer in enumerate(self.actor_layers):
51 | x = self.nonlinearity(layer(x))
52 |
53 | action = torch.tanh(self.network_out(x))
54 | return action
55 |
56 |
57 | class LSTM_Actor(Actor):
58 | def __init__(self, state_dim, action_dim, layers=(128, 128), nonlinearity=torch.tanh):
59 | super(LSTM_Actor, self).__init__()
60 |
61 | self.actor_layers = nn.ModuleList()
62 | self.actor_layers += [nn.LSTMCell(state_dim, layers[0])]
63 | for i in range(len(layers)-1):
64 | self.actor_layers += [nn.LSTMCell(layers[i], layers[i+1])]
65 | self.network_out = nn.Linear(layers[i-1], action_dim)
66 |
67 | self.action_dim = action_dim
68 | self.init_hidden_state()
69 | self.nonlinearity = nonlinearity
70 |
71 | def get_hidden_state(self):
72 | return self.hidden, self.cells
73 |
74 | def set_hidden_state(self, data):
75 | if len(data) != 2:
76 | print("Got invalid hidden state data.")
77 | exit(1)
78 |
79 | self.hidden, self.cells = data
80 |
81 | def init_hidden_state(self, batch_size=1):
82 | self.hidden = [torch.zeros(batch_size, l.hidden_size) for l in self.actor_layers]
83 | self.cells = [torch.zeros(batch_size, l.hidden_size) for l in self.actor_layers]
84 |
85 | def forward(self, x, deterministic=True):
86 | dims = len(x.size())
87 |
88 | if dims == 3: # if we get a batch of trajectories
89 | self.init_hidden_state(batch_size=x.size(1))
90 | y = []
91 | for t, x_t in enumerate(x):
92 | for idx, layer in enumerate(self.actor_layers):
93 | c, h = self.cells[idx], self.hidden[idx]
94 | self.hidden[idx], self.cells[idx] = layer(x_t, (h, c))
95 | x_t = self.hidden[idx]
96 | y.append(x_t)
97 | x = torch.stack([x_t for x_t in y])
98 |
99 | else:
100 | if dims == 1: # if we get a single timestep (if not, assume we got a batch of single timesteps)
101 | x = x.view(1, -1)
102 |
103 | for idx, layer in enumerate(self.actor_layers):
104 | h, c = self.hidden[idx], self.cells[idx]
105 | self.hidden[idx], self.cells[idx] = layer(x, (h, c))
106 | x = self.hidden[idx]
107 | x = self.nonlinearity(self.network_out(x))
108 |
109 | if dims == 1:
110 | x = x.view(-1)
111 |
112 | action = self.network_out(x)
113 | return action
114 |
115 |
116 | class Gaussian_FF_Actor(Actor): # more consistent with other actor naming conventions
117 | def __init__(self, state_dim, action_dim, layers=(256, 256), nonlinearity=torch.nn.functional.relu,
118 | init_std=0.2, learn_std=False, bounded=False, normc_init=True):
119 | super(Gaussian_FF_Actor, self).__init__()
120 |
121 | self.actor_layers = nn.ModuleList()
122 | self.actor_layers += [nn.Linear(state_dim, layers[0])]
123 | for i in range(len(layers)-1):
124 | self.actor_layers += [nn.Linear(layers[i], layers[i+1])]
125 | self.means = nn.Linear(layers[-1], action_dim)
126 |
127 | self.learn_std = learn_std
128 | if self.learn_std:
129 | self.stds = nn.Parameter(init_std * torch.ones(action_dim))
130 | else:
131 | self.stds = init_std * torch.ones(action_dim)
132 |
133 | self.action_dim = action_dim
134 | self.state_dim = state_dim
135 | self.nonlinearity = nonlinearity
136 |
137 | # Initialized to no input normalization, can be modified later
138 | self.obs_std = 1.0
139 | self.obs_mean = 0.0
140 |
141 | # weight initialization scheme used in PPO paper experiments
142 | self.normc_init = normc_init
143 |
144 | self.bounded = bounded
145 |
146 | self.init_parameters()
147 |
148 | def init_parameters(self):
149 | if self.normc_init:
150 | self.apply(normc_fn)
151 | self.means.weight.data.mul_(0.01)
152 |
153 | def _get_dist_params(self, state):
154 | state = (state - self.obs_mean) / self.obs_std
155 |
156 | x = state
157 | for l in self.actor_layers:
158 | x = self.nonlinearity(l(x))
159 | mean = self.means(x)
160 |
161 | if self.bounded:
162 | mean = torch.tanh(mean)
163 |
164 | sd = torch.zeros_like(mean)
165 | if hasattr(self, 'stds'):
166 | sd = self.stds
167 | return mean, sd
168 |
169 | def forward(self, state, deterministic=True):
170 | mu, sd = self._get_dist_params(state)
171 |
172 | if not deterministic:
173 | action = torch.distributions.Normal(mu, sd).sample()
174 | else:
175 | action = mu
176 |
177 | return action
178 |
179 | def distribution(self, inputs):
180 | mu, sd = self._get_dist_params(inputs)
181 | return torch.distributions.Normal(mu, sd)
182 |
183 |
184 | class Gaussian_LSTM_Actor(Actor):
185 | def __init__(self, state_dim, action_dim, layers=(128, 128), nonlinearity=F.tanh, normc_init=False,
186 | init_std=0.2, learn_std=False):
187 | super(Gaussian_LSTM_Actor, self).__init__()
188 |
189 | self.actor_layers = nn.ModuleList()
190 | self.actor_layers += [nn.LSTMCell(state_dim, layers[0])]
191 | for i in range(len(layers)-1):
192 | self.actor_layers += [nn.LSTMCell(layers[i], layers[i+1])]
193 | self.network_out = nn.Linear(layers[i-1], action_dim)
194 |
195 | self.action_dim = action_dim
196 | self.state_dim = state_dim
197 | self.init_hidden_state()
198 | self.nonlinearity = nonlinearity
199 |
200 | # Initialized to no input normalization, can be modified later
201 | self.obs_std = 1.0
202 | self.obs_mean = 0.0
203 |
204 | self.learn_std = learn_std
205 | if self.learn_std:
206 | self.stds = nn.Parameter(init_std * torch.ones(action_dim))
207 | else:
208 | self.stds = init_std * torch.ones(action_dim)
209 |
210 | if normc_init:
211 | self.initialize_parameters()
212 |
213 | self.act = self.forward
214 |
215 | def _get_dist_params(self, state):
216 | state = (state - self.obs_mean) / self.obs_std
217 |
218 | dims = len(state.size())
219 |
220 | x = state
221 | if dims == 3: # if we get a batch of trajectories
222 | self.init_hidden_state(batch_size=x.size(1))
223 | action = []
224 | y = []
225 | for t, x_t in enumerate(x):
226 | for idx, layer in enumerate(self.actor_layers):
227 | c, h = self.cells[idx], self.hidden[idx]
228 | self.hidden[idx], self.cells[idx] = layer(x_t, (h, c))
229 | x_t = self.hidden[idx]
230 | y.append(x_t)
231 | x = torch.stack([x_t for x_t in y])
232 |
233 | else:
234 | if dims == 1: # if we get a single timestep (if not, assume we got a batch of single timesteps)
235 | x = x.view(1, -1)
236 |
237 | for idx, layer in enumerate(self.actor_layers):
238 | h, c = self.hidden[idx], self.cells[idx]
239 | self.hidden[idx], self.cells[idx] = layer(x, (h, c))
240 | x = self.hidden[idx]
241 |
242 | if dims == 1:
243 | x = x.view(-1)
244 |
245 | mu = self.network_out(x)
246 | sd = self.stds
247 | return mu, sd
248 |
249 | def init_hidden_state(self, batch_size=1):
250 | self.hidden = [torch.zeros(batch_size, l.hidden_size) for l in self.actor_layers]
251 | self.cells = [torch.zeros(batch_size, l.hidden_size) for l in self.actor_layers]
252 |
253 | def forward(self, state, deterministic=True):
254 | mu, sd = self._get_dist_params(state)
255 |
256 | if not deterministic:
257 | action = torch.distributions.Normal(mu, sd).sample()
258 | else:
259 | action = mu
260 |
261 | return action
262 |
263 | def distribution(self, inputs):
264 | mu, sd = self._get_dist_params(inputs)
265 | return torch.distributions.Normal(mu, sd)
266 |
267 | # Initialization scheme for gaussian mlp (from ppo paper)
268 | # NOTE: the fact that this has the same name as a parameter caused a NASTY bug
269 | # apparently "if " evaluates to True in python...
270 | def normc_fn(m):
271 | classname = m.__class__.__name__
272 | if classname.find('Linear') != -1:
273 | m.weight.data.normal_(0, 1)
274 | m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
275 | if m.bias is not None:
276 | m.bias.data.fill_(0)
277 |
--------------------------------------------------------------------------------
/rl/policies/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch import sqrt
6 |
7 | def normc_fn(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Linear') != -1:
10 | m.weight.data.normal_(0, 1)
11 | m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
12 | if m.bias is not None:
13 | m.bias.data.fill_(0)
14 |
15 | # The base class for an actor. Includes functions for normalizing state (optional)
16 | class Net(nn.Module):
17 | def __init__(self):
18 | super(Net, self).__init__()
19 |
20 | self.welford_state_mean = torch.zeros(1)
21 | self.welford_state_mean_diff = torch.ones(1)
22 | self.welford_state_n = 1
23 |
24 | def forward(self):
25 | raise NotImplementedError
26 |
27 | def normalize_state(self, state, update=True):
28 | state = torch.Tensor(state)
29 |
30 | if self.welford_state_n == 1:
31 | self.welford_state_mean = torch.zeros(state.size(-1))
32 | self.welford_state_mean_diff = torch.ones(state.size(-1))
33 |
34 | if update:
35 | if len(state.size()) == 1: # If we get a single state vector
36 | state_old = self.welford_state_mean
37 | self.welford_state_mean += (state - state_old) / self.welford_state_n
38 | self.welford_state_mean_diff += (state - state_old) * (state - state_old)
39 | self.welford_state_n += 1
40 | elif len(state.size()) == 2: # If we get a batch
41 | print("NORMALIZING 2D TENSOR (this should not be happening)")
42 | for r_n in r:
43 | state_old = self.welford_state_mean
44 | self.welford_state_mean += (state_n - state_old) / self.welford_state_n
45 | self.welford_state_mean_diff += (state_n - state_old) * (state_n - state_old)
46 | self.welford_state_n += 1
47 | elif len(state.size()) == 3: # If we get a batch of sequences
48 | print("NORMALIZING 3D TENSOR (this really should not be happening)")
49 | for r_t in r:
50 | for r_n in r_t:
51 | state_old = self.welford_state_mean
52 | self.welford_state_mean += (state_n - state_old) / self.welford_state_n
53 | self.welford_state_mean_diff += (state_n - state_old) * (state_n - state_old)
54 | self.welford_state_n += 1
55 | return (state - self.welford_state_mean) / sqrt(self.welford_state_mean_diff / self.welford_state_n)
56 |
57 | def copy_normalizer_stats(self, net):
58 | self.welford_state_mean = net.self_state_mean
59 | self.welford_state_mean_diff = net.welford_state_mean_diff
60 | self.welford_state_n = net.welford_state_n
61 |
62 | def initialize_parameters(self):
63 | self.apply(normc_fn)
64 |
--------------------------------------------------------------------------------
/rl/policies/critic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from rl.policies.base import Net, normc_fn
6 |
7 | # The base class for a critic. Includes functions for normalizing reward and state (optional)
8 | class Critic(Net):
9 | def __init__(self):
10 | super(Critic, self).__init__()
11 |
12 | self.welford_reward_mean = 0.0
13 | self.welford_reward_mean_diff = 1.0
14 | self.welford_reward_n = 1
15 |
16 | def forward(self):
17 | raise NotImplementedError
18 |
19 | def normalize_reward(self, r, update=True):
20 | if update:
21 | if len(r.size()) == 1:
22 | r_old = self.welford_reward_mean
23 | self.welford_reward_mean += (r - r_old) / self.welford_reward_n
24 | self.welford_reward_mean_diff += (r - r_old) * (r - r_old)
25 | self.welford_reward_n += 1
26 | elif len(r.size()) == 2:
27 | for r_n in r:
28 | r_old = self.welford_reward_mean
29 | self.welford_reward_mean += (r_n - r_old) / self.welford_reward_n
30 | self.welford_reward_mean_diff += (r_n - r_old) * (r_n - r_old)
31 | self.welford_reward_n += 1
32 | else:
33 | raise NotImplementedError
34 |
35 | return (r - self.welford_reward_mean) / torch.sqrt(self.welford_reward_mean_diff / self.welford_reward_n)
36 |
37 | class FF_V(Critic):
38 | def __init__(self, state_dim, layers=(256, 256), nonlinearity=torch.nn.functional.relu, normc_init=True, obs_std=None, obs_mean=None):
39 | super(FF_V, self).__init__()
40 |
41 | self.critic_layers = nn.ModuleList()
42 | self.critic_layers += [nn.Linear(state_dim, layers[0])]
43 | for i in range(len(layers)-1):
44 | self.critic_layers += [nn.Linear(layers[i], layers[i+1])]
45 | self.network_out = nn.Linear(layers[-1], 1)
46 |
47 | self.nonlinearity = nonlinearity
48 |
49 | self.obs_std = obs_std
50 | self.obs_mean = obs_mean
51 |
52 | # weight initialization scheme used in PPO paper experiments
53 | self.normc_init = normc_init
54 |
55 | self.init_parameters()
56 |
57 | def init_parameters(self):
58 | if self.normc_init:
59 | print("Doing norm column initialization.")
60 | self.apply(normc_fn)
61 |
62 | def forward(self, inputs):
63 | inputs = (inputs - self.obs_mean) / self.obs_std
64 |
65 | x = inputs
66 | for l in self.critic_layers:
67 | x = self.nonlinearity(l(x))
68 | value = self.network_out(x)
69 |
70 | return value
71 |
72 | class LSTM_V(Critic):
73 | def __init__(self, input_dim, layers=(128, 128), normc_init=True):
74 | super(LSTM_V, self).__init__()
75 |
76 | self.critic_layers = nn.ModuleList()
77 | self.critic_layers += [nn.LSTMCell(input_dim, layers[0])]
78 | for i in range(len(layers)-1):
79 | self.critic_layers += [nn.LSTMCell(layers[i], layers[i+1])]
80 | self.network_out = nn.Linear(layers[-1], 1)
81 |
82 | self.init_hidden_state()
83 |
84 | if normc_init:
85 | self.initialize_parameters()
86 |
87 | def get_hidden_state(self):
88 | return self.hidden, self.cells
89 |
90 | def init_hidden_state(self, batch_size=1):
91 | self.hidden = [torch.zeros(batch_size, l.hidden_size) for l in self.critic_layers]
92 | self.cells = [torch.zeros(batch_size, l.hidden_size) for l in self.critic_layers]
93 |
94 | def forward(self, state):
95 | state = (state - self.obs_mean) / self.obs_std
96 | dims = len(state.size())
97 |
98 | if dims == 3: # if we get a batch of trajectories
99 | self.init_hidden_state(batch_size=state.size(1))
100 | value = []
101 | for t, state_batch_t in enumerate(state):
102 | x_t = state_batch_t
103 | for idx, layer in enumerate(self.critic_layers):
104 | c, h = self.cells[idx], self.hidden[idx]
105 | self.hidden[idx], self.cells[idx] = layer(x_t, (h, c))
106 | x_t = self.hidden[idx]
107 | x_t = self.network_out(x_t)
108 | value.append(x_t)
109 |
110 | x = torch.stack([a.float() for a in value])
111 |
112 | else:
113 | x = state
114 | if dims == 1:
115 | x = x.view(1, -1)
116 |
117 | for idx, layer in enumerate(self.critic_layers):
118 | c, h = self.cells[idx], self.hidden[idx]
119 | self.hidden[idx], self.cells[idx] = layer(x, (h, c))
120 | x = self.hidden[idx]
121 | x = self.network_out(x)
122 |
123 | if dims == 1:
124 | x = x.view(-1)
125 |
126 | return x
127 |
128 |
129 | GaussianMLP_Critic = FF_V
130 |
--------------------------------------------------------------------------------
/rl/storage/rollout_storage.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class PPOBuffer:
4 | def __init__(self, obs_len=1, act_len=1, gamma=0.99, lam=0.95, use_gae=False, size=1):
5 | self.states = torch.zeros(size, obs_len, dtype=float)
6 | self.actions = torch.zeros(size, act_len, dtype=float)
7 | self.rewards = torch.zeros(size, 1, dtype=float)
8 | self.values = torch.zeros(size, 1, dtype=float)
9 | self.returns = torch.zeros(size, 1, dtype=float)
10 | self.dones = torch.zeros(size, 1, dtype=float)
11 |
12 | self.gamma, self.lam = gamma, lam
13 | self.ptr = 0
14 | self.traj_idx = [0]
15 |
16 | def __len__(self):
17 | return self.ptr
18 |
19 | def store(self, state, action, reward, value, done):
20 | """
21 | Append one timestep of agent-environment interaction to the buffer.
22 | """
23 | self.states[self.ptr]= state
24 | self.actions[self.ptr] = action
25 | self.rewards[self.ptr] = reward
26 | self.values[self.ptr] = value
27 | self.dones[self.ptr] = done
28 | self.ptr += 1
29 |
30 | def finish_path(self, last_val=None):
31 | self.traj_idx += [self.ptr]
32 | rewards = self.rewards[self.traj_idx[-2]:self.traj_idx[-1], 0]
33 | R = last_val.squeeze(0)
34 | returns = torch.zeros_like(rewards)
35 | for i in range(len(rewards) - 1, -1, -1):
36 | R = self.gamma * R + rewards[i]
37 | returns[i] = R
38 | self.returns[self.traj_idx[-2]:self.traj_idx[-1], 0] = returns
39 | self.dones[-1] = True
40 |
41 | def get_data(self):
42 | """
43 | Return collected data and reset buffer.
44 |
45 | Returns:
46 | dict: Collected trajectory data
47 | """
48 | ep_lens = [j - i for i, j in zip(self.traj_idx, self.traj_idx[1:])]
49 | ep_rewards = [
50 | float(sum(self.rewards[int(i):int(j)])) for i, j in zip(self.traj_idx, self.traj_idx[1:])
51 | ]
52 | data = {
53 | 'states': self.states[:self.ptr],
54 | 'actions': self.actions[:self.ptr],
55 | 'rewards': self.rewards[:self.ptr],
56 | 'values': self.values[:self.ptr],
57 | 'returns': self.returns[:self.ptr],
58 | 'dones': self.dones[:self.ptr],
59 | 'traj_idx': torch.Tensor(self.traj_idx),
60 | 'ep_lens': torch.Tensor(ep_lens),
61 | 'ep_rewards': torch.Tensor(ep_rewards),
62 | }
63 | return data
64 |
--------------------------------------------------------------------------------
/rl/utils/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | from pathlib import Path
4 |
5 | import mujoco
6 | import mujoco.viewer
7 |
8 | import imageio
9 | from datetime import datetime
10 |
11 | class EvaluateEnv:
12 | def __init__(self, env, policy, args):
13 | self.env = env
14 | self.policy = policy
15 | self.ep_len = args.ep_len
16 |
17 | if args.out_dir is None:
18 | args.out_dir = Path(args.path.parent, "videos")
19 |
20 | video_outdir = Path(args.out_dir)
21 | try:
22 | Path.mkdir(video_outdir, exist_ok=True)
23 | now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
24 | video_fn = Path(video_outdir, args.path.stem + "-" + now + ".mp4")
25 | self.writer = imageio.get_writer(video_fn, fps=60)
26 | except Exception as e:
27 | print("Could not create video writer:", e)
28 | exit(-1)
29 |
30 | @torch.no_grad()
31 | def run(self):
32 |
33 | height = 480
34 | width = 640
35 | renderer = mujoco.Renderer(self.env.model, height, width)
36 | viewer = mujoco.viewer.launch_passive(self.env.model, self.env.data)
37 | frames = []
38 |
39 | # Make a camera.
40 | cam = viewer.cam
41 | mujoco.mjv_defaultCamera(cam)
42 | cam.elevation = -20
43 | cam.distance = 4
44 |
45 | reset_counter = 0
46 | observation = self.env.reset()
47 | while self.env.data.time < self.ep_len:
48 |
49 | step_start = time.time()
50 |
51 | # forward pass and step
52 | raw = self.policy.forward(torch.tensor(observation, dtype=torch.float32), deterministic=True).detach().numpy()
53 | observation, reward, done, _ = self.env.step(raw.copy())
54 |
55 | # render scene
56 | cam.lookat = self.env.data.body(1).xpos.copy()
57 | renderer.update_scene(self.env.data, cam)
58 | pixels = renderer.render()
59 | frames.append(pixels)
60 |
61 | viewer.sync()
62 |
63 | if done and reset_counter < 3:
64 | observation = self.env.reset()
65 | reset_counter += 1
66 |
67 | time_until_next_step = max(
68 | 0, self.env.frame_skip*self.env.model.opt.timestep - (time.time() - step_start))
69 | time.sleep(time_until_next_step)
70 |
71 | for frame in frames:
72 | self.writer.append_data(frame)
73 | self.writer.close()
74 | self.env.close()
75 | viewer.close()
76 |
--------------------------------------------------------------------------------
/robots/robot_base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class RobotBase(object):
4 | def __init__(self, pdgains, dt, client, task, pdrand_k = 0, sim_bemf = False, sim_motor_dyn = False):
5 |
6 | self.client = client
7 | self.task = task
8 | self.control_dt = dt
9 | self.pdrand_k = pdrand_k
10 | self.sim_bemf = sim_bemf
11 | self.sim_motor_dyn = sim_motor_dyn
12 |
13 | assert (self.sim_bemf & self.sim_motor_dyn)==False, \
14 | "You cannot simulate back-EMF and motor dynamics simultaneously!"
15 |
16 | # set PD gains
17 | self.kp = pdgains[0]
18 | self.kd = pdgains[1]
19 | assert self.kp.shape==self.kd.shape==(self.client.nu(),), \
20 | f"kp shape {self.kp.shape} and kd shape {self.kd.shape} must be {(self.client.nu(),)}"
21 |
22 | # torque damping param
23 | self.tau_d = np.zeros(self.client.nu())
24 |
25 | self.client.set_pd_gains(self.kp, self.kd)
26 | tau = self.client.step_pd(np.zeros(self.client.nu()), np.zeros(self.client.nu()))
27 | w = self.client.get_act_joint_velocities()
28 | assert len(w)==len(tau)
29 |
30 | self.prev_action = None
31 | self.prev_torque = None
32 | self.iteration_count = np.inf
33 |
34 | # frame skip parameter
35 | if (np.around(self.control_dt%self.client.sim_dt(), 6)):
36 | raise Exception("Control dt should be an integer multiple of Simulation dt.")
37 | self.frame_skip = int(self.control_dt/self.client.sim_dt())
38 |
39 | def _do_simulation(self, target, n_frames):
40 | # randomize PD gains
41 | if self.pdrand_k:
42 | k = self.pdrand_k
43 | kp = np.random.uniform((1-k)*self.kp, (1+k)*self.kp)
44 | kd = np.random.uniform((1-k)*self.kd, (1+k)*self.kd)
45 | self.client.set_pd_gains(kp, kd)
46 |
47 | assert target.shape == (self.client.nu(),), \
48 | f"Target shape must be {(self.client.nu(),)}"
49 |
50 | ratio = self.client.get_gear_ratios()
51 |
52 | if self.sim_bemf and np.random.randint(10)==0:
53 | self.tau_d = np.random.uniform(5, 40, self.client.nu())
54 |
55 | for _ in range(n_frames):
56 | w = self.client.get_act_joint_velocities()
57 | tau = self.client.step_pd(target, np.zeros(self.client.nu()))
58 | tau = tau - self.tau_d*w
59 | tau /= ratio
60 | self.client.set_motor_torque(tau, self.sim_motor_dyn)
61 | self.client.step()
62 |
63 | def step(self, action, offset=None):
64 |
65 | if not isinstance(action, np.ndarray):
66 | raise TypeError("Expected action to be a numpy array")
67 |
68 | action = np.copy(action)
69 |
70 | assert action.shape == (self.client.nu(),), \
71 | f"Action vector length expected to be: {self.client.nu()} but is {action.shape}"
72 |
73 | # If offset is provided, add to action vector
74 | if offset is not None:
75 | if not isinstance(offset, np.ndarray):
76 | raise TypeError("Expected offset to be a numpy array")
77 | assert offset.shape == action.shape, \
78 | f"Offset shape {offset} must match action shape {action.shape}"
79 | offset = np.copy(offset)
80 | action += offset
81 |
82 | if self.prev_action is None:
83 | self.prev_action = action
84 | if self.prev_torque is None:
85 | self.prev_torque = np.asarray(self.client.get_act_joint_torques())
86 |
87 | # Perform the simulation
88 | self._do_simulation(action, self.frame_skip)
89 |
90 | # Task-related operations
91 | self.task.step()
92 | rewards = self.task.calc_reward(self.prev_torque, self.prev_action, action)
93 | done = self.task.done()
94 |
95 | self.prev_action = action
96 | self.prev_torque = np.asarray(self.client.get_act_joint_torques())
97 |
98 | return rewards, done
99 |
--------------------------------------------------------------------------------
/run_experiment.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import sys
3 | import argparse
4 | import ray
5 | from functools import partial
6 |
7 | import numpy as np
8 | import torch
9 | import pickle
10 | import shutil
11 |
12 | from rl.algos.ppo import PPO
13 | from rl.envs.wrappers import SymmetricEnv
14 | from rl.utils.eval import EvaluateEnv
15 |
16 | def import_env(env_name_str):
17 | if env_name_str=='jvrc_walk':
18 | from envs.jvrc import JvrcWalkEnv as Env
19 | elif env_name_str=='jvrc_step':
20 | from envs.jvrc import JvrcStepEnv as Env
21 | elif env_name_str=='h1':
22 | from envs.h1 import H1Env as Env
23 | else:
24 | raise Exception("Check env name!")
25 | return Env
26 |
27 | def run_experiment(args):
28 | # import the correct environment
29 | Env = import_env(args.env)
30 |
31 | # wrapper function for creating parallelized envs
32 | env_fn = partial(Env, path_to_yaml=args.yaml)
33 | _env = env_fn()
34 | if not args.no_mirror:
35 | try:
36 | print("Wrapping in SymmetricEnv.")
37 | env_fn = partial(SymmetricEnv, env_fn,
38 | mirrored_obs=_env.robot.mirrored_obs,
39 | mirrored_act=_env.robot.mirrored_acts,
40 | clock_inds=_env.robot.clock_inds)
41 | except AttributeError as e:
42 | print("Warning! Cannot use SymmetricEnv.", e)
43 |
44 | # Set up Parallelism
45 | #os.environ['OMP_NUM_THREADS'] = '1' # [TODO: Is this needed?]
46 | if not ray.is_initialized():
47 | ray.init(num_cpus=args.num_procs)
48 |
49 | # dump hyperparameters
50 | Path.mkdir(args.logdir, parents=True, exist_ok=True)
51 | pkl_path = Path(args.logdir, "experiment.pkl")
52 | with open(pkl_path, 'wb') as f:
53 | pickle.dump(args, f)
54 |
55 | # copy config file
56 | if args.yaml:
57 | config_out_path = Path(args.logdir, "config.yaml")
58 | shutil.copyfile(args.yaml, config_out_path)
59 |
60 | algo = PPO(env_fn, args)
61 | algo.train(env_fn, args.n_itr)
62 |
63 | if __name__ == "__main__":
64 |
65 | parser = argparse.ArgumentParser()
66 |
67 | if sys.argv[1] == 'train':
68 | sys.argv.remove(sys.argv[1])
69 |
70 | parser.add_argument("--env", required=True, type=str)
71 | parser.add_argument("--logdir", default=Path("/tmp/logs"), type=Path, help="Path to save weights and logs")
72 | parser.add_argument("--input-norm-steps", type=int, default=100000)
73 | parser.add_argument("--n-itr", type=int, default=20000, help="Number of iterations of the learning algorithm")
74 | parser.add_argument("--lr", type=float, default=1e-4, help="Adam learning rate") # Xie
75 | parser.add_argument("--eps", type=float, default=1e-5, help="Adam epsilon (for numerical stability)")
76 | parser.add_argument("--lam", type=float, default=0.95, help="Generalized advantage estimate discount")
77 | parser.add_argument("--gamma", type=float, default=0.99, help="MDP discount")
78 | parser.add_argument("--std-dev", type=float, default=0.223, help="Action noise for exploration")
79 | parser.add_argument("--learn-std", action="store_true", help="Exploration noise will be learned")
80 | parser.add_argument("--entropy-coeff", type=float, default=0.0, help="Coefficient for entropy regularization")
81 | parser.add_argument("--clip", type=float, default=0.2, help="Clipping parameter for PPO surrogate loss")
82 | parser.add_argument("--minibatch-size", type=int, default=64, help="Batch size for PPO updates")
83 | parser.add_argument("--epochs", type=int, default=3, help="Number of optimization epochs per PPO update") #Xie
84 | parser.add_argument("--use-gae", type=bool, default=True,help="Whether or not to calculate returns using Generalized Advantage Estimation")
85 | parser.add_argument("--num-procs", type=int, default=12, help="Number of threads to train on")
86 | parser.add_argument("--max-grad-norm", type=float, default=0.05, help="Value to clip gradients at")
87 | parser.add_argument("--max-traj-len", type=int, default=400, help="Max episode horizon")
88 | parser.add_argument("--no-mirror", required=False, action="store_true", help="to use SymmetricEnv")
89 | parser.add_argument("--mirror-coeff", required=False, default=0.4, type=float, help="weight for mirror loss")
90 | parser.add_argument("--eval-freq", required=False, default=100, type=int, help="Frequency of performing evaluation")
91 | parser.add_argument("--continued", required=False, type=Path, help="path to pretrained weights")
92 | parser.add_argument("--recurrent", required=False, action="store_true", help="use LSTM instead of FF")
93 | parser.add_argument("--imitate", required=False, type=str, default=None, help="Policy to imitate")
94 | parser.add_argument("--imitate-coeff", required=False, type=float, default=0.3, help="Coefficient for imitation loss")
95 | parser.add_argument("--yaml", required=False, type=str, default=None, help="Path to config file passed to Env class")
96 | args = parser.parse_args()
97 |
98 | run_experiment(args)
99 |
100 | elif sys.argv[1] == 'eval':
101 | sys.argv.remove(sys.argv[1])
102 |
103 | parser.add_argument("--path", required=False, type=Path, default=Path("/tmp/logs"),
104 | help="Path to trained model dir")
105 | parser.add_argument("--out-dir", required=False, type=Path, default=None,
106 | help="Path to directory to save videos")
107 | parser.add_argument("--ep-len", required=False, type=int, default=10,
108 | help="Episode length to play (in seconds)")
109 | args = parser.parse_args()
110 |
111 | path_to_actor = ""
112 | if args.path.is_file() and args.path.suffix==".pt":
113 | path_to_actor = args.path
114 | elif args.path.is_dir():
115 | path_to_actor = Path(args.path, "actor.pt")
116 | else:
117 | raise Exception("Invalid path to actor module: ", args.path)
118 |
119 | path_to_critic = Path(path_to_actor.parent, "critic" + str(path_to_actor).split('actor')[1])
120 | path_to_pkl = Path(path_to_actor.parent, "experiment.pkl")
121 |
122 | # load experiment args
123 | run_args = pickle.load(open(path_to_pkl, "rb"))
124 | # load trained policy
125 | policy = torch.load(path_to_actor, weights_only=False)
126 | critic = torch.load(path_to_critic, weights_only=False)
127 | policy.eval()
128 | critic.eval()
129 |
130 | # load experiment args
131 | run_args = pickle.load(open(path_to_pkl, "rb"))
132 |
133 | # import the correct environment
134 | Env = import_env(run_args.env)
135 | if "yaml" in run_args and run_args.yaml is not None:
136 | yaml_path = Path(run_args.yaml)
137 | else:
138 | yaml_path = None
139 | env = partial(Env, yaml_path)()
140 |
141 | # run
142 | e = EvaluateEnv(env, policy, args)
143 | e.run()
144 |
--------------------------------------------------------------------------------
/scripts/debug_stepper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import torch
5 | import pickle
6 | import mujoco
7 | import numpy as np
8 | import transforms3d as tf3
9 | from run_experiment import import_env
10 |
11 | def print_reward(ep_rewards):
12 | mean_rewards = {k:[] for k in ep_rewards[-1].keys()}
13 | print('*********************************')
14 | for key in mean_rewards.keys():
15 | l = [step[key] for step in ep_rewards]
16 | mean_rewards[key] = sum(l)/len(l)
17 | print(key, ': ', mean_rewards[key])
18 | #total_rewards = [r for step in ep_rewards for r in step.values()]
19 | print('*********************************')
20 | print("mean per step reward: ", sum(mean_rewards.values()))
21 |
22 | def draw_targets(task, viewer):
23 | # draw step sequence
24 | arrow_size = [0.02, 0.02, 0.5]
25 | sphere = mujoco.mjtGeom.mjGEOM_SPHERE
26 | arrow = mujoco.mjtGeom.mjGEOM_ARROW
27 | if hasattr(task, 'sequence'):
28 | for idx, step in enumerate(task.sequence):
29 | step_pos = [step[0], step[1], step[2]]
30 | step_theta = step[3]
31 | if step_pos not in [task.sequence[task.t1][0:3].tolist(), task.sequence[task.t2][0:3].tolist()]:
32 | viewer.add_marker(pos=step_pos, size=np.ones(3)*0.05, rgba=np.array([0, 1, 1, 1]), type=sphere, label="")
33 | viewer.add_marker(pos=step_pos, mat=tf3.euler.euler2mat(0, np.pi/2, step_theta), size=arrow_size, rgba=np.array([0, 1, 1, 1]), type=arrow, label="")
34 |
35 | target_radius = task.target_radius
36 | step_pos = task.sequence[task.t1][0:3].tolist()
37 | step_theta = task.sequence[task.t1][3]
38 | viewer.add_marker(pos=step_pos, size=np.ones(3)*0.05, rgba=np.array([1, 0, 0, 1]), type=sphere, label="t1")
39 | viewer.add_marker(pos=step_pos, mat=tf3.euler.euler2mat(0, np.pi/2, step_theta), size=arrow_size, rgba=np.array([1, 0, 0, 1]), type=arrow, label="")
40 | viewer.add_marker(pos=step_pos, size=np.ones(3)*target_radius, rgba=np.array([1, 0, 0, 0.1]), type=sphere, label="")
41 | step_pos = task.sequence[task.t2][0:3].tolist()
42 | step_theta = task.sequence[task.t2][3]
43 | viewer.add_marker(pos=step_pos, size=np.ones(3)*0.05, rgba=np.array([0, 0, 1, 1]), type=sphere, label="t2")
44 | viewer.add_marker(pos=step_pos, mat=tf3.euler.euler2mat(0, np.pi/2, step_theta), size=arrow_size, rgba=np.array([0, 0, 1, 1]), type=arrow, label="")
45 | viewer.add_marker(pos=step_pos, size=np.ones(3)*target_radius, rgba=np.array([0, 0, 1, 0.1]), type=sphere, label="")
46 | return
47 |
48 | def draw_stuff(task, viewer):
49 | arrow_size = [0.02, 0.02, 0.5]
50 | sphere = mujoco.mjtGeom.mjGEOM_SPHERE
51 | arrow = mujoco.mjtGeom.mjGEOM_ARROW
52 |
53 | # draw observed targets
54 | goalx = task._goal_steps_x
55 | goaly = task._goal_steps_y
56 | goaltheta = task._goal_steps_theta
57 | viewer.add_marker(pos=[goalx[0], goaly[0], 0], size=np.ones(3)*0.05, rgba=np.array([0, 1, 1, 1]), type=sphere, label="G1")
58 | viewer.add_marker(pos=[goalx[0], goaly[0], 0], mat=tf3.euler.euler2mat(0, np.pi/2, goaltheta[0]), size=arrow_size, rgba=np.array([0, 1, 1, 1]), type=arrow, label="")
59 | viewer.add_marker(pos=[goalx[1], goaly[1], 0], size=np.ones(3)*0.05, rgba=np.array([0, 1, 1, 1]), type=sphere, label="G2")
60 | viewer.add_marker(pos=[goalx[1], goaly[1], 0], mat=tf3.euler.euler2mat(0, np.pi/2, goaltheta[1]), size=arrow_size, rgba=np.array([0, 1, 1, 1]), type=arrow, label="")
61 |
62 | # draw feet pose
63 | lfoot_orient = (tf3.quaternions.quat2mat(task.l_foot_quat)).dot(tf3.euler.euler2mat(0, np.pi/2, 0))
64 | rfoot_orient = (tf3.quaternions.quat2mat(task.r_foot_quat)).dot(tf3.euler.euler2mat(0, np.pi/2, 0))
65 | viewer.add_marker(pos=task.l_foot_pos, size=np.ones(3)*0.05, rgba=[0.5, 0.5, 0.5, 1], type=sphere, label="")
66 | viewer.add_marker(pos=task.l_foot_pos, mat=lfoot_orient, size=arrow_size, rgba=[0.5, 0.5, 0.5, 1], type=arrow, label="")
67 | viewer.add_marker(pos=task.r_foot_pos, size=np.ones(3)*0.05, rgba=[0.5, 0.5, 0.5, 1], type=sphere, label="")
68 | viewer.add_marker(pos=task.r_foot_pos, mat=rfoot_orient, size=arrow_size, rgba=[0.5, 0.5, 0.5, 1], type=arrow, label="")
69 |
70 | # draw origin
71 | viewer.add_marker(pos=[0, 0, 0], size=np.ones(3)*0.05, rgba=np.array([1, 1, 1, 1]), type=sphere, label="")
72 | viewer.add_marker(pos=[0, 0, 0], mat=tf3.euler.euler2mat(0, 0, 0), size=[0.01, 0.01, 2], rgba=np.array([0, 0, 1, 0.2]), type=arrow, label="")
73 | viewer.add_marker(pos=[0, 0, 0], mat=tf3.euler.euler2mat(0, np.pi/2, 0), size=[0.01, 0.01, 2], rgba=np.array([1, 0, 0, 0.2]), type=arrow, label="")
74 | viewer.add_marker(pos=[0, 0, 0], mat=tf3.euler.euler2mat(-np.pi/2, np.pi/2, 0), size=[0.01, 0.01, 2], rgba=np.array([0, 1, 0, 0.2]), type=arrow, label="")
75 | return
76 |
77 | def run(env, policy, args):
78 | observation = env.reset()
79 |
80 | env.render()
81 | viewer = env.viewer
82 | viewer._paused = True
83 | done = False
84 | ts, end_ts = 0, 2000
85 | ep_rewards = []
86 |
87 | while (ts < end_ts):
88 | if hasattr(env, 'frame_skip'):
89 | start = time.time()
90 |
91 | with torch.no_grad():
92 | action = policy.forward(torch.Tensor(observation), deterministic=True).detach().numpy()
93 |
94 | observation, _, done, info = env.step(action.copy())
95 | ep_rewards.append(info)
96 |
97 | if env.__class__.__name__ == 'JvrcStepEnv':
98 | draw_targets(env.task, viewer)
99 | draw_stuff(env.task, viewer)
100 | env.render()
101 |
102 | if args.sync and hasattr(env, 'frame_skip'):
103 | end = time.time()
104 | sim_dt = env.robot.client.sim_dt()
105 | delaytime = max(0, env.frame_skip / (1/sim_dt) - (end-start))
106 | time.sleep(delaytime)
107 |
108 | if args.quit_on_done and done:
109 | break
110 |
111 | ts+=1
112 |
113 | print("Episode finished after {} timesteps".format(ts))
114 | print_reward(ep_rewards)
115 | env.close()
116 |
117 | def main():
118 | # get command line arguments
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument("--path",
121 | required=True,
122 | type=str,
123 | help="path to trained model dir",
124 | )
125 | parser.add_argument("--sync",
126 | required=False,
127 | action="store_true",
128 | help="sync the simulation speed with real time",
129 | )
130 | parser.add_argument("--quit-on-done",
131 | required=False,
132 | action="store_true",
133 | help="Exit when done condition is reached",
134 | )
135 | args = parser.parse_args()
136 |
137 | path_to_actor = ""
138 | path_to_pkl = ""
139 | if os.path.isfile(args.path) and args.path.endswith(".pt"):
140 | path_to_actor = args.path
141 | path_to_pkl = os.path.join(os.path.dirname(args.path), "experiment.pkl")
142 | if os.path.isdir(args.path):
143 | path_to_actor = os.path.join(args.path, "actor.pt")
144 | path_to_pkl = os.path.join(args.path, "experiment.pkl")
145 |
146 | # load experiment args
147 | run_args = pickle.load(open(path_to_pkl, "rb"))
148 | # load trained policy
149 | policy = torch.load(path_to_actor)
150 | policy.eval()
151 | # import the correct environment
152 | env = import_env(run_args.env)()
153 |
154 | run(env, policy, args)
155 | print("-----------------------------------------")
156 |
157 | if __name__=='__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohanpsingh/LearningHumanoidWalking/e517f531433777cc59bd33e85a0d0f4884f5b39b/tasks/__init__.py
--------------------------------------------------------------------------------
/tasks/rewards.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | ##############################
4 | ##############################
5 | # Define reward functions here
6 | ##############################
7 | ##############################
8 |
9 | def _calc_fwd_vel_reward(self):
10 | # forward vel reward
11 | root_vel = self._client.get_qvel()[0]
12 | error = np.linalg.norm(root_vel - self._goal_speed_ref)
13 | return np.exp(-error)
14 |
15 | def _calc_action_reward(self, action, prev_action):
16 | # action reward
17 | penalty = 5 * sum(np.abs(prev_action - action)) / len(action)
18 | return np.exp(-penalty)
19 |
20 | def _calc_torque_reward(self, prev_torque):
21 | # actuator torque reward
22 | torque = np.asarray(self._client.get_act_joint_torques())
23 | penalty = 0.25 * (sum(np.abs(prev_torque - torque)) / len(torque))
24 | return np.exp(-penalty)
25 |
26 | def _calc_height_reward(self):
27 | # height reward
28 | if self._client.check_rfoot_floor_collision() or self._client.check_lfoot_floor_collision():
29 | contact_point = min([c.pos[2] for _,c in (self._client.get_rfoot_floor_contacts() +
30 | self._client.get_lfoot_floor_contacts())])
31 | else:
32 | contact_point = 0
33 | current_height = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY')[2]
34 | relative_height = current_height - contact_point
35 | error = np.abs(relative_height - self._goal_height_ref)
36 | deadzone_size = 0.01 + 0.05 * self._goal_speed_ref
37 | if error < deadzone_size:
38 | error = 0
39 | return np.exp(-40*np.square(error))
40 |
41 | def _calc_heading_reward(self):
42 | # heading reward
43 | cur_heading = self._client.get_qvel()[:3]
44 | cur_heading /= np.linalg.norm(cur_heading)
45 | error = np.linalg.norm(cur_heading - np.array([1, 0, 0]))
46 | return np.exp(-error)
47 |
48 | def _calc_root_accel_reward(self):
49 | qvel = self._client.get_qvel()
50 | qacc = self._client.get_qacc()
51 | error = 0.25 * (np.abs(qvel[3:6]).sum() + np.abs(qacc[0:3]).sum())
52 | return np.exp(-error)
53 |
54 | def _calc_feet_separation_reward(self):
55 | # feet y-separation cost
56 | rfoot_pos = self._client.get_rfoot_body_pos()[1]
57 | lfoot_pos = self._client.get_lfoot_body_pos()[1]
58 | foot_dist = np.abs(rfoot_pos-lfoot_pos)
59 | error = 5*np.square(foot_dist-0.35)
60 | if foot_dist < 0.40 and foot_dist > 0.30:
61 | error = 0
62 | return np.exp(-error)
63 |
64 | def _calc_foot_frc_clock_reward(self, left_frc_fn, right_frc_fn):
65 | # constraints of foot forces based on clock
66 | desired_max_foot_frc = self._client.get_robot_mass()*9.8*0.5
67 | #desired_max_foot_frc = self._client.get_robot_mass()*10*1.2
68 | normed_left_frc = min(self.l_foot_frc, desired_max_foot_frc) / desired_max_foot_frc
69 | normed_right_frc = min(self.r_foot_frc, desired_max_foot_frc) / desired_max_foot_frc
70 | normed_left_frc*=2
71 | normed_left_frc-=1
72 | normed_right_frc*=2
73 | normed_right_frc-=1
74 |
75 | left_frc_clock = left_frc_fn(self._phase)
76 | right_frc_clock = right_frc_fn(self._phase)
77 |
78 | left_frc_score = np.tan(np.pi/4 * left_frc_clock * normed_left_frc)
79 | right_frc_score = np.tan(np.pi/4 * right_frc_clock * normed_right_frc)
80 |
81 | foot_frc_score = (left_frc_score + right_frc_score)/2
82 | return foot_frc_score
83 |
84 | def _calc_foot_vel_clock_reward(self, left_vel_fn, right_vel_fn):
85 | # constraints of foot velocities based on clock
86 | desired_max_foot_vel = 0.2
87 | normed_left_vel = min(np.linalg.norm(self.l_foot_vel), desired_max_foot_vel) / desired_max_foot_vel
88 | normed_right_vel = min(np.linalg.norm(self.r_foot_vel), desired_max_foot_vel) / desired_max_foot_vel
89 | normed_left_vel*=2
90 | normed_left_vel-=1
91 | normed_right_vel*=2
92 | normed_right_vel-=1
93 |
94 | left_vel_clock = left_vel_fn(self._phase)
95 | right_vel_clock = right_vel_fn(self._phase)
96 |
97 | left_vel_score = np.tan(np.pi/4 * left_vel_clock * normed_left_vel)
98 | right_vel_score = np.tan(np.pi/4 * right_vel_clock * normed_right_vel)
99 |
100 | foot_vel_score = (left_vel_score + right_vel_score)/2
101 | return foot_vel_score
102 |
103 | def _calc_foot_pos_clock_reward(self):
104 | # constraints of foot height based on clock
105 | desired_max_foot_height = 0.05
106 | l_foot_pos = self._client.get_object_xpos_by_name('lf_force', 'OBJ_SITE')[2]
107 | r_foot_pos = self._client.get_object_xpos_by_name('rf_force', 'OBJ_SITE')[2]
108 | normed_left_pos = min(np.linalg.norm(l_foot_pos), desired_max_foot_height) / desired_max_foot_height
109 | normed_right_pos = min(np.linalg.norm(r_foot_pos), desired_max_foot_height) / desired_max_foot_height
110 |
111 | left_pos_clock = self.left_clock[1](self._phase)
112 | right_pos_clock = self.right_clock[1](self._phase)
113 |
114 | left_pos_score = np.tan(np.pi/4 * left_pos_clock * normed_left_pos)
115 | right_pos_score = np.tan(np.pi/4 * right_pos_clock * normed_right_pos)
116 |
117 | foot_pos_score = left_pos_score + right_pos_score
118 | return foot_pos_score
119 |
120 | def _calc_body_orient_reward(self, body_name, quat_ref=[1, 0, 0, 0]):
121 | # body orientation reward
122 | body_quat = self._client.get_object_xquat_by_name(body_name, "OBJ_BODY")
123 | target_quat = np.array(quat_ref)
124 | error = 10 * (1 - np.inner(target_quat, body_quat) ** 2)
125 | return np.exp(-error)
126 |
127 | def _calc_joint_vel_reward(self, enabled, cutoff=0.5):
128 | # joint velocity reward
129 | motor_speeds = self._client.get_motor_velocities()
130 | motor_limits = self._client.get_motor_speed_limits()
131 | motor_speeds = [motor_speeds[i] for i in enabled]
132 | motor_limits = [motor_limits[i] for i in enabled]
133 | error = 5e-6*sum([np.square(q)
134 | for q, qmax in zip(motor_speeds, motor_limits)
135 | if np.abs(q)>np.abs(cutoff*qmax)])
136 | return np.exp(-error)
137 |
138 |
139 | def _calc_joint_acc_reward(self):
140 | # joint accelaration reward
141 | joint_acc_cost = np.sum(np.square(self._client.get_qacc()[-self._num_joints:]))
142 | return self.wp.joint_acc_weight*joint_acc_cost
143 |
144 | def _calc_ang_vel_reward(self):
145 | # angular vel reward
146 | ang_vel = self._client.get_qvel()[3:6]
147 | ang_vel_cost = np.square(np.linalg.norm(ang_vel))
148 | return self.wp.ang_vel_weight*ang_vel_cost
149 |
150 | def _calc_impact_reward(self):
151 | # contact reward
152 | ncon = len(self._client.get_rfoot_floor_contacts()) + \
153 | len(self._client.get_lfoot_floor_contactts())
154 | if ncon==0:
155 | return 0
156 | quad_impact_cost = np.sum(np.square(self._client.get_body_ext_force()))/ncon
157 | return self.wp.impact_weight*quad_impact_cost
158 |
159 | def _calc_zmp_reward(self):
160 | # zmp reward
161 | self.current_zmp = estimate_zmp(self)
162 | if np.linalg.norm(self.current_zmp - self._prev_zmp) > 1:
163 | self.current_zmp = self._prev_zmp
164 | zmp_cost = np.square(np.linalg.norm(self.current_zmp - self.desired_zmp))
165 | self._prev_zmp = self.current_zmp
166 | return self.wp.zmp_weight*zmp_cost
167 |
168 | def _calc_foot_contact_reward(self):
169 | right_contacts = self._client.get_rfoot_floor_collisions()
170 | left_contacts = self._client.get_lfoot_floor_collisions()
171 |
172 | radius_thresh = 0.3
173 | f_base = self._client.get_qpos()[0:2]
174 | c_dist_r = [(np.linalg.norm(c.pos[0:2] - f_base)) for _, c in right_contacts]
175 | c_dist_l = [(np.linalg.norm(c.pos[0:2] - f_base)) for _, c in left_contacts]
176 | d = sum([r for r in c_dist_r if r > radius_thresh] +
177 | [r for r in c_dist_l if r > radius_thresh])
178 | return self.wp.foot_contact_weight*d
179 |
180 | def _calc_gait_reward(self):
181 | if self._period<=0:
182 | raise Exception("Cycle period should be greater than zero.")
183 |
184 | # get foot-ground contact force
185 | rfoot_grf = self._client.get_rfoot_grf()
186 | lfoot_grf = self._client.get_lfoot_grf()
187 |
188 | # get foot speed
189 | rfoot_speed = self._client.get_rfoot_body_speed()
190 | lfoot_speed = self._client.get_lfoot_body_speed()
191 |
192 | # get foot position
193 | rfoot_pos = self._client.get_rfoot_body_pos()
194 | lfoot_pos = self._client.get_lfoot_body_pos()
195 | swing_height = 0.3
196 | stance_height = 0.1
197 |
198 | r = 0.5
199 | if self._phase < r:
200 | # right foot is in contact
201 | # left foot is swinging
202 | cost = (0.01*lfoot_grf)# \
203 | #+ np.square(lfoot_pos[2]-swing_height)
204 | #+ (10*np.square(rfoot_pos[2]-stance_height))
205 | else:
206 | # left foot is in contact
207 | # right foot is swinging
208 | cost = (0.01*rfoot_grf)
209 | #+ np.square(rfoot_pos[2]-swing_height)
210 | #+ (10*np.square(lfoot_pos[2]-stance_height))
211 | return self.wp.gait_weight*cost
212 |
213 | def _calc_reference(self):
214 | if self.ref_poses is None:
215 | raise Exception("Reference trajectory not provided.")
216 |
217 | # get reference pose
218 | phase = self._phase
219 | traj_length = self.traj_len
220 | indx = int(phase*(traj_length-1))
221 | reference_pose = self.ref_poses[indx,:]
222 |
223 | # get current pose
224 | current_pose = np.array(self._client.get_act_joint_positions())
225 |
226 | cost = np.square(np.linalg.norm(reference_pose-current_pose))
227 | return self.wp.ref_traj_weight*cost
228 |
229 | ##############################
230 | ##############################
231 | # Define utility functions
232 | ##############################
233 | ##############################
234 |
235 | def estimate_zmp(self):
236 | Gv = 9.80665
237 | Mg = self._mass * Gv
238 |
239 | com_pos = self._sim.data.subtree_com[1].copy()
240 | lin_mom = self._sim.data.subtree_linvel[1].copy()*self._mass
241 | ang_mom = self._sim.data.subtree_angmom[1].copy() + np.cross(com_pos, lin_mom)
242 |
243 | d_lin_mom = (lin_mom - self._prev_lin_mom)/self._control_dt
244 | d_ang_mom = (ang_mom - self._prev_ang_mom)/self._control_dt
245 |
246 | Fgz = d_lin_mom[2] + Mg
247 |
248 | # check contact with floor
249 | contacts = [self._sim.data.contact[i] for i in range(self._sim.data.ncon)]
250 | contact_flag = [(c.geom1==0 or c.geom2==0) for c in contacts]
251 |
252 | if (True in contact_flag) and Fgz > 20:
253 | zmp_x = (Mg*com_pos[0] - d_ang_mom[1])/Fgz
254 | zmp_y = (Mg*com_pos[1] + d_ang_mom[0])/Fgz
255 | else:
256 | zmp_x = com_pos[0]
257 | zmp_y = com_pos[1]
258 |
259 | self._prev_lin_mom = lin_mom
260 | self._prev_ang_mom = ang_mom
261 | return np.array([zmp_x, zmp_y])
262 |
263 | ##############################
264 | ##############################
265 | # Based on apex
266 | ##############################
267 | ##############################
268 |
269 | def create_phase_reward(swing_duration, stance_duration, strict_relaxer, stance_mode, FREQ=40):
270 |
271 | from scipy.interpolate import PchipInterpolator
272 |
273 | # NOTE: these times are being converted from time in seconds to phaselength
274 | right_swing = np.array([0.0, swing_duration]) * FREQ
275 | first_dblstance = np.array([swing_duration, swing_duration + stance_duration]) * FREQ
276 | left_swing = np.array([swing_duration + stance_duration, 2 * swing_duration + stance_duration]) * FREQ
277 | second_dblstance = np.array([2 * swing_duration + stance_duration, 2 * (swing_duration + stance_duration)]) * FREQ
278 |
279 | r_frc_phase_points = np.zeros((2, 8))
280 | r_vel_phase_points = np.zeros((2, 8))
281 | l_frc_phase_points = np.zeros((2, 8))
282 | l_vel_phase_points = np.zeros((2, 8))
283 |
284 | right_swing_relax_offset = (right_swing[1] - right_swing[0]) * strict_relaxer
285 | l_frc_phase_points[0,0] = r_frc_phase_points[0,0] = right_swing[0] + right_swing_relax_offset
286 | l_frc_phase_points[0,1] = r_frc_phase_points[0,1] = right_swing[1] - right_swing_relax_offset
287 | l_vel_phase_points[0,0] = r_vel_phase_points[0,0] = right_swing[0] + right_swing_relax_offset
288 | l_vel_phase_points[0,1] = r_vel_phase_points[0,1] = right_swing[1] - right_swing_relax_offset
289 | # During right swing we want foot velocities and don't want foot forces
290 | l_vel_phase_points[1,:2] = r_frc_phase_points[1,:2] = np.negative(np.ones(2)) # penalize l vel and r force
291 | l_frc_phase_points[1,:2] = r_vel_phase_points[1,:2] = np.ones(2) # incentivize l force and r vel
292 |
293 | dbl_stance_relax_offset = (first_dblstance[1] - first_dblstance[0]) * strict_relaxer
294 | l_frc_phase_points[0,2] = r_frc_phase_points[0,2] = first_dblstance[0] + dbl_stance_relax_offset
295 | l_frc_phase_points[0,3] = r_frc_phase_points[0,3] = first_dblstance[1] - dbl_stance_relax_offset
296 | l_vel_phase_points[0,2] = r_vel_phase_points[0,2] = first_dblstance[0] + dbl_stance_relax_offset
297 | l_vel_phase_points[0,3] = r_vel_phase_points[0,3] = first_dblstance[1] - dbl_stance_relax_offset
298 | if stance_mode == "aerial":
299 | # During aerial we want foot velocities and don't want foot forces
300 | # During grounded walking we want foot forces and don't want velocities
301 | l_frc_phase_points[1,2:4] = r_frc_phase_points[1,2:4] = np.negative(np.ones(2)) # penalize l and r foot force
302 | l_vel_phase_points[1,2:4] = r_vel_phase_points[1,2:4] = np.ones(2) # incentivize l and r foot velocity
303 | elif stance_mode == "zero":
304 | l_frc_phase_points[1,2:4] = r_frc_phase_points[1,2:4] = np.zeros(2)
305 | l_vel_phase_points[1,2:4] = r_vel_phase_points[1,2:4] = np.zeros(2)
306 | else:
307 | # During grounded walking we want foot forces and don't want velocities
308 | l_frc_phase_points[1,2:4] = r_frc_phase_points[1,2:4] = np.ones(2) # incentivize l and r foot force
309 | l_vel_phase_points[1,2:4] = r_vel_phase_points[1,2:4] = np.negative(np.ones(2)) # penalize l and r foot velocity
310 |
311 | left_swing_relax_offset = (left_swing[1] - left_swing[0]) * strict_relaxer
312 | l_frc_phase_points[0,4] = r_frc_phase_points[0,4] = left_swing[0] + left_swing_relax_offset
313 | l_frc_phase_points[0,5] = r_frc_phase_points[0,5] = left_swing[1] - left_swing_relax_offset
314 | l_vel_phase_points[0,4] = r_vel_phase_points[0,4] = left_swing[0] + left_swing_relax_offset
315 | l_vel_phase_points[0,5] = r_vel_phase_points[0,5] = left_swing[1] - left_swing_relax_offset
316 | # During left swing we want foot forces and don't want foot velocities (from perspective of right foot)
317 | l_vel_phase_points[1,4:6] = r_frc_phase_points[1,4:6] = np.ones(2) # incentivize l vel and r force
318 | l_frc_phase_points[1,4:6] = r_vel_phase_points[1,4:6] = np.negative(np.ones(2)) # penalize l force and r vel
319 |
320 | dbl_stance_relax_offset = (second_dblstance[1] - second_dblstance[0]) * strict_relaxer
321 | l_frc_phase_points[0,6] = r_frc_phase_points[0,6] = second_dblstance[0] + dbl_stance_relax_offset
322 | l_frc_phase_points[0,7] = r_frc_phase_points[0,7] = second_dblstance[1] - dbl_stance_relax_offset
323 | l_vel_phase_points[0,6] = r_vel_phase_points[0,6] = second_dblstance[0] + dbl_stance_relax_offset
324 | l_vel_phase_points[0,7] = r_vel_phase_points[0,7] = second_dblstance[1] - dbl_stance_relax_offset
325 | if stance_mode == "aerial":
326 | # During aerial we want foot velocities and don't want foot forces
327 | # During grounded walking we want foot forces and don't want velocities
328 | l_frc_phase_points[1,6:] = r_frc_phase_points[1,6:] = np.negative(np.ones(2)) # penalize l and r foot force
329 | l_vel_phase_points[1,6:] = r_vel_phase_points[1,6:] = np.ones(2) # incentivize l and r foot velocity
330 | elif stance_mode == "zero":
331 | l_frc_phase_points[1,6:] = r_frc_phase_points[1,6:] = np.zeros(2)
332 | l_vel_phase_points[1,6:] = r_vel_phase_points[1,6:] = np.zeros(2)
333 | else:
334 | # During grounded walking we want foot forces and don't want velocities
335 | l_frc_phase_points[1,6:] = r_frc_phase_points[1,6:] = np.ones(2) # incentivize l and r foot force
336 | l_vel_phase_points[1,6:] = r_vel_phase_points[1,6:] = np.negative(np.ones(2)) # penalize l and r foot velocity
337 |
338 | ## extend the data to three cycles : one before and one after : this ensures continuity
339 |
340 | r_frc_prev_cycle = np.copy(r_frc_phase_points)
341 | r_vel_prev_cycle = np.copy(r_vel_phase_points)
342 | l_frc_prev_cycle = np.copy(l_frc_phase_points)
343 | l_vel_prev_cycle = np.copy(l_vel_phase_points)
344 | l_frc_prev_cycle[0] = r_frc_prev_cycle[0] = r_frc_phase_points[0] - r_frc_phase_points[0, -1] - dbl_stance_relax_offset
345 | l_vel_prev_cycle[0] = r_vel_prev_cycle[0] = r_vel_phase_points[0] - r_vel_phase_points[0, -1] - dbl_stance_relax_offset
346 |
347 | r_frc_second_cycle = np.copy(r_frc_phase_points)
348 | r_vel_second_cycle = np.copy(r_vel_phase_points)
349 | l_frc_second_cycle = np.copy(l_frc_phase_points)
350 | l_vel_second_cycle = np.copy(l_vel_phase_points)
351 | l_frc_second_cycle[0] = r_frc_second_cycle[0] = r_frc_phase_points[0] + r_frc_phase_points[0, -1] + dbl_stance_relax_offset
352 | l_vel_second_cycle[0] = r_vel_second_cycle[0] = r_vel_phase_points[0] + r_vel_phase_points[0, -1] + dbl_stance_relax_offset
353 |
354 | r_frc_phase_points_repeated = np.hstack((r_frc_prev_cycle, r_frc_phase_points, r_frc_second_cycle))
355 | r_vel_phase_points_repeated = np.hstack((r_vel_prev_cycle, r_vel_phase_points, r_vel_second_cycle))
356 | l_frc_phase_points_repeated = np.hstack((l_frc_prev_cycle, l_frc_phase_points, l_frc_second_cycle))
357 | l_vel_phase_points_repeated = np.hstack((l_vel_prev_cycle, l_vel_phase_points, l_vel_second_cycle))
358 |
359 | ## Create the smoothing function with cubic spline and cutoff at limits -1 and 1
360 | r_frc_phase_spline = PchipInterpolator(r_frc_phase_points_repeated[0], r_frc_phase_points_repeated[1])
361 | r_vel_phase_spline = PchipInterpolator(r_vel_phase_points_repeated[0], r_vel_phase_points_repeated[1])
362 | l_frc_phase_spline = PchipInterpolator(l_frc_phase_points_repeated[0], l_frc_phase_points_repeated[1])
363 | l_vel_phase_spline = PchipInterpolator(l_vel_phase_points_repeated[0], l_vel_phase_points_repeated[1])
364 |
365 | return [r_frc_phase_spline, r_vel_phase_spline], [l_frc_phase_spline, l_vel_phase_spline]
366 |
--------------------------------------------------------------------------------
/tasks/stepping_task.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import transforms3d as tf3
4 | from tasks import rewards
5 | from enum import Enum, auto
6 |
7 | class WalkModes(Enum):
8 | STANDING = auto()
9 | CURVED = auto()
10 | FORWARD = auto()
11 | BACKWARD = auto()
12 | INPLACE = auto()
13 | LATERAL = auto()
14 |
15 | class SteppingTask(object):
16 | """Bipedal locomotion by stepping on targets."""
17 |
18 | def __init__(self,
19 | client=None,
20 | dt=0.025,
21 | neutral_foot_orient=[],
22 | root_body='pelvis',
23 | lfoot_body='lfoot',
24 | rfoot_body='rfoot',
25 | head_body='head',
26 | ):
27 |
28 | self._client = client
29 | self._control_dt = dt
30 |
31 | self._mass = self._client.get_robot_mass()
32 |
33 | self._goal_speed_ref = 0
34 | self._goal_height_ref = []
35 | self._swing_duration = []
36 | self._stance_duration = []
37 | self._total_duration = []
38 |
39 | self._head_body_name = head_body
40 | self._root_body_name = root_body
41 | self._lfoot_body_name = lfoot_body
42 | self._rfoot_body_name = rfoot_body
43 |
44 | # read previously generated footstep plans
45 | with open('utils/footstep_plans.txt', 'r') as fn:
46 | lines = [l.strip() for l in fn.readlines()]
47 | self.plans = []
48 | sequence = []
49 | for line in lines:
50 | if line=='---':
51 | if len(sequence):
52 | self.plans.append(sequence)
53 | sequence=[]
54 | continue
55 | else:
56 | sequence.append(np.array([float(l) for l in line.split(',')]))
57 |
58 | def step_reward(self):
59 | target_pos = self.sequence[self.t1][0:3]
60 | foot_dist_to_target = min([np.linalg.norm(ft-target_pos) for ft in [self.l_foot_pos,
61 | self.r_foot_pos]])
62 | hit_reward = 0
63 | if self.target_reached:
64 | hit_reward = np.exp(-foot_dist_to_target/0.25)
65 |
66 | target_mp = (self.sequence[self.t1][0:2] + self.sequence[self.t2][0:2])/2
67 | root_xy_pos = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY')[0:2]
68 | root_dist_to_target = np.linalg.norm(root_xy_pos-target_mp)
69 | progress_reward = np.exp(-root_dist_to_target/2)
70 | return (0.8*hit_reward + 0.2*progress_reward)
71 |
72 | def calc_reward(self, prev_torque, prev_action, action):
73 | orient = tf3.euler.euler2quat(0, 0, self.sequence[self.t1][3])
74 | r_frc = self.right_clock[0]
75 | l_frc = self.left_clock[0]
76 | r_vel = self.right_clock[1]
77 | l_vel = self.left_clock[1]
78 | if self.mode == WalkModes.STANDING:
79 | r_frc = (lambda _:1)
80 | l_frc = (lambda _:1)
81 | r_vel = (lambda _:-1)
82 | l_vel = (lambda _:-1)
83 | head_pos = self._client.get_object_xpos_by_name(self._head_body_name, 'OBJ_BODY')[0:2]
84 | root_pos = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY')[0:2]
85 | reward = dict(foot_frc_score=0.150 * rewards._calc_foot_frc_clock_reward(self, l_frc, r_frc),
86 | foot_vel_score=0.150 * rewards._calc_foot_vel_clock_reward(self, l_vel, r_vel),
87 | orient_cost=0.050 * rewards._calc_body_orient_reward(self,
88 | self._root_body_name,
89 | quat_ref=orient),
90 | height_error=0.050 * rewards._calc_height_reward(self),
91 | #torque_penalty=0.050 * rewards._calc_torque_reward(self, prev_torque),
92 | #action_penalty=0.050 * rewards._calc_action_reward(self, prev_action),
93 | step_reward=0.450 * self.step_reward(),
94 | upper_body_reward=0.050 * np.exp(-10*np.square(np.linalg.norm(head_pos-root_pos)))
95 | )
96 | return reward
97 |
98 | def transform_sequence(self, sequence):
99 | lfoot_pos = self._client.get_lfoot_body_pos()
100 | rfoot_pos = self._client.get_rfoot_body_pos()
101 | root_yaw = tf3.euler.quat2euler(self._client.get_object_xquat_by_name(self._root_body_name, 'OBJ_BODY'))[2]
102 | mid_pt = (lfoot_pos + rfoot_pos)/2
103 | sequence_rel = []
104 | for x, y, z, theta in sequence:
105 | x_ = mid_pt[0] + x*np.cos(root_yaw) - y*np.sin(root_yaw)
106 | y_ = mid_pt[1] + x*np.sin(root_yaw) + y*np.cos(root_yaw)
107 | theta_ = root_yaw + theta
108 | step = np.array([x_, y_, z, theta_])
109 | sequence_rel.append(step)
110 | return sequence_rel
111 |
112 | def generate_step_sequence(self, **kwargs):
113 | step_size, step_gap, step_height, num_steps, curved, lateral = kwargs.values()
114 | if curved:
115 | # set 0 height for curved sequences
116 | plan = random.choice(self.plans)
117 | sequence = [[s[0], s[1], 0, s[2]] for s in plan]
118 | return np.array(sequence)
119 |
120 | if lateral:
121 | sequence = []
122 | y = 0
123 | c = np.random.choice([-1, 1])
124 | for i in range(1, num_steps):
125 | if i%2:
126 | y += step_size
127 | else:
128 | y -= (2/3)*step_size
129 | step = np.array([0, c*y, 0, 0])
130 | sequence.append(step)
131 | return sequence
132 |
133 | sequence = []
134 | if self._phase==(0.5*self._period):
135 | first_step = np.array([0, -1*np.random.uniform(0.095, 0.105), 0, 0])
136 | y = -step_gap
137 | else:
138 | first_step = np.array([0, 1*np.random.uniform(0.095, 0.105), 0, 0])
139 | y = step_gap
140 | sequence.append(first_step)
141 | x, z = 0, 0
142 | c = np.random.randint(2, 4)
143 | for i in range(1, num_steps-1):
144 | x += step_size
145 | y *= -1
146 | if i > c: # let height of first few steps equal to 0
147 | z += step_height
148 | step = np.array([x, y, z, 0])
149 | sequence.append(step)
150 | final_step = np.array([x+step_size, -y, z, 0])
151 | sequence.append(final_step)
152 | return sequence
153 |
154 | def update_goal_steps(self):
155 | self._goal_steps_x[:] = np.zeros(2)
156 | self._goal_steps_y[:] = np.zeros(2)
157 | self._goal_steps_z[:] = np.zeros(2)
158 | self._goal_steps_theta[:] = np.zeros(2)
159 | root_pos = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY')
160 | root_quat = self._client.get_object_xquat_by_name(self._root_body_name, 'OBJ_BODY')
161 | for idx, t in enumerate([self.t1, self.t2]):
162 | ref_frame = tf3.affines.compose(root_pos, tf3.quaternions.quat2mat(root_quat), np.ones(3))
163 | abs_goal_pos = self.sequence[t][0:3]
164 | abs_goal_rot = tf3.euler.euler2mat(0, 0, self.sequence[t][3])
165 | absolute_target = tf3.affines.compose(abs_goal_pos, abs_goal_rot, np.ones(3))
166 | relative_target = np.linalg.inv(ref_frame).dot(absolute_target)
167 | if self.mode != WalkModes.STANDING:
168 | self._goal_steps_x[idx] = relative_target[0, 3]
169 | self._goal_steps_y[idx] = relative_target[1, 3]
170 | self._goal_steps_z[idx] = relative_target[2, 3]
171 | self._goal_steps_theta[idx] = tf3.euler.mat2euler(relative_target[:3, :3])[2]
172 | return
173 |
174 | def update_target_steps(self):
175 | assert len(self.sequence)>0
176 | self.t1 = self.t2
177 | self.t2+=1
178 | if self.t2==len(self.sequence):
179 | self.t2 = len(self.sequence)-1
180 | return
181 |
182 | def step(self):
183 | # increment phase
184 | self._phase+=1
185 | if self._phase>=self._period:
186 | self._phase=0
187 |
188 | self.l_foot_quat = self._client.get_object_xquat_by_name('lf_force', 'OBJ_SITE')
189 | self.r_foot_quat = self._client.get_object_xquat_by_name('rf_force', 'OBJ_SITE')
190 | self.l_foot_pos = self._client.get_object_xpos_by_name('lf_force', 'OBJ_SITE')
191 | self.r_foot_pos = self._client.get_object_xpos_by_name('rf_force', 'OBJ_SITE')
192 | self.l_foot_vel = self._client.get_lfoot_body_vel()[0]
193 | self.r_foot_vel = self._client.get_rfoot_body_vel()[0]
194 | self.l_foot_frc = self._client.get_lfoot_grf()
195 | self.r_foot_frc = self._client.get_rfoot_grf()
196 |
197 | # check if target reached
198 | target_pos = self.sequence[self.t1][0:3]
199 | foot_dist_to_target = min([np.linalg.norm(ft-target_pos) for ft in [self.l_foot_pos,
200 | self.r_foot_pos]])
201 |
202 |
203 | lfoot_in_target = (np.linalg.norm(self.l_foot_pos-target_pos) < self.target_radius)
204 | rfoot_in_target = (np.linalg.norm(self.r_foot_pos-target_pos) < self.target_radius)
205 | if lfoot_in_target or rfoot_in_target:
206 | self.target_reached = True
207 | self.target_reached_frames+=1
208 | else:
209 | self.target_reached = False
210 | self.target_reached_frames=0
211 |
212 | # update target steps if needed
213 | if self.target_reached and (self.target_reached_frames>=self.delay_frames):
214 | self.update_target_steps()
215 | self.target_reached = False
216 | self.target_reached_frames = 0
217 |
218 | # update goal
219 | self.update_goal_steps()
220 | return
221 |
222 | def substep(self):
223 | pass
224 |
225 | def done(self):
226 | contact_flag = self._client.check_self_collisions()
227 |
228 | qpos = self._client.get_object_xpos_by_name(self._root_body_name, 'OBJ_BODY')
229 | foot_pos = min([c[2] for c in (self.l_foot_pos, self.r_foot_pos)])
230 | root_rel_height = qpos[2] - foot_pos
231 | terminate_conditions = {"qpos[2]_ll":(root_rel_height < 0.6),
232 | "contact_flag":contact_flag,
233 | }
234 |
235 | done = True in terminate_conditions.values()
236 | return done
237 |
238 | def reset(self, iter_count=0):
239 | # training iteration
240 | self.iteration_count = iter_count
241 |
242 | # for steps
243 | self._goal_steps_x = [0, 0]
244 | self._goal_steps_y = [0, 0]
245 | self._goal_steps_z = [0, 0]
246 | self._goal_steps_theta = [0, 0]
247 |
248 | self.target_radius = 0.20
249 | self.delay_frames = int(np.floor(self._swing_duration/self._control_dt))
250 | self.target_reached = False
251 | self.target_reached_frames = 0
252 | self.t1 = 0
253 | self.t2 = 0
254 |
255 | self.right_clock, self.left_clock = rewards.create_phase_reward(self._swing_duration,
256 | self._stance_duration,
257 | 0.1,
258 | "grounded",
259 | 1/self._control_dt)
260 |
261 | # number of control steps in one full cycle
262 | # (one full cycle includes left swing + right swing)
263 | self._period = np.floor(2*self._total_duration*(1/self._control_dt))
264 | # randomize phase during initialization
265 | self._phase = int(np.random.choice([0, self._period/2]))
266 |
267 | ## GENERATE STEP SEQUENCE
268 | # select a walking 'mode'
269 | self.mode = np.random.choice(
270 | [WalkModes.CURVED, WalkModes.STANDING, WalkModes.BACKWARD, WalkModes.LATERAL, WalkModes.FORWARD],
271 | p=[0.15, 0.05, 0.2, 0.3, 0.3])
272 |
273 | d = {'step_size':0.3, 'step_gap':0.15, 'step_height':0, 'num_steps':20, 'curved':False, 'lateral':False}
274 | # generate sequence according to mode
275 | if self.mode == WalkModes.CURVED:
276 | d['curved'] = True
277 | elif self.mode == WalkModes.STANDING:
278 | d['num_steps'] = 1
279 | elif self.mode == WalkModes.BACKWARD:
280 | d['step_size'] = -0.1
281 | elif self.mode == WalkModes.INPLACE:
282 | ss = np.random.uniform(-0.05, 0.05)
283 | d['step_size']=ss
284 | elif self.mode == WalkModes.LATERAL:
285 | d['step_size'] = 0.4
286 | d['lateral'] = True
287 | elif self.mode == WalkModes.FORWARD:
288 | h = np.clip((self.iteration_count-3000)/8000, 0, 1)*0.1
289 | d['step_height']=np.random.choice([-h, h])
290 | else:
291 | raise Exception("Invalid WalkModes")
292 | sequence = self.generate_step_sequence(**d)
293 | self.sequence = self.transform_sequence(sequence)
294 | self.update_target_steps()
295 |
296 | ## CREATE TERRAIN USING GEOMS
297 | nboxes = 20
298 | boxes = ["box"+repr(i+1).zfill(2) for i in range(nboxes)]
299 | sequence = [np.array([0, 0, -1, 0]) for i in range(nboxes)]
300 | sequence[:len(self.sequence)] = self.sequence
301 | for box, step in zip(boxes, sequence):
302 | box_h = self._client.model.geom(box).size[2]
303 | self._client.model.body(box).pos[:] = step[0:3] - np.array([0, 0, box_h])
304 | self._client.model.body(box).quat[:] = tf3.euler.euler2quat(0, 0, step[3])
305 | self._client.model.geom(box).size[:] = np.array([0.15, 1, box_h])
306 | self._client.model.geom(box).rgba[:] = np.array([0.8, 0.8, 0.8, 1])
307 |
308 | self._client.model.body('floor').pos[:] = np.array([0, 0, 0])
309 | if self.mode == WalkModes.FORWARD:
310 | self._client.model.body('floor').pos[:] = np.array([0, 0, -2])
311 |
--------------------------------------------------------------------------------
/tasks/walking_task.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import transforms3d as tf3
3 | from tasks import rewards
4 |
5 | class WalkingTask(object):
6 | """Dynamically stable walking on biped."""
7 |
8 | def __init__(self,
9 | client=None,
10 | dt=0.025,
11 | neutral_foot_orient=[],
12 | root_body='pelvis',
13 | lfoot_body='lfoot',
14 | rfoot_body='rfoot',
15 | waist_r_joint='waist_r',
16 | waist_p_joint='waist_p',
17 | ):
18 |
19 | self._client = client
20 | self._control_dt = dt
21 | self._neutral_foot_orient=neutral_foot_orient
22 |
23 | self._mass = self._client.get_robot_mass()
24 |
25 | # These depend on the robot, hardcoded for now
26 | # Ideally, they should be arguments to __init__
27 | self._goal_speed_ref = []
28 | self._goal_height_ref = []
29 | self._swing_duration = []
30 | self._stance_duration = []
31 | self._total_duration = []
32 |
33 | self._root_body_name = root_body
34 | self._lfoot_body_name = lfoot_body
35 | self._rfoot_body_name = rfoot_body
36 |
37 | def calc_reward(self, prev_torque, prev_action, action):
38 | self.l_foot_vel = self._client.get_lfoot_body_vel()[0]
39 | self.r_foot_vel = self._client.get_rfoot_body_vel()[0]
40 | self.l_foot_frc = self._client.get_lfoot_grf()
41 | self.r_foot_frc = self._client.get_rfoot_grf()
42 | r_frc = self.right_clock[0]
43 | l_frc = self.left_clock[0]
44 | r_vel = self.right_clock[1]
45 | l_vel = self.left_clock[1]
46 | reward = dict(foot_frc_score=0.150 * rewards._calc_foot_frc_clock_reward(self, l_frc, r_frc),
47 | foot_vel_score=0.150 * rewards._calc_foot_vel_clock_reward(self, l_vel, r_vel),
48 | orient_cost=0.050 * (rewards._calc_body_orient_reward(self, self._lfoot_body_name) +
49 | rewards._calc_body_orient_reward(self, self._rfoot_body_name) +
50 | rewards._calc_body_orient_reward(self, self._root_body_name))/3,
51 | root_accel=0.050 * rewards._calc_root_accel_reward(self),
52 | height_error=0.050 * rewards._calc_height_reward(self),
53 | com_vel_error=0.200 * rewards._calc_fwd_vel_reward(self),
54 | torque_penalty=0.050 * rewards._calc_torque_reward(self, prev_torque),
55 | action_penalty=0.050 * rewards._calc_action_reward(self, action, prev_action),
56 | )
57 | return reward
58 |
59 | def step(self):
60 | if self._phase>self._period:
61 | self._phase=0
62 | self._phase+=1
63 | return
64 |
65 | def done(self):
66 | contact_flag = self._client.check_self_collisions()
67 | qpos = self._client.get_qpos()
68 | terminate_conditions = {"qpos[2]_ll":(qpos[2] < 0.6),
69 | "qpos[2]_ul":(qpos[2] > 1.4),
70 | "contact_flag":contact_flag,
71 | }
72 |
73 | done = True in terminate_conditions.values()
74 | return done
75 |
76 | def reset(self, iter_count=0):
77 | self._goal_speed_ref = np.random.choice([0, np.random.uniform(0.3, 0.4)])
78 | self.right_clock, self.left_clock = rewards.create_phase_reward(self._swing_duration,
79 | self._stance_duration,
80 | 0.1,
81 | "grounded",
82 | 1/self._control_dt)
83 |
84 | # number of control steps in one full cycle
85 | # (one full cycle includes left swing + right swing)
86 | self._period = np.floor(2*self._total_duration*(1/self._control_dt))
87 | # randomize phase during initialization
88 | self._phase = np.random.randint(0, self._period)
89 |
--------------------------------------------------------------------------------