├── .ipynb_checkpoints ├── Four Observations-checkpoint.ipynb ├── LSTM, BPTT=8-checkpoint.ipynb ├── MDP_Size_9-checkpoint.ipynb ├── Single Observation-checkpoint.ipynb └── Two Observations-checkpoint.ipynb ├── Four Observations.ipynb ├── LSTM, BPTT=8.ipynb ├── MDP_Size_9.ipynb ├── README.md ├── Single Observation.ipynb ├── Two Observations.ipynb ├── __pycache__ ├── gridworld.cpython-35.pyc ├── gridworld.cpython-36.pyc └── helper.cpython-36.pyc ├── data ├── .ipynb_checkpoints │ └── Check Performance-checkpoint.ipynb ├── Check Performance.ipynb ├── FOUR_OBSERV_NINE.pkl ├── FOUR_OBSERV_NINE_WEIGHTS.torch ├── GIFs │ ├── LSTM_SIZE_9.gif │ ├── LSTM_SIZE_9_frames.gif │ ├── LSTM_SIZE_9_local.gif │ ├── MDP_SIZE_9.gif │ ├── SINGEL_SIZE_9_local.gif │ ├── SINGLE_OBSERV_9.gif │ ├── SINGLE_SIZE_9_frames.gif │ └── perf.png ├── LSTM_POMDP_V4.pkl ├── LSTM_POMDP_V4_TEST.pkl ├── LSTM_POMDP_V4_WEIGHTS.torch ├── MDP_ENV_SIZE_NINE.pkl ├── MDP_ENV_SIZE_NINE_WEIGHTS.torch ├── SINGLE_OBSERV_NINE.pkl ├── SINGLE_OBSERV_NINE_WEIGHTS.torch ├── TWO_OBSERV_NINE.pkl ├── TWO_OBSERV_NINE_WEIGHTS.torch ├── algo.png ├── download (1).png └── download.png └── gridworld.py /.ipynb_checkpoints/Four Observations-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from gridworld import gameEnv\n", 12 | "import numpy as np\n", 13 | "%matplotlib inline\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from collections import deque\n", 16 | "import pickle\n", 17 | "from skimage.color import rgb2gray\n", 18 | "import random\n", 19 | "import torch\n", 20 | "import torch.nn as nn" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "

Define Environment Object

" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3VuMXeV5xvH/Uw8OCUljm7SWi0ltFARCVTGRlYLggpLSOjSCXEQpKJHSKi03qUraSsG0Fy2VIiVSlYSLKpIFSVGVcohDE4uLpK7jpL1yMIe2YONgEgi2DKYCcrpAdXh7sZfbwR17r5nZe2YW3/8njfZeax/Wt2bp2eswe943VYWktvzCcg9A0tIz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEUFP8m2JIeSHE6yfVKDkjRdWegXeJKsAr4HXAscAR4CbqqqA5MbnqRpmFnEa98DHK6q7wMkuRe4ATht8JP4NUFpyqoq456zmEP984DnZk0f6eZJWuEWs8fvJcnNwM3TXo6k/hYT/KPA+bOmN3bzXqeqdgA7wEN9aaVYzKH+Q8CFSTYnWQ3cCOyazLAkTdOC9/hVdSLJHwPfBFYBX6yqJyY2MklTs+A/5y1oYR7qS1M37av6kgbK4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVobPCTfDHJ8SSPz5q3LsnuJE91t2unO0xJk9Rnj//3wLZT5m0H9lTVhcCeblrSQIwNflX9K/DSKbNvAO7u7t8NfGDC45I0RQs9x19fVce6+88D6yc0HklLYNGddKqqzlQ910460sqz0D3+C0k2AHS3x0/3xKraUVVbq2rrApclacIWGvxdwEe7+x8Fvj6Z4UhaCmMbaiS5B7gaeAfwAvBXwNeA+4F3As8CH6qqUy8AzvVeNtSQpqxPQw076UhvMHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPp10zk+yN8mBJE8kuaWbbzcdaaD61NzbAGyoqkeSvA14mFEDjd8HXqqqTyfZDqytqlvHvJelt6Qpm0jprao6VlWPdPd/AhwEzsNuOtJgzauhRpJNwGXAPnp207GhhrTy9K6ym+StwHeAT1XVA0leqao1sx5/uarOeJ7vob40fROrspvkLOCrwJer6oFudu9uOpJWlj5X9QPcBRysqs/OeshuOtJA9bmqfxXwb8B/Aq91s/+C0Xn+vLrpeKgvTZ+ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgedXc01LwP5fPbOx/nKoH9/hSgwy+1KA+NffOTvLdJP/eddK5vZu/Ocm+JIeT3Jdk9fSHK2kS+uzxXwWuqapLgS3AtiSXA58BPldV7wJeBj42vWFKmqQ+nXSqqn7aTZ7V/RRwDbCzm28nHWlA+tbVX5XkMUa183cDTwOvVNWJ7ilHGLXVmuu1NyfZn2T/JAYsafF6Bb+qfl5VW4CNwHuAi/suoKp2VNXWqtq6wDFKmrB5XdWvqleAvcAVwJokJ78HsBE4OuGxSZqSPlf1fynJmu7+m4FrGXXM3Qt8sHuanXSkAenTSefXGV28W8Xog+L+qvqbJBcA9wLrgEeBj1TVq2Pey6+ljeWv6Mz85t44dtIZJH9FZ2bwx7GTjqQ5GXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Q5+V2L70SQPdtN20pEGaj57/FsYFdk8yU460kD1baixEfhd4M5uOthJRxqsvnv8zwOfBF7rps/FTjrSYPWpq/9+4HhVPbyQBdhJR1p5ZsY/hSuB65NcB5wN/CJwB10nnW6vbycdaUD6dMu9rao2VtUm4EbgW1X1YeykIw3WYv6OfyvwZ0kOMzrnv2syQ5I0bXbSWXH8FZ2ZnXTGsZOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn5h5JngF+AvwcOFFVW5OsA+4DNgHPAB+qqpenM0xJkzSfPf5vVtWWWdVytwN7qupCYE83LWkAFnOofwOjRhpgQw1pUPoGv4B/TvJwkpu7eeur6lh3/3lg/cRHJ2kqep3jA1dV1dEkvwzsTvLk7Aerqk5XSLP7oLh5rsckLY95V9lN8tfAT4E/Aq6uqmNJNgDfrqqLxrzWErJj+Ss6M6vsjjORKrtJzknytpP3gd8GHgd2MWqkATbUkAZl7B4/yQXAP3WTM8A/VtWnkpwL3A+8E3iW0Z/zXhrzXu7OxvJXdGbu8cfps8e3ocaK46/ozAz+ODbUkDQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgvv+dpyXjN9M0fe7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mTZGeSJ5McTHJFknVJdid5qrtdO+3BSpqMvnv8O4BvVNXFwKXAQeykIw1Wn2KbbwceAy6oWU9OcgjLa0srzqRq7m0GXgS+lOTRJHd2ZbbtpCMNVJ/gzwDvBr5QVZcBP+OUw/ruSOC0nXSS7E+yf7GDlTQZfYJ/BDhSVfu66Z2MPghe6A7x6W6Pz/XiqtpRVVtnddmVtMzGBr+qngeeS3Ly/P29wAHspCMNVq+GGkm2AHcCq4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9yUZLHZv38OMkn7KQjDde8Sm8lWQUcBX4D+DjwUlV9Osl2YG1V3Trm9ZbekqZsGqW33gs8XVXPAjcAd3fz7wY+MM/3krRM5hv8G4F7uvt20pEGqnfwk6wGrge+cupjdtKRhmU+e/z3AY9U1QvdtJ10pIGaT/Bv4v8O88FOOtJg9e2kcw7wQ0atsn/UzTsXO+lIK46ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6BT/JnyZ5IsnjSe5JcnaSzUn2JTmc5L6uCq+kAejTQus84E+ArVX1a8AqRvX1PwN8rqreBbwMfGyaA5U0OX0P9WeANyeZAd4CHAOuAXZ2j9tJRxqQscGvqqPA3zKqsnsM+BHwMPBKVZ3onnYEOG9ag5Q0WX0O9dcy6pO3GfgV4BxgW98F2ElHWnlmejznt4AfVNWLAEkeAK4E1iSZ6fb6Gxl10f1/qmoHsKN7reW1pRWgzzn+D4HLk7wlSRh1zD0A7AU+2D3HTjrSgPTtpHM78HvACeBR4A8ZndPfC6zr5n2kql4d8z7u8aUps5OO1CA76Uiak8GXGmTwpQYZfKlBff6OP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akv1VtXVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7g71iGZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOfZFuSQ12dvu1LuezFSnJ+kr1JDnT1B2/p5q9LsjvJU93t2uUe63wkWZXk0SQPdtODraWYZE2SnUmeTHIwyRVD3j7TrHW5ZMFPsgr4O+B9wCXATUkuWarlT8AJ4M+r6hLgcuDj3fi3A3uq6kJgTzc9JLcAB2dND7mW4h3AN6rqYuBSRus1yO0z9VqXVbUkP8AVwDdnTd8G3LZUy5/C+nwduBY4BGzo5m0ADi332OaxDhsZheEa4EEgjL4gMjPXNlvJP8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3JFThpsnb4km4DLgH3A+qo61j30PLB+mYa1EJ8HPgm81k2fy3BrKW4GXgS+1J263JnkHAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsSfSZK8HzheVQ8v91gmZAZ4N/CFqrqM0VfDX3dYP7Dts6hal+MsZfCPAufPmj5tnb6VKslZjEL/5ap6oJv9QpIN3eMbgOPLNb55uhK4PskzjCopXcPoHHlNV0YdhrWNjgBHqmpfN72T0QfBULfP/9a6rKr/Bl5X67J7zoK3z1IG/yHgwu6q5GpGFyp2LeHyF6WrN3gXcLCqPjvroV2Mag7CgGoPVtVtVbWxqjYx2hbfqqoPM9BailX1PPBckou6WSdrQw5y+zDtWpdLfMHiOuB7wNPAXy73BZR5jv0qRoeJ/wE81v1cx+i8eA/wFPAvwLrlHusC1u1q4MHu/gXAd4HDwFeANy33+OaxHluA/d02+hqwdsjbB7gdeBJ4HPgH4E2T2j5+c09qkBf3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGvQ/4fcTtlJMEyYAAAAASUVORK5CYII=\n", 38 | "text/plain": [ 39 | "
" 40 | ] 41 | }, 42 | "metadata": {}, 43 | "output_type": "display_data" 44 | } 45 | ], 46 | "source": [ 47 | "env = gameEnv(partial=True,size=9)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "text/plain": [ 58 | "" 59 | ] 60 | }, 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | }, 65 | { 66 | "data": { 67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLlJREFUeJzt3X/oXfV9x/Hna4nW1m41cS5kxs2UiiIDowlOsYxNzWZd0f1RRCmjDMF/uk3XQqvbH6WwP1oYbf1jFETbyXD+qNU1hGLn0pQyGKlff6zVRJtoY01QEzudnYNtad/745xsX0OynG++997v9/h5PuBy7znn3pzP4fC659yT832/U1VIassvLPUAJM2ewZcaZPClBhl8qUEGX2qQwZcaZPClBi0q+EmuSvJckj1Jbp3UoCRNV070Bp4kK4AfApuBfcBjwA1VtXNyw5M0DSsX8dmLgT1V9QJAkvuAa4FjBj+JtwlqUTZu3LjUQ1jW9u7dy2uvvZbjvW8xwT8TeGne9D7gNxfx70nHNTc3t9RDWNY2bdo06H2LCf4gSW4Cbpr2eiQNt5jg7wfOmje9rp/3NlV1B3AHeKovLReLuar/GHBOkvVJTgauB7ZMZliSpumEj/hVdSjJHwPfAlYAX6mqZyY2MklTs6jf+FX1TeCbExqLpBnxzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQccNfpKvJDmQ5Ol581YneTTJ7v551XSHKWmShhzx/wa46oh5twLbquocYFs/LWkkjhv8qvou8K9HzL4WuLt/fTfwBxMel6QpOtHf+Guq6uX+9SvAmgmNR9IMLLqTTlXV/9cow0460vJzokf8V5OsBeifDxzrjVV1R1VtqqphTb0kTd2JBn8L8LH+9ceAb0xmOJJmYch/590L/DNwbpJ9SW4EPgdsTrIbuLKfljQSx/2NX1U3HGPRFRMei6QZ8c49qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUFDSm+dlWR7kp1Jnklycz/fbjrSSA054h8CPllV5wOXAB9Pcj5205FGa0gnnZer6on+9U+BXcCZ2E1HGq0FNdRIcjZwIbCDgd10bKghLT+DL+4leS/wdeCWqnpz/rKqKuCo3XRsqCEtP4OCn+QkutDfU1UP9bMHd9ORtLwMuaof4C5gV1V9Yd4iu+lIIzXkN/5lwB8CP0jyVD/vz+m65zzQd9Z5EbhuOkOUNGlDOun8E5BjLLabjjRC3rknNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw1aUM29RdsIzM10jeOTo1YwkybKI77UIIMvNWhIzb1Tknwvyb/0nXQ+289fn2RHkj1J7k9y8vSHK2kShhzx/xO4vKouADYAVyW5BPg88MWq+gDwOnDj9IYpaZKGdNKpqvr3fvKk/lHA5cCD/Xw76UgjMrSu/oq+wu4B4FHgeeCNqjrUv2UfXVuto332piRzSeY4OIkhS1qsQcGvqp9V1QZgHXAxcN7QFbytk84ZJzhKSRO1oKv6VfUGsB24FDgtyeH7ANYB+yc8NklTMuSq/hlJTutfvxvYTNcxdzvwkf5tdtKRRmTInXtrgbuTrKD7onigqrYm2Qncl+QvgSfp2mxJGoEhnXS+T9ca+8j5L9D93pc0Mt65JzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVocPD7EttPJtnaT9tJRxqphRzxb6YrsnmYnXSkkRraUGMd8PvAnf10sJOONFpDj/hfAj4F/LyfPh076UijNaSu/oeBA1X1+ImswE460vIzpK7+ZcA1Sa4GTgF+CbidvpNOf9S3k440IkO65d5WVeuq6mzgeuDbVfVR7KQjjdZi/h//08Ankuyh+81vJx1pJIac6v+vqvoO8J3+tZ10pJHyzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGDSrEkWQv8FPgZ8ChqtqUZDVwP3A2sBe4rqpen84wJU3SQo74v1NVG6pqUz99K7Ctqs4BtvXTkkZgMaf619I10gAbakijMjT4BfxDkseT3NTPW1NVL/evXwHWTHx0kqZiaLHND1bV/iS/Ajya5Nn5C6uqktTRPth/UXRfFr+2mKFKmpRBR/yq2t8/HwAepquu+2qStQD984FjfNZOOtIyM6SF1qlJfvHwa+B3gaeBLXSNNMCGGtKoDDnVXwM83DXIZSXwd1X1SJLHgAeS3Ai8CFw3vWFKmqTjBr9vnHHBUeb/BLhiGoOSNF3euSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aOhf503ERjYyx9wsVzk+R/0bR2myPOJLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgQcFPclqSB5M8m2RXkkuTrE7yaJLd/fOqaQ9W0mQMPeLfDjxSVefRleHahZ10pNEaUmX3fcBvAXcBVNV/VdUb2ElHGq0hR/z1wEHgq0meTHJnX2bbTjrSSA0J/krgIuDLVXUh8BZHnNZXVXGMu8yT3JRkLsncwYMHFzteSRMwJPj7gH1VtaOffpDui2DBnXTOOMNWOtJycNzgV9UrwEtJzu1nXQHsxE460mgN/bPcPwHuSXIy8ALwR3RfGnbSkUZoUPCr6ilg01EW2UlHGiHv3JMaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaNKSu/rlJnpr3eDPJLXbSkcZrSLHN56pqQ1VtADYC/wE8jJ10pNFa6Kn+FcDzVfUidtKRRmuhwb8euLd/bScdaaQGB78vrX0N8LUjl9lJRxqXhRzxPwQ8UVWv9tN20pFGaiHBv4H/O80HO+lIozUo+H133M3AQ/Nmfw7YnGQ3cGU/LWkEhnbSeQs4/Yh5P8FOOtIoeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KChpbf+LMkzSZ5Ocm+SU5KsT7IjyZ4k9/dVeCWNwJAWWmcCfwpsqqrfAFbQ1df/PPDFqvoA8Dpw4zQHKmlyhp7qrwTenWQl8B7gZeBy4MF+uZ10pBEZ0jtvP/BXwI/pAv9vwOPAG1V1qH/bPuDMaQ1S0mQNOdVfRdcnbz3wq8CpwFVDV2AnHWn5GXKqfyXwo6o6WFX/TVdb/zLgtP7UH2AdsP9oH7aTjrT8DAn+j4FLkrwnSehq6e8EtgMf6d9jJx1pRIb8xt9BdxHvCeAH/WfuAD4NfCLJHrpmG3dNcZySJmhoJ53PAJ85YvYLwMUTH5GkqfPOPalBBl9qkMGXGmTwpQalqma3suQg8Bbw2sxWOn2/jNuzXL2TtgWGbc+vV9Vxb5iZafABksxV1aaZrnSK3J7l6520LTDZ7fFUX2qQwZcatBTBv2MJ1jlNbs/y9U7aFpjg9sz8N76kpeepvtSgmQY/yVVJnuvr9N06y3UvVpKzkmxPsrOvP3hzP391kkeT7O6fVy31WBciyYokTybZ2k+PtpZiktOSPJjk2SS7klw65v0zzVqXMwt+khXAXwMfAs4Hbkhy/qzWPwGHgE9W1fnAJcDH+/HfCmyrqnOAbf30mNwM7Jo3PeZaircDj1TVecAFdNs1yv0z9VqXVTWTB3Ap8K1507cBt81q/VPYnm8Am4HngLX9vLXAc0s9tgVswzq6MFwObAVCd4PIyqPts+X8AN4H/Ij+utW8+aPcP3Sl7F4CVtP9Fe1W4PcmtX9meap/eEMOG22dviRnAxcCO4A1VfVyv+gVYM0SDetEfAn4FPDzfvp0xltLcT1wEPhq/9PlziSnMtL9U1OudenFvQVK8l7g68AtVfXm/GXVfQ2P4r9JknwYOFBVjy/1WCZkJXAR8OWqupDu1vC3ndaPbP8sqtbl8cwy+PuBs+ZNH7NO33KV5CS60N9TVQ/1s19NsrZfvhY4sFTjW6DLgGuS7AXuozvdv52BtRSXoX3AvuoqRkFXNeoixrt/FlXr8nhmGfzHgHP6q5In012o2DLD9S9KX2/wLmBXVX1h3qItdDUHYUS1B6vqtqpaV1Vn0+2Lb1fVRxlpLcWqegV4Kcm5/azDtSFHuX+Ydq3LGV+wuBr4IfA88BdLfQFlgWP/IN1p4veBp/rH1XS/i7cBu4F/BFYv9VhPYNt+G9jav34/8D1gD/A14F1LPb4FbMcGYK7fR38PrBrz/gE+CzwLPA38LfCuSe0f79yTGuTFPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQb9DzfZzH2Ei/mJAAAAAElFTkSuQmCC\n", 68 | "text/plain": [ 69 | "
" 70 | ] 71 | }, 72 | "metadata": {}, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "prev_state = env.reset()\n", 78 | "plt.imshow(prev_state)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "

Training Q Network

" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "

Hyper-parameters

" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": { 99 | "collapsed": true 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "BATCH_SIZE = 32\n", 104 | "FREEZE_INTERVAL = 20000 # steps\n", 105 | "MEMORY_SIZE = 60000 \n", 106 | "OUTPUT_SIZE = 4\n", 107 | "TOTAL_EPISODES = 10000\n", 108 | "MAX_STEPS = 50\n", 109 | "INITIAL_EPSILON = 1.0\n", 110 | "FINAL_EPSILON = 0.1\n", 111 | "GAMMA = 0.99\n", 112 | "INPUT_IMAGE_DIM = 84\n", 113 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "

Save Dictionay Function

" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": { 127 | "collapsed": true 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "def save_obj(obj, name ):\n", 132 | " with open('data/'+ name + '.pkl', 'wb') as f:\n", 133 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "

Experience Replay

" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 6, 146 | "metadata": { 147 | "collapsed": true 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "class Memory():\n", 152 | " \n", 153 | " def __init__(self,memsize):\n", 154 | " self.memsize = memsize\n", 155 | " self.memory = deque(maxlen=self.memsize)\n", 156 | " \n", 157 | " def add_sample(self,sample):\n", 158 | " self.memory.append(sample)\n", 159 | " \n", 160 | " def get_batch(self,size):\n", 161 | " return random.sample(self.memory,k=size)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "

Frame Collector

" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 7, 174 | "metadata": { 175 | "collapsed": true 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "class FrameCollector():\n", 180 | " \n", 181 | " def __init__(self,num_frames,img_dim):\n", 182 | " self.num_frames = num_frames\n", 183 | " self.img_dim = img_dim\n", 184 | " self.frames = deque(maxlen=self.num_frames)\n", 185 | " \n", 186 | " def reset(self):\n", 187 | " tmp = np.zeros((self.img_dim,self.img_dim))\n", 188 | " for i in range(0,self.num_frames):\n", 189 | " self.frames.append(tmp)\n", 190 | " \n", 191 | " def add_frame(self,frame):\n", 192 | " self.frames.append(frame)\n", 193 | " \n", 194 | " def get_state(self):\n", 195 | " return np.array(self.frames)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "

Preprocess Images

" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 8, 208 | "metadata": { 209 | "collapsed": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "def preprocess_image(image):\n", 214 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n", 215 | " return np.copy(image)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 9, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "" 227 | ] 228 | }, 229 | "execution_count": 9, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | }, 233 | { 234 | "data": { 235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADM1JREFUeJzt3X/oXfV9x/Hna4nW1k5jnIbMyMxoUGRg1OAUy9jUbKktuj+KKDLKEPJPt+laaHX7oxT2RwujrcIoiLaT4fxRq2sIxc6lljEYqfHHWk1ME22sCWqi80fnYFva9/64J9u3WbKcb+6P7/f4eT7gcu85596cz+Hwuufc8z15v1NVSGrLLy30ACTNnsGXGmTwpQYZfKlBBl9qkMGXGmTwpQaNFfwkG5LsTLI7ya2TGpSk6crx3sCTZAnwI2A9sBd4ArihqrZPbniSpmHpGJ+9BNhdVS8CJLkfuBY4avCTeJugxnLxxRcv9BAWtT179vD666/nWO8bJ/hnAS/Pmd4L/OYY/550TNu2bVvoISxq69at6/W+cYLfS5KNwMZpr0dSf+MEfx9w9pzpVd28X1BVdwJ3gqf60mIxzlX9J4A1SVYnORG4Htg0mWFJmqbjPuJX1cEkfwR8B1gCfK2qnpvYyCRNzVi/8avq28C3JzQWSTPinXtSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSg44Z/CRfS7I/ybNz5i1P8liSXd3zadMdpqRJ6nPE/2tgw2HzbgW2VNUaYEs3LWkgjhn8qvpH4F8Pm30tcE/3+h7g9yc8LklTdLy/8VdU1Svd61eBFRMaj6QZGLuTTlXV/9cow0460uJzvEf815KsBOie9x/tjVV1Z1Wtq6p+Tb0kTd3xBn8T8Inu9SeAb01mOJJmoc+f8+4D/hk4N8neJDcBXwDWJ9kFXNVNSxqIY/7Gr6objrLoygmPRdKMeOee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KA+pbfOTvJ4ku1JnktyczffbjrSQPU54h8EPl1V5wOXAp9Mcj5205EGq08nnVeq6qnu9U+BHcBZ2E1HGqx5NdRIcg5wIbCVnt10bKghLT69L+4l+SDwTeCWqnpn7rKqKuCI3XRsqCEtPr2Cn+QERqG/t6oe7mb37qYjaXHpc1U/wN3Ajqr60pxFdtORBqrPb/zLgT8AfpjkmW7enzHqnvNg11nnJeC66QxR0qT16aTzT0COsthuOtIAeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVoXjX3xrVmzRruuOOOWa5ycG688caFHoIa4BFfapDBlxrUp+beSUm+n+Rfuk46n+/mr06yNcnuJA8kOXH6w5U0CX2O+P8BXFFVFwBrgQ1JLgW+CHy5qj4EvAncNL1hSpqkPp10qqr+rZs8oXsUcAXwUDffTjrSgPStq7+kq7C7H3gMeAF4q6oOdm/Zy6it1pE+uzHJtiTb3n777UmMWdKYegW/qn5WVWuBVcAlwHl9VzC3k86pp556nMOUNEnzuqpfVW8BjwOXAcuSHLoPYBWwb8JjkzQlfa7qn5FkWff6/cB6Rh1zHwc+3r3NTjrSgPS5c28lcE+SJYy+KB6sqs1JtgP3J/kL4GlGbbYkDUCfTjo/YNQa+/D5LzL6vS9pYLxzT2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQ7+B3JbafTrK5m7aTjjRQ8zni38yoyOYhdtKRBqpvQ41VwEeBu7rpYCcdabD6HvG/AnwG+Hk3fTp20pEGq09d/Y8B+6vqyeNZgZ10pMWnT139y4FrklwNnAScAtxO10mnO+rbSUcakD7dcm+rqlVVdQ5wPfDdqroRO+lIgzXO3/E/C3wqyW5Gv/ntpCMNRJ9T/f9RVd8Dvte9tpOONFDeuSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgXoU4kuwBfgr8DDhYVeuSLAceAM4B9gDXVdWb0xmmpEmazxH/d6pqbVWt66ZvBbZU1RpgSzctaQDGOdW/llEjDbChhjQofYNfwN8neTLJxm7eiqp6pXv9KrBi4qOTNBV9i21+uKr2JTkTeCzJ83MXVlUlqSN9sPui2Ahw5plnjjVYSZPR64hfVfu65/3AI4yq676WZCVA97z/KJ+1k460yPRpoXVykl8+9Br4XeBZYBOjRhpgQw1pUPqc6q8AHhk1yGUp8LdV9WiSJ4AHk9wEvARcN71hSpqkYwa/a5xxwRHmvwFcOY1BSZou79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtT3f+dNxCmnnMKGDRtmucrBeeONNxZ6CGqAR3ypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUK/hJliV5KMnzSXYkuSzJ8iSPJdnVPZ827cFKmoy+R/zbgUer6jxGZbh2YCcdabD6VNk9Ffgt4G6AqvrPqnoLO+lIg9XniL8aOAB8PcnTSe7qymzbSUcaqD7BXwpcBHy1qi4E3uWw0/qqKkZttv6PJBuTbEuy7cCBA+OOV9IE9An+XmBvVW3tph9i9EUw7046Z5xxxiTGLGlMxwx+Vb0KvJzk3G7WlcB27KQjDVbf/5b7x8C9SU4EXgT+kNGXhp10pAHqFfyqegZYd4RFdtKRBsg796QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9amrf26SZ+Y83klyi510pOHqU2xzZ1Wtraq1wMXAvwOPYCcdabDme6p/JfBCVb2EnXSkwZpv8K8H7ute20lHGqjewe9Ka18DfOPwZXbSkYZlPkf8jwBPVdVr3bSddKSBmk/wb+B/T/PBTjrSYPUKftcddz3w8JzZXwDWJ9kFXNVNSxqAvp103gVOP2zeG9hJRxok79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtS39NafJnkuybNJ7ktyUpLVSbYm2Z3kga4Kr6QB6NNC6yzgT4B1VfUbwBJG9fW/CHy5qj4EvAncNM2BSpqcvqf6S4H3J1kKfAB4BbgCeKhbbicdaUD69M7bB/wl8BNGgX8beBJ4q6oOdm/bC5w1rUFKmqw+p/qnMeqTtxr4VeBkYEPfFdhJR1p8+pzqXwX8uKoOVNV/MaqtfzmwrDv1B1gF7DvSh+2kIy0+fYL/E+DSJB9IEka19LcDjwMf795jJx1pQPr8xt/K6CLeU8APu8/cCXwW+FSS3Yyabdw9xXFKmqC+nXQ+B3zusNkvApdMfESSps4796QGGXypQQZfapDBlxqUqprdypIDwLvA6zNb6fT9Cm7PYvVe2hbotz2/VlXHvGFmpsEHSLKtqtbNdKVT5PYsXu+lbYHJbo+n+lKDDL7UoIUI/p0LsM5pcnsWr/fStsAEt2fmv/ElLTxP9aUGzTT4STYk2dnV6bt1luseV5KzkzyeZHtXf/Dmbv7yJI8l2dU9n7bQY52PJEuSPJ1kczc92FqKSZYleSjJ80l2JLlsyPtnmrUuZxb8JEuAvwI+ApwP3JDk/FmtfwIOAp+uqvOBS4FPduO/FdhSVWuALd30kNwM7JgzPeRaircDj1bVecAFjLZrkPtn6rUuq2omD+Ay4Dtzpm8DbpvV+qewPd8C1gM7gZXdvJXAzoUe2zy2YRWjMFwBbAbC6AaRpUfaZ4v5AZwK/JjuutWc+YPcP4xK2b0MLGf0v2g3A783qf0zy1P9QxtyyGDr9CU5B7gQ2AqsqKpXukWvAisWaFjH4yvAZ4Cfd9OnM9xaiquBA8DXu58udyU5mYHun5pyrUsv7s1Tkg8C3wRuqap35i6r0dfwIP5MkuRjwP6qenKhxzIhS4GLgK9W1YWMbg3/hdP6ge2fsWpdHsssg78POHvO9FHr9C1WSU5gFPp7q+rhbvZrSVZ2y1cC+xdqfPN0OXBNkj3A/YxO92+nZy3FRWgvsLdGFaNgVDXqIoa7f8aqdXksswz+E8Ca7qrkiYwuVGya4frH0tUbvBvYUVVfmrNoE6OagzCg2oNVdVtVraqqcxjti+9W1Y0MtJZiVb0KvJzk3G7WodqQg9w/TLvW5YwvWFwN/Ah4Afjzhb6AMs+xf5jRaeIPgGe6x9WMfhdvAXYB/wAsX+ixHse2/TawuXv968D3gd3AN4D3LfT45rEda4Ft3T76O+C0Ie8f4PPA88CzwN8A75vU/vHOPalBXtyTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9q0H8DxmXSvDxTWXAAAAAASUVORK5CYII=\n", 236 | "text/plain": [ 237 | "
" 238 | ] 239 | }, 240 | "metadata": {}, 241 | "output_type": "display_data" 242 | } 243 | ], 244 | "source": [ 245 | "processed_prev_state = preprocess_image(prev_state)\n", 246 | "plt.imshow(processed_prev_state,cmap='gray')" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "

Build Model

" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 10, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Network(\n", 266 | " (conv_layer1): Conv2d(4, 64, kernel_size=(8, 8), stride=(4, 4))\n", 267 | " (conv_layer2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))\n", 268 | " (conv_layer3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n", 269 | " (fc1): Linear(in_features=6272, out_features=512, bias=True)\n", 270 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n", 271 | " (relu): ReLU()\n", 272 | ")\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "import torch.nn as nn\n", 278 | "import torch\n", 279 | "\n", 280 | "class Network(nn.Module):\n", 281 | " \n", 282 | " def __init__(self,image_input_size,out_size):\n", 283 | " super(Network,self).__init__()\n", 284 | " self.image_input_size = image_input_size\n", 285 | " self.out_size = out_size\n", 286 | "\n", 287 | " self.conv_layer1 = nn.Conv2d(in_channels=4,out_channels=64,kernel_size=8,stride=4) # GRAY - 1\n", 288 | " self.conv_layer2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2)\n", 289 | " self.conv_layer3 = nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1)\n", 290 | " self.fc1 = nn.Linear(in_features=7*7*128,out_features=512)\n", 291 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n", 292 | " self.relu = nn.ReLU()\n", 293 | "\n", 294 | " def forward(self,x,bsize):\n", 295 | " x = x.view(bsize,4,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n", 296 | " conv_out = self.conv_layer1(x)\n", 297 | " conv_out = self.relu(conv_out)\n", 298 | " conv_out = self.conv_layer2(conv_out)\n", 299 | " conv_out = self.relu(conv_out)\n", 300 | " conv_out = self.conv_layer3(conv_out)\n", 301 | " conv_out = self.relu(conv_out)\n", 302 | " out = self.fc1(conv_out.view(bsize,7*7*128))\n", 303 | " out = self.relu(out)\n", 304 | " out = self.fc2(out)\n", 305 | " return out\n", 306 | "\n", 307 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).cuda()\n", 308 | "print(main_model)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "

Deep Q Learning with Freeze Network

" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "Populated 60000 Samples in Episodes : 1200\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "mem = Memory(memsize=MEMORY_SIZE)\n", 333 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Primary Network\n", 334 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Target Network\n", 335 | "frameObj = FrameCollector(img_dim=INPUT_IMAGE_DIM,num_frames=4)\n", 336 | "\n", 337 | "target_model.load_state_dict(main_model.state_dict())\n", 338 | "criterion = nn.SmoothL1Loss()\n", 339 | "optimizer = torch.optim.Adam(main_model.parameters())\n", 340 | "\n", 341 | "# filling memory with transitions\n", 342 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n", 343 | " \n", 344 | " prev_state = env.reset()\n", 345 | " frameObj.reset()\n", 346 | " processed_prev_state = preprocess_image(prev_state)\n", 347 | " frameObj.add_frame(processed_prev_state)\n", 348 | " prev_frames = frameObj.get_state()\n", 349 | " step_count = 0\n", 350 | " game_over = False\n", 351 | " \n", 352 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 353 | " \n", 354 | " step_count +=1\n", 355 | " action = np.random.randint(0,4)\n", 356 | " next_state,reward, game_over = env.step(action)\n", 357 | " processed_next_state = preprocess_image(next_state)\n", 358 | " frameObj.add_frame(processed_next_state)\n", 359 | " next_frames = frameObj.get_state()\n", 360 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n", 361 | " \n", 362 | " prev_state = next_state\n", 363 | " processed_prev_state = processed_next_state\n", 364 | " prev_frames = next_frames\n", 365 | "\n", 366 | "print('Populated %d Samples in Episodes : %d'%(len(mem.memory),int(MEMORY_SIZE/MAX_STEPS)))\n", 367 | "\n", 368 | "\n", 369 | "# Algorithm Starts\n", 370 | "total_steps = 0\n", 371 | "epsilon = INITIAL_EPSILON\n", 372 | "loss_stat = []\n", 373 | "total_reward_stat = []\n", 374 | "\n", 375 | "for episode in range(0,TOTAL_EPISODES):\n", 376 | " \n", 377 | " prev_state = env.reset()\n", 378 | " frameObj.reset()\n", 379 | " processed_prev_state = preprocess_image(prev_state)\n", 380 | " frameObj.add_frame(processed_prev_state)\n", 381 | " prev_frames = frameObj.get_state()\n", 382 | " game_over = False\n", 383 | " step_count = 0\n", 384 | " total_reward = 0\n", 385 | " \n", 386 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 387 | " \n", 388 | " step_count +=1\n", 389 | " total_steps +=1\n", 390 | " \n", 391 | " if np.random.rand() <= epsilon:\n", 392 | " action = np.random.randint(0,4)\n", 393 | " else:\n", 394 | " with torch.no_grad():\n", 395 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n", 396 | "\n", 397 | " model_out = main_model.forward(torch_x,bsize=1)\n", 398 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 399 | " \n", 400 | " next_state, reward, game_over = env.step(action)\n", 401 | " processed_next_state = preprocess_image(next_state)\n", 402 | " frameObj.add_frame(processed_next_state)\n", 403 | " next_frames = frameObj.get_state()\n", 404 | " total_reward += reward\n", 405 | " \n", 406 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n", 407 | " \n", 408 | " prev_state = next_state\n", 409 | " processed_prev_state = processed_next_state\n", 410 | " prev_frames = next_frames\n", 411 | " \n", 412 | " if (total_steps % FREEZE_INTERVAL) == 0:\n", 413 | " target_model.load_state_dict(main_model.state_dict())\n", 414 | " \n", 415 | " batch = mem.get_batch(size=BATCH_SIZE)\n", 416 | " current_states = []\n", 417 | " next_states = []\n", 418 | " acts = []\n", 419 | " rewards = []\n", 420 | " game_status = []\n", 421 | " \n", 422 | " for element in batch:\n", 423 | " current_states.append(element[0])\n", 424 | " acts.append(element[1])\n", 425 | " rewards.append(element[2])\n", 426 | " next_states.append(element[3])\n", 427 | " game_status.append(element[4])\n", 428 | " \n", 429 | " current_states = np.array(current_states)\n", 430 | " next_states = np.array(next_states)\n", 431 | " rewards = np.array(rewards)\n", 432 | " game_status = [not b for b in game_status]\n", 433 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n", 434 | " torch_acts = torch.tensor(acts)\n", 435 | " \n", 436 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n", 437 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n", 438 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n", 439 | " Q_max_next = Q_max_next.double()\n", 440 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n", 441 | " \n", 442 | " target_values = (rewards + (GAMMA * Q_max_next))\n", 443 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n", 444 | " \n", 445 | " loss = criterion(Q_s_a,target_values.float().cuda())\n", 446 | " \n", 447 | " # save performance measure\n", 448 | " loss_stat.append(loss.item())\n", 449 | " \n", 450 | " # make previous grad zero\n", 451 | " optimizer.zero_grad()\n", 452 | " \n", 453 | " # back - propogate \n", 454 | " loss.backward()\n", 455 | " \n", 456 | " # update params\n", 457 | " optimizer.step()\n", 458 | " \n", 459 | " # save performance measure\n", 460 | " total_reward_stat.append(total_reward)\n", 461 | " \n", 462 | " if epsilon > FINAL_EPSILON:\n", 463 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 464 | " \n", 465 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 466 | " perf = {}\n", 467 | " perf['loss'] = loss_stat\n", 468 | " perf['total_reward'] = total_reward_stat\n", 469 | " save_obj(name='FOUR_OBSERV_NINE',obj=perf)\n", 470 | " \n", 471 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)\n" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "

Save Primary Network Weights

" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 12, 484 | "metadata": { 485 | "collapsed": true 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "torch.save(main_model.state_dict(),'data/FOUR_OBSERV_NINE_WEIGHTS.torch')" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": {}, 495 | "source": [ 496 | "

Testing Policy

" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "

Load Primary Network Weights

" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 13, 509 | "metadata": { 510 | "collapsed": true 511 | }, 512 | "outputs": [], 513 | "source": [ 514 | "weights = torch.load('data/FOUR_OBSERV_NINE_WEIGHTS.torch')\n", 515 | "main_model.load_state_dict(weights)" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "

Testing Policy

" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "metadata": { 529 | "collapsed": true 530 | }, 531 | "outputs": [], 532 | "source": [ 533 | "# Algorithm Starts\n", 534 | "epsilon = INITIAL_EPSILON\n", 535 | "FINAL_EPSILON = 0.01\n", 536 | "total_reward_stat = []\n", 537 | "\n", 538 | "for episode in range(0,TOTAL_EPISODES):\n", 539 | " \n", 540 | " prev_state = env.reset()\n", 541 | " processed_prev_state = preprocess_image(prev_state)\n", 542 | " frameObj.reset()\n", 543 | " frameObj.add_frame(processed_prev_state)\n", 544 | " prev_frames = frameObj.get_state()\n", 545 | " game_over = False\n", 546 | " step_count = 0\n", 547 | " total_reward = 0\n", 548 | " \n", 549 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 550 | " \n", 551 | " step_count +=1\n", 552 | " \n", 553 | " if np.random.rand() <= epsilon:\n", 554 | " action = np.random.randint(0,4)\n", 555 | " else:\n", 556 | " with torch.no_grad():\n", 557 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n", 558 | "\n", 559 | " model_out = main_model.forward(torch_x,bsize=1)\n", 560 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 561 | " \n", 562 | " next_state, reward, game_over = env.step(action)\n", 563 | " processed_next_state = preprocess_image(next_state)\n", 564 | " frameObj.add_frame(processed_next_state)\n", 565 | " next_frames = frameObj.get_state()\n", 566 | " \n", 567 | " total_reward += reward\n", 568 | " \n", 569 | " prev_state = next_state\n", 570 | " processed_prev_state = processed_next_state\n", 571 | " prev_frames = next_frames\n", 572 | " \n", 573 | " # save performance measure\n", 574 | " total_reward_stat.append(total_reward)\n", 575 | " \n", 576 | " if epsilon > FINAL_EPSILON:\n", 577 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 578 | " \n", 579 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 580 | " perf = {}\n", 581 | " perf['total_reward'] = total_reward_stat\n", 582 | " save_obj(name='FOUR_OBSERV_NINE',obj=perf)\n", 583 | " \n", 584 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)" 585 | ] 586 | } 587 | ], 588 | "metadata": { 589 | "kernelspec": { 590 | "display_name": "Python [conda env:myenv]", 591 | "language": "python", 592 | "name": "conda-env-myenv-py" 593 | }, 594 | "language_info": { 595 | "codemirror_mode": { 596 | "name": "ipython", 597 | "version": 3 598 | }, 599 | "file_extension": ".py", 600 | "mimetype": "text/x-python", 601 | "name": "python", 602 | "nbconvert_exporter": "python", 603 | "pygments_lexer": "ipython3", 604 | "version": "3.6.5" 605 | } 606 | }, 607 | "nbformat": 4, 608 | "nbformat_minor": 2 609 | } 610 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/MDP_Size_9-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from gridworld import gameEnv\n", 10 | "import numpy as np\n", 11 | "%matplotlib inline\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "from collections import deque\n", 14 | "import pickle\n", 15 | "from skimage.color import rgb2gray\n", 16 | "import random" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "

Define Environment Object

" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 3, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "data": { 33 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADONJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUMtCiKqWsAtIRH0IkKgqIoqJG7SFppIiWkvoki9SKQqCRdVJARJUUX5EwKNZUWk1CGqKlUO5k8TsCE2xARbBpsUSkqltk7eXuy4PXFtzhyf3T07/n0/0urszOye+Y1Hz5nZ2fH7pqqQ1JZfWO4BSJo+gy81yOBLDTL4UoMMvtQggy81yOBLDVpS8JNck+SFJLuTbBrXoCRNVo73Bp4kK4AfABuBvcATwE1VtWN8w5M0CSuX8N5Lgd1V9RJAkvuB64FjBv/MM8+subm5JaxS0jvZs2cPr7/+ehZ63VKCfzbwyrzpvcBvvtMb5ubm2L59+xJWKemdbNiwodfrJn5xL8ktSbYn2X7w4MFJr05SD0sJ/j7gnHnTa7t5P6eq7qyqDVW14ayzzlrC6iSNy1KC/wRwXpJ1SU4GbgQ2j2dYkibpuD/jV9WhJH8EfAtYAXylqp4b28gkTcxSLu5RVd8EvjmmsUiaEu/ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGLem/5c6CZMG6gtLMWq429R7xpQYZfKlBCwY/yVeSHEjy7Lx5q5I8lmRX9/P0yQ5T0jj1OeL/NXDNEfM2AVur6jxgazctaSAWDH5V/SPwr0fMvh64p3t+D/D7Yx6XpAk63s/4q6tqf/f8VWD1mMYjaQqWfHGvRt9HHPM7CTvpSLPneIP/WpI1AN3PA8d6oZ10pNlzvMHfDHyse/4x4BvjGY6kaejzdd59wD8D5yfZm+Rm4HPAxiS7gKu7aUkDseAtu1V10zEWXTXmsUiaEu/ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUp/TWOUkeT7IjyXNJbu3m201HGqg+R/xDwCer6kLgMuDjSS7EbjrSYPXppLO/qp7qnv8E2Amcjd10pMFa1Gf8JHPAxcA2enbTsaGGNHt6Bz/Je4GvA7dV1Vvzl71TNx0bakizp1fwk5zEKPT3VtXD3eze3XQkzZY+V/UD3A3srKovzFtkNx1poBZsqAFcAfwB8P0kz3Tz/oxR95wHu846LwM3TGaIksatTyedfwJyjMV205EGyDv3pAb1OdXXCeioX8EsQcb5C491fjlDxv3vN20e8aUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfalCfmnunJPlukn/pOul8tpu/Lsm2JLuTPJDk5MkPV9I49Dni/ydwZVVdBKwHrklyGfB54ItV9X7gDeDmyQ1T0jj16aRTVfXv3eRJ3aOAK4GHuvl20pEGpG9d/RVdhd0DwGPAi8CbVXWoe8leRm21jvZeO+lIM6ZX8Kvqp1W1HlgLXApc0HcFdtKZTRnzY7y/bPYNfVMXdVW/qt4EHgcuB05LcrhY51pg35jHJmlC+lzVPyvJad3zdwMbGXXMfRz4SPcyO+lIA9KnvPYa4J4kKxj9oXiwqrYk2QHcn+QvgKcZtdmSNAB9Oul8j1Fr7CPnv8To876kgfHOPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBvYPfldh+OsmWbtpOOtJALeaIfyujIpuH2UlHGqi+DTXWAh8C7uqmg510pMHqe8T/EvAp4Gfd9BnYSUcarD519T8MHKiqJ49nBXbSkWZPn7r6VwDXJbkWOAX4JeAOuk463VHfTjrSgPTplnt7Va2tqjngRuDbVfVR7KQjDdZSvsf/NPCJJLsZfea3k440EH1O9f9XVX0H+E733E460kB5557UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDFnUDjxapxvz7Mubfp2Z5xJcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9vsdPsgf4CfBT4FBVbUiyCngAmAP2ADdU1RuTGaakcVrMEf93qmp9VW3opjcBW6vqPGBrNy1pAJZyqn89o0YaYEMNaVD6Br+Av0/yZJJbunmrq2p/9/xVYPXYRydpIvreq/+BqtqX5JeBx5I8P39hVVWSo96Z3v2huAXg3HPPXdJgJY1HryN+Ve3rfh4AHmFUXfe1JGsAup8HjvFeO+lIM6ZPC61Tk/zi4efA7wLPApsZNdIAG2pIg9LnVH818MioQS4rgb+tqkeTPAE8mORm4GXghskNU9I4LRj8rnHGRUeZ/2PgqkkMStJkeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBewU9yWpKHkjyfZGeSy5OsSvJYkl3dz9MnPVhJ49H3iH8H8GhVXcCoDNdO7KQjDVafKrvvA34LuBugqv6rqt7ETjrSYPU54q8DDgJfTfJ0kru6Mtt20pEGqk/wVwKXAF+uqouBtznitL6qilGbrf8nyS1JtifZfvDgwaWOV9IY9An+XmBvVW3rph9i9IfATjoLyZgfs6zG+NDELRj8qnoVeCXJ+d2sq4Ad2ElHGqy+TTP/GLg3ycnAS8AfMvqjYScdaYB6Bb+qngE2HGWRnXSkAfLOPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBferqn5/kmXmPt5LcduJ10hlntcgGq0a2UlT0BNGn2OYLVbW+qtYDvwH8B/AIdtKRBmuxp/pXAS9W1cvYSUcarMUG/0bgvu65nXSkgeod/K609nXA145cZicdaVgWc8T/IPBUVb3WTdtJRxqoxQT/Jv7vNB/spCMNVq/gd91xNwIPz5v9OWBjkl3A1d20pAHo20nnbeCMI+b9GDvpSIPknXtSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSg3rduTfLRv8xcFbN8tjUMo/4UoMMvtQggy81yOBLDTL4UoMMvtQggy81qG/prT9N8lySZ5Pcl+SUJOuSbEuyO8kDXRVeSQPQp4XW2cCfABuq6teBFYzq638e+GJVvR94A7h5kgOVND59T/VXAu9OshJ4D7AfuBJ4qFtuJx1pQPr0ztsH/CXwI0aB/zfgSeDNqjrUvWwvcPakBilpvPqc6p/OqE/eOuBXgFOBa/quwE460uzpc6p/NfDDqjpYVf/NqLb+FcBp3ak/wFpg39HebCcdafb0Cf6PgMuSvCdJGNXS3wE8Dnyke42ddKQB6fMZfxuji3hPAd/v3nMn8GngE0l2M2q2cfcExylpjPp20vkM8JkjZr8EXDr2EUmaOO/ckxpk8KUGGXypQQZfalCmWawyyUHgbeD1qa108s7E7ZlVJ9K2QL/t+dWqWvCGmakGHyDJ9qraMNWVTpDbM7tOpG2B8W6Pp/pSgwy+1KDlCP6dy7DOSXJ7ZteJtC0wxu2Z+md8ScvPU32pQVMNfpJrkrzQ1enbNM11L1WSc5I8nmRHV3/w1m7+qiSPJdnV/Tx9uce6GElWJHk6yZZuerC1FJOcluShJM8n2Znk8iHvn0nWupxa8JOsAP4K+CBwIXBTkguntf4xOAR8sqouBC4DPt6NfxOwtarOA7Z200NyK7Bz3vSQayneATxaVRcAFzHarkHun4nXuqyqqTyAy4FvzZu+Hbh9WuufwPZ8A9gIvACs6eatAV5Y7rEtYhvWMgrDlcAWIIxuEFl5tH02yw/gfcAP6a5bzZs/yP3DqJTdK8AqRv+Ldgvwe+PaP9M81T+8IYcNtk5fkjngYmAbsLqq9neLXgVWL9OwjseXgE8BP+umz2C4tRTXAQeBr3YfXe5KcioD3T814VqXXtxbpCTvBb4O3FZVb81fVqM/w4P4miTJh4EDVfXkco9lTFYClwBfrqqLGd0a/nOn9QPbP0uqdbmQaQZ/H3DOvOlj1umbVUlOYhT6e6vq4W72a0nWdMvXAAeWa3yLdAVwXZI9wP2MTvfvoGctxRm0F9hbo4pRMKoadQnD3T9LqnW5kGkG/wngvO6q5MmMLlRsnuL6l6SrN3g3sLOqvjBv0WZGNQdhQLUHq+r2qlpbVXOM9sW3q+qjDLSWYlW9CryS5Pxu1uHakIPcP0y61uWUL1hcC/wAeBH48+W+gLLIsX+A0Wni94Bnuse1jD4XbwV2Af8ArFrusR7Htv02sKV7/mvAd4HdwNeAdy33+BaxHeuB7d0++jvg9CHvH+CzwPPAs8DfAO8a1/7xzj2pQV7ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfatD/ADPvEPBps6hQAAAAAElFTkSuQmCC\n", 34 | "text/plain": [ 35 | "
" 36 | ] 37 | }, 38 | "metadata": { 39 | "needs_background": "light" 40 | }, 41 | "output_type": "display_data" 42 | } 43 | ], 44 | "source": [ 45 | "env = gameEnv(partial=False,size=9)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "" 57 | ] 58 | }, 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "output_type": "execute_result" 62 | }, 63 | { 64 | "data": { 65 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADPlJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUMtCiKqWsAtIRH0IkKgqIoqJG7SFppICbQXUaReJFKVhIsqEoKkqKL8CYEGWREpdYiiSJWD+dMEbIgNMcEWYJNCSanU1snbix23B9fGc3xmz9nx7/uRVrszu3vmN2f0nJmdnfO+qSokteWXlnsAkpaewZcaZPClBhl8qUEGX2qQwZcaZPClBi0q+EmuSPJckp1Jbh5qUJKmK0d7AU+SFcCPgI3AbuAx4Lqq2jbc8CRNw8pFvPdCYGdVvQCQ5B7gauCwwT/11FNrbm5uEYuU9E527drFa6+9liO9bjHBPx14ad70buC33+kNc3NzbN26dRGLlPRONmzY0Ot1Uz+5l+SGJFuTbN23b9+0Fyeph8UEfw9wxrzptd28t6mq26pqQ1VtOO200xaxOElDWUzwHwPOSrIuyfHAtcBDwwxL0jQd9Wf8qtqf5E+AbwErgK9U1TODjUzS1Czm5B5V9U3gmwONRdIS8co9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2rQov4tdxYkR6wr2M80uoUPNDQdu5arTb17fKlBBl9q0BGDn+QrSfYmeXrevFVJHkmyo7s/ebrDlDSkPnv8vwWuOGjezcDmqjoL2NxNSxqJIwa/qr4L/OtBs68G7uwe3wn84cDjkjRFR/sZf3VVvdw9fgVYPdB4JC2BRZ/cq8n3EYf9TsJOOtLsOdrgv5pkDUB3v/dwL7STjjR7jjb4DwEf6x5/DPjGMMORtBT6fJ13N/DPwNlJdie5HvgcsDHJDuDyblrSSBzxkt2quu4wT1028FgkLRGv3JMaZPClBhl8qUEGX2qQwZcaZPClBo2+As9grJajhrjHlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Sm9dUaSR5NsS/JMkhu7+XbTkUaqzx5/P/DJqjoXuAj4eJJzsZuONFp9Oum8XFVPdI9/BmwHTsduOtJoLegzfpI54HxgCz276dhQQ5o9vYOf5L3A14GbqurN+c+9UzcdG2pIs6dX8JMcxyT0d1XVA93s3t10JM2WPmf1A9wBbK+qL8x7ym460kj1qcBzCfBHwA+TPNXN+wsm3XPu6zrrvAhcM50hShpan0463+PwhanspiONkFfuSQ2y2OaoHPKLk6PUUHXRIX9tB4z81+ceX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBvWpuXdCku8n+Zeuk85nu/nrkmxJsjPJvUmOn/5wJQ2hzx7/P4FLq+o8YD1wRZKLgM8DX6yq9wOvA9dPb5iShtSnk05V1b93k8d1twIuBe7v5ttJRxqRvnX1V3QVdvcCjwDPA29U1f7uJbuZtNU61HvtpCPNmF7Br6qfV9V6YC1wIXBO3wXYSWdIGfDWkCF/bcfIr29BZ/Wr6g3gUeBi4KQkB4p1rgX2DDw2SVPS56z+aUlO6h6/G9jIpGPuo8BHupfZSUcakT7ltdcAdyZZweQPxX1VtSnJNuCeJH8FPMmkzZakEejTSecHTFpjHzz/BSaf9yWNjFfuSQ0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw3qHfyuxPaTSTZ103bSkUZqIXv8G5kU2TzATjrSSPVtqLEW+BBwezcd7KQjjVbfPf6XgE8Bv+imT8FOOtJo9amr/2Fgb1U9fjQLsJOONHv61NW/BLgqyZXACcCvALfSddLp9vp20pFGpE+33Fuqam1VzQHXAt+uqo9iJx1ptBbzPf6ngU8k2cnkM7+ddKSR6HOo/7+q6jvAd7rHdtKRRsor96QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUqxBHkl3Az4CfA/urakOSVcC9wBywC7imql6fzjAlDWkhe/zfq6r1VbWhm74Z2FxVZwGbu2lJI7CYQ/2rmTTSABtqSKPSN/gF/GOSx5Pc0M1bXVUvd49fAVYPPjpJU9G32OYHqmpPkl8FHkny7Pwnq6qS1KHe2P2huAHgzDPPXNRgJQ2j1x6/qvZ093uBB5lU1301yRqA7n7vYd5rJx1pxvRpoXVikl8+8Bj4feBp4CEmjTTAhhrSqPQ51F8NPDhpkMtK4O+r6uEkjwH3JbkeeBG4ZnrDlDSkIwa/a5xx3iHm/xS4bBqDkjRdXrknNcjgSw1aUO+8mXTILxGP4sdkmJ8z3xR+pDQI9/hSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBewU9yUpL7kzybZHuSi5OsSvJIkh3d/cnTHqykYfTd498KPFxV5zApw7UdO+lIo9Wnyu77gN8B7gCoqv+qqjewk440Wn32+OuAfcBXkzyZ5PauzLaddKSR6hP8lcAFwJer6nzgLQ46rK+q4jBFsJLckGRrkq379u1b7HglDaBP8HcDu6tqSzd9P5M/BLPRSSfD3Ab6MW+76dhVA92WyxGDX1WvAC8lObubdRmwDTvpSKPVt8runwJ3JTkeeAH4YyZ/NOykI41Qr+BX1VPAhkM8ZScdaYS8ck9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUJ+6+mcneWre7c0kN9lJp4ehKjIud2VG/T9jL8jap9jmc1W1vqrWA78F/AfwIHbSkUZroYf6lwHPV9WL2ElHGq2FBv9a4O7usZ10pJHqHfyutPZVwNcOfs5OOtK4LGSP/0Hgiap6tZuejU46khZsIcG/jv87zAc76Uij1Sv4XXfcjcAD82Z/DtiYZAdweTctaQT6dtJ5CzjloHk/xU460ih55Z7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoF5X7s2yyT8GNqKhVdV0uceXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBfUtv/XmSZ5I8neTuJCckWZdkS5KdSe7tqvBKGoE+LbROB/4M2FBVvwmsYFJf//PAF6vq/cDrwPXTHKik4fQ91F8JvDvJSuA9wMvApcD93fN20pFGpE/vvD3AXwM/YRL4fwMeB96oqv3dy3YDp09rkJKG1edQ/2QmffLWAb8GnAhc0XcBdtKRZk+fQ/3LgR9X1b6q+m8mtfUvAU7qDv0B1gJ7DvVmO+lIs6dP8H8CXJTkPUnCpJb+NuBR4CPda+ykI41In8/4W5icxHsC+GH3ntuATwOfSLKTSbONO6Y4TkkD6ttJ5zPAZw6a/QJw4eAjkjR1XrknNcjgSw0y+FKDDL7UoCxlscok+4C3gNeWbKHTdyquz6w6ltYF+q3Pr1fVES+YWdLgAyTZWlUblnShU+T6zK5jaV1g2PXxUF9qkMGXGrQcwb9tGZY5Ta7P7DqW1gUGXJ8l/4wvafl5qC81aEmDn+SKJM91dfpuXsplL1aSM5I8mmRbV3/wxm7+qiSPJNnR3Z+83GNdiCQrkjyZZFM3PdpaiklOSnJ/kmeTbE9y8Zi3zzRrXS5Z8JOsAP4G+CBwLnBdknOXavkD2A98sqrOBS4CPt6N/2Zgc1WdBWzupsfkRmD7vOkx11K8FXi4qs4BzmOyXqPcPlOvdVlVS3IDLga+NW/6FuCWpVr+FNbnG8BG4DlgTTdvDfDcco9tAeuwlkkYLgU2AWFygcjKQ22zWb4B7wN+THfeat78UW4fJqXsXgJWMfkv2k3AHwy1fZbyUP/Aihww2jp9SeaA84EtwOqqerl76hVg9TIN62h8CfgU8Itu+hTGW0txHbAP+Gr30eX2JCcy0u1TU6516cm9BUryXuDrwE1V9eb852ryZ3gUX5Mk+TCwt6oeX+6xDGQlcAHw5ao6n8ml4W87rB/Z9llUrcsjWcrg7wHOmDd92Dp9syrJcUxCf1dVPdDNfjXJmu75NcDe5RrfAl0CXJVkF3APk8P9W+lZS3EG7QZ216RiFEyqRl3AeLfPompdHslSBv8x4KzurOTxTE5UPLSEy1+Urt7gHcD2qvrCvKceYlJzEEZUe7CqbqmqtVU1x2RbfLuqPspIaylW1SvAS0nO7mYdqA05yu3DtGtdLvEJiyuBHwHPA3+53CdQFjj2DzA5TPwB8FR3u5LJ5+LNwA7gn4BVyz3Wo1i33wU2dY9/A/g+sBP4GvCu5R7fAtZjPbC120b/AJw85u0DfBZ4Fnga+DvgXUNtH6/ckxrkyT2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG/Q++8g/3YgFKrwAAAABJRU5ErkJggg==\n", 66 | "text/plain": [ 67 | "
" 68 | ] 69 | }, 70 | "metadata": { 71 | "needs_background": "light" 72 | }, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "prev_state = env.reset()\n", 78 | "plt.imshow(prev_state)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "

Training Q Network

" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "

Hyper-parameters

" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 7, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "BATCH_SIZE = 64\n", 102 | "FREEZE_INTERVAL = 20000 # steps\n", 103 | "MEMORY_SIZE = 60000 \n", 104 | "OUTPUT_SIZE = 4\n", 105 | "TOTAL_EPISODES = 10000\n", 106 | "MAX_STEPS = 50\n", 107 | "INITIAL_EPSILON = 1.0\n", 108 | "FINAL_EPSILON = 0.01\n", 109 | "GAMMA = 0.99\n", 110 | "INPUT_IMAGE_DIM = 84\n", 111 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "

Save Dictionay Function

" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "metadata": { 125 | "collapsed": true 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "def save_obj(obj, name ):\n", 130 | " with open('data/'+ name + '.pkl', 'wb') as f:\n", 131 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "

Experience Replay

" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 6, 144 | "metadata": { 145 | "collapsed": true 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "class Memory():\n", 150 | " \n", 151 | " def __init__(self,memsize):\n", 152 | " self.memsize = memsize\n", 153 | " self.memory = deque(maxlen=self.memsize)\n", 154 | " \n", 155 | " def add_sample(self,sample):\n", 156 | " self.memory.append(sample)\n", 157 | " \n", 158 | " def get_batch(self,size):\n", 159 | " return random.sample(self.memory,k=size)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "

Preprocess Images

" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 5, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "def preprocess_image(image):\n", 176 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n", 177 | " return np.copy(image)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 7, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/plain": [ 188 | "" 189 | ] 190 | }, 191 | "execution_count": 7, 192 | "metadata": {}, 193 | "output_type": "execute_result" 194 | }, 195 | { 196 | "data": { 197 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADVNJREFUeJzt3X+oX/V9x/Hna4nW1m71VxYyo4ulosjA6C6ZYhmbmi1xRfdHEUVGGYL/dJtdC61usFLYHy2Mtv4xCkHbueH8UVtXCY2dSy1jMKLxx1o1WqONNUGTWHXtHGxL+94f3yO7zRLvubnne+89fp4PuHy/53x/nM/h8Pqe8z3fc9/vVBWS2vILSz0ASYvP4EsNMvhSgwy+1CCDLzXI4EsNMvhSgxYU/CSbkjybZHeSm4YalKTpyrFewJNkBfB9YCOwF3gEuLaqnh5ueJKmYeUCXrsB2F1VLwAkuQu4Cjhq8E877bRat27dAhYp6e3s2bOHV199NXM9byHBPx14adb0XuA33u4F69atY+fOnQtYpKS3MzMz0+t5Uz+5l+SGJDuT7Dx48OC0Fyeph4UEfx9wxqzptd28n1NVW6pqpqpmVq1atYDFSRrKQoL/CHB2krOSHA9cA9w/zLAkTdMxf8evqkNJ/gj4FrAC+HJVPTXYyCRNzUJO7lFV3wS+OdBYJC0Sr9yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYt6N9yl4NkzrqC0rK1VG3q3eNLDTL4UoPmDH6SLyc5kOTJWfNOSfJgkue625OnO0xJQ+qzx/8bYNNh824CtlfV2cD2blrSSMwZ/Kr6Z+C1w2ZfBdze3b8d+P2BxyVpio71O/7qqnq5u/8KsHqg8UhaBAs+uVeT3yOO+puEnXSk5edYg78/yRqA7vbA0Z5oJx1p+TnW4N8PfKS7/xHgG8MMR9Ji6PNz3p3AvwLnJNmb5Hrgs8DGJM8Bl3fTkkZizkt2q+raozx02cBjkbRIvHJPapDBlxpk8KUGGXypQQZfapDBlxo0+go8LdmwYcNg7/Xwww8P9l4aH/f4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPqW3zkjyUJKnkzyV5MZuvt10pJHqs8c/BHyiqs4DLgI+muQ87KYjjVafTjovV9Vj3f2fALuA07GbjjRa8/qOn2QdcAGwg57ddGyoIS0/vYOf5L3A14CPVdWPZz/2dt10bKghLT+9gp/kOCahv6Oqvt7N7t1NR9Ly0uesfoDbgF1V9flZD9lNRxqpPhV4LgH+APhekie6eX/GpHvOPV1nnReBq6czRElD69NJ51+AHOVhu+lII+SVe1KDRl9sc9u2bYO8z+bNmwd5n2myQKaG4h5fapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9am5d0KSh5P8W9dJ5zPd/LOS7EiyO8ndSY6f/nAlDaHPHv+/gEur6nxgPbApyUXA54AvVNUHgNeB66c3TElD6tNJp6rqP7rJ47q/Ai4F7u3m20lHGpG+dfVXdBV2DwAPAs8Db1TVoe4pe5m01TrSa+2kIy0zvWruVdVPgfVJTgLuA87tu4Cq2gJsAZiZmTlit52FGEOtvKFs2LBhsPeyfl/b5nVWv6reAB4CLgZOSvLWB8daYN/AY5M0JX3O6q/q9vQkeTewkUnH3IeAD3dPs5OONCJ9DvXXALcnWcHkg+Keqtqa5GngriR/CTzOpM2WpBHo00nnu0xaYx8+/wVguC+dkhaNV+5JDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6ld7S8mC5LA3FPb7UIIMvNah38LsS248n2dpN20lHGqn57PFvZFJk8y120pFGqm9DjbXA7wG3dtPBTjrSaPXd438R+CTws276VOykI41Wn7r6HwIOVNWjx7KAqtpSVTNVNbNq1apjeQtJA+vzO/4lwJVJrgBOAH4JuIWuk06317eTjjQifbrl3lxVa6tqHXAN8O2qug476UijtZDf8T8FfDzJbibf+e2kI43EvC7ZrarvAN/p7ttJRxopr9yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkDX3NIht27YN9l6bN28e7L10ZO7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/f8ZPsAX4C/BQ4VFUzSU4B7gbWAXuAq6vq9ekMU9KQ5rPH/+2qWl9VM930TcD2qjob2N5NSxqBhRzqX8WkkQbYUEMalb7BL+Afkzya5IZu3uqqerm7/wqwevDRSZqKvtfqf7Cq9iX5ZeDBJM/MfrCqKkkd6YXdB8UNAGeeeeaCBitpGL32+FW1r7s9ANzHpLru/iRrALrbA0d5rZ10pGWmTwutE5P84lv3gd8BngTuZ9JIA2yoIY1Kn0P91cB9kwa5rAT+vqoeSPIIcE+S64EXgaunN0xJQ5oz+F3jjPOPMP9HwGXTGJSk6fLKPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBvYKf5KQk9yZ5JsmuJBcnOSXJg0me625PnvZgJQ2j7x7/FuCBqjqXSRmuXdhJRxqtPlV23wf8JnAbQFX9d1W9gZ10pNHqs8c/CzgIfCXJ40lu7cps20lHGqk+wV8JXAh8qaouAN7ksMP6qiombbb+nyQ3JNmZZOfBgwcXOl5JA+hTV38vsLeqdnTT9zIJ/v4ka6rq5bk66QBbAGZmZo744aDxu+6665Z6CJqHOff4VfUK8FKSc7pZlwFPYycdabT6Ns38Y+COJMcDLwB/yORDw0460gj1Cn5VPQHMHOEhO+lII+SVe1KDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKD5qzA09Xau3vWrPcDfwH8bTd/HbAHuLqqXh9+iG9v27Ztg7zP5s2bB3mfVr322mtLPQTNQ59im89W1fqqWg/8OvCfwH3YSUcarfke6l8GPF9VL2InHWm05hv8a4A7u/t20pFGqnfwu9LaVwJfPfwxO+lI4zKfPf5m4LGq2t9N7+866DBXJ52qmqmqmVWrVi1stJIGMZ/gX8v/HeaDnXSk0eoV/K477kbg67NmfxbYmOQ54PJuWtII9O2k8yZw6mHzfoSddKRR8so9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9rtxbzjZt2jTI+0z+wVBqg3t8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZca1Lf01p8meSrJk0nuTHJCkrOS7EiyO8ndXRVeSSMwZ/CTnA78CTBTVb8GrGBSX/9zwBeq6gPA68D10xyopOH0PdRfCbw7yUrgPcDLwKXAvd3jdtKRRqRP77x9wF8BP2QS+H8HHgXeqKpD3dP2AqdPa5CShtXnUP9kJn3yzgJ+BTgR6H2BvJ10pOWnz6H+5cAPqupgVf0Pk9r6lwAndYf+AGuBfUd6sZ10pOWnT/B/CFyU5D1JwqSW/tPAQ8CHu+fYSUcakT7f8XcwOYn3GPC97jVbgE8BH0+ym0mzjdumOE5JA+rbSefTwKcPm/0CsGHwEUmaOq/ckxpk8KUGGXypQQZfalAWs8hkkoPAm8Cri7bQ6TsN12e5eietC/Rbn1+tqjkvmFnU4AMk2VlVM4u60ClyfZavd9K6wLDr46G+1CCDLzVoKYK/ZQmWOU2uz/L1TloXGHB9Fv07vqSl56G+1KBFDX6STUme7er03bSYy16oJGckeSjJ0139wRu7+ackeTDJc93tyUs91vlIsiLJ40m2dtOjraWY5KQk9yZ5JsmuJBePeftMs9blogU/yQrgr4HNwHnAtUnOW6zlD+AQ8ImqOg+4CPhoN/6bgO1VdTawvZsekxuBXbOmx1xL8Rbggao6FzifyXqNcvtMvdZlVS3KH3Ax8K1Z0zcDNy/W8qewPt8ANgLPAmu6eWuAZ5d6bPNYh7VMwnApsBUIkwtEVh5pmy3nP+B9wA/ozlvNmj/K7cOklN1LwClM/ot2K/C7Q22fxTzUf2tF3jLaOn1J1gEXADuA1VX1cvfQK8DqJRrWsfgi8EngZ930qYy3luJZwEHgK91Xl1uTnMhIt09NudalJ/fmKcl7ga8BH6uqH89+rCYfw6P4mSTJh4ADVfXoUo9lICuBC4EvVdUFTC4N/7nD+pFtnwXVupzLYgZ/H3DGrOmj1ulbrpIcxyT0d1TV17vZ+5Os6R5fAxxYqvHN0yXAlUn2AHcxOdy/hZ61FJehvcDemlSMgknVqAsZ7/ZZUK3LuSxm8B8Bzu7OSh7P5ETF/Yu4/AXp6g3eBuyqqs/Peuh+JjUHYUS1B6vq5qpaW1XrmGyLb1fVdYy0lmJVvQK8lOScbtZbtSFHuX2Ydq3LRT5hcQXwfeB54M+X+gTKPMf+QSaHid8Fnuj+rmDyvXg78BzwT8ApSz3WY1i33wK2dvffDzwM7Aa+Crxrqcc3j/VYD+zsttE/ACePefsAnwGeAZ4E/g5411Dbxyv3pAZ5ck9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlB/wsPrvc22jHaHAAAAABJRU5ErkJggg==\n", 198 | "text/plain": [ 199 | "
" 200 | ] 201 | }, 202 | "metadata": {}, 203 | "output_type": "display_data" 204 | } 205 | ], 206 | "source": [ 207 | "processed_prev_state = preprocess_image(prev_state)\n", 208 | "plt.imshow(processed_prev_state,cmap='gray')" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "

Build Model

" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 8, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "Network(\n", 228 | " (conv_layer1): Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))\n", 229 | " (conv_layer2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n", 230 | " (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", 231 | " (fc1): Linear(in_features=3136, out_features=512, bias=True)\n", 232 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n", 233 | " (relu): ReLU()\n", 234 | ")\n" 235 | ] 236 | } 237 | ], 238 | "source": [ 239 | "import torch.nn as nn\n", 240 | "import torch\n", 241 | "\n", 242 | "class Network(nn.Module):\n", 243 | " \n", 244 | " def __init__(self,image_input_size,out_size):\n", 245 | " super(Network,self).__init__()\n", 246 | " self.image_input_size = image_input_size\n", 247 | " self.out_size = out_size\n", 248 | "\n", 249 | " self.conv_layer1 = nn.Conv2d(in_channels=1,out_channels=32,kernel_size=8,stride=4) # GRAY - 1\n", 250 | " self.conv_layer2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4,stride=2)\n", 251 | " self.conv_layer3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1)\n", 252 | " self.fc1 = nn.Linear(in_features=7*7*64,out_features=512)\n", 253 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n", 254 | " self.relu = nn.ReLU()\n", 255 | "\n", 256 | " def forward(self,x,bsize):\n", 257 | " x = x.view(bsize,1,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n", 258 | " conv_out = self.conv_layer1(x)\n", 259 | " conv_out = self.relu(conv_out)\n", 260 | " conv_out = self.conv_layer2(conv_out)\n", 261 | " conv_out = self.relu(conv_out)\n", 262 | " conv_out = self.conv_layer3(conv_out)\n", 263 | " conv_out = self.relu(conv_out)\n", 264 | " out = self.fc1(conv_out.view(bsize,7*7*64))\n", 265 | " out = self.relu(out)\n", 266 | " out = self.fc2(out)\n", 267 | " return out\n", 268 | "\n", 269 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE)\n", 270 | "print(main_model)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "

Deep Q Learning with Target Freeze

" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Populated 200 Samples\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "mem = Memory(memsize=MEMORY_SIZE)\n", 295 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda()\n", 296 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda()\n", 297 | "\n", 298 | "target_model.load_state_dict(main_model.state_dict())\n", 299 | "criterion = nn.SmoothL1Loss()\n", 300 | "optimizer = torch.optim.Adam(main_model.parameters())\n", 301 | "\n", 302 | "# filling memory with transitions\n", 303 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n", 304 | " \n", 305 | " prev_state = env.reset()\n", 306 | " processed_prev_state = preprocess_image(prev_state)\n", 307 | " step_count = 0\n", 308 | " game_over = False\n", 309 | " \n", 310 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 311 | " \n", 312 | " step_count +=1\n", 313 | " action = np.random.randint(0,4)\n", 314 | " next_state,reward, game_over = env.step(action)\n", 315 | " processed_next_state = preprocess_image(next_state)\n", 316 | " mem.add_sample((processed_prev_state,action,reward,processed_next_state,game_over))\n", 317 | " \n", 318 | " prev_state = next_state\n", 319 | " processed_prev_state = processed_next_state\n", 320 | "\n", 321 | "print('Populated %d Samples'%(len(mem.memory)))\n", 322 | "\n", 323 | "# Algorithm Starts\n", 324 | "total_steps = 0\n", 325 | "epsilon = INITIAL_EPSILON\n", 326 | "loss_stat = []\n", 327 | "total_reward_stat = []\n", 328 | "\n", 329 | "for episode in range(0,TOTAL_EPISODES):\n", 330 | " \n", 331 | " prev_state = env.reset()\n", 332 | " processed_prev_state = preprocess_image(prev_state)\n", 333 | " game_over = False\n", 334 | " step_count = 0\n", 335 | " total_reward = 0\n", 336 | " \n", 337 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 338 | " \n", 339 | " step_count +=1\n", 340 | " total_steps +=1\n", 341 | " \n", 342 | " if np.random.rand() <= epsilon:\n", 343 | " action = np.random.randint(0,4)\n", 344 | " else:\n", 345 | " with torch.no_grad():\n", 346 | " torch_x = torch.from_numpy(processed_prev_state).float().cuda()\n", 347 | "\n", 348 | " model_out = main_model.forward(torch_x,bsize=1)\n", 349 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 350 | " \n", 351 | " next_state, reward, game_over = env.step(action)\n", 352 | " processed_next_state = preprocess_image(next_state)\n", 353 | " total_reward += reward\n", 354 | " \n", 355 | " mem.add_sample((processed_prev_state,action,reward,processed_next_state,game_over))\n", 356 | " \n", 357 | " prev_state = next_state\n", 358 | " processed_prev_state = processed_next_state\n", 359 | " \n", 360 | " if (total_steps % FREEZE_INTERVAL) == 0:\n", 361 | " target_model.load_state_dict(main_model.state_dict())\n", 362 | " \n", 363 | " batch = mem.get_batch(size=BATCH_SIZE)\n", 364 | " current_states = []\n", 365 | " next_states = []\n", 366 | " acts = []\n", 367 | " rewards = []\n", 368 | " game_status = []\n", 369 | " \n", 370 | " for element in batch:\n", 371 | " current_states.append(element[0])\n", 372 | " acts.append(element[1])\n", 373 | " rewards.append(element[2])\n", 374 | " next_states.append(element[3])\n", 375 | " game_status.append(element[4])\n", 376 | " \n", 377 | " current_states = np.array(current_states)\n", 378 | " next_states = np.array(next_states)\n", 379 | " rewards = np.array(rewards)\n", 380 | " game_status = [not b for b in game_status]\n", 381 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n", 382 | " torch_acts = torch.tensor(acts)\n", 383 | " \n", 384 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n", 385 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n", 386 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n", 387 | " Q_max_next = Q_max_next.double()\n", 388 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n", 389 | " \n", 390 | " target_values = (rewards + (GAMMA * Q_max_next)).cuda()\n", 391 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n", 392 | " \n", 393 | " loss = criterion(Q_s_a,target_values.float())\n", 394 | " \n", 395 | " # save performance measure\n", 396 | " loss_stat.append(loss.item())\n", 397 | " \n", 398 | " # make previous grad zero\n", 399 | " optimizer.zero_grad()\n", 400 | " \n", 401 | " # back - propogate \n", 402 | " loss.backward()\n", 403 | " \n", 404 | " # update params\n", 405 | " optimizer.step()\n", 406 | " \n", 407 | " # save performance measure\n", 408 | " total_reward_stat.append(total_reward)\n", 409 | " \n", 410 | " if epsilon > FINAL_EPSILON:\n", 411 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 412 | " \n", 413 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 414 | " perf = {}\n", 415 | " perf['loss'] = loss_stat\n", 416 | " perf['total_reward'] = total_reward_stat\n", 417 | " save_obj(name='MDP_ENV_SIZE_NINE',obj=perf)\n", 418 | " \n", 419 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "metadata": {}, 425 | "source": [ 426 | "

Save Primary Network Weights

" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 18, 432 | "metadata": { 433 | "collapsed": true 434 | }, 435 | "outputs": [], 436 | "source": [ 437 | "torch.save(main_model.state_dict(),'data/MDP_ENV_SIZE_NINE_WEIGHTS.torch')" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "

Testing Policy

" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "

Load Primary Network Weights

" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 9, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "weights = torch.load('data/MDP_ENV_SIZE_NINE_WEIGHTS.torch', map_location='cpu')\n", 461 | "main_model.load_state_dict(weights)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": {}, 467 | "source": [ 468 | "

Test Policy

" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "# Algorithm Starts\n", 478 | "epsilon = INITIAL_EPSILON\n", 479 | "FINAL_EPSILON = 0.01\n", 480 | "total_reward_stat = []\n", 481 | "\n", 482 | "for episode in range(0,TOTAL_EPISODES):\n", 483 | " \n", 484 | " prev_state = env.reset()\n", 485 | " processed_prev_state = preprocess_image(prev_state)\n", 486 | " game_over = False\n", 487 | " step_count = 0\n", 488 | " total_reward = 0\n", 489 | " \n", 490 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 491 | " \n", 492 | " step_count +=1\n", 493 | " \n", 494 | " if np.random.rand() <= epsilon:\n", 495 | " action = np.random.randint(0,4)\n", 496 | " else:\n", 497 | " with torch.no_grad():\n", 498 | " torch_x = torch.from_numpy(processed_prev_state).float().cuda()\n", 499 | "\n", 500 | " model_out = main_model.forward(torch_x,bsize=1)\n", 501 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 502 | " \n", 503 | " next_state, reward, game_over = env.step(action)\n", 504 | " processed_next_state = preprocess_image(next_state)\n", 505 | " total_reward += reward\n", 506 | " \n", 507 | " prev_state = next_state\n", 508 | " processed_prev_state = processed_next_state\n", 509 | " \n", 510 | " # save performance measure\n", 511 | " total_reward_stat.append(total_reward)\n", 512 | " \n", 513 | " if epsilon > FINAL_EPSILON:\n", 514 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 515 | " \n", 516 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 517 | " perf = {}\n", 518 | " perf['total_reward'] = total_reward_stat\n", 519 | " save_obj(name='MDP_ENV_SIZE_NINE',obj=perf)\n", 520 | " \n", 521 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "metadata": {}, 527 | "source": [ 528 | "

Create Policy GIF

" 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": {}, 534 | "source": [ 535 | "

Collect Frames Of an Episode Using Trained Network

" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 28, 541 | "metadata": {}, 542 | "outputs": [ 543 | { 544 | "name": "stdout", 545 | "output_type": "stream", 546 | "text": [ 547 | "Total Reward : 14\n" 548 | ] 549 | } 550 | ], 551 | "source": [ 552 | "frames = []\n", 553 | "random.seed(110)\n", 554 | "np.random.seed(110)\n", 555 | "\n", 556 | "for episode in range(0,1):\n", 557 | " \n", 558 | " prev_state = env.reset()\n", 559 | " processed_prev_state = preprocess_image(prev_state)\n", 560 | " frames.append(prev_state)\n", 561 | " game_over = False\n", 562 | " step_count = 0\n", 563 | " total_reward = 0\n", 564 | " \n", 565 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 566 | " \n", 567 | " step_count +=1\n", 568 | " \n", 569 | " with torch.no_grad():\n", 570 | " torch_x = torch.from_numpy(processed_prev_state).float()\n", 571 | " model_out = main_model.forward(torch_x,bsize=1)\n", 572 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 573 | " \n", 574 | " next_state, reward, game_over = env.step(action)\n", 575 | " frames.append(next_state)\n", 576 | " processed_next_state = preprocess_image(next_state)\n", 577 | " total_reward += reward\n", 578 | " \n", 579 | " prev_state = next_state\n", 580 | " processed_prev_state = processed_next_state\n", 581 | "\n", 582 | "print('Total Reward : %d'%(total_reward)) # This should output same value which verifies seed is working correctly\n", 583 | " " 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 29, 589 | "metadata": {}, 590 | "outputs": [], 591 | "source": [ 592 | "from PIL import Image, ImageDraw\n", 593 | "\n", 594 | "for idx, img in enumerate(frames):\n", 595 | " image = Image.fromarray(img)\n", 596 | " drawer = ImageDraw.Draw(image)\n", 597 | " drawer.rectangle([(7,7),(76,76)], outline=(255, 255, 0))\n", 598 | " #plt.imshow(np.array(image))\n", 599 | " frames[idx] = np.array(image)\n", 600 | " " 601 | ] 602 | }, 603 | { 604 | "cell_type": "markdown", 605 | "metadata": {}, 606 | "source": [ 607 | "

Frames to GIF

" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 30, 613 | "metadata": {}, 614 | "outputs": [ 615 | { 616 | "name": "stderr", 617 | "output_type": "stream", 618 | "text": [ 619 | "/home/mayank/miniconda3/envs/rdqn/lib/python3.5/site-packages/ipykernel_launcher.py:5: DeprecationWarning: `imresize` is deprecated!\n", 620 | "`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n", 621 | "Use ``skimage.transform.resize`` instead.\n", 622 | " \"\"\"\n" 623 | ] 624 | } 625 | ], 626 | "source": [ 627 | "import imageio\n", 628 | "from scipy.misc import imresize\n", 629 | "resized_frames = []\n", 630 | "for frame in frames:\n", 631 | " resized_frames.append(imresize(frame,(256,256)))\n", 632 | "imageio.mimsave('data/GIFs/MDP_SIZE_9.gif',resized_frames,fps=4)" 633 | ] 634 | } 635 | ], 636 | "metadata": { 637 | "kernelspec": { 638 | "display_name": "Python 3", 639 | "language": "python", 640 | "name": "python3" 641 | }, 642 | "language_info": { 643 | "codemirror_mode": { 644 | "name": "ipython", 645 | "version": 3 646 | }, 647 | "file_extension": ".py", 648 | "mimetype": "text/x-python", 649 | "name": "python", 650 | "nbconvert_exporter": "python", 651 | "pygments_lexer": "ipython3", 652 | "version": "3.5.6" 653 | } 654 | }, 655 | "nbformat": 4, 656 | "nbformat_minor": 2 657 | } 658 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/Two Observations-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from gridworld import gameEnv\n", 12 | "import numpy as np\n", 13 | "%matplotlib inline\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from collections import deque\n", 16 | "import pickle\n", 17 | "from skimage.color import rgb2gray\n", 18 | "import random\n", 19 | "import torch\n", 20 | "import torch.nn as nn" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "

Define Environment Object

" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADKdJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUOtFERUtYBbQiLoRYRAURVVSNykLTSREmivIvUikaokXFSRECRFFeVPCDTIikipQ1RVqhyOgSZgQ2wIBFuA3RRKSqW2Tt5ezDg9ce2cOT67e87w+36k1e7M7Gp+o/GzMzue876pKiS15RdWegCSZs/gSw0y+FKDDL7UIIMvNcjgSw0y+FKDlhX8JFcmeS7J3iS3TGpQkqYrx3sDT5I1wPeArcA+4HHg+qraNbnhSZqGtcv47PuAvVX1AkCSe4FrgGMG//TTT6+5ubllrPLtb+fOnSs9BI1cVWWx9ywn+GcCLy+Y3gf85s/7wNzcHPPz88tY5dtfsug+k5Zt6hf3ktyYZD7J/MGDB6e9OkkDLCf4+4GzFkxv7Of9jKq6vaq2VNWWM844YxmrkzQpywn+48A5STYlORG4Dnh4MsOSNE3H/Ru/qg4l+SPgG8Aa4EtV9czERiZpapZzcY+q+jrw9QmNRdKMeOee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBFg5/kS0kOJHl6wbx1SR5Nsqd/PnW6w5Q0SUOO+H8NXHnEvFuA7VV1DrC9n5Y0EosGv6r+Efi3I2ZfA9zVv74L+P0Jj0vSFB3vb/z1VfVK//pVYP2ExiNpBpZ9ca+6rpvH7LxpJx1p9Tne4L+WZANA/3zgWG+0k460+hxv8B8GPtq//ijwtckMR9IsDPnvvHuAfwbOTbIvyQ3AZ4CtSfYAV/TTkkZi0U46VXX9MRZdPuGxSJoR79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGjSk9NZZSR5LsivJM0lu6ufbTUcaqSFH/EPAJ6rqfOBi4GNJzsduOtJoDemk80pVPdG//hGwGzgTu+lIo7Wk3/hJ5oALgR0M7KZjQw1p9Rkc/CTvBr4K3FxVby5c9vO66dhQQ1p9BgU/yQl0ob+7qh7sZw/upiNpdRlyVT/AncDuqvrcgkV205FGatGGGsClwB8A303yVD/vz+i659zfd9Z5Cbh2OkOUNGlDOun8E5BjLLabjjRC3rknNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0a8vf4mqmjVjDTTx3rL8S1FB7xpQYZfKlBQ2runZTk20n+pe+k8+l+/qYkO5LsTXJfkhOnP1xJkzDkiP9fwGVVdQGwGbgyycXAZ4HPV9V7gdeBG6Y3TEmTNKSTTlXVf/STJ/SPAi4DHujn20lHGpGhdfXX9BV2DwCPAs8Db1TVof4t++jaah3ts3bSkVaZQcGvqh9X1WZgI/A+4LyhK7CTjrT6LOmqflW9ATwGXAKckuTwfQAbgf0THpukKRlyVf+MJKf0r98JbKXrmPsY8OH+bXbSkUZkyJ17G4C7kqyh+6K4v6q2JdkF3JvkL4An6dpsSRqBIZ10vkPXGvvI+S/Q/d6XNDLeuSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aHDw+xLbTybZ1k/bSUcaqaUc8W+iK7J5mJ10pJEa2lBjI/BB4I5+OthJRxqtoUf8LwCfBH7ST5+GnXSk0RpSV/9DwIGq2nk8K7CTjrT6DKmrfylwdZKrgJOAXwJuo++k0x/17aQjjciQbrm3VtXGqpoDrgO+WVUfwU460mgt5//xPwV8PMleut/8dtKRRmLIqf5PVdW3gG/1r+2kI42Ud+5JDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aFAhjiQvAj8CfgwcqqotSdYB9wFzwIvAtVX1+nSGKWmSlnLE/52q2lxVW/rpW4DtVXUOsL2fljQCyznVv4aukQbYUEMalaHBL+Dvk+xMcmM/b31VvdK/fhVYP/HRSZqKocU2319V+5P8MvBokmcXLqyqSlJH+2D/RXEjwNlnn72swUqajEFH/Kra3z8fAB6iq677WpINAP3zgWN81k460iozpIXWyUl+8fBr4HeBp4GH6RppgA01pFEZcqq/Hnioa5DLWuBvq+qRJI8D9ye5AXgJuHZ6w5Q0SYsGv2+cccFR5v8QuHwag5I0Xd65JzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVo6F/naWay0gNQAzziSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEHBT3JKkgeSPJtkd5JLkqxL8miSPf3zqdMerKTJGHrEvw14pKrOoyvDtRs76UijNaTK7nuA3wLuBKiq/66qN7CTjjRaQ474m4CDwJeTPJnkjr7Mtp10pJEaEvy1wEXAF6vqQuAtjjitr6qia7P1/yS5Mcl8kvmDBw8ud7ySJmBI8PcB+6pqRz/9AN0XgZ10pJFaNPhV9SrwcpJz+1mXA7uwk440WkP/LPePgbuTnAi8APwh3ZeGnXSkERoU/Kp6CthylEV20pFGyDv3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYNqat/bpKnFjzeTHKznXSk8RpSbPO5qtpcVZuB3wD+E3gIO+lIo7XUU/3Lgeer6iXspCON1lKDfx1wT//aTjrSSA0Ofl9a+2rgK0cus5OONC5LOeJ/AHiiql7rp+2kI43UUoJ/Pf93mg920pFGa1Dw++64W4EHF8z+DLA1yR7gin5a0ggM7aTzFnDaEfN+iJ10pFHyzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQUNLb/1pkmeSPJ3kniQnJdmUZEeSvUnu66vwShqBIS20zgT+BNhSVb8OrKGrr/9Z4PNV9V7gdeCGaQ5U0uQMPdVfC7wzyVrgXcArwGXAA/1yO+lIIzKkd95+4C+BH9AF/t+BncAbVXWof9s+4MxpDVLSZA051T+Vrk/eJuBXgJOBK4euwE460uoz5FT/CuD7VXWwqv6Hrrb+pcAp/ak/wEZg/9E+bCcdafUZEvwfABcneVeS0NXS3wU8Bny4f4+ddKQRGfIbfwfdRbwngO/2n7kd+BTw8SR76Zpt3DnFcUqaoHSNbmdjy5YtNT8/P7P1jVF3UiUdv6pa9B+Rd+5JDTL4UoMMvtQggy81aKYX95IcBN4C/nVmK52+03F7Vqu307bAsO351apa9IaZmQYfIMl8VW2Z6UqnyO1Zvd5O2wKT3R5P9aUGGXypQSsR/NtXYJ3T5PasXm+nbYEJbs/Mf+NLWnme6ksNmmnwk1yZ5Lm+Tt8ts1z3ciU5K8ljSXb19Qdv6uevS/Jokj3986krPdalSLImyZNJtvXTo62lmOSUJA8keTbJ7iSXjHn/TLPW5cyCn2QN8FfAB4DzgeuTnD+r9U/AIeATVXU+cDHwsX78twDbq+ocYHs/PSY3AbsXTI+5luJtwCNVdR5wAd12jXL/TL3WZVXN5AFcAnxjwfStwK2zWv8UtudrwFbgOWBDP28D8NxKj20J27CRLgyXAduA0N0gsvZo+2w1P4D3AN+nv261YP4o9w9dKbuXgXV0NS+3Ab83qf0zy1P9wxty2Gjr9CWZAy4EdgDrq+qVftGrwPoVGtbx+ALwSeAn/fRpjLeW4ibgIPDl/qfLHUlOZqT7p6Zc69KLe0uU5N3AV4Gbq+rNhcuq+xoexX+TJPkQcKCqdq70WCZkLXAR8MWqupDu1vCfOa0f2f5ZVq3Lxcwy+PuBsxZMH7NO32qV5AS60N9dVQ/2s19LsqFfvgE4sFLjW6JLgauTvAjcS3e6fxsDaymuQvuAfdVVjIKuatRFjHf/LKvW5WJmGfzHgXP6q5In0l2oeHiG61+Wvt7gncDuqvrcgkUP09UchBHVHqyqW6tqY1XN0e2Lb1bVRxhpLcWqehV4Ocm5/azDtSFHuX+Ydq3LGV+wuAr4HvA88OcrfQFliWN/P91p4neAp/rHVXS/i7cDe4B/ANat9FiPY9t+G9jWv/414NvAXuArwDtWenxL2I7NwHy/j/4OOHXM+wf4NPAs8DTwN8A7JrV/vHNPapAX96QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxr0v95G4v3/3GYuAAAAAElFTkSuQmCC\n", 38 | "text/plain": [ 39 | "
" 40 | ] 41 | }, 42 | "metadata": {}, 43 | "output_type": "display_data" 44 | } 45 | ], 46 | "source": [ 47 | "env = gameEnv(partial=True,size=9)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "text/plain": [ 58 | "" 59 | ] 60 | }, 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | }, 65 | { 66 | "data": { 67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3VuMXeV5xvH/Uw8OCUljm7SWi0ltFARCVTGRlYLggpLSOjSCXEQpKJHSKi03qUraSsG0Fy2VIiVSlYSLKpIFSVGVcohDE4uLpK7jpL1yMIe2YONgEgi2DKYCcrpAdXh7sZfbwR17r5nZe2YW3/8njfZeax/Wt2bp2eswe943VYWktvzCcg9A0tIz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEUFP8m2JIeSHE6yfVKDkjRdWegXeJKsAr4HXAscAR4CbqqqA5MbnqRpmFnEa98DHK6q7wMkuRe4ATht8JP4NUFpyqoq456zmEP984DnZk0f6eZJWuEWs8fvJcnNwM3TXo6k/hYT/KPA+bOmN3bzXqeqdgA7wEN9aaVYzKH+Q8CFSTYnWQ3cCOyazLAkTdOC9/hVdSLJHwPfBFYBX6yqJyY2MklTs+A/5y1oYR7qS1M37av6kgbK4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVobPCTfDHJ8SSPz5q3LsnuJE91t2unO0xJk9Rnj//3wLZT5m0H9lTVhcCeblrSQIwNflX9K/DSKbNvAO7u7t8NfGDC45I0RQs9x19fVce6+88D6yc0HklLYNGddKqqzlQ910460sqz0D3+C0k2AHS3x0/3xKraUVVbq2rrApclacIWGvxdwEe7+x8Fvj6Z4UhaCmMbaiS5B7gaeAfwAvBXwNeA+4F3As8CH6qqUy8AzvVeNtSQpqxPQw076UhvMHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPp10zk+yN8mBJE8kuaWbbzcdaaD61NzbAGyoqkeSvA14mFEDjd8HXqqqTyfZDqytqlvHvJelt6Qpm0jprao6VlWPdPd/AhwEzsNuOtJgzauhRpJNwGXAPnp207GhhrTy9K6ym+StwHeAT1XVA0leqao1sx5/uarOeJ7vob40fROrspvkLOCrwJer6oFudu9uOpJWlj5X9QPcBRysqs/OeshuOtJA9bmqfxXwb8B/Aq91s/+C0Xn+vLrpeKgvTZ+ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgedXc01LwP5fPbOx/nKoH9/hSgwy+1KA+NffOTvLdJP/eddK5vZu/Ocm+JIeT3Jdk9fSHK2kS+uzxXwWuqapLgS3AtiSXA58BPldV7wJeBj42vWFKmqQ+nXSqqn7aTZ7V/RRwDbCzm28nHWlA+tbVX5XkMUa183cDTwOvVNWJ7ilHGLXVmuu1NyfZn2T/JAYsafF6Bb+qfl5VW4CNwHuAi/suoKp2VNXWqtq6wDFKmrB5XdWvqleAvcAVwJokJ78HsBE4OuGxSZqSPlf1fynJmu7+m4FrGXXM3Qt8sHuanXSkAenTSefXGV28W8Xog+L+qvqbJBcA9wLrgEeBj1TVq2Pey6+ljeWv6Mz85t44dtIZJH9FZ2bwx7GTjqQ5GXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Q5+V2L70SQPdtN20pEGaj57/FsYFdk8yU460kD1baixEfhd4M5uOthJRxqsvnv8zwOfBF7rps/FTjrSYPWpq/9+4HhVPbyQBdhJR1p5ZsY/hSuB65NcB5wN/CJwB10nnW6vbycdaUD6dMu9rao2VtUm4EbgW1X1YeykIw3WYv6OfyvwZ0kOMzrnv2syQ5I0bXbSWXH8FZ2ZnXTGsZOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn5h5JngF+AvwcOFFVW5OsA+4DNgHPAB+qqpenM0xJkzSfPf5vVtWWWdVytwN7qupCYE83LWkAFnOofwOjRhpgQw1pUPoGv4B/TvJwkpu7eeur6lh3/3lg/cRHJ2kqep3jA1dV1dEkvwzsTvLk7Aerqk5XSLP7oLh5rsckLY95V9lN8tfAT4E/Aq6uqmNJNgDfrqqLxrzWErJj+Ss6M6vsjjORKrtJzknytpP3gd8GHgd2MWqkATbUkAZl7B4/yQXAP3WTM8A/VtWnkpwL3A+8E3iW0Z/zXhrzXu7OxvJXdGbu8cfps8e3ocaK46/ozAz+ODbUkDQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgvv+dpyXjN9M0fe7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mTZGeSJ5McTHJFknVJdid5qrtdO+3BSpqMvnv8O4BvVNXFwKXAQeykIw1Wn2KbbwceAy6oWU9OcgjLa0srzqRq7m0GXgS+lOTRJHd2ZbbtpCMNVJ/gzwDvBr5QVZcBP+OUw/ruSOC0nXSS7E+yf7GDlTQZfYJ/BDhSVfu66Z2MPghe6A7x6W6Pz/XiqtpRVVtnddmVtMzGBr+qngeeS3Ly/P29wAHspCMNVq+GGkm2AHcCq4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9yUZLHZv38OMkn7KQjDde8Sm8lWQUcBX4D+DjwUlV9Osl2YG1V3Trm9ZbekqZsGqW33gs8XVXPAjcAd3fz7wY+MM/3krRM5hv8G4F7uvt20pEGqnfwk6wGrge+cupjdtKRhmU+e/z3AY9U1QvdtJ10pIGaT/Bv4v8O88FOOtJg9e2kcw7wQ0atsn/UzTsXO+lIK46ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6BT/JnyZ5IsnjSe5JcnaSzUn2JTmc5L6uCq+kAejTQus84E+ArVX1a8AqRvX1PwN8rqreBbwMfGyaA5U0OX0P9WeANyeZAd4CHAOuAXZ2j9tJRxqQscGvqqPA3zKqsnsM+BHwMPBKVZ3onnYEOG9ag5Q0WX0O9dcy6pO3GfgV4BxgW98F2ElHWnlmejznt4AfVNWLAEkeAK4E1iSZ6fb6Gxl10f1/qmoHsKN7reW1pRWgzzn+D4HLk7wlSRh1zD0A7AU+2D3HTjrSgPTtpHM78HvACeBR4A8ZndPfC6zr5n2kql4d8z7u8aUps5OO1CA76Uiak8GXGmTwpQYZfKlBff6OP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akv1VtXVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7g71iGZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOfZFuSQ12dvu1LuezFSnJ+kr1JDnT1B2/p5q9LsjvJU93t2uUe63wkWZXk0SQPdtODraWYZE2SnUmeTHIwyRVD3j7TrHW5ZMFPsgr4O+B9wCXATUkuWarlT8AJ4M+r6hLgcuDj3fi3A3uq6kJgTzc9JLcAB2dND7mW4h3AN6rqYuBSRus1yO0z9VqXVbUkP8AVwDdnTd8G3LZUy5/C+nwduBY4BGzo5m0ADi332OaxDhsZheEa4EEgjL4gMjPXNlvJP8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3JFThpsnb4km4DLgH3A+qo61j30PLB+mYa1EJ8HPgm81k2fy3BrKW4GXgS+1J263JnkHAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsSfSZK8HzheVQ8v91gmZAZ4N/CFqrqM0VfDX3dYP7Dts6hal+MsZfCPAufPmj5tnb6VKslZjEL/5ap6oJv9QpIN3eMbgOPLNb55uhK4PskzjCopXcPoHHlNV0YdhrWNjgBHqmpfN72T0QfBULfP/9a6rKr/Bl5X67J7zoK3z1IG/yHgwu6q5GpGFyp2LeHyF6WrN3gXcLCqPjvroV2Mag7CgGoPVtVtVbWxqjYx2hbfqqoPM9BailX1PPBckou6WSdrQw5y+zDtWpdLfMHiOuB7wNPAXy73BZR5jv0qRoeJ/wE81v1cx+i8eA/wFPAvwLrlHusC1u1q4MHu/gXAd4HDwFeANy33+OaxHluA/d02+hqwdsjbB7gdeBJ4HPgH4E2T2j5+c09qkBf3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGvQ/4fcTtlJMEyYAAAAASUVORK5CYII=\n", 68 | "text/plain": [ 69 | "
" 70 | ] 71 | }, 72 | "metadata": {}, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "prev_state = env.reset()\n", 78 | "plt.imshow(prev_state)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "

Training Q Network

" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "

Hyper-parameters

" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": { 99 | "collapsed": true 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "BATCH_SIZE = 32\n", 104 | "FREEZE_INTERVAL = 20000 # steps\n", 105 | "MEMORY_SIZE = 60000 \n", 106 | "OUTPUT_SIZE = 4\n", 107 | "TOTAL_EPISODES = 10000\n", 108 | "MAX_STEPS = 50\n", 109 | "INITIAL_EPSILON = 1.0\n", 110 | "FINAL_EPSILON = 0.1\n", 111 | "GAMMA = 0.99\n", 112 | "INPUT_IMAGE_DIM = 84\n", 113 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "

Save Dictionay Function

" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 6, 126 | "metadata": { 127 | "collapsed": true 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "def save_obj(obj, name ):\n", 132 | " with open('data/'+ name + '.pkl', 'wb') as f:\n", 133 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "

Experience Replay

" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 7, 146 | "metadata": { 147 | "collapsed": true 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "class Memory():\n", 152 | " \n", 153 | " def __init__(self,memsize):\n", 154 | " self.memsize = memsize\n", 155 | " self.memory = deque(maxlen=self.memsize)\n", 156 | " \n", 157 | " def add_sample(self,sample):\n", 158 | " self.memory.append(sample)\n", 159 | " \n", 160 | " def get_batch(self,size):\n", 161 | " return random.sample(self.memory,k=size)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "

Frame Collector

" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 5, 174 | "metadata": { 175 | "collapsed": true 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "class FrameCollector():\n", 180 | " \n", 181 | " def __init__(self,num_frames,img_dim):\n", 182 | " self.num_frames = num_frames\n", 183 | " self.img_dim = img_dim\n", 184 | " self.frames = deque(maxlen=self.num_frames)\n", 185 | " \n", 186 | " def reset(self):\n", 187 | " tmp = np.zeros((self.img_dim,self.img_dim))\n", 188 | " for i in range(0,self.num_frames):\n", 189 | " self.frames.append(tmp)\n", 190 | " \n", 191 | " def add_frame(self,frame):\n", 192 | " self.frames.append(frame)\n", 193 | " \n", 194 | " def get_state(self):\n", 195 | " return np.array(self.frames)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "

Preprocess Images

" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 6, 208 | "metadata": { 209 | "collapsed": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "def preprocess_image(image):\n", 214 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n", 215 | " return np.copy(image)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 7, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "" 227 | ] 228 | }, 229 | "execution_count": 7, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | }, 233 | { 234 | "data": { 235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3V2MXPV5x/Hvr14ICUljm7SWi0kxigVCVTGRlYLggpLSEhpBLqIUlEhpldY3qUraSsG0Fy2VIiVSlYSLKpIFSVGV8hKHJhYXSV2HpL1ysDFtwcbBJBBs+YUKyNsFqsPTizluF7p4zu7O7O7h//1Iq5lz5uX8j45+c15m9nlSVUhqyy8s9wAkLT2DLzXI4EsNMvhSgwy+1CCDLzXI4EsNWlTwk1yf5FCSw0m2TWpQkqYrC/0BT5JVwPeA64AjwCPALVV1YHLDkzQNM4t47XuAw1X1fYAk9wE3Aa8b/CT+TFCasqrKuOcs5lD/fOC5WdNHunmSVrjF7PF7SbIV2Drt5UjqbzHBPwpcMGt6QzfvVapqO7AdPNSXVorFHOo/AmxKsjHJ2cDNwM7JDEvSNC14j19Vp5L8MfBNYBXwxap6YmIjkzQ1C/46b0EL81BfmrppX9WXNFAGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUFjg5/ki0lOJnl81ry1SXYleaq7XTPdYUqapD57/L8Hrn/NvG3A7qraBOzupiUNxNjgV9W/Ai+8ZvZNwD3d/XuAD0x4XJKmaKHn+Ouq6lh3/ziwbkLjkbQEFt1Jp6rqTNVz7aQjrTwL3eOfSLIeoLs9+XpPrKrtVbWlqrYscFmSJmyhwd8JfLS7/1Hg65MZjqSlMLahRpJ7gWuAdwAngL8CvgY8ALwTeBb4UFW99gLgXO9lQw1pyvo01LCTjvQGYycdSXMy+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw3q00nngiQPJzmQ5Ikkt3bz7aYjDVSfmnvrgfVV9WiStwH7GDXQ+H3ghar6dJJtwJqqum3Me1l6S5qyiZTeqqpjVfVod/8nwEHgfOymIw3WvBpqJLkQuBzYQ89uOjbUkFae3lV2k7wV+A7wqap6MMlLVbV61uMvVtUZz/M91Jemb2JVdpOcBXwV+HJVPdjN7t1NR9LK0ueqfoC7gYNV9dlZD9lNRxqoPlf1rwb+DfhP4JVu9l8wOs+fVzcdD/Wl6bOTjtQgO+lImpPBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxo0r5p7mr6l/DfpIRrVhdFiuceXGmTwpQb1qbl3TpLvJvn3rpPOHd38jUn2JDmc5P4kZ09/uJImoc8e/2Xg2qq6DNgMXJ/kCuAzwOeq6l3Ai8DHpjdMSZPUp5NOVdVPu8mzur8CrgV2dPPtpCMNSN+6+quSPMaodv4u4Gngpao61T3lCKO2WnO9dmuSvUn2TmLAkhavV/Cr6udVtRnYALwHuKTvAqpqe1VtqaotCxyjpAmb11X9qnoJeBi4Elid5PTvADYARyc8NklT0ueq/i8lWd3dfzNwHaOOuQ8DH+yeZicdaUD6dNL5dUYX71Yx+qB4oKr+JslFwH3AWmA/8JGqennMe/mztDH85d6Z+cu98eykM0AG/8wM/nh20pE0J4MvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoN7B70ps70/yUDdtJx1poOazx7+VUZHN0+ykIw1U34YaG4DfBe7qpoOddKTB6rvH/zzwSeCVbvo87KQjDVafuvrvB05W1b6FLMBOOtLKMzP+KVwF3JjkBuAc4BeBO+k66XR7fTvpSAPSp1vu7VW1oaouBG4GvlVVH8ZOOtJgLeZ7/NuAP0tymNE5/92TGZKkabOTzgpjJ50zs5POeHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoP61NwjyTPAT4CfA6eqakuStcD9wIXAM8CHqurF6QxT0iTNZ4//m1W1eVa13G3A7qraBOzupiUNwGIO9W9i1EgDbKghDUrf4Bfwz0n2JdnazVtXVce6+8eBdRMfnaSp6HWOD1xdVUeT/DKwK8mTsx+sqnq9QprdB8XWuR6TtDzmXWU3yV8DPwX+CLimqo4lWQ98u6ouHvNaS8iOYZXdM7PK7ngTqbKb5Nwkbzt9H/ht4HFgJ6NGGmBDDWlQxu7xk1wE/FM3OQP8Y1V9Ksl5wAPAO4FnGX2d98KY93J3NoZ7/DNzjz9enz2+DTVWGIN/ZgZ/PBtqSJqTwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQ3//O0xLxl2laCu7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mdZEeSJ5McTHJlkrVJdiV5qrtdM+3BSpqMvnv8O4FvVNUlwGXAQeykIw1Wn2KbbwceAy6qWU9OcgjLa0srzqRq7m0Enge+lGR/kru6Mtt20pEGqk/wZ4B3A1+oqsuBn/Gaw/ruSOB1O+kk2Ztk72IHK2ky+gT/CHCkqvZ00zsYfRCc6A7x6W5PzvXiqtpeVVtmddmVtMzGBr+qjgPPJTl9/v5e4AB20pEGq1dDjSSbgbuAs4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9ycZLHZv39OMkn7KQjDde8Sm8lWQUcBX4D+DjwQlV9Osk2YE1V3Tbm9ZbekqZsGqW33gs8XVXPAjcB93Tz7wE+MM/3krRM5hv8m4F7u/t20pEGqnfwk5wN3Ah85bWP2UlHGpb57PHfBzxaVSe6aTvpSAM1n+Dfwv8d5oOddKTB6ttJ51zgh4xaZf+om3cedtKRVhw76UgNspOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQb1Cn6SP03yRJLHk9yb5JwkG5PsSXI4yf1dFV5JA9Cnhdb5wJ8AW6rq14BVjOrrfwb4XFW9C3gR+Ng0Byppcvoe6s8Ab04yA7wFOAZcC+zoHreTjjQgY4NfVUeBv2VUZfcY8CNgH/BSVZ3qnnYEOH9ag5Q0WX0O9dcw6pO3EfgV4Fzg+r4LsJOOtPLM9HjObwE/qKrnAZI8CFwFrE4y0+31NzDqovv/VNV2YHv3WstrSytAn3P8HwJXJHlLkjDqmHsAeBj4YPccO+lIA9K3k84dwO8Bp4D9wB8yOqe/D1jbzftIVb085n3c40tTZicdqUF20pE0J4MvNcjgSw0y+FKD+nyPP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akr1VtWVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7gb1+GZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOf5Pokh7o6fduWctmLleSCJA8nOdDVH7y1m782ya4kT3W3a5Z7rPORZFWS/Uke6qYHW0sxyeokO5I8meRgkiuHvH2mWetyyYKfZBXwd8D7gEuBW5JculTLn4BTwJ9X1aXAFcDHu/FvA3ZX1SZgdzc9JLcCB2dND7mW4p3AN6rqEuAyRus1yO0z9VqXVbUkf8CVwDdnTd8O3L5Uy5/C+nwduA44BKzv5q0HDi332OaxDhsYheFa4CEgjH4gMjPXNlvJf8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3pFThtsnb4kFwKXA3uAdVV1rHvoOLBumYa1EJ8HPgm80k2fx3BrKW4Enge+1J263JXkXAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsTXJEneD5ysqn3LPZYJmQHeDXyhqi5n9NPwVx3WD2z7LKrW5ThLGfyjwAWzpl+3Tt9KleQsRqH/clU92M0+kWR99/h64ORyjW+ergJuTPIMo0pK1zI6R17dlVGHYW2jI8CRqtrTTe9g9EEw1O3zv7Uuq+q/gVfVuuyes+Dts5TBfwTY1F2VPJvRhYqdS7j8RenqDd4NHKyqz856aCejmoMwoNqDVXV7VW2oqgsZbYtvVdWHGWgtxao6DjyX5OJu1unakIPcPky71uUSX7C4Afge8DTwl8t9AWWeY7+a0WHifwCPdX83MDov3g08BfwLsHa5x7qAdbsGeKi7fxHwXeAw8BXgTcs9vnmsx2Zgb7eNvgasGfL2Ae4AngQeB/4BeNOkto+/3JMa5MU9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBv0PANIhuBFMr1IAAAAASUVORK5CYII=\n", 236 | "text/plain": [ 237 | "
" 238 | ] 239 | }, 240 | "metadata": {}, 241 | "output_type": "display_data" 242 | } 243 | ], 244 | "source": [ 245 | "processed_prev_state = preprocess_image(prev_state)\n", 246 | "plt.imshow(processed_prev_state,cmap='gray')" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "

Build Model

" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 8, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Network(\n", 266 | " (conv_layer1): Conv2d(2, 32, kernel_size=(8, 8), stride=(4, 4))\n", 267 | " (conv_layer2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n", 268 | " (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", 269 | " (fc1): Linear(in_features=3136, out_features=512, bias=True)\n", 270 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n", 271 | " (relu): ReLU()\n", 272 | ")\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "import torch.nn as nn\n", 278 | "import torch\n", 279 | "\n", 280 | "class Network(nn.Module):\n", 281 | " \n", 282 | " def __init__(self,image_input_size,out_size):\n", 283 | " super(Network,self).__init__()\n", 284 | " self.image_input_size = image_input_size\n", 285 | " self.out_size = out_size\n", 286 | "\n", 287 | " self.conv_layer1 = nn.Conv2d(in_channels=2,out_channels=32,kernel_size=8,stride=4) # GRAY - 1\n", 288 | " self.conv_layer2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4,stride=2)\n", 289 | " self.conv_layer3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1)\n", 290 | " self.fc1 = nn.Linear(in_features=7*7*64,out_features=512)\n", 291 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n", 292 | " self.relu = nn.ReLU()\n", 293 | "\n", 294 | " def forward(self,x,bsize):\n", 295 | " x = x.view(bsize,2,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n", 296 | " conv_out = self.conv_layer1(x)\n", 297 | " conv_out = self.relu(conv_out)\n", 298 | " conv_out = self.conv_layer2(conv_out)\n", 299 | " conv_out = self.relu(conv_out)\n", 300 | " conv_out = self.conv_layer3(conv_out)\n", 301 | " conv_out = self.relu(conv_out)\n", 302 | " out = self.fc1(conv_out.view(bsize,7*7*64))\n", 303 | " out = self.relu(out)\n", 304 | " out = self.fc2(out)\n", 305 | " return out\n", 306 | "\n", 307 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).cuda()\n", 308 | "print(main_model)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "

Deep Q Learning with Freeze Network

" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "Populated 60000 Samples in Episodes : 1200\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "mem = Memory(memsize=MEMORY_SIZE)\n", 333 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Primary Network\n", 334 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Target Network\n", 335 | "frameObj = FrameCollector(img_dim=INPUT_IMAGE_DIM,num_frames=2)\n", 336 | "\n", 337 | "target_model.load_state_dict(main_model.state_dict())\n", 338 | "criterion = nn.SmoothL1Loss()\n", 339 | "optimizer = torch.optim.Adam(main_model.parameters())\n", 340 | "\n", 341 | "# filling memory with transitions\n", 342 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n", 343 | " \n", 344 | " prev_state = env.reset()\n", 345 | " frameObj.reset()\n", 346 | " processed_prev_state = preprocess_image(prev_state)\n", 347 | " frameObj.add_frame(processed_prev_state)\n", 348 | " prev_frames = frameObj.get_state()\n", 349 | " step_count = 0\n", 350 | " game_over = False\n", 351 | " \n", 352 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 353 | " \n", 354 | " step_count +=1\n", 355 | " action = np.random.randint(0,4)\n", 356 | " next_state,reward, game_over = env.step(action)\n", 357 | " processed_next_state = preprocess_image(next_state)\n", 358 | " frameObj.add_frame(processed_next_state)\n", 359 | " next_frames = frameObj.get_state()\n", 360 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n", 361 | " \n", 362 | " prev_state = next_state\n", 363 | " processed_prev_state = processed_next_state\n", 364 | " prev_frames = next_frames\n", 365 | "\n", 366 | "print('Populated %d Samples in Episodes : %d'%(len(mem.memory),int(MEMORY_SIZE/MAX_STEPS)))\n", 367 | "\n", 368 | "\n", 369 | "# Algorithm Starts\n", 370 | "total_steps = 0\n", 371 | "epsilon = INITIAL_EPSILON\n", 372 | "loss_stat = []\n", 373 | "total_reward_stat = []\n", 374 | "\n", 375 | "for episode in range(0,TOTAL_EPISODES):\n", 376 | " \n", 377 | " prev_state = env.reset()\n", 378 | " frameObj.reset()\n", 379 | " processed_prev_state = preprocess_image(prev_state)\n", 380 | " frameObj.add_frame(processed_prev_state)\n", 381 | " prev_frames = frameObj.get_state()\n", 382 | " game_over = False\n", 383 | " step_count = 0\n", 384 | " total_reward = 0\n", 385 | " \n", 386 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 387 | " \n", 388 | " step_count +=1\n", 389 | " total_steps +=1\n", 390 | " \n", 391 | " if np.random.rand() <= epsilon:\n", 392 | " action = np.random.randint(0,4)\n", 393 | " else:\n", 394 | " with torch.no_grad():\n", 395 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n", 396 | "\n", 397 | " model_out = main_model.forward(torch_x,bsize=1)\n", 398 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 399 | " \n", 400 | " next_state, reward, game_over = env.step(action)\n", 401 | " processed_next_state = preprocess_image(next_state)\n", 402 | " frameObj.add_frame(processed_next_state)\n", 403 | " next_frames = frameObj.get_state()\n", 404 | " total_reward += reward\n", 405 | " \n", 406 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n", 407 | " \n", 408 | " prev_state = next_state\n", 409 | " processed_prev_state = processed_next_state\n", 410 | " prev_frames = next_frames\n", 411 | " \n", 412 | " if (total_steps % FREEZE_INTERVAL) == 0:\n", 413 | " target_model.load_state_dict(main_model.state_dict())\n", 414 | " \n", 415 | " batch = mem.get_batch(size=BATCH_SIZE)\n", 416 | " current_states = []\n", 417 | " next_states = []\n", 418 | " acts = []\n", 419 | " rewards = []\n", 420 | " game_status = []\n", 421 | " \n", 422 | " for element in batch:\n", 423 | " current_states.append(element[0])\n", 424 | " acts.append(element[1])\n", 425 | " rewards.append(element[2])\n", 426 | " next_states.append(element[3])\n", 427 | " game_status.append(element[4])\n", 428 | " \n", 429 | " current_states = np.array(current_states)\n", 430 | " next_states = np.array(next_states)\n", 431 | " rewards = np.array(rewards)\n", 432 | " game_status = [not b for b in game_status]\n", 433 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n", 434 | " torch_acts = torch.tensor(acts)\n", 435 | " \n", 436 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n", 437 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n", 438 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n", 439 | " Q_max_next = Q_max_next.double()\n", 440 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n", 441 | " \n", 442 | " target_values = (rewards + (GAMMA * Q_max_next))\n", 443 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n", 444 | " \n", 445 | " loss = criterion(Q_s_a,target_values.float().cuda())\n", 446 | " \n", 447 | " # save performance measure\n", 448 | " loss_stat.append(loss.item())\n", 449 | " \n", 450 | " # make previous grad zero\n", 451 | " optimizer.zero_grad()\n", 452 | " \n", 453 | " # back - propogate \n", 454 | " loss.backward()\n", 455 | " \n", 456 | " # update params\n", 457 | " optimizer.step()\n", 458 | " \n", 459 | " # save performance measure\n", 460 | " total_reward_stat.append(total_reward)\n", 461 | " \n", 462 | " if epsilon > FINAL_EPSILON:\n", 463 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 464 | " \n", 465 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 466 | " perf = {}\n", 467 | " perf['loss'] = loss_stat\n", 468 | " perf['total_reward'] = total_reward_stat\n", 469 | " save_obj(name='TWO_OBSERV_NINE',obj=perf)\n", 470 | " \n", 471 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)\n" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "

Save Primary Network Weights

" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 19, 484 | "metadata": { 485 | "collapsed": true 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "torch.save(main_model.state_dict(),'data/TWO_OBSERV_NINE_WEIGHTS.torch')" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": {}, 495 | "source": [ 496 | "

Testing Policy

" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "

Load Primary Network Weights

" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 10, 509 | "metadata": {}, 510 | "outputs": [], 511 | "source": [ 512 | "weights = torch.load('data/TWO_OBSERV_NINE_WEIGHTS.torch')\n", 513 | "main_model.load_state_dict(weights)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "

Testing Policy

" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": { 527 | "collapsed": true 528 | }, 529 | "outputs": [], 530 | "source": [ 531 | "# Algorithm Starts\n", 532 | "epsilon = INITIAL_EPSILON\n", 533 | "FINAL_EPSILON = 0.01\n", 534 | "total_reward_stat = []\n", 535 | "\n", 536 | "for episode in range(0,TOTAL_EPISODES):\n", 537 | " \n", 538 | " prev_state = env.reset()\n", 539 | " processed_prev_state = preprocess_image(prev_state)\n", 540 | " frameObj.reset()\n", 541 | " frameObj.add_frame(processed_prev_state)\n", 542 | " prev_frames = frameObj.get_state()\n", 543 | " game_over = False\n", 544 | " step_count = 0\n", 545 | " total_reward = 0\n", 546 | " \n", 547 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 548 | " \n", 549 | " step_count +=1\n", 550 | " \n", 551 | " if np.random.rand() <= epsilon:\n", 552 | " action = np.random.randint(0,4)\n", 553 | " else:\n", 554 | " with torch.no_grad():\n", 555 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n", 556 | "\n", 557 | " model_out = main_model.forward(torch_x,bsize=1)\n", 558 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 559 | " \n", 560 | " next_state, reward, game_over = env.step(action)\n", 561 | " processed_next_state = preprocess_image(next_state)\n", 562 | " frameObj.add_frame(processed_next_state)\n", 563 | " next_frames = frameObj.get_state()\n", 564 | " \n", 565 | " total_reward += reward\n", 566 | " \n", 567 | " prev_state = next_state\n", 568 | " processed_prev_state = processed_next_state\n", 569 | " prev_frames = next_frames\n", 570 | " \n", 571 | " # save performance measure\n", 572 | " total_reward_stat.append(total_reward)\n", 573 | " \n", 574 | " if epsilon > FINAL_EPSILON:\n", 575 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 576 | " \n", 577 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 578 | " perf = {}\n", 579 | " perf['total_reward'] = total_reward_stat\n", 580 | " save_obj(name='TWO_OBSERV_NINE',obj=perf)\n", 581 | " \n", 582 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)" 583 | ] 584 | } 585 | ], 586 | "metadata": { 587 | "kernelspec": { 588 | "display_name": "Python [conda env:myenv]", 589 | "language": "python", 590 | "name": "conda-env-myenv-py" 591 | }, 592 | "language_info": { 593 | "codemirror_mode": { 594 | "name": "ipython", 595 | "version": 3 596 | }, 597 | "file_extension": ".py", 598 | "mimetype": "text/x-python", 599 | "name": "python", 600 | "nbconvert_exporter": "python", 601 | "pygments_lexer": "ipython3", 602 | "version": "3.6.5" 603 | } 604 | }, 605 | "nbformat": 4, 606 | "nbformat_minor": 2 607 | } 608 | -------------------------------------------------------------------------------- /Four Observations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from gridworld import gameEnv\n", 12 | "import numpy as np\n", 13 | "%matplotlib inline\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from collections import deque\n", 16 | "import pickle\n", 17 | "from skimage.color import rgb2gray\n", 18 | "import random\n", 19 | "import torch\n", 20 | "import torch.nn as nn" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "

Define Environment Object

" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3VuMXeV5xvH/Uw8OCUljm7SWi0ltFARCVTGRlYLggpLSOjSCXEQpKJHSKi03qUraSsG0Fy2VIiVSlYSLKpIFSVGVcohDE4uLpK7jpL1yMIe2YONgEgi2DKYCcrpAdXh7sZfbwR17r5nZe2YW3/8njfZeax/Wt2bp2eswe943VYWktvzCcg9A0tIz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEUFP8m2JIeSHE6yfVKDkjRdWegXeJKsAr4HXAscAR4CbqqqA5MbnqRpmFnEa98DHK6q7wMkuRe4ATht8JP4NUFpyqoq456zmEP984DnZk0f6eZJWuEWs8fvJcnNwM3TXo6k/hYT/KPA+bOmN3bzXqeqdgA7wEN9aaVYzKH+Q8CFSTYnWQ3cCOyazLAkTdOC9/hVdSLJHwPfBFYBX6yqJyY2MklTs+A/5y1oYR7qS1M37av6kgbK4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVobPCTfDHJ8SSPz5q3LsnuJE91t2unO0xJk9Rnj//3wLZT5m0H9lTVhcCeblrSQIwNflX9K/DSKbNvAO7u7t8NfGDC45I0RQs9x19fVce6+88D6yc0HklLYNGddKqqzlQ910460sqz0D3+C0k2AHS3x0/3xKraUVVbq2rrApclacIWGvxdwEe7+x8Fvj6Z4UhaCmMbaiS5B7gaeAfwAvBXwNeA+4F3As8CH6qqUy8AzvVeNtSQpqxPQw076UhvMHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPp10zk+yN8mBJE8kuaWbbzcdaaD61NzbAGyoqkeSvA14mFEDjd8HXqqqTyfZDqytqlvHvJelt6Qpm0jprao6VlWPdPd/AhwEzsNuOtJgzauhRpJNwGXAPnp207GhhrTy9K6ym+StwHeAT1XVA0leqao1sx5/uarOeJ7vob40fROrspvkLOCrwJer6oFudu9uOpJWlj5X9QPcBRysqs/OeshuOtJA9bmqfxXwb8B/Aq91s/+C0Xn+vLrpeKgvTZ+ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgedXc01LwP5fPbOx/nKoH9/hSgwy+1KA+NffOTvLdJP/eddK5vZu/Ocm+JIeT3Jdk9fSHK2kS+uzxXwWuqapLgS3AtiSXA58BPldV7wJeBj42vWFKmqQ+nXSqqn7aTZ7V/RRwDbCzm28nHWlA+tbVX5XkMUa183cDTwOvVNWJ7ilHGLXVmuu1NyfZn2T/JAYsafF6Bb+qfl5VW4CNwHuAi/suoKp2VNXWqtq6wDFKmrB5XdWvqleAvcAVwJokJ78HsBE4OuGxSZqSPlf1fynJmu7+m4FrGXXM3Qt8sHuanXSkAenTSefXGV28W8Xog+L+qvqbJBcA9wLrgEeBj1TVq2Pey6+ljeWv6Mz85t44dtIZJH9FZ2bwx7GTjqQ5GXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Q5+V2L70SQPdtN20pEGaj57/FsYFdk8yU460kD1baixEfhd4M5uOthJRxqsvnv8zwOfBF7rps/FTjrSYPWpq/9+4HhVPbyQBdhJR1p5ZsY/hSuB65NcB5wN/CJwB10nnW6vbycdaUD6dMu9rao2VtUm4EbgW1X1YeykIw3WYv6OfyvwZ0kOMzrnv2syQ5I0bXbSWXH8FZ2ZnXTGsZOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn5h5JngF+AvwcOFFVW5OsA+4DNgHPAB+qqpenM0xJkzSfPf5vVtWWWdVytwN7qupCYE83LWkAFnOofwOjRhpgQw1pUPoGv4B/TvJwkpu7eeur6lh3/3lg/cRHJ2kqep3jA1dV1dEkvwzsTvLk7Aerqk5XSLP7oLh5rsckLY95V9lN8tfAT4E/Aq6uqmNJNgDfrqqLxrzWErJj+Ss6M6vsjjORKrtJzknytpP3gd8GHgd2MWqkATbUkAZl7B4/yQXAP3WTM8A/VtWnkpwL3A+8E3iW0Z/zXhrzXu7OxvJXdGbu8cfps8e3ocaK46/ozAz+ODbUkDQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgvv+dpyXjN9M0fe7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mTZGeSJ5McTHJFknVJdid5qrtdO+3BSpqMvnv8O4BvVNXFwKXAQeykIw1Wn2KbbwceAy6oWU9OcgjLa0srzqRq7m0GXgS+lOTRJHd2ZbbtpCMNVJ/gzwDvBr5QVZcBP+OUw/ruSOC0nXSS7E+yf7GDlTQZfYJ/BDhSVfu66Z2MPghe6A7x6W6Pz/XiqtpRVVtnddmVtMzGBr+qngeeS3Ly/P29wAHspCMNVq+GGkm2AHcCq4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9yUZLHZv38OMkn7KQjDde8Sm8lWQUcBX4D+DjwUlV9Osl2YG1V3Trm9ZbekqZsGqW33gs8XVXPAjcAd3fz7wY+MM/3krRM5hv8G4F7uvt20pEGqnfwk6wGrge+cupjdtKRhmU+e/z3AY9U1QvdtJ10pIGaT/Bv4v8O88FOOtJg9e2kcw7wQ0atsn/UzTsXO+lIK46ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6BT/JnyZ5IsnjSe5JcnaSzUn2JTmc5L6uCq+kAejTQus84E+ArVX1a8AqRvX1PwN8rqreBbwMfGyaA5U0OX0P9WeANyeZAd4CHAOuAXZ2j9tJRxqQscGvqqPA3zKqsnsM+BHwMPBKVZ3onnYEOG9ag5Q0WX0O9dcy6pO3GfgV4BxgW98F2ElHWnlmejznt4AfVNWLAEkeAK4E1iSZ6fb6Gxl10f1/qmoHsKN7reW1pRWgzzn+D4HLk7wlSRh1zD0A7AU+2D3HTjrSgPTtpHM78HvACeBR4A8ZndPfC6zr5n2kql4d8z7u8aUps5OO1CA76Uiak8GXGmTwpQYZfKlBff6OP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akv1VtXVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7g71iGZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOfZFuSQ12dvu1LuezFSnJ+kr1JDnT1B2/p5q9LsjvJU93t2uUe63wkWZXk0SQPdtODraWYZE2SnUmeTHIwyRVD3j7TrHW5ZMFPsgr4O+B9wCXATUkuWarlT8AJ4M+r6hLgcuDj3fi3A3uq6kJgTzc9JLcAB2dND7mW4h3AN6rqYuBSRus1yO0z9VqXVbUkP8AVwDdnTd8G3LZUy5/C+nwduBY4BGzo5m0ADi332OaxDhsZheEa4EEgjL4gMjPXNlvJP8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3JFThpsnb4km4DLgH3A+qo61j30PLB+mYa1EJ8HPgm81k2fy3BrKW4GXgS+1J263JnkHAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsSfSZK8HzheVQ8v91gmZAZ4N/CFqrqM0VfDX3dYP7Dts6hal+MsZfCPAufPmj5tnb6VKslZjEL/5ap6oJv9QpIN3eMbgOPLNb55uhK4PskzjCopXcPoHHlNV0YdhrWNjgBHqmpfN72T0QfBULfP/9a6rKr/Bl5X67J7zoK3z1IG/yHgwu6q5GpGFyp2LeHyF6WrN3gXcLCqPjvroV2Mag7CgGoPVtVtVbWxqjYx2hbfqqoPM9BailX1PPBckou6WSdrQw5y+zDtWpdLfMHiOuB7wNPAXy73BZR5jv0qRoeJ/wE81v1cx+i8eA/wFPAvwLrlHusC1u1q4MHu/gXAd4HDwFeANy33+OaxHluA/d02+hqwdsjbB7gdeBJ4HPgH4E2T2j5+c09qkBf3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGvQ/4fcTtlJMEyYAAAAASUVORK5CYII=\n", 38 | "text/plain": [ 39 | "
" 40 | ] 41 | }, 42 | "metadata": {}, 43 | "output_type": "display_data" 44 | } 45 | ], 46 | "source": [ 47 | "env = gameEnv(partial=True,size=9)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "text/plain": [ 58 | "" 59 | ] 60 | }, 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | }, 65 | { 66 | "data": { 67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLlJREFUeJzt3X/oXfV9x/Hna4nW1m41cS5kxs2UiiIDowlOsYxNzWZd0f1RRCmjDMF/uk3XQqvbH6WwP1oYbf1jFETbyXD+qNU1hGLn0pQyGKlff6zVRJtoY01QEzudnYNtad/745xsX0OynG++997v9/h5PuBy7znn3pzP4fC659yT832/U1VIassvLPUAJM2ewZcaZPClBhl8qUEGX2qQwZcaZPClBi0q+EmuSvJckj1Jbp3UoCRNV070Bp4kK4AfApuBfcBjwA1VtXNyw5M0DSsX8dmLgT1V9QJAkvuAa4FjBj+JtwlqUTZu3LjUQ1jW9u7dy2uvvZbjvW8xwT8TeGne9D7gNxfx70nHNTc3t9RDWNY2bdo06H2LCf4gSW4Cbpr2eiQNt5jg7wfOmje9rp/3NlV1B3AHeKovLReLuar/GHBOkvVJTgauB7ZMZliSpumEj/hVdSjJHwPfAlYAX6mqZyY2MklTs6jf+FX1TeCbExqLpBnxzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQccNfpKvJDmQ5Ol581YneTTJ7v551XSHKWmShhzx/wa46oh5twLbquocYFs/LWkkjhv8qvou8K9HzL4WuLt/fTfwBxMel6QpOtHf+Guq6uX+9SvAmgmNR9IMLLqTTlXV/9cow0460vJzokf8V5OsBeifDxzrjVV1R1VtqqphTb0kTd2JBn8L8LH+9ceAb0xmOJJmYch/590L/DNwbpJ9SW4EPgdsTrIbuLKfljQSx/2NX1U3HGPRFRMei6QZ8c49qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUFDSm+dlWR7kp1Jnklycz/fbjrSSA054h8CPllV5wOXAB9Pcj5205FGa0gnnZer6on+9U+BXcCZ2E1HGq0FNdRIcjZwIbCDgd10bKghLT+DL+4leS/wdeCWqnpz/rKqKuCo3XRsqCEtP4OCn+QkutDfU1UP9bMHd9ORtLwMuaof4C5gV1V9Yd4iu+lIIzXkN/5lwB8CP0jyVD/vz+m65zzQd9Z5EbhuOkOUNGlDOun8E5BjLLabjjRC3rknNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw1aUM29RdsIzM10jeOTo1YwkybKI77UIIMvNWhIzb1Tknwvyb/0nXQ+289fn2RHkj1J7k9y8vSHK2kShhzx/xO4vKouADYAVyW5BPg88MWq+gDwOnDj9IYpaZKGdNKpqvr3fvKk/lHA5cCD/Xw76UgjMrSu/oq+wu4B4FHgeeCNqjrUv2UfXVuto332piRzSeY4OIkhS1qsQcGvqp9V1QZgHXAxcN7QFbytk84ZJzhKSRO1oKv6VfUGsB24FDgtyeH7ANYB+yc8NklTMuSq/hlJTutfvxvYTNcxdzvwkf5tdtKRRmTInXtrgbuTrKD7onigqrYm2Qncl+QvgSfp2mxJGoEhnXS+T9ca+8j5L9D93pc0Mt65JzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVocPD7EttPJtnaT9tJRxqphRzxb6YrsnmYnXSkkRraUGMd8PvAnf10sJOONFpDj/hfAj4F/LyfPh076UijNaSu/oeBA1X1+ImswE460vIzpK7+ZcA1Sa4GTgF+CbidvpNOf9S3k440IkO65d5WVeuq6mzgeuDbVfVR7KQjjdZi/h//08Ankuyh+81vJx1pJIac6v+vqvoO8J3+tZ10pJHyzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGDSrEkWQv8FPgZ8ChqtqUZDVwP3A2sBe4rqpen84wJU3SQo74v1NVG6pqUz99K7Ctqs4BtvXTkkZgMaf619I10gAbakijMjT4BfxDkseT3NTPW1NVL/evXwHWTHx0kqZiaLHND1bV/iS/Ajya5Nn5C6uqktTRPth/UXRfFr+2mKFKmpRBR/yq2t8/HwAepquu+2qStQD984FjfNZOOtIyM6SF1qlJfvHwa+B3gaeBLXSNNMCGGtKoDDnVXwM83DXIZSXwd1X1SJLHgAeS3Ai8CFw3vWFKmqTjBr9vnHHBUeb/BLhiGoOSNF3euSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aOhf503ERjYyx9wsVzk+R/0bR2myPOJLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgQcFPclqSB5M8m2RXkkuTrE7yaJLd/fOqaQ9W0mQMPeLfDjxSVefRleHahZ10pNEaUmX3fcBvAXcBVNV/VdUb2ElHGq0hR/z1wEHgq0meTHJnX2bbTjrSSA0J/krgIuDLVXUh8BZHnNZXVXGMu8yT3JRkLsncwYMHFzteSRMwJPj7gH1VtaOffpDui2DBnXTOOMNWOtJycNzgV9UrwEtJzu1nXQHsxE460mgN/bPcPwHuSXIy8ALwR3RfGnbSkUZoUPCr6ilg01EW2UlHGiHv3JMaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaNKSu/rlJnpr3eDPJLXbSkcZrSLHN56pqQ1VtADYC/wE8jJ10pNFa6Kn+FcDzVfUidtKRRmuhwb8euLd/bScdaaQGB78vrX0N8LUjl9lJRxqXhRzxPwQ8UVWv9tN20pFGaiHBv4H/O80HO+lIozUo+H133M3AQ/Nmfw7YnGQ3cGU/LWkEhnbSeQs4/Yh5P8FOOtIoeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KChpbf+LMkzSZ5Ocm+SU5KsT7IjyZ4k9/dVeCWNwJAWWmcCfwpsqqrfAFbQ1df/PPDFqvoA8Dpw4zQHKmlyhp7qrwTenWQl8B7gZeBy4MF+uZ10pBEZ0jtvP/BXwI/pAv9vwOPAG1V1qH/bPuDMaQ1S0mQNOdVfRdcnbz3wq8CpwFVDV2AnHWn5GXKqfyXwo6o6WFX/TVdb/zLgtP7UH2AdsP9oH7aTjrT8DAn+j4FLkrwnSehq6e8EtgMf6d9jJx1pRIb8xt9BdxHvCeAH/WfuAD4NfCLJHrpmG3dNcZySJmhoJ53PAJ85YvYLwMUTH5GkqfPOPalBBl9qkMGXGmTwpQalqma3suQg8Bbw2sxWOn2/jNuzXL2TtgWGbc+vV9Vxb5iZafABksxV1aaZrnSK3J7l6520LTDZ7fFUX2qQwZcatBTBv2MJ1jlNbs/y9U7aFpjg9sz8N76kpeepvtSgmQY/yVVJnuvr9N06y3UvVpKzkmxPsrOvP3hzP391kkeT7O6fVy31WBciyYokTybZ2k+PtpZiktOSPJjk2SS7klw65v0zzVqXMwt+khXAXwMfAs4Hbkhy/qzWPwGHgE9W1fnAJcDH+/HfCmyrqnOAbf30mNwM7Jo3PeZaircDj1TVecAFdNs1yv0z9VqXVTWTB3Ap8K1507cBt81q/VPYnm8Am4HngLX9vLXAc0s9tgVswzq6MFwObAVCd4PIyqPts+X8AN4H/Ij+utW8+aPcP3Sl7F4CVtP9Fe1W4PcmtX9meap/eEMOG22dviRnAxcCO4A1VfVyv+gVYM0SDetEfAn4FPDzfvp0xltLcT1wEPhq/9PlziSnMtL9U1OudenFvQVK8l7g68AtVfXm/GXVfQ2P4r9JknwYOFBVjy/1WCZkJXAR8OWqupDu1vC3ndaPbP8sqtbl8cwy+PuBs+ZNH7NO33KV5CS60N9TVQ/1s19NsrZfvhY4sFTjW6DLgGuS7AXuozvdv52BtRSXoX3AvuoqRkFXNeoixrt/FlXr8nhmGfzHgHP6q5In012o2DLD9S9KX2/wLmBXVX1h3qItdDUHYUS1B6vqtqpaV1Vn0+2Lb1fVRxlpLcWqegV4Kcm5/azDtSFHuX+Ydq3LGV+wuBr4IfA88BdLfQFlgWP/IN1p4veBp/rH1XS/i7cBu4F/BFYv9VhPYNt+G9jav34/8D1gD/A14F1LPb4FbMcGYK7fR38PrBrz/gE+CzwLPA38LfCuSe0f79yTGuTFPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQb9DzfZzH2Ei/mJAAAAAElFTkSuQmCC\n", 68 | "text/plain": [ 69 | "
" 70 | ] 71 | }, 72 | "metadata": {}, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "prev_state = env.reset()\n", 78 | "plt.imshow(prev_state)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "

Training Q Network

" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "

Hyper-parameters

" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": { 99 | "collapsed": true 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "BATCH_SIZE = 32\n", 104 | "FREEZE_INTERVAL = 20000 # steps\n", 105 | "MEMORY_SIZE = 60000 \n", 106 | "OUTPUT_SIZE = 4\n", 107 | "TOTAL_EPISODES = 10000\n", 108 | "MAX_STEPS = 50\n", 109 | "INITIAL_EPSILON = 1.0\n", 110 | "FINAL_EPSILON = 0.1\n", 111 | "GAMMA = 0.99\n", 112 | "INPUT_IMAGE_DIM = 84\n", 113 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "

Save Dictionay Function

" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": { 127 | "collapsed": true 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "def save_obj(obj, name ):\n", 132 | " with open('data/'+ name + '.pkl', 'wb') as f:\n", 133 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "

Experience Replay

" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 6, 146 | "metadata": { 147 | "collapsed": true 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "class Memory():\n", 152 | " \n", 153 | " def __init__(self,memsize):\n", 154 | " self.memsize = memsize\n", 155 | " self.memory = deque(maxlen=self.memsize)\n", 156 | " \n", 157 | " def add_sample(self,sample):\n", 158 | " self.memory.append(sample)\n", 159 | " \n", 160 | " def get_batch(self,size):\n", 161 | " return random.sample(self.memory,k=size)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "

Frame Collector

" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 7, 174 | "metadata": { 175 | "collapsed": true 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "class FrameCollector():\n", 180 | " \n", 181 | " def __init__(self,num_frames,img_dim):\n", 182 | " self.num_frames = num_frames\n", 183 | " self.img_dim = img_dim\n", 184 | " self.frames = deque(maxlen=self.num_frames)\n", 185 | " \n", 186 | " def reset(self):\n", 187 | " tmp = np.zeros((self.img_dim,self.img_dim))\n", 188 | " for i in range(0,self.num_frames):\n", 189 | " self.frames.append(tmp)\n", 190 | " \n", 191 | " def add_frame(self,frame):\n", 192 | " self.frames.append(frame)\n", 193 | " \n", 194 | " def get_state(self):\n", 195 | " return np.array(self.frames)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "

Preprocess Images

" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 8, 208 | "metadata": { 209 | "collapsed": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "def preprocess_image(image):\n", 214 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n", 215 | " return np.copy(image)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 9, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "" 227 | ] 228 | }, 229 | "execution_count": 9, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | }, 233 | { 234 | "data": { 235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADM1JREFUeJzt3X/oXfV9x/Hna4nW1k5jnIbMyMxoUGRg1OAUy9jUbKktuj+KKDLKEPJPt+laaHX7oxT2RwujrcIoiLaT4fxRq2sIxc6lljEYqfHHWk1ME22sCWqi80fnYFva9/64J9u3WbKcb+6P7/f4eT7gcu85596cz+Hwuufc8z15v1NVSGrLLy30ACTNnsGXGmTwpQYZfKlBBl9qkMGXGmTwpQaNFfwkG5LsTLI7ya2TGpSk6crx3sCTZAnwI2A9sBd4ArihqrZPbniSpmHpGJ+9BNhdVS8CJLkfuBY4avCTeJugxnLxxRcv9BAWtT179vD666/nWO8bJ/hnAS/Pmd4L/OYY/550TNu2bVvoISxq69at6/W+cYLfS5KNwMZpr0dSf+MEfx9w9pzpVd28X1BVdwJ3gqf60mIxzlX9J4A1SVYnORG4Htg0mWFJmqbjPuJX1cEkfwR8B1gCfK2qnpvYyCRNzVi/8avq28C3JzQWSTPinXtSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSg44Z/CRfS7I/ybNz5i1P8liSXd3zadMdpqRJ6nPE/2tgw2HzbgW2VNUaYEs3LWkgjhn8qvpH4F8Pm30tcE/3+h7g9yc8LklTdLy/8VdU1Svd61eBFRMaj6QZGLuTTlXV/9cow0460uJzvEf815KsBOie9x/tjVV1Z1Wtq6p+Tb0kTd3xBn8T8Inu9SeAb01mOJJmoc+f8+4D/hk4N8neJDcBXwDWJ9kFXNVNSxqIY/7Gr6objrLoygmPRdKMeOee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KA+pbfOTvJ4ku1JnktyczffbjrSQPU54h8EPl1V5wOXAp9Mcj5205EGq08nnVeq6qnu9U+BHcBZ2E1HGqx5NdRIcg5wIbCVnt10bKghLT69L+4l+SDwTeCWqnpn7rKqKuCI3XRsqCEtPr2Cn+QERqG/t6oe7mb37qYjaXHpc1U/wN3Ajqr60pxFdtORBqrPb/zLgT8AfpjkmW7enzHqnvNg11nnJeC66QxR0qT16aTzT0COsthuOtIAeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVoXjX3xrVmzRruuOOOWa5ycG688caFHoIa4BFfapDBlxrUp+beSUm+n+Rfuk46n+/mr06yNcnuJA8kOXH6w5U0CX2O+P8BXFFVFwBrgQ1JLgW+CHy5qj4EvAncNL1hSpqkPp10qqr+rZs8oXsUcAXwUDffTjrSgPStq7+kq7C7H3gMeAF4q6oOdm/Zy6it1pE+uzHJtiTb3n777UmMWdKYegW/qn5WVWuBVcAlwHl9VzC3k86pp556nMOUNEnzuqpfVW8BjwOXAcuSHLoPYBWwb8JjkzQlfa7qn5FkWff6/cB6Rh1zHwc+3r3NTjrSgPS5c28lcE+SJYy+KB6sqs1JtgP3J/kL4GlGbbYkDUCfTjo/YNQa+/D5LzL6vS9pYLxzT2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQ7+B3JbafTrK5m7aTjjRQ8zni38yoyOYhdtKRBqpvQ41VwEeBu7rpYCcdabD6HvG/AnwG+Hk3fTp20pEGq09d/Y8B+6vqyeNZgZ10pMWnT139y4FrklwNnAScAtxO10mnO+rbSUcakD7dcm+rqlVVdQ5wPfDdqroRO+lIgzXO3/E/C3wqyW5Gv/ntpCMNRJ9T/f9RVd8Dvte9tpOONFDeuSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgXoU4kuwBfgr8DDhYVeuSLAceAM4B9gDXVdWb0xmmpEmazxH/d6pqbVWt66ZvBbZU1RpgSzctaQDGOdW/llEjDbChhjQofYNfwN8neTLJxm7eiqp6pXv9KrBi4qOTNBV9i21+uKr2JTkTeCzJ83MXVlUlqSN9sPui2Ahw5plnjjVYSZPR64hfVfu65/3AI4yq676WZCVA97z/KJ+1k460yPRpoXVykl8+9Br4XeBZYBOjRhpgQw1pUPqc6q8AHhk1yGUp8LdV9WiSJ4AHk9wEvARcN71hSpqkYwa/a5xxwRHmvwFcOY1BSZou79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtT3f+dNxCmnnMKGDRtmucrBeeONNxZ6CGqAR3ypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUK/hJliV5KMnzSXYkuSzJ8iSPJdnVPZ827cFKmoy+R/zbgUer6jxGZbh2YCcdabD6VNk9Ffgt4G6AqvrPqnoLO+lIg9XniL8aOAB8PcnTSe7qymzbSUcaqD7BXwpcBHy1qi4E3uWw0/qqKkZttv6PJBuTbEuy7cCBA+OOV9IE9An+XmBvVW3tph9i9EUw7046Z5xxxiTGLGlMxwx+Vb0KvJzk3G7WlcB27KQjDVbf/5b7x8C9SU4EXgT+kNGXhp10pAHqFfyqegZYd4RFdtKRBsg796QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9amrf26SZ+Y83klyi510pOHqU2xzZ1Wtraq1wMXAvwOPYCcdabDme6p/JfBCVb2EnXSkwZpv8K8H7ute20lHGqjewe9Ka18DfOPwZXbSkYZlPkf8jwBPVdVr3bSddKSBmk/wb+B/T/PBTjrSYPUKftcddz3w8JzZXwDWJ9kFXNVNSxqAvp103gVOP2zeG9hJRxok79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtS39NafJnkuybNJ7ktyUpLVSbYm2Z3kga4Kr6QB6NNC6yzgT4B1VfUbwBJG9fW/CHy5qj4EvAncNM2BSpqcvqf6S4H3J1kKfAB4BbgCeKhbbicdaUD69M7bB/wl8BNGgX8beBJ4q6oOdm/bC5w1rUFKmqw+p/qnMeqTtxr4VeBkYEPfFdhJR1p8+pzqXwX8uKoOVNV/MaqtfzmwrDv1B1gF7DvSh+2kIy0+fYL/E+DSJB9IEka19LcDjwMf795jJx1pQPr8xt/K6CLeU8APu8/cCXwW+FSS3Yyabdw9xXFKmqC+nXQ+B3zusNkvApdMfESSps4796QGGXypQQZfapDBlxqUqprdypIDwLvA6zNb6fT9Cm7PYvVe2hbotz2/VlXHvGFmpsEHSLKtqtbNdKVT5PYsXu+lbYHJbo+n+lKDDL7UoIUI/p0LsM5pcnsWr/fStsAEt2fmv/ElLTxP9aUGzTT4STYk2dnV6bt1luseV5KzkzyeZHtXf/Dmbv7yJI8l2dU9n7bQY52PJEuSPJ1kczc92FqKSZYleSjJ80l2JLlsyPtnmrUuZxb8JEuAvwI+ApwP3JDk/FmtfwIOAp+uqvOBS4FPduO/FdhSVWuALd30kNwM7JgzPeRaircDj1bVecAFjLZrkPtn6rUuq2omD+Ay4Dtzpm8DbpvV+qewPd8C1gM7gZXdvJXAzoUe2zy2YRWjMFwBbAbC6AaRpUfaZ4v5AZwK/JjuutWc+YPcP4xK2b0MLGf0v2g3A783qf0zy1P9QxtyyGDr9CU5B7gQ2AqsqKpXukWvAisWaFjH4yvAZ4Cfd9OnM9xaiquBA8DXu58udyU5mYHun5pyrUsv7s1Tkg8C3wRuqap35i6r0dfwIP5MkuRjwP6qenKhxzIhS4GLgK9W1YWMbg3/hdP6ge2fsWpdHsssg78POHvO9FHr9C1WSU5gFPp7q+rhbvZrSVZ2y1cC+xdqfPN0OXBNkj3A/YxO92+nZy3FRWgvsLdGFaNgVDXqIoa7f8aqdXksswz+E8Ca7qrkiYwuVGya4frH0tUbvBvYUVVfmrNoE6OagzCg2oNVdVtVraqqcxjti+9W1Y0MtJZiVb0KvJzk3G7WodqQg9w/TLvW5YwvWFwN/Ah4Afjzhb6AMs+xf5jRaeIPgGe6x9WMfhdvAXYB/wAsX+ixHse2/TawuXv968D3gd3AN4D3LfT45rEda4Ft3T76O+C0Ie8f4PPA88CzwN8A75vU/vHOPalBXtyTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9q0H8DxmXSvDxTWXAAAAAASUVORK5CYII=\n", 236 | "text/plain": [ 237 | "
" 238 | ] 239 | }, 240 | "metadata": {}, 241 | "output_type": "display_data" 242 | } 243 | ], 244 | "source": [ 245 | "processed_prev_state = preprocess_image(prev_state)\n", 246 | "plt.imshow(processed_prev_state,cmap='gray')" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "

Build Model

" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 10, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Network(\n", 266 | " (conv_layer1): Conv2d(4, 64, kernel_size=(8, 8), stride=(4, 4))\n", 267 | " (conv_layer2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))\n", 268 | " (conv_layer3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n", 269 | " (fc1): Linear(in_features=6272, out_features=512, bias=True)\n", 270 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n", 271 | " (relu): ReLU()\n", 272 | ")\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "import torch.nn as nn\n", 278 | "import torch\n", 279 | "\n", 280 | "class Network(nn.Module):\n", 281 | " \n", 282 | " def __init__(self,image_input_size,out_size):\n", 283 | " super(Network,self).__init__()\n", 284 | " self.image_input_size = image_input_size\n", 285 | " self.out_size = out_size\n", 286 | "\n", 287 | " self.conv_layer1 = nn.Conv2d(in_channels=4,out_channels=64,kernel_size=8,stride=4) # GRAY - 1\n", 288 | " self.conv_layer2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2)\n", 289 | " self.conv_layer3 = nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1)\n", 290 | " self.fc1 = nn.Linear(in_features=7*7*128,out_features=512)\n", 291 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n", 292 | " self.relu = nn.ReLU()\n", 293 | "\n", 294 | " def forward(self,x,bsize):\n", 295 | " x = x.view(bsize,4,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n", 296 | " conv_out = self.conv_layer1(x)\n", 297 | " conv_out = self.relu(conv_out)\n", 298 | " conv_out = self.conv_layer2(conv_out)\n", 299 | " conv_out = self.relu(conv_out)\n", 300 | " conv_out = self.conv_layer3(conv_out)\n", 301 | " conv_out = self.relu(conv_out)\n", 302 | " out = self.fc1(conv_out.view(bsize,7*7*128))\n", 303 | " out = self.relu(out)\n", 304 | " out = self.fc2(out)\n", 305 | " return out\n", 306 | "\n", 307 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).cuda()\n", 308 | "print(main_model)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "

Deep Q Learning with Freeze Network

" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "Populated 60000 Samples in Episodes : 1200\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "mem = Memory(memsize=MEMORY_SIZE)\n", 333 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Primary Network\n", 334 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Target Network\n", 335 | "frameObj = FrameCollector(img_dim=INPUT_IMAGE_DIM,num_frames=4)\n", 336 | "\n", 337 | "target_model.load_state_dict(main_model.state_dict())\n", 338 | "criterion = nn.SmoothL1Loss()\n", 339 | "optimizer = torch.optim.Adam(main_model.parameters())\n", 340 | "\n", 341 | "# filling memory with transitions\n", 342 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n", 343 | " \n", 344 | " prev_state = env.reset()\n", 345 | " frameObj.reset()\n", 346 | " processed_prev_state = preprocess_image(prev_state)\n", 347 | " frameObj.add_frame(processed_prev_state)\n", 348 | " prev_frames = frameObj.get_state()\n", 349 | " step_count = 0\n", 350 | " game_over = False\n", 351 | " \n", 352 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 353 | " \n", 354 | " step_count +=1\n", 355 | " action = np.random.randint(0,4)\n", 356 | " next_state,reward, game_over = env.step(action)\n", 357 | " processed_next_state = preprocess_image(next_state)\n", 358 | " frameObj.add_frame(processed_next_state)\n", 359 | " next_frames = frameObj.get_state()\n", 360 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n", 361 | " \n", 362 | " prev_state = next_state\n", 363 | " processed_prev_state = processed_next_state\n", 364 | " prev_frames = next_frames\n", 365 | "\n", 366 | "print('Populated %d Samples in Episodes : %d'%(len(mem.memory),int(MEMORY_SIZE/MAX_STEPS)))\n", 367 | "\n", 368 | "\n", 369 | "# Algorithm Starts\n", 370 | "total_steps = 0\n", 371 | "epsilon = INITIAL_EPSILON\n", 372 | "loss_stat = []\n", 373 | "total_reward_stat = []\n", 374 | "\n", 375 | "for episode in range(0,TOTAL_EPISODES):\n", 376 | " \n", 377 | " prev_state = env.reset()\n", 378 | " frameObj.reset()\n", 379 | " processed_prev_state = preprocess_image(prev_state)\n", 380 | " frameObj.add_frame(processed_prev_state)\n", 381 | " prev_frames = frameObj.get_state()\n", 382 | " game_over = False\n", 383 | " step_count = 0\n", 384 | " total_reward = 0\n", 385 | " \n", 386 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 387 | " \n", 388 | " step_count +=1\n", 389 | " total_steps +=1\n", 390 | " \n", 391 | " if np.random.rand() <= epsilon:\n", 392 | " action = np.random.randint(0,4)\n", 393 | " else:\n", 394 | " with torch.no_grad():\n", 395 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n", 396 | "\n", 397 | " model_out = main_model.forward(torch_x,bsize=1)\n", 398 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 399 | " \n", 400 | " next_state, reward, game_over = env.step(action)\n", 401 | " processed_next_state = preprocess_image(next_state)\n", 402 | " frameObj.add_frame(processed_next_state)\n", 403 | " next_frames = frameObj.get_state()\n", 404 | " total_reward += reward\n", 405 | " \n", 406 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n", 407 | " \n", 408 | " prev_state = next_state\n", 409 | " processed_prev_state = processed_next_state\n", 410 | " prev_frames = next_frames\n", 411 | " \n", 412 | " if (total_steps % FREEZE_INTERVAL) == 0:\n", 413 | " target_model.load_state_dict(main_model.state_dict())\n", 414 | " \n", 415 | " batch = mem.get_batch(size=BATCH_SIZE)\n", 416 | " current_states = []\n", 417 | " next_states = []\n", 418 | " acts = []\n", 419 | " rewards = []\n", 420 | " game_status = []\n", 421 | " \n", 422 | " for element in batch:\n", 423 | " current_states.append(element[0])\n", 424 | " acts.append(element[1])\n", 425 | " rewards.append(element[2])\n", 426 | " next_states.append(element[3])\n", 427 | " game_status.append(element[4])\n", 428 | " \n", 429 | " current_states = np.array(current_states)\n", 430 | " next_states = np.array(next_states)\n", 431 | " rewards = np.array(rewards)\n", 432 | " game_status = [not b for b in game_status]\n", 433 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n", 434 | " torch_acts = torch.tensor(acts)\n", 435 | " \n", 436 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n", 437 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n", 438 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n", 439 | " Q_max_next = Q_max_next.double()\n", 440 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n", 441 | " \n", 442 | " target_values = (rewards + (GAMMA * Q_max_next))\n", 443 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n", 444 | " \n", 445 | " loss = criterion(Q_s_a,target_values.float().cuda())\n", 446 | " \n", 447 | " # save performance measure\n", 448 | " loss_stat.append(loss.item())\n", 449 | " \n", 450 | " # make previous grad zero\n", 451 | " optimizer.zero_grad()\n", 452 | " \n", 453 | " # back - propogate \n", 454 | " loss.backward()\n", 455 | " \n", 456 | " # update params\n", 457 | " optimizer.step()\n", 458 | " \n", 459 | " # save performance measure\n", 460 | " total_reward_stat.append(total_reward)\n", 461 | " \n", 462 | " if epsilon > FINAL_EPSILON:\n", 463 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 464 | " \n", 465 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 466 | " perf = {}\n", 467 | " perf['loss'] = loss_stat\n", 468 | " perf['total_reward'] = total_reward_stat\n", 469 | " save_obj(name='FOUR_OBSERV_NINE',obj=perf)\n", 470 | " \n", 471 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)\n" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "

Save Primary Network Weights

" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 12, 484 | "metadata": { 485 | "collapsed": true 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "torch.save(main_model.state_dict(),'data/FOUR_OBSERV_NINE_WEIGHTS.torch')" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": {}, 495 | "source": [ 496 | "

Testing Policy

" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "

Load Primary Network Weights

" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 13, 509 | "metadata": { 510 | "collapsed": true 511 | }, 512 | "outputs": [], 513 | "source": [ 514 | "weights = torch.load('data/FOUR_OBSERV_NINE_WEIGHTS.torch')\n", 515 | "main_model.load_state_dict(weights)" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "

Testing Policy

" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "metadata": { 529 | "collapsed": true 530 | }, 531 | "outputs": [], 532 | "source": [ 533 | "# Algorithm Starts\n", 534 | "epsilon = INITIAL_EPSILON\n", 535 | "FINAL_EPSILON = 0.01\n", 536 | "total_reward_stat = []\n", 537 | "\n", 538 | "for episode in range(0,TOTAL_EPISODES):\n", 539 | " \n", 540 | " prev_state = env.reset()\n", 541 | " processed_prev_state = preprocess_image(prev_state)\n", 542 | " frameObj.reset()\n", 543 | " frameObj.add_frame(processed_prev_state)\n", 544 | " prev_frames = frameObj.get_state()\n", 545 | " game_over = False\n", 546 | " step_count = 0\n", 547 | " total_reward = 0\n", 548 | " \n", 549 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 550 | " \n", 551 | " step_count +=1\n", 552 | " \n", 553 | " if np.random.rand() <= epsilon:\n", 554 | " action = np.random.randint(0,4)\n", 555 | " else:\n", 556 | " with torch.no_grad():\n", 557 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n", 558 | "\n", 559 | " model_out = main_model.forward(torch_x,bsize=1)\n", 560 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 561 | " \n", 562 | " next_state, reward, game_over = env.step(action)\n", 563 | " processed_next_state = preprocess_image(next_state)\n", 564 | " frameObj.add_frame(processed_next_state)\n", 565 | " next_frames = frameObj.get_state()\n", 566 | " \n", 567 | " total_reward += reward\n", 568 | " \n", 569 | " prev_state = next_state\n", 570 | " processed_prev_state = processed_next_state\n", 571 | " prev_frames = next_frames\n", 572 | " \n", 573 | " # save performance measure\n", 574 | " total_reward_stat.append(total_reward)\n", 575 | " \n", 576 | " if epsilon > FINAL_EPSILON:\n", 577 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 578 | " \n", 579 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 580 | " perf = {}\n", 581 | " perf['total_reward'] = total_reward_stat\n", 582 | " save_obj(name='FOUR_OBSERV_NINE',obj=perf)\n", 583 | " \n", 584 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)" 585 | ] 586 | } 587 | ], 588 | "metadata": { 589 | "kernelspec": { 590 | "display_name": "Python [conda env:myenv]", 591 | "language": "python", 592 | "name": "conda-env-myenv-py" 593 | }, 594 | "language_info": { 595 | "codemirror_mode": { 596 | "name": "ipython", 597 | "version": 3 598 | }, 599 | "file_extension": ".py", 600 | "mimetype": "text/x-python", 601 | "name": "python", 602 | "nbconvert_exporter": "python", 603 | "pygments_lexer": "ipython3", 604 | "version": "3.6.5" 605 | } 606 | }, 607 | "nbformat": 4, 608 | "nbformat_minor": 2 609 | } 610 | -------------------------------------------------------------------------------- /MDP_Size_9.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from gridworld import gameEnv\n", 10 | "import numpy as np\n", 11 | "%matplotlib inline\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "from collections import deque\n", 14 | "import pickle\n", 15 | "from skimage.color import rgb2gray\n", 16 | "import random" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "

Define Environment Object

" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 3, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "data": { 33 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADONJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUMtCiKqWsAtIRH0IkKgqIoqJG7SFppIiWkvoki9SKQqCRdVJARJUUX5EwKNZUWk1CGqKlUO5k8TsCE2xARbBpsUSkqltk7eXuy4PXFtzhyf3T07/n0/0urszOye+Y1Hz5nZ2fH7pqqQ1JZfWO4BSJo+gy81yOBLDTL4UoMMvtQggy81yOBLDVpS8JNck+SFJLuTbBrXoCRNVo73Bp4kK4AfABuBvcATwE1VtWN8w5M0CSuX8N5Lgd1V9RJAkvuB64FjBv/MM8+subm5JaxS0jvZs2cPr7/+ehZ63VKCfzbwyrzpvcBvvtMb5ubm2L59+xJWKemdbNiwodfrJn5xL8ktSbYn2X7w4MFJr05SD0sJ/j7gnHnTa7t5P6eq7qyqDVW14ayzzlrC6iSNy1KC/wRwXpJ1SU4GbgQ2j2dYkibpuD/jV9WhJH8EfAtYAXylqp4b28gkTcxSLu5RVd8EvjmmsUiaEu/ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGLem/5c6CZMG6gtLMWq429R7xpQYZfKlBCwY/yVeSHEjy7Lx5q5I8lmRX9/P0yQ5T0jj1OeL/NXDNEfM2AVur6jxgazctaSAWDH5V/SPwr0fMvh64p3t+D/D7Yx6XpAk63s/4q6tqf/f8VWD1mMYjaQqWfHGvRt9HHPM7CTvpSLPneIP/WpI1AN3PA8d6oZ10pNlzvMHfDHyse/4x4BvjGY6kaejzdd59wD8D5yfZm+Rm4HPAxiS7gKu7aUkDseAtu1V10zEWXTXmsUiaEu/ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUp/TWOUkeT7IjyXNJbu3m201HGqg+R/xDwCer6kLgMuDjSS7EbjrSYPXppLO/qp7qnv8E2Amcjd10pMFa1Gf8JHPAxcA2enbTsaGGNHt6Bz/Je4GvA7dV1Vvzl71TNx0bakizp1fwk5zEKPT3VtXD3eze3XQkzZY+V/UD3A3srKovzFtkNx1poBZsqAFcAfwB8P0kz3Tz/oxR95wHu846LwM3TGaIksatTyedfwJyjMV205EGyDv3pAb1OdXXCeioX8EsQcb5C491fjlDxv3vN20e8aUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfalCfmnunJPlukn/pOul8tpu/Lsm2JLuTPJDk5MkPV9I49Dni/ydwZVVdBKwHrklyGfB54ItV9X7gDeDmyQ1T0jj16aRTVfXv3eRJ3aOAK4GHuvl20pEGpG9d/RVdhd0DwGPAi8CbVXWoe8leRm21jvZeO+lIM6ZX8Kvqp1W1HlgLXApc0HcFdtKZTRnzY7y/bPYNfVMXdVW/qt4EHgcuB05LcrhY51pg35jHJmlC+lzVPyvJad3zdwMbGXXMfRz4SPcyO+lIA9KnvPYa4J4kKxj9oXiwqrYk2QHcn+QvgKcZtdmSNAB9Oul8j1Fr7CPnv8To876kgfHOPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBvYPfldh+OsmWbtpOOtJALeaIfyujIpuH2UlHGqi+DTXWAh8C7uqmg510pMHqe8T/EvAp4Gfd9BnYSUcarD519T8MHKiqJ49nBXbSkWZPn7r6VwDXJbkWOAX4JeAOuk463VHfTjrSgPTplnt7Va2tqjngRuDbVfVR7KQjDdZSvsf/NPCJJLsZfea3k440EH1O9f9XVX0H+E733E460kB5557UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDFnUDjxapxvz7Mubfp2Z5xJcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9vsdPsgf4CfBT4FBVbUiyCngAmAP2ADdU1RuTGaakcVrMEf93qmp9VW3opjcBW6vqPGBrNy1pAJZyqn89o0YaYEMNaVD6Br+Av0/yZJJbunmrq2p/9/xVYPXYRydpIvreq/+BqtqX5JeBx5I8P39hVVWSo96Z3v2huAXg3HPPXdJgJY1HryN+Ve3rfh4AHmFUXfe1JGsAup8HjvFeO+lIM6ZPC61Tk/zi4efA7wLPApsZNdIAG2pIg9LnVH818MioQS4rgb+tqkeTPAE8mORm4GXghskNU9I4LRj8rnHGRUeZ/2PgqkkMStJkeeee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBewU9yWpKHkjyfZGeSy5OsSvJYkl3dz9MnPVhJ49H3iH8H8GhVXcCoDNdO7KQjDVafKrvvA34LuBugqv6rqt7ETjrSYPU54q8DDgJfTfJ0kru6Mtt20pEGqk/wVwKXAF+uqouBtznitL6qilGbrf8nyS1JtifZfvDgwaWOV9IY9An+XmBvVW3rph9i9IfATjoLyZgfs6zG+NDELRj8qnoVeCXJ+d2sq4Ad2ElHGqy+TTP/GLg3ycnAS8AfMvqjYScdaYB6Bb+qngE2HGWRnXSkAfLOPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBferqn5/kmXmPt5LcduJ10hlntcgGq0a2UlT0BNGn2OYLVbW+qtYDvwH8B/AIdtKRBmuxp/pXAS9W1cvYSUcarMUG/0bgvu65nXSkgeod/K609nXA145cZicdaVgWc8T/IPBUVb3WTdtJRxqoxQT/Jv7vNB/spCMNVq/gd91xNwIPz5v9OWBjkl3A1d20pAHo20nnbeCMI+b9GDvpSIPknXtSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSg3rduTfLRv8xcFbN8tjUMo/4UoMMvtQggy81yOBLDTL4UoMMvtQggy81qG/prT9N8lySZ5Pcl+SUJOuSbEuyO8kDXRVeSQPQp4XW2cCfABuq6teBFYzq638e+GJVvR94A7h5kgOVND59T/VXAu9OshJ4D7AfuBJ4qFtuJx1pQPr0ztsH/CXwI0aB/zfgSeDNqjrUvWwvcPakBilpvPqc6p/OqE/eOuBXgFOBa/quwE460uzpc6p/NfDDqjpYVf/NqLb+FcBp3ak/wFpg39HebCcdafb0Cf6PgMuSvCdJGNXS3wE8Dnyke42ddKQB6fMZfxuji3hPAd/v3nMn8GngE0l2M2q2cfcExylpjPp20vkM8JkjZr8EXDr2EUmaOO/ckxpk8KUGGXypQQZfalCmWawyyUHgbeD1qa108s7E7ZlVJ9K2QL/t+dWqWvCGmakGHyDJ9qraMNWVTpDbM7tOpG2B8W6Pp/pSgwy+1KDlCP6dy7DOSXJ7ZteJtC0wxu2Z+md8ScvPU32pQVMNfpJrkrzQ1enbNM11L1WSc5I8nmRHV3/w1m7+qiSPJdnV/Tx9uce6GElWJHk6yZZuerC1FJOcluShJM8n2Znk8iHvn0nWupxa8JOsAP4K+CBwIXBTkguntf4xOAR8sqouBC4DPt6NfxOwtarOA7Z200NyK7Bz3vSQayneATxaVRcAFzHarkHun4nXuqyqqTyAy4FvzZu+Hbh9WuufwPZ8A9gIvACs6eatAV5Y7rEtYhvWMgrDlcAWIIxuEFl5tH02yw/gfcAP6a5bzZs/yP3DqJTdK8AqRv+Ldgvwe+PaP9M81T+8IYcNtk5fkjngYmAbsLqq9neLXgVWL9OwjseXgE8BP+umz2C4tRTXAQeBr3YfXe5KcioD3T814VqXXtxbpCTvBb4O3FZVb81fVqM/w4P4miTJh4EDVfXkco9lTFYClwBfrqqLGd0a/nOn9QPbP0uqdbmQaQZ/H3DOvOlj1umbVUlOYhT6e6vq4W72a0nWdMvXAAeWa3yLdAVwXZI9wP2MTvfvoGctxRm0F9hbo4pRMKoadQnD3T9LqnW5kGkG/wngvO6q5MmMLlRsnuL6l6SrN3g3sLOqvjBv0WZGNQdhQLUHq+r2qlpbVXOM9sW3q+qjDLSWYlW9CryS5Pxu1uHakIPcP0y61uWUL1hcC/wAeBH48+W+gLLIsX+A0Wni94Bnuse1jD4XbwV2Af8ArFrusR7Htv02sKV7/mvAd4HdwNeAdy33+BaxHeuB7d0++jvg9CHvH+CzwPPAs8DfAO8a1/7xzj2pQV7ckxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfatD/ADPvEPBps6hQAAAAAElFTkSuQmCC\n", 34 | "text/plain": [ 35 | "
" 36 | ] 37 | }, 38 | "metadata": { 39 | "needs_background": "light" 40 | }, 41 | "output_type": "display_data" 42 | } 43 | ], 44 | "source": [ 45 | "env = gameEnv(partial=False,size=9)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "" 57 | ] 58 | }, 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "output_type": "execute_result" 62 | }, 63 | { 64 | "data": { 65 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADPlJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUMtCiKqWsAtIRH0IkKgqIoqJG7SFppICbQXUaReJFKVhIsqEoKkqKL8CYEGWREpdYiiSJWD+dMEbIgNMcEWYJNCSanU1snbix23B9fGc3xmz9nx7/uRVrszu3vmN2f0nJmdnfO+qSokteWXlnsAkpaewZcaZPClBhl8qUEGX2qQwZcaZPClBi0q+EmuSPJckp1Jbh5qUJKmK0d7AU+SFcCPgI3AbuAx4Lqq2jbc8CRNw8pFvPdCYGdVvQCQ5B7gauCwwT/11FNrbm5uEYuU9E527drFa6+9liO9bjHBPx14ad70buC33+kNc3NzbN26dRGLlPRONmzY0Ot1Uz+5l+SGJFuTbN23b9+0Fyeph8UEfw9wxrzptd28t6mq26pqQ1VtOO200xaxOElDWUzwHwPOSrIuyfHAtcBDwwxL0jQd9Wf8qtqf5E+AbwErgK9U1TODjUzS1Czm5B5V9U3gmwONRdIS8co9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2rQov4tdxYkR6wr2M80uoUPNDQdu5arTb17fKlBBl9q0BGDn+QrSfYmeXrevFVJHkmyo7s/ebrDlDSkPnv8vwWuOGjezcDmqjoL2NxNSxqJIwa/qr4L/OtBs68G7uwe3wn84cDjkjRFR/sZf3VVvdw9fgVYPdB4JC2BRZ/cq8n3EYf9TsJOOtLsOdrgv5pkDUB3v/dwL7STjjR7jjb4DwEf6x5/DPjGMMORtBT6fJ13N/DPwNlJdie5HvgcsDHJDuDyblrSSBzxkt2quu4wT1028FgkLRGv3JMaZPClBhl8qUEGX2qQwZcaZPClBo2+As9grJajhrjHlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Sm9dUaSR5NsS/JMkhu7+XbTkUaqzx5/P/DJqjoXuAj4eJJzsZuONFp9Oum8XFVPdI9/BmwHTsduOtJoLegzfpI54HxgCz276dhQQ5o9vYOf5L3A14GbqurN+c+9UzcdG2pIs6dX8JMcxyT0d1XVA93s3t10JM2WPmf1A9wBbK+qL8x7ym460kj1qcBzCfBHwA+TPNXN+wsm3XPu6zrrvAhcM50hShpan0463+PwhanspiONkFfuSQ2y2OaoHPKLk6PUUHXRIX9tB4z81+ceX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBvWpuXdCku8n+Zeuk85nu/nrkmxJsjPJvUmOn/5wJQ2hzx7/P4FLq+o8YD1wRZKLgM8DX6yq9wOvA9dPb5iShtSnk05V1b93k8d1twIuBe7v5ttJRxqRvnX1V3QVdvcCjwDPA29U1f7uJbuZtNU61HvtpCPNmF7Br6qfV9V6YC1wIXBO3wXYSWdIGfDWkCF/bcfIr29BZ/Wr6g3gUeBi4KQkB4p1rgX2DDw2SVPS56z+aUlO6h6/G9jIpGPuo8BHupfZSUcakT7ltdcAdyZZweQPxX1VtSnJNuCeJH8FPMmkzZakEejTSecHTFpjHzz/BSaf9yWNjFfuSQ0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw3qHfyuxPaTSTZ103bSkUZqIXv8G5kU2TzATjrSSPVtqLEW+BBwezcd7KQjjVbfPf6XgE8Bv+imT8FOOtJo9amr/2Fgb1U9fjQLsJOONHv61NW/BLgqyZXACcCvALfSddLp9vp20pFGpE+33Fuqam1VzQHXAt+uqo9iJx1ptBbzPf6ngU8k2cnkM7+ddKSR6HOo/7+q6jvAd7rHdtKRRsor96QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUqxBHkl3Az4CfA/urakOSVcC9wBywC7imql6fzjAlDWkhe/zfq6r1VbWhm74Z2FxVZwGbu2lJI7CYQ/2rmTTSABtqSKPSN/gF/GOSx5Pc0M1bXVUvd49fAVYPPjpJU9G32OYHqmpPkl8FHkny7Pwnq6qS1KHe2P2huAHgzDPPXNRgJQ2j1x6/qvZ093uBB5lU1301yRqA7n7vYd5rJx1pxvRpoXVikl8+8Bj4feBp4CEmjTTAhhrSqPQ51F8NPDhpkMtK4O+r6uEkjwH3JbkeeBG4ZnrDlDSkIwa/a5xx3iHm/xS4bBqDkjRdXrknNcjgSw1aUO+8mXTILxGP4sdkmJ8z3xR+pDQI9/hSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBewU9yUpL7kzybZHuSi5OsSvJIkh3d/cnTHqykYfTd498KPFxV5zApw7UdO+lIo9Wnyu77gN8B7gCoqv+qqjewk440Wn32+OuAfcBXkzyZ5PauzLaddKSR6hP8lcAFwJer6nzgLQ46rK+q4jBFsJLckGRrkq379u1b7HglDaBP8HcDu6tqSzd9P5M/BLPRSSfD3Ab6MW+76dhVA92WyxGDX1WvAC8lObubdRmwDTvpSKPVt8runwJ3JTkeeAH4YyZ/NOykI41Qr+BX1VPAhkM8ZScdaYS8ck9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUJ+6+mcneWre7c0kN9lJp4ehKjIud2VG/T9jL8jap9jmc1W1vqrWA78F/AfwIHbSkUZroYf6lwHPV9WL2ElHGq2FBv9a4O7usZ10pJHqHfyutPZVwNcOfs5OOtK4LGSP/0Hgiap6tZuejU46khZsIcG/jv87zAc76Uij1Sv4XXfcjcAD82Z/DtiYZAdweTctaQT6dtJ5CzjloHk/xU460ih55Z7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoF5X7s2yyT8GNqKhVdV0uceXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBfUtv/XmSZ5I8neTuJCckWZdkS5KdSe7tqvBKGoE+LbROB/4M2FBVvwmsYFJf//PAF6vq/cDrwPXTHKik4fQ91F8JvDvJSuA9wMvApcD93fN20pFGpE/vvD3AXwM/YRL4fwMeB96oqv3dy3YDp09rkJKG1edQ/2QmffLWAb8GnAhc0XcBdtKRZk+fQ/3LgR9X1b6q+m8mtfUvAU7qDv0B1gJ7DvVmO+lIs6dP8H8CXJTkPUnCpJb+NuBR4CPda+ykI41In8/4W5icxHsC+GH3ntuATwOfSLKTSbONO6Y4TkkD6ttJ5zPAZw6a/QJw4eAjkjR1XrknNcjgSw0y+FKDDL7UoCxlscok+4C3gNeWbKHTdyquz6w6ltYF+q3Pr1fVES+YWdLgAyTZWlUblnShU+T6zK5jaV1g2PXxUF9qkMGXGrQcwb9tGZY5Ta7P7DqW1gUGXJ8l/4wvafl5qC81aEmDn+SKJM91dfpuXsplL1aSM5I8mmRbV3/wxm7+qiSPJNnR3Z+83GNdiCQrkjyZZFM3PdpaiklOSnJ/kmeTbE9y8Zi3zzRrXS5Z8JOsAP4G+CBwLnBdknOXavkD2A98sqrOBS4CPt6N/2Zgc1WdBWzupsfkRmD7vOkx11K8FXi4qs4BzmOyXqPcPlOvdVlVS3IDLga+NW/6FuCWpVr+FNbnG8BG4DlgTTdvDfDcco9tAeuwlkkYLgU2AWFygcjKQ22zWb4B7wN+THfeat78UW4fJqXsXgJWMfkv2k3AHwy1fZbyUP/Aihww2jp9SeaA84EtwOqqerl76hVg9TIN62h8CfgU8Itu+hTGW0txHbAP+Gr30eX2JCcy0u1TU6516cm9BUryXuDrwE1V9eb852ryZ3gUX5Mk+TCwt6oeX+6xDGQlcAHw5ao6n8ml4W87rB/Z9llUrcsjWcrg7wHOmDd92Dp9syrJcUxCf1dVPdDNfjXJmu75NcDe5RrfAl0CXJVkF3APk8P9W+lZS3EG7QZ216RiFEyqRl3AeLfPompdHslSBv8x4KzurOTxTE5UPLSEy1+Urt7gHcD2qvrCvKceYlJzEEZUe7CqbqmqtVU1x2RbfLuqPspIaylW1SvAS0nO7mYdqA05yu3DtGtdLvEJiyuBHwHPA3+53CdQFjj2DzA5TPwB8FR3u5LJ5+LNwA7gn4BVyz3Wo1i33wU2dY9/A/g+sBP4GvCu5R7fAtZjPbC120b/AJw85u0DfBZ4Fnga+DvgXUNtH6/ckxrkyT2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG/Q++8g/3YgFKrwAAAABJRU5ErkJggg==\n", 66 | "text/plain": [ 67 | "
" 68 | ] 69 | }, 70 | "metadata": { 71 | "needs_background": "light" 72 | }, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "prev_state = env.reset()\n", 78 | "plt.imshow(prev_state)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "

Training Q Network

" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "

Hyper-parameters

" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 7, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "BATCH_SIZE = 64\n", 102 | "FREEZE_INTERVAL = 20000 # steps\n", 103 | "MEMORY_SIZE = 60000 \n", 104 | "OUTPUT_SIZE = 4\n", 105 | "TOTAL_EPISODES = 10000\n", 106 | "MAX_STEPS = 50\n", 107 | "INITIAL_EPSILON = 1.0\n", 108 | "FINAL_EPSILON = 0.01\n", 109 | "GAMMA = 0.99\n", 110 | "INPUT_IMAGE_DIM = 84\n", 111 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "

Save Dictionay Function

" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "metadata": { 125 | "collapsed": true 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "def save_obj(obj, name ):\n", 130 | " with open('data/'+ name + '.pkl', 'wb') as f:\n", 131 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "

Experience Replay

" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 6, 144 | "metadata": { 145 | "collapsed": true 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "class Memory():\n", 150 | " \n", 151 | " def __init__(self,memsize):\n", 152 | " self.memsize = memsize\n", 153 | " self.memory = deque(maxlen=self.memsize)\n", 154 | " \n", 155 | " def add_sample(self,sample):\n", 156 | " self.memory.append(sample)\n", 157 | " \n", 158 | " def get_batch(self,size):\n", 159 | " return random.sample(self.memory,k=size)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "

Preprocess Images

" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 5, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "def preprocess_image(image):\n", 176 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n", 177 | " return np.copy(image)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 7, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/plain": [ 188 | "" 189 | ] 190 | }, 191 | "execution_count": 7, 192 | "metadata": {}, 193 | "output_type": "execute_result" 194 | }, 195 | { 196 | "data": { 197 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADVNJREFUeJzt3X+oX/V9x/Hna4nW1m71VxYyo4ulosjA6C6ZYhmbmi1xRfdHEUVGGYL/dJtdC61usFLYHy2Mtv4xCkHbueH8UVtXCY2dSy1jMKLxx1o1WqONNUGTWHXtHGxL+94f3yO7zRLvubnne+89fp4PuHy/53x/nM/h8Pqe8z3fc9/vVBWS2vILSz0ASYvP4EsNMvhSgwy+1CCDLzXI4EsNMvhSgxYU/CSbkjybZHeSm4YalKTpyrFewJNkBfB9YCOwF3gEuLaqnh5ueJKmYeUCXrsB2F1VLwAkuQu4Cjhq8E877bRat27dAhYp6e3s2bOHV199NXM9byHBPx14adb0XuA33u4F69atY+fOnQtYpKS3MzMz0+t5Uz+5l+SGJDuT7Dx48OC0Fyeph4UEfx9wxqzptd28n1NVW6pqpqpmVq1atYDFSRrKQoL/CHB2krOSHA9cA9w/zLAkTdMxf8evqkNJ/gj4FrAC+HJVPTXYyCRNzUJO7lFV3wS+OdBYJC0Sr9yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYt6N9yl4NkzrqC0rK1VG3q3eNLDTL4UoPmDH6SLyc5kOTJWfNOSfJgkue625OnO0xJQ+qzx/8bYNNh824CtlfV2cD2blrSSMwZ/Kr6Z+C1w2ZfBdze3b8d+P2BxyVpio71O/7qqnq5u/8KsHqg8UhaBAs+uVeT3yOO+puEnXSk5edYg78/yRqA7vbA0Z5oJx1p+TnW4N8PfKS7/xHgG8MMR9Ji6PNz3p3AvwLnJNmb5Hrgs8DGJM8Bl3fTkkZizkt2q+raozx02cBjkbRIvHJPapDBlxpk8KUGGXypQQZfapDBlxo0+go8LdmwYcNg7/Xwww8P9l4aH/f4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPqW3zkjyUJKnkzyV5MZuvt10pJHqs8c/BHyiqs4DLgI+muQ87KYjjVafTjovV9Vj3f2fALuA07GbjjRa8/qOn2QdcAGwg57ddGyoIS0/vYOf5L3A14CPVdWPZz/2dt10bKghLT+9gp/kOCahv6Oqvt7N7t1NR9Ly0uesfoDbgF1V9flZD9lNRxqpPhV4LgH+APhekie6eX/GpHvOPV1nnReBq6czRElD69NJ51+AHOVhu+lII+SVe1KDRl9sc9u2bYO8z+bNmwd5n2myQKaG4h5fapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9am5d0KSh5P8W9dJ5zPd/LOS7EiyO8ndSY6f/nAlDaHPHv+/gEur6nxgPbApyUXA54AvVNUHgNeB66c3TElD6tNJp6rqP7rJ47q/Ai4F7u3m20lHGpG+dfVXdBV2DwAPAs8Db1TVoe4pe5m01TrSa+2kIy0zvWruVdVPgfVJTgLuA87tu4Cq2gJsAZiZmTlit52FGEOtvKFs2LBhsPeyfl/b5nVWv6reAB4CLgZOSvLWB8daYN/AY5M0JX3O6q/q9vQkeTewkUnH3IeAD3dPs5OONCJ9DvXXALcnWcHkg+Keqtqa5GngriR/CTzOpM2WpBHo00nnu0xaYx8+/wVguC+dkhaNV+5JDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6ld7S8mC5LA3FPb7UIIMvNah38LsS248n2dpN20lHGqn57PFvZFJk8y120pFGqm9DjbXA7wG3dtPBTjrSaPXd438R+CTws276VOykI41Wn7r6HwIOVNWjx7KAqtpSVTNVNbNq1apjeQtJA+vzO/4lwJVJrgBOAH4JuIWuk06317eTjjQifbrl3lxVa6tqHXAN8O2qug476UijtZDf8T8FfDzJbibf+e2kI43EvC7ZrarvAN/p7ttJRxopr9yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkDX3NIht27YN9l6bN28e7L10ZO7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/f8ZPsAX4C/BQ4VFUzSU4B7gbWAXuAq6vq9ekMU9KQ5rPH/+2qWl9VM930TcD2qjob2N5NSxqBhRzqX8WkkQbYUEMalb7BL+Afkzya5IZu3uqqerm7/wqwevDRSZqKvtfqf7Cq9iX5ZeDBJM/MfrCqKkkd6YXdB8UNAGeeeeaCBitpGL32+FW1r7s9ANzHpLru/iRrALrbA0d5rZ10pGWmTwutE5P84lv3gd8BngTuZ9JIA2yoIY1Kn0P91cB9kwa5rAT+vqoeSPIIcE+S64EXgaunN0xJQ5oz+F3jjPOPMP9HwGXTGJSk6fLKPalBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBvYKf5KQk9yZ5JsmuJBcnOSXJg0me625PnvZgJQ2j7x7/FuCBqjqXSRmuXdhJRxqtPlV23wf8JnAbQFX9d1W9gZ10pNHqs8c/CzgIfCXJ40lu7cps20lHGqk+wV8JXAh8qaouAN7ksMP6qiombbb+nyQ3JNmZZOfBgwcXOl5JA+hTV38vsLeqdnTT9zIJ/v4ka6rq5bk66QBbAGZmZo744aDxu+6665Z6CJqHOff4VfUK8FKSc7pZlwFPYycdabT6Ns38Y+COJMcDLwB/yORDw0460gj1Cn5VPQHMHOEhO+lII+SVe1KDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKD5qzA09Xau3vWrPcDfwH8bTd/HbAHuLqqXh9+iG9v27Ztg7zP5s2bB3mfVr322mtLPQTNQ59im89W1fqqWg/8OvCfwH3YSUcarfke6l8GPF9VL2InHWm05hv8a4A7u/t20pFGqnfwu9LaVwJfPfwxO+lI4zKfPf5m4LGq2t9N7+866DBXJ52qmqmqmVWrVi1stJIGMZ/gX8v/HeaDnXSk0eoV/K477kbg67NmfxbYmOQ54PJuWtII9O2k8yZw6mHzfoSddKRR8so9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9rtxbzjZt2jTI+0z+wVBqg3t8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZca1Lf01p8meSrJk0nuTHJCkrOS7EiyO8ndXRVeSSMwZ/CTnA78CTBTVb8GrGBSX/9zwBeq6gPA68D10xyopOH0PdRfCbw7yUrgPcDLwKXAvd3jdtKRRqRP77x9wF8BP2QS+H8HHgXeqKpD3dP2AqdPa5CShtXnUP9kJn3yzgJ+BTgR6H2BvJ10pOWnz6H+5cAPqupgVf0Pk9r6lwAndYf+AGuBfUd6sZ10pOWnT/B/CFyU5D1JwqSW/tPAQ8CHu+fYSUcakT7f8XcwOYn3GPC97jVbgE8BH0+ym0mzjdumOE5JA+rbSefTwKcPm/0CsGHwEUmaOq/ckxpk8KUGGXypQQZfalAWs8hkkoPAm8Cri7bQ6TsN12e5eietC/Rbn1+tqjkvmFnU4AMk2VlVM4u60ClyfZavd9K6wLDr46G+1CCDLzVoKYK/ZQmWOU2uz/L1TloXGHB9Fv07vqSl56G+1KBFDX6STUme7er03bSYy16oJGckeSjJ0139wRu7+ackeTDJc93tyUs91vlIsiLJ40m2dtOjraWY5KQk9yZ5JsmuJBePeftMs9blogU/yQrgr4HNwHnAtUnOW6zlD+AQ8ImqOg+4CPhoN/6bgO1VdTawvZsekxuBXbOmx1xL8Rbggao6FzifyXqNcvtMvdZlVS3KH3Ax8K1Z0zcDNy/W8qewPt8ANgLPAmu6eWuAZ5d6bPNYh7VMwnApsBUIkwtEVh5pmy3nP+B9wA/ozlvNmj/K7cOklN1LwClM/ot2K/C7Q22fxTzUf2tF3jLaOn1J1gEXADuA1VX1cvfQK8DqJRrWsfgi8EngZ930qYy3luJZwEHgK91Xl1uTnMhIt09NudalJ/fmKcl7ga8BH6uqH89+rCYfw6P4mSTJh4ADVfXoUo9lICuBC4EvVdUFTC4N/7nD+pFtnwXVupzLYgZ/H3DGrOmj1ulbrpIcxyT0d1TV17vZ+5Os6R5fAxxYqvHN0yXAlUn2AHcxOdy/hZ61FJehvcDemlSMgknVqAsZ7/ZZUK3LuSxm8B8Bzu7OSh7P5ETF/Yu4/AXp6g3eBuyqqs/Peuh+JjUHYUS1B6vq5qpaW1XrmGyLb1fVdYy0lmJVvQK8lOScbtZbtSFHuX2Ydq3LRT5hcQXwfeB54M+X+gTKPMf+QSaHid8Fnuj+rmDyvXg78BzwT8ApSz3WY1i33wK2dvffDzwM7Aa+Crxrqcc3j/VYD+zsttE/ACePefsAnwGeAZ4E/g5411Dbxyv3pAZ5ck9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlB/wsPrvc22jHaHAAAAABJRU5ErkJggg==\n", 198 | "text/plain": [ 199 | "
" 200 | ] 201 | }, 202 | "metadata": {}, 203 | "output_type": "display_data" 204 | } 205 | ], 206 | "source": [ 207 | "processed_prev_state = preprocess_image(prev_state)\n", 208 | "plt.imshow(processed_prev_state,cmap='gray')" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "

Build Model

" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 8, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "Network(\n", 228 | " (conv_layer1): Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))\n", 229 | " (conv_layer2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n", 230 | " (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", 231 | " (fc1): Linear(in_features=3136, out_features=512, bias=True)\n", 232 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n", 233 | " (relu): ReLU()\n", 234 | ")\n" 235 | ] 236 | } 237 | ], 238 | "source": [ 239 | "import torch.nn as nn\n", 240 | "import torch\n", 241 | "\n", 242 | "class Network(nn.Module):\n", 243 | " \n", 244 | " def __init__(self,image_input_size,out_size):\n", 245 | " super(Network,self).__init__()\n", 246 | " self.image_input_size = image_input_size\n", 247 | " self.out_size = out_size\n", 248 | "\n", 249 | " self.conv_layer1 = nn.Conv2d(in_channels=1,out_channels=32,kernel_size=8,stride=4) # GRAY - 1\n", 250 | " self.conv_layer2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4,stride=2)\n", 251 | " self.conv_layer3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1)\n", 252 | " self.fc1 = nn.Linear(in_features=7*7*64,out_features=512)\n", 253 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n", 254 | " self.relu = nn.ReLU()\n", 255 | "\n", 256 | " def forward(self,x,bsize):\n", 257 | " x = x.view(bsize,1,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n", 258 | " conv_out = self.conv_layer1(x)\n", 259 | " conv_out = self.relu(conv_out)\n", 260 | " conv_out = self.conv_layer2(conv_out)\n", 261 | " conv_out = self.relu(conv_out)\n", 262 | " conv_out = self.conv_layer3(conv_out)\n", 263 | " conv_out = self.relu(conv_out)\n", 264 | " out = self.fc1(conv_out.view(bsize,7*7*64))\n", 265 | " out = self.relu(out)\n", 266 | " out = self.fc2(out)\n", 267 | " return out\n", 268 | "\n", 269 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE)\n", 270 | "print(main_model)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "

Deep Q Learning with Target Freeze

" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Populated 200 Samples\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "mem = Memory(memsize=MEMORY_SIZE)\n", 295 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda()\n", 296 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda()\n", 297 | "\n", 298 | "target_model.load_state_dict(main_model.state_dict())\n", 299 | "criterion = nn.SmoothL1Loss()\n", 300 | "optimizer = torch.optim.Adam(main_model.parameters())\n", 301 | "\n", 302 | "# filling memory with transitions\n", 303 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n", 304 | " \n", 305 | " prev_state = env.reset()\n", 306 | " processed_prev_state = preprocess_image(prev_state)\n", 307 | " step_count = 0\n", 308 | " game_over = False\n", 309 | " \n", 310 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 311 | " \n", 312 | " step_count +=1\n", 313 | " action = np.random.randint(0,4)\n", 314 | " next_state,reward, game_over = env.step(action)\n", 315 | " processed_next_state = preprocess_image(next_state)\n", 316 | " mem.add_sample((processed_prev_state,action,reward,processed_next_state,game_over))\n", 317 | " \n", 318 | " prev_state = next_state\n", 319 | " processed_prev_state = processed_next_state\n", 320 | "\n", 321 | "print('Populated %d Samples'%(len(mem.memory)))\n", 322 | "\n", 323 | "# Algorithm Starts\n", 324 | "total_steps = 0\n", 325 | "epsilon = INITIAL_EPSILON\n", 326 | "loss_stat = []\n", 327 | "total_reward_stat = []\n", 328 | "\n", 329 | "for episode in range(0,TOTAL_EPISODES):\n", 330 | " \n", 331 | " prev_state = env.reset()\n", 332 | " processed_prev_state = preprocess_image(prev_state)\n", 333 | " game_over = False\n", 334 | " step_count = 0\n", 335 | " total_reward = 0\n", 336 | " \n", 337 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 338 | " \n", 339 | " step_count +=1\n", 340 | " total_steps +=1\n", 341 | " \n", 342 | " if np.random.rand() <= epsilon:\n", 343 | " action = np.random.randint(0,4)\n", 344 | " else:\n", 345 | " with torch.no_grad():\n", 346 | " torch_x = torch.from_numpy(processed_prev_state).float().cuda()\n", 347 | "\n", 348 | " model_out = main_model.forward(torch_x,bsize=1)\n", 349 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 350 | " \n", 351 | " next_state, reward, game_over = env.step(action)\n", 352 | " processed_next_state = preprocess_image(next_state)\n", 353 | " total_reward += reward\n", 354 | " \n", 355 | " mem.add_sample((processed_prev_state,action,reward,processed_next_state,game_over))\n", 356 | " \n", 357 | " prev_state = next_state\n", 358 | " processed_prev_state = processed_next_state\n", 359 | " \n", 360 | " if (total_steps % FREEZE_INTERVAL) == 0:\n", 361 | " target_model.load_state_dict(main_model.state_dict())\n", 362 | " \n", 363 | " batch = mem.get_batch(size=BATCH_SIZE)\n", 364 | " current_states = []\n", 365 | " next_states = []\n", 366 | " acts = []\n", 367 | " rewards = []\n", 368 | " game_status = []\n", 369 | " \n", 370 | " for element in batch:\n", 371 | " current_states.append(element[0])\n", 372 | " acts.append(element[1])\n", 373 | " rewards.append(element[2])\n", 374 | " next_states.append(element[3])\n", 375 | " game_status.append(element[4])\n", 376 | " \n", 377 | " current_states = np.array(current_states)\n", 378 | " next_states = np.array(next_states)\n", 379 | " rewards = np.array(rewards)\n", 380 | " game_status = [not b for b in game_status]\n", 381 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n", 382 | " torch_acts = torch.tensor(acts)\n", 383 | " \n", 384 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n", 385 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n", 386 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n", 387 | " Q_max_next = Q_max_next.double()\n", 388 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n", 389 | " \n", 390 | " target_values = (rewards + (GAMMA * Q_max_next)).cuda()\n", 391 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n", 392 | " \n", 393 | " loss = criterion(Q_s_a,target_values.float())\n", 394 | " \n", 395 | " # save performance measure\n", 396 | " loss_stat.append(loss.item())\n", 397 | " \n", 398 | " # make previous grad zero\n", 399 | " optimizer.zero_grad()\n", 400 | " \n", 401 | " # back - propogate \n", 402 | " loss.backward()\n", 403 | " \n", 404 | " # update params\n", 405 | " optimizer.step()\n", 406 | " \n", 407 | " # save performance measure\n", 408 | " total_reward_stat.append(total_reward)\n", 409 | " \n", 410 | " if epsilon > FINAL_EPSILON:\n", 411 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 412 | " \n", 413 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 414 | " perf = {}\n", 415 | " perf['loss'] = loss_stat\n", 416 | " perf['total_reward'] = total_reward_stat\n", 417 | " save_obj(name='MDP_ENV_SIZE_NINE',obj=perf)\n", 418 | " \n", 419 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "metadata": {}, 425 | "source": [ 426 | "

Save Primary Network Weights

" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 18, 432 | "metadata": { 433 | "collapsed": true 434 | }, 435 | "outputs": [], 436 | "source": [ 437 | "torch.save(main_model.state_dict(),'data/MDP_ENV_SIZE_NINE_WEIGHTS.torch')" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "

Testing Policy

" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "

Load Primary Network Weights

" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 9, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "weights = torch.load('data/MDP_ENV_SIZE_NINE_WEIGHTS.torch', map_location='cpu')\n", 461 | "main_model.load_state_dict(weights)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": {}, 467 | "source": [ 468 | "

Test Policy

" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "# Algorithm Starts\n", 478 | "epsilon = INITIAL_EPSILON\n", 479 | "FINAL_EPSILON = 0.01\n", 480 | "total_reward_stat = []\n", 481 | "\n", 482 | "for episode in range(0,TOTAL_EPISODES):\n", 483 | " \n", 484 | " prev_state = env.reset()\n", 485 | " processed_prev_state = preprocess_image(prev_state)\n", 486 | " game_over = False\n", 487 | " step_count = 0\n", 488 | " total_reward = 0\n", 489 | " \n", 490 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 491 | " \n", 492 | " step_count +=1\n", 493 | " \n", 494 | " if np.random.rand() <= epsilon:\n", 495 | " action = np.random.randint(0,4)\n", 496 | " else:\n", 497 | " with torch.no_grad():\n", 498 | " torch_x = torch.from_numpy(processed_prev_state).float().cuda()\n", 499 | "\n", 500 | " model_out = main_model.forward(torch_x,bsize=1)\n", 501 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 502 | " \n", 503 | " next_state, reward, game_over = env.step(action)\n", 504 | " processed_next_state = preprocess_image(next_state)\n", 505 | " total_reward += reward\n", 506 | " \n", 507 | " prev_state = next_state\n", 508 | " processed_prev_state = processed_next_state\n", 509 | " \n", 510 | " # save performance measure\n", 511 | " total_reward_stat.append(total_reward)\n", 512 | " \n", 513 | " if epsilon > FINAL_EPSILON:\n", 514 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 515 | " \n", 516 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 517 | " perf = {}\n", 518 | " perf['total_reward'] = total_reward_stat\n", 519 | " save_obj(name='MDP_ENV_SIZE_NINE',obj=perf)\n", 520 | " \n", 521 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "metadata": {}, 527 | "source": [ 528 | "

Create Policy GIF

" 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": {}, 534 | "source": [ 535 | "

Collect Frames Of an Episode Using Trained Network

" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 28, 541 | "metadata": {}, 542 | "outputs": [ 543 | { 544 | "name": "stdout", 545 | "output_type": "stream", 546 | "text": [ 547 | "Total Reward : 14\n" 548 | ] 549 | } 550 | ], 551 | "source": [ 552 | "frames = []\n", 553 | "random.seed(110)\n", 554 | "np.random.seed(110)\n", 555 | "\n", 556 | "for episode in range(0,1):\n", 557 | " \n", 558 | " prev_state = env.reset()\n", 559 | " processed_prev_state = preprocess_image(prev_state)\n", 560 | " frames.append(prev_state)\n", 561 | " game_over = False\n", 562 | " step_count = 0\n", 563 | " total_reward = 0\n", 564 | " \n", 565 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 566 | " \n", 567 | " step_count +=1\n", 568 | " \n", 569 | " with torch.no_grad():\n", 570 | " torch_x = torch.from_numpy(processed_prev_state).float()\n", 571 | " model_out = main_model.forward(torch_x,bsize=1)\n", 572 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 573 | " \n", 574 | " next_state, reward, game_over = env.step(action)\n", 575 | " frames.append(next_state)\n", 576 | " processed_next_state = preprocess_image(next_state)\n", 577 | " total_reward += reward\n", 578 | " \n", 579 | " prev_state = next_state\n", 580 | " processed_prev_state = processed_next_state\n", 581 | "\n", 582 | "print('Total Reward : %d'%(total_reward)) # This should output same value which verifies seed is working correctly\n", 583 | " " 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 29, 589 | "metadata": {}, 590 | "outputs": [], 591 | "source": [ 592 | "from PIL import Image, ImageDraw\n", 593 | "\n", 594 | "for idx, img in enumerate(frames):\n", 595 | " image = Image.fromarray(img)\n", 596 | " drawer = ImageDraw.Draw(image)\n", 597 | " drawer.rectangle([(7,7),(76,76)], outline=(255, 255, 0))\n", 598 | " #plt.imshow(np.array(image))\n", 599 | " frames[idx] = np.array(image)\n", 600 | " " 601 | ] 602 | }, 603 | { 604 | "cell_type": "markdown", 605 | "metadata": {}, 606 | "source": [ 607 | "

Frames to GIF

" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 30, 613 | "metadata": {}, 614 | "outputs": [ 615 | { 616 | "name": "stderr", 617 | "output_type": "stream", 618 | "text": [ 619 | "/home/mayank/miniconda3/envs/rdqn/lib/python3.5/site-packages/ipykernel_launcher.py:5: DeprecationWarning: `imresize` is deprecated!\n", 620 | "`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n", 621 | "Use ``skimage.transform.resize`` instead.\n", 622 | " \"\"\"\n" 623 | ] 624 | } 625 | ], 626 | "source": [ 627 | "import imageio\n", 628 | "from scipy.misc import imresize\n", 629 | "resized_frames = []\n", 630 | "for frame in frames:\n", 631 | " resized_frames.append(imresize(frame,(256,256)))\n", 632 | "imageio.mimsave('data/GIFs/MDP_SIZE_9.gif',resized_frames,fps=4)" 633 | ] 634 | } 635 | ], 636 | "metadata": { 637 | "kernelspec": { 638 | "display_name": "Python 3", 639 | "language": "python", 640 | "name": "python3" 641 | }, 642 | "language_info": { 643 | "codemirror_mode": { 644 | "name": "ipython", 645 | "version": 3 646 | }, 647 | "file_extension": ".py", 648 | "mimetype": "text/x-python", 649 | "name": "python", 650 | "nbconvert_exporter": "python", 651 | "pygments_lexer": "ipython3", 652 | "version": "3.5.6" 653 | } 654 | }, 655 | "nbformat": 4, 656 | "nbformat_minor": 2 657 | } 658 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recurrent-Deep-Q-Learning 2 | 3 | # Introduction 4 | Partially Observable Markov Decision Process (POMDP) is a generalization of Markov Decision Process where agent cannot directly observe 5 | the underlying state and only an observation is available. Earlier methods suggests to maintain a belief (a pmf) over all the possible states which encodes the probability of being in each state. This quickly limits the size of the problem to which we can use this method. However, the paper [Playing Atari with Deep Reinforcement Learning](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) presented an approach which uses last 4 observations as input to the learning 6 | algorithm, which can be seen as 4th order markov decision process. Many papers suggest that much performance can be obtained if we use more than last 4 frames but this is expensive from computational and storage point of view (experience replay). Recurrent networks can be used to summarize what agent has seen in past observations. **In this project, I investgated this using a simple Partially Observable Environment and found that using a single recurrent layer able to achieve much better performance than using some last k-frames.** 7 | 8 | # Environment 9 | 10 | Environment used was a 9x9 grid where *Red* colored blocks represents agent location in the grid. *Green* colored blocks are the goal points for agent. *Blue* colored blocks are the blocks which agent needs to avoid. A reward of +1 is given to agent when it eats *green* block. A reward of -1 is given to the agent when it eats *blue* block. Other movements results in zero. 11 | Observation is the RGB image of neigbouring cells of agent. Below figure describes the observation. 12 | 13 | Underlying MDP | Observation to Agent 14 | :-------------------------:|:-------------------------: 15 | ![](https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/master/data/download%20(1).png) | ![](https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/master/data/download.png) 16 | 17 | # Algorithm 18 | 19 |

