├── MANIFEST.in ├── setup.cfg ├── .travis.yml ├── gym_gridworlds ├── envs │ ├── __init__.py │ ├── cliff_env.py │ ├── windy_gridworld_env.py │ └── gridworld_env.py └── __init__.py ├── tests ├── test_cliff.py ├── test_windy_gridworld.py └── test_gridworld.py ├── setup.py ├── LICENSE.txt ├── README.rst └── .gitignore /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: '3.6' 3 | install: 'python setup.py install' 4 | script: 'python setup.py test' 5 | -------------------------------------------------------------------------------- /gym_gridworlds/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_gridworlds.envs.gridworld_env import GridworldEnv # noqa 2 | from gym_gridworlds.envs.windy_gridworld_env import WindyGridworldEnv # noqa 3 | from gym_gridworlds.envs.cliff_env import CliffEnv # noqa 4 | -------------------------------------------------------------------------------- /gym_gridworlds/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='Gridworld-v0', 5 | entry_point='gym_gridworlds.envs:GridworldEnv', 6 | ) 7 | register( 8 | id='WindyGridworld-v0', 9 | entry_point='gym_gridworlds.envs:WindyGridworldEnv', 10 | ) 11 | register( 12 | id='Cliff-v0', 13 | entry_point='gym_gridworlds.envs:CliffEnv', 14 | ) 15 | -------------------------------------------------------------------------------- /tests/test_cliff.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture 5 | def cliff(): 6 | import gym 7 | import gym_gridworlds # noqa 8 | return gym.make('Cliff-v0') 9 | 10 | 11 | def test_cliff_action_space(cliff): 12 | assert cliff.action_space.n == 4 13 | 14 | 15 | def test_cliff_step_on_cliff(cliff): 16 | S = cliff.reset() 17 | # check that in start state 18 | assert S == (3, 0) 19 | # move right 20 | S, R, _, _ = cliff.step(1) 21 | assert R == -100 22 | # check movement to start state 23 | assert S == (3, 0) 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | with open('README.rst') as f: 5 | long_description = ''.join(f.readlines()) 6 | 7 | setup( 8 | name='gym_gridworlds', 9 | version='0.0.2', 10 | description='Gridworlds environments for OpenAI gym.', 11 | long_description=long_description, 12 | author='Ondřej Podsztavek', 13 | author_email='ondrej.podsztavek@gmail.com', 14 | license='MIT License', 15 | url='https://github.com/podondra/gym-gridworlds', 16 | packages=find_packages(), 17 | install_requires=['gym', 'numpy'], 18 | setup_requires=['pytest-runner'], 19 | tests_require=['pytest'], 20 | ) 21 | -------------------------------------------------------------------------------- /tests/test_windy_gridworld.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture 5 | def windy_gridworld(): 6 | import gym 7 | import gym_gridworlds # noqa 8 | return gym.make('WindyGridworld-v0') 9 | 10 | 11 | def test_windy_gridworld_action_space(windy_gridworld): 12 | assert windy_gridworld.action_space.n == 4 13 | 14 | 15 | def test_windy_gridworld_reset(windy_gridworld): 16 | # reset windy_gridworld 17 | S = windy_gridworld.reset() 18 | # check that in start state 19 | assert S == (3, 0) 20 | # move right 21 | S, _, _, _ = windy_gridworld.step(1) 22 | # check movement to right state 23 | assert S == (3, 1) 24 | # reset again 25 | S = windy_gridworld.reset() 26 | assert S == (3, 0) 27 | -------------------------------------------------------------------------------- /tests/test_gridworld.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy 3 | 4 | 5 | @pytest.fixture 6 | def gridworld(): 7 | import gym 8 | import gym_gridworlds # noqa 9 | return gym.make('Gridworld-v0') 10 | 11 | 12 | def test_gridworld_action_space(gridworld): 13 | assert gridworld.action_space.n == 4 14 | 15 | 16 | def test_gridworld_observation_space(gridworld): 17 | assert gridworld.observation_space.n == 15 18 | 19 | 20 | def test_gridworld_transition_probabilities(gridworld): 21 | assert gridworld.P.shape == (4, 15, 15) 22 | # ensure that probabilities sum to 1 23 | assert numpy.all(gridworld.P.sum(axis=-1) == 1) 24 | 25 | 26 | def test_gridworld_rewards(gridworld): 27 | assert gridworld.R.shape == (4, 15) 28 | # assert that there is 0 reward when moving in terminal state 29 | assert numpy.all(gridworld.R[:, 0] == 0) 30 | # all other rewards are -1 31 | assert numpy.all(gridworld.R[:, 1:] == -1) 32 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ondřej Podsztavek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /gym_gridworlds/envs/cliff_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | 4 | 5 | class CliffEnv(gym.Env): 6 | def __init__(self): 7 | self.height = 4 8 | self.width = 12 9 | self.action_space = spaces.Discrete(4) 10 | self.observation_space = spaces.Tuple(( 11 | spaces.Discrete(self.height), 12 | spaces.Discrete(self.width) 13 | )) 14 | self.moves = { 15 | 0: (-1, 0), # up 16 | 1: (0, 1), # right 17 | 2: (1, 0), # down 18 | 3: (0, -1), # left 19 | } 20 | 21 | # begin in start state 22 | self.reset() 23 | 24 | def step(self, action): 25 | x, y = self.moves[action] 26 | self.S = self.S[0] + x, self.S[1] + y 27 | 28 | self.S = max(0, self.S[0]), max(0, self.S[1]) 29 | self.S = (min(self.S[0], self.height - 1), 30 | min(self.S[1], self.width - 1)) 31 | 32 | if self.S == (self.height - 1, self.width - 1): 33 | return self.S, -1, True, {} 34 | elif self.S[1] != 0 and self.S[0] == self.height - 1: 35 | # the cliff 36 | return self.reset(), -100, False, {} 37 | return self.S, -1, False, {} 38 | 39 | def reset(self): 40 | self.S = (3, 0) 41 | return self.S 42 | -------------------------------------------------------------------------------- /gym_gridworlds/envs/windy_gridworld_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | 4 | 5 | class WindyGridworldEnv(gym.Env): 6 | def __init__(self): 7 | self.height = 7 8 | self.width = 10 9 | self.action_space = spaces.Discrete(4) 10 | self.observation_space = spaces.Tuple(( 11 | spaces.Discrete(self.height), 12 | spaces.Discrete(self.width) 13 | )) 14 | self.moves = { 15 | 0: (-1, 0), # up 16 | 1: (0, 1), # right 17 | 2: (1, 0), # down 18 | 3: (0, -1), # left 19 | } 20 | 21 | # begin in start state 22 | self.reset() 23 | 24 | def step(self, action): 25 | if self.S[1] in (3, 4, 5, 8): 26 | self.S = self.S[0] - 1, self.S[1] 27 | elif self.S[1] in (6, 7): 28 | self.S = self.S[0] - 2, self.S[1] 29 | 30 | x, y = self.moves[action] 31 | self.S = self.S[0] + x, self.S[1] + y 32 | 33 | self.S = max(0, self.S[0]), max(0, self.S[1]) 34 | self.S = (min(self.S[0], self.height - 1), 35 | min(self.S[1], self.width - 1)) 36 | 37 | if self.S == (3, 7): 38 | return self.S, -1, True, {} 39 | return self.S, -1, False, {} 40 | 41 | def reset(self): 42 | self.S = (3, 0) 43 | return self.S 44 | -------------------------------------------------------------------------------- /gym_gridworlds/envs/gridworld_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy 4 | 5 | 6 | class GridworldEnv(gym.Env): 7 | reward_range = (-1, 0) 8 | action_space = spaces.Discrete(4) 9 | # although there are 2 terminal squares in the grid 10 | # they are considered as 1 state 11 | # therefore observation is between 0 and 14 12 | observation_space = spaces.Discrete(15) 13 | 14 | def __init__(self): 15 | gridworld = numpy.arange( 16 | self.observation_space.n + 1 17 | ).reshape((4, 4)) 18 | gridworld[-1, -1] = 0 19 | # state transition matrix 20 | self.P = numpy.zeros((self.action_space.n, 21 | self.observation_space.n, 22 | self.observation_space.n)) 23 | # any action taken in terminal state has no effect 24 | self.P[:, 0, 0] = 1 25 | 26 | for s in gridworld.flat[1:-1]: 27 | row, col = numpy.argwhere(gridworld == s)[0] 28 | for a, d in zip( 29 | range(self.action_space.n), 30 | [(-1, 0), (0, 1), (1, 0), (0, -1)] 31 | ): 32 | next_row = max(0, min(row + d[0], 3)) 33 | next_col = max(0, min(col + d[1], 3)) 34 | s_prime = gridworld[next_row, next_col] 35 | self.P[a, s, s_prime] = 1 36 | 37 | self.R = numpy.full((self.action_space.n, 38 | self.observation_space.n), -1) 39 | self.R[:, 0] = 0 40 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | OpenAI gym Gridworlds 2 | ===================== 3 | 4 | Implementation of three gridworlds environments 5 | from book `Reinforcement Learning: An Introduction 6 | `_ 7 | compatible with `OpenAI gym `_. 8 | 9 | Usage 10 | ----- 11 | 12 | .. code:: 13 | 14 | $ import gym 15 | $ import gym_gridworlds 16 | $ env = gym.make('Gridworld-v0') # substitute environment's name 17 | 18 | ``Gridworld-v0`` 19 | ---------------- 20 | 21 | Gridworld is simple 4 times 4 gridworld from example 4.1 in the [book]. 22 | There are four action in each state (up, down, right, left) 23 | which deterministically cause the corresponding state transitions 24 | but actions that would take an agent of the grid leave a state unchanged. 25 | The reward is -1 for all tranistion until the terminal state is reached. 26 | The terminal state is in top left and bottom right coners. 27 | 28 | ``WindyGridworld-v0`` 29 | --------------------- 30 | 31 | Windy gridworld is from example 6.5 in the book_. 32 | Windy gridworld is a standard gridworld as described above 33 | but there is a crosswind upward through the middle of the grid. 34 | Action are standard but in the middle region the resultant states are 35 | shifted upward by a wind which strength varies between columns. 36 | 37 | .. _book: http://incompleteideas.net/book/the-book-2nd.html 38 | 39 | ``Cliff-v0`` 40 | ------------ 41 | 42 | Cliff walking is a gridworld example 6.6 from the book_. 43 | Again reward is -1 on all transition except those into region 44 | that is cliff. 45 | Stepping into this region incurs a reward of -100 46 | and sends the agent instantly back to the start. 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | --------------------------------------------------------------------------------