├── CONTRIBUTING.md ├── LICENSE ├── README.md └── srl ├── __init__.py ├── all_tests.py ├── context.py ├── context_test.py ├── grid.py ├── grid_test.py ├── movement.py ├── policy_gradient.py ├── policy_gradient_test.py ├── simulation.py ├── simulation_test.py ├── world.py └── world_test.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to any Google project must be accompanied by a Contributor License 9 | Agreement. This is necessary because you own the copyright to your changes, even 10 | after your contribution becomes part of this project. So this agreement simply 11 | gives us permission to use and redistribute your contributions as part of the 12 | project. Head over to to see your current 13 | agreements on file or to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult [GitHub Help] for more 23 | information on using pull requests. 24 | 25 | [GitHub Help]: https://help.github.com/articles/about-pull-requests/ 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple Reinforcement Learning 2 | 3 | This demonstrates reinforcement learning. Specifically, it 4 | uses [Q-learning](https://webdocs.cs.ualberta.ca/~sutton/book/ebook/node65.html) 5 | to move a player (`@`) around a fixed maze and avoid traps (`^`) while getting 6 | treasure (`$`) as fast as possible. 7 | 8 | Add the directory containing srl to PYTHONPATH. Then there are three 9 | ways to run the grid.py program: 10 | 11 | 1. `srl/grid.py --interactive [--random]`: Use the arrow keys to walk 12 | around the maze. The episode ends when you reach a trap or the 13 | treasure. Press space to restart or Q to quit. No learning happens 14 | in this mode. Use `--random` to generate a random maze instead of 15 | the fixed maze. 16 | 17 | 1. `srl/grid.py --q [--random]`: An ε-greedy Q-learner 18 | repeatedly runs the maze. The parameters are not tuned to learn 19 | quickly. Over the course of several minutes the player first learns 20 | to avoid spikes, then reach the treasure, and eventually reach the 21 | treasure in the minimum number of steps. 22 | 23 | Learning is not saved between runs. 24 | 25 | The Q network is not reset between episodes, so it does not 26 | generalize to new random maps. This leads to very poor performance. 27 | 28 | 1. `srl/all_tests.py`: Run the unit tests. 29 | 30 | Here are some ideas in ways to extend grid.py. These are increasingly difficult. 31 | Some early steps may be useful for later steps. 32 | 33 | * Watch different random mazes and see how features like long hallways 34 | affect learning. 35 | * Change the learning rate α and future reward discount γ to try to 36 | improve the effectiveness of the learner. 37 | * Display summaries of results. For example, graph the distribution of rewards 38 | (Simulation.score) over repeated runs. It may be useful to separate the game 39 | simulation loop and the display so that you can simulate faster. 40 | * Implement TD(λ) 41 | and [eligibility traces](https://webdocs.cs.ualberta.ca/~sutton/book/ebook/node75.html). How do features like long hallways affect learning now? 42 | * Q-learning is an "off-policy" learning algorithm, which means the policy 43 | controlling the player and the policy being learned can be different. Adapt 44 | the HumanPlayer to permit a learner to learn by observing a human play. 45 | * Save policies to disk and load them later. This will let you checkpoint and 46 | restart learning. 47 | * Generate a new maze each episode. The state space will be too large 48 | for a QTable to be useful, so implement a value function 49 | approximator such as a neural network. The QTable memorizes the 50 | fixed map; with multiple maps you will need to feed the maze as 51 | input to the neural network so it can "see" the map instead of 52 | memorizing it. 53 | * Connect the learner to an actual roguelike such 54 | as [NetHack](http://www.nethack.org/) to speed run for dungeon depth. 55 | * Change the problem from one with discrete states (a grid) and actions (up, 56 | down, left, right) to continuous states (the player is at fine-grained x-y 57 | coordinates with a certain heading and velocity) and actions (acceleration and 58 | turning.) How will you relate states to each other in a continuous space? 59 | -------------------------------------------------------------------------------- /srl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/simple-reinforcement-learning/9bdac29427cd5c556d7ea7531b807645f043aae3/srl/__init__.py -------------------------------------------------------------------------------- /srl/all_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import unittest 18 | 19 | from srl import context_test 20 | from srl import grid_test 21 | from srl import policy_gradient_test 22 | from srl import simulation_test 23 | from srl import world_test 24 | 25 | 26 | def load_tests(loader, unused_tests, unused_pattern): 27 | # pylint: disable=unused-argument 28 | test_modules = [ 29 | context_test, 30 | grid_test, 31 | policy_gradient_test, 32 | simulation_test, 33 | world_test, 34 | ] 35 | return unittest.TestSuite(map(loader.loadTestsFromModule, test_modules)) 36 | 37 | 38 | if __name__ == '__main__': 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /srl/context.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import collections 16 | import curses 17 | 18 | 19 | class StubFailure(Exception): 20 | pass 21 | 22 | 23 | class StubWindow(object): 24 | '''A no-op implementation of the game display.''' 25 | def addstr(self, y, x, s): 26 | pass 27 | 28 | def erase(self): 29 | pass 30 | 31 | def getch(self): 32 | raise StubFailure('"getch" not implemented; use a mock') 33 | 34 | def move(self, y, x): 35 | pass 36 | 37 | def refresh(self): 38 | pass 39 | 40 | 41 | class StubContext(object): 42 | def __init__(self): 43 | self.run_loop = RunLoop() 44 | self.window = StubWindow() 45 | 46 | def start(self): 47 | self.run_loop.start() 48 | 49 | 50 | class Context(object): 51 | '''Provides the shared curses window and a run loop to other objects. 52 | 53 | Properties: 54 | run_loop: See RunLoop. 55 | window: A curses window to display text. 56 | ''' 57 | 58 | def __init__(self): 59 | self.run_loop = RunLoop() 60 | self.window = None 61 | 62 | def start(self): 63 | '''Initializes the context and starts the run loop.''' 64 | curses.wrapper(self._capture_window) 65 | 66 | def _capture_window(self, window): 67 | self.window = window 68 | self.run_loop.start() 69 | 70 | 71 | class RunLoop(object): 72 | '''A run loop invokes its tasks until there are none left.''' 73 | 74 | def __init__(self): 75 | self._tasks = collections.deque() 76 | self._quit = object() 77 | 78 | def start(self): 79 | while len(self._tasks): 80 | task = self._tasks.popleft() 81 | if task is self._quit: 82 | return 83 | task() 84 | 85 | def post_task(self, task, repeat=False): 86 | if repeat: 87 | def repeater(): 88 | task() 89 | self.post_task(repeater) 90 | self.post_task(repeater) 91 | else: 92 | self._tasks.append(task) 93 | 94 | def post_quit(self): 95 | self._tasks.append(self._quit) 96 | -------------------------------------------------------------------------------- /srl/context_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from srl import context 18 | 19 | 20 | class TestRunLoop(unittest.TestCase): 21 | def test_empty_run_loop_quits(self): 22 | run_loop = context.RunLoop() 23 | run_loop.start() 24 | # Test passes if this does not hang. 25 | 26 | def test_post_task(self): 27 | run_loop = context.RunLoop() 28 | log = [] 29 | run_loop.post_task(lambda: log.append('a')) 30 | self.assertEqual([], log, 31 | 'post_task should not complete tasks synchronously') 32 | run_loop.start() 33 | self.assertEqual(['a'], log, 'run loop should have run the callback') 34 | 35 | def test_posted_tasks_run_in_order(self): 36 | run_loop = context.RunLoop() 37 | log = [] 38 | run_loop.post_task(lambda: log.append('a')) 39 | run_loop.post_task(lambda: log.append('b')) 40 | run_loop.start() 41 | self.assertEqual(['a', 'b'], log, 42 | 'run loop should run tasks in the order they are posted') 43 | 44 | def test_post_quit(self): 45 | run_loop = context.RunLoop() 46 | log = [] 47 | run_loop.post_task(lambda: log.append('a')) 48 | run_loop.post_task(lambda: run_loop.post_quit()) 49 | run_loop.post_task(lambda: run_loop.post_task(lambda: log.append('c'))) 50 | run_loop.post_task(lambda: log.append('b')) 51 | run_loop.start() 52 | self.assertEqual( 53 | ['a', 'b'], log, 54 | 'run loop should run tasks posted before quit, but not after') 55 | 56 | def test_post_repeat(self): 57 | run_loop = context.RunLoop() 58 | n = 0 59 | def count(): 60 | nonlocal n 61 | n += 1 62 | if n == 3: 63 | run_loop.post_quit() 64 | run_loop.post_task(count, repeat=True) 65 | run_loop.start() 66 | self.assertEqual(3, n, 'the task should have run repetitively') 67 | -------------------------------------------------------------------------------- /srl/grid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2016 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # TODO: 18 | # - Implement approximate value functions 19 | 20 | import argparse 21 | import collections 22 | import curses 23 | import random 24 | import sys 25 | import tensorflow as tf 26 | import time 27 | 28 | from srl import context 29 | from srl import movement 30 | from srl import policy_gradient 31 | from srl import simulation 32 | from srl import world 33 | 34 | 35 | # There is also an interactive version of the game. These are keycodes 36 | # for interacting with it. 37 | KEY_Q = ord('q') 38 | KEY_ESC = 27 39 | KEY_SPACE = ord(' ') 40 | KEY_UP = 259 41 | KEY_DOWN = 258 42 | KEY_LEFT = 260 43 | KEY_RIGHT = 261 44 | KEY_ACTION_MAP = { 45 | KEY_UP: movement.ACTION_UP, 46 | KEY_DOWN: movement.ACTION_DOWN, 47 | KEY_LEFT: movement.ACTION_LEFT, 48 | KEY_RIGHT: movement.ACTION_RIGHT 49 | } 50 | QUIT_KEYS = set([KEY_Q, KEY_ESC]) 51 | 52 | 53 | class Game(object): 54 | '''A simulation that uses curses.''' 55 | def __init__(self, ctx, generator, driver): 56 | '''Creates a new game in world where driver will interact with the game.''' 57 | self._context = ctx 58 | self._sim = simulation.Simulation(generator) 59 | self._driver = driver 60 | self._wins = 0 61 | self._losses = 0 62 | self._was_in_terminal_state = False 63 | 64 | # The game loop. 65 | def step(self): 66 | # Paint 67 | self._draw(self._context.window) 68 | # Get input, etc. 69 | self._driver.interact(self._context, self._sim) 70 | if self._sim.in_terminal_state and not self._was_in_terminal_state: 71 | if self._sim.score < 0: 72 | self._losses += 1 73 | else: 74 | self._wins += 1 75 | self._was_in_terminal_state = self._sim.in_terminal_state 76 | 77 | # Paints the window. 78 | def _draw(self, window): 79 | window.erase() 80 | # Draw the environment 81 | for y, line in enumerate(self._sim.world._lines): 82 | window.addstr(y, 0, line) 83 | # Draw the player 84 | window.addstr(self._sim.y, self._sim.x, '@') 85 | # Draw status 86 | window.addstr(self._sim.world.h, 0, 87 | 'W/L: %d/%d Score: %d' % 88 | (self._wins, self._losses, self._sim.score)) 89 | window.move(self._sim.y, self._sim.x) 90 | # TODO: Add a display so multiple things can contribute to the output. 91 | window.refresh() 92 | 93 | 94 | class Player(object): 95 | '''A Player provides input to the game as a simulation evolves.''' 96 | def interact(self, ctx, sim): 97 | # All players have the same interface 98 | # pylint: disable=unused-argument 99 | pass 100 | 101 | 102 | class HumanPlayer(Player): 103 | '''A game driver that reads input from the keyboard.''' 104 | def __init__(self): 105 | super(HumanPlayer, self).__init__() 106 | self._ch = 0 107 | 108 | def interact(self, ctx, sim): 109 | self._ch = ctx.window.getch() 110 | if self._ch in KEY_ACTION_MAP and not sim.in_terminal_state: 111 | sim.act(KEY_ACTION_MAP[self._ch]) 112 | elif self._ch == KEY_SPACE and sim.in_terminal_state: 113 | sim.reset() 114 | elif self._ch in QUIT_KEYS: 115 | ctx.run_loop.post_quit() 116 | 117 | 118 | class MachinePlayer(Player): 119 | '''A game driver which applies a policy, observed by a learner. 120 | 121 | The learner can adjust the policy. 122 | ''' 123 | 124 | def __init__(self, policy, learner): 125 | super(MachinePlayer, self).__init__() 126 | self._policy = policy 127 | self._learner = learner 128 | 129 | def interact(self, ctx, sim): 130 | super(MachinePlayer, self).interact(ctx, sim) 131 | if sim.in_terminal_state: 132 | sim.reset() 133 | else: 134 | old_state = sim.state 135 | action = self._policy.pick_action(sim.state) 136 | reward = sim.act(action) 137 | self._learner.observe(old_state, action, reward, sim.state) 138 | 139 | 140 | class StubLearner(object): 141 | '''Plugs in as a learner but doesn't update anything.''' 142 | def observe(self, old_state, action, reward, new_state): 143 | pass 144 | 145 | 146 | class RandomPolicy(object): 147 | '''A policy which picks actions at random.''' 148 | def pick_action(self, _): 149 | return random.choice(movement.ALL_ACTIONS) 150 | 151 | 152 | class EpsilonPolicy(object): 153 | '''Pursues policy A, but uses policy B with probability epsilon. 154 | 155 | Be careful when using a learned function for one of these policies; 156 | the epsilon policy needs an off-policy learner. 157 | ''' 158 | def __init__(self, policy_a, policy_b, epsilon): 159 | self._policy_a = policy_a 160 | self._policy_b = policy_b 161 | self._epsilon = epsilon 162 | 163 | def pick_action(self, state): 164 | if random.random() < self._epsilon: 165 | return self._policy_b.pick_action(state) 166 | else: 167 | return self._policy_a.pick_action(state) 168 | 169 | 170 | class QTable(object): 171 | '''An approximation of the Q function based on a look-up table. 172 | As such it is only appropriate for discrete state-action spaces.''' 173 | def __init__(self, init_reward = 0): 174 | self._table = collections.defaultdict(lambda: init_reward) 175 | 176 | def get(self, state, action): 177 | return self._table[(state, action)] 178 | 179 | def set(self, state, action, value): 180 | self._table[(state, action)] = value 181 | 182 | def best(self, state): 183 | '''Gets the best predicted action and its value for |state|.''' 184 | best_value = -1e20 185 | best_action = None 186 | for action in movement.ALL_ACTIONS: 187 | value = self.get(state, action) 188 | if value > best_value: 189 | best_action, best_value = action, value 190 | return best_action, best_value 191 | 192 | 193 | class GreedyQ(object): 194 | '''A policy which chooses the action with the highest reward estimate.''' 195 | def __init__(self, q): 196 | self._q = q 197 | 198 | def pick_action(self, state): 199 | return self._q.best(state)[0] 200 | 201 | 202 | class QLearner(object): 203 | '''An off-policy learner which updates a QTable.''' 204 | def __init__(self, q, learning_rate, discount_rate): 205 | self._q = q 206 | self._alpha = learning_rate 207 | self._gamma = discount_rate 208 | 209 | def observe(self, old_state, action, reward, new_state): 210 | prev = self._q.get(old_state, action) 211 | self._q.set(old_state, action, prev + self._alpha * ( 212 | reward + self._gamma * self._q.best(new_state)[1] - prev)) 213 | 214 | 215 | def main(): 216 | parser = argparse.ArgumentParser(description='Simple Reinforcement Learning.') 217 | group = parser.add_mutually_exclusive_group(required=True) 218 | group.add_argument('--interactive', action='store_true', 219 | help='use the keyboard arrow keys to play') 220 | group.add_argument('--q', action='store_true', 221 | help='play automatically with Q-learning') 222 | group.add_argument('--pg', action='store_true', 223 | help='play automatically with policy gradients') 224 | parser.add_argument('--random', action='store_true', 225 | help='generate a random map') 226 | 227 | args = parser.parse_args() 228 | 229 | ctx = context.Context() 230 | 231 | if args.random: 232 | generator = world.Generator(25, 15) 233 | else: 234 | generator = world.Static(world.World.parse('''\ 235 | ######## 236 | #..#...# 237 | #.@#.$.# 238 | #.##^^.# 239 | #......# 240 | ######## 241 | ''')) 242 | 243 | if args.interactive: 244 | player = HumanPlayer() 245 | elif args.q: 246 | q = QTable() 247 | learner = QLearner(q, 0.05, 0.1) 248 | policy = EpsilonPolicy(GreedyQ(q), RandomPolicy(), 0.01) 249 | player = MachinePlayer(policy, learner) 250 | elif args.pg: 251 | g = tf.Graph() 252 | s = tf.Session(graph=g) 253 | player = policy_gradient.PolicyGradientPlayer(g, s, generator.size) 254 | with g.as_default(): 255 | init = tf.global_variables_initializer() 256 | s.run(init) 257 | else: 258 | sys.exit(1) 259 | 260 | is_automatic = args.q or args.pg 261 | if is_automatic: 262 | # Slow the game down to make it fun? to watch. 263 | ctx.run_loop.post_task(lambda: time.sleep(0.1), repeat=True) 264 | 265 | game = Game(ctx, generator, player) 266 | ctx.run_loop.post_task(game.step, repeat=True) 267 | 268 | ctx.start() 269 | 270 | 271 | if __name__ == '__main__': 272 | main() 273 | -------------------------------------------------------------------------------- /srl/grid_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from unittest.mock import patch 16 | import unittest 17 | 18 | from srl import context 19 | from srl import movement 20 | from srl import simulation 21 | from srl import world 22 | from srl import grid 23 | 24 | 25 | class TestMachinePlayer(unittest.TestCase): 26 | def test_interact(self): 27 | TEST_ACTION = movement.ACTION_RIGHT 28 | q = grid.QTable(-1) 29 | q.set((0, 0), TEST_ACTION, 1) 30 | 31 | player = grid.MachinePlayer(grid.GreedyQ(q), grid.StubLearner()) 32 | w = world.World.parse('@.') 33 | with patch.object(simulation.Simulation, 'act') as mock_act: 34 | sim = simulation.Simulation(world.Static(w)) 35 | ctx = context.StubContext() 36 | player.interact(ctx, sim) 37 | mock_act.assert_called_once_with(TEST_ACTION) 38 | -------------------------------------------------------------------------------- /srl/movement.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # The player can take four actions: move up, down, left or right. 16 | ACTION_UP = 'u' 17 | ACTION_DOWN = 'd' 18 | ACTION_LEFT = 'l' 19 | ACTION_RIGHT = 'r' 20 | 21 | MOVEMENT = { 22 | ACTION_UP: (0, -1), 23 | ACTION_DOWN: (0, 1), 24 | ACTION_LEFT: (-1, 0), 25 | ACTION_RIGHT: (1, 0) 26 | } 27 | 28 | ALL_ACTIONS = [ACTION_UP, ACTION_RIGHT, ACTION_DOWN, ACTION_LEFT] 29 | ALL_MOTIONS = [MOVEMENT[act] for act in ALL_ACTIONS] 30 | -------------------------------------------------------------------------------- /srl/policy_gradient.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | '''A policy gradient-based maze runner in TensorFlow. 16 | 17 | This is based on 18 | https://medium.com/@awjuliani/super-simple-reinforcement-learning-tutorial-part-2-ded33892c724 19 | ''' 20 | 21 | import collections 22 | import numpy as np 23 | import sys 24 | import tensorflow as tf 25 | 26 | from srl import grid 27 | from srl import movement 28 | from srl import world 29 | 30 | 31 | _EMBEDDING_SIZE = 3 32 | _FEATURE_SIZE = len(world.VALID_CHARS) 33 | _ACTION_SIZE = len(movement.ALL_ACTIONS) 34 | 35 | 36 | class PolicyGradientNetwork(object): 37 | '''A policy gradient network. 38 | 39 | This has a number of properties for operating on the network with 40 | TensorFlow: 41 | 42 | state: Feed a world-sized array for selecting an action. 43 | [-1, h, w] int32 44 | action_out: Produces an action, fed state. 45 | 46 | action_in: Feed the responsible actions for the rewards. 47 | [-1, 1] int32 0 <= x < len(movement.ALL_ACTIONS) 48 | advantage: Feed the "goodness" of each given state. 49 | [-1, 1] float32 50 | update: Given batches of experience, train the network. 51 | 52 | loss: Training loss. 53 | summary: Merged summaries. 54 | ''' 55 | 56 | def __init__(self, name, graph, world_size_h_w): 57 | '''Creates a PolicyGradientNetwork. 58 | 59 | Args: 60 | name: The name of this network. TF variables are in a scope with 61 | this name. 62 | graph: The TF graph to build operations in. 63 | world_size_h_w: The size of the world, height by width. 64 | ''' 65 | h, w = world_size_h_w 66 | self._h, self._w = h, w 67 | with graph.as_default(): 68 | initializer = tf.contrib.layers.xavier_initializer() 69 | with tf.variable_scope(name) as self.variables: 70 | self.state = tf.placeholder(tf.int32, shape=[None, h, w]) 71 | 72 | # Input embedding 73 | embedding = tf.get_variable( 74 | 'embedding', shape=[_FEATURE_SIZE, _EMBEDDING_SIZE], 75 | initializer=initializer) 76 | embedding_lookup = tf.nn.embedding_lookup( 77 | embedding, tf.reshape(self.state, [-1, h * w]), 78 | name='embedding_lookup') 79 | embedding_lookup = tf.reshape(embedding_lookup, 80 | [-1, h, w, _EMBEDDING_SIZE]) 81 | 82 | # First convolution. 83 | conv_1_out_channels = 27 84 | conv_1 = tf.contrib.layers.conv2d( 85 | trainable=True, 86 | inputs=embedding_lookup, 87 | num_outputs=conv_1_out_channels, 88 | kernel_size=[3, 3], 89 | stride=1, 90 | padding='SAME', 91 | activation_fn=tf.nn.relu, 92 | weights_initializer=tf.contrib.layers.xavier_initializer_conv2d(), 93 | weights_regularizer=tf.contrib.layers.l2_regularizer(1.0), 94 | # TODO: What's a good initializer for biases? Below too. 95 | biases_initializer=initializer) 96 | 97 | shrunk_h = h 98 | shrunk_w = w 99 | 100 | # Second convolution. 101 | conv_2_out_channels = 50 102 | conv_2_stride = 2 103 | conv_2 = tf.contrib.layers.conv2d( 104 | trainable=True, 105 | inputs=conv_1, 106 | num_outputs=conv_2_out_channels, 107 | kernel_size=[5, 5], 108 | stride=conv_2_stride, 109 | padding='SAME', 110 | activation_fn=tf.nn.relu, 111 | weights_initializer=tf.contrib.layers.xavier_initializer_conv2d(), 112 | weights_regularizer=tf.contrib.layers.l2_regularizer(1.0), 113 | biases_initializer=initializer) 114 | shrunk_h = (h + conv_2_stride - 1) // conv_2_stride 115 | shrunk_w = (w + conv_2_stride - 1) // conv_2_stride 116 | 117 | # Third convolution. 118 | conv_3_out_channels = 100 119 | conv_3_stride = 2 120 | conv_3 = tf.contrib.layers.conv2d( 121 | trainable=True, 122 | inputs=conv_2, 123 | num_outputs=conv_3_out_channels, 124 | kernel_size=[5, 5], 125 | stride=conv_3_stride, 126 | padding='SAME', 127 | activation_fn=tf.nn.relu, 128 | weights_initializer=tf.contrib.layers.xavier_initializer_conv2d(), 129 | weights_regularizer=tf.contrib.layers.l2_regularizer(1.0), 130 | biases_initializer=initializer) 131 | shrunk_h = (shrunk_h + conv_3_stride - 1) // conv_3_stride 132 | shrunk_w = (shrunk_w + conv_3_stride - 1) // conv_3_stride 133 | 134 | # Resupply the input at this point. 135 | resupply = tf.concat([ 136 | tf.reshape(conv_3, 137 | [-1, shrunk_h * shrunk_w * conv_3_out_channels]), 138 | tf.reshape(embedding_lookup, [-1, h * w * _EMBEDDING_SIZE]) 139 | ], 1, name='resupply') 140 | 141 | # First fully connected layer. 142 | connected_1 = tf.contrib.layers.fully_connected( 143 | trainable=True, 144 | inputs=resupply, 145 | num_outputs=h+w, 146 | activation_fn=tf.nn.relu, 147 | weights_initializer=initializer, 148 | weights_regularizer=tf.contrib.layers.l2_regularizer(1.0), 149 | biases_initializer=initializer) 150 | 151 | # Second fully connected layer, steps down. 152 | connected_2 = tf.contrib.layers.fully_connected( 153 | trainable=True, 154 | inputs=connected_1, 155 | num_outputs=17, 156 | activation_fn=tf.nn.relu, 157 | weights_initializer=initializer, 158 | weights_regularizer=tf.contrib.layers.l2_regularizer(1.0), 159 | biases_initializer=initializer) 160 | 161 | # Logits, softmax, random sample. 162 | connected_3 = tf.contrib.layers.fully_connected( 163 | trainable=True, 164 | inputs=connected_2, 165 | num_outputs=_ACTION_SIZE, 166 | activation_fn=tf.nn.sigmoid, 167 | weights_initializer=initializer, 168 | weights_regularizer=tf.contrib.layers.l2_regularizer(1.0), 169 | biases_initializer=initializer) 170 | self.action_softmax = tf.nn.softmax(connected_3, name='action_softmax') 171 | 172 | # Sum the components of the softmax 173 | probability_histogram = tf.cumsum(self.action_softmax, axis=1) 174 | sample = tf.random_uniform(tf.shape(probability_histogram)[:-1]) 175 | filtered = tf.where(probability_histogram >= sample, 176 | probability_histogram, 177 | tf.ones_like(probability_histogram)) 178 | 179 | self.action_out = tf.argmin(filtered, 1) 180 | 181 | self.action_in = tf.placeholder(tf.int32, shape=[None, 1]) 182 | self.advantage = tf.placeholder(tf.float32, shape=[None, 1]) 183 | 184 | action_one_hot = tf.one_hot(self.action_in, _ACTION_SIZE, 185 | dtype=tf.float32) 186 | action_advantage = self.advantage * action_one_hot 187 | loss_policy = -tf.reduce_mean( 188 | tf.reduce_sum(tf.log(self.action_softmax) * action_advantage, 1), 189 | name='loss_policy') 190 | # TODO: Investigate whether regularization losses are sums or 191 | # means and consider removing the division. 192 | loss_regularization = (0.05 / tf.to_float(tf.shape(self.state)[0]) * 193 | sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))) 194 | self.loss = loss_policy + loss_regularization 195 | 196 | tf.summary.scalar('loss_policy', loss_policy) 197 | tf.summary.scalar('loss_regularization', loss_regularization) 198 | 199 | # TODO: Use a decaying learning rate 200 | optimizer = tf.train.AdamOptimizer(learning_rate=0.05) 201 | self.update = optimizer.minimize(self.loss) 202 | 203 | self.summary = tf.summary.merge_all() 204 | 205 | def predict(self, session, states): 206 | '''Chooses actions for a list of states. 207 | 208 | Args: 209 | session: The TensorFlow session to run the net in. 210 | states: A list of simulation states which have been serialized 211 | to arrays. 212 | 213 | Returns: 214 | An array of actions, 0 .. 4 and an array of array of 215 | probabilities. 216 | ''' 217 | return session.run([self.action_out, self.action_softmax], 218 | feed_dict={self.state: states}) 219 | 220 | def train(self, session, episodes): 221 | '''Trains the network. 222 | 223 | Args: 224 | episodes: A list of episodes. Each episode is a list of 225 | 3-tuples with the state, the chosen action, and the 226 | reward. 227 | ''' 228 | size = sum(map(len, episodes)) 229 | state = np.empty([size, self._h, self._w]) 230 | action_in = np.empty([size, 1]) 231 | advantage = np.empty([size, 1]) 232 | i = 0 233 | for episode in episodes: 234 | r = 0.0 235 | for step_state, action, reward in reversed(episode): 236 | state[i,:,:] = step_state 237 | action_in[i,0] = action 238 | r = reward + 0.97 * r 239 | advantage[i,0] = r 240 | i += 1 241 | # Scale rewards to have zero mean, unit variance 242 | advantage = (advantage - np.mean(advantage)) / np.var(advantage) 243 | 244 | session.run([self.summary, self.update], feed_dict={ 245 | self.state: state, 246 | self.action_in: action_in, 247 | self.advantage: advantage 248 | }) 249 | 250 | 251 | _EXPERIENCE_BUFFER_SIZE = 5 252 | 253 | 254 | class PolicyGradientPlayer(grid.Player): 255 | def __init__(self, graph, session, world_size_w_h): 256 | super(PolicyGradientPlayer, self).__init__() 257 | w, h = world_size_w_h 258 | self._net = PolicyGradientNetwork('net', graph, (h, w)) 259 | self._experiences = collections.deque([], _EXPERIENCE_BUFFER_SIZE) 260 | self._experience = [] 261 | self._session = session 262 | 263 | def interact(self, ctx, sim): 264 | if sim.in_terminal_state: 265 | self._experiences.append(self._experience) 266 | self._experience = [] 267 | self._net.train(self._session, self._experiences) 268 | sim.reset() 269 | else: 270 | state = sim.to_array() 271 | [[action], _] = self._net.predict(self._session, [state]) 272 | reward = sim.act(movement.ALL_ACTIONS[action]) 273 | self._experience.append((state, action, reward)) 274 | -------------------------------------------------------------------------------- /srl/policy_gradient_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | import unittest 18 | 19 | from srl import movement 20 | from srl import simulation 21 | from srl import world 22 | import srl.policy_gradient as pg 23 | 24 | 25 | class TestPolicyGradientNetwork(unittest.TestCase): 26 | def testPredict(self): 27 | g = tf.Graph() 28 | net = pg.PolicyGradientNetwork('testPredict', g, (7, 11)) 29 | 30 | s = tf.Session(graph=g) 31 | with g.as_default(): 32 | init = tf.global_variables_initializer() 33 | s.run(init) 34 | 35 | sim = simulation.Simulation(world.Generator(11, 7)) 36 | state = sim.to_array() 37 | [[act], _] = net.predict(s, [state]) 38 | self.assertTrue(0 <= act) 39 | self.assertTrue(act < len(movement.ALL_ACTIONS)) 40 | 41 | def testTrain(self): 42 | g = tf.Graph() 43 | net = pg.PolicyGradientNetwork('testTrain', g, (4, 4)) 44 | 45 | s = tf.Session(graph=g) 46 | with g.as_default(): 47 | init = tf.global_variables_initializer() 48 | s.run(init) 49 | 50 | sim = simulation.Simulation(world.Generator(4, 4)) 51 | state = sim.to_array() 52 | net.train(s, [[(state, 3, 7), (state, 3, -1)], [(state, 0, 1000)]]) 53 | 54 | def testActionOut_untrainedPrediction(self): 55 | g = tf.Graph() 56 | net = pg.PolicyGradientNetwork('testActionOut_untrainedPrediction', g, 57 | (17, 13)) 58 | s = tf.Session(graph=g) 59 | with g.as_default(): 60 | init = tf.global_variables_initializer() 61 | s.run(init) 62 | act = s.run(net.action_out, 63 | feed_dict={net.state: [np.zeros((17, 13))]}) 64 | self.assertTrue(0 <= act) 65 | self.assertTrue(act < len(movement.ALL_ACTIONS)) 66 | 67 | def testUpdate(self): 68 | g = tf.Graph() 69 | net = pg.PolicyGradientNetwork('testUpdate', g, (13, 23)) 70 | s = tf.Session(graph=g) 71 | with g.as_default(): 72 | init = tf.global_variables_initializer() 73 | s.run(init) 74 | s.run(net.update, feed_dict={ 75 | net.state: np.zeros((7, 13, 23)), 76 | net.action_in: np.zeros((7, 1)), 77 | net.advantage: np.zeros((7, 1)), 78 | }) 79 | 80 | def testUpdate_lossDecreases(self): 81 | w = world.World.parse('@.....$') 82 | 83 | g = tf.Graph() 84 | net = pg.PolicyGradientNetwork('testUpdate_lossDecreases', g, (w.h, w.w)) 85 | s = tf.Session(graph=g) 86 | with g.as_default(): 87 | init = tf.global_variables_initializer() 88 | s.run(init) 89 | 90 | state = simulation.Simulation(world.Static(w)).to_array() 91 | losses = [] 92 | for _ in range(10): 93 | loss, _ = s.run([net.loss, net.update], feed_dict={ 94 | net.state: [state], 95 | net.action_in: [[1]], 96 | net.advantage: [[2]] 97 | }) 98 | losses.append(loss) 99 | self.assertTrue(losses[-1] < losses[0]) 100 | -------------------------------------------------------------------------------- /srl/simulation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | from srl import movement 18 | 19 | 20 | class Simulation(object): 21 | '''Tracks the player in a world and implements the rules and rewards. 22 | 23 | score is the cumulative score of the player in this run of the 24 | simulation. 25 | ''' 26 | def __init__(self, generator): 27 | self._generator = generator 28 | 29 | # Initialized by reset() 30 | self.state = None 31 | self.world = None 32 | 33 | self.reset() 34 | 35 | def reset(self): 36 | '''Resets the simulation to the initial state.''' 37 | self.world = self._generator.generate() 38 | self.state = self.world.init_state 39 | self.score = 0 40 | 41 | @property 42 | def in_terminal_state(self): 43 | '''Whether the simulation is in a terminal state (stopped.)''' 44 | return self.world.at(self.state) in ['^', '$'] or self.score < -500 45 | 46 | @property 47 | def x(self): 48 | '''The x coordinate of the player.''' 49 | return self.state[0] 50 | 51 | @property 52 | def y(self): 53 | '''The y coordinate of the player.''' 54 | return self.state[1] 55 | 56 | def act(self, action): 57 | '''Performs action and returns the reward from that step.''' 58 | reward = -1 59 | 60 | delta = movement.MOVEMENT[action] 61 | new_state = self.x + delta[0], self.y + delta[1] 62 | 63 | if self._valid_move(new_state): 64 | ch = self.world.at(new_state) 65 | if ch == '^': 66 | reward = -10000 67 | elif ch == '$': 68 | reward = 10000 69 | self.state = new_state 70 | else: 71 | # Penalty for hitting the walls. 72 | reward -= 5 73 | 74 | self.score += reward 75 | return reward 76 | 77 | def _valid_move(self, new_state): 78 | '''Gets whether movement to new_state is a valid move.''' 79 | new_x, new_y = new_state 80 | # TODO: Could check that there's no teleportation cheating. 81 | return (0 <= new_x and new_x < self.world.w and 82 | 0 <= new_y and new_y < self.world.h and 83 | self.world.at(new_state) in ['.', '^', '$']) 84 | 85 | def to_array(self): 86 | '''Converts the state of a simulation to numpy ndarray. 87 | 88 | The returned array has numpy.int8 units with the following mapping. 89 | This mapping has no special meaning because these indices are fed 90 | into an embedding layer. 91 | ' ' -> 0 92 | '#' -> 1 93 | '$' -> 2 94 | '.' -> 3 95 | '@' -> 4 96 | '^' -> 5 97 | Args: 98 | sim: A simulation.Simulation to externalize the state of. 99 | Returns: 100 | The world map and player position represented as an numpy ndarray. 101 | ''' 102 | key = ' #$.@^' 103 | w = np.empty(shape=(self.world.h, self.world.w), dtype=np.int8) 104 | for v in range(self.world.h): 105 | for u in range(self.world.w): 106 | w[v, u] = key.index(self.world.at((u, v))) 107 | w[self.y, self.x] = key.index('@') 108 | return w 109 | -------------------------------------------------------------------------------- /srl/simulation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import unittest 17 | 18 | from srl import movement 19 | from srl import simulation 20 | from srl import world 21 | 22 | 23 | class TestSimulation(unittest.TestCase): 24 | def test_in_terminal_state(self): 25 | w = world.World.parse('@^') 26 | sim = simulation.Simulation(world.Static(w)) 27 | self.assertFalse(sim.in_terminal_state) 28 | sim.act(movement.ACTION_RIGHT) 29 | self.assertTrue(sim.in_terminal_state) 30 | 31 | def test_act_accumulates_score(self): 32 | w = world.World.parse('@.') 33 | sim = simulation.Simulation(world.Static(w)) 34 | sim.act(movement.ACTION_RIGHT) 35 | sim.act(movement.ACTION_LEFT) 36 | self.assertEqual(-2, sim.score) 37 | 38 | def test_to_array(self): 39 | w = world.World.parse('$.@^#') 40 | sim = simulation.Simulation(world.Static(w)) 41 | self.assertTrue( 42 | (np.array([[2, 3, 4, 5, 1]], dtype=np.int8) == sim.to_array()) 43 | .all()) 44 | -------------------------------------------------------------------------------- /srl/world.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | 17 | from srl import movement 18 | 19 | 20 | # Grid world maps are specified with characters a bit like NetHack: 21 | # #, (blank) are impassable 22 | # . is passable 23 | # @ is the player start point 24 | # ^ is a trap, with a large negative reward 25 | # $ is the goal 26 | VALID_CHARS = set(['#', '.', '@', '$', '^', ' ']) 27 | 28 | 29 | class WorldFailure(Exception): 30 | pass 31 | 32 | 33 | class World(object): 34 | '''A grid world.''' 35 | def __init__(self, init_state, lines): 36 | '''Creates a grid world. 37 | init_state: the (x,y) player start position 38 | lines: list of strings of VALID_CHARS, the map''' 39 | self.init_state = init_state 40 | self._lines = [] + lines 41 | 42 | @classmethod 43 | def parse(cls, s): 44 | '''Parses a grid world in the string s. 45 | s must be made up of equal-length lines of VALID_CHARS with one start position 46 | denoted by @.''' 47 | init_state = None 48 | 49 | lines = s.split() 50 | if not lines: 51 | raise WorldFailure('no content') 52 | for (y, line) in enumerate(lines): 53 | if y > 0 and len(line) != len(lines[0]): 54 | raise WorldFailure('line %d is a different length (%d vs %d)' % 55 | (y, len(line), len(lines[0]))) 56 | for (x, ch) in enumerate(line): 57 | if not ch in VALID_CHARS: 58 | raise WorldFailure('invalid char "%c" at (%d, %d)' % (ch, x, y)) 59 | if ch == '@': 60 | if init_state: 61 | raise WorldFailure('multiple initial states, at %o and ' 62 | '(%d, %d)' % (init_state, x, y)) 63 | init_state = (x, y) 64 | if not init_state: 65 | raise WorldFailure('no initial state, use "@"') 66 | # The player start position is in fact ordinary ground. 67 | x, y = init_state 68 | line = lines[y] 69 | lines[y] = line[0:x] + '.' + line[x+1:] 70 | return World(init_state, lines) 71 | 72 | @property 73 | def size(self): 74 | '''The size of the grid world, width by height.''' 75 | return self.w, self.h 76 | 77 | @property 78 | def h(self): 79 | '''The height of the grid world.''' 80 | return len(self._lines) 81 | 82 | @property 83 | def w(self): 84 | '''The width of the grid world.''' 85 | return len(self._lines[0]) 86 | 87 | def at(self, pos): 88 | '''Gets the character at an (x, y) coordinate. 89 | Positions are indexed from the origin 0,0 at the top, left of the map.''' 90 | x, y = pos 91 | return self._lines[y][x] 92 | 93 | def pretty_str(self): 94 | copy = [] + self._lines 95 | x, y = self.init_state 96 | start_line = copy[y] 97 | copy[y] = start_line[0:x] + '@' + start_line[x+1:] 98 | return '\n'.join(copy) 99 | 100 | 101 | class Static(object): 102 | def __init__(self, value): 103 | self._value = value 104 | 105 | def generate(self): 106 | return self._value 107 | 108 | @property 109 | def size(self): 110 | '''The size of the grid world, width by height.''' 111 | return self._value.size 112 | 113 | 114 | class Generator(object): 115 | '''Generates random grid worlds.''' 116 | def __init__(self, width, height): 117 | '''Creates a generator for worlds with a fixed size. 118 | width: The width of the world. Must be at least two cells wide. 119 | height: The height of the world. Must be at least one cell high. 120 | ''' 121 | assert 2 <= width 122 | assert 1 <= height 123 | self._width = width 124 | self._height = height 125 | self._grid = None 126 | self._passable = None 127 | 128 | @property 129 | def size(self): 130 | return (self._width, self._height) 131 | 132 | def generate(self): 133 | '''Generates and returns a new world.''' 134 | self._grid = list(map(lambda _: [' '] * self._width, range(self._height))) 135 | self._passable = set() 136 | 137 | x = random.randrange(0, self._width - 1) 138 | y = random.randrange(0, self._height) 139 | 140 | # Make at least two squares passable 141 | self._paint((x, y), '.') 142 | self._paint((x + 1, y), '.') 143 | 144 | # Take a random walk, for a while 145 | d = random.randrange(0, 4) 146 | for _ in range(random.randrange(self._width + self._height, 147 | self._width * self._height + 2)): 148 | self._paint((x, y), '.') 149 | dx, dy = movement.ALL_MOTIONS[d] 150 | x += dx 151 | y += dy 152 | x = max(0, min(x, self._width - 1)) 153 | y = max(0, min(y, self._height - 1)) 154 | d = (d + random.choice([-1, 0, 0, 0, 0, 0, 1])) % 4 # Turn sometimes 155 | 156 | # Pick a start and end position 157 | start = self._random_passable() 158 | # Start is technically passable, but we do not want to overwrite it 159 | self._passable.discard(start) 160 | end = self._random_passable() 161 | self._paint(end, '$') 162 | 163 | # Paint some traps. 164 | n_squares = len(self._passable) 165 | for _ in range(random.randrange(n_squares // 6, n_squares // 4 + 1)): 166 | p = self._random_passable() 167 | self._paint(p, '^') 168 | if not self._is_reachable(start, end): 169 | # Oops, put it back. 170 | self._paint(p, '.') 171 | 172 | grid = list(map(''.join, self._grid)) 173 | return World(start, grid) 174 | 175 | def _random_passable(self): 176 | return random.choice(tuple(self._passable)) 177 | 178 | def _paint(self, p, ch): 179 | self._grid[p[1]][p[0]] = ch 180 | if ch == '.': 181 | self._passable.add(p) 182 | else: 183 | self._passable.discard(p) 184 | 185 | def _is_reachable(self, start, end): 186 | work = [start] 187 | visited = set(work) 188 | if start == end: 189 | return True 190 | while work: 191 | (x, y) = work.pop() 192 | for dx, dy in movement.ALL_MOTIONS: 193 | p = x + dx, y + dy 194 | if p == end: 195 | # Subtly this permits reaching end even if it is not 196 | # "passable". This is so the generator can discard the end 197 | # as "passable" so it is not selected to be overwritten with 198 | # a trap. 199 | return True 200 | elif self._is_passable(p) and p not in visited: 201 | visited.add(p) 202 | work.append(p) 203 | return False 204 | 205 | def _is_passable(self, p): 206 | return p in self._passable 207 | -------------------------------------------------------------------------------- /srl/world_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from srl import world 18 | 19 | 20 | class TestWorld(unittest.TestCase): 21 | def test_size(self): 22 | g = world.World.parse('@$') 23 | self.assertEqual((2, 1), g.size) 24 | 25 | def test_init_state(self): 26 | g = world.World.parse('####\n#.@#\n####') 27 | self.assertEqual((2, 1), g.init_state) 28 | 29 | def test_parse_no_init_state_fails(self): 30 | with self.assertRaises(world.WorldFailure): 31 | world.World.parse('#') 32 | 33 | 34 | class TestGenerator(unittest.TestCase): 35 | def test_generate_tiny_world(self): 36 | g = world.Generator(2, 1) 37 | w = g.generate() 38 | # The world should have a start and goal 39 | if w.init_state == (0, 0): 40 | self.assertEqual('$', w.at((1, 0))) 41 | elif w.init_state == (1, 0): 42 | self.assertEqual('$', w.at((0, 0))) 43 | else: 44 | self.fail('the start position %s is invalid' % (w.init_state,)) 45 | --------------------------------------------------------------------------------