├── README.md
├── env
├── __init__.py
├── acro3bot
│ ├── acro3bot.py
│ ├── derive.py
│ └── robot.p
├── acrobot
│ ├── acrobot.py
│ ├── derive.py
│ └── robot.p
├── base.py
├── cart2pole
│ ├── cart2pole.py
│ ├── derive.py
│ └── robot.p
├── cart3pole
│ ├── cart3pole.py
│ ├── derive.py
│ └── robot.p
├── cartpole
│ ├── cartpole.py
│ ├── derive.py
│ └── robot.p
├── pendulum
│ ├── derive.py
│ ├── pendulum.py
│ └── robot.p
├── reacher
│ ├── derive.py
│ ├── reacher.py
│ └── robot.p
├── rewards.py
└── utils.py
├── mbrl.py
├── models
├── __init__.py
├── mbrl.py
└── sac.py
└── sac.py
/README.md:
--------------------------------------------------------------------------------
1 | # Physics-Informed Model-Based RL
2 |
3 | Published at Learning for Dynamics & Control Conference (L4DC), 2023.
4 |
5 |
6 |
7 |
8 |
9 | ## Abstract
10 | We apply reinforcement learning (RL) to robotics tasks. One of the drawbacks of traditional RL algorithms has been their poor sample efficiency. One approach to improve the sample efficiency is model-based RL. In our model-based RL algorithm, we learn a model of the environment, essentially its transition dynamics and reward function, use it to generate imaginary trajectories and backpropagate through them to update the policy, exploiting the differentiability of the model.
11 |
12 | Intuitively, learning more accurate models should lead to better model-based RL performance. Recently, there has been growing interest in developing better deep neural network based dynamics models for physical systems, by utilizing the structure of the underlying physics. We focus on robotic systems undergoing rigid body motion without contacts. We compare two versions of our model-based RL algorithm, one which uses a standard deep neural network based dynamics model and the other which uses a much more accurate, physics-informed neural network based dynamics model.
13 |
14 | We show that, in model-based RL, model accuracy mainly matters in environments that are sensitive to initial conditions, where numerical errors accumulate fast. In these environments, the physics-informed version of our algorithm achieves significantly better average-return and sample efficiency. In environments that are not sensitive to initial conditions, both versions of our algorithm achieve similar average-return, while the physics-informed version achieves better sample efficiency.
15 |
16 | We also show that, in challenging environments, physics-informed model-based RL achieves better average-return than state-of-the-art model-free RL algorithms such as Soft Actor-Critic, as it computes the policy-gradient analytically.
17 |
18 | For more information check out,
19 | - [Project Webpage](https://adi3e08.github.io/research/pimbrl)
20 | - [Paper](https://arxiv.org/abs/2212.02179)
21 |
22 | ## Requirements
23 | - Python
24 | - Numpy
25 | - Pytorch
26 | - Tensorboard
27 | - Pygame
28 |
29 | ## Usage
30 | To train MBRL LNN on Acrobot task, run,
31 |
32 | python mbrl.py --env acrobot --mode train --episodes 500 --seed 0
33 |
34 | The data from this experiment will be stored in the folder "./log/acrobot/mbrl_lnn/seed_0". This folder will contain two sub folders, (i) models : here model checkpoints will be stored and (ii) tensorboard : here tensorboard plots will be stored.
35 |
36 | To evaluate MRBL LNN on Acrobot task, run,
37 |
38 | python mbrl.py --env acrobot --mode eval --episodes 3 --seed 100 --checkpoint ./log/acrobot/mbrl_lnn/seed_0/models/499.ckpt --render
39 |
40 | ## Citation
41 | If you find this work helpful, please consider starring this repo and citing our paper using the following Bibtex.
42 | ```bibtex
43 | @inproceedings{ramesh2023physics,
44 | title={Physics-Informed Model-Based Reinforcement Learning},
45 | author={Ramesh, Adithya and Ravindran, Balaraman},
46 | booktitle={Learning for Dynamics and Control Conference},
47 | pages={26--37},
48 | year={2023},
49 | organization={PMLR}
50 | }
51 |
52 |
53 |
--------------------------------------------------------------------------------
/env/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/env/acro3bot/acro3bot.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import pygame
4 | import numpy as np
5 | from env import rewards
6 | from env.utils import rect_points, wrap, basic_check, pygame_transform
7 | from env.base import BaseEnv
8 |
9 | class acro3bot(BaseEnv):
10 | def __init__(self):
11 | m1 = 0.1
12 | l1 = 1
13 | r1 = l1/2
14 | I1 = m1 * l1**2 / 12
15 |
16 | m2 = 0.1
17 | l2 = 1
18 | r2 = l2/2
19 | I2 = m2 * l2**2 / 12
20 |
21 | m3 = 0.1
22 | l3 = 1
23 | r3 = l3/2
24 | I3 = m3 * l3**2 / 12
25 |
26 | g = 9.8
27 |
28 | m = [m1,m2,m3]
29 | l = [l1,l2,l3]
30 | r = [r1,r2,r3]
31 | I = [I1,I2,I3]
32 | super(acro3bot, self).__init__(name = "acro3bot",
33 | n = 3,
34 | obs_size = 9,
35 | action_size = 2,
36 | inertials = m+l+r+I+[g],
37 | a_scale = np.array([0.5,2.0]))
38 |
39 | def wrap_state(self):
40 | self.state[:3] = wrap(self.state[:3])
41 |
42 | def reset_state(self):
43 | self.state = np.array([np.pi + 0.01*np.random.randn(),
44 | 0.01*np.random.randn(),
45 | 0.01*np.random.randn(),
46 | 0,0,0])
47 |
48 | def get_A(self, a):
49 | a_1, a_3 = np.clip(a, -1.0, 1.0)*self.a_scale
50 | a_2 = 0.0
51 | return np.array([a_1,a_2,a_3])
52 |
53 | def get_obs(self):
54 | return np.array([np.cos(self.state[0]),np.sin(self.state[0]),
55 | np.cos(self.state[1]),np.sin(self.state[1]),
56 | np.cos(self.state[2]),np.sin(self.state[2]),
57 | self.state[3],
58 | self.state[4],
59 | self.state[5]
60 | ])
61 |
62 | def get_reward(self):
63 | upright = (np.array([np.cos(self.state[0]), np.cos(self.state[0]+self.state[1]), np.cos(self.state[0]+self.state[1]+self.state[2])])+1)/2
64 |
65 | qdot = self.state[self.n:]
66 | ang_vel = np.array([qdot[0],qdot[0]+qdot[1],qdot[0]+qdot[1]+qdot[2]])
67 | small_velocity = rewards.tolerance(ang_vel, margin=self.ang_vel_limit).min()
68 | small_velocity = (1 + small_velocity) / 2
69 |
70 | reward = upright.mean() * small_velocity
71 |
72 | return reward
73 |
74 | def draw(self):
75 | centers, joints, angles = self.geo
76 |
77 | for j in range(self.n):
78 | link_points = rect_points(centers[j], self.link_length, self.link_width, angles[j,0],self.scaling,self.offset)
79 | pygame.draw.polygon(self.screen, self.link_color, link_points)
80 |
81 | joint_point = pygame_transform(joints[j],self.scaling,self.offset)
82 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius)
83 |
84 | if __name__ == '__main__':
85 | basic_check("acro3bot",0)
86 |
--------------------------------------------------------------------------------
/env/acro3bot/derive.py:
--------------------------------------------------------------------------------
1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten
2 | from sympy.utilities.lambdify import lambdify
3 | from sympy.matrices.dense import matrix_multiply_elementwise
4 | import dill as pickle
5 | pickle.settings['recurse'] = True
6 |
7 | def get_C_G(n,M,V,q,qdot):
8 | C = zeros(n)
9 | for i in range(n):
10 | for j in range(n):
11 | for k in range(n):
12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2
13 |
14 | G = Matrix([diff(V,q[i]) for i in range(n)])
15 |
16 | return C,G
17 |
18 | def derive():
19 | lambda_dict = {}
20 | n = 3
21 | m1,m2,m3 = symbols('m1,m2,m3')
22 | l1,l2,l3 = symbols('l1,l2,l3')
23 | r1,r2,r3 = symbols('r1,r2,r3')
24 | I1,I2,I3 = symbols('I1,I2,I3')
25 | g = symbols('g')
26 | q1,q2,q3 = symbols('q1,q2,q3')
27 | q1dot,q2dot,q3dot = symbols('q1dot,q2dot,q3dot')
28 |
29 | m = [m1,m2,m3]
30 | l = [l1,l2,l3]
31 | r = [r1,r2,r3]
32 | I = [I1,I2,I3]
33 | inertials = m+l+r+I+[g]
34 |
35 | q = Matrix([q1,q2,q3])
36 | qdot = Matrix([q1dot,q2dot,q3dot])
37 | state = [q1,q2,q3,q1dot,q2dot,q3dot]
38 |
39 | J_w = Matrix([[1,0,0],
40 | [1,1,0],
41 | [1,1,1]
42 | ])
43 |
44 | angles = J_w * q
45 |
46 | V = 0
47 | M = zeros(n)
48 | J = []
49 | for i in range(n):
50 | if i == 0:
51 | joint = Matrix([[0, 0]])
52 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
53 | joints = joint
54 | centers = center
55 | else:
56 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]])
57 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
58 | joints = Matrix.vstack(joints, joint)
59 | centers = Matrix.vstack(centers, center)
60 |
61 | J_v = center.jacobian(q)
62 | # J.append(J_v)
63 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:]
64 | M += M_i
65 |
66 | V += m[i]*g*center[0,1]
67 |
68 | # print(cse([centers,joints,J_w]+J, optimizations='basic'))
69 |
70 | C,G = get_C_G(n,M,V,q,qdot)
71 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True)
72 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True)
73 |
74 | with open("./env/acro3bot/robot.p", "wb") as outf:
75 | pickle.dump(lambda_dict, outf)
76 |
77 | print("Done")
78 |
79 | if __name__ == '__main__':
80 | derive()
81 |
--------------------------------------------------------------------------------
/env/acro3bot/robot.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/acro3bot/robot.p
--------------------------------------------------------------------------------
/env/acrobot/acrobot.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import pygame
4 | import numpy as np
5 | from env import rewards
6 | from env.utils import rect_points, wrap, basic_check, pygame_transform
7 | from env.base import BaseEnv
8 |
9 | class acrobot(BaseEnv):
10 | def __init__(self):
11 | m1 = 0.1
12 | l1 = 1
13 | r1 = l1/2
14 | I1 = m1 * l1**2 / 12
15 |
16 | m2 = 0.1
17 | l2 = 1
18 | r2 = l2/2
19 | I2 = m2 * l2**2 / 12
20 |
21 | g = 9.8
22 |
23 | m = [m1,m2]
24 | l = [l1,l2]
25 | r = [r1,r2]
26 | I = [I1,I2]
27 | super(acrobot, self).__init__(name = "acrobot",
28 | n = 2,
29 | obs_size = 6,
30 | action_size = 1,
31 | inertials = m+l+r+I+[g],
32 | a_scale = np.array([1.75]))
33 |
34 | def wrap_state(self):
35 | self.state[:2] = wrap(self.state[:2])
36 |
37 | def reset_state(self):
38 | self.state = np.array([np.pi + 0.01*np.random.randn(),
39 | 0.01*np.random.randn(),
40 | 0,
41 | 0])
42 |
43 | def get_A(self, a):
44 | a_2, = np.clip(a, -1.0, 1.0)*self.a_scale
45 | a_1 = 0.0
46 | return np.array([a_1,a_2])
47 |
48 | def get_obs(self):
49 | return np.array([np.cos(self.state[0]),np.sin(self.state[0]),
50 | np.cos(self.state[1]),np.sin(self.state[1]),
51 | self.state[2],
52 | self.state[3]
53 | ])
54 |
55 | def get_reward(self):
56 | upright = (np.array([np.cos(self.state[0]), np.cos(self.state[0]+self.state[1])])+1)/2
57 |
58 | qdot = self.state[self.n:]
59 | ang_vel = np.array([qdot[0],qdot[0]+qdot[1]])
60 | small_velocity = rewards.tolerance(ang_vel, margin=self.ang_vel_limit).min()
61 | small_velocity = (1 + small_velocity) / 2
62 |
63 | reward = upright.mean() * small_velocity
64 |
65 | return reward
66 |
67 | def draw(self):
68 | centers, joints, angles = self.geo
69 |
70 | for j in range(self.n):
71 | link_points = rect_points(centers[j], self.link_length, self.link_width, angles[j,0],self.scaling,self.offset)
72 | pygame.draw.polygon(self.screen, self.link_color, link_points)
73 |
74 | joint_point = pygame_transform(joints[j],self.scaling,self.offset)
75 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius)
76 |
77 | if __name__ == '__main__':
78 | basic_check("acrobot",0)
79 |
--------------------------------------------------------------------------------
/env/acrobot/derive.py:
--------------------------------------------------------------------------------
1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten
2 | from sympy.utilities.lambdify import lambdify
3 | from sympy.matrices.dense import matrix_multiply_elementwise
4 | import dill as pickle
5 | pickle.settings['recurse'] = True
6 |
7 | def get_C_G(n,M,V,q,qdot):
8 | C = zeros(n)
9 | for i in range(n):
10 | for j in range(n):
11 | for k in range(n):
12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2
13 |
14 | G = Matrix([diff(V,q[i]) for i in range(n)])
15 |
16 | return C,G
17 |
18 | def derive():
19 | lambda_dict = {}
20 | n = 2
21 | m1,m2 = symbols('m1,m2')
22 | l1,l2 = symbols('l1,l2')
23 | r1,r2 = symbols('r1,r2')
24 | I1,I2 = symbols('I1,I2')
25 | g = symbols('g')
26 | q1,q2 = symbols('q1,q2')
27 | q1dot,q2dot = symbols('q1dot,q2dot')
28 |
29 | m = [m1,m2]
30 | l = [l1,l2]
31 | r = [r1,r2]
32 | I = [I1,I2]
33 | inertials = m+l+r+I+[g]
34 |
35 | q = Matrix([q1,q2])
36 | qdot = Matrix([q1dot,q2dot])
37 | state = [q1,q2,q1dot,q2dot]
38 |
39 | J_w = Matrix([[1,0],
40 | [1,1]
41 | ])
42 |
43 | angles = J_w * q
44 |
45 | V = 0
46 | M = zeros(n)
47 | J = []
48 | for i in range(n):
49 | if i == 0:
50 | joint = Matrix([[0, 0]])
51 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
52 | joints = joint
53 | centers = center
54 | else:
55 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]])
56 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
57 | joints = Matrix.vstack(joints, joint)
58 | centers = Matrix.vstack(centers, center)
59 |
60 | J_v = center.jacobian(q)
61 | # J.append(J_v)
62 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:]
63 | M += M_i
64 |
65 | V += m[i]*g*center[0,1]
66 |
67 | # print(cse([centers,joints,J_w]+J, optimizations='basic'))
68 |
69 | C,G = get_C_G(n,M,V,q,qdot)
70 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True)
71 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True)
72 |
73 | with open("./env/acrobot/robot.p", "wb") as outf:
74 | pickle.dump(lambda_dict, outf)
75 |
76 | print("Done")
77 |
78 | if __name__ == '__main__':
79 | derive()
80 |
--------------------------------------------------------------------------------
/env/acrobot/robot.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/acrobot/robot.p
--------------------------------------------------------------------------------
/env/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import time
4 | import dill as pickle
5 | pickle.settings['recurse'] = True
6 | import pygame
7 | import numpy as np
8 | from env.utils import create_background
9 | from copy import deepcopy
10 |
11 | class BaseEnv:
12 | def __init__(self,name,n,obs_size,action_size,inertials,a_scale):
13 | self.name = name
14 | self.n = n
15 | self.obs_size = obs_size
16 | self.action_size = action_size
17 | self.inertials = inertials
18 | self.a_scale = a_scale
19 |
20 | self.dt = 0.01 # time per simulation step (in seconds)
21 | self.t = 0 # elapsed simulation steps
22 | self.t_max = 1000 # max simulation steps
23 | self.state = np.zeros(self.n)
24 | self.ang_vel_limit = 20.0
25 |
26 | with open("./env/"+self.name+"/robot.p", "rb") as inf:
27 | funcs = pickle.load(inf)
28 | self.kinematics = funcs['kinematics']
29 | self.dynamics = funcs['dynamics']
30 |
31 | # For rendering
32 | self.display = False
33 | self.screen_width = 500
34 | self.screen_height = 500
35 | self.offset = [250, 250]
36 | self.scaling = 75
37 | self.x_limit = 2.0
38 |
39 | self.link_length = 1.0
40 | self.link_width = 0.2
41 | self.link_color = (72,209,204) # medium turquoise
42 |
43 | self.joint_radius = self.link_width/1.8
44 | self.joint_color = (205,55,0) # orange red
45 |
46 | self.cart_length = 5*self.link_width
47 | self.cart_width = 2*self.link_width
48 | self.cart_color = (200,255,0) # yellow
49 |
50 | self.rail_length = 2*self.x_limit
51 | self.rail_width = self.link_width/2.5
52 | self.rail_color = (150,150,150) # gray
53 |
54 | def wrap_state(self):
55 | pass
56 |
57 | def reset_state(self):
58 | pass
59 |
60 | def get_A(self, a):
61 | pass
62 |
63 | def get_obs(self):
64 | pass
65 |
66 | def get_reward(self):
67 | pass
68 |
69 | def draw(self):
70 | pass
71 |
72 | def set_state(self,s):
73 | self.state = s
74 |
75 | def reset(self):
76 | self.reset_state()
77 | self.wrap_state()
78 | self.geo = self.kinematics(self.inertials+self.state.tolist())
79 | self.t = 0
80 |
81 | return self.get_obs(), 0.0, False
82 |
83 | def step(self, a):
84 | self.state = self.rk4(self.state, self.get_A(a))
85 | self.wrap_state()
86 | self.geo = self.kinematics(self.inertials+self.state.tolist())
87 |
88 | self.t += 1
89 | if self.t >= self.t_max: # infinite horizon formulation, no terminal state, similar to dm_control
90 | done = True
91 | else:
92 | done = False
93 |
94 | return self.get_obs(), self.get_reward(), done
95 |
96 | def F(self, s, a):
97 | M, C, G = self.dynamics(self.inertials+s.tolist())
98 | qdot = s[self.n:]
99 | qddot = np.linalg.inv(M+1e-6*np.eye(self.n)) @ (a - C @ qdot - G.flatten())
100 |
101 | return np.concatenate((qdot,qddot))
102 |
103 | def rk4(self, s, a):
104 | s0 = deepcopy(s)
105 | k = []
106 | for l in range(4):
107 | if l > 0:
108 | if l == 1 or l == 2:
109 | dt = self.dt/2
110 | elif l == 3:
111 | dt = self.dt
112 | s = s0 + dt * k[l-1]
113 | k.append(self.F(s, a))
114 | s = s0 + (self.dt/6.0) * (k[0] + 2 * k[1] + 2 * k[2] + k[3])
115 |
116 | return s
117 |
118 | def render(self):
119 | if self.display:
120 | self.screen.blit(self.background, (0, 0))
121 | self.draw()
122 | time.sleep(0.006)
123 | pygame.display.flip()
124 | else:
125 | self.display = True
126 | pygame.init()
127 | self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
128 | pygame.display.set_caption(self.name)
129 | self.background = create_background(self.screen_width, self.screen_height)
130 |
--------------------------------------------------------------------------------
/env/cart2pole/cart2pole.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import pygame
4 | import numpy as np
5 | from env import rewards
6 | from env.utils import rect_points, wrap, basic_check, pygame_transform
7 | from env.base import BaseEnv
8 |
9 | class cart2pole(BaseEnv):
10 | def __init__(self):
11 | m1 = 1
12 | l1, r1, I1 = 0, 0, 0 # dummy
13 |
14 | m2 = 0.1
15 | l2 = 1
16 | r2 = l2/2
17 | I2 = m2 * l2**2 / 12
18 |
19 | m3 = 0.1
20 | l3 = 1
21 | r3 = l3/2
22 | I3 = m3 * l3**2 / 12
23 |
24 | g = 9.8
25 |
26 | m = [m1,m2,m3]
27 | l = [l1,l2,l3]
28 | r = [r1,r2,r3]
29 | I = [I1,I2,I3]
30 | super(cart2pole, self).__init__(name = "cart2pole",
31 | n = 3,
32 | obs_size = 8,
33 | action_size = 1,
34 | inertials = m+l+r+I+[g],
35 | a_scale = np.array([10.0]))
36 | def wrap_state(self):
37 | self.state[1:3] = wrap(self.state[1:3])
38 |
39 | def reset_state(self):
40 | self.state = np.array([0.01*np.random.randn(),
41 | np.pi + 0.01*np.random.randn(),
42 | 0.01*np.random.randn(),
43 | 0,0,0])
44 |
45 | def get_A(self, a):
46 | a_1, = np.clip(a, -1.0, 1.0)*self.a_scale
47 | a_2, a_3 = 0.0, 0.0
48 | return np.array([a_1,a_2,a_3])
49 |
50 | def get_obs(self):
51 | return np.array([self.state[0],
52 | np.cos(self.state[1]),np.sin(self.state[1]),
53 | np.cos(self.state[2]),np.sin(self.state[2]),
54 | self.state[3],
55 | self.state[4],
56 | self.state[5]
57 | ])
58 |
59 | def get_reward(self):
60 | upright = (np.array([np.cos(self.state[1]), np.cos(self.state[1]+self.state[2])])+1)/2
61 |
62 | if np.abs(self.state[0]) <= self.x_limit:
63 | centered = rewards.tolerance(self.state[0], margin=self.x_limit)
64 | centered = (1 + centered) / 2
65 | else:
66 | centered = 0.1
67 |
68 | qdot = self.state[self.n:]
69 | ang_vel = np.array([qdot[0],qdot[1],qdot[1]+qdot[2]])
70 | small_velocity = rewards.tolerance(ang_vel[1:], margin=self.ang_vel_limit).min()
71 | small_velocity = (1 + small_velocity) / 2
72 |
73 | reward = upright.mean() * small_velocity * centered
74 |
75 | return reward
76 |
77 | def draw(self):
78 | centers, joints, angles = self.geo
79 |
80 | #horizontal rail
81 | pygame.draw.polygon(self.screen, self.rail_color, rect_points([0,0], self.rail_length, self.rail_width, np.pi/2, self.scaling, self.offset))
82 |
83 | plot_x = ((centers[0,0] + self.x_limit) % (2 * self.x_limit)) - self.x_limit
84 | link1_points = rect_points([plot_x,0], self.cart_length, self.cart_width, np.pi/2, self.scaling, self.offset)
85 | pygame.draw.polygon(self.screen, self.cart_color, link1_points)
86 |
87 | offset = np.array([plot_x-centers[0,0],0])
88 | for j in range(1,self.n):
89 | link_points = rect_points(centers[j]+offset, self.link_length, self.link_width, angles[j,0],self.scaling,self.offset)
90 | pygame.draw.polygon(self.screen, self.link_color, link_points)
91 |
92 | joint_point = pygame_transform(joints[j]+offset,self.scaling,self.offset)
93 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius)
94 |
95 |
96 | if __name__ == '__main__':
97 | basic_check("cart2pole",0)
98 |
--------------------------------------------------------------------------------
/env/cart2pole/derive.py:
--------------------------------------------------------------------------------
1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten
2 | from sympy.utilities.lambdify import lambdify
3 | from sympy.matrices.dense import matrix_multiply_elementwise
4 | import dill as pickle
5 | pickle.settings['recurse'] = True
6 |
7 | def get_C_G(n,M,V,q,qdot):
8 | C = zeros(n)
9 | for i in range(n):
10 | for j in range(n):
11 | for k in range(n):
12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2
13 |
14 | G = Matrix([diff(V,q[i]) for i in range(n)])
15 |
16 | return C,G
17 |
18 | def derive():
19 | lambda_dict = {}
20 | n = 3
21 | m1,m2,m3 = symbols('m1,m2,m3')
22 | l1,l2,l3 = symbols('l1,l2,l3')
23 | r1,r2,r3 = symbols('r1,r2,r3')
24 | I1,I2,I3 = symbols('I1,I2,I3')
25 | g = symbols('g')
26 | q1,q2,q3 = symbols('q1,q2,q3')
27 | q1dot,q2dot,q3dot = symbols('q1dot,q2dot,q3dot')
28 |
29 | m = [m1,m2,m3]
30 | l = [l1,l2,l3]
31 | r = [r1,r2,r3]
32 | I = [I1,I2,I3]
33 | inertials = m+l+r+I+[g]
34 |
35 | q = Matrix([q1,q2,q3])
36 | qdot = Matrix([q1dot,q2dot,q3dot])
37 | state = [q1,q2,q3,q1dot,q2dot,q3dot]
38 |
39 | J_w = Matrix([[0,0,0],
40 | [0,1,0],
41 | [0,1,1]
42 | ])
43 |
44 | angles = J_w * q
45 |
46 | V = 0
47 | M = zeros(n)
48 | J = []
49 | for i in range(n):
50 | if i == 0:
51 | joint = Matrix([[q1, 0]])
52 | center = joint
53 | joints = joint
54 | centers = center
55 | elif i == 1:
56 | joint = joint
57 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
58 | joints = Matrix.vstack(joints, joint)
59 | centers = Matrix.vstack(centers, center)
60 | else:
61 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]])
62 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
63 | joints = Matrix.vstack(joints, joint)
64 | centers = Matrix.vstack(centers, center)
65 |
66 | J_v = center.jacobian(q)
67 | # J.append(J_v)
68 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:]
69 | M += M_i
70 |
71 | V += m[i]*g*center[0,1]
72 |
73 | # print(cse([centers,joints,J_w]+J, optimizations='basic'))
74 |
75 | C,G = get_C_G(n,M,V,q,qdot)
76 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True)
77 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True)
78 |
79 | with open("./env/cart2pole/robot.p", "wb") as outf:
80 | pickle.dump(lambda_dict, outf)
81 |
82 | print("Done")
83 |
84 | if __name__ == '__main__':
85 | derive()
86 |
--------------------------------------------------------------------------------
/env/cart2pole/robot.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/cart2pole/robot.p
--------------------------------------------------------------------------------
/env/cart3pole/cart3pole.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import pygame
4 | import numpy as np
5 | from env import rewards
6 | from env.utils import rect_points, wrap, basic_check, pygame_transform
7 | from env.base import BaseEnv
8 |
9 | class cart3pole(BaseEnv):
10 | def __init__(self):
11 | m1 = 1
12 | l1, r1, I1 = 0, 0, 0 # dummy
13 |
14 | m2 = 0.1
15 | l2 = 1
16 | r2 = l2/2
17 | I2 = m2 * l2**2 / 12
18 |
19 | m3 = 0.1
20 | l3 = 1
21 | r3 = l3/2
22 | I3 = m3 * l3**2 / 12
23 |
24 | m4 = 0.1
25 | l4 = 1
26 | r4 = l4/2
27 | I4 = m4 * l4**2 / 12
28 |
29 | g = 9.8
30 |
31 | m = [m1,m2,m3,m4]
32 | l = [l1,l2,l3,l4]
33 | r = [r1,r2,r3,r4]
34 | I = [I1,I2,I3,I4]
35 | super(cart3pole, self).__init__(name = "cart3pole",
36 | n = 4,
37 | obs_size = 11,
38 | action_size = 2,
39 | inertials = m+l+r+I+[g],
40 | a_scale = np.array([10.0,1.0]))
41 |
42 | def wrap_state(self):
43 | self.state[1:4] = wrap(self.state[1:4])
44 |
45 | def reset_state(self):
46 | self.state = np.array([0.01*np.random.randn(),
47 | np.pi + 0.01*np.random.randn(),
48 | 0.01*np.random.randn(),
49 | 0.01*np.random.randn(),
50 | 0,0,0,0])
51 |
52 | def get_A(self, a):
53 | a_1, a_4 = np.clip(a, -1.0, 1.0)*self.a_scale
54 | a_2, a_3 = 0.0, 0.0
55 | return np.array([a_1,a_2,a_3,a_4])
56 |
57 | def get_obs(self):
58 | return np.array([self.state[0],
59 | np.cos(self.state[1]),np.sin(self.state[1]),
60 | np.cos(self.state[2]),np.sin(self.state[2]),
61 | np.cos(self.state[3]),np.sin(self.state[3]),
62 | self.state[4],
63 | self.state[5],
64 | self.state[6],
65 | self.state[7]
66 | ])
67 |
68 | def get_reward(self):
69 | upright = (np.array([np.cos(self.state[1]), np.cos(self.state[1]+self.state[2]), np.cos(self.state[1]+self.state[2]+self.state[3])])+1)/2
70 |
71 | if np.abs(self.state[0]) <= self.x_limit:
72 | centered = rewards.tolerance(self.state[0], margin=self.x_limit)
73 | centered = (1 + centered) / 2
74 | else:
75 | centered = 0.1
76 |
77 | qdot = self.state[self.n:]
78 | ang_vel = np.array([qdot[0],qdot[1],qdot[1]+qdot[2],qdot[1]+qdot[2]+qdot[3]])
79 | small_velocity = rewards.tolerance(ang_vel[1:], margin=self.ang_vel_limit).min()
80 | small_velocity = (1 + small_velocity) / 2
81 |
82 | reward = upright.mean() * small_velocity * centered
83 |
84 | return reward
85 |
86 | def draw(self):
87 | centers, joints, angles = self.geo
88 |
89 | #horizontal rail
90 | pygame.draw.polygon(self.screen, self.rail_color, rect_points([0,0], self.rail_length, self.rail_width, np.pi/2, self.scaling, self.offset))
91 |
92 | plot_x = ((centers[0,0] + self.x_limit) % (2 * self.x_limit)) - self.x_limit
93 | link1_points = rect_points([plot_x,0], self.cart_length, self.cart_width, np.pi/2, self.scaling, self.offset)
94 | pygame.draw.polygon(self.screen, self.cart_color, link1_points)
95 |
96 | offset = np.array([plot_x-centers[0,0],0])
97 | for j in range(1,self.n):
98 | link_points = rect_points(centers[j]+offset, self.link_length, self.link_width, angles[j,0],self.scaling,self.offset)
99 | pygame.draw.polygon(self.screen, self.link_color, link_points)
100 |
101 | joint_point = pygame_transform(joints[j]+offset,self.scaling,self.offset)
102 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius)
103 |
104 | if __name__ == '__main__':
105 | basic_check("cart3pole",0)
106 |
--------------------------------------------------------------------------------
/env/cart3pole/derive.py:
--------------------------------------------------------------------------------
1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten
2 | from sympy.utilities.lambdify import lambdify
3 | from sympy.matrices.dense import matrix_multiply_elementwise
4 | import dill as pickle
5 | pickle.settings['recurse'] = True
6 |
7 | def get_C_G(n,M,V,q,qdot):
8 | C = zeros(n)
9 | for i in range(n):
10 | for j in range(n):
11 | for k in range(n):
12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2
13 |
14 | G = Matrix([diff(V,q[i]) for i in range(n)])
15 |
16 | return C,G
17 |
18 | def derive():
19 | lambda_dict = {}
20 | n = 4
21 | m1,m2,m3,m4 = symbols('m1,m2,m3,m4')
22 | l1,l2,l3,l4 = symbols('l1,l2,l3,l4')
23 | r1,r2,r3,r4 = symbols('r1,r2,r3,r4')
24 | I1,I2,I3,I4 = symbols('I1,I2,I3,I4')
25 | g = symbols('g')
26 | q1,q2,q3,q4 = symbols('q1,q2,q3,q4')
27 | q1dot,q2dot,q3dot,q4dot = symbols('q1dot,q2dot,q3dot,q4dot')
28 |
29 | m = [m1,m2,m3,m4]
30 | l = [l1,l2,l3,l4]
31 | r = [r1,r2,r3,r4]
32 | I = [I1,I2,I3,I4]
33 | inertials = m+l+r+I+[g]
34 |
35 | q = Matrix([q1,q2,q3,q4])
36 | qdot = Matrix([q1dot,q2dot,q3dot,q4dot])
37 | state = [q1,q2,q3,q4,q1dot,q2dot,q3dot,q4dot]
38 |
39 | J_w = Matrix([[0,0,0,0],
40 | [0,1,0,0],
41 | [0,1,1,0],
42 | [0,1,1,1]
43 | ])
44 |
45 | angles = J_w * q
46 |
47 | V = 0
48 | M = zeros(n)
49 | J = []
50 | for i in range(n):
51 | if i == 0:
52 | joint = Matrix([[q1, 0]])
53 | center = joint
54 | joints = joint
55 | centers = center
56 | elif i == 1:
57 | joint = joint
58 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
59 | joints = Matrix.vstack(joints, joint)
60 | centers = Matrix.vstack(centers, center)
61 | else:
62 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]])
63 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
64 | joints = Matrix.vstack(joints, joint)
65 | centers = Matrix.vstack(centers, center)
66 |
67 | J_v = center.jacobian(q)
68 | # J.append(J_v)
69 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:]
70 | M += M_i
71 |
72 | V += m[i]*g*center[0,1]
73 |
74 | # print(cse([centers,joints,J_w]+J, optimizations='basic'))
75 |
76 | C,G = get_C_G(n,M,V,q,qdot)
77 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True)
78 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True)
79 |
80 | with open("./env/cart3pole/robot.p", "wb") as outf:
81 | pickle.dump(lambda_dict, outf)
82 |
83 | print("Done")
84 |
85 | if __name__ == '__main__':
86 | derive()
87 |
--------------------------------------------------------------------------------
/env/cart3pole/robot.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/cart3pole/robot.p
--------------------------------------------------------------------------------
/env/cartpole/cartpole.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import pygame
4 | import numpy as np
5 | from env import rewards
6 | from env.utils import rect_points, wrap, basic_check, pygame_transform
7 | from env.base import BaseEnv
8 |
9 | class cartpole(BaseEnv):
10 | def __init__(self):
11 | m1 = 1
12 | l1, r1, I1 = 0, 0, 0 # dummy
13 |
14 | m2 = 0.1
15 | l2 = 1
16 | r2 = l2/2
17 | I2 = m2 * l2**2 / 12
18 |
19 | g = 9.8
20 |
21 | m = [m1,m2]
22 | l = [l1,l2]
23 | r = [r1,r2]
24 | I = [I1,I2]
25 | super(cartpole, self).__init__(name = "cartpole",
26 | n = 2,
27 | obs_size = 5,
28 | action_size = 1,
29 | inertials = m+l+r+I+[g],
30 | a_scale = np.array([10.0]))
31 |
32 | def wrap_state(self):
33 | self.state[1] = wrap(self.state[1])
34 |
35 | def reset_state(self):
36 | self.state = np.array([0.01*np.random.randn(),
37 | np.pi + 0.01*np.random.randn(),
38 | 0,0])
39 |
40 | def get_A(self, a):
41 | a_1, = np.clip(a, -1.0, 1.0)*self.a_scale
42 | a_2 = 0.0
43 | return np.array([a_1,a_2])
44 |
45 | def get_obs(self):
46 | return np.array([self.state[0],
47 | np.cos(self.state[1]),np.sin(self.state[1]),
48 | self.state[2],
49 | self.state[3]
50 | ])
51 |
52 | def get_reward(self):
53 | upright = (np.array([np.cos(self.state[1])])+1)/2
54 |
55 | if np.abs(self.state[0]) <= self.x_limit:
56 | centered = rewards.tolerance(self.state[0], margin=self.x_limit)
57 | centered = (1 + centered) / 2
58 | else:
59 | centered = 0.1
60 |
61 | qdot = self.state[self.n:]
62 | ang_vel = qdot
63 | small_velocity = rewards.tolerance(ang_vel[1:], margin=self.ang_vel_limit).min()
64 | small_velocity = (1 + small_velocity) / 2
65 |
66 | reward = upright.mean() * small_velocity * centered
67 |
68 | return reward
69 |
70 | def draw(self):
71 | centers, joints, angles = self.geo
72 |
73 | #horizontal rail
74 | pygame.draw.polygon(self.screen, self.rail_color, rect_points([0,0], self.rail_length, self.rail_width, np.pi/2, self.scaling, self.offset))
75 |
76 | plot_x = ((centers[0,0] + self.x_limit) % (2 * self.x_limit)) - self.x_limit
77 | link1_points = rect_points([plot_x,0], self.cart_length, self.cart_width, np.pi/2, self.scaling, self.offset)
78 | pygame.draw.polygon(self.screen, self.cart_color, link1_points)
79 |
80 | offset = np.array([plot_x-centers[0,0],0])
81 | for j in range(1,self.n):
82 | link_points = rect_points(centers[j]+offset, self.link_length, self.link_width, angles[j,0],self.scaling,self.offset)
83 | pygame.draw.polygon(self.screen, self.link_color, link_points)
84 |
85 | joint_point = pygame_transform(joints[j]+offset,self.scaling,self.offset)
86 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius)
87 |
88 | if __name__ == '__main__':
89 | basic_check("cartpole",0)
90 |
--------------------------------------------------------------------------------
/env/cartpole/derive.py:
--------------------------------------------------------------------------------
1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten
2 | from sympy.utilities.lambdify import lambdify
3 | from sympy.matrices.dense import matrix_multiply_elementwise
4 | import dill as pickle
5 | pickle.settings['recurse'] = True
6 |
7 | def get_C_G(n,M,V,q,qdot):
8 | C = zeros(n)
9 | for i in range(n):
10 | for j in range(n):
11 | for k in range(n):
12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2
13 |
14 | G = Matrix([diff(V,q[i]) for i in range(n)])
15 |
16 | return C,G
17 |
18 | def derive():
19 | lambda_dict = {}
20 | n = 2
21 | m1,m2 = symbols('m1,m2')
22 | l1,l2 = symbols('l1,l2')
23 | r1,r2 = symbols('r1,r2')
24 | I1,I2 = symbols('I1,I2')
25 | g = symbols('g')
26 | q1,q2 = symbols('q1,q2')
27 | q1dot,q2dot = symbols('q1dot,q2dot')
28 |
29 | m = [m1,m2]
30 | l = [l1,l2]
31 | r = [r1,r2]
32 | I = [I1,I2]
33 | inertials = m+l+r+I+[g]
34 |
35 | q = Matrix([q1,q2])
36 | qdot = Matrix([q1dot,q2dot])
37 | state = [q1,q2,q1dot,q2dot]
38 |
39 | J_w = Matrix([[0,0],
40 | [0,1]
41 | ])
42 |
43 | angles = J_w * q
44 |
45 | V = 0
46 | M = zeros(n)
47 | J = []
48 | for i in range(n):
49 | if i == 0:
50 | joint = Matrix([[q1, 0]])
51 | center = joint
52 | joints = joint
53 | centers = center
54 | elif i == 1:
55 | joint = joint
56 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
57 | joints = Matrix.vstack(joints, joint)
58 | centers = Matrix.vstack(centers, center)
59 |
60 | J_v = center.jacobian(q)
61 | # J.append(J_v)
62 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:]
63 | M += M_i
64 |
65 | V += m[i]*g*center[0,1]
66 |
67 | # print(cse([centers,joints,J_w]+J, optimizations='basic'))
68 |
69 | C,G = get_C_G(n,M,V,q,qdot)
70 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True)
71 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True)
72 |
73 | with open("./env/cartpole/robot.p", "wb") as outf:
74 | pickle.dump(lambda_dict, outf)
75 |
76 | print("Done")
77 |
78 | if __name__ == '__main__':
79 | derive()
80 |
--------------------------------------------------------------------------------
/env/cartpole/robot.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/cartpole/robot.p
--------------------------------------------------------------------------------
/env/pendulum/derive.py:
--------------------------------------------------------------------------------
1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten
2 | from sympy.utilities.lambdify import lambdify
3 | from sympy.matrices.dense import matrix_multiply_elementwise
4 | import dill as pickle
5 | pickle.settings['recurse'] = True
6 |
7 | def get_C_G(n,M,V,q,qdot):
8 | C = zeros(n)
9 | for i in range(n):
10 | for j in range(n):
11 | for k in range(n):
12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2
13 |
14 | G = Matrix([diff(V,q[i]) for i in range(n)])
15 |
16 | return C,G
17 |
18 | def derive():
19 | lambda_dict = {}
20 | n = 1
21 | m1 = symbols('m1')
22 | l1 = symbols('l1')
23 | r1 = symbols('r1')
24 | I1 = symbols('I1')
25 | g = symbols('g')
26 | q1 = symbols('q1')
27 | q1dot = symbols('q1dot')
28 |
29 | m = [m1]
30 | l = [l1]
31 | r = [r1]
32 | I = [I1]
33 | inertials = m+l+r+I+[g]
34 |
35 | q = Matrix([q1])
36 | qdot = Matrix([q1dot])
37 | state = [q1,q1dot]
38 |
39 | J_w = Matrix([[1]
40 | ])
41 |
42 | angles = J_w * q
43 |
44 | V = 0
45 | M = zeros(n)
46 | J = []
47 | for i in range(n):
48 | if i == 0:
49 | joint = Matrix([[0, 0]])
50 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
51 | joints = joint
52 | centers = center
53 |
54 | M_i = I[i] * J_w[i,:].T * J_w[i,:]
55 | M += M_i
56 |
57 | V += m[i]*g*center[0,1]
58 |
59 | # print(cse([centers,joints,J_w]+J, optimizations='basic'))
60 |
61 | C,G = get_C_G(n,M,V,q,qdot)
62 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True)
63 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True)
64 |
65 | with open("./env/pendulum/robot.p", "wb") as outf:
66 | pickle.dump(lambda_dict, outf)
67 |
68 | print("Done")
69 |
70 | if __name__ == '__main__':
71 | derive()
72 |
--------------------------------------------------------------------------------
/env/pendulum/pendulum.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import pygame
4 | import numpy as np
5 | from env import rewards
6 | from env.utils import rect_points, wrap, basic_check, pygame_transform
7 | from env.base import BaseEnv
8 |
9 | class pendulum(BaseEnv):
10 | def __init__(self):
11 | m1 = 1
12 | l1 = 1
13 | r1 = l1
14 | I1 = m1 * l1**2
15 |
16 | g = 9.8
17 |
18 | m = [m1]
19 | l = [l1]
20 | r = [r1]
21 | I = [I1]
22 | super(pendulum, self).__init__(name = "pendulum",
23 | n = 1,
24 | obs_size = 3,
25 | action_size = 1,
26 | inertials = m+l+r+I+[g],
27 | a_scale = np.array([2.0]))
28 | self.dt = 0.02
29 |
30 | def wrap_state(self):
31 | self.state[0] = wrap(self.state[0])
32 |
33 | def reset_state(self):
34 | self.state = np.array([np.pi+0.01*np.random.randn(),0])
35 |
36 | def get_A(self, a):
37 | return np.clip(a, -1.0, 1.0)*self.a_scale
38 |
39 | def get_obs(self):
40 | return np.array([np.cos(self.state[0]),np.sin(self.state[0]),
41 | self.state[1]
42 | ])
43 |
44 | def get_reward(self):
45 | upright = (np.array([np.cos(self.state[0])])+1)/2
46 |
47 | qdot = self.state[self.n:]
48 | ang_vel = qdot
49 | small_velocity = rewards.tolerance(ang_vel, margin=self.ang_vel_limit).min()
50 | small_velocity = (1 + small_velocity) / 2
51 |
52 | reward = upright.mean() * small_velocity
53 |
54 | return reward
55 |
56 | def draw(self):
57 | centers, joints, angles = self.geo
58 |
59 | link1_center = (centers[0]+joints[0])/2
60 | link1_points = rect_points(link1_center, self.link_length, self.link_width/2.5, angles[0,0],self.scaling,self.offset)
61 | pygame.draw.polygon(self.screen, self.link_color, link1_points)
62 |
63 | for j in range(self.n):
64 | center_point = [self.offset[0]+self.scaling*centers[j,0],self.offset[1]-self.scaling*centers[j,1]]
65 | pygame.draw.circle(self.screen, 'slateblue1', center_point, self.scaling*self.joint_radius*1.5)
66 |
67 | joint_point = pygame_transform(joints[j],self.scaling,self.offset)
68 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius)
69 |
70 | if __name__ == '__main__':
71 | basic_check("pendulum",0)
72 |
--------------------------------------------------------------------------------
/env/pendulum/robot.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/pendulum/robot.p
--------------------------------------------------------------------------------
/env/reacher/derive.py:
--------------------------------------------------------------------------------
1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten
2 | from sympy.utilities.lambdify import lambdify
3 | from sympy.matrices.dense import matrix_multiply_elementwise
4 | import dill as pickle
5 | pickle.settings['recurse'] = True
6 |
7 | def get_C_G(n,M,V,q,qdot):
8 | C = zeros(n)
9 | for i in range(n):
10 | for j in range(n):
11 | for k in range(n):
12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2
13 |
14 | G = Matrix([diff(V,q[i]) for i in range(n)])
15 |
16 | return C,G
17 |
18 | def derive():
19 | lambda_dict = {}
20 | n = 2
21 | m1,m2 = symbols('m1,m2')
22 | l1,l2 = symbols('l1,l2')
23 | r1,r2 = symbols('r1,r2')
24 | I1,I2 = symbols('I1,I2')
25 | g = symbols('g')
26 | q1,q2 = symbols('q1,q2')
27 | q1dot,q2dot = symbols('q1dot,q2dot')
28 |
29 | m = [m1,m2]
30 | l = [l1,l2]
31 | r = [r1,r2]
32 | I = [I1,I2]
33 | inertials = m+l+r+I+[g]
34 |
35 | q = Matrix([q1,q2])
36 | qdot = Matrix([q1dot,q2dot])
37 | state = [q1,q2,q1dot,q2dot]
38 |
39 | J_w = Matrix([[1,0],
40 | [1,1]
41 | ])
42 |
43 | angles = J_w * q
44 |
45 | V = 0
46 | M = zeros(n)
47 | J = []
48 | for i in range(n):
49 | if i == 0:
50 | joint = Matrix([[0, 0]])
51 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
52 | joints = joint
53 | centers = center
54 | else:
55 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]])
56 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]])
57 | joints = Matrix.vstack(joints, joint)
58 | centers = Matrix.vstack(centers, center)
59 |
60 | J_v = center.jacobian(q)
61 | # J.append(J_v)
62 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:]
63 | M += M_i
64 |
65 | V += m[i]*g*center[0,1]
66 |
67 | # print(cse([centers,joints,J_w]+J, optimizations='basic'))
68 |
69 | C,G = get_C_G(n,M,V,q,qdot)
70 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True)
71 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True)
72 |
73 | with open("./env/reacher/robot.p", "wb") as outf:
74 | pickle.dump(lambda_dict, outf)
75 |
76 | print("Done")
77 |
78 | if __name__ == '__main__':
79 | derive()
80 |
--------------------------------------------------------------------------------
/env/reacher/reacher.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import pygame
4 | import numpy as np
5 | from env import rewards
6 | from env.utils import rect_points, wrap, basic_check, pygame_transform
7 | from env.base import BaseEnv
8 |
9 | class reacher(BaseEnv):
10 | def __init__(self):
11 | m1 = 0.1
12 | l1 = 1
13 | r1 = l1/2
14 | I1 = m1 * l1**2 / 12
15 |
16 | m2 = 0.1
17 | l2 = 1
18 | r2 = l2/2
19 | I2 = m2 * l2**2 / 12
20 |
21 | g = 0.0
22 |
23 | m = [m1,m2]
24 | l = [l1,l2]
25 | r = [r1,r2]
26 | I = [I1,I2]
27 | super(reacher, self).__init__(name = "reacher",
28 | n = 2,
29 | obs_size = 6,
30 | action_size = 2,
31 | inertials = m+l+r+I+[g],
32 | a_scale = np.array([0.1,0.1]))
33 | self.dt = 0.02
34 | self.goal_position = np.array([0,2])
35 |
36 | def wrap_state(self):
37 | self.state[:2] = wrap(self.state[:2])
38 |
39 | def reset_state(self):
40 | self.state = np.array([np.pi + 0.01*np.random.randn(),
41 | 0.01*np.random.randn(),
42 | 0,
43 | 0])
44 |
45 | def get_A(self, a):
46 | a_1, a_2 = np.clip(a, -1.0, 1.0)*self.a_scale
47 | return np.array([a_1,a_2])
48 |
49 | def get_obs(self):
50 | return np.array([np.cos(self.state[0]),np.sin(self.state[0]),
51 | np.cos(self.state[1]),np.sin(self.state[1]),
52 | self.state[2],
53 | self.state[3]
54 | ])
55 |
56 | def get_reward(self):
57 | upright = (np.array([np.cos(self.state[0]), np.cos(self.state[0]+self.state[1])])+1)/2
58 |
59 | qdot = self.state[self.n:]
60 | ang_vel = np.array([qdot[0],qdot[0]+qdot[1]])
61 | small_velocity = rewards.tolerance(ang_vel, margin=self.ang_vel_limit).min()
62 | small_velocity = (1 + small_velocity) / 2
63 |
64 | reward = upright.mean() * small_velocity
65 |
66 | return reward
67 |
68 | def draw(self):
69 | pygame.draw.circle(self.screen,'slateblue1', [self.offset[0]+self.scaling*self.goal_position[0],
70 | self.offset[1]-self.scaling*self.goal_position[1]], self.scaling*self.link_width)
71 |
72 | centers, joints, angles = self.geo
73 |
74 | for j in range(self.n):
75 | link_points = rect_points(centers[j], self.link_length, self.link_width, angles[j,0],self.scaling,self.offset)
76 | pygame.draw.polygon(self.screen, self.link_color, link_points)
77 |
78 | joint_point = pygame_transform(joints[j],self.scaling,self.offset)
79 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius)
80 |
81 |
82 | if __name__ == '__main__':
83 | basic_check("reacher",0)
--------------------------------------------------------------------------------
/env/reacher/robot.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/reacher/robot.p
--------------------------------------------------------------------------------
/env/rewards.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Soft indicator function evaluating whether a number is within bounds."""
17 |
18 | import warnings
19 | import numpy as np
20 |
21 | # The value returned by tolerance() at `margin` distance from `bounds` interval.
22 | _DEFAULT_VALUE_AT_MARGIN = 0.1
23 |
24 |
25 | def _sigmoids(x, value_at_1, sigmoid):
26 | """Returns 1 when `x` == 0, between 0 and 1 otherwise.
27 |
28 | Args:
29 | x: A scalar or numpy array.
30 | value_at_1: A float between 0 and 1 specifying the output when `x` == 1.
31 | sigmoid: String, choice of sigmoid type.
32 |
33 | Returns:
34 | A numpy array with values between 0.0 and 1.0.
35 |
36 | Raises:
37 | ValueError: If not 0 < `value_at_1` < 1, except for `linear`, `cosine` and
38 | `quadratic` sigmoids which allow `value_at_1` == 0.
39 | ValueError: If `sigmoid` is of an unknown type.
40 | """
41 | if sigmoid in ('cosine', 'linear', 'quadratic'):
42 | if not 0 <= value_at_1 < 1:
43 | raise ValueError('`value_at_1` must be nonnegative and smaller than 1, '
44 | 'got {}.'.format(value_at_1))
45 | else:
46 | if not 0 < value_at_1 < 1:
47 | raise ValueError('`value_at_1` must be strictly between 0 and 1, '
48 | 'got {}.'.format(value_at_1))
49 |
50 | if sigmoid == 'gaussian':
51 | scale = np.sqrt(-2 * np.log(value_at_1))
52 | return np.exp(-0.5 * (x*scale)**2)
53 |
54 | elif sigmoid == 'hyperbolic':
55 | scale = np.arccosh(1/value_at_1)
56 | return 1 / np.cosh(x*scale)
57 |
58 | elif sigmoid == 'long_tail':
59 | scale = np.sqrt(1/value_at_1 - 1)
60 | return 1 / ((x*scale)**2 + 1)
61 |
62 | elif sigmoid == 'reciprocal':
63 | scale = 1/value_at_1 - 1
64 | return 1 / (abs(x)*scale + 1)
65 |
66 | elif sigmoid == 'cosine':
67 | scale = np.arccos(2*value_at_1 - 1) / np.pi
68 | scaled_x = x*scale
69 | with warnings.catch_warnings():
70 | warnings.filterwarnings(
71 | action='ignore', message='invalid value encountered in cos')
72 | cos_pi_scaled_x = np.cos(np.pi*scaled_x)
73 | return np.where(abs(scaled_x) < 1, (1 + cos_pi_scaled_x)/2, 0.0)
74 |
75 | elif sigmoid == 'linear':
76 | scale = 1-value_at_1
77 | scaled_x = x*scale
78 | return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0)
79 |
80 | elif sigmoid == 'quadratic':
81 | scale = np.sqrt(1-value_at_1)
82 | scaled_x = x*scale
83 | return np.where(abs(scaled_x) < 1, 1 - scaled_x**2, 0.0)
84 |
85 | elif sigmoid == 'tanh_squared':
86 | scale = np.arctanh(np.sqrt(1-value_at_1))
87 | return 1 - np.tanh(x*scale)**2
88 |
89 | else:
90 | raise ValueError('Unknown sigmoid type {!r}.'.format(sigmoid))
91 |
92 |
93 | def tolerance(x, bounds=(0.0, 0.0), margin=0.0, sigmoid='gaussian',
94 | value_at_margin=_DEFAULT_VALUE_AT_MARGIN):
95 | """Returns 1 when `x` falls inside the bounds, between 0 and 1 otherwise.
96 |
97 | Args:
98 | x: A scalar or numpy array.
99 | bounds: A tuple of floats specifying inclusive `(lower, upper)` bounds for
100 | the target interval. These can be infinite if the interval is unbounded
101 | at one or both ends, or they can be equal to one another if the target
102 | value is exact.
103 | margin: Float. Parameter that controls how steeply the output decreases as
104 | `x` moves out-of-bounds.
105 | * If `margin == 0` then the output will be 0 for all values of `x`
106 | outside of `bounds`.
107 | * If `margin > 0` then the output will decrease sigmoidally with
108 | increasing distance from the nearest bound.
109 | sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian',
110 | 'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'.
111 | value_at_margin: A float between 0 and 1 specifying the output value when
112 | the distance from `x` to the nearest bound is equal to `margin`. Ignored
113 | if `margin == 0`.
114 |
115 | Returns:
116 | A float or numpy array with values between 0.0 and 1.0.
117 |
118 | Raises:
119 | ValueError: If `bounds[0] > bounds[1]`.
120 | ValueError: If `margin` is negative.
121 | """
122 | lower, upper = bounds
123 | if lower > upper:
124 | raise ValueError('Lower bound must be <= upper bound.')
125 | if margin < 0:
126 | raise ValueError('`margin` must be non-negative.')
127 |
128 | in_bounds = np.logical_and(lower <= x, x <= upper)
129 | if margin == 0:
130 | value = np.where(in_bounds, 1.0, 0.0)
131 | else:
132 | d = np.where(x < lower, lower - x, x - upper) / margin
133 | value = np.where(in_bounds, 1.0, _sigmoids(d, value_at_margin, sigmoid))
134 |
135 | return float(value) if np.isscalar(x) else value
136 |
137 |
--------------------------------------------------------------------------------
/env/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
3 | import pygame
4 | import numpy as np
5 |
6 | def create_background(length, width):
7 | background = pygame.Surface((length, width))
8 | pygame.draw.rect(background, (0,0,0), pygame.Rect(0, 0, length, width))
9 | return background
10 |
11 | def rect_points(center, length, width, ang, scaling, offset):
12 | points = []
13 | diag = np.sqrt(length**2+width**2)/2
14 | ang1 = 2*np.arctan2(width,length)
15 | ang2 = 2*np.arctan2(length,width)
16 |
17 | points.append((center[0]+np.sin(ang+ang1/2)*diag, center[1]+np.cos(ang+ang1/2)*diag))
18 |
19 | points.append((center[0]+np.sin(ang+ang1/2+ang2)*diag, center[1]+np.cos(ang+ang1/2+ang2)*diag))
20 |
21 | points.append((center[0]+np.sin(ang+ang1*1.5+ang2)*diag, center[1]+np.cos(ang+ang1*1.5+ang2)*diag))
22 |
23 | points.append((center[0]+np.sin(ang+ang1*1.5+2*ang2)*diag, center[1]+np.cos(ang+ang1*1.5+2*ang2)*diag))
24 |
25 | return [pygame_transform(point, scaling, offset) for point in points]
26 |
27 | def pygame_transform(point, scaling, offset):
28 | # Pygame's y axis points downwards. Hence invert y coordinate alone before offset.
29 | return (offset[0]+scaling*point[0],offset[1]-scaling*point[1])
30 |
31 | def wrap(x):
32 | return ((x + np.pi) % (2 * np.pi)) - np.pi
33 |
34 | def make_env(name):
35 |
36 | if name == "pendulum":
37 | from env.pendulum.pendulum import pendulum as Env
38 |
39 | elif name == "reacher":
40 | from env.reacher.reacher import reacher as Env
41 |
42 | elif name == "cartpole":
43 | from env.cartpole.cartpole import cartpole as Env
44 |
45 | elif name == "acrobot":
46 | from env.acrobot.acrobot import acrobot as Env
47 |
48 | elif name == "cart2pole":
49 | from env.cart2pole.cart2pole import cart2pole as Env
50 |
51 | elif name == "acro3bot":
52 | from env.acro3bot.acro3bot import acro3bot as Env
53 |
54 | elif name == "cart3pole":
55 | from env.cart3pole.cart3pole import cart3pole as Env
56 |
57 | env = Env()
58 |
59 | return env
60 |
61 | def basic_check(name, seed):
62 | np.random.seed(seed)
63 | env = make_env(name)
64 | no_episodes = 1
65 | for episode in range(no_episodes):
66 | t = 0
67 | o_t, _, _ = env.reset()
68 | ep_r = 0
69 | while True:
70 | env.render()
71 | a_t = np.random.uniform(-1,1,env.action_size)
72 | o_t_1, r_t, done = env.step(a_t)
73 | ep_r += r_t
74 | t += 1
75 | o_t = o_t_1
76 | if done:
77 | print("Episode finished with total reward ",ep_r,"time steps",t)
78 | break
79 |
80 |
81 |
--------------------------------------------------------------------------------
/mbrl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from copy import deepcopy
4 | import random
5 | import numpy as np
6 | import torch
7 | torch.set_default_dtype(torch.float64)
8 | from torch.utils.tensorboard import SummaryWriter
9 | from env.utils import make_env
10 | from models.mbrl import ReplayBuffer, V_FC, Pi_FC, dnn, lnn, reward_model_FC
11 |
12 | # Model-Based RL algorithm
13 | class MBRL:
14 | def __init__(self, arglist):
15 | self.arglist = arglist
16 |
17 | random.seed(self.arglist.seed)
18 | np.random.seed(self.arglist.seed)
19 | torch.manual_seed(self.arglist.seed)
20 |
21 | self.env = make_env(self.arglist.env)
22 |
23 | self.device = torch.device("cpu")
24 |
25 | path = "./log/"+self.env.name+"/mbrl_"+self.arglist.model
26 | self.exp_dir = os.path.join(path, "seed_"+str(self.arglist.seed))
27 | self.model_dir = os.path.join(self.exp_dir, "models")
28 | self.tensorboard_dir = os.path.join(self.exp_dir, "tensorboard")
29 |
30 | self.actor = Pi_FC(self.env.obs_size,self.env.action_size).to(self.device)
31 |
32 | if self.arglist.mode == "train":
33 | self.critic = V_FC(self.env.obs_size).to(self.device)
34 | self.critic_target = deepcopy(self.critic)
35 |
36 | if self.arglist.model == "lnn":
37 | if self.env.action_size < self.env.n:
38 | a_zeros = torch.zeros(self.arglist.batch_size,self.env.n-self.env.action_size, dtype=torch.float64, device=self.device)
39 | else:
40 | a_zeros = None
41 | self.transition_model = lnn(self.env.name, self.env.n, self.env.obs_size, self.env.action_size, self.env.dt, a_zeros).to(self.device)
42 |
43 | elif self.arglist.model == "dnn":
44 | self.transition_model = dnn(self.env.obs_size, self.env.action_size).to(self.device)
45 | self.transition_loss_fn = torch.nn.L1Loss()
46 |
47 | self.reward_model = reward_model_FC(self.env.obs_size).to(self.device)
48 | self.reward_loss_fn = torch.nn.L1Loss()
49 |
50 | if self.arglist.resume:
51 | checkpoint = torch.load(os.path.join(self.model_dir,"emergency.ckpt"))
52 | self.start_episode = checkpoint['episode'] + 1
53 |
54 | self.actor.load_state_dict(checkpoint['actor'])
55 | self.critic.load_state_dict(checkpoint['critic'])
56 | self.critic_target.load_state_dict(checkpoint['critic_target'])
57 | self.transition_model.load_state_dict(checkpoint['transition_model'])
58 | self.reward_model.load_state_dict(checkpoint['reward_model'])
59 |
60 | self.replay_buffer = checkpoint['replay_buffer']
61 |
62 | else:
63 | self.start_episode = 0
64 |
65 | self.replay_buffer = ReplayBuffer(self.arglist.replay_size, self.device)
66 |
67 | if os.path.exists(path):
68 | pass
69 | else:
70 | os.makedirs(path)
71 | os.mkdir(self.exp_dir)
72 | os.mkdir(os.path.join(self.tensorboard_dir))
73 | os.mkdir(self.model_dir)
74 |
75 | self.actor_optimizer = torch.optim.AdamW(self.actor.parameters(), lr=self.arglist.lr)
76 | self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), lr=self.arglist.lr)
77 | self.transition_optimizer = torch.optim.AdamW(self.transition_model.parameters(), lr=self.arglist.lr)
78 | self.reward_optimizer = torch.optim.AdamW(self.reward_model.parameters(), lr=self.arglist.lr)
79 |
80 | if self.arglist.resume:
81 | self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
82 | self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
83 | self.transition_optimizer.load_state_dict(checkpoint['transition_optimizer'])
84 | self.reward_optimizer.load_state_dict(checkpoint['reward_optimizer'])
85 |
86 | print("Done loading checkpoint ...")
87 |
88 | self.a_scale = torch.tensor(self.env.a_scale,dtype=torch.float64,device=self.device)
89 |
90 | self.train()
91 |
92 | elif self.arglist.mode == "eval":
93 | checkpoint = torch.load(self.arglist.checkpoint,map_location=self.device)
94 | self.actor.load_state_dict(checkpoint['actor'])
95 | ep_r_list = self.eval(self.arglist.episodes,self.arglist.render)
96 |
97 | def save_checkpoint(self, name):
98 | checkpoint = {'actor' : self.actor.state_dict(),\
99 | 'critic' : self.critic.state_dict(),\
100 | 'transition_model' : self.transition_model.state_dict(),\
101 | 'reward_model' : self.reward_model.state_dict()
102 | }
103 | torch.save(checkpoint, os.path.join(self.model_dir, name))
104 |
105 | def save_emergency_checkpoint(self, episode):
106 | checkpoint = {'episode' : episode,\
107 | 'actor' : self.actor.state_dict(),\
108 | 'actor_optimizer': self.actor_optimizer.state_dict(),\
109 | 'critic' : self.critic.state_dict(),\
110 | 'critic_optimizer': self.critic_optimizer.state_dict(),\
111 | 'critic_target' : self.critic_target.state_dict(),\
112 | 'transition_model' : self.transition_model.state_dict(),\
113 | 'transition_optimizer': self.transition_optimizer.state_dict(),\
114 | 'reward_model' : self.reward_model.state_dict(),\
115 | 'reward_optimizer': self.reward_optimizer.state_dict(),\
116 | 'replay_buffer' : self.replay_buffer \
117 | }
118 | torch.save(checkpoint, os.path.join(self.model_dir, "emergency.ckpt"))
119 |
120 | def hard_update(self, target, source):
121 | with torch.no_grad():
122 | for target_param, param in zip(target.parameters(), source.parameters()):
123 | target_param.data.copy_(param.data)
124 |
125 | def train(self):
126 | critic_target_updates = 0
127 |
128 | writer = SummaryWriter(log_dir=self.tensorboard_dir)
129 |
130 | if not self.arglist.resume:
131 | # Initialize replay buffer with K random episodes
132 | for episode in range(self.arglist.K):
133 | o,_,_ = self.env.reset()
134 | o_tensor = torch.tensor(o, dtype=torch.float64, device=self.device)
135 | ep_r = 0
136 | while True:
137 | a = np.random.uniform(-1.0, 1.0, size=self.env.action_size)
138 | o_1,r,done = self.env.step(a)
139 | a_tensor = torch.tensor(a, dtype=torch.float64, device=self.device)
140 | o_1_tensor = torch.tensor(o_1, dtype=torch.float64, device=self.device)
141 | r_tensor = torch.tensor(r, dtype=torch.float64, device=self.device)
142 | self.replay_buffer.push(o_tensor, a_tensor, r_tensor, o_1_tensor)
143 | ep_r += r
144 | o_tensor = o_1_tensor
145 | if done:
146 | break
147 |
148 | print("Done initialization ...")
149 | print("Started training ...")
150 |
151 | for episode in range(self.start_episode,self.arglist.episodes):
152 | # Model learning
153 | transition_loss_list, reward_loss_list = [], []
154 | transition_grad_list, reward_grad_list = [], []
155 | for model_batches in range(self.arglist.model_batches):
156 | O, A, R, O_1 = self.replay_buffer.sample_transitions(self.arglist.batch_size)
157 |
158 | # Dynamics learning
159 | O_1_pred = self.transition_model(O,A*self.a_scale)
160 | transition_loss = self.transition_loss_fn(O_1_pred, O_1)
161 | self.transition_optimizer.zero_grad()
162 | transition_loss.backward()
163 | torch.nn.utils.clip_grad_norm_(self.transition_model.parameters(), self.arglist.clip_term)
164 | self.transition_optimizer.step()
165 | transition_loss_list.append(transition_loss.item())
166 | transition_grad = []
167 | for param in self.transition_model.parameters():
168 | if param.grad is not None:
169 | transition_grad.append(param.grad.flatten())
170 | transition_grad_list.append(torch.norm(torch.cat(transition_grad)).item())
171 |
172 | # Reward learning
173 | R_pred = self.reward_model(O_1)
174 | reward_loss = self.reward_loss_fn(R_pred,R)
175 | self.reward_optimizer.zero_grad()
176 | reward_loss.backward()
177 | torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), self.arglist.clip_term)
178 | self.reward_optimizer.step()
179 | reward_loss_list.append(reward_loss.item())
180 | reward_grad_list.append(torch.norm(torch.cat([param.grad.flatten() for param in self.reward_model.parameters()])).item())
181 |
182 | writer.add_scalar('transition_loss', np.mean(transition_loss_list), episode)
183 | writer.add_scalar('reward_loss', np.mean(reward_loss_list), episode)
184 | writer.add_scalar('transition_grad',np.mean(transition_grad_list),episode)
185 | writer.add_scalar('reward_grad',np.mean(reward_grad_list),episode)
186 |
187 | #Behaviour learning
188 | actor_loss_list, critic_loss_list = [], []
189 | actor_grad_list, critic_grad_list = [], []
190 |
191 | nan_count = 0
192 | for behaviour_batches in range(self.arglist.behaviour_batches):
193 | O = self.replay_buffer.sample_states(self.arglist.batch_size)
194 | t = 0
195 | values, values_target, values_lambda, R = [], [], [], []
196 | log_probs = []
197 | try:
198 | while True:
199 | A, log_prob = self.actor(O, False, True)
200 | log_probs.append(log_prob)
201 | O_1 = self.transition_model(O, A*self.a_scale)
202 | R.append(self.reward_model(O_1))
203 | values.append(self.critic(O))
204 | values_target.append(self.critic_target(O))
205 | t += 1
206 | O = O_1
207 | if t % self.arglist.T == 0:
208 | values_target.append(self.critic_target(O_1))
209 | break
210 |
211 | # lambda-return calculation
212 | gae = torch.zeros_like(R[0])
213 | for t_ in reversed(range(self.arglist.T)):
214 | delta = R[t_]+self.arglist.gamma*values_target[t_+1]-values_target[t_]
215 | gae = delta+self.arglist.gamma*self.arglist.Lambda*gae
216 | values_lambda.append(gae+values_target[t_])
217 | values_lambda = torch.stack(values_lambda)
218 | values_lambda = values_lambda.flip(0)
219 |
220 | values = torch.stack(values)
221 | critic_loss = 0.5*torch.pow(values-values_lambda.detach(),2).sum(0).mean()
222 |
223 | log_probs = torch.stack(log_probs)
224 | actor_loss = - (values_lambda-0.0001*log_probs).sum(0).mean()
225 |
226 | self.critic_optimizer.zero_grad()
227 | critic_loss.backward(inputs=[param for param in self.critic.parameters()])
228 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.arglist.clip_term)
229 |
230 | self.actor_optimizer.zero_grad()
231 | actor_loss.backward(inputs=[param for param in self.actor.parameters()])
232 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.arglist.clip_term)
233 |
234 | critic_grad = torch.norm(torch.cat([param.grad.flatten() for param in self.critic.parameters()]))
235 | actor_grad = torch.norm(torch.cat([param.grad.flatten() for param in self.actor.parameters()]))
236 |
237 | if torch.isnan(critic_grad).any().item() or torch.isnan(actor_grad).any().item():
238 | nan_count += 1
239 | else:
240 | self.critic_optimizer.step()
241 | self.actor_optimizer.step()
242 |
243 | critic_target_updates = (critic_target_updates+1)%100
244 | if critic_target_updates == 0:
245 | self.hard_update(self.critic_target, self.critic)
246 |
247 | actor_loss_list.append(actor_loss.item())
248 | critic_loss_list.append(critic_loss.item())
249 | actor_grad_list.append(actor_grad.item())
250 | critic_grad_list.append(critic_grad.item())
251 |
252 | except:
253 | nan_count += 1
254 |
255 | if nan_count > 0:
256 | print("episode",episode,"got nan during behaviour learning","nan count",nan_count)
257 | writer.add_scalar('critic_loss',np.mean(critic_loss_list),episode)
258 | writer.add_scalar('actor_loss',np.mean(actor_loss_list),episode)
259 | writer.add_scalar('critic_grad',np.mean(critic_grad_list),episode)
260 | writer.add_scalar('actor_grad',np.mean(actor_grad_list),episode)
261 |
262 | # Environment Interaction
263 | o,_,_ = self.env.reset()
264 | o_tensor = torch.tensor(o, dtype=torch.float64, device=self.device).unsqueeze(0)
265 | ep_r = 0
266 | while True:
267 | with torch.no_grad():
268 | try:
269 | a_tensor, _ = self.actor(o_tensor)
270 | except:
271 | print("episode",episode,"got nan during environment interaction")
272 | break
273 | o_1,r,done = self.env.step(a_tensor.cpu().numpy()[0])
274 | o_1_tensor = torch.tensor(o_1, dtype=torch.float64, device=self.device).unsqueeze(0)
275 | r_tensor = torch.tensor(r, dtype=torch.float64, device=self.device)
276 | self.replay_buffer.push(o_tensor[0], a_tensor[0], r_tensor, o_1_tensor[0])
277 | ep_r += r
278 | o_tensor = o_1_tensor
279 | if done:
280 | writer.add_scalar('ep_r', ep_r, episode)
281 | if episode % self.arglist.eval_every == 0 or episode == self.arglist.episodes-1:
282 | try:
283 | # Evaluate agent performance
284 | eval_ep_r_list = self.eval(self.arglist.eval_over)
285 | writer.add_scalar('eval_ep_r', np.mean(eval_ep_r_list), episode)
286 | self.save_checkpoint(str(episode)+".ckpt")
287 | except:
288 | print("episode",episode,"got nan during eval")
289 | if (episode % 25 == 0 or episode == self.arglist.episodes-1) and episode > self.start_episode:
290 | self.save_emergency_checkpoint(episode)
291 | break
292 |
293 | def eval(self, episodes, render=False):
294 | # Evaluate agent performance over several episodes
295 | ep_r_list = []
296 | for episode in range(episodes):
297 | o,_,_ = self.env.reset()
298 | ep_r = 0
299 | while True:
300 | with torch.no_grad():
301 | a, _ = self.actor(torch.tensor(o, dtype=torch.float64, device=self.device).unsqueeze(0),True)
302 | a = a.cpu().numpy()[0]
303 | o_1,r,done = self.env.step(a)
304 | if render:
305 | self.env.render()
306 | ep_r += r
307 | o = o_1
308 | if done:
309 | ep_r_list.append(ep_r)
310 | if render:
311 | print("Episode finished with total reward ",ep_r)
312 | break
313 |
314 | if self.arglist.mode == "eval":
315 | print("Average return :",np.mean(ep_r_list))
316 |
317 | return ep_r_list
318 |
319 | def parse_args():
320 | parser = argparse.ArgumentParser("Model-Based Reinforcement Learning")
321 | # Common settings
322 | parser.add_argument("--env", type=str, default="acrobot", help="pendulum / reacher / cartpole / acrobot / cart2pole / acro3bot / cart3pole")
323 | parser.add_argument("--mode", type=str, default="train", help="train or eval")
324 | parser.add_argument("--episodes", type=int, default=500, help="number of episodes to run experiment for")
325 | parser.add_argument("--seed", type=int, default=0, help="seed")
326 | # Core training parameters
327 | parser.add_argument("--resume", action="store_true", default=False, help="continue training from checkpoint")
328 | parser.add_argument("--model", type=str, default="lnn", help="lnn / dnn")
329 | parser.add_argument("--T", type=int, default=16, help="imagination horizon")
330 | parser.add_argument("--K", type=int, default=10, help="init replay buffer with K random episodes")
331 | parser.add_argument("--lr", type=float, default=3e-4, help="learning rate")
332 | parser.add_argument("--clip-term", type=float, default=100, help="gradient clipping norm")
333 | parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
334 | parser.add_argument("--Lambda", type=float, default=0.95, help="GAE lambda")
335 | parser.add_argument("--batch-size", type=int, default=64, help="batch size for model learning, behaviour learning")
336 | parser.add_argument("--model-batches", type=int, default=int(1e4), help="model batches per episode")
337 | parser.add_argument("--behaviour-batches", type=int, default=int(1e3), help="behaviour batches per episode")
338 | parser.add_argument("--replay-size", type=int, default=int(1e5), help="replay buffer size")
339 | parser.add_argument("--eval-every", type=int, default=5, help="eval every _ episodes during training")
340 | parser.add_argument("--eval-over", type=int, default=50, help="each time eval over _ episodes")
341 | # Eval settings
342 | parser.add_argument("--checkpoint", type=str, default="", help="path to checkpoint")
343 | parser.add_argument("--render", action="store_true", default=False, help="render")
344 | return parser.parse_args()
345 |
346 | if __name__ == '__main__':
347 | arglist = parse_args()
348 | mbrl = MBRL(arglist)
349 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/models/__init__.py
--------------------------------------------------------------------------------
/models/mbrl.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | torch.set_default_dtype(torch.float64)
4 | import torch.nn.functional as F
5 | from torch.distributions.normal import Normal
6 | from torch.func import jacrev
7 | from collections import deque
8 | import random
9 |
10 | # Experience replay buffer
11 | class ReplayBuffer:
12 | def __init__(self, capacity, device):
13 | self.buffer = deque(maxlen=capacity)
14 | self.device = device
15 |
16 | def push(self, o, a, r, o_1):
17 | self.buffer.append((o, a, r, o_1))
18 |
19 | def sample_transitions(self, batch_size):
20 | O, A, R, O_1 = zip(*random.sample(self.buffer, batch_size))
21 | return torch.stack(O), torch.stack(A), torch.stack(R), torch.stack(O_1)
22 |
23 | def sample_states(self, batch_size):
24 | O, A, R, O_1 = zip(*random.sample(self.buffer, batch_size))
25 | return torch.stack(O)
26 |
27 | def __len__(self):
28 | return len(self.buffer)
29 |
30 | # Critic network
31 | class V_FC(torch.nn.Module):
32 | def __init__(self, obs_size):
33 | super(V_FC, self).__init__()
34 | self.fc1 = torch.nn.Linear(obs_size, 256)
35 | self.fc2 = torch.nn.Linear(256, 256)
36 | self.fc3 = torch.nn.Linear(256, 1)
37 |
38 | def forward(self, x):
39 | y1 = F.relu(self.fc1(x))
40 | y2 = F.relu(self.fc2(y1))
41 | y = self.fc3(y2).view(-1)
42 | return y
43 |
44 | # Actor network
45 | class Pi_FC(torch.nn.Module):
46 | def __init__(self, obs_size, action_size):
47 | super(Pi_FC, self).__init__()
48 | self.fc1 = torch.nn.Linear(obs_size, 256)
49 | self.fc2 = torch.nn.Linear(256, 256)
50 | self.mu = torch.nn.Linear(256, action_size)
51 | self.log_sigma = torch.nn.Linear(256, action_size)
52 |
53 | def forward(self, x, deterministic=False, with_logprob=False):
54 | y1 = F.relu(self.fc1(x))
55 | y2 = F.relu(self.fc2(y1))
56 | mu = self.mu(y2)
57 |
58 | if deterministic:
59 | # used for evaluating policy
60 | action = torch.tanh(mu)
61 | log_prob = None
62 | else:
63 | log_sigma = self.log_sigma(y2)
64 | log_sigma = torch.clamp(log_sigma,min=-20.0,max=2.0)
65 | sigma = torch.exp(log_sigma)
66 | dist = Normal(mu, sigma)
67 | x_t = dist.rsample()
68 | action = torch.tanh(x_t)
69 | if with_logprob:
70 | log_prob = dist.log_prob(x_t).sum(1)
71 | log_prob -= torch.log(torch.clamp(1-action.pow(2),min=1e-6)).sum(1)
72 | else:
73 | log_prob = None
74 |
75 | return action, log_prob
76 |
77 |
78 | # DNN dynamics model
79 | class dnn(torch.nn.Module):
80 | def __init__(self, obs_size, action_size):
81 | super(dnn, self).__init__()
82 | self.fc1 = torch.nn.Linear(obs_size+action_size, 64)
83 | self.fc2 = torch.nn.Linear(64, 64)
84 | self.fc3 = torch.nn.Linear(64, obs_size)
85 |
86 | def forward(self, x, a):
87 | y1 = F.relu(self.fc1(torch.cat((x,a),1)))
88 | y2 = F.relu(self.fc2(y1))
89 | y = self.fc3(y2)
90 | return y
91 |
92 | # LNN dynamics model
93 | class lnn(torch.nn.Module):
94 | def __init__(self, env_name, n, obs_size, action_size, dt, a_zeros):
95 | super(lnn, self).__init__()
96 | self.env_name = env_name
97 | self.dt = dt
98 | self.n = n
99 |
100 | input_size = obs_size - self.n
101 | out_L = int(self.n*(self.n+1)/2)
102 | self.fc1_L = torch.nn.Linear(input_size, 64)
103 | self.fc2_L = torch.nn.Linear(64, 64)
104 | self.fc3_L = torch.nn.Linear(64, out_L)
105 | if not self.env_name == "reacher":
106 | self.fc1_V = torch.nn.Linear(input_size, 64)
107 | self.fc2_V = torch.nn.Linear(64, 64)
108 | self.fc3_V = torch.nn.Linear(64, 1)
109 |
110 | self.a_zeros = a_zeros
111 |
112 | def trig_transform_q(self, q):
113 | if self.env_name == "pendulum":
114 | return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0])))
115 |
116 | elif self.env_name == "reacher" or self.env_name == "acrobot":
117 | return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0]),\
118 | torch.cos(q[:,1]),torch.sin(q[:,1])))
119 |
120 | elif self.env_name == "cartpole":
121 | return torch.column_stack((q[:,0],\
122 | torch.cos(q[:,1]),torch.sin(q[:,1])))
123 |
124 | elif self.env_name == "cart2pole":
125 | return torch.column_stack((q[:,0],\
126 | torch.cos(q[:,1]),torch.sin(q[:,1]),\
127 | torch.cos(q[:,2]),torch.sin(q[:,2])))
128 |
129 | elif self.env_name == "cart3pole":
130 | return torch.column_stack((q[:,0],\
131 | torch.cos(q[:,1]),torch.sin(q[:,1]),\
132 | torch.cos(q[:,2]),torch.sin(q[:,2]),\
133 | torch.cos(q[:,3]),torch.sin(q[:,3])))
134 |
135 | elif self.env_name == "acro3bot":
136 | return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0]),\
137 | torch.cos(q[:,1]),torch.sin(q[:,1]),\
138 | torch.cos(q[:,2]),torch.sin(q[:,2])))
139 |
140 | def inverse_trig_transform_model(self, x):
141 | if self.env_name == "pendulum":
142 | return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),x[:,2:]),1)
143 |
144 | elif self.env_name == "reacher" or self.env_name == "acrobot":
145 | return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),torch.atan2(x[:,3],x[:,2]).unsqueeze(1),x[:,4:]),1)
146 |
147 | elif self.env_name == "cartpole":
148 | return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),x[:,3:]),1)
149 |
150 | elif self.env_name == "cart2pole":
151 | return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),torch.atan2(x[:,4],x[:,3]).unsqueeze(1),x[:,5:]),1)
152 |
153 | elif self.env_name == "cart3pole":
154 | return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),torch.atan2(x[:,4],x[:,3]).unsqueeze(1),
155 | torch.atan2(x[:,6],x[:,5]).unsqueeze(1),x[:,7:]),1)
156 |
157 | elif self.env_name == "acro3bot":
158 | return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),torch.atan2(x[:,3],x[:,2]).unsqueeze(1),torch.atan2(x[:,5],x[:,4]).unsqueeze(1),
159 | x[:,6:]),1)
160 |
161 | def compute_L(self, q):
162 | y1_L = F.softplus(self.fc1_L(q))
163 | y2_L = F.softplus(self.fc2_L(y1_L))
164 | y_L = self.fc3_L(y2_L)
165 | device = y_L.device
166 | if self.n == 1:
167 | L = y_L.unsqueeze(1)
168 |
169 | elif self.n == 2:
170 | L11 = y_L[:,0].unsqueeze(1)
171 | L1_zeros = torch.zeros(L11.size(0),1, dtype=torch.float64, device=device)
172 |
173 | L21 = y_L[:,1].unsqueeze(1)
174 | L22 = y_L[:,2].unsqueeze(1)
175 |
176 | L1 = torch.cat((L11,L1_zeros),1)
177 | L2 = torch.cat((L21,L22),1)
178 | L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1)),1)
179 |
180 | elif self.n == 3:
181 | L11 = y_L[:,0].unsqueeze(1)
182 | L1_zeros = torch.zeros(L11.size(0),2, dtype=torch.float64, device=device)
183 |
184 | L21 = y_L[:,1].unsqueeze(1)
185 | L22 = y_L[:,2].unsqueeze(1)
186 | L2_zeros = torch.zeros(L21.size(0),1, dtype=torch.float64, device=device)
187 |
188 | L31 = y_L[:,3].unsqueeze(1)
189 | L32 = y_L[:,4].unsqueeze(1)
190 | L33 = y_L[:,5].unsqueeze(1)
191 |
192 | L1 = torch.cat((L11,L1_zeros),1)
193 | L2 = torch.cat((L21,L22,L2_zeros),1)
194 | L3 = torch.cat((L31,L32,L33),1)
195 | L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1),L3.unsqueeze(1)),1)
196 |
197 | elif self.n == 4:
198 | L11 = y_L[:,0].unsqueeze(1)
199 | L1_zeros = torch.zeros(L11.size(0),3, dtype=torch.float64, device=device)
200 |
201 | L21 = y_L[:,1].unsqueeze(1)
202 | L22 = y_L[:,2].unsqueeze(1)
203 | L2_zeros = torch.zeros(L21.size(0),2, dtype=torch.float64, device=device)
204 |
205 | L31 = y_L[:,3].unsqueeze(1)
206 | L32 = y_L[:,4].unsqueeze(1)
207 | L33 = y_L[:,5].unsqueeze(1)
208 | L3_zeros = torch.zeros(L31.size(0),1, dtype=torch.float64, device=device)
209 |
210 | L41 = y_L[:,6].unsqueeze(1)
211 | L42 = y_L[:,7].unsqueeze(1)
212 | L43 = y_L[:,8].unsqueeze(1)
213 | L44 = y_L[:,9].unsqueeze(1)
214 |
215 | L1 = torch.cat((L11,L1_zeros),1)
216 | L2 = torch.cat((L21,L22,L2_zeros),1)
217 | L3 = torch.cat((L31,L32,L33,L3_zeros),1)
218 | L4 = torch.cat((L41,L42,L43,L44),1)
219 | L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1),L3.unsqueeze(1),L4.unsqueeze(1)),1)
220 |
221 | return L
222 |
223 | def get_A(self, a):
224 | if self.env_name == "pendulum" or self.env_name == "reacher":
225 | A = a
226 |
227 | elif self.env_name == "acrobot":
228 | A = torch.cat((self.a_zeros,a),1)
229 |
230 | elif self.env_name == "cartpole" or self.env_name == "cart2pole":
231 | A = torch.cat((a,self.a_zeros),1)
232 |
233 | elif self.env_name == "cart3pole" or self.env_name == "acro3bot":
234 | A = torch.cat((a[:,:1],self.a_zeros,a[:,1:]),1)
235 |
236 | return A
237 |
238 | def get_L(self, q):
239 | trig_q = self.trig_transform_q(q)
240 | L = self.compute_L(trig_q)
241 | return L.sum(0), L
242 |
243 | def get_V(self, q):
244 | trig_q = self.trig_transform_q(q)
245 | y1_V = F.softplus(self.fc1_V(trig_q))
246 | y2_V = F.softplus(self.fc2_V(y1_V))
247 | V = self.fc3_V(y2_V).squeeze()
248 | return V.sum()
249 |
250 | def get_acc(self, q, qdot, a):
251 | dL_dq, L = jacrev(self.get_L, has_aux=True)(q)
252 | term_1 = torch.einsum('blk,bijk->bijl', L, dL_dq.permute(2,3,0,1))
253 | dM_dq = term_1 + term_1.transpose(2,3)
254 | c = torch.einsum('bjik,bk,bj->bi', dM_dq, qdot, qdot) - 0.5 * torch.einsum('bikj,bk,bj->bi', dM_dq, qdot, qdot)
255 | Minv = torch.cholesky_inverse(L)
256 | dV_dq = 0 if self.env_name == "reacher" else jacrev(self.get_V)(q)
257 | qddot = torch.matmul(Minv,(self.get_A(a)-c-dV_dq).unsqueeze(2)).squeeze(2)
258 | return qddot
259 |
260 | def derivs(self, s, a):
261 | q, qdot = s[:,:self.n], s[:,self.n:]
262 | qddot = self.get_acc(q, qdot, a)
263 | return torch.cat((qdot,qddot),dim=1)
264 |
265 | def rk2(self, s, a):
266 | alpha = 2.0/3.0 # Ralston's method
267 | k1 = self.derivs(s, a)
268 | k2 = self.derivs(s + alpha * self.dt * k1, a)
269 | s_1 = s + self.dt * ((1.0 - 1.0/(2.0*alpha))*k1 + (1.0/(2.0*alpha))*k2)
270 | return s_1
271 |
272 | def forward(self, o, a):
273 | s_1 = self.rk2(self.inverse_trig_transform_model(o), a)
274 | o_1 = torch.cat((self.trig_transform_q(s_1[:,:self.n]),s_1[:,self.n:]),1)
275 | return o_1
276 |
277 | # Reward model
278 | class reward_model_FC(torch.nn.Module):
279 | def __init__(self, obs_size):
280 | super(reward_model_FC, self).__init__()
281 | self.fc1 = torch.nn.Linear(obs_size, 64)
282 | self.fc2 = torch.nn.Linear(64, 64)
283 | self.fc3 = torch.nn.Linear(64, 1)
284 |
285 | def forward(self, x):
286 | y1 = F.relu(self.fc1(x))
287 | y2 = F.relu(self.fc2(y1))
288 | y = self.fc3(y2).view(-1)
289 | return y
290 |
--------------------------------------------------------------------------------
/models/sac.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | torch.set_default_dtype(torch.float64)
4 | import torch.nn.functional as F
5 | from torch.distributions.normal import Normal
6 | from collections import deque
7 | import random
8 |
9 | # Experience replay buffer
10 | class ReplayBuffer:
11 | def __init__(self, capacity, device):
12 | self.buffer = deque(maxlen=capacity)
13 | self.device = device
14 |
15 | def push(self, o, a, r, o_1):
16 | self.buffer.append((o, a, r, o_1))
17 |
18 | def sample(self, batch_size):
19 | O, A, R, O_1 = zip(*random.sample(self.buffer, batch_size))
20 | return torch.tensor(np.array(O), dtype=torch.float64, device=self.device),\
21 | torch.tensor(np.array(A), dtype=torch.float64, device=self.device),\
22 | torch.tensor(np.array(R), dtype=torch.float64, device=self.device),\
23 | torch.tensor(np.array(O_1), dtype=torch.float64, device=self.device)
24 |
25 | def __len__(self):
26 | return len(self.buffer)
27 |
28 | # Critic network
29 | class Q_FC(torch.nn.Module):
30 | def __init__(self, obs_size, action_size):
31 | super(Q_FC, self).__init__()
32 | self.fc1 = torch.nn.Linear(obs_size+action_size, 256)
33 | self.fc2 = torch.nn.Linear(256, 256)
34 | self.fc3 = torch.nn.Linear(256, 1)
35 |
36 | def forward(self, x, a):
37 | y1 = F.relu(self.fc1(torch.cat((x,a),1)))
38 | y2 = F.relu(self.fc2(y1))
39 | y = self.fc3(y2).view(-1)
40 | return y
41 |
42 | # Actor network
43 | class Pi_FC(torch.nn.Module):
44 | def __init__(self, obs_size, action_size):
45 | super(Pi_FC, self).__init__()
46 | self.fc1 = torch.nn.Linear(obs_size, 256)
47 | self.fc2 = torch.nn.Linear(256, 256)
48 | self.mu = torch.nn.Linear(256, action_size)
49 | self.log_sigma = torch.nn.Linear(256, action_size)
50 |
51 | def forward(self, x, deterministic=False, with_logprob=False):
52 | y1 = F.relu(self.fc1(x))
53 | y2 = F.relu(self.fc2(y1))
54 | mu = self.mu(y2)
55 |
56 | if deterministic:
57 | # used for evaluating policy
58 | action = torch.tanh(mu)
59 | log_prob = None
60 | else:
61 | log_sigma = self.log_sigma(y2)
62 | log_sigma = torch.clamp(log_sigma,min=-20.0,max=2.0)
63 | sigma = torch.exp(log_sigma)
64 | dist = Normal(mu, sigma)
65 | x_t = dist.rsample()
66 | if with_logprob:
67 | log_prob = dist.log_prob(x_t).sum(1)
68 | log_prob -= (2*(np.log(2) - x_t - F.softplus(-2*x_t))).sum(1)
69 | else:
70 | log_prob = None
71 | action = torch.tanh(x_t)
72 |
73 | return action, log_prob
74 |
--------------------------------------------------------------------------------
/sac.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from copy import deepcopy
4 | import random
5 | import numpy as np
6 | import torch
7 | torch.set_default_dtype(torch.float64)
8 | from torch.utils.tensorboard import SummaryWriter
9 | from env.utils import make_env
10 | from models.sac import ReplayBuffer, Q_FC, Pi_FC
11 |
12 | # Soft Actor-Critic algorithm
13 | class SAC:
14 | def __init__(self, arglist):
15 | self.arglist = arglist
16 |
17 | random.seed(self.arglist.seed)
18 | np.random.seed(self.arglist.seed)
19 | torch.manual_seed(self.arglist.seed)
20 |
21 | self.env = make_env(self.arglist.env)
22 | self.obs_size = self.env.obs_size
23 | self.action_size = self.env.action_size
24 |
25 | self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
26 |
27 | self.actor = Pi_FC(self.obs_size,self.action_size).to(self.device)
28 |
29 | if self.arglist.mode == "train":
30 | self.critic_1 = Q_FC(self.obs_size,self.action_size).to(self.device)
31 | self.critic_target_1 = deepcopy(self.critic_1)
32 | self.critic_loss_fn_1 = torch.nn.MSELoss()
33 |
34 | self.critic_2 = Q_FC(self.obs_size,self.action_size).to(self.device)
35 | self.critic_target_2 = deepcopy(self.critic_2)
36 | self.critic_loss_fn_2 = torch.nn.MSELoss()
37 |
38 | # set target entropy to -|A|
39 | self.target_entropy = - self.action_size
40 |
41 | path = "./log/"+self.env.name+"/sac"
42 | self.exp_dir = os.path.join(path, "seed_"+str(self.arglist.seed))
43 | self.model_dir = os.path.join(self.exp_dir, "models")
44 | self.tensorboard_dir = os.path.join(self.exp_dir, "tensorboard")
45 |
46 | if self.arglist.resume:
47 | checkpoint = torch.load(os.path.join(self.model_dir,"backup.ckpt"))
48 | self.start_episode = checkpoint['episode'] + 1
49 |
50 | self.actor.load_state_dict(checkpoint['actor'])
51 | self.critic_1.load_state_dict(checkpoint['critic_1'])
52 | self.critic_target_1.load_state_dict(checkpoint['critic_target_1'])
53 | self.critic_2.load_state_dict(checkpoint['critic_2'])
54 | self.critic_target_2.load_state_dict(checkpoint['critic_target_2'])
55 | self.log_alpha = torch.tensor(checkpoint['log_alpha'].item(), dtype=torch.float64, device=self.device, requires_grad=True)
56 |
57 | self.replay_buffer = checkpoint['replay_buffer']
58 |
59 | else:
60 | self.start_episode = 0
61 |
62 | self.log_alpha = torch.tensor(np.log(0.2), dtype=torch.float64, device=self.device, requires_grad=True)
63 |
64 | self.replay_buffer = ReplayBuffer(self.arglist.replay_size, self.device)
65 |
66 | if not os.path.exists(path):
67 | os.makedirs(path)
68 | os.mkdir(self.exp_dir)
69 | os.mkdir(self.tensorboard_dir)
70 | os.mkdir(self.model_dir)
71 |
72 | for param in self.critic_target_1.parameters():
73 | param.requires_grad = False
74 |
75 | for param in self.critic_target_2.parameters():
76 | param.requires_grad = False
77 |
78 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.arglist.lr)
79 | self.critic_optimizer_1 = torch.optim.Adam(self.critic_1.parameters(), lr=self.arglist.lr)
80 | self.critic_optimizer_2 = torch.optim.Adam(self.critic_2.parameters(), lr=self.arglist.lr)
81 | self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=self.arglist.lr)
82 |
83 | if self.arglist.resume:
84 | self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
85 | self.critic_optimizer_1.load_state_dict(checkpoint['critic_optimizer_1'])
86 | self.critic_optimizer_2.load_state_dict(checkpoint['critic_optimizer_2'])
87 | self.log_alpha_optimizer.load_state_dict(checkpoint['log_alpha_optimizer'])
88 |
89 | print("Done loading checkpoint ...")
90 |
91 | self.train()
92 |
93 | elif self.arglist.mode == "eval":
94 | checkpoint = torch.load(self.arglist.checkpoint,map_location=self.device)
95 | self.actor.load_state_dict(checkpoint['actor'])
96 | ep_r_list = self.eval(self.arglist.episodes,self.arglist.render)
97 |
98 | def save_checkpoint(self, name):
99 | checkpoint = {'actor' : self.actor.state_dict()}
100 | torch.save(checkpoint, os.path.join(self.model_dir, name))
101 |
102 | def save_backup(self, episode):
103 | checkpoint = {'episode' : episode,\
104 | 'actor' : self.actor.state_dict(),\
105 | 'actor_optimizer': self.actor_optimizer.state_dict(),\
106 | 'critic_1' : self.critic_1.state_dict(),\
107 | 'critic_optimizer_1': self.critic_optimizer_1.state_dict(),\
108 | 'critic_2' : self.critic_2.state_dict(),\
109 | 'critic_optimizer_2': self.critic_optimizer_2.state_dict(),\
110 | 'critic_target_1' : self.critic_target_1.state_dict(),\
111 | 'critic_target_2' : self.critic_target_2.state_dict(),\
112 | 'log_alpha' : self.log_alpha.detach(),\
113 | 'log_alpha_optimizer': self.log_alpha_optimizer.state_dict(),\
114 | 'replay_buffer' : self.replay_buffer \
115 | }
116 | torch.save(checkpoint, os.path.join(self.model_dir, "backup.ckpt"))
117 |
118 | def soft_update(self, target, source, tau):
119 | with torch.no_grad():
120 | for target_param, param in zip(target.parameters(), source.parameters()):
121 | target_param.data.copy_((1.0 - tau) * target_param.data + tau * param.data)
122 |
123 | def train(self):
124 | writer = SummaryWriter(log_dir=self.tensorboard_dir)
125 | for episode in range(self.start_episode,self.arglist.episodes):
126 | o,_,_ = self.env.reset()
127 | ep_r = 0
128 | while True:
129 | if self.replay_buffer.__len__() >= self.arglist.start_steps:
130 | with torch.no_grad():
131 | a, _ = self.actor(torch.tensor(o, dtype=torch.float64, device=self.device).unsqueeze(0))
132 | a = a.cpu().numpy()[0]
133 | else:
134 | a = np.random.uniform(-1.0, 1.0, size=self.action_size)
135 |
136 | o_1, r, done = self.env.step(a)
137 |
138 | self.replay_buffer.push(o, a, r, o_1)
139 |
140 | ep_r += r
141 | o = o_1
142 |
143 | if self.replay_buffer.__len__() >= self.arglist.replay_fill:
144 | O, A, R, O_1 = self.replay_buffer.sample(self.arglist.batch_size)
145 |
146 | q_value_1 = self.critic_1(O, A)
147 | q_value_2 = self.critic_2(O, A)
148 |
149 | with torch.no_grad():
150 | # Target actions come from *current* policy
151 | A_1, logp_A_1 = self.actor(O_1, False, True)
152 |
153 | next_q_value_1 = self.critic_target_1(O_1, A_1)
154 | next_q_value_2 = self.critic_target_2(O_1, A_1)
155 | next_q_value = torch.min(next_q_value_1, next_q_value_2)
156 | expected_q_value = R + self.arglist.gamma * (next_q_value - torch.exp(self.log_alpha) * logp_A_1)
157 |
158 | critic_loss_1 = self.critic_loss_fn_1(q_value_1, expected_q_value)
159 | self.critic_optimizer_1.zero_grad()
160 | critic_loss_1.backward()
161 | self.critic_optimizer_1.step()
162 |
163 | critic_loss_2 = self.critic_loss_fn_2(q_value_2, expected_q_value)
164 | self.critic_optimizer_2.zero_grad()
165 | critic_loss_2.backward()
166 | self.critic_optimizer_2.step()
167 |
168 | for param_1, param_2 in zip(self.critic_1.parameters(), self.critic_2.parameters()):
169 | param_1.requires_grad = False
170 | param_2.requires_grad = False
171 |
172 | A_pi, logp_A_pi = self.actor(O, False, True)
173 | q_value_pi_1 = self.critic_1(O, A_pi)
174 | q_value_pi_2 = self.critic_2(O, A_pi)
175 | q_value_pi = torch.min(q_value_pi_1, q_value_pi_2)
176 |
177 | actor_loss = - torch.mean(q_value_pi - torch.exp(self.log_alpha).detach() * logp_A_pi)
178 | self.actor_optimizer.zero_grad()
179 | actor_loss.backward()
180 | self.actor_optimizer.step()
181 |
182 | self.log_alpha_optimizer.zero_grad()
183 | alpha_loss = (torch.exp(self.log_alpha) * (-logp_A_pi - self.target_entropy).detach()).mean()
184 | alpha_loss.backward()
185 | self.log_alpha_optimizer.step()
186 |
187 | for param_1, param_2 in zip(self.critic_1.parameters(), self.critic_2.parameters()):
188 | param_1.requires_grad = True
189 | param_2.requires_grad = True
190 |
191 | self.soft_update(self.critic_target_1, self.critic_1, self.arglist.tau)
192 | self.soft_update(self.critic_target_2, self.critic_2, self.arglist.tau)
193 |
194 | if done:
195 | writer.add_scalar('ep_r', ep_r, episode)
196 | with torch.no_grad():
197 | writer.add_scalar('alpha',torch.exp(self.log_alpha).item(),episode)
198 | if episode % self.arglist.eval_every == 0 or episode == self.arglist.episodes-1:
199 | # Evaluate agent performance
200 | eval_ep_r_list = self.eval(self.arglist.eval_over)
201 | writer.add_scalar('eval_ep_r', np.mean(eval_ep_r_list), episode)
202 | self.save_checkpoint(str(episode)+".ckpt")
203 | if (episode % 250 == 0 or episode == self.arglist.episodes-1) and episode > self.start_episode:
204 | self.save_backup(episode)
205 | break
206 |
207 | def eval(self, episodes, render=False):
208 | # Evaluate agent performance over several episodes
209 | ep_r_list = []
210 | for episode in range(episodes):
211 | o,_,_ = self.env.reset()
212 | ep_r = 0
213 | while True:
214 | with torch.no_grad():
215 | a, _ = self.actor(torch.tensor(o, dtype=torch.float64, device=self.device).unsqueeze(0),True)
216 | a = a.cpu().numpy()[0]
217 | o_1,r,done = self.env.step(a)
218 | if render:
219 | self.env.render()
220 | ep_r += r
221 | o = o_1
222 | if done:
223 | ep_r_list.append(ep_r)
224 | if render:
225 | print("Episode finished with total reward ",ep_r)
226 | break
227 |
228 | if self.arglist.mode == "eval":
229 | print("Average return :",np.mean(ep_r_list))
230 |
231 | return ep_r_list
232 |
233 | def parse_args():
234 | parser = argparse.ArgumentParser("SAC")
235 | # Common settings
236 | parser.add_argument("--env", type=str, default="cart3pole", help="pendulum / reacher / cartpole / acrobot / cart2pole / acro3bot / cart3pole")
237 | parser.add_argument("--mode", type=str, default="train", help="train or eval")
238 | parser.add_argument("--episodes", type=int, default=10000, help="number of episodes")
239 | parser.add_argument("--seed", type=int, default=0, help="seed")
240 | # Core training parameters
241 | parser.add_argument("--resume", action="store_true", default=False, help="resume training")
242 | parser.add_argument("--lr", type=float, default=3e-4, help="actor, critic learning rate")
243 | parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
244 | parser.add_argument("--batch-size", type=int, default=256, help="batch size")
245 | parser.add_argument("--tau", type=float, default=0.005, help="soft target update parameter")
246 | parser.add_argument("--start-steps", type=int, default=int(1e4), help="start steps")
247 | parser.add_argument("--replay-size", type=int, default=int(1e6), help="replay buffer size")
248 | parser.add_argument("--replay-fill", type=int, default=int(1e4), help="elements in replay buffer before training starts")
249 | parser.add_argument("--eval-every", type=int, default=50, help="eval every _ episodes during training")
250 | parser.add_argument("--eval-over", type=int, default=50, help="each time eval over _ episodes")
251 | # Eval settings
252 | parser.add_argument("--checkpoint", type=str, default="", help="path to checkpoint")
253 | parser.add_argument("--render", action="store_true", default=False, help="render")
254 |
255 | return parser.parse_args()
256 |
257 | if __name__ == '__main__':
258 | arglist = parse_args()
259 | sac = SAC(arglist)
260 |
--------------------------------------------------------------------------------