├── images ├── A2C.png ├── Reinforce.png └── Actor-Critic.png ├── README.md ├── model.py ├── 3-A2C.py ├── 3.A2C-multiple_action.ipynb └── 2.Actor-Critic.ipynb /images/A2C.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiPatil/Policy-based-RL/HEAD/images/A2C.png -------------------------------------------------------------------------------- /images/Reinforce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiPatil/Policy-based-RL/HEAD/images/Reinforce.png -------------------------------------------------------------------------------- /images/Actor-Critic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiPatil/Policy-based-RL/HEAD/images/Actor-Critic.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Policy Based RL 2 | 3 | ### Algorithms Implemented 4 | - [x] REINFORCE 5 | - [x] Actor-Critic 6 | - [x] Advantage Actor-Critic (A2C) 7 | - [ ] Asynchronous Advantage Actor-Critic (A3C) 8 | 9 | ### Comparison 10 | ##### Reinforce 11 | ![](https://github.com/HiPatil/Policy-based-RL/blob/master/images/Reinforce.png) 12 | 13 | ##### Actor-Critic 14 | ![](https://github.com/HiPatil/Policy-based-RL/blob/master/images/Actor-Critic.png) 15 | 16 | ##### A2C 17 | ![](https://github.com/HiPatil/Policy-based-RL/blob/master/images/A2C.png) 18 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.layers as layers 3 | from tensorflow.keras import optimizers, Model 4 | import gym 5 | import mujoco_py 6 | import numpy as np 7 | import tensorflow_probability as tfp 8 | 9 | gpu_devices = tf.config.experimental.list_physical_devices('GPU') 10 | tf.config.experimental.set_memory_growth(gpu_devices[0], True) 11 | 12 | 13 | print(tf.config.list_physical_devices('GPU')) 14 | env = gym.make('InvertedPendulum-v2') 15 | # env.reset() 16 | # for _ in range(100): 17 | # env.render() 18 | # env.step(env.action_space.sample()) 19 | # env.close() 20 | 21 | N_ACT = env.action_space.shape[0] 22 | N_OBS = env.observation_space.shape[0] 23 | print(N_OBS, N_ACT) 24 | 25 | class reinforce(Model): 26 | def __init__(self, obs, act): 27 | super(reinforce, self).__init__() 28 | 29 | self.l1 = layers.Dense(units = 32, activation = 'relu') 30 | self.l2 = layers.Dense(units = 16, activation = 'relu') 31 | self.l3 = layers.Dense(units = act, activation = None) 32 | 33 | def call(self, x): 34 | layer1 = self.l1(x) 35 | layer2 = self.l2(layer1) 36 | mean = self.l3(layer2) 37 | 38 | return mean 39 | 40 | policy = reinforce(N_OBS, N_ACT) 41 | optimizer = optimizers.Adam(lr = 0.01) 42 | 43 | def choose_action(state, std = 0.0): 44 | m = policy(tf.convert_to_tensor(np.expand_dims(state, axis = 0), dtype=tf.float32)) 45 | dist = tfp.distributions.Normal(loc = m, scale = std) 46 | action = dist.sample() 47 | log_prob = dist.log_prob(action) 48 | 49 | return action, log_prob 50 | 51 | 52 | state = env.reset() 53 | action, log_prob = choose_action(state) 54 | print(action) 55 | -------------------------------------------------------------------------------- /3-A2C.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import random 4 | import gym 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | env = gym.make('InvertedPendulum-v2') 13 | 14 | # import mujoco_py 15 | 16 | 17 | test_env = False 18 | if test_env: 19 | env.reset() 20 | for _ in range(1000): 21 | env.render() 22 | env.step(env.action_space.sample()) # take random actions 23 | env.close() 24 | 25 | print('Observation Shape:', env.observation_space.shape, '\nAction Shape:', env.action_space) 26 | 27 | 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | print("Code is running on:", device) 30 | 31 | ############ PARAMETERS #################### 32 | 33 | N_OBS = env.observation_space.shape[0] 34 | N_ACT = env.action_space.shape[0] 35 | N_EPISODE = 1500 36 | LEARNING_RATE = 0.001 37 | DISCOUNT = 0.99 38 | 39 | ############### Network for A2C ####################3 40 | class ACNet(nn.Module): 41 | def __init__(self, observations, actions): 42 | super(ACNet, self).__init__() 43 | self.actor = nn.Sequential( 44 | nn.Linear(observations, 32), 45 | nn.ReLU(), 46 | nn.Linear(32, 16), 47 | nn.ReLU() 48 | ) 49 | self.mu = nn.Linear(16, actions) 50 | self.sigma = nn.Linear(16, actions) 51 | 52 | 53 | self.critic = nn.Sequential( 54 | nn.Linear(observations, 64), 55 | nn.ReLU(), 56 | nn.Linear(64, 32), 57 | nn.ReLU(), 58 | nn.Linear(32, 1) 59 | ) 60 | 61 | def forward(self, x): 62 | act = self.actor(x) 63 | mean = self.mu(act) 64 | std = F.softplus(self.sigma(act)) 65 | value = self.critic(x) 66 | # value = self.value(crt) 67 | 68 | return mean, std, value 69 | 70 | ac_network = ACNet(N_OBS, N_ACT).to(device) 71 | optimizer = optim.Adam(ac_network.parameters()) 72 | 73 | writer = SummaryWriter('run/using_tensorboard') 74 | 75 | def choose_action(state): 76 | mu, sigma, value = ac_network(state) 77 | m = torch.distributions.Normal(mu, sigma) 78 | action = m.sample() 79 | log_prob = m.log_prob(action) 80 | return action.detach().cpu().numpy(), log_prob, value 81 | 82 | def compute_returns(next_state, rewards, done, discount = DISCOUNT): 83 | next_state = torch.FloatTensor(next_state).to(device) 84 | _, _, next_q_val = ac_network(next_state) 85 | returns = [] 86 | for step in reversed(range(len(rewards))): 87 | next_q_val = rewards[step] + discount*next_q_val*(1-done[step]) 88 | returns.append(next_q_val) 89 | returns.reverse() 90 | return returns 91 | 92 | def ACupdate(log_probs, q_vals, values): 93 | optimizer.zero_grad() 94 | ac_loss = 0 95 | advantage = q_vals - values 96 | # print(-(log_probs*advantage).sum()) 97 | actor_loss = -(log_probs*advantage.detach()).mean() 98 | critic_loss = advantage.pow(2).mean() 99 | 100 | ac_loss = actor_loss+critic_loss 101 | ac_loss.backward() 102 | 103 | optimizer.step() 104 | 105 | return ac_loss.item() 106 | 107 | for i in range(1, N_EPISODE+1): 108 | ep_rewards = [] 109 | log_probs = [] 110 | done_states = [] 111 | 112 | total_reward = 0 113 | done = False 114 | values = [] 115 | state = env.reset() 116 | while not done: 117 | state = torch.FloatTensor(state).to(device) 118 | action, log_prob, value = choose_action(state) 119 | 120 | next_state, reward, done, info = env.step(action) 121 | 122 | done = torch.tensor([done], dtype = torch.float, device = device) 123 | ep_rewards.append(torch.tensor([reward], dtype = torch.float, device = device)) 124 | log_probs.append(log_prob) 125 | done_states.append(done) 126 | values.append(value) 127 | 128 | total_reward += reward 129 | state = next_state 130 | 131 | q_vals = compute_returns(next_state, ep_rewards, done_states) 132 | q_vals = torch.stack(q_vals) 133 | values = torch.stack(values) 134 | log_probs = torch.stack(log_probs) 135 | 136 | loss = ACupdate(log_probs, q_vals, values) 137 | 138 | writer.add_scalar('Attr/Training loss', loss, i) 139 | writer.add_scalar('Attr/Episode reward', total_reward, i) 140 | print('Episode Trained:', i) 141 | 142 | if i%1000 == 0: 143 | torch.save(ac_network.state_dict(), 'Models/ACNet_'+str(i)+'.pth') 144 | print('Model Saved') 145 | 146 | print('Done Training') 147 | 148 | -------------------------------------------------------------------------------- /3.A2C-multiple_action.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Advantage Actor Critic (A2C)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import time\n", 18 | "import random\n", 19 | "import gym\n", 20 | "\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "import torch.optim as optim\n", 24 | "import torch.nn.functional as F\n", 25 | "\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "\n", 28 | "from IPython.display import clear_output\n", 29 | "%matplotlib inline" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "Observation Shape: (8,) \n", 42 | "Action Shape: Box(2,)\n" 43 | ] 44 | }, 45 | { 46 | "name": "stderr", 47 | "output_type": "stream", 48 | "text": [ 49 | "/home/himanshu/anaconda3/envs/rl/lib/python3.7/site-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", 50 | " warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "env = gym.make('LunarLanderContinuous-v2')\n", 56 | "\n", 57 | "print('Observation Shape:', env.observation_space.shape, '\\nAction Shape:', env.action_space)\n", 58 | "# env.reset()\n", 59 | "# for _ in range(100):\n", 60 | "# env.render()\n", 61 | "# time.sleep(0.01)\n", 62 | "# env.step(env.action_space.sample()) # take random actions \n", 63 | "# env.close()" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### Checking for cuda device" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "cuda\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 88 | "print(device)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "### Initialization of Hyperparameters" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "LEARNING_RATE = 0.001\n", 105 | "DISCOUNT = 0.99\n", 106 | "EPS = 1\n", 107 | "EPS_DECAY = 0.9999\n", 108 | "END_EPS = 0.1\n", 109 | "\n", 110 | "N_EPISODE = 20000\n", 111 | "\n", 112 | "\n", 113 | "# Dimensions of input and output of environment\n", 114 | "obs_dim = env.observation_space.shape[0]\n", 115 | "action_dim = env.action_space.shape[0]\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "### Actor Network" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 6, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "class Actor(nn.Module):\n", 132 | " def __init__(self, observations, actions):\n", 133 | " super(Actor, self).__init__()\n", 134 | " self.actor = nn.Sequential(\n", 135 | " nn.Linear(observations, 32),\n", 136 | " nn.ReLU(),\n", 137 | " nn.Linear(32, 16),\n", 138 | " nn.ReLU()\n", 139 | " )\n", 140 | " self.l1 = nn.Linear(16, actions)\n", 141 | " self.l2 = nn.Linear(16, actions)\n", 142 | " \n", 143 | " def forward(self, x):\n", 144 | " x = self.actor(x)\n", 145 | " mean = self.l1(x)\n", 146 | " variance = F.softplus(self.l2(x))\n", 147 | " \n", 148 | " return mean, variance" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "actor = Actor(obs_dim, action_dim).to(device)\n", 158 | "optimizerA = optim.Adam(actor.parameters())" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "### Sampling of action from Normal Distribution" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 8, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "def actors_action(state):\n", 175 | " \n", 176 | " mean, variance = actor(state)\n", 177 | " \n", 178 | " m = torch.distributions.Normal(mean, torch.sqrt(variance))\n", 179 | " action = m.sample()\n", 180 | " log_prob = m.log_prob(action).sum(-1).unsqueeze(-1)\n", 181 | "# print(\"Entropy\", m.entropy())\n", 182 | " \n", 183 | " return action.detach().cpu().numpy(), log_prob" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "### Critic Network" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 9, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "class Critic(nn.Module):\n", 200 | " def __init__(self, observations):\n", 201 | " super(Critic, self).__init__()\n", 202 | " self.network = nn.Sequential(\n", 203 | " nn.Linear(observations, 64),\n", 204 | " nn.ReLU(),\n", 205 | " nn.Linear(64, 32),\n", 206 | " nn.ReLU(),\n", 207 | " nn.Linear(32, 1)\n", 208 | " )\n", 209 | " \n", 210 | " def forward(self, x):\n", 211 | " return self.network(x)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 10, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "critic = Critic(obs_dim).to(device)\n", 221 | "optimizerC = optim.Adam(critic.parameters())\n", 222 | "criterionC = nn.MSELoss().to(device)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "### Computing Q-Values" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 11, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "def compute_returns(next_state, rewards, done, discount = DISCOUNT):\n", 239 | "# q_val = critic(state)\n", 240 | "\n", 241 | " next_state = torch.FloatTensor(next_state).to(device)\n", 242 | " next_q_val = critic(next_state)\n", 243 | " returns = []\n", 244 | " for step in reversed(range(len(rewards))):\n", 245 | "\n", 246 | " next_q_val = rewards[step] + discount*next_q_val*(1-done[step])\n", 247 | " returns.append(next_q_val)\n", 248 | " \n", 249 | " returns.reverse()\n", 250 | " \n", 251 | " return returns\n", 252 | " " 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "### Updating Actor based on Q-values by critic" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 12, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "def ACupdate(log_probs, q_vals, values):\n", 269 | " optimizerA.zero_grad()\n", 270 | " optimizerC.zero_grad()\n", 271 | " \n", 272 | " actor_loss = 0\n", 273 | " critic_loss = 0\n", 274 | " \n", 275 | " advantage = q_vals - values\n", 276 | "# print(-(log_probs*advantage).sum())\n", 277 | " actor_loss = -(log_probs*advantage.detach()).mean()\n", 278 | " \n", 279 | " critic_loss = advantage.pow(2).mean()\n", 280 | " \n", 281 | "# for itr in range(len(log_probs)):\n", 282 | "# log_prob = log_probs[itr]\n", 283 | "# retr = ret[itr]\n", 284 | "# value = values[itr]\n", 285 | " \n", 286 | "# actor_loss -= torch.sum(log_prob*retr)\n", 287 | "# # print('AL', actor_loss)\n", 288 | " \n", 289 | "# # print(criterionC(value, retr))\n", 290 | "# critic_loss += criterionC(value, retr.detach())\n", 291 | " \n", 292 | " actor_loss.backward()\n", 293 | " critic_loss.backward()\n", 294 | " \n", 295 | " optimizerA.step()\n", 296 | " optimizerC.step()\n", 297 | " \n", 298 | " return (actor_loss+critic_loss).item()" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 13, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "mean_avg = []\n", 308 | "losses = []\n", 309 | "def plot(n_rewards, loss):\n", 310 | " clear_output(True)\n", 311 | " plt.figure(figsize=(30,5))\n", 312 | " mean = np.mean(n_rewards[-20:])\n", 313 | " plt.subplot(131)\n", 314 | " mean_avg.append(mean)\n", 315 | " plt.title('Reward: %s' % (mean))\n", 316 | " plt.plot(n_rewards)\n", 317 | " plt.plot(mean_avg)\n", 318 | " loss = loss/1000\n", 319 | " losses.append(loss)\n", 320 | " plt.subplot(132)\n", 321 | " plt.title('loss %.2f' % loss)\n", 322 | " plt.plot(losses)\n", 323 | " plt.show()" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "### Training QAC" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": { 337 | "scrolled": false 338 | }, 339 | "outputs": [ 340 | { 341 | "data": { 342 | "image/png": "\n", 343 | "text/plain": [ 344 | "
" 345 | ] 346 | }, 347 | "metadata": { 348 | "needs_background": "light" 349 | }, 350 | "output_type": "display_data" 351 | } 352 | ], 353 | "source": [ 354 | "n_rewards = []\n", 355 | "c = 0\n", 356 | "for i in range(1, N_EPISODE+1):\n", 357 | " ep_rewards = []\n", 358 | " log_probs = []\n", 359 | " done_states = []\n", 360 | "# returns = []\n", 361 | " total_reward = 0\n", 362 | " done = False\n", 363 | " values = []\n", 364 | " \n", 365 | " state = env.reset()\n", 366 | " ret = 0\n", 367 | " \n", 368 | " while not done:\n", 369 | " state = torch.FloatTensor(state).to(device)\n", 370 | " action, log_prob = actors_action(state)\n", 371 | "# print(state, action)\n", 372 | " value = critic(state)\n", 373 | " \n", 374 | " next_state, reward, done, _ = env.step(action)\n", 375 | " \n", 376 | " done = torch.tensor([done], dtype = torch.float, device = device)\n", 377 | " ep_rewards.append(torch.tensor([reward], dtype = torch.float, device = device))\n", 378 | " log_probs.append(log_prob)\n", 379 | " done_states.append(done)\n", 380 | " values.append(value)\n", 381 | "\n", 382 | " total_reward += reward\n", 383 | " \n", 384 | "# ret = compute_returns(next_state, reward, done, ret)\n", 385 | "# if i%25 == 0:\n", 386 | "# env.render()\n", 387 | "# time.sleep(0.02)\n", 388 | " state = next_state\n", 389 | "\n", 390 | " q_vals = compute_returns(next_state, ep_rewards, done_states)\n", 391 | " q_vals = torch.stack(q_vals)\n", 392 | " values = torch.stack(values)\n", 393 | " log_probs = torch.stack(log_probs)\n", 394 | "\n", 395 | " \n", 396 | " loss = ACupdate(log_probs, q_vals, values)\n", 397 | " \n", 398 | " n_rewards.append(total_reward)\n", 399 | " plot(n_rewards, loss)\n", 400 | " \n", 401 | "# if np.mean(n_rewards[-20:]) == 1000:\n", 402 | "# torch.save(actor.state_dict(), 'adv-actor.pth')\n", 403 | "# torch.save(critic.state_dict(), 'adv-critic.pth')\n", 404 | " \n", 405 | "# c += 1\n", 406 | " \n", 407 | "# print(\"Model Saved\")" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "state = env.reset()\n", 417 | "\n", 418 | "\n", 419 | "total_reward = 0\n", 420 | "done = False\n", 421 | "\n", 422 | "actor.load_state_dict(torch.load('/home/himanshu/RL/Policy-based-RL/adv-actor.pth'))\n", 423 | "critic.load_state_dict(torch.load('/home/himanshu/RL/Policy-based-RL/adv-critic.pth'))\n", 424 | "print(\"Model Loaded\")\n", 425 | "actor.to(device)\n", 426 | "critic.to(device)\n", 427 | "\n", 428 | "while not done:\n", 429 | " state = torch.FloatTensor(state).to(device)\n", 430 | " action, log_prob = actors_action(state)\n", 431 | " \n", 432 | " next_state, reward, done, _ = env.step(action)\n", 433 | " total_reward += reward\n", 434 | " \n", 435 | " env.render()\n", 436 | " time.sleep(0.01)\n", 437 | " \n", 438 | " state = next_state\n", 439 | " \n", 440 | "print('Duration till which pole is balanced: ', total_reward)\n", 441 | "\n", 442 | "env.close()" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "metadata": {}, 449 | "outputs": [], 450 | "source": [] 451 | } 452 | ], 453 | "metadata": { 454 | "kernelspec": { 455 | "display_name": "Python 3", 456 | "language": "python", 457 | "name": "python3" 458 | }, 459 | "language_info": { 460 | "codemirror_mode": { 461 | "name": "ipython", 462 | "version": 3 463 | }, 464 | "file_extension": ".py", 465 | "mimetype": "text/x-python", 466 | "name": "python", 467 | "nbconvert_exporter": "python", 468 | "pygments_lexer": "ipython3", 469 | "version": "3.7.7" 470 | } 471 | }, 472 | "nbformat": 4, 473 | "nbformat_minor": 4 474 | } 475 | -------------------------------------------------------------------------------- /2.Actor-Critic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Action Value Actor Critic (QAC)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import time\n", 18 | "import random\n", 19 | "import gym\n", 20 | "\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "import torch.optim as optim\n", 24 | "import torch.nn.functional as F\n", 25 | "\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "\n", 28 | "from IPython.display import clear_output\n", 29 | "%matplotlib inline" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "Observation Shape: (4,) \n", 42 | "Action Shape: Box(1,)\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "env = gym.make('InvertedPendulum-v2')\n", 48 | "\n", 49 | "print('Observation Shape:', env.observation_space.shape, '\\nAction Shape:', env.action_space)\n", 50 | "# env.reset()\n", 51 | "# for _ in range(100):\n", 52 | "# env.render()\n", 53 | "# time.sleep(0.01)\n", 54 | "# env.step(env.action_space.sample()) # take random actions \n", 55 | "# env.close()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "### Checking for cuda device" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "cuda\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 80 | "print(device)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "### Initialization of Hyperparameters" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "1\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "LEARNING_RATE = 0.001\n", 105 | "DISCOUNT = 0.99\n", 106 | "EPS = 1\n", 107 | "EPS_DECAY = 0.9999\n", 108 | "END_EPS = 0.1\n", 109 | "\n", 110 | "N_EPISODE = 1500\n", 111 | "\n", 112 | "\n", 113 | "# Dimensions of input and output of environment\n", 114 | "obs_dim = env.observation_space.shape[0]\n", 115 | "action_dim = env.action_space.shape[0]\n", 116 | "print(action_dim)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "### Actor Network" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "class Actor(nn.Module):\n", 133 | " def __init__(self, observations, actions):\n", 134 | " super(Actor, self).__init__()\n", 135 | " self.actor = nn.Sequential(\n", 136 | " nn.Linear(observations, 32),\n", 137 | " nn.ReLU(),\n", 138 | " nn.Linear(32, 16),\n", 139 | " nn.ReLU()\n", 140 | " )\n", 141 | " self.l1 = nn.Linear(16, actions)\n", 142 | " self.l2 = nn.Linear(16, actions)\n", 143 | " \n", 144 | " def forward(self, x):\n", 145 | " x = self.actor(x)\n", 146 | " mean = self.l1(x)\n", 147 | " variance = F.softplus(self.l2(x))\n", 148 | " \n", 149 | " return mean, variance" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 6, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "actor = Actor(obs_dim, action_dim).to(device)\n", 159 | "optimizerA = optim.Adam(actor.parameters(), lr = LEARNING_RATE)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "### Sampling of action from Normal Distribution" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 7, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "def actors_action(state):\n", 176 | " \n", 177 | " mean, variance = actor(state)\n", 178 | " \n", 179 | " dist = torch.distributions.Normal(mean, torch.sqrt(variance))\n", 180 | " action = dist.sample()\n", 181 | " log_prob = dist.log_prob(action).sum(-1).unsqueeze(-1)\n", 182 | " \n", 183 | " return action.detach().cpu().numpy(), log_prob" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "### Critic Network" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 8, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "class Critic(nn.Module):\n", 200 | " def __init__(self, observations):\n", 201 | " super(Critic, self).__init__()\n", 202 | " self.network = nn.Sequential(\n", 203 | " nn.Linear(observations, 64),\n", 204 | " nn.ReLU(),\n", 205 | " nn.Linear(64, 32),\n", 206 | " nn.ReLU(),\n", 207 | " nn.Linear(32, 1)\n", 208 | " )\n", 209 | " \n", 210 | " def forward(self, x):\n", 211 | " return self.network(x)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 9, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "critic = Critic(obs_dim).to(device)\n", 221 | "optimizerC = optim.Adam(critic.parameters(), lr = LEARNING_RATE)\n", 222 | "criterionC = nn.MSELoss().to(device)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "### Computing Q-Values" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 10, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "def compute_returns(next_state, rewards, done, discount = DISCOUNT):\n", 239 | "# q_val = critic(state)\n", 240 | "\n", 241 | " next_state = torch.FloatTensor(next_state).to(device)\n", 242 | " next_q_val = critic(next_state)\n", 243 | " returns = []\n", 244 | " \n", 245 | " for step in reversed(range(len(rewards))):\n", 246 | " next_q_val = rewards[step] + discount*next_q_val*(1-done[step])\n", 247 | " returns.append(next_q_val)\n", 248 | " \n", 249 | " returns.reverse()\n", 250 | " return returns\n", 251 | " " 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "### Updating Actor based on Q-values by critic" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 11, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "def ACupdate(log_probs, ret, values):\n", 268 | " optimizerA.zero_grad()\n", 269 | " optimizerC.zero_grad()\n", 270 | " \n", 271 | " actor_loss = 0\n", 272 | " critic_loss = 0\n", 273 | "\n", 274 | " for log_prob, value, retr in zip(log_probs, values, ret):\n", 275 | " \n", 276 | " actor_loss -= log_prob*retr\n", 277 | "# print('AL', actor_loss)\n", 278 | " \n", 279 | "# print(criterionC(value, retr))\n", 280 | " critic_loss += criterionC(retr.detach(), value)\n", 281 | " \n", 282 | " actor_loss.backward()\n", 283 | " critic_loss.backward()\n", 284 | " \n", 285 | " optimizerA.step()\n", 286 | " optimizerC.step()\n", 287 | " \n", 288 | " return (actor_loss+critic_loss).item()" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 12, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "mean_avg = []\n", 298 | "\n", 299 | "def plot(n_rewards, loss):\n", 300 | " clear_output(True)\n", 301 | " plt.figure(figsize=(30,6))\n", 302 | " mean = np.mean(n_rewards[-10:])\n", 303 | " plt.subplot(131)\n", 304 | " mean_avg.append(mean)\n", 305 | " plt.title('Reward: %s' % (mean))\n", 306 | " plt.plot(n_rewards)\n", 307 | " plt.plot(mean_avg)\n", 308 | " \n", 309 | " l = loss[-1]\n", 310 | " plt.subplot(132)\n", 311 | " plt.title('loss %.2f' %l)\n", 312 | " plt.plot(loss)\n", 313 | " plt.show()" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "### Training QAC" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 13, 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "data": { 330 | "image/png": "\n", 331 | "text/plain": [ 332 | "
" 333 | ] 334 | }, 335 | "metadata": { 336 | "needs_background": "light" 337 | }, 338 | "output_type": "display_data" 339 | } 340 | ], 341 | "source": [ 342 | "n_rewards = []\n", 343 | "losses = []\n", 344 | "\n", 345 | "for i in range(1, N_EPISODE+1):\n", 346 | " ep_rewards = []\n", 347 | " log_probs = []\n", 348 | " done_states = []\n", 349 | "# returns = []\n", 350 | " total_reward = 0\n", 351 | " done = False\n", 352 | " values = []\n", 353 | " \n", 354 | " state = env.reset()\n", 355 | " ret = 0\n", 356 | " \n", 357 | " while not done:\n", 358 | " state = torch.FloatTensor(state).to(device)\n", 359 | " action, log_prob = actors_action(state)\n", 360 | " \n", 361 | " value = critic(state)\n", 362 | " next_state, reward, done, _ = env.step(action)\n", 363 | " \n", 364 | " \n", 365 | " done = torch.tensor([done], dtype = torch.float, device = device)\n", 366 | " \n", 367 | " ep_rewards.append(torch.tensor([reward], dtype = torch.float, device = device))\n", 368 | " log_probs.append(log_prob)\n", 369 | " done_states.append(done)\n", 370 | " values.append(value)\n", 371 | "\n", 372 | " total_reward += reward\n", 373 | "# ret = compute_returns(next_state, reward, done, ret)\n", 374 | "# if i%5 == 0:\n", 375 | "# env.render()\n", 376 | " state = next_state\n", 377 | " \n", 378 | " q_vals = compute_returns(next_state, ep_rewards, done_states)\n", 379 | " \n", 380 | " loss = ACupdate(log_probs, q_vals, values)\n", 381 | " \n", 382 | " n_rewards.append(total_reward)\n", 383 | " losses.append(loss)\n", 384 | " plot(n_rewards, losses)\n", 385 | " \n", 386 | " if np.mean(n_rewards[-10:]) == 1000:\n", 387 | " torch.save(actor.state_dict(), 'actor.pth')\n", 388 | " torch.save(critic.state_dict(), 'critic.pth')\n", 389 | " print(\"Model Saved\")" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 14, 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "ename": "FileNotFoundError", 399 | "evalue": "[Errno 2] No such file or directory: '/home/himanshu/1.RL/Policy-based-RL/actor.pth'", 400 | "output_type": "error", 401 | "traceback": [ 402 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 403 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 404 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mactor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/himanshu/1.RL/Policy-based-RL/actor.pth'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mcritic\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/himanshu/1.RL/Policy-based-RL/critic.pth'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Model Loaded\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 405 | "\u001b[0;32m~/anaconda3/envs/rl/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m 582\u001b[0m \u001b[0mpickle_load_args\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'encoding'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 583\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 584\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 585\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_zipfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 586\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0m_open_zipfile_reader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_zipfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 406 | "\u001b[0;32m~/anaconda3/envs/rl/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'w'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 407 | "\u001b[0;32m~/anaconda3/envs/rl/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_opener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_open_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 408 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/himanshu/1.RL/Policy-based-RL/actor.pth'" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "state = env.reset()\n", 414 | "\n", 415 | "\n", 416 | "total_reward = 0\n", 417 | "done = False\n", 418 | "\n", 419 | "actor.load_state_dict(torch.load('/home/himanshu/1.RL/Policy-based-RL/actor.pth'))\n", 420 | "critic.load_state_dict(torch.load('/home/himanshu/1.RL/Policy-based-RL/critic.pth'))\n", 421 | "print(\"Model Loaded\")\n", 422 | "actor.to(device)\n", 423 | "critic.to(device)\n", 424 | "\n", 425 | "while not done:\n", 426 | " state = torch.FloatTensor(state).to(device)\n", 427 | " action, log_prob = actors_action(state)\n", 428 | " \n", 429 | " next_state, reward, done, _ = env.step(action)\n", 430 | " total_reward += reward\n", 431 | " \n", 432 | " env.render()\n", 433 | " time.sleep(0.01)\n", 434 | " \n", 435 | " state = next_state\n", 436 | " \n", 437 | "print('Duration till which pole is balanced: ', total_reward)\n", 438 | "\n", 439 | "env.close()" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [] 448 | } 449 | ], 450 | "metadata": { 451 | "kernelspec": { 452 | "display_name": "Python 3", 453 | "language": "python", 454 | "name": "python3" 455 | }, 456 | "language_info": { 457 | "codemirror_mode": { 458 | "name": "ipython", 459 | "version": 3 460 | }, 461 | "file_extension": ".py", 462 | "mimetype": "text/x-python", 463 | "name": "python", 464 | "nbconvert_exporter": "python", 465 | "pygments_lexer": "ipython3", 466 | "version": "3.8.1" 467 | } 468 | }, 469 | "nbformat": 4, 470 | "nbformat_minor": 4 471 | } 472 | --------------------------------------------------------------------------------