├── .gitignore ├── README.md ├── atacom ├── __init__.py ├── atacom.py ├── constraints.py ├── environments │ ├── __init__.py │ ├── circular_motion │ │ ├── __init__.py │ │ ├── circle_atacom.py │ │ ├── circle_base.py │ │ ├── circle_error_correction.py │ │ └── circle_terminated.py │ ├── collision_avoidance │ │ ├── __init__.py │ │ ├── collision_avoidance_atacom.py │ │ └── collision_avoidance_base.py │ ├── iiwa_air_hockey │ │ ├── __init__.py │ │ ├── env_base.py │ │ ├── env_hitting.py │ │ ├── env_single.py │ │ ├── iiwa_air_hockey_rmp.py │ │ ├── iiwa_hit_atacom.py │ │ ├── kinematics.py │ │ └── urdf │ │ │ ├── iiwa_1.urdf │ │ │ ├── iiwa_2.urdf │ │ │ └── meshes │ │ │ ├── collision │ │ │ ├── link_0.stl │ │ │ ├── link_1.stl │ │ │ ├── link_2.stl │ │ │ ├── link_3.stl │ │ │ ├── link_4.stl │ │ │ ├── link_5.stl │ │ │ ├── link_6.stl │ │ │ ├── link_7.stl │ │ │ └── link_7_old.stl │ │ │ ├── striker │ │ │ ├── collision │ │ │ │ ├── EE_arm_collision.stl │ │ │ │ ├── EE_mallet_collision.stl │ │ │ │ └── EE_mallet_short_collision.stl │ │ │ └── visual │ │ │ │ ├── EE_arm.stl │ │ │ │ ├── EE_mallet.stl │ │ │ │ └── EE_mallet_short.stl │ │ │ └── visual │ │ │ ├── link_0.stl │ │ │ ├── link_1.stl │ │ │ ├── link_2.stl │ │ │ ├── link_3.stl │ │ │ ├── link_4.stl │ │ │ ├── link_5.stl │ │ │ ├── link_6.stl │ │ │ ├── link_7.stl │ │ │ └── link_7_old.stl │ └── planar_air_hockey │ │ ├── __init__.py │ │ ├── atacom_air_hockey.py │ │ └── unconstrained_air_hockey.py ├── error_correction_wrapper.py └── utils │ ├── __init__.py │ ├── null_space_coordinate.py │ └── plot_utils.py ├── examples ├── __init__.py ├── circle_exp.py ├── collision_avoidance_exp.py ├── iiwa_air_hockey_exp.py ├── network.py └── planar_air_hockey_exp.py ├── fig └── manifold.gif ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | *.egg-info/ 4 | logs/ 5 | result/ 6 | *runs/ 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Acting on the Tangent Space of the Constraint Manifold 2 |

3 | 4 |

