├── README.md ├── env ├── __init__.py ├── acro3bot │ ├── acro3bot.py │ ├── derive.py │ └── robot.p ├── acrobot │ ├── acrobot.py │ ├── derive.py │ └── robot.p ├── base.py ├── cart2pole │ ├── cart2pole.py │ ├── derive.py │ └── robot.p ├── cart3pole │ ├── cart3pole.py │ ├── derive.py │ └── robot.p ├── cartpole │ ├── cartpole.py │ ├── derive.py │ └── robot.p ├── pendulum │ ├── derive.py │ ├── pendulum.py │ └── robot.p ├── reacher │ ├── derive.py │ ├── reacher.py │ └── robot.p ├── rewards.py └── utils.py ├── mbrl.py ├── models ├── __init__.py ├── mbrl.py └── sac.py └── sac.py /README.md: -------------------------------------------------------------------------------- 1 | # Physics-Informed Model-Based RL 2 | 3 | Published at Learning for Dynamics & Control Conference (L4DC), 2023. 4 | 5 |

6 | 7 |

8 | 9 | ## Abstract 10 | We apply reinforcement learning (RL) to robotics tasks. One of the drawbacks of traditional RL algorithms has been their poor sample efficiency. One approach to improve the sample efficiency is model-based RL. In our model-based RL algorithm, we learn a model of the environment, essentially its transition dynamics and reward function, use it to generate imaginary trajectories and backpropagate through them to update the policy, exploiting the differentiability of the model. 11 | 12 | Intuitively, learning more accurate models should lead to better model-based RL performance. Recently, there has been growing interest in developing better deep neural network based dynamics models for physical systems, by utilizing the structure of the underlying physics. We focus on robotic systems undergoing rigid body motion without contacts. We compare two versions of our model-based RL algorithm, one which uses a standard deep neural network based dynamics model and the other which uses a much more accurate, physics-informed neural network based dynamics model. 13 | 14 | We show that, in model-based RL, model accuracy mainly matters in environments that are sensitive to initial conditions, where numerical errors accumulate fast. In these environments, the physics-informed version of our algorithm achieves significantly better average-return and sample efficiency. In environments that are not sensitive to initial conditions, both versions of our algorithm achieve similar average-return, while the physics-informed version achieves better sample efficiency. 15 | 16 | We also show that, in challenging environments, physics-informed model-based RL achieves better average-return than state-of-the-art model-free RL algorithms such as Soft Actor-Critic, as it computes the policy-gradient analytically. 17 | 18 | For more information check out, 19 | - [Project Webpage](https://adi3e08.github.io/research/pimbrl) 20 | - [Paper](https://arxiv.org/abs/2212.02179) 21 | 22 | ## Requirements 23 | - Python 24 | - Numpy 25 | - Pytorch 26 | - Tensorboard 27 | - Pygame 28 | 29 | ## Usage 30 | To train MBRL LNN on Acrobot task, run, 31 | 32 | python mbrl.py --env acrobot --mode train --episodes 500 --seed 0 33 | 34 | The data from this experiment will be stored in the folder "./log/acrobot/mbrl_lnn/seed_0". This folder will contain two sub folders, (i) models : here model checkpoints will be stored and (ii) tensorboard : here tensorboard plots will be stored. 35 | 36 | To evaluate MRBL LNN on Acrobot task, run, 37 | 38 | python mbrl.py --env acrobot --mode eval --episodes 3 --seed 100 --checkpoint ./log/acrobot/mbrl_lnn/seed_0/models/499.ckpt --render 39 | 40 | ## Citation 41 | If you find this work helpful, please consider starring this repo and citing our paper using the following Bibtex. 42 | ```bibtex 43 | @inproceedings{ramesh2023physics, 44 | title={Physics-Informed Model-Based Reinforcement Learning}, 45 | author={Ramesh, Adithya and Ravindran, Balaraman}, 46 | booktitle={Learning for Dynamics and Control Conference}, 47 | pages={26--37}, 48 | year={2023}, 49 | organization={PMLR} 50 | } 51 | 52 | 53 | -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /env/acro3bot/acro3bot.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | import numpy as np 5 | from env import rewards 6 | from env.utils import rect_points, wrap, basic_check, pygame_transform 7 | from env.base import BaseEnv 8 | 9 | class acro3bot(BaseEnv): 10 | def __init__(self): 11 | m1 = 0.1 12 | l1 = 1 13 | r1 = l1/2 14 | I1 = m1 * l1**2 / 12 15 | 16 | m2 = 0.1 17 | l2 = 1 18 | r2 = l2/2 19 | I2 = m2 * l2**2 / 12 20 | 21 | m3 = 0.1 22 | l3 = 1 23 | r3 = l3/2 24 | I3 = m3 * l3**2 / 12 25 | 26 | g = 9.8 27 | 28 | m = [m1,m2,m3] 29 | l = [l1,l2,l3] 30 | r = [r1,r2,r3] 31 | I = [I1,I2,I3] 32 | super(acro3bot, self).__init__(name = "acro3bot", 33 | n = 3, 34 | obs_size = 9, 35 | action_size = 2, 36 | inertials = m+l+r+I+[g], 37 | a_scale = np.array([0.5,2.0])) 38 | 39 | def wrap_state(self): 40 | self.state[:3] = wrap(self.state[:3]) 41 | 42 | def reset_state(self): 43 | self.state = np.array([np.pi + 0.01*np.random.randn(), 44 | 0.01*np.random.randn(), 45 | 0.01*np.random.randn(), 46 | 0,0,0]) 47 | 48 | def get_A(self, a): 49 | a_1, a_3 = np.clip(a, -1.0, 1.0)*self.a_scale 50 | a_2 = 0.0 51 | return np.array([a_1,a_2,a_3]) 52 | 53 | def get_obs(self): 54 | return np.array([np.cos(self.state[0]),np.sin(self.state[0]), 55 | np.cos(self.state[1]),np.sin(self.state[1]), 56 | np.cos(self.state[2]),np.sin(self.state[2]), 57 | self.state[3], 58 | self.state[4], 59 | self.state[5] 60 | ]) 61 | 62 | def get_reward(self): 63 | upright = (np.array([np.cos(self.state[0]), np.cos(self.state[0]+self.state[1]), np.cos(self.state[0]+self.state[1]+self.state[2])])+1)/2 64 | 65 | qdot = self.state[self.n:] 66 | ang_vel = np.array([qdot[0],qdot[0]+qdot[1],qdot[0]+qdot[1]+qdot[2]]) 67 | small_velocity = rewards.tolerance(ang_vel, margin=self.ang_vel_limit).min() 68 | small_velocity = (1 + small_velocity) / 2 69 | 70 | reward = upright.mean() * small_velocity 71 | 72 | return reward 73 | 74 | def draw(self): 75 | centers, joints, angles = self.geo 76 | 77 | for j in range(self.n): 78 | link_points = rect_points(centers[j], self.link_length, self.link_width, angles[j,0],self.scaling,self.offset) 79 | pygame.draw.polygon(self.screen, self.link_color, link_points) 80 | 81 | joint_point = pygame_transform(joints[j],self.scaling,self.offset) 82 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius) 83 | 84 | if __name__ == '__main__': 85 | basic_check("acro3bot",0) 86 | -------------------------------------------------------------------------------- /env/acro3bot/derive.py: -------------------------------------------------------------------------------- 1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten 2 | from sympy.utilities.lambdify import lambdify 3 | from sympy.matrices.dense import matrix_multiply_elementwise 4 | import dill as pickle 5 | pickle.settings['recurse'] = True 6 | 7 | def get_C_G(n,M,V,q,qdot): 8 | C = zeros(n) 9 | for i in range(n): 10 | for j in range(n): 11 | for k in range(n): 12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2 13 | 14 | G = Matrix([diff(V,q[i]) for i in range(n)]) 15 | 16 | return C,G 17 | 18 | def derive(): 19 | lambda_dict = {} 20 | n = 3 21 | m1,m2,m3 = symbols('m1,m2,m3') 22 | l1,l2,l3 = symbols('l1,l2,l3') 23 | r1,r2,r3 = symbols('r1,r2,r3') 24 | I1,I2,I3 = symbols('I1,I2,I3') 25 | g = symbols('g') 26 | q1,q2,q3 = symbols('q1,q2,q3') 27 | q1dot,q2dot,q3dot = symbols('q1dot,q2dot,q3dot') 28 | 29 | m = [m1,m2,m3] 30 | l = [l1,l2,l3] 31 | r = [r1,r2,r3] 32 | I = [I1,I2,I3] 33 | inertials = m+l+r+I+[g] 34 | 35 | q = Matrix([q1,q2,q3]) 36 | qdot = Matrix([q1dot,q2dot,q3dot]) 37 | state = [q1,q2,q3,q1dot,q2dot,q3dot] 38 | 39 | J_w = Matrix([[1,0,0], 40 | [1,1,0], 41 | [1,1,1] 42 | ]) 43 | 44 | angles = J_w * q 45 | 46 | V = 0 47 | M = zeros(n) 48 | J = [] 49 | for i in range(n): 50 | if i == 0: 51 | joint = Matrix([[0, 0]]) 52 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 53 | joints = joint 54 | centers = center 55 | else: 56 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]]) 57 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 58 | joints = Matrix.vstack(joints, joint) 59 | centers = Matrix.vstack(centers, center) 60 | 61 | J_v = center.jacobian(q) 62 | # J.append(J_v) 63 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:] 64 | M += M_i 65 | 66 | V += m[i]*g*center[0,1] 67 | 68 | # print(cse([centers,joints,J_w]+J, optimizations='basic')) 69 | 70 | C,G = get_C_G(n,M,V,q,qdot) 71 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True) 72 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True) 73 | 74 | with open("./env/acro3bot/robot.p", "wb") as outf: 75 | pickle.dump(lambda_dict, outf) 76 | 77 | print("Done") 78 | 79 | if __name__ == '__main__': 80 | derive() 81 | -------------------------------------------------------------------------------- /env/acro3bot/robot.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/acro3bot/robot.p -------------------------------------------------------------------------------- /env/acrobot/acrobot.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | import numpy as np 5 | from env import rewards 6 | from env.utils import rect_points, wrap, basic_check, pygame_transform 7 | from env.base import BaseEnv 8 | 9 | class acrobot(BaseEnv): 10 | def __init__(self): 11 | m1 = 0.1 12 | l1 = 1 13 | r1 = l1/2 14 | I1 = m1 * l1**2 / 12 15 | 16 | m2 = 0.1 17 | l2 = 1 18 | r2 = l2/2 19 | I2 = m2 * l2**2 / 12 20 | 21 | g = 9.8 22 | 23 | m = [m1,m2] 24 | l = [l1,l2] 25 | r = [r1,r2] 26 | I = [I1,I2] 27 | super(acrobot, self).__init__(name = "acrobot", 28 | n = 2, 29 | obs_size = 6, 30 | action_size = 1, 31 | inertials = m+l+r+I+[g], 32 | a_scale = np.array([1.75])) 33 | 34 | def wrap_state(self): 35 | self.state[:2] = wrap(self.state[:2]) 36 | 37 | def reset_state(self): 38 | self.state = np.array([np.pi + 0.01*np.random.randn(), 39 | 0.01*np.random.randn(), 40 | 0, 41 | 0]) 42 | 43 | def get_A(self, a): 44 | a_2, = np.clip(a, -1.0, 1.0)*self.a_scale 45 | a_1 = 0.0 46 | return np.array([a_1,a_2]) 47 | 48 | def get_obs(self): 49 | return np.array([np.cos(self.state[0]),np.sin(self.state[0]), 50 | np.cos(self.state[1]),np.sin(self.state[1]), 51 | self.state[2], 52 | self.state[3] 53 | ]) 54 | 55 | def get_reward(self): 56 | upright = (np.array([np.cos(self.state[0]), np.cos(self.state[0]+self.state[1])])+1)/2 57 | 58 | qdot = self.state[self.n:] 59 | ang_vel = np.array([qdot[0],qdot[0]+qdot[1]]) 60 | small_velocity = rewards.tolerance(ang_vel, margin=self.ang_vel_limit).min() 61 | small_velocity = (1 + small_velocity) / 2 62 | 63 | reward = upright.mean() * small_velocity 64 | 65 | return reward 66 | 67 | def draw(self): 68 | centers, joints, angles = self.geo 69 | 70 | for j in range(self.n): 71 | link_points = rect_points(centers[j], self.link_length, self.link_width, angles[j,0],self.scaling,self.offset) 72 | pygame.draw.polygon(self.screen, self.link_color, link_points) 73 | 74 | joint_point = pygame_transform(joints[j],self.scaling,self.offset) 75 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius) 76 | 77 | if __name__ == '__main__': 78 | basic_check("acrobot",0) 79 | -------------------------------------------------------------------------------- /env/acrobot/derive.py: -------------------------------------------------------------------------------- 1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten 2 | from sympy.utilities.lambdify import lambdify 3 | from sympy.matrices.dense import matrix_multiply_elementwise 4 | import dill as pickle 5 | pickle.settings['recurse'] = True 6 | 7 | def get_C_G(n,M,V,q,qdot): 8 | C = zeros(n) 9 | for i in range(n): 10 | for j in range(n): 11 | for k in range(n): 12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2 13 | 14 | G = Matrix([diff(V,q[i]) for i in range(n)]) 15 | 16 | return C,G 17 | 18 | def derive(): 19 | lambda_dict = {} 20 | n = 2 21 | m1,m2 = symbols('m1,m2') 22 | l1,l2 = symbols('l1,l2') 23 | r1,r2 = symbols('r1,r2') 24 | I1,I2 = symbols('I1,I2') 25 | g = symbols('g') 26 | q1,q2 = symbols('q1,q2') 27 | q1dot,q2dot = symbols('q1dot,q2dot') 28 | 29 | m = [m1,m2] 30 | l = [l1,l2] 31 | r = [r1,r2] 32 | I = [I1,I2] 33 | inertials = m+l+r+I+[g] 34 | 35 | q = Matrix([q1,q2]) 36 | qdot = Matrix([q1dot,q2dot]) 37 | state = [q1,q2,q1dot,q2dot] 38 | 39 | J_w = Matrix([[1,0], 40 | [1,1] 41 | ]) 42 | 43 | angles = J_w * q 44 | 45 | V = 0 46 | M = zeros(n) 47 | J = [] 48 | for i in range(n): 49 | if i == 0: 50 | joint = Matrix([[0, 0]]) 51 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 52 | joints = joint 53 | centers = center 54 | else: 55 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]]) 56 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 57 | joints = Matrix.vstack(joints, joint) 58 | centers = Matrix.vstack(centers, center) 59 | 60 | J_v = center.jacobian(q) 61 | # J.append(J_v) 62 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:] 63 | M += M_i 64 | 65 | V += m[i]*g*center[0,1] 66 | 67 | # print(cse([centers,joints,J_w]+J, optimizations='basic')) 68 | 69 | C,G = get_C_G(n,M,V,q,qdot) 70 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True) 71 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True) 72 | 73 | with open("./env/acrobot/robot.p", "wb") as outf: 74 | pickle.dump(lambda_dict, outf) 75 | 76 | print("Done") 77 | 78 | if __name__ == '__main__': 79 | derive() 80 | -------------------------------------------------------------------------------- /env/acrobot/robot.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/acrobot/robot.p -------------------------------------------------------------------------------- /env/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import time 4 | import dill as pickle 5 | pickle.settings['recurse'] = True 6 | import pygame 7 | import numpy as np 8 | from env.utils import create_background 9 | from copy import deepcopy 10 | 11 | class BaseEnv: 12 | def __init__(self,name,n,obs_size,action_size,inertials,a_scale): 13 | self.name = name 14 | self.n = n 15 | self.obs_size = obs_size 16 | self.action_size = action_size 17 | self.inertials = inertials 18 | self.a_scale = a_scale 19 | 20 | self.dt = 0.01 # time per simulation step (in seconds) 21 | self.t = 0 # elapsed simulation steps 22 | self.t_max = 1000 # max simulation steps 23 | self.state = np.zeros(self.n) 24 | self.ang_vel_limit = 20.0 25 | 26 | with open("./env/"+self.name+"/robot.p", "rb") as inf: 27 | funcs = pickle.load(inf) 28 | self.kinematics = funcs['kinematics'] 29 | self.dynamics = funcs['dynamics'] 30 | 31 | # For rendering 32 | self.display = False 33 | self.screen_width = 500 34 | self.screen_height = 500 35 | self.offset = [250, 250] 36 | self.scaling = 75 37 | self.x_limit = 2.0 38 | 39 | self.link_length = 1.0 40 | self.link_width = 0.2 41 | self.link_color = (72,209,204) # medium turquoise 42 | 43 | self.joint_radius = self.link_width/1.8 44 | self.joint_color = (205,55,0) # orange red 45 | 46 | self.cart_length = 5*self.link_width 47 | self.cart_width = 2*self.link_width 48 | self.cart_color = (200,255,0) # yellow 49 | 50 | self.rail_length = 2*self.x_limit 51 | self.rail_width = self.link_width/2.5 52 | self.rail_color = (150,150,150) # gray 53 | 54 | def wrap_state(self): 55 | pass 56 | 57 | def reset_state(self): 58 | pass 59 | 60 | def get_A(self, a): 61 | pass 62 | 63 | def get_obs(self): 64 | pass 65 | 66 | def get_reward(self): 67 | pass 68 | 69 | def draw(self): 70 | pass 71 | 72 | def set_state(self,s): 73 | self.state = s 74 | 75 | def reset(self): 76 | self.reset_state() 77 | self.wrap_state() 78 | self.geo = self.kinematics(self.inertials+self.state.tolist()) 79 | self.t = 0 80 | 81 | return self.get_obs(), 0.0, False 82 | 83 | def step(self, a): 84 | self.state = self.rk4(self.state, self.get_A(a)) 85 | self.wrap_state() 86 | self.geo = self.kinematics(self.inertials+self.state.tolist()) 87 | 88 | self.t += 1 89 | if self.t >= self.t_max: # infinite horizon formulation, no terminal state, similar to dm_control 90 | done = True 91 | else: 92 | done = False 93 | 94 | return self.get_obs(), self.get_reward(), done 95 | 96 | def F(self, s, a): 97 | M, C, G = self.dynamics(self.inertials+s.tolist()) 98 | qdot = s[self.n:] 99 | qddot = np.linalg.inv(M+1e-6*np.eye(self.n)) @ (a - C @ qdot - G.flatten()) 100 | 101 | return np.concatenate((qdot,qddot)) 102 | 103 | def rk4(self, s, a): 104 | s0 = deepcopy(s) 105 | k = [] 106 | for l in range(4): 107 | if l > 0: 108 | if l == 1 or l == 2: 109 | dt = self.dt/2 110 | elif l == 3: 111 | dt = self.dt 112 | s = s0 + dt * k[l-1] 113 | k.append(self.F(s, a)) 114 | s = s0 + (self.dt/6.0) * (k[0] + 2 * k[1] + 2 * k[2] + k[3]) 115 | 116 | return s 117 | 118 | def render(self): 119 | if self.display: 120 | self.screen.blit(self.background, (0, 0)) 121 | self.draw() 122 | time.sleep(0.006) 123 | pygame.display.flip() 124 | else: 125 | self.display = True 126 | pygame.init() 127 | self.screen = pygame.display.set_mode((self.screen_width, self.screen_height)) 128 | pygame.display.set_caption(self.name) 129 | self.background = create_background(self.screen_width, self.screen_height) 130 | -------------------------------------------------------------------------------- /env/cart2pole/cart2pole.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | import numpy as np 5 | from env import rewards 6 | from env.utils import rect_points, wrap, basic_check, pygame_transform 7 | from env.base import BaseEnv 8 | 9 | class cart2pole(BaseEnv): 10 | def __init__(self): 11 | m1 = 1 12 | l1, r1, I1 = 0, 0, 0 # dummy 13 | 14 | m2 = 0.1 15 | l2 = 1 16 | r2 = l2/2 17 | I2 = m2 * l2**2 / 12 18 | 19 | m3 = 0.1 20 | l3 = 1 21 | r3 = l3/2 22 | I3 = m3 * l3**2 / 12 23 | 24 | g = 9.8 25 | 26 | m = [m1,m2,m3] 27 | l = [l1,l2,l3] 28 | r = [r1,r2,r3] 29 | I = [I1,I2,I3] 30 | super(cart2pole, self).__init__(name = "cart2pole", 31 | n = 3, 32 | obs_size = 8, 33 | action_size = 1, 34 | inertials = m+l+r+I+[g], 35 | a_scale = np.array([10.0])) 36 | def wrap_state(self): 37 | self.state[1:3] = wrap(self.state[1:3]) 38 | 39 | def reset_state(self): 40 | self.state = np.array([0.01*np.random.randn(), 41 | np.pi + 0.01*np.random.randn(), 42 | 0.01*np.random.randn(), 43 | 0,0,0]) 44 | 45 | def get_A(self, a): 46 | a_1, = np.clip(a, -1.0, 1.0)*self.a_scale 47 | a_2, a_3 = 0.0, 0.0 48 | return np.array([a_1,a_2,a_3]) 49 | 50 | def get_obs(self): 51 | return np.array([self.state[0], 52 | np.cos(self.state[1]),np.sin(self.state[1]), 53 | np.cos(self.state[2]),np.sin(self.state[2]), 54 | self.state[3], 55 | self.state[4], 56 | self.state[5] 57 | ]) 58 | 59 | def get_reward(self): 60 | upright = (np.array([np.cos(self.state[1]), np.cos(self.state[1]+self.state[2])])+1)/2 61 | 62 | if np.abs(self.state[0]) <= self.x_limit: 63 | centered = rewards.tolerance(self.state[0], margin=self.x_limit) 64 | centered = (1 + centered) / 2 65 | else: 66 | centered = 0.1 67 | 68 | qdot = self.state[self.n:] 69 | ang_vel = np.array([qdot[0],qdot[1],qdot[1]+qdot[2]]) 70 | small_velocity = rewards.tolerance(ang_vel[1:], margin=self.ang_vel_limit).min() 71 | small_velocity = (1 + small_velocity) / 2 72 | 73 | reward = upright.mean() * small_velocity * centered 74 | 75 | return reward 76 | 77 | def draw(self): 78 | centers, joints, angles = self.geo 79 | 80 | #horizontal rail 81 | pygame.draw.polygon(self.screen, self.rail_color, rect_points([0,0], self.rail_length, self.rail_width, np.pi/2, self.scaling, self.offset)) 82 | 83 | plot_x = ((centers[0,0] + self.x_limit) % (2 * self.x_limit)) - self.x_limit 84 | link1_points = rect_points([plot_x,0], self.cart_length, self.cart_width, np.pi/2, self.scaling, self.offset) 85 | pygame.draw.polygon(self.screen, self.cart_color, link1_points) 86 | 87 | offset = np.array([plot_x-centers[0,0],0]) 88 | for j in range(1,self.n): 89 | link_points = rect_points(centers[j]+offset, self.link_length, self.link_width, angles[j,0],self.scaling,self.offset) 90 | pygame.draw.polygon(self.screen, self.link_color, link_points) 91 | 92 | joint_point = pygame_transform(joints[j]+offset,self.scaling,self.offset) 93 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius) 94 | 95 | 96 | if __name__ == '__main__': 97 | basic_check("cart2pole",0) 98 | -------------------------------------------------------------------------------- /env/cart2pole/derive.py: -------------------------------------------------------------------------------- 1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten 2 | from sympy.utilities.lambdify import lambdify 3 | from sympy.matrices.dense import matrix_multiply_elementwise 4 | import dill as pickle 5 | pickle.settings['recurse'] = True 6 | 7 | def get_C_G(n,M,V,q,qdot): 8 | C = zeros(n) 9 | for i in range(n): 10 | for j in range(n): 11 | for k in range(n): 12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2 13 | 14 | G = Matrix([diff(V,q[i]) for i in range(n)]) 15 | 16 | return C,G 17 | 18 | def derive(): 19 | lambda_dict = {} 20 | n = 3 21 | m1,m2,m3 = symbols('m1,m2,m3') 22 | l1,l2,l3 = symbols('l1,l2,l3') 23 | r1,r2,r3 = symbols('r1,r2,r3') 24 | I1,I2,I3 = symbols('I1,I2,I3') 25 | g = symbols('g') 26 | q1,q2,q3 = symbols('q1,q2,q3') 27 | q1dot,q2dot,q3dot = symbols('q1dot,q2dot,q3dot') 28 | 29 | m = [m1,m2,m3] 30 | l = [l1,l2,l3] 31 | r = [r1,r2,r3] 32 | I = [I1,I2,I3] 33 | inertials = m+l+r+I+[g] 34 | 35 | q = Matrix([q1,q2,q3]) 36 | qdot = Matrix([q1dot,q2dot,q3dot]) 37 | state = [q1,q2,q3,q1dot,q2dot,q3dot] 38 | 39 | J_w = Matrix([[0,0,0], 40 | [0,1,0], 41 | [0,1,1] 42 | ]) 43 | 44 | angles = J_w * q 45 | 46 | V = 0 47 | M = zeros(n) 48 | J = [] 49 | for i in range(n): 50 | if i == 0: 51 | joint = Matrix([[q1, 0]]) 52 | center = joint 53 | joints = joint 54 | centers = center 55 | elif i == 1: 56 | joint = joint 57 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 58 | joints = Matrix.vstack(joints, joint) 59 | centers = Matrix.vstack(centers, center) 60 | else: 61 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]]) 62 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 63 | joints = Matrix.vstack(joints, joint) 64 | centers = Matrix.vstack(centers, center) 65 | 66 | J_v = center.jacobian(q) 67 | # J.append(J_v) 68 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:] 69 | M += M_i 70 | 71 | V += m[i]*g*center[0,1] 72 | 73 | # print(cse([centers,joints,J_w]+J, optimizations='basic')) 74 | 75 | C,G = get_C_G(n,M,V,q,qdot) 76 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True) 77 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True) 78 | 79 | with open("./env/cart2pole/robot.p", "wb") as outf: 80 | pickle.dump(lambda_dict, outf) 81 | 82 | print("Done") 83 | 84 | if __name__ == '__main__': 85 | derive() 86 | -------------------------------------------------------------------------------- /env/cart2pole/robot.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/cart2pole/robot.p -------------------------------------------------------------------------------- /env/cart3pole/cart3pole.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | import numpy as np 5 | from env import rewards 6 | from env.utils import rect_points, wrap, basic_check, pygame_transform 7 | from env.base import BaseEnv 8 | 9 | class cart3pole(BaseEnv): 10 | def __init__(self): 11 | m1 = 1 12 | l1, r1, I1 = 0, 0, 0 # dummy 13 | 14 | m2 = 0.1 15 | l2 = 1 16 | r2 = l2/2 17 | I2 = m2 * l2**2 / 12 18 | 19 | m3 = 0.1 20 | l3 = 1 21 | r3 = l3/2 22 | I3 = m3 * l3**2 / 12 23 | 24 | m4 = 0.1 25 | l4 = 1 26 | r4 = l4/2 27 | I4 = m4 * l4**2 / 12 28 | 29 | g = 9.8 30 | 31 | m = [m1,m2,m3,m4] 32 | l = [l1,l2,l3,l4] 33 | r = [r1,r2,r3,r4] 34 | I = [I1,I2,I3,I4] 35 | super(cart3pole, self).__init__(name = "cart3pole", 36 | n = 4, 37 | obs_size = 11, 38 | action_size = 2, 39 | inertials = m+l+r+I+[g], 40 | a_scale = np.array([10.0,1.0])) 41 | 42 | def wrap_state(self): 43 | self.state[1:4] = wrap(self.state[1:4]) 44 | 45 | def reset_state(self): 46 | self.state = np.array([0.01*np.random.randn(), 47 | np.pi + 0.01*np.random.randn(), 48 | 0.01*np.random.randn(), 49 | 0.01*np.random.randn(), 50 | 0,0,0,0]) 51 | 52 | def get_A(self, a): 53 | a_1, a_4 = np.clip(a, -1.0, 1.0)*self.a_scale 54 | a_2, a_3 = 0.0, 0.0 55 | return np.array([a_1,a_2,a_3,a_4]) 56 | 57 | def get_obs(self): 58 | return np.array([self.state[0], 59 | np.cos(self.state[1]),np.sin(self.state[1]), 60 | np.cos(self.state[2]),np.sin(self.state[2]), 61 | np.cos(self.state[3]),np.sin(self.state[3]), 62 | self.state[4], 63 | self.state[5], 64 | self.state[6], 65 | self.state[7] 66 | ]) 67 | 68 | def get_reward(self): 69 | upright = (np.array([np.cos(self.state[1]), np.cos(self.state[1]+self.state[2]), np.cos(self.state[1]+self.state[2]+self.state[3])])+1)/2 70 | 71 | if np.abs(self.state[0]) <= self.x_limit: 72 | centered = rewards.tolerance(self.state[0], margin=self.x_limit) 73 | centered = (1 + centered) / 2 74 | else: 75 | centered = 0.1 76 | 77 | qdot = self.state[self.n:] 78 | ang_vel = np.array([qdot[0],qdot[1],qdot[1]+qdot[2],qdot[1]+qdot[2]+qdot[3]]) 79 | small_velocity = rewards.tolerance(ang_vel[1:], margin=self.ang_vel_limit).min() 80 | small_velocity = (1 + small_velocity) / 2 81 | 82 | reward = upright.mean() * small_velocity * centered 83 | 84 | return reward 85 | 86 | def draw(self): 87 | centers, joints, angles = self.geo 88 | 89 | #horizontal rail 90 | pygame.draw.polygon(self.screen, self.rail_color, rect_points([0,0], self.rail_length, self.rail_width, np.pi/2, self.scaling, self.offset)) 91 | 92 | plot_x = ((centers[0,0] + self.x_limit) % (2 * self.x_limit)) - self.x_limit 93 | link1_points = rect_points([plot_x,0], self.cart_length, self.cart_width, np.pi/2, self.scaling, self.offset) 94 | pygame.draw.polygon(self.screen, self.cart_color, link1_points) 95 | 96 | offset = np.array([plot_x-centers[0,0],0]) 97 | for j in range(1,self.n): 98 | link_points = rect_points(centers[j]+offset, self.link_length, self.link_width, angles[j,0],self.scaling,self.offset) 99 | pygame.draw.polygon(self.screen, self.link_color, link_points) 100 | 101 | joint_point = pygame_transform(joints[j]+offset,self.scaling,self.offset) 102 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius) 103 | 104 | if __name__ == '__main__': 105 | basic_check("cart3pole",0) 106 | -------------------------------------------------------------------------------- /env/cart3pole/derive.py: -------------------------------------------------------------------------------- 1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten 2 | from sympy.utilities.lambdify import lambdify 3 | from sympy.matrices.dense import matrix_multiply_elementwise 4 | import dill as pickle 5 | pickle.settings['recurse'] = True 6 | 7 | def get_C_G(n,M,V,q,qdot): 8 | C = zeros(n) 9 | for i in range(n): 10 | for j in range(n): 11 | for k in range(n): 12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2 13 | 14 | G = Matrix([diff(V,q[i]) for i in range(n)]) 15 | 16 | return C,G 17 | 18 | def derive(): 19 | lambda_dict = {} 20 | n = 4 21 | m1,m2,m3,m4 = symbols('m1,m2,m3,m4') 22 | l1,l2,l3,l4 = symbols('l1,l2,l3,l4') 23 | r1,r2,r3,r4 = symbols('r1,r2,r3,r4') 24 | I1,I2,I3,I4 = symbols('I1,I2,I3,I4') 25 | g = symbols('g') 26 | q1,q2,q3,q4 = symbols('q1,q2,q3,q4') 27 | q1dot,q2dot,q3dot,q4dot = symbols('q1dot,q2dot,q3dot,q4dot') 28 | 29 | m = [m1,m2,m3,m4] 30 | l = [l1,l2,l3,l4] 31 | r = [r1,r2,r3,r4] 32 | I = [I1,I2,I3,I4] 33 | inertials = m+l+r+I+[g] 34 | 35 | q = Matrix([q1,q2,q3,q4]) 36 | qdot = Matrix([q1dot,q2dot,q3dot,q4dot]) 37 | state = [q1,q2,q3,q4,q1dot,q2dot,q3dot,q4dot] 38 | 39 | J_w = Matrix([[0,0,0,0], 40 | [0,1,0,0], 41 | [0,1,1,0], 42 | [0,1,1,1] 43 | ]) 44 | 45 | angles = J_w * q 46 | 47 | V = 0 48 | M = zeros(n) 49 | J = [] 50 | for i in range(n): 51 | if i == 0: 52 | joint = Matrix([[q1, 0]]) 53 | center = joint 54 | joints = joint 55 | centers = center 56 | elif i == 1: 57 | joint = joint 58 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 59 | joints = Matrix.vstack(joints, joint) 60 | centers = Matrix.vstack(centers, center) 61 | else: 62 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]]) 63 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 64 | joints = Matrix.vstack(joints, joint) 65 | centers = Matrix.vstack(centers, center) 66 | 67 | J_v = center.jacobian(q) 68 | # J.append(J_v) 69 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:] 70 | M += M_i 71 | 72 | V += m[i]*g*center[0,1] 73 | 74 | # print(cse([centers,joints,J_w]+J, optimizations='basic')) 75 | 76 | C,G = get_C_G(n,M,V,q,qdot) 77 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True) 78 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True) 79 | 80 | with open("./env/cart3pole/robot.p", "wb") as outf: 81 | pickle.dump(lambda_dict, outf) 82 | 83 | print("Done") 84 | 85 | if __name__ == '__main__': 86 | derive() 87 | -------------------------------------------------------------------------------- /env/cart3pole/robot.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/cart3pole/robot.p -------------------------------------------------------------------------------- /env/cartpole/cartpole.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | import numpy as np 5 | from env import rewards 6 | from env.utils import rect_points, wrap, basic_check, pygame_transform 7 | from env.base import BaseEnv 8 | 9 | class cartpole(BaseEnv): 10 | def __init__(self): 11 | m1 = 1 12 | l1, r1, I1 = 0, 0, 0 # dummy 13 | 14 | m2 = 0.1 15 | l2 = 1 16 | r2 = l2/2 17 | I2 = m2 * l2**2 / 12 18 | 19 | g = 9.8 20 | 21 | m = [m1,m2] 22 | l = [l1,l2] 23 | r = [r1,r2] 24 | I = [I1,I2] 25 | super(cartpole, self).__init__(name = "cartpole", 26 | n = 2, 27 | obs_size = 5, 28 | action_size = 1, 29 | inertials = m+l+r+I+[g], 30 | a_scale = np.array([10.0])) 31 | 32 | def wrap_state(self): 33 | self.state[1] = wrap(self.state[1]) 34 | 35 | def reset_state(self): 36 | self.state = np.array([0.01*np.random.randn(), 37 | np.pi + 0.01*np.random.randn(), 38 | 0,0]) 39 | 40 | def get_A(self, a): 41 | a_1, = np.clip(a, -1.0, 1.0)*self.a_scale 42 | a_2 = 0.0 43 | return np.array([a_1,a_2]) 44 | 45 | def get_obs(self): 46 | return np.array([self.state[0], 47 | np.cos(self.state[1]),np.sin(self.state[1]), 48 | self.state[2], 49 | self.state[3] 50 | ]) 51 | 52 | def get_reward(self): 53 | upright = (np.array([np.cos(self.state[1])])+1)/2 54 | 55 | if np.abs(self.state[0]) <= self.x_limit: 56 | centered = rewards.tolerance(self.state[0], margin=self.x_limit) 57 | centered = (1 + centered) / 2 58 | else: 59 | centered = 0.1 60 | 61 | qdot = self.state[self.n:] 62 | ang_vel = qdot 63 | small_velocity = rewards.tolerance(ang_vel[1:], margin=self.ang_vel_limit).min() 64 | small_velocity = (1 + small_velocity) / 2 65 | 66 | reward = upright.mean() * small_velocity * centered 67 | 68 | return reward 69 | 70 | def draw(self): 71 | centers, joints, angles = self.geo 72 | 73 | #horizontal rail 74 | pygame.draw.polygon(self.screen, self.rail_color, rect_points([0,0], self.rail_length, self.rail_width, np.pi/2, self.scaling, self.offset)) 75 | 76 | plot_x = ((centers[0,0] + self.x_limit) % (2 * self.x_limit)) - self.x_limit 77 | link1_points = rect_points([plot_x,0], self.cart_length, self.cart_width, np.pi/2, self.scaling, self.offset) 78 | pygame.draw.polygon(self.screen, self.cart_color, link1_points) 79 | 80 | offset = np.array([plot_x-centers[0,0],0]) 81 | for j in range(1,self.n): 82 | link_points = rect_points(centers[j]+offset, self.link_length, self.link_width, angles[j,0],self.scaling,self.offset) 83 | pygame.draw.polygon(self.screen, self.link_color, link_points) 84 | 85 | joint_point = pygame_transform(joints[j]+offset,self.scaling,self.offset) 86 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius) 87 | 88 | if __name__ == '__main__': 89 | basic_check("cartpole",0) 90 | -------------------------------------------------------------------------------- /env/cartpole/derive.py: -------------------------------------------------------------------------------- 1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten 2 | from sympy.utilities.lambdify import lambdify 3 | from sympy.matrices.dense import matrix_multiply_elementwise 4 | import dill as pickle 5 | pickle.settings['recurse'] = True 6 | 7 | def get_C_G(n,M,V,q,qdot): 8 | C = zeros(n) 9 | for i in range(n): 10 | for j in range(n): 11 | for k in range(n): 12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2 13 | 14 | G = Matrix([diff(V,q[i]) for i in range(n)]) 15 | 16 | return C,G 17 | 18 | def derive(): 19 | lambda_dict = {} 20 | n = 2 21 | m1,m2 = symbols('m1,m2') 22 | l1,l2 = symbols('l1,l2') 23 | r1,r2 = symbols('r1,r2') 24 | I1,I2 = symbols('I1,I2') 25 | g = symbols('g') 26 | q1,q2 = symbols('q1,q2') 27 | q1dot,q2dot = symbols('q1dot,q2dot') 28 | 29 | m = [m1,m2] 30 | l = [l1,l2] 31 | r = [r1,r2] 32 | I = [I1,I2] 33 | inertials = m+l+r+I+[g] 34 | 35 | q = Matrix([q1,q2]) 36 | qdot = Matrix([q1dot,q2dot]) 37 | state = [q1,q2,q1dot,q2dot] 38 | 39 | J_w = Matrix([[0,0], 40 | [0,1] 41 | ]) 42 | 43 | angles = J_w * q 44 | 45 | V = 0 46 | M = zeros(n) 47 | J = [] 48 | for i in range(n): 49 | if i == 0: 50 | joint = Matrix([[q1, 0]]) 51 | center = joint 52 | joints = joint 53 | centers = center 54 | elif i == 1: 55 | joint = joint 56 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 57 | joints = Matrix.vstack(joints, joint) 58 | centers = Matrix.vstack(centers, center) 59 | 60 | J_v = center.jacobian(q) 61 | # J.append(J_v) 62 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:] 63 | M += M_i 64 | 65 | V += m[i]*g*center[0,1] 66 | 67 | # print(cse([centers,joints,J_w]+J, optimizations='basic')) 68 | 69 | C,G = get_C_G(n,M,V,q,qdot) 70 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True) 71 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True) 72 | 73 | with open("./env/cartpole/robot.p", "wb") as outf: 74 | pickle.dump(lambda_dict, outf) 75 | 76 | print("Done") 77 | 78 | if __name__ == '__main__': 79 | derive() 80 | -------------------------------------------------------------------------------- /env/cartpole/robot.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/cartpole/robot.p -------------------------------------------------------------------------------- /env/pendulum/derive.py: -------------------------------------------------------------------------------- 1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten 2 | from sympy.utilities.lambdify import lambdify 3 | from sympy.matrices.dense import matrix_multiply_elementwise 4 | import dill as pickle 5 | pickle.settings['recurse'] = True 6 | 7 | def get_C_G(n,M,V,q,qdot): 8 | C = zeros(n) 9 | for i in range(n): 10 | for j in range(n): 11 | for k in range(n): 12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2 13 | 14 | G = Matrix([diff(V,q[i]) for i in range(n)]) 15 | 16 | return C,G 17 | 18 | def derive(): 19 | lambda_dict = {} 20 | n = 1 21 | m1 = symbols('m1') 22 | l1 = symbols('l1') 23 | r1 = symbols('r1') 24 | I1 = symbols('I1') 25 | g = symbols('g') 26 | q1 = symbols('q1') 27 | q1dot = symbols('q1dot') 28 | 29 | m = [m1] 30 | l = [l1] 31 | r = [r1] 32 | I = [I1] 33 | inertials = m+l+r+I+[g] 34 | 35 | q = Matrix([q1]) 36 | qdot = Matrix([q1dot]) 37 | state = [q1,q1dot] 38 | 39 | J_w = Matrix([[1] 40 | ]) 41 | 42 | angles = J_w * q 43 | 44 | V = 0 45 | M = zeros(n) 46 | J = [] 47 | for i in range(n): 48 | if i == 0: 49 | joint = Matrix([[0, 0]]) 50 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 51 | joints = joint 52 | centers = center 53 | 54 | M_i = I[i] * J_w[i,:].T * J_w[i,:] 55 | M += M_i 56 | 57 | V += m[i]*g*center[0,1] 58 | 59 | # print(cse([centers,joints,J_w]+J, optimizations='basic')) 60 | 61 | C,G = get_C_G(n,M,V,q,qdot) 62 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True) 63 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True) 64 | 65 | with open("./env/pendulum/robot.p", "wb") as outf: 66 | pickle.dump(lambda_dict, outf) 67 | 68 | print("Done") 69 | 70 | if __name__ == '__main__': 71 | derive() 72 | -------------------------------------------------------------------------------- /env/pendulum/pendulum.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | import numpy as np 5 | from env import rewards 6 | from env.utils import rect_points, wrap, basic_check, pygame_transform 7 | from env.base import BaseEnv 8 | 9 | class pendulum(BaseEnv): 10 | def __init__(self): 11 | m1 = 1 12 | l1 = 1 13 | r1 = l1 14 | I1 = m1 * l1**2 15 | 16 | g = 9.8 17 | 18 | m = [m1] 19 | l = [l1] 20 | r = [r1] 21 | I = [I1] 22 | super(pendulum, self).__init__(name = "pendulum", 23 | n = 1, 24 | obs_size = 3, 25 | action_size = 1, 26 | inertials = m+l+r+I+[g], 27 | a_scale = np.array([2.0])) 28 | self.dt = 0.02 29 | 30 | def wrap_state(self): 31 | self.state[0] = wrap(self.state[0]) 32 | 33 | def reset_state(self): 34 | self.state = np.array([np.pi+0.01*np.random.randn(),0]) 35 | 36 | def get_A(self, a): 37 | return np.clip(a, -1.0, 1.0)*self.a_scale 38 | 39 | def get_obs(self): 40 | return np.array([np.cos(self.state[0]),np.sin(self.state[0]), 41 | self.state[1] 42 | ]) 43 | 44 | def get_reward(self): 45 | upright = (np.array([np.cos(self.state[0])])+1)/2 46 | 47 | qdot = self.state[self.n:] 48 | ang_vel = qdot 49 | small_velocity = rewards.tolerance(ang_vel, margin=self.ang_vel_limit).min() 50 | small_velocity = (1 + small_velocity) / 2 51 | 52 | reward = upright.mean() * small_velocity 53 | 54 | return reward 55 | 56 | def draw(self): 57 | centers, joints, angles = self.geo 58 | 59 | link1_center = (centers[0]+joints[0])/2 60 | link1_points = rect_points(link1_center, self.link_length, self.link_width/2.5, angles[0,0],self.scaling,self.offset) 61 | pygame.draw.polygon(self.screen, self.link_color, link1_points) 62 | 63 | for j in range(self.n): 64 | center_point = [self.offset[0]+self.scaling*centers[j,0],self.offset[1]-self.scaling*centers[j,1]] 65 | pygame.draw.circle(self.screen, 'slateblue1', center_point, self.scaling*self.joint_radius*1.5) 66 | 67 | joint_point = pygame_transform(joints[j],self.scaling,self.offset) 68 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius) 69 | 70 | if __name__ == '__main__': 71 | basic_check("pendulum",0) 72 | -------------------------------------------------------------------------------- /env/pendulum/robot.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/pendulum/robot.p -------------------------------------------------------------------------------- /env/reacher/derive.py: -------------------------------------------------------------------------------- 1 | from sympy import symbols,cos,sin,simplify,diff,Matrix,linsolve,expand,nsimplify,zeros,flatten 2 | from sympy.utilities.lambdify import lambdify 3 | from sympy.matrices.dense import matrix_multiply_elementwise 4 | import dill as pickle 5 | pickle.settings['recurse'] = True 6 | 7 | def get_C_G(n,M,V,q,qdot): 8 | C = zeros(n) 9 | for i in range(n): 10 | for j in range(n): 11 | for k in range(n): 12 | C[i,j] += (diff(M[i,j],q[k])+diff(M[i,k],q[j])-diff(M[k,j],q[i]))*qdot[k]/2 13 | 14 | G = Matrix([diff(V,q[i]) for i in range(n)]) 15 | 16 | return C,G 17 | 18 | def derive(): 19 | lambda_dict = {} 20 | n = 2 21 | m1,m2 = symbols('m1,m2') 22 | l1,l2 = symbols('l1,l2') 23 | r1,r2 = symbols('r1,r2') 24 | I1,I2 = symbols('I1,I2') 25 | g = symbols('g') 26 | q1,q2 = symbols('q1,q2') 27 | q1dot,q2dot = symbols('q1dot,q2dot') 28 | 29 | m = [m1,m2] 30 | l = [l1,l2] 31 | r = [r1,r2] 32 | I = [I1,I2] 33 | inertials = m+l+r+I+[g] 34 | 35 | q = Matrix([q1,q2]) 36 | qdot = Matrix([q1dot,q2dot]) 37 | state = [q1,q2,q1dot,q2dot] 38 | 39 | J_w = Matrix([[1,0], 40 | [1,1] 41 | ]) 42 | 43 | angles = J_w * q 44 | 45 | V = 0 46 | M = zeros(n) 47 | J = [] 48 | for i in range(n): 49 | if i == 0: 50 | joint = Matrix([[0, 0]]) 51 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 52 | joints = joint 53 | centers = center 54 | else: 55 | joint = joint + l[i-1]*Matrix([[sin(angles[i-1]),cos(angles[i-1])]]) 56 | center = joint + (r[i])*Matrix([[sin(angles[i]),cos(angles[i])]]) 57 | joints = Matrix.vstack(joints, joint) 58 | centers = Matrix.vstack(centers, center) 59 | 60 | J_v = center.jacobian(q) 61 | # J.append(J_v) 62 | M_i = m[i] * J_v.T * J_v + I[i] * J_w[i,:].T * J_w[i,:] 63 | M += M_i 64 | 65 | V += m[i]*g*center[0,1] 66 | 67 | # print(cse([centers,joints,J_w]+J, optimizations='basic')) 68 | 69 | C,G = get_C_G(n,M,V,q,qdot) 70 | lambda_dict['kinematics'] = lambdify([tuple(inertials+state)],[centers,joints,angles],'numpy',cse=True) 71 | lambda_dict['dynamics'] = lambdify([tuple(inertials+state)],[M,C,G],'numpy',cse=True) 72 | 73 | with open("./env/reacher/robot.p", "wb") as outf: 74 | pickle.dump(lambda_dict, outf) 75 | 76 | print("Done") 77 | 78 | if __name__ == '__main__': 79 | derive() 80 | -------------------------------------------------------------------------------- /env/reacher/reacher.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | import numpy as np 5 | from env import rewards 6 | from env.utils import rect_points, wrap, basic_check, pygame_transform 7 | from env.base import BaseEnv 8 | 9 | class reacher(BaseEnv): 10 | def __init__(self): 11 | m1 = 0.1 12 | l1 = 1 13 | r1 = l1/2 14 | I1 = m1 * l1**2 / 12 15 | 16 | m2 = 0.1 17 | l2 = 1 18 | r2 = l2/2 19 | I2 = m2 * l2**2 / 12 20 | 21 | g = 0.0 22 | 23 | m = [m1,m2] 24 | l = [l1,l2] 25 | r = [r1,r2] 26 | I = [I1,I2] 27 | super(reacher, self).__init__(name = "reacher", 28 | n = 2, 29 | obs_size = 6, 30 | action_size = 2, 31 | inertials = m+l+r+I+[g], 32 | a_scale = np.array([0.1,0.1])) 33 | self.dt = 0.02 34 | self.goal_position = np.array([0,2]) 35 | 36 | def wrap_state(self): 37 | self.state[:2] = wrap(self.state[:2]) 38 | 39 | def reset_state(self): 40 | self.state = np.array([np.pi + 0.01*np.random.randn(), 41 | 0.01*np.random.randn(), 42 | 0, 43 | 0]) 44 | 45 | def get_A(self, a): 46 | a_1, a_2 = np.clip(a, -1.0, 1.0)*self.a_scale 47 | return np.array([a_1,a_2]) 48 | 49 | def get_obs(self): 50 | return np.array([np.cos(self.state[0]),np.sin(self.state[0]), 51 | np.cos(self.state[1]),np.sin(self.state[1]), 52 | self.state[2], 53 | self.state[3] 54 | ]) 55 | 56 | def get_reward(self): 57 | upright = (np.array([np.cos(self.state[0]), np.cos(self.state[0]+self.state[1])])+1)/2 58 | 59 | qdot = self.state[self.n:] 60 | ang_vel = np.array([qdot[0],qdot[0]+qdot[1]]) 61 | small_velocity = rewards.tolerance(ang_vel, margin=self.ang_vel_limit).min() 62 | small_velocity = (1 + small_velocity) / 2 63 | 64 | reward = upright.mean() * small_velocity 65 | 66 | return reward 67 | 68 | def draw(self): 69 | pygame.draw.circle(self.screen,'slateblue1', [self.offset[0]+self.scaling*self.goal_position[0], 70 | self.offset[1]-self.scaling*self.goal_position[1]], self.scaling*self.link_width) 71 | 72 | centers, joints, angles = self.geo 73 | 74 | for j in range(self.n): 75 | link_points = rect_points(centers[j], self.link_length, self.link_width, angles[j,0],self.scaling,self.offset) 76 | pygame.draw.polygon(self.screen, self.link_color, link_points) 77 | 78 | joint_point = pygame_transform(joints[j],self.scaling,self.offset) 79 | pygame.draw.circle(self.screen, self.joint_color, joint_point, self.scaling*self.joint_radius) 80 | 81 | 82 | if __name__ == '__main__': 83 | basic_check("reacher",0) -------------------------------------------------------------------------------- /env/reacher/robot.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/env/reacher/robot.p -------------------------------------------------------------------------------- /env/rewards.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Soft indicator function evaluating whether a number is within bounds.""" 17 | 18 | import warnings 19 | import numpy as np 20 | 21 | # The value returned by tolerance() at `margin` distance from `bounds` interval. 22 | _DEFAULT_VALUE_AT_MARGIN = 0.1 23 | 24 | 25 | def _sigmoids(x, value_at_1, sigmoid): 26 | """Returns 1 when `x` == 0, between 0 and 1 otherwise. 27 | 28 | Args: 29 | x: A scalar or numpy array. 30 | value_at_1: A float between 0 and 1 specifying the output when `x` == 1. 31 | sigmoid: String, choice of sigmoid type. 32 | 33 | Returns: 34 | A numpy array with values between 0.0 and 1.0. 35 | 36 | Raises: 37 | ValueError: If not 0 < `value_at_1` < 1, except for `linear`, `cosine` and 38 | `quadratic` sigmoids which allow `value_at_1` == 0. 39 | ValueError: If `sigmoid` is of an unknown type. 40 | """ 41 | if sigmoid in ('cosine', 'linear', 'quadratic'): 42 | if not 0 <= value_at_1 < 1: 43 | raise ValueError('`value_at_1` must be nonnegative and smaller than 1, ' 44 | 'got {}.'.format(value_at_1)) 45 | else: 46 | if not 0 < value_at_1 < 1: 47 | raise ValueError('`value_at_1` must be strictly between 0 and 1, ' 48 | 'got {}.'.format(value_at_1)) 49 | 50 | if sigmoid == 'gaussian': 51 | scale = np.sqrt(-2 * np.log(value_at_1)) 52 | return np.exp(-0.5 * (x*scale)**2) 53 | 54 | elif sigmoid == 'hyperbolic': 55 | scale = np.arccosh(1/value_at_1) 56 | return 1 / np.cosh(x*scale) 57 | 58 | elif sigmoid == 'long_tail': 59 | scale = np.sqrt(1/value_at_1 - 1) 60 | return 1 / ((x*scale)**2 + 1) 61 | 62 | elif sigmoid == 'reciprocal': 63 | scale = 1/value_at_1 - 1 64 | return 1 / (abs(x)*scale + 1) 65 | 66 | elif sigmoid == 'cosine': 67 | scale = np.arccos(2*value_at_1 - 1) / np.pi 68 | scaled_x = x*scale 69 | with warnings.catch_warnings(): 70 | warnings.filterwarnings( 71 | action='ignore', message='invalid value encountered in cos') 72 | cos_pi_scaled_x = np.cos(np.pi*scaled_x) 73 | return np.where(abs(scaled_x) < 1, (1 + cos_pi_scaled_x)/2, 0.0) 74 | 75 | elif sigmoid == 'linear': 76 | scale = 1-value_at_1 77 | scaled_x = x*scale 78 | return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0) 79 | 80 | elif sigmoid == 'quadratic': 81 | scale = np.sqrt(1-value_at_1) 82 | scaled_x = x*scale 83 | return np.where(abs(scaled_x) < 1, 1 - scaled_x**2, 0.0) 84 | 85 | elif sigmoid == 'tanh_squared': 86 | scale = np.arctanh(np.sqrt(1-value_at_1)) 87 | return 1 - np.tanh(x*scale)**2 88 | 89 | else: 90 | raise ValueError('Unknown sigmoid type {!r}.'.format(sigmoid)) 91 | 92 | 93 | def tolerance(x, bounds=(0.0, 0.0), margin=0.0, sigmoid='gaussian', 94 | value_at_margin=_DEFAULT_VALUE_AT_MARGIN): 95 | """Returns 1 when `x` falls inside the bounds, between 0 and 1 otherwise. 96 | 97 | Args: 98 | x: A scalar or numpy array. 99 | bounds: A tuple of floats specifying inclusive `(lower, upper)` bounds for 100 | the target interval. These can be infinite if the interval is unbounded 101 | at one or both ends, or they can be equal to one another if the target 102 | value is exact. 103 | margin: Float. Parameter that controls how steeply the output decreases as 104 | `x` moves out-of-bounds. 105 | * If `margin == 0` then the output will be 0 for all values of `x` 106 | outside of `bounds`. 107 | * If `margin > 0` then the output will decrease sigmoidally with 108 | increasing distance from the nearest bound. 109 | sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian', 110 | 'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'. 111 | value_at_margin: A float between 0 and 1 specifying the output value when 112 | the distance from `x` to the nearest bound is equal to `margin`. Ignored 113 | if `margin == 0`. 114 | 115 | Returns: 116 | A float or numpy array with values between 0.0 and 1.0. 117 | 118 | Raises: 119 | ValueError: If `bounds[0] > bounds[1]`. 120 | ValueError: If `margin` is negative. 121 | """ 122 | lower, upper = bounds 123 | if lower > upper: 124 | raise ValueError('Lower bound must be <= upper bound.') 125 | if margin < 0: 126 | raise ValueError('`margin` must be non-negative.') 127 | 128 | in_bounds = np.logical_and(lower <= x, x <= upper) 129 | if margin == 0: 130 | value = np.where(in_bounds, 1.0, 0.0) 131 | else: 132 | d = np.where(x < lower, lower - x, x - upper) / margin 133 | value = np.where(in_bounds, 1.0, _sigmoids(d, value_at_margin, sigmoid)) 134 | 135 | return float(value) if np.isscalar(x) else value 136 | 137 | -------------------------------------------------------------------------------- /env/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | import numpy as np 5 | 6 | def create_background(length, width): 7 | background = pygame.Surface((length, width)) 8 | pygame.draw.rect(background, (0,0,0), pygame.Rect(0, 0, length, width)) 9 | return background 10 | 11 | def rect_points(center, length, width, ang, scaling, offset): 12 | points = [] 13 | diag = np.sqrt(length**2+width**2)/2 14 | ang1 = 2*np.arctan2(width,length) 15 | ang2 = 2*np.arctan2(length,width) 16 | 17 | points.append((center[0]+np.sin(ang+ang1/2)*diag, center[1]+np.cos(ang+ang1/2)*diag)) 18 | 19 | points.append((center[0]+np.sin(ang+ang1/2+ang2)*diag, center[1]+np.cos(ang+ang1/2+ang2)*diag)) 20 | 21 | points.append((center[0]+np.sin(ang+ang1*1.5+ang2)*diag, center[1]+np.cos(ang+ang1*1.5+ang2)*diag)) 22 | 23 | points.append((center[0]+np.sin(ang+ang1*1.5+2*ang2)*diag, center[1]+np.cos(ang+ang1*1.5+2*ang2)*diag)) 24 | 25 | return [pygame_transform(point, scaling, offset) for point in points] 26 | 27 | def pygame_transform(point, scaling, offset): 28 | # Pygame's y axis points downwards. Hence invert y coordinate alone before offset. 29 | return (offset[0]+scaling*point[0],offset[1]-scaling*point[1]) 30 | 31 | def wrap(x): 32 | return ((x + np.pi) % (2 * np.pi)) - np.pi 33 | 34 | def make_env(name): 35 | 36 | if name == "pendulum": 37 | from env.pendulum.pendulum import pendulum as Env 38 | 39 | elif name == "reacher": 40 | from env.reacher.reacher import reacher as Env 41 | 42 | elif name == "cartpole": 43 | from env.cartpole.cartpole import cartpole as Env 44 | 45 | elif name == "acrobot": 46 | from env.acrobot.acrobot import acrobot as Env 47 | 48 | elif name == "cart2pole": 49 | from env.cart2pole.cart2pole import cart2pole as Env 50 | 51 | elif name == "acro3bot": 52 | from env.acro3bot.acro3bot import acro3bot as Env 53 | 54 | elif name == "cart3pole": 55 | from env.cart3pole.cart3pole import cart3pole as Env 56 | 57 | env = Env() 58 | 59 | return env 60 | 61 | def basic_check(name, seed): 62 | np.random.seed(seed) 63 | env = make_env(name) 64 | no_episodes = 1 65 | for episode in range(no_episodes): 66 | t = 0 67 | o_t, _, _ = env.reset() 68 | ep_r = 0 69 | while True: 70 | env.render() 71 | a_t = np.random.uniform(-1,1,env.action_size) 72 | o_t_1, r_t, done = env.step(a_t) 73 | ep_r += r_t 74 | t += 1 75 | o_t = o_t_1 76 | if done: 77 | print("Episode finished with total reward ",ep_r,"time steps",t) 78 | break 79 | 80 | 81 | -------------------------------------------------------------------------------- /mbrl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from copy import deepcopy 4 | import random 5 | import numpy as np 6 | import torch 7 | torch.set_default_dtype(torch.float64) 8 | from torch.utils.tensorboard import SummaryWriter 9 | from env.utils import make_env 10 | from models.mbrl import ReplayBuffer, V_FC, Pi_FC, dnn, lnn, reward_model_FC 11 | 12 | # Model-Based RL algorithm 13 | class MBRL: 14 | def __init__(self, arglist): 15 | self.arglist = arglist 16 | 17 | random.seed(self.arglist.seed) 18 | np.random.seed(self.arglist.seed) 19 | torch.manual_seed(self.arglist.seed) 20 | 21 | self.env = make_env(self.arglist.env) 22 | 23 | self.device = torch.device("cpu") 24 | 25 | path = "./log/"+self.env.name+"/mbrl_"+self.arglist.model 26 | self.exp_dir = os.path.join(path, "seed_"+str(self.arglist.seed)) 27 | self.model_dir = os.path.join(self.exp_dir, "models") 28 | self.tensorboard_dir = os.path.join(self.exp_dir, "tensorboard") 29 | 30 | self.actor = Pi_FC(self.env.obs_size,self.env.action_size).to(self.device) 31 | 32 | if self.arglist.mode == "train": 33 | self.critic = V_FC(self.env.obs_size).to(self.device) 34 | self.critic_target = deepcopy(self.critic) 35 | 36 | if self.arglist.model == "lnn": 37 | if self.env.action_size < self.env.n: 38 | a_zeros = torch.zeros(self.arglist.batch_size,self.env.n-self.env.action_size, dtype=torch.float64, device=self.device) 39 | else: 40 | a_zeros = None 41 | self.transition_model = lnn(self.env.name, self.env.n, self.env.obs_size, self.env.action_size, self.env.dt, a_zeros).to(self.device) 42 | 43 | elif self.arglist.model == "dnn": 44 | self.transition_model = dnn(self.env.obs_size, self.env.action_size).to(self.device) 45 | self.transition_loss_fn = torch.nn.L1Loss() 46 | 47 | self.reward_model = reward_model_FC(self.env.obs_size).to(self.device) 48 | self.reward_loss_fn = torch.nn.L1Loss() 49 | 50 | if self.arglist.resume: 51 | checkpoint = torch.load(os.path.join(self.model_dir,"emergency.ckpt")) 52 | self.start_episode = checkpoint['episode'] + 1 53 | 54 | self.actor.load_state_dict(checkpoint['actor']) 55 | self.critic.load_state_dict(checkpoint['critic']) 56 | self.critic_target.load_state_dict(checkpoint['critic_target']) 57 | self.transition_model.load_state_dict(checkpoint['transition_model']) 58 | self.reward_model.load_state_dict(checkpoint['reward_model']) 59 | 60 | self.replay_buffer = checkpoint['replay_buffer'] 61 | 62 | else: 63 | self.start_episode = 0 64 | 65 | self.replay_buffer = ReplayBuffer(self.arglist.replay_size, self.device) 66 | 67 | if os.path.exists(path): 68 | pass 69 | else: 70 | os.makedirs(path) 71 | os.mkdir(self.exp_dir) 72 | os.mkdir(os.path.join(self.tensorboard_dir)) 73 | os.mkdir(self.model_dir) 74 | 75 | self.actor_optimizer = torch.optim.AdamW(self.actor.parameters(), lr=self.arglist.lr) 76 | self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), lr=self.arglist.lr) 77 | self.transition_optimizer = torch.optim.AdamW(self.transition_model.parameters(), lr=self.arglist.lr) 78 | self.reward_optimizer = torch.optim.AdamW(self.reward_model.parameters(), lr=self.arglist.lr) 79 | 80 | if self.arglist.resume: 81 | self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer']) 82 | self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer']) 83 | self.transition_optimizer.load_state_dict(checkpoint['transition_optimizer']) 84 | self.reward_optimizer.load_state_dict(checkpoint['reward_optimizer']) 85 | 86 | print("Done loading checkpoint ...") 87 | 88 | self.a_scale = torch.tensor(self.env.a_scale,dtype=torch.float64,device=self.device) 89 | 90 | self.train() 91 | 92 | elif self.arglist.mode == "eval": 93 | checkpoint = torch.load(self.arglist.checkpoint,map_location=self.device) 94 | self.actor.load_state_dict(checkpoint['actor']) 95 | ep_r_list = self.eval(self.arglist.episodes,self.arglist.render) 96 | 97 | def save_checkpoint(self, name): 98 | checkpoint = {'actor' : self.actor.state_dict(),\ 99 | 'critic' : self.critic.state_dict(),\ 100 | 'transition_model' : self.transition_model.state_dict(),\ 101 | 'reward_model' : self.reward_model.state_dict() 102 | } 103 | torch.save(checkpoint, os.path.join(self.model_dir, name)) 104 | 105 | def save_emergency_checkpoint(self, episode): 106 | checkpoint = {'episode' : episode,\ 107 | 'actor' : self.actor.state_dict(),\ 108 | 'actor_optimizer': self.actor_optimizer.state_dict(),\ 109 | 'critic' : self.critic.state_dict(),\ 110 | 'critic_optimizer': self.critic_optimizer.state_dict(),\ 111 | 'critic_target' : self.critic_target.state_dict(),\ 112 | 'transition_model' : self.transition_model.state_dict(),\ 113 | 'transition_optimizer': self.transition_optimizer.state_dict(),\ 114 | 'reward_model' : self.reward_model.state_dict(),\ 115 | 'reward_optimizer': self.reward_optimizer.state_dict(),\ 116 | 'replay_buffer' : self.replay_buffer \ 117 | } 118 | torch.save(checkpoint, os.path.join(self.model_dir, "emergency.ckpt")) 119 | 120 | def hard_update(self, target, source): 121 | with torch.no_grad(): 122 | for target_param, param in zip(target.parameters(), source.parameters()): 123 | target_param.data.copy_(param.data) 124 | 125 | def train(self): 126 | critic_target_updates = 0 127 | 128 | writer = SummaryWriter(log_dir=self.tensorboard_dir) 129 | 130 | if not self.arglist.resume: 131 | # Initialize replay buffer with K random episodes 132 | for episode in range(self.arglist.K): 133 | o,_,_ = self.env.reset() 134 | o_tensor = torch.tensor(o, dtype=torch.float64, device=self.device) 135 | ep_r = 0 136 | while True: 137 | a = np.random.uniform(-1.0, 1.0, size=self.env.action_size) 138 | o_1,r,done = self.env.step(a) 139 | a_tensor = torch.tensor(a, dtype=torch.float64, device=self.device) 140 | o_1_tensor = torch.tensor(o_1, dtype=torch.float64, device=self.device) 141 | r_tensor = torch.tensor(r, dtype=torch.float64, device=self.device) 142 | self.replay_buffer.push(o_tensor, a_tensor, r_tensor, o_1_tensor) 143 | ep_r += r 144 | o_tensor = o_1_tensor 145 | if done: 146 | break 147 | 148 | print("Done initialization ...") 149 | print("Started training ...") 150 | 151 | for episode in range(self.start_episode,self.arglist.episodes): 152 | # Model learning 153 | transition_loss_list, reward_loss_list = [], [] 154 | transition_grad_list, reward_grad_list = [], [] 155 | for model_batches in range(self.arglist.model_batches): 156 | O, A, R, O_1 = self.replay_buffer.sample_transitions(self.arglist.batch_size) 157 | 158 | # Dynamics learning 159 | O_1_pred = self.transition_model(O,A*self.a_scale) 160 | transition_loss = self.transition_loss_fn(O_1_pred, O_1) 161 | self.transition_optimizer.zero_grad() 162 | transition_loss.backward() 163 | torch.nn.utils.clip_grad_norm_(self.transition_model.parameters(), self.arglist.clip_term) 164 | self.transition_optimizer.step() 165 | transition_loss_list.append(transition_loss.item()) 166 | transition_grad = [] 167 | for param in self.transition_model.parameters(): 168 | if param.grad is not None: 169 | transition_grad.append(param.grad.flatten()) 170 | transition_grad_list.append(torch.norm(torch.cat(transition_grad)).item()) 171 | 172 | # Reward learning 173 | R_pred = self.reward_model(O_1) 174 | reward_loss = self.reward_loss_fn(R_pred,R) 175 | self.reward_optimizer.zero_grad() 176 | reward_loss.backward() 177 | torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), self.arglist.clip_term) 178 | self.reward_optimizer.step() 179 | reward_loss_list.append(reward_loss.item()) 180 | reward_grad_list.append(torch.norm(torch.cat([param.grad.flatten() for param in self.reward_model.parameters()])).item()) 181 | 182 | writer.add_scalar('transition_loss', np.mean(transition_loss_list), episode) 183 | writer.add_scalar('reward_loss', np.mean(reward_loss_list), episode) 184 | writer.add_scalar('transition_grad',np.mean(transition_grad_list),episode) 185 | writer.add_scalar('reward_grad',np.mean(reward_grad_list),episode) 186 | 187 | #Behaviour learning 188 | actor_loss_list, critic_loss_list = [], [] 189 | actor_grad_list, critic_grad_list = [], [] 190 | 191 | nan_count = 0 192 | for behaviour_batches in range(self.arglist.behaviour_batches): 193 | O = self.replay_buffer.sample_states(self.arglist.batch_size) 194 | t = 0 195 | values, values_target, values_lambda, R = [], [], [], [] 196 | log_probs = [] 197 | try: 198 | while True: 199 | A, log_prob = self.actor(O, False, True) 200 | log_probs.append(log_prob) 201 | O_1 = self.transition_model(O, A*self.a_scale) 202 | R.append(self.reward_model(O_1)) 203 | values.append(self.critic(O)) 204 | values_target.append(self.critic_target(O)) 205 | t += 1 206 | O = O_1 207 | if t % self.arglist.T == 0: 208 | values_target.append(self.critic_target(O_1)) 209 | break 210 | 211 | # lambda-return calculation 212 | gae = torch.zeros_like(R[0]) 213 | for t_ in reversed(range(self.arglist.T)): 214 | delta = R[t_]+self.arglist.gamma*values_target[t_+1]-values_target[t_] 215 | gae = delta+self.arglist.gamma*self.arglist.Lambda*gae 216 | values_lambda.append(gae+values_target[t_]) 217 | values_lambda = torch.stack(values_lambda) 218 | values_lambda = values_lambda.flip(0) 219 | 220 | values = torch.stack(values) 221 | critic_loss = 0.5*torch.pow(values-values_lambda.detach(),2).sum(0).mean() 222 | 223 | log_probs = torch.stack(log_probs) 224 | actor_loss = - (values_lambda-0.0001*log_probs).sum(0).mean() 225 | 226 | self.critic_optimizer.zero_grad() 227 | critic_loss.backward(inputs=[param for param in self.critic.parameters()]) 228 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.arglist.clip_term) 229 | 230 | self.actor_optimizer.zero_grad() 231 | actor_loss.backward(inputs=[param for param in self.actor.parameters()]) 232 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.arglist.clip_term) 233 | 234 | critic_grad = torch.norm(torch.cat([param.grad.flatten() for param in self.critic.parameters()])) 235 | actor_grad = torch.norm(torch.cat([param.grad.flatten() for param in self.actor.parameters()])) 236 | 237 | if torch.isnan(critic_grad).any().item() or torch.isnan(actor_grad).any().item(): 238 | nan_count += 1 239 | else: 240 | self.critic_optimizer.step() 241 | self.actor_optimizer.step() 242 | 243 | critic_target_updates = (critic_target_updates+1)%100 244 | if critic_target_updates == 0: 245 | self.hard_update(self.critic_target, self.critic) 246 | 247 | actor_loss_list.append(actor_loss.item()) 248 | critic_loss_list.append(critic_loss.item()) 249 | actor_grad_list.append(actor_grad.item()) 250 | critic_grad_list.append(critic_grad.item()) 251 | 252 | except: 253 | nan_count += 1 254 | 255 | if nan_count > 0: 256 | print("episode",episode,"got nan during behaviour learning","nan count",nan_count) 257 | writer.add_scalar('critic_loss',np.mean(critic_loss_list),episode) 258 | writer.add_scalar('actor_loss',np.mean(actor_loss_list),episode) 259 | writer.add_scalar('critic_grad',np.mean(critic_grad_list),episode) 260 | writer.add_scalar('actor_grad',np.mean(actor_grad_list),episode) 261 | 262 | # Environment Interaction 263 | o,_,_ = self.env.reset() 264 | o_tensor = torch.tensor(o, dtype=torch.float64, device=self.device).unsqueeze(0) 265 | ep_r = 0 266 | while True: 267 | with torch.no_grad(): 268 | try: 269 | a_tensor, _ = self.actor(o_tensor) 270 | except: 271 | print("episode",episode,"got nan during environment interaction") 272 | break 273 | o_1,r,done = self.env.step(a_tensor.cpu().numpy()[0]) 274 | o_1_tensor = torch.tensor(o_1, dtype=torch.float64, device=self.device).unsqueeze(0) 275 | r_tensor = torch.tensor(r, dtype=torch.float64, device=self.device) 276 | self.replay_buffer.push(o_tensor[0], a_tensor[0], r_tensor, o_1_tensor[0]) 277 | ep_r += r 278 | o_tensor = o_1_tensor 279 | if done: 280 | writer.add_scalar('ep_r', ep_r, episode) 281 | if episode % self.arglist.eval_every == 0 or episode == self.arglist.episodes-1: 282 | try: 283 | # Evaluate agent performance 284 | eval_ep_r_list = self.eval(self.arglist.eval_over) 285 | writer.add_scalar('eval_ep_r', np.mean(eval_ep_r_list), episode) 286 | self.save_checkpoint(str(episode)+".ckpt") 287 | except: 288 | print("episode",episode,"got nan during eval") 289 | if (episode % 25 == 0 or episode == self.arglist.episodes-1) and episode > self.start_episode: 290 | self.save_emergency_checkpoint(episode) 291 | break 292 | 293 | def eval(self, episodes, render=False): 294 | # Evaluate agent performance over several episodes 295 | ep_r_list = [] 296 | for episode in range(episodes): 297 | o,_,_ = self.env.reset() 298 | ep_r = 0 299 | while True: 300 | with torch.no_grad(): 301 | a, _ = self.actor(torch.tensor(o, dtype=torch.float64, device=self.device).unsqueeze(0),True) 302 | a = a.cpu().numpy()[0] 303 | o_1,r,done = self.env.step(a) 304 | if render: 305 | self.env.render() 306 | ep_r += r 307 | o = o_1 308 | if done: 309 | ep_r_list.append(ep_r) 310 | if render: 311 | print("Episode finished with total reward ",ep_r) 312 | break 313 | 314 | if self.arglist.mode == "eval": 315 | print("Average return :",np.mean(ep_r_list)) 316 | 317 | return ep_r_list 318 | 319 | def parse_args(): 320 | parser = argparse.ArgumentParser("Model-Based Reinforcement Learning") 321 | # Common settings 322 | parser.add_argument("--env", type=str, default="acrobot", help="pendulum / reacher / cartpole / acrobot / cart2pole / acro3bot / cart3pole") 323 | parser.add_argument("--mode", type=str, default="train", help="train or eval") 324 | parser.add_argument("--episodes", type=int, default=500, help="number of episodes to run experiment for") 325 | parser.add_argument("--seed", type=int, default=0, help="seed") 326 | # Core training parameters 327 | parser.add_argument("--resume", action="store_true", default=False, help="continue training from checkpoint") 328 | parser.add_argument("--model", type=str, default="lnn", help="lnn / dnn") 329 | parser.add_argument("--T", type=int, default=16, help="imagination horizon") 330 | parser.add_argument("--K", type=int, default=10, help="init replay buffer with K random episodes") 331 | parser.add_argument("--lr", type=float, default=3e-4, help="learning rate") 332 | parser.add_argument("--clip-term", type=float, default=100, help="gradient clipping norm") 333 | parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") 334 | parser.add_argument("--Lambda", type=float, default=0.95, help="GAE lambda") 335 | parser.add_argument("--batch-size", type=int, default=64, help="batch size for model learning, behaviour learning") 336 | parser.add_argument("--model-batches", type=int, default=int(1e4), help="model batches per episode") 337 | parser.add_argument("--behaviour-batches", type=int, default=int(1e3), help="behaviour batches per episode") 338 | parser.add_argument("--replay-size", type=int, default=int(1e5), help="replay buffer size") 339 | parser.add_argument("--eval-every", type=int, default=5, help="eval every _ episodes during training") 340 | parser.add_argument("--eval-over", type=int, default=50, help="each time eval over _ episodes") 341 | # Eval settings 342 | parser.add_argument("--checkpoint", type=str, default="", help="path to checkpoint") 343 | parser.add_argument("--render", action="store_true", default=False, help="render") 344 | return parser.parse_args() 345 | 346 | if __name__ == '__main__': 347 | arglist = parse_args() 348 | mbrl = MBRL(arglist) 349 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adi3e08/Physics_Informed_Model_Based_RL/b630360bfac0e27f3b3d6e2a6b6cb46b1ced5859/models/__init__.py -------------------------------------------------------------------------------- /models/mbrl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | torch.set_default_dtype(torch.float64) 4 | import torch.nn.functional as F 5 | from torch.distributions.normal import Normal 6 | from torch.func import jacrev 7 | from collections import deque 8 | import random 9 | 10 | # Experience replay buffer 11 | class ReplayBuffer: 12 | def __init__(self, capacity, device): 13 | self.buffer = deque(maxlen=capacity) 14 | self.device = device 15 | 16 | def push(self, o, a, r, o_1): 17 | self.buffer.append((o, a, r, o_1)) 18 | 19 | def sample_transitions(self, batch_size): 20 | O, A, R, O_1 = zip(*random.sample(self.buffer, batch_size)) 21 | return torch.stack(O), torch.stack(A), torch.stack(R), torch.stack(O_1) 22 | 23 | def sample_states(self, batch_size): 24 | O, A, R, O_1 = zip(*random.sample(self.buffer, batch_size)) 25 | return torch.stack(O) 26 | 27 | def __len__(self): 28 | return len(self.buffer) 29 | 30 | # Critic network 31 | class V_FC(torch.nn.Module): 32 | def __init__(self, obs_size): 33 | super(V_FC, self).__init__() 34 | self.fc1 = torch.nn.Linear(obs_size, 256) 35 | self.fc2 = torch.nn.Linear(256, 256) 36 | self.fc3 = torch.nn.Linear(256, 1) 37 | 38 | def forward(self, x): 39 | y1 = F.relu(self.fc1(x)) 40 | y2 = F.relu(self.fc2(y1)) 41 | y = self.fc3(y2).view(-1) 42 | return y 43 | 44 | # Actor network 45 | class Pi_FC(torch.nn.Module): 46 | def __init__(self, obs_size, action_size): 47 | super(Pi_FC, self).__init__() 48 | self.fc1 = torch.nn.Linear(obs_size, 256) 49 | self.fc2 = torch.nn.Linear(256, 256) 50 | self.mu = torch.nn.Linear(256, action_size) 51 | self.log_sigma = torch.nn.Linear(256, action_size) 52 | 53 | def forward(self, x, deterministic=False, with_logprob=False): 54 | y1 = F.relu(self.fc1(x)) 55 | y2 = F.relu(self.fc2(y1)) 56 | mu = self.mu(y2) 57 | 58 | if deterministic: 59 | # used for evaluating policy 60 | action = torch.tanh(mu) 61 | log_prob = None 62 | else: 63 | log_sigma = self.log_sigma(y2) 64 | log_sigma = torch.clamp(log_sigma,min=-20.0,max=2.0) 65 | sigma = torch.exp(log_sigma) 66 | dist = Normal(mu, sigma) 67 | x_t = dist.rsample() 68 | action = torch.tanh(x_t) 69 | if with_logprob: 70 | log_prob = dist.log_prob(x_t).sum(1) 71 | log_prob -= torch.log(torch.clamp(1-action.pow(2),min=1e-6)).sum(1) 72 | else: 73 | log_prob = None 74 | 75 | return action, log_prob 76 | 77 | 78 | # DNN dynamics model 79 | class dnn(torch.nn.Module): 80 | def __init__(self, obs_size, action_size): 81 | super(dnn, self).__init__() 82 | self.fc1 = torch.nn.Linear(obs_size+action_size, 64) 83 | self.fc2 = torch.nn.Linear(64, 64) 84 | self.fc3 = torch.nn.Linear(64, obs_size) 85 | 86 | def forward(self, x, a): 87 | y1 = F.relu(self.fc1(torch.cat((x,a),1))) 88 | y2 = F.relu(self.fc2(y1)) 89 | y = self.fc3(y2) 90 | return y 91 | 92 | # LNN dynamics model 93 | class lnn(torch.nn.Module): 94 | def __init__(self, env_name, n, obs_size, action_size, dt, a_zeros): 95 | super(lnn, self).__init__() 96 | self.env_name = env_name 97 | self.dt = dt 98 | self.n = n 99 | 100 | input_size = obs_size - self.n 101 | out_L = int(self.n*(self.n+1)/2) 102 | self.fc1_L = torch.nn.Linear(input_size, 64) 103 | self.fc2_L = torch.nn.Linear(64, 64) 104 | self.fc3_L = torch.nn.Linear(64, out_L) 105 | if not self.env_name == "reacher": 106 | self.fc1_V = torch.nn.Linear(input_size, 64) 107 | self.fc2_V = torch.nn.Linear(64, 64) 108 | self.fc3_V = torch.nn.Linear(64, 1) 109 | 110 | self.a_zeros = a_zeros 111 | 112 | def trig_transform_q(self, q): 113 | if self.env_name == "pendulum": 114 | return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0]))) 115 | 116 | elif self.env_name == "reacher" or self.env_name == "acrobot": 117 | return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0]),\ 118 | torch.cos(q[:,1]),torch.sin(q[:,1]))) 119 | 120 | elif self.env_name == "cartpole": 121 | return torch.column_stack((q[:,0],\ 122 | torch.cos(q[:,1]),torch.sin(q[:,1]))) 123 | 124 | elif self.env_name == "cart2pole": 125 | return torch.column_stack((q[:,0],\ 126 | torch.cos(q[:,1]),torch.sin(q[:,1]),\ 127 | torch.cos(q[:,2]),torch.sin(q[:,2]))) 128 | 129 | elif self.env_name == "cart3pole": 130 | return torch.column_stack((q[:,0],\ 131 | torch.cos(q[:,1]),torch.sin(q[:,1]),\ 132 | torch.cos(q[:,2]),torch.sin(q[:,2]),\ 133 | torch.cos(q[:,3]),torch.sin(q[:,3]))) 134 | 135 | elif self.env_name == "acro3bot": 136 | return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0]),\ 137 | torch.cos(q[:,1]),torch.sin(q[:,1]),\ 138 | torch.cos(q[:,2]),torch.sin(q[:,2]))) 139 | 140 | def inverse_trig_transform_model(self, x): 141 | if self.env_name == "pendulum": 142 | return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),x[:,2:]),1) 143 | 144 | elif self.env_name == "reacher" or self.env_name == "acrobot": 145 | return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),torch.atan2(x[:,3],x[:,2]).unsqueeze(1),x[:,4:]),1) 146 | 147 | elif self.env_name == "cartpole": 148 | return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),x[:,3:]),1) 149 | 150 | elif self.env_name == "cart2pole": 151 | return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),torch.atan2(x[:,4],x[:,3]).unsqueeze(1),x[:,5:]),1) 152 | 153 | elif self.env_name == "cart3pole": 154 | return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),torch.atan2(x[:,4],x[:,3]).unsqueeze(1), 155 | torch.atan2(x[:,6],x[:,5]).unsqueeze(1),x[:,7:]),1) 156 | 157 | elif self.env_name == "acro3bot": 158 | return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),torch.atan2(x[:,3],x[:,2]).unsqueeze(1),torch.atan2(x[:,5],x[:,4]).unsqueeze(1), 159 | x[:,6:]),1) 160 | 161 | def compute_L(self, q): 162 | y1_L = F.softplus(self.fc1_L(q)) 163 | y2_L = F.softplus(self.fc2_L(y1_L)) 164 | y_L = self.fc3_L(y2_L) 165 | device = y_L.device 166 | if self.n == 1: 167 | L = y_L.unsqueeze(1) 168 | 169 | elif self.n == 2: 170 | L11 = y_L[:,0].unsqueeze(1) 171 | L1_zeros = torch.zeros(L11.size(0),1, dtype=torch.float64, device=device) 172 | 173 | L21 = y_L[:,1].unsqueeze(1) 174 | L22 = y_L[:,2].unsqueeze(1) 175 | 176 | L1 = torch.cat((L11,L1_zeros),1) 177 | L2 = torch.cat((L21,L22),1) 178 | L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1)),1) 179 | 180 | elif self.n == 3: 181 | L11 = y_L[:,0].unsqueeze(1) 182 | L1_zeros = torch.zeros(L11.size(0),2, dtype=torch.float64, device=device) 183 | 184 | L21 = y_L[:,1].unsqueeze(1) 185 | L22 = y_L[:,2].unsqueeze(1) 186 | L2_zeros = torch.zeros(L21.size(0),1, dtype=torch.float64, device=device) 187 | 188 | L31 = y_L[:,3].unsqueeze(1) 189 | L32 = y_L[:,4].unsqueeze(1) 190 | L33 = y_L[:,5].unsqueeze(1) 191 | 192 | L1 = torch.cat((L11,L1_zeros),1) 193 | L2 = torch.cat((L21,L22,L2_zeros),1) 194 | L3 = torch.cat((L31,L32,L33),1) 195 | L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1),L3.unsqueeze(1)),1) 196 | 197 | elif self.n == 4: 198 | L11 = y_L[:,0].unsqueeze(1) 199 | L1_zeros = torch.zeros(L11.size(0),3, dtype=torch.float64, device=device) 200 | 201 | L21 = y_L[:,1].unsqueeze(1) 202 | L22 = y_L[:,2].unsqueeze(1) 203 | L2_zeros = torch.zeros(L21.size(0),2, dtype=torch.float64, device=device) 204 | 205 | L31 = y_L[:,3].unsqueeze(1) 206 | L32 = y_L[:,4].unsqueeze(1) 207 | L33 = y_L[:,5].unsqueeze(1) 208 | L3_zeros = torch.zeros(L31.size(0),1, dtype=torch.float64, device=device) 209 | 210 | L41 = y_L[:,6].unsqueeze(1) 211 | L42 = y_L[:,7].unsqueeze(1) 212 | L43 = y_L[:,8].unsqueeze(1) 213 | L44 = y_L[:,9].unsqueeze(1) 214 | 215 | L1 = torch.cat((L11,L1_zeros),1) 216 | L2 = torch.cat((L21,L22,L2_zeros),1) 217 | L3 = torch.cat((L31,L32,L33,L3_zeros),1) 218 | L4 = torch.cat((L41,L42,L43,L44),1) 219 | L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1),L3.unsqueeze(1),L4.unsqueeze(1)),1) 220 | 221 | return L 222 | 223 | def get_A(self, a): 224 | if self.env_name == "pendulum" or self.env_name == "reacher": 225 | A = a 226 | 227 | elif self.env_name == "acrobot": 228 | A = torch.cat((self.a_zeros,a),1) 229 | 230 | elif self.env_name == "cartpole" or self.env_name == "cart2pole": 231 | A = torch.cat((a,self.a_zeros),1) 232 | 233 | elif self.env_name == "cart3pole" or self.env_name == "acro3bot": 234 | A = torch.cat((a[:,:1],self.a_zeros,a[:,1:]),1) 235 | 236 | return A 237 | 238 | def get_L(self, q): 239 | trig_q = self.trig_transform_q(q) 240 | L = self.compute_L(trig_q) 241 | return L.sum(0), L 242 | 243 | def get_V(self, q): 244 | trig_q = self.trig_transform_q(q) 245 | y1_V = F.softplus(self.fc1_V(trig_q)) 246 | y2_V = F.softplus(self.fc2_V(y1_V)) 247 | V = self.fc3_V(y2_V).squeeze() 248 | return V.sum() 249 | 250 | def get_acc(self, q, qdot, a): 251 | dL_dq, L = jacrev(self.get_L, has_aux=True)(q) 252 | term_1 = torch.einsum('blk,bijk->bijl', L, dL_dq.permute(2,3,0,1)) 253 | dM_dq = term_1 + term_1.transpose(2,3) 254 | c = torch.einsum('bjik,bk,bj->bi', dM_dq, qdot, qdot) - 0.5 * torch.einsum('bikj,bk,bj->bi', dM_dq, qdot, qdot) 255 | Minv = torch.cholesky_inverse(L) 256 | dV_dq = 0 if self.env_name == "reacher" else jacrev(self.get_V)(q) 257 | qddot = torch.matmul(Minv,(self.get_A(a)-c-dV_dq).unsqueeze(2)).squeeze(2) 258 | return qddot 259 | 260 | def derivs(self, s, a): 261 | q, qdot = s[:,:self.n], s[:,self.n:] 262 | qddot = self.get_acc(q, qdot, a) 263 | return torch.cat((qdot,qddot),dim=1) 264 | 265 | def rk2(self, s, a): 266 | alpha = 2.0/3.0 # Ralston's method 267 | k1 = self.derivs(s, a) 268 | k2 = self.derivs(s + alpha * self.dt * k1, a) 269 | s_1 = s + self.dt * ((1.0 - 1.0/(2.0*alpha))*k1 + (1.0/(2.0*alpha))*k2) 270 | return s_1 271 | 272 | def forward(self, o, a): 273 | s_1 = self.rk2(self.inverse_trig_transform_model(o), a) 274 | o_1 = torch.cat((self.trig_transform_q(s_1[:,:self.n]),s_1[:,self.n:]),1) 275 | return o_1 276 | 277 | # Reward model 278 | class reward_model_FC(torch.nn.Module): 279 | def __init__(self, obs_size): 280 | super(reward_model_FC, self).__init__() 281 | self.fc1 = torch.nn.Linear(obs_size, 64) 282 | self.fc2 = torch.nn.Linear(64, 64) 283 | self.fc3 = torch.nn.Linear(64, 1) 284 | 285 | def forward(self, x): 286 | y1 = F.relu(self.fc1(x)) 287 | y2 = F.relu(self.fc2(y1)) 288 | y = self.fc3(y2).view(-1) 289 | return y 290 | -------------------------------------------------------------------------------- /models/sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | torch.set_default_dtype(torch.float64) 4 | import torch.nn.functional as F 5 | from torch.distributions.normal import Normal 6 | from collections import deque 7 | import random 8 | 9 | # Experience replay buffer 10 | class ReplayBuffer: 11 | def __init__(self, capacity, device): 12 | self.buffer = deque(maxlen=capacity) 13 | self.device = device 14 | 15 | def push(self, o, a, r, o_1): 16 | self.buffer.append((o, a, r, o_1)) 17 | 18 | def sample(self, batch_size): 19 | O, A, R, O_1 = zip(*random.sample(self.buffer, batch_size)) 20 | return torch.tensor(np.array(O), dtype=torch.float64, device=self.device),\ 21 | torch.tensor(np.array(A), dtype=torch.float64, device=self.device),\ 22 | torch.tensor(np.array(R), dtype=torch.float64, device=self.device),\ 23 | torch.tensor(np.array(O_1), dtype=torch.float64, device=self.device) 24 | 25 | def __len__(self): 26 | return len(self.buffer) 27 | 28 | # Critic network 29 | class Q_FC(torch.nn.Module): 30 | def __init__(self, obs_size, action_size): 31 | super(Q_FC, self).__init__() 32 | self.fc1 = torch.nn.Linear(obs_size+action_size, 256) 33 | self.fc2 = torch.nn.Linear(256, 256) 34 | self.fc3 = torch.nn.Linear(256, 1) 35 | 36 | def forward(self, x, a): 37 | y1 = F.relu(self.fc1(torch.cat((x,a),1))) 38 | y2 = F.relu(self.fc2(y1)) 39 | y = self.fc3(y2).view(-1) 40 | return y 41 | 42 | # Actor network 43 | class Pi_FC(torch.nn.Module): 44 | def __init__(self, obs_size, action_size): 45 | super(Pi_FC, self).__init__() 46 | self.fc1 = torch.nn.Linear(obs_size, 256) 47 | self.fc2 = torch.nn.Linear(256, 256) 48 | self.mu = torch.nn.Linear(256, action_size) 49 | self.log_sigma = torch.nn.Linear(256, action_size) 50 | 51 | def forward(self, x, deterministic=False, with_logprob=False): 52 | y1 = F.relu(self.fc1(x)) 53 | y2 = F.relu(self.fc2(y1)) 54 | mu = self.mu(y2) 55 | 56 | if deterministic: 57 | # used for evaluating policy 58 | action = torch.tanh(mu) 59 | log_prob = None 60 | else: 61 | log_sigma = self.log_sigma(y2) 62 | log_sigma = torch.clamp(log_sigma,min=-20.0,max=2.0) 63 | sigma = torch.exp(log_sigma) 64 | dist = Normal(mu, sigma) 65 | x_t = dist.rsample() 66 | if with_logprob: 67 | log_prob = dist.log_prob(x_t).sum(1) 68 | log_prob -= (2*(np.log(2) - x_t - F.softplus(-2*x_t))).sum(1) 69 | else: 70 | log_prob = None 71 | action = torch.tanh(x_t) 72 | 73 | return action, log_prob 74 | -------------------------------------------------------------------------------- /sac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from copy import deepcopy 4 | import random 5 | import numpy as np 6 | import torch 7 | torch.set_default_dtype(torch.float64) 8 | from torch.utils.tensorboard import SummaryWriter 9 | from env.utils import make_env 10 | from models.sac import ReplayBuffer, Q_FC, Pi_FC 11 | 12 | # Soft Actor-Critic algorithm 13 | class SAC: 14 | def __init__(self, arglist): 15 | self.arglist = arglist 16 | 17 | random.seed(self.arglist.seed) 18 | np.random.seed(self.arglist.seed) 19 | torch.manual_seed(self.arglist.seed) 20 | 21 | self.env = make_env(self.arglist.env) 22 | self.obs_size = self.env.obs_size 23 | self.action_size = self.env.action_size 24 | 25 | self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 26 | 27 | self.actor = Pi_FC(self.obs_size,self.action_size).to(self.device) 28 | 29 | if self.arglist.mode == "train": 30 | self.critic_1 = Q_FC(self.obs_size,self.action_size).to(self.device) 31 | self.critic_target_1 = deepcopy(self.critic_1) 32 | self.critic_loss_fn_1 = torch.nn.MSELoss() 33 | 34 | self.critic_2 = Q_FC(self.obs_size,self.action_size).to(self.device) 35 | self.critic_target_2 = deepcopy(self.critic_2) 36 | self.critic_loss_fn_2 = torch.nn.MSELoss() 37 | 38 | # set target entropy to -|A| 39 | self.target_entropy = - self.action_size 40 | 41 | path = "./log/"+self.env.name+"/sac" 42 | self.exp_dir = os.path.join(path, "seed_"+str(self.arglist.seed)) 43 | self.model_dir = os.path.join(self.exp_dir, "models") 44 | self.tensorboard_dir = os.path.join(self.exp_dir, "tensorboard") 45 | 46 | if self.arglist.resume: 47 | checkpoint = torch.load(os.path.join(self.model_dir,"backup.ckpt")) 48 | self.start_episode = checkpoint['episode'] + 1 49 | 50 | self.actor.load_state_dict(checkpoint['actor']) 51 | self.critic_1.load_state_dict(checkpoint['critic_1']) 52 | self.critic_target_1.load_state_dict(checkpoint['critic_target_1']) 53 | self.critic_2.load_state_dict(checkpoint['critic_2']) 54 | self.critic_target_2.load_state_dict(checkpoint['critic_target_2']) 55 | self.log_alpha = torch.tensor(checkpoint['log_alpha'].item(), dtype=torch.float64, device=self.device, requires_grad=True) 56 | 57 | self.replay_buffer = checkpoint['replay_buffer'] 58 | 59 | else: 60 | self.start_episode = 0 61 | 62 | self.log_alpha = torch.tensor(np.log(0.2), dtype=torch.float64, device=self.device, requires_grad=True) 63 | 64 | self.replay_buffer = ReplayBuffer(self.arglist.replay_size, self.device) 65 | 66 | if not os.path.exists(path): 67 | os.makedirs(path) 68 | os.mkdir(self.exp_dir) 69 | os.mkdir(self.tensorboard_dir) 70 | os.mkdir(self.model_dir) 71 | 72 | for param in self.critic_target_1.parameters(): 73 | param.requires_grad = False 74 | 75 | for param in self.critic_target_2.parameters(): 76 | param.requires_grad = False 77 | 78 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.arglist.lr) 79 | self.critic_optimizer_1 = torch.optim.Adam(self.critic_1.parameters(), lr=self.arglist.lr) 80 | self.critic_optimizer_2 = torch.optim.Adam(self.critic_2.parameters(), lr=self.arglist.lr) 81 | self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=self.arglist.lr) 82 | 83 | if self.arglist.resume: 84 | self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer']) 85 | self.critic_optimizer_1.load_state_dict(checkpoint['critic_optimizer_1']) 86 | self.critic_optimizer_2.load_state_dict(checkpoint['critic_optimizer_2']) 87 | self.log_alpha_optimizer.load_state_dict(checkpoint['log_alpha_optimizer']) 88 | 89 | print("Done loading checkpoint ...") 90 | 91 | self.train() 92 | 93 | elif self.arglist.mode == "eval": 94 | checkpoint = torch.load(self.arglist.checkpoint,map_location=self.device) 95 | self.actor.load_state_dict(checkpoint['actor']) 96 | ep_r_list = self.eval(self.arglist.episodes,self.arglist.render) 97 | 98 | def save_checkpoint(self, name): 99 | checkpoint = {'actor' : self.actor.state_dict()} 100 | torch.save(checkpoint, os.path.join(self.model_dir, name)) 101 | 102 | def save_backup(self, episode): 103 | checkpoint = {'episode' : episode,\ 104 | 'actor' : self.actor.state_dict(),\ 105 | 'actor_optimizer': self.actor_optimizer.state_dict(),\ 106 | 'critic_1' : self.critic_1.state_dict(),\ 107 | 'critic_optimizer_1': self.critic_optimizer_1.state_dict(),\ 108 | 'critic_2' : self.critic_2.state_dict(),\ 109 | 'critic_optimizer_2': self.critic_optimizer_2.state_dict(),\ 110 | 'critic_target_1' : self.critic_target_1.state_dict(),\ 111 | 'critic_target_2' : self.critic_target_2.state_dict(),\ 112 | 'log_alpha' : self.log_alpha.detach(),\ 113 | 'log_alpha_optimizer': self.log_alpha_optimizer.state_dict(),\ 114 | 'replay_buffer' : self.replay_buffer \ 115 | } 116 | torch.save(checkpoint, os.path.join(self.model_dir, "backup.ckpt")) 117 | 118 | def soft_update(self, target, source, tau): 119 | with torch.no_grad(): 120 | for target_param, param in zip(target.parameters(), source.parameters()): 121 | target_param.data.copy_((1.0 - tau) * target_param.data + tau * param.data) 122 | 123 | def train(self): 124 | writer = SummaryWriter(log_dir=self.tensorboard_dir) 125 | for episode in range(self.start_episode,self.arglist.episodes): 126 | o,_,_ = self.env.reset() 127 | ep_r = 0 128 | while True: 129 | if self.replay_buffer.__len__() >= self.arglist.start_steps: 130 | with torch.no_grad(): 131 | a, _ = self.actor(torch.tensor(o, dtype=torch.float64, device=self.device).unsqueeze(0)) 132 | a = a.cpu().numpy()[0] 133 | else: 134 | a = np.random.uniform(-1.0, 1.0, size=self.action_size) 135 | 136 | o_1, r, done = self.env.step(a) 137 | 138 | self.replay_buffer.push(o, a, r, o_1) 139 | 140 | ep_r += r 141 | o = o_1 142 | 143 | if self.replay_buffer.__len__() >= self.arglist.replay_fill: 144 | O, A, R, O_1 = self.replay_buffer.sample(self.arglist.batch_size) 145 | 146 | q_value_1 = self.critic_1(O, A) 147 | q_value_2 = self.critic_2(O, A) 148 | 149 | with torch.no_grad(): 150 | # Target actions come from *current* policy 151 | A_1, logp_A_1 = self.actor(O_1, False, True) 152 | 153 | next_q_value_1 = self.critic_target_1(O_1, A_1) 154 | next_q_value_2 = self.critic_target_2(O_1, A_1) 155 | next_q_value = torch.min(next_q_value_1, next_q_value_2) 156 | expected_q_value = R + self.arglist.gamma * (next_q_value - torch.exp(self.log_alpha) * logp_A_1) 157 | 158 | critic_loss_1 = self.critic_loss_fn_1(q_value_1, expected_q_value) 159 | self.critic_optimizer_1.zero_grad() 160 | critic_loss_1.backward() 161 | self.critic_optimizer_1.step() 162 | 163 | critic_loss_2 = self.critic_loss_fn_2(q_value_2, expected_q_value) 164 | self.critic_optimizer_2.zero_grad() 165 | critic_loss_2.backward() 166 | self.critic_optimizer_2.step() 167 | 168 | for param_1, param_2 in zip(self.critic_1.parameters(), self.critic_2.parameters()): 169 | param_1.requires_grad = False 170 | param_2.requires_grad = False 171 | 172 | A_pi, logp_A_pi = self.actor(O, False, True) 173 | q_value_pi_1 = self.critic_1(O, A_pi) 174 | q_value_pi_2 = self.critic_2(O, A_pi) 175 | q_value_pi = torch.min(q_value_pi_1, q_value_pi_2) 176 | 177 | actor_loss = - torch.mean(q_value_pi - torch.exp(self.log_alpha).detach() * logp_A_pi) 178 | self.actor_optimizer.zero_grad() 179 | actor_loss.backward() 180 | self.actor_optimizer.step() 181 | 182 | self.log_alpha_optimizer.zero_grad() 183 | alpha_loss = (torch.exp(self.log_alpha) * (-logp_A_pi - self.target_entropy).detach()).mean() 184 | alpha_loss.backward() 185 | self.log_alpha_optimizer.step() 186 | 187 | for param_1, param_2 in zip(self.critic_1.parameters(), self.critic_2.parameters()): 188 | param_1.requires_grad = True 189 | param_2.requires_grad = True 190 | 191 | self.soft_update(self.critic_target_1, self.critic_1, self.arglist.tau) 192 | self.soft_update(self.critic_target_2, self.critic_2, self.arglist.tau) 193 | 194 | if done: 195 | writer.add_scalar('ep_r', ep_r, episode) 196 | with torch.no_grad(): 197 | writer.add_scalar('alpha',torch.exp(self.log_alpha).item(),episode) 198 | if episode % self.arglist.eval_every == 0 or episode == self.arglist.episodes-1: 199 | # Evaluate agent performance 200 | eval_ep_r_list = self.eval(self.arglist.eval_over) 201 | writer.add_scalar('eval_ep_r', np.mean(eval_ep_r_list), episode) 202 | self.save_checkpoint(str(episode)+".ckpt") 203 | if (episode % 250 == 0 or episode == self.arglist.episodes-1) and episode > self.start_episode: 204 | self.save_backup(episode) 205 | break 206 | 207 | def eval(self, episodes, render=False): 208 | # Evaluate agent performance over several episodes 209 | ep_r_list = [] 210 | for episode in range(episodes): 211 | o,_,_ = self.env.reset() 212 | ep_r = 0 213 | while True: 214 | with torch.no_grad(): 215 | a, _ = self.actor(torch.tensor(o, dtype=torch.float64, device=self.device).unsqueeze(0),True) 216 | a = a.cpu().numpy()[0] 217 | o_1,r,done = self.env.step(a) 218 | if render: 219 | self.env.render() 220 | ep_r += r 221 | o = o_1 222 | if done: 223 | ep_r_list.append(ep_r) 224 | if render: 225 | print("Episode finished with total reward ",ep_r) 226 | break 227 | 228 | if self.arglist.mode == "eval": 229 | print("Average return :",np.mean(ep_r_list)) 230 | 231 | return ep_r_list 232 | 233 | def parse_args(): 234 | parser = argparse.ArgumentParser("SAC") 235 | # Common settings 236 | parser.add_argument("--env", type=str, default="cart3pole", help="pendulum / reacher / cartpole / acrobot / cart2pole / acro3bot / cart3pole") 237 | parser.add_argument("--mode", type=str, default="train", help="train or eval") 238 | parser.add_argument("--episodes", type=int, default=10000, help="number of episodes") 239 | parser.add_argument("--seed", type=int, default=0, help="seed") 240 | # Core training parameters 241 | parser.add_argument("--resume", action="store_true", default=False, help="resume training") 242 | parser.add_argument("--lr", type=float, default=3e-4, help="actor, critic learning rate") 243 | parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") 244 | parser.add_argument("--batch-size", type=int, default=256, help="batch size") 245 | parser.add_argument("--tau", type=float, default=0.005, help="soft target update parameter") 246 | parser.add_argument("--start-steps", type=int, default=int(1e4), help="start steps") 247 | parser.add_argument("--replay-size", type=int, default=int(1e6), help="replay buffer size") 248 | parser.add_argument("--replay-fill", type=int, default=int(1e4), help="elements in replay buffer before training starts") 249 | parser.add_argument("--eval-every", type=int, default=50, help="eval every _ episodes during training") 250 | parser.add_argument("--eval-over", type=int, default=50, help="each time eval over _ episodes") 251 | # Eval settings 252 | parser.add_argument("--checkpoint", type=str, default="", help="path to checkpoint") 253 | parser.add_argument("--render", action="store_true", default=False, help="render") 254 | 255 | return parser.parse_args() 256 | 257 | if __name__ == '__main__': 258 | arglist = parse_args() 259 | sac = SAC(arglist) 260 | --------------------------------------------------------------------------------