├── .gitignore ├── README.md ├── frozen_lake.ipynb ├── gym_basic ├── __init__.py └── envs │ ├── __init__.py │ ├── basic_env.py │ └── basic_env_2.py ├── gym_basic_env_test.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,pydev,macos,windows,pydev,sublimetext,visualstudiocode,jupyternotebooks 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pydev,macos,windows,pydev,sublimetext,visualstudiocode,jupyternotebooks 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | # Thumbnails 29 | ._* 30 | 31 | # Files that might appear in the root of a volume 32 | .DocumentRevisions-V100 33 | .fseventsd 34 | .Spotlight-V100 35 | .TemporaryItems 36 | .Trashes 37 | .VolumeIcon.icns 38 | .com.apple.timemachine.donotpresent 39 | 40 | # Directories potentially created on remote AFP share 41 | .AppleDB 42 | .AppleDesktop 43 | Network Trash Folder 44 | Temporary Items 45 | .apdisk 46 | 47 | ### pydev ### 48 | .pydevproject 49 | 50 | ### Python ### 51 | # Byte-compiled / optimized / DLL files 52 | __pycache__/ 53 | *.py[cod] 54 | *$py.class 55 | 56 | # C extensions 57 | *.so 58 | 59 | # Distribution / packaging 60 | .Python 61 | build/ 62 | develop-eggs/ 63 | dist/ 64 | downloads/ 65 | eggs/ 66 | .eggs/ 67 | lib/ 68 | lib64/ 69 | parts/ 70 | sdist/ 71 | var/ 72 | wheels/ 73 | pip-wheel-metadata/ 74 | share/python-wheels/ 75 | *.egg-info/ 76 | .installed.cfg 77 | *.egg 78 | MANIFEST 79 | 80 | # PyInstaller 81 | # Usually these files are written by a python script from a template 82 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 83 | *.manifest 84 | *.spec 85 | 86 | # Installer logs 87 | pip-log.txt 88 | pip-delete-this-directory.txt 89 | 90 | # Unit test / coverage reports 91 | htmlcov/ 92 | .tox/ 93 | .nox/ 94 | .coverage 95 | .coverage.* 96 | .cache 97 | nosetests.xml 98 | coverage.xml 99 | *.cover 100 | *.py,cover 101 | .hypothesis/ 102 | .pytest_cache/ 103 | pytestdebug.log 104 | 105 | # Translations 106 | *.mo 107 | *.pot 108 | 109 | # Django stuff: 110 | *.log 111 | local_settings.py 112 | db.sqlite3 113 | db.sqlite3-journal 114 | 115 | # Flask stuff: 116 | instance/ 117 | .webassets-cache 118 | 119 | # Scrapy stuff: 120 | .scrapy 121 | 122 | # Sphinx documentation 123 | docs/_build/ 124 | doc/_build/ 125 | 126 | # PyBuilder 127 | target/ 128 | 129 | # Jupyter Notebook 130 | 131 | # IPython 132 | 133 | # pyenv 134 | .python-version 135 | 136 | # pipenv 137 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 138 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 139 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 140 | # install all needed dependencies. 141 | #Pipfile.lock 142 | 143 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 144 | __pypackages__/ 145 | 146 | # Celery stuff 147 | celerybeat-schedule 148 | celerybeat.pid 149 | 150 | # SageMath parsed files 151 | *.sage.py 152 | 153 | # Environments 154 | .env 155 | .venv 156 | env/ 157 | venv/ 158 | ENV/ 159 | env.bak/ 160 | venv.bak/ 161 | pythonenv* 162 | 163 | # Spyder project settings 164 | .spyderproject 165 | .spyproject 166 | 167 | # Rope project settings 168 | .ropeproject 169 | 170 | # mkdocs documentation 171 | /site 172 | 173 | # mypy 174 | .mypy_cache/ 175 | .dmypy.json 176 | dmypy.json 177 | 178 | # Pyre type checker 179 | .pyre/ 180 | 181 | # pytype static type analyzer 182 | .pytype/ 183 | 184 | # profiling data 185 | .prof 186 | 187 | ### SublimeText ### 188 | # Cache files for Sublime Text 189 | *.tmlanguage.cache 190 | *.tmPreferences.cache 191 | *.stTheme.cache 192 | 193 | # Workspace files are user-specific 194 | *.sublime-workspace 195 | 196 | # Project files should be checked into the repository, unless a significant 197 | # proportion of contributors will probably not be using Sublime Text 198 | # *.sublime-project 199 | 200 | # SFTP configuration file 201 | sftp-config.json 202 | 203 | # Package control specific files 204 | Package Control.last-run 205 | Package Control.ca-list 206 | Package Control.ca-bundle 207 | Package Control.system-ca-bundle 208 | Package Control.cache/ 209 | Package Control.ca-certs/ 210 | Package Control.merged-ca-bundle 211 | Package Control.user-ca-bundle 212 | oscrypto-ca-bundle.crt 213 | bh_unicode_properties.cache 214 | 215 | # Sublime-github package stores a github token in this file 216 | # https://packagecontrol.io/packages/sublime-github 217 | GitHub.sublime-settings 218 | 219 | ### VisualStudioCode ### 220 | .vscode/* 221 | !.vscode/tasks.json 222 | !.vscode/launch.json 223 | *.code-workspace 224 | 225 | ### VisualStudioCode Patch ### 226 | # Ignore all local history of files 227 | .history 228 | .ionide 229 | 230 | ### Windows ### 231 | # Windows thumbnail cache files 232 | Thumbs.db 233 | Thumbs.db:encryptable 234 | ehthumbs.db 235 | ehthumbs_vista.db 236 | 237 | # Dump file 238 | *.stackdump 239 | 240 | # Folder config file 241 | [Dd]esktop.ini 242 | 243 | # Recycle Bin used on file shares 244 | $RECYCLE.BIN/ 245 | 246 | # Windows Installer files 247 | *.cab 248 | *.msi 249 | *.msix 250 | *.msm 251 | *.msp 252 | 253 | # Windows shortcuts 254 | *.lnk 255 | 256 | # End of https://www.toptal.com/developers/gitignore/api/python,pydev,macos,windows,pydev,sublimetext,visualstudiocode,jupyternotebooks 257 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Setting up & testing a custom environment in Gym. 2 | 3 | For more info, have a look at my article: 4 | -------------------------------------------------------------------------------- /frozen_lake.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Following tutorial here: \n", 8 | "https://www.youtube.com/watch?v=QK_PP_2KgGE&list=PLZbbT5o_s2xoWNVdDudn51XM8lOuZ_Njv&index=8" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 2, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import numpy as np\n", 18 | "import gym\n", 19 | "import random\n", 20 | "import time\n", 21 | "from IPython.display import clear_output" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "env = gym.make(\"FrozenLake-v0\")" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 4, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "[[0. 0. 0. 0.]\n", 43 | " [0. 0. 0. 0.]\n", 44 | " [0. 0. 0. 0.]\n", 45 | " [0. 0. 0. 0.]\n", 46 | " [0. 0. 0. 0.]\n", 47 | " [0. 0. 0. 0.]\n", 48 | " [0. 0. 0. 0.]\n", 49 | " [0. 0. 0. 0.]\n", 50 | " [0. 0. 0. 0.]\n", 51 | " [0. 0. 0. 0.]\n", 52 | " [0. 0. 0. 0.]\n", 53 | " [0. 0. 0. 0.]\n", 54 | " [0. 0. 0. 0.]\n", 55 | " [0. 0. 0. 0.]\n", 56 | " [0. 0. 0. 0.]\n", 57 | " [0. 0. 0. 0.]]\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "action_space_size = env.action_space.n\n", 63 | "state_space_size = env.observation_space.n\n", 64 | "\n", 65 | "q_table = np.zeros((state_space_size, action_space_size))\n", 66 | "\n", 67 | "print(q_table)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 9, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "num_episodes = 10000\n", 77 | "max_steps_per_episode = 100\n", 78 | "\n", 79 | "learning_rate = 0.1\n", 80 | "discount_rate = 0.99\n", 81 | "\n", 82 | "exploration_rate = 1\n", 83 | "max_exploration_rate = 1\n", 84 | "min_exploration_rate = 0.01\n", 85 | "\n", 86 | "exploration_decay_rate = 0.001" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 10, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "********** Average reward per thousand episodes **********\n", 99 | "\n", 100 | "1000 : 0.04000000000000003\n", 101 | "2000 : 0.21200000000000016\n", 102 | "3000 : 0.4070000000000003\n", 103 | "4000 : 0.5750000000000004\n", 104 | "5000 : 0.6110000000000004\n", 105 | "6000 : 0.6410000000000005\n", 106 | "7000 : 0.6470000000000005\n", 107 | "8000 : 0.6960000000000005\n", 108 | "9000 : 0.6930000000000005\n", 109 | "10000 : 0.7090000000000005\n", 110 | "\n", 111 | "\n", 112 | "********** Q-table **********\n", 113 | "\n", 114 | "[[0.58024171 0.54229881 0.53260582 0.53139742]\n", 115 | " [0.3371445 0.26884291 0.27842931 0.52097642]\n", 116 | " [0.41808381 0.43173947 0.39995505 0.48437107]\n", 117 | " [0.36820324 0.30313074 0.25543035 0.46334875]\n", 118 | " [0.5999856 0.39049707 0.37160209 0.42678134]\n", 119 | " [0. 0. 0. 0. ]\n", 120 | " [0.23352799 0.11040702 0.49197896 0.09404461]\n", 121 | " [0. 0. 0. 0. ]\n", 122 | " [0.37630744 0.39149089 0.37812894 0.65489582]\n", 123 | " [0.46439661 0.72593727 0.46364658 0.38923149]\n", 124 | " [0.7203016 0.39882864 0.47009292 0.26180184]\n", 125 | " [0. 0. 0. 0. ]\n", 126 | " [0. 0. 0. 0. ]\n", 127 | " [0.42688433 0.66110586 0.79314389 0.48597383]\n", 128 | " [0.78862736 0.87495019 0.77836452 0.73360119]\n", 129 | " [0. 0. 0. 0. ]]\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "rewards_all_episodes = []\n", 135 | "\n", 136 | "# Q-Learning algorithm\n", 137 | "for episode in range(num_episodes):\n", 138 | " state = env.reset()\n", 139 | " \n", 140 | " done = False\n", 141 | " rewards_current_episode = 0\n", 142 | " \n", 143 | " for step in range(max_steps_per_episode):\n", 144 | " \n", 145 | " # Exploration -exploitation trade-off\n", 146 | " exploration_rate_threshold = random.uniform(0,1)\n", 147 | " if exploration_rate_threshold > exploration_rate: \n", 148 | " action = np.argmax(q_table[state,:])\n", 149 | " else:\n", 150 | " action = env.action_space.sample()\n", 151 | " \n", 152 | " new_state, reward, done, info = env.step(action)\n", 153 | " \n", 154 | " # Update Q-table for Q(s,a)\n", 155 | " q_table[state, action] = (1 - learning_rate) * q_table[state, action] + \\\n", 156 | " learning_rate * (reward + discount_rate * np.max(q_table[new_state,:]))\n", 157 | " \n", 158 | " state = new_state\n", 159 | " rewards_current_episode += reward\n", 160 | " \n", 161 | " if done == True: \n", 162 | " break\n", 163 | " \n", 164 | " # Exploration rate decay\n", 165 | " exploration_rate = min_exploration_rate + \\\n", 166 | " (max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate * episode)\n", 167 | " \n", 168 | " rewards_all_episodes.append(rewards_current_episode)\n", 169 | " \n", 170 | "# Calculate and print the average reward per thousand episodes\n", 171 | "rewards_per_thousand_episodes = np.split(np.array(rewards_all_episodes), num_episodes / 1000)\n", 172 | "count = 1000\n", 173 | "print(\"********** Average reward per thousand episodes **********\\n\")\n", 174 | "\n", 175 | "for r in rewards_per_thousand_episodes:\n", 176 | " print(count, \": \", str(sum(r / 1000)))\n", 177 | " count += 1000\n", 178 | " \n", 179 | "# Print updated Q-table\n", 180 | "print(\"\\n\\n********** Q-table **********\\n\")\n", 181 | "print(q_table)\n", 182 | " " 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "Now, we let the agent play the game. " 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 11, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | " (Down)\n", 202 | "SFFF\n", 203 | "FHFH\n", 204 | "FFFH\n", 205 | "HFF\u001b[41mG\u001b[0m\n", 206 | "*****You reached your goal!*****\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "for episode in range(3):\n", 212 | " state = env.reset()\n", 213 | " done = False\n", 214 | " print(\"***** EPISODE \", episode + 1, \" *****\\n\\n\\n\")\n", 215 | " time.sleep(1)\n", 216 | " \n", 217 | " for step in range(max_steps_per_episode):\n", 218 | " clear_output(wait = True)\n", 219 | " env.render()\n", 220 | " time.sleep(0.3)\n", 221 | " \n", 222 | " action = np.argmax(q_table[state,:])\n", 223 | " new_state, reward, done, info = env.step(action)\n", 224 | " \n", 225 | " if done: \n", 226 | " clear_output(wait = True)\n", 227 | " env.render()\n", 228 | " if reward == 1: \n", 229 | " print(\"*****You reached your goal!*****\")\n", 230 | " time.sleep(3)\n", 231 | " else:\n", 232 | " print(\"*****You fall through a hole!*****\")\n", 233 | " time.sleep(3)\n", 234 | " clear_output(wait = True)\n", 235 | " break\n", 236 | " \n", 237 | " state = new_state\n", 238 | " \n", 239 | "env.close()" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "Python 3", 260 | "language": "python", 261 | "name": "python3" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.7.4" 274 | }, 275 | "toc": { 276 | "base_numbering": 1, 277 | "nav_menu": {}, 278 | "number_sections": true, 279 | "sideBar": true, 280 | "skip_h1_title": false, 281 | "title_cell": "Table of Contents", 282 | "title_sidebar": "Contents", 283 | "toc_cell": false, 284 | "toc_position": {}, 285 | "toc_section_display": true, 286 | "toc_window_display": false 287 | }, 288 | "varInspector": { 289 | "cols": { 290 | "lenName": 16, 291 | "lenType": 16, 292 | "lenVar": 40 293 | }, 294 | "kernels_config": { 295 | "python": { 296 | "delete_cmd_postfix": "", 297 | "delete_cmd_prefix": "del ", 298 | "library": "var_list.py", 299 | "varRefreshCmd": "print(var_dic_list())" 300 | }, 301 | "r": { 302 | "delete_cmd_postfix": ") ", 303 | "delete_cmd_prefix": "rm(", 304 | "library": "var_list.r", 305 | "varRefreshCmd": "cat(var_dic_list()) " 306 | } 307 | }, 308 | "types_to_exclude": [ 309 | "module", 310 | "function", 311 | "builtin_function_or_method", 312 | "instance", 313 | "_Feature" 314 | ], 315 | "window_display": false 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 2 320 | } 321 | -------------------------------------------------------------------------------- /gym_basic/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='basic-v0', 5 | entry_point='gym_basic.envs:BasicEnv', 6 | ) 7 | 8 | register( 9 | id='basic-v2', 10 | entry_point='gym_basic.envs:BasicEnv2', 11 | ) -------------------------------------------------------------------------------- /gym_basic/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_basic.envs.basic_env import BasicEnv 2 | from gym_basic.envs.basic_env_2 import BasicEnv2 -------------------------------------------------------------------------------- /gym_basic/envs/basic_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import error, spaces, utils 3 | from gym.utils import seeding 4 | 5 | class BasicEnv(gym.Env): 6 | metadata = {'render.modes': ['human']} 7 | 8 | def __init__(self): 9 | # There are two actions, first will get reward of 1, second reward of -1. 10 | self.action_space = spaces.Discrete(5) 11 | self.observation_space = spaces.Discrete(2) 12 | 13 | def step(self, action): 14 | 15 | # if we took an action, we were in state 1 16 | state = 1 17 | 18 | if action == 2: 19 | reward = 1 20 | else: 21 | reward = -1 22 | 23 | # regardless of the action, game is done after a single step 24 | done = True 25 | 26 | info = {} 27 | 28 | return state, reward, done, info 29 | 30 | def reset(self): 31 | state = 0 32 | return state 33 | 34 | def render(self, mode='human'): 35 | pass 36 | 37 | def close(self): 38 | pass 39 | -------------------------------------------------------------------------------- /gym_basic/envs/basic_env_2.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import error, spaces, utils 3 | from gym.utils import seeding 4 | import numpy as np 5 | 6 | # same as BasicEnv, with one difference: the reward for each action is a normal variable 7 | # purpose is to see if we can use libraries 8 | 9 | class BasicEnv2(gym.Env): 10 | metadata = {'render.modes': ['human']} 11 | 12 | def __init__(self): 13 | # There are two actions, first will get reward of 1, second reward of -1. 14 | self.action_space = spaces.Discrete(5) 15 | self.observation_space = spaces.Discrete(2) 16 | 17 | def step(self, action): 18 | 19 | # if we took an action, we were in state 1 20 | state = 1 21 | 22 | reward = np.random.normal(loc = action, scale = action) 23 | 24 | # regardless of the action, game is done after a single step 25 | done = True 26 | 27 | info = {} 28 | 29 | return state, reward, done, info 30 | 31 | def reset(self): 32 | state = 0 33 | return state 34 | 35 | def render(self, mode='human'): 36 | pass 37 | 38 | def close(self): 39 | pass 40 | -------------------------------------------------------------------------------- /gym_basic_env_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Infrastructure" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Purpose of this project: set up a custom gym environment, from scratch, with different versions. Train model with simple q-learning, see how to make it compatible with stable baselines. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "import gym\n", 25 | "import random" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "Creating environments." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 28, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "class BasicEnv(gym.Env):\n", 42 | "\n", 43 | " def __init__(self):\n", 44 | " # There are two actions, first will get reward of 1, second reward of -1. \n", 45 | "# self.action_space = gym.spaces.Discrete(5)\n", 46 | " self.action_space = 1\n", 47 | " self.observation_space = gym.spaces.Discrete(2)\n", 48 | "\n", 49 | " def step(self, action):\n", 50 | "\n", 51 | " # if we took an action, we were in state 1\n", 52 | " state = 1\n", 53 | " \n", 54 | " if action == 2:\n", 55 | " reward = 1\n", 56 | " else:\n", 57 | " reward = -1\n", 58 | " \n", 59 | " # regardless of the action, game is done after a single step\n", 60 | " done = True\n", 61 | "\n", 62 | " info = {}\n", 63 | "\n", 64 | " return state, reward, done, info\n", 65 | "\n", 66 | " def reset(self):\n", 67 | " state = 0\n", 68 | " return state" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# same as BasicEnv, with one difference: the reward for each action is a normal variable\n", 78 | "# purpose is to see if we can use libraries\n", 79 | "\n", 80 | "class BasicEnv2(gym.Env):\n", 81 | " metadata = {'render.modes': ['human']}\n", 82 | "\n", 83 | " def __init__(self):\n", 84 | " # There are two actions, first will get reward of 1, second reward of -1. \n", 85 | " self.action_space = gym.spaces.Discrete(5)\n", 86 | " self.observation_space = gym.spaces.Discrete(2)\n", 87 | "\n", 88 | " def step(self, action):\n", 89 | "\n", 90 | " # if we took an action, we were in state 1\n", 91 | " state = 1\n", 92 | " \n", 93 | " reward = np.random.normal(loc = action, scale = action)\n", 94 | " \n", 95 | " # regardless of the action, game is done after a single step\n", 96 | " done = True\n", 97 | "\n", 98 | " info = {}\n", 99 | "\n", 100 | " return state, reward, done, info\n", 101 | "\n", 102 | " def reset(self):\n", 103 | " state = 0\n", 104 | " return state\n", 105 | " \n", 106 | " def render(self, mode='human'):\n", 107 | " pass\n", 108 | "\n", 109 | " def close(self):\n", 110 | " pass" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 4, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# method 1 - build from gym package\n", 120 | "env = gym.make(\"gym_basic:basic-v2\")" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# method 2 - use local test class\n", 130 | "env = BasicEnv2()" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "# Q-Learning" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "Source: https://deeplizard.com/learn/video/HGeI30uATws\n", 145 | "\n", 146 | "I copied the code and tested it with the custom environment instead of the built-in Frozen Lake environment. " 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 18, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "env = BasicEnv()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 19, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "[[0. 0. 0. 0. 0.]\n", 168 | " [0. 0. 0. 0. 0.]]\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "action_space_size = env.action_space.n\n", 174 | "state_space_size = env.observation_space.n\n", 175 | "\n", 176 | "q_table = np.zeros((state_space_size, action_space_size))\n", 177 | "\n", 178 | "print(q_table)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 20, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "num_episodes = 1000\n", 188 | "max_steps_per_episode = 10 # but it won't go higher than 1\n", 189 | "\n", 190 | "learning_rate = 0.1\n", 191 | "discount_rate = 0.99\n", 192 | "\n", 193 | "exploration_rate = 1\n", 194 | "max_exploration_rate = 1\n", 195 | "min_exploration_rate = 0.01\n", 196 | "\n", 197 | "exploration_decay_rate = 0.01 #if we decrease it, will learn slower" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 21, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "********** Average reward per thousand episodes **********\n", 210 | "\n", 211 | "100 : -0.060000000000000005\n", 212 | "200 : 0.5800000000000003\n", 213 | "300 : 0.7200000000000004\n", 214 | "400 : 0.9800000000000006\n", 215 | "500 : 0.9600000000000006\n", 216 | "600 : 0.9800000000000006\n", 217 | "700 : 0.9800000000000006\n", 218 | "800 : 0.9800000000000006\n", 219 | "900 : 0.9800000000000006\n", 220 | "1000 : 1.0000000000000007\n", 221 | "\n", 222 | "\n", 223 | "********** Q-table **********\n", 224 | "\n", 225 | "[[-0.94185026 -0.89058101 1. -0.92023356 -0.91137062]\n", 226 | " [ 0. 0. 0. 0. 0. ]]\n" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "rewards_all_episodes = []\n", 232 | "\n", 233 | "# Q-Learning algorithm\n", 234 | "for episode in range(num_episodes):\n", 235 | " state = env.reset()\n", 236 | " \n", 237 | " done = False\n", 238 | " rewards_current_episode = 0\n", 239 | " \n", 240 | " for step in range(max_steps_per_episode):\n", 241 | " \n", 242 | " # Exploration -exploitation trade-off\n", 243 | " exploration_rate_threshold = random.uniform(0,1)\n", 244 | " if exploration_rate_threshold > exploration_rate: \n", 245 | " action = np.argmax(q_table[state,:])\n", 246 | " else:\n", 247 | " action = env.action_space.sample()\n", 248 | " \n", 249 | " new_state, reward, done, info = env.step(action)\n", 250 | " \n", 251 | " # Update Q-table for Q(s,a)\n", 252 | " q_table[state, action] = (1 - learning_rate) * q_table[state, action] + \\\n", 253 | " learning_rate * (reward + discount_rate * np.max(q_table[new_state,:]))\n", 254 | " \n", 255 | " state = new_state\n", 256 | " rewards_current_episode += reward\n", 257 | " \n", 258 | " if done == True: \n", 259 | " break\n", 260 | " \n", 261 | " # Exploration rate decay\n", 262 | " exploration_rate = min_exploration_rate + \\\n", 263 | " (max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate * episode)\n", 264 | " \n", 265 | " rewards_all_episodes.append(rewards_current_episode)\n", 266 | " \n", 267 | "# Calculate and print the average reward per 10 episodes\n", 268 | "rewards_per_thousand_episodes = np.split(np.array(rewards_all_episodes), num_episodes / 100)\n", 269 | "count = 100\n", 270 | "print(\"********** Average reward per thousand episodes **********\\n\")\n", 271 | "\n", 272 | "for r in rewards_per_thousand_episodes:\n", 273 | " print(count, \": \", str(sum(r / 100)))\n", 274 | " count += 100\n", 275 | " \n", 276 | "# Print updated Q-table\n", 277 | "print(\"\\n\\n********** Q-table **********\\n\")\n", 278 | "print(q_table)\n", 279 | " " 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": {}, 285 | "source": [ 286 | "# Verify Environment with Stable Baselines" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 22, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "from stable_baselines import PPO2\n", 296 | "from stable_baselines.common.policies import MlpPolicy\n", 297 | "from stable_baselines.common.vec_env import DummyVecEnv" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 25, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "2\n", 310 | "2\n", 311 | "2\n", 312 | "2\n", 313 | "2\n", 314 | "2\n", 315 | "2\n", 316 | "2\n", 317 | "2\n", 318 | "2\n" 319 | ] 320 | } 321 | ], 322 | "source": [ 323 | "env = gym.make('gym_basic:basic-v0')\n", 324 | "\n", 325 | "# Optional: PPO2 requires a vectorized environment to run\n", 326 | "# the env is now wrapped automatically when passing it to the constructor\n", 327 | "# env = DummyVecEnv([lambda: env])\n", 328 | "\n", 329 | "model = PPO2(MlpPolicy, env, verbose=False)\n", 330 | "model.learn(total_timesteps=10000)\n", 331 | "\n", 332 | "obs = env.reset()\n", 333 | "for i in range(10):\n", 334 | " action, _states = model.predict(obs)\n", 335 | " print(action)\n", 336 | " obs, rewards, dones, info = env.step(action)\n", 337 | " env.render()" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 10, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "from stable_baselines.common.env_checker import check_env" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 30, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "# env = gym.make('gym_basic:basic-v0')\n", 356 | "env = BasicEnv()" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 31, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "ename": "AssertionError", 366 | "evalue": "The action space must inherit from gym.spaces cf https://github.com/openai/gym/blob/master/gym/spaces/", 367 | "output_type": "error", 368 | "traceback": [ 369 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 370 | "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", 371 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcheck_env\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 372 | "\u001b[0;32m~/opt/anaconda3/lib/python3.7/site-packages/stable_baselines/common/env_checker.py\u001b[0m in \u001b[0;36mcheck_env\u001b[0;34m(env, warn, skip_render_check)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;31m# ============= Check the spaces (observation and action) ================\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0m_check_spaces\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;31m# Define aliases for convenience\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 373 | "\u001b[0;32m~/opt/anaconda3/lib/python3.7/site-packages/stable_baselines/common/env_checker.py\u001b[0m in \u001b[0;36m_check_spaces\u001b[0;34m(env)\u001b[0m\n\u001b[1;32m 133\u001b[0m assert isinstance(env.observation_space,\n\u001b[1;32m 134\u001b[0m spaces.Space), \"The observation space must inherit from gym.spaces\" + gym_spaces\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mspaces\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSpace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"The action space must inherit from gym.spaces\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mgym_spaces\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 374 | "\u001b[0;31mAssertionError\u001b[0m: The action space must inherit from gym.spaces cf https://github.com/openai/gym/blob/master/gym/spaces/" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "check_env(env)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [] 388 | } 389 | ], 390 | "metadata": { 391 | "kernelspec": { 392 | "display_name": "Python 3", 393 | "language": "python", 394 | "name": "python3" 395 | }, 396 | "language_info": { 397 | "codemirror_mode": { 398 | "name": "ipython", 399 | "version": 3 400 | }, 401 | "file_extension": ".py", 402 | "mimetype": "text/x-python", 403 | "name": "python", 404 | "nbconvert_exporter": "python", 405 | "pygments_lexer": "ipython3", 406 | "version": "3.7.4" 407 | }, 408 | "toc": { 409 | "base_numbering": 1, 410 | "nav_menu": {}, 411 | "number_sections": true, 412 | "sideBar": true, 413 | "skip_h1_title": false, 414 | "title_cell": "Table of Contents", 415 | "title_sidebar": "Contents", 416 | "toc_cell": false, 417 | "toc_position": {}, 418 | "toc_section_display": true, 419 | "toc_window_display": false 420 | }, 421 | "varInspector": { 422 | "cols": { 423 | "lenName": 16, 424 | "lenType": 16, 425 | "lenVar": 40 426 | }, 427 | "kernels_config": { 428 | "python": { 429 | "delete_cmd_postfix": "", 430 | "delete_cmd_prefix": "del ", 431 | "library": "var_list.py", 432 | "varRefreshCmd": "print(var_dic_list())" 433 | }, 434 | "r": { 435 | "delete_cmd_postfix": ") ", 436 | "delete_cmd_prefix": "rm(", 437 | "library": "var_list.r", 438 | "varRefreshCmd": "cat(var_dic_list()) " 439 | } 440 | }, 441 | "types_to_exclude": [ 442 | "module", 443 | "function", 444 | "builtin_function_or_method", 445 | "instance", 446 | "_Feature" 447 | ], 448 | "window_display": false 449 | } 450 | }, 451 | "nbformat": 4, 452 | "nbformat_minor": 2 453 | } 454 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='gym_basic', 4 | version='0.0.1', 5 | install_requires=['gym'] 6 | ) --------------------------------------------------------------------------------