├── .gitignore
├── README.md
├── atacom
├── __init__.py
├── atacom.py
├── constraints.py
├── environments
│ ├── __init__.py
│ ├── circular_motion
│ │ ├── __init__.py
│ │ ├── circle_atacom.py
│ │ ├── circle_base.py
│ │ ├── circle_error_correction.py
│ │ └── circle_terminated.py
│ ├── collision_avoidance
│ │ ├── __init__.py
│ │ ├── collision_avoidance_atacom.py
│ │ └── collision_avoidance_base.py
│ ├── iiwa_air_hockey
│ │ ├── __init__.py
│ │ ├── env_base.py
│ │ ├── env_hitting.py
│ │ ├── env_single.py
│ │ ├── iiwa_air_hockey_rmp.py
│ │ ├── iiwa_hit_atacom.py
│ │ ├── kinematics.py
│ │ └── urdf
│ │ │ ├── iiwa_1.urdf
│ │ │ ├── iiwa_2.urdf
│ │ │ └── meshes
│ │ │ ├── collision
│ │ │ ├── link_0.stl
│ │ │ ├── link_1.stl
│ │ │ ├── link_2.stl
│ │ │ ├── link_3.stl
│ │ │ ├── link_4.stl
│ │ │ ├── link_5.stl
│ │ │ ├── link_6.stl
│ │ │ ├── link_7.stl
│ │ │ └── link_7_old.stl
│ │ │ ├── striker
│ │ │ ├── collision
│ │ │ │ ├── EE_arm_collision.stl
│ │ │ │ ├── EE_mallet_collision.stl
│ │ │ │ └── EE_mallet_short_collision.stl
│ │ │ └── visual
│ │ │ │ ├── EE_arm.stl
│ │ │ │ ├── EE_mallet.stl
│ │ │ │ └── EE_mallet_short.stl
│ │ │ └── visual
│ │ │ ├── link_0.stl
│ │ │ ├── link_1.stl
│ │ │ ├── link_2.stl
│ │ │ ├── link_3.stl
│ │ │ ├── link_4.stl
│ │ │ ├── link_5.stl
│ │ │ ├── link_6.stl
│ │ │ ├── link_7.stl
│ │ │ └── link_7_old.stl
│ └── planar_air_hockey
│ │ ├── __init__.py
│ │ ├── atacom_air_hockey.py
│ │ └── unconstrained_air_hockey.py
├── error_correction_wrapper.py
└── utils
│ ├── __init__.py
│ ├── null_space_coordinate.py
│ └── plot_utils.py
├── examples
├── __init__.py
├── circle_exp.py
├── collision_avoidance_exp.py
├── iiwa_air_hockey_exp.py
├── network.py
└── planar_air_hockey_exp.py
├── fig
└── manifold.gif
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | __pycache__/
3 | *.egg-info/
4 | logs/
5 | result/
6 | *runs/
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Acting on the Tangent Space of the Constraint Manifold
2 |
3 |
4 |
5 | Implementation of "Robot Reinforcement Learning on the Constraint Manifold"
6 |
7 | [[paper]](https://www.ias.informatik.tu-darmstadt.de/uploads/Team/PuzeLiu/CORL_2021_Learning_on_the_Manifold.pdf)
8 | [[website]](https://sites.google.com/view/robot-air-hockey/atacom)
9 |
10 | ## Install
11 | ```python
12 | pip install -e .
13 | ```
14 |
15 | ## Run Examples
16 | ```python
17 | cd examples
18 | ```
19 | ### CircularMotion Environment.
20 | Environment options [A, E, T], algorithms options [TRPO, PPO, SAC, DDPG, TD3]
21 | ```python
22 | python circle_exp.py --render --env A --alg TRPO
23 | ```
24 |
25 | ### PlanarAirHockey Environment.
26 | Environment options [H, D, UH, UD], algorithms options [TRPO, PPO, SAC, DDPG, TD3]
27 | ```python
28 | python planar_air_hockey_exp.py --debug-gui --env H --alg SAC
29 | ```
30 |
31 | ### IiwaAirHockey Environment.
32 | Environment options [7H, RMP], algorithms options [TRPO, PPO, SAC, DDPG, TD3]
33 | ```python
34 | python iiwa_air_hockey_exp.py --debug-gui --env 7H --alg SAC
35 | ```
36 |
37 | ### CollisionAvoidance Environment.
38 | Environment options [C], algorithms options [TRPO, PPO, SAC, DDPG, TD3]
39 | ```python
40 | python collision_avoidance_exp.py --render --env C --alg SAC
41 | ```
42 |
43 |
44 | ## Bibtex
45 | ```bibtex
46 | @inproceedings{CORL_2021_Learning_on_the_Manifold,
47 | author = "Liu, P. and Tateo D. and Bou-Ammar, H. and Peters, J.",
48 | year = "2021",
49 | title = "Robot Reinforcement Learning on the Constraint Manifold",
50 | booktitle = "Proceedings of the Conference on Robot Learning (CoRL)",
51 | key = "robot learning, constrained reinforcement learning, safe exploration",
52 | }
53 | ```
54 |
--------------------------------------------------------------------------------
/atacom/__init__.py:
--------------------------------------------------------------------------------
1 | __version = '1.0.0'
2 |
--------------------------------------------------------------------------------
/atacom/atacom.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.utils.spaces import *
2 | from atacom.utils.null_space_coordinate import rref, gram_schmidt, pinv_null
3 |
4 |
5 | class AtacomEnvWrapper:
6 | """
7 | Environment wrapper of ATACOM
8 | """
9 |
10 | def __init__(self, base_env, dim_q, vel_max, acc_max, f=None, g=None, Kc=1., Kq=10., time_step=0.01):
11 | """
12 | Constructor
13 | Args:
14 | base_env (mushroomrl.Core.Environment): The base environment inherited from
15 | dim_q (int): [int] dimension of the directly controllable variable
16 | vel_max (array, float): the maximum velocity of the directly controllable variable
17 | acc_max (array, float): the maximum acceleration of the directly controllable variable
18 | f (ViabilityConstraint, ConstraintsSet): the equality constraint f(q) = 0
19 | g (ViabilityConstraint, ConstraintsSet): the inequality constraint g(q) = 0
20 | Kc (array, float): the scaling factor for error correction
21 | Ka (array, float): the scaling factor for the viability acceleration bound
22 | time_step (float): the step size for time discretization
23 | """
24 | self.env = base_env
25 | self.dims = {'q': dim_q, 'f': 0, 'g': 0}
26 | self.f = f
27 | self.g = g
28 | self.time_step = time_step
29 | self._logger = None
30 |
31 | if self.f is not None:
32 | assert self.dims['q'] == self.f.dim_q, "Input dimension is different in f"
33 | self.dims['f'] = self.f.dim_out
34 | if self.g is not None:
35 | assert self.dims['q'] == self.g.dim_q, "Input dimension is different in g"
36 | self.dims['g'] = self.g.dim_out
37 | self.s = np.zeros(self.dims['g'])
38 |
39 | self.dims['null'] = self.dims['q'] - self.dims['f']
40 | self.dims['c'] = self.dims['f'] + self.dims['g']
41 |
42 | if np.isscalar(Kc):
43 | self.K_c = np.ones(self.dims['c']) * Kc
44 | else:
45 | self.K_c = Kc
46 |
47 | self.q = np.zeros(self.dims['q'])
48 | self.dq = np.zeros(self.dims['q'])
49 |
50 | self._mdp_info = self.env.info.copy()
51 | self._mdp_info.action_space = Box(low=-np.ones(self.dims['null']), high=np.ones(self.dims['null']))
52 |
53 | if np.isscalar(vel_max):
54 | self.vel_max = np.ones(self.dims['q']) * vel_max
55 | else:
56 | self.vel_max = vel_max
57 | assert np.shape(self.vel_max)[0] == self.dims['q']
58 |
59 | if np.isscalar(acc_max):
60 | self.acc_max = np.ones(self.dims['q']) * acc_max
61 | else:
62 | self.acc_max = acc_max
63 | assert np.shape(self.acc_max)[0] == self.dims['q']
64 |
65 | if np.isscalar(Kq):
66 | self.K_q = np.ones(self.dims['q']) * Kq
67 | else:
68 | self.K_q = Kq
69 | assert np.shape(self.K_q)[0] == self.dims['q']
70 |
71 | self.alpha_max = np.ones(self.dims['null']) * self.acc_max.max()
72 |
73 | self.state = self.env.reset()
74 | self._act_a = None
75 | self._act_b = None
76 | self._act_err = None
77 |
78 | self.constr_logs = list()
79 | self.env.step_action_function = self.step_action_function
80 |
81 | def _get_q(self, state):
82 | raise NotImplementedError
83 |
84 | def _get_dq(self, state):
85 | raise NotImplementedError
86 |
87 | def acc_to_ctrl_action(self, ddq):
88 | raise NotImplementedError
89 |
90 | def seed(self, seed):
91 | self.env.seed(seed)
92 |
93 | def reset(self, state=None):
94 | self.state = self.env.reset(state)
95 | self.q = self._get_q(self.state)
96 | self.dq = self._get_dq(self.state)
97 | self._compute_slack_variables()
98 | return self.state
99 |
100 | def render(self):
101 | self.env.render()
102 |
103 | def stop(self):
104 | self.env.stop()
105 |
106 | def step(self, action):
107 | alpha = np.clip(action, self.info.action_space.low, self.info.action_space.high)
108 | alpha = alpha * self.alpha_max
109 |
110 | self.state, reward, absorb, info = self.env.step(alpha)
111 | self.q = self._get_q(self.state)
112 | self.dq = self._get_dq(self.state)
113 | if not hasattr(self.env, "get_constraints_logs"):
114 | self._update_constraint_stats(self.q, self.dq)
115 | return self.state.copy(), reward, absorb, info
116 |
117 | def acc_truncation(self, dq, ddq):
118 | acc_u = np.maximum(np.minimum(self.acc_max, -self.K_q * (dq - self.vel_max)), -self.acc_max)
119 | acc_l = np.minimum(np.maximum(-self.acc_max, -self.K_q * (dq + self.vel_max)), self.acc_max)
120 | ddq = np.clip(ddq, acc_l, acc_u)
121 | return ddq
122 |
123 | def step_action_function(self, sim_state, alpha):
124 | self.state = self.env._create_observation(sim_state)
125 |
126 | Jc, psi = self._construct_Jc_psi(self.q, self.s, self.dq)
127 | Jc_inv, Nc = pinv_null(Jc)
128 | Nc = rref(Nc[:, :self.dims['null']], row_vectors=False, tol=0.05)
129 |
130 | self._act_a = -Jc_inv @ psi
131 | self._act_b = Nc @ alpha
132 | self._act_err = self._compute_error_correction(self.q, self.dq, self.s, Jc_inv)
133 | ddq_ds = self._act_a + self._act_b + self._act_err
134 |
135 | self.s += ddq_ds[self.dims['q']:(self.dims['q'] + self.dims['g'])] * self.time_step
136 |
137 | ddq = self.acc_truncation(self.dq, ddq_ds[:self.dims['q']])
138 | ctrl_action = self.acc_to_ctrl_action(ddq)
139 | return ctrl_action
140 |
141 | @property
142 | def info(self):
143 | return self._mdp_info
144 |
145 | def _compute_slack_variables(self):
146 | self.s = None
147 | if self.dims['g'] > 0:
148 | s_2 = np.maximum(-2 * self.g.fun(self.q, self.dq, origin_constr=False), 0)
149 | self.s = np.sqrt(s_2)
150 |
151 | def _construct_Jc_psi(self, q, s, dq):
152 | Jc = np.zeros((self.dims['f'] + self.dims['g'], self.dims['q'] + self.dims['g']))
153 | psi = np.zeros(self.dims['c'])
154 | if self.dims['f'] > 0:
155 | idx_0 = 0
156 | idx_1 = self.dims['f']
157 | Jc[idx_0:idx_1, :self.dims['q']] = self.f.K_J(q)
158 | psi[idx_0:idx_1] = self.f.b(q, dq)
159 | if self.dims['g'] > 0:
160 | idx_0 = self.dims['f']
161 | idx_1 = self.dims['f'] + self.dims['g']
162 | Jc[idx_0:idx_1, :self.dims['q']] = self.g.K_J(q)
163 | Jc[idx_0:idx_1, self.dims['q']:(self.dims['q'] + self.dims['g'])] = np.diag(s)
164 | psi[idx_0:idx_1] = self.g.b(q, dq)
165 | return Jc, psi
166 |
167 | def _compute_error_correction(self, q, dq, s, Jc_inv, act_null=None):
168 | q_tmp = q.copy()
169 | dq_tmp = dq.copy()
170 | s_tmp = None
171 |
172 | if self.dims['g'] > 0:
173 | s_tmp = s.copy()
174 |
175 | if act_null is not None:
176 | q_tmp += dq_tmp * self.time_step + act_null[:self.dims['q']] * self.time_step ** 2 / 2
177 | dq_tmp += act_null[:self.dims['q']] * self.time_step
178 | if self.dims['g'] > 0:
179 | s_tmp += act_null[self.dims['q']:self.dims['q'] + self.dims['g']] * self.time_step
180 |
181 | return -Jc_inv @ (self.K_c * self._compute_c(q_tmp, dq_tmp, s_tmp, origin_constr=False))
182 |
183 | def _compute_c(self, q, dq, s, origin_constr=False):
184 | c = np.zeros(self.dims['f'] + self.dims['g'])
185 | if self.dims['f'] > 0:
186 | idx_0 = 0
187 | idx_1 = self.dims['f']
188 | c[idx_0:idx_1] = self.f.fun(q, dq, origin_constr)
189 | if self.dims['g'] > 0:
190 | idx_0 = self.dims['f']
191 | idx_1 = self.dims['f'] + self.dims['g']
192 | if origin_constr:
193 | c[idx_0:idx_1] = self.g.fun(q, dq, origin_constr)
194 | else:
195 | c[idx_0:idx_1] = self.g.fun(q, dq, origin_constr) + 1 / 2 * s ** 2
196 | return c
197 |
198 | def set_logger(self, logger):
199 | self._logger = logger
200 |
201 | def _update_constraint_stats(self, q, dq):
202 | c_i = self._compute_c(q, dq, 0., origin_constr=True)
203 | c_i[:self.dims['f']] = np.abs(c_i[:self.dims['f']])
204 | c_dq_i = np.abs(dq) - self.vel_max
205 | self.constr_logs.append([np.max(c_i), np.max(c_dq_i)])
206 |
207 | def get_constraints_logs(self):
208 | if not hasattr(self.env, "get_constraints_logs"):
209 | constr_logs = np.array(self.constr_logs)
210 | c_avg = np.mean(constr_logs[:, 0])
211 | c_max = np.max(constr_logs[:, 0])
212 | c_dq_max = np.max(constr_logs[:, 1])
213 | self.constr_logs.clear()
214 | return c_avg, c_max, c_dq_max
215 | else:
216 | return self.env.get_constraints_logs()
217 |
--------------------------------------------------------------------------------
/atacom/constraints.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class ViabilityConstraint:
5 | """
6 | Class of viability constraint
7 | f(q) + K_f df(q, dq) = 0
8 | g(q) + K_g dg(q, dq) <= 0
9 | """
10 |
11 | def __init__(self, dim_q, dim_out, fun, J, b, K):
12 | """
13 | Constructor of the viability constraint
14 |
15 | Args
16 | dim_q (int): Dimension of the controllable variable
17 | dim_out (int): Dimension of the constraint
18 | fun (function): The constraint function f(q) or g(q)
19 | J (function): The Jacobian matrix of J_f(q) or J_g(q)
20 | b (function): The term: dJ(q, dq) dq
21 | K (scalar or array): The scale variable K_f or K_g
22 | """
23 | self.dim_q = dim_q
24 | self.dim_out = dim_out
25 | self.fun_origin = fun
26 | if np.isscalar(K):
27 | self.K = np.ones(dim_out) * K
28 | else:
29 | self.K = K
30 | self.J = J
31 | self.b_state = b
32 |
33 | def fun(self, q, dq, origin_constr=False):
34 | if origin_constr:
35 | return self.fun_origin(q)
36 | else:
37 | return self.fun_origin(q) + self.K * (self.J(q) @ dq)
38 |
39 | def K_J(self, q):
40 | return np.diag(self.K) @ self.J(q)
41 |
42 | def b(self, q, dq):
43 | return self.J(q) @ dq + self.K * self.b_state(q, dq)
44 |
45 |
46 | class ConstraintsSet:
47 | """
48 | The class to gather multiple constraints
49 | """
50 |
51 | def __init__(self, dim_q):
52 | self.dim_q = dim_q
53 | self.constraints_list = list()
54 | self.dim_out = 0
55 |
56 | def add_constraint(self, c: ViabilityConstraint):
57 | self.dim_out += c.dim_out
58 | self.constraints_list.append(c)
59 |
60 | def fun(self, q, dq, origin_constr=False):
61 | ret = np.zeros(self.dim_out)
62 | i = 0
63 | for c in self.constraints_list:
64 | ret[i:i + c.dim_out] = c.fun(q, dq, origin_constr)
65 | i += c.dim_out
66 | return ret
67 |
68 | def K_J(self, q):
69 | ret = np.zeros((self.dim_out, self.dim_q))
70 | i = 0
71 | for c in self.constraints_list:
72 | ret[i:i + c.dim_out] = c.K_J(q)
73 | i += c.dim_out
74 | return ret
75 |
76 | def b(self, q, dq):
77 | ret = np.zeros(self.dim_out)
78 | i = 0
79 | for c in self.constraints_list:
80 | ret[i:i + c.dim_out] = c.b(q, dq)
81 | i += c.dim_out
82 | return ret
83 |
--------------------------------------------------------------------------------
/atacom/environments/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/atacom/environments/circular_motion/__init__.py:
--------------------------------------------------------------------------------
1 | from .circle_error_correction import CircleEnvErrorCorrection
2 | from .circle_terminated import CircleEnvTerminated
3 | from .circle_atacom import CircleEnvAtacom
4 |
--------------------------------------------------------------------------------
/atacom/environments/circular_motion/circle_atacom.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from atacom.environments.circular_motion.circle_base import CircularMotion
3 | from atacom.atacom import AtacomEnvWrapper
4 | from atacom.constraints import ViabilityConstraint, ConstraintsSet
5 |
6 | class CircleEnvAtacom(AtacomEnvWrapper):
7 | def __init__(self, horizon=500, gamma=0.99, random_init=False, Kc=100, time_step=0.01):
8 | base_env = CircularMotion(random_init=random_init, horizon=horizon, gamma=gamma)
9 | circle_constr = ViabilityConstraint(2, 1, fun=self.circle_fun, J=self.circle_J, b=self.circle_b, K=0.1)
10 |
11 | height_constr = ViabilityConstraint(2, 1, fun=self.height_fun, J=self.height_J, b=self.height_b, K=2)
12 |
13 | f = ConstraintsSet(2)
14 | f.add_constraint(circle_constr)
15 | g = ConstraintsSet(2)
16 | g.add_constraint(height_constr)
17 |
18 | super().__init__(base_env=base_env, dim_q=2, f=f, g=g, Kc=Kc, acc_max=10, vel_max=1, Kq=20, time_step=time_step)
19 |
20 | def _get_q(self, state):
21 | return state[:2]
22 |
23 | def _get_dq(self, state):
24 | return state[2:4]
25 |
26 | def acc_to_ctrl_action(self, ddq):
27 | return ddq / self.acc_max
28 |
29 | def render(self):
30 | offset = np.array([1.25, 1.25])
31 | pos = self.state[:2] + offset
32 |
33 | act_a = self._act_a[:2]
34 | act_b = self._act_b[:2]
35 |
36 | self.env._viewer.force_arrow(center=pos, direction=act_a,
37 | force=np.linalg.norm(act_a),
38 | max_force=3, width=5,
39 | max_length=0.3, color=(255, 165, 0))
40 |
41 | self.env._viewer.force_arrow(center=pos, direction=act_b,
42 | force=np.linalg.norm(act_b),
43 | max_force=10, width=5,
44 | max_length=0.3, color=(0, 255, 255))
45 | super().render()
46 |
47 | @staticmethod
48 | def circle_fun(q):
49 | return np.array([q[0] ** 2 + q[1] ** 2 - 1])
50 |
51 | @staticmethod
52 | def circle_J(q):
53 | return np.array([[2 * q[0], 2 * q[1]]])
54 |
55 | @staticmethod
56 | def circle_b(q, dq):
57 | return np.array([[2 * dq[0], 2 * dq[1]]]) @ dq
58 |
59 | @staticmethod
60 | def height_fun(q):
61 | return np.array([-q[1] - 0.5])
62 |
63 | @staticmethod
64 | def height_J(q):
65 | return np.array([[0, -1]])
66 |
67 | @staticmethod
68 | def height_b(q, dq):
69 | return np.array([0])
70 |
71 | @staticmethod
72 | def vel_fun(q, dq):
73 | return np.array([dq[0] ** 2 + dq[1] ** 2 - 1])
74 |
75 | @staticmethod
76 | def vel_A(q, dq):
77 | return np.array([[2 * dq[0], 2 * dq[1]]])
78 |
79 | @staticmethod
80 | def vel_b(q, dq):
81 | return np.array([0.])
82 |
--------------------------------------------------------------------------------
/atacom/environments/circular_motion/circle_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import FormatStrFormatter
5 |
6 | from mushroom_rl.core import Environment, MDPInfo
7 | from mushroom_rl.utils import spaces
8 | from mushroom_rl.utils.viewer import Viewer
9 |
10 |
11 | class CircularMotion(Environment):
12 | """
13 | Base environment of CircularMotion Environment
14 | A point is moving on the 2D unit circular_motion.
15 | Control actions are 2d acceleration along each direction
16 | """
17 |
18 | def __init__(self, time_step=0.01, horizon=500, gamma=0.99, random_init=False):
19 | self.time_step = time_step
20 | self.random_init = random_init
21 | inf_array = np.ones(4) * np.inf
22 | observation_space = spaces.Box(low=-inf_array, high=inf_array)
23 | action_space = spaces.Box(low=-np.ones(2) * 1, high=np.ones(2) * 1)
24 | mdp_info = MDPInfo(observation_space, action_space, gamma, horizon)
25 | self._viewer = Viewer(env_width=2.5, env_height=2.5, background=(255, 255, 255))
26 | self.step_action_function = None
27 | self.action_scale = np.array([10., 10.])
28 | super().__init__(mdp_info)
29 |
30 | self.c = np.zeros(3)
31 | self.constr_logs = list()
32 |
33 | def reset(self, state=None):
34 | if state is None:
35 | if self.random_init:
36 | y = np.random.uniform(-0.5, 1)
37 | x = np.sqrt(1 - y ** 2) * np.sign(np.random.uniform(-1, 1))
38 | dx = np.random.uniform(-1, 1)
39 | dy = -x * dx / y
40 | v = np.array([dx, dy])
41 | v = v / np.linalg.norm(v) * np.random.uniform(0, 1)
42 | self._state = np.array([x, y, v[0], v[1]])
43 | else:
44 | self._state = np.array([-1., 0., 0., 0.0])
45 | else:
46 | if abs(state[0] ** 2 + state[1] ** 2 - 1) < 1e-6 and abs(state[0] * state[2] - state[1] * state[3]) < 1e-6:
47 | self._state = state
48 | else:
49 | raise ValueError("Can not reset to the state: ", state)
50 |
51 | return self._state
52 |
53 | def step(self, action):
54 | self.check_constraint()
55 |
56 | if self.step_action_function is not None:
57 | action = self.step_action_function(self._state, action)
58 |
59 | self._action = np.clip(action, self.info.action_space.low, self.info.action_space.high)
60 | self._action = self._action * self.action_scale
61 |
62 | self._state[:2] += self._state[2:4] * self.time_step + self._action * self.time_step ** 2 / 2
63 | self._state[2:4] += self._action * self.time_step
64 |
65 | reward = np.exp(-np.linalg.norm(np.array([1., 0.]) - self._state[:2]))
66 |
67 | return self._state, reward, False, dict()
68 |
69 | def render(self):
70 | offset = np.array([1.25, 1.25])
71 | self._viewer.circle(center=offset, radius=1., color=(0., 0., 255), width=5)
72 | self._viewer.line(start=np.array([-1.25, -0.5]) + offset, end=np.array([1.25, -0.5]) + offset,
73 | color=(255, 20, 147), width=3)
74 | self._viewer.square(center=np.array([1.0, 0.0]) + offset, angle=0, edge=0.05, color=(50, 205, 50))
75 |
76 | pos = self._state[:2] + offset
77 | self._viewer.circle(center=pos, radius=0.03, color=(255, 0, 0))
78 | self._viewer.display(self.time_step)
79 |
80 | def _create_sim_state(self):
81 | return self._state
82 |
83 | def _create_observation(self, state):
84 | return state
85 |
86 | def check_constraint(self):
87 | q = self._state[:2]
88 | dq = self._state[2:4]
89 |
90 | self.c = self.get_c(q, dq)
91 | self.c[0] = np.abs(self.c[0])
92 | self.constr_logs.append(self.c)
93 |
94 | def get_c(self, q, dq):
95 | return np.concatenate([self.c_1(q), self.c_2(q), self.c_3(q, dq)])
96 |
97 | @staticmethod
98 | def c_1(q):
99 | return np.array([q[0] ** 2 + q[1] ** 2 - 1])
100 |
101 | @staticmethod
102 | def c_2(q):
103 | return np.array([-q[1] - 0.5])
104 |
105 | @staticmethod
106 | def c_3(q, dq):
107 | return np.abs(dq) - 1
108 |
109 | def get_constraints_logs(self):
110 | constr_logs = np.array(self.constr_logs)
111 | c_avg = np.mean(np.max(constr_logs[:, :2], axis=1))
112 | c_max = np.max(constr_logs[:, :2])
113 | c_dq_max = np.max(constr_logs[:, 2:])
114 | self.constr_logs.clear()
115 | return c_avg, c_max, c_dq_max
116 |
117 |
118 | def plot_2d_circle(dataset, save_dir="", suffix=""):
119 | plt.rcParams.update({
120 | "text.usetex": True,
121 | "font.serif": ["Times"]})
122 | state_list = list()
123 |
124 | if suffix != '':
125 | suffix = suffix + "-"
126 |
127 | i = 0
128 | for data in dataset:
129 | state = data[0]
130 | state_list.append(state)
131 | if data[-1]:
132 | i += 1
133 | state_hist = np.array(state_list)
134 |
135 | fig = plt.figure()
136 | axes_circle = plt.subplot2grid((3, 2), (0, 0), rowspan=2, colspan=1)
137 | axes_c1 = plt.subplot2grid((3, 2), (0, 1))
138 | axes_c2 = plt.subplot2grid((3, 2), (1, 1))
139 | axes_c3 = plt.subplot2grid((3, 2), (2, 1))
140 | fig.subplots_adjust(hspace=.5)
141 | fig.subplots_adjust(wspace=.1)
142 |
143 | axes_circle.plot(state_hist[:, 0], state_hist[:, 1])
144 | # reference circular_motion
145 | x_1 = np.linspace(-1, 1, 100)
146 | y_1 = np.sqrt(1 - x_1 ** 2)
147 | x_2 = np.linspace(1, -1, 100)
148 | y_2 = -np.sqrt(1 - x_2 ** 2)
149 | x = np.concatenate([x_1, x_2])
150 | y = np.concatenate([y_1, y_2])
151 | axes_circle.plot(x, y, ls=':', label="$c_1$")
152 |
153 | # line
154 | # x_3 = np.linspace(-1.2, 1.2, 100)
155 | # y_3 = np.ones_like(x_3) * -0.5
156 | # axes_circle.plot(x_3, y_3, ls=':', c='tab:red', label="c2")
157 | axes_circle.fill_between([-1.2, 1.2], [-0.5, -0.5], [-1.2, -1.2], color='tab:red', alpha=0.3, label="$c_2$")
158 |
159 | axes_circle.set_aspect(1.0)
160 | axes_circle.set_xlim(-1.2, 1.2)
161 | axes_circle.set_ylim(-1.2, 1.2)
162 | axes_circle.set_ylim(-1.2, 1.2)
163 | axes_circle.legend(loc='upper right')
164 |
165 | # c1
166 | axes_c1.plot(state_hist[:, 0] ** 2 + state_hist[:, 1] ** 2 - 1)
167 | max_c = np.max(np.abs(state_hist[:, 0] ** 2 + state_hist[:, 1] ** 2 - 1))
168 | axes_c1.plot(np.zeros_like(state_hist[:, 2]), c='tab:red', label="c1")
169 | axes_c1.yaxis.tick_right()
170 | axes_c1.yaxis.set_label_position("right")
171 | axes_c1.set_title("$c_1$")
172 | axes_c1.set_ylim(-max_c, max_c)
173 | axes_c1.yaxis.set_major_formatter(FormatStrFormatter('%.4f'))
174 |
175 | # c2
176 | axes_c2.plot(-state_hist[:, 1] - 0.5)
177 | axes_c2.plot(np.zeros_like(state_hist[:, 2]) * 1, c='tab:red', ls='--', label="c2")
178 | axes_c2.yaxis.tick_right()
179 | axes_c2.yaxis.set_label_position("right")
180 | axes_c2.set_title("$c_2$")
181 | axes_c2.yaxis.set_major_formatter(FormatStrFormatter('%.4f'))
182 |
183 | # c3
184 | axes_c3.plot(state_hist[:, 2] - 1, label="c3")
185 | axes_c3.plot(state_hist[:, 3] - 1, label="c4")
186 | axes_c3.plot(np.zeros_like(state_hist[:, 2]) * 1, c='tab:red', ls='--', label="c3")
187 | axes_c3.yaxis.tick_right()
188 | axes_c3.yaxis.set_label_position("right")
189 | axes_c3.yaxis.set_major_formatter(FormatStrFormatter('%.4f'))
190 | axes_c3.set_title("$c_3 \& c_4$")
191 |
192 | textstr = '\n'.join(('$c_1: x^2 + y^2 - 1 = 0$',
193 | '$c_2: - y - 0.5 < 0$',
194 | '$c_3: \dot{x} - 1 < 0$',
195 | '$c_4: \dot{y} - 1 < 0$'))
196 | props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
197 | axes_circle.text(0.5, -0.2, textstr, transform=axes_circle.transAxes, fontsize=14,
198 | verticalalignment='top', horizontalalignment='center', bbox=props)
199 |
200 | filename = "circular_motion-" + suffix + str(i) + ".pdf"
201 |
202 | if not os.path.exists(save_dir):
203 | os.makedirs(save_dir)
204 |
205 | plt.savefig(os.path.join(save_dir, filename))
206 | plt.close(fig)
207 | state_list = list()
208 |
--------------------------------------------------------------------------------
/atacom/environments/circular_motion/circle_error_correction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from atacom.constraints import ViabilityConstraint, ConstraintsSet
3 | from atacom.environments.circular_motion.circle_base import CircularMotion
4 | from atacom.error_correction_wrapper import ErrorCorrectionEnvWrapper
5 |
6 |
7 | class CircleEnvErrorCorrection(ErrorCorrectionEnvWrapper):
8 | def __init__(self, horizon=500, gamma=0.99, random_init=False, Kc=100., time_step=0.01):
9 | base_env = CircularMotion(random_init=random_init, horizon=horizon, gamma=gamma)
10 |
11 | circle_constr = ViabilityConstraint(dim_q=2, dim_out=1, fun=self.circle_fun, J=self.circle_J, b=self.circle_b,
12 | K=0.1)
13 | height_constr = ViabilityConstraint(dim_q=2, dim_out=1, fun=self.height_fun, J=self.height_J, b=self.height_b,
14 | K=2)
15 |
16 | f = ConstraintsSet(2)
17 | f.add_constraint(circle_constr)
18 | g = ConstraintsSet(2)
19 | g.add_constraint(height_constr)
20 |
21 | super().__init__(base_env=base_env, dim_q=2, f=f, g=g, Kc=Kc, acc_max=10, vel_max=1, Kq=20, time_step=time_step)
22 |
23 | def _get_q(self, state):
24 | return state[:2]
25 |
26 | def _get_dq(self, state):
27 | return state[2:4]
28 |
29 | def acc_to_ctrl_action(self, ddq):
30 | return ddq / self.acc_max
31 |
32 | def render(self):
33 | offset = np.array([1.25, 1.25])
34 | pos = self.state[:2] + offset
35 |
36 | act_a = self._act_a[:2]
37 | act_b = self._act_b[:2]
38 | act_err = self._act_err[:2]
39 |
40 | self.env._viewer.force_arrow(center=pos, direction=act_b,
41 | force=np.linalg.norm(act_b),
42 | max_force=10, width=5,
43 | max_length=0.3, color=(0, 255, 255))
44 | super().render()
45 |
46 | @staticmethod
47 | def circle_fun(q):
48 | return np.array([q[0] ** 2 + q[1] ** 2 - 1])
49 |
50 | @staticmethod
51 | def circle_J(q):
52 | return np.array([[2 * q[0], 2 * q[1]]])
53 |
54 | @staticmethod
55 | def circle_b(q, dq):
56 | return np.array([[2 * dq[0], 2 * dq[1]]]) @ dq
57 |
58 | @staticmethod
59 | def height_fun(q):
60 | return np.array([-q[1] - 0.5])
61 |
62 | @staticmethod
63 | def height_J(q):
64 | return np.array([[0, -1]])
65 |
66 | @staticmethod
67 | def height_b(q, dq):
68 | return np.array([0])
69 |
70 | @staticmethod
71 | def vel_fun(q, dq):
72 | return np.array([dq[0] ** 2 + dq[1] ** 2 - 1])
73 |
74 | @staticmethod
75 | def vel_A(q, dq):
76 | return np.array([[2 * dq[0], 2 * dq[1]]])
77 |
78 | @staticmethod
79 | def vel_b(q, dq):
80 | return np.array([0.])
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/atacom/environments/circular_motion/circle_terminated.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from atacom.environments.circular_motion.circle_base import CircularMotion
3 |
4 | from mushroom_rl.core import Core, Agent
5 | from mushroom_rl.utils.dataset import compute_J
6 |
7 |
8 | class CircleEnvTerminated(CircularMotion):
9 | """
10 | CircularMotion with termination when constraint > tolerance
11 | """
12 |
13 | def __init__(self, time_step=0.01, horizon=500, gamma=0.99, random_init=False, tol=0.1):
14 | super().__init__(time_step=time_step, horizon=horizon, gamma=gamma, random_init=random_init)
15 | self._tol = tol
16 |
17 | def step(self, action):
18 | state, reward, absorbing, info = super().step(action)
19 |
20 | absorbing = absorbing or self._terminate()
21 | if absorbing:
22 | reward = -100
23 | return state, reward, absorbing, info
24 |
25 | def _terminate(self):
26 | if np.any(self.c > self._tol):
27 | return True
28 | else:
29 | return False
30 |
31 | def render(self):
32 | offset = np.array([1.25, 1.25])
33 | pos = self._state[:2] + offset
34 |
35 | act_b = self._action[:2]
36 | self._viewer.force_arrow(center=pos, direction=act_b,
37 | force=np.linalg.norm(act_b),
38 | max_force=10, width=5,
39 | max_length=0.3, color=(0, 255, 255))
40 | super().render()
41 |
42 |
43 | def env_test():
44 | env = CircleEnvTerminated(tol=0.1)
45 |
46 | class DummyAgent(Agent):
47 | def __init__(self, mdp_info):
48 | self.mdp_info = mdp_info
49 |
50 | def fit(self, dataset):
51 | pass
52 |
53 | def episode_start(self):
54 | pass
55 |
56 | def draw_action(self, state):
57 | return np.random.randn(self.mdp_info.action_space.shape[0]) * 1
58 |
59 | agent = DummyAgent(env.info)
60 |
61 | core = Core(agent, env)
62 |
63 | dataset = core.evaluate(n_steps=1000, render=True)
64 |
65 | J = np.mean(compute_J(dataset, core.mdp.info.gamma))
66 | R = np.mean(compute_J(dataset))
67 | c_avg, c_max, c_dq_max = env.get_constraints_logs()
68 | print("J: {}, R:{}, c_avg:{}, c_max:{}, c_dq_max:{}".format(J, R, c_avg, c_max, c_dq_max))
69 |
70 |
71 | if __name__ == '__main__':
72 | env_test()
73 |
--------------------------------------------------------------------------------
/atacom/environments/collision_avoidance/__init__.py:
--------------------------------------------------------------------------------
1 | from .collision_avoidance_atacom import PointReachAtacom
--------------------------------------------------------------------------------
/atacom/environments/collision_avoidance/collision_avoidance_atacom.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.core import Core, Agent
2 | from mushroom_rl.utils.spaces import *
3 | from mushroom_rl.utils.dataset import compute_J
4 | from atacom.environments.collision_avoidance.collision_avoidance_base import PointGoalReach
5 | from atacom.utils.null_space_coordinate import pinv_null, rref
6 |
7 |
8 | class PointReachAtacom(PointGoalReach):
9 | def __init__(self, time_step=0.01, horizon=1000, gamma=0.99, n_objects=4, random_walk=False):
10 | super().__init__(time_step=time_step, horizon=horizon, gamma=gamma, n_objects=n_objects,
11 | random_walk=random_walk)
12 | self.s = np.zeros(self.n_objects)
13 | self.Kc = np.diag(np.ones(self.n_objects) * 100)
14 | self.K = np.ones(self.n_objects) * 0.5
15 |
16 | self.constr_logs = list()
17 |
18 | def reset(self, state=None):
19 | super(PointReachAtacom, self).reset(state)
20 | self.q = self.get_q(self._state)
21 | self.dq = self.get_dq(self._state)
22 | self.p = self.get_p(self._state)
23 | self.dp = self.get_dp(self._state)
24 |
25 | self.s = np.sqrt(np.maximum(- 2 * self.get_c(self.q, self.p), 0.))
26 | # self.s = np.maximum(- self.get_c(self.q, self.p), 0.)
27 | return self._state
28 |
29 | def step(self, action):
30 | self.q = self.get_q(self._state)
31 | self.dq = self.get_dq(self._state)
32 | self.p = self.get_p(self._state)
33 | self.dp = self.get_dp(self._state)
34 |
35 | Jc_q = self.get_Jc_q(self.q, self.p)
36 | Jc_q_inv, Nc_q = pinv_null(Jc_q)
37 |
38 | c_origin = self.get_c(self.q, self.p)
39 | c_dq_i = 0.
40 | self.constr_logs.append([np.max(c_origin), np.max(c_dq_i)])
41 | c = c_origin + 1 / 2 * self.s ** 2 + self.K * (self.get_dc(self.q, self.p, self.dq, self.dp))
42 |
43 | Nc = rref(Nc_q, row_vectors=False)
44 |
45 | psi = self.get_psi(self.q, self.p, self.dq, self.dp)
46 | action = - Jc_q_inv @ (psi + self.Kc @ c) + Nc @ action
47 | self.s += action[2:] * self.time_step
48 | return super().step(action[:2])
49 |
50 | @staticmethod
51 | def get_q(state):
52 | return state[:2]
53 |
54 | @staticmethod
55 | def get_dq(state):
56 | return state[2:4]
57 |
58 | def get_p(self, state):
59 | p = np.zeros(2 * self.n_objects)
60 | for i in range(self.n_objects):
61 | idx = 4 * (i + 1)
62 | p[2 * i: 2 * i + 2] = state[idx: idx + 2]
63 | return p
64 |
65 | def get_dp(self, state):
66 | dp = np.zeros(2 * self.n_objects)
67 | for i in range(self.n_objects):
68 | idx = 4 * (i + 1)
69 | dp[2 * i: 2 * i + 2] = state[idx + 2: idx + 4]
70 | return dp
71 |
72 | def get_c(self, q, p):
73 | c_out = np.zeros(self.n_objects)
74 | for i in range(self.n_objects):
75 | c_out[i] = 0.6 ** 2 - np.linalg.norm(q - p[2 * i: 2 * i + 2]) ** 2
76 | return c_out
77 |
78 | def get_dc(self, q, p, dq, dp):
79 | dc_out = np.zeros(self.n_objects)
80 | for i in range(self.n_objects):
81 | p_i = p[2 * i:2 * i + 2]
82 | dp_i = dp[2 * i:2 * i + 2]
83 | dc_out[i] = self.get_Jp(q, p_i) @ dp_i + self.get_Jq(q, p_i) @ dq
84 | return dc_out
85 |
86 | @staticmethod
87 | def get_Jq(q, p_i):
88 | return -2 * (q - p_i)
89 |
90 | @staticmethod
91 | def get_Jp(q, p_i):
92 | return 2 * (q - p_i)
93 |
94 | @staticmethod
95 | def get_Hqq(q, p_i):
96 | return -2 * np.eye(2)
97 |
98 | @staticmethod
99 | def get_Hqp(q, p_i):
100 | return 2 * np.eye(2)
101 |
102 | @staticmethod
103 | def get_Hpp(q, p_i):
104 | return -2 * np.eye(2)
105 |
106 | def get_bp(self, q, p_i, dq, dp_i):
107 | return (p_i @ self.get_Hpp(q, p_i) + q @ self.get_Hqp(q, p_i)) @ p_i
108 |
109 | def get_bq(self, q, p_i, dq, dp_i):
110 | return (q @ self.get_Hqq(q, p_i) + p_i @ self.get_Hqp(q, p_i)) @ q
111 |
112 | def get_Jc_q(self, q, p):
113 | Jc_q = np.zeros((self.n_objects, self.n_objects + 2))
114 | for i in range(self.n_objects):
115 | p_i = p[2 * i:2 * i + 2]
116 | Jc_q[i, :2] = self.get_Jq(q, p_i)
117 | Jc_q[:, 2:] = np.diag(self.s)
118 | return Jc_q
119 |
120 | def get_psi(self, q, p, dq, dp):
121 | psi = np.zeros(self.n_objects)
122 | for i in range(self.n_objects):
123 | p_i = p[2 * i:2 * i + 2]
124 | dp_i = dp[2 * i:2 * i + 2]
125 | psi[i] = self.get_Jp(q, p_i) @ dp_i + self.get_Jq(q, p_i) @ dq + \
126 | self.K[i] * (self.get_bp(q, p_i, dq, dp_i) + self.get_bq(q, p_i, dq, dp_i))
127 | return psi
128 |
129 | def get_constraints_logs(self):
130 | constr_logs = np.array(self.constr_logs)
131 | c_avg = np.mean(constr_logs[:, 0])
132 | c_max = np.max(constr_logs[:, 0])
133 | c_dq_max = np.max(constr_logs[:, 1])
134 | self.constr_logs.clear()
135 | return c_avg, c_max, c_dq_max
136 |
137 |
138 |
--------------------------------------------------------------------------------
/atacom/environments/collision_avoidance/collision_avoidance_base.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.core import Environment, MDPInfo
2 | from mushroom_rl.utils.spaces import *
3 | from mushroom_rl.utils.viewer import Viewer
4 |
5 |
6 | class PointGoalReach(Environment):
7 | def __init__(self, time_step=0.01, horizon=500, gamma=0.99, n_objects=2, random_walk=False):
8 | self.time_step = time_step
9 | self.n_objects = n_objects
10 | self.random_walk = random_walk
11 |
12 | self.state_dim = 4 * (1 + n_objects)
13 | observation_space = Box(low=-np.ones(self.state_dim) * 10, high=np.ones(self.state_dim) * 10)
14 | action_space = Box(low=-np.ones(2) * 1, high=np.ones(2) * 1)
15 | mdp_info = MDPInfo(observation_space, action_space, gamma, horizon)
16 |
17 | self._viewer = Viewer(env_width=10, env_height=10, background=(255, 255, 255))
18 | self.step_action_function = None
19 | self.action_scale = np.array([10., 10.])
20 | self._obj_circle_center = []
21 | self._obj_radius = 2
22 | self._time = 0.
23 | super().__init__(mdp_info)
24 |
25 | def reset(self, state=None):
26 | self._time = 0.
27 | self._state = np.zeros(self.state_dim)
28 |
29 | self._state[:2] = np.array([1., 1.])
30 | self._state[2:4] = np.array([0., 0.])
31 |
32 | for i in range(self.n_objects):
33 | obj_idx = 4 * (i + 1)
34 | self._state[obj_idx:obj_idx + 2] = np.random.uniform(2, 8, 2)
35 | self._obj_circle_center.append(self._state[obj_idx:obj_idx + 2] - np.array([self._obj_radius, 0.]))
36 | self._state[obj_idx + 2:obj_idx + 4] = np.zeros(2)
37 |
38 | return self._state
39 |
40 | def step(self, action):
41 | if self.step_action_function is not None:
42 | action = self.step_action_function(self._state, action)
43 |
44 | self._action = np.clip(action, self.info.action_space.low, self.info.action_space.high)
45 | self._action = self._action * self.action_scale
46 |
47 | self._state[:2] += self._state[2:4] * self.time_step
48 | self._state[2:4] += self._action * self.time_step
49 |
50 | change_sign = np.logical_or(self._state[0:2] <= 0, self._state[0:2] >= 10)
51 | if np.any(change_sign):
52 | self._state[2:4][change_sign] = - self._state[2:4][change_sign]
53 |
54 | for i in range(self.n_objects):
55 | obj_idx = 4 * (i + 1)
56 |
57 | if self.random_walk:
58 | self._state[obj_idx:obj_idx + 2] += self._state[obj_idx + 2:obj_idx + 4] * self.time_step
59 | self._state[obj_idx:obj_idx + 2] = np.clip(self._state[obj_idx:obj_idx + 2], 2, 10)
60 | obj_action = np.random.uniform(-1, 1, 2) * 10
61 | change_sign = np.logical_or(self._state[obj_idx:obj_idx + 2] <= 2,
62 | self._state[obj_idx:obj_idx + 2] >= 10)
63 | if np.any(change_sign):
64 | self._state[obj_idx + 2:obj_idx + 4][change_sign] = - self._state[obj_idx + 2:obj_idx + 4][
65 | change_sign]
66 | self._state[obj_idx + 2:obj_idx + 4] += obj_action * self.time_step
67 | self._state[obj_idx + 2:obj_idx + 4] = np.clip(self._state[obj_idx + 2:obj_idx + 4], -1, 1)
68 | else:
69 | self._state[obj_idx] = self._obj_circle_center[i][0] + self._obj_radius * np.cos(self._time * 2 * np.pi)
70 | self._state[obj_idx + 1] = self._obj_circle_center[i][1] + self._obj_radius * np.sin(
71 | self._time * 2 * np.pi)
72 | self._state[obj_idx + 2] = -2 * self._obj_radius * np.pi * np.sin(self._time * 2 * np.pi)
73 | self._state[obj_idx + 3] = 2 * self._obj_radius * np.pi * np.cos(self._time * 2 * np.pi)
74 |
75 | self._time += self.time_step
76 | reward = - np.linalg.norm(np.array([9., 9.]) - self._state[:2]) / (8 * np.sqrt(2))
77 | return self._state, reward, False, dict()
78 |
79 | def render(self):
80 | self._viewer.circle(center=self._state[:2], radius=0.3, color=(0., 0., 255), width=0)
81 |
82 | for i in range(self.n_objects):
83 | obj_idx = 4 * (i + 1)
84 | self._viewer.circle(center=self._state[obj_idx:obj_idx + 2], radius=0.3, color=(255., 0., 0), width=0)
85 |
86 | self._viewer.square(center=np.array([9., 9.]), angle=0, edge=0.6, color=(0., 255., 0))
87 |
88 | self._viewer.display(self.time_step * 0.4)
89 |
90 | def _create_sim_state(self):
91 | return self._state
92 |
93 | def _create_observation(self, state):
94 | return state
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/__init__.py:
--------------------------------------------------------------------------------
1 | from .iiwa_hit_atacom import AirHockeyIiwaAtacom
2 | from .iiwa_air_hockey_rmp import AirHockeyIiwaRmp
3 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/env_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pinocchio as pino
4 | import pybullet
5 | import pybullet_utils.transformations as transformations
6 | from mushroom_rl.environments.pybullet import PyBullet, PyBulletObservationType
7 | from mushroom_rl.environments.pybullet_envs import __file__ as env_path
8 |
9 |
10 | class AirHockeyBase(PyBullet):
11 | def __init__(self, gamma=0.99, horizon=500, timestep=1 / 240., n_intermediate_steps=1, debug_gui=False,
12 | n_agents=1, env_noise=False, obs_noise=False, obs_delay=False, torque_control=True,
13 | step_action_function=None, isolated_joint_7=False):
14 | self.n_agents = n_agents
15 | self.env_noise = env_noise
16 | self.obs_noise = obs_noise
17 | self.obs_delay = obs_delay
18 | self.step_action_function = step_action_function
19 | self.isolated_joint_7 = isolated_joint_7
20 |
21 | puck_file = os.path.join(os.path.dirname(os.path.abspath(env_path)),
22 | "data", "air_hockey", "puck.urdf")
23 | table_file = os.path.join(os.path.dirname(os.path.abspath(env_path)),
24 | "data", "air_hockey", "air_hockey_table.urdf")
25 | robot_file_1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "urdf", "iiwa_1.urdf")
26 | robot_file_2 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "urdf", "iiwa_2.urdf")
27 |
28 | model_files = dict()
29 | model_files[puck_file] = dict(flags=pybullet.URDF_USE_IMPLICIT_CYLINDER,
30 | basePosition=[0.0, 0, 0], baseOrientation=[0, 0, 0.0, 1.0])
31 | model_files[table_file] = dict(useFixedBase=True, basePosition=[0.0, 0, 0],
32 | baseOrientation=[0, 0, 0.0, 1.0])
33 |
34 | actuation_spec = list()
35 | observation_spec = [("puck", PyBulletObservationType.BODY_POS),
36 | ("puck", PyBulletObservationType.BODY_LIN_VEL),
37 | ("puck", PyBulletObservationType.BODY_ANG_VEL)]
38 | self.agents = []
39 |
40 | if torque_control:
41 | control = pybullet.TORQUE_CONTROL
42 | else:
43 | control = pybullet.POSITION_CONTROL
44 |
45 | if 1 <= self.n_agents <= 2:
46 | agent_spec = dict()
47 | agent_spec['name'] = "iiwa_1"
48 | agent_spec.update({"urdf": os.path.join(os.path.dirname(os.path.abspath(__file__)),
49 | "urdf", "iiwa_1.urdf")})
50 | translate = [-1.51, 0, -0.1]
51 | quaternion = [0.0, 0.0, 0.0, 1.0]
52 | agent_spec['frame'] = transformations.translation_matrix(translate)
53 | agent_spec['frame'] = agent_spec['frame'] @ transformations.quaternion_matrix(quaternion)
54 | model_files[robot_file_1] = dict(
55 | flags=pybullet.URDF_USE_IMPLICIT_CYLINDER | pybullet.URDF_USE_INERTIA_FROM_FILE,
56 | basePosition=translate, baseOrientation=quaternion)
57 |
58 | self.agents.append(agent_spec)
59 | actuation_spec += [("iiwa_1/joint_1", control),
60 | ("iiwa_1/joint_2", control),
61 | ("iiwa_1/joint_3", control),
62 | ("iiwa_1/joint_4", control),
63 | ("iiwa_1/joint_5", control),
64 | ("iiwa_1/joint_6", control)]
65 | if self.isolated_joint_7:
66 | actuation_spec += [("iiwa_1/joint_7", pybullet.POSITION_CONTROL)]
67 | else:
68 | actuation_spec += [("iiwa_1/joint_7", control)]
69 | actuation_spec += [("iiwa_1/striker_joint_1", pybullet.POSITION_CONTROL),
70 | ("iiwa_1/striker_joint_2", pybullet.POSITION_CONTROL)]
71 |
72 | observation_spec += [("iiwa_1/joint_1", PyBulletObservationType.JOINT_POS),
73 | ("iiwa_1/joint_2", PyBulletObservationType.JOINT_POS),
74 | ("iiwa_1/joint_3", PyBulletObservationType.JOINT_POS),
75 | ("iiwa_1/joint_4", PyBulletObservationType.JOINT_POS),
76 | ("iiwa_1/joint_5", PyBulletObservationType.JOINT_POS),
77 | ("iiwa_1/joint_6", PyBulletObservationType.JOINT_POS),
78 | ("iiwa_1/joint_7", PyBulletObservationType.JOINT_POS),
79 | ("iiwa_1/striker_joint_1", PyBulletObservationType.JOINT_POS),
80 | ("iiwa_1/striker_joint_2", PyBulletObservationType.JOINT_POS),
81 | ("iiwa_1/joint_1", PyBulletObservationType.JOINT_VEL),
82 | ("iiwa_1/joint_2", PyBulletObservationType.JOINT_VEL),
83 | ("iiwa_1/joint_3", PyBulletObservationType.JOINT_VEL),
84 | ("iiwa_1/joint_4", PyBulletObservationType.JOINT_VEL),
85 | ("iiwa_1/joint_5", PyBulletObservationType.JOINT_VEL),
86 | ("iiwa_1/joint_6", PyBulletObservationType.JOINT_VEL),
87 | ("iiwa_1/joint_7", PyBulletObservationType.JOINT_VEL),
88 | ("iiwa_1/striker_joint_1", PyBulletObservationType.JOINT_VEL),
89 | ("iiwa_1/striker_joint_2", PyBulletObservationType.JOINT_VEL),
90 | ("iiwa_1/striker_mallet_tip", PyBulletObservationType.LINK_POS),
91 | ("iiwa_1/striker_mallet_tip", PyBulletObservationType.LINK_LIN_VEL)]
92 |
93 | if self.n_agents == 2:
94 | agent_spec = dict()
95 | agent_spec['name'] = "iiwa_2"
96 | agent_spec.update({"urdf": os.path.join(os.path.dirname(os.path.abspath(__file__)),
97 | "urdf", "iiwa_pino.urdf")})
98 | translate = [1.51, 0, -0.1]
99 | quaternion = [0.0, 0.0, 1.0, 0.0]
100 | agent_spec['frame'] = transformations.translation_matrix(translate)
101 | agent_spec['frame'] = agent_spec['frame'] @ transformations.quaternion_matrix(quaternion)
102 | model_files[robot_file_2] = dict(
103 | flags=pybullet.URDF_USE_IMPLICIT_CYLINDER | pybullet.URDF_USE_INERTIA_FROM_FILE,
104 | basePosition=translate, baseOrientation=quaternion)
105 | self.agents.append(agent_spec)
106 |
107 | actuation_spec += [("iiwa_2/joint_1", control),
108 | ("iiwa_2/joint_2", control),
109 | ("iiwa_2/joint_3", control),
110 | ("iiwa_2/joint_4", control),
111 | ("iiwa_2/joint_5", control),
112 | ("iiwa_2/joint_6", control)]
113 | if self.isolated_joint_7:
114 | actuation_spec += [("iiwa_2/joint_7", pybullet.POSITION_CONTROL)]
115 | else:
116 | actuation_spec += [("iiwa_2/joint_7", control)]
117 | actuation_spec += [("iiwa_2/striker_joint_1", pybullet.POSITION_CONTROL),
118 | ("iiwa_2/striker_joint_2", pybullet.POSITION_CONTROL)]
119 |
120 | observation_spec += [("iiwa_2/joint_1", PyBulletObservationType.JOINT_POS),
121 | ("iiwa_2/joint_2", PyBulletObservationType.JOINT_POS),
122 | ("iiwa_2/joint_3", PyBulletObservationType.JOINT_POS),
123 | ("iiwa_2/joint_4", PyBulletObservationType.JOINT_POS),
124 | ("iiwa_2/joint_5", PyBulletObservationType.JOINT_POS),
125 | ("iiwa_2/joint_6", PyBulletObservationType.JOINT_POS),
126 | ("iiwa_2/joint_7", PyBulletObservationType.JOINT_POS),
127 | ("iiwa_2/striker_joint_1", PyBulletObservationType.JOINT_POS),
128 | ("iiwa_2/striker_joint_2", PyBulletObservationType.JOINT_POS),
129 | ("iiwa_2/joint_1", PyBulletObservationType.JOINT_VEL),
130 | ("iiwa_2/joint_2", PyBulletObservationType.JOINT_VEL),
131 | ("iiwa_2/joint_3", PyBulletObservationType.JOINT_VEL),
132 | ("iiwa_2/joint_4", PyBulletObservationType.JOINT_VEL),
133 | ("iiwa_2/joint_5", PyBulletObservationType.JOINT_VEL),
134 | ("iiwa_2/joint_6", PyBulletObservationType.JOINT_VEL),
135 | ("iiwa_2/joint_7", PyBulletObservationType.JOINT_VEL),
136 | ("iiwa_2/striker_joint_1", PyBulletObservationType.JOINT_VEL),
137 | ("iiwa_2/striker_joint_2", PyBulletObservationType.JOINT_VEL),
138 | ("iiwa_2/striker_mallet_tip", PyBulletObservationType.LINK_POS),
139 | ("iiwa_2/striker_mallet_tip", PyBulletObservationType.LINK_LIN_VEL)]
140 | else:
141 | raise ValueError('n_agents should be 1 or 2')
142 |
143 | super().__init__(model_files, actuation_spec, observation_spec, gamma,
144 | horizon, timestep=timestep, n_intermediate_steps=n_intermediate_steps,
145 | debug_gui=debug_gui, size=(500, 500), distance=1.8)
146 |
147 | self.pino_model = pino.buildModelFromUrdf(self.agents[0]['urdf'])
148 | se_tip = pino.SE3(np.eye(3), np.array([0., 0., 0.585]))
149 | self.pino_model.addBodyFrame('striker_rod_tip', 7, se_tip, self.pino_model.nframes - 1)
150 | self.pino_data = self.pino_model.createData()
151 | self.frame_idx = self.pino_model.nframes - 1
152 |
153 | self._client.resetDebugVisualizerCamera(cameraDistance=2, cameraYaw=0.0, cameraPitch=-89.9,
154 | cameraTargetPosition=[0., 0., 0.])
155 | self.env_spec = dict()
156 | self.env_spec['table'] = {"length": 1.96, "width": 1.02, "height": 0.0, "goal": 0.25, "urdf": table_file}
157 | self.env_spec['puck'] = {"radius": 0.03165, "urdf": puck_file}
158 | self.env_spec['mallet'] = {"radius": 0.05}
159 | self.env_spec['universal_height'] = 0.1505
160 |
161 | def _compute_action(self, state, action):
162 | if self.step_action_function is None:
163 | ctrl_action = action
164 | else:
165 | ctrl_action = self.step_action_function(state, action)
166 |
167 | joint_state = self.joints.positions(state)[:9]
168 |
169 | if self.isolated_joint_7:
170 | joint_7_des_pos = self._compute_joint_7(joint_state)
171 | ctrl_action = np.concatenate([ctrl_action, joint_7_des_pos])
172 |
173 | joint_universal_pos = self._compute_universal_joint(joint_state)
174 | return np.concatenate([ctrl_action, joint_universal_pos])
175 |
176 | def _simulation_pre_step(self):
177 | if self.env_noise:
178 | force = np.concatenate([np.random.randn(2), [0]]) * 0.0005
179 | self._client.applyExternalForce(self._model_map['puck']['id'], -1, force, [0., 0., 0.],
180 | self._client.WORLD_FRAME)
181 |
182 | def is_absorbing(self, state):
183 | boundary = np.array([self.env_spec['table']['length'], self.env_spec['table']['width']]) / 2
184 | puck_pos = self.get_sim_state(state, "puck", PyBulletObservationType.BODY_POS)[:3]
185 | if np.any(np.abs(puck_pos[:2]) > boundary) or abs(puck_pos[2] - self.env_spec['table']['height']) > 0.1:
186 | return True
187 |
188 | boundary_mallet = boundary
189 | for agent in self.agents:
190 | mallet_pose = self.get_sim_state(state, agent['name'] + "/striker_mallet_tip",
191 | PyBulletObservationType.LINK_POS)
192 | if np.any(np.abs(mallet_pose[:2]) - boundary_mallet > 0.02):
193 | return True
194 | return False
195 |
196 | def _compute_joint_7(self, state):
197 | raise NotImplementedError
198 |
199 | def _compute_universal_joint(self, state):
200 | raise NotImplementedError
201 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/env_hitting.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | from mushroom_rl.environments.pybullet import PyBulletObservationType
4 | from atacom.environments.iiwa_air_hockey.env_single import AirHockeySingle
5 |
6 |
7 | class AirHockeyHit(AirHockeySingle):
8 | def __init__(self, gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=1,
9 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, torque_control=True,
10 | random_init=False, step_action_function=None, action_penalty=1e-3, isolated_joint_7=False):
11 | self.hit_range = np.array([[-0.6, -0.2], [-0.4, 0.4]])
12 | self.goal = np.array([0.98, 0])
13 | self.has_hit = False
14 | self.r_hit = 0.
15 | self.vel_hit_x = 0.
16 | self.random_init = random_init
17 | self.action_penalty = action_penalty
18 | super().__init__(gamma=gamma, horizon=horizon, timestep=timestep,
19 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui,
20 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay, torque_control=torque_control,
21 | step_action_function=step_action_function, isolated_joint_7=isolated_joint_7)
22 |
23 | def setup(self, state):
24 | if self.random_init:
25 | puck_pos = np.random.rand(2) * (self.hit_range[:, 1] - self.hit_range[:, 0]) + self.hit_range[:, 0]
26 | else:
27 | puck_pos = np.mean(self.hit_range, axis=1)
28 |
29 | puck_pos = np.concatenate([puck_pos, [0.0]])
30 | self.client.resetBasePositionAndOrientation(self._model_map['puck'], puck_pos, [0, 0, 0, 1.0])
31 |
32 | for i, (model_id, joint_id, _) in enumerate(self._indexer.action_data[:7]):
33 | self.client.resetJointState(model_id, joint_id, self.init_state[i])
34 |
35 | self.has_hit = False
36 | self.r_hit = 0.
37 | self.vel_hit_x = 0.
38 |
39 | def reward(self, state, action, next_state, absorbing):
40 | r = 0
41 | puck_pos = self.get_sim_state(next_state, "puck", PyBulletObservationType.BODY_POS)[:2]
42 | puck_vel = self.get_sim_state(next_state, "puck", PyBulletObservationType.BODY_LIN_VEL)[:2]
43 | if absorbing:
44 | if puck_pos[0] - self.env_spec['table']['length'] / 2 > 0 and \
45 | np.abs(puck_pos[1]) - self.env_spec['table']['goal'] < 0:
46 | r = 80
47 | else:
48 | if not self.has_hit:
49 | ee_pos = self.get_sim_state(next_state,
50 | self.agents[0]['name'] + "/striker_mallet_tip",
51 | PyBulletObservationType.LINK_POS)[:2]
52 | dist_ee_puck = np.linalg.norm(puck_pos - ee_pos)
53 | # r = np.exp(-8 * (dist_ee_puck - 0.08))
54 |
55 | vec_ee_puck = (puck_pos - ee_pos) / dist_ee_puck
56 | vec_puck_goal = (self.goal - puck_pos) / np.linalg.norm(self.goal - puck_pos)
57 | cos_ang = np.clip(vec_puck_goal @ vec_ee_puck, 0, 1)
58 | r = np.exp(-8 * (dist_ee_puck - 0.08)) * cos_ang
59 | self.r_hit = r
60 | else:
61 | # dist = np.linalg.norm(self.goal - puck_pos)
62 | # if puck_vel[0] > 0:
63 | # r_vel = np.abs(puck_vel[0])
64 | # r = 0.5 * (np.exp(-10 * dist) + r_vel) + 0.5
65 |
66 | r = 1 + self.r_hit + self.vel_hit_x * 0.1
67 |
68 | r -= self.action_penalty * np.linalg.norm(action)
69 | return r
70 |
71 | def is_absorbing(self, state):
72 | if super().is_absorbing(state):
73 | return True
74 | if self.has_hit:
75 | puck_vel = self.get_sim_state(self._state, "puck", PyBulletObservationType.BODY_LIN_VEL)[:2]
76 | if np.linalg.norm(puck_vel) < 0.01:
77 | return True
78 | return False
79 |
80 | def _simulation_post_step(self):
81 | if not self.has_hit:
82 | puck_vel = self.get_sim_state(self._state, "puck", PyBulletObservationType.BODY_LIN_VEL)[:2]
83 | if np.linalg.norm(puck_vel) > 0.1:
84 | self.has_hit = True
85 | self.vel_hit_x = puck_vel[0]
86 |
87 |
88 | if __name__ == '__main__':
89 | env = AirHockeyHit(debug_gui=True, env_noise=False, obs_noise=False, obs_delay=False, n_intermediate_steps=4)
90 |
91 | R = 0.
92 | J = 0.
93 | gamma = 1.
94 | steps = 0
95 | while True:
96 | action = np.random.uniform(-1, 1, env.info.action_space.low.shape) * 8
97 | observation, reward, done, info = env.step(action)
98 | gamma *= env.info.gamma
99 | J += gamma * reward
100 | R += reward
101 | steps += 1
102 | if done or steps > env.info.horizon:
103 | print("J: ", J, " R: ", R)
104 | R = 0.
105 | J = 0.
106 | gamma = 1.
107 | steps = 0
108 | env.reset()
109 | time.sleep(1/60.)
110 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/env_single.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pinocchio as pino
3 | import pybullet_utils.transformations as transformations
4 | from mushroom_rl.core import MDPInfo
5 | from mushroom_rl.environments.pybullet import PyBulletObservationType
6 | from mushroom_rl.utils.spaces import Box
7 | from atacom.environments.iiwa_air_hockey.env_base import AirHockeyBase
8 | from atacom.environments.iiwa_air_hockey.kinematics import clik, fk
9 |
10 |
11 | class AirHockeySingle(AirHockeyBase):
12 | def __init__(self, gamma=0.99, horizon=500, timestep=1 / 240., n_intermediate_steps=1, debug_gui=False,
13 | env_noise=False, obs_noise=False, obs_delay=False, torque_control=True, step_action_function=None,
14 | isolated_joint_7=False):
15 | self.obs_prev = None
16 |
17 | if isolated_joint_7:
18 | self.n_ctrl_joints = 6
19 | else:
20 | self.n_ctrl_joints = 7
21 |
22 | super().__init__(gamma=gamma, horizon=horizon, timestep=timestep,
23 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui,
24 | env_noise=env_noise, n_agents=1, obs_noise=obs_noise, obs_delay=obs_delay,
25 | torque_control=torque_control, step_action_function=step_action_function,
26 | isolated_joint_7=isolated_joint_7)
27 |
28 | self._compute_init_state()
29 |
30 | self._client.resetDebugVisualizerCamera(cameraDistance=1.5, cameraYaw=-00.0, cameraPitch=-45.0,
31 | cameraTargetPosition=[-0.5, 0., 0.])
32 |
33 | self._change_dynamics()
34 |
35 | self._disable_collision()
36 |
37 | self.reset()
38 |
39 | def _compute_init_state(self):
40 | q = np.zeros(9)
41 | des_pos = pino.SE3(np.diag([-1., 1., -1.]), np.array([0.65, 0., self.env_spec['universal_height']]))
42 |
43 | success, self.init_state = clik(self.pino_model, self.pino_data, des_pos, q, self.frame_idx)
44 | assert success is True
45 |
46 | def _disable_collision(self):
47 | # disable the collision with left and right rim Because of the improper collision shape
48 | iiwa_links = ['iiwa_1/link_1', 'iiwa_1/link_2', 'iiwa_1/link_3', 'iiwa_1/link_4', 'iiwa_1/link_5',
49 | 'iiwa_1/link_6', 'iiwa_1/link_7', 'iiwa_1/link_ee', 'iiwa_1/striker_base',
50 | 'iiwa_1/striker_joint_link', 'iiwa_1/striker_mallet', 'iiwa_1/striker_mallet_tip']
51 | table_rims = ['t_down_rim_l', 't_down_rim_r', 't_up_rim_r', 't_up_rim_l',
52 | 't_left_rim', 't_right_rim', 't_base', 't_up_rim_top', 't_down_rim_top', 't_base']
53 | for iiwa_l in iiwa_links:
54 | for table_r in table_rims:
55 | self.client.setCollisionFilterPair(self._indexer.link_map[iiwa_l][0],
56 | self._indexer.link_map[table_r][0],
57 | self._indexer.link_map[iiwa_l][1],
58 | self._indexer.link_map[table_r][1], 0)
59 |
60 | self.client.setCollisionFilterPair(self._model_map['puck'], self._indexer.link_map['t_down_rim_top'][0],
61 | -1, self._indexer.link_map['t_down_rim_top'][1], 0)
62 | self.client.setCollisionFilterPair(self._model_map['puck'], self._indexer.link_map['t_up_rim_top'][0],
63 | -1, self._indexer.link_map['t_up_rim_top'][1], 0)
64 |
65 | def _change_dynamics(self):
66 | for i in range(12):
67 | self.client.changeDynamics(self._model_map['iiwa_1'], i, linearDamping=0., angularDamping=0.)
68 |
69 | def _modify_mdp_info(self, mdp_info):
70 | obs_idx = [0, 1, 2, 7, 8, 9, 13, 14, 15, 16, 17, 18, 22, 23, 24, 25, 26, 27]
71 | obs_low = mdp_info.observation_space.low[obs_idx]
72 | obs_high = mdp_info.observation_space.high[obs_idx]
73 | obs_low[0:3] = [-1, -0.5, -np.pi]
74 | obs_high[0:3] = [1, 0.5, np.pi]
75 | observation_space = Box(low=obs_low, high=obs_high)
76 |
77 | act_low = mdp_info.action_space.low[:self.n_ctrl_joints]
78 | act_high = mdp_info.action_space.high[:self.n_ctrl_joints]
79 | action_space = Box(low=act_low, high=act_high)
80 | return MDPInfo(observation_space, action_space, mdp_info.gamma, mdp_info.horizon)
81 |
82 | def _create_observation(self, state):
83 | puck_pose = self.get_sim_state(state, "puck", PyBulletObservationType.BODY_POS)
84 | puck_pose_2d = self._puck_2d_in_robot_frame(puck_pose, self.agents[0]['frame'], type='pose')
85 |
86 | robot_pos = list()
87 | robot_vel = list()
88 | for i in range(6):
89 | robot_pos.append(self.get_sim_state(state,
90 | self.agents[0]['name'] + "/joint_"+str(i+1),
91 | PyBulletObservationType.JOINT_POS))
92 | robot_vel.append(self.get_sim_state(state,
93 | self.agents[0]['name'] + "/joint_" + str(i + 1),
94 | PyBulletObservationType.JOINT_VEL))
95 | if not self.isolated_joint_7:
96 | robot_pos.append(self.get_sim_state(state,
97 | self.agents[0]['name'] + "/joint_" + str(7),
98 | PyBulletObservationType.JOINT_POS))
99 | robot_vel.append(self.get_sim_state(state,
100 | self.agents[0]['name'] + "/joint_" + str(7),
101 | PyBulletObservationType.JOINT_VEL))
102 | robot_pos = np.asarray(robot_pos).flatten()
103 | robot_vel = np.asarray(robot_vel).flatten()
104 |
105 | if self.obs_noise:
106 | puck_pose_2d[:2] += np.random.randn(2) * 0.001
107 | puck_pose_2d[2] += np.random.randn(1) * 0.001
108 |
109 | puck_lin_vel = self.get_sim_state(state, "puck", PyBulletObservationType.BODY_LIN_VEL)
110 | puck_ang_vel = self.get_sim_state(state, "puck", PyBulletObservationType.BODY_ANG_VEL)
111 | puck_vel_2d = self._puck_2d_in_robot_frame(np.concatenate([puck_lin_vel, puck_ang_vel]),
112 | self.agents[0]['frame'], type='vel')
113 |
114 | if self.obs_delay:
115 | alpha = 0.5
116 | puck_vel_2d = alpha * puck_vel_2d + (1 - alpha) * self.obs_prev[3:6]
117 | robot_vel = alpha * robot_vel + (1 - alpha) * self.obs_prev[9:12]
118 |
119 | self.obs_prev = np.concatenate([puck_pose_2d, puck_vel_2d, robot_pos, robot_vel])
120 | return self.obs_prev
121 |
122 | def _puck_2d_in_robot_frame(self, puck_in, robot_frame, type='pose'):
123 | if type == 'pose':
124 | puck_frame = transformations.translation_matrix(puck_in[:3])
125 | puck_frame = puck_frame @ transformations.quaternion_matrix(puck_in[3:])
126 |
127 | frame_target = transformations.inverse_matrix(robot_frame) @ puck_frame
128 | puck_translate = transformations.translation_from_matrix(frame_target)
129 | _, _, puck_euler_yaw = transformations.euler_from_matrix(frame_target)
130 |
131 | return np.concatenate([puck_translate[:2], [puck_euler_yaw]])
132 | if type == 'vel':
133 | rot_mat = robot_frame[:3, :3]
134 | vec_lin = rot_mat.T @ puck_in[:3]
135 | return np.concatenate([vec_lin[:2], puck_in[5:6]])
136 |
137 | def _compute_joint_7(self, joint_state):
138 | q_cur = joint_state.copy()
139 | q_cur_7 = q_cur[6]
140 | q_cur[6] = 0.
141 |
142 | f_cur = fk(self.pino_model, self.pino_data, q_cur, self.frame_idx)
143 | z_axis = np.array([0., 0., -1.])
144 |
145 | y_des = np.cross(z_axis, f_cur.rotation[:, 2])
146 | y_des_norm = np.linalg.norm(y_des)
147 | if y_des_norm > 1e-2:
148 | y_des = y_des / y_des_norm
149 | else:
150 | y_des = f_cur.rotation[:, 2]
151 |
152 | target = np.arccos(f_cur.rotation[:, 1].dot(y_des))
153 |
154 | axis = np.cross(f_cur.rotation[:, 1], y_des)
155 | axis_norm = np.linalg.norm(axis)
156 | if axis_norm > 1e-2:
157 | axis = axis / axis_norm
158 | else:
159 | axis = np.array([0., 0., 1.])
160 |
161 | target = target * axis.dot(f_cur.rotation[:, 2])
162 |
163 | if target - q_cur_7 > np.pi / 2:
164 | target -= np.pi
165 | elif target - q_cur_7 < -np.pi / 2:
166 | target += np.pi
167 |
168 | return np.atleast_1d(target)
169 |
170 | def _compute_universal_joint(self, joint_state):
171 | rot_mat = transformations.quaternion_matrix(
172 | self.client.getLinkState(*self._indexer.link_map['iiwa_1/link_ee'])[1])
173 |
174 | q1 = np.arccos(rot_mat[:3, 2].dot(np.array((0., 0., -1))))
175 | q2 = 0
176 |
177 | axis = np.cross(rot_mat[:3, 2], np.array([0., 0., -1.]))
178 | axis_norm = np.linalg.norm(axis)
179 | if axis_norm > 1e-2:
180 | axis = axis / axis_norm
181 | else:
182 | axis = np.array([0., 0., 1.])
183 | q1 = q1 * axis.dot(rot_mat[:3, 1])
184 |
185 | return np.array([q1, q2])
186 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/iiwa_air_hockey_rmp.py:
--------------------------------------------------------------------------------
1 | import time
2 | import pinocchio as pino
3 | from mushroom_rl.utils.spaces import *
4 | from scipy.linalg import solve
5 | from atacom.environments.iiwa_air_hockey.env_hitting import AirHockeyHit
6 |
7 |
8 | class AirHockeyIiwaRmp:
9 | def __init__(self, task='H', gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=4,
10 | acc_max=10, debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, random_init=False,
11 | action_penalty=1e-3, Kq=10.):
12 | if task == 'H':
13 | base_env = AirHockeyHit(gamma=gamma, horizon=horizon, timestep=timestep,
14 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui,
15 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay,
16 | torque_control='torque', random_init=random_init,
17 | step_action_function=self.step_action_function,
18 | action_penalty=action_penalty, isolated_joint_7=True)
19 | if task == 'D':
20 | raise NotImplementedError
21 |
22 | self.observation = None
23 |
24 | self.env = base_env
25 | self.env.agents[0]['constr_a'] = 1.0
26 | self.env.agents[0]['constr_b'] = 1.0
27 | self.env.agents[0]['constr_bound_x_l'] = 0.58
28 | self.env.agents[0]['constr_bound_y_l'] = -0.46
29 | self.env.agents[0]['constr_bound_y_u'] = 0.46
30 |
31 | self.dim_q = 6
32 | self._mdp_info = self.env.info.copy()
33 | self._mdp_info.action_space = Box(low=-np.ones(self.dim_q), high=np.ones(self.dim_q))
34 |
35 | self.constr_logs = list()
36 |
37 | if np.isscalar(acc_max):
38 | self.acc_max = np.ones(self.dim_q) * acc_max
39 | else:
40 | self.acc_max = acc_max
41 | assert np.shape(self.acc_max)[0] == self.dim_q
42 |
43 | if np.isscalar(Kq):
44 | self.K_q = np.ones(self.dim_q) * Kq
45 | else:
46 | self.K_q = Kq
47 | assert np.shape(self.K_q)[0] == self.dim_q
48 |
49 | @property
50 | def info(self):
51 | """
52 | Returns:
53 | An object containing the info of the environment.
54 | """
55 | return self._mdp_info
56 |
57 | def reset(self, state=None):
58 | return self.env.reset(state)
59 |
60 | def step(self, action):
61 | action = np.clip(action, self.info.action_space.low, self.info.action_space.high)
62 | self.observation, reward, absorbing, _ = self.env.step(action)
63 | q = self.env.joints.positions(self.env._state)
64 | dq = self.env.joints.positions(self.env._state)
65 | self._update_constraint_stats(q, dq)
66 | return self.observation.copy(), reward, absorbing, _
67 |
68 | def stop(self):
69 | self.env.stop()
70 |
71 | def step_action_function(self, state, action):
72 | q = self.env.joints.positions(state)
73 | dq = self.env.joints.velocities(state)
74 | pino.forwardKinematics(self.env.pino_model, self.env.pino_data, q, dq)
75 | pino.computeJointJacobians(self.env.pino_model, self.env.pino_data, q)
76 | pino.updateFramePlacements(self.env.pino_model, self.env.pino_data)
77 |
78 | ddq_ee = self.rmp_ddq_ee()
79 | ddq_elbow = self.rmp_ddq_elbow()
80 | ddq_wrist = self.rmp_ddq_wrist()
81 |
82 | ddq_joints = self.rmp_joint_limit(q, dq)
83 |
84 | ddq_total = np.zeros(self.env.pino_model.nq)
85 | ddq_total[:self.dim_q] = ddq_ee + ddq_elbow + ddq_wrist + ddq_joints + action * self.acc_max
86 | ddq = self.acc_truncation(dq, ddq_total)
87 | tau = pino.rnea(self.env.pino_model, self.env.pino_data, q, dq, ddq)
88 | return tau[:6]
89 |
90 | def rmp_ddq_ee(self):
91 | frame_id = self.env.frame_idx
92 | J_frame = pino.getFrameJacobian(self.env.pino_model, self.env.pino_data, frame_id,
93 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED)
94 | J = J_frame[:3, :6]
95 | link_pos = self.env.pino_data.oMf[frame_id].translation
96 | link_vel = pino.getFrameVelocity(self.env.pino_model, self.env.pino_data, frame_id,
97 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED).linear
98 | cart_acc = self.f_bound(link_pos, link_vel, self.env.agents[0]['constr_bound_x_l'], idx=0, b_type='l',
99 | eta_rep=0.1, v_rep=10) + \
100 | self.f_bound(link_pos, link_vel, self.env.agents[0]['constr_bound_y_l'], idx=1, b_type='l',
101 | eta_rep=0.1, v_rep=10) + \
102 | self.f_bound(link_pos, link_vel, self.env.agents[0]['constr_bound_y_u'], idx=1, b_type='u',
103 | eta_rep=0.1, v_rep=10) + \
104 | self.f_plane(link_pos, link_vel, self.env.env_spec['universal_height'])
105 |
106 | return J.T @ solve(J @ J.T + 1e-6 * np.eye(3), cart_acc)
107 |
108 | def rmp_ddq_elbow(self):
109 | frame_id = self.env.pino_model.getFrameId("iiwa_1/link_4")
110 | J_frame = pino.getFrameJacobian(self.env.pino_model, self.env.pino_data, frame_id,
111 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED)
112 | J = J_frame[:3, :6]
113 | link_pos = self.env.pino_data.oMf[frame_id].translation
114 | link_vel = pino.getFrameVelocity(self.env.pino_model, self.env.pino_data, frame_id,
115 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED).linear
116 | cart_acc = self.f_bound(link_pos, link_vel, 0.36, idx=2, b_type='l',
117 | eta_rep=0.5, v_rep=10)
118 | return J.T @ solve(J @ J.T + 1e-6 * np.eye(3), cart_acc)
119 |
120 | def rmp_ddq_wrist(self):
121 | frame_id = self.env.pino_model.getFrameId("iiwa_1/link_6")
122 | J_frame = pino.getFrameJacobian(self.env.pino_model, self.env.pino_data, frame_id,
123 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED)
124 | J = J_frame[:3, :6]
125 | link_pos = self.env.pino_data.oMf[frame_id].translation
126 | link_vel = pino.getFrameVelocity(self.env.pino_model, self.env.pino_data, frame_id,
127 | pino.ReferenceFrame.LOCAL_WORLD_ALIGNED).linear
128 | cart_acc = self.f_bound(link_pos, link_vel, 0.25, idx=2, b_type='l',
129 | eta_rep=0.01, v_rep=10)
130 | return J.T @ solve(J @ J.T + 1e-6 * np.eye(3), cart_acc)
131 |
132 | def rmp_joint_limit(self, q, dq):
133 | sigma = 10
134 |
135 | s = (q - self.env.pino_model.lowerPositionLimit) / \
136 | (self.env.pino_model.upperPositionLimit - self.env.pino_model.lowerPositionLimit)
137 |
138 | d = 4 * s * (1 - s)
139 |
140 | alpha_u = 1 - np.exp(- (np.maximum(dq, 0.) / sigma) ** 2 / 2)
141 | alpha_l = 1 - np.exp(- (np.minimum(dq, 0.) / sigma) ** 2 / 2)
142 |
143 | b = s * (alpha_u * d + (1 - alpha_u)) + (1 - s) * (alpha_l * d + (1 - alpha_l))
144 |
145 | a = b ** (-2)
146 | return a[:6]
147 |
148 | def f_bound(self, x, dx, bound, idx, b_type='u', eta_rep=5., v_rep=1., eta_damp=None):
149 | ddx = np.zeros_like(x)
150 | if eta_damp is None:
151 | eta_damp = np.sqrt(eta_rep)
152 | if b_type == 'u':
153 | d = np.maximum(bound - x[idx], 0.) ** 2
154 | ddx[idx] = -eta_rep * np.exp(-d / v_rep) - eta_damp / (d + 1e-6) * np.maximum(dx[idx], 0)
155 | elif b_type == 'l':
156 | d = np.maximum(x[idx] - bound, 0.) ** 2
157 | ddx[idx] = eta_rep * np.exp(-d / v_rep) - eta_damp / (d + 1e-6) * np.minimum(dx[idx], 0)
158 | return ddx
159 |
160 | def f_plane(self, x, dx, height):
161 | ddx = np.zeros_like(x)
162 | k = 1000
163 | d = np.sqrt(k)
164 | ddx[2] = k * (height - x[2]) - d * dx[2]
165 | return ddx
166 |
167 | def acc_truncation(self, dq, ddq):
168 | acc_u = np.maximum(np.minimum(self.acc_max,
169 | -self.K_q * (dq[:self.dim_q] - self.env.pino_model.velocityLimit[:self.dim_q])),
170 | -self.acc_max)
171 | acc_l = np.minimum(np.maximum(-self.acc_max,
172 | -self.K_q * (dq[:self.dim_q] + self.env.pino_model.velocityLimit[:self.dim_q])),
173 | self.acc_max)
174 | ddq[:self.dim_q] = np.clip(ddq[:self.dim_q], acc_l, acc_u)
175 | return ddq
176 |
177 | def _update_constraint_stats(self, q, dq):
178 | c_i = self._compute_c(q, dq)
179 | c_dq_i = (np.abs(dq) - self.env.pino_model.velocityLimit)[:self.dim_q]
180 | self.constr_logs.append([np.max(c_i), np.max(c_dq_i)])
181 |
182 | def get_constraints_logs(self):
183 | constr_logs = np.array(self.constr_logs)
184 | c_avg = np.mean(constr_logs[:, 0])
185 | c_max = np.max(constr_logs[:, 0])
186 | c_dq_max = np.max(constr_logs[:, 1])
187 | self.constr_logs.clear()
188 | return c_avg, c_max, c_dq_max
189 |
190 | def _compute_c(self, q, dq):
191 | pino.forwardKinematics(self.env.pino_model, self.env.pino_data, q, dq)
192 | pino.computeJointJacobians(self.env.pino_model, self.env.pino_data, q)
193 | pino.updateFramePlacements(self.env.pino_model, self.env.pino_data)
194 |
195 | ee_pos = self.env.pino_data.oMf[self.env.frame_idx].translation
196 | elbow_pos = self.env.pino_data.oMf[self.env.pino_model.getFrameId("iiwa_1/link_4")].translation
197 | wrist_pos = self.env.pino_data.oMf[self.env.pino_model.getFrameId("iiwa_1/link_6")].translation
198 |
199 | c = []
200 | c.append(np.abs(ee_pos[2] - self.env.env_spec['universal_height']))
201 | c.append(-ee_pos[0] + self.env.agents[0]['constr_bound_x_l'])
202 | c.append(-ee_pos[1] + self.env.agents[0]['constr_bound_y_l'])
203 | c.append(ee_pos[1] - self.env.agents[0]['constr_bound_y_u'])
204 | c.append(- elbow_pos[2] + 0.36)
205 | c.append(- wrist_pos[2] + 0.25)
206 | c.extend(- q[:self.env.n_ctrl_joints] + self.env.pino_model.lowerPositionLimit[:self.env.n_ctrl_joints])
207 | c.extend(q[:self.env.n_ctrl_joints] - self.env.pino_model.upperPositionLimit[:self.env.n_ctrl_joints])
208 | return np.array(c)
209 |
210 |
211 | if __name__ == '__main__':
212 | env = AirHockeyIiwaRmp(debug_gui=True)
213 |
214 | env.reset()
215 | for i in range(10000):
216 | action = np.random.randn(env.dim_q)
217 | _, _, absorb, _ = env.step(action)
218 | if absorb:
219 | env.reset()
220 | time.sleep(1 / 240.)
221 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/iiwa_hit_atacom.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pinocchio as pino
4 | import matplotlib.pyplot as plt
5 | from atacom.environments.iiwa_air_hockey.env_hitting import AirHockeyHit
6 | from atacom.atacom import AtacomEnvWrapper
7 | from atacom.constraints import ViabilityConstraint, ConstraintsSet
8 |
9 |
10 | class AirHockeyIiwaAtacom(AtacomEnvWrapper):
11 | def __init__(self, task='H', gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=4,
12 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, Kc=240., random_init=False,
13 | action_penalty=1e-3):
14 | if task == 'H':
15 | base_env = AirHockeyHit(gamma=gamma, horizon=horizon, timestep=timestep,
16 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui,
17 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay,
18 | torque_control='torque', random_init=random_init,
19 | action_penalty=action_penalty, isolated_joint_7=True)
20 | if task == 'D':
21 | raise NotImplementedError
22 |
23 | dim_q = base_env.n_ctrl_joints
24 | ee_pos_f = ViabilityConstraint(dim_q=dim_q, dim_out=1, fun=self.ee_pose_f, J=self.ee_pose_J_f,
25 | b=self.ee_pos_b_f, K=0.1)
26 | f = ConstraintsSet(dim_q)
27 | f.add_constraint(ee_pos_f)
28 |
29 | cart_pos_g = ViabilityConstraint(dim_q=dim_q, dim_out=5, fun=self.ee_pos_g, J=self.ee_pos_J_g,
30 | b=self.ee_pos_b_g, K=0.5)
31 | joint_pos_g = ViabilityConstraint(dim_q=dim_q, dim_out=dim_q, fun=self.joint_pos_g, J=self.joint_pos_J_g,
32 | b=self.joint_pos_b_g, K=1)
33 | g = ConstraintsSet(dim_q)
34 | g.add_constraint(cart_pos_g)
35 | g.add_constraint(joint_pos_g)
36 |
37 | acc_max = np.ones(base_env.n_ctrl_joints) * 10
38 | vel_max = base_env.joints.velocity_limits()[:base_env.n_ctrl_joints]
39 | super().__init__(base_env, dim_q, f=f, g=g, Kc=Kc, vel_max=vel_max, acc_max=acc_max, Kq=4 * acc_max / vel_max,
40 | time_step=timestep)
41 |
42 | self.pino_model = self.env.pino_model
43 | self.pino_data = self.env.pino_data
44 | self.frame_idx = self.env.frame_idx
45 | self.frame_idx_4 = 12
46 | self.frame_idx_7 = 18
47 |
48 | for i in range(self.pino_model.nq):
49 | self.env.client.changeDynamics(*base_env._indexer.joint_map[self.pino_model.names[i+1]],
50 | maxJointVelocity=self.pino_model.velocityLimit[i] * 1.5)
51 |
52 | def _get_q(self, state):
53 | return state[-2 * self.env.n_ctrl_joints:-self.env.n_ctrl_joints]
54 |
55 | def _get_dq(self, state):
56 | return state[-self.env.n_ctrl_joints:]
57 |
58 | def acc_to_ctrl_action(self, ddq):
59 | ddq = self._get_pino_value(ddq).tolist()
60 | sim_state = self.env._indexer.create_sim_state()
61 | q = self.env.joints.positions(sim_state).tolist()
62 | dq = self.env.joints.velocities(sim_state).tolist()
63 | return self.env.client.calculateInverseDynamics(2, q, dq, ddq)[:self.env.n_ctrl_joints]
64 |
65 | def _get_pino_value(self, q):
66 | ret = np.zeros(9)
67 | ret[:q.shape[0]] = q
68 | return ret
69 |
70 | def ee_pose_f(self, q):
71 | q = self._get_pino_value(q)
72 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q)
73 | ee_pos_z = self.pino_data.oMf[self.frame_idx].translation[2]
74 | return np.atleast_1d(ee_pos_z - self.env.env_spec['universal_height'])
75 |
76 | def ee_pose_J_f(self, q):
77 | q = self._get_pino_value(q)
78 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q)
79 | ee_jac = pino.computeFrameJacobian(self.pino_model, self.pino_data, q,
80 | self.frame_idx, pino.LOCAL_WORLD_ALIGNED)[:, :self.env.n_ctrl_joints]
81 | J_pos = ee_jac[2]
82 | return np.atleast_2d(J_pos)
83 |
84 | def ee_pos_b_f(self, q, dq):
85 | q = self._get_pino_value(q)
86 | dq = self._get_pino_value(dq)
87 | pino.forwardKinematics(self.pino_model, self.pino_data, q, dq)
88 | acc = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.frame_idx,
89 | pino.LOCAL_WORLD_ALIGNED).vector
90 | b_pos = acc[2]
91 | return np.atleast_1d(b_pos)
92 |
93 | def ee_pos_g(self, q):
94 | q = self._get_pino_value(q)
95 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q)
96 | ee_pos = self.pino_data.oMf[self.frame_idx].translation[:2]
97 | ee_pos_world = ee_pos + self.env.agents[0]['frame'][:2, 3]
98 | g_1 = - ee_pos_world[0] - (self.env.env_spec['table']['length'] / 2 - self.env.env_spec['mallet']['radius'])
99 | g_2 = - ee_pos_world[1] - (self.env.env_spec['table']['width'] / 2 - self.env.env_spec['mallet']['radius'])
100 | g_3 = ee_pos_world[1] - (self.env.env_spec['table']['width'] / 2 - self.env.env_spec['mallet']['radius'])
101 |
102 | ee_pos_4 = self.pino_data.oMf[self.frame_idx_4].translation
103 | ee_pos_7 = self.pino_data.oMf[self.frame_idx_7].translation
104 | g_4 = -ee_pos_4[2] + 0.36
105 | g_5 = -ee_pos_7[2] + 0.25
106 | return np.array([g_1, g_2, g_3, g_4, g_5])
107 |
108 | def ee_pos_J_g(self, q):
109 | q = self._get_pino_value(q)
110 | pino.computeJointJacobians(self.pino_model, self.pino_data, q)
111 | jac_ee = pino.getFrameJacobian(self.pino_model, self.pino_data,
112 | self.frame_idx, pino.LOCAL_WORLD_ALIGNED)[:, :self.env.n_ctrl_joints]
113 | jac_4 = pino.getFrameJacobian(self.pino_model, self.pino_data,
114 | self.frame_idx_4, pino.LOCAL_WORLD_ALIGNED)[:, :self.env.n_ctrl_joints]
115 | jac_7 = pino.getFrameJacobian(self.pino_model, self.pino_data,
116 | self.frame_idx_7, pino.LOCAL_WORLD_ALIGNED)[:, :self.env.n_ctrl_joints]
117 | return np.vstack([-jac_ee[0], -jac_ee[1], jac_ee[1], -jac_4[2], -jac_7[2]])
118 |
119 | def ee_pos_b_g(self, q, dq):
120 | q = self._get_pino_value(q)
121 | dq = self._get_pino_value(dq)
122 | pino.forwardKinematics(self.pino_model, self.pino_data, q, dq)
123 | acc_ee = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.frame_idx,
124 | pino.LOCAL_WORLD_ALIGNED).vector
125 | acc_4 = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.frame_idx_4,
126 | pino.LOCAL_WORLD_ALIGNED).vector
127 | acc_7 = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.frame_idx_7,
128 | pino.LOCAL_WORLD_ALIGNED).vector
129 |
130 | return np.array([-acc_ee[0], -acc_ee[1], acc_ee[1], -acc_4[2], -acc_7[2]])
131 |
132 | def joint_pos_g(self, q):
133 | return np.array(q ** 2 - self.pino_model.upperPositionLimit[:self.env.n_ctrl_joints] ** 2)
134 |
135 | def joint_pos_J_g(self, q):
136 | return 2 * np.diag(q)
137 |
138 | def joint_pos_b_g(self, q, dq):
139 | return 2 * dq ** 2
140 |
141 | def plot_constraints(self, dataset, save_dir="", suffix="", state_norm_processor=None):
142 | state_list = list()
143 | i = 0
144 |
145 | if suffix != '':
146 | suffix = suffix + "_"
147 |
148 | for data in dataset:
149 | state = data[0]
150 | if state_norm_processor is not None:
151 | state[state_norm_processor._obs_mask] = (state * state_norm_processor._obs_delta + \
152 | state_norm_processor._obs_mean)[state_norm_processor._obs_mask]
153 | state_list.append(state)
154 | if data[-1]:
155 | i += 1
156 | state_hist = np.array(state_list)
157 |
158 | if not os.path.exists(save_dir):
159 | os.makedirs(save_dir)
160 |
161 | ee_pos_list = list()
162 | for state_i in state_hist:
163 | q = np.zeros(9)
164 | q[:6] = state_i[6:12]
165 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q)
166 | ee_pos_list.append(self.pino_data.oMf[-1].translation[:2] + self.env.agents[0]['frame'][:2, 3])
167 |
168 | ee_pos_list = np.array(ee_pos_list)
169 | fig1, axes1 = plt.subplots(1, figsize=(10, 10))
170 | axes1.plot(ee_pos_list[:, 0], ee_pos_list[:, 1], label='position')
171 | axes1.plot([0.0, -0.91, -0.91, 0.0], [-0.45, -0.45, 0.45, 0.45], label='boundary', c='k', lw='5')
172 | axes1.set_aspect(1.0)
173 | axes1.set_xlim(-1, 0)
174 | axes1.set_ylim(-0.5, 0.5)
175 | axes1.legend(loc='upper right')
176 | axes1.set_title('EndEffector')
177 | axes1.legend(loc='center right')
178 | file1 = "EndEffector_" + suffix + str(i) + ".pdf"
179 | plt.savefig(os.path.join(save_dir, file1))
180 | plt.close(fig1)
181 |
182 | fig2, axes2 = plt.subplots(2, 3, figsize=(21, 8), sharex=True, sharey=True)
183 | for j in range(6):
184 | axes2[j // 3, j % 3].plot(state_hist[:, 6 + j], lw=3, color='tab:blue')
185 | axes2[j // 3, j % 3].plot(state_hist[:, 12 + j], lw=3, color='tab:orange')
186 | axes2[j // 3, j % 3].plot([0, state_hist.shape[0]], [self.pino_model.lowerPositionLimit[j]] * 2,
187 | lw=3, c='tab:red', ls='--')
188 | axes2[j // 3, j % 3].plot([0, state_hist.shape[0]], [self.pino_model.upperPositionLimit[j]] * 2,
189 | lw=3, c='tab:red', ls='--')
190 | axes2[j // 3, j % 3].plot([0, state_hist.shape[0]], [-self.pino_model.velocityLimit[j]] * 2,
191 | lw=3, c='tab:pink', ls=':')
192 | axes2[j // 3, j % 3].plot([0, state_hist.shape[0]], [self.pino_model.velocityLimit[j]] * 2,
193 | lw=3, c='tab:pink', ls=':')
194 |
195 | axes2[j // 3, j % 3].set_title('Joint ' + str(j + 1))
196 |
197 | axes2[0, 0].plot([], lw=3, color='tab:blue', label='position')
198 | axes2[0, 0].plot([], lw=3, color='tab:red', ls='--', label='position limit')
199 | axes2[0, 0].plot([], lw=3, color='tab:orange', label='velocity')
200 | axes2[0, 0].plot([], lw=3, color='tab:pink', ls=':', label='velocity limit')
201 | fig2.legend(ncol=4, loc='lower center')
202 |
203 | file2 = "JointProfile_" + suffix + str(i) + ".pdf"
204 | plt.savefig(os.path.join(save_dir, file2))
205 | plt.close(fig2)
206 |
207 | state_list = list()
208 |
209 |
210 |
211 |
212 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/kinematics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pinocchio as pino
3 | from numpy.linalg import norm, solve
4 |
5 | IT_MAX = 1000
6 | eps = 1e-4
7 | DT = 1e-1
8 | damp = 1e-12
9 |
10 |
11 | def clik(model, data, fDes, q, idx):
12 | q_cur = q.copy()
13 | i = 0
14 | while True:
15 | pino.forwardKinematics(model, data, q_cur)
16 | pino.updateFramePlacements(model, data)
17 |
18 | dMi = fDes.actInv(data.oMf[idx])
19 | lin_err = -dMi.translation
20 | ang_err = pino.log3(dMi.rotation)
21 | err = np.concatenate([lin_err, ang_err])
22 | if norm(err) < eps:
23 | success = True
24 | break
25 | if i >= IT_MAX:
26 | success = False
27 | break
28 | J = pino.computeFrameJacobian(model, data, q_cur, idx, pino.LOCAL_WORLD_ALIGNED)
29 | v = - J.T.dot(solve(J.dot(J.T) + damp * np.eye(6), err))
30 | q_cur = pino.integrate(model, q_cur, v*DT)
31 | i += 1
32 | idx = np.where(q_cur > model.upperPositionLimit)
33 | q_cur[idx] -= np.pi * 2
34 | idx = np.where(q_cur < model.lowerPositionLimit)
35 | q_cur[idx] += np.pi * 2
36 | if not (np.all(model.lowerPositionLimit < q_cur) and np.all(q_cur < model.upperPositionLimit)):
37 | return False, q_cur
38 | return success, q_cur
39 |
40 |
41 | def fk(model, data, q, idx):
42 | pino.forwardKinematics(model, data, q)
43 | pino.updateFramePlacement(model, data, idx)
44 | return data.oMf[idx]
45 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/iiwa_1.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/iiwa_2.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_0.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_0.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_1.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_2.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_3.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_4.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_4.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_5.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_5.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_6.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_6.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_7.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_7.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_7_old.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/collision/link_7_old.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_arm_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_arm_collision.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_mallet_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_mallet_collision.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_mallet_short_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/collision/EE_mallet_short_collision.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_arm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_arm.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_mallet.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_mallet.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_mallet_short.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/striker/visual/EE_mallet_short.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_0.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_0.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_1.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_2.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_3.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_4.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_4.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_5.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_5.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_6.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_6.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_7.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_7.stl
--------------------------------------------------------------------------------
/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_7_old.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/atacom/environments/iiwa_air_hockey/urdf/meshes/visual/link_7_old.stl
--------------------------------------------------------------------------------
/atacom/environments/planar_air_hockey/__init__.py:
--------------------------------------------------------------------------------
1 | from .atacom_air_hockey import AirHockeyPlanarAtacom
2 | from .unconstrained_air_hockey import AirHockeyHitUnconstrained, AirHockeyDefendUnconstrained
3 |
--------------------------------------------------------------------------------
/atacom/environments/planar_air_hockey/atacom_air_hockey.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pinocchio as pino
4 | import matplotlib.pyplot as plt
5 |
6 | from atacom.atacom import AtacomEnvWrapper
7 | from atacom.constraints import ViabilityConstraint, ConstraintsSet
8 | from mushroom_rl.environments.pybullet_envs.air_hockey import AirHockeyHit, AirHockeyDefend
9 |
10 |
11 | class AirHockeyPlanarAtacom(AtacomEnvWrapper):
12 | def __init__(self, task='H', gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=4,
13 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, Kc=240., random_init=False,
14 | action_penalty=1e-3):
15 | if task == 'H':
16 | base_env = AirHockeyHit(gamma=gamma, horizon=horizon, timestep=timestep,
17 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui,
18 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay,
19 | torque_control=True, random_init=random_init,
20 | action_penalty=action_penalty)
21 | if task == 'D':
22 | base_env = AirHockeyDefend(gamma=gamma, horizon=horizon, timestep=timestep,
23 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui,
24 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay,
25 | torque_control=True, random_init=random_init,
26 | action_penalty=action_penalty)
27 |
28 | dim_q = 3
29 | cart_pos_g = ViabilityConstraint(dim_q=dim_q, dim_out=3, fun=self.cart_pos_g, J=self.cart_pos_J_g,
30 | b=self.cart_pos_b_g, K=0.5)
31 | joint_pos_g = ViabilityConstraint(dim_q=dim_q, dim_out=3, fun=self.joint_pos_g, J=self.joint_pos_J_g,
32 | b=self.joint_pos_b_g, K=1.0)
33 | # joint_vel_g = StateVelocityConstraint(dim_q=dim_q, dim_out=3, fun=self.joint_vel_g, A=self.joint_vel_A_g,
34 | # b=self.joint_vel_b_g, margin=0.0)
35 | g = ConstraintsSet(dim_q)
36 | g.add_constraint(cart_pos_g)
37 | g.add_constraint(joint_pos_g)
38 | # g.add_constraint(joint_vel_g)
39 |
40 | acc_max = np.ones(3) * 10
41 | vel_max = base_env.joints.velocity_limits()
42 | super().__init__(base_env, 3, f=None, g=g, Kc=Kc, vel_max=vel_max, acc_max=acc_max, Kq=2 * acc_max / vel_max,
43 | time_step=timestep)
44 |
45 | self.pino_model = pino.buildModelFromUrdf(self.env.agents[0]['urdf'])
46 | self.pino_data = self.pino_model.createData()
47 | self.frame_idx = self.pino_model.nframes - 1
48 |
49 | self.env.client.changeDynamics(*base_env._indexer.joint_map[self.pino_model.names[1]],
50 | maxJointVelocity=self.pino_model.velocityLimit[0] * 1.5)
51 | self.env.client.changeDynamics(*base_env._indexer.joint_map[self.pino_model.names[2]],
52 | maxJointVelocity=self.pino_model.velocityLimit[1] * 1.5)
53 | self.env.client.changeDynamics(*base_env._indexer.joint_map[self.pino_model.names[3]],
54 | maxJointVelocity=self.pino_model.velocityLimit[2] * 1.5)
55 |
56 | robot_links = ['planar_robot_1/link_striker_hand', 'planar_robot_1/link_striker_ee']
57 | table_rims = ['t_down_rim_l', 't_down_rim_r', 't_up_rim_r', 't_up_rim_l',
58 | 't_left_rim', 't_right_rim', 't_base', 't_up_rim_top', 't_down_rim_top', 't_base']
59 | for iiwa_l in robot_links:
60 | for table_r in table_rims:
61 | self.env.client.setCollisionFilterPair(self.env._indexer.link_map[iiwa_l][0],
62 | self.env._indexer.link_map[table_r][0],
63 | self.env._indexer.link_map[iiwa_l][1],
64 | self.env._indexer.link_map[table_r][1], 0)
65 |
66 | def _get_q(self, state):
67 | return state[6:9]
68 |
69 | def _get_dq(self, state):
70 | return state[9:12]
71 |
72 | def acc_to_ctrl_action(self, ddq):
73 | q = self.q.tolist()
74 | dq = self.dq.tolist()
75 | ddq = ddq.tolist()
76 | return self.env.client.calculateInverseDynamics(self.env._model_map['planar_robot_1'], q, dq, ddq)
77 |
78 | def cart_pos_g(self, q):
79 | pino.framesForwardKinematics(self.pino_model, self.pino_data, q)
80 | ee_pos = self.pino_data.oMf[-1].translation[:2]
81 | ee_pos_world = ee_pos + self.env.agents[0]['frame'][:2, 3]
82 | g_1 = - ee_pos_world[0] - (self.env.env_spec['table']['length'] / 2 - self.env.env_spec['mallet']['radius'])
83 | g_2 = - ee_pos_world[1] - (self.env.env_spec['table']['width'] / 2 - self.env.env_spec['mallet']['radius'])
84 | g_3 = ee_pos_world[1] - (self.env.env_spec['table']['width'] / 2 - self.env.env_spec['mallet']['radius'])
85 | return np.array([g_1, g_2, g_3])
86 |
87 | def cart_pos_J_g(self, q):
88 | ee_jac = pino.computeFrameJacobian(self.pino_model, self.pino_data, q,
89 | self.frame_idx, pino.LOCAL_WORLD_ALIGNED)[:2]
90 | J_c = np.array([[-1., 0.], [0., -1.], [0., 1.]])
91 | return J_c @ ee_jac
92 |
93 | def cart_pos_b_g(self, q, dq):
94 | pino.forwardKinematics(self.pino_model, self.pino_data, q, dq)
95 | acc = pino.getFrameClassicalAcceleration(self.pino_model, self.pino_data, self.pino_model.nframes - 1,
96 | pino.LOCAL_WORLD_ALIGNED).vector
97 | J_c = np.array([[-1., 0.], [0., -1.], [0., 1.]])
98 | return J_c @ acc[:2]
99 |
100 | def joint_pos_g(self, q):
101 | return np.array(q ** 2 - self.pino_model.upperPositionLimit ** 2)
102 |
103 | def joint_pos_J_g(self, q):
104 | return 2 * np.diag(q)
105 |
106 | def joint_pos_b_g(self, q, dq):
107 | return 2 * dq ** 2
108 |
109 | def joint_vel_g(self, q, dq):
110 | return np.array([dq ** 2 - self.pino_model.velocityLimit ** 2])
111 |
112 | def joint_vel_A_g(self, q, dq):
113 | return 2 * np.diag(dq)
114 |
115 | def joint_vel_b_g(self, q, dq):
116 | return np.zeros(3)
117 |
118 | def plot_constraints(self, dataset, save_dir="", suffix="", state_norm_processor=None):
119 | state_list = list()
120 | i = 0
121 |
122 | if suffix != '':
123 | suffix = suffix + "_"
124 |
125 | for data in dataset:
126 | state = data[0]
127 | if state_norm_processor is not None:
128 | state[state_norm_processor._obs_mask] = (state * state_norm_processor._obs_delta + \
129 | state_norm_processor._obs_mean)[state_norm_processor._obs_mask]
130 | state_list.append(state)
131 | if data[-1]:
132 | i += 1
133 | state_hist = np.array(state_list)
134 |
135 | if not os.path.exists(save_dir):
136 | os.makedirs(save_dir)
137 |
138 | ee_pos_list = list()
139 | for state_i in state_hist:
140 | pino.framesForwardKinematics(self.pino_model, self.pino_data, state_i[6:9])
141 | ee_pos_list.append(self.pino_data.oMf[-1].translation[:2] + self.env.agents[0]['frame'][:2, 3])
142 |
143 | ee_pos_list = np.array(ee_pos_list)
144 | fig1, axes1 = plt.subplots(1, figsize=(10, 10))
145 | axes1.plot(ee_pos_list[:, 0], ee_pos_list[:, 1], label='position')
146 | axes1.plot([0.0, -0.91, -0.91, 0.0], [-0.45, -0.45, 0.45, 0.45], label='boundary', c='k', lw='5')
147 | axes1.set_aspect(1.0)
148 | axes1.set_xlim(-1, 0)
149 | axes1.set_ylim(-0.5, 0.5)
150 | axes1.legend(loc='upper right')
151 | axes1.set_title('EndEffector')
152 | axes1.legend(loc='center right')
153 | file1 = "EndEffector_" + suffix + str(i) + ".pdf"
154 | plt.savefig(os.path.join(save_dir, file1))
155 | plt.close(fig1)
156 |
157 | fig2, axes2 = plt.subplots(1, 3, sharey=True, figsize=(21, 8))
158 | axes2[0].plot(state_hist[:, 6], label='position', c='tab:blue')
159 | axes2[1].plot(state_hist[:, 7], c='tab:blue')
160 | axes2[2].plot(state_hist[:, 8], c='tab:blue')
161 | axes2[0].plot([0, state_hist.shape[0]], [self.pino_model.lowerPositionLimit[0]] * 2,
162 | label='position limit', c='tab:red', ls='--')
163 | axes2[1].plot([0, state_hist.shape[0]], [self.pino_model.lowerPositionLimit[1]] * 2, c='tab:red',
164 | ls='--')
165 | axes2[2].plot([0, state_hist.shape[0]], [self.pino_model.lowerPositionLimit[2]] * 2, c='tab:red',
166 | ls='--')
167 | axes2[0].plot([0, state_hist.shape[0]], [self.pino_model.upperPositionLimit[0]] * 2, c='tab:red',
168 | ls='--')
169 | axes2[1].plot([0, state_hist.shape[0]], [self.pino_model.upperPositionLimit[1]] * 2, c='tab:red',
170 | ls='--')
171 | axes2[2].plot([0, state_hist.shape[0]], [self.pino_model.upperPositionLimit[2]] * 2, c='tab:red',
172 | ls='--')
173 |
174 | axes2[0].plot(state_hist[:, 9], label='velocity', c='tab:orange')
175 | axes2[1].plot(state_hist[:, 10], c='tab:orange')
176 | axes2[2].plot(state_hist[:, 11], c='tab:orange')
177 | axes2[0].plot([0, state_hist.shape[0]], [-self.pino_model.velocityLimit[0]] * 2,
178 | label='velocity limit', c='tab:pink', ls=':')
179 | axes2[1].plot([0, state_hist.shape[0]], [-self.pino_model.velocityLimit[1]] * 2, c='tab:pink', ls=':')
180 | axes2[2].plot([0, state_hist.shape[0]], [-self.pino_model.velocityLimit[2]] * 2, c='tab:pink', ls=':')
181 | axes2[0].plot([0, state_hist.shape[0]], [self.pino_model.velocityLimit[0]] * 2, c='tab:pink', ls=':')
182 | axes2[1].plot([0, state_hist.shape[0]], [self.pino_model.velocityLimit[1]] * 2, c='tab:pink', ls=':')
183 | axes2[2].plot([0, state_hist.shape[0]], [self.pino_model.velocityLimit[2]] * 2, c='tab:pink', ls=':')
184 |
185 | axes2[0].set_title('Joint 1')
186 | axes2[1].set_title('Joint 2')
187 | axes2[2].set_title('Joint 3')
188 | fig2.legend(ncol=4, loc='lower center')
189 |
190 | file2 = "JointProfile_" + suffix + str(i) + ".pdf"
191 | plt.savefig(os.path.join(save_dir, file2))
192 | plt.close(fig2)
193 |
194 | state_list = list()
--------------------------------------------------------------------------------
/atacom/environments/planar_air_hockey/unconstrained_air_hockey.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mushroom_rl.environments.pybullet import PyBulletObservationType
3 | from mushroom_rl.environments.pybullet_envs.air_hockey import AirHockeyHit, AirHockeyDefend
4 |
5 |
6 | class AirHockeyHitUnconstrained(AirHockeyHit):
7 | def __init__(self, gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=1,
8 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, torque_control="torque",
9 | random_init=False, action_penalty=1e-3):
10 | super().__init__(gamma=gamma, horizon=horizon, timestep=timestep,
11 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui,
12 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay,
13 | torque_control=torque_control, random_init=random_init,
14 | step_action_function=self._step_action_function,
15 | action_penalty=action_penalty)
16 | self.constr_logs = list()
17 |
18 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_1"],
19 | maxJointVelocity=self.joints.velocity_limits()[0] * 1.5)
20 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_2"],
21 | maxJointVelocity=self.joints.velocity_limits()[1] * 1.5)
22 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_3"],
23 | maxJointVelocity=self.joints.velocity_limits()[2] * 1.5)
24 | self.acc_max = np.ones(3) * 10
25 |
26 | def setup(self, state):
27 | super().setup(state)
28 | self.constr_logs.clear()
29 |
30 | def _step_action_function(self, state, action):
31 | action = np.clip(action, self.info.action_space.low, self.info.action_space.high)
32 | self._update_constraint_stats(state)
33 | return action
34 |
35 | def _update_constraint_stats(self, state):
36 | q = self.joints.positions(state)
37 | dq = self.joints.velocities(state)
38 | mallet_pose = self.get_sim_state(state, "planar_robot_1/link_striker_ee", PyBulletObservationType.LINK_POS)
39 | c_ee_i = np.array([-mallet_pose[0] - self.env_spec['table']['length'] / 2,
40 | -mallet_pose[1] - self.env_spec['table']['width'] / 2,
41 | mallet_pose[1] - self.env_spec['table']['width'] / 2])
42 | c_q_i = q ** 2 - self.joints.limits()[1] ** 2
43 | c_dq_i = dq ** 2 - self.joints.velocity_limits() ** 2
44 | c_i = np.concatenate([c_ee_i, c_q_i])
45 | self.constr_logs.append([np.max(c_i), np.max(c_dq_i)])
46 |
47 | def get_constraints_logs(self):
48 | constr_logs = np.array(self.constr_logs)
49 | c_avg = np.mean(constr_logs[:, 0])
50 | c_max = np.max(constr_logs[:, 0])
51 | c_dq_max = np.max(constr_logs[:, 1])
52 | self.constr_logs.clear()
53 | return c_avg, c_max, c_dq_max
54 |
55 |
56 | class AirHockeyDefendUnconstrained(AirHockeyDefend):
57 | def __init__(self, gamma=0.99, horizon=120, timestep=1 / 240., n_intermediate_steps=1,
58 | debug_gui=False, env_noise=False, obs_noise=False, obs_delay=False, torque_control="torque",
59 | random_init=False, action_penalty=1e-3):
60 | super().__init__(gamma=gamma, horizon=horizon, timestep=timestep,
61 | n_intermediate_steps=n_intermediate_steps, debug_gui=debug_gui,
62 | env_noise=env_noise, obs_noise=obs_noise, obs_delay=obs_delay,
63 | torque_control=torque_control, random_init=random_init,
64 | step_action_function=self._step_action_function,
65 | action_penalty=action_penalty)
66 | self.constr_logs = list()
67 |
68 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_1"],
69 | maxJointVelocity=self.joints.velocity_limits()[0] * 1.5)
70 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_2"],
71 | maxJointVelocity=self.joints.velocity_limits()[1] * 1.5)
72 | self.client.changeDynamics(*self._indexer.link_map["planar_robot_1/link_3"],
73 | maxJointVelocity=self.joints.velocity_limits()[2] * 1.5)
74 | self.acc_max = np.ones(3) * 10
75 |
76 | def setup(self, state):
77 | super().setup(state)
78 | self.constr_logs.clear()
79 |
80 | def _step_action_function(self, state, action):
81 | action = np.clip(action, self.info.action_space.low, self.info.action_space.high)
82 | self._update_constraint_stats(state)
83 | return action
84 |
85 | def _update_constraint_stats(self, state):
86 | q = self.joints.positions(state)
87 | dq = self.joints.velocities(state)
88 | mallet_pose = self.get_sim_state(state, "planar_robot_1/link_striker_ee", PyBulletObservationType.LINK_POS)
89 | c_ee_i = np.array([-mallet_pose[0] - self.env_spec['table']['length'] / 2,
90 | -mallet_pose[1] - self.env_spec['table']['width'] / 2,
91 | mallet_pose[1] - self.env_spec['table']['width'] / 2])
92 | c_q_i = q ** 2 - self.joints.limits()[1] ** 2
93 | c_dq_i = dq ** 2 - self.joints.velocity_limits() ** 2
94 | c_i = np.concatenate([c_ee_i, c_q_i])
95 | self.constr_logs.append([np.max(c_i), np.max(c_dq_i)])
96 |
97 | def get_constraints_logs(self):
98 | constr_logs = np.array(self.constr_logs)
99 | c_avg = np.mean(constr_logs[:, 0])
100 | c_max = np.max(constr_logs[:, 0])
101 | c_dq_max = np.max(constr_logs[:, 1])
102 | self.constr_logs.clear()
103 | return c_avg, c_max, c_dq_max
--------------------------------------------------------------------------------
/atacom/error_correction_wrapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.utils import spaces
4 | from atacom.utils import pinv_null
5 |
6 |
7 | class ErrorCorrectionEnvWrapper:
8 | """
9 | Environment wrapper of the Error Correction Method
10 |
11 | """
12 | def __init__(self, base_env, dim_q, vel_max, acc_max, f=None, g=None, Kc=100., Kq=10., time_step=0.01):
13 | """
14 | Constructor
15 | Args:
16 | base_env (mushroomrl.Core.Environment): The base environment inherited from
17 | dim_q (int): [int] dimension of the directly controllable variable
18 | vel_max (array, float): the maximum velocity of the directly controllable variable
19 | acc_max (array, float): the maximum acceleration of the directly controllable variable
20 | f (ViabilityConstraint, ConstraintsSet): the equality constraint f(q) = 0
21 | g (ViabilityConstraint, ConstraintsSet): the inequality constraint g(q) = 0
22 | Kc (array, float): the scaling factor for error correction
23 | Ka (array, float): the scaling factor for the viability acceleration bound
24 | time_step (float): the step size for time discretization
25 | """
26 | self.env = base_env
27 | self.dims = {'q': dim_q, 'f': 0, 'g': 0}
28 | self.f = f
29 | self.g = g
30 | self.time_step = time_step
31 | self._logger = None
32 |
33 | if self.f is not None:
34 | assert self.dims['q'] == self.f.dim_q, "Input dimension is different in f"
35 | self.dims['f'] = self.f.dim_out
36 | if self.g is not None:
37 | assert self.dims['q'] == self.g.dim_q, "Input dimension is different in g"
38 | self.dims['g'] = self.g.dim_out
39 | self.s = np.zeros(self.dims['g'])
40 |
41 | self.dims['c'] = self.dims['f'] + self.dims['g']
42 |
43 | if np.isscalar(Kc):
44 | self.K_c = np.ones(self.dims['c']) * Kc
45 | else:
46 | self.K_c = Kc
47 |
48 | self.q = np.zeros(self.dims['q'])
49 | self.dq = np.zeros(self.dims['q'])
50 |
51 | self._mdp_info = self.env.info.copy()
52 | self._mdp_info.action_space = spaces.Box(low=-np.ones(self.dims['q']), high=np.ones(self.dims['q']))
53 |
54 | if np.isscalar(vel_max):
55 | self.vel_max = np.ones(self.dims['q']) * vel_max
56 | else:
57 | self.vel_max = vel_max
58 | assert np.shape(self.vel_max)[0] == self.dims['q']
59 |
60 | if np.isscalar(acc_max):
61 | self.acc_max = np.ones(self.dims['q']) * acc_max
62 | else:
63 | self.acc_max = acc_max
64 | assert np.shape(self.acc_max)[0] == self.dims['q']
65 |
66 | if np.isscalar(Kq):
67 | self.K_q = np.ones(self.dims['q']) * Kq
68 | else:
69 | self.K_q = Kq
70 | assert np.shape(self.K_q)[0] == self.dims['q']
71 |
72 | self.state = self.env.reset()
73 | self._act_a = None
74 | self._act_b = None
75 | self._act_err = None
76 |
77 | self.env.step_action_function = self.step_action_function
78 |
79 | def acc_to_ctrl_action(self, ddq):
80 | raise NotImplementedError
81 |
82 | def _get_q(self, state):
83 | raise NotImplementedError
84 |
85 | def _get_dq(self, state):
86 | raise NotImplementedError
87 |
88 | def seed(self, seed):
89 | self.env.seed(seed)
90 |
91 | def reset(self, state=None):
92 | self.state = self.env.reset(state)
93 | self.q = self._get_q(self.state)
94 | self.dq = self._get_dq(self.state)
95 | self._compute_slack_variables()
96 | return self.state
97 |
98 | def render(self):
99 | self.env.render()
100 |
101 | def stop(self):
102 | self.env.stop()
103 |
104 | def step(self, action):
105 | alpha = np.clip(action, self.info.action_space.low, self.info.action_space.high)
106 | alpha = alpha * self.acc_max
107 |
108 | self.state, reward, absorb, info = self.env.step(alpha)
109 | return self.state.copy(), reward, absorb, info
110 |
111 | def acc_truncation(self, dq, ddq):
112 | acc_u = np.maximum(np.minimum(self.acc_max, -self.K_q * (dq - self.vel_max)), -self.acc_max)
113 | acc_l = np.minimum(np.maximum(-self.acc_max, -self.K_q * (dq + self.vel_max)), self.acc_max)
114 | ddq = np.clip(ddq, acc_l, acc_u)
115 | return ddq
116 |
117 | def step_action_function(self, sim_state, alpha):
118 | self.state = self.env._create_observation(sim_state)
119 | self.q = self._get_q(self.state)
120 | self.dq = self._get_dq(self.state)
121 |
122 | Jc, psi = self._construct_Jc_psi(self.q, self.s, self.dq)
123 | Jc_inv, Nc = pinv_null(Jc)
124 |
125 | self._act_a = np.zeros(self.dims['q'] + self.dims['g'])
126 | self._act_b = np.concatenate([alpha, np.zeros(self.dims['g'])])
127 | self._act_err = self._compute_error_correction(self.q, self.dq, self.s, Jc_inv)
128 | ddq_ds = self._act_a + self._act_b + self._act_err
129 |
130 | self.s += ddq_ds[self.dims['q']:(self.dims['q'] + self.dims['g'])] * self.time_step
131 |
132 | ddq = self.acc_truncation(self.dq, ddq_ds[:self.dims['q']])
133 | ctrl_action = self.acc_to_ctrl_action(ddq)
134 | return ctrl_action
135 |
136 | @property
137 | def info(self):
138 | return self._mdp_info
139 |
140 | def _compute_slack_variables(self):
141 | self.s = None
142 | if self.dims['g'] > 0:
143 | s_2 = np.maximum(-2 * self.g.fun(self.q, self.dq, origin_constr=False), 0)
144 | self.s = np.sqrt(s_2)
145 |
146 | def _construct_Jc_psi(self, q, s, dq):
147 | Jc = np.zeros((self.dims['f'] + self.dims['g'], self.dims['q'] + self.dims['g']))
148 | psi = np.zeros(self.dims['c'])
149 | if self.dims['f'] > 0:
150 | idx_0 = 0
151 | idx_1 = self.dims['f']
152 | Jc[idx_0:idx_1, :self.dims['q']] = self.f.K_J(q)
153 | psi[idx_0:idx_1] = self.f.b(q, dq)
154 | if self.dims['g'] > 0:
155 | idx_0 = self.dims['f']
156 | idx_1 = self.dims['f'] + self.dims['g']
157 | Jc[idx_0:idx_1, :self.dims['q']] = self.g.K_J(q)
158 | Jc[idx_0:idx_1, self.dims['q']:(self.dims['q'] + self.dims['g'])] = np.diag(s)
159 | psi[idx_0:idx_1] = self.g.b(q, dq)
160 | return Jc, psi
161 |
162 | def _compute_error_correction(self, q, dq, s, Jc_inv, act_null=None):
163 | q_tmp = q.copy()
164 | dq_tmp = dq.copy()
165 | s_tmp = None
166 |
167 | if self.dims['g'] > 0:
168 | s_tmp = s.copy()
169 |
170 | if act_null is not None:
171 | q_tmp += dq_tmp * self.time_step + act_null[:self.dims['q']] * self.time_step ** 2 / 2
172 | dq_tmp += act_null[:self.dims['q']] * self.time_step
173 | if self.dims['g'] > 0:
174 | s_tmp += act_null[self.dims['q']:self.dims['q'] + self.dims['g']] * self.time_step
175 |
176 | return -Jc_inv @ (self.K_c * self._compute_c(q_tmp, dq_tmp, s_tmp, origin_constr=False))
177 |
178 | def _compute_c(self, q, dq, s, origin_constr=False):
179 | c = np.zeros(self.dims['f'] + self.dims['g'])
180 | if self.dims['f'] > 0:
181 | idx_0 = 0
182 | idx_1 = self.dims['f']
183 | c[idx_0:idx_1] = self.f.fun(q, dq, origin_constr)
184 | if self.dims['g'] > 0:
185 | idx_0 = self.dims['f']
186 | idx_1 = self.dims['f'] + self.dims['g']
187 | if origin_constr:
188 | c[idx_0:idx_1] = self.g.fun(q, dq, origin_constr)
189 | else:
190 | c[idx_0:idx_1] = self.g.fun(q, dq, origin_constr) + 1 / 2 * s ** 2
191 | return c
192 |
193 | def set_logger(self, logger):
194 | self._logger = logger
195 |
196 | def get_constraints_logs(self):
197 | return self.env.get_constraints_logs()
198 |
--------------------------------------------------------------------------------
/atacom/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .null_space_coordinate import rref, gram_schmidt, pinv_null
2 | from .plot_utils import *
--------------------------------------------------------------------------------
/atacom/utils/null_space_coordinate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse
3 | from scipy import linalg
4 | import sympy
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | def pinv_null(a, rcond=None, return_rank=False):
9 | u, s, vh = linalg.svd(a, full_matrices=True, check_finite=False)
10 | M, N = u.shape[0], vh.shape[1]
11 |
12 | if rcond is None:
13 | rcond = np.finfo(s.dtype).eps * max(M, N)
14 | tol = np.amax(s) * rcond
15 |
16 | rank = np.sum(s > tol)
17 | Q = vh[rank:, :].T.conj()
18 |
19 | u = u[:, :rank]
20 | u /= s[:rank]
21 | B = np.transpose(np.conjugate(np.dot(u, vh[:rank])))
22 |
23 | if return_rank:
24 | return B, Q, rank
25 | else:
26 | return B, Q
27 |
28 |
29 | def rref_sympy(V, row_vectors=True):
30 | if not row_vectors:
31 | V = V.T
32 | sym_V = sympy.Matrix(V)
33 | basis = sympy.matrices.matrix2numpy(sym_V.rref()[0], dtype=float)
34 | if not row_vectors:
35 | return basis.T
36 | else:
37 | return basis
38 |
39 |
40 | def rref(A, row_vectors=True, tol=None):
41 | # Follow the implementation:
42 | # https://www.mathworks.com/matlabcentral/fileexchange/21583-fast-reduced-row-echelon-form
43 | V = A.copy()
44 | if not row_vectors:
45 | V = V.T
46 | m, n = V.shape
47 |
48 | if tol is None:
49 | tol = max(m, n) * np.finfo(V.dtype).eps * linalg.norm(V, np.inf)
50 |
51 | i = 0
52 | j = 0
53 | jb = list()
54 | while i < m and j < n:
55 | # Find value and index of largest element in the remainder of column j.
56 | k = np.argmax(np.abs(V[i:m, j]))
57 | k += i
58 | p = abs(V[k, j])
59 |
60 | if p <= tol:
61 | # The column is negligible, zero it out.
62 | V[i:m, j] = 0
63 | j += 1
64 | else:
65 | # Remember column index
66 | jb += [j]
67 | # Swap i-th and k-th rows.
68 | V[[i, k], j:n] = V[[k, i], j:n]
69 | # Divide the pivot row by the pivot element.
70 | Vi = V[i, j:n] / V[i, j]
71 | # Subtract multiples of the pivot row from all the other rows.
72 | V[:, j:n] = V[:, j:n] - np.outer(V[:, j], Vi)
73 | V[i, j:n] = Vi
74 | i += 1
75 | j += 1
76 |
77 | if not row_vectors:
78 | return V.T
79 | return V
80 |
81 |
82 | def gram_schmidt(A, row_vectors=True, norm_dim=None):
83 | V = A.copy()
84 | if not row_vectors:
85 | V = V.T
86 | for i, v in enumerate(V):
87 | prev_basis = V[0:i]
88 | coeff_vec = prev_basis @ v
89 | v -= coeff_vec @ prev_basis
90 | v_norm = np.linalg.norm(v)
91 | if v_norm > 1e-10:
92 | v /= v_norm
93 | else:
94 | v[v<1e-10] = 0
95 |
96 | if norm_dim is not None:
97 | for i, v in enumerate(V):
98 | bn = np.linalg.norm(v[:norm_dim])
99 | if bn > 0.01:
100 | V[i] = v / bn
101 |
102 | if not row_vectors:
103 | return np.array(V).T
104 | else:
105 | return np.array(V)
106 |
107 |
108 | def orthogonalization_test():
109 | m = 3
110 | n = 2
111 | V_1 = np.random.randn(m, n)
112 | z_1 = np.cross(V_1[:, 0], V_1[:, 1])
113 | print("original: \n", V_1)
114 |
115 | V_2 = gram_schmidt(V_1, row_vectors=False)
116 | z_2 = np.cross(V_2[:, 0], V_2[:, 1])
117 | print("gs: \n", V_2)
118 |
119 | V_3 = rref(V_1, row_vectors=False)
120 | z_3 = np.cross(V_3[:, 0], V_3[:, 1])
121 | print("rref: \n", V_3)
122 |
123 | V_4 = gram_schmidt(V_3, row_vectors=False)
124 | z_4 = np.cross(V_4[:, 0], V_4[:, 1])
125 | print("gs + rref:\n", V_4)
126 |
127 | V_q, v_r = np.linalg.qr(z_1[:, np.newaxis], mode='complete')
128 | V_q = V_q[:, np.where(v_r.flatten() == 0)[0]]
129 | z_q = np.cross(V_q[:, 0], V_q[:, 1])
130 | print("qr: \n", V_q)
131 |
132 | V_svd = scipy.linalg.null_space(z_1[np.newaxis, :])
133 | z_svd = np.cross(V_svd[:, 0], V_svd[:, 1])
134 | print("svd: \n", V_svd)
135 |
136 | fig = plt.figure()
137 | ax = fig.add_subplot(projection='3d')
138 | ax.quiver(0, 0, 0, V_1[0, 0], V_1[1, 0], V_1[2, 0], color='tab:blue', label='original')
139 | ax.quiver(0, 0, 0, V_1[0, 1], V_1[1, 1], V_1[2, 1], color='tab:blue')
140 | ax.quiver(0, 0, 0, z_1[0], z_1[1], z_1[2], color='tab:blue', linestyle='--')
141 |
142 | ax.quiver(0, 0, 0, V_2[0, 0], V_2[1, 0], V_2[2, 0], color='tab:orange', label='GS')
143 | ax.quiver(0, 0, 0, V_2[0, 1], V_2[1, 1], V_2[2, 1], color='tab:orange')
144 | ax.quiver(0, 0, 0, z_2[0], z_2[1], z_2[2], color='tab:orange', linestyle='--')
145 |
146 | ax.quiver(0, 0, 0, V_3[0, 0], V_3[1, 0], V_3[2, 0], color='tab:red', label='RREF')
147 | ax.quiver(0, 0, 0, V_3[0, 1], V_3[1, 1], V_3[2, 1], color='tab:red')
148 | ax.quiver(0, 0, 0, z_3[0], z_3[1], z_3[2], color='tab:red', linestyle='--')
149 |
150 | ax.quiver(0, 0, 0, V_4[0, 0], V_4[1, 0], V_4[2, 0], color='tab:pink', label='RREF + GS')
151 | ax.quiver(0, 0, 0, V_4[0, 1], V_4[1, 1], V_4[2, 1], color='tab:pink')
152 | ax.quiver(0, 0, 0, z_4[0], z_4[1], z_4[2], color='tab:pink', linestyle='--')
153 |
154 | ax.quiver(0, 0, 0, V_q[0, 0], V_q[1, 0], V_q[2, 0], color='tab:brown', label='QR')
155 | ax.quiver(0, 0, 0, V_q[0, 1], V_q[1, 1], V_q[2, 1], color='tab:brown')
156 | ax.quiver(0, 0, 0, z_q[0], z_q[1], z_q[2], color='tab:brown', linestyle='--')
157 |
158 | ax.quiver(0, 0, 0, V_svd[0, 0], V_svd[1, 0], V_svd[2, 0], color='tab:purple', label='SVD')
159 | ax.quiver(0, 0, 0, V_svd[0, 1], V_svd[1, 1], V_svd[2, 1], color='tab:purple')
160 | ax.quiver(0, 0, 0, z_svd[0], z_svd[1], z_svd[2], color='tab:purple', linestyle='--')
161 |
162 | ax.set_xlim(-3, 3)
163 | ax.set_ylim(-3, 3)
164 | ax.set_zlim(-3, 3)
165 | ax.set_xlabel("x")
166 | ax.set_ylabel("y")
167 | ax.set_zlabel("z")
168 | ax.legend()
169 | plt.show()
170 |
171 |
172 | def rref_test():
173 | for i in range(100):
174 | m, n = np.random.randint(1, 10, 2)
175 | V = np.random.randn(m, n)
176 |
177 | V_1 = rref(V)
178 | V_2 = rref_sympy(V)
179 | print(np.isclose(V_1, V_2).all())
180 |
181 |
182 | def gram_schmidt_test():
183 | for i in range(100):
184 | m = 7
185 | n = 5
186 |
187 | V = np.random.randn(m, n)
188 | V_rref = rref(V, row_vectors=False)
189 | V_orth = gram_schmidt(V_rref, row_vectors=False)
190 | print(V_orth.max(), np.abs(V_orth).min())
191 |
192 |
193 |
194 | if __name__ == '__main__':
195 | # rref_test()
196 | # gram_schmidt_test()
197 | orthogonalization_test()
198 |
199 |
200 |
--------------------------------------------------------------------------------
/atacom/utils/plot_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import scipy.stats as st
5 |
6 |
7 | def get_mean_and_confidence(data):
8 | """
9 | Compute the mean and 95% confidence interval
10 | Args:
11 | data (np.ndarray): Array of experiment data of shape (n_runs, n_epochs).
12 | Returns:
13 | The mean of the dataset at each epoch along with the confidence interval.
14 | """
15 | mean = np.mean(data, axis=0)
16 | se = st.sem(data, axis=0)
17 | n = len(data)
18 | interval, _ = st.t.interval(0.95, n - 1, scale=se)
19 | return mean, interval
20 |
21 |
22 | def get_data_list(file_dir, prefix):
23 | data = list()
24 | for filename in os.listdir(file_dir):
25 | if filename.startswith(prefix+'-'):
26 | data.append(np.load(os.path.join(file_dir, filename), allow_pickle=True))
27 | return np.array(data)
28 |
29 |
30 | def plot_learning_curve(log_dir, exp_list):
31 | cm = plt.get_cmap('tab10')
32 |
33 | fig1, ax1 = plt.subplots(3, 1, figsize=(10, 10))
34 | fig1.subplots_adjust(hspace=.5)
35 | fig2, ax2 = plt.subplots(3, 1, figsize=(10, 10))
36 | fig2.subplots_adjust(hspace=.5)
37 |
38 | for i, exp_suffix in enumerate(exp_list):
39 | color_i = cm(i)
40 | exp_dir = os.path.join(log_dir, exp_suffix)
41 | _, seeds, _ = next(os.walk(exp_dir))
42 |
43 | exp_J_list = get_data_list(exp_dir, "J")
44 | exp_R_list = get_data_list(exp_dir, "R")
45 | exp_E_list = get_data_list(exp_dir, "E")
46 | c_max_list = get_data_list(exp_dir, "c_max")
47 | c_avg_list = get_data_list(exp_dir, "c_avg")
48 | c_dq_max_list = get_data_list(exp_dir, "c_dq_max")
49 |
50 | mean_J, conf_J = get_mean_and_confidence(exp_J_list)
51 | mean_R, conf_R = get_mean_and_confidence(exp_R_list)
52 |
53 | ax1[0].plot(mean_J, label=exp_suffix, color=color_i)
54 | ax1[0].fill_between(np.arange(np.size(mean_J)), mean_J + conf_J, mean_J - conf_J, alpha=0.2, color=color_i)
55 | ax1[0].set_title("J")
56 | ax1[1].plot(mean_R, label=exp_suffix, color=color_i)
57 | ax1[1].fill_between(np.arange(np.size(mean_R)), mean_R + conf_R, mean_R - conf_R, alpha=0.2, color=color_i)
58 | ax1[1].set_title("R")
59 | ax1[1].legend()
60 |
61 | if np.all(exp_E_list!=None):
62 | mean_E, conf_E = get_mean_and_confidence(exp_E_list[:, 1:])
63 | ax1[2].plot(mean_E, label=exp_suffix, color=color_i)
64 | ax1[2].fill_between(np.arange(np.size(mean_E)), mean_E + conf_E, mean_E - conf_E, alpha=0.2, color=color_i)
65 | ax1[2].set_title("E")
66 |
67 | mean_c_max, conf_c_max = get_mean_and_confidence(c_max_list)
68 | mean_c_avg, conf_c_avg = get_mean_and_confidence(c_avg_list)
69 | mean_c_dq_max, conf_c_dq_max = get_mean_and_confidence(c_dq_max_list)
70 | ax2[0].plot(mean_c_max, label=exp_suffix, color=color_i)
71 | ax2[0].fill_between(np.arange(np.size(mean_c_max)), mean_c_max + conf_c_max,
72 | mean_c_max - conf_c_max, alpha=0.2, color=color_i)
73 | ax2[0].set_title("c_max")
74 | ax2[1].plot(mean_c_avg, label=exp_suffix, color=color_i)
75 | ax2[1].fill_between(np.arange(np.size(mean_c_avg)), mean_c_avg + conf_c_avg,
76 | mean_c_avg - conf_c_avg, alpha=0.2, color=color_i)
77 | ax2[1].set_title("c_avg")
78 | ax2[2].plot(mean_c_dq_max, label=exp_suffix, color=color_i)
79 | ax2[2].fill_between(np.arange(np.size(mean_c_dq_max)), mean_c_dq_max + conf_c_dq_max,
80 | mean_c_dq_max - conf_c_dq_max, alpha=0.2, color=color_i)
81 | ax2[2].set_title("c_dq_max")
82 | ax2[2].legend()
83 |
84 | fig1.savefig(os.path.join(log_dir, "Reward.pdf"))
85 | fig2.savefig(os.path.join(log_dir, "Constraints.pdf"))
86 | plt.show()
87 |
88 |
89 | def plot_learning_metric(log_dir, exp_list, metric, label_list, title, y_scale='linear', file_name=None):
90 | cm = plt.get_cmap('tab10')
91 |
92 | fig = plt.figure(figsize=(12, 9))
93 | ax = plt.gca()
94 |
95 | for i, exp_suffix in enumerate(exp_list):
96 | color_i = cm(i)
97 | exp_dir = os.path.join(log_dir, exp_suffix)
98 | _, seeds, _ = next(os.walk(exp_dir))
99 |
100 | metric_list = get_data_list(exp_dir, str(metric))
101 |
102 | mean_metric, conf_metric = get_mean_and_confidence(metric_list)
103 |
104 | ax.plot(mean_metric, label=label_list[i], color=color_i)
105 | ax.fill_between(np.arange(np.size(mean_metric)), mean_metric + conf_metric, mean_metric - conf_metric,
106 | alpha=0.1, color=color_i)
107 | ax.legend(fontsize=30)
108 | ax.set_yscale(y_scale)
109 |
110 | ax.set_title(title, fontsize=40)
111 | ax.tick_params('both', labelsize=30)
112 | if file_name is None:
113 | file_name = title + ".pdf"
114 | else:
115 | file_name += ".pdf"
116 | fig.savefig(os.path.join(log_dir, file_name))
117 | plt.show()
118 |
119 | def plot_learning_curve_single(log_dir, exp_name, seeds):
120 | fig1, ax1 = plt.subplots(3, 1)
121 | fig1.subplots_adjust(hspace=.5)
122 | fig2, ax2 = plt.subplots(3, 1)
123 | fig2.subplots_adjust(hspace=.5)
124 |
125 | exp_dir = os.path.join(log_dir, exp_name)
126 |
127 | for seed in seeds:
128 | postfix = "-" + str(seed) + ".npy"
129 | J = np.load(os.path.join(exp_dir, "J" + postfix))
130 | R = np.load(os.path.join(exp_dir, "R" + postfix))
131 | E = np.load(os.path.join(exp_dir, "E" + postfix), allow_pickle=True)
132 | c_max = np.load(os.path.join(exp_dir, "c_max" + postfix))
133 | c_avg = np.load(os.path.join(exp_dir, "c_avg" + postfix))
134 | c_dq_max = np.load(os.path.join(exp_dir, "c_dq_max" + postfix))
135 |
136 | ax1[0].plot(J)
137 | ax1[0].set_title("J")
138 | ax1[1].plot(R)
139 | ax1[1].set_title("R")
140 |
141 | if np.all(E!=None):
142 | ax1[2].plot(E)
143 | ax1[2].set_title("E")
144 |
145 | ax2[0].plot(c_max)
146 | ax2[0].plot(np.zeros_like(c_max), c='tab:red', lw=2)
147 | ax2[0].set_title("c_max")
148 | ax2[1].plot(c_avg)
149 | ax2[1].plot(np.zeros_like(c_avg), c='tab:red', lw=2)
150 | ax2[1].set_title("c_avg")
151 | ax2[2].plot(c_dq_max)
152 | ax2[2].plot(np.zeros_like(c_dq_max), c='tab:red', lw=2)
153 | ax2[2].set_title("c_dq_max")
154 |
155 | plt.show()
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/examples/__init__.py
--------------------------------------------------------------------------------
/examples/circle_exp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from copy import deepcopy
4 | from tqdm import trange
5 | import pandas as pd
6 |
7 | from mushroom_rl.algorithms.actor_critic import PPO, TRPO, DDPG, TD3, SAC
8 | from mushroom_rl.core import Core, Logger
9 | from mushroom_rl.policy import GaussianTorchPolicy, OrnsteinUhlenbeckPolicy, ClippedGaussianPolicy
10 | from mushroom_rl.utils.dataset import compute_J, parse_dataset
11 | from atacom.environments.circular_motion import CircleEnvAtacom, CircleEnvTerminated, CircleEnvErrorCorrection
12 | from network import *
13 |
14 |
15 | def experiment(mdp, agent, seed, results_dir, n_epochs, n_steps, n_steps_per_fit, n_episodes_test,
16 | quiet, render, **kwargs):
17 | build_params = kwargs['build_params']
18 | logger = Logger(results_dir=results_dir, seed=seed, log_name="exp")
19 |
20 | logger.strong_line()
21 | logger.info('Experiment Algorithm: ' + type(agent).__name__)
22 | logger.info('Environment: ' + type(mdp).__name__ + " seed: " + str(seed))
23 |
24 | best_agent = deepcopy(agent)
25 |
26 | core = Core(agent, mdp)
27 |
28 | eval_params = dict(
29 | n_episodes=n_episodes_test,
30 | render=render,
31 | quiet=quiet
32 | )
33 |
34 | J, R, E, c_avg, c_max, c_dq_max = compute_metrics(core, eval_params, build_params)
35 | best_J, best_R, best_E, best_c_avg, best_c_max, best_c_dq_max = J, R, E, c_avg, c_max, c_dq_max
36 |
37 | logger.epoch_info(0, J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max)
38 | logger.weak_line()
39 | logger.log_numpy(J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max)
40 |
41 | for it in trange(n_epochs, leave=False, disable=quiet):
42 | core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit, quiet=quiet, render=render)
43 | J, R, E, c_avg, c_max, c_dq_max = compute_metrics(core, eval_params, build_params)
44 |
45 | logger.epoch_info(it + 1, J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max)
46 | logger.log_numpy(J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max)
47 |
48 | if J > best_J:
49 | best_J = J
50 | best_R = R
51 | best_E = E
52 | best_c_avg = c_avg
53 | best_c_max = c_max
54 | best_c_dq_max = c_dq_max
55 | best_agent = deepcopy(agent)
56 |
57 | if it % 10 == 0:
58 | logger.log_agent(agent, epoch=it)
59 |
60 | logger.info("Best result | J: {}, R: {}, E:{}, c_avg:{}, c_max:{}, c_dq_max{}.".format(best_J, best_R, best_E,
61 | best_c_avg, best_c_max,
62 | best_c_dq_max))
63 | logger.strong_line()
64 | logger.log_agent(best_agent)
65 | best_res = {"best_J": best_J, "best_R": best_R, "best_E": best_E,
66 | "best_c_avg": best_c_avg, "best_c_max": best_c_max, "best_c_dq_max": best_c_dq_max}
67 | best_res = pd.DataFrame.from_dict(best_res, orient="index")
68 | best_res.to_csv(os.path.join(logger.path, "best_result.csv"))
69 |
70 |
71 | def compute_metrics(core, eval_params, build_params):
72 | dataset = core.evaluate(**eval_params)
73 | c_avg, c_max, c_dq_max = core.mdp.get_constraints_logs()
74 | J = np.mean(compute_J(dataset, core.mdp.info.gamma))
75 | R = np.mean(compute_J(dataset))
76 | E = None
77 | if build_params['compute_policy_entropy']:
78 | if build_params['compute_entropy_with_states']:
79 | E = core.agent.policy.entropy(parse_dataset(dataset)[0])
80 | else:
81 | E = core.agent.policy.entropy()
82 | return J, R, E, c_avg, c_max, c_dq_max
83 |
84 |
85 | def build_env(env, horizon, gamma, random_init, **kwargs):
86 | if env == 'A':
87 | mdp = CircleEnvAtacom(horizon=horizon, gamma=gamma, random_init=random_init, Kc=100)
88 | elif env == 'T':
89 | mdp = CircleEnvTerminated(horizon=horizon, gamma=gamma, random_init=random_init, tol=kwargs['termination_tol'])
90 | elif env == 'E':
91 | mdp = CircleEnvErrorCorrection(horizon=horizon, gamma=gamma, random_init=random_init)
92 | else:
93 | raise NotImplementedError
94 | return mdp
95 |
96 |
97 | def build_agent(alg, mdp_info, **kwargs):
98 | alg = alg.upper()
99 | if alg == 'PPO':
100 | agent, build_params = build_agent_PPO(mdp_info, **kwargs)
101 | elif alg == 'TRPO':
102 | agent, build_params = build_agent_TRPO(mdp_info, **kwargs)
103 | elif alg == 'DDPG':
104 | agent, build_params = build_agent_DDPG(mdp_info, **kwargs)
105 | elif alg == 'TD3':
106 | agent, build_params = build_agent_TD3(mdp_info, **kwargs)
107 | elif alg == 'SAC':
108 | agent, build_params = build_agent_SAC(mdp_info, **kwargs)
109 | else:
110 | raise NotImplementedError
111 | return agent, build_params
112 |
113 |
114 | def build_agent_PPO(mdp_info, actor_lr, critic_lr, n_features, batch_size, eps_ppo, lam, ent_coeff, **kwargs):
115 | policy_params = dict(
116 | std_0=0.5,
117 | n_features=n_features,
118 | use_cuda=torch.cuda.is_available()
119 | )
120 | policy = GaussianTorchPolicy(PPONetwork,
121 | mdp_info.observation_space.shape,
122 | mdp_info.action_space.shape,
123 | **policy_params)
124 |
125 | critic_params = dict(network=PPONetwork,
126 | optimizer={'class': optim.Adam,
127 | 'params': {'lr': critic_lr}},
128 | loss=F.mse_loss,
129 | n_features=n_features,
130 | batch_size=batch_size,
131 | input_shape=mdp_info.observation_space.shape,
132 | output_shape=(1,))
133 |
134 | ppo_params = dict(actor_optimizer={'class': optim.Adam,
135 | 'params': {'lr': actor_lr}},
136 | n_epochs_policy=4,
137 | batch_size=batch_size,
138 | eps_ppo=eps_ppo,
139 | lam=lam,
140 | ent_coeff=ent_coeff,
141 | critic_params=critic_params)
142 |
143 | build_params = dict(compute_entropy_with_states=False,
144 | compute_policy_entropy=True)
145 |
146 | return PPO(mdp_info, policy, **ppo_params), build_params
147 |
148 |
149 | def build_agent_TRPO(mdp_info, critic_lr, n_features, batch_size, lam, ent_coeff,
150 | max_kl, n_epochs_line_search, n_epochs_cg, cg_damping, cg_residual_tol, critic_fit_params,
151 | **kwargs):
152 | policy_params = dict(
153 | std_0=0.5,
154 | n_features=n_features,
155 | use_cuda=torch.cuda.is_available()
156 | )
157 |
158 | critic_params = dict(
159 | network=TRPONetwork,
160 | optimizer={'class': optim.Adam,
161 | 'params': {'lr': critic_lr}},
162 | loss=F.mse_loss,
163 | n_features=n_features,
164 | batch_size=batch_size,
165 | input_shape=mdp_info.observation_space.shape,
166 | output_shape=(1,))
167 |
168 | trpo_params = dict(
169 | ent_coeff=ent_coeff,
170 | max_kl=max_kl,
171 | lam=lam,
172 | n_epochs_line_search=n_epochs_line_search,
173 | n_epochs_cg=n_epochs_cg,
174 | cg_damping=cg_damping,
175 | cg_residual_tol=cg_residual_tol,
176 | critic_fit_params=critic_fit_params)
177 |
178 | policy = GaussianTorchPolicy(TRPONetwork,
179 | mdp_info.observation_space.shape,
180 | mdp_info.action_space.shape,
181 | **policy_params)
182 |
183 | build_params = dict(compute_entropy_with_states=False,
184 | compute_policy_entropy=True)
185 |
186 | return TRPO(mdp_info, policy, critic_params, **trpo_params), build_params
187 |
188 |
189 | def build_agent_DDPG(mdp_info, actor_lr, critic_lr, n_features, batch_size,
190 | initial_replay_size, max_replay_size, tau, **kwargs):
191 | policy_params = dict(
192 | sigma=np.ones(1) * .2,
193 | theta=0.15,
194 | dt=1e-2)
195 |
196 | actor_params = dict(
197 | network=DDPGActorNetwork,
198 | input_shape=mdp_info.observation_space.shape,
199 | output_shape=mdp_info.action_space.shape,
200 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2,
201 | n_features=n_features,
202 | use_cuda=torch.cuda.is_available())
203 |
204 | actor_optimizer = {
205 | 'class': optim.Adam,
206 | 'params': {'lr': actor_lr}}
207 |
208 | critic_params = dict(
209 | network=DDPGCriticNetwork,
210 | optimizer={'class': optim.Adam,
211 | 'params': {'lr': critic_lr}},
212 | loss=F.mse_loss,
213 | n_features=n_features,
214 | batch_size=batch_size,
215 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],),
216 | action_shape=mdp_info.action_space.shape,
217 | output_shape=(1,),
218 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2,
219 | use_cuda=torch.cuda.is_available())
220 |
221 | alg_params = dict(
222 | initial_replay_size=initial_replay_size,
223 | max_replay_size=max_replay_size,
224 | batch_size=batch_size,
225 | tau=tau)
226 |
227 | build_params = dict(compute_entropy_with_states=False,
228 | compute_policy_entropy=False)
229 |
230 | return DDPG(mdp_info, OrnsteinUhlenbeckPolicy, policy_params, actor_params, actor_optimizer, critic_params,
231 | **alg_params), build_params
232 |
233 |
234 | def build_agent_TD3(mdp_info, actor_lr, critic_lr, n_features, batch_size,
235 | initial_replay_size, max_replay_size, tau, sigma, **kwargs):
236 | policy_params = dict(
237 | sigma=np.eye(mdp_info.action_space.shape[0]) * sigma,
238 | low=mdp_info.action_space.low,
239 | high=mdp_info.action_space.high)
240 |
241 | actor_params = dict(
242 | network=TD3ActorNetwork,
243 | input_shape=mdp_info.observation_space.shape,
244 | output_shape=mdp_info.action_space.shape,
245 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2,
246 | n_features=n_features,
247 | use_cuda=torch.cuda.is_available())
248 |
249 | actor_optimizer = {
250 | 'class': optim.Adam,
251 | 'params': {'lr': actor_lr}}
252 |
253 | critic_params = dict(
254 | network=TD3CriticNetwork,
255 | optimizer={'class': optim.Adam,
256 | 'params': {'lr': critic_lr}},
257 | loss=F.mse_loss,
258 | n_features=n_features,
259 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],),
260 | action_shape=mdp_info.action_space.shape,
261 | output_shape=(1,),
262 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2,
263 | use_cuda=torch.cuda.is_available())
264 |
265 | alg_params = dict(
266 | initial_replay_size=initial_replay_size,
267 | max_replay_size=max_replay_size,
268 | batch_size=batch_size,
269 | tau=tau)
270 |
271 | build_params = dict(compute_entropy_with_states=False,
272 | compute_policy_entropy=False)
273 |
274 | return TD3(mdp_info, ClippedGaussianPolicy, policy_params, actor_params, actor_optimizer, critic_params,
275 | **alg_params), build_params
276 |
277 |
278 | def build_agent_SAC(mdp_info, actor_lr, critic_lr, n_features, batch_size,
279 | initial_replay_size, max_replay_size, tau,
280 | warmup_transitions, lr_alpha, target_entropy,
281 | **kwargs):
282 | actor_mu_params = dict(network=SACActorNetwork,
283 | input_shape=mdp_info.observation_space.shape,
284 | output_shape=mdp_info.action_space.shape,
285 | n_features=n_features,
286 | use_cuda=torch.cuda.is_available())
287 | actor_sigma_params = dict(network=SACActorNetwork,
288 | input_shape=mdp_info.observation_space.shape,
289 | output_shape=mdp_info.action_space.shape,
290 | n_features=n_features,
291 | use_cuda=torch.cuda.is_available())
292 |
293 | actor_optimizer = {'class': optim.Adam,
294 | 'params': {'lr': actor_lr}}
295 | critic_params = dict(network=SACCriticNetwork,
296 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],),
297 | optimizer={'class': optim.Adam,
298 | 'params': {'lr': critic_lr}},
299 | loss=F.mse_loss,
300 | n_features=n_features,
301 | output_shape=(1,),
302 | use_cuda=torch.cuda.is_available())
303 |
304 | alg_params = dict(initial_replay_size=initial_replay_size,
305 | max_replay_size=max_replay_size,
306 | batch_size=batch_size,
307 | warmup_transitions=warmup_transitions,
308 | tau=tau,
309 | lr_alpha=lr_alpha,
310 | critic_fit_params=None,
311 | target_entropy=target_entropy)
312 |
313 | build_params = dict(compute_entropy_with_states=True,
314 | compute_policy_entropy=True)
315 |
316 | return SAC(mdp_info, actor_mu_params, actor_sigma_params, actor_optimizer, critic_params,
317 | **alg_params), build_params
318 |
319 |
320 | def default_params():
321 | defaults = dict(env='A', alg='TRPO', seed=1,
322 | horizon=500, gamma=0.99, random_init=False, quiet=False, termination_tol=0.4, render=False,
323 | results_dir='../logs/circular_motion')
324 | training_params = dict(n_epochs=50, n_steps=5000, n_steps_per_fit=1000, n_episodes_test=25)
325 |
326 | network_params = dict(actor_lr=3e-4, critic_lr=3e-4, n_features=[32, 32], batch_size=64)
327 |
328 | trpo_ppo_params = dict(lam=0.95, ent_coeff=5e-5)
329 | ppo_params = dict(eps_ppo=0.1)
330 | trpo_params = dict(max_kl=1e-2, n_epochs_line_search=10, n_epochs_cg=10, cg_damping=1e-2, cg_residual_tol=1e-10,
331 | critic_fit_params=None)
332 |
333 | ddpg_td3_sac_params = dict(initial_replay_size=5000, max_replay_size=200000, tau=1e-3)
334 | td3_params = dict(sigma=1.0)
335 |
336 | sac_params = dict(warmup_transitions=10000, lr_alpha=3e-3, target_entropy=-6)
337 |
338 | defaults.update(training_params)
339 | defaults.update(network_params)
340 | defaults.update(trpo_ppo_params)
341 | defaults.update(ppo_params)
342 | defaults.update(trpo_params)
343 | defaults.update(ddpg_td3_sac_params)
344 | defaults.update(td3_params)
345 | defaults.update(sac_params)
346 | return defaults
347 |
348 |
349 | def parse_args():
350 | parser = argparse.ArgumentParser()
351 |
352 | arg_test = parser.add_argument_group('Experiment')
353 | arg_test.add_argument('--env', choices=['A', 'T', 'E'], help="Environment argument ['A', 'T', 'E']: "
354 | "'A' for ATACOM, "
355 | "'T' for TerminatedCircle, "
356 | "'E' for ErrorCorrection")
357 | arg_test.add_argument('--alg', choices=['TRPO', 'trpo', 'PPO', 'ppo', 'DDPG', 'ddpg', 'TD3', 'td3', 'SAC', 'sac'])
358 |
359 | arg_test.add_argument('--horizon', type=int)
360 | arg_test.add_argument('--gamma', type=float)
361 | arg_test.add_argument('--random-init', action="store_true")
362 | arg_test.add_argument('--termination-tol', type=float)
363 | arg_test.add_argument('--quiet', action="store_true")
364 | arg_test.add_argument('--render', action="store_true")
365 |
366 | # training parameter
367 | arg_test.add_argument('--n-epochs', type=int)
368 | arg_test.add_argument('--n-steps', type=int)
369 | arg_test.add_argument('--n-steps-per-fit', type=int)
370 | arg_test.add_argument('--n-episodes-test', type=int)
371 |
372 | # network parameter
373 | arg_test.add_argument('--actor-lr', type=float)
374 | arg_test.add_argument('--critic-lr', type=float)
375 | arg_test.add_argument('--n-features', nargs='+')
376 | arg_test.add_argument('--batch-size', type=int)
377 |
378 | # TRPO PPO parameter
379 | arg_test.add_argument('--lam', type=float)
380 | arg_test.add_argument('--ent-coeff', type=float)
381 |
382 | # PPO parameters
383 | arg_test.add_argument('--eps-ppo', type=float)
384 |
385 | # TRPO parameters
386 | arg_test.add_argument('--max-kl', type=float)
387 | arg_test.add_argument('--n-epochs-line-search', type=int)
388 | arg_test.add_argument('--n-epochs-cg', type=int)
389 | arg_test.add_argument('--cg-damping', type=float)
390 | arg_test.add_argument('--cg-residual-tol', type=float)
391 |
392 | # DDPG TD3 parameters
393 | arg_test.add_argument('--initial-replay-size', type=int)
394 | arg_test.add_argument('--max-replay-size', type=int)
395 | arg_test.add_argument('--tau', type=float)
396 |
397 | # TD3 parameters
398 | arg_test.add_argument('--sigma', type=float)
399 |
400 | # SAC parameters
401 | arg_test.add_argument('--warmup-transitions', type=int)
402 | arg_test.add_argument('--lr-alpha', type=float)
403 | arg_test.add_argument('--target-entropy', type=float)
404 |
405 | arg_default = parser.add_argument_group('Default')
406 | arg_default.add_argument('--seed', type=int)
407 | arg_default.add_argument('--results-dir', type=str)
408 |
409 | parser.set_defaults(**default_params())
410 | args = parser.parse_args()
411 | return vars(args)
412 |
413 |
414 | if __name__ == '__main__':
415 | args_ = parse_args()
416 | env_ = build_env(**args_)
417 | agent_, build_params_ = build_agent(mdp_info=env_.info, **args_)
418 | experiment(env_, agent_, build_params=build_params_, **args_)
419 |
--------------------------------------------------------------------------------
/examples/iiwa_air_hockey_exp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import pandas as pd
5 | from mushroom_rl.algorithms.actor_critic import PPO, TRPO, DDPG, TD3, SAC
6 | from mushroom_rl.core import Core, Logger
7 | from mushroom_rl.policy import GaussianTorchPolicy, OrnsteinUhlenbeckPolicy, ClippedGaussianPolicy
8 | from mushroom_rl.utils.dataset import compute_J, parse_dataset
9 | from mushroom_rl.utils.preprocessors import MinMaxPreprocessor
10 | from tqdm import trange
11 |
12 | from atacom.environments.iiwa_air_hockey import AirHockeyIiwaAtacom, AirHockeyIiwaRmp
13 | from network import *
14 |
15 |
16 | def experiment(seed, results_dir, n_epochs, n_steps, n_steps_per_fit, n_episodes_test,
17 | quiet, **kwargs):
18 | mdp = build_env(**kwargs)
19 |
20 | agent, build_params = build_agent(mdp_info=mdp.info, **kwargs)
21 |
22 | logger = Logger(results_dir=results_dir, seed=seed, log_name='exp')
23 |
24 | logger.strong_line()
25 | logger.info('Experiment Algorithm: ' + type(agent).__name__)
26 | if hasattr(mdp, "env"):
27 | logger.info('Environment: ' + type(mdp.env).__name__ + " seed: " + str(seed))
28 | else:
29 | logger.info('Environment: ' + type(mdp).__name__ + " seed: " + str(seed))
30 |
31 | # normalization callback
32 | prepro = MinMaxPreprocessor(mdp_info=mdp.info)
33 |
34 | core = Core(agent, mdp, preprocessors=[prepro])
35 |
36 | eval_params = dict(
37 | n_episodes=n_episodes_test,
38 | render=False,
39 | quiet=quiet
40 | )
41 |
42 | J, R, E, c_avg, c_max, c_dq_max = compute_metrics(core, eval_params, build_params)
43 | best_J, best_R, best_E, best_c_avg, best_c_max, best_c_dq_max = J, R, E, c_avg, c_max, c_dq_max
44 |
45 | logger.epoch_info(0, J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max)
46 | logger.log_numpy(J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max)
47 | logger.log_agent(agent)
48 | prepro.save(os.path.join(logger.path, "state_normalization" + logger._suffix + ".msh"))
49 |
50 | for it in trange(n_epochs, leave=False, disable=quiet):
51 | core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit, quiet=quiet)
52 | J, R, E, c_avg, c_max, c_dq_max = compute_metrics(core, eval_params, build_params)
53 |
54 | logger.epoch_info(it + 1, J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max)
55 | logger.log_numpy(J=J, R=R, E=E, c_avg=c_avg, c_max=c_max, c_dq_max=c_dq_max)
56 |
57 | if J > best_J:
58 | best_J = J
59 | best_R = R
60 | best_E = E
61 | best_c_avg = c_avg
62 | best_c_max = c_max
63 | best_c_dq_max = c_dq_max
64 |
65 | logger.log_agent(agent)
66 | prepro.save(os.path.join(logger.path, "state_normalization" + logger._suffix + ".msh"))
67 |
68 | logger.info("Best result | J: {}, R: {}, E:{}, c_avg:{}, c_max:{}, c_dq_max{}.".format(best_J, best_R, best_E,
69 | best_c_avg, best_c_max,
70 | best_c_dq_max))
71 | logger.strong_line()
72 | best_res = {"best_J": best_J, "best_R": best_R, "best_E": best_E,
73 | "best_c_avg": best_c_avg, "best_c_max": best_c_max, "best_c_dq_max": best_c_dq_max}
74 | best_res = pd.DataFrame.from_dict(best_res, orient="index")
75 | best_res.to_csv(os.path.join(logger.path, "best_result.csv"))
76 |
77 |
78 | def compute_metrics(core, eval_params, build_params):
79 | dataset = core.evaluate(**eval_params)
80 | c_avg, c_max, c_dq_max = 0., 0., 0.
81 | if hasattr(core.mdp, "get_constraints_logs"):
82 | c_avg, c_max, c_dq_max = core.mdp.get_constraints_logs()
83 | J = np.mean(compute_J(dataset, core.mdp.info.gamma))
84 | R = np.mean(compute_J(dataset))
85 | E = None
86 | if build_params['compute_policy_entropy']:
87 | if build_params['compute_entropy_with_states']:
88 | E = core.agent.policy.entropy(parse_dataset(dataset)[0])
89 | else:
90 | E = core.agent.policy.entropy()
91 | return J, R, E, c_avg, c_max, c_dq_max
92 |
93 |
94 | def build_env(**kwargs):
95 | env = kwargs['env']
96 | gamma = kwargs['gamma']
97 | random_init = kwargs['random_init']
98 | debug_gui = kwargs['debug_gui']
99 |
100 | iiwa_hit_horizon = kwargs['iiwa_hit_horizon']
101 | iiwa_hit_time_step = kwargs['iiwa_hit_time_step']
102 | iiwa_hit_n_intermediate_steps = kwargs['iiwa_hit_n_intermediate_steps']
103 |
104 | if env == '7H':
105 | mdp = AirHockeyIiwaAtacom(task='H', horizon=iiwa_hit_horizon, gamma=gamma, random_init=random_init,
106 | timestep=iiwa_hit_time_step, n_intermediate_steps=iiwa_hit_n_intermediate_steps,
107 | debug_gui=debug_gui)
108 | elif env == 'RMP':
109 | mdp = AirHockeyIiwaRmp(task='H', horizon=120, gamma=gamma, random_init=random_init,
110 | timestep=1 / 240., n_intermediate_steps=4,
111 | debug_gui=debug_gui)
112 | else:
113 | raise NotImplementedError
114 | return mdp
115 |
116 |
117 | def build_agent(alg, mdp_info, **kwargs):
118 | if isinstance(kwargs['n_features'], str):
119 | kwargs['n_features'] = kwargs['n_features'].split(' ')
120 |
121 | alg = alg.upper()
122 | if alg == 'PPO':
123 | agent, build_params = build_agent_PPO(mdp_info, **kwargs)
124 | elif alg == 'TRPO':
125 | agent, build_params = build_agent_TRPO(mdp_info, **kwargs)
126 | elif alg == 'DDPG':
127 | agent, build_params = build_agent_DDPG(mdp_info, **kwargs)
128 | elif alg == 'TD3':
129 | agent, build_params = build_agent_TD3(mdp_info, **kwargs)
130 | elif alg == 'SAC':
131 | agent, build_params = build_agent_SAC(mdp_info, **kwargs)
132 | else:
133 | raise NotImplementedError
134 | return agent, build_params
135 |
136 |
137 | def build_agent_PPO(mdp_info, actor_lr, critic_lr, n_features, batch_size, eps_ppo, lam, ent_coeff, use_cuda, **kwargs):
138 | policy_params = dict(
139 | std_0=0.5,
140 | n_features=n_features,
141 | use_cuda=use_cuda
142 | )
143 | policy = GaussianTorchPolicy(PPONetwork,
144 | mdp_info.observation_space.shape,
145 | mdp_info.action_space.shape,
146 | **policy_params)
147 |
148 | critic_params = dict(network=PPONetwork,
149 | optimizer={'class': optim.Adam,
150 | 'params': {'lr': critic_lr}},
151 | loss=F.mse_loss,
152 | n_features=n_features,
153 | batch_size=batch_size,
154 | input_shape=mdp_info.observation_space.shape,
155 | output_shape=(1,))
156 |
157 | ppo_params = dict(actor_optimizer={'class': optim.Adam,
158 | 'params': {'lr': actor_lr}},
159 | n_epochs_policy=4,
160 | batch_size=batch_size,
161 | eps_ppo=eps_ppo,
162 | lam=lam,
163 | ent_coeff=ent_coeff,
164 | critic_params=critic_params)
165 |
166 | build_params = dict(compute_entropy_with_states=False,
167 | compute_policy_entropy=True)
168 |
169 | return PPO(mdp_info, policy, **ppo_params), build_params
170 |
171 |
172 | def build_agent_TRPO(mdp_info, critic_lr, n_features, batch_size, lam, ent_coeff, use_cuda,
173 | max_kl, n_epochs_line_search, n_epochs_cg, cg_damping, cg_residual_tol, critic_fit_params,
174 | **kwargs):
175 | policy_params = dict(
176 | std_0=0.5,
177 | n_features=n_features,
178 | use_cuda=use_cuda
179 | )
180 |
181 | critic_params = dict(
182 | network=TRPONetwork,
183 | optimizer={'class': optim.Adam,
184 | 'params': {'lr': critic_lr}},
185 | loss=F.mse_loss,
186 | n_features=n_features,
187 | batch_size=batch_size,
188 | input_shape=mdp_info.observation_space.shape,
189 | output_shape=(1,))
190 |
191 | trpo_params = dict(
192 | ent_coeff=ent_coeff,
193 | max_kl=max_kl,
194 | lam=lam,
195 | n_epochs_line_search=n_epochs_line_search,
196 | n_epochs_cg=n_epochs_cg,
197 | cg_damping=cg_damping,
198 | cg_residual_tol=cg_residual_tol,
199 | critic_fit_params=critic_fit_params)
200 |
201 | policy = GaussianTorchPolicy(TRPONetwork,
202 | mdp_info.observation_space.shape,
203 | mdp_info.action_space.shape,
204 | **policy_params)
205 |
206 | build_params = dict(compute_entropy_with_states=False,
207 | compute_policy_entropy=True)
208 |
209 | return TRPO(mdp_info, policy, critic_params, **trpo_params), build_params
210 |
211 |
212 | def build_agent_DDPG(mdp_info, actor_lr, critic_lr, n_features, batch_size,
213 | initial_replay_size, max_replay_size, tau, use_cuda, **kwargs):
214 | policy_params = dict(
215 | sigma=np.ones(1) * .2,
216 | theta=0.15,
217 | dt=1e-2)
218 |
219 | actor_params = dict(
220 | network=DDPGActorNetwork,
221 | input_shape=mdp_info.observation_space.shape,
222 | output_shape=mdp_info.action_space.shape,
223 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2,
224 | n_features=n_features,
225 | use_cuda=use_cuda)
226 |
227 | actor_optimizer = {
228 | 'class': optim.Adam,
229 | 'params': {'lr': actor_lr}}
230 |
231 | critic_params = dict(
232 | network=DDPGCriticNetwork,
233 | optimizer={'class': optim.Adam,
234 | 'params': {'lr': critic_lr}},
235 | loss=F.mse_loss,
236 | n_features=n_features,
237 | batch_size=batch_size,
238 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],),
239 | action_shape=mdp_info.action_space.shape,
240 | output_shape=(1,),
241 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2,
242 | use_cuda=use_cuda)
243 |
244 | alg_params = dict(
245 | initial_replay_size=initial_replay_size,
246 | max_replay_size=max_replay_size,
247 | batch_size=batch_size,
248 | tau=tau)
249 |
250 | build_params = dict(compute_entropy_with_states=False,
251 | compute_policy_entropy=False)
252 |
253 | return DDPG(mdp_info, OrnsteinUhlenbeckPolicy, policy_params, actor_params, actor_optimizer, critic_params,
254 | **alg_params), build_params
255 |
256 |
257 | def build_agent_TD3(mdp_info, actor_lr, critic_lr, n_features, batch_size, use_cuda,
258 | initial_replay_size, max_replay_size, tau, sigma, **kwargs):
259 | policy_params = dict(
260 | sigma=np.eye(mdp_info.action_space.shape[0]) * sigma,
261 | low=mdp_info.action_space.low,
262 | high=mdp_info.action_space.high)
263 |
264 | actor_params = dict(
265 | network=TD3ActorNetwork,
266 | input_shape=mdp_info.observation_space.shape,
267 | output_shape=mdp_info.action_space.shape,
268 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2,
269 | n_features=n_features,
270 | use_cuda=use_cuda)
271 |
272 | actor_optimizer = {
273 | 'class': optim.Adam,
274 | 'params': {'lr': actor_lr}}
275 |
276 | critic_params = dict(
277 | network=TD3CriticNetwork,
278 | optimizer={'class': optim.Adam,
279 | 'params': {'lr': critic_lr}},
280 | loss=F.mse_loss,
281 | n_features=n_features,
282 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],),
283 | action_shape=mdp_info.action_space.shape,
284 | output_shape=(1,),
285 | action_scaling=(mdp_info.action_space.high - mdp_info.action_space.low) / 2,
286 | use_cuda=use_cuda)
287 |
288 | alg_params = dict(
289 | initial_replay_size=initial_replay_size,
290 | max_replay_size=max_replay_size,
291 | batch_size=batch_size,
292 | tau=tau)
293 |
294 | build_params = dict(compute_entropy_with_states=False,
295 | compute_policy_entropy=False)
296 |
297 | return TD3(mdp_info, ClippedGaussianPolicy, policy_params, actor_params, actor_optimizer, critic_params,
298 | **alg_params), build_params
299 |
300 |
301 | def build_agent_SAC(mdp_info, actor_lr, critic_lr, n_features, batch_size,
302 | initial_replay_size, max_replay_size, tau,
303 | warmup_transitions, lr_alpha, target_entropy, use_cuda,
304 | **kwargs):
305 | actor_mu_params = dict(network=SACActorNetwork,
306 | input_shape=mdp_info.observation_space.shape,
307 | output_shape=mdp_info.action_space.shape,
308 | n_features=n_features,
309 | use_cuda=use_cuda)
310 | actor_sigma_params = dict(network=SACActorNetwork,
311 | input_shape=mdp_info.observation_space.shape,
312 | output_shape=mdp_info.action_space.shape,
313 | n_features=n_features,
314 | use_cuda=use_cuda)
315 |
316 | actor_optimizer = {'class': optim.Adam,
317 | 'params': {'lr': actor_lr}}
318 | critic_params = dict(network=SACCriticNetwork,
319 | input_shape=(mdp_info.observation_space.shape[0] + mdp_info.action_space.shape[0],),
320 | optimizer={'class': optim.Adam,
321 | 'params': {'lr': critic_lr}},
322 | loss=F.mse_loss,
323 | n_features=n_features,
324 | output_shape=(1,),
325 | use_cuda=use_cuda)
326 |
327 | alg_params = dict(initial_replay_size=initial_replay_size,
328 | max_replay_size=max_replay_size,
329 | batch_size=batch_size,
330 | warmup_transitions=warmup_transitions,
331 | tau=tau,
332 | lr_alpha=lr_alpha,
333 | critic_fit_params=None,
334 | target_entropy=target_entropy)
335 |
336 | build_params = dict(compute_entropy_with_states=True,
337 | compute_policy_entropy=True)
338 |
339 | return SAC(mdp_info, actor_mu_params, actor_sigma_params, actor_optimizer, critic_params,
340 | **alg_params), build_params
341 |
342 |
343 | def default_params():
344 | defaults = dict(env='RMP', alg='SAC', seed=1,
345 | gamma=0.99, random_init=False, debug_gui=True, quiet=False, use_cuda=False,
346 | iiwa_hit_time_step=1 / 240., iiwa_hit_n_intermediate_steps=4, iiwa_hit_horizon=120,
347 | results_dir='../logs/iiwa_hit')
348 | training_params = dict(n_epochs=100, n_steps=3000, n_steps_per_fit=600, n_episodes_test=25)
349 |
350 | network_params = dict(actor_lr=3e-4, critic_lr=3e-4, n_features=[64, 64], batch_size=64)
351 |
352 | trpo_ppo_params = dict(lam=0.95, ent_coeff=5e-5)
353 | ppo_params = dict(eps_ppo=0.1)
354 | trpo_params = dict(max_kl=1e-2, n_epochs_line_search=10, n_epochs_cg=10, cg_damping=1e-2, cg_residual_tol=1e-10,
355 | critic_fit_params=None)
356 |
357 | ddpg_td3_sac_params = dict(initial_replay_size=5000, max_replay_size=200000, tau=1e-3)
358 | td3_params = dict(sigma=0.25)
359 |
360 | sac_params = dict(warmup_transitions=10000, lr_alpha=3e-4, target_entropy=-6)
361 |
362 | defaults.update(training_params)
363 | defaults.update(network_params)
364 | defaults.update(trpo_ppo_params)
365 | defaults.update(ppo_params)
366 | defaults.update(trpo_params)
367 | defaults.update(ddpg_td3_sac_params)
368 | defaults.update(td3_params)
369 | defaults.update(sac_params)
370 | return defaults
371 |
372 |
373 | def parse_args():
374 | parser = argparse.ArgumentParser()
375 |
376 | arg_test = parser.add_argument_group('Experiment')
377 | arg_test.add_argument('--env', choices=['7H', 'RMP'],
378 | help="Environment argument ['7H', 'RMP']: "
379 | "7H for Iiwa Hitting using ATACOM, "
380 | "RMP for Iiwa Hitting using Riemannian Motion Policies.")
381 | arg_test.add_argument('--alg', choices=['TRPO', 'trpo', 'PPO', 'ppo', 'DDPG', 'ddpg', 'TD3', 'td3', 'SAC', 'sac'])
382 |
383 | arg_test.add_argument('--gamma', type=float)
384 | arg_test.add_argument('--random-init', action="store_true")
385 | arg_test.add_argument('--termination-tol', type=float)
386 | arg_test.add_argument('--debug-gui', action="store_true")
387 | arg_test.add_argument('--quiet', action="store_true")
388 | arg_test.add_argument('--use-cuda', action="store_true")
389 |
390 | arg_test.add_argument('--iiwa-hit-horizon', type=int)
391 | arg_test.add_argument('--iiwa-hit-time-step', type=float)
392 | arg_test.add_argument('--iiwa-hit-n-intermediate-steps', type=int)
393 |
394 | # training parameter
395 | arg_test.add_argument('--n-epochs', type=int)
396 | arg_test.add_argument('--n-steps', type=int)
397 | arg_test.add_argument('--n-steps-per-fit', type=int)
398 | arg_test.add_argument('--n-episodes-test', type=int)
399 |
400 | # network parameter
401 | arg_test.add_argument('--actor-lr', type=float)
402 | arg_test.add_argument('--critic-lr', type=float)
403 | arg_test.add_argument('--n-features', nargs='+')
404 | arg_test.add_argument('--batch-size', type=int)
405 |
406 | # TRPO PPO parameter
407 | arg_test.add_argument('--lam', type=float)
408 | arg_test.add_argument('--ent-coeff', type=float)
409 |
410 | # PPO parameters
411 | arg_test.add_argument('--eps-ppo', type=float)
412 |
413 | # TRPO parameters
414 | arg_test.add_argument('--max-kl', type=float)
415 | arg_test.add_argument('--n-epochs-line-search', type=int)
416 | arg_test.add_argument('--n-epochs-cg', type=int)
417 | arg_test.add_argument('--cg-damping', type=float)
418 | arg_test.add_argument('--cg-residual-tol', type=float)
419 |
420 | # DDPG TD3 parameters
421 | arg_test.add_argument('--initial-replay-size', type=int)
422 | arg_test.add_argument('--max-replay-size', type=int)
423 | arg_test.add_argument('--tau', type=float)
424 |
425 | # TD3 parameters
426 | arg_test.add_argument('--sigma', type=float)
427 |
428 | # SAC parameters
429 | arg_test.add_argument('--warmup-transitions', type=int)
430 | arg_test.add_argument('--lr-alpha', type=float)
431 | arg_test.add_argument('--target-entropy', type=float)
432 |
433 | arg_default = parser.add_argument_group('Default')
434 | arg_default.add_argument('--seed', type=int)
435 | arg_default.add_argument('--results-dir', type=str)
436 |
437 | parser.set_defaults(**default_params())
438 | args = parser.parse_args()
439 | return vars(args)
440 |
441 |
442 | if __name__ == '__main__':
443 | args_ = parse_args()
444 | experiment(**args_)
445 |
--------------------------------------------------------------------------------
/examples/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 |
8 | class PPONetwork(nn.Module):
9 | def __init__(self, input_shape, output_shape, n_features, **kwargs):
10 | super(PPONetwork, self).__init__()
11 |
12 | n_input = input_shape[-1]
13 | n_output = output_shape[0]
14 |
15 | assert len(n_features) == 2, 'PPO network needs 2 hidden layers'
16 | n_features = list(map(int, n_features))
17 |
18 | self._h1 = nn.Linear(n_input, n_features[0])
19 | self._h2 = nn.Linear(n_features[0], n_features[1])
20 | self._h3 = nn.Linear(n_features[1], n_output)
21 |
22 | nn.init.xavier_uniform_(self._h1.weight,
23 | gain=nn.init.calculate_gain('relu'))
24 | nn.init.xavier_uniform_(self._h2.weight,
25 | gain=nn.init.calculate_gain('relu'))
26 | nn.init.xavier_uniform_(self._h3.weight,
27 | gain=nn.init.calculate_gain('linear'))
28 |
29 | def forward(self, state, **kwargs):
30 | features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
31 | features2 = F.relu(self._h2(features1))
32 | a = self._h3(features2)
33 |
34 | return a
35 |
36 |
37 | class TRPONetwork(nn.Module):
38 | def __init__(self, input_shape, output_shape, n_features, **kwargs):
39 | super(TRPONetwork, self).__init__()
40 |
41 | n_input = input_shape[-1]
42 | n_output = output_shape[0]
43 |
44 | assert len(n_features) == 2, 'TRPO network needs 2 hidden layers'
45 | n_features = list(map(int, n_features))
46 |
47 | self._h1 = nn.Linear(n_input, n_features[0])
48 | self._h2 = nn.Linear(n_features[0], n_features[1])
49 | self._h3 = nn.Linear(n_features[1], n_output)
50 |
51 | nn.init.xavier_uniform_(self._h1.weight,
52 | gain=nn.init.calculate_gain('relu'))
53 | nn.init.xavier_uniform_(self._h2.weight,
54 | gain=nn.init.calculate_gain('relu'))
55 | nn.init.xavier_uniform_(self._h3.weight,
56 | gain=nn.init.calculate_gain('linear'))
57 |
58 | def forward(self, state, **kwargs):
59 | features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
60 | features2 = F.relu(self._h2(features1))
61 | a = self._h3(features2)
62 |
63 | return a
64 |
65 |
66 | class TD3CriticNetwork(nn.Module):
67 | def __init__(self, input_shape, output_shape, n_features, **kwargs):
68 | super().__init__()
69 |
70 | n_input = input_shape[-1]
71 | dim_action = kwargs['action_shape'][0]
72 | dim_state = n_input - dim_action
73 | n_output = output_shape[0]
74 |
75 | self._action_scaling = torch.tensor(kwargs['action_scaling'], dtype=torch.float32).to(
76 | device=torch.device('cuda' if kwargs['use_cuda'] else 'cpu'))
77 |
78 | # Assume there are two hidden layers
79 | assert len(n_features) == 2, 'TD3 critic needs 2 hidden layers'
80 | n_features = list(map(int, n_features))
81 |
82 | self._h1 = nn.Linear(dim_state + dim_action, n_features[0])
83 | self._h2_s = nn.Linear(n_features[0], n_features[1])
84 | self._h2_a = nn.Linear(dim_action, n_features[1], bias=False)
85 | self._h3 = nn.Linear(n_features[1], n_output)
86 |
87 | fan_in_h1, _ = nn.init._calculate_fan_in_and_fan_out(self._h1.weight)
88 | nn.init.uniform_(self._h1.weight, a=-1 / np.sqrt(fan_in_h1), b=1 / np.sqrt(fan_in_h1))
89 |
90 | fan_in_h2_s, _ = nn.init._calculate_fan_in_and_fan_out(self._h2_s.weight)
91 | nn.init.uniform_(self._h2_s.weight, a=-1 / np.sqrt(fan_in_h2_s), b=1 / np.sqrt(fan_in_h2_s))
92 |
93 | fan_in_h2_a, _ = nn.init._calculate_fan_in_and_fan_out(self._h2_a.weight)
94 | nn.init.uniform_(self._h2_a.weight, a=-1 / np.sqrt(fan_in_h2_a), b=1 / np.sqrt(fan_in_h2_a))
95 |
96 | nn.init.uniform_(self._h3.weight, a=-3e-3, b=3e-3)
97 |
98 | def forward(self, state, action):
99 | state = state.float()
100 | action = action.float() / self._action_scaling
101 | state_action = torch.cat((state, action), dim=1)
102 |
103 | features1 = F.relu(self._h1(state_action))
104 | features2_s = self._h2_s(features1)
105 | features2_a = self._h2_a(action)
106 | features2 = F.relu(features2_s + features2_a)
107 |
108 | q = self._h3(features2)
109 | return torch.squeeze(q)
110 |
111 |
112 | class TD3ActorNetwork(nn.Module):
113 | def __init__(self, input_shape, output_shape, n_features, **kwargs):
114 | super().__init__()
115 |
116 | dim_state = input_shape[0]
117 | dim_action = output_shape[0]
118 |
119 | self._action_scaling = torch.tensor(kwargs['action_scaling']).to(
120 | device=torch.device('cuda' if kwargs['use_cuda'] else 'cpu'))
121 |
122 | # Assume there are two hidden layers
123 | assert len(n_features) == 2, 'TD3 actor needs two hidden layers'
124 | n_features = list(map(int, n_features))
125 |
126 | self._h1 = nn.Linear(dim_state, n_features[0])
127 | self._h2 = nn.Linear(n_features[0], n_features[1])
128 | self._h3 = nn.Linear(n_features[1], dim_action)
129 |
130 | fan_in_h1, _ = nn.init._calculate_fan_in_and_fan_out(self._h1.weight)
131 | nn.init.uniform_(self._h1.weight, a=-1 / np.sqrt(fan_in_h1), b=1 / np.sqrt(fan_in_h1))
132 |
133 | fan_in_h2, _ = nn.init._calculate_fan_in_and_fan_out(self._h2.weight)
134 | nn.init.uniform_(self._h2.weight, a=-1 / np.sqrt(fan_in_h2), b=1 / np.sqrt(fan_in_h2))
135 |
136 | nn.init.uniform_(self._h3.weight, a=-3e-3, b=3e-3)
137 |
138 | def forward(self, state):
139 | state = state.float()
140 |
141 | features1 = F.relu(self._h1(state))
142 | features2 = F.relu(self._h2(features1))
143 | a = self._h3(features2)
144 |
145 | a = self._action_scaling * torch.tanh(a)
146 |
147 | return a
148 |
149 |
150 | class DDPGCriticNetwork(nn.Module):
151 | def __init__(self, input_shape, output_shape, n_features, **kwargs):
152 | super().__init__()
153 |
154 | n_input = input_shape[-1]
155 | dim_action = kwargs['action_shape'][0]
156 | dim_state = n_input - dim_action
157 |
158 | self._action_scaling = torch.tensor(kwargs['action_scaling'], dtype=torch.float).to(
159 | device=torch.device('cuda' if kwargs['use_cuda'] else 'cpu'))
160 |
161 | n_output = output_shape[0]
162 |
163 | # Assume there are two hidden layers
164 | assert len(n_features) == 2, 'DDPG critic needs 2 hidden layers'
165 | n_features = list(map(int, n_features))
166 |
167 | self._h1 = nn.Linear(dim_state, n_features[0])
168 | self._h2_s = nn.Linear(n_features[0], n_features[1])
169 | self._h2_a = nn.Linear(dim_action, n_features[1], bias=False)
170 | self._h3 = nn.Linear(n_features[1], n_output)
171 |
172 | fan_in_h1, _ = nn.init._calculate_fan_in_and_fan_out(self._h1.weight)
173 | nn.init.uniform_(self._h1.weight, a=-1 / np.sqrt(fan_in_h1), b=1 / np.sqrt(fan_in_h1))
174 |
175 | fan_in_h2_s, _ = nn.init._calculate_fan_in_and_fan_out(self._h2_s.weight)
176 | nn.init.uniform_(self._h2_s.weight, a=-1 / np.sqrt(fan_in_h2_s), b=1 / np.sqrt(fan_in_h2_s))
177 |
178 | fan_in_h2_a, _ = nn.init._calculate_fan_in_and_fan_out(self._h2_a.weight)
179 | nn.init.uniform_(self._h2_a.weight, a=-1 / np.sqrt(fan_in_h2_a), b=1 / np.sqrt(fan_in_h2_a))
180 |
181 | nn.init.uniform_(self._h3.weight, a=-3e-3, b=3e-3)
182 |
183 | def forward(self, state, action):
184 | state = state.float()
185 | action = action.float() / self._action_scaling
186 |
187 | features1 = F.relu(self._h1(state))
188 | features2_s = self._h2_s(features1)
189 | features2_a = self._h2_a(action)
190 | features2 = F.relu(features2_s + features2_a)
191 |
192 | q = self._h3(features2)
193 |
194 | return torch.squeeze(q)
195 |
196 |
197 | class DDPGActorNetwork(nn.Module):
198 | def __init__(self, input_shape, output_shape, n_features, **kwargs):
199 | super().__init__()
200 |
201 | dim_state = input_shape[0]
202 | dim_action = output_shape[0]
203 |
204 | self._action_scaling = torch.tensor(kwargs['action_scaling']).to(
205 | device=torch.device('cuda' if kwargs['use_cuda'] else 'cpu'))
206 |
207 | # Assume there are two hidden layers
208 | assert len(n_features) == 2, 'DDPG actor needs 2 hidden layers'
209 | n_features = list(map(int, n_features))
210 |
211 | self._h1 = nn.Linear(dim_state, n_features[0])
212 | self._h2 = nn.Linear(n_features[0], n_features[1])
213 | self._h3 = nn.Linear(n_features[1], dim_action)
214 |
215 | fan_in_h1, _ = nn.init._calculate_fan_in_and_fan_out(self._h1.weight)
216 | nn.init.uniform_(self._h1.weight, a=-1 / np.sqrt(fan_in_h1), b=1 / np.sqrt(fan_in_h1))
217 |
218 | fan_in_h2, _ = nn.init._calculate_fan_in_and_fan_out(self._h2.weight)
219 | nn.init.uniform_(self._h2.weight, a=-1 / np.sqrt(fan_in_h2), b=1 / np.sqrt(fan_in_h2))
220 |
221 | nn.init.uniform_(self._h3.weight, a=-3e-3, b=3e-3)
222 |
223 | def forward(self, state):
224 | state = state.float()
225 |
226 | features1 = F.relu(self._h1(state))
227 | features2 = F.relu(self._h2(features1))
228 | a = self._h3(features2)
229 |
230 | a = self._action_scaling * torch.tanh(a)
231 |
232 | return a
233 |
234 |
235 | class SACCriticNetwork(nn.Module):
236 | def __init__(self, input_shape, output_shape, n_features, **kwargs):
237 | super().__init__()
238 |
239 | n_input = input_shape[-1]
240 | n_output = output_shape[0]
241 |
242 | # Assume there are two hidden layers
243 | assert len(n_features) == 2, 'SAC actor needs 2 hidden layers'
244 | n_features = list(map(int, n_features))
245 |
246 | self._h1 = nn.Linear(n_input, n_features[0])
247 | self._h2 = nn.Linear(n_features[0], n_features[1])
248 | self._h3 = nn.Linear(n_features[1], n_output)
249 |
250 | nn.init.xavier_uniform_(self._h1.weight,
251 | gain=nn.init.calculate_gain('relu'))
252 | nn.init.xavier_uniform_(self._h2.weight,
253 | gain=nn.init.calculate_gain('relu'))
254 | nn.init.xavier_uniform_(self._h3.weight,
255 | gain=nn.init.calculate_gain('linear'))
256 |
257 | def forward(self, state, action):
258 | state_action = torch.cat((state.float(), action.float()), dim=1)
259 | features1 = F.relu(self._h1(state_action))
260 | features2 = F.relu(self._h2(features1))
261 | q = self._h3(features2)
262 |
263 | return torch.squeeze(q)
264 |
265 |
266 | class SACActorNetwork(nn.Module):
267 | def __init__(self, input_shape, output_shape, n_features, **kwargs):
268 | super(SACActorNetwork, self).__init__()
269 |
270 | n_input = input_shape[-1]
271 | n_output = output_shape[0]
272 |
273 | # Assume there are two hidden layers
274 | assert len(n_features) == 2, 'SAC actor needs 2 hidden layers'
275 | n_features = list(map(int, n_features))
276 |
277 | self._h1 = nn.Linear(n_input, n_features[0])
278 | self._h2 = nn.Linear(n_features[0], n_features[1])
279 | self._h3 = nn.Linear(n_features[1], n_output)
280 |
281 | nn.init.xavier_uniform_(self._h1.weight,
282 | gain=nn.init.calculate_gain('relu'))
283 | nn.init.xavier_uniform_(self._h2.weight,
284 | gain=nn.init.calculate_gain('relu'))
285 | nn.init.xavier_uniform_(self._h3.weight,
286 | gain=nn.init.calculate_gain('linear'))
287 |
288 | def forward(self, state):
289 | features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
290 | features2 = F.relu(self._h2(features1))
291 | a = self._h3(features2)
292 |
293 | return a
--------------------------------------------------------------------------------
/fig/manifold.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PuzeLiu/rl_on_manifold/7ebd4125222f2bc4b1171c7069f6b85e6fe6019a/fig/manifold.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.17.0
2 | matplotlib>=3.3.4
3 | scipy>=1.5.2
4 | mushroom-rl>=1.7.0
5 | pandas>=1.2.1
6 | pybullet>=3.0.8
7 | torch
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from codecs import open
3 | from os import path
4 |
5 | version = '1.0.0'
6 |
7 |
8 | here = path.abspath(path.dirname(__file__))
9 |
10 | requires_list = []
11 | with open(path.join(here, 'requirements.txt'), encoding='utf-8') as f:
12 | for line in f:
13 | requires_list.append(str(line))
14 |
15 | extras = {}
16 |
17 | all_deps = []
18 | for group_name in extras:
19 | all_deps += extras[group_name]
20 | extras['all'] = all_deps
21 |
22 | long_description = 'TODO.'
23 |
24 | setup(
25 | name='atacom',
26 | version=version,
27 | description='Acting on the Tangent Space of the Constraint Manifold',
28 | long_description=long_description,
29 | author="Puze Liu",
30 | author_email='puze@robot-learning.de',
31 | license='MIT',
32 | packages=[package for package in find_packages()
33 | if package.startswith('node')],
34 | zip_safe=False,
35 | install_requires=requires_list,
36 | extras_require=extras,
37 | classifiers=["Programming Language :: Python :: 3",
38 | "License :: OSI Approved :: MIT License",
39 | "Operating System :: OS Independent",
40 | ]
41 | )
42 |
--------------------------------------------------------------------------------