5 | Implementation of "Robot Reinforcement Learning on the Constraint Manifold" 6 | 7 | [[paper]](https://www.ias.informatik.tu-darmstadt.de/uploads/Team/PuzeLiu/CORL_2021_Learning_on_the_Manifold.pdf) 8 | [[website]](https://sites.google.com/view/robot-air-hockey/atacom) 9 | 10 | ## Install 11 | ```python 12 | pip install -e . 13 | ``` 14 | 15 | ## Run Examples 16 | ```python 17 | cd examples 18 | ``` 19 | ### CircularMotion Environment. 20 | Environment options [A, E, T], algorithms options [TRPO, PPO, SAC, DDPG, TD3] 21 | ```python 22 | python circle_exp.py --render --env A --alg TRPO 23 | ``` 24 | 25 | ### PlanarAirHockey Environment. 26 | Environment options [H, D, UH, UD], algorithms options [TRPO, PPO, SAC, DDPG, TD3] 27 | ```python 28 | python planar_air_hockey_exp.py --debug-gui --env H --alg SAC 29 | ``` 30 | 31 | ### IiwaAirHockey Environment. 32 | Environment options [7H, RMP], algorithms options [TRPO, PPO, SAC, DDPG, TD3] 33 | ```python 34 | python iiwa_air_hockey_exp.py --debug-gui --env 7H --alg SAC 35 | ``` 36 | 37 | ### CollisionAvoidance Environment. 38 | Environment options [C], algorithms options [TRPO, PPO, SAC, DDPG, TD3] 39 | ```python 40 | python collision_avoidance_exp.py --render --env C --alg SAC 41 | ``` 42 | 43 | 44 | ## Bibtex 45 | ```bibtex 46 | @inproceedings{CORL_2021_Learning_on_the_Manifold, 47 | author = "Liu, P. and Tateo D. and Bou-Ammar, H. and Peters, J.", 48 | year = "2021", 49 | title = "Robot Reinforcement Learning on the Constraint Manifold", 50 | booktitle = "Proceedings of the Conference on Robot Learning (CoRL)", 51 | key = "robot learning, constrained reinforcement learning, safe exploration", 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /atacom/__init__.py: -------------------------------------------------------------------------------- 1 | __version = '1.0.0' 2 | -------------------------------------------------------------------------------- /atacom/atacom.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.utils.spaces import * 2 | from atacom.utils.null_space_coordinate import rref, gram_schmidt, pinv_null 3 | 4 | 5 | class AtacomEnvWrapper: 6 | """ 7 | Environment wrapper of ATACOM 8 | """ 9 | 10 | def __init__(self, base_env, dim_q, vel_max, acc_max, f=None, g=None, Kc=1., Kq=10., time_step=0.01): 11 | """ 12 | Constructor 13 | Args: 14 | base_env (mushroomrl.Core.Environment): The base environment inherited from 15 | dim_q (int): [int] dimension of the directly controllable variable 16 | vel_max (array, float): the maximum velocity of the directly controllable variable 17 | acc_max (array, float): the maximum acceleration of the directly controllable variable 18 | f (ViabilityConstraint, ConstraintsSet): the equality constraint f(q) = 0 19 | g (ViabilityConstraint, ConstraintsSet): the inequality constraint g(q) = 0 20 | Kc (array, float): the scaling factor for error correction 21 | Ka (array, float): the scaling factor for the viability acceleration bound 22 | time_step (float): the step size for time discretization 23 | """ 24 | self.env = base_env 25 | self.dims = {'q': dim_q, 'f': 0, 'g': 0} 26 | self.f = f 27 | self.g = g 28 | self.time_step = time_step 29 | self._logger = None 30 | 31 | if self.f is not None: 32 | assert self.dims['q'] == self.f.dim_q, "Input dimension is different in f" 33 | self.dims['f'] = self.f.dim_out 34 | if self.g is not None: 35 | assert self.dims['q'] == self.g.dim_q, "Input dimension is different in g" 36 | self.dims['g'] = self.g.dim_out 37 | self.s = np.zeros(self.dims['g']) 38 | 39 | self.dims['null'] = self.dims['q'] - self.dims['f'] 40 | self.dims['c'] = self.dims['f'] + self.dims['g'] 41 | 42 | if np.isscalar(Kc): 43 | self.K_c = np.ones(self.dims['c']) * Kc 44 | else: 45 | self.K_c = Kc 46 | 47 | self.q = np.zeros(self.dims['q']) 48 | self.dq = np.zeros(self.dims['q']) 49 | 50 | self._mdp_info = self.env.info.copy() 51 | self._mdp_info.action_space = Box(low=-np.ones(self.dims['null']), high=np.ones(self.dims['null'])) 52 | 53 | if np.isscalar(vel_max): 54 | self.vel_max = np.ones(self.dims['q']) * vel_max 55 | else: 56 | self.vel_max = vel_max 57 | assert np.shape(self.vel_max)[0] == self.dims['q'] 58 | 59 | if np.isscalar(acc_max): 60 | self.acc_max = np.ones(self.dims['q']) * acc_max 61 | else: 62 | self.acc_max = acc_max 63 | assert np.shape(self.acc_max)[0] == self.dims['q'] 64 | 65 | if np.isscalar(Kq): 66 | self.K_q = np.ones(self.dims['q']) * Kq 67 | else: 68 | self.K_q = Kq 69 | assert np.shape(self.K_q)[0] == self.dims['q'] 70 | 71 | self.alpha_max = np.ones(self.dims['null']) * self.acc_max.max() 72 | 73 | self.state = self.env.reset() 74 | self._act_a = None 75 | self._act_b = None 76 | self._act_err = None 77 | 78 | self.constr_logs = list() 79 | self.env.step_action_function = self.step_action_function 80 | 81 | def _get_q(self, state): 82 | raise NotImplementedError 83 | 84 | def _get_dq(self, state): 85 | raise NotImplementedError 86 | 87 | def acc_to_ctrl_action(self, ddq): 88 | raise NotImplementedError 89 | 90 | def seed(self, seed): 91 | self.env.seed(seed) 92 | 93 | def reset(self, state=None): 94 | self.state = self.env.reset(state) 95 | self.q = self._get_q(self.state) 96 | self.dq = self._get_dq(self.state) 97 | self._compute_slack_variables() 98 | return self.state 99 | 100 | def render(self): 101 | self.env.render() 102 | 103 | def stop(self): 104 | self.env.stop() 105 | 106 | def step(self, action): 107 | alpha = np.clip(action, self.info.action_space.low, self.info.action_space.high) 108 | alpha = alpha * self.alpha_max 109 | 110 | self.state, reward, absorb, info = self.env.step(alpha) 111 | self.q = self._get_q(self.state) 112 | self.dq = self._get_dq(self.state) 113 | if not hasattr(self.env, "get_constraints_logs"): 114 | self._update_constraint_stats(self.q, self.dq) 115 | return self.state.copy(), reward, absorb, info 116 | 117 | def acc_truncation(self, dq, ddq): 118 | acc_u = np.maximum(np.minimum(self.acc_max, -self.K_q * (dq - self.vel_max)), -self.acc_max) 119 | acc_l = np.minimum(np.maximum(-self.acc_max, -self.K_q * (dq + self.vel_max)), self.acc_max) 120 | ddq = np.clip(ddq, acc_l, acc_u) 121 | return ddq 122 | 123 | def step_action_function(self, sim_state, alpha): 124 | self.state = self.env._create_observation(sim_state) 125 | 126 | Jc, psi = self._construct_Jc_psi(self.q, self.s, self.dq) 127 | Jc_inv, Nc = pinv_null(Jc) 128 | Nc = rref(Nc[:, :self.dims['null']], row_vectors=False, tol=0.05) 129 | 130 | self._act_a = -Jc_inv @ psi 131 | self._act_b = Nc @ alpha 132 | self._act_err = self._compute_error_correction(self.q, self.dq, self.s, Jc_inv) 133 | ddq_ds = self._act_a + self._act_b + self._act_err 134 | 135 | self.s += ddq_ds[self.dims['q']:(self.dims['q'] + self.dims['g'])] * self.time_step 136 | 137 | ddq = self.acc_truncation(self.dq, ddq_ds[:self.dims['q']]) 138 | ctrl_action = self.acc_to_ctrl_action(ddq) 139 | return ctrl_action 140 | 141 | @property 142 | def info(self): 143 | return self._mdp_info 144 | 145 | def _compute_slack_variables(self): 146 | self.s = None 147 | if self.dims['g'] > 0: 148 | s_2 = np.maximum(-2 * self.g.fun(self.q, self.dq, origin_constr=False), 0) 149 | self.s = np.sqrt(s_2) 150 | 151 | def _construct_Jc_psi(self, q, s, dq): 152 | Jc = np.zeros((self.dims['f'] + self.dims['g'], self.dims['q'] + self.dims['g'])) 153 | psi = np.zeros(self.dims['c']) 154 | if self.dims['f'] > 0: 155 | idx_0 = 0 156 | idx_1 = self.dims['f'] 157 | Jc[idx_0:idx_1, :self.dims['q']] = self.f.K_J(q) 158 | psi[idx_0:idx_1] = self.f.b(q, dq) 159 | if self.dims['g'] > 0: 160 | idx_0 = self.dims['f'] 161 | idx_1 = self.dims['f'] + self.dims['g'] 162 | Jc[idx_0:idx_1, :self.dims['q']] = self.g.K_J(q) 163 | Jc[idx_0:idx_1, self.dims['q']:(self.dims['q'] + self.dims['g'])] = np.diag(s) 164 | psi[idx_0:idx_1] = self.g.b(q, dq) 165 | return Jc, psi 166 | 167 | def _compute_error_correction(self, q, dq, s, Jc_inv, act_null=None): 168 | q_tmp = q.copy() 169 | dq_tmp = dq.copy() 170 | s_tmp = None 171 | 172 | if self.dims['g'] > 0: 173 | s_tmp = s.copy() 174 | 175 | if act_null is not None: 176 | q_tmp += dq_tmp * self.time_step + act_null[:self.dims['q']] * self.time_step ** 2 / 2 177 | dq_tmp += act_null[:self.dims['q']] * self.time_step 178 | if self.dims['g'] > 0: 179 | s_tmp += act_null[self.dims['q']:self.dims['q'] + self.dims['g']] * self.time_step 180 | 181 | return -Jc_inv @ (self.K_c * self._compute_c(q_tmp, dq_tmp, s_tmp, origin_constr=False)) 182 | 183 | def _compute_c(self, q, dq, s, origin_constr=False): 184 | c = np.zeros(self.dims['f'] + self.dims['g']) 185 | if self.dims['f'] > 0: 186 | idx_0 = 0 187 | idx_1 = self.dims['f'] 188 | c[idx_0:idx_1] = self.f.fun(q, dq, origin_constr) 189 | if self.dims['g'] > 0: 190 | idx_0 = self.dims['f'] 191 | idx_1 = self.dims['f'] + self.dims['g'] 192 | if origin_constr: 193 | c[idx_0:idx_1] = self.g.fun(q, dq, origin_constr) 194 | else: 195 | c[idx_0:idx_1] = self.g.fun(q, dq, origin_constr) + 1 / 2 * s ** 2 196 | return c 197 | 198 | def set_logger(self, logger): 199 | self._logger = logger 200 | 201 | def _update_constraint_stats(self, q, dq): 202 | c_i = self._compute_c(q, dq, 0., origin_constr=True) 203 | c_i[:self.dims['f']] = np.abs(c_i[:self.dims['f']]) 204 | c_dq_i = np.abs(dq) - self.vel_max 205 | self.constr_logs.append([np.max(c_i), np.max(c_dq_i)]) 206 | 207 | def get_constraints_logs(self): 208 | if not hasattr(self.env, "get_constraints_logs"): 209 | constr_logs = np.array(self.constr_logs) 210 | c_avg = np.mean(constr_logs[:, 0]) 211 | c_max = np.max(constr_logs[:, 0]) 212 | c_dq_max = np.max(constr_logs[:, 1]) 213 | self.constr_logs.clear() 214 | return c_avg, c_max, c_dq_max 215 | else: 216 | return self.env.get_constraints_logs() 217 | -------------------------------------------------------------------------------- /atacom/constraints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ViabilityConstraint: 5 | """ 6 | Class of viability constraint 7 | f(q) + K_f df(q, dq) = 0 8 | g(q) + K_g dg(q, dq) <= 0 9 | """ 10 | 11 | def __init__(self, dim_q, dim_out, fun, J, b, K): 12 | """ 13 | Constructor of the viability constraint 14 | 15 | Args 16 | dim_q (int): Dimension of the controllable variable 17 | dim_out (int): Dimension of the constraint 18 | fun (function): The constraint function f(q) or g(q) 19 | J (function): The Jacobian matrix of J_f(q) or J_g(q) 20 | b (function): The term: dJ(q, dq) dq 21 | K (scalar or array): The scale variable K_f or K_g 22 | """ 23 | self.dim_q = dim_q 24 | self.dim_out = dim_out 25 | self.fun_origin = fun 26 | if np.isscalar(K): 27 | self.K = np.ones(dim_out) * K 28 | else: 29 | self.K = K 30 | self.J = J 31 | self.b_state = b 32 | 33 | def fun(self, q, dq, origin_constr=False): 34 | if origin_constr: 35 | return self.fun_origin(q) 36 | else: 37 | return self.fun_origin(q) + self.K * (self.J(q) @ dq) 38 | 39 | def K_J(self, q): 40 | return np.diag(self.K) @ self.J(q) 41 | 42 | def b(self, q, dq): 43 | return self.J(q) @ dq + self.K * self.b_state(q, dq) 44 | 45 | 46 | class ConstraintsSet: 47 | """ 48 | The class to gather multiple constraints 49 | """ 50 | 51 | def __init__(self, dim_q): 52 | self.dim_q = dim_q 53 | self.constraints_list = list() 54 | self.dim_out = 0 55 | 56 | def add_constraint(self, c: ViabilityConstraint): 57 | self.dim_out += c.dim_out 58 | self.constraints_list.append(c) 59 | 60 | def fun(self, q, dq, origin_constr=False): 61 | ret = np.zeros(self.dim_out) 62 | i = 0 63 | for c in self.constraints_list: 64 | ret[i:i + c.dim_out] = c.fun(q, dq, origin_constr) 65 | i += c.dim_out 66 | return ret 67 | 68 | def K_J(self, q): 69 | ret = np.zeros((self.dim_out, self.dim_q)) 70 | i = 0 71 | for c in self.constraints_list: 72 | ret[i:i + c.dim_out] = c.K_J(q) 73 | i += c.dim_out 74 | return ret 75 | 76 | def b(self, q, dq): 77 | ret = np.zeros(self.dim_out) 78 | i = 0 79 | for c in self.constraints_list: 80 | ret[i:i + c.dim_out] = c.b(q, dq) 81 | i += c.dim_out 82 | return ret 83 | -------------------------------------------------------------------------------- /atacom/environments/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /atacom/environments/circular_motion/__init__.py: -------------------------------------------------------------------------------- 1 | from .circle_error_correction import CircleEnvErrorCorrection 2 | from .circle_terminated import CircleEnvTerminated 3 | from .circle_atacom import CircleEnvAtacom 4 | -------------------------------------------------------------------------------- /atacom/environments/circular_motion/circle_atacom.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from atacom.environments.circular_motion.circle_base import CircularMotion 3 | from atacom.atacom import AtacomEnvWrapper 4 | from atacom.constraints import ViabilityConstraint, ConstraintsSet 5 | 6 | class CircleEnvAtacom(AtacomEnvWrapper): 7 | def __init__(self, horizon=500, gamma=0.99, random_init=False, Kc=100, time_step=0.01): 8 | base_env = CircularMotion(random_init=random_init, horizon=horizon, gamma=gamma) 9 | circle_constr = ViabilityConstraint(2, 1, fun=self.circle_fun, J=self.circle_J, b=self.circle_b, K=0.1) 10 | 11 | height_constr = ViabilityConstraint(2, 1, fun=self.height_fun, J=self.height_J, b=self.height_b, K=2) 12 | 13 | f = ConstraintsSet(2) 14 | f.add_constraint(circle_constr) 15 | g = ConstraintsSet(2) 16 | g.add_constraint(height_constr) 17 | 18 | super().__init__(base_env=base_env, dim_q=2, f=f, g=g, Kc=Kc, acc_max=10, vel_max=1, Kq=20, time_step=time_step) 19 | 20 | def _get_q(self, state): 21 | return state[:2] 22 | 23 | def _get_dq(self, state): 24 | return state[2:4] 25 | 26 | def acc_to_ctrl_action(self, ddq): 27 | return ddq / self.acc_max 28 | 29 | def render(self): 30 | offset = np.array([1.25, 1.25]) 31 | pos = self.state[:2] + offset 32 | 33 | act_a = self._act_a[:2] 34 | act_b = self._act_b[:2] 35 | 36 | self.env._viewer.force_arrow(center=pos, direction=act_a, 37 | force=np.linalg.norm(act_a), 38 | max_force=3, width=5, 39 | max_length=0.3, color=(255, 165, 0)) 40 | 41 | self.env._viewer.force_arrow(center=pos, direction=act_b, 42 | force=np.linalg.norm(act_b), 43 | max_force=10, width=5, 44 | max_length=0.3, color=(0, 255, 255)) 45 | super().render() 46 | 47 | @staticmethod 48 | def circle_fun(q): 49 | return np.array([q[0] ** 2 + q[1] ** 2 - 1]) 50 | 51 | @staticmethod 52 | def circle_J(q): 53 | return np.array([[2 * q[0], 2 * q[1]]]) 54 | 55 | @staticmethod 56 | def circle_b(q, dq): 57 | return np.array([[2 * dq[0], 2 * dq[1]]]) @ dq 58 | 59 | @staticmethod 60 | def height_fun(q): 61 | return np.array([-q[1] - 0.5]) 62 | 63 | @staticmethod 64 | def height_J(q): 65 | return np.array([[0, -1]]) 66 | 67 | @staticmethod 68 | def height_b(q, dq): 69 | return np.array([0]) 70 | 71 | @staticmethod 72 | def vel_fun(q, dq): 73 | return np.array([dq[0] ** 2 + dq[1] ** 2 - 1]) 74 | 75 | @staticmethod 76 | def vel_A(q, dq): 77 | return np.array([[2 * dq[0], 2 * dq[1]]]) 78 | 79 | @staticmethod 80 | def vel_b(q, dq): 81 | return np.array([0.]) 82 | -------------------------------------------------------------------------------- /atacom/environments/circular_motion/circle_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import FormatStrFormatter 5 | 6 | from mushroom_rl.core import Environment, MDPInfo 7 | from mushroom_rl.utils import spaces 8 | from mushroom_rl.utils.viewer import Viewer 9 | 10 | 11 | class CircularMotion(Environment): 12 | """ 13 | Base environment of CircularMotion Environment 14 | A point is moving on the 2D unit circular_motion. 15 | Control actions are 2d acceleration along each direction 16 | """ 17 | 18 | def __init__(self, time_step=0.01, horizon=500, gamma=0.99, random_init=False): 19 | self.time_step = time_step 20 | self.random_init = random_init 21 | inf_array = np.ones(4) * np.inf 22 | observation_space = spaces.Box(low=-inf_array, high=inf_array) 23 | action_space = spaces.Box(low=-np.ones(2) * 1, high=np.ones(2) * 1) 24 | mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) 25 | self._viewer = Viewer(env_width=2.5, env_height=2.5, background=(255, 255, 255)) 26 | self.step_action_function = None 27 | self.action_scale = np.array([10., 10.]) 28 | super().__init__(mdp_info) 29 | 30 | self.c = np.zeros(3) 31 | self.constr_logs = list() 32 | 33 | def reset(self, state=None): 34 | if state is None: 35 | if self.random_init: 36 | y = np.random.uniform(-0.5, 1) 37 | x = np.sqrt(1 - y ** 2) * np.sign(np.random.uniform(-1, 1)) 38 | dx = np.random.uniform(-1, 1) 39 | dy = -x * dx / y 40 | v = np.array([dx, dy]) 41 | v = v / np.linalg.norm(v) * np.random.uniform(0, 1) 42 | self._state = np.array([x, y, v[0], v[1]]) 43 | else: 44 | self._state = np.array([-1., 0., 0., 0.0]) 45 | else: 46 | if abs(state[0] ** 2 + state[1] ** 2 - 1) < 1e-6 and abs(state[0] * state[2] - state[1] * state[3]) < 1e-6: 47 | self._state = state 48 | else: 49 | raise ValueError("Can not reset to the state: ", state) 50 | 51 | return self._state 52 | 53 | def step(self, action): 54 | self.check_constraint() 55 | 56 | if self.step_action_function is not None: 57 | action = self.step_action_function(self._state, action) 58 | 59 | self._action = np.clip(action, self.info.action_space.low, self.info.action_space.high) 60 | self._action = self._action * self.action_scale 61 | 62 | self._state[:2] += self._state[2:4] * self.time_step + self._action * self.time_step ** 2 / 2 63 | self._state[2:4] += self._action * self.time_step 64 | 65 | reward = np.exp(-np.linalg.norm(np.array([1., 0.]) - self._state[:2])) 66 | 67 | return self._state, reward, False, dict() 68 | 69 | def render(self): 70 | offset = np.array([1.25, 1.25]) 71 | self._viewer.circle(center=offset, radius=1., color=(0., 0., 255), width=5) 72 | self._viewer.line(start=np.array([-1.25, -0.5]) + offset, end=np.array([1.25, -0.5]) + offset, 73 | color=(255, 20, 147), width=3) 74 | self._viewer.square(center=np.array([1.0, 0.0]) + offset, angle=0, edge=0.05, color=(50, 205, 50)) 75 | 76 | pos = self._state[:2] + offset 77 | self._viewer.circle(center=pos, radius=0.03, color=(255, 0, 0)) 78 | self._viewer.display(self.time_step) 79 | 80 | def _create_sim_state(self): 81 | return self._state 82 | 83 | def _create_observation(self, state): 84 | return state 85 | 86 | def check_constraint(self): 87 | q = self._state[:2] 88 | dq = self._state[2:4] 89 | 90 | self.c = self.get_c(q, dq) 91 | self.c[0] = np.abs(self.c[0]) 92 | self.constr_logs.append(self.c) 93 | 94 | def get_c(self, q, dq): 95 | return np.concatenate([self.c_1(q), self.c_2(q), self.c_3(q, dq)]) 96 | 97 | @staticmethod 98 | def c_1(q): 99 | return np.array([q[0] ** 2 + q[1] ** 2 - 1]) 100 | 101 | @staticmethod 102 | def c_2(q): 103 | return np.array([-q[1] - 0.5]) 104 | 105 | @staticmethod 106 | def c_3(q, dq): 107 | return np.abs(dq) - 1 108 | 109 | def get_constraints_logs(self): 110 | constr_logs = np.array(self.constr_logs) 111 | c_avg = np.mean(np.max(constr_logs[:, :2], axis=1)) 112 | c_max = np.max(constr_logs[:, :2]) 113 | c_dq_max = np.max(constr_logs[:, 2:]) 114 | self.constr_logs.clear() 115 | return c_avg, c_max, c_dq_max 116 | 117 | 118 | def plot_2d_circle(dataset, save_dir="", suffix=""): 119 | plt.rcParams.update({ 120 | "text.usetex": True, 121 | "font.serif": ["Times"]}) 122 | state_list = list() 123 | 124 | if suffix != '': 125 | suffix = suffix + "-" 126 | 127 | i = 0 128 | for data in dataset: 129 | state = data[0] 130 | state_list.append(state) 131 | if data[-1]: 132 | i += 1 133 | state_hist = np.array(state_list) 134 | 135 | fig = plt.figure() 136 | axes_circle = plt.subplot2grid((3, 2), (0, 0), rowspan=2, colspan=1) 137 | axes_c1 = plt.subplot2grid((3, 2), (0, 1)) 138 | axes_c2 = plt.subplot2grid((3, 2), (1, 1)) 139 | axes_c3 = plt.subplot2grid((3, 2), (2, 1)) 140 | fig.subplots_adjust(hspace=.5) 141 | fig.subplots_adjust(wspace=.1) 142 | 143 | axes_circle.plot(state_hist[:, 0], state_hist[:, 1]) 144 | # reference circular_motion 145 | x_1 = np.linspace(-1, 1, 100) 146 | y_1 = np.sqrt(1 - x_1 ** 2) 147 | x_2 = np.linspace(1, -1, 100) 148 | y_2 = -np.sqrt(1 - x_2 ** 2) 149 | x = np.concatenate([x_1, x_2]) 150 | y = np.concatenate([y_1, y_2]) 151 | axes_circle.plot(x, y, ls=':', label="$c_1$") 152 | 153 | # line 154 | # x_3 = np.linspace(-1.2, 1.2, 100) 155 | # y_3 = np.ones_like(x_3) * -0.5 156 | # axes_circle.plot(x_3, y_3, ls=':', c='tab:red', label="c2") 157 | axes_circle.fill_between([-1.2, 1.2], [-0.5, -0.5], [-1.2, -1.2], color='tab:red', alpha=0.3, label="$c_2$") 158 | 159 | axes_circle.set_aspect(1.0) 160 | axes_circle.set_xlim(-1.2, 1.2) 161 | axes_circle.set_ylim(-1.2, 1.2) 162 | axes_circle.set_ylim(-1.2, 1.2) 163 | axes_circle.legend(loc='upper right') 164 | 165 | # c1 166 | axes_c1.plot(state_hist[:, 0] ** 2 + state_hist[:, 1] ** 2 - 1) 167 | max_c = np.max(np.abs(state_hist[:, 0] ** 2 + state_hist[:, 1] ** 2 - 1)) 168 | axes_c1.plot(np.zeros_like(state_hist[:, 2]), c='tab:red', label="c1") 169 | axes_c1.yaxis.tick_right() 170 | axes_c1.yaxis.set_label_position("right") 171 | axes_c1.set_title("$c_1$") 172 | axes_c1.set_ylim(-max_c, max_c) 173 | axes_c1.yaxis.set_major_formatter(FormatStrFormatter('%.4f')) 174 | 175 | # c2 176 | axes_c2.plot(-state_hist[:, 1] - 0.5) 177 | axes_c2.plot(np.zeros_like(state_hist[:, 2]) * 1, c='tab:red', ls='--', label="c2") 178 | axes_c2.yaxis.tick_right() 179 | axes_c2.yaxis.set_label_position("right") 180 | axes_c2.set_title("$c_2$") 181 | axes_c2.yaxis.set_major_formatter(FormatStrFormatter('%.4f')) 182 | 183 | # c3 184 | axes_c3.plot(state_hist[:, 2] - 1, label="c3") 185 | axes_c3.plot(state_hist[:, 3] - 1, label="c4") 186 | axes_c3.plot(np.zeros_like(state_hist[:, 2]) * 1, c='tab:red', ls='--', label="c3") 187 | axes_c3.yaxis.tick_right() 188 | axes_c3.yaxis.set_label_position("right") 189 | axes_c3.yaxis.set_major_formatter(FormatStrFormatter('%.4f')) 190 | axes_c3.set_title("$c_3 \& c_4$") 191 | 192 | textstr = '\n'.join(('$c_1: x^2 + y^2 - 1 = 0$', 193 | '$c_2: - y - 0.5 < 0$', 194 | '$c_3: \dot{x} - 1 < 0$', 195 | '$c_4: \dot{y} - 1 < 0$')) 196 | props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) 197 | axes_circle.text(0.5, -0.2, textstr, transform=axes_circle.transAxes, fontsize=14, 198 | verticalalignment='top', horizontalalignment='center', bbox=props) 199 | 200 | filename = "circular_motion-" + suffix + str(i) + ".pdf" 201 | 202 | if not os.path.exists(save_dir): 203 | os.makedirs(save_dir) 204 | 205 | plt.savefig(os.path.join(save_dir, filename)) 206 | plt.close(fig) 207 | state_list = list() 208 | -------------------------------------------------------------------------------- /atacom/environments/circular_motion/circle_error_correction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from atacom.constraints import ViabilityConstraint, ConstraintsSet 3 | from atacom.environments.circular_motion.circle_base import CircularMotion 4 | from atacom.error_correction_wrapper import ErrorCorrectionEnvWrapper 5 | 6 | 7 | class CircleEnvErrorCorrection(ErrorCorrectionEnvWrapper): 8 | def __init__(self, horizon=500, gamma=0.99, random_init=False, Kc=100., time_step=0.01): 9 | base_env = CircularMotion(random_init=random_init, horizon=horizon, gamma=gamma) 10 | 11 | circle_constr = ViabilityConstraint(dim_q=2, dim_out=1, fun=self.circle_fun, J=self.circle_J, b=self.circle_b, 12 | K=0.1) 13 | height_constr = ViabilityConstraint(dim_q=2, dim_out=1, fun=self.height_fun, J=self.height_J, b=self.height_b, 14 | K=2) 15 | 16 | f = ConstraintsSet(2) 17 | f.add_constraint(circle_constr) 18 | g = ConstraintsSet(2) 19 | g.add_constraint(height_constr) 20 | 21 | super().__init__(base_env=base_env, dim_q=2, f=f, g=g, Kc=Kc, acc_max=10, vel_max=1, Kq=20, time_step=time_step) 22 | 23 | def _get_q(self, state): 24 | return state[:2] 25 | 26 | def _get_dq(self, state): 27 | return state[2:4] 28 | 29 | def acc_to_ctrl_action(self, ddq): 30 | return ddq / self.acc_max 31 | 32 | def render(self): 33 | offset = np.array([1.25, 1.25]) 34 | pos = self.state[:2] + offset 35 | 36 | act_a = self._act_a[:2] 37 | act_b = self._act_b[:2] 38 | act_err = self._act_err[:2] 39 | 40 | self.env._viewer.force_arrow(center=pos, direction=act_b, 41 | force=np.linalg.norm(act_b), 42 | max_force=10, width=5, 43 | max_length=0.3, color=(0, 255, 255)) 44 | super().render() 45 | 46 | @staticmethod 47 | def circle_fun(q): 48 | return np.array([q[0] ** 2 + q[1] ** 2 - 1]) 49 | 50 | @staticmethod 51 | def circle_J(q): 52 | return np.array([[2 * q[0], 2 * q[1]]]) 53 | 54 | @staticmethod 55 | def circle_b(q, dq): 56 | return np.array([[2 * dq[0], 2 * dq[1]]]) @ dq 57 | 58 | @staticmethod 59 | def height_fun(q): 60 | return np.array([-q[1] - 0.5]) 61 | 62 | @staticmethod 63 | def height_J(q): 64 | return np.array([[0, -1]]) 65 | 66 | @staticmethod 67 | def height_b(q, dq): 68 | return np.array([0]) 69 | 70 | @staticmethod 71 | def vel_fun(q, dq): 72 | return np.array([dq[0] ** 2 + dq[1] ** 2 - 1]) 73 | 74 | @staticmethod 75 | def vel_A(q, dq): 76 | return np.array([[2 * dq[0], 2 * dq[1]]]) 77 | 78 | @staticmethod 79 | def vel_b(q, dq): 80 | return np.array([0.]) 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /atacom/environments/circular_motion/circle_terminated.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from atacom.environments.circular_motion.circle_base import CircularMotion 3 | 4 | from mushroom_rl.core import Core, Agent 5 | from mushroom_rl.utils.dataset import compute_J 6 | 7 | 8 | class CircleEnvTerminated(CircularMotion): 9 | """ 10 | CircularMotion with termination when constraint > tolerance 11 | """ 12 | 13 | def __init__(self, time_step=0.01, horizon=500, gamma=0.99, random_init=False, tol=0.1): 14 | super().__init__(time_step=time_step, horizon=horizon, gamma=gamma, random_init=random_init) 15 | self._tol = tol 16 | 17 | def step(self, action): 18 | state, reward, absorbing, info = super().step(action) 19 | 20 | absorbing = absorbing or self._terminate() 21 | if absorbing: 22 | reward = -100 23 | return state, reward, absorbing, info 24 | 25 | def _terminate(self): 26 | if np.any(self.c > self._tol): 27 | return True 28 | else: 29 | return False 30 | 31 | def render(self): 32 | offset = np.array([1.25, 1.25]) 33 | pos = self._state[:2] + offset 34 | 35 | act_b = self._action[:2] 36 | self._viewer.force_arrow(center=pos, direction=act_b, 37 | force=np.linalg.norm(act_b), 38 | max_force=10, width=5, 39 | max_length=0.3, color=(0, 255, 255)) 40 | super().render() 41 | 42 | 43 | def env_test(): 44 | env = CircleEnvTerminated(tol=0.1) 45 | 46 | class DummyAgent(Agent): 47 | def __init__(self, mdp_info): 48 | self.mdp_info = mdp_info 49 | 50 | def fit(self, dataset): 51 | pass 52 | 53 | def episode_start(self): 54 | pass 55 | 56 | def draw_action(self, state): 57 | return np.random.randn(self.mdp_info.action_space.shape[0]) * 1 58 | 59 | agent = DummyAgent(env.info) 60 | 61 | core = Core(agent, env) 62 | 63 | dataset = core.evaluate(n_steps=1000, render=True) 64 | 65 | J = np.mean(compute_J(dataset, core.mdp.info.gamma)) 66 | R = np.mean(compute_J(dataset)) 67 | c_avg, c_max, c_dq_max = env.get_constraints_logs() 68 | print("J: {}, R:{}, c_avg:{}, c_max:{}, c_dq_max:{}".format(J, R, c_avg, c_max, c_dq_max)) 69 | 70 | 71 | if __name__ == '__main__': 72 | env_test() 73 | -------------------------------------------------------------------------------- /atacom/environments/collision_avoidance/__init__.py: -------------------------------------------------------------------------------- 1 | from .collision_avoidance_atacom import PointReachAtacom -------------------------------------------------------------------------------- /atacom/environments/collision_avoidance/collision_avoidance_atacom.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.core import Core, Agent 2 | from mushroom_rl.utils.spaces import * 3 | from mushroom_rl.utils.dataset import compute_J 4 | from atacom.environments.collision_avoidance.collision_avoidance_base import PointGoalReach 5 | from atacom.utils.null_space_coordinate import pinv_null, rref 6 | 7 | 8 | class PointReachAtacom(PointGoalReach): 9 | def __init__(self, time_step=0.01, horizon=1000, gamma=0.99, n_objects=4, random_walk=False): 10 | super().__init__(time_step=time_step, horizon=horizon, gamma=gamma, n_objects=n_objects, 11 | random_walk=random_walk) 12 | self.s = np.zeros(self.n_objects) 13 | self.Kc = np.diag(np.ones(self.n_objects) * 100) 14 | self.K = np.ones(self.n_objects) * 0.5 15 | 16 | self.constr_logs = list() 17 | 18 | def reset(self, state=None): 19 | super(PointReachAtacom, self).reset(state) 20 | self.q = self.get_q(self._state) 21 | self.dq = self.get_dq(self._state) 22 | self.p = self.get_p(self._state) 23 | self.dp = self.get_dp(self._state) 24 | 25 | self.s = np.sqrt(np.maximum(- 2 * self.get_c(self.q, self.p), 0.)) 26 | # self.s = np.maximum(- self.get_c(self.q, self.p), 0.) 27 | return self._state 28 | 29 | def step(self, action): 30 | self.q = self.get_q(self._state) 31 | self.dq = self.get_dq(self._state) 32 | self.p = self.get_p(self._state) 33 | self.dp = self.get_dp(self._state) 34 | 35 | Jc_q = self.get_Jc_q(self.q, self.p) 36 | Jc_q_inv, Nc_q = pinv_null(Jc_q) 37 | 38 | c_origin = self.get_c(self.q, self.p) 39 | c_dq_i = 0. 40 | self.constr_logs.append([np.max(c_origin), np.max(c_dq_i)]) 41 | c = c_origin + 1 / 2 * self.s ** 2 + self.K * (self.get_dc(self.q, self.p, self.dq, self.dp)) 42 | 43 | Nc = rref(Nc_q, row_vectors=False) 44 | 45 | psi = self.get_psi(self.q, self.p, self.dq, self.dp) 46 | action = - Jc_q_inv @ (psi + self.Kc @ c) + Nc @ action 47 | self.s += action[2:] * self.time_step 48 | return super().step(action[:2]) 49 | 50 | @staticmethod 51 | def get_q(state): 52 | return state[:2] 53 | 54 | @staticmethod 55 | def get_dq(state): 56 | return state[2:4] 57 | 58 | def get_p(self, state): 59 | p = np.zeros(2 * self.n_objects) 60 | for i in range(self.n_objects): 61 | idx = 4 * (i + 1) 62 | p[2 * i: 2 * i + 2] = state[idx: idx + 2] 63 | return p 64 | 65 | def get_dp(self, state): 66 | dp = np.zeros(2 * self.n_objects) 67 | for i in range(self.n_objects): 68 | idx = 4 * (i + 1) 69 | dp[2 * i: 2 * i + 2] = state[idx + 2: idx + 4] 70 | return dp 71 | 72 | def get_c(self, q, p): 73 | c_out = np.zeros(self.n_objects) 74 | for i in range(self.n_objects): 75 | c_out[i] = 0.6 ** 2 - np.linalg.norm(q - p[2 * i: 2 * i + 2]) ** 2 76 | return c_out 77 | 78 | def get_dc(self, q, p, dq, dp): 79 | dc_out = np.zeros(self.n_objects) 80 | for i in range(self.n_objects): 81 | p_i = p[2 * i:2 * i + 2] 82 | dp_i = dp[2 * i:2 * i + 2] 83 | dc_out[i] = self.get_Jp(q, p_i) @ dp_i + self.get_Jq(q, p_i) @ dq 84 | return dc_out 85 | 86 | @staticmethod 87 | def get_Jq(q, p_i): 88 | return -2 * (q - p_i) 89 | 90 | @staticmethod 91 | def get_Jp(q, p_i): 92 | return 2 * (q - p_i) 93 | 94 | @staticmethod 95 | def get_Hqq(q, p_i): 96 | return -2 * np.eye(2) 97 | 98 | @staticmethod 99 | def get_Hqp(q, p_i): 100 | return 2 * np.eye(2) 101 | 102 | @staticmethod 103 | def get_Hpp(q, p_i): 104 | return -2 * np.eye(2) 105 | 106 | def get_bp(self, q, p_i, dq, dp_i): 107 | return (p_i @ self.get_Hpp(q, p_i) + q @ self.get_Hqp(q, p_i)) @ p_i 108 | 109 | def get_bq(self, q, p_i, dq, dp_i): 110 | return (q @ self.get_Hqq(q, p_i) + p_i @ self.get_Hqp(q, p_i)) @ q 111 | 112 | def get_Jc_q(self, q, p): 113 | Jc_q = np.zeros((self.n_objects, self.n_objects + 2)) 114 | for i in range(self.n_objects): 115 | p_i = p[2 * i:2 * i + 2] 116 | Jc_q[i, :2] = self.get_Jq(q, p_i) 117 | Jc_q[:, 2:] = np.diag(self.s) 118 | return Jc_q 119 | 120 | def get_psi(self, q, p, dq, dp): 121 | psi = np.zeros(self.n_objects) 122 | for i in range(self.n_objects): 123 | p_i = p[2 * i:2 * i + 2] 124 | dp_i = dp[2 * i:2 * i + 2] 125 | psi[i] = self.get_Jp(q, p_i) @ dp_i + self.get_Jq(q, p_i) @ dq + \ 126 | self.K[i] * (self.get_bp(q, p_i, dq, dp_i) + self.get_bq(q, p_i, dq, dp_i)) 127 | return psi 128 | 129 | def get_constraints_logs(self): 130 | constr_logs = np.array(self.constr_logs) 131 | c_avg = np.mean(constr_logs[:, 0]) 132 | c_max = np.max(constr_logs[:, 0]) 133 | c_dq_max = np.max(constr_logs[:, 1]) 134 | self.constr_logs.clear() 135 | return c_avg, c_max, c_dq_max 136 | 137 | 138 | -------------------------------------------------------------------------------- /atacom/environments/collision_avoidance/collision_avoidance_base.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.core import Environment, MDPInfo 2 | from mushroom_rl.utils.spaces import * 3 | from mushroom_rl.utils.viewer import Viewer 4 | 5 | 6 | class PointGoalReach(Environment): 7 | def __init__(self, time_step=0.01, horizon=500, gamma=0.99, n_objects=2, random_walk=False): 8 | self.time_step = time_step 9 | self.n_objects = n_objects 10 | self.random_walk = random_walk 11 | 12 | self.state_dim = 4 * (1 + n_objects) 13 | observation_space = Box(low=-np.ones(self.state_dim) * 10, high=np.ones(self.state_dim) * 10) 14 | action_space = Box(low=-np.ones(2) * 1, high=np.ones(2) * 1) 15 | mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) 16 | 17 | self._viewer = Viewer(env_width=10, env_height=10, background=(255, 255, 255)) 18 | self.step_action_function = None 19 | self.action_scale = np.array([10., 10.]) 20 | self._obj_circle_center = [] 21 | self._obj_radius = 2 22 | self._time = 0. 23 | super().__init__(mdp_info) 24 | 25 | def reset(self, state=None): 26 | self._time = 0. 27 | self._state = np.zeros(self.state_dim) 28 | 29 | self._state[:2] = np.array([1., 1.]) 30 | self._state[2:4] = np.array([0., 0.]) 31 | 32 | for i in range(self.n_objects): 33 | obj_idx = 4 * (i + 1) 34 | self._state[obj_idx:obj_idx + 2] = np.random.uniform(2, 8, 2) 35 | self._obj_circle_center.append(self._state[obj_idx:obj_idx + 2] - np.array([self._obj_radius, 0.])) 36 | self._state[obj_idx + 2:obj_idx + 4] = np.zeros(2) 37 | 38 | return self._state 39 | 40 | def step(self, action): 41 | if self.step_action_function is not None: 42 | action = self.step_action_function(self._state, action) 43 | 44 | self._action = np.clip(action, self.info.action_space.low, self.info.action_space.high) 45 | self._action = self._action * self.action_scale 46 | 47 | self._state[:2] += self._state[2:4] * self.time_step 48 | self._state[2:4] += self._action * self.time_step 49 | 50 | change_sign = np.logical_or(self._state[0:2] <= 0, self._state[0:2] >= 10) 51 | if np.any(change_sign): 52 | self._state[2:4][change_sign] = - self._state[2:4][change_sign] 53 | 54 | for i in range(self.n_objects): 55 | obj_idx = 4 * (i + 1) 56 | 57 | if self.random_walk: 58 | self._state[obj_idx:obj_idx + 2] += self._state[obj_idx + 2:obj_idx + 4] * self.time_step 59 | self._state[obj_idx:obj_idx + 2] = np.clip(self._state[obj_idx:obj_idx + 2], 2, 10) 60 | obj_action = np.random.uniform(-1, 1, 2) * 10 61 | change_sign = np.logical_or(self._state[obj_idx:obj_idx + 2] <= 2, 62 | self._state[obj_idx:obj_idx + 2] >= 10) 63 | if np.any(change_sign): 64 | self._state[obj_idx + 2:obj_idx + 4][change_sign] = - self._state[obj_idx + 2:obj_idx + 4][ 65 | change_sign] 66 | self._state[obj_idx + 2:obj_idx + 4] += obj_action * self.time_step 67 | self._state[obj_idx + 2:obj_idx + 4] = np.clip(self._state[obj_idx + 2:obj_idx + 4], -1, 1) 68 | else: 69 | self._state[obj_idx] = self._obj_circle_center[i][0] + self._obj_radius * np.cos(self._time * 2 * np.pi) 70 | self._state[obj_idx + 1] = self._obj_circle_center[i][1] + self._obj_radius * np.sin( 71 | self._time * 2 * np.pi) 72 | self._state[obj_idx + 2] = -2 * self._obj_radius * np.pi * np.sin(self._time * 2 * np.pi) 73 | self._state[obj_idx + 3] = 2 * self._obj_radius * np.pi * np.cos(self._time * 2 * np.pi) 74 | 75 | self._time += self.time_step 76 | reward = - np.linalg.norm(np.array([9., 9.]) - self._state[:2]) / (8 * np.sqrt(2)) 77 | return self._state, reward, False, dict() 78 | 79 | def render(self): 80 | self._viewer.circle(center=self._state[:2], radius=0.3, color=(0., 0., 255), width=0) 81 | 82 | for i in range(self.n_objects): 83 | obj_idx = 4 * (i + 1) 84 | self._viewer.circle(center=self._state[obj_idx:obj_idx + 2], radius=0.3, color=(255., 0., 0), width=0) 85 | 86 | self._viewer.square(center=np.array([9., 9.]), angle=0, edge=0.6, color=(0., 255., 0)) 87 | 88 | self._viewer.display(self.time_step * 0.4) 89 | 90 | def _create_sim_state(self): 91 | return self._state 92 | 93 | def _create_observation(self, state): 94 | return state -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/__init__.py: -------------------------------------------------------------------------------- 1 | from .iiwa_hit_atacom import AirHockeyIiwaAtacom 2 | from .iiwa_air_hockey_rmp import AirHockeyIiwaRmp 3 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/env_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pinocchio as pino 4 | import pybullet 5 | import pybullet_utils.transformations as transformations 6 | from mushroom_rl.environments.pybullet import PyBullet, PyBulletObservationType 7 | from mushroom_rl.environments.pybullet_envs import __file__ as env_path 8 | 9 | 10 | class AirHockeyBase(PyBullet): 11 | def __init__(self, gamma=0.99, horizon=500, timestep=1 / 240., n_intermediate_steps=1, debug_gui=False, 12 | n_agents=1, env_noise=False, obs_noise=False, obs_delay=False, torque_control=True, 13 | step_action_function=None, isolated_joint_7=False): 14 | self.n_agents = n_agents 15 | self.env_noise = env_noise 16 | self.obs_noise = obs_noise 17 | self.obs_delay = obs_delay 18 | self.step_action_function = step_action_function 19 | self.isolated_joint_7 = isolated_joint_7 20 | 21 | puck_file = os.path.join(os.path.dirname(os.path.abspath(env_path)), 22 | "data", "air_hockey", "puck.urdf") 23 | table_file = os.path.join(os.path.dirname(os.path.abspath(env_path)), 24 | "data", "air_hockey", "air_hockey_table.urdf") 25 | robot_file_1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "urdf", "iiwa_1.urdf") 26 | robot_file_2 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "urdf", "iiwa_2.urdf") 27 | 28 | model_files = dict() 29 | model_files[puck_file] = dict(flags=pybullet.URDF_USE_IMPLICIT_CYLINDER, 30 | basePosition=[0.0, 0, 0], baseOrientation=[0, 0, 0.0, 1.0]) 31 | model_files[table_file] = dict(useFixedBase=True, basePosition=[0.0, 0, 0], 32 | baseOrientation=[0, 0, 0.0, 1.0]) 33 | 34 | actuation_spec = list() 35 | observation_spec = [("puck", PyBulletObservationType.BODY_POS), 36 | ("puck", PyBulletObservationType.BODY_LIN_VEL), 37 | ("puck", PyBulletObservationType.BODY_ANG_VEL)] 38 | self.agents = [] 39 | 40 | if torque_control: 41 | control = pybullet.TORQUE_CONTROL 42 | else: 43 | control = pybullet.POSITION_CONTROL 44 | 45 | if 1 <= self.n_agents <= 2: 46 | agent_spec = dict() 47 | agent_spec['name'] = "iiwa_1" 48 | agent_spec.update({"urdf": os.path.join(os.path.dirname(os.path.abspath(__file__)), 49 | "urdf", "iiwa_1.urdf")}) 50 | translate = [-1.51, 0, -0.1] 51 | quaternion = [0.0, 0.0, 0.0, 1.0] 52 | agent_spec['frame'] = transformations.translation_matrix(translate) 53 | agent_spec['frame'] = agent_spec['frame'] @ transformations.quaternion_matrix(quaternion) 54 | model_files[robot_file_1] = dict( 55 | flags=pybullet.URDF_USE_IMPLICIT_CYLINDER | pybullet.URDF_USE_INERTIA_FROM_FILE, 56 | basePosition=translate, baseOrientation=quaternion) 57 | 58 | self.agents.append(agent_spec) 59 | actuation_spec += [("iiwa_1/joint_1", control), 60 | ("iiwa_1/joint_2", control), 61 | ("iiwa_1/joint_3", control), 62 | ("iiwa_1/joint_4", control), 63 | ("iiwa_1/joint_5", control), 64 | ("iiwa_1/joint_6", control)] 65 | if self.isolated_joint_7: 66 | actuation_spec += [("iiwa_1/joint_7", pybullet.POSITION_CONTROL)] 67 | else: 68 | actuation_spec += [("iiwa_1/joint_7", control)] 69 | actuation_spec += [("iiwa_1/striker_joint_1", pybullet.POSITION_CONTROL), 70 | ("iiwa_1/striker_joint_2", pybullet.POSITION_CONTROL)] 71 | 72 | observation_spec += [("iiwa_1/joint_1", PyBulletObservationType.JOINT_POS), 73 | ("iiwa_1/joint_2", PyBulletObservationType.JOINT_POS), 74 | ("iiwa_1/joint_3", PyBulletObservationType.JOINT_POS), 75 | ("iiwa_1/joint_4", PyBulletObservationType.JOINT_POS), 76 | ("iiwa_1/joint_5", PyBulletObservationType.JOINT_POS), 77 | ("iiwa_1/joint_6", PyBulletObservationType.JOINT_POS), 78 | ("iiwa_1/joint_7", PyBulletObservationType.JOINT_POS), 79 | ("iiwa_1/striker_joint_1", PyBulletObservationType.JOINT_POS), 80 | ("iiwa_1/striker_joint_2", PyBulletObservationType.JOINT_POS), 81 | ("iiwa_1/joint_1", PyBulletObservationType.JOINT_VEL), 82 | ("iiwa_1/joint_2", PyBulletObservationType.JOINT_VEL), 83 | ("iiwa_1/joint_3", PyBulletObservationType.JOINT_VEL), 84 | ("iiwa_1/joint_4", PyBulletObservationType.JOINT_VEL), 85 | ("iiwa_1/joint_5", PyBulletObservationType.JOINT_VEL), 86 | ("iiwa_1/joint_6", PyBulletObservationType.JOINT_VEL), 87 | ("iiwa_1/joint_7", PyBulletObservationType.JOINT_VEL), 88 | ("iiwa_1/striker_joint_1", PyBulletObservationType.JOINT_VEL), 89 | ("iiwa_1/striker_joint_2", PyBulletObservationType.JOINT_VEL), 90 | ("iiwa_1/striker_mallet_tip", PyBulletObservationType.LINK_POS), 91 | ("iiwa_1/striker_mallet_tip", PyBulletObservationType.LINK_LIN_VEL)] 92 | 93 | if self.n_agents == 2: 94 | agent_spec = dict() 95 | agent_spec['name'] = "iiwa_2" 96 | agent_spec.update({"urdf": os.path.join(os.path.dirname(os.path.abspath(__file__)), 97 | "urdf", "iiwa_pino.urdf")}) 98 | translate = [1.51, 0, -0.1] 99 | quaternion = [0.0, 0.0, 1.0, 0.0] 100 | agent_spec['frame'] = transformations.translation_matrix(translate) 101 | agent_spec['frame'] = agent_spec['frame'] @ transformations.quaternion_matrix(quaternion) 102 | model_files[robot_file_2] = dict( 103 | flags=pybullet.URDF_USE_IMPLICIT_CYLINDER | pybullet.URDF_USE_INERTIA_FROM_FILE, 104 | basePosition=translate, baseOrientation=quaternion) 105 | self.agents.append(agent_spec) 106 | 107 | actuation_spec += [("iiwa_2/joint_1", control), 108 | ("iiwa_2/joint_2", control), 109 | ("iiwa_2/joint_3", control), 110 | ("iiwa_2/joint_4", control), 111 | ("iiwa_2/joint_5", control), 112 | ("iiwa_2/joint_6", control)] 113 | if self.isolated_joint_7: 114 | actuation_spec += [("iiwa_2/joint_7", pybullet.POSITION_CONTROL)] 115 | else: 116 | actuation_spec += [("iiwa_2/joint_7", control)] 117 | actuation_spec += [("iiwa_2/striker_joint_1", pybullet.POSITION_CONTROL), 118 | ("iiwa_2/striker_joint_2", pybullet.POSITION_CONTROL)] 119 | 120 | observation_spec += [("iiwa_2/joint_1", PyBulletObservationType.JOINT_POS), 121 | ("iiwa_2/joint_2", PyBulletObservationType.JOINT_POS), 122 | ("iiwa_2/joint_3", PyBulletObservationType.JOINT_POS), 123 | ("iiwa_2/joint_4", PyBulletObservationType.JOINT_POS), 124 | ("iiwa_2/joint_5", PyBulletObservationType.JOINT_POS), 125 | ("iiwa_2/joint_6", PyBulletObservationType.JOINT_POS), 126 | ("iiwa_2/joint_7", PyBulletObservationType.JOINT_POS), 127 | ("iiwa_2/striker_joint_1", PyBulletObservationType.JOINT_POS), 128 | ("iiwa_2/striker_joint_2", PyBulletObservationType.JOINT_POS), 129 | ("iiwa_2/joint_1", PyBulletObservationType.JOINT_VEL), 130 | ("iiwa_2/joint_2", PyBulletObservationType.JOINT_VEL), 131 | ("iiwa_2/joint_3", PyBulletObservationType.JOINT_VEL), 132 | ("iiwa_2/joint_4", PyBulletObservationType.JOINT_VEL), 133 | ("iiwa_2/joint_5", PyBulletObservationType.JOINT_VEL), 134 | ("iiwa_2/joint_6", PyBulletObservationType.JOINT_VEL), 135 | ("iiwa_2/joint_7", PyBulletObservationType.JOINT_VEL), 136 | ("iiwa_2/striker_joint_1", PyBulletObservationType.JOINT_VEL), 137 | ("iiwa_2/striker_joint_2", PyBulletObservationType.JOINT_VEL), 138 | ("iiwa_2/striker_mallet_tip", PyBulletObservationType.LINK_POS), 139 | ("iiwa_2/striker_mallet_tip", PyBulletObservationType.LINK_LIN_VEL)] 140 | else: 141 | raise ValueError('n_agents should be 1 or 2') 142 | 143 | super().__init__(model_files, actuation_spec, observation_spec, gamma, 144 | horizon, timestep=timestep, n_intermediate_steps=n_intermediate_steps, 145 | debug_gui=debug_gui, size=(500, 500), distance=1.8) 146 | 147 | self.pino_model = pino.buildModelFromUrdf(self.agents[0]['urdf']) 148 | se_tip = pino.SE3(np.eye(3), np.array([0., 0., 0.585])) 149 | self.pino_model.addBodyFrame('striker_rod_tip', 7, se_tip, self.pino_model.nframes - 1) 150 | self.pino_data = self.pino_model.createData() 151 | self.frame_idx = self.pino_model.nframes - 1 152 | 153 | self._client.resetDebugVisualizerCamera(cameraDistance=2, cameraYaw=0.0, cameraPitch=-89.9, 154 | cameraTargetPosition=[0., 0., 0.]) 155 | self.env_spec = dict() 156 | self.env_spec['table'] = {"length": 1.96, "width": 1.02, "height": 0.0, "goal": 0.25, "urdf": table_file} 157 | self.env_spec['puck'] = {"radius": 0.03165, "urdf": puck_file} 158 | self.env_spec['mallet'] = {"radius": 0.05} 159 | self.env_spec['universal_height'] = 0.1505 160 | 161 | def _compute_action(self, state, action): 162 | if self.step_action_function is None: 163 | ctrl_action = action 164 | else: 165 | ctrl_action = self.step_action_function(state, action) 166 | 167 | joint_state = self.joints.positions(state)[:9] 168 | 169 | if self.isolated_joint_7: 170 | joint_7_des_pos = self._compute_joint_7(joint_state) 171 | ctrl_action = np.concatenate([ctrl_action, joint_7_des_pos]) 172 | 173 | joint_universal_pos = self._compute_universal_joint(joint_state) 174 | return np.concatenate([ctrl_action, joint_universal_pos]) 175 | 176 | def _simulation_pre_step(self): 177 | if self.env_noise: 178 | force = np.concatenate([np.random.randn(2), [0]]) * 0.0005 179 | self._client.applyExternalForce(self._model_map['puck']['id'], -1, force, [0., 0., 0.], 180 | self._client.WORLD_FRAME) 181 | 182 | def is_absorbing(self, state): 183 | boundary = np.array([self.env_spec['table']['length'], self.env_spec['table']['width']]) / 2 184 | puck_pos = self.get_sim_state(state, "puck", PyBulletObservationType.BODY_POS)[:3] 185 | if np.any(np.abs(puck_pos[:2]) > boundary) or abs(puck_pos[2] - self.env_spec['table']['height']) > 0.1: 186 | return True 187 | 188 | boundary_mallet = boundary 189 | for agent in self.agents: 190 | mallet_pose = self.get_sim_state(state, agent['name'] + "/striker_mallet_tip", 191 | PyBulletObservationType.LINK_POS) 192 | if np.any(np.abs(mallet_pose[:2]) - boundary_mallet > 0.02): 193 | return True 194 | return False 195 | 196 | def _compute_joint_7(self, state): 197 | raise NotImplementedError 198 | 199 | def _compute_universal_joint(self, state): 200 | raise NotImplementedError 201 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/env_hitting.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from mushroom_rl.environments.pybullet import PyBulletObservationType 4 | from atacom.environments.iiwa_air_hockey.env_single import AirHockeySingle 5 | 6 | 7 | class AirHockeyHit(AirHockeySingle): 8 | def __init__(self, gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=1, 9 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, torque_control=True, 10 | random_init=False, step_action_function=None, action_penalty=1e-3, isolated_joint_7=False): 11 | self.hit_range = np.array([[-0.6, -0.2], [-0.4, 0.4]]) 12 | self.goal = np.array([0.98, 0]) 13 | self.has_hit = False 14 | self.r_hit = 0. 15 | self.vel_hit_x = 0. 16 | self.random_init = random_init 17 | self.action_penalty = action_penalty 18 | super().__init__(gamma=gamma, horizon=horizon, timestep=timestep, 19 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui, 20 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay, torque_control=torque_control, 21 | step_action_function=step_action_function, isolated_joint_7=isolated_joint_7) 22 | 23 | def setup(self, state): 24 | if self.random_init: 25 | puck_pos = np.random.rand(2) * (self.hit_range[:, 1] - self.hit_range[:, 0]) + self.hit_range[:, 0] 26 | else: 27 | puck_pos = np.mean(self.hit_range, axis=1) 28 | 29 | puck_pos = np.concatenate([puck_pos, [0.0]]) 30 | self.client.resetBasePositionAndOrientation(self._model_map['puck'], puck_pos, [0, 0, 0, 1.0]) 31 | 32 | for i, (model_id, joint_id, _) in enumerate(self._indexer.action_data[:7]): 33 | self.client.resetJointState(model_id, joint_id, self.init_state[i]) 34 | 35 | self.has_hit = False 36 | self.r_hit = 0. 37 | self.vel_hit_x = 0. 38 | 39 | def reward(self, state, action, next_state, absorbing): 40 | r = 0 41 | puck_pos = self.get_sim_state(next_state, "puck", PyBulletObservationType.BODY_POS)[:2] 42 | puck_vel = self.get_sim_state(next_state, "puck", PyBulletObservationType.BODY_LIN_VEL)[:2] 43 | if absorbing: 44 | if puck_pos[0] - self.env_spec['table']['length'] / 2 > 0 and \ 45 | np.abs(puck_pos[1]) - self.env_spec['table']['goal'] < 0: 46 | r = 80 47 | else: 48 | if not self.has_hit: 49 | ee_pos = self.get_sim_state(next_state, 50 | self.agents[0]['name'] + "/striker_mallet_tip", 51 | PyBulletObservationType.LINK_POS)[:2] 52 | dist_ee_puck = np.linalg.norm(puck_pos - ee_pos) 53 | # r = np.exp(-8 * (dist_ee_puck - 0.08)) 54 | 55 | vec_ee_puck = (puck_pos - ee_pos) / dist_ee_puck 56 | vec_puck_goal = (self.goal - puck_pos) / np.linalg.norm(self.goal - puck_pos) 57 | cos_ang = np.clip(vec_puck_goal @ vec_ee_puck, 0, 1) 58 | r = np.exp(-8 * (dist_ee_puck - 0.08)) * cos_ang 59 | self.r_hit = r 60 | else: 61 | # dist = np.linalg.norm(self.goal - puck_pos) 62 | # if puck_vel[0] > 0: 63 | # r_vel = np.abs(puck_vel[0]) 64 | # r = 0.5 * (np.exp(-10 * dist) + r_vel) + 0.5 65 | 66 | r = 1 + self.r_hit + self.vel_hit_x * 0.1 67 | 68 | r -= self.action_penalty * np.linalg.norm(action) 69 | return r 70 | 71 | def is_absorbing(self, state): 72 | if super().is_absorbing(state): 73 | return True 74 | if self.has_hit: 75 | puck_vel = self.get_sim_state(self._state, "puck", PyBulletObservationType.BODY_LIN_VEL)[:2] 76 | if np.linalg.norm(puck_vel) < 0.01: 77 | return True 78 | return False 79 | 80 | def _simulation_post_step(self): 81 | if not self.has_hit: 82 | puck_vel = self.get_sim_state(self._state, "puck", PyBulletObservationType.BODY_LIN_VEL)[:2] 83 | if np.linalg.norm(puck_vel) > 0.1: 84 | self.has_hit = True 85 | self.vel_hit_x = puck_vel[0] 86 | 87 | 88 | if __name__ == '__main__': 89 | env = AirHockeyHit(debug_gui=True, env_noise=False, obs_noise=False, obs_delay=False, n_intermediate_steps=4) 90 | 91 | R = 0. 92 | J = 0. 93 | gamma = 1. 94 | steps = 0 95 | while True: 96 | action = np.random.uniform(-1, 1, env.info.action_space.low.shape) * 8 97 | observation, reward, done, info = env.step(action) 98 | gamma *= env.info.gamma 99 | J += gamma * reward 100 | R += reward 101 | steps += 1 102 | if done or steps > env.info.horizon: 103 | print("J: ", J, " R: ", R) 104 | R = 0. 105 | J = 0. 106 | gamma = 1. 107 | steps = 0 108 | env.reset() 109 | time.sleep(1/60.) 110 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/env_single.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pinocchio as pino 3 | import pybullet_utils.transformations as transformations 4 | from mushroom_rl.core import MDPInfo 5 | from mushroom_rl.environments.pybullet import PyBulletObservationType 6 | from mushroom_rl.utils.spaces import Box 7 | from atacom.environments.iiwa_air_hockey.env_base import AirHockeyBase 8 | from atacom.environments.iiwa_air_hockey.kinematics import clik, fk 9 | 10 | 11 | class AirHockeySingle(AirHockeyBase): 12 | def __init__(self, gamma=0.99, horizon=500, timestep=1 / 240., n_intermediate_steps=1, debug_gui=False, 13 | env_noise=False, obs_noise=False, obs_delay=False, torque_control=True, step_action_function=None, 14 | isolated_joint_7=False): 15 | self.obs_prev = None 16 | 17 | if isolated_joint_7: 18 | self.n_ctrl_joints = 6 19 | else: 20 | self.n_ctrl_joints = 7 21 | 22 | super().__init__(gamma=gamma, horizon=horizon, timestep=timestep, 23 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui, 24 | env_noise=env_noise, n_agents=1, obs_noise=obs_noise, obs_delay=obs_delay, 25 | torque_control=torque_control, step_action_function=step_action_function, 26 | isolated_joint_7=isolated_joint_7) 27 | 28 | self._compute_init_state() 29 | 30 | self._client.resetDebugVisualizerCamera(cameraDistance=1.5, cameraYaw=-00.0, cameraPitch=-45.0, 31 | cameraTargetPosition=[-0.5, 0., 0.]) 32 | 33 | self._change_dynamics() 34 | 35 | self._disable_collision() 36 | 37 | self.reset() 38 | 39 | def _compute_init_state(self): 40 | q = np.zeros(9) 41 | des_pos = pino.SE3(np.diag([-1., 1., -1.]), np.array([0.65, 0., self.env_spec['universal_height']])) 42 | 43 | success, self.init_state = clik(self.pino_model, self.pino_data, des_pos, q, self.frame_idx) 44 | assert success is True 45 | 46 | def _disable_collision(self): 47 | # disable the collision with left and right rim Because of the improper collision shape 48 | iiwa_links = ['iiwa_1/link_1', 'iiwa_1/link_2', 'iiwa_1/link_3', 'iiwa_1/link_4', 'iiwa_1/link_5', 49 | 'iiwa_1/link_6', 'iiwa_1/link_7', 'iiwa_1/link_ee', 'iiwa_1/striker_base', 50 | 'iiwa_1/striker_joint_link', 'iiwa_1/striker_mallet', 'iiwa_1/striker_mallet_tip'] 51 | table_rims = ['t_down_rim_l', 't_down_rim_r', 't_up_rim_r', 't_up_rim_l', 52 | 't_left_rim', 't_right_rim', 't_base', 't_up_rim_top', 't_down_rim_top', 't_base'] 53 | for iiwa_l in iiwa_links: 54 | for table_r in table_rims: 55 | self.client.setCollisionFilterPair(self._indexer.link_map[iiwa_l][0], 56 | self._indexer.link_map[table_r][0], 57 | self._indexer.link_map[iiwa_l][1], 58 | self._indexer.link_map[table_r][1], 0) 59 | 60 | self.client.setCollisionFilterPair(self._model_map['puck'], self._indexer.link_map['t_down_rim_top'][0], 61 | -1, self._indexer.link_map['t_down_rim_top'][1], 0) 62 | self.client.setCollisionFilterPair(self._model_map['puck'], self._indexer.link_map['t_up_rim_top'][0], 63 | -1, self._indexer.link_map['t_up_rim_top'][1], 0) 64 | 65 | def _change_dynamics(self): 66 | for i in range(12): 67 | self.client.changeDynamics(self._model_map['iiwa_1'], i, linearDamping=0., angularDamping=0.) 68 | 69 | def _modify_mdp_info(self, mdp_info): 70 | obs_idx = [0, 1, 2, 7, 8, 9, 13, 14, 15, 16, 17, 18, 22, 23, 24, 25, 26, 27] 71 | obs_low = mdp_info.observation_space.low[obs_idx] 72 | obs_high = mdp_info.observation_space.high[obs_idx] 73 | obs_low[0:3] = [-1, -0.5, -np.pi] 74 | obs_high[0:3] = [1, 0.5, np.pi] 75 | observation_space = Box(low=obs_low, high=obs_high) 76 | 77 | act_low = mdp_info.action_space.low[:self.n_ctrl_joints] 78 | act_high = mdp_info.action_space.high[:self.n_ctrl_joints] 79 | action_space = Box(low=act_low, high=act_high) 80 | return MDPInfo(observation_space, action_space, mdp_info.gamma, mdp_info.horizon) 81 | 82 | def _create_observation(self, state): 83 | puck_pose = self.get_sim_state(state, "puck", PyBulletObservationType.BODY_POS) 84 | puck_pose_2d = self._puck_2d_in_robot_frame(puck_pose, self.agents[0]['frame'], type='pose') 85 | 86 | robot_pos = list() 87 | robot_vel = list() 88 | for i in range(6): 89 | robot_pos.append(self.get_sim_state(state, 90 | self.agents[0]['name'] + "/joint_"+str(i+1), 91 | PyBulletObservationType.JOINT_POS)) 92 | robot_vel.append(self.get_sim_state(state, 93 | self.agents[0]['name'] + "/joint_" + str(i + 1), 94 | PyBulletObservationType.JOINT_VEL)) 95 | if not self.isolated_joint_7: 96 | robot_pos.append(self.get_sim_state(state, 97 | self.agents[0]['name'] + "/joint_" + str(7), 98 | PyBulletObservationType.JOINT_POS)) 99 | robot_vel.append(self.get_sim_state(state, 100 | self.agents[0]['name'] + "/joint_" + str(7), 101 | PyBulletObservationType.JOINT_VEL)) 102 | robot_pos = np.asarray(robot_pos).flatten() 103 | robot_vel = np.asarray(robot_vel).flatten() 104 | 105 | if self.obs_noise: 106 | puck_pose_2d[:2] += np.random.randn(2) * 0.001 107 | puck_pose_2d[2] += np.random.randn(1) * 0.001 108 | 109 | puck_lin_vel = self.get_sim_state(state, "puck", PyBulletObservationType.BODY_LIN_VEL) 110 | puck_ang_vel = self.get_sim_state(state, "puck", PyBulletObservationType.BODY_ANG_VEL) 111 | puck_vel_2d = self._puck_2d_in_robot_frame(np.concatenate([puck_lin_vel, puck_ang_vel]), 112 | self.agents[0]['frame'], type='vel') 113 | 114 | if self.obs_delay: 115 | alpha = 0.5 116 | puck_vel_2d = alpha * puck_vel_2d + (1 - alpha) * self.obs_prev[3:6] 117 | robot_vel = alpha * robot_vel + (1 - alpha) * self.obs_prev[9:12] 118 | 119 | self.obs_prev = np.concatenate([puck_pose_2d, puck_vel_2d, robot_pos, robot_vel]) 120 | return self.obs_prev 121 | 122 | def _puck_2d_in_robot_frame(self, puck_in, robot_frame, type='pose'): 123 | if type == 'pose': 124 | puck_frame = transformations.translation_matrix(puck_in[:3]) 125 | puck_frame = puck_frame @ transformations.quaternion_matrix(puck_in[3:]) 126 | 127 | frame_target = transformations.inverse_matrix(robot_frame) @ puck_frame 128 | puck_translate = transformations.translation_from_matrix(frame_target) 129 | _, _, puck_euler_yaw = transformations.euler_from_matrix(frame_target) 130 | 131 | return np.concatenate([puck_translate[:2], [puck_euler_yaw]]) 132 | if type == 'vel': 133 | rot_mat = robot_frame[:3, :3] 134 | vec_lin = rot_mat.T @ puck_in[:3] 135 | return np.concatenate([vec_lin[:2], puck_in[5:6]]) 136 | 137 | def _compute_joint_7(self, joint_state): 138 | q_cur = joint_state.copy() 139 | q_cur_7 = q_cur[6] 140 | q_cur[6] = 0. 141 | 142 | f_cur = fk(self.pino_model, self.pino_data, q_cur, self.frame_idx) 143 | z_axis = np.array([0., 0., -1.]) 144 | 145 | y_des = np.cross(z_axis, f_cur.rotation[:, 2]) 146 | y_des_norm = np.linalg.norm(y_des) 147 | if y_des_norm > 1e-2: 148 | y_des = y_des / y_des_norm 149 | else: 150 | y_des = f_cur.rotation[:, 2] 151 | 152 | target = np.arccos(f_cur.rotation[:, 1].dot(y_des)) 153 | 154 | axis = np.cross(f_cur.rotation[:, 1], y_des) 155 | axis_norm = np.linalg.norm(axis) 156 | if axis_norm > 1e-2: 157 | axis = axis / axis_norm 158 | else: 159 | axis = np.array([0., 0., 1.]) 160 | 161 | target = target * axis.dot(f_cur.rotation[:, 2]) 162 | 163 | if target - q_cur_7 > np.pi / 2: 164 | target -= np.pi 165 | elif target - q_cur_7 < -np.pi / 2: 166 | target += np.pi 167 | 168 | return np.atleast_1d(target) 169 | 170 | def _compute_universal_joint(self, joint_state): 171 | rot_mat = transformations.quaternion_matrix( 172 | self.client.getLinkState(*self._indexer.link_map['iiwa_1/link_ee'])[1]) 173 | 174 | q1 = np.arccos(rot_mat[:3, 2].dot(np.array((0., 0., -1)))) 175 | q2 = 0 176 | 177 | axis = np.cross(rot_mat[:3, 2], np.array([0., 0., -1.])) 178 | axis_norm = np.linalg.norm(axis) 179 | if axis_norm > 1e-2: 180 | axis = axis / axis_norm 181 | else: 182 | axis = np.array([0., 0., 1.]) 183 | q1 = q1 * axis.dot(rot_mat[:3, 1]) 184 | 185 | return np.array([q1, q2]) 186 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/iiwa_air_hockey_rmp.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pinocchio as pino 3 | from mushroom_rl.utils.spaces import * 4 | from scipy.linalg import solve 5 | from atacom.environments.iiwa_air_hockey.env_hitting import AirHockeyHit 6 | 7 | 8 | class AirHockeyIiwaRmp: 9 | def __init__(self, task='H', gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=4, 10 | acc_max=10, debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, random_init=False, 11 | action_penalty=1e-3, Kq=10.): 12 | if task == 'H': 13 | base_env = AirHockeyHit(gamma=gamma, horizon=horizon, timestep=timestep, 14 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui, 15 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay, 16 | torque_control='torque', random_init=random_init, 17 | step_action_function=self.step_action_function, 18 | action_penalty=action_penalty, isolated_joint_7=True) 19 | if task == 'D': 20 | raise NotImplementedError 21 | 22 | self.observation = None 23 | 24 | self.env = base_env 25 | self.env.agents[0]['constr_a'] = 1.0 26 | self.env.agents[0]['constr_b'] = 1.0 27 | self.env.agents[0]['constr_bound_x_l'] = 0.58 28 | self.env.agents[0]['constr_bound_y_l'] = -0.46 29 | self.env.agents[0]['constr_bound_y_u'] = 0.46 30 | 31 | self.dim_q = 6 32 | self._mdp_info = self.env.info.copy() 33 | self._mdp_info.action_space = Box(low=-np.ones(self.dim_q), high=np.ones(self.dim_q)) 34 | 35 | self.constr_logs = list() 36 | 37 | if np.isscalar(acc_max): 38 | self.acc_max = np.ones(self.dim_q) * acc_max 39 | else: 40 | self.acc_max = acc_max 41 | assert np.shape(self.acc_max)[0] == self.dim_q 42 | 43 | if np.isscalar(Kq): 44 | self.K_q = np.ones(self.dim_q) * Kq 45 | else: 46 | self.K_q = Kq 47 | assert np.shape(self.K_q)[0] == self.dim_q 48 | 49 | @property 50 | def info(self): 51 | """ 52 | Returns: 53 | An object containing the info of the environment. 54 | """ 55 | return self._mdp_info 56 | 57 | def reset(self, state=None): 58 | return self.env.reset(state) 59 | 60 | def step(self, action): 61 | action = np.clip(action, self.info.action_space.low, self.info.action_space.high) 62 | self.observation, reward, absorbing, _ = self.env.step(action) 63 | q = self.env.joints.positions(self.env._state) 64 | dq = self.env.joints.positions(self.env._state) 65 | self._update_constraint_stats(q, dq) 66 | return self.observation.copy(), reward, absorbing, _ 67 | 68 | def stop(self): 69 | self.env.stop() 70 | 71 | def step_action_function(self, state, action): 72 | q = self.env.joints.positions(state) 73 | dq = self.env.joints.velocities(state) 74 | pino.forwardKinematics(self.env.pino_model, self.env.pino_data, q, dq) 75 | pino.computeJointJacobians(self.env.pino_model, self.env.pino_data, q) 76 | pino.updateFramePlacements(self.env.pino_model, self.env.pino_data) 77 | 78 | ddq_ee = self.rmp_ddq_ee() 79 | ddq_elbow = self.rmp_ddq_elbow() 80 | ddq_wrist = self.rmp_ddq_wrist() 81 | 82 | ddq_joints = self.rmp_joint_limit(q, dq) 83 | 84 | ddq_total = np.zeros(self.env.pino_model.nq) 85 | ddq_total[:self.dim_q] = ddq_ee + ddq_elbow + ddq_wrist + ddq_joints + action * self.acc_max 86 | ddq = self.acc_truncation(dq, ddq_total) 87 | tau = pino.rnea(self.env.pino_model, self.env.pino_data, q, dq, ddq) 88 | return tau[:6] 89 | 90 | def rmp_ddq_ee(self): 91 | frame_id = self.env.frame_idx 92 | J_frame = pino.getFrameJacobian(self.env.pino_model, self.env.pino_data, frame_id, 93 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED) 94 | J = J_frame[:3, :6] 95 | link_pos = self.env.pino_data.oMf[frame_id].translation 96 | link_vel = pino.getFrameVelocity(self.env.pino_model, self.env.pino_data, frame_id, 97 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED).linear 98 | cart_acc = self.f_bound(link_pos, link_vel, self.env.agents[0]['constr_bound_x_l'], idx=0, b_type='l', 99 | eta_rep=0.1, v_rep=10) + \ 100 | self.f_bound(link_pos, link_vel, self.env.agents[0]['constr_bound_y_l'], idx=1, b_type='l', 101 | eta_rep=0.1, v_rep=10) + \ 102 | self.f_bound(link_pos, link_vel, self.env.agents[0]['constr_bound_y_u'], idx=1, b_type='u', 103 | eta_rep=0.1, v_rep=10) + \ 104 | self.f_plane(link_pos, link_vel, self.env.env_spec['universal_height']) 105 | 106 | return J.T @ solve(J @ J.T + 1e-6 * np.eye(3), cart_acc) 107 | 108 | def rmp_ddq_elbow(self): 109 | frame_id = self.env.pino_model.getFrameId("iiwa_1/link_4") 110 | J_frame = pino.getFrameJacobian(self.env.pino_model, self.env.pino_data, frame_id, 111 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED) 112 | J = J_frame[:3, :6] 113 | link_pos = self.env.pino_data.oMf[frame_id].translation 114 | link_vel = pino.getFrameVelocity(self.env.pino_model, self.env.pino_data, frame_id, 115 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED).linear 116 | cart_acc = self.f_bound(link_pos, link_vel, 0.36, idx=2, b_type='l', 117 | eta_rep=0.5, v_rep=10) 118 | return J.T @ solve(J @ J.T + 1e-6 * np.eye(3), cart_acc) 119 | 120 | def rmp_ddq_wrist(self): 121 | frame_id = self.env.pino_model.getFrameId("iiwa_1/link_6") 122 | J_frame = pino.getFrameJacobian(self.env.pino_model, self.env.pino_data, frame_id, 123 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED) 124 | J = J_frame[:3, :6] 125 | link_pos = self.env.pino_data.oMf[frame_id].translation 126 | link_vel = pino.getFrameVelocity(self.env.pino_model, self.env.pino_data, frame_id, 127 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED).linear 128 | cart_acc = self.f_bound(link_pos, link_vel, 0.25, idx=2, b_type='l', 129 | eta_rep=0.01, v_rep=10) 130 | return J.T @ solve(J @ J.T + 1e-6 * np.eye(3), cart_acc) 131 | 132 | def rmp_joint_limit(self, q, dq): 133 | sigma = 10 134 | 135 | s = (q - self.env.pino_model.lowerPositionLimit) / \ 136 | (self.env.pino_model.upperPositionLimit - self.env.pino_model.lowerPositionLimit) 137 | 138 | d = 4 * s * (1 - s) 139 | 140 | alpha_u = 1 - np.exp(- (np.maximum(dq, 0.) / sigma) ** 2 / 2) 141 | alpha_l = 1 - np.exp(- (np.minimum(dq, 0.) / sigma) ** 2 / 2) 142 | 143 | b = s * (alpha_u * d + (1 - alpha_u)) + (1 - s) * (alpha_l * d + (1 - alpha_l)) 144 | 145 | a = b ** (-2) 146 | return a[:6] 147 | 148 | def f_bound(self, x, dx, bound, idx, b_type='u', eta_rep=5., v_rep=1., eta_damp=None): 149 | ddx = np.zeros_like(x) 150 | if eta_damp is None: 151 | eta_damp = np.sqrt(eta_rep) 152 | if b_type == 'u': 153 | d = np.maximum(bound - x[idx], 0.) ** 2 154 | ddx[idx] = -eta_rep * np.exp(-d / v_rep) - eta_damp / (d + 1e-6) * np.maximum(dx[idx], 0) 155 | elif b_type == 'l': 156 | d = np.maximum(x[idx] - bound, 0.) ** 2 157 | ddx[idx] = eta_rep * np.exp(-d / v_rep) - eta_damp / (d + 1e-6) * np.minimum(dx[idx], 0) 158 | return ddx 159 | 160 | def f_plane(self, x, dx, height): 161 | ddx = np.zeros_like(x) 162 | k = 1000 163 | d = np.sqrt(k) 164 | ddx[2] = k * (height - x[2]) - d * dx[2] 165 | return ddx 166 | 167 | def acc_truncation(self, dq, ddq): 168 | acc_u = np.maximum(np.minimum(self.acc_max, 169 | -self.K_q * (dq[:self.dim_q] - self.env.pino_model.velocityLimit[:self.dim_q])), 170 | -self.acc_max) 171 | acc_l = np.minimum(np.maximum(-self.acc_max, 172 | -self.K_q * (dq[:self.dim_q] + self.env.pino_model.velocityLimit[:self.dim_q])), 173 | self.acc_max) 174 | ddq[:self.dim_q] = np.clip(ddq[:self.dim_q], acc_l, acc_u) 175 | return ddq 176 | 177 | def _update_constraint_stats(self, q, dq): 178 | c_i = self._compute_c(q, dq) 179 | c_dq_i = (np.abs(dq) - self.env.pino_model.velocityLimit)[:self.dim_q] 180 | self.constr_logs.append([np.max(c_i), np.max(c_dq_i)]) 181 | 182 | def get_constraints_logs(self): 183 | constr_logs = np.array(self.constr_logs) 184 | c_avg = np.mean(constr_logs[:, 0]) 185 | c_max = np.max(constr_logs[:, 0]) 186 | c_dq_max = np.max(constr_logs[:, 1]) 187 | self.constr_logs.clear() 188 | return c_avg, c_max, c_dq_max 189 | 190 | def _compute_c(self, q, dq): 191 | pino.forwardKinematics(self.env.pino_model, self.env.pino_data, q, dq) 192 | pino.computeJointJacobians(self.env.pino_model, self.env.pino_data, q) 193 | pino.updateFramePlacements(self.env.pino_model, self.env.pino_data) 194 | 195 | ee_pos = self.env.pino_data.oMf[self.env.frame_idx].translation 196 | elbow_pos = self.env.pino_data.oMf[self.env.pino_model.getFrameId("iiwa_1/link_4")].translation 197 | wrist_pos = self.env.pino_data.oMf[self.env.pino_model.getFrameId("iiwa_1/link_6")].translation 198 | 199 | c = [] 200 | c.append(np.abs(ee_pos[2] - self.env.env_spec['universal_height'])) 201 | c.append(-ee_pos[0] + self.env.agents[0]['constr_bound_x_l']) 202 | c.append(-ee_pos[1] + self.env.agents[0]['constr_bound_y_l']) 203 | c.append(ee_pos[1] - self.env.agents[0]['constr_bound_y_u']) 204 | c.append(- elbow_pos[2] + 0.36) 205 | c.append(- wrist_pos[2] + 0.25) 206 | c.extend(- q[:self.env.n_ctrl_joints] + self.env.pino_model.lowerPositionLimit[:self.env.n_ctrl_joints]) 207 | c.extend(q[:self.env.n_ctrl_joints] - self.env.pino_model.upperPositionLimit[:self.env.n_ctrl_joints]) 208 | return np.array(c) 209 | 210 | 211 | if __name__ == '__main__': 212 | env = AirHockeyIiwaRmp(debug_gui=True) 213 | 214 | env.reset() 215 | for i in range(10000): 216 | action = np.random.randn(env.dim_q) 217 | _, _, absorb, _ = env.step(action) 218 | if absorb: 219 | env.reset() 220 | time.sleep(1 / 240.) 221 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/iiwa_hit_atacom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pinocchio as pino 4 | import matplotlib.pyplot as plt 5 | from atacom.environments.iiwa_air_hockey.env_hitting import AirHockeyHit 6 | from atacom.atacom import AtacomEnvWrapper 7 | from atacom.constraints import ViabilityConstraint, ConstraintsSet 8 | 9 | 10 | class AirHockeyIiwaAtacom(AtacomEnvWrapper): 11 | def __init__(self, task='H', gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=4, 12 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, Kc=240., random_init=False, 13 | action_penalty=1e-3): 14 | if task == 'H': 15 | base_env = AirHockeyHit(gamma=gamma, horizon=horizon, timestep=timestep, 16 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui, 17 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay, 18 | torque_control='torque', random_init=random_init, 19 | action_penalty=action_penalty, isolated_joint_7=True) 20 | if task == 'D': 21 | raise NotImplementedError 22 | 23 | dim_q = base_env.n_ctrl_joints 24 | ee_pos_f = ViabilityConstraint(dim_q=dim_q, dim_out=1, fun=self.ee_pose_f, J=self.ee_pose_J_f, 25 | b=self.ee_pos_b_f, K=0.1) 26 | f = ConstraintsSet(dim_q) 27 | f.add_constraint(ee_pos_f) 28 | 29 | cart_pos_g = ViabilityConstraint(dim_q=dim_q, dim_out=5, fun=self.ee_pos_g, J=self.ee_pos_J_g, 30 | b=self.ee_pos_b_g, K=0.5) 31 | joint_pos_g = ViabilityConstraint(dim_q=dim_q, dim_out=dim_q, fun=self.joint_pos_g, J=self.joint_pos_J_g, 32 | b=self.joint_pos_b_g, K=1) 33 | g = ConstraintsSet(dim_q) 34 | g.add_constraint(cart_pos_g) 35 | g.add_constraint(joint_pos_g) 36 | 37 | acc_max = np.ones(base_env.n_ctrl_joints) * 10 38 | vel_max = base_env.joints.velocity_limits()[:base_env.n_ctrl_joints] 39 | super().__init__(base_env, dim_q, f=f, g=g, Kc=Kc, vel_max=vel_max, acc_max=acc_max, Kq=4 * acc_max / vel_max, 40 | time_step=timestep) 41 | 42 | self.pino_model = self.env.pino_model 43 | self.pino_data = self.env.pino_data 44 | self.frame_idx = self.env.frame_idx 45 | self.frame_idx_4 = 12 46 | self.frame_idx_7 = 18 47 | 48 | for i in range(self.pino_model.nq): 49 | self.env.client.changeDynamics(*base_env._indexer.joint_map[self.pino_model.names[i+1]], 50 | maxJointVelocity=self.pino_model.velocityLimit[i] * 1.5) 51 | 52 | def _get_q(self, state): 53 | return state[-2 * self.env.n_ctrl_joints:-self.env.n_ctrl_joints] 54 | 55 | def _get_dq(self, state): 56 | return state[-self.env.n_ctrl_joints:] 57 | 58 | def acc_to_ctrl_action(self, ddq): 59 | ddq = self._get_pino_value(ddq).tolist() 60 | sim_state = self.env._indexer.create_sim_state() 61 | q = self.env.joints.positions(sim_state).tolist() 62 | dq = self.env.joints.velocities(sim_state).tolist() 63 | return self.env.client.calculateInverseDynamics(2, q, dq, ddq)[:self.env.n_ctrl_joints] 64 | 65 | def _get_pino_value(self, q): 66 | ret = np.zeros(9) 67 | ret[:q.shape[0]] = q 68 | return ret 69 | 70 | def ee_pose_f(self, q): 71 | q = self._get_pino_value(q) 72 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q) 73 | ee_pos_z = self.pino_data.oMf[self.frame_idx].translation[2] 74 | return np.atleast_1d(ee_pos_z - self.env.env_spec['universal_height']) 75 | 76 | def ee_pose_J_f(self, q): 77 | q = self._get_pino_value(q) 78 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q) 79 | ee_jac = pino.computeFrameJacobian(self.pino_model, self.pino_data, q, 80 | self.frame_idx, pino.LOCAL_WORLD_ALIGNED)[:, :self.env.n_ctrl_joints] 81 | J_pos = ee_jac[2] 82 | return np.atleast_2d(J_pos) 83 | 84 | def ee_pos_b_f(self, q, dq): 85 | q = self._get_pino_value(q) 86 | dq = self._get_pino_value(dq) 87 | pino.forwardKinematics(self.pino_model, self.pino_data, q, dq) 88 | acc = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.frame_idx, 89 | pino.LOCAL_WORLD_ALIGNED).vector 90 | b_pos = acc[2] 91 | return np.atleast_1d(b_pos) 92 | 93 | def ee_pos_g(self, q): 94 | q = self._get_pino_value(q) 95 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q) 96 | ee_pos = self.pino_data.oMf[self.frame_idx].translation[:2] 97 | ee_pos_world = ee_pos + self.env.agents[0]['frame'][:2, 3] 98 | g_1 = - ee_pos_world[0] - (self.env.env_spec['table']['length'] / 2 - self.env.env_spec['mallet']['radius']) 99 | g_2 = - ee_pos_world[1] - (self.env.env_spec['table']['width'] / 2 - self.env.env_spec['mallet']['radius']) 100 | g_3 = ee_pos_world[1] - (self.env.env_spec['table']['width'] / 2 - self.env.env_spec['mallet']['radius']) 101 | 102 | ee_pos_4 = self.pino_data.oMf[self.frame_idx_4].translation 103 | ee_pos_7 = self.pino_data.oMf[self.frame_idx_7].translation 104 | g_4 = -ee_pos_4[2] + 0.36 105 | g_5 = -ee_pos_7[2] + 0.25 106 | return np.array([g_1, g_2, g_3, g_4, g_5]) 107 | 108 | def ee_pos_J_g(self, q): 109 | q = self._get_pino_value(q) 110 | pino.computeJointJacobians(self.pino_model, self.pino_data, q) 111 | jac_ee = pino.getFrameJacobian(self.pino_model, self.pino_data, 112 | self.frame_idx, pino.LOCAL_WORLD_ALIGNED)[:, :self.env.n_ctrl_joints] 113 | jac_4 = pino.getFrameJacobian(self.pino_model, self.pino_data, 114 | self.frame_idx_4, pino.LOCAL_WORLD_ALIGNED)[:, :self.env.n_ctrl_joints] 115 | jac_7 = pino.getFrameJacobian(self.pino_model, self.pino_data, 116 | self.frame_idx_7, pino.LOCAL_WORLD_ALIGNED)[:, :self.env.n_ctrl_joints] 117 | return np.vstack([-jac_ee[0], -jac_ee[1], jac_ee[1], -jac_4[2], -jac_7[2]]) 118 | 119 | def ee_pos_b_g(self, q, dq): 120 | q = self._get_pino_value(q) 121 | dq = self._get_pino_value(dq) 122 | pino.forwardKinematics(self.pino_model, self.pino_data, q, dq) 123 | acc_ee = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.frame_idx, 124 | pino.LOCAL_WORLD_ALIGNED).vector 125 | acc_4 = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.frame_idx_4, 126 | pino.LOCAL_WORLD_ALIGNED).vector 127 | acc_7 = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.frame_idx_7, 128 | pino.LOCAL_WORLD_ALIGNED).vector 129 | 130 | return np.array([-acc_ee[0], -acc_ee[1], acc_ee[1], -acc_4[2], -acc_7[2]]) 131 | 132 | def joint_pos_g(self, q): 133 | return np.array(q ** 2 - self.pino_model.upperPositionLimit[:self.env.n_ctrl_joints] ** 2) 134 | 135 | def joint_pos_J_g(self, q): 136 | return 2 * np.diag(q) 137 | 138 | def joint_pos_b_g(self, q, dq): 139 | return 2 * dq ** 2 140 | 141 | def plot_constraints(self, dataset, save_dir="", suffix="", state_norm_processor=None): 142 | state_list = list() 143 | i = 0 144 | 145 | if suffix != '': 146 | suffix = suffix + "_" 147 | 148 | for data in dataset: 149 | state = data[0] 150 | if state_norm_processor is not None: 151 | state[state_norm_processor._obs_mask] = (state * state_norm_processor._obs_delta + \ 152 | state_norm_processor._obs_mean)[state_norm_processor._obs_mask] 153 | state_list.append(state) 154 | if data[-1]: 155 | i += 1 156 | state_hist = np.array(state_list) 157 | 158 | if not os.path.exists(save_dir): 159 | os.makedirs(save_dir) 160 | 161 | ee_pos_list = list() 162 | for state_i in state_hist: 163 | q = np.zeros(9) 164 | q[:6] = state_i[6:12] 165 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q) 166 | ee_pos_list.append(self.pino_data.oMf[-1].translation[:2] + self.env.agents[0]['frame'][:2, 3]) 167 | 168 | ee_pos_list = np.array(ee_pos_list) 169 | fig1, axes1 = plt.subplots(1, figsize=(10, 10)) 170 | axes1.plot(ee_pos_list[:, 0], ee_pos_list[:, 1], label='position') 171 | axes1.plot([0.0, -0.91, -0.91, 0.0], [-0.45, -0.45, 0.45, 0.45], label='boundary', c='k', lw='5') 172 | axes1.set_aspect(1.0) 173 | axes1.set_xlim(-1, 0) 174 | axes1.set_ylim(-0.5, 0.5) 175 | axes1.legend(loc='upper right') 176 | axes1.set_title('EndEffector') 177 | axes1.legend(loc='center right') 178 | file1 = "EndEffector_" + suffix + str(i) + ".pdf" 179 | plt.savefig(os.path.join(save_dir, file1)) 180 | plt.close(fig1) 181 | 182 | fig2, axes2 = plt.subplots(2, 3, figsize=(21, 8), sharex=True, sharey=True) 183 | for j in range(6): 184 | axes2[j // 3, j % 3].plot(state_hist[:, 6 + j], lw=3, color='tab:blue') 185 | axes2[j // 3, j % 3].plot(state_hist[:, 12 + j], lw=3, color='tab:orange') 186 | axes2[j // 3, j % 3].plot([0, state_hist.shape[0]], [self.pino_model.lowerPositionLimit[j]] * 2, 187 | lw=3, c='tab:red', ls='--') 188 | axes2[j // 3, j % 3].plot([0, state_hist.shape[0]], [self.pino_model.upperPositionLimit[j]] * 2, 189 | lw=3, c='tab:red', ls='--') 190 | axes2[j // 3, j % 3].plot([0, state_hist.shape[0]], [-self.pino_model.velocityLimit[j]] * 2, 191 | lw=3, c='tab:pink', ls=':') 192 | axes2[j // 3, j % 3].plot([0, state_hist.shape[0]], [self.pino_model.velocityLimit[j]] * 2, 193 | lw=3, c='tab:pink', ls=':') 194 | 195 | axes2[j // 3, j % 3].set_title('Joint ' + str(j + 1)) 196 | 197 | axes2[0, 0].plot([], lw=3, color='tab:blue', label='position') 198 | axes2[0, 0].plot([], lw=3, color='tab:red', ls='--', label='position limit') 199 | axes2[0, 0].plot([], lw=3, color='tab:orange', label='velocity') 200 | axes2[0, 0].plot([], lw=3, color='tab:pink', ls=':', label='velocity limit') 201 | fig2.legend(ncol=4, loc='lower center') 202 | 203 | file2 = "JointProfile_" + suffix + str(i) + ".pdf" 204 | plt.savefig(os.path.join(save_dir, file2)) 205 | plt.close(fig2) 206 | 207 | state_list = list() 208 | 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/kinematics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pinocchio as pino 3 | from numpy.linalg import norm, solve 4 | 5 | IT_MAX = 1000 6 | eps = 1e-4 7 | DT = 1e-1 8 | damp = 1e-12 9 | 10 | 11 | def clik(model, data, fDes, q, idx): 12 | q_cur = q.copy() 13 | i = 0 14 | while True: 15 | pino.forwardKinematics(model, data, q_cur) 16 | pino.updateFramePlacements(model, data) 17 | 18 | dMi = fDes.actInv(data.oMf[idx]) 19 | lin_err = -dMi.translation 20 | ang_err = pino.log3(dMi.rotation) 21 | err = np.concatenate([lin_err, ang_err]) 22 | if norm(err) < eps: 23 | success = True 24 | break 25 | if i >= IT_MAX: 26 | success = False 27 | break 28 | J = pino.computeFrameJacobian(model, data, q_cur, idx, pino.LOCAL_WORLD_ALIGNED) 29 | v = - J.T.dot(solve(J.dot(J.T) + damp * np.eye(6), err)) 30 | q_cur = pino.integrate(model, q_cur, v*DT) 31 | i += 1 32 | idx = np.where(q_cur > model.upperPositionLimit) 33 | q_cur[idx] -= np.pi * 2 34 | idx = np.where(q_cur < model.lowerPositionLimit) 35 | q_cur[idx] += np.pi * 2 36 | if not (np.all(model.lowerPositionLimit < q_cur) and np.all(q_cur < model.upperPositionLimit)): 37 | return False, q_cur 38 | return success, q_cur 39 | 40 | 41 | def fk(model, data, q, idx): 42 | pino.forwardKinematics(model, data, q) 43 | pino.updateFramePlacement(model, data, idx) 44 | return data.oMf[idx] 45 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/iiwa_1.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/iiwa_2.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_0.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_1.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_2.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_3.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_4.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_5.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_6.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_7.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_7_old.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_7_old.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_arm_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_arm_collision.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_mallet_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_mallet_collision.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_mallet_short_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_mallet_short_collision.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_arm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_arm.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_mallet.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_mallet.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_mallet_short.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_mallet_short.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_0.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_1.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_2.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_3.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_4.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_5.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_6.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_7.stl -------------------------------------------------------------------------------- /atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_7_old.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_7_old.stl -------------------------------------------------------------------------------- /atacom/environments/planar_air_hockey/__init__.py: -------------------------------------------------------------------------------- 1 | from .atacom_air_hockey import AirHockeyPlanarAtacom 2 | from .unconstrained_air_hockey import AirHockeyHitUnconstrained, AirHockeyDefendUnconstrained 3 | -------------------------------------------------------------------------------- /atacom/environments/planar_air_hockey/atacom_air_hockey.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pinocchio as pino 4 | import matplotlib.pyplot as plt 5 | 6 | from atacom.atacom import AtacomEnvWrapper 7 | from atacom.constraints import ViabilityConstraint, ConstraintsSet 8 | from mushroom_rl.environments.pybullet_envs.air_hockey import AirHockeyHit, AirHockeyDefend 9 | 10 | 11 | class AirHockeyPlanarAtacom(AtacomEnvWrapper): 12 | def __init__(self, task='H', gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=4, 13 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, Kc=240., random_init=False, 14 | action_penalty=1e-3): 15 | if task == 'H': 16 | base_env = AirHockeyHit(gamma=gamma, horizon=horizon, timestep=timestep, 17 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui, 18 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay, 19 | torque_control=True, random_init=random_init, 20 | action_penalty=action_penalty) 21 | if task == 'D': 22 | base_env = AirHockeyDefend(gamma=gamma, horizon=horizon, timestep=timestep, 23 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui, 24 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay, 25 | torque_control=True, random_init=random_init, 26 | action_penalty=action_penalty) 27 | 28 | dim_q = 3 29 | cart_pos_g = ViabilityConstraint(dim_q=dim_q, dim_out=3, fun=self.cart_pos_g, J=self.cart_pos_J_g, 30 | b=self.cart_pos_b_g, K=0.5) 31 | joint_pos_g = ViabilityConstraint(dim_q=dim_q, dim_out=3, fun=self.joint_pos_g, J=self.joint_pos_J_g, 32 | b=self.joint_pos_b_g, K=1.0) 33 | # joint_vel_g = StateVelocityConstraint(dim_q=dim_q, dim_out=3, fun=self.joint_vel_g, A=self.joint_vel_A_g, 34 | # b=self.joint_vel_b_g, margin=0.0) 35 | g = ConstraintsSet(dim_q) 36 | g.add_constraint(cart_pos_g) 37 | g.add_constraint(joint_pos_g) 38 | # g.add_constraint(joint_vel_g) 39 | 40 | acc_max = np.ones(3) * 10 41 | vel_max = base_env.joints.velocity_limits() 42 | super().__init__(base_env, 3, f=None, g=g, Kc=Kc, vel_max=vel_max, acc_max=acc_max, Kq=2 * acc_max / vel_max, 43 | time_step=timestep) 44 | 45 | self.pino_model = pino.buildModelFromUrdf(self.env.agents[0]['urdf']) 46 | self.pino_data = self.pino_model.createData() 47 | self.frame_idx = self.pino_model.nframes - 1 48 | 49 | self.env.client.changeDynamics(*base_env._indexer.joint_map[self.pino_model.names[1]], 50 | maxJointVelocity=self.pino_model.velocityLimit[0] * 1.5) 51 | self.env.client.changeDynamics(*base_env._indexer.joint_map[self.pino_model.names[2]], 52 | maxJointVelocity=self.pino_model.velocityLimit[1] * 1.5) 53 | self.env.client.changeDynamics(*base_env._indexer.joint_map[self.pino_model.names[3]], 54 | maxJointVelocity=self.pino_model.velocityLimit[2] * 1.5) 55 | 56 | robot_links = ['planar_robot_1/link_striker_hand', 'planar_robot_1/link_striker_ee'] 57 | table_rims = ['t_down_rim_l', 't_down_rim_r', 't_up_rim_r', 't_up_rim_l', 58 | 't_left_rim', 't_right_rim', 't_base', 't_up_rim_top', 't_down_rim_top', 't_base'] 59 | for iiwa_l in robot_links: 60 | for table_r in table_rims: 61 | self.env.client.setCollisionFilterPair(self.env._indexer.link_map[iiwa_l][0], 62 | self.env._indexer.link_map[table_r][0], 63 | self.env._indexer.link_map[iiwa_l][1], 64 | self.env._indexer.link_map[table_r][1], 0) 65 | 66 | def _get_q(self, state): 67 | return state[6:9] 68 | 69 | def _get_dq(self, state): 70 | return state[9:12] 71 | 72 | def acc_to_ctrl_action(self, ddq): 73 | q = self.q.tolist() 74 | dq = self.dq.tolist() 75 | ddq = ddq.tolist() 76 | return self.env.client.calculateInverseDynamics(self.env._model_map['planar_robot_1'], q, dq, ddq) 77 | 78 | def cart_pos_g(self, q): 79 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q) 80 | ee_pos = self.pino_data.oMf[-1].translation[:2] 81 | ee_pos_world = ee_pos + self.env.agents[0]['frame'][:2, 3] 82 | g_1 = - ee_pos_world[0] - (self.env.env_spec['table']['length'] / 2 - self.env.env_spec['mallet']['radius']) 83 | g_2 = - ee_pos_world[1] - (self.env.env_spec['table']['width'] / 2 - self.env.env_spec['mallet']['radius']) 84 | g_3 = ee_pos_world[1] - (self.env.env_spec['table']['width'] / 2 - self.env.env_spec['mallet']['radius']) 85 | return np.array([g_1, g_2, g_3]) 86 | 87 | def cart_pos_J_g(self, q): 88 | ee_jac = pino.computeFrameJacobian(self.pino_model, self.pino_data, q, 89 | self.frame_idx, pino.LOCAL_WORLD_ALIGNED)[:2] 90 | J_c = np.array([[-1., 0.], [0., -1.], [0., 1.]]) 91 | return J_c @ ee_jac 92 | 93 | def cart_pos_b_g(self, q, dq): 94 | pino.forwardKinematics(self.pino_model, self.pino_data, q, dq) 95 | acc = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.pino_model.nframes - 1, 96 | pino.LOCAL_WORLD_ALIGNED).vector 97 | J_c = np.array([[-1., 0.], [0., -1.], [0., 1.]]) 98 | return J_c @ acc[:2] 99 | 100 | def joint_pos_g(self, q): 101 | return np.array(q ** 2 - self.pino_model.upperPositionLimit ** 2) 102 | 103 | def joint_pos_J_g(self, q): 104 | return 2 * np.diag(q) 105 | 106 | def joint_pos_b_g(self, q, dq): 107 | return 2 * dq ** 2 108 | 109 | def joint_vel_g(self, q, dq): 110 | return np.array([dq ** 2 - self.pino_model.velocityLimit ** 2]) 111 | 112 | def joint_vel_A_g(self, q, dq): 113 | return 2 * np.diag(dq) 114 | 115 | def joint_vel_b_g(self, q, dq): 116 | return np.zeros(3) 117 | 118 | def plot_constraints(self, dataset, save_dir="", suffix="", state_norm_processor=None): 119 | state_list = list() 120 | i = 0 121 | 122 | if suffix != '': 123 | suffix = suffix + "_" 124 | 125 | for data in dataset: 126 | state = data[0] 127 | if state_norm_processor is not None: 128 | state[state_norm_processor._obs_mask] = (state * state_norm_processor._obs_delta + \ 129 | state_norm_processor._obs_mean)[state_norm_processor._obs_mask] 130 | state_list.append(state) 131 | if data[-1]: 132 | i += 1 133 | state_hist = np.array(state_list) 134 | 135 | if not os.path.exists(save_dir): 136 | os.makedirs(save_dir) 137 | 138 | ee_pos_list = list() 139 | for state_i in state_hist: 140 | pino.framesForwardKinematics(self.pino_model, self.pino_data, state_i[6:9]) 141 | ee_pos_list.append(self.pino_data.oMf[-1].translation[:2] + self.env.agents[0]['frame'][:2, 3]) 142 | 143 | ee_pos_list = np.array(ee_pos_list) 144 | fig1, axes1 = plt.subplots(1, figsize=(10, 10)) 145 | axes1.plot(ee_pos_list[:, 0], ee_pos_list[:, 1], label='position') 146 | axes1.plot([0.0, -0.91, -0.91, 0.0], [-0.45, -0.45, 0.45, 0.45], label='boundary', c='k', lw='5') 147 | axes1.set_aspect(1.0) 148 | axes1.set_xlim(-1, 0) 149 | axes1.set_ylim(-0.5, 0.5) 150 | axes1.legend(loc='upper right') 151 | axes1.set_title('EndEffector') 152 | axes1.legend(loc='center right') 153 | file1 = "EndEffector_" + suffix + str(i) + ".pdf" 154 | plt.savefig(os.path.join(save_dir, file1)) 155 | plt.close(fig1) 156 | 157 | fig2, axes2 = plt.subplots(1, 3, sharey=True, figsize=(21, 8)) 158 | axes2[0].plot(state_hist[:, 6], label='position', c='tab:blue') 159 | axes2[1].plot(state_hist[:, 7], c='tab:blue') 160 | axes2[2].plot(state_hist[:, 8], c='tab:blue') 161 | axes2[0].plot([0, state_hist.shape[0]], [self.pino_model.lowerPositionLimit[0]] * 2, 162 | label='position limit', c='tab:red', ls='--') 163 | axes2[1].plot([0, state_hist.shape[0]], [self.pino_model.lowerPositionLimit[1]] * 2, c='tab:red', 164 | ls='--') 165 | axes2[2].plot([0, state_hist.shape[0]], [self.pino_model.lowerPositionLimit[2]] * 2, c='tab:red', 166 | ls='--') 167 | axes2[0].plot([0, state_hist.shape[0]], [self.pino_model.upperPositionLimit[0]] * 2, c='tab:red', 168 | ls='--') 169 | axes2[1].plot([0, state_hist.shape[0]], [self.pino_model.upperPositionLimit[1]] * 2, c='tab:red', 170 | ls='--') 171 | axes2[2].plot([0, state_hist.shape[0]], [self.pino_model.upperPositionLimit[2]] * 2, c='tab:red', 172 | ls='--') 173 | 174 | axes2[0].plot(state_hist[:, 9], label='velocity', c='tab:orange') 175 | axes2[1].plot(state_hist[:, 10], c='tab:orange') 176 | axes2[2].plot(state_hist[:, 11], c='tab:orange') 177 | axes2[0].plot([0, state_hist.shape[0]], [-self.pino_model.velocityLimit[0]] * 2, 178 | label='velocity limit', c='tab:pink', ls=':') 179 | axes2[1].plot([0, state_hist.shape[0]], [-self.pino_model.velocityLimit[1]] * 2, c='tab:pink', ls=':') 180 | axes2[2].plot([0, state_hist.shape[0]], [-self.pino_model.velocityLimit[2]] * 2, c='tab:pink', ls=':') 181 | axes2[0].plot([0, state_hist.shape[0]], [self.pino_model.velocityLimit[0]] * 2, c='tab:pink', ls=':') 182 | axes2[1].plot([0, state_hist.shape[0]], [self.pino_model.velocityLimit[1]] * 2, c='tab:pink', ls=':') 183 | axes2[2].plot([0, state_hist.shape[0]], [self.pino_model.velocityLimit[2]] * 2, c='tab:pink', ls=':') 184 | 185 | axes2[0].set_title('Joint 1') 186 | axes2[1].set_title('Joint 2') 187 | axes2[2].set_title('Joint 3') 188 | fig2.legend(ncol=4, loc='lower center') 189 | 190 | file2 = "JointProfile_" + suffix + str(i) + ".pdf" 191 | plt.savefig(os.path.join(save_dir, file2)) 192 | plt.close(fig2) 193 | 194 | state_list = list() -------------------------------------------------------------------------------- /atacom/environments/planar_air_hockey/unconstrained_air_hockey.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mushroom_rl.environments.pybullet import PyBulletObservationType 3 | from mushroom_rl.environments.pybullet_envs.air_hockey import AirHockeyHit, AirHockeyDefend 4 | 5 | 6 | class AirHockeyHitUnconstrained(AirHockeyHit): 7 | def __init__(self, gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=1, 8 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, torque_control="torque", 9 | random_init=False, action_penalty=1e-3): 10 | super().__init__(gamma=gamma, horizon=horizon, timestep=timestep, 11 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui, 12 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay, 13 | torque_control=torque_control, random_init=random_init, 14 | step_action_function=self._step_action_function, 15 | action_penalty=action_penalty) 16 | self.constr_logs = list() 17 | 18 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_1"], 19 | maxJointVelocity=self.joints.velocity_limits()[0] * 1.5) 20 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_2"], 21 | maxJointVelocity=self.joints.velocity_limits()[1] * 1.5) 22 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_3"], 23 | maxJointVelocity=self.joints.velocity_limits()[2] * 1.5) 24 | self.acc_max = np.ones(3) * 10 25 | 26 | def setup(self, state): 27 | super().setup(state) 28 | self.constr_logs.clear() 29 | 30 | def _step_action_function(self, state, action): 31 | action = np.clip(action, self.info.action_space.low, self.info.action_space.high) 32 | self._update_constraint_stats(state) 33 | return action 34 | 35 | def _update_constraint_stats(self, state): 36 | q = self.joints.positions(state) 37 | dq = self.joints.velocities(state) 38 | mallet_pose = self.get_sim_state(state, "planar_robot_1/link_striker_ee", PyBulletObservationType.LINK_POS) 39 | c_ee_i = np.array([-mallet_pose[0] - self.env_spec['table']['length'] / 2, 40 | -mallet_pose[1] - self.env_spec['table']['width'] / 2, 41 | mallet_pose[1] - self.env_spec['table']['width'] / 2]) 42 | c_q_i = q ** 2 - self.joints.limits()[1] ** 2 43 | c_dq_i = dq ** 2 - self.joints.velocity_limits() ** 2 44 | c_i = np.concatenate([c_ee_i, c_q_i]) 45 | self.constr_logs.append([np.max(c_i), np.max(c_dq_i)]) 46 | 47 | def get_constraints_logs(self): 48 | constr_logs = np.array(self.constr_logs) 49 | c_avg = np.mean(constr_logs[:, 0]) 50 | c_max = np.max(constr_logs[:, 0]) 51 | c_dq_max = np.max(constr_logs[:, 1]) 52 | self.constr_logs.clear() 53 | return c_avg, c_max, c_dq_max 54 | 55 | 56 | class AirHockeyDefendUnconstrained(AirHockeyDefend): 57 | def __init__(self, gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=1, 58 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, torque_control="torque", 59 | random_init=False, action_penalty=1e-3): 60 | super().__init__(gamma=gamma, horizon=horizon, timestep=timestep, 61 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui, 62 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay, 63 | torque_control=torque_control, random_init=random_init, 64 | step_action_function=self._step_action_function, 65 | action_penalty=action_penalty) 66 | self.constr_logs = list() 67 | 68 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_1"], 69 | maxJointVelocity=self.joints.velocity_limits()[0] * 1.5) 70 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_2"], 71 | maxJointVelocity=self.joints.velocity_limits()[1] * 1.5) 72 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_3"], 73 | maxJointVelocity=self.joints.velocity_limits()[2] * 1.5) 74 | self.acc_max = np.ones(3) * 10 75 | 76 | def setup(self, state): 77 | super().setup(state) 78 | self.constr_logs.clear() 79 | 80 | def _step_action_function(self, state, action): 81 | action = np.clip(action, self.info.action_space.low, self.info.action_space.high) 82 | self._update_constraint_stats(state) 83 | return action 84 | 85 | def _update_constraint_stats(self, state): 86 | q = self.joints.positions(state) 87 | dq = self.joints.velocities(state) 88 | mallet_pose = self.get_sim_state(state, "planar_robot_1/link_striker_ee", PyBulletObservationType.LINK_POS) 89 | c_ee_i = np.array([-mallet_pose[0] - self.env_spec['table']['length'] / 2, 90 | -mallet_pose[1] - self.env_spec['table']['width'] / 2, 91 | mallet_pose[1] - self.env_spec['table']['width'] / 2]) 92 | c_q_i = q ** 2 - self.joints.limits()[1] ** 2 93 | c_dq_i = dq ** 2 - self.joints.velocity_limits() ** 2 94 | c_i = np.concatenate([c_ee_i, c_q_i]) 95 | self.constr_logs.append([np.max(c_i), np.max(c_dq_i)]) 96 | 97 | def get_constraints_logs(self): 98 | constr_logs = np.array(self.constr_logs) 99 | c_avg = np.mean(constr_logs[:, 0]) 100 | c_max = np.max(constr_logs[:, 0]) 101 | c_dq_max = np.max(constr_logs[:, 1]) 102 | self.constr_logs.clear() 103 | return c_avg, c_max, c_dq_max -------------------------------------------------------------------------------- /atacom/error_correction_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.utils import spaces 4 | from atacom.utils import pinv_null 5 | 6 | 7 | class ErrorCorrectionEnvWrapper: 8 | """ 9 | Environment wrapper of the Error Correction Method 10 | 11 | """ 12 | def __init__(self, base_env, dim_q, vel_max, acc_max, f=None, g=None, Kc=100., Kq=10., time_step=0.01): 13 | """ 14 | Constructor 15 | Args: 16 | base_env (mushroomrl.Core.Environment): The base environment inherited from 17 | dim_q (int): [int] dimension of the directly controllable variable 18 | vel_max (array, float): the maximum velocity of the directly controllable variable 19 | acc_max (array, float): the maximum acceleration of the directly controllable variable 20 | f (ViabilityConstraint, ConstraintsSet): the equality constraint f(q) = 0 21 | g (ViabilityConstraint, ConstraintsSet): the inequality constraint g(q) = 0 22 | Kc (array, float): the scaling factor for error correction 23 | Ka (array, float): the scaling factor for the viability acceleration bound 24 | time_step (float): the step size for time discretization 25 | """ 26 | self.env = base_env 27 | self.dims = {'q': dim_q, 'f': 0, 'g': 0} 28 | self.f = f 29 | self.g = g 30 | self.time_step = time_step 31 | self._logger = None 32 | 33 | if self.f is not None: 34 | assert self.dims['q'] == self.f.dim_q, "Input dimension is different in f" 35 | self.dims['f'] = self.f.dim_out 36 | if self.g is not None: 37 | assert self.dims['q'] == self.g.dim_q, "Input dimension is different in g" 38 | self.dims['g'] = self.g.dim_out 39 | self.s = np.zeros(self.dims['g']) 40 | 41 | self.dims['c'] = self.dims['f'] + self.dims['g'] 42 | 43 | if np.isscalar(Kc): 44 | self.K_c = np.ones(self.dims['c']) * Kc 45 | else: 46 | self.K_c = Kc 47 | 48 | self.q = np.zeros(self.dims['q']) 49 | self.dq = np.zeros(self.dims['q']) 50 | 51 | self._mdp_info = self.env.info.copy() 52 | self._mdp_info.action_space = spaces.Box(low=-np.ones(self.dims['q']), high=np.ones(self.dims['q'])) 53 | 54 | if np.isscalar(vel_max): 55 | self.vel_max = np.ones(self.dims['q']) * vel_max 56 | else: 57 | self.vel_max = vel_max 58 | assert np.shape(self.vel_max)[0] == self.dims['q'] 59 | 60 | if np.isscalar(acc_max): 61 | self.acc_max = np.ones(self.dims['q']) * acc_max 62 | else: 63 | self.acc_max = acc_max 64 | assert np.shape(self.acc_max)[0] == self.dims['q'] 65 | 66 | if np.isscalar(Kq): 67 | self.K_q = np.ones(self.dims['q']) * Kq 68 | else: 69 | self.K_q = Kq 70 | assert np.shape(self.K_q)[0] == self.dims['q'] 71 | 72 | self.state = self.env.reset() 73 | self._act_a = None 74 | self._act_b = None 75 | self._act_err = None 76 | 77 | self.env.step_action_function = self.step_action_function 78 | 79 | def acc_to_ctrl_action(self, ddq): 80 | raise NotImplementedError 81 | 82 | def _get_q(self, state): 83 | raise NotImplementedError 84 | 85 | def _get_dq(self, state): 86 | raise NotImplementedError 87 | 88 | def seed(self, seed): 89 | self.env.seed(seed) 90 | 91 | def reset(self, state=None): 92 | self.state = self.env.reset(state) 93 | self.q = self._get_q(self.state) 94 | self.dq = self._get_dq(self.state) 95 | self._compute_slack_variables() 96 | return self.state 97 | 98 | def render(self): 99 | self.env.render() 100 | 101 | def stop(self): 102 | self.env.stop() 103 | 104 | def step(self, action): 105 | alpha = np.clip(action, self.info.action_space.low, self.info.action_space.high) 106 | alpha = alpha * self.acc_max 107 | 108 | self.state, reward, absorb, info = self.env.step(alpha) 109 | return self.state.copy(), reward, absorb, info 110 | 111 | def acc_truncation(self, dq, ddq): 112 | acc_u = np.maximum(np.minimum(self.acc_max, -self.K_q * (dq - self.vel_max)), -self.acc_max) 113 | acc_l = np.minimum(np.maximum(-self.acc_max, -self.K_q * (dq + self.vel_max)), self.acc_max) 114 | ddq = np.clip(ddq, acc_l, acc_u) 115 | return ddq 116 | 117 | def step_action_function(self, sim_state, alpha): 118 | self.state = self.env._create_observation(sim_state) 119 | self.q = self._get_q(self.state) 120 | self.dq = self._get_dq(self.state) 121 | 122 | Jc, psi = self._construct_Jc_psi(self.q, self.s, self.dq) 123 | Jc_inv, Nc = pinv_null(Jc) 124 | 125 | self._act_a = np.zeros(self.dims['q'] + self.dims['g']) 126 | self._act_b = np.concatenate([alpha, np.zeros(self.dims['g'])]) 127 | self._act_err = self._compute_error_correction(self.q, self.dq, self.s, Jc_inv) 128 | ddq_ds = self._act_a + self._act_b + self._act_err 129 | 130 | self.s += ddq_ds[self.dims['q']:(self.dims['q'] + self.dims['g'])] * self.time_step 131 | 132 | ddq = self.acc_truncation(self.dq, ddq_ds[:self.dims['q']]) 133 | ctrl_action = self.acc_to_ctrl_action(ddq) 134 | return ctrl_action 135 | 136 | @property 137 | def info(self): 138 | return self._mdp_info 139 | 140 | def _compute_slack_variables(self): 141 | self.s = None 142 | if self.dims['g'] > 0: 143 | s_2 = np.maximum(-2 * self.g.fun(self.q, self.dq, origin_constr=False), 0) 144 | self.s = np.sqrt(s_2) 145 | 146 | def _construct_Jc_psi(self, q, s, dq): 147 | Jc = np.zeros((self.dims['f'] + self.dims['g'], self.dims['q'] + self.dims['g'])) 148 | psi = np.zeros(self.dims['c']) 149 | if self.dims['f'] > 0: 150 | idx_0 = 0 151 | idx_1 = self.dims['f'] 152 | Jc[idx_0:idx_1, :self.dims['q']] = self.f.K_J(q) 153 | psi[idx_0:idx_1] = self.f.b(q, dq) 154 | if self.dims['g'] > 0: 155 | idx_0 = self.dims['f'] 156 | idx_1 = self.dims['f'] + self.dims['g'] 157 | Jc[idx_0:idx_1, :self.dims['q']] = self.g.K_J(q) 158 | Jc[idx_0:idx_1, self.dims['q']:(self.dims['q'] + self.dims['g'])] = np.diag(s) 159 | psi[idx_0:idx_1] = self.g.b(q, dq) 160 | return Jc, psi 161 | 162 | def _compute_error_correction(self, q, dq, s, Jc_inv, act_null=None): 163 | q_tmp = q.copy() 164 | dq_tmp = dq.copy() 165 | s_tmp = None 166 | 167 | if self.dims['g'] > 0: 168 | s_tmp = s.copy() 169 | 170 | if act_null is not None: 171 | q_tmp += dq_tmp * self.time_step + act_null[:self.dims['q']] * self.time_step ** 2 / 2 172 | dq_tmp += act_null[:self.dims['q']] * self.time_step 173 | if self.dims['g'] > 0: 174 | s_tmp += act_null[self.dims['q']:self.dims['q'] + self.dims['g']] * self.time_step 175 | 176 | return -Jc_inv @ (self.K_c * self._compute_c(q_tmp, dq_tmp, s_tmp, origin_constr=False)) 177 | 178 | def _compute_c(self, q, dq, s, origin_constr=False): 179 | c = np.zeros(self.dims['f'] + self.dims['g']) 180 | if self.dims['f'] > 0: 181 | idx_0 = 0 182 | idx_1 = self.dims['f'] 183 | c[idx_0:idx_1] = self.f.fun(q, dq, origin_constr) 184 | if self.dims['g'] > 0: 185 | idx_0 = self.dims['f'] 186 | idx_1 = self.dims['f'] + self.dims['g'] 187 | if origin_constr: 188 | c[idx_0:idx_1] = self.g.fun(q, dq, origin_constr) 189 | else: 190 | c[idx_0:idx_1] = self.g.fun(q, dq, origin_constr) + 1 / 2 * s ** 2 191 | return c 192 | 193 | def set_logger(self, logger): 194 | self._logger = logger 195 | 196 | def get_constraints_logs(self): 197 | return self.env.get_constraints_logs() 198 | -------------------------------------------------------------------------------- /atacom/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .null_space_coordinate import rref, gram_schmidt, pinv_null 2 | from .plot_utils import * -------------------------------------------------------------------------------- /atacom/utils/null_space_coordinate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse 3 | from scipy import linalg 4 | import sympy 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def pinv_null(a, rcond=None, return_rank=False): 9 | u, s, vh = linalg.svd(a, full_matrices=True, check_finite=False) 10 | M, N = u.shape[0], vh.shape[1] 11 | 12 | if rcond is None: 13 | rcond = np.finfo(s.dtype).eps * max(M, N) 14 | tol = np.amax(s) * rcond 15 | 16 | rank = np.sum(s > tol) 17 | Q = vh[rank:, :].T.conj() 18 | 19 | u = u[:, :rank] 20 | u /= s[:rank] 21 | B = np.transpose(np.conjugate(np.dot(u, vh[:rank]))) 22 | 23 | if return_rank: 24 | return B, Q, rank 25 | else: 26 | return B, Q 27 | 28 | 29 | def rref_sympy(V, row_vectors=True): 30 | if not row_vectors: 31 | V = V.T 32 | sym_V = sympy.Matrix(V) 33 | basis = sympy.matrices.matrix2numpy(sym_V.rref()[0], dtype=float) 34 | if not row_vectors: 35 | return basis.T 36 | else: 37 | return basis 38 | 39 | 40 | def rref(A, row_vectors=True, tol=None): 41 | # Follow the implementation: 42 | # https://www.mathworks.com/matlabcentral/fileexchange/21583-fast-reduced-row-echelon-form 43 | V = A.copy() 44 | if not row_vectors: 45 | V = V.T 46 | m, n = V.shape 47 | 48 | if tol is None: 49 | tol = max(m, n) * np.finfo(V.dtype).eps * linalg.norm(V, np.inf) 50 | 51 | i = 0 52 | j = 0 53 | jb = list() 54 | while i < m and j < n: 55 | # Find value and index of largest element in the remainder of column j. 56 | k = np.argmax(np.abs(V[i:m, j])) 57 | k += i 58 | p = abs(V[k, j]) 59 | 60 | if p <= tol: 61 | # The column is negligible, zero it out. 62 | V[i:m, j] = 0 63 | j += 1 64 | else: 65 | # Remember column index 66 | jb += [j] 67 | # Swap i-th and k-th rows. 68 | V[[i, k], j:n] = V[[k, i], j:n] 69 | # Divide the pivot row by the pivot element. 70 | Vi = V[i, j:n] / V[i, j] 71 | # Subtract multiples of the pivot row from all the other rows. 72 | V[:, j:n] = V[:, j:n] - np.outer(V[:, j], Vi) 73 | V[i, j:n] = Vi 74 | i += 1 75 | j += 1 76 | 77 | if not row_vectors: 78 | return V.T 79 | return V 80 | 81 | 82 | def gram_schmidt(A, row_vectors=True, norm_dim=None): 83 | V = A.copy() 84 | if not row_vectors: 85 | V = V.T 86 | for i, v in enumerate(V): 87 | prev_basis = V[0:i] 88 | coeff_vec = prev_basis @ v 89 | v -= coeff_vec @ prev_basis 90 | v_norm = np.linalg.norm(v) 91 | if v_norm > 1e-10: 92 | v /= v_norm 93 | else: 94 | v[v<1e-10] = 0 95 | 96 | if norm_dim is not None: 97 | for i, v in enumerate(V): 98 | bn = np.linalg.norm(v[:norm_dim]) 99 | if bn > 0.01: 100 | V[i] = v / bn 101 | 102 | if not row_vectors: 103 | return np.array(V).T 104 | else: 105 | return np.array(V) 106 | 107 | 108 | def orthogonalization_test(): 109 | m = 3 110 | n = 2 111 | V_1 = np.random.randn(m, n) 112 | z_1 = np.cross(V_1[:, 0], V_1[:, 1]) 113 | print("original: \n", V_1) 114 | 115 | V_2 = gram_schmidt(V_1, row_vectors=False) 116 | z_2 = np.cross(V_2[:, 0], V_2[:, 1]) 117 | print("gs: \n", V_2) 118 | 119 | V_3 = rref(V_1, row_vectors=False) 120 | z_3 = np.cross(V_3[:, 0], V_3[:, 1]) 121 | print("rref: \n", V_3) 122 | 123 | V_4 = gram_schmidt(V_3, row_vectors=False) 124 | z_4 = np.cross(V_4[:, 0], V_4[:, 1]) 125 | print("gs + rref:\n", V_4) 126 | 127 | V_q, v_r = np.linalg.qr(z_1[:, np.newaxis], mode='complete') 128 | V_q = V_q[:, np.where(v_r.flatten() == 0)[0]] 129 | z_q = np.cross(V_q[:, 0], V_q[:, 1]) 130 | print("qr: \n", V_q) 131 | 132 | V_svd = scipy.linalg.null_space(z_1[np.newaxis, :]) 133 | z_svd = np.cross(V_svd[:, 0], V_svd[:, 1]) 134 | print("svd: \n", V_svd) 135 | 136 | fig = plt.figure() 137 | ax = fig.add_subplot(projection='3d') 138 | ax.quiver(0, 0, 0, V_1[0, 0], V_1[1, 0], V_1[2, 0], color='tab:blue', label='original') 139 | ax.quiver(0, 0, 0, V_1[0, 1], V_1[1, 1], V_1[2, 1], color='tab:blue') 140 | ax.quiver(0, 0, 0, z_1[0], z_1[1], z_1[2], color='tab:blue', linestyle='--') 141 | 142 | ax.quiver(0, 0, 0, V_2[0, 0], V_2[1, 0], V_2[2, 0], color='tab:orange', label='GS') 143 | ax.quiver(0, 0, 0, V_2[0, 1], V_2[1, 1], V_2[2, 1], color='tab:orange') 144 | ax.quiver(0, 0, 0, z_2[0], z_2[1], z_2[2], color='tab:orange', linestyle='--') 145 | 146 | ax.quiver(0, 0, 0, V_3[0, 0], V_3[1, 0], V_3[2, 0], color='tab:red', label='RREF') 147 | ax.quiver(0, 0, 0, V_3[0, 1], V_3[1, 1], V_3[2, 1], color='tab:red') 148 | ax.quiver(0, 0, 0, z_3[0], z_3[1], z_3[2], color='tab:red', linestyle='--') 149 | 150 | ax.quiver(0, 0, 0, V_4[0, 0], V_4[1, 0], V_4[2, 0], color='tab:pink', label='RREF + GS') 151 | ax.quiver(0, 0, 0, V_4[0, 1], V_4[1, 1], V_4[2, 1], color='tab:pink') 152 | ax.quiver(0, 0, 0, z_4[0], z_4[1], z_4[2], color='tab:pink', linestyle='--') 153 | 154 | ax.quiver(0, 0, 0, V_q[0, 0], V_q[1, 0], V_q[2, 0], color='tab:brown', label='QR') 155 | ax.quiver(0, 0, 0, V_q[0, 1], V_q[1, 1], V_q[2, 1], color='tab:brown') 156 | ax.quiver(0, 0, 0, z_q[0], z_q[1], z_q[2], color='tab:brown', linestyle='--') 157 | 158 | ax.quiver(0, 0, 0, V_svd[0, 0], V_svd[1, 0], V_svd[2, 0], color='tab:purple', label='SVD') 159 | ax.quiver(0, 0, 0, V_svd[0, 1], V_svd[1, 1], V_svd[2, 1], color='tab:purple') 160 | ax.quiver(0, 0, 0, z_svd[0], z_svd[1], z_svd[2], color='tab:purple', linestyle='--') 161 | 162 | ax.set_xlim(-3, 3) 163 | ax.set_ylim(-3, 3) 164 | ax.set_zlim(-3, 3) 165 | ax.set_xlabel("x") 166 | ax.set_ylabel("y") 167 | ax.set_zlabel("z") 168 | ax.legend() 169 | plt.show() 170 | 171 | 172 | def rref_test(): 173 | for i in range(100): 174 | m, n = np.random.randint(1, 10, 2) 175 | V = np.random.randn(m, n) 176 | 177 | V_1 = rref(V) 178 | V_2 = rref_sympy(V) 179 | print(np.isclose(V_1, V_2).all()) 180 | 181 | 182 | def gram_schmidt_test(): 183 | for i in range(100): 184 | m = 7 185 | n = 5 186 | 187 | V = np.random.randn(m, n) 188 | V_rref = rref(V, row_vectors=False) 189 | V_orth = gram_schmidt(V_rref, row_vectors=False) 190 | print(V_orth.max(), np.abs(V_orth).min()) 191 | 192 | 193 | 194 | if __name__ == '__main__': 195 | # rref_test() 196 | # gram_schmidt_test() 197 | orthogonalization_test() 198 | 199 | 200 | -------------------------------------------------------------------------------- /atacom/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import scipy.stats as st 5 | 6 | 7 | def get_mean_and_confidence(data): 8 | """ 9 | Compute the mean and 95% confidence interval 10 | Args: 11 | data (np.ndarray): Array of experiment data of shape (n_runs, n_epochs). 12 | Returns: 13 | The mean of the dataset at each epoch along with the confidence interval. 14 | """ 15 | mean = np.mean(data, axis=0) 16 | se = st.sem(data, axis=0) 17 | n = len(data) 18 | interval, _ = st.t.interval(0.95, n - 1, scale=se) 19 | return mean, interval 20 | 21 | 22 | def get_data_list(file_dir, prefix): 23 | data = list() 24 | for filename in os.listdir(file_dir): 25 | if filename.startswith(prefix+'-'): 26 | data.append(np.load(os.path.join(file_dir, filename), allow_pickle=True)) 27 | return np.array(data) 28 | 29 | 30 | def plot_learning_curve(log_dir, exp_list): 31 | cm = plt.get_cmap('tab10') 32 | 33 | fig1, ax1 = plt.subplots(3, 1, figsize=(10, 10)) 34 | fig1.subplots_adjust(hspace=.5) 35 | fig2, ax2 = plt.subplots(3, 1, figsize=(10, 10)) 36 | fig2.subplots_adjust(hspace=.5) 37 | 38 | for i, exp_suffix in enumerate(exp_list): 39 | color_i = cm(i) 40 | exp_dir = os.path.join(log_dir, exp_suffix) 41 | _, seeds, _ = next(os.walk(exp_dir)) 42 | 43 | exp_J_list = get_data_list(exp_dir, "J") 44 | exp_R_list = get_data_list(exp_dir, "R") 45 | exp_E_list = get_data_list(exp_dir, "E") 46 | c_max_list = get_data_list(exp_dir, "c_max") 47 | c_avg_list = get_data_list(exp_dir, "c_avg") 48 | c_dq_max_list = get_data_list(exp_dir, "c_dq_max") 49 | 50 | mean_J, conf_J = get_mean_and_confidence(exp_J_list) 51 | mean_R, conf_R = get_mean_and_confidence(exp_R_list) 52 | 53 | ax1[0].plot(mean_J, label=exp_suffix, color=color_i) 54 | ax1[0].fill_between(np.arange(np.size(mean_J)), mean_J + conf_J, mean_J - conf_J, alpha=0.2, color=color_i) 55 | ax1[0].set_title("J") 56 | ax1[1].plot(mean_R, label=exp_suffix, color=color_i) 57 | ax1[1].fill_between(np.arange(np.size(mean_R)), mean_R + conf_R, mean_R - conf_R, alpha=0.2, color=color_i) 58 | ax1[1].set_title("R") 59 | ax1[1].legend() 60 | 61 | if np.all(exp_E_list!=None): 62 | mean_E, conf_E = get_mean_and_confidence(exp_E_list[:, 1:]) 63 | ax1[2].plot(mean_E, label=exp_suffix, color=color_i) 64 | ax1[2].fill_between(np.arange(np.size(mean_E)), mean_E + conf_E, mean_E - conf_E, alpha=0.2, color=color_i) 65 | ax1[2].set_title("E") 66 | 67 | mean_c_max, conf_c_max = get_mean_and_confidence(c_max_list) 68 | mean_c_avg, conf_c_avg = get_mean_and_confidence(c_avg_list) 69 | mean_c_dq_max, conf_c_dq_max = get_mean_and_confidence(c_dq_max_list) 70 | ax2[0].plot(mean_c_max, label=exp_suffix, color=color_i) 71 | ax2[0].fill_between(np.arange(np.size(mean_c_max)), mean_c_max + conf_c_max, 72 | mean_c_max - conf_c_max, alpha=0.2, color=color_i) 73 | ax2[0].set_title("c_max") 74 | ax2[1].plot(mean_c_avg, label=exp_suffix, color=color_i) 75 | ax2[1].fill_between(np.arange(np.size(mean_c_avg)), mean_c_avg + conf_c_avg, 76 | mean_c_avg - conf_c_avg, alpha=0.2, color=color_i) 77 | ax2[1].set_title("c_avg") 78 | ax2[2].plot(mean_c_dq_max, label=exp_suffix, color=color_i) 79 | ax2[2].fill_between(np.arange(np.size(mean_c_dq_max)), mean_c_dq_max + conf_c_dq_max, 80 | mean_c_dq_max - conf_c_dq_max, alpha=0.2, color=color_i) 81 | ax2[2].set_title("c_dq_max") 82 | ax2[2].legend() 83 | 84 | fig1.savefig(os.path.join(log_dir, "Reward.pdf")) 85 | fig2.savefig(os.path.join(log_dir, "Constraints.pdf")) 86 | plt.show() 87 | 88 | 89 | def plot_learning_metric(log_dir, exp_list, metric, label_list, title, y_scale='linear', file_name=None): 90 | cm = plt.get_cmap('tab10') 91 | 92 | fig = plt.figure(figsize=(12, 9)) 93 | ax = plt.gca() 94 | 95 | for i, exp_suffix in enumerate(exp_list): 96 | color_i = cm(i) 97 | exp_dir = os.path.join(log_dir, exp_suffix) 98 | _, seeds, _ = next(os.walk(exp_dir)) 99 | 100 | metric_list = get_data_list(exp_dir, str(metric)) 101 | 102 | mean_metric, conf_metric = get_mean_and_confidence(metric_list) 103 | 104 | ax.plot(mean_metric, label=label_list[i], color=color_i) 105 | ax.fill_between(np.arange(np.size(mean_metric)), mean_metric + conf_metric, mean_metric - conf_metric, 106 | alpha=0.1, color=color_i) 107 | ax.legend(fontsize=30) 108 | ax.set_yscale(y_scale) 109 | 110 | ax.set_title(title, fontsize=40) 111 | ax.tick_params('both', labelsize=30) 112 | if file_name is None: 113 | file_name = title + ".pdf" 114 | else: 115 | file_name += ".pdf" 116 | fig.savefig(os.path.join(log_dir, file_name)) 117 | plt.show() 118 | 119 | def plot_learning_curve_single(log_dir, exp_name, seeds): 120 | fig1, ax1 = plt.subplots(3, 1) 121 | fig1.subplots_adjust(hspace=.5) 122 | fig2, ax2 = plt.subplots(3, 1) 123 | fig2.subplots_adjust(hspace=.5) 124 | 125 | exp_dir = os.path.join(log_dir, exp_name) 126 | 127 | for seed in seeds: 128 | postfix = "-" + str(seed) + ".npy" 129 | J = np.load(os.path.join(exp_dir, "J" + postfix)) 130 | R = np.load(os.path.join(exp_dir, "R" + postfix)) 131 | E = np.load(os.path.join(exp_dir, "E" + postfix), allow_pickle=True) 132 | c_max = np.load(os.path.join(exp_dir, "c_max" + postfix)) 133 | c_avg = np.load(os.path.join(exp_dir, "c_avg" + postfix)) 134 | c_dq_max = np.load(os.path.join(exp_dir, "c_dq_max" + postfix)) 135 | 136 | ax1[0].plot(J) 137 | ax1[0].set_title("J") 138 | ax1[1].plot(R) 139 | ax1[1].set_title("R") 140 | 141 | if np.all(E!=None): 142 | ax1[2].plot(E) 143 | ax1[2].set_title("E") 144 | 145 | ax2[0].plot(c_max) 146 | ax2[0].plot(np.zeros_like(c_max), c='tab:red', lw=2) 147 | ax2[0].set_title("c_max") 148 | ax2[1].plot(c_avg) 149 | ax2[1].plot(np.zeros_like(c_avg), c='tab:red', lw=2) 150 | ax2[1].set_title("c_avg") 151 | ax2[2].plot(c_dq_max) 152 | ax2[2].plot(np.zeros_like(c_dq_max), c='tab:red', lw=2) 153 | ax2[2].set_title("c_dq_max") 154 | 155 | plt.show() -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/examples/__init__.py -------------------------------------------------------------------------------- /examples/circle_exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from copy import deepcopy 4 | from tqdm import trange 5 | import pandas as pd 6 | 7 | from mushroom_rl.algorithms.actor_critic import PPO, TRPO, DDPG, TD3, SAC 8 | from mushroom_rl.core import Core, Logger 9 | from mushroom_rl.policy import GaussianTorchPolicy, OrnsteinUhlenbeckPolicy, ClippedGaussianPolicy 10 | from mushroom_rl.utils.dataset import compute_J, parse_dataset 11 | from atacom.environments.circular_motion import CircleEnvAtacom, CircleEnvTerminated, CircleEnvErrorCorrection 12 | from network import * 13 | 14 | 15 | def experiment(mdp, agent, seed, results_dir, n_epochs, n_steps, n_steps_per_fit, n_episodes_test, 16 | quiet, render, **kwargs): 17 | build_params = kwargs['build_params'] 18 | logger = Logger(results_dir=results_dir, seed=seed, log_name="exp") 19 | 20 | logger.strong_line() 21 | logger.info('Experiment Algorithm: ' + type(agent).__name__) 22 | logger.info('Environment: ' + type(mdp).__name__ + " seed: " + str(seed)) 23 | 24 | best_agent = deepcopy(agent) 25 | 26 | core = Core(agent, mdp) 27 | 28 | eval_params = dict( 29 | n_episodes=n_episodes_test, 30 | render=render, 31 | quiet=quiet 32 | ) 33 | 34 | J, R, E, c_avg, c_max, c_dq_max = compute_metrics(core, eval_params, build_params) 35 | best_J, best_R, best_E, best_c_avg, best_c_max, best_c_dq_max = J, R, E, c_avg, c_max, c_dq_max 36 | 37 | logger.epoch_info(0, J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max) 38 | logger.weak_line() 39 | logger.log_numpy(J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max) 40 | 41 | for it in trange(n_epochs, leave=False, disable=quiet): 42 | core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit, quiet=quiet, render=render) 43 | J, R, E, c_avg, c_max, c_dq_max = compute_metrics(core, eval_params, build_params) 44 | 45 | logger.epoch_info(it + 1, J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max) 46 | logger.log_numpy(J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max) 47 | 48 | if J > best_J: 49 | best_J = J 50 | best_R = R 51 | best_E = E 52 | best_c_avg = c_avg 53 | best_c_max = c_max 54 | best_c_dq_max = c_dq_max 55 | best_agent = deepcopy(agent) 56 | 57 | if it % 10 == 0: 58 | logger.log_agent(agent, epoch=it) 59 | 60 | logger.info("Best result | J: {}, R: {}, E:{}, c_avg:{}, c_max:{}, c_dq_max{}.".format(best_J, best_R, best_E, 61 | best_c_avg, best_c_max, 62 | best_c_dq_max)) 63 | logger.strong_line() 64 | logger.log_agent(best_agent) 65 | best_res = {"best_J": best_J, "best_R": best_R, "best_E": best_E, 66 | "best_c_avg": best_c_avg, "best_c_max": best_c_max, "best_c_dq_max": best_c_dq_max} 67 | best_res = pd.DataFrame.from_dict(best_res, orient="index") 68 | best_res.to_csv(os.path.join(logger.path, "best_result.csv")) 69 | 70 | 71 | def compute_metrics(core, eval_params, build_params): 72 | dataset = core.evaluate(**eval_params) 73 | c_avg, c_max, c_dq_max = core.mdp.get_constraints_logs() 74 | J = np.mean(compute_J(dataset, core.mdp.info.gamma)) 75 | R = np.mean(compute_J(dataset)) 76 | E = None 77 | if build_params['compute_policy_entropy']: 78 | if build_params['compute_entropy_with_states']: 79 | E = core.agent.policy.entropy(parse_dataset(dataset)[0]) 80 | else: 81 | E = core.agent.policy.entropy() 82 | return J, R, E, c_avg, c_max, c_dq_max 83 | 84 | 85 | def build_env(env, horizon, gamma, random_init, **kwargs): 86 | if env == 'A': 87 | mdp = CircleEnvAtacom(horizon=horizon, gamma=gamma, random_init=random_init, Kc=100) 88 | elif env == 'T': 89 | mdp = CircleEnvTerminated(horizon=horizon, gamma=gamma, random_init=random_init, tol=kwargs['termination_tol']) 90 | elif env == 'E': 91 | mdp = CircleEnvErrorCorrection(horizon=horizon, gamma=gamma, random_init=random_init) 92 | else: 93 | raise NotImplementedError 94 | return mdp 95 | 96 | 97 | def build_agent(alg, mdp_info, **kwargs): 98 | alg = alg.upper() 99 | if alg == 'PPO': 100 | agent, build_params = build_agent_PPO(mdp_info, **kwargs) 101 | elif alg == 'TRPO': 102 | agent, build_params = build_agent_TRPO(mdp_info, **kwargs) 103 | elif alg == 'DDPG': 104 | agent, build_params = build_agent_DDPG(mdp_info, **kwargs) 105 | elif alg == 'TD3': 106 | agent, build_params = build_agent_TD3(mdp_info, **kwargs) 107 | elif alg == 'SAC': 108 | agent, build_params = build_agent_SAC(mdp_info, **kwargs) 109 | else: 110 | raise NotImplementedError 111 | return agent, build_params 112 | 113 | 114 | def build_agent_PPO(mdp_info, actor_lr, critic_lr, n_features, batch_size, eps_ppo, lam, ent_coeff, **kwargs): 115 | policy_params = dict( 116 | std_0=0.5, 117 | n_features=n_features, 118 | use_cuda=torch.cuda.is_available() 119 | ) 120 | policy = GaussianTorchPolicy(PPONetwork, 121 | mdp_info.observation_space.shape, 122 | mdp_info.action_space.shape, 123 | **policy_params) 124 | 125 | critic_params = dict(network=PPONetwork, 126 | optimizer={'class': optim.Adam, 127 | 'params': {'lr': critic_lr}}, 128 | loss=F.mse_loss, 129 | n_features=n_features, 130 | batch_size=batch_size, 131 | input_shape=mdp_info.observation_space.shape, 132 | output_shape=(1,)) 133 | 134 | ppo_params = dict(actor_optimizer={'class': optim.Adam, 135 | 'params': {'lr': actor_lr}}, 136 | n_epochs_policy=4, 137 | batch_size=batch_size, 138 | eps_ppo=eps_ppo, 139 | lam=lam, 140 | ent_coeff=ent_coeff, 141 | critic_params=critic_params) 142 | 143 | build_params = dict(compute_entropy_with_states=False, 144 | compute_policy_entropy=True) 145 | 146 | return PPO(mdp_info, policy, **ppo_params), build_params 147 | 148 | 149 | def build_agent_TRPO(mdp_info, critic_lr, n_features, batch_size, lam, ent_coeff, 150 | max_kl, n_epochs_line_search, n_epochs_cg, cg_damping, cg_residual_tol, critic_fit_params, 151 | **kwargs): 152 | policy_params = dict( 153 | std_0=0.5, 154 | n_features=n_features, 155 | use_cuda=torch.cuda.is_available() 156 | ) 157 | 158 | critic_params = dict( 159 | network=TRPONetwork, 160 | optimizer={'class': optim.Adam, 161 | 'params': {'lr': critic_lr}}, 162 | loss=F.mse_loss, 163 | n_features=n_features, 164 | batch_size=batch_size, 165 | input_shape=mdp_info.observation_space.shape, 166 | output_shape=(1,)) 167 | 168 | trpo_params = dict( 169 | ent_coeff=ent_coeff, 170 | max_kl=max_kl, 171 | lam=lam, 172 | n_epochs_line_search=n_epochs_line_search, 173 | n_epochs_cg=n_epochs_cg, 174 | cg_damping=cg_damping, 175 | cg_residual_tol=cg_residual_tol, 176 | critic_fit_params=critic_fit_params) 177 | 178 | policy = GaussianTorchPolicy(TRPONetwork, 179 | mdp_info.observation_space.shape, 180 | mdp_info.action_space.shape, 181 | **policy_params) 182 | 183 | build_params = dict(compute_entropy_with_states=False, 184 | compute_policy_entropy=True) 185 | 186 | return TRPO(mdp_info, policy, critic_params, **trpo_params), build_params 187 | 188 | 189 | def build_agent_DDPG(mdp_info, actor_lr, critic_lr, n_features, batch_size, 190 | initial_replay_size, max_replay_size, tau, **kwargs): 191 | policy_params = dict( 192 | sigma=np.ones(1) * .2, 193 | theta=0.15, 194 | dt=1e-2) 195 | 196 | actor_params = dict( 197 | network=DDPGActorNetwork, 198 | input_shape=mdp_info.observation_space.shape, 199 | output_shape=mdp_info.action_space.shape, 200 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2, 201 | n_features=n_features, 202 | use_cuda=torch.cuda.is_available()) 203 | 204 | actor_optimizer = { 205 | 'class': optim.Adam, 206 | 'params': {'lr': actor_lr}} 207 | 208 | critic_params = dict( 209 | network=DDPGCriticNetwork, 210 | optimizer={'class': optim.Adam, 211 | 'params': {'lr': critic_lr}}, 212 | loss=F.mse_loss, 213 | n_features=n_features, 214 | batch_size=batch_size, 215 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],), 216 | action_shape=mdp_info.action_space.shape, 217 | output_shape=(1,), 218 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2, 219 | use_cuda=torch.cuda.is_available()) 220 | 221 | alg_params = dict( 222 | initial_replay_size=initial_replay_size, 223 | max_replay_size=max_replay_size, 224 | batch_size=batch_size, 225 | tau=tau) 226 | 227 | build_params = dict(compute_entropy_with_states=False, 228 | compute_policy_entropy=False) 229 | 230 | return DDPG(mdp_info, OrnsteinUhlenbeckPolicy, policy_params, actor_params, actor_optimizer, critic_params, 231 | **alg_params), build_params 232 | 233 | 234 | def build_agent_TD3(mdp_info, actor_lr, critic_lr, n_features, batch_size, 235 | initial_replay_size, max_replay_size, tau, sigma, **kwargs): 236 | policy_params = dict( 237 | sigma=np.eye(mdp_info.action_space.shape[0]) * sigma, 238 | low=mdp_info.action_space.low, 239 | high=mdp_info.action_space.high) 240 | 241 | actor_params = dict( 242 | network=TD3ActorNetwork, 243 | input_shape=mdp_info.observation_space.shape, 244 | output_shape=mdp_info.action_space.shape, 245 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2, 246 | n_features=n_features, 247 | use_cuda=torch.cuda.is_available()) 248 | 249 | actor_optimizer = { 250 | 'class': optim.Adam, 251 | 'params': {'lr': actor_lr}} 252 | 253 | critic_params = dict( 254 | network=TD3CriticNetwork, 255 | optimizer={'class': optim.Adam, 256 | 'params': {'lr': critic_lr}}, 257 | loss=F.mse_loss, 258 | n_features=n_features, 259 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],), 260 | action_shape=mdp_info.action_space.shape, 261 | output_shape=(1,), 262 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2, 263 | use_cuda=torch.cuda.is_available()) 264 | 265 | alg_params = dict( 266 | initial_replay_size=initial_replay_size, 267 | max_replay_size=max_replay_size, 268 | batch_size=batch_size, 269 | tau=tau) 270 | 271 | build_params = dict(compute_entropy_with_states=False, 272 | compute_policy_entropy=False) 273 | 274 | return TD3(mdp_info, ClippedGaussianPolicy, policy_params, actor_params, actor_optimizer, critic_params, 275 | **alg_params), build_params 276 | 277 | 278 | def build_agent_SAC(mdp_info, actor_lr, critic_lr, n_features, batch_size, 279 | initial_replay_size, max_replay_size, tau, 280 | warmup_transitions, lr_alpha, target_entropy, 281 | **kwargs): 282 | actor_mu_params = dict(network=SACActorNetwork, 283 | input_shape=mdp_info.observation_space.shape, 284 | output_shape=mdp_info.action_space.shape, 285 | n_features=n_features, 286 | use_cuda=torch.cuda.is_available()) 287 | actor_sigma_params = dict(network=SACActorNetwork, 288 | input_shape=mdp_info.observation_space.shape, 289 | output_shape=mdp_info.action_space.shape, 290 | n_features=n_features, 291 | use_cuda=torch.cuda.is_available()) 292 | 293 | actor_optimizer = {'class': optim.Adam, 294 | 'params': {'lr': actor_lr}} 295 | critic_params = dict(network=SACCriticNetwork, 296 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],), 297 | optimizer={'class': optim.Adam, 298 | 'params': {'lr': critic_lr}}, 299 | loss=F.mse_loss, 300 | n_features=n_features, 301 | output_shape=(1,), 302 | use_cuda=torch.cuda.is_available()) 303 | 304 | alg_params = dict(initial_replay_size=initial_replay_size, 305 | max_replay_size=max_replay_size, 306 | batch_size=batch_size, 307 | warmup_transitions=warmup_transitions, 308 | tau=tau, 309 | lr_alpha=lr_alpha, 310 | critic_fit_params=None, 311 | target_entropy=target_entropy) 312 | 313 | build_params = dict(compute_entropy_with_states=True, 314 | compute_policy_entropy=True) 315 | 316 | return SAC(mdp_info, actor_mu_params, actor_sigma_params, actor_optimizer, critic_params, 317 | **alg_params), build_params 318 | 319 | 320 | def default_params(): 321 | defaults = dict(env='A', alg='TRPO', seed=1, 322 | horizon=500, gamma=0.99, random_init=False, quiet=False, termination_tol=0.4, render=False, 323 | results_dir='../logs/circular_motion') 324 | training_params = dict(n_epochs=50, n_steps=5000, n_steps_per_fit=1000, n_episodes_test=25) 325 | 326 | network_params = dict(actor_lr=3e-4, critic_lr=3e-4, n_features=[32, 32], batch_size=64) 327 | 328 | trpo_ppo_params = dict(lam=0.95, ent_coeff=5e-5) 329 | ppo_params = dict(eps_ppo=0.1) 330 | trpo_params = dict(max_kl=1e-2, n_epochs_line_search=10, n_epochs_cg=10, cg_damping=1e-2, cg_residual_tol=1e-10, 331 | critic_fit_params=None) 332 | 333 | ddpg_td3_sac_params = dict(initial_replay_size=5000, max_replay_size=200000, tau=1e-3) 334 | td3_params = dict(sigma=1.0) 335 | 336 | sac_params = dict(warmup_transitions=10000, lr_alpha=3e-3, target_entropy=-6) 337 | 338 | defaults.update(training_params) 339 | defaults.update(network_params) 340 | defaults.update(trpo_ppo_params) 341 | defaults.update(ppo_params) 342 | defaults.update(trpo_params) 343 | defaults.update(ddpg_td3_sac_params) 344 | defaults.update(td3_params) 345 | defaults.update(sac_params) 346 | return defaults 347 | 348 | 349 | def parse_args(): 350 | parser = argparse.ArgumentParser() 351 | 352 | arg_test = parser.add_argument_group('Experiment') 353 | arg_test.add_argument('--env', choices=['A', 'T', 'E'], help="Environment argument ['A', 'T', 'E']: " 354 | "'A' for ATACOM, " 355 | "'T' for TerminatedCircle, " 356 | "'E' for ErrorCorrection") 357 | arg_test.add_argument('--alg', choices=['TRPO', 'trpo', 'PPO', 'ppo', 'DDPG', 'ddpg', 'TD3', 'td3', 'SAC', 'sac']) 358 | 359 | arg_test.add_argument('--horizon', type=int) 360 | arg_test.add_argument('--gamma', type=float) 361 | arg_test.add_argument('--random-init', action="store_true") 362 | arg_test.add_argument('--termination-tol', type=float) 363 | arg_test.add_argument('--quiet', action="store_true") 364 | arg_test.add_argument('--render', action="store_true") 365 | 366 | # training parameter 367 | arg_test.add_argument('--n-epochs', type=int) 368 | arg_test.add_argument('--n-steps', type=int) 369 | arg_test.add_argument('--n-steps-per-fit', type=int) 370 | arg_test.add_argument('--n-episodes-test', type=int) 371 | 372 | # network parameter 373 | arg_test.add_argument('--actor-lr', type=float) 374 | arg_test.add_argument('--critic-lr', type=float) 375 | arg_test.add_argument('--n-features', nargs='+') 376 | arg_test.add_argument('--batch-size', type=int) 377 | 378 | # TRPO PPO parameter 379 | arg_test.add_argument('--lam', type=float) 380 | arg_test.add_argument('--ent-coeff', type=float) 381 | 382 | # PPO parameters 383 | arg_test.add_argument('--eps-ppo', type=float) 384 | 385 | # TRPO parameters 386 | arg_test.add_argument('--max-kl', type=float) 387 | arg_test.add_argument('--n-epochs-line-search', type=int) 388 | arg_test.add_argument('--n-epochs-cg', type=int) 389 | arg_test.add_argument('--cg-damping', type=float) 390 | arg_test.add_argument('--cg-residual-tol', type=float) 391 | 392 | # DDPG TD3 parameters 393 | arg_test.add_argument('--initial-replay-size', type=int) 394 | arg_test.add_argument('--max-replay-size', type=int) 395 | arg_test.add_argument('--tau', type=float) 396 | 397 | # TD3 parameters 398 | arg_test.add_argument('--sigma', type=float) 399 | 400 | # SAC parameters 401 | arg_test.add_argument('--warmup-transitions', type=int) 402 | arg_test.add_argument('--lr-alpha', type=float) 403 | arg_test.add_argument('--target-entropy', type=float) 404 | 405 | arg_default = parser.add_argument_group('Default') 406 | arg_default.add_argument('--seed', type=int) 407 | arg_default.add_argument('--results-dir', type=str) 408 | 409 | parser.set_defaults(**default_params()) 410 | args = parser.parse_args() 411 | return vars(args) 412 | 413 | 414 | if __name__ == '__main__': 415 | args_ = parse_args() 416 | env_ = build_env(**args_) 417 | agent_, build_params_ = build_agent(mdp_info=env_.info, **args_) 418 | experiment(env_, agent_, build_params=build_params_, **args_) 419 | -------------------------------------------------------------------------------- /examples/iiwa_air_hockey_exp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pandas as pd 5 | from mushroom_rl.algorithms.actor_critic import PPO, TRPO, DDPG, TD3, SAC 6 | from mushroom_rl.core import Core, Logger 7 | from mushroom_rl.policy import GaussianTorchPolicy, OrnsteinUhlenbeckPolicy, ClippedGaussianPolicy 8 | from mushroom_rl.utils.dataset import compute_J, parse_dataset 9 | from mushroom_rl.utils.preprocessors import MinMaxPreprocessor 10 | from tqdm import trange 11 | 12 | from atacom.environments.iiwa_air_hockey import AirHockeyIiwaAtacom, AirHockeyIiwaRmp 13 | from network import * 14 | 15 | 16 | def experiment(seed, results_dir, n_epochs, n_steps, n_steps_per_fit, n_episodes_test, 17 | quiet, **kwargs): 18 | mdp = build_env(**kwargs) 19 | 20 | agent, build_params = build_agent(mdp_info=mdp.info, **kwargs) 21 | 22 | logger = Logger(results_dir=results_dir, seed=seed, log_name='exp') 23 | 24 | logger.strong_line() 25 | logger.info('Experiment Algorithm: ' + type(agent).__name__) 26 | if hasattr(mdp, "env"): 27 | logger.info('Environment: ' + type(mdp.env).__name__ + " seed: " + str(seed)) 28 | else: 29 | logger.info('Environment: ' + type(mdp).__name__ + " seed: " + str(seed)) 30 | 31 | # normalization callback 32 | prepro = MinMaxPreprocessor(mdp_info=mdp.info) 33 | 34 | core = Core(agent, mdp, preprocessors=[prepro]) 35 | 36 | eval_params = dict( 37 | n_episodes=n_episodes_test, 38 | render=False, 39 | quiet=quiet 40 | ) 41 | 42 | J, R, E, c_avg, c_max, c_dq_max = compute_metrics(core, eval_params, build_params) 43 | best_J, best_R, best_E, best_c_avg, best_c_max, best_c_dq_max = J, R, E, c_avg, c_max, c_dq_max 44 | 45 | logger.epoch_info(0, J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max) 46 | logger.log_numpy(J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max) 47 | logger.log_agent(agent) 48 | prepro.save(os.path.join(logger.path, "state_normalization" + logger._suffix + ".msh")) 49 | 50 | for it in trange(n_epochs, leave=False, disable=quiet): 51 | core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit, quiet=quiet) 52 | J, R, E, c_avg, c_max, c_dq_max = compute_metrics(core, eval_params, build_params) 53 | 54 | logger.epoch_info(it + 1, J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max) 55 | logger.log_numpy(J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max) 56 | 57 | if J > best_J: 58 | best_J = J 59 | best_R = R 60 | best_E = E 61 | best_c_avg = c_avg 62 | best_c_max = c_max 63 | best_c_dq_max = c_dq_max 64 | 65 | logger.log_agent(agent) 66 | prepro.save(os.path.join(logger.path, "state_normalization" + logger._suffix + ".msh")) 67 | 68 | logger.info("Best result | J: {}, R: {}, E:{}, c_avg:{}, c_max:{}, c_dq_max{}.".format(best_J, best_R, best_E, 69 | best_c_avg, best_c_max, 70 | best_c_dq_max)) 71 | logger.strong_line() 72 | best_res = {"best_J": best_J, "best_R": best_R, "best_E": best_E, 73 | "best_c_avg": best_c_avg, "best_c_max": best_c_max, "best_c_dq_max": best_c_dq_max} 74 | best_res = pd.DataFrame.from_dict(best_res, orient="index") 75 | best_res.to_csv(os.path.join(logger.path, "best_result.csv")) 76 | 77 | 78 | def compute_metrics(core, eval_params, build_params): 79 | dataset = core.evaluate(**eval_params) 80 | c_avg, c_max, c_dq_max = 0., 0., 0. 81 | if hasattr(core.mdp, "get_constraints_logs"): 82 | c_avg, c_max, c_dq_max = core.mdp.get_constraints_logs() 83 | J = np.mean(compute_J(dataset, core.mdp.info.gamma)) 84 | R = np.mean(compute_J(dataset)) 85 | E = None 86 | if build_params['compute_policy_entropy']: 87 | if build_params['compute_entropy_with_states']: 88 | E = core.agent.policy.entropy(parse_dataset(dataset)[0]) 89 | else: 90 | E = core.agent.policy.entropy() 91 | return J, R, E, c_avg, c_max, c_dq_max 92 | 93 | 94 | def build_env(**kwargs): 95 | env = kwargs['env'] 96 | gamma = kwargs['gamma'] 97 | random_init = kwargs['random_init'] 98 | debug_gui = kwargs['debug_gui'] 99 | 100 | iiwa_hit_horizon = kwargs['iiwa_hit_horizon'] 101 | iiwa_hit_time_step = kwargs['iiwa_hit_time_step'] 102 | iiwa_hit_n_intermediate_steps = kwargs['iiwa_hit_n_intermediate_steps'] 103 | 104 | if env == '7H': 105 | mdp = AirHockeyIiwaAtacom(task='H', horizon=iiwa_hit_horizon, gamma=gamma, random_init=random_init, 106 | timestep=iiwa_hit_time_step, n_intermediate_steps=iiwa_hit_n_intermediate_steps, 107 | debug_gui=debug_gui) 108 | elif env == 'RMP': 109 | mdp = AirHockeyIiwaRmp(task='H', horizon=120, gamma=gamma, random_init=random_init, 110 | timestep=1 / 240., n_intermediate_steps=4, 111 | debug_gui=debug_gui) 112 | else: 113 | raise NotImplementedError 114 | return mdp 115 | 116 | 117 | def build_agent(alg, mdp_info, **kwargs): 118 | if isinstance(kwargs['n_features'], str): 119 | kwargs['n_features'] = kwargs['n_features'].split(' ') 120 | 121 | alg = alg.upper() 122 | if alg == 'PPO': 123 | agent, build_params = build_agent_PPO(mdp_info, **kwargs) 124 | elif alg == 'TRPO': 125 | agent, build_params = build_agent_TRPO(mdp_info, **kwargs) 126 | elif alg == 'DDPG': 127 | agent, build_params = build_agent_DDPG(mdp_info, **kwargs) 128 | elif alg == 'TD3': 129 | agent, build_params = build_agent_TD3(mdp_info, **kwargs) 130 | elif alg == 'SAC': 131 | agent, build_params = build_agent_SAC(mdp_info, **kwargs) 132 | else: 133 | raise NotImplementedError 134 | return agent, build_params 135 | 136 | 137 | def build_agent_PPO(mdp_info, actor_lr, critic_lr, n_features, batch_size, eps_ppo, lam, ent_coeff, use_cuda, **kwargs): 138 | policy_params = dict( 139 | std_0=0.5, 140 | n_features=n_features, 141 | use_cuda=use_cuda 142 | ) 143 | policy = GaussianTorchPolicy(PPONetwork, 144 | mdp_info.observation_space.shape, 145 | mdp_info.action_space.shape, 146 | **policy_params) 147 | 148 | critic_params = dict(network=PPONetwork, 149 | optimizer={'class': optim.Adam, 150 | 'params': {'lr': critic_lr}}, 151 | loss=F.mse_loss, 152 | n_features=n_features, 153 | batch_size=batch_size, 154 | input_shape=mdp_info.observation_space.shape, 155 | output_shape=(1,)) 156 | 157 | ppo_params = dict(actor_optimizer={'class': optim.Adam, 158 | 'params': {'lr': actor_lr}}, 159 | n_epochs_policy=4, 160 | batch_size=batch_size, 161 | eps_ppo=eps_ppo, 162 | lam=lam, 163 | ent_coeff=ent_coeff, 164 | critic_params=critic_params) 165 | 166 | build_params = dict(compute_entropy_with_states=False, 167 | compute_policy_entropy=True) 168 | 169 | return PPO(mdp_info, policy, **ppo_params), build_params 170 | 171 | 172 | def build_agent_TRPO(mdp_info, critic_lr, n_features, batch_size, lam, ent_coeff, use_cuda, 173 | max_kl, n_epochs_line_search, n_epochs_cg, cg_damping, cg_residual_tol, critic_fit_params, 174 | **kwargs): 175 | policy_params = dict( 176 | std_0=0.5, 177 | n_features=n_features, 178 | use_cuda=use_cuda 179 | ) 180 | 181 | critic_params = dict( 182 | network=TRPONetwork, 183 | optimizer={'class': optim.Adam, 184 | 'params': {'lr': critic_lr}}, 185 | loss=F.mse_loss, 186 | n_features=n_features, 187 | batch_size=batch_size, 188 | input_shape=mdp_info.observation_space.shape, 189 | output_shape=(1,)) 190 | 191 | trpo_params = dict( 192 | ent_coeff=ent_coeff, 193 | max_kl=max_kl, 194 | lam=lam, 195 | n_epochs_line_search=n_epochs_line_search, 196 | n_epochs_cg=n_epochs_cg, 197 | cg_damping=cg_damping, 198 | cg_residual_tol=cg_residual_tol, 199 | critic_fit_params=critic_fit_params) 200 | 201 | policy = GaussianTorchPolicy(TRPONetwork, 202 | mdp_info.observation_space.shape, 203 | mdp_info.action_space.shape, 204 | **policy_params) 205 | 206 | build_params = dict(compute_entropy_with_states=False, 207 | compute_policy_entropy=True) 208 | 209 | return TRPO(mdp_info, policy, critic_params, **trpo_params), build_params 210 | 211 | 212 | def build_agent_DDPG(mdp_info, actor_lr, critic_lr, n_features, batch_size, 213 | initial_replay_size, max_replay_size, tau, use_cuda, **kwargs): 214 | policy_params = dict( 215 | sigma=np.ones(1) * .2, 216 | theta=0.15, 217 | dt=1e-2) 218 | 219 | actor_params = dict( 220 | network=DDPGActorNetwork, 221 | input_shape=mdp_info.observation_space.shape, 222 | output_shape=mdp_info.action_space.shape, 223 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2, 224 | n_features=n_features, 225 | use_cuda=use_cuda) 226 | 227 | actor_optimizer = { 228 | 'class': optim.Adam, 229 | 'params': {'lr': actor_lr}} 230 | 231 | critic_params = dict( 232 | network=DDPGCriticNetwork, 233 | optimizer={'class': optim.Adam, 234 | 'params': {'lr': critic_lr}}, 235 | loss=F.mse_loss, 236 | n_features=n_features, 237 | batch_size=batch_size, 238 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],), 239 | action_shape=mdp_info.action_space.shape, 240 | output_shape=(1,), 241 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2, 242 | use_cuda=use_cuda) 243 | 244 | alg_params = dict( 245 | initial_replay_size=initial_replay_size, 246 | max_replay_size=max_replay_size, 247 | batch_size=batch_size, 248 | tau=tau) 249 | 250 | build_params = dict(compute_entropy_with_states=False, 251 | compute_policy_entropy=False) 252 | 253 | return DDPG(mdp_info, OrnsteinUhlenbeckPolicy, policy_params, actor_params, actor_optimizer, critic_params, 254 | **alg_params), build_params 255 | 256 | 257 | def build_agent_TD3(mdp_info, actor_lr, critic_lr, n_features, batch_size, use_cuda, 258 | initial_replay_size, max_replay_size, tau, sigma, **kwargs): 259 | policy_params = dict( 260 | sigma=np.eye(mdp_info.action_space.shape[0]) * sigma, 261 | low=mdp_info.action_space.low, 262 | high=mdp_info.action_space.high) 263 | 264 | actor_params = dict( 265 | network=TD3ActorNetwork, 266 | input_shape=mdp_info.observation_space.shape, 267 | output_shape=mdp_info.action_space.shape, 268 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2, 269 | n_features=n_features, 270 | use_cuda=use_cuda) 271 | 272 | actor_optimizer = { 273 | 'class': optim.Adam, 274 | 'params': {'lr': actor_lr}} 275 | 276 | critic_params = dict( 277 | network=TD3CriticNetwork, 278 | optimizer={'class': optim.Adam, 279 | 'params': {'lr': critic_lr}}, 280 | loss=F.mse_loss, 281 | n_features=n_features, 282 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],), 283 | action_shape=mdp_info.action_space.shape, 284 | output_shape=(1,), 285 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2, 286 | use_cuda=use_cuda) 287 | 288 | alg_params = dict( 289 | initial_replay_size=initial_replay_size, 290 | max_replay_size=max_replay_size, 291 | batch_size=batch_size, 292 | tau=tau) 293 | 294 | build_params = dict(compute_entropy_with_states=False, 295 | compute_policy_entropy=False) 296 | 297 | return TD3(mdp_info, ClippedGaussianPolicy, policy_params, actor_params, actor_optimizer, critic_params, 298 | **alg_params), build_params 299 | 300 | 301 | def build_agent_SAC(mdp_info, actor_lr, critic_lr, n_features, batch_size, 302 | initial_replay_size, max_replay_size, tau, 303 | warmup_transitions, lr_alpha, target_entropy, use_cuda, 304 | **kwargs): 305 | actor_mu_params = dict(network=SACActorNetwork, 306 | input_shape=mdp_info.observation_space.shape, 307 | output_shape=mdp_info.action_space.shape, 308 | n_features=n_features, 309 | use_cuda=use_cuda) 310 | actor_sigma_params = dict(network=SACActorNetwork, 311 | input_shape=mdp_info.observation_space.shape, 312 | output_shape=mdp_info.action_space.shape, 313 | n_features=n_features, 314 | use_cuda=use_cuda) 315 | 316 | actor_optimizer = {'class': optim.Adam, 317 | 'params': {'lr': actor_lr}} 318 | critic_params = dict(network=SACCriticNetwork, 319 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],), 320 | optimizer={'class': optim.Adam, 321 | 'params': {'lr': critic_lr}}, 322 | loss=F.mse_loss, 323 | n_features=n_features, 324 | output_shape=(1,), 325 | use_cuda=use_cuda) 326 | 327 | alg_params = dict(initial_replay_size=initial_replay_size, 328 | max_replay_size=max_replay_size, 329 | batch_size=batch_size, 330 | warmup_transitions=warmup_transitions, 331 | tau=tau, 332 | lr_alpha=lr_alpha, 333 | critic_fit_params=None, 334 | target_entropy=target_entropy) 335 | 336 | build_params = dict(compute_entropy_with_states=True, 337 | compute_policy_entropy=True) 338 | 339 | return SAC(mdp_info, actor_mu_params, actor_sigma_params, actor_optimizer, critic_params, 340 | **alg_params), build_params 341 | 342 | 343 | def default_params(): 344 | defaults = dict(env='RMP', alg='SAC', seed=1, 345 | gamma=0.99, random_init=False, debug_gui=True, quiet=False, use_cuda=False, 346 | iiwa_hit_time_step=1 / 240., iiwa_hit_n_intermediate_steps=4, iiwa_hit_horizon=120, 347 | results_dir='../logs/iiwa_hit') 348 | training_params = dict(n_epochs=100, n_steps=3000, n_steps_per_fit=600, n_episodes_test=25) 349 | 350 | network_params = dict(actor_lr=3e-4, critic_lr=3e-4, n_features=[64, 64], batch_size=64) 351 | 352 | trpo_ppo_params = dict(lam=0.95, ent_coeff=5e-5) 353 | ppo_params = dict(eps_ppo=0.1) 354 | trpo_params = dict(max_kl=1e-2, n_epochs_line_search=10, n_epochs_cg=10, cg_damping=1e-2, cg_residual_tol=1e-10, 355 | critic_fit_params=None) 356 | 357 | ddpg_td3_sac_params = dict(initial_replay_size=5000, max_replay_size=200000, tau=1e-3) 358 | td3_params = dict(sigma=0.25) 359 | 360 | sac_params = dict(warmup_transitions=10000, lr_alpha=3e-4, target_entropy=-6) 361 | 362 | defaults.update(training_params) 363 | defaults.update(network_params) 364 | defaults.update(trpo_ppo_params) 365 | defaults.update(ppo_params) 366 | defaults.update(trpo_params) 367 | defaults.update(ddpg_td3_sac_params) 368 | defaults.update(td3_params) 369 | defaults.update(sac_params) 370 | return defaults 371 | 372 | 373 | def parse_args(): 374 | parser = argparse.ArgumentParser() 375 | 376 | arg_test = parser.add_argument_group('Experiment') 377 | arg_test.add_argument('--env', choices=['7H', 'RMP'], 378 | help="Environment argument ['7H', 'RMP']: " 379 | "7H for Iiwa Hitting using ATACOM, " 380 | "RMP for Iiwa Hitting using Riemannian Motion Policies.") 381 | arg_test.add_argument('--alg', choices=['TRPO', 'trpo', 'PPO', 'ppo', 'DDPG', 'ddpg', 'TD3', 'td3', 'SAC', 'sac']) 382 | 383 | arg_test.add_argument('--gamma', type=float) 384 | arg_test.add_argument('--random-init', action="store_true") 385 | arg_test.add_argument('--termination-tol', type=float) 386 | arg_test.add_argument('--debug-gui', action="store_true") 387 | arg_test.add_argument('--quiet', action="store_true") 388 | arg_test.add_argument('--use-cuda', action="store_true") 389 | 390 | arg_test.add_argument('--iiwa-hit-horizon', type=int) 391 | arg_test.add_argument('--iiwa-hit-time-step', type=float) 392 | arg_test.add_argument('--iiwa-hit-n-intermediate-steps', type=int) 393 | 394 | # training parameter 395 | arg_test.add_argument('--n-epochs', type=int) 396 | arg_test.add_argument('--n-steps', type=int) 397 | arg_test.add_argument('--n-steps-per-fit', type=int) 398 | arg_test.add_argument('--n-episodes-test', type=int) 399 | 400 | # network parameter 401 | arg_test.add_argument('--actor-lr', type=float) 402 | arg_test.add_argument('--critic-lr', type=float) 403 | arg_test.add_argument('--n-features', nargs='+') 404 | arg_test.add_argument('--batch-size', type=int) 405 | 406 | # TRPO PPO parameter 407 | arg_test.add_argument('--lam', type=float) 408 | arg_test.add_argument('--ent-coeff', type=float) 409 | 410 | # PPO parameters 411 | arg_test.add_argument('--eps-ppo', type=float) 412 | 413 | # TRPO parameters 414 | arg_test.add_argument('--max-kl', type=float) 415 | arg_test.add_argument('--n-epochs-line-search', type=int) 416 | arg_test.add_argument('--n-epochs-cg', type=int) 417 | arg_test.add_argument('--cg-damping', type=float) 418 | arg_test.add_argument('--cg-residual-tol', type=float) 419 | 420 | # DDPG TD3 parameters 421 | arg_test.add_argument('--initial-replay-size', type=int) 422 | arg_test.add_argument('--max-replay-size', type=int) 423 | arg_test.add_argument('--tau', type=float) 424 | 425 | # TD3 parameters 426 | arg_test.add_argument('--sigma', type=float) 427 | 428 | # SAC parameters 429 | arg_test.add_argument('--warmup-transitions', type=int) 430 | arg_test.add_argument('--lr-alpha', type=float) 431 | arg_test.add_argument('--target-entropy', type=float) 432 | 433 | arg_default = parser.add_argument_group('Default') 434 | arg_default.add_argument('--seed', type=int) 435 | arg_default.add_argument('--results-dir', type=str) 436 | 437 | parser.set_defaults(**default_params()) 438 | args = parser.parse_args() 439 | return vars(args) 440 | 441 | 442 | if __name__ == '__main__': 443 | args_ = parse_args() 444 | experiment(**args_) 445 | -------------------------------------------------------------------------------- /examples/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | 8 | class PPONetwork(nn.Module): 9 | def __init__(self, input_shape, output_shape, n_features, **kwargs): 10 | super(PPONetwork, self).__init__() 11 | 12 | n_input = input_shape[-1] 13 | n_output = output_shape[0] 14 | 15 | assert len(n_features) == 2, 'PPO network needs 2 hidden layers' 16 | n_features = list(map(int, n_features)) 17 | 18 | self._h1 = nn.Linear(n_input, n_features[0]) 19 | self._h2 = nn.Linear(n_features[0], n_features[1]) 20 | self._h3 = nn.Linear(n_features[1], n_output) 21 | 22 | nn.init.xavier_uniform_(self._h1.weight, 23 | gain=nn.init.calculate_gain('relu')) 24 | nn.init.xavier_uniform_(self._h2.weight, 25 | gain=nn.init.calculate_gain('relu')) 26 | nn.init.xavier_uniform_(self._h3.weight, 27 | gain=nn.init.calculate_gain('linear')) 28 | 29 | def forward(self, state, **kwargs): 30 | features1 = F.relu(self._h1(torch.squeeze(state, 1).float())) 31 | features2 = F.relu(self._h2(features1)) 32 | a = self._h3(features2) 33 | 34 | return a 35 | 36 | 37 | class TRPONetwork(nn.Module): 38 | def __init__(self, input_shape, output_shape, n_features, **kwargs): 39 | super(TRPONetwork, self).__init__() 40 | 41 | n_input = input_shape[-1] 42 | n_output = output_shape[0] 43 | 44 | assert len(n_features) == 2, 'TRPO network needs 2 hidden layers' 45 | n_features = list(map(int, n_features)) 46 | 47 | self._h1 = nn.Linear(n_input, n_features[0]) 48 | self._h2 = nn.Linear(n_features[0], n_features[1]) 49 | self._h3 = nn.Linear(n_features[1], n_output) 50 | 51 | nn.init.xavier_uniform_(self._h1.weight, 52 | gain=nn.init.calculate_gain('relu')) 53 | nn.init.xavier_uniform_(self._h2.weight, 54 | gain=nn.init.calculate_gain('relu')) 55 | nn.init.xavier_uniform_(self._h3.weight, 56 | gain=nn.init.calculate_gain('linear')) 57 | 58 | def forward(self, state, **kwargs): 59 | features1 = F.relu(self._h1(torch.squeeze(state, 1).float())) 60 | features2 = F.relu(self._h2(features1)) 61 | a = self._h3(features2) 62 | 63 | return a 64 | 65 | 66 | class TD3CriticNetwork(nn.Module): 67 | def __init__(self, input_shape, output_shape, n_features, **kwargs): 68 | super().__init__() 69 | 70 | n_input = input_shape[-1] 71 | dim_action = kwargs['action_shape'][0] 72 | dim_state = n_input - dim_action 73 | n_output = output_shape[0] 74 | 75 | self._action_scaling = torch.tensor(kwargs['action_scaling'], dtype=torch.float32).to( 76 | device=torch.device('cuda' if kwargs['use_cuda'] else 'cpu')) 77 | 78 | # Assume there are two hidden layers 79 | assert len(n_features) == 2, 'TD3 critic needs 2 hidden layers' 80 | n_features = list(map(int, n_features)) 81 | 82 | self._h1 = nn.Linear(dim_state + dim_action, n_features[0]) 83 | self._h2_s = nn.Linear(n_features[0], n_features[1]) 84 | self._h2_a = nn.Linear(dim_action, n_features[1], bias=False) 85 | self._h3 = nn.Linear(n_features[1], n_output) 86 | 87 | fan_in_h1, _ = nn.init._calculate_fan_in_and_fan_out(self._h1.weight) 88 | nn.init.uniform_(self._h1.weight, a=-1 / np.sqrt(fan_in_h1), b=1 / np.sqrt(fan_in_h1)) 89 | 90 | fan_in_h2_s, _ = nn.init._calculate_fan_in_and_fan_out(self._h2_s.weight) 91 | nn.init.uniform_(self._h2_s.weight, a=-1 / np.sqrt(fan_in_h2_s), b=1 / np.sqrt(fan_in_h2_s)) 92 | 93 | fan_in_h2_a, _ = nn.init._calculate_fan_in_and_fan_out(self._h2_a.weight) 94 | nn.init.uniform_(self._h2_a.weight, a=-1 / np.sqrt(fan_in_h2_a), b=1 / np.sqrt(fan_in_h2_a)) 95 | 96 | nn.init.uniform_(self._h3.weight, a=-3e-3, b=3e-3) 97 | 98 | def forward(self, state, action): 99 | state = state.float() 100 | action = action.float() / self._action_scaling 101 | state_action = torch.cat((state, action), dim=1) 102 | 103 | features1 = F.relu(self._h1(state_action)) 104 | features2_s = self._h2_s(features1) 105 | features2_a = self._h2_a(action) 106 | features2 = F.relu(features2_s + features2_a) 107 | 108 | q = self._h3(features2) 109 | return torch.squeeze(q) 110 | 111 | 112 | class TD3ActorNetwork(nn.Module): 113 | def __init__(self, input_shape, output_shape, n_features, **kwargs): 114 | super().__init__() 115 | 116 | dim_state = input_shape[0] 117 | dim_action = output_shape[0] 118 | 119 | self._action_scaling = torch.tensor(kwargs['action_scaling']).to( 120 | device=torch.device('cuda' if kwargs['use_cuda'] else 'cpu')) 121 | 122 | # Assume there are two hidden layers 123 | assert len(n_features) == 2, 'TD3 actor needs two hidden layers' 124 | n_features = list(map(int, n_features)) 125 | 126 | self._h1 = nn.Linear(dim_state, n_features[0]) 127 | self._h2 = nn.Linear(n_features[0], n_features[1]) 128 | self._h3 = nn.Linear(n_features[1], dim_action) 129 | 130 | fan_in_h1, _ = nn.init._calculate_fan_in_and_fan_out(self._h1.weight) 131 | nn.init.uniform_(self._h1.weight, a=-1 / np.sqrt(fan_in_h1), b=1 / np.sqrt(fan_in_h1)) 132 | 133 | fan_in_h2, _ = nn.init._calculate_fan_in_and_fan_out(self._h2.weight) 134 | nn.init.uniform_(self._h2.weight, a=-1 / np.sqrt(fan_in_h2), b=1 / np.sqrt(fan_in_h2)) 135 | 136 | nn.init.uniform_(self._h3.weight, a=-3e-3, b=3e-3) 137 | 138 | def forward(self, state): 139 | state = state.float() 140 | 141 | features1 = F.relu(self._h1(state)) 142 | features2 = F.relu(self._h2(features1)) 143 | a = self._h3(features2) 144 | 145 | a = self._action_scaling * torch.tanh(a) 146 | 147 | return a 148 | 149 | 150 | class DDPGCriticNetwork(nn.Module): 151 | def __init__(self, input_shape, output_shape, n_features, **kwargs): 152 | super().__init__() 153 | 154 | n_input = input_shape[-1] 155 | dim_action = kwargs['action_shape'][0] 156 | dim_state = n_input - dim_action 157 | 158 | self._action_scaling = torch.tensor(kwargs['action_scaling'], dtype=torch.float).to( 159 | device=torch.device('cuda' if kwargs['use_cuda'] else 'cpu')) 160 | 161 | n_output = output_shape[0] 162 | 163 | # Assume there are two hidden layers 164 | assert len(n_features) == 2, 'DDPG critic needs 2 hidden layers' 165 | n_features = list(map(int, n_features)) 166 | 167 | self._h1 = nn.Linear(dim_state, n_features[0]) 168 | self._h2_s = nn.Linear(n_features[0], n_features[1]) 169 | self._h2_a = nn.Linear(dim_action, n_features[1], bias=False) 170 | self._h3 = nn.Linear(n_features[1], n_output) 171 | 172 | fan_in_h1, _ = nn.init._calculate_fan_in_and_fan_out(self._h1.weight) 173 | nn.init.uniform_(self._h1.weight, a=-1 / np.sqrt(fan_in_h1), b=1 / np.sqrt(fan_in_h1)) 174 | 175 | fan_in_h2_s, _ = nn.init._calculate_fan_in_and_fan_out(self._h2_s.weight) 176 | nn.init.uniform_(self._h2_s.weight, a=-1 / np.sqrt(fan_in_h2_s), b=1 / np.sqrt(fan_in_h2_s)) 177 | 178 | fan_in_h2_a, _ = nn.init._calculate_fan_in_and_fan_out(self._h2_a.weight) 179 | nn.init.uniform_(self._h2_a.weight, a=-1 / np.sqrt(fan_in_h2_a), b=1 / np.sqrt(fan_in_h2_a)) 180 | 181 | nn.init.uniform_(self._h3.weight, a=-3e-3, b=3e-3) 182 | 183 | def forward(self, state, action): 184 | state = state.float() 185 | action = action.float() / self._action_scaling 186 | 187 | features1 = F.relu(self._h1(state)) 188 | features2_s = self._h2_s(features1) 189 | features2_a = self._h2_a(action) 190 | features2 = F.relu(features2_s + features2_a) 191 | 192 | q = self._h3(features2) 193 | 194 | return torch.squeeze(q) 195 | 196 | 197 | class DDPGActorNetwork(nn.Module): 198 | def __init__(self, input_shape, output_shape, n_features, **kwargs): 199 | super().__init__() 200 | 201 | dim_state = input_shape[0] 202 | dim_action = output_shape[0] 203 | 204 | self._action_scaling = torch.tensor(kwargs['action_scaling']).to( 205 | device=torch.device('cuda' if kwargs['use_cuda'] else 'cpu')) 206 | 207 | # Assume there are two hidden layers 208 | assert len(n_features) == 2, 'DDPG actor needs 2 hidden layers' 209 | n_features = list(map(int, n_features)) 210 | 211 | self._h1 = nn.Linear(dim_state, n_features[0]) 212 | self._h2 = nn.Linear(n_features[0], n_features[1]) 213 | self._h3 = nn.Linear(n_features[1], dim_action) 214 | 215 | fan_in_h1, _ = nn.init._calculate_fan_in_and_fan_out(self._h1.weight) 216 | nn.init.uniform_(self._h1.weight, a=-1 / np.sqrt(fan_in_h1), b=1 / np.sqrt(fan_in_h1)) 217 | 218 | fan_in_h2, _ = nn.init._calculate_fan_in_and_fan_out(self._h2.weight) 219 | nn.init.uniform_(self._h2.weight, a=-1 / np.sqrt(fan_in_h2), b=1 / np.sqrt(fan_in_h2)) 220 | 221 | nn.init.uniform_(self._h3.weight, a=-3e-3, b=3e-3) 222 | 223 | def forward(self, state): 224 | state = state.float() 225 | 226 | features1 = F.relu(self._h1(state)) 227 | features2 = F.relu(self._h2(features1)) 228 | a = self._h3(features2) 229 | 230 | a = self._action_scaling * torch.tanh(a) 231 | 232 | return a 233 | 234 | 235 | class SACCriticNetwork(nn.Module): 236 | def __init__(self, input_shape, output_shape, n_features, **kwargs): 237 | super().__init__() 238 | 239 | n_input = input_shape[-1] 240 | n_output = output_shape[0] 241 | 242 | # Assume there are two hidden layers 243 | assert len(n_features) == 2, 'SAC actor needs 2 hidden layers' 244 | n_features = list(map(int, n_features)) 245 | 246 | self._h1 = nn.Linear(n_input, n_features[0]) 247 | self._h2 = nn.Linear(n_features[0], n_features[1]) 248 | self._h3 = nn.Linear(n_features[1], n_output) 249 | 250 | nn.init.xavier_uniform_(self._h1.weight, 251 | gain=nn.init.calculate_gain('relu')) 252 | nn.init.xavier_uniform_(self._h2.weight, 253 | gain=nn.init.calculate_gain('relu')) 254 | nn.init.xavier_uniform_(self._h3.weight, 255 | gain=nn.init.calculate_gain('linear')) 256 | 257 | def forward(self, state, action): 258 | state_action = torch.cat((state.float(), action.float()), dim=1) 259 | features1 = F.relu(self._h1(state_action)) 260 | features2 = F.relu(self._h2(features1)) 261 | q = self._h3(features2) 262 | 263 | return torch.squeeze(q) 264 | 265 | 266 | class SACActorNetwork(nn.Module): 267 | def __init__(self, input_shape, output_shape, n_features, **kwargs): 268 | super(SACActorNetwork, self).__init__() 269 | 270 | n_input = input_shape[-1] 271 | n_output = output_shape[0] 272 | 273 | # Assume there are two hidden layers 274 | assert len(n_features) == 2, 'SAC actor needs 2 hidden layers' 275 | n_features = list(map(int, n_features)) 276 | 277 | self._h1 = nn.Linear(n_input, n_features[0]) 278 | self._h2 = nn.Linear(n_features[0], n_features[1]) 279 | self._h3 = nn.Linear(n_features[1], n_output) 280 | 281 | nn.init.xavier_uniform_(self._h1.weight, 282 | gain=nn.init.calculate_gain('relu')) 283 | nn.init.xavier_uniform_(self._h2.weight, 284 | gain=nn.init.calculate_gain('relu')) 285 | nn.init.xavier_uniform_(self._h3.weight, 286 | gain=nn.init.calculate_gain('linear')) 287 | 288 | def forward(self, state): 289 | features1 = F.relu(self._h1(torch.squeeze(state, 1).float())) 290 | features2 = F.relu(self._h2(features1)) 291 | a = self._h3(features2) 292 | 293 | return a -------------------------------------------------------------------------------- /fig/manifold.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/fig/manifold.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.17.0 2 | matplotlib>=3.3.4 3 | scipy>=1.5.2 4 | mushroom-rl>=1.7.0 5 | pandas>=1.2.1 6 | pybullet>=3.0.8 7 | torch 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from codecs import open 3 | from os import path 4 | 5 | version = '1.0.0' 6 | 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | requires_list = [] 11 | with open(path.join(here, 'requirements.txt'), encoding='utf-8') as f: 12 | for line in f: 13 | requires_list.append(str(line)) 14 | 15 | extras = {} 16 | 17 | all_deps = [] 18 | for group_name in extras: 19 | all_deps += extras[group_name] 20 | extras['all'] = all_deps 21 | 22 | long_description = 'TODO.' 23 | 24 | setup( 25 | name='atacom', 26 | version=version, 27 | description='Acting on the Tangent Space of the Constraint Manifold', 28 | long_description=long_description, 29 | author="Puze Liu", 30 | author_email='puze@robot-learning.de', 31 | license='MIT', 32 | packages=[package for package in find_packages() 33 | if package.startswith('node')], 34 | zip_safe=False, 35 | install_requires=requires_list, 36 | extras_require=extras, 37 | classifiers=["Programming Language :: Python :: 3", 38 | "License :: OSI Approved :: MIT License", 39 | "Operating System :: OS Independent", 40 | ] 41 | ) 42 | --------------------------------------------------------------------------------