20 | 21 | # How to Run ? 22 | 23 | I ran the experiment for the following cases. The corresponding code/jupyter files are linked to each experiment. 24 | * [MDP Case](https://github.com/mynkpl1998/Recurrent-Deep-Q-Learning/blob/master/MDP_Size_9.ipynb) - The underlying state was fully visible. The whole grid was given as the input to the agent. 25 | * [Single Observation](https://github.com/mynkpl1998/Recurrent-Deep-Q-Learning/blob/master/Single%20Observation.ipynb) - In this case, the most recent observation was used as the input to agent. 26 | * [Last Two Observations](https://github.com/mynkpl1998/Recurrent-Deep-Q-Learning/blob/master/Two%20Observations.ipynb) - In this case, the last two most recent observation was used as the input to agent to encode the temporal information among observations. 27 | * [LSTM Case](https://github.com/mynkpl1998/Recurrent-Deep-Q-Learning/blob/master/LSTM%2C%20BPTT%3D8.ipynb) - In this case, an LSTM layer is used to pass the temporal information among observations. 28 | 29 | # Learned Policies 30 | 31 | Fully Observable | Single Observation | LSTM 32 | :-------------------------:|:-------------------------:|:-------------------------: 33 | ![](https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/master/data/GIFs/MDP_SIZE_9.gif) | ![](https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/master/data/GIFs/SINGLE_SIZE_9_frames.gif) | ![](https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/master/data/GIFs/LSTM_SIZE_9_frames.gif) 34 | 35 | # Results 36 | 37 | The figure given below compares the performance of different cases. MDP case is the best we can do as the underlying state is fully visible to the agent. However, the challenge is to perform better given an observation. The graph clearly shows the LSTM consistently performed better as the total reward per episode was much higher than using some last k-frames. 38 | 39 |

