├── floyd.yml ├── model_best.pth.tar ├── openaigym.video.0.55018.video000000.mp4 ├── install.sh ├── README.md └── Landing A Rocket With Simple Reinforcement Learning.ipynb /floyd.yml: -------------------------------------------------------------------------------- 1 | env: pytorch-0.4 2 | machine: cpu 3 | -------------------------------------------------------------------------------- /model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djbyrne/Landing-A-Rocket-With-Simple-Reinforcement-Learning/HEAD/model_best.pth.tar -------------------------------------------------------------------------------- /openaigym.video.0.55018.video000000.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djbyrne/Landing-A-Rocket-With-Simple-Reinforcement-Learning/HEAD/openaigym.video.0.55018.video000000.mp4 -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | echo "Installing deps..." 2 | 3 | pip install gym 4 | 5 | # Install packages for running gym-retro 6 | #apt-get update && apt-get install -y lua5.1 libav-tools 7 | 8 | #Ubuntu 9 | apt-get update && apt-get install -y python-numpy python-dev cmake zlib1g-dev libjpeg-dev xvfb libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig 10 | 11 | #OSX 12 | # brew install cmake boost boost-python sdl2 swig wget 13 | 14 | pip install -e '.[all]' 15 | 16 | pip install 'gym[all]' 17 | 18 | pip install box2d box2d-kengz 19 | 20 | echo "- Done!" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Landing A Rocket With Simple Reinforcement Learning 2 | This repo gives an example of using a simple method of reinforcement learning to beat the Lunar Lander environment. The agent uses a combination of CEM and neural networks using the pytorch library. 3 | 4 | You can quickly get this project up and running on the cloud with this button 5 | 6 | [![Run on FloydHub](https://static.floydhub.com/button/button.svg)](https://floydhub.com/run?template=https://github.com/djbyrne/Landing-A-Rocket-With-Simple-Reinforcement-Learning) 7 | 8 | You can checkout the accompanying Medium article here to find a step by step walkthrough on how the agent was made: 9 | https://medium.com/@donaljbyrne/landing-a-rocket-with-simple-reinforcement-learning-3a0265f8b58c 10 | -------------------------------------------------------------------------------- /Landing A Rocket With Simple Reinforcement Learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Install Dependencies" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "! bash install.sh" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 43, 24 | "metadata": { 25 | "collapsed": true 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "import gym\n", 30 | "from collections import namedtuple\n", 31 | "import numpy as np\n", 32 | "\n", 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "import torch.optim as optim\n", 36 | "import torch.nn.functional as F" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Neural Network" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 44, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "class Net(nn.Module):\n", 53 | " def __init__(self, obs_size, hidden_size, n_actions):\n", 54 | " super(Net, self).__init__()\n", 55 | " self.fc1 = nn.Linear(obs_size, hidden_size)\n", 56 | " self.fc2 = nn.Linear(hidden_size, n_actions)\n", 57 | " \n", 58 | " def forward(self, x):\n", 59 | " x = F.relu(self.fc1(x))\n", 60 | " return self.fc2(x)\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "# Generate Sessions" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 45, 73 | "metadata": { 74 | "collapsed": true 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "def generate_batch(env,batch_size, t_max=5000):\n", 79 | " \n", 80 | " activation = nn.Softmax(dim=1)\n", 81 | " batch_actions,batch_states, batch_rewards = [],[],[]\n", 82 | " \n", 83 | " for b in range(batch_size):\n", 84 | " states,actions = [],[]\n", 85 | " total_reward = 0\n", 86 | " s = env.reset()\n", 87 | " for t in range(t_max):\n", 88 | " \n", 89 | " s_v = torch.FloatTensor([s])\n", 90 | " act_probs_v = activation(net(s_v))\n", 91 | " act_probs = act_probs_v.data.numpy()[0]\n", 92 | " a = np.random.choice(len(act_probs), p=act_probs)\n", 93 | "\n", 94 | " new_s, r, done, info = env.step(a)\n", 95 | "\n", 96 | " #record sessions like you did before\n", 97 | " states.append(s)\n", 98 | " actions.append(a)\n", 99 | " total_reward += r\n", 100 | "\n", 101 | " s = new_s\n", 102 | " if done:\n", 103 | " batch_actions.append(actions)\n", 104 | " batch_states.append(states)\n", 105 | " batch_rewards.append(total_reward)\n", 106 | " break\n", 107 | " \n", 108 | " return batch_states, batch_actions, batch_rewards" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "# Filter Elite Episodes" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 46, 121 | "metadata": { 122 | "collapsed": true 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "def filter_batch(states_batch,actions_batch,rewards_batch,percentile=50):\n", 127 | " \n", 128 | " reward_threshold = np.percentile(rewards_batch, percentile)\n", 129 | " \n", 130 | " elite_states = []\n", 131 | " elite_actions = []\n", 132 | " \n", 133 | " \n", 134 | " for i in range(len(rewards_batch)):\n", 135 | " if rewards_batch[i] > reward_threshold:\n", 136 | " for j in range(len(states_batch[i])):\n", 137 | " elite_states.append(states_batch[i][j])\n", 138 | " elite_actions.append(actions_batch[i][j])\n", 139 | " \n", 140 | " return elite_states,elite_actions\n", 141 | " " 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "# Carry Out Training" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 47, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "[2018-07-27 14:42:26,697] Making new env: LunarLander-v2\n" 161 | ] 162 | }, 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "0: loss=1.384, reward_mean=-218.9, reward_threshold=-147.4\n", 168 | "1: loss=1.360, reward_mean=-240.0, reward_threshold=-137.2\n", 169 | "2: loss=1.357, reward_mean=-215.3, reward_threshold=-129.1\n" 170 | ] 171 | }, 172 | { 173 | "ename": "KeyboardInterrupt", 174 | "evalue": "", 175 | "output_type": "error", 176 | "traceback": [ 177 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 178 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 179 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m#generate new sessions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0mbatch_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_actions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_rewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgenerate_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_max\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0melite_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0melite_actions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfilter_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_actions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_rewards\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mpercentile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 180 | "\u001b[0;32m\u001b[0m in \u001b[0;36mgenerate_batch\u001b[0;34m(env, batch_size, t_max)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0ms_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mact_probs_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mactivation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms_v\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mact_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mact_probs_v\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mact_probs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mact_probs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 181 | "\u001b[0;32m~/anaconda/envs/DQN_35/lib/python3.5/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 492\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 182 | "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 183 | "\u001b[0;32m~/anaconda/envs/DQN_35/lib/python3.5/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 492\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 184 | "\u001b[0;32m~/anaconda/envs/DQN_35/lib/python3.5/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 185 | "\u001b[0;32m~/anaconda/envs/DQN_35/lib/python3.5/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 990\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 991\u001b[0m \u001b[0;31m# fused op is marginally faster\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 992\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 993\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 994\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 186 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "batch_size = 100\n", 192 | "session_size = 500\n", 193 | "percentile = 80\n", 194 | "hidden_size = 200\n", 195 | "completion_score = 200\n", 196 | "learning_rate = 0.01\n", 197 | "\n", 198 | "env = gym.make(\"LunarLander-v2\")\n", 199 | "n_states = env.observation_space.shape[0]\n", 200 | "n_actions = env.action_space.n\n", 201 | "\n", 202 | "#neural network\n", 203 | "net = Net(n_states, hidden_size, n_actions)\n", 204 | "#loss function\n", 205 | "objective = nn.CrossEntropyLoss()\n", 206 | "#optimisation function\n", 207 | "optimizer = optim.Adam(params=net.parameters(), lr=learning_rate)\n", 208 | "\n", 209 | "for i in range(session_size):\n", 210 | " #generate new sessions\n", 211 | " batch_states,batch_actions,batch_rewards = generate_batch(env, batch_size, t_max=5000)\n", 212 | "\n", 213 | " elite_states, elite_actions = filter_batch(batch_states,batch_actions,batch_rewards,percentile)\n", 214 | " \n", 215 | " optimizer.zero_grad()\n", 216 | " tensor_states = torch.FloatTensor(elite_states)\n", 217 | " tensor_actions = torch.LongTensor(elite_actions)\n", 218 | " action_scores_v = net(tensor_states)\n", 219 | " loss_v = objective(action_scores_v, tensor_actions)\n", 220 | " loss_v.backward()\n", 221 | " optimizer.step()\n", 222 | "\n", 223 | " #show results\n", 224 | " mean_reward, threshold = np.mean(batch_rewards), np.percentile(batch_rewards, percentile)\n", 225 | " print(\"%d: loss=%.3f, reward_mean=%.1f, reward_threshold=%.1f\" % (\n", 226 | " i, loss_v.item(), mean_reward, threshold))\n", 227 | " \n", 228 | " #check if \n", 229 | " if np.mean(batch_rewards)> completion_score:\n", 230 | " print(\"Environment has been successfullly completed!\")" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "# Results" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 15, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "name": "stderr", 247 | "output_type": "stream", 248 | "text": [ 249 | "[2018-07-25 21:56:26,476] Making new env: LunarLander-v2\n", 250 | "[2018-07-25 21:56:26,490] Clearing 4 monitor files from previous run (because force=True was provided)\n", 251 | "[2018-07-25 21:56:26,496] Starting new video recorder writing to /Users/donalbyrne/Workspace/DeeplyLearnDeepLearning/videos/openaigym.video.0.55018.video000000.mp4\n", 252 | "[2018-07-25 21:56:31,936] Finished writing results. You can upload them to the scoreboard via gym.upload('/Users/donalbyrne/Workspace/DeeplyLearnDeepLearning/videos')\n" 253 | ] 254 | } 255 | ], 256 | "source": [ 257 | "#record sessions\n", 258 | "import gym.wrappers\n", 259 | "env = gym.wrappers.Monitor(gym.make(\"LunarLander-v2\"), directory=\"videos\", force=True)\n", 260 | "generate_batcXh(env, 1, t_max=5000)\n", 261 | "env.close()" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": { 267 | "collapsed": true 268 | }, 269 | "source": [ 270 | "# Save our model" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 17, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stderr", 280 | "output_type": "stream", 281 | "text": [ 282 | "/Users/donalbyrne/anaconda/envs/DQN_35/lib/python3.5/site-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.\n", 283 | " \"type \" + obj.__name__ + \". It won't be checked \"\n" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "# save the model\n", 289 | "torch.save(net, 'model_best.pth.tar')" 290 | ] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.5.4" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 2 314 | } 315 | --------------------------------------------------------------------------------