├── .ipynb_checkpoints
├── Four Observations-checkpoint.ipynb
├── LSTM, BPTT=8-checkpoint.ipynb
├── MDP_Size_9-checkpoint.ipynb
├── Single Observation-checkpoint.ipynb
└── Two Observations-checkpoint.ipynb
├── Four Observations.ipynb
├── LSTM, BPTT=8.ipynb
├── MDP_Size_9.ipynb
├── README.md
├── Single Observation.ipynb
├── Two Observations.ipynb
├── __pycache__
├── gridworld.cpython-35.pyc
├── gridworld.cpython-36.pyc
└── helper.cpython-36.pyc
├── data
├── .ipynb_checkpoints
│ └── Check Performance-checkpoint.ipynb
├── Check Performance.ipynb
├── FOUR_OBSERV_NINE.pkl
├── FOUR_OBSERV_NINE_WEIGHTS.torch
├── GIFs
│ ├── LSTM_SIZE_9.gif
│ ├── LSTM_SIZE_9_frames.gif
│ ├── LSTM_SIZE_9_local.gif
│ ├── MDP_SIZE_9.gif
│ ├── SINGEL_SIZE_9_local.gif
│ ├── SINGLE_OBSERV_9.gif
│ ├── SINGLE_SIZE_9_frames.gif
│ └── perf.png
├── LSTM_POMDP_V4.pkl
├── LSTM_POMDP_V4_TEST.pkl
├── LSTM_POMDP_V4_WEIGHTS.torch
├── MDP_ENV_SIZE_NINE.pkl
├── MDP_ENV_SIZE_NINE_WEIGHTS.torch
├── SINGLE_OBSERV_NINE.pkl
├── SINGLE_OBSERV_NINE_WEIGHTS.torch
├── TWO_OBSERV_NINE.pkl
├── TWO_OBSERV_NINE_WEIGHTS.torch
├── algo.png
├── download (1).png
└── download.png
└── gridworld.py
/.ipynb_checkpoints/Four Observations-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "from gridworld import gameEnv\n",
12 | "import numpy as np\n",
13 | "%matplotlib inline\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "from collections import deque\n",
16 | "import pickle\n",
17 | "from skimage.color import rgb2gray\n",
18 | "import random\n",
19 | "import torch\n",
20 | "import torch.nn as nn"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | "
Define Environment Object "
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "metadata": {},
34 | "outputs": [
35 | {
36 | "data": {
37 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3VuMXeV5xvH/Uw8OCUljm7SWi0ltFARCVTGRlYLggpLSOjSCXEQpKJHSKi03qUraSsG0Fy2VIiVSlYSLKpIFSVGVcohDE4uLpK7jpL1yMIe2YONgEgi2DKYCcrpAdXh7sZfbwR17r5nZe2YW3/8njfZeax/Wt2bp2eswe943VYWktvzCcg9A0tIz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEUFP8m2JIeSHE6yfVKDkjRdWegXeJKsAr4HXAscAR4CbqqqA5MbnqRpmFnEa98DHK6q7wMkuRe4ATht8JP4NUFpyqoq456zmEP984DnZk0f6eZJWuEWs8fvJcnNwM3TXo6k/hYT/KPA+bOmN3bzXqeqdgA7wEN9aaVYzKH+Q8CFSTYnWQ3cCOyazLAkTdOC9/hVdSLJHwPfBFYBX6yqJyY2MklTs+A/5y1oYR7qS1M37av6kgbK4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVobPCTfDHJ8SSPz5q3LsnuJE91t2unO0xJk9Rnj//3wLZT5m0H9lTVhcCeblrSQIwNflX9K/DSKbNvAO7u7t8NfGDC45I0RQs9x19fVce6+88D6yc0HklLYNGddKqqzlQ910460sqz0D3+C0k2AHS3x0/3xKraUVVbq2rrApclacIWGvxdwEe7+x8Fvj6Z4UhaCmMbaiS5B7gaeAfwAvBXwNeA+4F3As8CH6qqUy8AzvVeNtSQpqxPQw076UhvMHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPp10zk+yN8mBJE8kuaWbbzcdaaD61NzbAGyoqkeSvA14mFEDjd8HXqqqTyfZDqytqlvHvJelt6Qpm0jprao6VlWPdPd/AhwEzsNuOtJgzauhRpJNwGXAPnp207GhhrTy9K6ym+StwHeAT1XVA0leqao1sx5/uarOeJ7vob40fROrspvkLOCrwJer6oFudu9uOpJWlj5X9QPcBRysqs/OeshuOtJA9bmqfxXwb8B/Aq91s/+C0Xn+vLrpeKgvTZ+ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgedXc01LwP5fPbOx/nKoH9/hSgwy+1KA+NffOTvLdJP/eddK5vZu/Ocm+JIeT3Jdk9fSHK2kS+uzxXwWuqapLgS3AtiSXA58BPldV7wJeBj42vWFKmqQ+nXSqqn7aTZ7V/RRwDbCzm28nHWlA+tbVX5XkMUa183cDTwOvVNWJ7ilHGLXVmuu1NyfZn2T/JAYsafF6Bb+qfl5VW4CNwHuAi/suoKp2VNXWqtq6wDFKmrB5XdWvqleAvcAVwJokJ78HsBE4OuGxSZqSPlf1fynJmu7+m4FrGXXM3Qt8sHuanXSkAenTSefXGV28W8Xog+L+qvqbJBcA9wLrgEeBj1TVq2Pey6+ljeWv6Mz85t44dtIZJH9FZ2bwx7GTjqQ5GXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Q5+V2L70SQPdtN20pEGaj57/FsYFdk8yU460kD1baixEfhd4M5uOthJRxqsvnv8zwOfBF7rps/FTjrSYPWpq/9+4HhVPbyQBdhJR1p5ZsY/hSuB65NcB5wN/CJwB10nnW6vbycdaUD6dMu9rao2VtUm4EbgW1X1YeykIw3WYv6OfyvwZ0kOMzrnv2syQ5I0bXbSWXH8FZ2ZnXTGsZOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn5h5JngF+AvwcOFFVW5OsA+4DNgHPAB+qqpenM0xJkzSfPf5vVtWWWdVytwN7qupCYE83LWkAFnOofwOjRhpgQw1pUPoGv4B/TvJwkpu7eeur6lh3/3lg/cRHJ2kqep3jA1dV1dEkvwzsTvLk7Aerqk5XSLP7oLh5rsckLY95V9lN8tfAT4E/Aq6uqmNJNgDfrqqLxrzWErJj+Ss6M6vsjjORKrtJzknytpP3gd8GHgd2MWqkATbUkAZl7B4/yQXAP3WTM8A/VtWnkpwL3A+8E3iW0Z/zXhrzXu7OxvJXdGbu8cfps8e3ocaK46/ozAz+ODbUkDQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgvv+dpyXjN9M0fe7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mTZGeSJ5McTHJFknVJdid5qrtdO+3BSpqMvnv8O4BvVNXFwKXAQeykIw1Wn2KbbwceAy6oWU9OcgjLa0srzqRq7m0GXgS+lOTRJHd2ZbbtpCMNVJ/gzwDvBr5QVZcBP+OUw/ruSOC0nXSS7E+yf7GDlTQZfYJ/BDhSVfu66Z2MPghe6A7x6W6Pz/XiqtpRVVtnddmVtMzGBr+qngeeS3Ly/P29wAHspCMNVq+GGkm2AHcCq4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9yUZLHZv38OMkn7KQjDde8Sm8lWQUcBX4D+DjwUlV9Osl2YG1V3Trm9ZbekqZsGqW33gs8XVXPAjcAd3fz7wY+MM/3krRM5hv8G4F7uvt20pEGqnfwk6wGrge+cupjdtKRhmU+e/z3AY9U1QvdtJ10pIGaT/Bv4v8O88FOOtJg9e2kcw7wQ0atsn/UzTsXO+lIK46ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6BT/JnyZ5IsnjSe5JcnaSzUn2JTmc5L6uCq+kAejTQus84E+ArVX1a8AqRvX1PwN8rqreBbwMfGyaA5U0OX0P9WeANyeZAd4CHAOuAXZ2j9tJRxqQscGvqqPA3zKqsnsM+BHwMPBKVZ3onnYEOG9ag5Q0WX0O9dcy6pO3GfgV4BxgW98F2ElHWnlmejznt4AfVNWLAEkeAK4E1iSZ6fb6Gxl10f1/qmoHsKN7reW1pRWgzzn+D4HLk7wlSRh1zD0A7AU+2D3HTjrSgPTtpHM78HvACeBR4A8ZndPfC6zr5n2kql4d8z7u8aUps5OO1CA76Uiak8GXGmTwpQYZfKlBff6OP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akv1VtXVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7g71iGZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOfZFuSQ12dvu1LuezFSnJ+kr1JDnT1B2/p5q9LsjvJU93t2uUe63wkWZXk0SQPdtODraWYZE2SnUmeTHIwyRVD3j7TrHW5ZMFPsgr4O+B9wCXATUkuWarlT8AJ4M+r6hLgcuDj3fi3A3uq6kJgTzc9JLcAB2dND7mW4h3AN6rqYuBSRus1yO0z9VqXVbUkP8AVwDdnTd8G3LZUy5/C+nwduBY4BGzo5m0ADi332OaxDhsZheEa4EEgjL4gMjPXNlvJP8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3JFThpsnb4km4DLgH3A+qo61j30PLB+mYa1EJ8HPgm81k2fy3BrKW4GXgS+1J263JnkHAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsSfSZK8HzheVQ8v91gmZAZ4N/CFqrqM0VfDX3dYP7Dts6hal+MsZfCPAufPmj5tnb6VKslZjEL/5ap6oJv9QpIN3eMbgOPLNb55uhK4PskzjCopXcPoHHlNV0YdhrWNjgBHqmpfN72T0QfBULfP/9a6rKr/Bl5X67J7zoK3z1IG/yHgwu6q5GpGFyp2LeHyF6WrN3gXcLCqPjvroV2Mag7CgGoPVtVtVbWxqjYx2hbfqqoPM9BailX1PPBckou6WSdrQw5y+zDtWpdLfMHiOuB7wNPAXy73BZR5jv0qRoeJ/wE81v1cx+i8eA/wFPAvwLrlHusC1u1q4MHu/gXAd4HDwFeANy33+OaxHluA/d02+hqwdsjbB7gdeBJ4HPgH4E2T2j5+c09qkBf3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGvQ/4fcTtlJMEyYAAAAASUVORK5CYII=\n",
38 | "text/plain": [
39 | ""
40 | ]
41 | },
42 | "metadata": {},
43 | "output_type": "display_data"
44 | }
45 | ],
46 | "source": [
47 | "env = gameEnv(partial=True,size=9)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 3,
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "data": {
57 | "text/plain": [
58 | ""
59 | ]
60 | },
61 | "execution_count": 3,
62 | "metadata": {},
63 | "output_type": "execute_result"
64 | },
65 | {
66 | "data": {
67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLlJREFUeJzt3X/oXfV9x/Hna4nW1m41cS5kxs2UiiIDowlOsYxNzWZd0f1RRCmjDMF/uk3XQqvbH6WwP1oYbf1jFETbyXD+qNU1hGLn0pQyGKlff6zVRJtoY01QEzudnYNtad/745xsX0OynG++997v9/h5PuBy7znn3pzP4fC659yT832/U1VIassvLPUAJM2ewZcaZPClBhl8qUEGX2qQwZcaZPClBi0q+EmuSvJckj1Jbp3UoCRNV070Bp4kK4AfApuBfcBjwA1VtXNyw5M0DSsX8dmLgT1V9QJAkvuAa4FjBj+JtwlqUTZu3LjUQ1jW9u7dy2uvvZbjvW8xwT8TeGne9D7gNxfx70nHNTc3t9RDWNY2bdo06H2LCf4gSW4Cbpr2eiQNt5jg7wfOmje9rp/3NlV1B3AHeKovLReLuar/GHBOkvVJTgauB7ZMZliSpumEj/hVdSjJHwPfAlYAX6mqZyY2MklTs6jf+FX1TeCbExqLpBnxzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQccNfpKvJDmQ5Ol581YneTTJ7v551XSHKWmShhzx/wa46oh5twLbquocYFs/LWkkjhv8qvou8K9HzL4WuLt/fTfwBxMel6QpOtHf+Guq6uX+9SvAmgmNR9IMLLqTTlXV/9cow0460vJzokf8V5OsBeifDxzrjVV1R1VtqqphTb0kTd2JBn8L8LH+9ceAb0xmOJJmYch/590L/DNwbpJ9SW4EPgdsTrIbuLKfljQSx/2NX1U3HGPRFRMei6QZ8c49qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUFDSm+dlWR7kp1Jnklycz/fbjrSSA054h8CPllV5wOXAB9Pcj5205FGa0gnnZer6on+9U+BXcCZ2E1HGq0FNdRIcjZwIbCDgd10bKghLT+DL+4leS/wdeCWqnpz/rKqKuCo3XRsqCEtP4OCn+QkutDfU1UP9bMHd9ORtLwMuaof4C5gV1V9Yd4iu+lIIzXkN/5lwB8CP0jyVD/vz+m65zzQd9Z5EbhuOkOUNGlDOun8E5BjLLabjjRC3rknNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw1aUM29RdsIzM10jeOTo1YwkybKI77UIIMvNWhIzb1Tknwvyb/0nXQ+289fn2RHkj1J7k9y8vSHK2kShhzx/xO4vKouADYAVyW5BPg88MWq+gDwOnDj9IYpaZKGdNKpqvr3fvKk/lHA5cCD/Xw76UgjMrSu/oq+wu4B4FHgeeCNqjrUv2UfXVuto332piRzSeY4OIkhS1qsQcGvqp9V1QZgHXAxcN7QFbytk84ZJzhKSRO1oKv6VfUGsB24FDgtyeH7ANYB+yc8NklTMuSq/hlJTutfvxvYTNcxdzvwkf5tdtKRRmTInXtrgbuTrKD7onigqrYm2Qncl+QvgSfp2mxJGoEhnXS+T9ca+8j5L9D93pc0Mt65JzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVocPD7EttPJtnaT9tJRxqphRzxb6YrsnmYnXSkkRraUGMd8PvAnf10sJOONFpDj/hfAj4F/LyfPh076UijNaSu/oeBA1X1+ImswE460vIzpK7+ZcA1Sa4GTgF+CbidvpNOf9S3k440IkO65d5WVeuq6mzgeuDbVfVR7KQjjdZi/h//08Ankuyh+81vJx1pJIac6v+vqvoO8J3+tZ10pJHyzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGDSrEkWQv8FPgZ8ChqtqUZDVwP3A2sBe4rqpen84wJU3SQo74v1NVG6pqUz99K7Ctqs4BtvXTkkZgMaf619I10gAbakijMjT4BfxDkseT3NTPW1NVL/evXwHWTHx0kqZiaLHND1bV/iS/Ajya5Nn5C6uqktTRPth/UXRfFr+2mKFKmpRBR/yq2t8/HwAepquu+2qStQD984FjfNZOOtIyM6SF1qlJfvHwa+B3gaeBLXSNNMCGGtKoDDnVXwM83DXIZSXwd1X1SJLHgAeS3Ai8CFw3vWFKmqTjBr9vnHHBUeb/BLhiGoOSNF3euSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aOhf503ERjYyx9wsVzk+R/0bR2myPOJLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgQcFPclqSB5M8m2RXkkuTrE7yaJLd/fOqaQ9W0mQMPeLfDjxSVefRleHahZ10pNEaUmX3fcBvAXcBVNV/VdUb2ElHGq0hR/z1wEHgq0meTHJnX2bbTjrSSA0J/krgIuDLVXUh8BZHnNZXVXGMu8yT3JRkLsncwYMHFzteSRMwJPj7gH1VtaOffpDui2DBnXTOOMNWOtJycNzgV9UrwEtJzu1nXQHsxE460mgN/bPcPwHuSXIy8ALwR3RfGnbSkUZoUPCr6ilg01EW2UlHGiHv3JMaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaNKSu/rlJnpr3eDPJLXbSkcZrSLHN56pqQ1VtADYC/wE8jJ10pNFa6Kn+FcDzVfUidtKRRmuhwb8euLd/bScdaaQGB78vrX0N8LUjl9lJRxqXhRzxPwQ8UVWv9tN20pFGaiHBv4H/O80HO+lIozUo+H133M3AQ/Nmfw7YnGQ3cGU/LWkEhnbSeQs4/Yh5P8FOOtIoeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KChpbf+LMkzSZ5Ocm+SU5KsT7IjyZ4k9/dVeCWNwJAWWmcCfwpsqqrfAFbQ1df/PPDFqvoA8Dpw4zQHKmlyhp7qrwTenWQl8B7gZeBy4MF+uZ10pBEZ0jtvP/BXwI/pAv9vwOPAG1V1qH/bPuDMaQ1S0mQNOdVfRdcnbz3wq8CpwFVDV2AnHWn5GXKqfyXwo6o6WFX/TVdb/zLgtP7UH2AdsP9oH7aTjrT8DAn+j4FLkrwnSehq6e8EtgMf6d9jJx1pRIb8xt9BdxHvCeAH/WfuAD4NfCLJHrpmG3dNcZySJmhoJ53PAJ85YvYLwMUTH5GkqfPOPalBBl9qkMGXGmTwpQalqma3suQg8Bbw2sxWOn2/jNuzXL2TtgWGbc+vV9Vxb5iZafABksxV1aaZrnSK3J7l6520LTDZ7fFUX2qQwZcatBTBv2MJ1jlNbs/y9U7aFpjg9sz8N76kpeepvtSgmQY/yVVJnuvr9N06y3UvVpKzkmxPsrOvP3hzP391kkeT7O6fVy31WBciyYokTybZ2k+PtpZiktOSPJjk2SS7klw65v0zzVqXMwt+khXAXwMfAs4Hbkhy/qzWPwGHgE9W1fnAJcDH+/HfCmyrqnOAbf30mNwM7Jo3PeZaircDj1TVecAFdNs1yv0z9VqXVTWTB3Ap8K1507cBt81q/VPYnm8Am4HngLX9vLXAc0s9tgVswzq6MFwObAVCd4PIyqPts+X8AN4H/Ij+utW8+aPcP3Sl7F4CVtP9Fe1W4PcmtX9meap/eEMOG22dviRnAxcCO4A1VfVyv+gVYM0SDetEfAn4FPDzfvp0xltLcT1wEPhq/9PlziSnMtL9U1OudenFvQVK8l7g68AtVfXm/GXVfQ2P4r9JknwYOFBVjy/1WCZkJXAR8OWqupDu1vC3ndaPbP8sqtbl8cwy+PuBs+ZNH7NO33KV5CS60N9TVQ/1s19NsrZfvhY4sFTjW6DLgGuS7AXuozvdv52BtRSXoX3AvuoqRkFXNeoixrt/FlXr8nhmGfzHgHP6q5In012o2DLD9S9KX2/wLmBXVX1h3qItdDUHYUS1B6vqtqpaV1Vn0+2Lb1fVRxlpLcWqegV4Kcm5/azDtSFHuX+Ydq3LGV+wuBr4IfA88BdLfQFlgWP/IN1p4veBp/rH1XS/i7cBu4F/BFYv9VhPYNt+G9jav34/8D1gD/A14F1LPb4FbMcGYK7fR38PrBrz/gE+CzwLPA38LfCuSe0f79yTGuTFPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQb9DzfZzH2Ei/mJAAAAAElFTkSuQmCC\n",
68 | "text/plain": [
69 | ""
70 | ]
71 | },
72 | "metadata": {},
73 | "output_type": "display_data"
74 | }
75 | ],
76 | "source": [
77 | "prev_state = env.reset()\n",
78 | "plt.imshow(prev_state)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | " Training Q Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | " Hyper-parameters "
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 4,
98 | "metadata": {
99 | "collapsed": true
100 | },
101 | "outputs": [],
102 | "source": [
103 | "BATCH_SIZE = 32\n",
104 | "FREEZE_INTERVAL = 20000 # steps\n",
105 | "MEMORY_SIZE = 60000 \n",
106 | "OUTPUT_SIZE = 4\n",
107 | "TOTAL_EPISODES = 10000\n",
108 | "MAX_STEPS = 50\n",
109 | "INITIAL_EPSILON = 1.0\n",
110 | "FINAL_EPSILON = 0.1\n",
111 | "GAMMA = 0.99\n",
112 | "INPUT_IMAGE_DIM = 84\n",
113 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | " Save Dictionay Function "
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 5,
126 | "metadata": {
127 | "collapsed": true
128 | },
129 | "outputs": [],
130 | "source": [
131 | "def save_obj(obj, name ):\n",
132 | " with open('data/'+ name + '.pkl', 'wb') as f:\n",
133 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | " Experience Replay "
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 6,
146 | "metadata": {
147 | "collapsed": true
148 | },
149 | "outputs": [],
150 | "source": [
151 | "class Memory():\n",
152 | " \n",
153 | " def __init__(self,memsize):\n",
154 | " self.memsize = memsize\n",
155 | " self.memory = deque(maxlen=self.memsize)\n",
156 | " \n",
157 | " def add_sample(self,sample):\n",
158 | " self.memory.append(sample)\n",
159 | " \n",
160 | " def get_batch(self,size):\n",
161 | " return random.sample(self.memory,k=size)"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | " Frame Collector "
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 7,
174 | "metadata": {
175 | "collapsed": true
176 | },
177 | "outputs": [],
178 | "source": [
179 | "class FrameCollector():\n",
180 | " \n",
181 | " def __init__(self,num_frames,img_dim):\n",
182 | " self.num_frames = num_frames\n",
183 | " self.img_dim = img_dim\n",
184 | " self.frames = deque(maxlen=self.num_frames)\n",
185 | " \n",
186 | " def reset(self):\n",
187 | " tmp = np.zeros((self.img_dim,self.img_dim))\n",
188 | " for i in range(0,self.num_frames):\n",
189 | " self.frames.append(tmp)\n",
190 | " \n",
191 | " def add_frame(self,frame):\n",
192 | " self.frames.append(frame)\n",
193 | " \n",
194 | " def get_state(self):\n",
195 | " return np.array(self.frames)"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | " Preprocess Images "
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 8,
208 | "metadata": {
209 | "collapsed": true
210 | },
211 | "outputs": [],
212 | "source": [
213 | "def preprocess_image(image):\n",
214 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n",
215 | " return np.copy(image)"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 9,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "data": {
225 | "text/plain": [
226 | ""
227 | ]
228 | },
229 | "execution_count": 9,
230 | "metadata": {},
231 | "output_type": "execute_result"
232 | },
233 | {
234 | "data": {
235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADM1JREFUeJzt3X/oXfV9x/Hna4nW1k5jnIbMyMxoUGRg1OAUy9jUbKktuj+KKDLKEPJPt+laaHX7oxT2RwujrcIoiLaT4fxRq2sIxc6lljEYqfHHWk1ME22sCWqi80fnYFva9/64J9u3WbKcb+6P7/f4eT7gcu85596cz+Hwuufc8z15v1NVSGrLLy30ACTNnsGXGmTwpQYZfKlBBl9qkMGXGmTwpQaNFfwkG5LsTLI7ya2TGpSk6crx3sCTZAnwI2A9sBd4ArihqrZPbniSpmHpGJ+9BNhdVS8CJLkfuBY4avCTeJugxnLxxRcv9BAWtT179vD666/nWO8bJ/hnAS/Pmd4L/OYY/550TNu2bVvoISxq69at6/W+cYLfS5KNwMZpr0dSf+MEfx9w9pzpVd28X1BVdwJ3gqf60mIxzlX9J4A1SVYnORG4Htg0mWFJmqbjPuJX1cEkfwR8B1gCfK2qnpvYyCRNzVi/8avq28C3JzQWSTPinXtSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSg44Z/CRfS7I/ybNz5i1P8liSXd3zadMdpqRJ6nPE/2tgw2HzbgW2VNUaYEs3LWkgjhn8qvpH4F8Pm30tcE/3+h7g9yc8LklTdLy/8VdU1Svd61eBFRMaj6QZGLuTTlXV/9cow0460uJzvEf815KsBOie9x/tjVV1Z1Wtq6p+Tb0kTd3xBn8T8Inu9SeAb01mOJJmoc+f8+4D/hk4N8neJDcBXwDWJ9kFXNVNSxqIY/7Gr6objrLoygmPRdKMeOee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KA+pbfOTvJ4ku1JnktyczffbjrSQPU54h8EPl1V5wOXAp9Mcj5205EGq08nnVeq6qnu9U+BHcBZ2E1HGqx5NdRIcg5wIbCVnt10bKghLT69L+4l+SDwTeCWqnpn7rKqKuCI3XRsqCEtPr2Cn+QERqG/t6oe7mb37qYjaXHpc1U/wN3Ajqr60pxFdtORBqrPb/zLgT8AfpjkmW7enzHqnvNg11nnJeC66QxR0qT16aTzT0COsthuOtIAeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVoXjX3xrVmzRruuOOOWa5ycG688caFHoIa4BFfapDBlxrUp+beSUm+n+Rfuk46n+/mr06yNcnuJA8kOXH6w5U0CX2O+P8BXFFVFwBrgQ1JLgW+CHy5qj4EvAncNL1hSpqkPp10qqr+rZs8oXsUcAXwUDffTjrSgPStq7+kq7C7H3gMeAF4q6oOdm/Zy6it1pE+uzHJtiTb3n777UmMWdKYegW/qn5WVWuBVcAlwHl9VzC3k86pp556nMOUNEnzuqpfVW8BjwOXAcuSHLoPYBWwb8JjkzQlfa7qn5FkWff6/cB6Rh1zHwc+3r3NTjrSgPS5c28lcE+SJYy+KB6sqs1JtgP3J/kL4GlGbbYkDUCfTjo/YNQa+/D5LzL6vS9pYLxzT2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQ7+B3JbafTrK5m7aTjjRQ8zni38yoyOYhdtKRBqpvQ41VwEeBu7rpYCcdabD6HvG/AnwG+Hk3fTp20pEGq09d/Y8B+6vqyeNZgZ10pMWnT139y4FrklwNnAScAtxO10mnO+rbSUcakD7dcm+rqlVVdQ5wPfDdqroRO+lIgzXO3/E/C3wqyW5Gv/ntpCMNRJ9T/f9RVd8Dvte9tpOONFDeuSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgXoU4kuwBfgr8DDhYVeuSLAceAM4B9gDXVdWb0xmmpEmazxH/d6pqbVWt66ZvBbZU1RpgSzctaQDGOdW/llEjDbChhjQofYNfwN8neTLJxm7eiqp6pXv9KrBi4qOTNBV9i21+uKr2JTkTeCzJ83MXVlUlqSN9sPui2Ahw5plnjjVYSZPR64hfVfu65/3AI4yq676WZCVA97z/KJ+1k460yPRpoXVykl8+9Br4XeBZYBOjRhpgQw1pUPqc6q8AHhk1yGUp8LdV9WiSJ4AHk9wEvARcN71hSpqkYwa/a5xxwRHmvwFcOY1BSZou79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtT3f+dNxCmnnMKGDRtmucrBeeONNxZ6CGqAR3ypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUK/hJliV5KMnzSXYkuSzJ8iSPJdnVPZ827cFKmoy+R/zbgUer6jxGZbh2YCcdabD6VNk9Ffgt4G6AqvrPqnoLO+lIg9XniL8aOAB8PcnTSe7qymzbSUcaqD7BXwpcBHy1qi4E3uWw0/qqKkZttv6PJBuTbEuy7cCBA+OOV9IE9An+XmBvVW3tph9i9EUw7046Z5xxxiTGLGlMxwx+Vb0KvJzk3G7WlcB27KQjDVbf/5b7x8C9SU4EXgT+kNGXhp10pAHqFfyqegZYd4RFdtKRBsg796QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9amrf26SZ+Y83klyi510pOHqU2xzZ1Wtraq1wMXAvwOPYCcdabDme6p/JfBCVb2EnXSkwZpv8K8H7ute20lHGqjewe9Ka18DfOPwZXbSkYZlPkf8jwBPVdVr3bSddKSBmk/wb+B/T/PBTjrSYPUKftcddz3w8JzZXwDWJ9kFXNVNSxqAvp103gVOP2zeG9hJRxok79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtS39NafJnkuybNJ7ktyUpLVSbYm2Z3kga4Kr6QB6NNC6yzgT4B1VfUbwBJG9fW/CHy5qj4EvAncNM2BSpqcvqf6S4H3J1kKfAB4BbgCeKhbbicdaUD69M7bB/wl8BNGgX8beBJ4q6oOdm/bC5w1rUFKmqw+p/qnMeqTtxr4VeBkYEPfFdhJR1p8+pzqXwX8uKoOVNV/MaqtfzmwrDv1B1gF7DvSh+2kIy0+fYL/E+DSJB9IEka19LcDjwMf795jJx1pQPr8xt/K6CLeU8APu8/cCXwW+FSS3Yyabdw9xXFKmqC+nXQ+B3zusNkvApdMfESSps4796QGGXypQQZfapDBlxqUqprdypIDwLvA6zNb6fT9Cm7PYvVe2hbotz2/VlXHvGFmpsEHSLKtqtbNdKVT5PYsXu+lbYHJbo+n+lKDDL7UoIUI/p0LsM5pcnsWr/fStsAEt2fmv/ElLTxP9aUGzTT4STYk2dnV6bt1luseV5KzkzyeZHtXf/Dmbv7yJI8l2dU9n7bQY52PJEuSPJ1kczc92FqKSZYleSjJ80l2JLlsyPtnmrUuZxb8JEuAvwI+ApwP3JDk/FmtfwIOAp+uqvOBS4FPduO/FdhSVWuALd30kNwM7JgzPeRaircDj1bVecAFjLZrkPtn6rUuq2omD+Ay4Dtzpm8DbpvV+qewPd8C1gM7gZXdvJXAzoUe2zy2YRWjMFwBbAbC6AaRpUfaZ4v5AZwK/JjuutWc+YPcP4xK2b0MLGf0v2g3A783qf0zy1P9QxtyyGDr9CU5B7gQ2AqsqKpXukWvAisWaFjH4yvAZ4Cfd9OnM9xaiquBA8DXu58udyU5mYHun5pyrUsv7s1Tkg8C3wRuqap35i6r0dfwIP5MkuRjwP6qenKhxzIhS4GLgK9W1YWMbg3/hdP6ge2fsWpdHsssg78POHvO9FHr9C1WSU5gFPp7q+rhbvZrSVZ2y1cC+xdqfPN0OXBNkj3A/YxO92+nZy3FRWgvsLdGFaNgVDXqIoa7f8aqdXksswz+E8Ca7qrkiYwuVGya4frH0tUbvBvYUVVfmrNoE6OagzCg2oNVdVtVraqqcxjti+9W1Y0MtJZiVb0KvJzk3G7WodqQg9w/TLvW5YwvWFwN/Ah4Afjzhb6AMs+xf5jRaeIPgGe6x9WMfhdvAXYB/wAsX+ixHse2/TawuXv968D3gd3AN4D3LfT45rEda4Ft3T76O+C0Ie8f4PPA88CzwN8A75vU/vHOPalBXtyTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9q0H8DxmXSvDxTWXAAAAAASUVORK5CYII=\n",
236 | "text/plain": [
237 | ""
238 | ]
239 | },
240 | "metadata": {},
241 | "output_type": "display_data"
242 | }
243 | ],
244 | "source": [
245 | "processed_prev_state = preprocess_image(prev_state)\n",
246 | "plt.imshow(processed_prev_state,cmap='gray')"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | " Build Model "
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 10,
259 | "metadata": {},
260 | "outputs": [
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "Network(\n",
266 | " (conv_layer1): Conv2d(4, 64, kernel_size=(8, 8), stride=(4, 4))\n",
267 | " (conv_layer2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))\n",
268 | " (conv_layer3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n",
269 | " (fc1): Linear(in_features=6272, out_features=512, bias=True)\n",
270 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n",
271 | " (relu): ReLU()\n",
272 | ")\n"
273 | ]
274 | }
275 | ],
276 | "source": [
277 | "import torch.nn as nn\n",
278 | "import torch\n",
279 | "\n",
280 | "class Network(nn.Module):\n",
281 | " \n",
282 | " def __init__(self,image_input_size,out_size):\n",
283 | " super(Network,self).__init__()\n",
284 | " self.image_input_size = image_input_size\n",
285 | " self.out_size = out_size\n",
286 | "\n",
287 | " self.conv_layer1 = nn.Conv2d(in_channels=4,out_channels=64,kernel_size=8,stride=4) # GRAY - 1\n",
288 | " self.conv_layer2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2)\n",
289 | " self.conv_layer3 = nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1)\n",
290 | " self.fc1 = nn.Linear(in_features=7*7*128,out_features=512)\n",
291 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n",
292 | " self.relu = nn.ReLU()\n",
293 | "\n",
294 | " def forward(self,x,bsize):\n",
295 | " x = x.view(bsize,4,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n",
296 | " conv_out = self.conv_layer1(x)\n",
297 | " conv_out = self.relu(conv_out)\n",
298 | " conv_out = self.conv_layer2(conv_out)\n",
299 | " conv_out = self.relu(conv_out)\n",
300 | " conv_out = self.conv_layer3(conv_out)\n",
301 | " conv_out = self.relu(conv_out)\n",
302 | " out = self.fc1(conv_out.view(bsize,7*7*128))\n",
303 | " out = self.relu(out)\n",
304 | " out = self.fc2(out)\n",
305 | " return out\n",
306 | "\n",
307 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).cuda()\n",
308 | "print(main_model)"
309 | ]
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "metadata": {},
314 | "source": [
315 | " Deep Q Learning with Freeze Network "
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {},
322 | "outputs": [
323 | {
324 | "name": "stdout",
325 | "output_type": "stream",
326 | "text": [
327 | "Populated 60000 Samples in Episodes : 1200\n"
328 | ]
329 | }
330 | ],
331 | "source": [
332 | "mem = Memory(memsize=MEMORY_SIZE)\n",
333 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Primary Network\n",
334 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Target Network\n",
335 | "frameObj = FrameCollector(img_dim=INPUT_IMAGE_DIM,num_frames=4)\n",
336 | "\n",
337 | "target_model.load_state_dict(main_model.state_dict())\n",
338 | "criterion = nn.SmoothL1Loss()\n",
339 | "optimizer = torch.optim.Adam(main_model.parameters())\n",
340 | "\n",
341 | "# filling memory with transitions\n",
342 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n",
343 | " \n",
344 | " prev_state = env.reset()\n",
345 | " frameObj.reset()\n",
346 | " processed_prev_state = preprocess_image(prev_state)\n",
347 | " frameObj.add_frame(processed_prev_state)\n",
348 | " prev_frames = frameObj.get_state()\n",
349 | " step_count = 0\n",
350 | " game_over = False\n",
351 | " \n",
352 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
353 | " \n",
354 | " step_count +=1\n",
355 | " action = np.random.randint(0,4)\n",
356 | " next_state,reward, game_over = env.step(action)\n",
357 | " processed_next_state = preprocess_image(next_state)\n",
358 | " frameObj.add_frame(processed_next_state)\n",
359 | " next_frames = frameObj.get_state()\n",
360 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n",
361 | " \n",
362 | " prev_state = next_state\n",
363 | " processed_prev_state = processed_next_state\n",
364 | " prev_frames = next_frames\n",
365 | "\n",
366 | "print('Populated %d Samples in Episodes : %d'%(len(mem.memory),int(MEMORY_SIZE/MAX_STEPS)))\n",
367 | "\n",
368 | "\n",
369 | "# Algorithm Starts\n",
370 | "total_steps = 0\n",
371 | "epsilon = INITIAL_EPSILON\n",
372 | "loss_stat = []\n",
373 | "total_reward_stat = []\n",
374 | "\n",
375 | "for episode in range(0,TOTAL_EPISODES):\n",
376 | " \n",
377 | " prev_state = env.reset()\n",
378 | " frameObj.reset()\n",
379 | " processed_prev_state = preprocess_image(prev_state)\n",
380 | " frameObj.add_frame(processed_prev_state)\n",
381 | " prev_frames = frameObj.get_state()\n",
382 | " game_over = False\n",
383 | " step_count = 0\n",
384 | " total_reward = 0\n",
385 | " \n",
386 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
387 | " \n",
388 | " step_count +=1\n",
389 | " total_steps +=1\n",
390 | " \n",
391 | " if np.random.rand() <= epsilon:\n",
392 | " action = np.random.randint(0,4)\n",
393 | " else:\n",
394 | " with torch.no_grad():\n",
395 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n",
396 | "\n",
397 | " model_out = main_model.forward(torch_x,bsize=1)\n",
398 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
399 | " \n",
400 | " next_state, reward, game_over = env.step(action)\n",
401 | " processed_next_state = preprocess_image(next_state)\n",
402 | " frameObj.add_frame(processed_next_state)\n",
403 | " next_frames = frameObj.get_state()\n",
404 | " total_reward += reward\n",
405 | " \n",
406 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n",
407 | " \n",
408 | " prev_state = next_state\n",
409 | " processed_prev_state = processed_next_state\n",
410 | " prev_frames = next_frames\n",
411 | " \n",
412 | " if (total_steps % FREEZE_INTERVAL) == 0:\n",
413 | " target_model.load_state_dict(main_model.state_dict())\n",
414 | " \n",
415 | " batch = mem.get_batch(size=BATCH_SIZE)\n",
416 | " current_states = []\n",
417 | " next_states = []\n",
418 | " acts = []\n",
419 | " rewards = []\n",
420 | " game_status = []\n",
421 | " \n",
422 | " for element in batch:\n",
423 | " current_states.append(element[0])\n",
424 | " acts.append(element[1])\n",
425 | " rewards.append(element[2])\n",
426 | " next_states.append(element[3])\n",
427 | " game_status.append(element[4])\n",
428 | " \n",
429 | " current_states = np.array(current_states)\n",
430 | " next_states = np.array(next_states)\n",
431 | " rewards = np.array(rewards)\n",
432 | " game_status = [not b for b in game_status]\n",
433 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n",
434 | " torch_acts = torch.tensor(acts)\n",
435 | " \n",
436 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n",
437 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n",
438 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n",
439 | " Q_max_next = Q_max_next.double()\n",
440 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n",
441 | " \n",
442 | " target_values = (rewards + (GAMMA * Q_max_next))\n",
443 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n",
444 | " \n",
445 | " loss = criterion(Q_s_a,target_values.float().cuda())\n",
446 | " \n",
447 | " # save performance measure\n",
448 | " loss_stat.append(loss.item())\n",
449 | " \n",
450 | " # make previous grad zero\n",
451 | " optimizer.zero_grad()\n",
452 | " \n",
453 | " # back - propogate \n",
454 | " loss.backward()\n",
455 | " \n",
456 | " # update params\n",
457 | " optimizer.step()\n",
458 | " \n",
459 | " # save performance measure\n",
460 | " total_reward_stat.append(total_reward)\n",
461 | " \n",
462 | " if epsilon > FINAL_EPSILON:\n",
463 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
464 | " \n",
465 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
466 | " perf = {}\n",
467 | " perf['loss'] = loss_stat\n",
468 | " perf['total_reward'] = total_reward_stat\n",
469 | " save_obj(name='FOUR_OBSERV_NINE',obj=perf)\n",
470 | " \n",
471 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)\n"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {},
477 | "source": [
478 | " Save Primary Network Weights "
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 12,
484 | "metadata": {
485 | "collapsed": true
486 | },
487 | "outputs": [],
488 | "source": [
489 | "torch.save(main_model.state_dict(),'data/FOUR_OBSERV_NINE_WEIGHTS.torch')"
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "metadata": {},
495 | "source": [
496 | " Testing Policy "
497 | ]
498 | },
499 | {
500 | "cell_type": "markdown",
501 | "metadata": {},
502 | "source": [
503 | " Load Primary Network Weights "
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 13,
509 | "metadata": {
510 | "collapsed": true
511 | },
512 | "outputs": [],
513 | "source": [
514 | "weights = torch.load('data/FOUR_OBSERV_NINE_WEIGHTS.torch')\n",
515 | "main_model.load_state_dict(weights)"
516 | ]
517 | },
518 | {
519 | "cell_type": "markdown",
520 | "metadata": {},
521 | "source": [
522 | " Testing Policy "
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": null,
528 | "metadata": {
529 | "collapsed": true
530 | },
531 | "outputs": [],
532 | "source": [
533 | "# Algorithm Starts\n",
534 | "epsilon = INITIAL_EPSILON\n",
535 | "FINAL_EPSILON = 0.01\n",
536 | "total_reward_stat = []\n",
537 | "\n",
538 | "for episode in range(0,TOTAL_EPISODES):\n",
539 | " \n",
540 | " prev_state = env.reset()\n",
541 | " processed_prev_state = preprocess_image(prev_state)\n",
542 | " frameObj.reset()\n",
543 | " frameObj.add_frame(processed_prev_state)\n",
544 | " prev_frames = frameObj.get_state()\n",
545 | " game_over = False\n",
546 | " step_count = 0\n",
547 | " total_reward = 0\n",
548 | " \n",
549 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
550 | " \n",
551 | " step_count +=1\n",
552 | " \n",
553 | " if np.random.rand() <= epsilon:\n",
554 | " action = np.random.randint(0,4)\n",
555 | " else:\n",
556 | " with torch.no_grad():\n",
557 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n",
558 | "\n",
559 | " model_out = main_model.forward(torch_x,bsize=1)\n",
560 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
561 | " \n",
562 | " next_state, reward, game_over = env.step(action)\n",
563 | " processed_next_state = preprocess_image(next_state)\n",
564 | " frameObj.add_frame(processed_next_state)\n",
565 | " next_frames = frameObj.get_state()\n",
566 | " \n",
567 | " total_reward += reward\n",
568 | " \n",
569 | " prev_state = next_state\n",
570 | " processed_prev_state = processed_next_state\n",
571 | " prev_frames = next_frames\n",
572 | " \n",
573 | " # save performance measure\n",
574 | " total_reward_stat.append(total_reward)\n",
575 | " \n",
576 | " if epsilon > FINAL_EPSILON:\n",
577 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
578 | " \n",
579 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
580 | " perf = {}\n",
581 | " perf['total_reward'] = total_reward_stat\n",
582 | " save_obj(name='FOUR_OBSERV_NINE',obj=perf)\n",
583 | " \n",
584 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)"
585 | ]
586 | }
587 | ],
588 | "metadata": {
589 | "kernelspec": {
590 | "display_name": "Python [conda env:myenv]",
591 | "language": "python",
592 | "name": "conda-env-myenv-py"
593 | },
594 | "language_info": {
595 | "codemirror_mode": {
596 | "name": "ipython",
597 | "version": 3
598 | },
599 | "file_extension": ".py",
600 | "mimetype": "text/x-python",
601 | "name": "python",
602 | "nbconvert_exporter": "python",
603 | "pygments_lexer": "ipython3",
604 | "version": "3.6.5"
605 | }
606 | },
607 | "nbformat": 4,
608 | "nbformat_minor": 2
609 | }
610 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/MDP_Size_9-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from gridworld import gameEnv\n",
10 | "import numpy as np\n",
11 | "%matplotlib inline\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "from collections import deque\n",
14 | "import pickle\n",
15 | "from skimage.color import rgb2gray\n",
16 | "import random"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | " Define Environment Object "
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 3,
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "data": {
33 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADONJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUMtCiKqWsAtIRH0IkKgqIoqJG7SFppIiWkvoki9SKQqCRdVJARJUUX5EwKNZUWk1CGqKlUO5k8TsCE2xARbBpsUSkqltk7eXuy4PXFtzhyf3T07/n0/0urszOye+Y1Hz5nZ2fH7pqqQ1JZfWO4BSJo+gy81yOBLDTL4UoMMvtQggy81yOBLDVpS8JNck+SFJLuTbBrXoCRNVo73Bp4kK4AfABuBvcATwE1VtWN8w5M0CSuX8N5Lgd1V9RJAkvuB64FjBv/MM8+subm5JaxS0jvZs2cPr7/+ehZ63VKCfzbwyrzpvcBvvtMb5ubm2L59+xJWKemdbNiwodfrJn5xL8ktSbYn2X7w4MFJr05SD0sJ/j7gnHnTa7t5P6eq7qyqDVW14ayzzlrC6iSNy1KC/wRwXpJ1SU4GbgQ2j2dYkibpuD/jV9WhJH8EfAtYAXylqp4b28gkTcxSLu5RVd8EvjmmsUiaEu/ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGLem/5c6CZMG6gtLMWq429R7xpQYZfKlBCwY/yVeSHEjy7Lx5q5I8lmRX9/P0yQ5T0jj1OeL/NXDNEfM2AVur6jxgazctaSAWDH5V/SPwr0fMvh64p3t+D/D7Yx6XpAk63s/4q6tqf/f8VWD1mMYjaQqWfHGvRt9HHPM7CTvpSLPneIP/WpI1AN3PA8d6oZ10pNlzvMHfDHyse/4x4BvjGY6kaejzdd59wD8D5yfZm+Rm4HPAxiS7gKu7aUkDseAtu1V10zEWXTXmsUiaEu/ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUp/TWOUkeT7IjyXNJbu3m201HGqg+R/xDwCer6kLgMuDjSS7EbjrSYPXppLO/qp7qnv8E2Amcjd10pMFa1Gf8JHPAxcA2enbTsaGGNHt6Bz/Je4GvA7dV1Vvzl71TNx0bakizp1fwk5zEKPT3VtXD3eze3XQkzZY+V/UD3A3srKovzFtkNx1poBZsqAFcAfwB8P0kz3Tz/oxR95wHu846LwM3TGaIksatTyedfwJyjMV205EGyDv3pAb1OdXXCeioX8EsQcb5C491fjlDxv3vN20e8aUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfalCfmnunJPlukn/pOul8tpu/Lsm2JLuTPJDk5MkPV9I49Dni/ydwZVVdBKwHrklyGfB54ItV9X7gDeDmyQ1T0jj16aRTVfXv3eRJ3aOAK4GHuvl20pEGpG9d/RVdhd0DwGPAi8CbVXWoe8leRm21jvZeO+lIM6ZX8Kvqp1W1HlgLXApc0HcFdtKZTRnzY7y/bPYNfVMXdVW/qt4EHgcuB05LcrhY51pg35jHJmlC+lzVPyvJad3zdwMbGXXMfRz4SPcyO+lIA9KnvPYa4J4kKxj9oXiwqrYk2QHcn+QvgKcZtdmSNAB9Oul8j1Fr7CPnv8To876kgfHOPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBvYPfldh+OsmWbtpOOtJALeaIfyujIpuH2UlHGqi+DTXWAh8C7uqmg510pMHqe8T/EvAp4Gfd9BnYSUcarD519T8MHKiqJ49nBXbSkWZPn7r6VwDXJbkWOAX4JeAOuk463VHfTjrSgPTplnt7Va2tqjngRuDbVfVR7KQjDdZSvsf/NPCJJLsZfea3k440EH1O9f9XVX0H+E733E460kB5557UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDFnUDjxapxvz7Mubfp2Z5xJcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9vsdPsgf4CfBT4FBVbUiyCngAmAP2ADdU1RuTGaakcVrMEf93qmp9VW3opjcBW6vqPGBrNy1pAJZyqn89o0YaYEMNaVD6Br+Av0/yZJJbunmrq2p/9/xVYPXYRydpIvreq/+BqtqX5JeBx5I8P39hVVWSo96Z3v2huAXg3HPPXdJgJY1HryN+Ve3rfh4AHmFUXfe1JGsAup8HjvFeO+lIM6ZPC61Tk/zi4efA7wLPApsZNdIAG2pIg9LnVH818MioQS4rgb+tqkeTPAE8mORm4GXghskNU9I4LRj8rnHGRUeZ/2PgqkkMStJkeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBewU9yWpKHkjyfZGeSy5OsSvJYkl3dz9MnPVhJ49H3iH8H8GhVXcCoDNdO7KQjDVafKrvvA34LuBugqv6rqt7ETjrSYPU54q8DDgJfTfJ0kru6Mtt20pEGqk/wVwKXAF+uqouBtznitL6qilGbrf8nyS1JtifZfvDgwaWOV9IY9An+XmBvVW3rph9i9IfATjoLyZgfs6zG+NDELRj8qnoVeCXJ+d2sq4Ad2ElHGqy+TTP/GLg3ycnAS8AfMvqjYScdaYB6Bb+qngE2HGWRnXSkAfLOPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBferqn5/kmXmPt5LcduJ10hlntcgGq0a2UlT0BNGn2OYLVbW+qtYDvwH8B/AIdtKRBmuxp/pXAS9W1cvYSUcarMUG/0bgvu65nXSkgeod/K609nXA145cZicdaVgWc8T/IPBUVb3WTdtJRxqoxQT/Jv7vNB/spCMNVq/gd91xNwIPz5v9OWBjkl3A1d20pAHo20nnbeCMI+b9GDvpSIPknXtSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSg3rduTfLRv8xcFbN8tjUMo/4UoMMvtQggy81yOBLDTL4UoMMvtQggy81qG/prT9N8lySZ5Pcl+SUJOuSbEuyO8kDXRVeSQPQp4XW2cCfABuq6teBFYzq638e+GJVvR94A7h5kgOVND59T/VXAu9OshJ4D7AfuBJ4qFtuJx1pQPr0ztsH/CXwI0aB/zfgSeDNqjrUvWwvcPakBilpvPqc6p/OqE/eOuBXgFOBa/quwE460uzpc6p/NfDDqjpYVf/NqLb+FcBp3ak/wFpg39HebCcdafb0Cf6PgMuSvCdJGNXS3wE8Dnyke42ddKQB6fMZfxuji3hPAd/v3nMn8GngE0l2M2q2cfcExylpjPp20vkM8JkjZr8EXDr2EUmaOO/ckxpk8KUGGXypQQZfalCmWawyyUHgbeD1qa108s7E7ZlVJ9K2QL/t+dWqWvCGmakGHyDJ9qraMNWVTpDbM7tOpG2B8W6Pp/pSgwy+1KDlCP6dy7DOSXJ7ZteJtC0wxu2Z+md8ScvPU32pQVMNfpJrkrzQ1enbNM11L1WSc5I8nmRHV3/w1m7+qiSPJdnV/Tx9uce6GElWJHk6yZZuerC1FJOcluShJM8n2Znk8iHvn0nWupxa8JOsAP4K+CBwIXBTkguntf4xOAR8sqouBC4DPt6NfxOwtarOA7Z200NyK7Bz3vSQayneATxaVRcAFzHarkHun4nXuqyqqTyAy4FvzZu+Hbh9WuufwPZ8A9gIvACs6eatAV5Y7rEtYhvWMgrDlcAWIIxuEFl5tH02yw/gfcAP6a5bzZs/yP3DqJTdK8AqRv+Ldgvwe+PaP9M81T+8IYcNtk5fkjngYmAbsLqq9neLXgVWL9OwjseXgE8BP+umz2C4tRTXAQeBr3YfXe5KcioD3T814VqXXtxbpCTvBb4O3FZVb81fVqM/w4P4miTJh4EDVfXkco9lTFYClwBfrqqLGd0a/nOn9QPbP0uqdbmQaQZ/H3DOvOlj1umbVUlOYhT6e6vq4W72a0nWdMvXAAeWa3yLdAVwXZI9wP2MTvfvoGctxRm0F9hbo4pRMKoadQnD3T9LqnW5kGkG/wngvO6q5MmMLlRsnuL6l6SrN3g3sLOqvjBv0WZGNQdhQLUHq+r2qlpbVXOM9sW3q+qjDLSWYlW9CryS5Pxu1uHakIPcP0y61uWUL1hcC/wAeBH48+W+gLLIsX+A0Wni94Bnuse1jD4XbwV2Af8ArFrusR7Htv02sKV7/mvAd4HdwNeAdy33+BaxHeuB7d0++jvg9CHvH+CzwPPAs8DfAO8a1/7xzj2pQV7ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfatD/ADPvEPBps6hQAAAAAElFTkSuQmCC\n",
34 | "text/plain": [
35 | ""
36 | ]
37 | },
38 | "metadata": {
39 | "needs_background": "light"
40 | },
41 | "output_type": "display_data"
42 | }
43 | ],
44 | "source": [
45 | "env = gameEnv(partial=False,size=9)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 4,
51 | "metadata": {},
52 | "outputs": [
53 | {
54 | "data": {
55 | "text/plain": [
56 | ""
57 | ]
58 | },
59 | "execution_count": 4,
60 | "metadata": {},
61 | "output_type": "execute_result"
62 | },
63 | {
64 | "data": {
65 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADPlJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUMtCiKqWsAtIRH0IkKgqIoqJG7SFppICbQXUaReJFKVhIsqEoKkqKL8CYEGWREpdYiiSJWD+dMEbIgNMcEWYJNCSanU1snbix23B9fGc3xmz9nx7/uRVrszu3vmN2f0nJmdnfO+qSokteWXlnsAkpaewZcaZPClBhl8qUEGX2qQwZcaZPClBi0q+EmuSPJckp1Jbh5qUJKmK0d7AU+SFcCPgI3AbuAx4Lqq2jbc8CRNw8pFvPdCYGdVvQCQ5B7gauCwwT/11FNrbm5uEYuU9E527drFa6+9liO9bjHBPx14ad70buC33+kNc3NzbN26dRGLlPRONmzY0Ot1Uz+5l+SGJFuTbN23b9+0Fyeph8UEfw9wxrzptd28t6mq26pqQ1VtOO200xaxOElDWUzwHwPOSrIuyfHAtcBDwwxL0jQd9Wf8qtqf5E+AbwErgK9U1TODjUzS1Czm5B5V9U3gmwONRdIS8co9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2rQov4tdxYkR6wr2M80uoUPNDQdu5arTb17fKlBBl9q0BGDn+QrSfYmeXrevFVJHkmyo7s/ebrDlDSkPnv8vwWuOGjezcDmqjoL2NxNSxqJIwa/qr4L/OtBs68G7uwe3wn84cDjkjRFR/sZf3VVvdw9fgVYPdB4JC2BRZ/cq8n3EYf9TsJOOtLsOdrgv5pkDUB3v/dwL7STjjR7jjb4DwEf6x5/DPjGMMORtBT6fJ13N/DPwNlJdie5HvgcsDHJDuDyblrSSBzxkt2quu4wT1028FgkLRGv3JMaZPClBhl8qUEGX2qQwZcaZPClBo2+As9grJajhrjHlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Sm9dUaSR5NsS/JMkhu7+XbTkUaqzx5/P/DJqjoXuAj4eJJzsZuONFp9Oum8XFVPdI9/BmwHTsduOtJoLegzfpI54HxgCz276dhQQ5o9vYOf5L3A14GbqurN+c+9UzcdG2pIs6dX8JMcxyT0d1XVA93s3t10JM2WPmf1A9wBbK+qL8x7ym460kj1qcBzCfBHwA+TPNXN+wsm3XPu6zrrvAhcM50hShpan0463+PwhanspiONkFfuSQ2y2OaoHPKLk6PUUHXRIX9tB4z81+ceX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBvWpuXdCku8n+Zeuk85nu/nrkmxJsjPJvUmOn/5wJQ2hzx7/P4FLq+o8YD1wRZKLgM8DX6yq9wOvA9dPb5iShtSnk05V1b93k8d1twIuBe7v5ttJRxqRvnX1V3QVdvcCjwDPA29U1f7uJbuZtNU61HvtpCPNmF7Br6qfV9V6YC1wIXBO3wXYSWdIGfDWkCF/bcfIr29BZ/Wr6g3gUeBi4KQkB4p1rgX2DDw2SVPS56z+aUlO6h6/G9jIpGPuo8BHupfZSUcakT7ltdcAdyZZweQPxX1VtSnJNuCeJH8FPMmkzZakEejTSecHTFpjHzz/BSaf9yWNjFfuSQ0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw3qHfyuxPaTSTZ103bSkUZqIXv8G5kU2TzATjrSSPVtqLEW+BBwezcd7KQjjVbfPf6XgE8Bv+imT8FOOtJo9amr/2Fgb1U9fjQLsJOONHv61NW/BLgqyZXACcCvALfSddLp9vp20pFGpE+33Fuqam1VzQHXAt+uqo9iJx1ptBbzPf6ngU8k2cnkM7+ddKSR6HOo/7+q6jvAd7rHdtKRRsor96QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUqxBHkl3Az4CfA/urakOSVcC9wBywC7imql6fzjAlDWkhe/zfq6r1VbWhm74Z2FxVZwGbu2lJI7CYQ/2rmTTSABtqSKPSN/gF/GOSx5Pc0M1bXVUvd49fAVYPPjpJU9G32OYHqmpPkl8FHkny7Pwnq6qS1KHe2P2huAHgzDPPXNRgJQ2j1x6/qvZ093uBB5lU1301yRqA7n7vYd5rJx1pxvRpoXVikl8+8Bj4feBp4CEmjTTAhhrSqPQ51F8NPDhpkMtK4O+r6uEkjwH3JbkeeBG4ZnrDlDSkIwa/a5xx3iHm/xS4bBqDkjRdXrknNcjgSw1aUO+8mXTILxGP4sdkmJ8z3xR+pDQI9/hSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBewU9yUpL7kzybZHuSi5OsSvJIkh3d/cnTHqykYfTd498KPFxV5zApw7UdO+lIo9Wnyu77gN8B7gCoqv+qqjewk440Wn32+OuAfcBXkzyZ5PauzLaddKSR6hP8lcAFwJer6nzgLQ46rK+q4jBFsJLckGRrkq379u1b7HglDaBP8HcDu6tqSzd9P5M/BLPRSSfD3Ab6MW+76dhVA92WyxGDX1WvAC8lObubdRmwDTvpSKPVt8runwJ3JTkeeAH4YyZ/NOykI41Qr+BX1VPAhkM8ZScdaYS8ck9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUJ+6+mcneWre7c0kN9lJp4ehKjIud2VG/T9jL8jap9jmc1W1vqrWA78F/AfwIHbSkUZroYf6lwHPV9WL2ElHGq2FBv9a4O7usZ10pJHqHfyutPZVwNcOfs5OOtK4LGSP/0Hgiap6tZuejU46khZsIcG/jv87zAc76Uij1Sv4XXfcjcAD82Z/DtiYZAdweTctaQT6dtJ5CzjloHk/xU460ih55Z7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoF5X7s2yyT8GNqKhVdV0uceXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBfUtv/XmSZ5I8neTuJCckWZdkS5KdSe7tqvBKGoE+LbROB/4M2FBVvwmsYFJf//PAF6vq/cDrwPXTHKik4fQ91F8JvDvJSuA9wMvApcD93fN20pFGpE/vvD3AXwM/YRL4fwMeB96oqv3dy3YDp09rkJKG1edQ/2QmffLWAb8GnAhc0XcBdtKRZk+fQ/3LgR9X1b6q+m8mtfUvAU7qDv0B1gJ7DvVmO+lIs6dP8H8CXJTkPUnCpJb+NuBR4CPda+ykI41In8/4W5icxHsC+GH3ntuATwOfSLKTSbONO6Y4TkkD6ttJ5zPAZw6a/QJw4eAjkjR1XrknNcjgSw0y+FKDDL7UoCxlscok+4C3gNeWbKHTdyquz6w6ltYF+q3Pr1fVES+YWdLgAyTZWlUblnShU+T6zK5jaV1g2PXxUF9qkMGXGrQcwb9tGZY5Ta7P7DqW1gUGXJ8l/4wvafl5qC81aEmDn+SKJM91dfpuXsplL1aSM5I8mmRbV3/wxm7+qiSPJNnR3Z+83GNdiCQrkjyZZFM3PdpaiklOSnJ/kmeTbE9y8Zi3zzRrXS5Z8JOsAP4G+CBwLnBdknOXavkD2A98sqrOBS4CPt6N/2Zgc1WdBWzupsfkRmD7vOkx11K8FXi4qs4BzmOyXqPcPlOvdVlVS3IDLga+NW/6FuCWpVr+FNbnG8BG4DlgTTdvDfDcco9tAeuwlkkYLgU2AWFygcjKQ22zWb4B7wN+THfeat78UW4fJqXsXgJWMfkv2k3AHwy1fZbyUP/Aihww2jp9SeaA84EtwOqqerl76hVg9TIN62h8CfgU8Itu+hTGW0txHbAP+Gr30eX2JCcy0u1TU6516cm9BUryXuDrwE1V9eb852ryZ3gUX5Mk+TCwt6oeX+6xDGQlcAHw5ao6n8ml4W87rB/Z9llUrcsjWcrg7wHOmDd92Dp9syrJcUxCf1dVPdDNfjXJmu75NcDe5RrfAl0CXJVkF3APk8P9W+lZS3EG7QZ216RiFEyqRl3AeLfPompdHslSBv8x4KzurOTxTE5UPLSEy1+Urt7gHcD2qvrCvKceYlJzEEZUe7CqbqmqtVU1x2RbfLuqPspIaylW1SvAS0nO7mYdqA05yu3DtGtdLvEJiyuBHwHPA3+53CdQFjj2DzA5TPwB8FR3u5LJ5+LNwA7gn4BVyz3Wo1i33wU2dY9/A/g+sBP4GvCu5R7fAtZjPbC120b/AJw85u0DfBZ4Fnga+DvgXUNtH6/ckxrkyT2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG/Q++8g/3YgFKrwAAAABJRU5ErkJggg==\n",
66 | "text/plain": [
67 | ""
68 | ]
69 | },
70 | "metadata": {
71 | "needs_background": "light"
72 | },
73 | "output_type": "display_data"
74 | }
75 | ],
76 | "source": [
77 | "prev_state = env.reset()\n",
78 | "plt.imshow(prev_state)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | " Training Q Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | " Hyper-parameters "
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 7,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "BATCH_SIZE = 64\n",
102 | "FREEZE_INTERVAL = 20000 # steps\n",
103 | "MEMORY_SIZE = 60000 \n",
104 | "OUTPUT_SIZE = 4\n",
105 | "TOTAL_EPISODES = 10000\n",
106 | "MAX_STEPS = 50\n",
107 | "INITIAL_EPSILON = 1.0\n",
108 | "FINAL_EPSILON = 0.01\n",
109 | "GAMMA = 0.99\n",
110 | "INPUT_IMAGE_DIM = 84\n",
111 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | " Save Dictionay Function "
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 5,
124 | "metadata": {
125 | "collapsed": true
126 | },
127 | "outputs": [],
128 | "source": [
129 | "def save_obj(obj, name ):\n",
130 | " with open('data/'+ name + '.pkl', 'wb') as f:\n",
131 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | " Experience Replay "
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 6,
144 | "metadata": {
145 | "collapsed": true
146 | },
147 | "outputs": [],
148 | "source": [
149 | "class Memory():\n",
150 | " \n",
151 | " def __init__(self,memsize):\n",
152 | " self.memsize = memsize\n",
153 | " self.memory = deque(maxlen=self.memsize)\n",
154 | " \n",
155 | " def add_sample(self,sample):\n",
156 | " self.memory.append(sample)\n",
157 | " \n",
158 | " def get_batch(self,size):\n",
159 | " return random.sample(self.memory,k=size)"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | " Preprocess Images "
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 5,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "def preprocess_image(image):\n",
176 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n",
177 | " return np.copy(image)"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 7,
183 | "metadata": {},
184 | "outputs": [
185 | {
186 | "data": {
187 | "text/plain": [
188 | ""
189 | ]
190 | },
191 | "execution_count": 7,
192 | "metadata": {},
193 | "output_type": "execute_result"
194 | },
195 | {
196 | "data": {
197 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADVNJREFUeJzt3X+oX/V9x/Hna4nW1m71VxYyo4ulosjA6C6ZYhmbmi1xRfdHEUVGGYL/dJtdC61usFLYHy2Mtv4xCkHbueH8UVtXCY2dSy1jMKLxx1o1WqONNUGTWHXtHGxL+94f3yO7zRLvubnne+89fp4PuHy/53x/nM/h8Pqe8z3fc9/vVBWS2vILSz0ASYvP4EsNMvhSgwy+1CCDLzXI4EsNMvhSgxYU/CSbkjybZHeSm4YalKTpyrFewJNkBfB9YCOwF3gEuLaqnh5ueJKmYeUCXrsB2F1VLwAkuQu4Cjhq8E877bRat27dAhYp6e3s2bOHV199NXM9byHBPx14adb0XuA33u4F69atY+fOnQtYpKS3MzMz0+t5Uz+5l+SGJDuT7Dx48OC0Fyeph4UEfx9wxqzptd28n1NVW6pqpqpmVq1atYDFSRrKQoL/CHB2krOSHA9cA9w/zLAkTdMxf8evqkNJ/gj4FrAC+HJVPTXYyCRNzUJO7lFV3wS+OdBYJC0Sr9yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYt6N9yl4NkzrqC0rK1VG3q3eNLDTL4UoPmDH6SLyc5kOTJWfNOSfJgkue625OnO0xJQ+qzx/8bYNNh824CtlfV2cD2blrSSMwZ/Kr6Z+C1w2ZfBdze3b8d+P2BxyVpio71O/7qqnq5u/8KsHqg8UhaBAs+uVeT3yOO+puEnXSk5edYg78/yRqA7vbA0Z5oJx1p+TnW4N8PfKS7/xHgG8MMR9Ji6PNz3p3AvwLnJNmb5Hrgs8DGJM8Bl3fTkkZizkt2q+raozx02cBjkbRIvHJPapDBlxpk8KUGGXypQQZfapDBlxo0+go8LdmwYcNg7/Xwww8P9l4aH/f4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPqW3zkjyUJKnkzyV5MZuvt10pJHqs8c/BHyiqs4DLgI+muQ87KYjjVafTjovV9Vj3f2fALuA07GbjjRa8/qOn2QdcAGwg57ddGyoIS0/vYOf5L3A14CPVdWPZz/2dt10bKghLT+9gp/kOCahv6Oqvt7N7t1NR9Ly0uesfoDbgF1V9flZD9lNRxqpPhV4LgH+APhekie6eX/GpHvOPV1nnReBq6czRElD69NJ51+AHOVhu+lII+SVe1KDRl9sc9u2bYO8z+bNmwd5n2myQKaG4h5fapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9am5d0KSh5P8W9dJ5zPd/LOS7EiyO8ndSY6f/nAlDaHPHv+/gEur6nxgPbApyUXA54AvVNUHgNeB66c3TElD6tNJp6rqP7rJ47q/Ai4F7u3m20lHGpG+dfVXdBV2DwAPAs8Db1TVoe4pe5m01TrSa+2kIy0zvWruVdVPgfVJTgLuA87tu4Cq2gJsAZiZmTlit52FGEOtvKFs2LBhsPeyfl/b5nVWv6reAB4CLgZOSvLWB8daYN/AY5M0JX3O6q/q9vQkeTewkUnH3IeAD3dPs5OONCJ9DvXXALcnWcHkg+Keqtqa5GngriR/CTzOpM2WpBHo00nnu0xaYx8+/wVguC+dkhaNV+5JDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6ld7S8mC5LA3FPb7UIIMvNah38LsS248n2dpN20lHGqn57PFvZFJk8y120pFGqm9DjbXA7wG3dtPBTjrSaPXd438R+CTws276VOykI41Wn7r6HwIOVNWjx7KAqtpSVTNVNbNq1apjeQtJA+vzO/4lwJVJrgBOAH4JuIWuk06317eTjjQifbrl3lxVa6tqHXAN8O2qug476UijtZDf8T8FfDzJbibf+e2kI43EvC7ZrarvAN/p7ttJRxopr9yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkDX3NIht27YN9l6bN28e7L10ZO7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/f8ZPsAX4C/BQ4VFUzSU4B7gbWAXuAq6vq9ekMU9KQ5rPH/+2qWl9VM930TcD2qjob2N5NSxqBhRzqX8WkkQbYUEMalb7BL+Afkzya5IZu3uqqerm7/wqwevDRSZqKvtfqf7Cq9iX5ZeDBJM/MfrCqKkkd6YXdB8UNAGeeeeaCBitpGL32+FW1r7s9ANzHpLru/iRrALrbA0d5rZ10pGWmTwutE5P84lv3gd8BngTuZ9JIA2yoIY1Kn0P91cB9kwa5rAT+vqoeSPIIcE+S64EXgaunN0xJQ5oz+F3jjPOPMP9HwGXTGJSk6fLKPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBvYKf5KQk9yZ5JsmuJBcnOSXJg0me625PnvZgJQ2j7x7/FuCBqjqXSRmuXdhJRxqtPlV23wf8JnAbQFX9d1W9gZ10pNHqs8c/CzgIfCXJ40lu7cps20lHGqk+wV8JXAh8qaouAN7ksMP6qiombbb+nyQ3JNmZZOfBgwcXOl5JA+hTV38vsLeqdnTT9zIJ/v4ka6rq5bk66QBbAGZmZo744aDxu+6665Z6CJqHOff4VfUK8FKSc7pZlwFPYycdabT6Ns38Y+COJMcDLwB/yORDw0460gj1Cn5VPQHMHOEhO+lII+SVe1KDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKD5qzA09Xau3vWrPcDfwH8bTd/HbAHuLqqXh9+iG9v27Ztg7zP5s2bB3mfVr322mtLPQTNQ59im89W1fqqWg/8OvCfwH3YSUcarfke6l8GPF9VL2InHWm05hv8a4A7u/t20pFGqnfwu9LaVwJfPfwxO+lI4zKfPf5m4LGq2t9N7+866DBXJ52qmqmqmVWrVi1stJIGMZ/gX8v/HeaDnXSk0eoV/K477kbg67NmfxbYmOQ54PJuWtII9O2k8yZw6mHzfoSddKRR8so9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9rtxbzjZt2jTI+0z+wVBqg3t8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZca1Lf01p8meSrJk0nuTHJCkrOS7EiyO8ndXRVeSSMwZ/CTnA78CTBTVb8GrGBSX/9zwBeq6gPA68D10xyopOH0PdRfCbw7yUrgPcDLwKXAvd3jdtKRRqRP77x9wF8BP2QS+H8HHgXeqKpD3dP2AqdPa5CShtXnUP9kJn3yzgJ+BTgR6H2BvJ10pOWnz6H+5cAPqupgVf0Pk9r6lwAndYf+AGuBfUd6sZ10pOWnT/B/CFyU5D1JwqSW/tPAQ8CHu+fYSUcakT7f8XcwOYn3GPC97jVbgE8BH0+ym0mzjdumOE5JA+rbSefTwKcPm/0CsGHwEUmaOq/ckxpk8KUGGXypQQZfalAWs8hkkoPAm8Cri7bQ6TsN12e5eietC/Rbn1+tqjkvmFnU4AMk2VlVM4u60ClyfZavd9K6wLDr46G+1CCDLzVoKYK/ZQmWOU2uz/L1TloXGHB9Fv07vqSl56G+1KBFDX6STUme7er03bSYy16oJGckeSjJ0139wRu7+ackeTDJc93tyUs91vlIsiLJ40m2dtOjraWY5KQk9yZ5JsmuJBePeftMs9blogU/yQrgr4HNwHnAtUnOW6zlD+AQ8ImqOg+4CPhoN/6bgO1VdTawvZsekxuBXbOmx1xL8Rbggao6FzifyXqNcvtMvdZlVS3KH3Ax8K1Z0zcDNy/W8qewPt8ANgLPAmu6eWuAZ5d6bPNYh7VMwnApsBUIkwtEVh5pmy3nP+B9wA/ozlvNmj/K7cOklN1LwClM/ot2K/C7Q22fxTzUf2tF3jLaOn1J1gEXADuA1VX1cvfQK8DqJRrWsfgi8EngZ930qYy3luJZwEHgK91Xl1uTnMhIt09NudalJ/fmKcl7ga8BH6uqH89+rCYfw6P4mSTJh4ADVfXoUo9lICuBC4EvVdUFTC4N/7nD+pFtnwXVupzLYgZ/H3DGrOmj1ulbrpIcxyT0d1TV17vZ+5Os6R5fAxxYqvHN0yXAlUn2AHcxOdy/hZ61FJehvcDemlSMgknVqAsZ7/ZZUK3LuSxm8B8Bzu7OSh7P5ETF/Yu4/AXp6g3eBuyqqs/Peuh+JjUHYUS1B6vq5qpaW1XrmGyLb1fVdYy0lmJVvQK8lOScbtZbtSFHuX2Ydq3LRT5hcQXwfeB54M+X+gTKPMf+QSaHid8Fnuj+rmDyvXg78BzwT8ApSz3WY1i33wK2dvffDzwM7Aa+Crxrqcc3j/VYD+zsttE/ACePefsAnwGeAZ4E/g5411Dbxyv3pAZ5ck9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlB/wsPrvc22jHaHAAAAABJRU5ErkJggg==\n",
198 | "text/plain": [
199 | ""
200 | ]
201 | },
202 | "metadata": {},
203 | "output_type": "display_data"
204 | }
205 | ],
206 | "source": [
207 | "processed_prev_state = preprocess_image(prev_state)\n",
208 | "plt.imshow(processed_prev_state,cmap='gray')"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "metadata": {},
214 | "source": [
215 | " Build Model "
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 8,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "Network(\n",
228 | " (conv_layer1): Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))\n",
229 | " (conv_layer2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
230 | " (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
231 | " (fc1): Linear(in_features=3136, out_features=512, bias=True)\n",
232 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n",
233 | " (relu): ReLU()\n",
234 | ")\n"
235 | ]
236 | }
237 | ],
238 | "source": [
239 | "import torch.nn as nn\n",
240 | "import torch\n",
241 | "\n",
242 | "class Network(nn.Module):\n",
243 | " \n",
244 | " def __init__(self,image_input_size,out_size):\n",
245 | " super(Network,self).__init__()\n",
246 | " self.image_input_size = image_input_size\n",
247 | " self.out_size = out_size\n",
248 | "\n",
249 | " self.conv_layer1 = nn.Conv2d(in_channels=1,out_channels=32,kernel_size=8,stride=4) # GRAY - 1\n",
250 | " self.conv_layer2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4,stride=2)\n",
251 | " self.conv_layer3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1)\n",
252 | " self.fc1 = nn.Linear(in_features=7*7*64,out_features=512)\n",
253 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n",
254 | " self.relu = nn.ReLU()\n",
255 | "\n",
256 | " def forward(self,x,bsize):\n",
257 | " x = x.view(bsize,1,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n",
258 | " conv_out = self.conv_layer1(x)\n",
259 | " conv_out = self.relu(conv_out)\n",
260 | " conv_out = self.conv_layer2(conv_out)\n",
261 | " conv_out = self.relu(conv_out)\n",
262 | " conv_out = self.conv_layer3(conv_out)\n",
263 | " conv_out = self.relu(conv_out)\n",
264 | " out = self.fc1(conv_out.view(bsize,7*7*64))\n",
265 | " out = self.relu(out)\n",
266 | " out = self.fc2(out)\n",
267 | " return out\n",
268 | "\n",
269 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE)\n",
270 | "print(main_model)"
271 | ]
272 | },
273 | {
274 | "cell_type": "markdown",
275 | "metadata": {},
276 | "source": [
277 | " Deep Q Learning with Target Freeze "
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [
285 | {
286 | "name": "stdout",
287 | "output_type": "stream",
288 | "text": [
289 | "Populated 200 Samples\n"
290 | ]
291 | }
292 | ],
293 | "source": [
294 | "mem = Memory(memsize=MEMORY_SIZE)\n",
295 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda()\n",
296 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda()\n",
297 | "\n",
298 | "target_model.load_state_dict(main_model.state_dict())\n",
299 | "criterion = nn.SmoothL1Loss()\n",
300 | "optimizer = torch.optim.Adam(main_model.parameters())\n",
301 | "\n",
302 | "# filling memory with transitions\n",
303 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n",
304 | " \n",
305 | " prev_state = env.reset()\n",
306 | " processed_prev_state = preprocess_image(prev_state)\n",
307 | " step_count = 0\n",
308 | " game_over = False\n",
309 | " \n",
310 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
311 | " \n",
312 | " step_count +=1\n",
313 | " action = np.random.randint(0,4)\n",
314 | " next_state,reward, game_over = env.step(action)\n",
315 | " processed_next_state = preprocess_image(next_state)\n",
316 | " mem.add_sample((processed_prev_state,action,reward,processed_next_state,game_over))\n",
317 | " \n",
318 | " prev_state = next_state\n",
319 | " processed_prev_state = processed_next_state\n",
320 | "\n",
321 | "print('Populated %d Samples'%(len(mem.memory)))\n",
322 | "\n",
323 | "# Algorithm Starts\n",
324 | "total_steps = 0\n",
325 | "epsilon = INITIAL_EPSILON\n",
326 | "loss_stat = []\n",
327 | "total_reward_stat = []\n",
328 | "\n",
329 | "for episode in range(0,TOTAL_EPISODES):\n",
330 | " \n",
331 | " prev_state = env.reset()\n",
332 | " processed_prev_state = preprocess_image(prev_state)\n",
333 | " game_over = False\n",
334 | " step_count = 0\n",
335 | " total_reward = 0\n",
336 | " \n",
337 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
338 | " \n",
339 | " step_count +=1\n",
340 | " total_steps +=1\n",
341 | " \n",
342 | " if np.random.rand() <= epsilon:\n",
343 | " action = np.random.randint(0,4)\n",
344 | " else:\n",
345 | " with torch.no_grad():\n",
346 | " torch_x = torch.from_numpy(processed_prev_state).float().cuda()\n",
347 | "\n",
348 | " model_out = main_model.forward(torch_x,bsize=1)\n",
349 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
350 | " \n",
351 | " next_state, reward, game_over = env.step(action)\n",
352 | " processed_next_state = preprocess_image(next_state)\n",
353 | " total_reward += reward\n",
354 | " \n",
355 | " mem.add_sample((processed_prev_state,action,reward,processed_next_state,game_over))\n",
356 | " \n",
357 | " prev_state = next_state\n",
358 | " processed_prev_state = processed_next_state\n",
359 | " \n",
360 | " if (total_steps % FREEZE_INTERVAL) == 0:\n",
361 | " target_model.load_state_dict(main_model.state_dict())\n",
362 | " \n",
363 | " batch = mem.get_batch(size=BATCH_SIZE)\n",
364 | " current_states = []\n",
365 | " next_states = []\n",
366 | " acts = []\n",
367 | " rewards = []\n",
368 | " game_status = []\n",
369 | " \n",
370 | " for element in batch:\n",
371 | " current_states.append(element[0])\n",
372 | " acts.append(element[1])\n",
373 | " rewards.append(element[2])\n",
374 | " next_states.append(element[3])\n",
375 | " game_status.append(element[4])\n",
376 | " \n",
377 | " current_states = np.array(current_states)\n",
378 | " next_states = np.array(next_states)\n",
379 | " rewards = np.array(rewards)\n",
380 | " game_status = [not b for b in game_status]\n",
381 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n",
382 | " torch_acts = torch.tensor(acts)\n",
383 | " \n",
384 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n",
385 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n",
386 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n",
387 | " Q_max_next = Q_max_next.double()\n",
388 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n",
389 | " \n",
390 | " target_values = (rewards + (GAMMA * Q_max_next)).cuda()\n",
391 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n",
392 | " \n",
393 | " loss = criterion(Q_s_a,target_values.float())\n",
394 | " \n",
395 | " # save performance measure\n",
396 | " loss_stat.append(loss.item())\n",
397 | " \n",
398 | " # make previous grad zero\n",
399 | " optimizer.zero_grad()\n",
400 | " \n",
401 | " # back - propogate \n",
402 | " loss.backward()\n",
403 | " \n",
404 | " # update params\n",
405 | " optimizer.step()\n",
406 | " \n",
407 | " # save performance measure\n",
408 | " total_reward_stat.append(total_reward)\n",
409 | " \n",
410 | " if epsilon > FINAL_EPSILON:\n",
411 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
412 | " \n",
413 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
414 | " perf = {}\n",
415 | " perf['loss'] = loss_stat\n",
416 | " perf['total_reward'] = total_reward_stat\n",
417 | " save_obj(name='MDP_ENV_SIZE_NINE',obj=perf)\n",
418 | " \n",
419 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)"
420 | ]
421 | },
422 | {
423 | "cell_type": "markdown",
424 | "metadata": {},
425 | "source": [
426 | " Save Primary Network Weights "
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 18,
432 | "metadata": {
433 | "collapsed": true
434 | },
435 | "outputs": [],
436 | "source": [
437 | "torch.save(main_model.state_dict(),'data/MDP_ENV_SIZE_NINE_WEIGHTS.torch')"
438 | ]
439 | },
440 | {
441 | "cell_type": "markdown",
442 | "metadata": {},
443 | "source": [
444 | " Testing Policy "
445 | ]
446 | },
447 | {
448 | "cell_type": "markdown",
449 | "metadata": {},
450 | "source": [
451 | " Load Primary Network Weights "
452 | ]
453 | },
454 | {
455 | "cell_type": "code",
456 | "execution_count": 9,
457 | "metadata": {},
458 | "outputs": [],
459 | "source": [
460 | "weights = torch.load('data/MDP_ENV_SIZE_NINE_WEIGHTS.torch', map_location='cpu')\n",
461 | "main_model.load_state_dict(weights)"
462 | ]
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "metadata": {},
467 | "source": [
468 | " Test Policy "
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": null,
474 | "metadata": {},
475 | "outputs": [],
476 | "source": [
477 | "# Algorithm Starts\n",
478 | "epsilon = INITIAL_EPSILON\n",
479 | "FINAL_EPSILON = 0.01\n",
480 | "total_reward_stat = []\n",
481 | "\n",
482 | "for episode in range(0,TOTAL_EPISODES):\n",
483 | " \n",
484 | " prev_state = env.reset()\n",
485 | " processed_prev_state = preprocess_image(prev_state)\n",
486 | " game_over = False\n",
487 | " step_count = 0\n",
488 | " total_reward = 0\n",
489 | " \n",
490 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
491 | " \n",
492 | " step_count +=1\n",
493 | " \n",
494 | " if np.random.rand() <= epsilon:\n",
495 | " action = np.random.randint(0,4)\n",
496 | " else:\n",
497 | " with torch.no_grad():\n",
498 | " torch_x = torch.from_numpy(processed_prev_state).float().cuda()\n",
499 | "\n",
500 | " model_out = main_model.forward(torch_x,bsize=1)\n",
501 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
502 | " \n",
503 | " next_state, reward, game_over = env.step(action)\n",
504 | " processed_next_state = preprocess_image(next_state)\n",
505 | " total_reward += reward\n",
506 | " \n",
507 | " prev_state = next_state\n",
508 | " processed_prev_state = processed_next_state\n",
509 | " \n",
510 | " # save performance measure\n",
511 | " total_reward_stat.append(total_reward)\n",
512 | " \n",
513 | " if epsilon > FINAL_EPSILON:\n",
514 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
515 | " \n",
516 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
517 | " perf = {}\n",
518 | " perf['total_reward'] = total_reward_stat\n",
519 | " save_obj(name='MDP_ENV_SIZE_NINE',obj=perf)\n",
520 | " \n",
521 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)"
522 | ]
523 | },
524 | {
525 | "cell_type": "markdown",
526 | "metadata": {},
527 | "source": [
528 | " Create Policy GIF "
529 | ]
530 | },
531 | {
532 | "cell_type": "markdown",
533 | "metadata": {},
534 | "source": [
535 | " Collect Frames Of an Episode Using Trained Network "
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": 28,
541 | "metadata": {},
542 | "outputs": [
543 | {
544 | "name": "stdout",
545 | "output_type": "stream",
546 | "text": [
547 | "Total Reward : 14\n"
548 | ]
549 | }
550 | ],
551 | "source": [
552 | "frames = []\n",
553 | "random.seed(110)\n",
554 | "np.random.seed(110)\n",
555 | "\n",
556 | "for episode in range(0,1):\n",
557 | " \n",
558 | " prev_state = env.reset()\n",
559 | " processed_prev_state = preprocess_image(prev_state)\n",
560 | " frames.append(prev_state)\n",
561 | " game_over = False\n",
562 | " step_count = 0\n",
563 | " total_reward = 0\n",
564 | " \n",
565 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
566 | " \n",
567 | " step_count +=1\n",
568 | " \n",
569 | " with torch.no_grad():\n",
570 | " torch_x = torch.from_numpy(processed_prev_state).float()\n",
571 | " model_out = main_model.forward(torch_x,bsize=1)\n",
572 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
573 | " \n",
574 | " next_state, reward, game_over = env.step(action)\n",
575 | " frames.append(next_state)\n",
576 | " processed_next_state = preprocess_image(next_state)\n",
577 | " total_reward += reward\n",
578 | " \n",
579 | " prev_state = next_state\n",
580 | " processed_prev_state = processed_next_state\n",
581 | "\n",
582 | "print('Total Reward : %d'%(total_reward)) # This should output same value which verifies seed is working correctly\n",
583 | " "
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": 29,
589 | "metadata": {},
590 | "outputs": [],
591 | "source": [
592 | "from PIL import Image, ImageDraw\n",
593 | "\n",
594 | "for idx, img in enumerate(frames):\n",
595 | " image = Image.fromarray(img)\n",
596 | " drawer = ImageDraw.Draw(image)\n",
597 | " drawer.rectangle([(7,7),(76,76)], outline=(255, 255, 0))\n",
598 | " #plt.imshow(np.array(image))\n",
599 | " frames[idx] = np.array(image)\n",
600 | " "
601 | ]
602 | },
603 | {
604 | "cell_type": "markdown",
605 | "metadata": {},
606 | "source": [
607 | " Frames to GIF "
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": 30,
613 | "metadata": {},
614 | "outputs": [
615 | {
616 | "name": "stderr",
617 | "output_type": "stream",
618 | "text": [
619 | "/home/mayank/miniconda3/envs/rdqn/lib/python3.5/site-packages/ipykernel_launcher.py:5: DeprecationWarning: `imresize` is deprecated!\n",
620 | "`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
621 | "Use ``skimage.transform.resize`` instead.\n",
622 | " \"\"\"\n"
623 | ]
624 | }
625 | ],
626 | "source": [
627 | "import imageio\n",
628 | "from scipy.misc import imresize\n",
629 | "resized_frames = []\n",
630 | "for frame in frames:\n",
631 | " resized_frames.append(imresize(frame,(256,256)))\n",
632 | "imageio.mimsave('data/GIFs/MDP_SIZE_9.gif',resized_frames,fps=4)"
633 | ]
634 | }
635 | ],
636 | "metadata": {
637 | "kernelspec": {
638 | "display_name": "Python 3",
639 | "language": "python",
640 | "name": "python3"
641 | },
642 | "language_info": {
643 | "codemirror_mode": {
644 | "name": "ipython",
645 | "version": 3
646 | },
647 | "file_extension": ".py",
648 | "mimetype": "text/x-python",
649 | "name": "python",
650 | "nbconvert_exporter": "python",
651 | "pygments_lexer": "ipython3",
652 | "version": "3.5.6"
653 | }
654 | },
655 | "nbformat": 4,
656 | "nbformat_minor": 2
657 | }
658 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/Two Observations-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "from gridworld import gameEnv\n",
12 | "import numpy as np\n",
13 | "%matplotlib inline\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "from collections import deque\n",
16 | "import pickle\n",
17 | "from skimage.color import rgb2gray\n",
18 | "import random\n",
19 | "import torch\n",
20 | "import torch.nn as nn"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | " Define Environment Object "
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "metadata": {},
34 | "outputs": [
35 | {
36 | "data": {
37 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADKdJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUOtFERUtYBbQiLoRYRAURVVSNykLTSREmivIvUikaokXFSRECRFFeVPCDTIikipQ1RVqhyOgSZgQ2wIBFuA3RRKSqW2Tt5ezDg9ce2cOT67e87w+36k1e7M7Gp+o/GzMzue876pKiS15RdWegCSZs/gSw0y+FKDDL7UIIMvNcjgSw0y+FKDlhX8JFcmeS7J3iS3TGpQkqYrx3sDT5I1wPeArcA+4HHg+qraNbnhSZqGtcv47PuAvVX1AkCSe4FrgGMG//TTT6+5ubllrPLtb+fOnSs9BI1cVWWx9ywn+GcCLy+Y3gf85s/7wNzcHPPz88tY5dtfsug+k5Zt6hf3ktyYZD7J/MGDB6e9OkkDLCf4+4GzFkxv7Of9jKq6vaq2VNWWM844YxmrkzQpywn+48A5STYlORG4Dnh4MsOSNE3H/Ru/qg4l+SPgG8Aa4EtV9czERiZpapZzcY+q+jrw9QmNRdKMeOee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBFg5/kS0kOJHl6wbx1SR5Nsqd/PnW6w5Q0SUOO+H8NXHnEvFuA7VV1DrC9n5Y0EosGv6r+Efi3I2ZfA9zVv74L+P0Jj0vSFB3vb/z1VfVK//pVYP2ExiNpBpZ9ca+6rpvH7LxpJx1p9Tne4L+WZANA/3zgWG+0k460+hxv8B8GPtq//ijwtckMR9IsDPnvvHuAfwbOTbIvyQ3AZ4CtSfYAV/TTkkZi0U46VXX9MRZdPuGxSJoR79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGjSk9NZZSR5LsivJM0lu6ufbTUcaqSFH/EPAJ6rqfOBi4GNJzsduOtJoDemk80pVPdG//hGwGzgTu+lIo7Wk3/hJ5oALgR0M7KZjQw1p9Rkc/CTvBr4K3FxVby5c9vO66dhQQ1p9BgU/yQl0ob+7qh7sZw/upiNpdRlyVT/AncDuqvrcgkV205FGatGGGsClwB8A303yVD/vz+i659zfd9Z5Cbh2OkOUNGlDOun8E5BjLLabjjRC3rknNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0a8vf4mqmjVjDTTx3rL8S1FB7xpQYZfKlBQ2runZTk20n+pe+k8+l+/qYkO5LsTXJfkhOnP1xJkzDkiP9fwGVVdQGwGbgyycXAZ4HPV9V7gdeBG6Y3TEmTNKSTTlXVf/STJ/SPAi4DHujn20lHGpGhdfXX9BV2DwCPAs8Db1TVof4t++jaah3ts3bSkVaZQcGvqh9X1WZgI/A+4LyhK7CTjrT6LOmqflW9ATwGXAKckuTwfQAbgf0THpukKRlyVf+MJKf0r98JbKXrmPsY8OH+bXbSkUZkyJ17G4C7kqyh+6K4v6q2JdkF3JvkL4An6dpsSRqBIZ10vkPXGvvI+S/Q/d6XNDLeuSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aHDw+xLbTybZ1k/bSUcaqaUc8W+iK7J5mJ10pJEa2lBjI/BB4I5+OthJRxqtoUf8LwCfBH7ST5+GnXSk0RpSV/9DwIGq2nk8K7CTjrT6DKmrfylwdZKrgJOAXwJuo++k0x/17aQjjciQbrm3VtXGqpoDrgO+WVUfwU460mgt5//xPwV8PMleut/8dtKRRmLIqf5PVdW3gG/1r+2kI42Ud+5JDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aFAhjiQvAj8CfgwcqqotSdYB9wFzwIvAtVX1+nSGKWmSlnLE/52q2lxVW/rpW4DtVXUOsL2fljQCyznVv4aukQbYUEMalaHBL+Dvk+xMcmM/b31VvdK/fhVYP/HRSZqKocU2319V+5P8MvBokmcXLqyqSlJH+2D/RXEjwNlnn72swUqajEFH/Kra3z8fAB6iq677WpINAP3zgWN81k460iozpIXWyUl+8fBr4HeBp4GH6RppgA01pFEZcqq/Hnioa5DLWuBvq+qRJI8D9ye5AXgJuHZ6w5Q0SYsGv2+cccFR5v8QuHwag5I0Xd65JzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVo6F/naWay0gNQAzziSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEHBT3JKkgeSPJtkd5JLkqxL8miSPf3zqdMerKTJGHrEvw14pKrOoyvDtRs76UijNaTK7nuA3wLuBKiq/66qN7CTjjRaQ474m4CDwJeTPJnkjr7Mtp10pJEaEvy1wEXAF6vqQuAtjjitr6qia7P1/yS5Mcl8kvmDBw8ud7ySJmBI8PcB+6pqRz/9AN0XgZ10pJFaNPhV9SrwcpJz+1mXA7uwk440WkP/LPePgbuTnAi8APwh3ZeGnXSkERoU/Kp6CthylEV20pFGyDv3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYNqat/bpKnFjzeTHKznXSk8RpSbPO5qtpcVZuB3wD+E3gIO+lIo7XUU/3Lgeer6iXspCON1lKDfx1wT//aTjrSSA0Ofl9a+2rgK0cus5OONC5LOeJ/AHiiql7rp+2kI43UUoJ/Pf93mg920pFGa1Dw++64W4EHF8z+DLA1yR7gin5a0ggM7aTzFnDaEfN+iJ10pFHyzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQUNLb/1pkmeSPJ3kniQnJdmUZEeSvUnu66vwShqBIS20zgT+BNhSVb8OrKGrr/9Z4PNV9V7gdeCGaQ5U0uQMPdVfC7wzyVrgXcArwGXAA/1yO+lIIzKkd95+4C+BH9AF/t+BncAbVXWof9s+4MxpDVLSZA051T+Vrk/eJuBXgJOBK4euwE460uoz5FT/CuD7VXWwqv6Hrrb+pcAp/ak/wEZg/9E+bCcdafUZEvwfABcneVeS0NXS3wU8Bny4f4+ddKQRGfIbfwfdRbwngO/2n7kd+BTw8SR76Zpt3DnFcUqaoHSNbmdjy5YtNT8/P7P1jVF3UiUdv6pa9B+Rd+5JDTL4UoMMvtQggy81aKYX95IcBN4C/nVmK52+03F7Vqu307bAsO351apa9IaZmQYfIMl8VW2Z6UqnyO1Zvd5O2wKT3R5P9aUGGXypQSsR/NtXYJ3T5PasXm+nbYEJbs/Mf+NLWnme6ksNmmnwk1yZ5Lm+Tt8ts1z3ciU5K8ljSXb19Qdv6uevS/Jokj3986krPdalSLImyZNJtvXTo62lmOSUJA8keTbJ7iSXjHn/TLPW5cyCn2QN8FfAB4DzgeuTnD+r9U/AIeATVXU+cDHwsX78twDbq+ocYHs/PSY3AbsXTI+5luJtwCNVdR5wAd12jXL/TL3WZVXN5AFcAnxjwfStwK2zWv8UtudrwFbgOWBDP28D8NxKj20J27CRLgyXAduA0N0gsvZo+2w1P4D3AN+nv261YP4o9w9dKbuXgXV0NS+3Ab83qf0zy1P9wxty2Gjr9CWZAy4EdgDrq+qVftGrwPoVGtbx+ALwSeAn/fRpjLeW4ibgIPDl/qfLHUlOZqT7p6Zc69KLe0uU5N3AV4Gbq+rNhcuq+xoexX+TJPkQcKCqdq70WCZkLXAR8MWqupDu1vCfOa0f2f5ZVq3Lxcwy+PuBsxZMH7NO32qV5AS60N9dVQ/2s19LsqFfvgE4sFLjW6JLgauTvAjcS3e6fxsDaymuQvuAfdVVjIKuatRFjHf/LKvW5WJmGfzHgXP6q5In0l2oeHiG61+Wvt7gncDuqvrcgkUP09UchBHVHqyqW6tqY1XN0e2Lb1bVRxhpLcWqehV4Ocm5/azDtSFHuX+Ydq3LGV+wuAr4HvA88OcrfQFliWN/P91p4neAp/rHVXS/i7cDe4B/ANat9FiPY9t+G9jWv/414NvAXuArwDtWenxL2I7NwHy/j/4OOHXM+wf4NPAs8DTwN8A7JrV/vHNPapAX96QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxr0v95G4v3/3GYuAAAAAElFTkSuQmCC\n",
38 | "text/plain": [
39 | ""
40 | ]
41 | },
42 | "metadata": {},
43 | "output_type": "display_data"
44 | }
45 | ],
46 | "source": [
47 | "env = gameEnv(partial=True,size=9)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 3,
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "data": {
57 | "text/plain": [
58 | ""
59 | ]
60 | },
61 | "execution_count": 3,
62 | "metadata": {},
63 | "output_type": "execute_result"
64 | },
65 | {
66 | "data": {
67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3VuMXeV5xvH/Uw8OCUljm7SWi0ltFARCVTGRlYLggpLSOjSCXEQpKJHSKi03qUraSsG0Fy2VIiVSlYSLKpIFSVGVcohDE4uLpK7jpL1yMIe2YONgEgi2DKYCcrpAdXh7sZfbwR17r5nZe2YW3/8njfZeax/Wt2bp2eswe943VYWktvzCcg9A0tIz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEUFP8m2JIeSHE6yfVKDkjRdWegXeJKsAr4HXAscAR4CbqqqA5MbnqRpmFnEa98DHK6q7wMkuRe4ATht8JP4NUFpyqoq456zmEP984DnZk0f6eZJWuEWs8fvJcnNwM3TXo6k/hYT/KPA+bOmN3bzXqeqdgA7wEN9aaVYzKH+Q8CFSTYnWQ3cCOyazLAkTdOC9/hVdSLJHwPfBFYBX6yqJyY2MklTs+A/5y1oYR7qS1M37av6kgbK4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVobPCTfDHJ8SSPz5q3LsnuJE91t2unO0xJk9Rnj//3wLZT5m0H9lTVhcCeblrSQIwNflX9K/DSKbNvAO7u7t8NfGDC45I0RQs9x19fVce6+88D6yc0HklLYNGddKqqzlQ910460sqz0D3+C0k2AHS3x0/3xKraUVVbq2rrApclacIWGvxdwEe7+x8Fvj6Z4UhaCmMbaiS5B7gaeAfwAvBXwNeA+4F3As8CH6qqUy8AzvVeNtSQpqxPQw076UhvMHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPp10zk+yN8mBJE8kuaWbbzcdaaD61NzbAGyoqkeSvA14mFEDjd8HXqqqTyfZDqytqlvHvJelt6Qpm0jprao6VlWPdPd/AhwEzsNuOtJgzauhRpJNwGXAPnp207GhhrTy9K6ym+StwHeAT1XVA0leqao1sx5/uarOeJ7vob40fROrspvkLOCrwJer6oFudu9uOpJWlj5X9QPcBRysqs/OeshuOtJA9bmqfxXwb8B/Aq91s/+C0Xn+vLrpeKgvTZ+ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgedXc01LwP5fPbOx/nKoH9/hSgwy+1KA+NffOTvLdJP/eddK5vZu/Ocm+JIeT3Jdk9fSHK2kS+uzxXwWuqapLgS3AtiSXA58BPldV7wJeBj42vWFKmqQ+nXSqqn7aTZ7V/RRwDbCzm28nHWlA+tbVX5XkMUa183cDTwOvVNWJ7ilHGLXVmuu1NyfZn2T/JAYsafF6Bb+qfl5VW4CNwHuAi/suoKp2VNXWqtq6wDFKmrB5XdWvqleAvcAVwJokJ78HsBE4OuGxSZqSPlf1fynJmu7+m4FrGXXM3Qt8sHuanXSkAenTSefXGV28W8Xog+L+qvqbJBcA9wLrgEeBj1TVq2Pey6+ljeWv6Mz85t44dtIZJH9FZ2bwx7GTjqQ5GXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Q5+V2L70SQPdtN20pEGaj57/FsYFdk8yU460kD1baixEfhd4M5uOthJRxqsvnv8zwOfBF7rps/FTjrSYPWpq/9+4HhVPbyQBdhJR1p5ZsY/hSuB65NcB5wN/CJwB10nnW6vbycdaUD6dMu9rao2VtUm4EbgW1X1YeykIw3WYv6OfyvwZ0kOMzrnv2syQ5I0bXbSWXH8FZ2ZnXTGsZOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn5h5JngF+AvwcOFFVW5OsA+4DNgHPAB+qqpenM0xJkzSfPf5vVtWWWdVytwN7qupCYE83LWkAFnOofwOjRhpgQw1pUPoGv4B/TvJwkpu7eeur6lh3/3lg/cRHJ2kqep3jA1dV1dEkvwzsTvLk7Aerqk5XSLP7oLh5rsckLY95V9lN8tfAT4E/Aq6uqmNJNgDfrqqLxrzWErJj+Ss6M6vsjjORKrtJzknytpP3gd8GHgd2MWqkATbUkAZl7B4/yQXAP3WTM8A/VtWnkpwL3A+8E3iW0Z/zXhrzXu7OxvJXdGbu8cfps8e3ocaK46/ozAz+ODbUkDQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgvv+dpyXjN9M0fe7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mTZGeSJ5McTHJFknVJdid5qrtdO+3BSpqMvnv8O4BvVNXFwKXAQeykIw1Wn2KbbwceAy6oWU9OcgjLa0srzqRq7m0GXgS+lOTRJHd2ZbbtpCMNVJ/gzwDvBr5QVZcBP+OUw/ruSOC0nXSS7E+yf7GDlTQZfYJ/BDhSVfu66Z2MPghe6A7x6W6Pz/XiqtpRVVtnddmVtMzGBr+qngeeS3Ly/P29wAHspCMNVq+GGkm2AHcCq4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9yUZLHZv38OMkn7KQjDde8Sm8lWQUcBX4D+DjwUlV9Osl2YG1V3Trm9ZbekqZsGqW33gs8XVXPAjcAd3fz7wY+MM/3krRM5hv8G4F7uvt20pEGqnfwk6wGrge+cupjdtKRhmU+e/z3AY9U1QvdtJ10pIGaT/Bv4v8O88FOOtJg9e2kcw7wQ0atsn/UzTsXO+lIK46ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6BT/JnyZ5IsnjSe5JcnaSzUn2JTmc5L6uCq+kAejTQus84E+ArVX1a8AqRvX1PwN8rqreBbwMfGyaA5U0OX0P9WeANyeZAd4CHAOuAXZ2j9tJRxqQscGvqqPA3zKqsnsM+BHwMPBKVZ3onnYEOG9ag5Q0WX0O9dcy6pO3GfgV4BxgW98F2ElHWnlmejznt4AfVNWLAEkeAK4E1iSZ6fb6Gxl10f1/qmoHsKN7reW1pRWgzzn+D4HLk7wlSRh1zD0A7AU+2D3HTjrSgPTtpHM78HvACeBR4A8ZndPfC6zr5n2kql4d8z7u8aUps5OO1CA76Uiak8GXGmTwpQYZfKlBff6OP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akv1VtXVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7g71iGZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOfZFuSQ12dvu1LuezFSnJ+kr1JDnT1B2/p5q9LsjvJU93t2uUe63wkWZXk0SQPdtODraWYZE2SnUmeTHIwyRVD3j7TrHW5ZMFPsgr4O+B9wCXATUkuWarlT8AJ4M+r6hLgcuDj3fi3A3uq6kJgTzc9JLcAB2dND7mW4h3AN6rqYuBSRus1yO0z9VqXVbUkP8AVwDdnTd8G3LZUy5/C+nwduBY4BGzo5m0ADi332OaxDhsZheEa4EEgjL4gMjPXNlvJP8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3JFThpsnb4km4DLgH3A+qo61j30PLB+mYa1EJ8HPgm81k2fy3BrKW4GXgS+1J263JnkHAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsSfSZK8HzheVQ8v91gmZAZ4N/CFqrqM0VfDX3dYP7Dts6hal+MsZfCPAufPmj5tnb6VKslZjEL/5ap6oJv9QpIN3eMbgOPLNb55uhK4PskzjCopXcPoHHlNV0YdhrWNjgBHqmpfN72T0QfBULfP/9a6rKr/Bl5X67J7zoK3z1IG/yHgwu6q5GpGFyp2LeHyF6WrN3gXcLCqPjvroV2Mag7CgGoPVtVtVbWxqjYx2hbfqqoPM9BailX1PPBckou6WSdrQw5y+zDtWpdLfMHiOuB7wNPAXy73BZR5jv0qRoeJ/wE81v1cx+i8eA/wFPAvwLrlHusC1u1q4MHu/gXAd4HDwFeANy33+OaxHluA/d02+hqwdsjbB7gdeBJ4HPgH4E2T2j5+c09qkBf3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGvQ/4fcTtlJMEyYAAAAASUVORK5CYII=\n",
68 | "text/plain": [
69 | ""
70 | ]
71 | },
72 | "metadata": {},
73 | "output_type": "display_data"
74 | }
75 | ],
76 | "source": [
77 | "prev_state = env.reset()\n",
78 | "plt.imshow(prev_state)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | " Training Q Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | " Hyper-parameters "
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 4,
98 | "metadata": {
99 | "collapsed": true
100 | },
101 | "outputs": [],
102 | "source": [
103 | "BATCH_SIZE = 32\n",
104 | "FREEZE_INTERVAL = 20000 # steps\n",
105 | "MEMORY_SIZE = 60000 \n",
106 | "OUTPUT_SIZE = 4\n",
107 | "TOTAL_EPISODES = 10000\n",
108 | "MAX_STEPS = 50\n",
109 | "INITIAL_EPSILON = 1.0\n",
110 | "FINAL_EPSILON = 0.1\n",
111 | "GAMMA = 0.99\n",
112 | "INPUT_IMAGE_DIM = 84\n",
113 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | " Save Dictionay Function "
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 6,
126 | "metadata": {
127 | "collapsed": true
128 | },
129 | "outputs": [],
130 | "source": [
131 | "def save_obj(obj, name ):\n",
132 | " with open('data/'+ name + '.pkl', 'wb') as f:\n",
133 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | " Experience Replay "
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 7,
146 | "metadata": {
147 | "collapsed": true
148 | },
149 | "outputs": [],
150 | "source": [
151 | "class Memory():\n",
152 | " \n",
153 | " def __init__(self,memsize):\n",
154 | " self.memsize = memsize\n",
155 | " self.memory = deque(maxlen=self.memsize)\n",
156 | " \n",
157 | " def add_sample(self,sample):\n",
158 | " self.memory.append(sample)\n",
159 | " \n",
160 | " def get_batch(self,size):\n",
161 | " return random.sample(self.memory,k=size)"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | " Frame Collector "
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 5,
174 | "metadata": {
175 | "collapsed": true
176 | },
177 | "outputs": [],
178 | "source": [
179 | "class FrameCollector():\n",
180 | " \n",
181 | " def __init__(self,num_frames,img_dim):\n",
182 | " self.num_frames = num_frames\n",
183 | " self.img_dim = img_dim\n",
184 | " self.frames = deque(maxlen=self.num_frames)\n",
185 | " \n",
186 | " def reset(self):\n",
187 | " tmp = np.zeros((self.img_dim,self.img_dim))\n",
188 | " for i in range(0,self.num_frames):\n",
189 | " self.frames.append(tmp)\n",
190 | " \n",
191 | " def add_frame(self,frame):\n",
192 | " self.frames.append(frame)\n",
193 | " \n",
194 | " def get_state(self):\n",
195 | " return np.array(self.frames)"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | " Preprocess Images "
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 6,
208 | "metadata": {
209 | "collapsed": true
210 | },
211 | "outputs": [],
212 | "source": [
213 | "def preprocess_image(image):\n",
214 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n",
215 | " return np.copy(image)"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 7,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "data": {
225 | "text/plain": [
226 | ""
227 | ]
228 | },
229 | "execution_count": 7,
230 | "metadata": {},
231 | "output_type": "execute_result"
232 | },
233 | {
234 | "data": {
235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3V2MXPV5x/Hvr14ICUljm7SWi0kxigVCVTGRlYLggpLSEhpBLqIUlEhpldY3qUraSsG0Fy2VIiVSlYSLKpIFSVGV8hKHJhYXSV2HpL1ysDFtwcbBJBBs+YUKyNsFqsPTizluF7p4zu7O7O7h//1Iq5lz5uX8j45+c15m9nlSVUhqyy8s9wAkLT2DLzXI4EsNMvhSgwy+1CCDLzXI4EsNWlTwk1yf5FCSw0m2TWpQkqYrC/0BT5JVwPeA64AjwCPALVV1YHLDkzQNM4t47XuAw1X1fYAk9wE3Aa8b/CT+TFCasqrKuOcs5lD/fOC5WdNHunmSVrjF7PF7SbIV2Drt5UjqbzHBPwpcMGt6QzfvVapqO7AdPNSXVorFHOo/AmxKsjHJ2cDNwM7JDEvSNC14j19Vp5L8MfBNYBXwxap6YmIjkzQ1C/46b0EL81BfmrppX9WXNFAGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUFjg5/ki0lOJnl81ry1SXYleaq7XTPdYUqapD57/L8Hrn/NvG3A7qraBOzupiUNxNjgV9W/Ai+8ZvZNwD3d/XuAD0x4XJKmaKHn+Ouq6lh3/ziwbkLjkbQEFt1Jp6rqTNVz7aQjrTwL3eOfSLIeoLs9+XpPrKrtVbWlqrYscFmSJmyhwd8JfLS7/1Hg65MZjqSlMLahRpJ7gWuAdwAngL8CvgY8ALwTeBb4UFW99gLgXO9lQw1pyvo01LCTjvQGYycdSXMy+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw3q00nngiQPJzmQ5Ikkt3bz7aYjDVSfmnvrgfVV9WiStwH7GDXQ+H3ghar6dJJtwJqqum3Me1l6S5qyiZTeqqpjVfVod/8nwEHgfOymIw3WvBpqJLkQuBzYQ89uOjbUkFae3lV2k7wV+A7wqap6MMlLVbV61uMvVtUZz/M91Jemb2JVdpOcBXwV+HJVPdjN7t1NR9LK0ueqfoC7gYNV9dlZD9lNRxqoPlf1rwb+DfhP4JVu9l8wOs+fVzcdD/Wl6bOTjtQgO+lImpPBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxo0r5p7mr6l/DfpIRrVhdFiuceXGmTwpQb1qbl3TpLvJvn3rpPOHd38jUn2JDmc5P4kZ09/uJImoc8e/2Xg2qq6DNgMXJ/kCuAzwOeq6l3Ai8DHpjdMSZPUp5NOVdVPu8mzur8CrgV2dPPtpCMNSN+6+quSPMaodv4u4Gngpao61T3lCKO2WnO9dmuSvUn2TmLAkhavV/Cr6udVtRnYALwHuKTvAqpqe1VtqaotCxyjpAmb11X9qnoJeBi4Elid5PTvADYARyc8NklT0ueq/i8lWd3dfzNwHaOOuQ8DH+yeZicdaUD6dNL5dUYX71Yx+qB4oKr+JslFwH3AWmA/8JGqennMe/mztDH85d6Z+cu98eykM0AG/8wM/nh20pE0J4MvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoN7B70ps70/yUDdtJx1poOazx7+VUZHN0+ykIw1U34YaG4DfBe7qpoOddKTB6rvH/zzwSeCVbvo87KQjDVafuvrvB05W1b6FLMBOOtLKMzP+KVwF3JjkBuAc4BeBO+k66XR7fTvpSAPSp1vu7VW1oaouBG4GvlVVH8ZOOtJgLeZ7/NuAP0tymNE5/92TGZKkabOTzgpjJ50zs5POeHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoP61NwjyTPAT4CfA6eqakuStcD9wIXAM8CHqurF6QxT0iTNZ4//m1W1eVa13G3A7qraBOzupiUNwGIO9W9i1EgDbKghDUrf4Bfwz0n2JdnazVtXVce6+8eBdRMfnaSp6HWOD1xdVUeT/DKwK8mTsx+sqnq9QprdB8XWuR6TtDzmXWU3yV8DPwX+CLimqo4lWQ98u6ouHvNaS8iOYZXdM7PK7ngTqbKb5Nwkbzt9H/ht4HFgJ6NGGmBDDWlQxu7xk1wE/FM3OQP8Y1V9Ksl5wAPAO4FnGX2d98KY93J3NoZ7/DNzjz9enz2+DTVWGIN/ZgZ/PBtqSJqTwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQ3//O0xLxl2laCu7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mdZEeSJ5McTHJlkrVJdiV5qrtdM+3BSpqMvnv8O4FvVNUlwGXAQeykIw1Wn2KbbwceAy6qWU9OcgjLa0srzqRq7m0Enge+lGR/kru6Mtt20pEGqk/wZ4B3A1+oqsuBn/Gaw/ruSOB1O+kk2Ztk72IHK2ky+gT/CHCkqvZ00zsYfRCc6A7x6W5PzvXiqtpeVVtmddmVtMzGBr+qjgPPJTl9/v5e4AB20pEGq1dDjSSbgbuAs4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9ycZLHZv39OMkn7KQjDde8Sm8lWQUcBX4D+DjwQlV9Osk2YE1V3Tbm9ZbekqZsGqW33gs8XVXPAjcB93Tz7wE+MM/3krRM5hv8m4F7u/t20pEGqnfwk5wN3Ah85bWP2UlHGpb57PHfBzxaVSe6aTvpSAM1n+Dfwv8d5oOddKTB6ttJ51zgh4xaZf+om3cedtKRVhw76UgNspOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQb1Cn6SP03yRJLHk9yb5JwkG5PsSXI4yf1dFV5JA9Cnhdb5wJ8AW6rq14BVjOrrfwb4XFW9C3gR+Ng0Byppcvoe6s8Ab04yA7wFOAZcC+zoHreTjjQgY4NfVUeBv2VUZfcY8CNgH/BSVZ3qnnYEOH9ag5Q0WX0O9dcw6pO3EfgV4Fzg+r4LsJOOtPLM9HjObwE/qKrnAZI8CFwFrE4y0+31NzDqovv/VNV2YHv3WstrSytAn3P8HwJXJHlLkjDqmHsAeBj4YPccO+lIA9K3k84dwO8Bp4D9wB8yOqe/D1jbzftIVb085n3c40tTZicdqUF20pE0J4MvNcjgSw0y+FKD+nyPP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akr1VtWVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7gb1+GZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOf5Pokh7o6fduWctmLleSCJA8nOdDVH7y1m782ya4kT3W3a5Z7rPORZFWS/Uke6qYHW0sxyeokO5I8meRgkiuHvH2mWetyyYKfZBXwd8D7gEuBW5JculTLn4BTwJ9X1aXAFcDHu/FvA3ZX1SZgdzc9JLcCB2dND7mW4p3AN6rqEuAyRus1yO0z9VqXVbUkf8CVwDdnTd8O3L5Uy5/C+nwduA44BKzv5q0HDi332OaxDhsYheFa4CEgjH4gMjPXNlvJf8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3pFThtsnb4kFwKXA3uAdVV1rHvoOLBumYa1EJ8HPgm80k2fx3BrKW4Enge+1J263JXkXAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsTXJEneD5ysqn3LPZYJmQHeDXyhqi5n9NPwVx3WD2z7LKrW5ThLGfyjwAWzpl+3Tt9KleQsRqH/clU92M0+kWR99/h64ORyjW+ergJuTPIMo0pK1zI6R17dlVGHYW2jI8CRqtrTTe9g9EEw1O3zv7Uuq+q/gVfVuuyes+Dts5TBfwTY1F2VPJvRhYqdS7j8RenqDd4NHKyqz856aCejmoMwoNqDVXV7VW2oqgsZbYtvVdWHGWgtxao6DjyX5OJu1unakIPcPky71uUSX7C4Afge8DTwl8t9AWWeY7+a0WHifwCPdX83MDov3g08BfwLsHa5x7qAdbsGeKi7fxHwXeAw8BXgTcs9vnmsx2Zgb7eNvgasGfL2Ae4AngQeB/4BeNOkto+/3JMa5MU9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBv0PANIhuBFMr1IAAAAASUVORK5CYII=\n",
236 | "text/plain": [
237 | ""
238 | ]
239 | },
240 | "metadata": {},
241 | "output_type": "display_data"
242 | }
243 | ],
244 | "source": [
245 | "processed_prev_state = preprocess_image(prev_state)\n",
246 | "plt.imshow(processed_prev_state,cmap='gray')"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | " Build Model "
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 8,
259 | "metadata": {},
260 | "outputs": [
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "Network(\n",
266 | " (conv_layer1): Conv2d(2, 32, kernel_size=(8, 8), stride=(4, 4))\n",
267 | " (conv_layer2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
268 | " (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
269 | " (fc1): Linear(in_features=3136, out_features=512, bias=True)\n",
270 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n",
271 | " (relu): ReLU()\n",
272 | ")\n"
273 | ]
274 | }
275 | ],
276 | "source": [
277 | "import torch.nn as nn\n",
278 | "import torch\n",
279 | "\n",
280 | "class Network(nn.Module):\n",
281 | " \n",
282 | " def __init__(self,image_input_size,out_size):\n",
283 | " super(Network,self).__init__()\n",
284 | " self.image_input_size = image_input_size\n",
285 | " self.out_size = out_size\n",
286 | "\n",
287 | " self.conv_layer1 = nn.Conv2d(in_channels=2,out_channels=32,kernel_size=8,stride=4) # GRAY - 1\n",
288 | " self.conv_layer2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4,stride=2)\n",
289 | " self.conv_layer3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1)\n",
290 | " self.fc1 = nn.Linear(in_features=7*7*64,out_features=512)\n",
291 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n",
292 | " self.relu = nn.ReLU()\n",
293 | "\n",
294 | " def forward(self,x,bsize):\n",
295 | " x = x.view(bsize,2,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n",
296 | " conv_out = self.conv_layer1(x)\n",
297 | " conv_out = self.relu(conv_out)\n",
298 | " conv_out = self.conv_layer2(conv_out)\n",
299 | " conv_out = self.relu(conv_out)\n",
300 | " conv_out = self.conv_layer3(conv_out)\n",
301 | " conv_out = self.relu(conv_out)\n",
302 | " out = self.fc1(conv_out.view(bsize,7*7*64))\n",
303 | " out = self.relu(out)\n",
304 | " out = self.fc2(out)\n",
305 | " return out\n",
306 | "\n",
307 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).cuda()\n",
308 | "print(main_model)"
309 | ]
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "metadata": {},
314 | "source": [
315 | " Deep Q Learning with Freeze Network "
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {},
322 | "outputs": [
323 | {
324 | "name": "stdout",
325 | "output_type": "stream",
326 | "text": [
327 | "Populated 60000 Samples in Episodes : 1200\n"
328 | ]
329 | }
330 | ],
331 | "source": [
332 | "mem = Memory(memsize=MEMORY_SIZE)\n",
333 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Primary Network\n",
334 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Target Network\n",
335 | "frameObj = FrameCollector(img_dim=INPUT_IMAGE_DIM,num_frames=2)\n",
336 | "\n",
337 | "target_model.load_state_dict(main_model.state_dict())\n",
338 | "criterion = nn.SmoothL1Loss()\n",
339 | "optimizer = torch.optim.Adam(main_model.parameters())\n",
340 | "\n",
341 | "# filling memory with transitions\n",
342 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n",
343 | " \n",
344 | " prev_state = env.reset()\n",
345 | " frameObj.reset()\n",
346 | " processed_prev_state = preprocess_image(prev_state)\n",
347 | " frameObj.add_frame(processed_prev_state)\n",
348 | " prev_frames = frameObj.get_state()\n",
349 | " step_count = 0\n",
350 | " game_over = False\n",
351 | " \n",
352 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
353 | " \n",
354 | " step_count +=1\n",
355 | " action = np.random.randint(0,4)\n",
356 | " next_state,reward, game_over = env.step(action)\n",
357 | " processed_next_state = preprocess_image(next_state)\n",
358 | " frameObj.add_frame(processed_next_state)\n",
359 | " next_frames = frameObj.get_state()\n",
360 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n",
361 | " \n",
362 | " prev_state = next_state\n",
363 | " processed_prev_state = processed_next_state\n",
364 | " prev_frames = next_frames\n",
365 | "\n",
366 | "print('Populated %d Samples in Episodes : %d'%(len(mem.memory),int(MEMORY_SIZE/MAX_STEPS)))\n",
367 | "\n",
368 | "\n",
369 | "# Algorithm Starts\n",
370 | "total_steps = 0\n",
371 | "epsilon = INITIAL_EPSILON\n",
372 | "loss_stat = []\n",
373 | "total_reward_stat = []\n",
374 | "\n",
375 | "for episode in range(0,TOTAL_EPISODES):\n",
376 | " \n",
377 | " prev_state = env.reset()\n",
378 | " frameObj.reset()\n",
379 | " processed_prev_state = preprocess_image(prev_state)\n",
380 | " frameObj.add_frame(processed_prev_state)\n",
381 | " prev_frames = frameObj.get_state()\n",
382 | " game_over = False\n",
383 | " step_count = 0\n",
384 | " total_reward = 0\n",
385 | " \n",
386 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
387 | " \n",
388 | " step_count +=1\n",
389 | " total_steps +=1\n",
390 | " \n",
391 | " if np.random.rand() <= epsilon:\n",
392 | " action = np.random.randint(0,4)\n",
393 | " else:\n",
394 | " with torch.no_grad():\n",
395 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n",
396 | "\n",
397 | " model_out = main_model.forward(torch_x,bsize=1)\n",
398 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
399 | " \n",
400 | " next_state, reward, game_over = env.step(action)\n",
401 | " processed_next_state = preprocess_image(next_state)\n",
402 | " frameObj.add_frame(processed_next_state)\n",
403 | " next_frames = frameObj.get_state()\n",
404 | " total_reward += reward\n",
405 | " \n",
406 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n",
407 | " \n",
408 | " prev_state = next_state\n",
409 | " processed_prev_state = processed_next_state\n",
410 | " prev_frames = next_frames\n",
411 | " \n",
412 | " if (total_steps % FREEZE_INTERVAL) == 0:\n",
413 | " target_model.load_state_dict(main_model.state_dict())\n",
414 | " \n",
415 | " batch = mem.get_batch(size=BATCH_SIZE)\n",
416 | " current_states = []\n",
417 | " next_states = []\n",
418 | " acts = []\n",
419 | " rewards = []\n",
420 | " game_status = []\n",
421 | " \n",
422 | " for element in batch:\n",
423 | " current_states.append(element[0])\n",
424 | " acts.append(element[1])\n",
425 | " rewards.append(element[2])\n",
426 | " next_states.append(element[3])\n",
427 | " game_status.append(element[4])\n",
428 | " \n",
429 | " current_states = np.array(current_states)\n",
430 | " next_states = np.array(next_states)\n",
431 | " rewards = np.array(rewards)\n",
432 | " game_status = [not b for b in game_status]\n",
433 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n",
434 | " torch_acts = torch.tensor(acts)\n",
435 | " \n",
436 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n",
437 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n",
438 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n",
439 | " Q_max_next = Q_max_next.double()\n",
440 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n",
441 | " \n",
442 | " target_values = (rewards + (GAMMA * Q_max_next))\n",
443 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n",
444 | " \n",
445 | " loss = criterion(Q_s_a,target_values.float().cuda())\n",
446 | " \n",
447 | " # save performance measure\n",
448 | " loss_stat.append(loss.item())\n",
449 | " \n",
450 | " # make previous grad zero\n",
451 | " optimizer.zero_grad()\n",
452 | " \n",
453 | " # back - propogate \n",
454 | " loss.backward()\n",
455 | " \n",
456 | " # update params\n",
457 | " optimizer.step()\n",
458 | " \n",
459 | " # save performance measure\n",
460 | " total_reward_stat.append(total_reward)\n",
461 | " \n",
462 | " if epsilon > FINAL_EPSILON:\n",
463 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
464 | " \n",
465 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
466 | " perf = {}\n",
467 | " perf['loss'] = loss_stat\n",
468 | " perf['total_reward'] = total_reward_stat\n",
469 | " save_obj(name='TWO_OBSERV_NINE',obj=perf)\n",
470 | " \n",
471 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)\n"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {},
477 | "source": [
478 | " Save Primary Network Weights "
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 19,
484 | "metadata": {
485 | "collapsed": true
486 | },
487 | "outputs": [],
488 | "source": [
489 | "torch.save(main_model.state_dict(),'data/TWO_OBSERV_NINE_WEIGHTS.torch')"
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "metadata": {},
495 | "source": [
496 | " Testing Policy "
497 | ]
498 | },
499 | {
500 | "cell_type": "markdown",
501 | "metadata": {},
502 | "source": [
503 | " Load Primary Network Weights "
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 10,
509 | "metadata": {},
510 | "outputs": [],
511 | "source": [
512 | "weights = torch.load('data/TWO_OBSERV_NINE_WEIGHTS.torch')\n",
513 | "main_model.load_state_dict(weights)"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {},
519 | "source": [
520 | " Testing Policy "
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": null,
526 | "metadata": {
527 | "collapsed": true
528 | },
529 | "outputs": [],
530 | "source": [
531 | "# Algorithm Starts\n",
532 | "epsilon = INITIAL_EPSILON\n",
533 | "FINAL_EPSILON = 0.01\n",
534 | "total_reward_stat = []\n",
535 | "\n",
536 | "for episode in range(0,TOTAL_EPISODES):\n",
537 | " \n",
538 | " prev_state = env.reset()\n",
539 | " processed_prev_state = preprocess_image(prev_state)\n",
540 | " frameObj.reset()\n",
541 | " frameObj.add_frame(processed_prev_state)\n",
542 | " prev_frames = frameObj.get_state()\n",
543 | " game_over = False\n",
544 | " step_count = 0\n",
545 | " total_reward = 0\n",
546 | " \n",
547 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
548 | " \n",
549 | " step_count +=1\n",
550 | " \n",
551 | " if np.random.rand() <= epsilon:\n",
552 | " action = np.random.randint(0,4)\n",
553 | " else:\n",
554 | " with torch.no_grad():\n",
555 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n",
556 | "\n",
557 | " model_out = main_model.forward(torch_x,bsize=1)\n",
558 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
559 | " \n",
560 | " next_state, reward, game_over = env.step(action)\n",
561 | " processed_next_state = preprocess_image(next_state)\n",
562 | " frameObj.add_frame(processed_next_state)\n",
563 | " next_frames = frameObj.get_state()\n",
564 | " \n",
565 | " total_reward += reward\n",
566 | " \n",
567 | " prev_state = next_state\n",
568 | " processed_prev_state = processed_next_state\n",
569 | " prev_frames = next_frames\n",
570 | " \n",
571 | " # save performance measure\n",
572 | " total_reward_stat.append(total_reward)\n",
573 | " \n",
574 | " if epsilon > FINAL_EPSILON:\n",
575 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
576 | " \n",
577 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
578 | " perf = {}\n",
579 | " perf['total_reward'] = total_reward_stat\n",
580 | " save_obj(name='TWO_OBSERV_NINE',obj=perf)\n",
581 | " \n",
582 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)"
583 | ]
584 | }
585 | ],
586 | "metadata": {
587 | "kernelspec": {
588 | "display_name": "Python [conda env:myenv]",
589 | "language": "python",
590 | "name": "conda-env-myenv-py"
591 | },
592 | "language_info": {
593 | "codemirror_mode": {
594 | "name": "ipython",
595 | "version": 3
596 | },
597 | "file_extension": ".py",
598 | "mimetype": "text/x-python",
599 | "name": "python",
600 | "nbconvert_exporter": "python",
601 | "pygments_lexer": "ipython3",
602 | "version": "3.6.5"
603 | }
604 | },
605 | "nbformat": 4,
606 | "nbformat_minor": 2
607 | }
608 |
--------------------------------------------------------------------------------
/Four Observations.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "from gridworld import gameEnv\n",
12 | "import numpy as np\n",
13 | "%matplotlib inline\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "from collections import deque\n",
16 | "import pickle\n",
17 | "from skimage.color import rgb2gray\n",
18 | "import random\n",
19 | "import torch\n",
20 | "import torch.nn as nn"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | " Define Environment Object "
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "metadata": {},
34 | "outputs": [
35 | {
36 | "data": {
37 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3VuMXeV5xvH/Uw8OCUljm7SWi0ltFARCVTGRlYLggpLSOjSCXEQpKJHSKi03qUraSsG0Fy2VIiVSlYSLKpIFSVGVcohDE4uLpK7jpL1yMIe2YONgEgi2DKYCcrpAdXh7sZfbwR17r5nZe2YW3/8njfZeax/Wt2bp2eswe943VYWktvzCcg9A0tIz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEUFP8m2JIeSHE6yfVKDkjRdWegXeJKsAr4HXAscAR4CbqqqA5MbnqRpmFnEa98DHK6q7wMkuRe4ATht8JP4NUFpyqoq456zmEP984DnZk0f6eZJWuEWs8fvJcnNwM3TXo6k/hYT/KPA+bOmN3bzXqeqdgA7wEN9aaVYzKH+Q8CFSTYnWQ3cCOyazLAkTdOC9/hVdSLJHwPfBFYBX6yqJyY2MklTs+A/5y1oYR7qS1M37av6kgbK4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVobPCTfDHJ8SSPz5q3LsnuJE91t2unO0xJk9Rnj//3wLZT5m0H9lTVhcCeblrSQIwNflX9K/DSKbNvAO7u7t8NfGDC45I0RQs9x19fVce6+88D6yc0HklLYNGddKqqzlQ910460sqz0D3+C0k2AHS3x0/3xKraUVVbq2rrApclacIWGvxdwEe7+x8Fvj6Z4UhaCmMbaiS5B7gaeAfwAvBXwNeA+4F3As8CH6qqUy8AzvVeNtSQpqxPQw076UhvMHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPp10zk+yN8mBJE8kuaWbbzcdaaD61NzbAGyoqkeSvA14mFEDjd8HXqqqTyfZDqytqlvHvJelt6Qpm0jprao6VlWPdPd/AhwEzsNuOtJgzauhRpJNwGXAPnp207GhhrTy9K6ym+StwHeAT1XVA0leqao1sx5/uarOeJ7vob40fROrspvkLOCrwJer6oFudu9uOpJWlj5X9QPcBRysqs/OeshuOtJA9bmqfxXwb8B/Aq91s/+C0Xn+vLrpeKgvTZ+ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgedXc01LwP5fPbOx/nKoH9/hSgwy+1KA+NffOTvLdJP/eddK5vZu/Ocm+JIeT3Jdk9fSHK2kS+uzxXwWuqapLgS3AtiSXA58BPldV7wJeBj42vWFKmqQ+nXSqqn7aTZ7V/RRwDbCzm28nHWlA+tbVX5XkMUa183cDTwOvVNWJ7ilHGLXVmuu1NyfZn2T/JAYsafF6Bb+qfl5VW4CNwHuAi/suoKp2VNXWqtq6wDFKmrB5XdWvqleAvcAVwJokJ78HsBE4OuGxSZqSPlf1fynJmu7+m4FrGXXM3Qt8sHuanXSkAenTSefXGV28W8Xog+L+qvqbJBcA9wLrgEeBj1TVq2Pey6+ljeWv6Mz85t44dtIZJH9FZ2bwx7GTjqQ5GXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Q5+V2L70SQPdtN20pEGaj57/FsYFdk8yU460kD1baixEfhd4M5uOthJRxqsvnv8zwOfBF7rps/FTjrSYPWpq/9+4HhVPbyQBdhJR1p5ZsY/hSuB65NcB5wN/CJwB10nnW6vbycdaUD6dMu9rao2VtUm4EbgW1X1YeykIw3WYv6OfyvwZ0kOMzrnv2syQ5I0bXbSWXH8FZ2ZnXTGsZOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn5h5JngF+AvwcOFFVW5OsA+4DNgHPAB+qqpenM0xJkzSfPf5vVtWWWdVytwN7qupCYE83LWkAFnOofwOjRhpgQw1pUPoGv4B/TvJwkpu7eeur6lh3/3lg/cRHJ2kqep3jA1dV1dEkvwzsTvLk7Aerqk5XSLP7oLh5rsckLY95V9lN8tfAT4E/Aq6uqmNJNgDfrqqLxrzWErJj+Ss6M6vsjjORKrtJzknytpP3gd8GHgd2MWqkATbUkAZl7B4/yQXAP3WTM8A/VtWnkpwL3A+8E3iW0Z/zXhrzXu7OxvJXdGbu8cfps8e3ocaK46/ozAz+ODbUkDQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgvv+dpyXjN9M0fe7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mTZGeSJ5McTHJFknVJdid5qrtdO+3BSpqMvnv8O4BvVNXFwKXAQeykIw1Wn2KbbwceAy6oWU9OcgjLa0srzqRq7m0GXgS+lOTRJHd2ZbbtpCMNVJ/gzwDvBr5QVZcBP+OUw/ruSOC0nXSS7E+yf7GDlTQZfYJ/BDhSVfu66Z2MPghe6A7x6W6Pz/XiqtpRVVtnddmVtMzGBr+qngeeS3Ly/P29wAHspCMNVq+GGkm2AHcCq4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9yUZLHZv38OMkn7KQjDde8Sm8lWQUcBX4D+DjwUlV9Osl2YG1V3Trm9ZbekqZsGqW33gs8XVXPAjcAd3fz7wY+MM/3krRM5hv8G4F7uvt20pEGqnfwk6wGrge+cupjdtKRhmU+e/z3AY9U1QvdtJ10pIGaT/Bv4v8O88FOOtJg9e2kcw7wQ0atsn/UzTsXO+lIK46ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6BT/JnyZ5IsnjSe5JcnaSzUn2JTmc5L6uCq+kAejTQus84E+ArVX1a8AqRvX1PwN8rqreBbwMfGyaA5U0OX0P9WeANyeZAd4CHAOuAXZ2j9tJRxqQscGvqqPA3zKqsnsM+BHwMPBKVZ3onnYEOG9ag5Q0WX0O9dcy6pO3GfgV4BxgW98F2ElHWnlmejznt4AfVNWLAEkeAK4E1iSZ6fb6Gxl10f1/qmoHsKN7reW1pRWgzzn+D4HLk7wlSRh1zD0A7AU+2D3HTjrSgPTtpHM78HvACeBR4A8ZndPfC6zr5n2kql4d8z7u8aUps5OO1CA76Uiak8GXGmTwpQYZfKlBff6OP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akv1VtXVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7g71iGZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOfZFuSQ12dvu1LuezFSnJ+kr1JDnT1B2/p5q9LsjvJU93t2uUe63wkWZXk0SQPdtODraWYZE2SnUmeTHIwyRVD3j7TrHW5ZMFPsgr4O+B9wCXATUkuWarlT8AJ4M+r6hLgcuDj3fi3A3uq6kJgTzc9JLcAB2dND7mW4h3AN6rqYuBSRus1yO0z9VqXVbUkP8AVwDdnTd8G3LZUy5/C+nwduBY4BGzo5m0ADi332OaxDhsZheEa4EEgjL4gMjPXNlvJP8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3JFThpsnb4km4DLgH3A+qo61j30PLB+mYa1EJ8HPgm81k2fy3BrKW4GXgS+1J263JnkHAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsSfSZK8HzheVQ8v91gmZAZ4N/CFqrqM0VfDX3dYP7Dts6hal+MsZfCPAufPmj5tnb6VKslZjEL/5ap6oJv9QpIN3eMbgOPLNb55uhK4PskzjCopXcPoHHlNV0YdhrWNjgBHqmpfN72T0QfBULfP/9a6rKr/Bl5X67J7zoK3z1IG/yHgwu6q5GpGFyp2LeHyF6WrN3gXcLCqPjvroV2Mag7CgGoPVtVtVbWxqjYx2hbfqqoPM9BailX1PPBckou6WSdrQw5y+zDtWpdLfMHiOuB7wNPAXy73BZR5jv0qRoeJ/wE81v1cx+i8eA/wFPAvwLrlHusC1u1q4MHu/gXAd4HDwFeANy33+OaxHluA/d02+hqwdsjbB7gdeBJ4HPgH4E2T2j5+c09qkBf3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGvQ/4fcTtlJMEyYAAAAASUVORK5CYII=\n",
38 | "text/plain": [
39 | ""
40 | ]
41 | },
42 | "metadata": {},
43 | "output_type": "display_data"
44 | }
45 | ],
46 | "source": [
47 | "env = gameEnv(partial=True,size=9)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 3,
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "data": {
57 | "text/plain": [
58 | ""
59 | ]
60 | },
61 | "execution_count": 3,
62 | "metadata": {},
63 | "output_type": "execute_result"
64 | },
65 | {
66 | "data": {
67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLlJREFUeJzt3X/oXfV9x/Hna4nW1m41cS5kxs2UiiIDowlOsYxNzWZd0f1RRCmjDMF/uk3XQqvbH6WwP1oYbf1jFETbyXD+qNU1hGLn0pQyGKlff6zVRJtoY01QEzudnYNtad/745xsX0OynG++997v9/h5PuBy7znn3pzP4fC659yT832/U1VIassvLPUAJM2ewZcaZPClBhl8qUEGX2qQwZcaZPClBi0q+EmuSvJckj1Jbp3UoCRNV070Bp4kK4AfApuBfcBjwA1VtXNyw5M0DSsX8dmLgT1V9QJAkvuAa4FjBj+JtwlqUTZu3LjUQ1jW9u7dy2uvvZbjvW8xwT8TeGne9D7gNxfx70nHNTc3t9RDWNY2bdo06H2LCf4gSW4Cbpr2eiQNt5jg7wfOmje9rp/3NlV1B3AHeKovLReLuar/GHBOkvVJTgauB7ZMZliSpumEj/hVdSjJHwPfAlYAX6mqZyY2MklTs6jf+FX1TeCbExqLpBnxzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQccNfpKvJDmQ5Ol581YneTTJ7v551XSHKWmShhzx/wa46oh5twLbquocYFs/LWkkjhv8qvou8K9HzL4WuLt/fTfwBxMel6QpOtHf+Guq6uX+9SvAmgmNR9IMLLqTTlXV/9cow0460vJzokf8V5OsBeifDxzrjVV1R1VtqqphTb0kTd2JBn8L8LH+9ceAb0xmOJJmYch/590L/DNwbpJ9SW4EPgdsTrIbuLKfljQSx/2NX1U3HGPRFRMei6QZ8c49qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUFDSm+dlWR7kp1Jnklycz/fbjrSSA054h8CPllV5wOXAB9Pcj5205FGa0gnnZer6on+9U+BXcCZ2E1HGq0FNdRIcjZwIbCDgd10bKghLT+DL+4leS/wdeCWqnpz/rKqKuCo3XRsqCEtP4OCn+QkutDfU1UP9bMHd9ORtLwMuaof4C5gV1V9Yd4iu+lIIzXkN/5lwB8CP0jyVD/vz+m65zzQd9Z5EbhuOkOUNGlDOun8E5BjLLabjjRC3rknNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw1aUM29RdsIzM10jeOTo1YwkybKI77UIIMvNWhIzb1Tknwvyb/0nXQ+289fn2RHkj1J7k9y8vSHK2kShhzx/xO4vKouADYAVyW5BPg88MWq+gDwOnDj9IYpaZKGdNKpqvr3fvKk/lHA5cCD/Xw76UgjMrSu/oq+wu4B4FHgeeCNqjrUv2UfXVuto332piRzSeY4OIkhS1qsQcGvqp9V1QZgHXAxcN7QFbytk84ZJzhKSRO1oKv6VfUGsB24FDgtyeH7ANYB+yc8NklTMuSq/hlJTutfvxvYTNcxdzvwkf5tdtKRRmTInXtrgbuTrKD7onigqrYm2Qncl+QvgSfp2mxJGoEhnXS+T9ca+8j5L9D93pc0Mt65JzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVocPD7EttPJtnaT9tJRxqphRzxb6YrsnmYnXSkkRraUGMd8PvAnf10sJOONFpDj/hfAj4F/LyfPh076UijNaSu/oeBA1X1+ImswE460vIzpK7+ZcA1Sa4GTgF+CbidvpNOf9S3k440IkO65d5WVeuq6mzgeuDbVfVR7KQjjdZi/h//08Ankuyh+81vJx1pJIac6v+vqvoO8J3+tZ10pJHyzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGDSrEkWQv8FPgZ8ChqtqUZDVwP3A2sBe4rqpen84wJU3SQo74v1NVG6pqUz99K7Ctqs4BtvXTkkZgMaf619I10gAbakijMjT4BfxDkseT3NTPW1NVL/evXwHWTHx0kqZiaLHND1bV/iS/Ajya5Nn5C6uqktTRPth/UXRfFr+2mKFKmpRBR/yq2t8/HwAepquu+2qStQD984FjfNZOOtIyM6SF1qlJfvHwa+B3gaeBLXSNNMCGGtKoDDnVXwM83DXIZSXwd1X1SJLHgAeS3Ai8CFw3vWFKmqTjBr9vnHHBUeb/BLhiGoOSNF3euSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aOhf503ERjYyx9wsVzk+R/0bR2myPOJLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgQcFPclqSB5M8m2RXkkuTrE7yaJLd/fOqaQ9W0mQMPeLfDjxSVefRleHahZ10pNEaUmX3fcBvAXcBVNV/VdUb2ElHGq0hR/z1wEHgq0meTHJnX2bbTjrSSA0J/krgIuDLVXUh8BZHnNZXVXGMu8yT3JRkLsncwYMHFzteSRMwJPj7gH1VtaOffpDui2DBnXTOOMNWOtJycNzgV9UrwEtJzu1nXQHsxE460mgN/bPcPwHuSXIy8ALwR3RfGnbSkUZoUPCr6ilg01EW2UlHGiHv3JMaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaNKSu/rlJnpr3eDPJLXbSkcZrSLHN56pqQ1VtADYC/wE8jJ10pNFa6Kn+FcDzVfUidtKRRmuhwb8euLd/bScdaaQGB78vrX0N8LUjl9lJRxqXhRzxPwQ8UVWv9tN20pFGaiHBv4H/O80HO+lIozUo+H133M3AQ/Nmfw7YnGQ3cGU/LWkEhnbSeQs4/Yh5P8FOOtIoeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KChpbf+LMkzSZ5Ocm+SU5KsT7IjyZ4k9/dVeCWNwJAWWmcCfwpsqqrfAFbQ1df/PPDFqvoA8Dpw4zQHKmlyhp7qrwTenWQl8B7gZeBy4MF+uZ10pBEZ0jtvP/BXwI/pAv9vwOPAG1V1qH/bPuDMaQ1S0mQNOdVfRdcnbz3wq8CpwFVDV2AnHWn5GXKqfyXwo6o6WFX/TVdb/zLgtP7UH2AdsP9oH7aTjrT8DAn+j4FLkrwnSehq6e8EtgMf6d9jJx1pRIb8xt9BdxHvCeAH/WfuAD4NfCLJHrpmG3dNcZySJmhoJ53PAJ85YvYLwMUTH5GkqfPOPalBBl9qkMGXGmTwpQalqma3suQg8Bbw2sxWOn2/jNuzXL2TtgWGbc+vV9Vxb5iZafABksxV1aaZrnSK3J7l6520LTDZ7fFUX2qQwZcatBTBv2MJ1jlNbs/y9U7aFpjg9sz8N76kpeepvtSgmQY/yVVJnuvr9N06y3UvVpKzkmxPsrOvP3hzP391kkeT7O6fVy31WBciyYokTybZ2k+PtpZiktOSPJjk2SS7klw65v0zzVqXMwt+khXAXwMfAs4Hbkhy/qzWPwGHgE9W1fnAJcDH+/HfCmyrqnOAbf30mNwM7Jo3PeZaircDj1TVecAFdNs1yv0z9VqXVTWTB3Ap8K1507cBt81q/VPYnm8Am4HngLX9vLXAc0s9tgVswzq6MFwObAVCd4PIyqPts+X8AN4H/Ij+utW8+aPcP3Sl7F4CVtP9Fe1W4PcmtX9meap/eEMOG22dviRnAxcCO4A1VfVyv+gVYM0SDetEfAn4FPDzfvp0xltLcT1wEPhq/9PlziSnMtL9U1OudenFvQVK8l7g68AtVfXm/GXVfQ2P4r9JknwYOFBVjy/1WCZkJXAR8OWqupDu1vC3ndaPbP8sqtbl8cwy+PuBs+ZNH7NO33KV5CS60N9TVQ/1s19NsrZfvhY4sFTjW6DLgGuS7AXuozvdv52BtRSXoX3AvuoqRkFXNeoixrt/FlXr8nhmGfzHgHP6q5In012o2DLD9S9KX2/wLmBXVX1h3qItdDUHYUS1B6vqtqpaV1Vn0+2Lb1fVRxlpLcWqegV4Kcm5/azDtSFHuX+Ydq3LGV+wuBr4IfA88BdLfQFlgWP/IN1p4veBp/rH1XS/i7cBu4F/BFYv9VhPYNt+G9jav34/8D1gD/A14F1LPb4FbMcGYK7fR38PrBrz/gE+CzwLPA38LfCuSe0f79yTGuTFPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQb9DzfZzH2Ei/mJAAAAAElFTkSuQmCC\n",
68 | "text/plain": [
69 | ""
70 | ]
71 | },
72 | "metadata": {},
73 | "output_type": "display_data"
74 | }
75 | ],
76 | "source": [
77 | "prev_state = env.reset()\n",
78 | "plt.imshow(prev_state)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | " Training Q Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | " Hyper-parameters "
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 4,
98 | "metadata": {
99 | "collapsed": true
100 | },
101 | "outputs": [],
102 | "source": [
103 | "BATCH_SIZE = 32\n",
104 | "FREEZE_INTERVAL = 20000 # steps\n",
105 | "MEMORY_SIZE = 60000 \n",
106 | "OUTPUT_SIZE = 4\n",
107 | "TOTAL_EPISODES = 10000\n",
108 | "MAX_STEPS = 50\n",
109 | "INITIAL_EPSILON = 1.0\n",
110 | "FINAL_EPSILON = 0.1\n",
111 | "GAMMA = 0.99\n",
112 | "INPUT_IMAGE_DIM = 84\n",
113 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | " Save Dictionay Function "
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 5,
126 | "metadata": {
127 | "collapsed": true
128 | },
129 | "outputs": [],
130 | "source": [
131 | "def save_obj(obj, name ):\n",
132 | " with open('data/'+ name + '.pkl', 'wb') as f:\n",
133 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | " Experience Replay "
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 6,
146 | "metadata": {
147 | "collapsed": true
148 | },
149 | "outputs": [],
150 | "source": [
151 | "class Memory():\n",
152 | " \n",
153 | " def __init__(self,memsize):\n",
154 | " self.memsize = memsize\n",
155 | " self.memory = deque(maxlen=self.memsize)\n",
156 | " \n",
157 | " def add_sample(self,sample):\n",
158 | " self.memory.append(sample)\n",
159 | " \n",
160 | " def get_batch(self,size):\n",
161 | " return random.sample(self.memory,k=size)"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | " Frame Collector "
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 7,
174 | "metadata": {
175 | "collapsed": true
176 | },
177 | "outputs": [],
178 | "source": [
179 | "class FrameCollector():\n",
180 | " \n",
181 | " def __init__(self,num_frames,img_dim):\n",
182 | " self.num_frames = num_frames\n",
183 | " self.img_dim = img_dim\n",
184 | " self.frames = deque(maxlen=self.num_frames)\n",
185 | " \n",
186 | " def reset(self):\n",
187 | " tmp = np.zeros((self.img_dim,self.img_dim))\n",
188 | " for i in range(0,self.num_frames):\n",
189 | " self.frames.append(tmp)\n",
190 | " \n",
191 | " def add_frame(self,frame):\n",
192 | " self.frames.append(frame)\n",
193 | " \n",
194 | " def get_state(self):\n",
195 | " return np.array(self.frames)"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | " Preprocess Images "
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 8,
208 | "metadata": {
209 | "collapsed": true
210 | },
211 | "outputs": [],
212 | "source": [
213 | "def preprocess_image(image):\n",
214 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n",
215 | " return np.copy(image)"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 9,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "data": {
225 | "text/plain": [
226 | ""
227 | ]
228 | },
229 | "execution_count": 9,
230 | "metadata": {},
231 | "output_type": "execute_result"
232 | },
233 | {
234 | "data": {
235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADM1JREFUeJzt3X/oXfV9x/Hna4nW1k5jnIbMyMxoUGRg1OAUy9jUbKktuj+KKDLKEPJPt+laaHX7oxT2RwujrcIoiLaT4fxRq2sIxc6lljEYqfHHWk1ME22sCWqi80fnYFva9/64J9u3WbKcb+6P7/f4eT7gcu85596cz+Hwuufc8z15v1NVSGrLLy30ACTNnsGXGmTwpQYZfKlBBl9qkMGXGmTwpQaNFfwkG5LsTLI7ya2TGpSk6crx3sCTZAnwI2A9sBd4ArihqrZPbniSpmHpGJ+9BNhdVS8CJLkfuBY4avCTeJugxnLxxRcv9BAWtT179vD666/nWO8bJ/hnAS/Pmd4L/OYY/550TNu2bVvoISxq69at6/W+cYLfS5KNwMZpr0dSf+MEfx9w9pzpVd28X1BVdwJ3gqf60mIxzlX9J4A1SVYnORG4Htg0mWFJmqbjPuJX1cEkfwR8B1gCfK2qnpvYyCRNzVi/8avq28C3JzQWSTPinXtSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSg44Z/CRfS7I/ybNz5i1P8liSXd3zadMdpqRJ6nPE/2tgw2HzbgW2VNUaYEs3LWkgjhn8qvpH4F8Pm30tcE/3+h7g9yc8LklTdLy/8VdU1Svd61eBFRMaj6QZGLuTTlXV/9cow0460uJzvEf815KsBOie9x/tjVV1Z1Wtq6p+Tb0kTd3xBn8T8Inu9SeAb01mOJJmoc+f8+4D/hk4N8neJDcBXwDWJ9kFXNVNSxqIY/7Gr6objrLoygmPRdKMeOee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KA+pbfOTvJ4ku1JnktyczffbjrSQPU54h8EPl1V5wOXAp9Mcj5205EGq08nnVeq6qnu9U+BHcBZ2E1HGqx5NdRIcg5wIbCVnt10bKghLT69L+4l+SDwTeCWqnpn7rKqKuCI3XRsqCEtPr2Cn+QERqG/t6oe7mb37qYjaXHpc1U/wN3Ajqr60pxFdtORBqrPb/zLgT8AfpjkmW7enzHqnvNg11nnJeC66QxR0qT16aTzT0COsthuOtIAeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVoXjX3xrVmzRruuOOOWa5ycG688caFHoIa4BFfapDBlxrUp+beSUm+n+Rfuk46n+/mr06yNcnuJA8kOXH6w5U0CX2O+P8BXFFVFwBrgQ1JLgW+CHy5qj4EvAncNL1hSpqkPp10qqr+rZs8oXsUcAXwUDffTjrSgPStq7+kq7C7H3gMeAF4q6oOdm/Zy6it1pE+uzHJtiTb3n777UmMWdKYegW/qn5WVWuBVcAlwHl9VzC3k86pp556nMOUNEnzuqpfVW8BjwOXAcuSHLoPYBWwb8JjkzQlfa7qn5FkWff6/cB6Rh1zHwc+3r3NTjrSgPS5c28lcE+SJYy+KB6sqs1JtgP3J/kL4GlGbbYkDUCfTjo/YNQa+/D5LzL6vS9pYLxzT2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQ7+B3JbafTrK5m7aTjjRQ8zni38yoyOYhdtKRBqpvQ41VwEeBu7rpYCcdabD6HvG/AnwG+Hk3fTp20pEGq09d/Y8B+6vqyeNZgZ10pMWnT139y4FrklwNnAScAtxO10mnO+rbSUcakD7dcm+rqlVVdQ5wPfDdqroRO+lIgzXO3/E/C3wqyW5Gv/ntpCMNRJ9T/f9RVd8Dvte9tpOONFDeuSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgXoU4kuwBfgr8DDhYVeuSLAceAM4B9gDXVdWb0xmmpEmazxH/d6pqbVWt66ZvBbZU1RpgSzctaQDGOdW/llEjDbChhjQofYNfwN8neTLJxm7eiqp6pXv9KrBi4qOTNBV9i21+uKr2JTkTeCzJ83MXVlUlqSN9sPui2Ahw5plnjjVYSZPR64hfVfu65/3AI4yq676WZCVA97z/KJ+1k460yPRpoXVykl8+9Br4XeBZYBOjRhpgQw1pUPqc6q8AHhk1yGUp8LdV9WiSJ4AHk9wEvARcN71hSpqkYwa/a5xxwRHmvwFcOY1BSZou79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtT3f+dNxCmnnMKGDRtmucrBeeONNxZ6CGqAR3ypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUK/hJliV5KMnzSXYkuSzJ8iSPJdnVPZ827cFKmoy+R/zbgUer6jxGZbh2YCcdabD6VNk9Ffgt4G6AqvrPqnoLO+lIg9XniL8aOAB8PcnTSe7qymzbSUcaqD7BXwpcBHy1qi4E3uWw0/qqKkZttv6PJBuTbEuy7cCBA+OOV9IE9An+XmBvVW3tph9i9EUw7046Z5xxxiTGLGlMxwx+Vb0KvJzk3G7WlcB27KQjDVbf/5b7x8C9SU4EXgT+kNGXhp10pAHqFfyqegZYd4RFdtKRBsg796QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9amrf26SZ+Y83klyi510pOHqU2xzZ1Wtraq1wMXAvwOPYCcdabDme6p/JfBCVb2EnXSkwZpv8K8H7ute20lHGqjewe9Ka18DfOPwZXbSkYZlPkf8jwBPVdVr3bSddKSBmk/wb+B/T/PBTjrSYPUKftcddz3w8JzZXwDWJ9kFXNVNSxqAvp103gVOP2zeG9hJRxok79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtS39NafJnkuybNJ7ktyUpLVSbYm2Z3kga4Kr6QB6NNC6yzgT4B1VfUbwBJG9fW/CHy5qj4EvAncNM2BSpqcvqf6S4H3J1kKfAB4BbgCeKhbbicdaUD69M7bB/wl8BNGgX8beBJ4q6oOdm/bC5w1rUFKmqw+p/qnMeqTtxr4VeBkYEPfFdhJR1p8+pzqXwX8uKoOVNV/MaqtfzmwrDv1B1gF7DvSh+2kIy0+fYL/E+DSJB9IEka19LcDjwMf795jJx1pQPr8xt/K6CLeU8APu8/cCXwW+FSS3Yyabdw9xXFKmqC+nXQ+B3zusNkvApdMfESSps4796QGGXypQQZfapDBlxqUqprdypIDwLvA6zNb6fT9Cm7PYvVe2hbotz2/VlXHvGFmpsEHSLKtqtbNdKVT5PYsXu+lbYHJbo+n+lKDDL7UoIUI/p0LsM5pcnsWr/fStsAEt2fmv/ElLTxP9aUGzTT4STYk2dnV6bt1luseV5KzkzyeZHtXf/Dmbv7yJI8l2dU9n7bQY52PJEuSPJ1kczc92FqKSZYleSjJ80l2JLlsyPtnmrUuZxb8JEuAvwI+ApwP3JDk/FmtfwIOAp+uqvOBS4FPduO/FdhSVWuALd30kNwM7JgzPeRaircDj1bVecAFjLZrkPtn6rUuq2omD+Ay4Dtzpm8DbpvV+qewPd8C1gM7gZXdvJXAzoUe2zy2YRWjMFwBbAbC6AaRpUfaZ4v5AZwK/JjuutWc+YPcP4xK2b0MLGf0v2g3A783qf0zy1P9QxtyyGDr9CU5B7gQ2AqsqKpXukWvAisWaFjH4yvAZ4Cfd9OnM9xaiquBA8DXu58udyU5mYHun5pyrUsv7s1Tkg8C3wRuqap35i6r0dfwIP5MkuRjwP6qenKhxzIhS4GLgK9W1YWMbg3/hdP6ge2fsWpdHsssg78POHvO9FHr9C1WSU5gFPp7q+rhbvZrSVZ2y1cC+xdqfPN0OXBNkj3A/YxO92+nZy3FRWgvsLdGFaNgVDXqIoa7f8aqdXksswz+E8Ca7qrkiYwuVGya4frH0tUbvBvYUVVfmrNoE6OagzCg2oNVdVtVraqqcxjti+9W1Y0MtJZiVb0KvJzk3G7WodqQg9w/TLvW5YwvWFwN/Ah4Afjzhb6AMs+xf5jRaeIPgGe6x9WMfhdvAXYB/wAsX+ixHse2/TawuXv968D3gd3AN4D3LfT45rEda4Ft3T76O+C0Ie8f4PPA88CzwN8A75vU/vHOPalBXtyTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9q0H8DxmXSvDxTWXAAAAAASUVORK5CYII=\n",
236 | "text/plain": [
237 | ""
238 | ]
239 | },
240 | "metadata": {},
241 | "output_type": "display_data"
242 | }
243 | ],
244 | "source": [
245 | "processed_prev_state = preprocess_image(prev_state)\n",
246 | "plt.imshow(processed_prev_state,cmap='gray')"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | " Build Model "
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 10,
259 | "metadata": {},
260 | "outputs": [
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "Network(\n",
266 | " (conv_layer1): Conv2d(4, 64, kernel_size=(8, 8), stride=(4, 4))\n",
267 | " (conv_layer2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))\n",
268 | " (conv_layer3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n",
269 | " (fc1): Linear(in_features=6272, out_features=512, bias=True)\n",
270 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n",
271 | " (relu): ReLU()\n",
272 | ")\n"
273 | ]
274 | }
275 | ],
276 | "source": [
277 | "import torch.nn as nn\n",
278 | "import torch\n",
279 | "\n",
280 | "class Network(nn.Module):\n",
281 | " \n",
282 | " def __init__(self,image_input_size,out_size):\n",
283 | " super(Network,self).__init__()\n",
284 | " self.image_input_size = image_input_size\n",
285 | " self.out_size = out_size\n",
286 | "\n",
287 | " self.conv_layer1 = nn.Conv2d(in_channels=4,out_channels=64,kernel_size=8,stride=4) # GRAY - 1\n",
288 | " self.conv_layer2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2)\n",
289 | " self.conv_layer3 = nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1)\n",
290 | " self.fc1 = nn.Linear(in_features=7*7*128,out_features=512)\n",
291 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n",
292 | " self.relu = nn.ReLU()\n",
293 | "\n",
294 | " def forward(self,x,bsize):\n",
295 | " x = x.view(bsize,4,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n",
296 | " conv_out = self.conv_layer1(x)\n",
297 | " conv_out = self.relu(conv_out)\n",
298 | " conv_out = self.conv_layer2(conv_out)\n",
299 | " conv_out = self.relu(conv_out)\n",
300 | " conv_out = self.conv_layer3(conv_out)\n",
301 | " conv_out = self.relu(conv_out)\n",
302 | " out = self.fc1(conv_out.view(bsize,7*7*128))\n",
303 | " out = self.relu(out)\n",
304 | " out = self.fc2(out)\n",
305 | " return out\n",
306 | "\n",
307 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).cuda()\n",
308 | "print(main_model)"
309 | ]
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "metadata": {},
314 | "source": [
315 | " Deep Q Learning with Freeze Network "
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {},
322 | "outputs": [
323 | {
324 | "name": "stdout",
325 | "output_type": "stream",
326 | "text": [
327 | "Populated 60000 Samples in Episodes : 1200\n"
328 | ]
329 | }
330 | ],
331 | "source": [
332 | "mem = Memory(memsize=MEMORY_SIZE)\n",
333 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Primary Network\n",
334 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Target Network\n",
335 | "frameObj = FrameCollector(img_dim=INPUT_IMAGE_DIM,num_frames=4)\n",
336 | "\n",
337 | "target_model.load_state_dict(main_model.state_dict())\n",
338 | "criterion = nn.SmoothL1Loss()\n",
339 | "optimizer = torch.optim.Adam(main_model.parameters())\n",
340 | "\n",
341 | "# filling memory with transitions\n",
342 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n",
343 | " \n",
344 | " prev_state = env.reset()\n",
345 | " frameObj.reset()\n",
346 | " processed_prev_state = preprocess_image(prev_state)\n",
347 | " frameObj.add_frame(processed_prev_state)\n",
348 | " prev_frames = frameObj.get_state()\n",
349 | " step_count = 0\n",
350 | " game_over = False\n",
351 | " \n",
352 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
353 | " \n",
354 | " step_count +=1\n",
355 | " action = np.random.randint(0,4)\n",
356 | " next_state,reward, game_over = env.step(action)\n",
357 | " processed_next_state = preprocess_image(next_state)\n",
358 | " frameObj.add_frame(processed_next_state)\n",
359 | " next_frames = frameObj.get_state()\n",
360 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n",
361 | " \n",
362 | " prev_state = next_state\n",
363 | " processed_prev_state = processed_next_state\n",
364 | " prev_frames = next_frames\n",
365 | "\n",
366 | "print('Populated %d Samples in Episodes : %d'%(len(mem.memory),int(MEMORY_SIZE/MAX_STEPS)))\n",
367 | "\n",
368 | "\n",
369 | "# Algorithm Starts\n",
370 | "total_steps = 0\n",
371 | "epsilon = INITIAL_EPSILON\n",
372 | "loss_stat = []\n",
373 | "total_reward_stat = []\n",
374 | "\n",
375 | "for episode in range(0,TOTAL_EPISODES):\n",
376 | " \n",
377 | " prev_state = env.reset()\n",
378 | " frameObj.reset()\n",
379 | " processed_prev_state = preprocess_image(prev_state)\n",
380 | " frameObj.add_frame(processed_prev_state)\n",
381 | " prev_frames = frameObj.get_state()\n",
382 | " game_over = False\n",
383 | " step_count = 0\n",
384 | " total_reward = 0\n",
385 | " \n",
386 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
387 | " \n",
388 | " step_count +=1\n",
389 | " total_steps +=1\n",
390 | " \n",
391 | " if np.random.rand() <= epsilon:\n",
392 | " action = np.random.randint(0,4)\n",
393 | " else:\n",
394 | " with torch.no_grad():\n",
395 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n",
396 | "\n",
397 | " model_out = main_model.forward(torch_x,bsize=1)\n",
398 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
399 | " \n",
400 | " next_state, reward, game_over = env.step(action)\n",
401 | " processed_next_state = preprocess_image(next_state)\n",
402 | " frameObj.add_frame(processed_next_state)\n",
403 | " next_frames = frameObj.get_state()\n",
404 | " total_reward += reward\n",
405 | " \n",
406 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n",
407 | " \n",
408 | " prev_state = next_state\n",
409 | " processed_prev_state = processed_next_state\n",
410 | " prev_frames = next_frames\n",
411 | " \n",
412 | " if (total_steps % FREEZE_INTERVAL) == 0:\n",
413 | " target_model.load_state_dict(main_model.state_dict())\n",
414 | " \n",
415 | " batch = mem.get_batch(size=BATCH_SIZE)\n",
416 | " current_states = []\n",
417 | " next_states = []\n",
418 | " acts = []\n",
419 | " rewards = []\n",
420 | " game_status = []\n",
421 | " \n",
422 | " for element in batch:\n",
423 | " current_states.append(element[0])\n",
424 | " acts.append(element[1])\n",
425 | " rewards.append(element[2])\n",
426 | " next_states.append(element[3])\n",
427 | " game_status.append(element[4])\n",
428 | " \n",
429 | " current_states = np.array(current_states)\n",
430 | " next_states = np.array(next_states)\n",
431 | " rewards = np.array(rewards)\n",
432 | " game_status = [not b for b in game_status]\n",
433 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n",
434 | " torch_acts = torch.tensor(acts)\n",
435 | " \n",
436 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n",
437 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n",
438 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n",
439 | " Q_max_next = Q_max_next.double()\n",
440 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n",
441 | " \n",
442 | " target_values = (rewards + (GAMMA * Q_max_next))\n",
443 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n",
444 | " \n",
445 | " loss = criterion(Q_s_a,target_values.float().cuda())\n",
446 | " \n",
447 | " # save performance measure\n",
448 | " loss_stat.append(loss.item())\n",
449 | " \n",
450 | " # make previous grad zero\n",
451 | " optimizer.zero_grad()\n",
452 | " \n",
453 | " # back - propogate \n",
454 | " loss.backward()\n",
455 | " \n",
456 | " # update params\n",
457 | " optimizer.step()\n",
458 | " \n",
459 | " # save performance measure\n",
460 | " total_reward_stat.append(total_reward)\n",
461 | " \n",
462 | " if epsilon > FINAL_EPSILON:\n",
463 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
464 | " \n",
465 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
466 | " perf = {}\n",
467 | " perf['loss'] = loss_stat\n",
468 | " perf['total_reward'] = total_reward_stat\n",
469 | " save_obj(name='FOUR_OBSERV_NINE',obj=perf)\n",
470 | " \n",
471 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)\n"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {},
477 | "source": [
478 | " Save Primary Network Weights "
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 12,
484 | "metadata": {
485 | "collapsed": true
486 | },
487 | "outputs": [],
488 | "source": [
489 | "torch.save(main_model.state_dict(),'data/FOUR_OBSERV_NINE_WEIGHTS.torch')"
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "metadata": {},
495 | "source": [
496 | " Testing Policy "
497 | ]
498 | },
499 | {
500 | "cell_type": "markdown",
501 | "metadata": {},
502 | "source": [
503 | " Load Primary Network Weights "
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 13,
509 | "metadata": {
510 | "collapsed": true
511 | },
512 | "outputs": [],
513 | "source": [
514 | "weights = torch.load('data/FOUR_OBSERV_NINE_WEIGHTS.torch')\n",
515 | "main_model.load_state_dict(weights)"
516 | ]
517 | },
518 | {
519 | "cell_type": "markdown",
520 | "metadata": {},
521 | "source": [
522 | " Testing Policy "
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": null,
528 | "metadata": {
529 | "collapsed": true
530 | },
531 | "outputs": [],
532 | "source": [
533 | "# Algorithm Starts\n",
534 | "epsilon = INITIAL_EPSILON\n",
535 | "FINAL_EPSILON = 0.01\n",
536 | "total_reward_stat = []\n",
537 | "\n",
538 | "for episode in range(0,TOTAL_EPISODES):\n",
539 | " \n",
540 | " prev_state = env.reset()\n",
541 | " processed_prev_state = preprocess_image(prev_state)\n",
542 | " frameObj.reset()\n",
543 | " frameObj.add_frame(processed_prev_state)\n",
544 | " prev_frames = frameObj.get_state()\n",
545 | " game_over = False\n",
546 | " step_count = 0\n",
547 | " total_reward = 0\n",
548 | " \n",
549 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
550 | " \n",
551 | " step_count +=1\n",
552 | " \n",
553 | " if np.random.rand() <= epsilon:\n",
554 | " action = np.random.randint(0,4)\n",
555 | " else:\n",
556 | " with torch.no_grad():\n",
557 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n",
558 | "\n",
559 | " model_out = main_model.forward(torch_x,bsize=1)\n",
560 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
561 | " \n",
562 | " next_state, reward, game_over = env.step(action)\n",
563 | " processed_next_state = preprocess_image(next_state)\n",
564 | " frameObj.add_frame(processed_next_state)\n",
565 | " next_frames = frameObj.get_state()\n",
566 | " \n",
567 | " total_reward += reward\n",
568 | " \n",
569 | " prev_state = next_state\n",
570 | " processed_prev_state = processed_next_state\n",
571 | " prev_frames = next_frames\n",
572 | " \n",
573 | " # save performance measure\n",
574 | " total_reward_stat.append(total_reward)\n",
575 | " \n",
576 | " if epsilon > FINAL_EPSILON:\n",
577 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
578 | " \n",
579 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
580 | " perf = {}\n",
581 | " perf['total_reward'] = total_reward_stat\n",
582 | " save_obj(name='FOUR_OBSERV_NINE',obj=perf)\n",
583 | " \n",
584 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)"
585 | ]
586 | }
587 | ],
588 | "metadata": {
589 | "kernelspec": {
590 | "display_name": "Python [conda env:myenv]",
591 | "language": "python",
592 | "name": "conda-env-myenv-py"
593 | },
594 | "language_info": {
595 | "codemirror_mode": {
596 | "name": "ipython",
597 | "version": 3
598 | },
599 | "file_extension": ".py",
600 | "mimetype": "text/x-python",
601 | "name": "python",
602 | "nbconvert_exporter": "python",
603 | "pygments_lexer": "ipython3",
604 | "version": "3.6.5"
605 | }
606 | },
607 | "nbformat": 4,
608 | "nbformat_minor": 2
609 | }
610 |
--------------------------------------------------------------------------------
/MDP_Size_9.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from gridworld import gameEnv\n",
10 | "import numpy as np\n",
11 | "%matplotlib inline\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "from collections import deque\n",
14 | "import pickle\n",
15 | "from skimage.color import rgb2gray\n",
16 | "import random"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | " Define Environment Object "
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 3,
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "data": {
33 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADONJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUMtCiKqWsAtIRH0IkKgqIoqJG7SFppIiWkvoki9SKQqCRdVJARJUUX5EwKNZUWk1CGqKlUO5k8TsCE2xARbBpsUSkqltk7eXuy4PXFtzhyf3T07/n0/0urszOye+Y1Hz5nZ2fH7pqqQ1JZfWO4BSJo+gy81yOBLDTL4UoMMvtQggy81yOBLDVpS8JNck+SFJLuTbBrXoCRNVo73Bp4kK4AfABuBvcATwE1VtWN8w5M0CSuX8N5Lgd1V9RJAkvuB64FjBv/MM8+subm5JaxS0jvZs2cPr7/+ehZ63VKCfzbwyrzpvcBvvtMb5ubm2L59+xJWKemdbNiwodfrJn5xL8ktSbYn2X7w4MFJr05SD0sJ/j7gnHnTa7t5P6eq7qyqDVW14ayzzlrC6iSNy1KC/wRwXpJ1SU4GbgQ2j2dYkibpuD/jV9WhJH8EfAtYAXylqp4b28gkTcxSLu5RVd8EvjmmsUiaEu/ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGLem/5c6CZMG6gtLMWq429R7xpQYZfKlBCwY/yVeSHEjy7Lx5q5I8lmRX9/P0yQ5T0jj1OeL/NXDNEfM2AVur6jxgazctaSAWDH5V/SPwr0fMvh64p3t+D/D7Yx6XpAk63s/4q6tqf/f8VWD1mMYjaQqWfHGvRt9HHPM7CTvpSLPneIP/WpI1AN3PA8d6oZ10pNlzvMHfDHyse/4x4BvjGY6kaejzdd59wD8D5yfZm+Rm4HPAxiS7gKu7aUkDseAtu1V10zEWXTXmsUiaEu/ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUp/TWOUkeT7IjyXNJbu3m201HGqg+R/xDwCer6kLgMuDjSS7EbjrSYPXppLO/qp7qnv8E2Amcjd10pMFa1Gf8JHPAxcA2enbTsaGGNHt6Bz/Je4GvA7dV1Vvzl71TNx0bakizp1fwk5zEKPT3VtXD3eze3XQkzZY+V/UD3A3srKovzFtkNx1poBZsqAFcAfwB8P0kz3Tz/oxR95wHu846LwM3TGaIksatTyedfwJyjMV205EGyDv3pAb1OdXXCeioX8EsQcb5C491fjlDxv3vN20e8aUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfalCfmnunJPlukn/pOul8tpu/Lsm2JLuTPJDk5MkPV9I49Dni/ydwZVVdBKwHrklyGfB54ItV9X7gDeDmyQ1T0jj16aRTVfXv3eRJ3aOAK4GHuvl20pEGpG9d/RVdhd0DwGPAi8CbVXWoe8leRm21jvZeO+lIM6ZX8Kvqp1W1HlgLXApc0HcFdtKZTRnzY7y/bPYNfVMXdVW/qt4EHgcuB05LcrhY51pg35jHJmlC+lzVPyvJad3zdwMbGXXMfRz4SPcyO+lIA9KnvPYa4J4kKxj9oXiwqrYk2QHcn+QvgKcZtdmSNAB9Oul8j1Fr7CPnv8To876kgfHOPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBvYPfldh+OsmWbtpOOtJALeaIfyujIpuH2UlHGqi+DTXWAh8C7uqmg510pMHqe8T/EvAp4Gfd9BnYSUcarD519T8MHKiqJ49nBXbSkWZPn7r6VwDXJbkWOAX4JeAOuk463VHfTjrSgPTplnt7Va2tqjngRuDbVfVR7KQjDdZSvsf/NPCJJLsZfea3k440EH1O9f9XVX0H+E733E460kB5557UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDFnUDjxapxvz7Mubfp2Z5xJcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9vsdPsgf4CfBT4FBVbUiyCngAmAP2ADdU1RuTGaakcVrMEf93qmp9VW3opjcBW6vqPGBrNy1pAJZyqn89o0YaYEMNaVD6Br+Av0/yZJJbunmrq2p/9/xVYPXYRydpIvreq/+BqtqX5JeBx5I8P39hVVWSo96Z3v2huAXg3HPPXdJgJY1HryN+Ve3rfh4AHmFUXfe1JGsAup8HjvFeO+lIM6ZPC61Tk/zi4efA7wLPApsZNdIAG2pIg9LnVH818MioQS4rgb+tqkeTPAE8mORm4GXghskNU9I4LRj8rnHGRUeZ/2PgqkkMStJkeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBewU9yWpKHkjyfZGeSy5OsSvJYkl3dz9MnPVhJ49H3iH8H8GhVXcCoDNdO7KQjDVafKrvvA34LuBugqv6rqt7ETjrSYPU54q8DDgJfTfJ0kru6Mtt20pEGqk/wVwKXAF+uqouBtznitL6qilGbrf8nyS1JtifZfvDgwaWOV9IY9An+XmBvVW3rph9i9IfATjoLyZgfs6zG+NDELRj8qnoVeCXJ+d2sq4Ad2ElHGqy+TTP/GLg3ycnAS8AfMvqjYScdaYB6Bb+qngE2HGWRnXSkAfLOPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBferqn5/kmXmPt5LcduJ10hlntcgGq0a2UlT0BNGn2OYLVbW+qtYDvwH8B/AIdtKRBmuxp/pXAS9W1cvYSUcarMUG/0bgvu65nXSkgeod/K609nXA145cZicdaVgWc8T/IPBUVb3WTdtJRxqoxQT/Jv7vNB/spCMNVq/gd91xNwIPz5v9OWBjkl3A1d20pAHo20nnbeCMI+b9GDvpSIPknXtSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSg3rduTfLRv8xcFbN8tjUMo/4UoMMvtQggy81yOBLDTL4UoMMvtQggy81qG/prT9N8lySZ5Pcl+SUJOuSbEuyO8kDXRVeSQPQp4XW2cCfABuq6teBFYzq638e+GJVvR94A7h5kgOVND59T/VXAu9OshJ4D7AfuBJ4qFtuJx1pQPr0ztsH/CXwI0aB/zfgSeDNqjrUvWwvcPakBilpvPqc6p/OqE/eOuBXgFOBa/quwE460uzpc6p/NfDDqjpYVf/NqLb+FcBp3ak/wFpg39HebCcdafb0Cf6PgMuSvCdJGNXS3wE8Dnyke42ddKQB6fMZfxuji3hPAd/v3nMn8GngE0l2M2q2cfcExylpjPp20vkM8JkjZr8EXDr2EUmaOO/ckxpk8KUGGXypQQZfalCmWawyyUHgbeD1qa108s7E7ZlVJ9K2QL/t+dWqWvCGmakGHyDJ9qraMNWVTpDbM7tOpG2B8W6Pp/pSgwy+1KDlCP6dy7DOSXJ7ZteJtC0wxu2Z+md8ScvPU32pQVMNfpJrkrzQ1enbNM11L1WSc5I8nmRHV3/w1m7+qiSPJdnV/Tx9uce6GElWJHk6yZZuerC1FJOcluShJM8n2Znk8iHvn0nWupxa8JOsAP4K+CBwIXBTkguntf4xOAR8sqouBC4DPt6NfxOwtarOA7Z200NyK7Bz3vSQayneATxaVRcAFzHarkHun4nXuqyqqTyAy4FvzZu+Hbh9WuufwPZ8A9gIvACs6eatAV5Y7rEtYhvWMgrDlcAWIIxuEFl5tH02yw/gfcAP6a5bzZs/yP3DqJTdK8AqRv+Ldgvwe+PaP9M81T+8IYcNtk5fkjngYmAbsLqq9neLXgVWL9OwjseXgE8BP+umz2C4tRTXAQeBr3YfXe5KcioD3T814VqXXtxbpCTvBb4O3FZVb81fVqM/w4P4miTJh4EDVfXkco9lTFYClwBfrqqLGd0a/nOn9QPbP0uqdbmQaQZ/H3DOvOlj1umbVUlOYhT6e6vq4W72a0nWdMvXAAeWa3yLdAVwXZI9wP2MTvfvoGctxRm0F9hbo4pRMKoadQnD3T9LqnW5kGkG/wngvO6q5MmMLlRsnuL6l6SrN3g3sLOqvjBv0WZGNQdhQLUHq+r2qlpbVXOM9sW3q+qjDLSWYlW9CryS5Pxu1uHakIPcP0y61uWUL1hcC/wAeBH48+W+gLLIsX+A0Wni94Bnuse1jD4XbwV2Af8ArFrusR7Htv02sKV7/mvAd4HdwNeAdy33+BaxHeuB7d0++jvg9CHvH+CzwPPAs8DfAO8a1/7xzj2pQV7ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfatD/ADPvEPBps6hQAAAAAElFTkSuQmCC\n",
34 | "text/plain": [
35 | ""
36 | ]
37 | },
38 | "metadata": {
39 | "needs_background": "light"
40 | },
41 | "output_type": "display_data"
42 | }
43 | ],
44 | "source": [
45 | "env = gameEnv(partial=False,size=9)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 4,
51 | "metadata": {},
52 | "outputs": [
53 | {
54 | "data": {
55 | "text/plain": [
56 | ""
57 | ]
58 | },
59 | "execution_count": 4,
60 | "metadata": {},
61 | "output_type": "execute_result"
62 | },
63 | {
64 | "data": {
65 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADPlJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUMtCiKqWsAtIRH0IkKgqIoqJG7SFppICbQXUaReJFKVhIsqEoKkqKL8CYEGWREpdYiiSJWD+dMEbIgNMcEWYJNCSanU1snbix23B9fGc3xmz9nx7/uRVrszu3vmN2f0nJmdnfO+qSokteWXlnsAkpaewZcaZPClBhl8qUEGX2qQwZcaZPClBi0q+EmuSPJckp1Jbh5qUJKmK0d7AU+SFcCPgI3AbuAx4Lqq2jbc8CRNw8pFvPdCYGdVvQCQ5B7gauCwwT/11FNrbm5uEYuU9E527drFa6+9liO9bjHBPx14ad70buC33+kNc3NzbN26dRGLlPRONmzY0Ot1Uz+5l+SGJFuTbN23b9+0Fyeph8UEfw9wxrzptd28t6mq26pqQ1VtOO200xaxOElDWUzwHwPOSrIuyfHAtcBDwwxL0jQd9Wf8qtqf5E+AbwErgK9U1TODjUzS1Czm5B5V9U3gmwONRdIS8co9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2rQov4tdxYkR6wr2M80uoUPNDQdu5arTb17fKlBBl9q0BGDn+QrSfYmeXrevFVJHkmyo7s/ebrDlDSkPnv8vwWuOGjezcDmqjoL2NxNSxqJIwa/qr4L/OtBs68G7uwe3wn84cDjkjRFR/sZf3VVvdw9fgVYPdB4JC2BRZ/cq8n3EYf9TsJOOtLsOdrgv5pkDUB3v/dwL7STjjR7jjb4DwEf6x5/DPjGMMORtBT6fJ13N/DPwNlJdie5HvgcsDHJDuDyblrSSBzxkt2quu4wT1028FgkLRGv3JMaZPClBhl8qUEGX2qQwZcaZPClBo2+As9grJajhrjHlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Sm9dUaSR5NsS/JMkhu7+XbTkUaqzx5/P/DJqjoXuAj4eJJzsZuONFp9Oum8XFVPdI9/BmwHTsduOtJoLegzfpI54HxgCz276dhQQ5o9vYOf5L3A14GbqurN+c+9UzcdG2pIs6dX8JMcxyT0d1XVA93s3t10JM2WPmf1A9wBbK+qL8x7ym460kj1qcBzCfBHwA+TPNXN+wsm3XPu6zrrvAhcM50hShpan0463+PwhanspiONkFfuSQ2y2OaoHPKLk6PUUHXRIX9tB4z81+ceX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBvWpuXdCku8n+Zeuk85nu/nrkmxJsjPJvUmOn/5wJQ2hzx7/P4FLq+o8YD1wRZKLgM8DX6yq9wOvA9dPb5iShtSnk05V1b93k8d1twIuBe7v5ttJRxqRvnX1V3QVdvcCjwDPA29U1f7uJbuZtNU61HvtpCPNmF7Br6qfV9V6YC1wIXBO3wXYSWdIGfDWkCF/bcfIr29BZ/Wr6g3gUeBi4KQkB4p1rgX2DDw2SVPS56z+aUlO6h6/G9jIpGPuo8BHupfZSUcakT7ltdcAdyZZweQPxX1VtSnJNuCeJH8FPMmkzZakEejTSecHTFpjHzz/BSaf9yWNjFfuSQ0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw3qHfyuxPaTSTZ103bSkUZqIXv8G5kU2TzATjrSSPVtqLEW+BBwezcd7KQjjVbfPf6XgE8Bv+imT8FOOtJo9amr/2Fgb1U9fjQLsJOONHv61NW/BLgqyZXACcCvALfSddLp9vp20pFGpE+33Fuqam1VzQHXAt+uqo9iJx1ptBbzPf6ngU8k2cnkM7+ddKSR6HOo/7+q6jvAd7rHdtKRRsor96QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUqxBHkl3Az4CfA/urakOSVcC9wBywC7imql6fzjAlDWkhe/zfq6r1VbWhm74Z2FxVZwGbu2lJI7CYQ/2rmTTSABtqSKPSN/gF/GOSx5Pc0M1bXVUvd49fAVYPPjpJU9G32OYHqmpPkl8FHkny7Pwnq6qS1KHe2P2huAHgzDPPXNRgJQ2j1x6/qvZ093uBB5lU1301yRqA7n7vYd5rJx1pxvRpoXVikl8+8Bj4feBp4CEmjTTAhhrSqPQ51F8NPDhpkMtK4O+r6uEkjwH3JbkeeBG4ZnrDlDSkIwa/a5xx3iHm/xS4bBqDkjRdXrknNcjgSw1aUO+8mXTILxGP4sdkmJ8z3xR+pDQI9/hSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBewU9yUpL7kzybZHuSi5OsSvJIkh3d/cnTHqykYfTd498KPFxV5zApw7UdO+lIo9Wnyu77gN8B7gCoqv+qqjewk440Wn32+OuAfcBXkzyZ5PauzLaddKSR6hP8lcAFwJer6nzgLQ46rK+q4jBFsJLckGRrkq379u1b7HglDaBP8HcDu6tqSzd9P5M/BLPRSSfD3Ab6MW+76dhVA92WyxGDX1WvAC8lObubdRmwDTvpSKPVt8runwJ3JTkeeAH4YyZ/NOykI41Qr+BX1VPAhkM8ZScdaYS8ck9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUJ+6+mcneWre7c0kN9lJp4ehKjIud2VG/T9jL8jap9jmc1W1vqrWA78F/AfwIHbSkUZroYf6lwHPV9WL2ElHGq2FBv9a4O7usZ10pJHqHfyutPZVwNcOfs5OOtK4LGSP/0Hgiap6tZuejU46khZsIcG/jv87zAc76Uij1Sv4XXfcjcAD82Z/DtiYZAdweTctaQT6dtJ5CzjloHk/xU460ih55Z7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoF5X7s2yyT8GNqKhVdV0uceXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBfUtv/XmSZ5I8neTuJCckWZdkS5KdSe7tqvBKGoE+LbROB/4M2FBVvwmsYFJf//PAF6vq/cDrwPXTHKik4fQ91F8JvDvJSuA9wMvApcD93fN20pFGpE/vvD3AXwM/YRL4fwMeB96oqv3dy3YDp09rkJKG1edQ/2QmffLWAb8GnAhc0XcBdtKRZk+fQ/3LgR9X1b6q+m8mtfUvAU7qDv0B1gJ7DvVmO+lIs6dP8H8CXJTkPUnCpJb+NuBR4CPda+ykI41In8/4W5icxHsC+GH3ntuATwOfSLKTSbONO6Y4TkkD6ttJ5zPAZw6a/QJw4eAjkjR1XrknNcjgSw0y+FKDDL7UoCxlscok+4C3gNeWbKHTdyquz6w6ltYF+q3Pr1fVES+YWdLgAyTZWlUblnShU+T6zK5jaV1g2PXxUF9qkMGXGrQcwb9tGZY5Ta7P7DqW1gUGXJ8l/4wvafl5qC81aEmDn+SKJM91dfpuXsplL1aSM5I8mmRbV3/wxm7+qiSPJNnR3Z+83GNdiCQrkjyZZFM3PdpaiklOSnJ/kmeTbE9y8Zi3zzRrXS5Z8JOsAP4G+CBwLnBdknOXavkD2A98sqrOBS4CPt6N/2Zgc1WdBWzupsfkRmD7vOkx11K8FXi4qs4BzmOyXqPcPlOvdVlVS3IDLga+NW/6FuCWpVr+FNbnG8BG4DlgTTdvDfDcco9tAeuwlkkYLgU2AWFygcjKQ22zWb4B7wN+THfeat78UW4fJqXsXgJWMfkv2k3AHwy1fZbyUP/Aihww2jp9SeaA84EtwOqqerl76hVg9TIN62h8CfgU8Itu+hTGW0txHbAP+Gr30eX2JCcy0u1TU6516cm9BUryXuDrwE1V9eb852ryZ3gUX5Mk+TCwt6oeX+6xDGQlcAHw5ao6n8ml4W87rB/Z9llUrcsjWcrg7wHOmDd92Dp9syrJcUxCf1dVPdDNfjXJmu75NcDe5RrfAl0CXJVkF3APk8P9W+lZS3EG7QZ216RiFEyqRl3AeLfPompdHslSBv8x4KzurOTxTE5UPLSEy1+Urt7gHcD2qvrCvKceYlJzEEZUe7CqbqmqtVU1x2RbfLuqPspIaylW1SvAS0nO7mYdqA05yu3DtGtdLvEJiyuBHwHPA3+53CdQFjj2DzA5TPwB8FR3u5LJ5+LNwA7gn4BVyz3Wo1i33wU2dY9/A/g+sBP4GvCu5R7fAtZjPbC120b/AJw85u0DfBZ4Fnga+DvgXUNtH6/ckxrkyT2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG/Q++8g/3YgFKrwAAAABJRU5ErkJggg==\n",
66 | "text/plain": [
67 | ""
68 | ]
69 | },
70 | "metadata": {
71 | "needs_background": "light"
72 | },
73 | "output_type": "display_data"
74 | }
75 | ],
76 | "source": [
77 | "prev_state = env.reset()\n",
78 | "plt.imshow(prev_state)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | " Training Q Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | " Hyper-parameters "
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 7,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "BATCH_SIZE = 64\n",
102 | "FREEZE_INTERVAL = 20000 # steps\n",
103 | "MEMORY_SIZE = 60000 \n",
104 | "OUTPUT_SIZE = 4\n",
105 | "TOTAL_EPISODES = 10000\n",
106 | "MAX_STEPS = 50\n",
107 | "INITIAL_EPSILON = 1.0\n",
108 | "FINAL_EPSILON = 0.01\n",
109 | "GAMMA = 0.99\n",
110 | "INPUT_IMAGE_DIM = 84\n",
111 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | " Save Dictionay Function "
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 5,
124 | "metadata": {
125 | "collapsed": true
126 | },
127 | "outputs": [],
128 | "source": [
129 | "def save_obj(obj, name ):\n",
130 | " with open('data/'+ name + '.pkl', 'wb') as f:\n",
131 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | " Experience Replay "
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 6,
144 | "metadata": {
145 | "collapsed": true
146 | },
147 | "outputs": [],
148 | "source": [
149 | "class Memory():\n",
150 | " \n",
151 | " def __init__(self,memsize):\n",
152 | " self.memsize = memsize\n",
153 | " self.memory = deque(maxlen=self.memsize)\n",
154 | " \n",
155 | " def add_sample(self,sample):\n",
156 | " self.memory.append(sample)\n",
157 | " \n",
158 | " def get_batch(self,size):\n",
159 | " return random.sample(self.memory,k=size)"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | " Preprocess Images "
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 5,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "def preprocess_image(image):\n",
176 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n",
177 | " return np.copy(image)"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 7,
183 | "metadata": {},
184 | "outputs": [
185 | {
186 | "data": {
187 | "text/plain": [
188 | ""
189 | ]
190 | },
191 | "execution_count": 7,
192 | "metadata": {},
193 | "output_type": "execute_result"
194 | },
195 | {
196 | "data": {
197 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADVNJREFUeJzt3X+oX/V9x/Hna4nW1m71VxYyo4ulosjA6C6ZYhmbmi1xRfdHEUVGGYL/dJtdC61usFLYHy2Mtv4xCkHbueH8UVtXCY2dSy1jMKLxx1o1WqONNUGTWHXtHGxL+94f3yO7zRLvubnne+89fp4PuHy/53x/nM/h8Pqe8z3fc9/vVBWS2vILSz0ASYvP4EsNMvhSgwy+1CCDLzXI4EsNMvhSgxYU/CSbkjybZHeSm4YalKTpyrFewJNkBfB9YCOwF3gEuLaqnh5ueJKmYeUCXrsB2F1VLwAkuQu4Cjhq8E877bRat27dAhYp6e3s2bOHV199NXM9byHBPx14adb0XuA33u4F69atY+fOnQtYpKS3MzMz0+t5Uz+5l+SGJDuT7Dx48OC0Fyeph4UEfx9wxqzptd28n1NVW6pqpqpmVq1atYDFSRrKQoL/CHB2krOSHA9cA9w/zLAkTdMxf8evqkNJ/gj4FrAC+HJVPTXYyCRNzUJO7lFV3wS+OdBYJC0Sr9yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYt6N9yl4NkzrqC0rK1VG3q3eNLDTL4UoPmDH6SLyc5kOTJWfNOSfJgkue625OnO0xJQ+qzx/8bYNNh824CtlfV2cD2blrSSMwZ/Kr6Z+C1w2ZfBdze3b8d+P2BxyVpio71O/7qqnq5u/8KsHqg8UhaBAs+uVeT3yOO+puEnXSk5edYg78/yRqA7vbA0Z5oJx1p+TnW4N8PfKS7/xHgG8MMR9Ji6PNz3p3AvwLnJNmb5Hrgs8DGJM8Bl3fTkkZizkt2q+raozx02cBjkbRIvHJPapDBlxpk8KUGGXypQQZfapDBlxo0+go8LdmwYcNg7/Xwww8P9l4aH/f4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPqW3zkjyUJKnkzyV5MZuvt10pJHqs8c/BHyiqs4DLgI+muQ87KYjjVafTjovV9Vj3f2fALuA07GbjjRa8/qOn2QdcAGwg57ddGyoIS0/vYOf5L3A14CPVdWPZz/2dt10bKghLT+9gp/kOCahv6Oqvt7N7t1NR9Ly0uesfoDbgF1V9flZD9lNRxqpPhV4LgH+APhekie6eX/GpHvOPV1nnReBq6czRElD69NJ51+AHOVhu+lII+SVe1KDRl9sc9u2bYO8z+bNmwd5n2myQKaG4h5fapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9am5d0KSh5P8W9dJ5zPd/LOS7EiyO8ndSY6f/nAlDaHPHv+/gEur6nxgPbApyUXA54AvVNUHgNeB66c3TElD6tNJp6rqP7rJ47q/Ai4F7u3m20lHGpG+dfVXdBV2DwAPAs8Db1TVoe4pe5m01TrSa+2kIy0zvWruVdVPgfVJTgLuA87tu4Cq2gJsAZiZmTlit52FGEOtvKFs2LBhsPeyfl/b5nVWv6reAB4CLgZOSvLWB8daYN/AY5M0JX3O6q/q9vQkeTewkUnH3IeAD3dPs5OONCJ9DvXXALcnWcHkg+Keqtqa5GngriR/CTzOpM2WpBHo00nnu0xaYx8+/wVguC+dkhaNV+5JDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6ld7S8mC5LA3FPb7UIIMvNah38LsS248n2dpN20lHGqn57PFvZFJk8y120pFGqm9DjbXA7wG3dtPBTjrSaPXd438R+CTws276VOykI41Wn7r6HwIOVNWjx7KAqtpSVTNVNbNq1apjeQtJA+vzO/4lwJVJrgBOAH4JuIWuk06317eTjjQifbrl3lxVa6tqHXAN8O2qug476UijtZDf8T8FfDzJbibf+e2kI43EvC7ZrarvAN/p7ttJRxopr9yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkDX3NIht27YN9l6bN28e7L10ZO7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/f8ZPsAX4C/BQ4VFUzSU4B7gbWAXuAq6vq9ekMU9KQ5rPH/+2qWl9VM930TcD2qjob2N5NSxqBhRzqX8WkkQbYUEMalb7BL+Afkzya5IZu3uqqerm7/wqwevDRSZqKvtfqf7Cq9iX5ZeDBJM/MfrCqKkkd6YXdB8UNAGeeeeaCBitpGL32+FW1r7s9ANzHpLru/iRrALrbA0d5rZ10pGWmTwutE5P84lv3gd8BngTuZ9JIA2yoIY1Kn0P91cB9kwa5rAT+vqoeSPIIcE+S64EXgaunN0xJQ5oz+F3jjPOPMP9HwGXTGJSk6fLKPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBvYKf5KQk9yZ5JsmuJBcnOSXJg0me625PnvZgJQ2j7x7/FuCBqjqXSRmuXdhJRxqtPlV23wf8JnAbQFX9d1W9gZ10pNHqs8c/CzgIfCXJ40lu7cps20lHGqk+wV8JXAh8qaouAN7ksMP6qiombbb+nyQ3JNmZZOfBgwcXOl5JA+hTV38vsLeqdnTT9zIJ/v4ka6rq5bk66QBbAGZmZo744aDxu+6665Z6CJqHOff4VfUK8FKSc7pZlwFPYycdabT6Ns38Y+COJMcDLwB/yORDw0460gj1Cn5VPQHMHOEhO+lII+SVe1KDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKD5qzA09Xau3vWrPcDfwH8bTd/HbAHuLqqXh9+iG9v27Ztg7zP5s2bB3mfVr322mtLPQTNQ59im89W1fqqWg/8OvCfwH3YSUcarfke6l8GPF9VL2InHWm05hv8a4A7u/t20pFGqnfwu9LaVwJfPfwxO+lI4zKfPf5m4LGq2t9N7+866DBXJ52qmqmqmVWrVi1stJIGMZ/gX8v/HeaDnXSk0eoV/K477kbg67NmfxbYmOQ54PJuWtII9O2k8yZw6mHzfoSddKRR8so9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9rtxbzjZt2jTI+0z+wVBqg3t8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZca1Lf01p8meSrJk0nuTHJCkrOS7EiyO8ndXRVeSSMwZ/CTnA78CTBTVb8GrGBSX/9zwBeq6gPA68D10xyopOH0PdRfCbw7yUrgPcDLwKXAvd3jdtKRRqRP77x9wF8BP2QS+H8HHgXeqKpD3dP2AqdPa5CShtXnUP9kJn3yzgJ+BTgR6H2BvJ10pOWnz6H+5cAPqupgVf0Pk9r6lwAndYf+AGuBfUd6sZ10pOWnT/B/CFyU5D1JwqSW/tPAQ8CHu+fYSUcakT7f8XcwOYn3GPC97jVbgE8BH0+ym0mzjdumOE5JA+rbSefTwKcPm/0CsGHwEUmaOq/ckxpk8KUGGXypQQZfalAWs8hkkoPAm8Cri7bQ6TsN12e5eietC/Rbn1+tqjkvmFnU4AMk2VlVM4u60ClyfZavd9K6wLDr46G+1CCDLzVoKYK/ZQmWOU2uz/L1TloXGHB9Fv07vqSl56G+1KBFDX6STUme7er03bSYy16oJGckeSjJ0139wRu7+ackeTDJc93tyUs91vlIsiLJ40m2dtOjraWY5KQk9yZ5JsmuJBePeftMs9blogU/yQrgr4HNwHnAtUnOW6zlD+AQ8ImqOg+4CPhoN/6bgO1VdTawvZsekxuBXbOmx1xL8Rbggao6FzifyXqNcvtMvdZlVS3KH3Ax8K1Z0zcDNy/W8qewPt8ANgLPAmu6eWuAZ5d6bPNYh7VMwnApsBUIkwtEVh5pmy3nP+B9wA/ozlvNmj/K7cOklN1LwClM/ot2K/C7Q22fxTzUf2tF3jLaOn1J1gEXADuA1VX1cvfQK8DqJRrWsfgi8EngZ930qYy3luJZwEHgK91Xl1uTnMhIt09NudalJ/fmKcl7ga8BH6uqH89+rCYfw6P4mSTJh4ADVfXoUo9lICuBC4EvVdUFTC4N/7nD+pFtnwXVupzLYgZ/H3DGrOmj1ulbrpIcxyT0d1TV17vZ+5Os6R5fAxxYqvHN0yXAlUn2AHcxOdy/hZ61FJehvcDemlSMgknVqAsZ7/ZZUK3LuSxm8B8Bzu7OSh7P5ETF/Yu4/AXp6g3eBuyqqs/Peuh+JjUHYUS1B6vq5qpaW1XrmGyLb1fVdYy0lmJVvQK8lOScbtZbtSFHuX2Ydq3LRT5hcQXwfeB54M+X+gTKPMf+QSaHid8Fnuj+rmDyvXg78BzwT8ApSz3WY1i33wK2dvffDzwM7Aa+Crxrqcc3j/VYD+zsttE/ACePefsAnwGeAZ4E/g5411Dbxyv3pAZ5ck9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlB/wsPrvc22jHaHAAAAABJRU5ErkJggg==\n",
198 | "text/plain": [
199 | ""
200 | ]
201 | },
202 | "metadata": {},
203 | "output_type": "display_data"
204 | }
205 | ],
206 | "source": [
207 | "processed_prev_state = preprocess_image(prev_state)\n",
208 | "plt.imshow(processed_prev_state,cmap='gray')"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "metadata": {},
214 | "source": [
215 | " Build Model "
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 8,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "Network(\n",
228 | " (conv_layer1): Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))\n",
229 | " (conv_layer2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
230 | " (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
231 | " (fc1): Linear(in_features=3136, out_features=512, bias=True)\n",
232 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n",
233 | " (relu): ReLU()\n",
234 | ")\n"
235 | ]
236 | }
237 | ],
238 | "source": [
239 | "import torch.nn as nn\n",
240 | "import torch\n",
241 | "\n",
242 | "class Network(nn.Module):\n",
243 | " \n",
244 | " def __init__(self,image_input_size,out_size):\n",
245 | " super(Network,self).__init__()\n",
246 | " self.image_input_size = image_input_size\n",
247 | " self.out_size = out_size\n",
248 | "\n",
249 | " self.conv_layer1 = nn.Conv2d(in_channels=1,out_channels=32,kernel_size=8,stride=4) # GRAY - 1\n",
250 | " self.conv_layer2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4,stride=2)\n",
251 | " self.conv_layer3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1)\n",
252 | " self.fc1 = nn.Linear(in_features=7*7*64,out_features=512)\n",
253 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n",
254 | " self.relu = nn.ReLU()\n",
255 | "\n",
256 | " def forward(self,x,bsize):\n",
257 | " x = x.view(bsize,1,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n",
258 | " conv_out = self.conv_layer1(x)\n",
259 | " conv_out = self.relu(conv_out)\n",
260 | " conv_out = self.conv_layer2(conv_out)\n",
261 | " conv_out = self.relu(conv_out)\n",
262 | " conv_out = self.conv_layer3(conv_out)\n",
263 | " conv_out = self.relu(conv_out)\n",
264 | " out = self.fc1(conv_out.view(bsize,7*7*64))\n",
265 | " out = self.relu(out)\n",
266 | " out = self.fc2(out)\n",
267 | " return out\n",
268 | "\n",
269 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE)\n",
270 | "print(main_model)"
271 | ]
272 | },
273 | {
274 | "cell_type": "markdown",
275 | "metadata": {},
276 | "source": [
277 | " Deep Q Learning with Target Freeze "
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [
285 | {
286 | "name": "stdout",
287 | "output_type": "stream",
288 | "text": [
289 | "Populated 200 Samples\n"
290 | ]
291 | }
292 | ],
293 | "source": [
294 | "mem = Memory(memsize=MEMORY_SIZE)\n",
295 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda()\n",
296 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda()\n",
297 | "\n",
298 | "target_model.load_state_dict(main_model.state_dict())\n",
299 | "criterion = nn.SmoothL1Loss()\n",
300 | "optimizer = torch.optim.Adam(main_model.parameters())\n",
301 | "\n",
302 | "# filling memory with transitions\n",
303 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n",
304 | " \n",
305 | " prev_state = env.reset()\n",
306 | " processed_prev_state = preprocess_image(prev_state)\n",
307 | " step_count = 0\n",
308 | " game_over = False\n",
309 | " \n",
310 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
311 | " \n",
312 | " step_count +=1\n",
313 | " action = np.random.randint(0,4)\n",
314 | " next_state,reward, game_over = env.step(action)\n",
315 | " processed_next_state = preprocess_image(next_state)\n",
316 | " mem.add_sample((processed_prev_state,action,reward,processed_next_state,game_over))\n",
317 | " \n",
318 | " prev_state = next_state\n",
319 | " processed_prev_state = processed_next_state\n",
320 | "\n",
321 | "print('Populated %d Samples'%(len(mem.memory)))\n",
322 | "\n",
323 | "# Algorithm Starts\n",
324 | "total_steps = 0\n",
325 | "epsilon = INITIAL_EPSILON\n",
326 | "loss_stat = []\n",
327 | "total_reward_stat = []\n",
328 | "\n",
329 | "for episode in range(0,TOTAL_EPISODES):\n",
330 | " \n",
331 | " prev_state = env.reset()\n",
332 | " processed_prev_state = preprocess_image(prev_state)\n",
333 | " game_over = False\n",
334 | " step_count = 0\n",
335 | " total_reward = 0\n",
336 | " \n",
337 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
338 | " \n",
339 | " step_count +=1\n",
340 | " total_steps +=1\n",
341 | " \n",
342 | " if np.random.rand() <= epsilon:\n",
343 | " action = np.random.randint(0,4)\n",
344 | " else:\n",
345 | " with torch.no_grad():\n",
346 | " torch_x = torch.from_numpy(processed_prev_state).float().cuda()\n",
347 | "\n",
348 | " model_out = main_model.forward(torch_x,bsize=1)\n",
349 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
350 | " \n",
351 | " next_state, reward, game_over = env.step(action)\n",
352 | " processed_next_state = preprocess_image(next_state)\n",
353 | " total_reward += reward\n",
354 | " \n",
355 | " mem.add_sample((processed_prev_state,action,reward,processed_next_state,game_over))\n",
356 | " \n",
357 | " prev_state = next_state\n",
358 | " processed_prev_state = processed_next_state\n",
359 | " \n",
360 | " if (total_steps % FREEZE_INTERVAL) == 0:\n",
361 | " target_model.load_state_dict(main_model.state_dict())\n",
362 | " \n",
363 | " batch = mem.get_batch(size=BATCH_SIZE)\n",
364 | " current_states = []\n",
365 | " next_states = []\n",
366 | " acts = []\n",
367 | " rewards = []\n",
368 | " game_status = []\n",
369 | " \n",
370 | " for element in batch:\n",
371 | " current_states.append(element[0])\n",
372 | " acts.append(element[1])\n",
373 | " rewards.append(element[2])\n",
374 | " next_states.append(element[3])\n",
375 | " game_status.append(element[4])\n",
376 | " \n",
377 | " current_states = np.array(current_states)\n",
378 | " next_states = np.array(next_states)\n",
379 | " rewards = np.array(rewards)\n",
380 | " game_status = [not b for b in game_status]\n",
381 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n",
382 | " torch_acts = torch.tensor(acts)\n",
383 | " \n",
384 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n",
385 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n",
386 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n",
387 | " Q_max_next = Q_max_next.double()\n",
388 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n",
389 | " \n",
390 | " target_values = (rewards + (GAMMA * Q_max_next)).cuda()\n",
391 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n",
392 | " \n",
393 | " loss = criterion(Q_s_a,target_values.float())\n",
394 | " \n",
395 | " # save performance measure\n",
396 | " loss_stat.append(loss.item())\n",
397 | " \n",
398 | " # make previous grad zero\n",
399 | " optimizer.zero_grad()\n",
400 | " \n",
401 | " # back - propogate \n",
402 | " loss.backward()\n",
403 | " \n",
404 | " # update params\n",
405 | " optimizer.step()\n",
406 | " \n",
407 | " # save performance measure\n",
408 | " total_reward_stat.append(total_reward)\n",
409 | " \n",
410 | " if epsilon > FINAL_EPSILON:\n",
411 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
412 | " \n",
413 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
414 | " perf = {}\n",
415 | " perf['loss'] = loss_stat\n",
416 | " perf['total_reward'] = total_reward_stat\n",
417 | " save_obj(name='MDP_ENV_SIZE_NINE',obj=perf)\n",
418 | " \n",
419 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)"
420 | ]
421 | },
422 | {
423 | "cell_type": "markdown",
424 | "metadata": {},
425 | "source": [
426 | " Save Primary Network Weights "
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 18,
432 | "metadata": {
433 | "collapsed": true
434 | },
435 | "outputs": [],
436 | "source": [
437 | "torch.save(main_model.state_dict(),'data/MDP_ENV_SIZE_NINE_WEIGHTS.torch')"
438 | ]
439 | },
440 | {
441 | "cell_type": "markdown",
442 | "metadata": {},
443 | "source": [
444 | " Testing Policy "
445 | ]
446 | },
447 | {
448 | "cell_type": "markdown",
449 | "metadata": {},
450 | "source": [
451 | " Load Primary Network Weights "
452 | ]
453 | },
454 | {
455 | "cell_type": "code",
456 | "execution_count": 9,
457 | "metadata": {},
458 | "outputs": [],
459 | "source": [
460 | "weights = torch.load('data/MDP_ENV_SIZE_NINE_WEIGHTS.torch', map_location='cpu')\n",
461 | "main_model.load_state_dict(weights)"
462 | ]
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "metadata": {},
467 | "source": [
468 | " Test Policy "
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": null,
474 | "metadata": {},
475 | "outputs": [],
476 | "source": [
477 | "# Algorithm Starts\n",
478 | "epsilon = INITIAL_EPSILON\n",
479 | "FINAL_EPSILON = 0.01\n",
480 | "total_reward_stat = []\n",
481 | "\n",
482 | "for episode in range(0,TOTAL_EPISODES):\n",
483 | " \n",
484 | " prev_state = env.reset()\n",
485 | " processed_prev_state = preprocess_image(prev_state)\n",
486 | " game_over = False\n",
487 | " step_count = 0\n",
488 | " total_reward = 0\n",
489 | " \n",
490 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
491 | " \n",
492 | " step_count +=1\n",
493 | " \n",
494 | " if np.random.rand() <= epsilon:\n",
495 | " action = np.random.randint(0,4)\n",
496 | " else:\n",
497 | " with torch.no_grad():\n",
498 | " torch_x = torch.from_numpy(processed_prev_state).float().cuda()\n",
499 | "\n",
500 | " model_out = main_model.forward(torch_x,bsize=1)\n",
501 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
502 | " \n",
503 | " next_state, reward, game_over = env.step(action)\n",
504 | " processed_next_state = preprocess_image(next_state)\n",
505 | " total_reward += reward\n",
506 | " \n",
507 | " prev_state = next_state\n",
508 | " processed_prev_state = processed_next_state\n",
509 | " \n",
510 | " # save performance measure\n",
511 | " total_reward_stat.append(total_reward)\n",
512 | " \n",
513 | " if epsilon > FINAL_EPSILON:\n",
514 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
515 | " \n",
516 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
517 | " perf = {}\n",
518 | " perf['total_reward'] = total_reward_stat\n",
519 | " save_obj(name='MDP_ENV_SIZE_NINE',obj=perf)\n",
520 | " \n",
521 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)"
522 | ]
523 | },
524 | {
525 | "cell_type": "markdown",
526 | "metadata": {},
527 | "source": [
528 | " Create Policy GIF "
529 | ]
530 | },
531 | {
532 | "cell_type": "markdown",
533 | "metadata": {},
534 | "source": [
535 | " Collect Frames Of an Episode Using Trained Network "
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": 28,
541 | "metadata": {},
542 | "outputs": [
543 | {
544 | "name": "stdout",
545 | "output_type": "stream",
546 | "text": [
547 | "Total Reward : 14\n"
548 | ]
549 | }
550 | ],
551 | "source": [
552 | "frames = []\n",
553 | "random.seed(110)\n",
554 | "np.random.seed(110)\n",
555 | "\n",
556 | "for episode in range(0,1):\n",
557 | " \n",
558 | " prev_state = env.reset()\n",
559 | " processed_prev_state = preprocess_image(prev_state)\n",
560 | " frames.append(prev_state)\n",
561 | " game_over = False\n",
562 | " step_count = 0\n",
563 | " total_reward = 0\n",
564 | " \n",
565 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
566 | " \n",
567 | " step_count +=1\n",
568 | " \n",
569 | " with torch.no_grad():\n",
570 | " torch_x = torch.from_numpy(processed_prev_state).float()\n",
571 | " model_out = main_model.forward(torch_x,bsize=1)\n",
572 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
573 | " \n",
574 | " next_state, reward, game_over = env.step(action)\n",
575 | " frames.append(next_state)\n",
576 | " processed_next_state = preprocess_image(next_state)\n",
577 | " total_reward += reward\n",
578 | " \n",
579 | " prev_state = next_state\n",
580 | " processed_prev_state = processed_next_state\n",
581 | "\n",
582 | "print('Total Reward : %d'%(total_reward)) # This should output same value which verifies seed is working correctly\n",
583 | " "
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": 29,
589 | "metadata": {},
590 | "outputs": [],
591 | "source": [
592 | "from PIL import Image, ImageDraw\n",
593 | "\n",
594 | "for idx, img in enumerate(frames):\n",
595 | " image = Image.fromarray(img)\n",
596 | " drawer = ImageDraw.Draw(image)\n",
597 | " drawer.rectangle([(7,7),(76,76)], outline=(255, 255, 0))\n",
598 | " #plt.imshow(np.array(image))\n",
599 | " frames[idx] = np.array(image)\n",
600 | " "
601 | ]
602 | },
603 | {
604 | "cell_type": "markdown",
605 | "metadata": {},
606 | "source": [
607 | " Frames to GIF "
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": 30,
613 | "metadata": {},
614 | "outputs": [
615 | {
616 | "name": "stderr",
617 | "output_type": "stream",
618 | "text": [
619 | "/home/mayank/miniconda3/envs/rdqn/lib/python3.5/site-packages/ipykernel_launcher.py:5: DeprecationWarning: `imresize` is deprecated!\n",
620 | "`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
621 | "Use ``skimage.transform.resize`` instead.\n",
622 | " \"\"\"\n"
623 | ]
624 | }
625 | ],
626 | "source": [
627 | "import imageio\n",
628 | "from scipy.misc import imresize\n",
629 | "resized_frames = []\n",
630 | "for frame in frames:\n",
631 | " resized_frames.append(imresize(frame,(256,256)))\n",
632 | "imageio.mimsave('data/GIFs/MDP_SIZE_9.gif',resized_frames,fps=4)"
633 | ]
634 | }
635 | ],
636 | "metadata": {
637 | "kernelspec": {
638 | "display_name": "Python 3",
639 | "language": "python",
640 | "name": "python3"
641 | },
642 | "language_info": {
643 | "codemirror_mode": {
644 | "name": "ipython",
645 | "version": 3
646 | },
647 | "file_extension": ".py",
648 | "mimetype": "text/x-python",
649 | "name": "python",
650 | "nbconvert_exporter": "python",
651 | "pygments_lexer": "ipython3",
652 | "version": "3.5.6"
653 | }
654 | },
655 | "nbformat": 4,
656 | "nbformat_minor": 2
657 | }
658 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Recurrent-Deep-Q-Learning
2 |
3 | # Introduction
4 | Partially Observable Markov Decision Process (POMDP) is a generalization of Markov Decision Process where agent cannot directly observe
5 | the underlying state and only an observation is available. Earlier methods suggests to maintain a belief (a pmf) over all the possible states which encodes the probability of being in each state. This quickly limits the size of the problem to which we can use this method. However, the paper [Playing Atari with Deep Reinforcement Learning](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) presented an approach which uses last 4 observations as input to the learning
6 | algorithm, which can be seen as 4th order markov decision process. Many papers suggest that much performance can be obtained if we use more than last 4 frames but this is expensive from computational and storage point of view (experience replay). Recurrent networks can be used to summarize what agent has seen in past observations. **In this project, I investgated this using a simple Partially Observable Environment and found that using a single recurrent layer able to achieve much better performance than using some last k-frames.**
7 |
8 | # Environment
9 |
10 | Environment used was a 9x9 grid where *Red* colored blocks represents agent location in the grid. *Green* colored blocks are the goal points for agent. *Blue* colored blocks are the blocks which agent needs to avoid. A reward of +1 is given to agent when it eats *green* block. A reward of -1 is given to the agent when it eats *blue* block. Other movements results in zero.
11 | Observation is the RGB image of neigbouring cells of agent. Below figure describes the observation.
12 |
13 | Underlying MDP | Observation to Agent
14 | :-------------------------:|:-------------------------:
15 | .png) | 
16 |
17 | # Algorithm
18 |
19 |
20 |
21 | # How to Run ?
22 |
23 | I ran the experiment for the following cases. The corresponding code/jupyter files are linked to each experiment.
24 | * [MDP Case](https://github.com/mynkpl1998/Recurrent-Deep-Q-Learning/blob/master/MDP_Size_9.ipynb) - The underlying state was fully visible. The whole grid was given as the input to the agent.
25 | * [Single Observation](https://github.com/mynkpl1998/Recurrent-Deep-Q-Learning/blob/master/Single%20Observation.ipynb) - In this case, the most recent observation was used as the input to agent.
26 | * [Last Two Observations](https://github.com/mynkpl1998/Recurrent-Deep-Q-Learning/blob/master/Two%20Observations.ipynb) - In this case, the last two most recent observation was used as the input to agent to encode the temporal information among observations.
27 | * [LSTM Case](https://github.com/mynkpl1998/Recurrent-Deep-Q-Learning/blob/master/LSTM%2C%20BPTT%3D8.ipynb) - In this case, an LSTM layer is used to pass the temporal information among observations.
28 |
29 | # Learned Policies
30 |
31 | Fully Observable | Single Observation | LSTM
32 | :-------------------------:|:-------------------------:|:-------------------------:
33 |  |  | 
34 |
35 | # Results
36 |
37 | The figure given below compares the performance of different cases. MDP case is the best we can do as the underlying state is fully visible to the agent. However, the challenge is to perform better given an observation. The graph clearly shows the LSTM consistently performed better as the total reward per episode was much higher than using some last k-frames.
38 |
39 |
40 |
41 | # References
42 | * [Deep Recurrent Q Learning for POMDPs](https://arxiv.org/pdf/1507.06527.pdf)
43 | * [Recurrent Neural Networks for Reinforcement Learning: an Investigation of Relevant Design Choices](https://esc.fnwi.uva.nl/thesis/centraal/files/f499544468.pdf)
44 | * [Grid World POMDP Environment](https://github.com/awjuliani/DeepRL-Agents)
45 |
46 | # Requirements
47 | * Python >= 3.5
48 | * PyTorch >= 0.4.1
49 |
--------------------------------------------------------------------------------
/Two Observations.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "from gridworld import gameEnv\n",
12 | "import numpy as np\n",
13 | "%matplotlib inline\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "from collections import deque\n",
16 | "import pickle\n",
17 | "from skimage.color import rgb2gray\n",
18 | "import random\n",
19 | "import torch\n",
20 | "import torch.nn as nn"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | " Define Environment Object "
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "metadata": {},
34 | "outputs": [
35 | {
36 | "data": {
37 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADKdJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUOtFERUtYBbQiLoRYRAURVVSNykLTSREmivIvUikaokXFSRECRFFeVPCDTIikipQ1RVqhyOgSZgQ2wIBFuA3RRKSqW2Tt5ezDg9ce2cOT67e87w+36k1e7M7Gp+o/GzMzue876pKiS15RdWegCSZs/gSw0y+FKDDL7UIIMvNcjgSw0y+FKDlhX8JFcmeS7J3iS3TGpQkqYrx3sDT5I1wPeArcA+4HHg+qraNbnhSZqGtcv47PuAvVX1AkCSe4FrgGMG//TTT6+5ubllrPLtb+fOnSs9BI1cVWWx9ywn+GcCLy+Y3gf85s/7wNzcHPPz88tY5dtfsug+k5Zt6hf3ktyYZD7J/MGDB6e9OkkDLCf4+4GzFkxv7Of9jKq6vaq2VNWWM844YxmrkzQpywn+48A5STYlORG4Dnh4MsOSNE3H/Ru/qg4l+SPgG8Aa4EtV9czERiZpapZzcY+q+jrw9QmNRdKMeOee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBFg5/kS0kOJHl6wbx1SR5Nsqd/PnW6w5Q0SUOO+H8NXHnEvFuA7VV1DrC9n5Y0EosGv6r+Efi3I2ZfA9zVv74L+P0Jj0vSFB3vb/z1VfVK//pVYP2ExiNpBpZ9ca+6rpvH7LxpJx1p9Tne4L+WZANA/3zgWG+0k460+hxv8B8GPtq//ijwtckMR9IsDPnvvHuAfwbOTbIvyQ3AZ4CtSfYAV/TTkkZi0U46VXX9MRZdPuGxSJoR79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGjSk9NZZSR5LsivJM0lu6ufbTUcaqSFH/EPAJ6rqfOBi4GNJzsduOtJoDemk80pVPdG//hGwGzgTu+lIo7Wk3/hJ5oALgR0M7KZjQw1p9Rkc/CTvBr4K3FxVby5c9vO66dhQQ1p9BgU/yQl0ob+7qh7sZw/upiNpdRlyVT/AncDuqvrcgkV205FGatGGGsClwB8A303yVD/vz+i659zfd9Z5Cbh2OkOUNGlDOun8E5BjLLabjjRC3rknNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0a8vf4mqmjVjDTTx3rL8S1FB7xpQYZfKlBQ2runZTk20n+pe+k8+l+/qYkO5LsTXJfkhOnP1xJkzDkiP9fwGVVdQGwGbgyycXAZ4HPV9V7gdeBG6Y3TEmTNKSTTlXVf/STJ/SPAi4DHujn20lHGpGhdfXX9BV2DwCPAs8Db1TVof4t++jaah3ts3bSkVaZQcGvqh9X1WZgI/A+4LyhK7CTjrT6LOmqflW9ATwGXAKckuTwfQAbgf0THpukKRlyVf+MJKf0r98JbKXrmPsY8OH+bXbSkUZkyJ17G4C7kqyh+6K4v6q2JdkF3JvkL4An6dpsSRqBIZ10vkPXGvvI+S/Q/d6XNDLeuSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aHDw+xLbTybZ1k/bSUcaqaUc8W+iK7J5mJ10pJEa2lBjI/BB4I5+OthJRxqtoUf8LwCfBH7ST5+GnXSk0RpSV/9DwIGq2nk8K7CTjrT6DKmrfylwdZKrgJOAXwJuo++k0x/17aQjjciQbrm3VtXGqpoDrgO+WVUfwU460mgt5//xPwV8PMleut/8dtKRRmLIqf5PVdW3gG/1r+2kI42Ud+5JDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aFAhjiQvAj8CfgwcqqotSdYB9wFzwIvAtVX1+nSGKWmSlnLE/52q2lxVW/rpW4DtVXUOsL2fljQCyznVv4aukQbYUEMalaHBL+Dvk+xMcmM/b31VvdK/fhVYP/HRSZqKocU2319V+5P8MvBokmcXLqyqSlJH+2D/RXEjwNlnn72swUqajEFH/Kra3z8fAB6iq677WpINAP3zgWN81k460iozpIXWyUl+8fBr4HeBp4GH6RppgA01pFEZcqq/Hnioa5DLWuBvq+qRJI8D9ye5AXgJuHZ6w5Q0SYsGv2+cccFR5v8QuHwag5I0Xd65JzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVo6F/naWay0gNQAzziSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEHBT3JKkgeSPJtkd5JLkqxL8miSPf3zqdMerKTJGHrEvw14pKrOoyvDtRs76UijNaTK7nuA3wLuBKiq/66qN7CTjjRaQ474m4CDwJeTPJnkjr7Mtp10pJEaEvy1wEXAF6vqQuAtjjitr6qia7P1/yS5Mcl8kvmDBw8ud7ySJmBI8PcB+6pqRz/9AN0XgZ10pJFaNPhV9SrwcpJz+1mXA7uwk440WkP/LPePgbuTnAi8APwh3ZeGnXSkERoU/Kp6CthylEV20pFGyDv3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYNqat/bpKnFjzeTHKznXSk8RpSbPO5qtpcVZuB3wD+E3gIO+lIo7XUU/3Lgeer6iXspCON1lKDfx1wT//aTjrSSA0Ofl9a+2rgK0cus5OONC5LOeJ/AHiiql7rp+2kI43UUoJ/Pf93mg920pFGa1Dw++64W4EHF8z+DLA1yR7gin5a0ggM7aTzFnDaEfN+iJ10pFHyzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQUNLb/1pkmeSPJ3kniQnJdmUZEeSvUnu66vwShqBIS20zgT+BNhSVb8OrKGrr/9Z4PNV9V7gdeCGaQ5U0uQMPdVfC7wzyVrgXcArwGXAA/1yO+lIIzKkd95+4C+BH9AF/t+BncAbVXWof9s+4MxpDVLSZA051T+Vrk/eJuBXgJOBK4euwE460uoz5FT/CuD7VXWwqv6Hrrb+pcAp/ak/wEZg/9E+bCcdafUZEvwfABcneVeS0NXS3wU8Bny4f4+ddKQRGfIbfwfdRbwngO/2n7kd+BTw8SR76Zpt3DnFcUqaoHSNbmdjy5YtNT8/P7P1jVF3UiUdv6pa9B+Rd+5JDTL4UoMMvtQggy81aKYX95IcBN4C/nVmK52+03F7Vqu307bAsO351apa9IaZmQYfIMl8VW2Z6UqnyO1Zvd5O2wKT3R5P9aUGGXypQSsR/NtXYJ3T5PasXm+nbYEJbs/Mf+NLWnme6ksNmmnwk1yZ5Lm+Tt8ts1z3ciU5K8ljSXb19Qdv6uevS/Jokj3986krPdalSLImyZNJtvXTo62lmOSUJA8keTbJ7iSXjHn/TLPW5cyCn2QN8FfAB4DzgeuTnD+r9U/AIeATVXU+cDHwsX78twDbq+ocYHs/PSY3AbsXTI+5luJtwCNVdR5wAd12jXL/TL3WZVXN5AFcAnxjwfStwK2zWv8UtudrwFbgOWBDP28D8NxKj20J27CRLgyXAduA0N0gsvZo+2w1P4D3AN+nv261YP4o9w9dKbuXgXV0NS+3Ab83qf0zy1P9wxty2Gjr9CWZAy4EdgDrq+qVftGrwPoVGtbx+ALwSeAn/fRpjLeW4ibgIPDl/qfLHUlOZqT7p6Zc69KLe0uU5N3AV4Gbq+rNhcuq+xoexX+TJPkQcKCqdq70WCZkLXAR8MWqupDu1vCfOa0f2f5ZVq3Lxcwy+PuBsxZMH7NO32qV5AS60N9dVQ/2s19LsqFfvgE4sFLjW6JLgauTvAjcS3e6fxsDaymuQvuAfdVVjIKuatRFjHf/LKvW5WJmGfzHgXP6q5In0l2oeHiG61+Wvt7gncDuqvrcgkUP09UchBHVHqyqW6tqY1XN0e2Lb1bVRxhpLcWqehV4Ocm5/azDtSFHuX+Ydq3LGV+wuAr4HvA88OcrfQFliWN/P91p4neAp/rHVXS/i7cDe4B/ANat9FiPY9t+G9jWv/414NvAXuArwDtWenxL2I7NwHy/j/4OOHXM+wf4NPAs8DTwN8A7JrV/vHNPapAX96QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxr0v95G4v3/3GYuAAAAAElFTkSuQmCC\n",
38 | "text/plain": [
39 | ""
40 | ]
41 | },
42 | "metadata": {},
43 | "output_type": "display_data"
44 | }
45 | ],
46 | "source": [
47 | "env = gameEnv(partial=True,size=9)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 3,
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "data": {
57 | "text/plain": [
58 | ""
59 | ]
60 | },
61 | "execution_count": 3,
62 | "metadata": {},
63 | "output_type": "execute_result"
64 | },
65 | {
66 | "data": {
67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3VuMXeV5xvH/Uw8OCUljm7SWi0ltFARCVTGRlYLggpLSOjSCXEQpKJHSKi03qUraSsG0Fy2VIiVSlYSLKpIFSVGVcohDE4uLpK7jpL1yMIe2YONgEgi2DKYCcrpAdXh7sZfbwR17r5nZe2YW3/8njfZeax/Wt2bp2eswe943VYWktvzCcg9A0tIz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEUFP8m2JIeSHE6yfVKDkjRdWegXeJKsAr4HXAscAR4CbqqqA5MbnqRpmFnEa98DHK6q7wMkuRe4ATht8JP4NUFpyqoq456zmEP984DnZk0f6eZJWuEWs8fvJcnNwM3TXo6k/hYT/KPA+bOmN3bzXqeqdgA7wEN9aaVYzKH+Q8CFSTYnWQ3cCOyazLAkTdOC9/hVdSLJHwPfBFYBX6yqJyY2MklTs+A/5y1oYR7qS1M37av6kgbK4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVobPCTfDHJ8SSPz5q3LsnuJE91t2unO0xJk9Rnj//3wLZT5m0H9lTVhcCeblrSQIwNflX9K/DSKbNvAO7u7t8NfGDC45I0RQs9x19fVce6+88D6yc0HklLYNGddKqqzlQ910460sqz0D3+C0k2AHS3x0/3xKraUVVbq2rrApclacIWGvxdwEe7+x8Fvj6Z4UhaCmMbaiS5B7gaeAfwAvBXwNeA+4F3As8CH6qqUy8AzvVeNtSQpqxPQw076UhvMHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPp10zk+yN8mBJE8kuaWbbzcdaaD61NzbAGyoqkeSvA14mFEDjd8HXqqqTyfZDqytqlvHvJelt6Qpm0jprao6VlWPdPd/AhwEzsNuOtJgzauhRpJNwGXAPnp207GhhrTy9K6ym+StwHeAT1XVA0leqao1sx5/uarOeJ7vob40fROrspvkLOCrwJer6oFudu9uOpJWlj5X9QPcBRysqs/OeshuOtJA9bmqfxXwb8B/Aq91s/+C0Xn+vLrpeKgvTZ+ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgedXc01LwP5fPbOx/nKoH9/hSgwy+1KA+NffOTvLdJP/eddK5vZu/Ocm+JIeT3Jdk9fSHK2kS+uzxXwWuqapLgS3AtiSXA58BPldV7wJeBj42vWFKmqQ+nXSqqn7aTZ7V/RRwDbCzm28nHWlA+tbVX5XkMUa183cDTwOvVNWJ7ilHGLXVmuu1NyfZn2T/JAYsafF6Bb+qfl5VW4CNwHuAi/suoKp2VNXWqtq6wDFKmrB5XdWvqleAvcAVwJokJ78HsBE4OuGxSZqSPlf1fynJmu7+m4FrGXXM3Qt8sHuanXSkAenTSefXGV28W8Xog+L+qvqbJBcA9wLrgEeBj1TVq2Pey6+ljeWv6Mz85t44dtIZJH9FZ2bwx7GTjqQ5GXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Q5+V2L70SQPdtN20pEGaj57/FsYFdk8yU460kD1baixEfhd4M5uOthJRxqsvnv8zwOfBF7rps/FTjrSYPWpq/9+4HhVPbyQBdhJR1p5ZsY/hSuB65NcB5wN/CJwB10nnW6vbycdaUD6dMu9rao2VtUm4EbgW1X1YeykIw3WYv6OfyvwZ0kOMzrnv2syQ5I0bXbSWXH8FZ2ZnXTGsZOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn5h5JngF+AvwcOFFVW5OsA+4DNgHPAB+qqpenM0xJkzSfPf5vVtWWWdVytwN7qupCYE83LWkAFnOofwOjRhpgQw1pUPoGv4B/TvJwkpu7eeur6lh3/3lg/cRHJ2kqep3jA1dV1dEkvwzsTvLk7Aerqk5XSLP7oLh5rsckLY95V9lN8tfAT4E/Aq6uqmNJNgDfrqqLxrzWErJj+Ss6M6vsjjORKrtJzknytpP3gd8GHgd2MWqkATbUkAZl7B4/yQXAP3WTM8A/VtWnkpwL3A+8E3iW0Z/zXhrzXu7OxvJXdGbu8cfps8e3ocaK46/ozAz+ODbUkDQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgvv+dpyXjN9M0fe7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mTZGeSJ5McTHJFknVJdid5qrtdO+3BSpqMvnv8O4BvVNXFwKXAQeykIw1Wn2KbbwceAy6oWU9OcgjLa0srzqRq7m0GXgS+lOTRJHd2ZbbtpCMNVJ/gzwDvBr5QVZcBP+OUw/ruSOC0nXSS7E+yf7GDlTQZfYJ/BDhSVfu66Z2MPghe6A7x6W6Pz/XiqtpRVVtnddmVtMzGBr+qngeeS3Ly/P29wAHspCMNVq+GGkm2AHcCq4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9yUZLHZv38OMkn7KQjDde8Sm8lWQUcBX4D+DjwUlV9Osl2YG1V3Trm9ZbekqZsGqW33gs8XVXPAjcAd3fz7wY+MM/3krRM5hv8G4F7uvt20pEGqnfwk6wGrge+cupjdtKRhmU+e/z3AY9U1QvdtJ10pIGaT/Bv4v8O88FOOtJg9e2kcw7wQ0atsn/UzTsXO+lIK46ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6BT/JnyZ5IsnjSe5JcnaSzUn2JTmc5L6uCq+kAejTQus84E+ArVX1a8AqRvX1PwN8rqreBbwMfGyaA5U0OX0P9WeANyeZAd4CHAOuAXZ2j9tJRxqQscGvqqPA3zKqsnsM+BHwMPBKVZ3onnYEOG9ag5Q0WX0O9dcy6pO3GfgV4BxgW98F2ElHWnlmejznt4AfVNWLAEkeAK4E1iSZ6fb6Gxl10f1/qmoHsKN7reW1pRWgzzn+D4HLk7wlSRh1zD0A7AU+2D3HTjrSgPTtpHM78HvACeBR4A8ZndPfC6zr5n2kql4d8z7u8aUps5OO1CA76Uiak8GXGmTwpQYZfKlBff6OP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akv1VtXVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7g71iGZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOfZFuSQ12dvu1LuezFSnJ+kr1JDnT1B2/p5q9LsjvJU93t2uUe63wkWZXk0SQPdtODraWYZE2SnUmeTHIwyRVD3j7TrHW5ZMFPsgr4O+B9wCXATUkuWarlT8AJ4M+r6hLgcuDj3fi3A3uq6kJgTzc9JLcAB2dND7mW4h3AN6rqYuBSRus1yO0z9VqXVbUkP8AVwDdnTd8G3LZUy5/C+nwduBY4BGzo5m0ADi332OaxDhsZheEa4EEgjL4gMjPXNlvJP8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3JFThpsnb4km4DLgH3A+qo61j30PLB+mYa1EJ8HPgm81k2fy3BrKW4GXgS+1J263JnkHAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsSfSZK8HzheVQ8v91gmZAZ4N/CFqrqM0VfDX3dYP7Dts6hal+MsZfCPAufPmj5tnb6VKslZjEL/5ap6oJv9QpIN3eMbgOPLNb55uhK4PskzjCopXcPoHHlNV0YdhrWNjgBHqmpfN72T0QfBULfP/9a6rKr/Bl5X67J7zoK3z1IG/yHgwu6q5GpGFyp2LeHyF6WrN3gXcLCqPjvroV2Mag7CgGoPVtVtVbWxqjYx2hbfqqoPM9BailX1PPBckou6WSdrQw5y+zDtWpdLfMHiOuB7wNPAXy73BZR5jv0qRoeJ/wE81v1cx+i8eA/wFPAvwLrlHusC1u1q4MHu/gXAd4HDwFeANy33+OaxHluA/d02+hqwdsjbB7gdeBJ4HPgH4E2T2j5+c09qkBf3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGvQ/4fcTtlJMEyYAAAAASUVORK5CYII=\n",
68 | "text/plain": [
69 | ""
70 | ]
71 | },
72 | "metadata": {},
73 | "output_type": "display_data"
74 | }
75 | ],
76 | "source": [
77 | "prev_state = env.reset()\n",
78 | "plt.imshow(prev_state)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | " Training Q Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | " Hyper-parameters "
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 4,
98 | "metadata": {
99 | "collapsed": true
100 | },
101 | "outputs": [],
102 | "source": [
103 | "BATCH_SIZE = 32\n",
104 | "FREEZE_INTERVAL = 20000 # steps\n",
105 | "MEMORY_SIZE = 60000 \n",
106 | "OUTPUT_SIZE = 4\n",
107 | "TOTAL_EPISODES = 10000\n",
108 | "MAX_STEPS = 50\n",
109 | "INITIAL_EPSILON = 1.0\n",
110 | "FINAL_EPSILON = 0.1\n",
111 | "GAMMA = 0.99\n",
112 | "INPUT_IMAGE_DIM = 84\n",
113 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | " Save Dictionay Function "
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 6,
126 | "metadata": {
127 | "collapsed": true
128 | },
129 | "outputs": [],
130 | "source": [
131 | "def save_obj(obj, name ):\n",
132 | " with open('data/'+ name + '.pkl', 'wb') as f:\n",
133 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | " Experience Replay "
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 7,
146 | "metadata": {
147 | "collapsed": true
148 | },
149 | "outputs": [],
150 | "source": [
151 | "class Memory():\n",
152 | " \n",
153 | " def __init__(self,memsize):\n",
154 | " self.memsize = memsize\n",
155 | " self.memory = deque(maxlen=self.memsize)\n",
156 | " \n",
157 | " def add_sample(self,sample):\n",
158 | " self.memory.append(sample)\n",
159 | " \n",
160 | " def get_batch(self,size):\n",
161 | " return random.sample(self.memory,k=size)"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | " Frame Collector "
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 5,
174 | "metadata": {
175 | "collapsed": true
176 | },
177 | "outputs": [],
178 | "source": [
179 | "class FrameCollector():\n",
180 | " \n",
181 | " def __init__(self,num_frames,img_dim):\n",
182 | " self.num_frames = num_frames\n",
183 | " self.img_dim = img_dim\n",
184 | " self.frames = deque(maxlen=self.num_frames)\n",
185 | " \n",
186 | " def reset(self):\n",
187 | " tmp = np.zeros((self.img_dim,self.img_dim))\n",
188 | " for i in range(0,self.num_frames):\n",
189 | " self.frames.append(tmp)\n",
190 | " \n",
191 | " def add_frame(self,frame):\n",
192 | " self.frames.append(frame)\n",
193 | " \n",
194 | " def get_state(self):\n",
195 | " return np.array(self.frames)"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | " Preprocess Images "
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 6,
208 | "metadata": {
209 | "collapsed": true
210 | },
211 | "outputs": [],
212 | "source": [
213 | "def preprocess_image(image):\n",
214 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n",
215 | " return np.copy(image)"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 7,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "data": {
225 | "text/plain": [
226 | ""
227 | ]
228 | },
229 | "execution_count": 7,
230 | "metadata": {},
231 | "output_type": "execute_result"
232 | },
233 | {
234 | "data": {
235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3V2MXPV5x/Hvr14ICUljm7SWi0kxigVCVTGRlYLggpLSEhpBLqIUlEhpldY3qUraSsG0Fy2VIiVSlYSLKpIFSVGV8hKHJhYXSV2HpL1ysDFtwcbBJBBs+YUKyNsFqsPTizluF7p4zu7O7O7h//1Iq5lz5uX8j45+c15m9nlSVUhqyy8s9wAkLT2DLzXI4EsNMvhSgwy+1CCDLzXI4EsNWlTwk1yf5FCSw0m2TWpQkqYrC/0BT5JVwPeA64AjwCPALVV1YHLDkzQNM4t47XuAw1X1fYAk9wE3Aa8b/CT+TFCasqrKuOcs5lD/fOC5WdNHunmSVrjF7PF7SbIV2Drt5UjqbzHBPwpcMGt6QzfvVapqO7AdPNSXVorFHOo/AmxKsjHJ2cDNwM7JDEvSNC14j19Vp5L8MfBNYBXwxap6YmIjkzQ1C/46b0EL81BfmrppX9WXNFAGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUFjg5/ki0lOJnl81ry1SXYleaq7XTPdYUqapD57/L8Hrn/NvG3A7qraBOzupiUNxNjgV9W/Ai+8ZvZNwD3d/XuAD0x4XJKmaKHn+Ouq6lh3/ziwbkLjkbQEFt1Jp6rqTNVz7aQjrTwL3eOfSLIeoLs9+XpPrKrtVbWlqrYscFmSJmyhwd8JfLS7/1Hg65MZjqSlMLahRpJ7gWuAdwAngL8CvgY8ALwTeBb4UFW99gLgXO9lQw1pyvo01LCTjvQGYycdSXMy+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw3q00nngiQPJzmQ5Ikkt3bz7aYjDVSfmnvrgfVV9WiStwH7GDXQ+H3ghar6dJJtwJqqum3Me1l6S5qyiZTeqqpjVfVod/8nwEHgfOymIw3WvBpqJLkQuBzYQ89uOjbUkFae3lV2k7wV+A7wqap6MMlLVbV61uMvVtUZz/M91Jemb2JVdpOcBXwV+HJVPdjN7t1NR9LK0ueqfoC7gYNV9dlZD9lNRxqoPlf1rwb+DfhP4JVu9l8wOs+fVzcdD/Wl6bOTjtQgO+lImpPBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxo0r5p7mr6l/DfpIRrVhdFiuceXGmTwpQb1qbl3TpLvJvn3rpPOHd38jUn2JDmc5P4kZ09/uJImoc8e/2Xg2qq6DNgMXJ/kCuAzwOeq6l3Ai8DHpjdMSZPUp5NOVdVPu8mzur8CrgV2dPPtpCMNSN+6+quSPMaodv4u4Gngpao61T3lCKO2WnO9dmuSvUn2TmLAkhavV/Cr6udVtRnYALwHuKTvAqpqe1VtqaotCxyjpAmb11X9qnoJeBi4Elid5PTvADYARyc8NklT0ueq/i8lWd3dfzNwHaOOuQ8DH+yeZicdaUD6dNL5dUYX71Yx+qB4oKr+JslFwH3AWmA/8JGqennMe/mztDH85d6Z+cu98eykM0AG/8wM/nh20pE0J4MvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoN7B70ps70/yUDdtJx1poOazx7+VUZHN0+ykIw1U34YaG4DfBe7qpoOddKTB6rvH/zzwSeCVbvo87KQjDVafuvrvB05W1b6FLMBOOtLKMzP+KVwF3JjkBuAc4BeBO+k66XR7fTvpSAPSp1vu7VW1oaouBG4GvlVVH8ZOOtJgLeZ7/NuAP0tymNE5/92TGZKkabOTzgpjJ50zs5POeHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoP61NwjyTPAT4CfA6eqakuStcD9wIXAM8CHqurF6QxT0iTNZ4//m1W1eVa13G3A7qraBOzupiUNwGIO9W9i1EgDbKghDUrf4Bfwz0n2JdnazVtXVce6+8eBdRMfnaSp6HWOD1xdVUeT/DKwK8mTsx+sqnq9QprdB8XWuR6TtDzmXWU3yV8DPwX+CLimqo4lWQ98u6ouHvNaS8iOYZXdM7PK7ngTqbKb5Nwkbzt9H/ht4HFgJ6NGGmBDDWlQxu7xk1wE/FM3OQP8Y1V9Ksl5wAPAO4FnGX2d98KY93J3NoZ7/DNzjz9enz2+DTVWGIN/ZgZ/PBtqSJqTwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQ3//O0xLxl2laCu7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mdZEeSJ5McTHJlkrVJdiV5qrtdM+3BSpqMvnv8O4FvVNUlwGXAQeykIw1Wn2KbbwceAy6qWU9OcgjLa0srzqRq7m0Enge+lGR/kru6Mtt20pEGqk/wZ4B3A1+oqsuBn/Gaw/ruSOB1O+kk2Ztk72IHK2ky+gT/CHCkqvZ00zsYfRCc6A7x6W5PzvXiqtpeVVtmddmVtMzGBr+qjgPPJTl9/v5e4AB20pEGq1dDjSSbgbuAs4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9ycZLHZv39OMkn7KQjDde8Sm8lWQUcBX4D+DjwQlV9Osk2YE1V3Tbm9ZbekqZsGqW33gs8XVXPAjcB93Tz7wE+MM/3krRM5hv8m4F7u/t20pEGqnfwk5wN3Ah85bWP2UlHGpb57PHfBzxaVSe6aTvpSAM1n+Dfwv8d5oOddKTB6ttJ51zgh4xaZf+om3cedtKRVhw76UgNspOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQb1Cn6SP03yRJLHk9yb5JwkG5PsSXI4yf1dFV5JA9Cnhdb5wJ8AW6rq14BVjOrrfwb4XFW9C3gR+Ng0Byppcvoe6s8Ab04yA7wFOAZcC+zoHreTjjQgY4NfVUeBv2VUZfcY8CNgH/BSVZ3qnnYEOH9ag5Q0WX0O9dcw6pO3EfgV4Fzg+r4LsJOOtPLM9HjObwE/qKrnAZI8CFwFrE4y0+31NzDqovv/VNV2YHv3WstrSytAn3P8HwJXJHlLkjDqmHsAeBj4YPccO+lIA9K3k84dwO8Bp4D9wB8yOqe/D1jbzftIVb085n3c40tTZicdqUF20pE0J4MvNcjgSw0y+FKD+nyPP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akr1VtWVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7gb1+GZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOf5Pokh7o6fduWctmLleSCJA8nOdDVH7y1m782ya4kT3W3a5Z7rPORZFWS/Uke6qYHW0sxyeokO5I8meRgkiuHvH2mWetyyYKfZBXwd8D7gEuBW5JculTLn4BTwJ9X1aXAFcDHu/FvA3ZX1SZgdzc9JLcCB2dND7mW4p3AN6rqEuAyRus1yO0z9VqXVbUkf8CVwDdnTd8O3L5Uy5/C+nwduA44BKzv5q0HDi332OaxDhsYheFa4CEgjH4gMjPXNlvJf8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3pFThtsnb4kFwKXA3uAdVV1rHvoOLBumYa1EJ8HPgm80k2fx3BrKW4Enge+1J263JXkXAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsTXJEneD5ysqn3LPZYJmQHeDXyhqi5n9NPwVx3WD2z7LKrW5ThLGfyjwAWzpl+3Tt9KleQsRqH/clU92M0+kWR99/h64ORyjW+ergJuTPIMo0pK1zI6R17dlVGHYW2jI8CRqtrTTe9g9EEw1O3zv7Uuq+q/gVfVuuyes+Dts5TBfwTY1F2VPJvRhYqdS7j8RenqDd4NHKyqz856aCejmoMwoNqDVXV7VW2oqgsZbYtvVdWHGWgtxao6DjyX5OJu1unakIPcPky71uUSX7C4Afge8DTwl8t9AWWeY7+a0WHifwCPdX83MDov3g08BfwLsHa5x7qAdbsGeKi7fxHwXeAw8BXgTcs9vnmsx2Zgb7eNvgasGfL2Ae4AngQeB/4BeNOkto+/3JMa5MU9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBv0PANIhuBFMr1IAAAAASUVORK5CYII=\n",
236 | "text/plain": [
237 | ""
238 | ]
239 | },
240 | "metadata": {},
241 | "output_type": "display_data"
242 | }
243 | ],
244 | "source": [
245 | "processed_prev_state = preprocess_image(prev_state)\n",
246 | "plt.imshow(processed_prev_state,cmap='gray')"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | " Build Model "
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 8,
259 | "metadata": {},
260 | "outputs": [
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "Network(\n",
266 | " (conv_layer1): Conv2d(2, 32, kernel_size=(8, 8), stride=(4, 4))\n",
267 | " (conv_layer2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
268 | " (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
269 | " (fc1): Linear(in_features=3136, out_features=512, bias=True)\n",
270 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n",
271 | " (relu): ReLU()\n",
272 | ")\n"
273 | ]
274 | }
275 | ],
276 | "source": [
277 | "import torch.nn as nn\n",
278 | "import torch\n",
279 | "\n",
280 | "class Network(nn.Module):\n",
281 | " \n",
282 | " def __init__(self,image_input_size,out_size):\n",
283 | " super(Network,self).__init__()\n",
284 | " self.image_input_size = image_input_size\n",
285 | " self.out_size = out_size\n",
286 | "\n",
287 | " self.conv_layer1 = nn.Conv2d(in_channels=2,out_channels=32,kernel_size=8,stride=4) # GRAY - 1\n",
288 | " self.conv_layer2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4,stride=2)\n",
289 | " self.conv_layer3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1)\n",
290 | " self.fc1 = nn.Linear(in_features=7*7*64,out_features=512)\n",
291 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n",
292 | " self.relu = nn.ReLU()\n",
293 | "\n",
294 | " def forward(self,x,bsize):\n",
295 | " x = x.view(bsize,2,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n",
296 | " conv_out = self.conv_layer1(x)\n",
297 | " conv_out = self.relu(conv_out)\n",
298 | " conv_out = self.conv_layer2(conv_out)\n",
299 | " conv_out = self.relu(conv_out)\n",
300 | " conv_out = self.conv_layer3(conv_out)\n",
301 | " conv_out = self.relu(conv_out)\n",
302 | " out = self.fc1(conv_out.view(bsize,7*7*64))\n",
303 | " out = self.relu(out)\n",
304 | " out = self.fc2(out)\n",
305 | " return out\n",
306 | "\n",
307 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).cuda()\n",
308 | "print(main_model)"
309 | ]
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "metadata": {},
314 | "source": [
315 | " Deep Q Learning with Freeze Network "
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {},
322 | "outputs": [
323 | {
324 | "name": "stdout",
325 | "output_type": "stream",
326 | "text": [
327 | "Populated 60000 Samples in Episodes : 1200\n"
328 | ]
329 | }
330 | ],
331 | "source": [
332 | "mem = Memory(memsize=MEMORY_SIZE)\n",
333 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Primary Network\n",
334 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Target Network\n",
335 | "frameObj = FrameCollector(img_dim=INPUT_IMAGE_DIM,num_frames=2)\n",
336 | "\n",
337 | "target_model.load_state_dict(main_model.state_dict())\n",
338 | "criterion = nn.SmoothL1Loss()\n",
339 | "optimizer = torch.optim.Adam(main_model.parameters())\n",
340 | "\n",
341 | "# filling memory with transitions\n",
342 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n",
343 | " \n",
344 | " prev_state = env.reset()\n",
345 | " frameObj.reset()\n",
346 | " processed_prev_state = preprocess_image(prev_state)\n",
347 | " frameObj.add_frame(processed_prev_state)\n",
348 | " prev_frames = frameObj.get_state()\n",
349 | " step_count = 0\n",
350 | " game_over = False\n",
351 | " \n",
352 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
353 | " \n",
354 | " step_count +=1\n",
355 | " action = np.random.randint(0,4)\n",
356 | " next_state,reward, game_over = env.step(action)\n",
357 | " processed_next_state = preprocess_image(next_state)\n",
358 | " frameObj.add_frame(processed_next_state)\n",
359 | " next_frames = frameObj.get_state()\n",
360 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n",
361 | " \n",
362 | " prev_state = next_state\n",
363 | " processed_prev_state = processed_next_state\n",
364 | " prev_frames = next_frames\n",
365 | "\n",
366 | "print('Populated %d Samples in Episodes : %d'%(len(mem.memory),int(MEMORY_SIZE/MAX_STEPS)))\n",
367 | "\n",
368 | "\n",
369 | "# Algorithm Starts\n",
370 | "total_steps = 0\n",
371 | "epsilon = INITIAL_EPSILON\n",
372 | "loss_stat = []\n",
373 | "total_reward_stat = []\n",
374 | "\n",
375 | "for episode in range(0,TOTAL_EPISODES):\n",
376 | " \n",
377 | " prev_state = env.reset()\n",
378 | " frameObj.reset()\n",
379 | " processed_prev_state = preprocess_image(prev_state)\n",
380 | " frameObj.add_frame(processed_prev_state)\n",
381 | " prev_frames = frameObj.get_state()\n",
382 | " game_over = False\n",
383 | " step_count = 0\n",
384 | " total_reward = 0\n",
385 | " \n",
386 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
387 | " \n",
388 | " step_count +=1\n",
389 | " total_steps +=1\n",
390 | " \n",
391 | " if np.random.rand() <= epsilon:\n",
392 | " action = np.random.randint(0,4)\n",
393 | " else:\n",
394 | " with torch.no_grad():\n",
395 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n",
396 | "\n",
397 | " model_out = main_model.forward(torch_x,bsize=1)\n",
398 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
399 | " \n",
400 | " next_state, reward, game_over = env.step(action)\n",
401 | " processed_next_state = preprocess_image(next_state)\n",
402 | " frameObj.add_frame(processed_next_state)\n",
403 | " next_frames = frameObj.get_state()\n",
404 | " total_reward += reward\n",
405 | " \n",
406 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n",
407 | " \n",
408 | " prev_state = next_state\n",
409 | " processed_prev_state = processed_next_state\n",
410 | " prev_frames = next_frames\n",
411 | " \n",
412 | " if (total_steps % FREEZE_INTERVAL) == 0:\n",
413 | " target_model.load_state_dict(main_model.state_dict())\n",
414 | " \n",
415 | " batch = mem.get_batch(size=BATCH_SIZE)\n",
416 | " current_states = []\n",
417 | " next_states = []\n",
418 | " acts = []\n",
419 | " rewards = []\n",
420 | " game_status = []\n",
421 | " \n",
422 | " for element in batch:\n",
423 | " current_states.append(element[0])\n",
424 | " acts.append(element[1])\n",
425 | " rewards.append(element[2])\n",
426 | " next_states.append(element[3])\n",
427 | " game_status.append(element[4])\n",
428 | " \n",
429 | " current_states = np.array(current_states)\n",
430 | " next_states = np.array(next_states)\n",
431 | " rewards = np.array(rewards)\n",
432 | " game_status = [not b for b in game_status]\n",
433 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n",
434 | " torch_acts = torch.tensor(acts)\n",
435 | " \n",
436 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n",
437 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n",
438 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n",
439 | " Q_max_next = Q_max_next.double()\n",
440 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n",
441 | " \n",
442 | " target_values = (rewards + (GAMMA * Q_max_next))\n",
443 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n",
444 | " \n",
445 | " loss = criterion(Q_s_a,target_values.float().cuda())\n",
446 | " \n",
447 | " # save performance measure\n",
448 | " loss_stat.append(loss.item())\n",
449 | " \n",
450 | " # make previous grad zero\n",
451 | " optimizer.zero_grad()\n",
452 | " \n",
453 | " # back - propogate \n",
454 | " loss.backward()\n",
455 | " \n",
456 | " # update params\n",
457 | " optimizer.step()\n",
458 | " \n",
459 | " # save performance measure\n",
460 | " total_reward_stat.append(total_reward)\n",
461 | " \n",
462 | " if epsilon > FINAL_EPSILON:\n",
463 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
464 | " \n",
465 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
466 | " perf = {}\n",
467 | " perf['loss'] = loss_stat\n",
468 | " perf['total_reward'] = total_reward_stat\n",
469 | " save_obj(name='TWO_OBSERV_NINE',obj=perf)\n",
470 | " \n",
471 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)\n"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {},
477 | "source": [
478 | " Save Primary Network Weights "
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 19,
484 | "metadata": {
485 | "collapsed": true
486 | },
487 | "outputs": [],
488 | "source": [
489 | "torch.save(main_model.state_dict(),'data/TWO_OBSERV_NINE_WEIGHTS.torch')"
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "metadata": {},
495 | "source": [
496 | " Testing Policy "
497 | ]
498 | },
499 | {
500 | "cell_type": "markdown",
501 | "metadata": {},
502 | "source": [
503 | " Load Primary Network Weights "
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 10,
509 | "metadata": {},
510 | "outputs": [],
511 | "source": [
512 | "weights = torch.load('data/TWO_OBSERV_NINE_WEIGHTS.torch')\n",
513 | "main_model.load_state_dict(weights)"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {},
519 | "source": [
520 | " Testing Policy "
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": null,
526 | "metadata": {
527 | "collapsed": true
528 | },
529 | "outputs": [],
530 | "source": [
531 | "# Algorithm Starts\n",
532 | "epsilon = INITIAL_EPSILON\n",
533 | "FINAL_EPSILON = 0.01\n",
534 | "total_reward_stat = []\n",
535 | "\n",
536 | "for episode in range(0,TOTAL_EPISODES):\n",
537 | " \n",
538 | " prev_state = env.reset()\n",
539 | " processed_prev_state = preprocess_image(prev_state)\n",
540 | " frameObj.reset()\n",
541 | " frameObj.add_frame(processed_prev_state)\n",
542 | " prev_frames = frameObj.get_state()\n",
543 | " game_over = False\n",
544 | " step_count = 0\n",
545 | " total_reward = 0\n",
546 | " \n",
547 | " while (game_over == False) and (step_count < MAX_STEPS):\n",
548 | " \n",
549 | " step_count +=1\n",
550 | " \n",
551 | " if np.random.rand() <= epsilon:\n",
552 | " action = np.random.randint(0,4)\n",
553 | " else:\n",
554 | " with torch.no_grad():\n",
555 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n",
556 | "\n",
557 | " model_out = main_model.forward(torch_x,bsize=1)\n",
558 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n",
559 | " \n",
560 | " next_state, reward, game_over = env.step(action)\n",
561 | " processed_next_state = preprocess_image(next_state)\n",
562 | " frameObj.add_frame(processed_next_state)\n",
563 | " next_frames = frameObj.get_state()\n",
564 | " \n",
565 | " total_reward += reward\n",
566 | " \n",
567 | " prev_state = next_state\n",
568 | " processed_prev_state = processed_next_state\n",
569 | " prev_frames = next_frames\n",
570 | " \n",
571 | " # save performance measure\n",
572 | " total_reward_stat.append(total_reward)\n",
573 | " \n",
574 | " if epsilon > FINAL_EPSILON:\n",
575 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n",
576 | " \n",
577 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n",
578 | " perf = {}\n",
579 | " perf['total_reward'] = total_reward_stat\n",
580 | " save_obj(name='TWO_OBSERV_NINE',obj=perf)\n",
581 | " \n",
582 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)"
583 | ]
584 | }
585 | ],
586 | "metadata": {
587 | "kernelspec": {
588 | "display_name": "Python [conda env:myenv]",
589 | "language": "python",
590 | "name": "conda-env-myenv-py"
591 | },
592 | "language_info": {
593 | "codemirror_mode": {
594 | "name": "ipython",
595 | "version": 3
596 | },
597 | "file_extension": ".py",
598 | "mimetype": "text/x-python",
599 | "name": "python",
600 | "nbconvert_exporter": "python",
601 | "pygments_lexer": "ipython3",
602 | "version": "3.6.5"
603 | }
604 | },
605 | "nbformat": 4,
606 | "nbformat_minor": 2
607 | }
608 |
--------------------------------------------------------------------------------
/__pycache__/gridworld.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/__pycache__/gridworld.cpython-35.pyc
--------------------------------------------------------------------------------
/__pycache__/gridworld.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/__pycache__/gridworld.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/helper.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/__pycache__/helper.cpython-36.pyc
--------------------------------------------------------------------------------
/data/FOUR_OBSERV_NINE.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/FOUR_OBSERV_NINE.pkl
--------------------------------------------------------------------------------
/data/FOUR_OBSERV_NINE_WEIGHTS.torch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/FOUR_OBSERV_NINE_WEIGHTS.torch
--------------------------------------------------------------------------------
/data/GIFs/LSTM_SIZE_9.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/LSTM_SIZE_9.gif
--------------------------------------------------------------------------------
/data/GIFs/LSTM_SIZE_9_frames.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/LSTM_SIZE_9_frames.gif
--------------------------------------------------------------------------------
/data/GIFs/LSTM_SIZE_9_local.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/LSTM_SIZE_9_local.gif
--------------------------------------------------------------------------------
/data/GIFs/MDP_SIZE_9.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/MDP_SIZE_9.gif
--------------------------------------------------------------------------------
/data/GIFs/SINGEL_SIZE_9_local.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/SINGEL_SIZE_9_local.gif
--------------------------------------------------------------------------------
/data/GIFs/SINGLE_OBSERV_9.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/SINGLE_OBSERV_9.gif
--------------------------------------------------------------------------------
/data/GIFs/SINGLE_SIZE_9_frames.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/SINGLE_SIZE_9_frames.gif
--------------------------------------------------------------------------------
/data/GIFs/perf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/perf.png
--------------------------------------------------------------------------------
/data/LSTM_POMDP_V4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/LSTM_POMDP_V4.pkl
--------------------------------------------------------------------------------
/data/LSTM_POMDP_V4_TEST.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/LSTM_POMDP_V4_TEST.pkl
--------------------------------------------------------------------------------
/data/LSTM_POMDP_V4_WEIGHTS.torch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/LSTM_POMDP_V4_WEIGHTS.torch
--------------------------------------------------------------------------------
/data/MDP_ENV_SIZE_NINE.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/MDP_ENV_SIZE_NINE.pkl
--------------------------------------------------------------------------------
/data/MDP_ENV_SIZE_NINE_WEIGHTS.torch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/MDP_ENV_SIZE_NINE_WEIGHTS.torch
--------------------------------------------------------------------------------
/data/SINGLE_OBSERV_NINE.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/SINGLE_OBSERV_NINE.pkl
--------------------------------------------------------------------------------
/data/SINGLE_OBSERV_NINE_WEIGHTS.torch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/SINGLE_OBSERV_NINE_WEIGHTS.torch
--------------------------------------------------------------------------------
/data/TWO_OBSERV_NINE.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/TWO_OBSERV_NINE.pkl
--------------------------------------------------------------------------------
/data/TWO_OBSERV_NINE_WEIGHTS.torch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/TWO_OBSERV_NINE_WEIGHTS.torch
--------------------------------------------------------------------------------
/data/algo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/algo.png
--------------------------------------------------------------------------------
/data/download (1).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/download (1).png
--------------------------------------------------------------------------------
/data/download.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/download.png
--------------------------------------------------------------------------------
/gridworld.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import itertools
4 | import scipy.misc
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | class gameOb():
9 | def __init__(self,coordinates,size,intensity,channel,reward,name):
10 | self.x = coordinates[0]
11 | self.y = coordinates[1]
12 | self.size = size
13 | self.intensity = intensity
14 | self.channel = channel
15 | self.reward = reward
16 | self.name = name
17 |
18 | class gameEnv():
19 | def __init__(self,partial,size):
20 | self.sizeX = size
21 | self.sizeY = size
22 | self.actions = 4
23 | self.objects = []
24 | self.partial = partial
25 | a = self.reset()
26 | plt.imshow(a,interpolation="nearest")
27 |
28 |
29 | def reset(self):
30 | self.objects = []
31 | hero = gameOb(self.newPosition(),1,1,2,None,'hero')
32 | self.objects.append(hero)
33 | bug = gameOb(self.newPosition(),1,1,1,1,'goal')
34 | self.objects.append(bug)
35 | hole = gameOb(self.newPosition(),1,1,0,-1,'fire')
36 | self.objects.append(hole)
37 | bug2 = gameOb(self.newPosition(),1,1,1,1,'goal')
38 | self.objects.append(bug2)
39 | hole2 = gameOb(self.newPosition(),1,1,0,-1,'fire')
40 | self.objects.append(hole2)
41 | bug3 = gameOb(self.newPosition(),1,1,1,1,'goal')
42 | self.objects.append(bug3)
43 | bug4 = gameOb(self.newPosition(),1,1,1,1,'goal')
44 | self.objects.append(bug4)
45 | state = self.renderEnv()
46 | self.state = state
47 | return state
48 |
49 | def moveChar(self,direction):
50 | # 0 - up, 1 - down, 2 - left, 3 - right
51 | hero = self.objects[0]
52 | heroX = hero.x
53 | heroY = hero.y
54 | penalize = 0.
55 | if direction == 0 and hero.y >= 1:
56 | hero.y -= 1
57 | if direction == 1 and hero.y <= self.sizeY-2:
58 | hero.y += 1
59 | if direction == 2 and hero.x >= 1:
60 | hero.x -= 1
61 | if direction == 3 and hero.x <= self.sizeX-2:
62 | hero.x += 1
63 | if hero.x == heroX and hero.y == heroY:
64 | penalize = 0.0
65 | self.objects[0] = hero
66 | return penalize
67 |
68 | def newPosition(self):
69 | iterables = [ range(self.sizeX), range(self.sizeY)]
70 | points = []
71 | for t in itertools.product(*iterables):
72 | points.append(t)
73 | currentPositions = []
74 | for objectA in self.objects:
75 | if (objectA.x,objectA.y) not in currentPositions:
76 | currentPositions.append((objectA.x,objectA.y))
77 | for pos in currentPositions:
78 | points.remove(pos)
79 | location = np.random.choice(range(len(points)),replace=False)
80 | return points[location]
81 |
82 | def checkGoal(self):
83 | others = []
84 | for obj in self.objects:
85 | if obj.name == 'hero':
86 | hero = obj
87 | else:
88 | others.append(obj)
89 | ended = False
90 | for other in others:
91 | if hero.x == other.x and hero.y == other.y:
92 | self.objects.remove(other)
93 | if other.reward == 1:
94 | self.objects.append(gameOb(self.newPosition(),1,1,1,1,'goal'))
95 | else:
96 | self.objects.append(gameOb(self.newPosition(),1,1,0,-1,'fire'))
97 | return other.reward,False
98 | if ended == False:
99 | return 0.0,False
100 |
101 | def renderEnv(self):
102 | #a = np.zeros([self.sizeY,self.sizeX,3])
103 | a = np.ones([self.sizeY+2,self.sizeX+2,3])
104 | a[1:-1,1:-1,:] = 0
105 | hero = None
106 | for item in self.objects:
107 | a[item.y+1:item.y+item.size+1,item.x+1:item.x+item.size+1,item.channel] = item.intensity
108 | if item.name == 'hero':
109 | hero = item
110 | if self.partial == True:
111 | a = a[hero.y:hero.y+3,hero.x:hero.x+3,:]
112 | b = scipy.misc.imresize(a[:,:,0],[84,84,1],interp='nearest')
113 | c = scipy.misc.imresize(a[:,:,1],[84,84,1],interp='nearest')
114 | d = scipy.misc.imresize(a[:,:,2],[84,84,1],interp='nearest')
115 | a = np.stack([b,c,d],axis=2)
116 | return a
117 |
118 | def step(self,action):
119 | penalty = self.moveChar(action)
120 | reward,done = self.checkGoal()
121 | state = self.renderEnv()
122 | if reward == None:
123 | print(done)
124 | print(reward)
125 | print(penalty)
126 | return state,(reward+penalty),done
127 | else:
128 | return state,(reward+penalty),done
--------------------------------------------------------------------------------