40 | 41 | # References 42 | * [Deep Recurrent Q Learning for POMDPs](https://arxiv.org/pdf/1507.06527.pdf) 43 | * [Recurrent Neural Networks for Reinforcement Learning: an Investigation of Relevant Design Choices](https://esc.fnwi.uva.nl/thesis/centraal/files/f499544468.pdf) 44 | * [Grid World POMDP Environment](https://github.com/awjuliani/DeepRL-Agents) 45 | 46 | # Requirements 47 | * Python >= 3.5 48 | * PyTorch >= 0.4.1 49 | -------------------------------------------------------------------------------- /Two Observations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from gridworld import gameEnv\n", 12 | "import numpy as np\n", 13 | "%matplotlib inline\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from collections import deque\n", 16 | "import pickle\n", 17 | "from skimage.color import rgb2gray\n", 18 | "import random\n", 19 | "import torch\n", 20 | "import torch.nn as nn" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "

Define Environment Object

" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADKdJREFUeJzt3V+sHPV5xvHvUxtCQtqAgVouhh5XQSBUCUOtFERUtYBbQiLoRYRAURVVSNykLTSREmivIvUikaokXFSRECRFFeVPCDTIikipQ1RVqhyOgSZgQ2wIBFuA3RRKSqW2Tt5ezDg9ce2cOT67e87w+36k1e7M7Gp+o/GzMzue876pKiS15RdWegCSZs/gSw0y+FKDDL7UIIMvNcjgSw0y+FKDlhX8JFcmeS7J3iS3TGpQkqYrx3sDT5I1wPeArcA+4HHg+qraNbnhSZqGtcv47PuAvVX1AkCSe4FrgGMG//TTT6+5ubllrPLtb+fOnSs9BI1cVWWx9ywn+GcCLy+Y3gf85s/7wNzcHPPz88tY5dtfsug+k5Zt6hf3ktyYZD7J/MGDB6e9OkkDLCf4+4GzFkxv7Of9jKq6vaq2VNWWM844YxmrkzQpywn+48A5STYlORG4Dnh4MsOSNE3H/Ru/qg4l+SPgG8Aa4EtV9czERiZpapZzcY+q+jrw9QmNRdKMeOee1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KBFg5/kS0kOJHl6wbx1SR5Nsqd/PnW6w5Q0SUOO+H8NXHnEvFuA7VV1DrC9n5Y0EosGv6r+Efi3I2ZfA9zVv74L+P0Jj0vSFB3vb/z1VfVK//pVYP2ExiNpBpZ9ca+6rpvH7LxpJx1p9Tne4L+WZANA/3zgWG+0k460+hxv8B8GPtq//ijwtckMR9IsDPnvvHuAfwbOTbIvyQ3AZ4CtSfYAV/TTkkZi0U46VXX9MRZdPuGxSJoR79yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGjSk9NZZSR5LsivJM0lu6ufbTUcaqSFH/EPAJ6rqfOBi4GNJzsduOtJoDemk80pVPdG//hGwGzgTu+lIo7Wk3/hJ5oALgR0M7KZjQw1p9Rkc/CTvBr4K3FxVby5c9vO66dhQQ1p9BgU/yQl0ob+7qh7sZw/upiNpdRlyVT/AncDuqvrcgkV205FGatGGGsClwB8A303yVD/vz+i659zfd9Z5Cbh2OkOUNGlDOun8E5BjLLabjjRC3rknNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0a8vf4mqmjVjDTTx3rL8S1FB7xpQYZfKlBQ2runZTk20n+pe+k8+l+/qYkO5LsTXJfkhOnP1xJkzDkiP9fwGVVdQGwGbgyycXAZ4HPV9V7gdeBG6Y3TEmTNKSTTlXVf/STJ/SPAi4DHujn20lHGpGhdfXX9BV2DwCPAs8Db1TVof4t++jaah3ts3bSkVaZQcGvqh9X1WZgI/A+4LyhK7CTjrT6LOmqflW9ATwGXAKckuTwfQAbgf0THpukKRlyVf+MJKf0r98JbKXrmPsY8OH+bXbSkUZkyJ17G4C7kqyh+6K4v6q2JdkF3JvkL4An6dpsSRqBIZ10vkPXGvvI+S/Q/d6XNDLeuSc1yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aHDw+xLbTybZ1k/bSUcaqaUc8W+iK7J5mJ10pJEa2lBjI/BB4I5+OthJRxqtoUf8LwCfBH7ST5+GnXSk0RpSV/9DwIGq2nk8K7CTjrT6DKmrfylwdZKrgJOAXwJuo++k0x/17aQjjciQbrm3VtXGqpoDrgO+WVUfwU460mgt5//xPwV8PMleut/8dtKRRmLIqf5PVdW3gG/1r+2kI42Ud+5JDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81aFAhjiQvAj8CfgwcqqotSdYB9wFzwIvAtVX1+nSGKWmSlnLE/52q2lxVW/rpW4DtVXUOsL2fljQCyznVv4aukQbYUEMalaHBL+Dvk+xMcmM/b31VvdK/fhVYP/HRSZqKocU2319V+5P8MvBokmcXLqyqSlJH+2D/RXEjwNlnn72swUqajEFH/Kra3z8fAB6iq677WpINAP3zgWN81k460iozpIXWyUl+8fBr4HeBp4GH6RppgA01pFEZcqq/Hnioa5DLWuBvq+qRJI8D9ye5AXgJuHZ6w5Q0SYsGv2+cccFR5v8QuHwag5I0Xd65JzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVo6F/naWay0gNQAzziSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEHBT3JKkgeSPJtkd5JLkqxL8miSPf3zqdMerKTJGHrEvw14pKrOoyvDtRs76UijNaTK7nuA3wLuBKiq/66qN7CTjjRaQ474m4CDwJeTPJnkjr7Mtp10pJEaEvy1wEXAF6vqQuAtjjitr6qia7P1/yS5Mcl8kvmDBw8ud7ySJmBI8PcB+6pqRz/9AN0XgZ10pJFaNPhV9SrwcpJz+1mXA7uwk440WkP/LPePgbuTnAi8APwh3ZeGnXSkERoU/Kp6CthylEV20pFGyDv3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYNqat/bpKnFjzeTHKznXSk8RpSbPO5qtpcVZuB3wD+E3gIO+lIo7XUU/3Lgeer6iXspCON1lKDfx1wT//aTjrSSA0Ofl9a+2rgK0cus5OONC5LOeJ/AHiiql7rp+2kI43UUoJ/Pf93mg920pFGa1Dw++64W4EHF8z+DLA1yR7gin5a0ggM7aTzFnDaEfN+iJ10pFHyzj2pQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQUNLb/1pkmeSPJ3kniQnJdmUZEeSvUnu66vwShqBIS20zgT+BNhSVb8OrKGrr/9Z4PNV9V7gdeCGaQ5U0uQMPdVfC7wzyVrgXcArwGXAA/1yO+lIIzKkd95+4C+BH9AF/t+BncAbVXWof9s+4MxpDVLSZA051T+Vrk/eJuBXgJOBK4euwE460uoz5FT/CuD7VXWwqv6Hrrb+pcAp/ak/wEZg/9E+bCcdafUZEvwfABcneVeS0NXS3wU8Bny4f4+ddKQRGfIbfwfdRbwngO/2n7kd+BTw8SR76Zpt3DnFcUqaoHSNbmdjy5YtNT8/P7P1jVF3UiUdv6pa9B+Rd+5JDTL4UoMMvtQggy81aKYX95IcBN4C/nVmK52+03F7Vqu307bAsO351apa9IaZmQYfIMl8VW2Z6UqnyO1Zvd5O2wKT3R5P9aUGGXypQSsR/NtXYJ3T5PasXm+nbYEJbs/Mf+NLWnme6ksNmmnwk1yZ5Lm+Tt8ts1z3ciU5K8ljSXb19Qdv6uevS/Jokj3986krPdalSLImyZNJtvXTo62lmOSUJA8keTbJ7iSXjHn/TLPW5cyCn2QN8FfAB4DzgeuTnD+r9U/AIeATVXU+cDHwsX78twDbq+ocYHs/PSY3AbsXTI+5luJtwCNVdR5wAd12jXL/TL3WZVXN5AFcAnxjwfStwK2zWv8UtudrwFbgOWBDP28D8NxKj20J27CRLgyXAduA0N0gsvZo+2w1P4D3AN+nv261YP4o9w9dKbuXgXV0NS+3Ab83qf0zy1P9wxty2Gjr9CWZAy4EdgDrq+qVftGrwPoVGtbx+ALwSeAn/fRpjLeW4ibgIPDl/qfLHUlOZqT7p6Zc69KLe0uU5N3AV4Gbq+rNhcuq+xoexX+TJPkQcKCqdq70WCZkLXAR8MWqupDu1vCfOa0f2f5ZVq3Lxcwy+PuBsxZMH7NO32qV5AS60N9dVQ/2s19LsqFfvgE4sFLjW6JLgauTvAjcS3e6fxsDaymuQvuAfdVVjIKuatRFjHf/LKvW5WJmGfzHgXP6q5In0l2oeHiG61+Wvt7gncDuqvrcgkUP09UchBHVHqyqW6tqY1XN0e2Lb1bVRxhpLcWqehV4Ocm5/azDtSFHuX+Ydq3LGV+wuAr4HvA88OcrfQFliWN/P91p4neAp/rHVXS/i7cDe4B/ANat9FiPY9t+G9jWv/414NvAXuArwDtWenxL2I7NwHy/j/4OOHXM+wf4NPAs8DTwN8A7JrV/vHNPapAX96QGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxr0v95G4v3/3GYuAAAAAElFTkSuQmCC\n", 38 | "text/plain": [ 39 | "
" 40 | ] 41 | }, 42 | "metadata": {}, 43 | "output_type": "display_data" 44 | } 45 | ], 46 | "source": [ 47 | "env = gameEnv(partial=True,size=9)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "text/plain": [ 58 | "" 59 | ] 60 | }, 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | }, 65 | { 66 | "data": { 67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3VuMXeV5xvH/Uw8OCUljm7SWi0ltFARCVTGRlYLggpLSOjSCXEQpKJHSKi03qUraSsG0Fy2VIiVSlYSLKpIFSVGVcohDE4uLpK7jpL1yMIe2YONgEgi2DKYCcrpAdXh7sZfbwR17r5nZe2YW3/8njfZeax/Wt2bp2eswe943VYWktvzCcg9A0tIz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoEUFP8m2JIeSHE6yfVKDkjRdWegXeJKsAr4HXAscAR4CbqqqA5MbnqRpmFnEa98DHK6q7wMkuRe4ATht8JP4NUFpyqoq456zmEP984DnZk0f6eZJWuEWs8fvJcnNwM3TXo6k/hYT/KPA+bOmN3bzXqeqdgA7wEN9aaVYzKH+Q8CFSTYnWQ3cCOyazLAkTdOC9/hVdSLJHwPfBFYBX6yqJyY2MklTs+A/5y1oYR7qS1M37av6kgbK4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzVobPCTfDHJ8SSPz5q3LsnuJE91t2unO0xJk9Rnj//3wLZT5m0H9lTVhcCeblrSQIwNflX9K/DSKbNvAO7u7t8NfGDC45I0RQs9x19fVce6+88D6yc0HklLYNGddKqqzlQ910460sqz0D3+C0k2AHS3x0/3xKraUVVbq2rrApclacIWGvxdwEe7+x8Fvj6Z4UhaCmMbaiS5B7gaeAfwAvBXwNeA+4F3As8CH6qqUy8AzvVeNtSQpqxPQw076UhvMHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgPp10zk+yN8mBJE8kuaWbbzcdaaD61NzbAGyoqkeSvA14mFEDjd8HXqqqTyfZDqytqlvHvJelt6Qpm0jprao6VlWPdPd/AhwEzsNuOtJgzauhRpJNwGXAPnp207GhhrTy9K6ym+StwHeAT1XVA0leqao1sx5/uarOeJ7vob40fROrspvkLOCrwJer6oFudu9uOpJWlj5X9QPcBRysqs/OeshuOtJA9bmqfxXwb8B/Aq91s/+C0Xn+vLrpeKgvTZ+ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgedXc01LwP5fPbOx/nKoH9/hSgwy+1KA+NffOTvLdJP/eddK5vZu/Ocm+JIeT3Jdk9fSHK2kS+uzxXwWuqapLgS3AtiSXA58BPldV7wJeBj42vWFKmqQ+nXSqqn7aTZ7V/RRwDbCzm28nHWlA+tbVX5XkMUa183cDTwOvVNWJ7ilHGLXVmuu1NyfZn2T/JAYsafF6Bb+qfl5VW4CNwHuAi/suoKp2VNXWqtq6wDFKmrB5XdWvqleAvcAVwJokJ78HsBE4OuGxSZqSPlf1fynJmu7+m4FrGXXM3Qt8sHuanXSkAenTSefXGV28W8Xog+L+qvqbJBcA9wLrgEeBj1TVq2Pey6+ljeWv6Mz85t44dtIZJH9FZ2bwx7GTjqQ5GXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUG9Q5+V2L70SQPdtN20pEGaj57/FsYFdk8yU460kD1baixEfhd4M5uOthJRxqsvnv8zwOfBF7rps/FTjrSYPWpq/9+4HhVPbyQBdhJR1p5ZsY/hSuB65NcB5wN/CJwB10nnW6vbycdaUD6dMu9rao2VtUm4EbgW1X1YeykIw3WYv6OfyvwZ0kOMzrnv2syQ5I0bXbSWXH8FZ2ZnXTGsZOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn5h5JngF+AvwcOFFVW5OsA+4DNgHPAB+qqpenM0xJkzSfPf5vVtWWWdVytwN7qupCYE83LWkAFnOofwOjRhpgQw1pUPoGv4B/TvJwkpu7eeur6lh3/3lg/cRHJ2kqep3jA1dV1dEkvwzsTvLk7Aerqk5XSLP7oLh5rsckLY95V9lN8tfAT4E/Aq6uqmNJNgDfrqqLxrzWErJj+Ss6M6vsjjORKrtJzknytpP3gd8GHgd2MWqkATbUkAZl7B4/yQXAP3WTM8A/VtWnkpwL3A+8E3iW0Z/zXhrzXu7OxvJXdGbu8cfps8e3ocaK46/ozAz+ODbUkDQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtSgvv+dpyXjN9M0fe7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mTZGeSJ5McTHJFknVJdid5qrtdO+3BSpqMvnv8O4BvVNXFwKXAQeykIw1Wn2KbbwceAy6oWU9OcgjLa0srzqRq7m0GXgS+lOTRJHd2ZbbtpCMNVJ/gzwDvBr5QVZcBP+OUw/ruSOC0nXSS7E+yf7GDlTQZfYJ/BDhSVfu66Z2MPghe6A7x6W6Pz/XiqtpRVVtnddmVtMzGBr+qngeeS3Ly/P29wAHspCMNVq+GGkm2AHcCq4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9yUZLHZv38OMkn7KQjDde8Sm8lWQUcBX4D+DjwUlV9Osl2YG1V3Trm9ZbekqZsGqW33gs8XVXPAjcAd3fz7wY+MM/3krRM5hv8G4F7uvt20pEGqnfwk6wGrge+cupjdtKRhmU+e/z3AY9U1QvdtJ10pIGaT/Bv4v8O88FOOtJg9e2kcw7wQ0atsn/UzTsXO+lIK46ddKQG2UlH0pwMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoN6BT/JnyZ5IsnjSe5JcnaSzUn2JTmc5L6uCq+kAejTQus84E+ArVX1a8AqRvX1PwN8rqreBbwMfGyaA5U0OX0P9WeANyeZAd4CHAOuAXZ2j9tJRxqQscGvqqPA3zKqsnsM+BHwMPBKVZ3onnYEOG9ag5Q0WX0O9dcy6pO3GfgV4BxgW98F2ElHWnlmejznt4AfVNWLAEkeAK4E1iSZ6fb6Gxl10f1/qmoHsKN7reW1pRWgzzn+D4HLk7wlSRh1zD0A7AU+2D3HTjrSgPTtpHM78HvACeBR4A8ZndPfC6zr5n2kql4d8z7u8aUps5OO1CA76Uiak8GXGmTwpQYZfKlBff6OP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akv1VtXVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7g71iGZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOfZFuSQ12dvu1LuezFSnJ+kr1JDnT1B2/p5q9LsjvJU93t2uUe63wkWZXk0SQPdtODraWYZE2SnUmeTHIwyRVD3j7TrHW5ZMFPsgr4O+B9wCXATUkuWarlT8AJ4M+r6hLgcuDj3fi3A3uq6kJgTzc9JLcAB2dND7mW4h3AN6rqYuBSRus1yO0z9VqXVbUkP8AVwDdnTd8G3LZUy5/C+nwduBY4BGzo5m0ADi332OaxDhsZheEa4EEgjL4gMjPXNlvJP8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3JFThpsnb4km4DLgH3A+qo61j30PLB+mYa1EJ8HPgm81k2fy3BrKW4GXgS+1J263JnkHAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsSfSZK8HzheVQ8v91gmZAZ4N/CFqrqM0VfDX3dYP7Dts6hal+MsZfCPAufPmj5tnb6VKslZjEL/5ap6oJv9QpIN3eMbgOPLNb55uhK4PskzjCopXcPoHHlNV0YdhrWNjgBHqmpfN72T0QfBULfP/9a6rKr/Bl5X67J7zoK3z1IG/yHgwu6q5GpGFyp2LeHyF6WrN3gXcLCqPjvroV2Mag7CgGoPVtVtVbWxqjYx2hbfqqoPM9BailX1PPBckou6WSdrQw5y+zDtWpdLfMHiOuB7wNPAXy73BZR5jv0qRoeJ/wE81v1cx+i8eA/wFPAvwLrlHusC1u1q4MHu/gXAd4HDwFeANy33+OaxHluA/d02+hqwdsjbB7gdeBJ4HPgH4E2T2j5+c09qkBf3pAYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGvQ/4fcTtlJMEyYAAAAASUVORK5CYII=\n", 68 | "text/plain": [ 69 | "
" 70 | ] 71 | }, 72 | "metadata": {}, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "prev_state = env.reset()\n", 78 | "plt.imshow(prev_state)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "

Training Q Network

" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "

Hyper-parameters

" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": { 99 | "collapsed": true 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "BATCH_SIZE = 32\n", 104 | "FREEZE_INTERVAL = 20000 # steps\n", 105 | "MEMORY_SIZE = 60000 \n", 106 | "OUTPUT_SIZE = 4\n", 107 | "TOTAL_EPISODES = 10000\n", 108 | "MAX_STEPS = 50\n", 109 | "INITIAL_EPSILON = 1.0\n", 110 | "FINAL_EPSILON = 0.1\n", 111 | "GAMMA = 0.99\n", 112 | "INPUT_IMAGE_DIM = 84\n", 113 | "PERFORMANCE_SAVE_INTERVAL = 500 # episodes" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "

Save Dictionay Function

" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 6, 126 | "metadata": { 127 | "collapsed": true 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "def save_obj(obj, name ):\n", 132 | " with open('data/'+ name + '.pkl', 'wb') as f:\n", 133 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "

Experience Replay

" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 7, 146 | "metadata": { 147 | "collapsed": true 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "class Memory():\n", 152 | " \n", 153 | " def __init__(self,memsize):\n", 154 | " self.memsize = memsize\n", 155 | " self.memory = deque(maxlen=self.memsize)\n", 156 | " \n", 157 | " def add_sample(self,sample):\n", 158 | " self.memory.append(sample)\n", 159 | " \n", 160 | " def get_batch(self,size):\n", 161 | " return random.sample(self.memory,k=size)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "

Frame Collector

" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 5, 174 | "metadata": { 175 | "collapsed": true 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "class FrameCollector():\n", 180 | " \n", 181 | " def __init__(self,num_frames,img_dim):\n", 182 | " self.num_frames = num_frames\n", 183 | " self.img_dim = img_dim\n", 184 | " self.frames = deque(maxlen=self.num_frames)\n", 185 | " \n", 186 | " def reset(self):\n", 187 | " tmp = np.zeros((self.img_dim,self.img_dim))\n", 188 | " for i in range(0,self.num_frames):\n", 189 | " self.frames.append(tmp)\n", 190 | " \n", 191 | " def add_frame(self,frame):\n", 192 | " self.frames.append(frame)\n", 193 | " \n", 194 | " def get_state(self):\n", 195 | " return np.array(self.frames)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "

Preprocess Images

" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 6, 208 | "metadata": { 209 | "collapsed": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "def preprocess_image(image):\n", 214 | " image = rgb2gray(image) # this automatically scales the color for block between 0 - 1\n", 215 | " return np.copy(image)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 7, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "" 227 | ] 228 | }, 229 | "execution_count": 7, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | }, 233 | { 234 | "data": { 235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADLhJREFUeJzt3V2MXPV5x/Hvr14ICUljm7SWi0kxigVCVTGRlYLggpLSEhpBLqIUlEhpldY3qUraSsG0Fy2VIiVSlYSLKpIFSVGV8hKHJhYXSV2HpL1ysDFtwcbBJBBs+YUKyNsFqsPTizluF7p4zu7O7O7h//1Iq5lz5uX8j45+c15m9nlSVUhqyy8s9wAkLT2DLzXI4EsNMvhSgwy+1CCDLzXI4EsNWlTwk1yf5FCSw0m2TWpQkqYrC/0BT5JVwPeA64AjwCPALVV1YHLDkzQNM4t47XuAw1X1fYAk9wE3Aa8b/CT+TFCasqrKuOcs5lD/fOC5WdNHunmSVrjF7PF7SbIV2Drt5UjqbzHBPwpcMGt6QzfvVapqO7AdPNSXVorFHOo/AmxKsjHJ2cDNwM7JDEvSNC14j19Vp5L8MfBNYBXwxap6YmIjkzQ1C/46b0EL81BfmrppX9WXNFAGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUFjg5/ki0lOJnl81ry1SXYleaq7XTPdYUqapD57/L8Hrn/NvG3A7qraBOzupiUNxNjgV9W/Ai+8ZvZNwD3d/XuAD0x4XJKmaKHn+Ouq6lh3/ziwbkLjkbQEFt1Jp6rqTNVz7aQjrTwL3eOfSLIeoLs9+XpPrKrtVbWlqrYscFmSJmyhwd8JfLS7/1Hg65MZjqSlMLahRpJ7gWuAdwAngL8CvgY8ALwTeBb4UFW99gLgXO9lQw1pyvo01LCTjvQGYycdSXMy+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw3q00nngiQPJzmQ5Ikkt3bz7aYjDVSfmnvrgfVV9WiStwH7GDXQ+H3ghar6dJJtwJqqum3Me1l6S5qyiZTeqqpjVfVod/8nwEHgfOymIw3WvBpqJLkQuBzYQ89uOjbUkFae3lV2k7wV+A7wqap6MMlLVbV61uMvVtUZz/M91Jemb2JVdpOcBXwV+HJVPdjN7t1NR9LK0ueqfoC7gYNV9dlZD9lNRxqoPlf1rwb+DfhP4JVu9l8wOs+fVzcdD/Wl6bOTjtQgO+lImpPBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxo0r5p7mr6l/DfpIRrVhdFiuceXGmTwpQb1qbl3TpLvJvn3rpPOHd38jUn2JDmc5P4kZ09/uJImoc8e/2Xg2qq6DNgMXJ/kCuAzwOeq6l3Ai8DHpjdMSZPUp5NOVdVPu8mzur8CrgV2dPPtpCMNSN+6+quSPMaodv4u4Gngpao61T3lCKO2WnO9dmuSvUn2TmLAkhavV/Cr6udVtRnYALwHuKTvAqpqe1VtqaotCxyjpAmb11X9qnoJeBi4Elid5PTvADYARyc8NklT0ueq/i8lWd3dfzNwHaOOuQ8DH+yeZicdaUD6dNL5dUYX71Yx+qB4oKr+JslFwH3AWmA/8JGqennMe/mztDH85d6Z+cu98eykM0AG/8wM/nh20pE0J4MvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UoN7B70ps70/yUDdtJx1poOazx7+VUZHN0+ykIw1U34YaG4DfBe7qpoOddKTB6rvH/zzwSeCVbvo87KQjDVafuvrvB05W1b6FLMBOOtLKMzP+KVwF3JjkBuAc4BeBO+k66XR7fTvpSAPSp1vu7VW1oaouBG4GvlVVH8ZOOtJgLeZ7/NuAP0tymNE5/92TGZKkabOTzgpjJ50zs5POeHbSkTQngy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoP61NwjyTPAT4CfA6eqakuStcD9wIXAM8CHqurF6QxT0iTNZ4//m1W1eVa13G3A7qraBOzupiUNwGIO9W9i1EgDbKghDUrf4Bfwz0n2JdnazVtXVce6+8eBdRMfnaSp6HWOD1xdVUeT/DKwK8mTsx+sqnq9QprdB8XWuR6TtDzmXWU3yV8DPwX+CLimqo4lWQ98u6ouHvNaS8iOYZXdM7PK7ngTqbKb5Nwkbzt9H/ht4HFgJ6NGGmBDDWlQxu7xk1wE/FM3OQP8Y1V9Ksl5wAPAO4FnGX2d98KY93J3NoZ7/DNzjz9enz2+DTVWGIN/ZgZ/PBtqSJqTwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQ3//O0xLxl2laCu7xpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qUK/gJ1mdZEeSJ5McTHJlkrVJdiV5qrtdM+3BSpqMvnv8O4FvVNUlwGXAQeykIw1Wn2KbbwceAy6qWU9OcgjLa0srzqRq7m0Enge+lGR/kru6Mtt20pEGqk/wZ4B3A1+oqsuBn/Gaw/ruSOB1O+kk2Ztk72IHK2ky+gT/CHCkqvZ00zsYfRCc6A7x6W5PzvXiqtpeVVtmddmVtMzGBr+qjgPPJTl9/v5e4AB20pEGq1dDjSSbgbuAs4HvA3/A6EPDTjrSCmMnHalBdtKRNCeDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1CCDLzXI4EsNMvhSgwy+1KCxwU9ycZLHZv39OMkn7KQjDde8Sm8lWQUcBX4D+DjwQlV9Osk2YE1V3Tbm9ZbekqZsGqW33gs8XVXPAjcB93Tz7wE+MM/3krRM5hv8m4F7u/t20pEGqnfwk5wN3Ah85bWP2UlHGpb57PHfBzxaVSe6aTvpSAM1n+Dfwv8d5oOddKTB6ttJ51zgh4xaZf+om3cedtKRVhw76UgNspOOpDkZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQb1Cn6SP03yRJLHk9yb5JwkG5PsSXI4yf1dFV5JA9Cnhdb5wJ8AW6rq14BVjOrrfwb4XFW9C3gR+Ng0Byppcvoe6s8Ab04yA7wFOAZcC+zoHreTjjQgY4NfVUeBv2VUZfcY8CNgH/BSVZ3qnnYEOH9ag5Q0WX0O9dcw6pO3EfgV4Fzg+r4LsJOOtPLM9HjObwE/qKrnAZI8CFwFrE4y0+31NzDqovv/VNV2YHv3WstrSytAn3P8HwJXJHlLkjDqmHsAeBj4YPccO+lIA9K3k84dwO8Bp4D9wB8yOqe/D1jbzftIVb085n3c40tTZicdqUF20pE0J4MvNcjgSw0y+FKD+nyPP0n/Bfysu32jeAeuz0r1RloX6Lc+v9rnjZb0qj5Akr1VtWVJFzpFrs/K9UZaF5js+nioLzXI4EsNWo7gb1+GZU6T67NyvZHWBSa4Pkt+ji9p+XmoLzVoSYOf5Pokh7o6fduWctmLleSCJA8nOdDVH7y1m782ya4kT3W3a5Z7rPORZFWS/Uke6qYHW0sxyeokO5I8meRgkiuHvH2mWetyyYKfZBXwd8D7gEuBW5JculTLn4BTwJ9X1aXAFcDHu/FvA3ZX1SZgdzc9JLcCB2dND7mW4p3AN6rqEuAyRus1yO0z9VqXVbUkf8CVwDdnTd8O3L5Uy5/C+nwduA44BKzv5q0HDi332OaxDhsYheFa4CEgjH4gMjPXNlvJf8DbgR/QXbeaNX+Q24fRv70/x+jf3me67fM7k9o+S3mof3pFThtsnb4kFwKXA3uAdVV1rHvoOLBumYa1EJ8HPgm80k2fx3BrKW4Enge+1J263JXkXAa6fWrKtS69uDdPSd4KfBX4RFX9ePZjNfoYHsTXJEneD5ysqn3LPZYJmQHeDXyhqi5n9NPwVx3WD2z7LKrW5ThLGfyjwAWzpl+3Tt9KleQsRqH/clU92M0+kWR99/h64ORyjW+ergJuTPIMo0pK1zI6R17dlVGHYW2jI8CRqtrTTe9g9EEw1O3zv7Uuq+q/gVfVuuyes+Dts5TBfwTY1F2VPJvRhYqdS7j8RenqDd4NHKyqz856aCejmoMwoNqDVXV7VW2oqgsZbYtvVdWHGWgtxao6DjyX5OJu1unakIPcPky71uUSX7C4Afge8DTwl8t9AWWeY7+a0WHifwCPdX83MDov3g08BfwLsHa5x7qAdbsGeKi7fxHwXeAw8BXgTcs9vnmsx2Zgb7eNvgasGfL2Ae4AngQeB/4BeNOkto+/3JMa5MU9qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBv0PANIhuBFMr1IAAAAASUVORK5CYII=\n", 236 | "text/plain": [ 237 | "
" 238 | ] 239 | }, 240 | "metadata": {}, 241 | "output_type": "display_data" 242 | } 243 | ], 244 | "source": [ 245 | "processed_prev_state = preprocess_image(prev_state)\n", 246 | "plt.imshow(processed_prev_state,cmap='gray')" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "

Build Model

" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 8, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Network(\n", 266 | " (conv_layer1): Conv2d(2, 32, kernel_size=(8, 8), stride=(4, 4))\n", 267 | " (conv_layer2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n", 268 | " (conv_layer3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", 269 | " (fc1): Linear(in_features=3136, out_features=512, bias=True)\n", 270 | " (fc2): Linear(in_features=512, out_features=4, bias=True)\n", 271 | " (relu): ReLU()\n", 272 | ")\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "import torch.nn as nn\n", 278 | "import torch\n", 279 | "\n", 280 | "class Network(nn.Module):\n", 281 | " \n", 282 | " def __init__(self,image_input_size,out_size):\n", 283 | " super(Network,self).__init__()\n", 284 | " self.image_input_size = image_input_size\n", 285 | " self.out_size = out_size\n", 286 | "\n", 287 | " self.conv_layer1 = nn.Conv2d(in_channels=2,out_channels=32,kernel_size=8,stride=4) # GRAY - 1\n", 288 | " self.conv_layer2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4,stride=2)\n", 289 | " self.conv_layer3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1)\n", 290 | " self.fc1 = nn.Linear(in_features=7*7*64,out_features=512)\n", 291 | " self.fc2 = nn.Linear(in_features=512,out_features=OUTPUT_SIZE)\n", 292 | " self.relu = nn.ReLU()\n", 293 | "\n", 294 | " def forward(self,x,bsize):\n", 295 | " x = x.view(bsize,2,self.image_input_size,self.image_input_size) # (N,Cin,H,W) batch size, input channel, height , width\n", 296 | " conv_out = self.conv_layer1(x)\n", 297 | " conv_out = self.relu(conv_out)\n", 298 | " conv_out = self.conv_layer2(conv_out)\n", 299 | " conv_out = self.relu(conv_out)\n", 300 | " conv_out = self.conv_layer3(conv_out)\n", 301 | " conv_out = self.relu(conv_out)\n", 302 | " out = self.fc1(conv_out.view(bsize,7*7*64))\n", 303 | " out = self.relu(out)\n", 304 | " out = self.fc2(out)\n", 305 | " return out\n", 306 | "\n", 307 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).cuda()\n", 308 | "print(main_model)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "

Deep Q Learning with Freeze Network

" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "Populated 60000 Samples in Episodes : 1200\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "mem = Memory(memsize=MEMORY_SIZE)\n", 333 | "main_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Primary Network\n", 334 | "target_model = Network(image_input_size=INPUT_IMAGE_DIM,out_size=OUTPUT_SIZE).float().cuda() # Target Network\n", 335 | "frameObj = FrameCollector(img_dim=INPUT_IMAGE_DIM,num_frames=2)\n", 336 | "\n", 337 | "target_model.load_state_dict(main_model.state_dict())\n", 338 | "criterion = nn.SmoothL1Loss()\n", 339 | "optimizer = torch.optim.Adam(main_model.parameters())\n", 340 | "\n", 341 | "# filling memory with transitions\n", 342 | "for i in range(0,int(MEMORY_SIZE/MAX_STEPS)):\n", 343 | " \n", 344 | " prev_state = env.reset()\n", 345 | " frameObj.reset()\n", 346 | " processed_prev_state = preprocess_image(prev_state)\n", 347 | " frameObj.add_frame(processed_prev_state)\n", 348 | " prev_frames = frameObj.get_state()\n", 349 | " step_count = 0\n", 350 | " game_over = False\n", 351 | " \n", 352 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 353 | " \n", 354 | " step_count +=1\n", 355 | " action = np.random.randint(0,4)\n", 356 | " next_state,reward, game_over = env.step(action)\n", 357 | " processed_next_state = preprocess_image(next_state)\n", 358 | " frameObj.add_frame(processed_next_state)\n", 359 | " next_frames = frameObj.get_state()\n", 360 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n", 361 | " \n", 362 | " prev_state = next_state\n", 363 | " processed_prev_state = processed_next_state\n", 364 | " prev_frames = next_frames\n", 365 | "\n", 366 | "print('Populated %d Samples in Episodes : %d'%(len(mem.memory),int(MEMORY_SIZE/MAX_STEPS)))\n", 367 | "\n", 368 | "\n", 369 | "# Algorithm Starts\n", 370 | "total_steps = 0\n", 371 | "epsilon = INITIAL_EPSILON\n", 372 | "loss_stat = []\n", 373 | "total_reward_stat = []\n", 374 | "\n", 375 | "for episode in range(0,TOTAL_EPISODES):\n", 376 | " \n", 377 | " prev_state = env.reset()\n", 378 | " frameObj.reset()\n", 379 | " processed_prev_state = preprocess_image(prev_state)\n", 380 | " frameObj.add_frame(processed_prev_state)\n", 381 | " prev_frames = frameObj.get_state()\n", 382 | " game_over = False\n", 383 | " step_count = 0\n", 384 | " total_reward = 0\n", 385 | " \n", 386 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 387 | " \n", 388 | " step_count +=1\n", 389 | " total_steps +=1\n", 390 | " \n", 391 | " if np.random.rand() <= epsilon:\n", 392 | " action = np.random.randint(0,4)\n", 393 | " else:\n", 394 | " with torch.no_grad():\n", 395 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n", 396 | "\n", 397 | " model_out = main_model.forward(torch_x,bsize=1)\n", 398 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 399 | " \n", 400 | " next_state, reward, game_over = env.step(action)\n", 401 | " processed_next_state = preprocess_image(next_state)\n", 402 | " frameObj.add_frame(processed_next_state)\n", 403 | " next_frames = frameObj.get_state()\n", 404 | " total_reward += reward\n", 405 | " \n", 406 | " mem.add_sample((prev_frames,action,reward,next_frames,game_over))\n", 407 | " \n", 408 | " prev_state = next_state\n", 409 | " processed_prev_state = processed_next_state\n", 410 | " prev_frames = next_frames\n", 411 | " \n", 412 | " if (total_steps % FREEZE_INTERVAL) == 0:\n", 413 | " target_model.load_state_dict(main_model.state_dict())\n", 414 | " \n", 415 | " batch = mem.get_batch(size=BATCH_SIZE)\n", 416 | " current_states = []\n", 417 | " next_states = []\n", 418 | " acts = []\n", 419 | " rewards = []\n", 420 | " game_status = []\n", 421 | " \n", 422 | " for element in batch:\n", 423 | " current_states.append(element[0])\n", 424 | " acts.append(element[1])\n", 425 | " rewards.append(element[2])\n", 426 | " next_states.append(element[3])\n", 427 | " game_status.append(element[4])\n", 428 | " \n", 429 | " current_states = np.array(current_states)\n", 430 | " next_states = np.array(next_states)\n", 431 | " rewards = np.array(rewards)\n", 432 | " game_status = [not b for b in game_status]\n", 433 | " game_status_bool = np.array(game_status,dtype='float') # FALSE 1, TRUE 0\n", 434 | " torch_acts = torch.tensor(acts)\n", 435 | " \n", 436 | " Q_next = target_model.forward(torch.from_numpy(next_states).float().cuda(),bsize=BATCH_SIZE)\n", 437 | " Q_s = main_model.forward(torch.from_numpy(current_states).float().cuda(),bsize=BATCH_SIZE)\n", 438 | " Q_max_next, _ = Q_next.detach().max(dim=1)\n", 439 | " Q_max_next = Q_max_next.double()\n", 440 | " Q_max_next = torch.from_numpy(game_status_bool).cuda()*Q_max_next\n", 441 | " \n", 442 | " target_values = (rewards + (GAMMA * Q_max_next))\n", 443 | " Q_s_a = Q_s.gather(dim=1,index=torch_acts.cuda().unsqueeze(dim=1)).squeeze(dim=1)\n", 444 | " \n", 445 | " loss = criterion(Q_s_a,target_values.float().cuda())\n", 446 | " \n", 447 | " # save performance measure\n", 448 | " loss_stat.append(loss.item())\n", 449 | " \n", 450 | " # make previous grad zero\n", 451 | " optimizer.zero_grad()\n", 452 | " \n", 453 | " # back - propogate \n", 454 | " loss.backward()\n", 455 | " \n", 456 | " # update params\n", 457 | " optimizer.step()\n", 458 | " \n", 459 | " # save performance measure\n", 460 | " total_reward_stat.append(total_reward)\n", 461 | " \n", 462 | " if epsilon > FINAL_EPSILON:\n", 463 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 464 | " \n", 465 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 466 | " perf = {}\n", 467 | " perf['loss'] = loss_stat\n", 468 | " perf['total_reward'] = total_reward_stat\n", 469 | " save_obj(name='TWO_OBSERV_NINE',obj=perf)\n", 470 | " \n", 471 | " #print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Loss : ',loss.item(),'Steps : ',step_count)\n" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "

Save Primary Network Weights

" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 19, 484 | "metadata": { 485 | "collapsed": true 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "torch.save(main_model.state_dict(),'data/TWO_OBSERV_NINE_WEIGHTS.torch')" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": {}, 495 | "source": [ 496 | "

Testing Policy

" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "

Load Primary Network Weights

" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 10, 509 | "metadata": {}, 510 | "outputs": [], 511 | "source": [ 512 | "weights = torch.load('data/TWO_OBSERV_NINE_WEIGHTS.torch')\n", 513 | "main_model.load_state_dict(weights)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "

Testing Policy

" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": { 527 | "collapsed": true 528 | }, 529 | "outputs": [], 530 | "source": [ 531 | "# Algorithm Starts\n", 532 | "epsilon = INITIAL_EPSILON\n", 533 | "FINAL_EPSILON = 0.01\n", 534 | "total_reward_stat = []\n", 535 | "\n", 536 | "for episode in range(0,TOTAL_EPISODES):\n", 537 | " \n", 538 | " prev_state = env.reset()\n", 539 | " processed_prev_state = preprocess_image(prev_state)\n", 540 | " frameObj.reset()\n", 541 | " frameObj.add_frame(processed_prev_state)\n", 542 | " prev_frames = frameObj.get_state()\n", 543 | " game_over = False\n", 544 | " step_count = 0\n", 545 | " total_reward = 0\n", 546 | " \n", 547 | " while (game_over == False) and (step_count < MAX_STEPS):\n", 548 | " \n", 549 | " step_count +=1\n", 550 | " \n", 551 | " if np.random.rand() <= epsilon:\n", 552 | " action = np.random.randint(0,4)\n", 553 | " else:\n", 554 | " with torch.no_grad():\n", 555 | " torch_x = torch.from_numpy(prev_frames).float().cuda()\n", 556 | "\n", 557 | " model_out = main_model.forward(torch_x,bsize=1)\n", 558 | " action = int(torch.argmax(model_out.view(OUTPUT_SIZE),dim=0))\n", 559 | " \n", 560 | " next_state, reward, game_over = env.step(action)\n", 561 | " processed_next_state = preprocess_image(next_state)\n", 562 | " frameObj.add_frame(processed_next_state)\n", 563 | " next_frames = frameObj.get_state()\n", 564 | " \n", 565 | " total_reward += reward\n", 566 | " \n", 567 | " prev_state = next_state\n", 568 | " processed_prev_state = processed_next_state\n", 569 | " prev_frames = next_frames\n", 570 | " \n", 571 | " # save performance measure\n", 572 | " total_reward_stat.append(total_reward)\n", 573 | " \n", 574 | " if epsilon > FINAL_EPSILON:\n", 575 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/TOTAL_EPISODES\n", 576 | " \n", 577 | " if (episode + 1)% PERFORMANCE_SAVE_INTERVAL == 0:\n", 578 | " perf = {}\n", 579 | " perf['total_reward'] = total_reward_stat\n", 580 | " save_obj(name='TWO_OBSERV_NINE',obj=perf)\n", 581 | " \n", 582 | " print('Completed episode : ',episode+1,' Epsilon : ',epsilon,' Reward : ',total_reward,'Steps : ',step_count)" 583 | ] 584 | } 585 | ], 586 | "metadata": { 587 | "kernelspec": { 588 | "display_name": "Python [conda env:myenv]", 589 | "language": "python", 590 | "name": "conda-env-myenv-py" 591 | }, 592 | "language_info": { 593 | "codemirror_mode": { 594 | "name": "ipython", 595 | "version": 3 596 | }, 597 | "file_extension": ".py", 598 | "mimetype": "text/x-python", 599 | "name": "python", 600 | "nbconvert_exporter": "python", 601 | "pygments_lexer": "ipython3", 602 | "version": "3.6.5" 603 | } 604 | }, 605 | "nbformat": 4, 606 | "nbformat_minor": 2 607 | } 608 | -------------------------------------------------------------------------------- /__pycache__/gridworld.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/__pycache__/gridworld.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/gridworld.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/__pycache__/gridworld.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/__pycache__/helper.cpython-36.pyc -------------------------------------------------------------------------------- /data/FOUR_OBSERV_NINE.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/FOUR_OBSERV_NINE.pkl -------------------------------------------------------------------------------- /data/FOUR_OBSERV_NINE_WEIGHTS.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/FOUR_OBSERV_NINE_WEIGHTS.torch -------------------------------------------------------------------------------- /data/GIFs/LSTM_SIZE_9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/LSTM_SIZE_9.gif -------------------------------------------------------------------------------- /data/GIFs/LSTM_SIZE_9_frames.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/LSTM_SIZE_9_frames.gif -------------------------------------------------------------------------------- /data/GIFs/LSTM_SIZE_9_local.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/LSTM_SIZE_9_local.gif -------------------------------------------------------------------------------- /data/GIFs/MDP_SIZE_9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/MDP_SIZE_9.gif -------------------------------------------------------------------------------- /data/GIFs/SINGEL_SIZE_9_local.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/SINGEL_SIZE_9_local.gif -------------------------------------------------------------------------------- /data/GIFs/SINGLE_OBSERV_9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/SINGLE_OBSERV_9.gif -------------------------------------------------------------------------------- /data/GIFs/SINGLE_SIZE_9_frames.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/SINGLE_SIZE_9_frames.gif -------------------------------------------------------------------------------- /data/GIFs/perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/GIFs/perf.png -------------------------------------------------------------------------------- /data/LSTM_POMDP_V4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/LSTM_POMDP_V4.pkl -------------------------------------------------------------------------------- /data/LSTM_POMDP_V4_TEST.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/LSTM_POMDP_V4_TEST.pkl -------------------------------------------------------------------------------- /data/LSTM_POMDP_V4_WEIGHTS.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/LSTM_POMDP_V4_WEIGHTS.torch -------------------------------------------------------------------------------- /data/MDP_ENV_SIZE_NINE.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/MDP_ENV_SIZE_NINE.pkl -------------------------------------------------------------------------------- /data/MDP_ENV_SIZE_NINE_WEIGHTS.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/MDP_ENV_SIZE_NINE_WEIGHTS.torch -------------------------------------------------------------------------------- /data/SINGLE_OBSERV_NINE.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/SINGLE_OBSERV_NINE.pkl -------------------------------------------------------------------------------- /data/SINGLE_OBSERV_NINE_WEIGHTS.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/SINGLE_OBSERV_NINE_WEIGHTS.torch -------------------------------------------------------------------------------- /data/TWO_OBSERV_NINE.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/TWO_OBSERV_NINE.pkl -------------------------------------------------------------------------------- /data/TWO_OBSERV_NINE_WEIGHTS.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/TWO_OBSERV_NINE_WEIGHTS.torch -------------------------------------------------------------------------------- /data/algo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/algo.png -------------------------------------------------------------------------------- /data/download (1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/download (1).png -------------------------------------------------------------------------------- /data/download.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mynkpl1998/Recurrent-Deep-Q-Learning/ccec7cc282aefa95d6ab73dec4c5fc8826cf8226/data/download.png -------------------------------------------------------------------------------- /gridworld.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import itertools 4 | import scipy.misc 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class gameOb(): 9 | def __init__(self,coordinates,size,intensity,channel,reward,name): 10 | self.x = coordinates[0] 11 | self.y = coordinates[1] 12 | self.size = size 13 | self.intensity = intensity 14 | self.channel = channel 15 | self.reward = reward 16 | self.name = name 17 | 18 | class gameEnv(): 19 | def __init__(self,partial,size): 20 | self.sizeX = size 21 | self.sizeY = size 22 | self.actions = 4 23 | self.objects = [] 24 | self.partial = partial 25 | a = self.reset() 26 | plt.imshow(a,interpolation="nearest") 27 | 28 | 29 | def reset(self): 30 | self.objects = [] 31 | hero = gameOb(self.newPosition(),1,1,2,None,'hero') 32 | self.objects.append(hero) 33 | bug = gameOb(self.newPosition(),1,1,1,1,'goal') 34 | self.objects.append(bug) 35 | hole = gameOb(self.newPosition(),1,1,0,-1,'fire') 36 | self.objects.append(hole) 37 | bug2 = gameOb(self.newPosition(),1,1,1,1,'goal') 38 | self.objects.append(bug2) 39 | hole2 = gameOb(self.newPosition(),1,1,0,-1,'fire') 40 | self.objects.append(hole2) 41 | bug3 = gameOb(self.newPosition(),1,1,1,1,'goal') 42 | self.objects.append(bug3) 43 | bug4 = gameOb(self.newPosition(),1,1,1,1,'goal') 44 | self.objects.append(bug4) 45 | state = self.renderEnv() 46 | self.state = state 47 | return state 48 | 49 | def moveChar(self,direction): 50 | # 0 - up, 1 - down, 2 - left, 3 - right 51 | hero = self.objects[0] 52 | heroX = hero.x 53 | heroY = hero.y 54 | penalize = 0. 55 | if direction == 0 and hero.y >= 1: 56 | hero.y -= 1 57 | if direction == 1 and hero.y <= self.sizeY-2: 58 | hero.y += 1 59 | if direction == 2 and hero.x >= 1: 60 | hero.x -= 1 61 | if direction == 3 and hero.x <= self.sizeX-2: 62 | hero.x += 1 63 | if hero.x == heroX and hero.y == heroY: 64 | penalize = 0.0 65 | self.objects[0] = hero 66 | return penalize 67 | 68 | def newPosition(self): 69 | iterables = [ range(self.sizeX), range(self.sizeY)] 70 | points = [] 71 | for t in itertools.product(*iterables): 72 | points.append(t) 73 | currentPositions = [] 74 | for objectA in self.objects: 75 | if (objectA.x,objectA.y) not in currentPositions: 76 | currentPositions.append((objectA.x,objectA.y)) 77 | for pos in currentPositions: 78 | points.remove(pos) 79 | location = np.random.choice(range(len(points)),replace=False) 80 | return points[location] 81 | 82 | def checkGoal(self): 83 | others = [] 84 | for obj in self.objects: 85 | if obj.name == 'hero': 86 | hero = obj 87 | else: 88 | others.append(obj) 89 | ended = False 90 | for other in others: 91 | if hero.x == other.x and hero.y == other.y: 92 | self.objects.remove(other) 93 | if other.reward == 1: 94 | self.objects.append(gameOb(self.newPosition(),1,1,1,1,'goal')) 95 | else: 96 | self.objects.append(gameOb(self.newPosition(),1,1,0,-1,'fire')) 97 | return other.reward,False 98 | if ended == False: 99 | return 0.0,False 100 | 101 | def renderEnv(self): 102 | #a = np.zeros([self.sizeY,self.sizeX,3]) 103 | a = np.ones([self.sizeY+2,self.sizeX+2,3]) 104 | a[1:-1,1:-1,:] = 0 105 | hero = None 106 | for item in self.objects: 107 | a[item.y+1:item.y+item.size+1,item.x+1:item.x+item.size+1,item.channel] = item.intensity 108 | if item.name == 'hero': 109 | hero = item 110 | if self.partial == True: 111 | a = a[hero.y:hero.y+3,hero.x:hero.x+3,:] 112 | b = scipy.misc.imresize(a[:,:,0],[84,84,1],interp='nearest') 113 | c = scipy.misc.imresize(a[:,:,1],[84,84,1],interp='nearest') 114 | d = scipy.misc.imresize(a[:,:,2],[84,84,1],interp='nearest') 115 | a = np.stack([b,c,d],axis=2) 116 | return a 117 | 118 | def step(self,action): 119 | penalty = self.moveChar(action) 120 | reward,done = self.checkGoal() 121 | state = self.renderEnv() 122 | if reward == None: 123 | print(done) 124 | print(reward) 125 | print(penalty) 126 | return state,(reward+penalty),done 127 | else: 128 | return state,(reward+penalty),done --------------------------------------------------------------------------------