├── .gitignore ├── LICENSE.md ├── LICENSE2.md ├── README.md ├── TODO.txt ├── data └── example.gif ├── notebooks ├── DoublePendulum.ipynb ├── MLP.ipynb ├── MLP_copy.ipynb ├── XOR Network.ipynb ├── game.ipynb ├── game_memory.ipynb ├── karpathy_game.ipynb ├── pong.py └── tf_rl ├── requirements.txt ├── saved_models ├── checkpoint └── karpathy_game.ckpt ├── scripts └── make_gif.sh └── tf_rl ├── __init__.py ├── controller ├── __init__.py ├── discrete_deepq.py └── human_controller.py ├── models.py ├── simulate.py ├── simulation ├── __init__.py ├── discrete_hill.py ├── double_pendulum.py └── karpathy_game.py └── utils ├── __init__.py ├── event_queue.py ├── geometry.py ├── getch.py └── svg.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | *.pyc 4 | *.pyo 5 | notebooks/Goofiness.ipynb 6 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Szymon Sidor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE2.md: -------------------------------------------------------------------------------- 1 | Copyright 2016 Szymon Sidor. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2015, The TensorFlow Authors. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This reposity is now obsolte! 2 | 3 | Check out the new simpler, better performing and more complete implementation that we released at OpenAI: 4 | 5 | https://github.com/openai/baselines 6 | 7 | 8 | (scroll for docs of the obsolete version) 9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | ### Reinforcement Learning using Tensor Flow 43 | 44 | 45 | #### Quick start 46 | 47 | Check out Karpathy game in `notebooks` folder. 48 | 49 | 50 | 51 | *The image above depicts a strategy learned by the DeepQ controller. Available actions are accelerating top, bottom, left or right. The reward signal is +1 for the green fellas, -1 for red and -5 for orange.* 52 | 53 | #### Requirements 54 | 55 | - `future==0.15.2` 56 | - `euclid==0.1` 57 | - `inkscape` (for animation gif creation) 58 | 59 | #### How does this all fit together. 60 | 61 | `tf_rl` has controllers and simulators which can be pieced together using simulate function. 62 | 63 | #### Using human controller. 64 | Want to have some fun controlling the simulation by yourself? You got it! 65 | Use `tf_rl.controller.HumanController` in your simulation. 66 | 67 | To issue commands run in terminal 68 | ```python3 69 | python3 tf_rl/controller/human_controller.py 70 | ``` 71 | For it to work you also need to have a redis server running locally. 72 | 73 | #### Writing your own controller 74 | To write your own controller define a controller class with 3 functions: 75 | - `action(self, observation)` given an observation (usually a tensor of numbers) representing an observation returns action to perform. 76 | - `store(self, observation, action, reward, newobservation)` called each time a transition is observed from `observation` to `newobservation`. Transition is a consequence of `action` and has associated `reward` 77 | - `training_step(self)` if your controller requires training that is the place to do it, should not take to long, because it will be called roughly every action execution. 78 | 79 | #### Writing your own simulation 80 | To write your own simulation define a simulation class with 4 functions: 81 | - `observe(self)` returns a current observation 82 | - `collect_reward(self)` returns the reward accumulated since the last time function was called. 83 | - `perform_action(self, action)` updates internal state to reflect the fact that `aciton` was executed 84 | - `step(self, dt)` update internal state as if `dt` of simulation time has passed. 85 | - `to_html(self, info=[])` generate an html visualization of the game. `info` can be optionally passed an has a list of strings that should be displayed along with the visualization 86 | 87 | 88 | 89 | #### Creating GIFs based on simulation 90 | The `simulate` method accepts `save_path` argument which is a folder where all the consecutive images will be stored. 91 | To make them into a GIF use `scripts/make_gif.sh PATH` where path is the same as the path you passed to `save_path` argument 92 | -------------------------------------------------------------------------------- /TODO.txt: -------------------------------------------------------------------------------- 1 | -> Variable naming 2 | -> Documentation 3 | -> move to separate files 4 | -> more interesting examples 5 | -------------------------------------------------------------------------------- /data/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siemanko/tensorflow-deepq/149e69e5340984d75df3ff1a374920d870517fb9/data/example.gif -------------------------------------------------------------------------------- /notebooks/DoublePendulum.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": { 19 | "collapsed": true 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import math\n", 24 | "import random\n", 25 | "import time\n", 26 | "\n", 27 | "from collections import defaultdict\n", 28 | "\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 8, 34 | "metadata": { 35 | "collapsed": true 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "from tf_rl.simulation import DoublePendulum\n", 40 | "from tf_rl import simulate" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 9, 46 | "metadata": { 47 | "collapsed": false 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "DOUBLE_PENDULUM_PARAMS = {\n", 52 | " 'g_ms2': 9.8, # acceleration due to gravity, in m/s^2\n", 53 | " 'l1_m': 1.0, # length of pendulum 1 in m\n", 54 | " 'l2_m': 2.0, # length of pendulum 2 in m\n", 55 | " 'm1_kg': 1.0, # mass of pendulum 1 in kg\n", 56 | " 'm2_kg': 1.0, # mass of pendulum 2 in kg\n", 57 | " 'damping': 0.4,\n", 58 | " 'max_control_input': 20.0\n", 59 | "}" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 10, 65 | "metadata": { 66 | "collapsed": false 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "d = DoublePendulum(DOUBLE_PENDULUM_PARAMS)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 11, 76 | "metadata": { 77 | "collapsed": true 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "d.perform_action(0.2)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 13, 87 | "metadata": { 88 | "collapsed": false 89 | }, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/html": [ 94 | "\n", 95 | "\n", 96 | "\n", 97 | "\n", 98 | " \n", 101 | "\n", 102 | " \n", 105 | "\n", 106 | " \n", 107 | "\n", 108 | " \n", 109 | "\n", 110 | " \n", 113 | "\n", 114 | " \n", 117 | "\n", 118 | " \n", 121 | "\n", 122 | " \n", 123 | "\n", 124 | " Reward = -2.1\n", 125 | "\n", 126 | " \n", 127 | "\n", 128 | " \n", 129 | "\n" 130 | ], 131 | "text/plain": [ 132 | "" 133 | ] 134 | }, 135 | "metadata": {}, 136 | "output_type": "display_data" 137 | }, 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "Interrupted\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "try:\n", 148 | " simulate(d, fps=30, actions_per_simulation_second=1, speed=1.0, simulation_resultion=0.01)\n", 149 | "except KeyboardInterrupt:\n", 150 | " print(\"Interrupted\")" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": { 157 | "collapsed": true 158 | }, 159 | "outputs": [], 160 | "source": [] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "Python 2", 166 | "language": "python", 167 | "name": "python2" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 2 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython2", 179 | "version": "2.7.8" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 0 184 | } 185 | -------------------------------------------------------------------------------- /notebooks/MLP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import tensorflow as tf\n", 13 | "\n", 14 | "from __future__ import print_function" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "collapsed": true 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "def create_examples(N, batch_size):\n", 26 | " A = np.random.binomial(n=1, p=0.5, size=(batch_size, N))\n", 27 | " B = np.random.binomial(n=1, p=0.5, size=(batch_size, N,))\n", 28 | "\n", 29 | " X = np.zeros((batch_size, 2 *N,), dtype=np.float32)\n", 30 | " X[:,:N], X[:,N:] = A, B\n", 31 | "\n", 32 | " Y = (A ^ B).astype(np.float32)\n", 33 | " return X,Y" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": { 40 | "collapsed": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "import math\n", 45 | "\n", 46 | "class Layer(object):\n", 47 | " def __init__(self, input_sizes, output_size):\n", 48 | " \"\"\"Cretes a neural network layer.\"\"\"\n", 49 | " if type(input_sizes) != list:\n", 50 | " input_sizes = [input_sizes]\n", 51 | " \n", 52 | " self.input_sizes = input_sizes\n", 53 | " self.output_size = output_size\n", 54 | " \n", 55 | " self.Ws = []\n", 56 | " for input_size in input_sizes:\n", 57 | " tensor_W = tf.random_uniform((input_size, output_size),\n", 58 | " -1.0 / math.sqrt(input_size),\n", 59 | " 1.0 / math.sqrt(input_size))\n", 60 | " self.Ws.append(tf.Variable(tensor_W))\n", 61 | "\n", 62 | " tensor_b = tf.zeros((output_size,))\n", 63 | " self.b = tf.Variable(tensor_b)\n", 64 | " \n", 65 | " def __call__(self, xs):\n", 66 | " if type(xs) != list:\n", 67 | " xs = [xs]\n", 68 | " assert len(xs) == len(self.Ws), \\\n", 69 | " \"Expected %d input vectors, got %d\" % (len(self.Ws), len(x))\n", 70 | " return sum([tf.matmul(x, W) for x, W in zip(xs, self.Ws)]) + self.b\n", 71 | "\n", 72 | " \n", 73 | "class MLP(object):\n", 74 | " def __init__(self, input_sizes, hiddens, nonlinearities):\n", 75 | " self.input_sizes = input_sizes\n", 76 | " self.hiddens = hiddens\n", 77 | " self.input_nonlinearity, self.layer_nonlinearities = nonlinearities[0], nonlinearities[1:]\n", 78 | "\n", 79 | " assert len(hiddens) == len(nonlinearities), \\\n", 80 | " \"Number of hiddens must be equal to number of nonlinearities\"\n", 81 | " \n", 82 | " self.input_layer = Layer(input_sizes, hiddens[0])\n", 83 | " self.layers = [Layer(h_from, h_to) for h_from, h_to in zip(hiddens[:-1], hiddens[1:])]\n", 84 | "\n", 85 | " def __call__(self, xs):\n", 86 | " if type(xs) != list:\n", 87 | " xs = [xs]\n", 88 | " hidden = self.input_nonlinearity(self.input_layer(xs))\n", 89 | " for layer, nonlinearity in zip(self.layers, self.layer_nonlinearities):\n", 90 | " hidden = nonlinearity(layer(hidden))\n", 91 | " return hidden" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 19, 97 | "metadata": { 98 | "collapsed": false 99 | }, 100 | "outputs": [ 101 | { 102 | "name": "stderr", 103 | "output_type": "stream", 104 | "text": [ 105 | "Exception AssertionError: AssertionError() in > ignored\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "tf.ops.reset_default_graph()\n", 111 | "sess = tf.InteractiveSession()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 20, 117 | "metadata": { 118 | "collapsed": false 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "N = 5\n", 123 | "# we add a single hidden layer of size 12\n", 124 | "# otherwise code is similar to above\n", 125 | "HIDDEN_SIZE = 12\n", 126 | "\n", 127 | "x = tf.placeholder(tf.float32, (None, 2 * N), name=\"x\")\n", 128 | "y_golden = tf.placeholder(tf.float32, (None, N), name=\"y\")\n", 129 | "\n", 130 | "mlp = MLP(2 * N, [HIDDEN_SIZE, N], [tf.tanh, tf.sigmoid])\n", 131 | "y = mlp(x)\n", 132 | "\n", 133 | "cost = tf.reduce_mean(tf.square(y - y_golden))\n", 134 | "\n", 135 | "optimizer = tf.train.AdagradOptimizer(learning_rate=0.3)\n", 136 | "train_op = optimizer.minimize(cost)\n", 137 | "sess.run(tf.initialize_all_variables())" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 21, 143 | "metadata": { 144 | "collapsed": false 145 | }, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "0.241206\n", 152 | "0.246315\n", 153 | "0.193208\n", 154 | "0.107224\n", 155 | "0.100235\n", 156 | "0.0613462\n", 157 | "0.0480775\n", 158 | "0.0498072\n", 159 | "0.0403215\n", 160 | "0.0474323\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "for t in range(5000):\n", 166 | " example_x, example_y = create_examples(N, 10)\n", 167 | " cost_t, _ = sess.run([cost, train_op], {x: example_x, y_golden: example_y})\n", 168 | " if t % 500 == 0: \n", 169 | " print(cost_t.mean())" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 22, 175 | "metadata": { 176 | "collapsed": false 177 | }, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "Accuracy over 1000 examples: 98 %\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "N_EXAMPLES = 1000\n", 189 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n", 190 | "is_correct = tf.less_equal(tf.abs(y - y_golden), tf.constant(0.5))\n", 191 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n", 192 | "\n", 193 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n", 194 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": { 201 | "collapsed": true 202 | }, 203 | "outputs": [], 204 | "source": [] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 2", 210 | "language": "python", 211 | "name": "python2" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 2 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython2", 223 | "version": "2.7.8" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 0 228 | } 229 | -------------------------------------------------------------------------------- /notebooks/MLP_copy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": { 19 | "collapsed": true 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "import tensorflow as tf\n", 25 | "\n", 26 | "from __future__ import print_function" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": { 33 | "collapsed": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "def create_examples(N, batch_size):\n", 38 | " A = np.random.binomial(n=1, p=0.5, size=(batch_size, N))\n", 39 | " B = np.random.binomial(n=1, p=0.5, size=(batch_size, N,))\n", 40 | "\n", 41 | " X = np.zeros((batch_size, 2 *N,), dtype=np.float32)\n", 42 | " X[:,:N], X[:,N:] = A, B\n", 43 | "\n", 44 | " Y = (A ^ B).astype(np.float32)\n", 45 | " return X,Y" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": { 52 | "collapsed": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "from tf_rl.models import MLP" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 5, 62 | "metadata": { 63 | "collapsed": false 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "tf.ops.reset_default_graph()\n", 68 | "sess = tf.InteractiveSession()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 6, 74 | "metadata": { 75 | "collapsed": false 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "N = 5\n", 80 | "# we add a single hidden layer of size 12\n", 81 | "# otherwise code is similar to above\n", 82 | "HIDDEN_SIZE = 12\n", 83 | "\n", 84 | "x = tf.placeholder(tf.float32, (None, 2 * N), name=\"x\")\n", 85 | "y_golden = tf.placeholder(tf.float32, (None, N), name=\"y\")\n", 86 | "\n", 87 | "mlp = MLP(2 * N, [HIDDEN_SIZE, N], [tf.tanh, tf.sigmoid])\n", 88 | "y = mlp(x)\n", 89 | "\n", 90 | "cost = tf.reduce_mean(tf.square(y - y_golden))\n", 91 | "\n", 92 | "optimizer = tf.train.AdagradOptimizer(learning_rate=0.3)\n", 93 | "train_op = optimizer.minimize(cost)\n", 94 | "sess.run(tf.initialize_all_variables())" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 7, 100 | "metadata": { 101 | "collapsed": false 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "mlp2 = mlp.copy(sess)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 8, 111 | "metadata": { 112 | "collapsed": false 113 | }, 114 | "outputs": [ 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "0.25368\n", 120 | "0.24415\n", 121 | "0.170396\n", 122 | "0.0975643\n", 123 | "0.0823987\n", 124 | "0.0314237\n", 125 | "0.0417636\n", 126 | "0.0364044\n", 127 | "0.039408\n", 128 | "0.026788\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "for t in range(5000):\n", 134 | " example_x, example_y = create_examples(N, 10)\n", 135 | " cost_t, _ = sess.run([cost, train_op], {x: example_x, y_golden: example_y})\n", 136 | " if t % 500 == 0: \n", 137 | " print(cost_t.mean())" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 9, 143 | "metadata": { 144 | "collapsed": false 145 | }, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Accuracy over 1000 examples: 99 %\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "N_EXAMPLES = 1000\n", 157 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n", 158 | "is_correct = tf.less_equal(tf.abs(y - y_golden), tf.constant(0.5))\n", 159 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n", 160 | "\n", 161 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n", 162 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 10, 168 | "metadata": { 169 | "collapsed": false 170 | }, 171 | "outputs": [ 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "Accuracy over 1000 examples: 52 %\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "# If copy works accuracy should be around 50% for this one\n", 182 | "N_EXAMPLES = 1000\n", 183 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n", 184 | "is_correct = tf.less_equal(tf.abs(mlp2(x) - y_golden), tf.constant(0.5))\n", 185 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n", 186 | "\n", 187 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n", 188 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": { 195 | "collapsed": true 196 | }, 197 | "outputs": [], 198 | "source": [] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "Python 2", 204 | "language": "python", 205 | "name": "python2" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 2 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython2", 217 | "version": "2.7.8" 218 | } 219 | }, 220 | "nbformat": 4, 221 | "nbformat_minor": 0 222 | } 223 | -------------------------------------------------------------------------------- /notebooks/XOR Network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import tensorflow as tf\n", 13 | "\n", 14 | "from __future__ import print_function" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# XOR Network" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "### Data generation" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "collapsed": true 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "def create_examples(N, batch_size):\n", 40 | " A = np.random.binomial(n=1, p=0.5, size=(batch_size, N))\n", 41 | " B = np.random.binomial(n=1, p=0.5, size=(batch_size, N,))\n", 42 | "\n", 43 | " X = np.zeros((batch_size, 2 *N,), dtype=np.float32)\n", 44 | " X[:,:N], X[:,N:] = A, B\n", 45 | "\n", 46 | " Y = (A ^ B).astype(np.float32)\n", 47 | " return X,Y" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": { 54 | "collapsed": false 55 | }, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "[ 0. 1. 0.] xor [ 1. 1. 1.] equals [ 1. 0. 1.]\n", 62 | "[ 0. 0. 1.] xor [ 1. 1. 0.] equals [ 1. 1. 1.]\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "X, Y = create_examples(3, 2)\n", 68 | "print(X[0,:3], \"xor\", X[0,3:],\"equals\", Y[0])\n", 69 | "print(X[1,:3], \"xor\", X[1,3:],\"equals\", Y[1])\n" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "### Xor cannot be solved with single layer of neural network" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 22, 82 | "metadata": { 83 | "collapsed": false 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "import math\n", 88 | "\n", 89 | "class Layer(object):\n", 90 | " def __init__(self, input_size, output_size):\n", 91 | " tensor_b = tf.zeros((output_size,))\n", 92 | " self.b = tf.Variable(tensor_b)\n", 93 | " tensor_W = tf.random_uniform((input_size, output_size),\n", 94 | " -1.0 / math.sqrt(input_size),\n", 95 | " 1.0 / math.sqrt(input_size))\n", 96 | " self.W = tf.Variable(tensor_W)\n", 97 | "\n", 98 | " def __call__(self, x):\n", 99 | " return tf.matmul(x, self.W) + self.b" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 105, 105 | "metadata": { 106 | "collapsed": false 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "tf.ops.reset_default_graph()\n", 111 | "sess = tf.InteractiveSession()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 106, 117 | "metadata": { 118 | "collapsed": false 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "N = 5\n", 123 | "# x represents input data\n", 124 | "x = tf.placeholder(tf.float32, (None, 2 * N), name=\"x\")\n", 125 | "# y_golden is a reference output data.\n", 126 | "y_golden = tf.placeholder(tf.float32, (None, N), name=\"y\")\n", 127 | "\n", 128 | "layer1 = Layer(2 * N, N)\n", 129 | "# y is a linear projection of x with nonlinearity applied to the result.\n", 130 | "y = tf.nn.sigmoid(layer1(x))\n", 131 | "\n", 132 | "# mean squared error over all examples and all N output dimensions.\n", 133 | "cost = tf.reduce_mean(tf.square(y - y_golden))\n", 134 | "\n", 135 | "# create a function that will optimize the neural network\n", 136 | "optimizer = tf.train.AdagradOptimizer(learning_rate=0.3)\n", 137 | "train_op = optimizer.minimize(cost)\n", 138 | "\n", 139 | "# initialize the variables\n", 140 | "sess.run(tf.initialize_all_variables())" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 107, 146 | "metadata": { 147 | "collapsed": false 148 | }, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "0.262958\n", 155 | "0.249229\n", 156 | "0.259427\n", 157 | "0.245061\n", 158 | "0.252946\n", 159 | "0.24782\n", 160 | "0.250937\n", 161 | "0.246418\n", 162 | "0.246755\n", 163 | "0.244774\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "for t in range(5000):\n", 169 | " example_x, example_y = create_examples(N, 10)\n", 170 | " cost_t, _ = sess.run([cost, train_op], {x: example_x, y_golden: example_y})\n", 171 | " if t % 500 == 0: \n", 172 | " print(cost_t.mean())" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "### Notice that the error is far from zero.\n", 180 | "\n", 181 | "Actually network always predicts approximately $0.5$, regardless of input data. That yields error of about $0.25$, because we use mean squared error ($0.5^2 = 0.25$). " 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 109, 187 | "metadata": { 188 | "collapsed": false 189 | }, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "[[ 1. 0. 1. 1. 1. 1. 0. 0. 1. 1.]\n", 196 | " [ 1. 0. 1. 1. 0. 1. 1. 1. 1. 1.]\n", 197 | " [ 0. 0. 1. 0. 1. 0. 0. 1. 1. 1.]]\n", 198 | "[array([[ 0.56099683, 0.54470569, 0.4940519 , 0.49518651, 0.54470527],\n", 199 | " [ 0.56658453, 0.52068532, 0.48442408, 0.4748241 , 0.5073036 ],\n", 200 | " [ 0.53004831, 0.52866411, 0.48705727, 0.48926324, 0.53761232]], dtype=float32)]\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "X, _ = create_examples(N, 3)\n", 206 | "prediction = sess.run([y], {x: X})\n", 207 | "print(X)\n", 208 | "print(prediction)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "### Accuracy is not that hard to predict..." 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 113, 221 | "metadata": { 222 | "collapsed": false 223 | }, 224 | "outputs": [ 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "Accuracy over 1000 examples: 48 %\n" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "N_EXAMPLES = 1000\n", 235 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n", 236 | "# one day I need to write a wrapper which will turn the expression\n", 237 | "# below to:\n", 238 | "# tf.abs(y - y_golden) < 0.5\n", 239 | "is_correct = tf.less_equal(tf.abs(y - y_golden), tf.constant(0.5))\n", 240 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n", 241 | "\n", 242 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n", 243 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "### Xor Network with 2 layers" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 149, 256 | "metadata": { 257 | "collapsed": false 258 | }, 259 | "outputs": [ 260 | { 261 | "name": "stderr", 262 | "output_type": "stream", 263 | "text": [ 264 | "Exception AssertionError: AssertionError() in > ignored\n" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "tf.ops.reset_default_graph()\n", 270 | "sess = tf.InteractiveSession()" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 150, 276 | "metadata": { 277 | "collapsed": false 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "N = 5\n", 282 | "# we add a single hidden layer of size 12\n", 283 | "# otherwise code is similar to above\n", 284 | "HIDDEN_SIZE = 12\n", 285 | "\n", 286 | "x = tf.placeholder(tf.float32, (None, 2 * N), name=\"x\")\n", 287 | "y_golden = tf.placeholder(tf.float32, (None, N), name=\"y\")\n", 288 | "\n", 289 | "layer1 = Layer(2 * N, HIDDEN_SIZE)\n", 290 | "layer2 = Layer(HIDDEN_SIZE, N) # <------- HERE IT IS!\n", 291 | "\n", 292 | "hidden_repr = tf.nn.tanh(layer1(x))\n", 293 | "y = tf.nn.sigmoid(layer2(hidden_repr))\n", 294 | "\n", 295 | "cost = tf.reduce_mean(tf.square(y - y_golden))\n", 296 | "\n", 297 | "optimizer = tf.train.AdagradOptimizer(learning_rate=0.3)\n", 298 | "train_op = optimizer.minimize(cost)\n", 299 | "sess.run(tf.initialize_all_variables())" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 151, 305 | "metadata": { 306 | "collapsed": false 307 | }, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "0.241089\n", 314 | "0.240045\n", 315 | "0.1631\n", 316 | "0.0709099\n", 317 | "0.0326128\n", 318 | "0.0087687\n", 319 | "0.00526247\n", 320 | "0.00518266\n", 321 | "0.00272845\n", 322 | "0.00213744\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "for t in range(5000):\n", 328 | " example_x, example_y = create_examples(N, 10)\n", 329 | " cost_t, _ = sess.run([cost, train_op], {x: example_x, y_golden: example_y})\n", 330 | " if t % 500 == 0: \n", 331 | " print(cost_t.mean())" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "### This time the network works a tad better" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 156, 344 | "metadata": { 345 | "collapsed": false 346 | }, 347 | "outputs": [ 348 | { 349 | "name": "stdout", 350 | "output_type": "stream", 351 | "text": [ 352 | "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 353 | " [ 1. 0. 0. 1. 1. 1. 0. 1. 1. 1.]\n", 354 | " [ 0. 1. 1. 1. 0. 0. 0. 0. 1. 0.]]\n", 355 | "[[ 0. 0. 0. 0. 1.]\n", 356 | " [ 0. 0. 1. 0. 0.]\n", 357 | " [ 0. 1. 1. 0. 0.]]\n", 358 | "[array([[ 0.10384335, 0.04389301, 0.05774897, 0.04509954, 0.9374879 ],\n", 359 | " [ 0.05130127, 0.02655722, 0.97246277, 0.03545236, 0.04168396],\n", 360 | " [ 0.03924223, 0.96327722, 0.96935028, 0.03265698, 0.0310236 ]], dtype=float32)]\n" 361 | ] 362 | } 363 | ], 364 | "source": [ 365 | "X, Y = create_examples(N, 3)\n", 366 | "prediction = sess.run([y], {x: X})\n", 367 | "print(X)\n", 368 | "print(Y)\n", 369 | "print(prediction)" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 152, 375 | "metadata": { 376 | "collapsed": false 377 | }, 378 | "outputs": [ 379 | { 380 | "name": "stdout", 381 | "output_type": "stream", 382 | "text": [ 383 | "Accuracy over 1000 examples: 100 %\n" 384 | ] 385 | } 386 | ], 387 | "source": [ 388 | "N_EXAMPLES = 1000\n", 389 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n", 390 | "is_correct = tf.less_equal(tf.abs(y - y_golden), tf.constant(0.5))\n", 391 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n", 392 | "\n", 393 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n", 394 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))" 395 | ] 396 | } 397 | ], 398 | "metadata": { 399 | "kernelspec": { 400 | "display_name": "Python 2", 401 | "language": "python", 402 | "name": "python2" 403 | }, 404 | "language_info": { 405 | "codemirror_mode": { 406 | "name": "ipython", 407 | "version": 2 408 | }, 409 | "file_extension": ".py", 410 | "mimetype": "text/x-python", 411 | "name": "python", 412 | "nbconvert_exporter": "python", 413 | "pygments_lexer": "ipython2", 414 | "version": "2.7.8" 415 | } 416 | }, 417 | "nbformat": 4, 418 | "nbformat_minor": 0 419 | } 420 | -------------------------------------------------------------------------------- /notebooks/game.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "The autoreload extension is already loaded. To reload it, use:\n", 15 | " %reload_ext autoreload\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "%load_ext autoreload\n", 21 | "%autoreload 2" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 5, 27 | "metadata": { 28 | "collapsed": true 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "import matplotlib.pyplot as plt\n", 33 | "%matplotlib inline" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 6, 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "import numpy as np\n", 45 | "import tempfile\n", 46 | "import tensorflow as tf\n", 47 | "\n", 48 | "from tf_rl.controller import DiscreteDeepQ, HumanController\n", 49 | "from tf_rl.simulation import KarpathyGame\n", 50 | "from tf_rl import simulate\n", 51 | "from tf_rl.models import MLP" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 7, 57 | "metadata": { 58 | "collapsed": false 59 | }, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "/tmp/tmpdzxofD\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "LOG_DIR = tempfile.mkdtemp()\n", 71 | "print(LOG_DIR)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 8, 77 | "metadata": { 78 | "collapsed": false 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "from tf_rl.simulation import DiscreteHill" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 9, 88 | "metadata": { 89 | "collapsed": false 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "# Tensorflow business - it is always good to reset a graph before creating a new controller.\n", 94 | "tf.ops.reset_default_graph()\n", 95 | "session = tf.InteractiveSession()\n", 96 | "\n", 97 | "# This little guy will let us run tensorboard\n", 98 | "# tensorboard --logdir [LOG_DIR]\n", 99 | "journalist = tf.train.SummaryWriter(LOG_DIR)\n", 100 | "\n", 101 | "# Brain maps from observation to Q values for different actions.\n", 102 | "# Here it is a done using a multi layer perceptron with 2 hidden\n", 103 | "# layers\n", 104 | "brain = MLP([4,], [10, 4], \n", 105 | " [tf.tanh, tf.identity])\n", 106 | "\n", 107 | "# The optimizer to use. Here we use RMSProp as recommended\n", 108 | "# by the publication\n", 109 | "optimizer = tf.train.RMSPropOptimizer(learning_rate= 0.001, decay=0.9)\n", 110 | "\n", 111 | "# DiscreteDeepQ object\n", 112 | "current_controller = DiscreteDeepQ((4,), 4, brain, optimizer, session,\n", 113 | " discount_rate=0.9, exploration_period=100, max_experience=10000, \n", 114 | " store_every_nth=1, train_every_nth=4, target_network_update_rate=0.1,\n", 115 | " summary_writer=journalist)\n", 116 | "\n", 117 | "session.run(tf.initialize_all_variables())\n", 118 | "session.run(current_controller.target_network_update)\n", 119 | "# graph was not available when journalist was created \n", 120 | "journalist.add_graph(session.graph_def)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 10, 126 | "metadata": { 127 | "collapsed": false 128 | }, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "Game 9900: iterations before success 12. Pos: (-3, 7), Target: (-3, 7)\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "performances = []\n", 140 | "\n", 141 | "try:\n", 142 | " for game_idx in range(10000):\n", 143 | " game = DiscreteHill()\n", 144 | " game_iterations = 0\n", 145 | "\n", 146 | " observation = game.observe()\n", 147 | "\n", 148 | " while game_iterations < 50 and not game.is_over():\n", 149 | " action = current_controller.action(observation)\n", 150 | " reward = game.collect_reward(action)\n", 151 | " game.perform_action(action)\n", 152 | " new_observation = game.observe()\n", 153 | " current_controller.store(observation, action, reward, new_observation)\n", 154 | " current_controller.training_step()\n", 155 | " observation = new_observation\n", 156 | " game_iterations += 1\n", 157 | " performance = float(game_iterations - (game.shortest_path)) / game.shortest_path\n", 158 | " performances.append(performance)\n", 159 | " if game_idx % 100 == 0:\n", 160 | " print \"\\rGame %d: iterations before success %d.\" % (game_idx, game_iterations),\n", 161 | " print \"Pos: %s, Target: %s\" % (game.position, game.target),\n", 162 | "except KeyboardInterrupt:\n", 163 | " print \"Interrupted\"" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 11, 169 | "metadata": { 170 | "collapsed": false 171 | }, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "[]" 177 | ] 178 | }, 179 | "execution_count": 11, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | }, 183 | { 184 | "data": { 185 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEACAYAAABbMHZzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XuUVNWZ9/HvIyCKN0QUubRCBFQUBVEkxkiZMV4YhpmM\nGp1JIJoJuhwTL1HH6wyYiTHrXXFwjEYxGkOMUTJoFBW8UxqNooaLCCKiEAG5CrZyE7r7ef/YVVR1\ndXV3dXdVn7r8PmvVqlPn7D77qd3dzzm1z6m9zd0REZHKslvUAYiISPtT8hcRqUBK/iIiFUjJX0Sk\nAin5i4hUICV/EZEK1DGXQma2HPgcqAV2uvvwjO0x4Ango8SqR939p/kLU0RE8imn5A84EHP3jU2U\nedndx+QhJhERKbCWdPtYG7eLiEiRyDX5O/CCmb1tZuMb2X6Smc03sxlmNih/IYqISL7l2u3zNXdf\nbWYHAs+b2WJ3/3Pa9jlAlbtvNbOzgMeBgfkOVkRE8sNaOraPmU0ANrv7bU2UWQYMS79GYGYaREhE\npBXcPe/d6s12+5hZFzPbJ7G8F3A6sCCjTA8zs8TycMJBpcHFYXfXw50JEyZEHkOxPNQWagu1RdOP\nQsml26cH8KdEbu8IPOTuz5nZxYmEPhk4B7jEzGqArcD5BYpXRETyoNnk7+7LgCFZ1k9OW74LuCu/\noYmISKHoG74RiMViUYdQNNQWKWqLFLVF4bX4gm+rKzLz9qpLRKRcmBkexQVfEREpP0r+IiIVSMlf\nRKQCKfmLiFQgJX8RkQqk5C8iUoGU/EVEKpCSv4hIBVLyFxGpQEr+IiIVKPLkv3p11BGIiFSeSJP/\n8uXQqxesXx9lFCIilSfS5L9uXXhesiTKKEREKk8kyf+aa2DLFqirC6/HZ5sSXkRECianIZ3NbDnw\nOVAL7HT34VnK3AGcRZjJ6wJ3n5ux3d2dTZugWzc48EDYsQOqq8P2rVuhY0fo1Kmtb0lEpHwUakjn\nXKZxBHAg5lnm5QUws1FAf3cfYGYnAncDI7KV3ZjYQ2Y/f5cu4Xn+fDjmmByjEhGRVmlJt09TR54x\nwBQAd58NdDWzHtkK9u/fdCXHHdeCiEREpFVyTf4OvGBmb5tZth763sCKtNcrgT6ZhXbubL6iu+/O\nMSIREWm1XLt9vubuq83sQOB5M1vs7n/OKJP5yaDBxYTdd5+4a/mWW2IMHRrjlVfg//4vnPHvuy9Y\n3nu2RERKRzweJx6PF7yeFs/ha2YTgM3uflvaunuAuLs/kni9GBjp7mvTynjyePD22zBsWMN9X3op\nDBoUnkVEJMI5fM2si5ntk1jeCzgdWJBRbDowLlFmBPBZeuJPuvlmcM+e+AH22AO2bWtR/CIi0grN\nnvmbWT/gT4mXHYGH3P1WM7sYwN0nJ8rdCZwJbAEudPc5Gfvx1audgw9uvK4TT4Q33wwHCBERKdyZ\nf4u7fVpdkZlXVzv77tt4md/+Fi68UMlfRCSpLJL/zp1OxyYuMW/bFu7337ABDjigXcISESlqkfX5\n51NTiR9gzz3D84oVTZcTEZG2iXxI52yGDoWRI1NDP4iISH4VZfIHeOUVeOedqKMQESlPRZf8p0xJ\nLX/5ZXRxiIiUs6JL/uPGpZaV/EVECqPokj9A377heceOSMMQESlbRZn8J0wIzzU10cYhIlKu2vU+\n/5bUZQYDB8L77xcwKBGRIhf1ZC7t7vDDYfDgqKMQESlPRX3mD2GeXw3zLCKVqiy+4dsSP/pReN5r\nrzC/r4iI5E/RJv+rrgrP27bB9OnRxiIiUm6Kttsn/ExqWSN9ikglqrhuH4Brr00tb94cXRwiIuWm\nqJP/rbfCs8+G5SefjDYWEZFyklPyN7MOZjbXzBqkYDOLmVl1YvtcM7spX8GZwemnw+jRuuNHRCSf\ncj3zvxxYRHIG9oZedvehicdP8xNayqBBYXpHERHJj1wmcO8DjALuAxo7/y7oefnxx8OkSbroKyKS\nL7mc+U8CrgHqGtnuwElmNt/MZpjZoLxFl/DNb4ZnTe4iIpIfTQ7vYGajgXXuPtfMYo0UmwNUuftW\nMzsLeBwYmK3gxIkTdy3HYjFiscZ2WV/XrtCvH6xfH5ZFRMpVPB4nHo8XvJ4m7/M3s58BY4EaYA9g\nX+BRdx/XxM8sA4a5+8aM9S2+z7/+z0OvXrBqVat3ISJSciK5z9/db3D3KnfvB5wPvJSZ+M2sh1m4\nF8fMhhMOKBuz7K7NOncuxF5FRCpPS+/zdwAzu9jMLk6sOwdYYGbzgNsJB4m8mzULevcuxJ5FRCpP\nUQ/vkG7JEhg1CpYuzWNQIiJFriKHd0jXqRPs3Bl1FCIi5aFkkr8ZfPxxGOVTRETapmSSf/Ksf/Xq\naOMQESkHJZP8998/PI9r9CZTERHJVckk/+7dw62e6vcXEWm7kkn+AL/7nZK/iEg+lFTyHzZMt3qK\niORDSSX/vn3hiy9g0aKoIxERKW0llfw7dIATToAFC6KORESktJVU8gcYOhQ2bYo6ChGR0lZyyX+f\nfULXj4iItF7JJf+994bNm6OOQkSktJVc8t9nH/j886ijEBEpbSWX/Hv2hE8+iToKEZHSVnLJ/5BD\nYMWKqKMQESltJZf8q6rC6J4iItJ6OSV/M+tgZnPN7MlGtt9hZh+Y2XwzG5rfEOvr1QvWrdMwDyIi\nbZHrmf/lwCIS0zimM7NRQH93HwBcBNydv/Aa6tQJevRQv7+ISFs0m/zNrA8wCrgPyDaV2BhgCoC7\nzwa6mlmPfAaZSV0/IiJtk8uZ/yTgGqCuke29gfRLsCuBPm2Mq0m9eunMX0SkLTo2tdHMRgPr3H2u\nmcWaKprxOutM7RMnTty1HIvFiMWa2mXj9t1XX/QSkfIUj8eJx+MFr8fcs+bpsNHsZ8BYoAbYA9gX\neNTdx6WVuQeIu/sjideLgZHuvjZjX95UXS1x2WXQv394FhEpZ2aGu2frcm+TJrt93P0Gd69y937A\n+cBL6Yk/YTowLhHkCOCzzMSfb927w9qC1iAiUt5aep+/A5jZxWZ2MYC7zwA+MrOlwGTg3/MbYkM9\nesCGDYWuRUSkfDXZ55/O3V8GXk4sT87Y9sM8x9Wk5OBur7wCp5zSnjWLiJSHkvuGL4TB3WbPhpEj\nIU+XEUREKkpJJv8DD4QPPwzLut9fRKTlSjL59+yZWl6zJro4RERKVcknf33ZS0Sk5Uoy+XfuDA8+\nGJa3bIk2FhGRUlSSyR9g9OjwvHhxtHGIiJSikk3+XbvC1VdraGcRkdYo2eQPcNhhsHBh1FGIiJSe\nkk7+778PTz8ddRQiIqWnpJP/GWeE59raaOMQESk1TY7qmdeK8jiqZ/39wrJl0Ldv3nctIhK5SEb1\nLAVdu0J1ddRRiIiUlpJP/oMGwRdfRB2FiEhpKfnkv2wZvP561FGIiJSWkk/+I0bAp59GHYWISGlp\nNvmb2R5mNtvM5pnZIjO7NUuZmJlVm9ncxOOmwoTb0FFHQZcu7VWbiEh5aHYyF3ffbmanuvtWM+sI\nvGpmJ7v7qxlFX3b3MYUJs3GdO8O2be1dq4hIacup28fdtyYWdwc6ABuzFMv7rUi56NwZduyIomYR\nkdKVU/I3s93MbB6wFpjl7osyijhwkpnNN7MZZjYo34E2pnNn+PLL9qpNRKQ85HrmX+fuQ4A+wClm\nFssoMgeocvdjgV8Cj+c1yiYo+YuItFzOE7gDuHu1mT0NHA/E09Z/kbY808x+ZWbd3L1e99DEiRN3\nLcdiMWKxWOuiTqPkLyLlJB6PE4/HC15Ps8M7mFl3oMbdPzOzPYFngZvd/cW0Mj2Ade7uZjYc+KO7\n983YT0GGd5g2Df7wB3jssbzvWkQkcoUa3iGXM/+ewBQz243QTfSgu79oZhcDuPtk4BzgEjOrAbYC\n5+c70Mb06QMrV7ZXbSIi5aHkB3ZbuRKGD9dcviJSngp15l/yyb+mJvT7a1hnESlHGtWzER07Ql0d\nbN8edSQiIqWj5JN/ki74iojkriyS//e+p9s9RURaouT7/AGOPhqWLNEwDyJSftTn34QTToCdO6OO\nQkSkdJRF8r/0UhgyJOooRERKR1l0+6xbBz16hLt+LJKxRUVECkPdPk046KDwrBm9RERyUxbJH+CY\nY+Djj6OOQkSkNJRN8j/0UPjb36KOQkSkNJRN8j/kECV/EZFclU3yP/RQdfuIiOSqrJK/zvxFRHJT\nNslf3T4iIrkri/v8IXWvf00NdOhQsGpERNpVJPf5m9keZjbbzOaZ2SIzu7WRcneY2QdmNt/MhuY7\nyFwcdBAccACsWhVF7SIipaXJ5O/u24FT3X0IcAxwqpmdnF7GzEYB/d19AHARcHehgm3OkUfCsmVR\n1S4iUjqa7fN3962Jxd2BDsDGjCJjgCmJsrOBrokJ3dvdV74CH34YRc0iIqWl2eRvZruZ2TxgLTDL\n3RdlFOkNrEh7vRLok78QczdwILz/fhQ1i4iUllzO/OsS3T59gFPMLJalWObFiPa5ipzhkEN0r7+I\nSC465lrQ3avN7GngeCCetmkVUJX2uk9iXQMTJ07ctRyLxYjFYrlHmoMBA8KkLiIipSoejxOPxwte\nT5O3eppZd6DG3T8zsz2BZ4Gb3f3FtDKjgB+6+ygzGwHc7u4jsuyroLd6AmzcCH37QnW1hnYWkfJQ\nqFs9mzvz7wlMMbPdCF1ED7r7i2Z2MYC7T3b3GWY2ysyWAluAC/MdZK66dYMvvoBFi+Coo6KKQkSk\n+JXNl7xS9cC998L48QWvSkSk4Ap15l92yX/UKJg5E9rpbYmIFJSSf44WLw5f9qqthd3KZuQiEalU\nmsYxR0ccEYZ52LAh6khERIpX2SV/gJ49YfXqqKMQESleZZn8e/VS8hcRaUrZ9fmHusKzLvqKSKlT\nn38LHH101BGIiBS3skz+DzwAQyOZVUBEpDSUZfI/9FBYvjzqKEREildZJv/u3cN0jps2RR2JiEhx\nKsvkbwaHHQZLl0YdiYhIcSrL5A/QuTMMHx51FCIixalsk/9110FVVfPlREQqUdkm/6qq0PcvIiIN\nlW3y33132LEj6ihERIpT2Sb/zp3hyy+jjkJEpDg1m/zNrMrMZpnZQjN718wuy1ImZmbVZjY38bip\nMOHmTmf+IiKNy2UC953Ale4+z8z2Bv5qZs+7+3sZ5V529zH5D7F1unSBLVuijkJEpDg1e+bv7mvc\nfV5ieTPwHtArS9GimjL9wAPDmf+6dVFHIiJSfFrU529mfYGhwOyMTQ6cZGbzzWyGmQ3KT3itZwbH\nHgvz50cdiYhI8cml2weARJfPNODyxCeAdHOAKnffamZnAY8DAzP3MXHixF3LsViMWCzWipBz16MH\nVFcXtAoRkbyKx+PE4/GC15PTeP5m1gl4Cpjp7rfnUH4ZMMzdN6ata7fx/JN+8AM48UQYP75dqxUR\nyZvIxvM3MwPuBxY1lvjNrEeiHGY2nHBQ2ZitbHvq3RtWrow6ChGR4pNLt8/XgO8C75jZ3MS6G4BD\nANx9MnAOcImZ1QBbgfMLEGuL9esHL70UdRQiIsWn2eTv7q/SzCcEd78LuCtfQeVLv36wbFnUUYiI\nFJ+ynMM3aeXKMMbPmjXh4q+ISKnRHL6t0KdPeH7uuWjjEBEpNmWd/AHOPhs6dYo6ChGR4lL2yb+q\nClatijoKEZHiUvbJf//9NZeviEgmJX8RkQpU9sm/e3f49NOooxARKS5ln/z33RemTo06ChGR4lL2\nyb9///C8fXu0cYiIFJOyT/6HHw49e6rrR0QkXdknfwj9/prURUQkpSKSf00NTJkSdRQiIsWjrMf2\nSdUdniOqXkSk1TS2TxsUeMIwEZGSUxHJf9q08HzPPdHGISJSLCoi+R9wAFx6KSxfHnUkIiLFIZdp\nHKvMbJaZLTSzd83sskbK3WFmH5jZfDMbmv9Q22b0aHjhhaijEBEpDrmc+e8ErnT3o4ARwKVmdmR6\nATMbBfR39wHARcDdeY+0jQ47DP76V9iyJepIRESi12zyd/c17j4vsbwZeA/olVFsDDAlUWY20NXM\nimrurAEDwvPChdHGISJSDFrU529mfYGhwOyMTb2BFWmvVwJ92hJYIYwcCc88E3UUIiLRa3YC9yQz\n2xuYBlye+ATQoEjG6wZ31U+cOHHXciwWI9bO92COHw/Tp7drlSIiLRKPx4nH4wWvJ6cveZlZJ+Ap\nYKa7355l+z1A3N0fSbxeDIx097VpZSL7klfSs8/CbbdpTl8RKR2RfcnLzAy4H1iULfEnTAfGJcqP\nAD5LT/zFYq+94LXXoo5CRCR6zZ75m9nJwCvAO6S6cm4ADgFw98mJcncCZwJbgAvdfU7GfiI/86+u\nhq5d4bPPYL/9Ig1FRBqxfTvssUfUURSPQp35V8TYPvXjgEmT4Ioroo5EpDJ88UWYVGnrVthzz+xl\n3GG3tH6IurrUmFyVTmP75NGVV4Y/rN/+NupIRMrXlCmwc2dI/ADf/nbjZWfOrP/6qqsKF5cEFZf8\n0yd1ufDC6OIQKVfjx4eTqwsugEcegQ4d4M474amnGh9Z95NP4OSTU9snTWq3cCtWxSX/bt1C3//r\nr8PAgVFHI1Je3n4b7rsv9XrcODj66NSJ1qZN0K9f/TJjx4YDxquvhtfJLtnq6vaJuVJVXPKH8DF0\nwABYsgQ2bIg6GpHyMXZseB4yJNWHP38+dOkChx4KK1aEARZffDFse+IJ+P3vw3Jy9N3k14E0Cm9h\nVdwF36TkBaZZszTev0hbrF8P550X/pcg3FK9eTNs2xaS/rPPwumnN7yAe+ONMGMGzJ0b9tG9e2pb\nsuwDD4RPDZs3h/0Wo/feC/OEd+1amP3rbp8CGDcOTj1Vff8irbFxIxxySMPBEt98E044oWH5xu7e\nuffe0O2TbuvWhsm+yNLHLmYhlxRqqljd7VMA/fvD0qVRRyFSOmprQ7IzC/NkJBP/P/4jzJkTzuCz\nJX6ABQtCf797mFc7Kdv/YJcuDZP/li3hLqBzz4V2GP2gWb/8ZeqAVlcXbSytUdFn/r//PTz9NDz8\ncNSRSCVbvx4OPDDqKHIzdSqcf379dRs2hANBS/3kJzBhQsMun6SNG8Nt2ZMmZd9/FOnk3XfhoovC\nbeKHH15/W21t/e8q5IvO/Atgr73CtwmltHzwQWH3v2ULHHxww/V//Svccgs8+ij8wz+E2xfTtebs\nb8YMOOigcAHULNSRrczNN7d83/nkHrpzpk0Lt24eeWS4NbO2tnWJH+C//ivsN1vih3Bn3pQp4fmf\n/imsSz/wvPlm6+ptiW98I/VJxwwGDw53CiYT/wMPwGOPheWVK+HDD0OblAR3b5dHqKq4PP64+5gx\nUUchLbH77u4hZbg//LD7GWe4X3GF+8SJ9ctNmeL+xRetq2PRorD/hx4Kr5P1DRyYWk4+3nknlJk+\nPbxetKhldWXuD9x37Ehtr61Nra+tbd37yWbbNvcHH2y6zJYt7qtXh+U77kjFscceIZZ8xpOL2bPD\n86ZN9dtr+/bC1Hfuudl/P7vtllpOyizz85+7b9zoXl3d9jgSuTP/ObkQO81aUREm/1tvrf8LlOLz\nt7+533KL+5NPuq9fX/8fbMCAhv90H3zgftttqdd1dS2vMz3RLVyYPQEsX55a/s536m+rq3P/7LPc\n6gL3vn3db7jBfezY1D4uv9x9/PiG7y3dnXe6v/9+83Xs3JlqhzfeaPheLrigfvlly9xXrEhtnzcv\n/A6Sr4cMye29FdJPftKwzdsivY3eeiu0f/q+H3ggJPTkwWbHjvqJ/YYbsv+dVFW1LS53Jf+C+MEP\nQgts2RJ1JKVl0SL3PffMzx92UzZsyP4PdeWV2ddne3z1q03XsXVrarmuzv3zz927dnUfOrT+fl54\nwRuclY8c2XTdzXn1VfcjjnB/6qnUumSCaWqf6Z8GTj65+XqSZbt0aXy/3/mO+69+Vf+glvn4xS+a\nr6u91NS4v/uu+4cfhtjacoa9alXj73nTptz2kTxYrljh/tprqZ+/667Wx5Wk5F8AyY/qP/pR1JEU\nr2wfqfv3T/1xX3FF4eq+//7s/5AffOA+a5bvOgtzD4l74sRUmQUL3P/u78Ly559n3/977/muM1v3\n+nWce244e4fwz9yYjz5K/czOnfX3kewy2bYtJJGnngpxJyXLvfRS/X1efHHDhH/ZZWE5W5s0Zd68\nhuWrq90feyzE9OWXTR9s0n/+mmuarisqnTuH32VrZTuIH364+403tm5/NTWp/axc2fq4kpT8C+Sm\nm9x//OOooyhOffqk/oivuy6se/LJ8PrNN1PbTjklbLvppuaTUa527nQfPNj9qqta/rPJj++1te49\ne7q/+GLoIknGm7wW8Jvf1E+g6f/8X36Ze321teEf3j0kkm99y71jR/c//SmsGzy4/r6ffTZV3+jR\nDfvO6+rCCUlyn0np+3jwQfc5c+ofeJJeein8TpLdmslHv37Zu0fWrg3JLjMBnnBC2J781NPUQTBK\nENq7tQ491H3SpLC8ZYv7PffkJay8UfIvkJtvDq1wxhlRR1J46QmiOelntNkedXX1E2rmY+HCtnWn\nHXxw2M+117Z+H+7uZ5/dMPkmDwCjR6deP/powyTaFsn9/uUv2dunX7/wnN7t1JwtW9w7dQpJ3919\n6dLU/tLbOrOueDy3/a9Z437eeeHaykMP1f+UUsx+9rPwPj/6qOU/m+xCe+SR/MeVL0r+BbJmTeqf\npFev8JzLRbRSkt6Pm4u6uvrlb7yxfjLZti2sT2+7xhJc0mefhcTS2B0is2aFRL9xo3vv3ql9PPFE\nq9+2u7sfd1xqX9Omhf1BqmslfXs+/0TPO6/+fh97LKz//PP81VdX53766WE/n3xS/32B+5lnhusm\n5S69m6W5O5gyffhh+IRbzCJL/sBvgLXAgka2x4BqYG7icVMj5QrXOm20fn34iFuIJFAMxoxJva/0\nM9u333Y/9VT3556rf3fK//xPKHvLLeH1xx+nfv7UU7PXsWZNavnVV1PlZ8xw/+d/Tr2++273005z\nnzq1/s8PG9bw4JE8yLTFmjXu++wT7nJJSq+jtjZ1cLzhhrbXl/TMM2Gf3buHC4qZZs1ynzkzP3WB\n+/nn139f+foEUyomT67//u+8M6z/3e/qX5eqqwsJ/7XX3H/5y1B20KBoYs5VlMn/68DQZpL/9Bz2\nU6CmyY/Nm92///1UN9D69WH9zJnuZu7/+Z8hSZQicD/66PoHtqbucMh2AGzpQXHHjub3X1eX/c6d\n//7vcCZbKBdcEOpJv56QfhdPqUm2W/onpkqU+Xd00kmp5bVrc/9bLzaRdvsAfZtJ/k/msI/CtEye\npXd5ZD7OOSfq6JpXV9ewrx3CGXAy0cZiqffUrVvDLgpo+AWpefNa3qf61lup/a1Y4X7MMU3/A7bX\n2er27eHTQLl0iSS/H3DGGfn7NFGKXnnF/ac/DXczNXfikXy8+mq4ZbSYFXPyHwl8CswHZgCDGilX\nsMbJt+RtdcnH1VeHs4jdd3d//fWoo2tc+v3fEBLctdf6ru6N5O2RyUd6N8f27e4nnpj6xJMvdXWp\nfv733gv97rW14UszELqE5s0r7nYtdjt2hGsK5XIwy4elS92XLKn/CXS//dwXL07dgrtgQbQx5qpQ\nyT+ngd3MrG/i7H5wlm37ALXuvtXMzgL+190bzJFlZj5hwoRdr2OxGLEiHUh/2zZYty6M8dKvX5h0\n+umnYfToVJk774RLLinMQE65+PWvQ93/9m+pdSNHwiuvZC+f/DW/9FIYlGrKFE2QLZXjj3+Eo44K\nj2IXj8eJpw1bevPNN+NRjeffVPLPUnYZMMzdN2as91zqKmbLl8O//Au88UZq3de/Hiar2HPPMLBX\nXR107FjYOGpqoFOnsJwc9bBXrzAP6n33wYgRYXKJqVNhv/3CYFhRHaREpG0incylmTP/HsA6d3cz\nGw780d37ZilX8sk/qbo6jESYHJP8lFNCEv7Wt8LrqVPh298uXP0PPQTf/W4YN/2tt+pvK5MmFpGE\nyJK/mT1M6NfvTrjlcwLQCcDdJ5vZpcAlQA2wFfixu7+RZT9lk/yT0s/A033lK2Fo13wzg2OPDXOi\nptu+PcyP+vLLDccYF5HSpmkci9SVV8Ltt4flSy8Nk08fdlgYo33p0jA/6ZAhDX9u06Zw4Nh779zq\nca/fdfOLX8DVV8PMmXDmmW1+GyJSpJT8i5x76gJq5oXUTz8NE1IkrVsHPXqkfu6hh2D//WHUqMb3\nn9zn88+HCecLfV1BRIqDZvIqcukJ//77w10Fye6ZzJmOqqtTyz17hv77v//7+p8Qrr8+7HPlytS6\nbt3gtNOU+EWk7ZT8C+D73w9zfQ4eDNdcE9YtXBiS+Z//DDt3whFHhPVr1oTnsWPDweLhh8P2n/88\nrK+qChNFQ3FMWi0i5UHdPu2gsfvpd+4M9x+ffTZ07tyw3Ne/Hg4WSXV1ujdfpNKo26eEJe8IOuec\n1Lrjjw/dN//6ryHxA8yZk9p+5ZXhC1vJ20nvvVeJX0TyR2f+7aCmJnxbeJ99cvuy1WuvwUknKdmL\niO72ERGpSOr2ERGRvFHyFxGpQEr+IiIVSMlfRKQCKfmLiFQgJX8RkQqk5C8iUoGU/EVEKlCzyd/M\nfmNma81sQRNl7jCzD8xsvpkNzW+IIiKSb7mc+T8ANDpdiJmNAvq7+wDgIuDuPMVWtuIannMXtUWK\n2iJFbVF4zSZ/d/8zsKmJImOAKYmys4GuiXl9pRH6w05RW6SoLVLUFoWXjz7/3sCKtNcrgT552K+I\niBRIvi74Zg46pBHcRESKWE6jeppZX+BJdx+cZds9QNzdH0m8XgyMdPe1GeV0QBARaYVCjOqZj9lg\npwM/BB4xsxHAZ5mJHwoTvIiItE6zyd/MHgZGAt3NbAUwAegE4O6T3X2GmY0ys6XAFuDCQgYsIiJt\n126TuYiISPFol2/4mtmZZrY48UWwa9ujzvZkZlVmNsvMFprZu2Z2WWJ9NzN73syWmNlzZtY17Weu\nT7THYjM7PW39MDNbkNj2v1G8n3wwsw5mNtfMnky8rsi2MLOuZjbNzN4zs0VmdmIFt8X1if+RBWb2\nBzPrXClPzWGEAAADYUlEQVRtke3Lsvl874m2nJpY/4aZHdpsUO5e0AfQAVgK9CV0F80Djix0ve35\nAA4GhiSW9wbeB44E/h/wH4n11wI/TywPSrRDp0S7LCX1KexNYHhieQZwZtTvr5Vt8mPgIWB64nVF\ntgXhOzDfTyx3BParxLZIvJ+PgM6J11OB71VKWwBfB4YCC9LW5e29A/8O/CqxfB7wSLMxtcOb/irw\nTNrr64Drov5lFPg9Pw6cBiwGeiTWHQwsTixfD1ybVv4ZYATQE3gvbf35wD1Rv59WvP8+wAvAqYS7\nxKjEtkgk+o+yrK/EtuhGOCnan3AQfBL4ZiW1RSKRpyf/vL33RJkTE8sdgfXNxdMe3T7ZvgTWux3q\njUTittihwGzCLzZ559NaIPnN516EdkhKtknm+lWUZltNAq4B6tLWVWJb9APWm9kDZjbHzH5tZntR\ngW3h7huB24CPgU8IdwU+TwW2RZp8vvddedbda4BqM+vWVOXtkfwr5oqyme0NPApc7u5fpG/zcEgu\n+7Yws9HAOnefS8Mv/wGV0xaEM7DjCB/HjyPcDXddeoFKaQszOwy4gnD22wvY28y+m16mUtoimyje\ne3sk/1VAVdrrKuofvcqCmXUiJP4H3f3xxOq1ZnZwYntPYF1ifWab9CG0ySrqD43RJ7GulJwEjDGz\nZcDDwDfM7EEqsy1WAivd/a3E62mEg8GaCmyL44G/uPuniTPTxwhdwpXYFkn5+J9YmfYzhyT21RHY\nL/Fpq1HtkfzfBgaYWV8z251wMWJ6O9TbbszMgPuBRe5+e9qm6YSLWiSeH09bf76Z7W5m/YABwJvu\nvgb4PHFHiAFj036mJLj7De5e5e79CH2SL7n7WCqzLdYAK8xsYGLVacBCQn93RbUFoX97hJntmXgP\npwGLqMy2SMrH/8QTWfZ1DvBis7W304WOswgXe5YC10d94aUA7+9kQv/2PGBu4nEm4SLXC8AS4Dmg\na9rP3JBoj8XAGWnrhwELEtvuiPq9tbFdRpK626ci2wI4FngLmE84292vgtviPwgHvwWEu6A6VUpb\nED4FfwLsIPTNX5jP9w50Bv4IfAC8AfRtLiZ9yUtEpAJpGkcRkQqk5C8iUoGU/EVEKpCSv4hIBVLy\nFxGpQEr+IiIVSMlfRKQCKfmLiFSg/w+W9pFU+SNQPAAAAABJRU5ErkJggg==\n", 186 | "text/plain": [ 187 | "" 188 | ] 189 | }, 190 | "metadata": {}, 191 | "output_type": "display_data" 192 | } 193 | ], 194 | "source": [ 195 | "N = 500\n", 196 | "smooth_performances = [float(sum(performances[i:i+N])) / N for i in range(0, len(performances) - N)]\n", 197 | "\n", 198 | "plt.plot(range(len(smooth_performances)), smooth_performances)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 19, 204 | "metadata": { 205 | "collapsed": false 206 | }, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "text/plain": [ 211 | "1.6869445151941282" 212 | ] 213 | }, 214 | "execution_count": 19, 215 | "metadata": {}, 216 | "output_type": "execute_result" 217 | } 218 | ], 219 | "source": [ 220 | "np.average(performances[-1000:])" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 231, 226 | "metadata": { 227 | "collapsed": false 228 | }, 229 | "outputs": [ 230 | { 231 | "data": { 232 | "text/plain": [ 233 | "" 234 | ] 235 | }, 236 | "execution_count": 231, 237 | "metadata": {}, 238 | "output_type": "execute_result" 239 | }, 240 | { 241 | "data": { 242 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAJBCAYAAABRfpDFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGgJJREFUeJzt3X+wZGV95/HPhzsgDojEUAGBSU3WRYNGUYpfhUm8GLAm\nxJ8pE0MkGkxZVrJjTKqSqCQxQ2o3WZO4IVk2LiKiBCOpoBgMjmRivCpRR0aZAWeGBTRUZoZlQFkw\ngETuzHf/6AYu95zb/fTTp/ucnuf9quqy+/bpc773yvnO5zzn6XMcEQKAYQ5quwAAs4FmASAJzQJA\nEpoFgCQ0CwBJaBYAktAsgBlk+1Dbm21vtb3D9h/XLPNG29ts32L7X2y/aJxtrhrnwwDaERGP2j4r\nIh6xvUrSjbZ/PCJuXLLYtyT9ZEQ8aHudpA9IOiN3mzQLYEZFxCP9p4dImpN0/7L3v7zk5WZJx4+z\nPQ5DgBll+yDbWyXtlfS5iNgxYPFfkfTpcbZHswBmVETsj4gXq5cYftL2fN1yts+S9BZJ7xxnexyG\nABlsT/1LVRHhFX7+oO3rJZ0iaWHpe/1BzcskrYuI/zfO9mkWQKYNLW7L9lGSFiPiAdtPl3SOpIuW\nLfPDkj4h6fyIuHPcGmgWwGx6tqSP2D5IveGEv46Iz9p+myRFxKWS3iPpByS937YkPRYRp+Vu0HxF\nHRid7fivU9ze72nlw5BpYYATQBIOQ4BMB7ddwJSRLAAkIVkAmUrbeUgWAJKU1hyBxjBmAQA1SBZA\nptJ2HpIFgCSlNUegMYxZAEANmgWAJByGAJlK23lIFgCSlNYcgcYwwAkANUgWQKbSdh6SBYAkpTVH\noDGMWQBADZIFkIlkAQA1SBZAptJ2HpIFgCQ0CwBJSktSQGMY4ASAGiQLIFNpOw/JAkCS0poj0BjG\nLACgBskCyFTazkOyAJCktOYINIYxCwCoQbMAkITDECBTaTsPyQJAkk42C9vrbN9m+w7b72y7nmFs\nf8j2Xtu3tl1LKttrbH/O9nbb37D9623XNIjtQ21vtr3V9g7bf9x2TQdP8dEFnWsWtuckXSJpnaTn\nSzrP9ontVjXUFerVO0sek/SbEfECSWdI+i9d/jtHxKOSzoqIF0t6kaSzbP94y2UVpYuHXadJujMi\n7pIk21dLeo2knW0WNUhEfNH22rbrGEVE3CPpnv7zh2zvlHSsuv13fqT/9BBJc5Lub7GcTu48k9S5\nZCHpOEm7lrze3f8ZJqTf6F4iaXO7lQxm+yDbWyXtlfS5iNjRdk0l6WJzjLYLKIntwyVdI+kdEfFQ\n2/UMEhH7Jb3Y9jMl3WB7PiIW2qqnK2MJ09LFZLFH0polr9eoly7QMNsHS/q4pKsi4pNt15MqIh6U\ndL2kU9qupSRdTBZbJJ3Qj8Z3S3qDpPPaLOhAZNuSLpe0IyIubrueYWwfJWkxIh6w/XRJ50i6qM2a\nSBYti4hFSesl3SBph6S/jYjODrpJku2PSfqSpOfa3mX7grZrSvBSSeerd1bh5v6jy2d0ni3pn/tj\nFpslfSoiPttyTUVxBEMEwKhsxz1T3N4xkiLCU9xkReeSBYBuolkASNLFAU5gJhw8zb1ncYrbWgHJ\nAkASkgWQaVVhyWLsX9c2p1NwwGj7jEOXNdQb/6CZ1VQsSJqfwHrXTmCdkvRJSa+d0LpPntB63y/p\nVye07r0TWu+Vkt40gfW+YqSlD56bQAkdxpgFgCSMWQCZpjpm0QEdTxZr2y5gRD/adgEZZvG7WCe1\nXUCROt4b17ZdwIhmsVmc2nYBGbrRLKY6z6IDOp4sAHQFzQJAksKCFNAgTp0CQBXJAshV2N5DsgCQ\npLDeCDSosL2HZAHMoFFuP2n7VNuLtn92nG0W1huBBrW79zx++8mt/Xu/fM32puUXt+7fDvS9kj4j\naaxv1JIsgBkUEfdExNb+84fUu+3ksTWLvl29m0jdN+42SRZAro7Ms1jp9pO2j1PvPsEvV29e/1jX\nnhnaLPr3krhYvT/NByPiveNsEMBwC49IC98bvtyQ209eLOldERH9m0qNdRgy8L4h/eOd/yPpbPVu\nK3iTpPOWHhf1rpQ1qYvfTMratgvIMKmL30zSpC5+MymvSL5Slu2IEyddz5Lt7axexat/+8l/kLSx\n7q5ytr+lJxvEUZIekfTWiLgup4ZhyeI0SXdGxF39jV+tXqzp9B3CgANdyu0nI+I/LVn+CvXu4pbV\nKKThzeI4SbuWvN4t6fTcjQFozOO3n7zF9s39n10o6YclKSIubXqDw5oFF+MFVtLi6YGIuFEjnM2M\niLHvvzvs190jac2S12vUSxfLLCx5vlazOSaA8mzrP5BiWLPYIumE/qmZuyW9QdJ51cXmm60KmIqT\n9NSrbl012sc7cup0WgY2i4hYtL1e0g3q/WkuXz5DDEAZhh51RcRGSRunUAswWwqb0sh0bwBJCuuN\nQIMK23tIFgCSFNYbgQYVdjaEZAEgCc0CQBIOQ4Bche09JAsASQrrjUCDCtt7SBYAkhTWG4EGFbb3\nkCwAJCmsNwINYlIWAFSRLIBche09JAsASQrrjUCDCtt7SBYAkjTSG4+PNzaxmqnZ/boT2i5hZD99\n7SfaLmFkP6jvtF3CSK4a6+Z+B77CghTQIE6dAkAVyQLIVdjeQ7IAkKSw3gg0qLC9h2QBIElhvRFo\nEGdDAKCKZAHkKmzvIVkASFJYbwQaVNjeQ7IAkIRmASBJYUEKaFBhew/JAkCSwnoj0CAmZQFAFckC\nyFXY3kOyAJCksN4INKiwvYdkASBJYb0RaBBnQwCgimYBIAmHIUCuwvYekgWAJIX1RqBBhe09JAsA\nSQrrjUCDCtt7SBYAkhTWG4EGMSkLAKpIFkCuwvYekgWAJIX1RqBBhe09JAsASWgWAJIUFqSABnHq\nFACqSBZArsL2HpIFgCQ0CyDXqik+lrH9Idt7bd+6Unm2523fbPsbthfG/XVpFsBsukLSupXetH2k\npP8l6VUR8WOSXj/uBhs56nqP/rCJ1UzN3LX72i5hZG954cfaLmFkD2+ZrX+Lrhr1Ay2eDYmIL9pe\nO2CRX5T08YjY3V/+2+Nuc7b+3wSQ6gRJz7L9OdtbbP/SuCssbDwXaFC3956DJZ0s6ackrZb0Zdtf\niYg7clfY7V8XKNTCzt5jDLskfTsivifpe7a/IOkkSTQL4EAyf2Lv8biLrh15FX8v6RLbc5KeJul0\nSf9jnJpoFkCuFvce2x+T9DJJR9neJekP1Dv0UERcGhG32f6MpFsk7Zd0WUTsGGebNAtgBkXEeQnL\n/JmkP2tqmzQLIFdhew+nTgEkKaw3Ag3iK+oAUEWyAHIVtveQLAAkKaw3Ag0qbO8hWQBIUlhvBBrE\n2RAAqKJZAEjCYQiQq7C9Z2iysL2mf7Wd7f0Lf/76NAoD0C0pvfExSb8ZEVttHy7pa7Y3RcR4l+YA\nZh3J4qki4p6I2Np//pCknZKOnXRhALplpN7Yv5rwSyRtnkQxwEwp7NRpcrPoH4JcI+kd/YTxhOs2\nbHvi+fPmj9bz5o9prEBgUr7w+dAXvxBtlzEzHDH8j2X7YEn/IGljRFy87L34QJw/ofImY07cN2Qa\nZu2+IYcful8R4ZRlbUd8ZdIVLdneGUqubVJSzoZY0uWSdixvFADKkXIY8lJJ50u6xfbN/Z+9OyI+\nM7mygBlQ2NmQob9uRNwoZnoCxaMJAEhSWJACGlTY3kOyAJCksN4INKiwSVkkCwBJSBZArsL2HpIF\ngCSF9UagQYXtPSQLAEkK641AgzgbAgBVJAsgV2F7D8kCQBKaBYAkhQUpoEGF7T0kCwBJCuuNQIMK\n23sa+XX3+KomVjM1G17edgUZrmu7gNEdtnV/2yWgQYX1RqA5waQsAKgiWQCZ9hW295AsACQprDcC\nzSFZAECNwnoj0JzFuWn+W9v+aWiSBYAkNAsASTgMATLtWzXN3ef7U9xWPZIFgCQkCyDTvrmy5nuT\nLAAkIVkAmfYVdnlvkgWAJCQLINMiyQIAqkgWQKZ9he0+JAsASWgWAJKUlaOABnHqFABqkCyATCQL\nAKhBsgAykSwAoAbNAsi0qLmpPerYXmf7Ntt32H5nzftH2f6M7a22v2H7l8f5fWkWwAyyPSfpEknr\nJD1f0nm2T1y22HpJN0fEiyXNS3qf7eyhB8YsgEwtT/c+TdKdEXGXJNm+WtJrJO1cssz/lfSi/vMj\nJH0nIhZzN0izAGbTcZJ2LXm9W9Lpy5a5TNI/275b0jMk/fw4G6RZAJkmeTbkpoVHtGXhkUGLRMJq\nLpS0NSLmbT9H0ibbJ0XEv+fURLMAOujU+dU6dX71E6//90X3L19kj6Q1S16vUS9dLHWmpP8mSRHx\nTdv/Kul5krbk1MQAJzCbtkg6wfZa24dIeoOk65Ytc5uksyXJ9tHqNYpv5W6QZAFkanNSVkQs2l4v\n6QZJc5Iuj4idtt/Wf/9SSX8k6Qrb29QLBr8TEZWIkopmAcyoiNgoaeOyn1265Pm3Jb2qqe3RLIBM\nXIMTAGqQLIBMXIMTAGqU1RqBBvEVdQCoQbIAMpEsAKAGzQJAkkYOQz4cO4cv1CGffco3e2fDjZef\n03YJozu87QJG5ZGW5jAEAGowwAlkYro3ANQgWQCZmO4NADXKao1AgzgbAgA1SBZAJpIFANQgWQCZ\nmGcBADVoFgCScBgCZGJSFgDUKKs1Ag3i1CkA1CBZAJlIFjVsz9m+2fanJl0QgG5KTRbvkLRD0jMm\nWAswU5iUtYzt4yWdK+mDGvUihQAOGCnJ4s8l/bakIyZcCzBTSptnMfC3tf1KSfdGxM2251da7oEN\nlzzx/ND503To/GmNFQhMzPYFacdC21XMjGGt8UxJr7Z9rqRDJR1h+8qIeNPShY7csH5S9QGT84L5\n3uNxH7+orUpmwsBmEREXSrpQkmy/TNJvLW8UQKk4dTpYTKQKAJ2XPEITEZ+X9PkJ1gLMFJIFANQo\n69wP0CCSBQDUIFkAmZjuDQA1SBZAptKme5MsACQpqzUCDeJsCADUoFkASMJhCJCJwxAAqEGyADIx\nKQsAapAsgExMygKAGmW1RqBBnA0BgBqNJIu7PnxiE6uZmofPm70eedgr9rddwuiuaLuA0Yx6By2S\nBQDUYMwCyESyAIAaNAsASTgMATIx3RsAapAsgExM9waAGmW1RqBBnDoFMBNsr7N9m+07bL9zhWX+\nsv/+NtsvGWd7JAsgU5vJwvacpEsknS1pj6SbbF8XETuXLHOupP8cESfYPl3S+yWdkbtNkgUwm06T\ndGdE3BURj0m6WtJrli3zakkfkaSI2CzpSNtH526QZAFkanmexXGSdi15vVvS6QnLHC9pb84GSRbA\nbIrE5ZZ/mTb1cxUkC6CD7l64Q3cv3DlokT2S1ix5vUa95DBomeP7P8tCswAyTXJS1tHzJ+ro+Sev\nE/P1i25YvsgWSSfYXivpbklvkHTesmWuk7Re0tW2z5D0QERkHYJINAtgJkXEou31km6QNCfp8ojY\naftt/fcvjYhP2z7X9p2SHpZ0wTjbpFkAmdqelBURGyVtXPazS5e9Xt/U9hjgBJCEZAFkajtZTBvJ\nAkASkgWQiWQBADVIFkAmLqsHADVIFkAmLqsHADVoFgCSlJWjgAZx6hQAapAsgEwkCwCoQbIAMjEp\nCwBqkCyATEzKAoAaZbVGoEGcDQGAGjQLAEk4DAEycRgCADVIFkAmJmUBQI1mksUZjaxlav79ac9o\nu4SRHfbog22XMLq/a7uAyWJSFgDUKKs1Ag3ibAgA1CBZAJlIFgBQg2QBZCJZAEANmgWAJByGAJmY\n7g0ANUgWQCamewNAjbJaI9AgTp0CQA2SBZCJZAEANUgWQCbmWSxj+0jb19jeaXuH7Rm7LhaAJqQk\ni7+Q9OmIeL3tVZIOm3BNADpoYLOw/UxJPxERb5akiFiUNIMXgwSax6Ssp/oRSffZvsL2121fZnv1\nNAoD0C3DmsUqSSdL+quIOFnSw5LeNfGqgBmwT3NTe3TBsBy1W9LuiLip//oa1TSLDZc8+Xz+tN4D\n6LqFh6SFh9uuYnYMbBYRcY/tXbafGxG3Szpb0vbly21YP6nygMmZP7z3eNxF9432+a78iz8tKSM0\nb5f0UduHSPqmpAsmWxKALhraLCJim6RTp1ALMFP27S8rWTDdG0CSsk4UAw1aXCRZAEAFyQLItG+x\nrN2HZAEgCc0CQJKychTQoH0McAJAFckCyESyAIAaJAsg0+JjJAsAqCBZAJn27ytr9yFZAAcY28+y\nvcn27bb/0faRKyz3btvbbd9q+29sP23QemkWQK7Fuek9RvMuSZsi4rmSPquaq9vZXivprZJOjogX\nSpqT9AuDVkqzAA48r5b0kf7zj0h6bc0y35X0mKTV/Vt8rJa0Z9BKyzroAprU3XkWR0fE3v7zvZKO\nXr5ARNxv+32S/k3S9yTdEBH/NGilNAtgBtneJOmYmrd+d+mLiAjbUfP550j6DUlr1bsX0N/ZfmNE\nfHSlbTbSLOKlTaxlenz/7N0n6Y62C8gwMNNisK8sSJsXVnw7Is5Z6T3be20f07/g9rMl3Vuz2CmS\nvhQR3+l/5hOSzpQ02WYBFGnRk1v3KWf1Ho/7yz8c5dPXSXqzpPf2//eTNcvcJun3bT9d0qPqXbn/\nq4NWygAncOD575LOsX27pJf3X8v2sbavl564EPeVkrZIuqX/uQ8MWqkjKoczI7Ed+5811iqm7t77\n265gdN9tu4AMs3YYcpakiEiKC7ZD28fbd0byAifXNikkCwBJGLMAci22XcB0kSwAJCFZALlIFgBQ\nRbIAcj3WdgHTRbIAkIRmASAJhyFArn1tFzBdJAsASUgWQC5OnQJAFckCyEWyAIAqkgWQi2QBAFUk\nCyAXyQIAqkgWQC6SBQBU0SwAJOEwBMjFYQgAVJEsgFxcKQsAqkgWQC4ufgMAVSQLIBdnQwCgimQB\n5CJZAEAVzQJAEg5DgFwchgBAFckCyEWyAIAqkgWQi2QBAFUkCyAXyQIAqkgWQK7CLn7TSLM46P5o\nYjXTc0rbBWR4bdsFZLin7QJGdInbrqDTSBZALi5+AwBVNAsASTgMAXJx6hQAqkgWQC6SBQBUkSyA\nXCQLAKgiWQC5CpvuTbIAkIRkAeRiujcAVNEsACThMATIxalTAKgamixsv1vS+ZL2S7pV0gUR8R+T\nLgzoPJLFk2yvlfRWSSdHxAslzUn6hcmXBaBrhiWL76o39WS17X2SVkvaM/GqgFnApKwnRcT9kt4n\n6d8k3S3pgYj4p2kUBqBbhh2GPEfSb0haK+lYSYfbfuMU6gK6b98UHx0w7DDkFElfiojvSJLtT0g6\nU9JHn7rYhiXP5/sPoON2L0h7FtquYmYMaxa3Sfp920+X9KiksyV9tbrYhqbrAibv+Pne43E3XTTa\n5zkb8qSI2CbpSklbJN3S//EHJl0UgHy2f872dtv7bJ88YLkjbV9je6ftHbbPGLTeofMsIuJPJP1J\nRs3Aga27yeJWSa+TdOmQ5f5C0qcj4vW2V0k6bNDCTPcGDjARcZsk2SvfYc32MyX9RES8uf+ZRUkP\nDlov072BMv2IpPtsX2H767Yvs7160AdIFkCuSU7KundBum9hxbdtb5J0TM1bF0bEpxK2sErSyZLW\nR8RNti+W9C5J7xn0AQBd80Pzvcfjdj71TE1EnDPmFnZL2h0RN/VfX6Nes1gRzQLI1ZHJUkPUDlxE\nxD22d9l+bkTcrt60iO2DVsSYBXCAsf0627sknSHpetsb+z8/1vb1SxZ9u6SP2t4m6UWS/mjQekkW\nQK6OnjqNiGslXVvz87sl/cyS19sknZq6XpIFgCQkCyBXR5PFpJAsACQhWQC5uPgNAFTRLAAk4TAE\nyDUbk7IaQ7IAkIRkAeTi1CkAVJEsgFyFJYuGmsXmZlYzLVtmsEduqblOcuf9UNsFoEEzuNcAHcGk\nLACoIlkAuZhnAQBVJAsgV2FnQ0gWAJLQLAAk4TAEyMVhCABUkSyAXEzKAoAqkgWQi0lZAFBFsgBy\ncTYEAKpIFkAukgUAVJEsgFzMswCAKpoFgCQchgC5mJQFAFUkCyAXp04BoIpkAeQiWQBAFckCyMWk\nLACoIlkAuZhnAQBVNAsASTgMAXJF2wVMF8kCQBKaBYAkNAsASWgWAJLQLAAkoVkASEKzAJCEeRZA\ntrK+SdbxZPG1tgsY0Za2C8hwe9sFZPhG2wUUqePN4uttFzCiWWtu0mw2i+1tF9C3OMVH+zreLAB0\nBc0CQBJHjPdtGNuFfZ0GB7KIcMpyvf/uH5x0OUs8M7m2SRn7bEjbvwCA6eDUKZCtGwOP08KYBYAk\nJAsgG5OyAKCCZAFkI1kAQAXJAsjWzbMhtv9U0islfV/SNyVdEBG1k0Jsz6n3pabdEfGqQeslWQAH\nnn+U9IKIOEm9L/+8e8Cy75C0QwnXKqdZAAeYiNgUEfv7LzdLOr5uOdvHSzpX0gclDZ1cyWEIkG0m\nBjjfIuljK7z355J+W9IRKSuiWQAzyPYmScfUvHVhRHyqv8zvSvp+RPxNzedfKeneiLjZ9nzKNmkW\nQLZJDnB+tf+oFxHnDPq07V9W7xDjp1ZY5ExJr7Z9rqRDJR1h+8qIeNOK6xz3W6dAiXrfOt0xxS0+\nf5RvxK6T9D5JL4uIbycs/zJJvzXsbAjJAsjW2TGL/ynpEEmbbEvSlyPi12wfK+myiPiZms8MTQ0k\nCyBDL1lsm+IWT2r9chAkCyBbNydlTQrzLAAkIVkA2To7ZjERJAsASUgWQDbGLACggmYBIAmHIUA2\nBjgBoIJkAWRjgBMAKkgWQDbGLACggmQBZGPMAgAqSBZANsYsAKCCZgEgCYchQDYOQwCggmQBZOPU\nKQBUkCyAbIxZAEAFyQLIxpgFAFSQLIBsjFkAQAXJAsjGmAUAVNAsACThMATIxgAnAFSQLIBsDHAC\nQAXJAsjGmAUAVJAsgGyMWQBAhSOi7RqAmWN76jtORHja21yKZgEgCYchAJLQLAAkoVkASEKzAJCE\nZgEgyf8HmzwGuPTq2GoAAAAASUVORK5CYII=\n", 243 | "text/plain": [ 244 | "" 245 | ] 246 | }, 247 | "metadata": {}, 248 | "output_type": "display_data" 249 | } 250 | ], 251 | "source": [ 252 | "x = brain.layers[0].Ws[0].eval()\n", 253 | "import matplotlib.pyplot as plt\n", 254 | "%matplotlib inline\n", 255 | "plt.matshow(x)\n", 256 | "plt.colorbar()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 138, 262 | "metadata": { 263 | "collapsed": false 264 | }, 265 | "outputs": [ 266 | { 267 | "data": { 268 | "text/plain": [ 269 | "array([ 11.01352692, 11.28201485, 12.03692055, 12.26954937], dtype=float32)" 270 | ] 271 | }, 272 | "execution_count": 138, 273 | "metadata": {}, 274 | "output_type": "execute_result" 275 | } 276 | ], 277 | "source": [ 278 | "brain.input_layer.b.eval()" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 88, 284 | "metadata": { 285 | "collapsed": false 286 | }, 287 | "outputs": [ 288 | { 289 | "data": { 290 | "text/plain": [ 291 | "-2.0" 292 | ] 293 | }, 294 | "execution_count": 88, 295 | "metadata": {}, 296 | "output_type": "execute_result" 297 | } 298 | ], 299 | "source": [ 300 | "game.collect_reward(0)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 7, 306 | "metadata": { 307 | "collapsed": false 308 | }, 309 | "outputs": [], 310 | "source": [ 311 | "x = tf.Variable(tf.zeros((5,5)))" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 8, 317 | "metadata": { 318 | "collapsed": false 319 | }, 320 | "outputs": [ 321 | { 322 | "data": { 323 | "text/plain": [ 324 | "" 325 | ] 326 | }, 327 | "execution_count": 8, 328 | "metadata": {}, 329 | "output_type": "execute_result" 330 | } 331 | ], 332 | "source": [ 333 | "tf.clip_by_norm(x, 5)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "collapsed": true 341 | }, 342 | "outputs": [], 343 | "source": [] 344 | } 345 | ], 346 | "metadata": { 347 | "kernelspec": { 348 | "display_name": "Python 2", 349 | "language": "python", 350 | "name": "python2" 351 | }, 352 | "language_info": { 353 | "codemirror_mode": { 354 | "name": "ipython", 355 | "version": 2 356 | }, 357 | "file_extension": ".py", 358 | "mimetype": "text/x-python", 359 | "name": "python", 360 | "nbconvert_exporter": "python", 361 | "pygments_lexer": "ipython2", 362 | "version": "2.7.8" 363 | } 364 | }, 365 | "nbformat": 4, 366 | "nbformat_minor": 0 367 | } 368 | -------------------------------------------------------------------------------- /notebooks/game_memory.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import tempfile\n", 13 | "import tensorflow as tf\n", 14 | "\n", 15 | "from tf_rl.controller import DiscreteDeepQ, HumanController\n", 16 | "from tf_rl.simulation import KarpathyGame\n", 17 | "from tf_rl import simulate\n", 18 | "from tf_rl.models import MLP" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 20, 24 | "metadata": { 25 | "collapsed": false 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "/tmp/tmp3rSEdd\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "LOG_DIR = tempfile.mkdtemp()\n", 38 | "print(LOG_DIR)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 316, 44 | "metadata": { 45 | "collapsed": false 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "from random import randint, gauss\n", 50 | "\n", 51 | "import numpy as np\n", 52 | "\n", 53 | "class DiscreteHill(object):\n", 54 | " \n", 55 | " directions = [(0,1), (0,-1), (1,0), (-1,0)]\n", 56 | " \n", 57 | " def __init__(self, board=(10,10), variance=4.):\n", 58 | " self.variance = variance\n", 59 | " self.target = (0, 0)\n", 60 | " while self.target == (0, 0):\n", 61 | " self.target = (randint(-board[0], board[0]), randint(-board[1], board[1]))\n", 62 | " self.position = (0, 0)\n", 63 | " \n", 64 | " @staticmethod\n", 65 | " def add(p, q):\n", 66 | " return (p[0] + q[0], p[1] + q[1])\n", 67 | " \n", 68 | " @staticmethod\n", 69 | " def distance(p, q):\n", 70 | " return abs(p[0] - q[0]) + abs(p[1] - q[1])\n", 71 | " \n", 72 | " def estimate_distance(self, p):\n", 73 | " distance = DiscreteHill.distance(self.target, p) - DiscreteHill.distance(self.target, self.position)\n", 74 | " return distance + abs(gauss(0, self.variance))\n", 75 | " \n", 76 | " def observe(self): \n", 77 | " return np.array([self.estimate_distance(DiscreteHill.add(self.position, delta)) \n", 78 | " for delta in DiscreteHill.directions])\n", 79 | " \n", 80 | " def perform_action(self, action):\n", 81 | " self.position = DiscreteHill.add(self.position, DiscreteHill.directions[action])\n", 82 | " \n", 83 | " def is_over(self):\n", 84 | " return self.position == self.target\n", 85 | " \n", 86 | " def collect_reward(self, action):\n", 87 | " return -DiscreteHill.distance(self.target, DiscreteHill.add(self.position, DiscreteHill.directions[action])) \\\n", 88 | " + DiscreteHill.distance(self.target, self.position) - 2" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 329, 94 | "metadata": { 95 | "collapsed": false 96 | }, 97 | "outputs": [ 98 | { 99 | "name": "stderr", 100 | "output_type": "stream", 101 | "text": [ 102 | "Exception AssertionError: AssertionError() in > ignored\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "n_prev_frames = 3\n", 108 | "\n", 109 | "# Tensorflow business - it is always good to reset a graph before creating a new controller.\n", 110 | "tf.ops.reset_default_graph()\n", 111 | "session = tf.InteractiveSession()\n", 112 | "\n", 113 | "# This little guy will let us run tensorboard\n", 114 | "# tensorboard --logdir [LOG_DIR]\n", 115 | "journalist = tf.train.SummaryWriter(LOG_DIR)\n", 116 | "\n", 117 | "# Brain maps from observation to Q values for different actions.\n", 118 | "# Here it is a done using a multi layer perceptron with 2 hidden\n", 119 | "# layers\n", 120 | "brain = MLP([n_prev_frames * 4 + n_prev_frames - 1,], [4], \n", 121 | " [tf.identity])\n", 122 | "\n", 123 | "# The optimizer to use. Here we use RMSProp as recommended\n", 124 | "# by the publication\n", 125 | "optimizer = tf.train.RMSPropOptimizer(learning_rate= 0.001, decay=0.9)\n", 126 | "\n", 127 | "# DiscreteDeepQ object\n", 128 | "current_controller = DiscreteDeepQ(n_prev_frames * 4 + n_prev_frames - 1, 4, brain, optimizer, session,\n", 129 | " discount_rate=0.9, exploration_period=100, max_experience=10000, \n", 130 | " store_every_nth=1, train_every_nth=4, target_network_update_rate=0.1,\n", 131 | " summary_writer=journalist)\n", 132 | "\n", 133 | "session.run(tf.initialize_all_variables())\n", 134 | "session.run(current_controller.target_network_update)\n", 135 | "# graph was not available when journalist was created \n", 136 | "journalist.add_graph(session.graph_def)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 330, 142 | "metadata": { 143 | "collapsed": false 144 | }, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "Game 9900: iterations before success 16. Pos: (6, 6), Target: (6, 6)\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "performances = []\n", 156 | "\n", 157 | "try:\n", 158 | " for game_idx in range(10000):\n", 159 | " game = DiscreteHill()\n", 160 | " game_iterations = 0\n", 161 | "\n", 162 | " observation = game.observe()\n", 163 | " \n", 164 | " prev_frames = [(observation, -1)] * (n_prev_frames - 1)\n", 165 | " memory = np.concatenate([np.concatenate([observation, np.array([-1])])] * (n_prev_frames - 1) + [observation])\n", 166 | " \n", 167 | " while game_iterations < 50 and not game.is_over():\n", 168 | " action = current_controller.action(memory)\n", 169 | " if n_prev_frames > 1:\n", 170 | " prev_frames = prev_frames[1:] + [(observation, action)]\n", 171 | " reward = game.collect_reward(action)\n", 172 | " game.perform_action(action)\n", 173 | " observation = game.observe()\n", 174 | " new_memory = np.concatenate([np.concatenate([a, np.array([b])]) for (a, b) in prev_frames] + [observation])\n", 175 | " current_controller.store(memory, action, reward, new_memory)\n", 176 | " current_controller.training_step()\n", 177 | " memory = new_memory\n", 178 | " game_iterations += 1\n", 179 | " cost = abs(game.target[0]) + abs(game.target[1])\n", 180 | " performances.append((game_iterations - cost) / float(cost))\n", 181 | " if game_idx % 100 == 0:\n", 182 | " print \"\\rGame %d: iterations before success %d.\" % (game_idx, game_iterations),\n", 183 | " print \"Pos: %s, Target: %s\" % (game.position, game.target),\n", 184 | "except KeyboardInterrupt:\n", 185 | " print \"Interrupted\"" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 327, 191 | "metadata": { 192 | "collapsed": false 193 | }, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "[]" 199 | ] 200 | }, 201 | "execution_count": 327, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | }, 205 | { 206 | "data": { 207 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAEACAYAAABMEua6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGzdJREFUeJzt3XuUFPWd9/H3F2aG4X4xgiA35RijBCFuRPGymcQbxLia\nk+yqJ0qSx71kicTk8WxQkydgTtYVJXHDusnj8bau8QmwrCis4CWRUdYrLiLITVQMcr8Ml5GLzDDf\n549ftz0DM9M9M91d3dWf1zl9qrqquupbv+n59q9+9asqc3dERCQeOkUdgIiIZI+SuohIjCipi4jE\niJK6iEiMKKmLiMSIkrqISIykTepm1sfM5prZGjNbbWbn5SMwERFpu7IMlvk1sNDdv2lmZUD3HMck\nIiLtZK1dfGRmvYG33P3U/IUkIiLtla755RRgp5k9YmbLzOwBM+uWj8BERKTt0iX1MuBs4DfufjZw\nALg151GJiEi7pGtT3wRscvelifdzOSapm5luHiMi0g7ubtleZ6s1dXffBnxkZp9NTLoEWNXMcnq5\nM3Xq1MhjKJSXykJlobJo/ZUrmfR+mQw8bmYVwPvAd3MWjYiIdEjapO7ubwPn5CEWERHpIF1RmkVV\nVVVRh1AwVBYpKosUlUXutdpPPaMVmHku24dEROLIzPB8nygVEZHikvWk/tprsHBhttcqIiKZyHrz\niyUOJtQiIyLSMjW/iIhIWllN6ps3Z3NtIiLSVplcfJSRK66AbdtS72tqoF+/bK1dREQykZWa+qFD\n4eTosmWpaQsWZGPNIiLSFllJ6vPmNX0/ZoyaYkREopCVpF5fDxUVYXzBArj4Yigvz8aaRUSkLbKS\n1L/9bThyJIx37gzdu8Ojj8LOndlYu4iIZCrrXRo//hhWrIBVq+C667K9dhERaU1WLj4Cp7ISZs+G\nr34VeveGgwdh+HDYsCE7gYqIxEmuLj7KWlL/8EMYNixM27EjnCzdulVXloqINKfgryhNJnSA/v1h\nzpww/uGH2dqCiIikk5WkvmPH8dOGDg3DNWuysQUREclEVpL6iScePy2Z1B99NBtbEBGRTOT0IRnJ\nOzbu2we9enVoMyIisVLwberNmTEjDH/3u1xuRUREkrJ2Q6/m3HILPP88NDTkcisiIpKU82eU6qEZ\nIiLHK8rml8aefTZfWxIRKV05T+p33BGG48eH55eKiEju5Dyp/+xnoW0d4PBh2LVLTTEiIrmSl+aX\ne+4J94Gprw992nv0yMdWRURKT16SuhmMHAmXXhreHzyYj62KiJSevJ0oPbZb47p1+dqyiEjpyKif\nupl9COwHjgJ17j62rRtatCgM3UPNXbV1EZHsy/TiIweq3L0mGxs96yzo1OgYYdOmcCuBkSOzsXYR\nkdLVluaXDnWSf+EFePXVML5iBXz/+/DRR+H9NdfA5z/fkbWLiAhkeEWpmX0A7CM0v9zv7g80mtfq\nFaXNry81fvRoeK4pqKujiJSOqK8ovcDdvwBMAL5vZhd1ZKPz5qXGDx3qyJpERKSxjNrU3X1rYrjT\nzOYBY4ElyfnTpk37dNmqqiqqqqpaXd/VV8OECfDyy6mHaIwb17bARUSKSXV1NdXV1TnfTtrmFzPr\nBnR291oz6w48B9zh7s8l5re5+QWgrg4qKmD0aHj77TDtmWfg8svbvCoRkaITZfPLAGCJmS0HXgf+\nK5nQO6K8PAzffhsmTYKuXdXNUUSko9I2v7j7BmBMLoPo3Ts0x+i+6yIiHZO3K0qbc999YThiROi3\nPn16at7994eXiIhkLqdPPkqnoiIMKyth7twwbhZq7N/7Xnj/d38XTWwiIsUo0pp6bW0Ydu0Ka9em\npr/ySjTxiIgUu0iTerI747vvwumnQ03iJgQXXgg9e4bx996LJjYRkWKU82eUtmbLFrjoonDbgO7d\nk+truswrr6gPu4jET666NEaa1JtfX2q8V6/wGLzZs7O2ehGRghD1bQLy5q/+KjX+uc/BnDnRxSIi\nUmwKrqYO8PDDcOONsHcv9OkT7g9TWZnVTYiIRKpkml+OX38YNjQc394uIlKsSqb55Vh9+4bhm29G\nG4eISDEo+KR+zz1huHp1tHGIiBSDgm9+AZg8OdxSYOdO+MxncropEZG8KNnmF4Dt28Pw/fejjUNE\npNAVRVKfNQsuuQQ2bIg6EhGRwlYUSb1TJ+jfH371q6gjEREpbEWR1AFuuAGWLo06ChGRwlY0Sf3c\nc8PDNEREpGVFk9T79AF32L076khERApX0SR1Mxg1CpYtizoSEZHCVRT91FPbCsM8bU5EJGdKup96\n0s9+Fu6/LiIizSuqpD5mTOpeMCIicryiSuqdOsH8+VBfH3UkIiKFqaiS+vjxYZh8tqmIiDRVVEm9\nSxcYOhQeeCDqSEREClNRJXWA889XTV1EpCVFl9S/8Q3o2TPqKEREClPRJfVhw2DjxqijEBEpTEV1\n8RGEe6ufdJKeWSoixS3Si4/MrLOZvWVmC7IdQFsNGAAVFfCnP0UdiYhI4cm0+eVmYDVQEBfon302\nLFoUdRQiIoUnbVI3s8HAV4EHgYJo8Ni8GSZNijoKEZHCk0lN/V7gH4CGHMeSsd/9Ds45J+ooREQK\nT1lrM83sa8AOd3/LzKpaWm7atGmfjldVVVFV1eKiWXHmmbByZU43ISKSVdXV1VRXV+d8O632fjGz\nO4EbgHqgEugF/Ke7T2y0TF57v0C49W6nTqEZZtCgvG5aRCQrIun94u63u/sQdz8FuBZ4oXFCj4oZ\nDB8O77wTdSQiIoWlrRcfFUTvF4CdO2HWrKijEBEpLBkndXd/0d3/IpfBtMXDD0NNTdRRiIgUlqK7\nTUDSF7+o55WKiByr6G4TkFRXB927w4EDUF6e982LiHSInlF6jPLykNiXL486EhGRwlG0ST1p8uSo\nIxARKRxFndQnToTXX486ChGRwlG0beoAe/dC377hQdSdO0cSgohIu6hNvRl9+oRb8W7bFnUkIiKF\noaiTOsC+fVBbG3UUIiKFoeiTerdusH591FGIiBSGok/ql10G+/dHHYWISGEo+qTevz/s2BF1FCIi\nhaHok/rJJ4db8IqIiJK6iEisFH1SHzwYNm2KOgoRkcJQ9EldNXURkZSivqIU4NChcBHSoUPhEXci\nIsVAV5S2oGtXqKjQBUgiIhCDpA7QpQt88knUUYiIRE9JXUQkRmKR1CsrldRFRCBGSf3QoaijEBGJ\nXiyS+sknq6+6iAjEJKn37RtuwSsiUupikdR79dKdGkVEIEZJfe/eqKMQEYleLJL65z8Py5dHHYWI\nSPRikdT791ebuogIxCSpq0ujiEiQNqmbWaWZvW5my81stZn9Uz4Ca4vKSjh8OOooRESiV5ZuAXc/\nbGZfdveDZlYG/LeZXeju/52H+DLSrRscPBh1FCIi0cuo+cXdkymzAugM1OQsonbo3x+2b486ChGR\n6GWU1M2sk5ktB7YDi919dW7DapsBA2D3bqirizoSEZFopW1+AXD3BmCMmfUGnjWzKnevTs6fNm3a\np8tWVVVRVVWV3SjTKCsLtfWtW2Ho0LxuWkQkI9XV1VRXV+d8O21+8pGZ/R/gkLvPSLyP9MlHSePG\nwYwZcMEFUUciIpJeZE8+MrPPmFmfxHhX4FLgrWwH0lF6ALWISGbNLwOBR82sE+FH4DF3/2Nuw2q7\nYcNgw4aooxARiVYmXRpXAmfnIZYOOessWLQo6ihERKIViytKAcaMgbcKrlFIRCS/2nyi9LgVFMiJ\n0ro66N0bduyAHj2ijkZEpHWRnSgtFuXlcPrpsHZt1JGIiEQnNkkd4LTTYP36qKMQEYlOrJL66afD\nunVRRyEiEp1YJfXdu+GOO6KOQkQkOrFK6mPHRh2BiEi0YtP7BWDPHujXD2pqoG/fqKMREWmZer9k\nIJnI778/2jhERKISq6QOcNttcOBA1FGIiEQjdkl93Dh49dWooxARiUbskvro0boASURKV6xOlEJ4\nAHXXrrB/P/TsGXU0IiLN04nSDFVWhqFuwysipSh2SR3g0kthy5aooxARyb9YJvXu3WHJkqijEBHJ\nv1gm9TPOgJ07o45CRCT/YpnUhw+HAjp3KyKSN7FM6pWVoReMiEipiWVS79Il3AdGRKTUxDKpNzTA\n00/rmaUiUnpid/ERhPb0Tp3g2mvh97+POhoRkePp4qM2MIM774Q+faKOREQkv2KZ1AGGDYO9e6OO\nQkQkv2Kb1Lt1gzfeiDoKEZH8im1SHzwYPvgg6ihERPIrlidKIfRT7907DC3rpyJERDpGJ0rbqLIy\nJHNdhCQipSRtUjezIWa22MxWmdk7ZvaDfASWDb16hfuqi4iUikxq6nXAj9x9JHAe8H0zOyO3YWVH\nz55QWxt1FCIi+ZM2qbv7Nndfnhj/GFgDDMp1YNlQUaFujSJSWtrUpm5mw4EvAK/nIphsq6+HFSui\njkJEJH8y7v1iZj2AauAX7v5ko+k+derUT5erqqqiqqoqu1G2U+fO4T4wBdg5R0RKTHV1NdXV1Z++\nv+OOO3LS+yWjpG5m5cB/AYvc/Z+PmVeQXRoB7r8fvvc9ePddOO20qKMREUmJrEujmRnwELD62IRe\n6P76r8PwpZeijUNEJF/S1tTN7ELgJWAFkFz4Nnd/JjG/YGvqAP36hXurF3CIIlKCclVTj+0VpUn7\n94crS5cvh9Gjo45GRCTQFaXt1KsXfP3roV1dRCTuYp/UIdzca8uWqKMQEcm9kkjqAwfC5s1RRyEi\nknslkdQ/9zn1gBGR0hD7E6UQ7qs+YgRs3w79+0cdjYiITpR2yKmnwoQJ8MADUUciIpJbJZHUAa6+\nWj1gRCT+Siapn3mmkrqIxF9JtKkD7NgBAwbArl1wwglRRyMipU5t6h3Uvz/07Qvz5kUdiYhI7pRM\nUodw18Y5c6KOQkQkd0qm+QXgwAHo0QOWLYPPfha6d486IhEpVbqhV5ZYoyIsorBFJGbUpp4lf/mX\nqfE//CG6OEREcqHkauoQHkb9ta/B2rWhV0ynkvtpE5Goqfkly5YuhbFjw3hDQ9NmGRGRXFPzS5ad\ncw68+WYYf+21kNSTr6VLo41NRKS9SramntSzJ3z88fHT6+qgrCz/8YhIaVBNPUeWLAnDe++F2lpY\nvbrpdBGRYlLyNfXm3HcfTJ4Mt9wCM2ZEHY2IxJFOlOZRfT2MHBluABazXRORAqHmlzwqK4OVK8P4\nyy9HG4uISFsoqbegogJuvBEuvTTU3EVEioGaX1qxbx/06RPuF1NbG3U0IhInan6JQO/ecOutocvj\nc8+FaUruIlLIVFNP4+jR5vurb9wIQ4bkPx4RiQfV1CPSuXO4jcDJJzedfu+90cQjItIaJfUMmMGm\nTaF7o3u4u+O994bpV14ZdXQiIilpm1/M7GHgCmCHu49qZn6sm19a0twNwHRjMBHJVJTNL48A47O9\n4WKXrLUvWpSa1qkTXHUVvPNOdHFJ9tXV5fYitK1bYf/+MF5fDy+8ECoIIu2RNqm7+xJgTx5iKUrj\nx4d/+HHjwvv582HUqFBjX7gQnn8e7rwzPEpPcsc93Er5O98JTWUnnBD+BsmLyFrS0ABHjsD//E9Y\nfvp0OPHEMP7b38LBg+GahU6dYMUKuPnmsO7rrgv35W/rNQy1tamE/cYbcNddMGhQ6Gn1538O5eVw\n8cXhXM6kSXDPPTBlCtx+ezhpL5KWu6d9AcOBlS3Mc3Gvq3PfsycMn3zSvWfPZF2+6etv/iYsE6WG\nBvf33w9xbtjQ8nI1Ne6XXea+fn1u4zl61P3RR90nTgxl9MknqXnf+laYVlvrPmGCe9eu7nv3pubP\nn+8+b577r37VfHmD+3vvhWU//ND9wIGw/y++6P7gg80vX17e8rqaez3xhPuYMe5TpqTf10zW9+tf\ntzyvcdlIcUvkzoxycFteSuo59NFH7vv2ua9d2/w/6B//GBLamjXuf/pTx7d33XXuZWXuu3alpu3c\nmdreTTe5z5rVfCwff5z6TEND88usWRPmv/CCe3V1x+N1dz98uPltXXRRiKktyXXgwLDvTzwR9mHX\nrtS87t1b/tyYMe5nnOHeo0d4n9y3F190nzQpxOju/tJL4YewocF99Wr3QYOaT/DN2b8/tcy554bh\n5Mnuhw6llqmpCT88yXJZvjzM37XL/ec/T31+5MjslH22LFuWWUWlutp99+7wPaqpyX1cha6gk/rU\nqVM/fS1evDinBVHsFi8OpX7mmccnhLvuanstvqEh1N5efbXpuv7xH1tPgA8+6L5li/uzz6amLVoU\n1nn99eF9nz6hRv+jHzW/jl/+8vhYDh92v/nmMP/ll8P0l192X7Uq/ICdf37z65o71/3gwfBDeNNN\nTedt2eJ+yinhx+TgQfdTT00lx3/5F/d169yHDGm+prx7d2o9v/iF+0MPuT/+ePgRfegh94ULmy6/\ncqV7fX3m5V9b6z56tPu//7t7r17+aU37yBH3GTNCmQ4dmorhxz/OfN3Huuee1HoKxSWXpGJ68cWm\n8xoa3KdNa/k72KOH+9atHY9h507373zH/cor3f/+71Pr//rX3Z96qumRXZQWL17cJFcWdFKX9kl+\n+WbOdD/ppNT7zp3dN20KyzQ0hET/29+GeU8+GZLcsmXuTz/d9J/km98MCenYf553302ta+fO4+PY\ns6fp9iE0aRzrscfCvFtucZ86NYx/61uhRplJTXrs2OanJ2vCjU2f7n7BBU1rsoXu6NHW9z8bR2NH\nj7pXVobvSLbW2R6bN6f265ZbUuMTJ4bvSfL7Cu5/+7fuZuGoZ+ZM94svTs077bTwA9ge69eHWn8m\n3z0IR7L/9m/hyOzxx7NbHu0RWVIHfg9sAT4BPgK+e8z8nO98XDU0pL7Q9fWhptP4SzhiROZf2K98\nJfzDu4cmn2Stuy2uuCKsa8KEzJafPv34OKZMcb/qqjD/yivDtFGjUvNXr257XMUmua//8R+hLHfv\nzu76Z85sWuannx6O1JJHGPlod//GN8K2k812+/Yd/12YOLFpU+CxfvCD1LLXXOM+e3aIvaEhNHXN\nnu2+YEGovKxbF5b74IOwzEsvNd3W3LnHr7+hIfz4jBqVOoo69vX++2G5KERaU291BUrqWVVfHw7p\nk1+6QYNC+/KhQ+HLt2FDaDa44gr3f/3XqKMNRxTJtvbWHD3qvm1b7uMpBAsXtr/22RaHD7vfcEPz\nyerqq8OwS5fQzj9njvvGjaFtv7kk9tRToSad/PxPfxqOBhtLJvBk+35ztd2jR8PnHnkks2R5+HDr\n5ztae/XsGWremVq1yv3558N446PcQYNCs16+5Sqp694vIkWupgZ++lN48MHQpz6dgQNh2bLQdbNz\n5/D5E05oftkvfxl++MNwY7s1a1LTH3sMrr8+O/FDiHvJEnj6aRg+HM46K4wPGBC6k3btGm6FXVEB\nffvCnj1QXQ1f+lL7tzlzZuiimrRsGcyaBXffnZr2k5/Az38etveTn4RrCH7zm1AWN94I3bq1f/t6\n8pGIZMQ9dWVzbS188gk8/DD07x+upxg7NnWx0333wU03wamnhttfDB8ePl9fD0uXwoUXptY7eHC4\nBuDEE2HHjrzvVs5s3AjDhrX9c/X14UexvZTURSQrDh+GZ5+Fq69OTduyJdTgj3XkCOzaFS6QirO6\nunAl+LBh0K9fanp9PVx+eThauPvucAQxfTqMGAHXXNOxbSqpi0jWvfIKnHtux2qc0j5K6iIiMaL7\nqYuISFpK6iIiMaKkLiISI0rqIiIxoqQuIhIjSuoiIjGipC4iEiNK6iIiMaKkLiISI0rqIiIxoqQu\nIhIjSuoiIjGipC4iEiNK6iIiMaKkLiISI0rqIiIxoqQuIhIjSuoiIjGipC4iEiNK6iIiMaKkLiIS\nI0rqIiIxoqQuIhIjaZO6mY03s7Vmtt7MpuQjKBERaZ9Wk7qZdQbuA8YDZwLXmdkZ+QisGFVXV0cd\nQsFQWaSoLFJUFrmXrqY+FnjP3T909zpgFnBV7sMqTvrCpqgsUlQWKSqL3EuX1E8GPmr0flNimoiI\nFKB0Sd3zEoWIiGSFubect83sPGCau49PvL8NaHD36Y2WUeIXEWkHd7dsrzNdUi8D1gEXA1uAN4Dr\n3H1NtgMREZGOK2ttprvXm9lNwLNAZ+AhJXQRkcLVak1dRESKS4euKI37hUlmNsTMFpvZKjN7x8x+\nkJjez8yeN7N3zew5M+vT6DO3JcpjrZld1mj6n5nZysS8X0exP9lgZp3N7C0zW5B4X5JlYWZ9zGyu\nma0xs9Vmdm4Jl8Vtif+RlWb2/8ysS6mUhZk9bGbbzWxlo2lZ2/dEWc5OTH/NzIalDcrd2/UiNMe8\nBwwHyoHlwBntXV8hvoCTgDGJ8R6E8wtnAHcDP05MnwLclRg/M1EO5YlyeY/U0dAbwNjE+EJgfNT7\n184y+d/A48D8xPuSLAvgUeB/JcbLgN6lWBaJ/fkA6JJ4Pxv4dqmUBXAR8AVgZaNpWdt3YBLwm8T4\nNcCstDF1YGfGAc80en8rcGvUhZzjP+CTwCXAWmBAYtpJwNrE+G3AlEbLPwOcBwwE1jSafi3wf6Pe\nn3bs/2DgD8CXgQWJaSVXFokE/kEz00uxLPoRKjt9CT9uC4BLS6ksEgm6cVLP2r4nljk3MV4G7EwX\nT0eaX0rqwiQzG074RX6d8Afbnpi1HRiQGB9EKIekZJkcO30zxVlW9wL/ADQ0mlaKZXEKsNPMHjGz\nZWb2gJl1pwTLwt1rgF8CGwk95Pa6+/OUYFk0ks19/zTPuns9sM/M+rW28Y4k9ZI5w2pmPYD/BG52\n99rG8zz8hMa+LMzsa8AOd38LaLZvbamUBaHGdDbhsPhs4ADhSPVTpVIWZjYC+CGhtjoI6GFm1zde\nplTKojlR7HtHkvpmYEij90No+msTC2ZWTkjoj7n7k4nJ283spMT8gcCOxPRjy2QwoUw2J8YbT9+c\ny7hz4HzgL8xsA/B74Ctm9hilWRabgE3uvjTxfi4hyW8rwbL4IvCKu+9O1CSfIDTNlmJZJGXjf2JT\no88MTayrDOidODpqUUeS+pvAaWY23MwqCI348zuwvoJjZgY8BKx2939uNGs+4WQQieGTjaZfa2YV\nZnYKcBrwhrtvA/YnekgYcEOjzxQFd7/d3Ye4+ymENr8X3P0GSrMstgEfmdlnE5MuAVYR2pNLqiwI\n7cfnmVnXxD5cAqymNMsiKRv/E081s65vAn9Mu/UOniCYQDhJ8h5wW9QnLHJwAuRCQvvxcuCtxGs8\n4eTQH4B3geeAPo0+c3uiPNYClzea/mfAysS8mVHvWwfL5Uuker+UZFkAo4GlwNuE2mnvEi6LHxN+\n1FYSegWVl0pZEI5atwBHCG3f383mvgNdgDnAeuA1YHi6mHTxkYhIjOhxdiIiMaKkLiISI0rqIiIx\noqQuIhIjSuoiIjGipC4iEiNK6iIiMaKkLiISI/8fTk8JmejJIq0AAAAASUVORK5CYII=\n", 208 | "text/plain": [ 209 | "" 210 | ] 211 | }, 212 | "metadata": {}, 213 | "output_type": "display_data" 214 | } 215 | ], 216 | "source": [ 217 | "N = 500\n", 218 | "smooth_performances = [float(sum(performances[i:i+N])) / N for i in range(0, len(performances) - N)]\n", 219 | "\n", 220 | "plt.plot(range(len(smooth_performances)), smooth_performances)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 231, 226 | "metadata": { 227 | "collapsed": false 228 | }, 229 | "outputs": [ 230 | { 231 | "data": { 232 | "text/plain": [ 233 | "" 234 | ] 235 | }, 236 | "execution_count": 231, 237 | "metadata": {}, 238 | "output_type": "execute_result" 239 | }, 240 | { 241 | "data": { 242 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAJBCAYAAABRfpDFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGgJJREFUeJzt3X+wZGV95/HPhzsgDojEUAGBSU3WRYNGUYpfhUm8GLAm\nxJ8pE0MkGkxZVrJjTKqSqCQxQ2o3WZO4IVk2LiKiBCOpoBgMjmRivCpRR0aZAWeGBTRUZoZlQFkw\ngETuzHf/6AYu95zb/fTTp/ucnuf9quqy+/bpc773yvnO5zzn6XMcEQKAYQ5quwAAs4FmASAJzQJA\nEpoFgCQ0CwBJaBYAktAsgBlk+1Dbm21vtb3D9h/XLPNG29ts32L7X2y/aJxtrhrnwwDaERGP2j4r\nIh6xvUrSjbZ/PCJuXLLYtyT9ZEQ8aHudpA9IOiN3mzQLYEZFxCP9p4dImpN0/7L3v7zk5WZJx4+z\nPQ5DgBll+yDbWyXtlfS5iNgxYPFfkfTpcbZHswBmVETsj4gXq5cYftL2fN1yts+S9BZJ7xxnexyG\nABlsT/1LVRHhFX7+oO3rJZ0iaWHpe/1BzcskrYuI/zfO9mkWQKYNLW7L9lGSFiPiAdtPl3SOpIuW\nLfPDkj4h6fyIuHPcGmgWwGx6tqSP2D5IveGEv46Iz9p+myRFxKWS3iPpByS937YkPRYRp+Vu0HxF\nHRid7fivU9ze72nlw5BpYYATQBIOQ4BMB7ddwJSRLAAkIVkAmUrbeUgWAJKU1hyBxjBmAQA1SBZA\nptJ2HpIFgCSlNUegMYxZAEANmgWAJByGAJlK23lIFgCSlNYcgcYwwAkANUgWQKbSdh6SBYAkpTVH\noDGMWQBADZIFkIlkAQA1SBZAptJ2HpIFgCQ0CwBJSktSQGMY4ASAGiQLIFNpOw/JAkCS0poj0BjG\nLACgBskCyFTazkOyAJCktOYINIYxCwCoQbMAkITDECBTaTsPyQJAkk42C9vrbN9m+w7b72y7nmFs\nf8j2Xtu3tl1LKttrbH/O9nbb37D9623XNIjtQ21vtr3V9g7bf9x2TQdP8dEFnWsWtuckXSJpnaTn\nSzrP9ontVjXUFerVO0sek/SbEfECSWdI+i9d/jtHxKOSzoqIF0t6kaSzbP94y2UVpYuHXadJujMi\n7pIk21dLeo2knW0WNUhEfNH22rbrGEVE3CPpnv7zh2zvlHSsuv13fqT/9BBJc5Lub7GcTu48k9S5\nZCHpOEm7lrze3f8ZJqTf6F4iaXO7lQxm+yDbWyXtlfS5iNjRdk0l6WJzjLYLKIntwyVdI+kdEfFQ\n2/UMEhH7Jb3Y9jMl3WB7PiIW2qqnK2MJ09LFZLFH0polr9eoly7QMNsHS/q4pKsi4pNt15MqIh6U\ndL2kU9qupSRdTBZbJJ3Qj8Z3S3qDpPPaLOhAZNuSLpe0IyIubrueYWwfJWkxIh6w/XRJ50i6qM2a\nSBYti4hFSesl3SBph6S/jYjODrpJku2PSfqSpOfa3mX7grZrSvBSSeerd1bh5v6jy2d0ni3pn/tj\nFpslfSoiPttyTUVxBEMEwKhsxz1T3N4xkiLCU9xkReeSBYBuolkASNLFAU5gJhw8zb1ncYrbWgHJ\nAkASkgWQaVVhyWLsX9c2p1NwwGj7jEOXNdQb/6CZ1VQsSJqfwHrXTmCdkvRJSa+d0LpPntB63y/p\nVye07r0TWu+Vkt40gfW+YqSlD56bQAkdxpgFgCSMWQCZpjpm0QEdTxZr2y5gRD/adgEZZvG7WCe1\nXUCROt4b17ZdwIhmsVmc2nYBGbrRLKY6z6IDOp4sAHQFzQJAksKCFNAgTp0CQBXJAshV2N5DsgCQ\npLDeCDSosL2HZAHMoFFuP2n7VNuLtn92nG0W1huBBrW79zx++8mt/Xu/fM32puUXt+7fDvS9kj4j\naaxv1JIsgBkUEfdExNb+84fUu+3ksTWLvl29m0jdN+42SRZAro7Ms1jp9pO2j1PvPsEvV29e/1jX\nnhnaLPr3krhYvT/NByPiveNsEMBwC49IC98bvtyQ209eLOldERH9m0qNdRgy8L4h/eOd/yPpbPVu\nK3iTpPOWHhf1rpQ1qYvfTMratgvIMKmL30zSpC5+MymvSL5Slu2IEyddz5Lt7axexat/+8l/kLSx\n7q5ytr+lJxvEUZIekfTWiLgup4ZhyeI0SXdGxF39jV+tXqzp9B3CgANdyu0nI+I/LVn+CvXu4pbV\nKKThzeI4SbuWvN4t6fTcjQFozOO3n7zF9s39n10o6YclKSIubXqDw5oFF+MFVtLi6YGIuFEjnM2M\niLHvvzvs190jac2S12vUSxfLLCx5vlazOSaA8mzrP5BiWLPYIumE/qmZuyW9QdJ51cXmm60KmIqT\n9NSrbl012sc7cup0WgY2i4hYtL1e0g3q/WkuXz5DDEAZhh51RcRGSRunUAswWwqb0sh0bwBJCuuN\nQIMK23tIFgCSFNYbgQYVdjaEZAEgCc0CQBIOQ4Bche09JAsASQrrjUCDCtt7SBYAkhTWG4EGFbb3\nkCwAJCmsNwINYlIWAFSRLIBche09JAsASQrrjUCDCtt7SBYAkjTSG4+PNzaxmqnZ/boT2i5hZD99\n7SfaLmFkP6jvtF3CSK4a6+Z+B77CghTQIE6dAkAVyQLIVdjeQ7IAkKSw3gg0qLC9h2QBIElhvRFo\nEGdDAKCKZAHkKmzvIVkASFJYbwQaVNjeQ7IAkIRmASBJYUEKaFBhew/JAkCSwnoj0CAmZQFAFckC\nyFXY3kOyAJCksN4INKiwvYdkASBJYb0RaBBnQwCgimYBIAmHIUCuwvYekgWAJIX1RqBBhe09JAsA\nSQrrjUCDCtt7SBYAkhTWG4EGMSkLAKpIFkCuwvYekgWAJIX1RqBBhe09JAsASWgWAJIUFqSABnHq\nFACqSBZArsL2HpIFgCQ0CyDXqik+lrH9Idt7bd+6Unm2523fbPsbthfG/XVpFsBsukLSupXetH2k\npP8l6VUR8WOSXj/uBhs56nqP/rCJ1UzN3LX72i5hZG954cfaLmFkD2+ZrX+Lrhr1Ay2eDYmIL9pe\nO2CRX5T08YjY3V/+2+Nuc7b+3wSQ6gRJz7L9OdtbbP/SuCssbDwXaFC3956DJZ0s6ackrZb0Zdtf\niYg7clfY7V8XKNTCzt5jDLskfTsivifpe7a/IOkkSTQL4EAyf2Lv8biLrh15FX8v6RLbc5KeJul0\nSf9jnJpoFkCuFvce2x+T9DJJR9neJekP1Dv0UERcGhG32f6MpFsk7Zd0WUTsGGebNAtgBkXEeQnL\n/JmkP2tqmzQLIFdhew+nTgEkKaw3Ag3iK+oAUEWyAHIVtveQLAAkKaw3Ag0qbO8hWQBIUlhvBBrE\n2RAAqKJZAEjCYQiQq7C9Z2iysL2mf7Wd7f0Lf/76NAoD0C0pvfExSb8ZEVttHy7pa7Y3RcR4l+YA\nZh3J4qki4p6I2Np//pCknZKOnXRhALplpN7Yv5rwSyRtnkQxwEwp7NRpcrPoH4JcI+kd/YTxhOs2\nbHvi+fPmj9bz5o9prEBgUr7w+dAXvxBtlzEzHDH8j2X7YEn/IGljRFy87L34QJw/ofImY07cN2Qa\nZu2+IYcful8R4ZRlbUd8ZdIVLdneGUqubVJSzoZY0uWSdixvFADKkXIY8lJJ50u6xfbN/Z+9OyI+\nM7mygBlQ2NmQob9uRNwoZnoCxaMJAEhSWJACGlTY3kOyAJCksN4INKiwSVkkCwBJSBZArsL2HpIF\ngCSF9UagQYXtPSQLAEkK641AgzgbAgBVJAsgV2F7D8kCQBKaBYAkhQUpoEGF7T0kCwBJCuuNQIMK\n23sa+XX3+KomVjM1G17edgUZrmu7gNEdtnV/2yWgQYX1RqA5waQsAKgiWQCZ9hW295AsACQprDcC\nzSFZAECNwnoj0JzFuWn+W9v+aWiSBYAkNAsASTgMATLtWzXN3ef7U9xWPZIFgCQkCyDTvrmy5nuT\nLAAkIVkAmfYVdnlvkgWAJCQLINMiyQIAqkgWQKZ9he0+JAsASWgWAJKUlaOABnHqFABqkCyATCQL\nAKhBsgAykSwAoAbNAsi0qLmpPerYXmf7Ntt32H5nzftH2f6M7a22v2H7l8f5fWkWwAyyPSfpEknr\nJD1f0nm2T1y22HpJN0fEiyXNS3qf7eyhB8YsgEwtT/c+TdKdEXGXJNm+WtJrJO1cssz/lfSi/vMj\nJH0nIhZzN0izAGbTcZJ2LXm9W9Lpy5a5TNI/275b0jMk/fw4G6RZAJkmeTbkpoVHtGXhkUGLRMJq\nLpS0NSLmbT9H0ibbJ0XEv+fURLMAOujU+dU6dX71E6//90X3L19kj6Q1S16vUS9dLHWmpP8mSRHx\nTdv/Kul5krbk1MQAJzCbtkg6wfZa24dIeoOk65Ytc5uksyXJ9tHqNYpv5W6QZAFkanNSVkQs2l4v\n6QZJc5Iuj4idtt/Wf/9SSX8k6Qrb29QLBr8TEZWIkopmAcyoiNgoaeOyn1265Pm3Jb2qqe3RLIBM\nXIMTAGqQLIBMXIMTAGqU1RqBBvEVdQCoQbIAMpEsAKAGzQJAkkYOQz4cO4cv1CGffco3e2fDjZef\n03YJozu87QJG5ZGW5jAEAGowwAlkYro3ANQgWQCZmO4NADXKao1AgzgbAgA1SBZAJpIFANQgWQCZ\nmGcBADVoFgCScBgCZGJSFgDUKKs1Ag3i1CkA1CBZAJlIFjVsz9m+2fanJl0QgG5KTRbvkLRD0jMm\nWAswU5iUtYzt4yWdK+mDGvUihQAOGCnJ4s8l/bakIyZcCzBTSptnMfC3tf1KSfdGxM2251da7oEN\nlzzx/ND503To/GmNFQhMzPYFacdC21XMjGGt8UxJr7Z9rqRDJR1h+8qIeNPShY7csH5S9QGT84L5\n3uNxH7+orUpmwsBmEREXSrpQkmy/TNJvLW8UQKk4dTpYTKQKAJ2XPEITEZ+X9PkJ1gLMFJIFANQo\n69wP0CCSBQDUIFkAmZjuDQA1SBZAptKme5MsACQpqzUCDeJsCADUoFkASMJhCJCJwxAAqEGyADIx\nKQsAapAsgExMygKAGmW1RqBBnA0BgBqNJIu7PnxiE6uZmofPm70eedgr9rddwuiuaLuA0Yx6By2S\nBQDUYMwCyESyAIAaNAsASTgMATIx3RsAapAsgExM9waAGmW1RqBBnDoFMBNsr7N9m+07bL9zhWX+\nsv/+NtsvGWd7JAsgU5vJwvacpEsknS1pj6SbbF8XETuXLHOupP8cESfYPl3S+yWdkbtNkgUwm06T\ndGdE3BURj0m6WtJrli3zakkfkaSI2CzpSNtH526QZAFkanmexXGSdi15vVvS6QnLHC9pb84GSRbA\nbIrE5ZZ/mTb1cxUkC6CD7l64Q3cv3DlokT2S1ix5vUa95DBomeP7P8tCswAyTXJS1tHzJ+ro+Sev\nE/P1i25YvsgWSSfYXivpbklvkHTesmWuk7Re0tW2z5D0QERkHYJINAtgJkXEou31km6QNCfp8ojY\naftt/fcvjYhP2z7X9p2SHpZ0wTjbpFkAmdqelBURGyVtXPazS5e9Xt/U9hjgBJCEZAFkajtZTBvJ\nAkASkgWQiWQBADVIFkAmLqsHADVIFkAmLqsHADVoFgCSlJWjgAZx6hQAapAsgEwkCwCoQbIAMjEp\nCwBqkCyATEzKAoAaZbVGoEGcDQGAGjQLAEk4DAEycRgCADVIFkAmJmUBQI1mksUZjaxlav79ac9o\nu4SRHfbog22XMLq/a7uAyWJSFgDUKKs1Ag3ibAgA1CBZAJlIFgBQg2QBZCJZAEANmgWAJByGAJmY\n7g0ANUgWQCamewNAjbJaI9AgTp0CQA2SBZCJZAEANUgWQCbmWSxj+0jb19jeaXuH7Rm7LhaAJqQk\ni7+Q9OmIeL3tVZIOm3BNADpoYLOw/UxJPxERb5akiFiUNIMXgwSax6Ssp/oRSffZvsL2121fZnv1\nNAoD0C3DmsUqSSdL+quIOFnSw5LeNfGqgBmwT3NTe3TBsBy1W9LuiLip//oa1TSLDZc8+Xz+tN4D\n6LqFh6SFh9uuYnYMbBYRcY/tXbafGxG3Szpb0vbly21YP6nygMmZP7z3eNxF9432+a78iz8tKSM0\nb5f0UduHSPqmpAsmWxKALhraLCJim6RTp1ALMFP27S8rWTDdG0CSsk4UAw1aXCRZAEAFyQLItG+x\nrN2HZAEgCc0CQJKychTQoH0McAJAFckCyESyAIAaJAsg0+JjJAsAqCBZAJn27ytr9yFZAAcY28+y\nvcn27bb/0faRKyz3btvbbd9q+29sP23QemkWQK7Fuek9RvMuSZsi4rmSPquaq9vZXivprZJOjogX\nSpqT9AuDVkqzAA48r5b0kf7zj0h6bc0y35X0mKTV/Vt8rJa0Z9BKyzroAprU3XkWR0fE3v7zvZKO\nXr5ARNxv+32S/k3S9yTdEBH/NGilNAtgBtneJOmYmrd+d+mLiAjbUfP550j6DUlr1bsX0N/ZfmNE\nfHSlbTbSLOKlTaxlenz/7N0n6Y62C8gwMNNisK8sSJsXVnw7Is5Z6T3be20f07/g9rMl3Vuz2CmS\nvhQR3+l/5hOSzpQ02WYBFGnRk1v3KWf1Ho/7yz8c5dPXSXqzpPf2//eTNcvcJun3bT9d0qPqXbn/\nq4NWygAncOD575LOsX27pJf3X8v2sbavl564EPeVkrZIuqX/uQ8MWqkjKoczI7Ed+5811iqm7t77\n265gdN9tu4AMs3YYcpakiEiKC7ZD28fbd0byAifXNikkCwBJGLMAci22XcB0kSwAJCFZALlIFgBQ\nRbIAcj3WdgHTRbIAkIRmASAJhyFArn1tFzBdJAsASUgWQC5OnQJAFckCyEWyAIAqkgWQi2QBAFUk\nCyAXyQIAqkgWQC6SBQBU0SwAJOEwBMjFYQgAVJEsgFxcKQsAqkgWQC4ufgMAVSQLIBdnQwCgimQB\n5CJZAEAVzQJAEg5DgFwchgBAFckCyEWyAIAqkgWQi2QBAFUkCyAXyQIAqkgWQK7CLn7TSLM46P5o\nYjXTc0rbBWR4bdsFZLin7QJGdInbrqDTSBZALi5+AwBVNAsASTgMAXJx6hQAqkgWQC6SBQBUkSyA\nXCQLAKgiWQC5CpvuTbIAkIRkAeRiujcAVNEsACThMATIxalTAKgamixsv1vS+ZL2S7pV0gUR8R+T\nLgzoPJLFk2yvlfRWSSdHxAslzUn6hcmXBaBrhiWL76o39WS17X2SVkvaM/GqgFnApKwnRcT9kt4n\n6d8k3S3pgYj4p2kUBqBbhh2GPEfSb0haK+lYSYfbfuMU6gK6b98UHx0w7DDkFElfiojvSJLtT0g6\nU9JHn7rYhiXP5/sPoON2L0h7FtquYmYMaxa3Sfp920+X9KiksyV9tbrYhqbrAibv+Pne43E3XTTa\n5zkb8qSI2CbpSklbJN3S//EHJl0UgHy2f872dtv7bJ88YLkjbV9je6ftHbbPGLTeofMsIuJPJP1J\nRs3Aga27yeJWSa+TdOmQ5f5C0qcj4vW2V0k6bNDCTPcGDjARcZsk2SvfYc32MyX9RES8uf+ZRUkP\nDlov072BMv2IpPtsX2H767Yvs7160AdIFkCuSU7KundBum9hxbdtb5J0TM1bF0bEpxK2sErSyZLW\nR8RNti+W9C5J7xn0AQBd80Pzvcfjdj71TE1EnDPmFnZL2h0RN/VfX6Nes1gRzQLI1ZHJUkPUDlxE\nxD22d9l+bkTcrt60iO2DVsSYBXCAsf0627sknSHpetsb+z8/1vb1SxZ9u6SP2t4m6UWS/mjQekkW\nQK6OnjqNiGslXVvz87sl/cyS19sknZq6XpIFgCQkCyBXR5PFpJAsACQhWQC5uPgNAFTRLAAk4TAE\nyDUbk7IaQ7IAkIRkAeTi1CkAVJEsgFyFJYuGmsXmZlYzLVtmsEduqblOcuf9UNsFoEEzuNcAHcGk\nLACoIlkAuZhnAQBVJAsgV2FnQ0gWAJLQLAAk4TAEyMVhCABUkSyAXEzKAoAqkgWQi0lZAFBFsgBy\ncTYEAKpIFkAukgUAVJEsgFzMswCAKpoFgCQchgC5mJQFAFUkCyAXp04BoIpkAeQiWQBAFckCyMWk\nLACoIlkAuZhnAQBVNAsASTgMAXJF2wVMF8kCQBKaBYAkNAsASWgWAJLQLAAkoVkASEKzAJCEeRZA\ntrK+SdbxZPG1tgsY0Za2C8hwe9sFZPhG2wUUqePN4uttFzCiWWtu0mw2i+1tF9C3OMVH+zreLAB0\nBc0CQBJHjPdtGNuFfZ0GB7KIcMpyvf/uH5x0OUs8M7m2SRn7bEjbvwCA6eDUKZCtGwOP08KYBYAk\nJAsgG5OyAKCCZAFkI1kAQAXJAsjWzbMhtv9U0islfV/SNyVdEBG1k0Jsz6n3pabdEfGqQeslWQAH\nnn+U9IKIOEm9L/+8e8Cy75C0QwnXKqdZAAeYiNgUEfv7LzdLOr5uOdvHSzpX0gclDZ1cyWEIkG0m\nBjjfIuljK7z355J+W9IRKSuiWQAzyPYmScfUvHVhRHyqv8zvSvp+RPxNzedfKeneiLjZ9nzKNmkW\nQLZJDnB+tf+oFxHnDPq07V9W7xDjp1ZY5ExJr7Z9rqRDJR1h+8qIeNOK6xz3W6dAiXrfOt0xxS0+\nf5RvxK6T9D5JL4uIbycs/zJJvzXsbAjJAsjW2TGL/ynpEEmbbEvSlyPi12wfK+myiPiZms8MTQ0k\nCyBDL1lsm+IWT2r9chAkCyBbNydlTQrzLAAkIVkA2To7ZjERJAsASUgWQDbGLACggmYBIAmHIUA2\nBjgBoIJkAWRjgBMAKkgWQDbGLACggmQBZGPMAgAqSBZANsYsAKCCZgEgCYchQDYOQwCggmQBZOPU\nKQBUkCyAbIxZAEAFyQLIxpgFAFSQLIBsjFkAQAXJAsjGmAUAVNAsACThMATIxgAnAFSQLIBsDHAC\nQAXJAsjGmAUAVJAsgGyMWQBAhSOi7RqAmWN76jtORHja21yKZgEgCYchAJLQLAAkoVkASEKzAJCE\nZgEgyf8HmzwGuPTq2GoAAAAASUVORK5CYII=\n", 243 | "text/plain": [ 244 | "" 245 | ] 246 | }, 247 | "metadata": {}, 248 | "output_type": "display_data" 249 | } 250 | ], 251 | "source": [ 252 | "x = brain.layers[0].Ws[0].eval()\n", 253 | "import matplotlib.pyplot as plt\n", 254 | "%matplotlib inline\n", 255 | "plt.matshow(x)\n", 256 | "plt.colorbar()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 138, 262 | "metadata": { 263 | "collapsed": false 264 | }, 265 | "outputs": [ 266 | { 267 | "data": { 268 | "text/plain": [ 269 | "array([ 11.01352692, 11.28201485, 12.03692055, 12.26954937], dtype=float32)" 270 | ] 271 | }, 272 | "execution_count": 138, 273 | "metadata": {}, 274 | "output_type": "execute_result" 275 | } 276 | ], 277 | "source": [ 278 | "brain.input_layer.b.eval()" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 88, 284 | "metadata": { 285 | "collapsed": false 286 | }, 287 | "outputs": [ 288 | { 289 | "data": { 290 | "text/plain": [ 291 | "-2.0" 292 | ] 293 | }, 294 | "execution_count": 88, 295 | "metadata": {}, 296 | "output_type": "execute_result" 297 | } 298 | ], 299 | "source": [ 300 | "game.collect_reward(0)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 269, 306 | "metadata": { 307 | "collapsed": false 308 | }, 309 | "outputs": [ 310 | { 311 | "data": { 312 | "text/plain": [ 313 | "array([ 1.39407934, 2.791605 , 0.92194436, 1.73143704, -1. ])" 314 | ] 315 | }, 316 | "execution_count": 269, 317 | "metadata": {}, 318 | "output_type": "execute_result" 319 | } 320 | ], 321 | "source": [ 322 | "np.concatenate([observation, np.array([-1])])" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 278, 328 | "metadata": { 329 | "collapsed": false 330 | }, 331 | "outputs": [ 332 | { 333 | "data": { 334 | "text/plain": [ 335 | "1" 336 | ] 337 | }, 338 | "execution_count": 278, 339 | "metadata": {}, 340 | "output_type": "execute_result" 341 | } 342 | ], 343 | "source": [ 344 | "n_prev_frames" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 285, 350 | "metadata": { 351 | "collapsed": false 352 | }, 353 | "outputs": [ 354 | { 355 | "data": { 356 | "text/plain": [ 357 | "1" 358 | ] 359 | }, 360 | "execution_count": 285, 361 | "metadata": {}, 362 | "output_type": "execute_result" 363 | } 364 | ], 365 | "source": [ 366 | "action" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 294, 372 | "metadata": { 373 | "collapsed": false 374 | }, 375 | "outputs": [ 376 | { 377 | "data": { 378 | "text/plain": [ 379 | "[]" 380 | ] 381 | }, 382 | "execution_count": 294, 383 | "metadata": {}, 384 | "output_type": "execute_result" 385 | } 386 | ], 387 | "source": [ 388 | "performances[:10]" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 306, 394 | "metadata": { 395 | "collapsed": true 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "performances_1 = performances[:]" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 328, 405 | "metadata": { 406 | "collapsed": false 407 | }, 408 | "outputs": [ 409 | { 410 | "data": { 411 | "text/plain": [ 412 | "0.89357223079637949" 413 | ] 414 | }, 415 | "execution_count": 328, 416 | "metadata": {}, 417 | "output_type": "execute_result" 418 | } 419 | ], 420 | "source": [ 421 | "np.average(performances[-1000:])" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": { 428 | "collapsed": true 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "np" 433 | ] 434 | } 435 | ], 436 | "metadata": { 437 | "kernelspec": { 438 | "display_name": "Python 2", 439 | "language": "python", 440 | "name": "python2" 441 | }, 442 | "language_info": { 443 | "codemirror_mode": { 444 | "name": "ipython", 445 | "version": 2 446 | }, 447 | "file_extension": ".py", 448 | "mimetype": "text/x-python", 449 | "name": "python", 450 | "nbconvert_exporter": "python", 451 | "pygments_lexer": "ipython2", 452 | "version": "2.7.6" 453 | } 454 | }, 455 | "nbformat": 4, 456 | "nbformat_minor": 0 457 | } 458 | -------------------------------------------------------------------------------- /notebooks/karpathy_game.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%load_ext autoreload\n", 12 | "%autoreload 2\n", 13 | "%matplotlib inline" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": { 20 | "collapsed": false 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "from __future__ import print_function\n", 25 | "\n", 26 | "import numpy as np\n", 27 | "import tempfile\n", 28 | "import tensorflow as tf\n", 29 | "\n", 30 | "from tf_rl.controller import DiscreteDeepQ, HumanController\n", 31 | "from tf_rl.simulation import KarpathyGame\n", 32 | "from tf_rl import simulate\n", 33 | "from tf_rl.models import MLP\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "/tmp/tmp7dz3utz0\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "LOG_DIR = tempfile.mkdtemp()\n", 53 | "print(LOG_DIR)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": { 60 | "collapsed": false 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "current_settings = {\n", 65 | " 'objects': [\n", 66 | " 'friend',\n", 67 | " 'enemy',\n", 68 | " ],\n", 69 | " 'colors': {\n", 70 | " 'hero': 'yellow',\n", 71 | " 'friend': 'green',\n", 72 | " 'enemy': 'red',\n", 73 | " },\n", 74 | " 'object_reward': {\n", 75 | " 'friend': 0.1,\n", 76 | " 'enemy': -0.1,\n", 77 | " },\n", 78 | " 'hero_bounces_off_walls': False,\n", 79 | " 'world_size': (700,500),\n", 80 | " 'hero_initial_position': [400, 300],\n", 81 | " 'hero_initial_speed': [0, 0],\n", 82 | " \"maximum_speed\": [50, 50],\n", 83 | " \"object_radius\": 10.0,\n", 84 | " \"num_objects\": {\n", 85 | " \"friend\" : 25,\n", 86 | " \"enemy\" : 25,\n", 87 | " },\n", 88 | " \"num_observation_lines\" : 32,\n", 89 | " \"observation_line_length\": 120.,\n", 90 | " \"tolerable_distance_to_wall\": 50,\n", 91 | " \"wall_distance_penalty\": -0.0,\n", 92 | " \"delta_v\": 50\n", 93 | "}" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": { 100 | "collapsed": false 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "# create the game simulator\n", 105 | "g = KarpathyGame(current_settings)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": { 112 | "collapsed": false 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "human_control = False\n", 117 | "\n", 118 | "if human_control:\n", 119 | " # WSAD CONTROL (requires extra setup - check out README)\n", 120 | " current_controller = HumanController({b\"w\": 3, b\"d\": 0, b\"s\": 1,b\"a\": 2,}) \n", 121 | "else:\n", 122 | " # Tensorflow business - it is always good to reset a graph before creating a new controller.\n", 123 | " tf.reset_default_graph()\n", 124 | " session = tf.InteractiveSession()\n", 125 | "\n", 126 | " # This little guy will let us run tensorboard\n", 127 | " # tensorboard --logdir [LOG_DIR]\n", 128 | " journalist = tf.train.SummaryWriter(LOG_DIR)\n", 129 | "\n", 130 | " # Brain maps from observation to Q values for different actions.\n", 131 | " # Here it is a done using a multi layer perceptron with 2 hidden\n", 132 | " # layers\n", 133 | " brain = MLP([g.observation_size,], [200, 200, g.num_actions], \n", 134 | " [tf.tanh, tf.tanh, tf.identity])\n", 135 | " \n", 136 | " # The optimizer to use. Here we use RMSProp as recommended\n", 137 | " # by the publication\n", 138 | " optimizer = tf.train.RMSPropOptimizer(learning_rate= 0.001, decay=0.9)\n", 139 | "\n", 140 | " # DiscreteDeepQ object\n", 141 | " current_controller = DiscreteDeepQ((g.observation_size,), g.num_actions, brain, optimizer, session,\n", 142 | " discount_rate=0.99, exploration_period=5000, max_experience=10000, \n", 143 | " store_every_nth=4, train_every_nth=4,\n", 144 | " summary_writer=journalist)\n", 145 | " \n", 146 | " session.run(tf.initialize_all_variables())\n", 147 | " session.run(current_controller.target_network_update)\n", 148 | " # graph was not available when journalist was created \n", 149 | " journalist.add_graph(session.graph)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 7, 155 | "metadata": { 156 | "collapsed": false, 157 | "scrolled": false 158 | }, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/html": [ 163 | "\n", 164 | "\n", 165 | "\n", 166 | "\n", 167 | " \n", 170 | "\n", 171 | " \n", 174 | "\n", 175 | " \n", 176 | "\n", 177 | " \n", 178 | "\n", 179 | " \n", 180 | "\n", 181 | " \n", 182 | "\n", 183 | " \n", 184 | "\n", 185 | " \n", 186 | "\n", 187 | " \n", 188 | "\n", 189 | " \n", 190 | "\n", 191 | " \n", 192 | "\n", 193 | " \n", 194 | "\n", 195 | " \n", 196 | "\n", 197 | " \n", 198 | "\n", 199 | " \n", 200 | "\n", 201 | " \n", 202 | "\n", 203 | " \n", 204 | "\n", 205 | " \n", 206 | "\n", 207 | " \n", 208 | "\n", 209 | " \n", 210 | "\n", 211 | " \n", 212 | "\n", 213 | " \n", 214 | "\n", 215 | " \n", 216 | "\n", 217 | " \n", 218 | "\n", 219 | " \n", 220 | "\n", 221 | " \n", 222 | "\n", 223 | " \n", 224 | "\n", 225 | " \n", 226 | "\n", 227 | " \n", 228 | "\n", 229 | " \n", 230 | "\n", 231 | " \n", 232 | "\n", 233 | " \n", 234 | "\n", 235 | " \n", 236 | "\n", 237 | " \n", 238 | "\n", 239 | " \n", 242 | "\n", 243 | " \n", 246 | "\n", 247 | " \n", 250 | "\n", 251 | " \n", 254 | "\n", 255 | " \n", 258 | "\n", 259 | " \n", 262 | "\n", 263 | " \n", 266 | "\n", 267 | " \n", 270 | "\n", 271 | " \n", 274 | "\n", 275 | " \n", 278 | "\n", 279 | " \n", 282 | "\n", 283 | " \n", 286 | "\n", 287 | " \n", 290 | "\n", 291 | " \n", 294 | "\n", 295 | " \n", 298 | "\n", 299 | " \n", 302 | "\n", 303 | " \n", 306 | "\n", 307 | " \n", 310 | "\n", 311 | " \n", 314 | "\n", 315 | " \n", 318 | "\n", 319 | " \n", 322 | "\n", 323 | " \n", 326 | "\n", 327 | " \n", 330 | "\n", 331 | " \n", 334 | "\n", 335 | " \n", 338 | "\n", 339 | " \n", 342 | "\n", 343 | " \n", 346 | "\n", 347 | " \n", 350 | "\n", 351 | " \n", 354 | "\n", 355 | " \n", 358 | "\n", 359 | " \n", 362 | "\n", 363 | " \n", 366 | "\n", 367 | " \n", 370 | "\n", 371 | " \n", 374 | "\n", 375 | " \n", 378 | "\n", 379 | " \n", 382 | "\n", 383 | " \n", 386 | "\n", 387 | " \n", 390 | "\n", 391 | " \n", 394 | "\n", 395 | " \n", 398 | "\n", 399 | " \n", 402 | "\n", 403 | " \n", 406 | "\n", 407 | " \n", 410 | "\n", 411 | " \n", 414 | "\n", 415 | " \n", 418 | "\n", 419 | " \n", 422 | "\n", 423 | " \n", 426 | "\n", 427 | " \n", 430 | "\n", 431 | " \n", 434 | "\n", 435 | " \n", 438 | "\n", 439 | " \n", 442 | "\n", 443 | " \n", 444 | "\n", 445 | " fps = 108.6\n", 446 | "\n", 447 | " \n", 448 | "\n", 449 | " \n", 450 | "\n", 451 | " nearest wall = 112.5\n", 452 | "\n", 453 | " \n", 454 | "\n", 455 | " \n", 456 | "\n", 457 | " reward = -0.0\n", 458 | "\n", 459 | " \n", 460 | "\n", 461 | " \n", 462 | "\n", 463 | " objects eaten => friend: 191, enemy: 144\n", 464 | "\n", 465 | " \n", 466 | "\n", 467 | " \n", 468 | "\n" 469 | ], 470 | "text/plain": [ 471 | "" 472 | ] 473 | }, 474 | "metadata": {}, 475 | "output_type": "display_data" 476 | }, 477 | { 478 | "name": "stdout", 479 | "output_type": "stream", 480 | "text": [ 481 | "Interrupted\n" 482 | ] 483 | } 484 | ], 485 | "source": [ 486 | "FPS = 30\n", 487 | "ACTION_EVERY = 3\n", 488 | " \n", 489 | "fast_mode = True\n", 490 | "if fast_mode:\n", 491 | " WAIT, VISUALIZE_EVERY = False, 50\n", 492 | "else:\n", 493 | " WAIT, VISUALIZE_EVERY = True, 1\n", 494 | "\n", 495 | " \n", 496 | "try:\n", 497 | " with tf.device(\"/cpu:0\"):\n", 498 | " simulate(simulation=g,\n", 499 | " controller=current_controller,\n", 500 | " fps=FPS,\n", 501 | " visualize_every=VISUALIZE_EVERY,\n", 502 | " action_every=ACTION_EVERY,\n", 503 | " wait=WAIT,\n", 504 | " disable_training=False,\n", 505 | " simulation_resolution=0.001,\n", 506 | " save_path=None)\n", 507 | "except KeyboardInterrupt:\n", 508 | " print(\"Interrupted\")" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 8, 514 | "metadata": { 515 | "collapsed": true 516 | }, 517 | "outputs": [], 518 | "source": [ 519 | "session.run(current_controller.target_network_update)" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 9, 525 | "metadata": { 526 | "collapsed": false 527 | }, 528 | "outputs": [ 529 | { 530 | "data": { 531 | "text/plain": [ 532 | "array([[ 0.00824914, -0.04792793, 0.08260646, ..., -0.00135659,\n", 533 | " -0.0149605 , -0.00065048],\n", 534 | " [ 0.04895933, -0.01720949, 0.03015076, ..., 0.04350275,\n", 535 | " -0.00071916, -0.00507376],\n", 536 | " [-0.03408033, -0.00734746, -0.07286905, ..., 0.06636748,\n", 537 | " 0.0507561 , 0.04723936],\n", 538 | " ..., \n", 539 | " [-0.01454929, -0.00313209, 0.02152171, ..., 0.01723659,\n", 540 | " -0.01757577, -0.02262094],\n", 541 | " [ 0.03226471, -0.09545884, 0.01721121, ..., 0.0179732 ,\n", 542 | " -0.01188065, 0.03430547],\n", 543 | " [ 0.02971489, 0.06272104, -0.05087573, ..., -0.00265156,\n", 544 | " -0.00139228, 0.01183042]], dtype=float32)" 545 | ] 546 | }, 547 | "execution_count": 9, 548 | "metadata": {}, 549 | "output_type": "execute_result" 550 | } 551 | ], 552 | "source": [ 553 | "current_controller.q_network.input_layer.Ws[0].eval()" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 10, 559 | "metadata": { 560 | "collapsed": false 561 | }, 562 | "outputs": [ 563 | { 564 | "data": { 565 | "text/plain": [ 566 | "array([[ 0.00792158, -0.04618821, 0.08227044, ..., -0.00076906,\n", 567 | " -0.01525626, -0.00064818],\n", 568 | " [ 0.04957703, -0.01709567, 0.03077196, ..., 0.04223152,\n", 569 | " -0.00179744, -0.00452804],\n", 570 | " [-0.03365123, -0.00761478, -0.07296818, ..., 0.0649555 ,\n", 571 | " 0.0494989 , 0.04785381],\n", 572 | " ..., \n", 573 | " [-0.01843384, -0.00160771, 0.02573192, ..., 0.01739593,\n", 574 | " -0.01991378, -0.01864818],\n", 575 | " [ 0.03407663, -0.09417476, 0.01543032, ..., 0.01674915,\n", 576 | " -0.00885472, 0.03387342],\n", 577 | " [ 0.02986682, 0.06386744, -0.05153623, ..., -0.00791921,\n", 578 | " -0.00157495, 0.01063347]], dtype=float32)" 579 | ] 580 | }, 581 | "execution_count": 10, 582 | "metadata": {}, 583 | "output_type": "execute_result" 584 | } 585 | ], 586 | "source": [ 587 | "current_controller.target_q_network.input_layer.Ws[0].eval()" 588 | ] 589 | }, 590 | { 591 | "cell_type": "markdown", 592 | "metadata": {}, 593 | "source": [ 594 | "# Average Reward over time" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": 11, 600 | "metadata": { 601 | "collapsed": false 602 | }, 603 | "outputs": [ 604 | { 605 | "data": { 606 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAEACAYAAAB27puMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXmYHFW5h9/JZCZMJglZZjo7SQwgJIAE7g1BEIY9IKso\n6FVEUAE1rIqAiCQ8XkUu94KCyKpGkMAFBFGDEJZBvAgIYc9CEjLZk54sE8JM9vT946uyqnu6p6un\n1+r+vc9TTy19TtVXPT31q+9853wHhBBCCCGEEEIIIYQQQgghhBBCCCGEEEIIIYQQAoDJwDxgAXBV\nijK/cD5/G5gQoO5/AXOd8n8Advd9do1Tfh5wfPbmCyGEKDbVwEJgNFADvAXsm1DmJGCms30I8EqA\nuscBPZztG50FYJxTrsapt9BXTgghRJHI9kE8EXugtwDbgYeA0xLKnApMd7ZfBfoDQ9LUnQXs8tUZ\n4WyfBsxwyrc49SdmeQ9CCCGyJFsxGQ4s8+0vd44FKTMsQF2A8/E8m2FOuXR1hBBCFJBsxSQWsFxV\nN89/LbANeDAHNgghhMgTPbOsvwIY6dsfSbznkKzMCKdMTZq6X8PiLcekOdeKRKPGjh0bW7RoUaAb\nEEII8S8WAXsW48I9nYuPBmpJH4CfhBeA76ruZOB9oCHhXG4AvhYY49RP5vXEwsz1119fbBO6TZht\nj8Vkf7GR/cWFLFp6svVMdgBTgKex3ln3YV16L3Q+vwsTkpOwYHk7cF6augC3YYIxy9n/B/BtYA7w\nv856h3NMzVxCCFFkshUTgKecxc9dCftTMqgLsFcX1/uJswghhCgRNEajBGlqaiq2Cd0mzLaD7C82\nsj+8dLeXVanjNP8JIYQISlVVFXRTF+SZCCGEyBqJiRBCiKyRmAghhMgaiYkQQoiskZgIIdKyYwec\ncAI891yxLRGlisRECJGWaBSeeQZefLHYlohSRWIihEhLNBq/FiIRiYkQIi0SE5EOiYkQIi3RKIwY\nITERqZGYCCHS0toK++1nYhKLwcKFMH8+7NqVvq6oDCQmQoi0RKMwbhysXQtz5piw/Nu/qXeX8MhF\n1mAhRJkTjcKECdDeDitXwuGHw7Bhti0EyDMRQgQgGoU99rDxJitXQiRii2IowkViIoRISzRq4lFf\nDy0tEhPRGYmJECItfjFZvBgaGyUmIh6JiRCiS37yE2vaamz0xKTUPJM774THHy+2FZWNAvBCiC6Z\nOhXuvRf69i1dMfnWt6y32RlnFNuSykWeiRAiJdu321iSc86x/fp6WLbME5PW1uLa56dv32JbUNlI\nTIQQKWlvNwGpciZyra+3dSRizV7uIMZi0t5u6x56mhUVff1CiJS4YuLiF5O6OqipgU2bimObi+sd\nuaIiioPERAiRkkQxcT2UPn1sXQpxk2gUdt9dYlJsJCZCiJQkisnw4bZ2RSUSgenTbX/Fitxcc7fd\n4NFHg5ePRmHMGIlJsZGYCCFSkigmt98eHyOJRODll217yZLcXHPrVnj11eDlJSalgcRECJGSRDFJ\nJBKBd9+17WI1d/nFpNidASoZiYkQIiVBxMQNgOdCTNyU9pmkto9GLelkz57m1YjiIDERQqSkvR16\n9079eSRi6332yY2YbNhg67a24HX8qV7U1FU8JCZCiJTMnBlMTPbbD55/3mvySqSjA155xdtftszi\nL/Pm2f6uXdDcDH/5i+37pwl+773U19+0CZ56yuzYtg3+7/8C3ZbIAxITIURKHnoITj459eef+Qxc\nfjlceSVUV8OttyYvd8stcOih3v6DD8L3vgd33GH7//gHHHUU/Pa3Nm+KKybnnAP775/6+i+9ZEJ1\n8MFw0klw//0Z3Z7IIRITIURSNm+G2lo49dTUZUaMgP/5H5g4ES6+OHV6lS1b4vejURg/3hMNN9bR\n2grXXusd//jjrm2MRuGss6ChAS68ENatS39fIj9ITIQQSWltteYjd0xJOroawJh4jtZWExNXfNav\nt/W8efEik+7abrwk3fVF/pGYCCGSEo1a/q2guLm6kuEGxt11NGpxFn9sBGwmxzFjrItvkGC6xKR0\nkJgIIZLif1AHoauHeWL34cRmLn+9Xr28c+3cacfcdVc2DhpkvcFSlRX5RWIihEhKa2tmnkmfPvYg\n7+iw/dWrLeZywAHexFUnnGBpV6JRm39k40Y4+ujOIjRsmB2fPdv2b7ml8/VaWizgPmyY7ffsaZ7N\nf/93RrcpcoQmxxJCJGXTJujXL3j5qiqbU2TTJutOPH++zYdSUwOvvw4DB8Jxx8HSpSYeQ4fCggUw\nahT0728TcJ1yip3rySdNzGprYcYMmDOn8/WWL4exY+GYY7xjN99s1xKFJxeeyWRgHrAAuCpFmV84\nn78NTAhQ9wvA+8BO4CDf8dHAZuBNZ7kja+uFEElJN/o9Gf6Bg66H0tAAn/ykeTmDBtkUwHV1ltBx\n5EjbbmmBPff0mqwaGmDffU0sxo71zuVn61bYY4/4eUwOOKC0JuyqJLL1TKqB24FjgRXAP4Engbm+\nMicBewJ7AYcAvwImpan7LnAGcFeSay4kXpCEEHkgWzFxH+r+Hln+aX+7OpbqnH62bDFB8qMgfPHI\n1jOZiD3cW4DtwEPAaQllTgWmO9uvAv2BIWnqzgM+yNI2IUQWZCsmybr3phKTtrbMxWTrVgvW+5GY\nFI9sxWQ4sMy3v9w5FqTMsAB1kzEGa+JqBg7PzFwhRBCWLOm+mCxdCm++mfyhnkpMqqthwIDU5wwq\nJg0NNnAxk0SRIjdk28wVNOFzwGFPaVkJjAQ2YLGUJ4DxQKeJQ6dOnfqv7aamJpqamnJkghDlz+jR\n9hD/9Kczq1dfD9ddBx98AOeeaz2tvve9+M8XL4Yjjog/1tiYeg73TJq5amqs/MaNqcVJeDQ3N9Pc\n3JyTc2UrJiuwh7vLSMzD6KrMCKdMTYC6iWxzFoDZwCIsFjM7saBfTIQQmdNdz6SlxbZXrIC7747v\nbdW7t4nJ5z8ff6yr8SyZeCb+8hKT9CS+aE+bNq3b58q2met17GE+GqgFzsaC6H6eBL7qbE8C2oA1\nAetCvFfTgAXuAT7h1P8wu1sQQqSiq4zBqcpvc173kgXV6+tNBBKbufIhJqKwZOuZ7ACmAE9jD/n7\nsN5YFzqf3wXMxHp0LQTagfPS1AXryfULTDz+gsVITgSOBKZhAftdznUymPlACJEJ3fFMXJYuTS4m\n0FlMdt899Tl795aYhIFcDFp8yln8JHbpnZJBXYDHnSWRx5xFCJEn/MHr7opJz542YDFxBH13xWTz\nZrPLH1dJFjNxzycxKTxKpyKEiMM/QLBPn8zq9u1r6zFjrG5iM5n7uV9M+vbtupmruto8kM2bvWOt\nrXD99fJMSgmlUxFCxNHebm/8jz0Ge+2VWd1vftMmwbruuuQJF08+Gf78Zxvd7nLVVck9DD91deaJ\nuJ7Nyy/bWmJSOsgzEULE0d4OgwfbzIWpuuumIhKxBI2pgup9+8JnPxt/3mHDLG9XV/Tq5U2gBZ7H\nIzEpHeSZCCHi6E6X4ETq682byBW77RY/W+OOHbb2N335ry0xKTwSEyFEHKUoJomeiRvX2bgx+bUl\nJoVHzVxCVBC7dnlv9anIlZhkMrFWOhLFxBWLVGKyaZMmySo0EhMhKojvftfiIV2RCzE54ACYkMPc\n3qnE5OijO5ft2xduuMFiPqJwqJlLiAritddg/fquy+RCTL7znezqJ5IYM2lvhyuugDPP7FzWHdvy\n/PO5tUF0jTwTISqIWIDUrLkQk1yTzDNJZWMum9dEcCQmQog4wi4mmcxbL3KHxESICsL1TLryUEpR\nTJI1c8kzKS0kJkJUCMuWwSuv2PasWanLlaKYuJ7Jzp0wfTo8/nh6z2THDpgzp3A2VjoSEyEqhGee\nsUmvzjgDZsxIXa6UxeSDD+BrX4NVq+D445OX7d3bUsFMnAiPPlpQMysa9eYSokKIRuHss+Gww+Cu\nxLzePkpZTDY5c6pOmADDu5jk+3Ofs8m55s8vjH1CnokQFUM0avGESCT5/OwupSgmbsykK7sTSXef\nIrdITISoEMIsJq5n4todpIuzxKSwSEyEqBASxSTVAzkMYhIEiUlhkZgIUSG4YlJfD1VVqZMhfvxx\n5pNi5ZtBg2DaNFi50vaPPDJ9nUjEJtEShUFiIkSF0NrqjcHo6q197VpoaCicXUG4+GLLuTV3Lvzu\nd3DrrenrDBwIbW3pE1uK3CAxEaIC2LUrXiS6EhPXgyklqqpgjz3gvfeC21ZdDQMGwLp1+bVNGBIT\nISqAtjZruqqttf1UYrJ5s/Wa2n33wtoXhEgEVq/OTOgUNykcEhMhKoBEbyNVPMFtCquqKpxtQfE3\n0WVSR2JSGCQmQlQAiWLS2GgpVZYv77pcKeGmSckkniMxKRwSEyEqgPXrLX7gcvLJNjr8zjvjy5Wy\nmJx+Olx2mXUTDorEpHAonYoQFUB7e3x338MPh4sugjfeiC9XymJyzDG2ZILEpHDIMxGiAkg2EDES\ngTVr4o+Vsph0B4lJ4ZCYCFEBJBOTwYM7P2glJqK7SEyEqABSeSaJD1r/wMZyQGJSOCQmQlQAqTyT\nZctg0iR48UU7Fo2W17S3EpPCITERogJIJiZ9+8Ls2TBqFLz1lh376CPo16/w9uULiUnhkJgIUQG0\nt9sMhInstx984hNe0sdSzBicDX37wvbtNrJf5BeJiRAVQFciUV9fvmJSVWXNdsoenH8kJkJUAJUq\nJqCmrkIhMRGiDPnoI5tMykViYttr1wabpVFkjsREiDKkf3847zxvP6iYdHSUt5g0NsIDDxTXnnIl\nF2IyGZgHLACuSlHmF87nbwMTAtT9AvA+sBM4KOFc1zjl5wHHZ2m7EGVJLAYLFnj7QcRk2zbbd9PU\nlwuJzVyrVxfPlnImWzGpBm7HRGEc8CVg34QyJwF7AnsBFwC/ClD3XeAM4G8J5xoHnO2sJwN35OAe\nhChL/M05QcSkHJu4QDGTQpHtg3gisBBoAbYDDwGnJZQ5FZjubL8K9AeGpKk7D/ggyfVOA2Y45Vuc\n+hOzvAchyh6JSbGtKH+yFZPhwDLf/nLnWJAywwLUTWSYUy6TOkJUPBKTYltR/mQrJkH7ReRz3jb1\nzRCiC267zQLryQYtgicmM2YU1q5C4YqJ2+yn3lz5Idv5TFYAI337I4n3HJKVGeGUqQlQN931RjjH\nOjF16tR/bTc1NdHU1JTm1EKUJ5dcYuseKV4dBwywybOeey6+B1i54IqJOwpeo+E9mpubaW5uLrYZ\ngInRImA0UAu8RfIA/ExnexLwSgZ1XwAO9u2Pc8rVAmOc+sm8npgQlQzEYgcd5G139S+xc2cs1rNn\nLLbXXrHYO+8Uxr5CsnlzLFZTE4utWWPfw5VXFtui0oUsWnqy9Ux2AFOAp7HeWfcBc4ELnc/vwoTk\nJCxY3g6cl6YuWE+uXwANwF+AN4ETgTnA/zrrHcC3UTOXEFnRo4fNq75gQXmln3fZbTeoq4OVK23f\nHVMjcksupu19yln83JWwPyWDugCPO0syfuIsQogk7Npl60xiA4MG2fiLQYPyY1OxiUSgpcW2JSb5\nQWM0hCgDOjpMRLZts/gHWEqVoFQ5jcU9c/F6WYI0NsIHzmCDbMQkFvMGd4p4JCZChJyHH7YeWVdf\nbc1Ve+wBQ4fCokXwxBPBzjFxoqWiL1f22w+mTrWU9Js2df88P/0p9OqlHmHJkJgIEXJWr7Zuv3Pm\nWCqUjg6LD3zjG7BmjTVdpUvBft99Jj7lyt132/fy3HOW7LG7uE1lmXh9lYLERIiQ097uxQT8AXR3\n/MiWLfY2LbIfwOjW1SDIzkhMhAg5HR32kFy8OLmYbN0qMXFpbIwfwJgpEpPUSEyECDmuZ+KKikt9\nvcUHdu6Empri2VdK9O5t30V34ybRKIwbJzFJhsREiBAzaxZs2OCJSKKYbNhgXklVPhMahYxIxIsh\nffAB3HSTxZt++UtYssQrt24d/PznFkt64w07Fo1aMF9i0pky7QgoRGVwvDOjz9VX2zpRTNavVxNX\nIgMGmMgC3H8/3HwzvPkmPPSQNQlecYV99re/wWWXwQ9+YF5fR4d9Pno0tLUVzfySRZ6JEGVAKs9k\n3TobAS48/DNLRqMwfjy8/77t+3u9+WefdD+LROLrCw+JiRAhxZ+wMJWYyDPpjF8MWlut2eq992zf\n33yVKBjRqMSkKyQmQoQU/1t0V2IizySeZJ5JLAb9+0tMskFiIkRI8T/4GhpsnayZS55JPL17x4vJ\nfvvZdmJgvb0d+vTx9leulJh0hQLwQoSMhQvhhBPg44+9Y4MGWQqVIUO8YwMH2oDFfv0Kb2MpU19v\n87YsXGjisf/+dvyQQ6z31qhRtr90KRx0EMyebfvf/CZ8//sSk1TIMxEiZHzwgeXfeu01S+uxapXt\nf/hhvHDsvTesWAEzZ6Y+VyXiTk38179acH34cEs7c+ONsGyZ9eK65x4rc8459v3uuaft33CDxCQV\n8kyECBnRqImH+wbdt6+tk8VGhg0rnF1hwR3AuWqVjYivqvKaB13Pzv0u+/SxY+44nV69JCapkGci\nRMhwA8Gie7hdfd0YSDLceV3q6mztT78iMUmOxESIkCExyQ6/EKT6Ht15XbZs6fyZxCQ5EhMhQobE\nJDv8QuDGT4KU9deRmHRGMRMhQobEJDsuuQT23deC7ocdlrrclClw8sm2/fOfe3m7JCbJKdf0b7GY\npkITZcrBB9tkTwcfXGxLKpNt20xQtm8vtiW5p8p6GnRLF9TMJUTIiEatF5IoDrW1ttZc8PFITIQI\nEbGYxKQUUFNXZyQmQoSIjz6ysQ5ul1VRHCQmnZGYCFHCxGJw3HG2/t3v4NRTYfDgYlsl6utttPwp\np8BJJ1lqlkpHYiJECfPRR/DsszYZ05//DIcfDo8/XmyrRH29zb744Yfmobj5uyoZiYkQJYybxTYa\nteXYY70st6J41NfD4sWWs0vT+BoSEyFKmEQx0fiS0sAVk0jEFomJxESIksZ9SK1ZY4vEpDSQmHRG\nYiJECePOpvjuuxY/cRMQiuKSKCarV9v8J6nGSm/ebPPP7Npl++3t5dcbTGIiRAmzbp313vr97+Go\no6CH/mNLggkTbKzPpz5lYvLHP9qUAO++m7x8v342VcCNN9p+nz62vPFG4WzON/ppClHCbNkC3/qW\ndT195pliWyNcrrgCFiyApqb4AaQrViQvv2OHrefNiz++YUNezCsKEhMhSpgtW5JPeiVKB38cy22W\n7IqdO71tNzVLOSAxEaKE2brVRryL0qV/f287SCB+3Tpv2531sRyQmAhRwkhMSh83jjVgQHoxaW+P\nL+P3UsKOxESEhrY2uP9+WL4crroKNm0qtkX54Te/sZ4/YM1cEpNwMHx4ejH55z/hhhu8/XJKYy8x\nEaHhoYfgq1+Fv/0Nbropdc+ZsHP++TBzpm1v3aqYSRh48UX40Y9Si8m4cfDjH8OVV8Khh0JDgx2X\nmMQzGZgHLACuSlHmF87nbwMTAtQdCMwCPgCeAdxWydHAZuBNZ7kjB/aLkOBOseofFV6uuPeqZq5w\ncMQRMHp06t/kzp1w5plw8cVw+eVw5JF2vJzmRMlWTKqB2zFRGAd8Cdg3ocxJwJ7AXsAFwK8C1L0a\nE5O9geecfZeFmCBNAL6dpf0iRLgP1ZUrbV2OYuIOZKty5rpTM1d46GokfKq/ozwTj4nYw70F2A48\nBJyWUOZUYLqz/SrmZQxJU9dfZzpwepZ2ijJg82Zbz51rffvLUUzce3JFRc1c4cH9TSYbBZ/q7yjP\nxGM4sMy3v9w5FqTMsC7qDgbWONtrnH2XMVgTVzNwePdNF2HDfcC+9175ZmpNJibyTMJB797W1TdZ\nx5BUf0d5Jh4pMtF0IsgE9VUpzhfzHV8JjMSauK4AHgT6BrRBlDCf/az9I9bUWOqJ9eu9zw47DKZO\n9R6wS5bAMcfAbbd5TV7lgjvo7Xe/s+/i73+XmISJsWNh4EDvtzxihHkqiWLymc/Yupw8k55Z1l+B\nPdxdRmIeRldlRjhlapIcd5MRrMGawlYDQwH3HXSbswDMBhZhsZhOU9NMnTr1X9tNTU00NTUFuiFR\nHObOtd5ZY8fCgQdaWoqBA+2zl1+2tBNnnQU//CFcfz307GmTRC1dCsOGFdf2XOJ2CZ49G4YOtdn8\nJCbhYfbs+LEjAweap5IYM7n0Upgzp/ieSXNzM83NzTk5V7Zi8jr2MB+NeQ1nY4F0P08CU7CYyCSg\nDROLdV3UfRI4F/iZs37COd4AbAB2Ap9w6n+YzDC/mIjSJxq1fvo1NZbYMLEJy82y2tBgQgL2sC23\npi7X+9q40TyyZcsUMwkTPXrEJ+OMRMx7rqryfrcuNTXF90wSX7SnTZvW7XNl28y1AxOKp4E5wMPA\nXOBCZwGYiT3wFwJ34fXASlUX4EbgOKxr8NHOPsARWPfiN4FHnGu0ZXkPosh0dNjbXJ8+tp+sV4wr\nJm6X2VTlwo4/Lbk7o6I8k/ASiaT2Lmtri++Z5JJsPROAp5zFz10J+1MyqAuwHjg2yfE/OIsoI1pb\nrSeM2x02WU+tjo7OYlKOPbra2y1V+aZNMH68HZOYhJfGxtRiUgqeSS7RCHhRNHbsgOeft/Qh/syr\nkQg8+6ylXXfZvBnef7+zZ/LCC3D33VY2FoMHHrCJisJKe7v3XezrjLoqp2SAlUYkAn/6U/KmynLz\nTCQmomg8/bT1yrrnHvjOd7zjJ59svbluv932hwyByZPh4INh4kSv3AknwJgxMH26lV26FM45Bx5+\nuLD3kUva2+H44+37GDfOerENGFBsq0R3+eIXzTu57LLOn5WbZ5KLZi4huoXr+n/603Deed7xCRNs\nQqinnAbQXbvg17+2gLuf8ePNK3ngASvrNnkFmVOiVGlvh4MOgosusv3rry+uPSI7jj3WlmTU1lrz\nbbkgz0QUDXdiIDfw7scfXE+MlaQqWw45u9Ldqygfys0zkZiIorF1q62rkgxpdQUiFrO3t6Bi0rev\nxESEg5oaxUyEyAmumCQjEoE1ayzwXlsL1dVdl3XFJOxpViQmlUNtrTwTIXKCKyZjx3b+rLHRRr3/\n+7/D7rt3fZ6GBli71sRn/HiJiQgH5eaZKAAvisaWLZaT6+qrO39WUwPz5tl4C/8c28morbXmrfnz\nbV6JJ5/Mj72FQGJSOZxxhvXcKxckJqJobN1qTVSpmrDGjAl+rkjEsgmfe278KPKwITGpHAYO9PLP\nlQNq5hJFI5fp1SMRG2cyerQF7JPNKREGJCYirEhMRNHYsiV3SQzdUeNDh1qzlzuRVtiQmIiwIjER\nRSOXnkljo7eurw9fU1csZmn3JSYirEhMRNHIpZgcdRSceKJ5OmEUk4cftomUqquVi0uEE4mJKBqJ\nEwZlw1lnwcyZth1GMXFTwMgrEWFFYiKKxtat+Zn4KYxismuXrXv3Lq4dQnQXiYkoGrls5vITRjHZ\nuNHW5TSITVQWEhNRNHLZzOUnjGLiT2opRBiRmIikbN8O3/ymbW/YAP/xH/GTVXWHF1+Es8+25Yor\n4Lnn8tPM1acP3HgjXH652e2fLCsWg69/Pdg4lK9/3eZHKcSYFVdMFHwXYUViIpKyahXce689SOfM\ngRkz4B//yO6cM2faw3LXLrjlFjjzTDjllNzY6+eGG6BnT7j1Vnj8cUvL4rJpk82N4jYrdcWvf21z\npQQpmy0bN8JPfwrNzfm/lhD5QGIikrJ+va23bs3dPCHRKBx9NBx5pO2ffHL6JI7dYZ99bFZGsJQs\n/qajoPfi90YKkTiyo8MmCTvwwPxfS4h8IDERSfG34edSTCIRb7S6f973XOMOYuyumPhTgxdCTDRY\nUYQdJXoUSUkUk6FDcycm7lSl+RQTt4ttY2P3xGTLls518onERIQdiYlISkuLrVetsgSK2Uw61dFh\nMYwlS+zh/tFHdjyfGVPdZip/z672dvjwQ9v238u2bZ5NYHEd/8Rd7neRTyQmIuxITEQnduyA666z\n7UmTbPKpyy+3YHZ38D8khwyxuUeg69kTs+WQQ2DQoHgx2WcfE7Z+/bwR5wBf/jI8/bQ3J31bW3wg\n/Lvftd5n+URiIsKOYiaiE2vXWhOUGxxvbbUutt3xTPyB7P/8T6irM3HKd3fbSZPsPlwx2bHDughH\no3DVVfFNX0uXwjPPWPm1a212x0WLvM979ICdO/Nrr8REhB2JieiEG9vwP3AbG+14piLgP0c+YySp\ncMVk3ToYMMC8ocRBje79+u1ctszb793bi/Pkg+3brbu06xkJEUYkJqIT7sN1xw6vKaq+3rY//jjz\nc7kUU0z8ghFETJYu9fZ7987vyHTXK6mqyt81hMg3EhPRicQHr0skknlTVzRqwXvwuusWknRi0t5u\n3lbiffo9k3ynZ1ETlygHFIAPSDQKs2bBHnvAZz5TbGvyi//B26+fdzwSgZ/9DIYNs/1jjun6u1iz\nBn75Sxvr8d57xXlgphKTF16AqVPN/kgk3iuIROCvf/X26+rg+uvh+OPhq1/NvY0SE1EOyDMJyLXX\nwle+AkccUWxL8o/74H3jDXvoulx7rSckc+bAHXd0fZ6XXrJzXHYZPPIIjB+fP5tTMWCA9c5KFJPW\nVrjpJksZk9j8dtpp3vbbb1u85fe/h6uvzo+NbW3Qv39+zi1EoZBnEpB89+YpJaJR61p70EHxx085\nxcul9eyzlkuqK9rbLa3J0Ufnx84guB0HkjXdjR8P774Lo0fH13H3Bw+GAw6IbxLLB9GoXUuIMCPP\nRHQiMSCdjCDxk1JovnHtTCYm++0H77+f+l7b2mztdjpw4yu5Zs2a4nROECKXSEwC4s6EVwnkSkw6\nOoovJoMGWQr9VauSiwkkv9fq6vhR8L162XgTf86uXBHk+xai1JGYBMTNogvlP4FRa2v6h1tDg30n\nXTX/lYJn0rOnxSPmzs1MTBKzGdfX25KP8SYSE1EOSEwC4n8LHznSHqZ77AGjRtl2Y2N+JnrKN+ef\nb/YPHWpv39//vt1rum68PXvaA3fECNh7b/PcolEYPtzO19BgvaWKLSZgf6fFi81WsMmzAMaNs/XI\nkZ3rfPaa5EJZAAAQJ0lEQVSznth86lMweXJ+ughPm2YdGZLZIESYUAA+INEozJ9vDyT37dQVkLlz\nrftrGHn9dXjsMeuptXq1zYa4fbuXP6sr+vSx5I0bN8LmzZYQMRKxLtRz5ti8JaUgJn//uzVPud5G\nfb03IHPjxuT3On26Fx+ZPdu6Du+zT+7F5M034e674fOfz+15hSg0ufBMJgPzgAXAVSnK/ML5/G1g\nQoC6A4FZwAfAM4C/4+Q1Tvl5wPHZmx+M1lZLUti7t/fmDdaOPmpUoazIPdGoeRYDBtj+smWdx12k\nwi3jH8sxdKh9N3vv7X1WbOrqOjdbuSP7+/VLfq9VVfa3BVtXVeXHM4lGYexYjX4X4SdbMakGbsdE\nYRzwJWDfhDInAXsCewEXAL8KUPdqTEz2Bp5z9nHKne2sJwN35OAe0tLR0fXbelgfBLt22RgKVxgh\nPlAdlGQDAwcNsnVdXW5sLQXyISZB4lNChIFsH8QTgYVAC7AdeAg4LaHMqcB0Z/tVzMsYkqauv850\n4HRn+zRghlO+xak/Mct7SIv7Dx9ENPKdDTeXrF9vb+Y1NfHHM324ubmr/GLintPfIyrs5MszkZiI\nciBbMRkO+LIYsdw5FqTMsC7qDgbWONtrnH2cOsvTXC/nPPts8LnK/TP0pWPzZhsZ/vLL8Mc/Bu9+\nvH49/OlP2QvXCy94wWg/QeIlftyH7MyZnR+M+cy2W2jq6+G55yxN/SOPwB/+YLGX7rJli/0Ggv62\nhChlsg3AB32cBWkIqkpxvlia6yT9bOrUqf/abmpqoqmpKYAJyfnlL5OP4n70Ua8H1223wcUX20M1\naNPO88/bPCHuA2nOHNg3sZEwCQ8+aNdassR6KnUX/31NnWo9llpa4LDDgtW/915rFrv3Xpg3z9Kn\n3Hqr9/m118Kpp3bfvlLjc5+DH/zAPNVZs0woX3gBDjywe+fLxOMVIh80NzfT7J8JLguyFZMVgL9T\n40jiPYdkZUY4ZWqSHF/hbK/BmsJWA0MBt2NusnOtIAl+McmWjz6CKVM6Hz/zTG97yhTL9dTeHh+D\n6Ipo1LqnvvOOtx9ETPzzmGcjJm1tcMkltn3wwbZkwjHH2HrGDJsO96CD4lOw/PjH3betFPnyl+Ef\n/7Dl2GNNQDPxRBNRE5coNokv2tOmTev2ubJt5nodC6yPBmqx4PiTCWWeBNxcq5OANkwsuqr7JHCu\ns30u8ITv+Bed8mOc+q9leQ9pCfpPn2mbuj89u7sftF4m5bs6Ty4eZvX1No6jEh6MkYhlQI5EzCvN\nJiYkMRHlRLaeyQ5gCvA01jvrPmAucKHz+V3ATKxH10KgHTgvTV2AG4H/Bb6OBdrPco7PcY7Pcep/\nm+BNbd1iyxZ7YPhTsaeiO2LiZtLt3TszMamvz05Mdu2yZpZczDFSX2/T3IZ1rE0mRCI2ZiUSsUGe\nEhMhjFwMWnzKWfzclbCfpJEoZV2A9cCxKer8xFkKgvvADdKu3R0x+dSnbHvEiMzEZPz47MSkrc3s\nzcVUsa5nMjHv/eqKjyu+jY0mJmrmEsLQCPg03HRT8IB6d8TEfZh88pMWvK6rg+9+1+bU2LTJelb1\n6GHjNo49Fu65x0ZNn3OOBf3HjIEvfCH4NTs67Dy1tZ27BHeXfv1gxQob1FnuuPc4ZEh8M9evf21/\nr0svDX4uiYkoJyQmaXjtNbjuumBluysmy5ZZUPdPf4JrrrFJuF5/3VJ6nHSSV37WLPNknn/eAvd9\n+lhvokzExL0WmBjlgiuvhBNP7H6vpjAxaRK88op1NHjgAc8zufRSS1WfqZgUY8IwIfKBEj2mYcMG\nmygqCN0VkxEj4idHchMmnnBC5/Ljx5s9ffvaujtzsrscemhmdVMxYIBNf5vp+JQwUl1t33tNjWIm\nQviRmKQhk3/4TMQkFosPgPuv4V6zR8JfZ+fO+HJB5hRJxF9eD7LsUG8uITwkJl2wdavFGILOz52J\nmLS1WQ+uXr1sP5mYJCMXYuKmYs9FT65KRgF4ITzKPmbS0QELF1rKim3bLMNv0B5M//xn8J5cYGKy\nYoV1u3Vn5VuyBPbaK75cLGaxGP+DxC9Y77wTXExWrfIGPQ4ebFPA+hk50ssIvHixjZEYORKWLw/n\n/CulRK9eNljznXe8LAYtLfZ9b9pkLxZjx9rcL36WLLHvfvlyCboQpU7M5cc/jsUGD47F7BEei91z\nTywQixbFYr16xWLnnx+sfCwWi/3Xf9k17rzT9u+7z/ZXrYov9847sVhdXSz2ne/EHx81KhY75JBY\nbP/9Y7FHHrFjX/1qLLbXXnZ89OhY7MMPvfI7d8ZiRx5p5QcNsmuNGmX7++9v22efbWW3bo3Famvt\n+D33xGKf/GTw+xLJmTrVvvN99onFDj3Ulrq6WOzzn/d+b48/3rme+1m/foW3WYiuIItxe2XvmaxY\nAT/8oeWyAli5Mni9gw+G++4Lfi137o4lS2zd1mbr1avju82uXAmHHw633x5fv6Wl8zmnT+98zKVH\nD3DT6lx3naUv+e1vwc2O0NwM119v22vXwsCBnhfzjW8EuiXRBW4T5RtvWJMlmAf84YdemcTfmz+Z\n5y235Nc+IQpJ2cdMotHOPaW6Uy8IiRNBufGTxGvmo63cPV+qmIra53OPG3x3hQTsO16wwNtP/Nv7\nsyjr7yHKiYoQk8TgdnfqBcEdBLh9u60lJuXNxo2dj0UiFi9xSfzb+zto6O8hygmJScB6QXDfOteu\ntXV7uzVFFUJM3AD+wIHesYEDLePx9u0Sk3zw0UedjyUG1JOJidvlW38PUU6Ubczkhhust8ySJd4/\neK9e8O67cOGFNi/Fn/8MZ53lff6jH1l8A2yU+AUXZHZNtz38hRes7sMPW7qTBx6wdOUuL70E3/te\ndveXiNt+7x+b0qOHDbJ78EG4887ggy9FMJJNTuZPCDpkiP2OfvMbOM9Jb9rebr+JRYvUk0uUF+U6\nLU8MYhx3nP0Tf/GL1j24rs5mTbzzTgvIf+Urlg/r0kvtwdCzp00Y5T6QTzsts7jJ1q3w979bF9x3\n3rF0JTff3HlkeFUVnHFG8HlPgrBrl83YePjh8ccvuQSWLrWZHOfPh733zt01K50NG6yjhn8agXnz\nLGYyZox5hrffbi8PL71kn7/yiv3ebr7ZsgYIUUpU2TiIbulC2XomYLMIfulLtu2O9fja1+Cttzo3\nP2zdam/xF13U/ev16uVNGPXqqyYme+8Np5zS/XMGpUePzkICcPrp8O1vW5OKhCS3DBjgjeFx2Wcf\nW1xOP91yqrm0t1tHDQmJKDfKOmaSqk26sdHeKMFSlID3T57ra+fynN21Y/58tc8Xi8SUK7n+nQlR\nKlSkmEQiNt86xAfLy1VM/GtRWBKTQUpMRLlS1mKS6p82ErFAPFjAPRbL/T+5e65kQdpCMmiQrf29\nvEThSMzfJTER5UpZx0z22CP58U98wjySxka4/34LvF9wQX7+yYvtEVRX2zoxP5goDImeyccf2zw0\nQpQbZSsmXXkE++8Pmzfb9rPPwk9/mp83xmJ7JS6lYkclkhgz8U87IEQ5UdbNXEFwR4l3dKj5QeSe\nRM9Eg0dFuSIxccSkvV3NDyL3uDET1zuUmIhypeLFpKEB1q+3fEryTESuqa62xZ3vRGIiypWKF5Oe\nPS1t+GOPSUxEfvA3dUWjipmI8qTixQRg2jTr7XT22cW2RJQj/u7BGzZ0HjUvRDlQtrm5YurCJEqE\n4cNtmubhw01YNm7UlMmiNMkmN5c8EyHyjNvMtWOHLW6GZyHKCYmJEHnGFRN3LFNVubYHiIpGYiJE\nntltN4uZKJWKKGfKdgS8EKXCbrvBscfCEUdITET5Is9EiDwzaJCNZXriCYmJKF8kJkLkGf+MmhIT\nUa5ITITIMzU13rbERJQrEhMhCkhdXbEtECI/SEyEKCDu1AdClBvqzSVEnrnoIpvpsqMDDjyw2NYI\nkR+yGT41EHgYGAW0AGcBbUnKTQZuBaqBe4GfBah/DXA+sBO4BHjGOd4MDAHc97vjgLVJrql0KkII\nkSHFSqdyNTAL2Bt4ztlPpBq4HROUccCXgH3T1B8HnO2sJwN34N1cDPgPYIKzJBOS0NPc3FxsE7pN\nmG0H2V9sZH94yUZMTgWmO9vTgdOTlJkILMQ8j+3AQ8BpaeqfBsxwyrc49Q/xnbPsk1GE+QcZZttB\n9hcb2R9eshGTwcAaZ3uNs5/IcGCZb3+5c6yr+sOccv46w3z704E3gR9213AhhBC5JV0AfhYWo0jk\n2oT9mLMkknisqotyQYIcXwZWAn2Ax4BzgPsD1BNCCFGizMMTmqHOfiKTgL/69q8BrkpT/2ri4y9/\nJb6Zy+Vc4LYUti3EEygtWrRo0RJsWUgRuAlPGK4GbkxSpiewCBgN1AJv4QXgU9Uf55SrBcY49auw\nYL6bmKIGeBS4ICd3IoQQomgMBJ4FPsC67vZ3jg8D/uIrdyIwH1O8awLUB/iBU34ecIJzrB54HXgb\neA+4hQoIxgshhBBCCCFCyGTMm1mA14RWavwa6732ru/YQKyzQzIv7RrsfuYBxxfIxq4YCbwAvI95\niJc4x8NwD7sBr2LNqHOAnzrHw2C7n2qsR+OfnP0w2d8CvIPZ/5pzLEz298ea2Odiv6FDCI/9n8S+\nd3fZiP3/hsX+glGNNY2NxmIq/vhMKfEZbMClX0xuAr7vbF9F5/hRDXZfCyl+PrUhgJsUpA/WhLkv\n4bmH3s66J/AKcDjhsd3lCuD3wJPOfpjsX4w9vPyEyf7pWHYOsN/Q7oTLfpcewCrs5TCM9ueVQ4nv\nOZbYK6yUGE28mMzDG2czBK9nm7/3G9j9Tcq3cRnyBHAs4buH3sA/gfGEy/YRWKzxKDzPJEz2LwYG\nJRwLi/27Ax8mOR4W+/0cD7zkbOfE/nJSma4GSJY6mQzgLKV7Go15Wa8Snnvogb1trcFrrguL7WAd\nT64EdvmOhcn+GCaGrwPfdI6Fxf4xQCvwG2A2cA/WMSgs9vv5IpZpBHJkfzmJSazYBuQIt793V5+X\nAu7A0UuBTQmflfI97MKa6UYAR2Bv+H5K2faTgSjW3p2qJ2Mp2w9wGPYCciLwHazZ108p298TOAjL\nF3gQ0E7n1o9Stt+lFjgFeCTJZ922v5zEZAXW/ucyknhVLWXWED+AM+psJ97TCOdYsanBhOR+rJkL\nwncPG7Eu7AcTHts/jeW0W4y9VR6N/Q3CYj9YOz3YG/7jWP6+sNi/3Fn+6ew/ionKasJhv8uJwBvY\n3wDC8/0XjK4GSJYao+kcgM9kAGcxqQJ+hzW3+AnDPTTg9VSpA/4GHEM4bE/kSLyYSVjs7w30dbbr\ngf/D2u7DYj/Yb2ZvZ3sqZnuY7AdLuHuubz9s9heEVAMkS4kZWH6xbViM5zwyH8BZTA7Hmorewuti\nOJlw3MP+WFv3W1j31Cud42GwPZEj8XpzhcX+Mdh3/xbWrdz9Hw2L/QCfwjyTt4E/YEH5MNlfj03d\n0dd3LEz2CyGEEEIIIYQQQgghhBBCCCGEEEIIIYQQQgghhBBCCCGEEKLc+X+LTdyDrkqycQAAAABJ\nRU5ErkJggg==\n", 607 | "text/plain": [ 608 | "" 609 | ] 610 | }, 611 | "metadata": {}, 612 | "output_type": "display_data" 613 | } 614 | ], 615 | "source": [ 616 | "g.plot_reward(smoothing=100)" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 12, 622 | "metadata": { 623 | "collapsed": true 624 | }, 625 | "outputs": [], 626 | "source": [ 627 | "session.run(current_controller.target_network_update)" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 13, 633 | "metadata": { 634 | "collapsed": false 635 | }, 636 | "outputs": [ 637 | { 638 | "data": { 639 | "text/plain": [ 640 | "array([[ 0.00824914, -0.04792793, 0.08260646, ..., -0.00135659,\n", 641 | " -0.0149605 , -0.00065048],\n", 642 | " [ 0.04895933, -0.01720949, 0.03015076, ..., 0.04350275,\n", 643 | " -0.00071916, -0.00507376],\n", 644 | " [-0.03408033, -0.00734746, -0.07286905, ..., 0.06636748,\n", 645 | " 0.0507561 , 0.04723936],\n", 646 | " ..., \n", 647 | " [-0.01454929, -0.00313209, 0.02152171, ..., 0.01723659,\n", 648 | " -0.01757577, -0.02262094],\n", 649 | " [ 0.03226471, -0.09545884, 0.01721121, ..., 0.0179732 ,\n", 650 | " -0.01188065, 0.03430547],\n", 651 | " [ 0.02971489, 0.06272104, -0.05087573, ..., -0.00265156,\n", 652 | " -0.00139228, 0.01183042]], dtype=float32)" 653 | ] 654 | }, 655 | "execution_count": 13, 656 | "metadata": {}, 657 | "output_type": "execute_result" 658 | } 659 | ], 660 | "source": [ 661 | "current_controller.q_network.input_layer.Ws[0].eval()" 662 | ] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "execution_count": 14, 667 | "metadata": { 668 | "collapsed": false 669 | }, 670 | "outputs": [ 671 | { 672 | "data": { 673 | "text/plain": [ 674 | "array([[ 0.00792486, -0.04620561, 0.0822738 , ..., -0.00077494,\n", 675 | " -0.0152533 , -0.0006482 ],\n", 676 | " [ 0.04957085, -0.01709681, 0.03076575, ..., 0.04224423,\n", 677 | " -0.00178666, -0.0045335 ],\n", 678 | " [-0.03365552, -0.00761211, -0.07296719, ..., 0.06496961,\n", 679 | " 0.04951147, 0.04784767],\n", 680 | " ..., \n", 681 | " [-0.01839499, -0.00162295, 0.02568982, ..., 0.01739434,\n", 682 | " -0.0198904 , -0.01868791],\n", 683 | " [ 0.03405851, -0.09418759, 0.01544813, ..., 0.01676139,\n", 684 | " -0.00888498, 0.03387775],\n", 685 | " [ 0.0298653 , 0.06385598, -0.05152962, ..., -0.00786654,\n", 686 | " -0.00157313, 0.01064544]], dtype=float32)" 687 | ] 688 | }, 689 | "execution_count": 14, 690 | "metadata": {}, 691 | "output_type": "execute_result" 692 | } 693 | ], 694 | "source": [ 695 | "current_controller.target_q_network.input_layer.Ws[0].eval()" 696 | ] 697 | }, 698 | { 699 | "cell_type": "markdown", 700 | "metadata": {}, 701 | "source": [ 702 | "# Visualizing what the agent is seeing\n", 703 | "\n", 704 | "Starting with the ray pointing all the way right, we have one row per ray in clockwise order.\n", 705 | "The numbers for each ray are the following:\n", 706 | "- first three numbers are normalized distances to the closest visible (intersecting with the ray) object. If no object is visible then all of them are $1$. If there's many objects in sight, then only the closest one is visible. The numbers represent distance to friend, enemy and wall in order.\n", 707 | "- the last two numbers represent the speed of moving object (x and y components). Speed of wall is ... zero.\n", 708 | "\n", 709 | "Finally the last two numbers in the representation correspond to speed of the hero." 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": 15, 715 | "metadata": { 716 | "collapsed": false 717 | }, 718 | "outputs": [ 719 | { 720 | "name": "stdout", 721 | "output_type": "stream", 722 | "text": [ 723 | "[[1.00 0.55 1.00 -0.42 -0.40]\n", 724 | " [1.00 1.00 1.00 0.00 0.00]\n", 725 | " [1.00 1.00 1.00 0.00 0.00]\n", 726 | " [1.00 0.83 1.00 0.42 0.63]\n", 727 | " [1.00 1.00 1.00 0.00 0.00]\n", 728 | " [1.00 1.00 1.00 0.00 0.00]\n", 729 | " [1.00 1.00 1.00 0.00 0.00]\n", 730 | " [1.00 1.00 1.00 0.00 0.00]\n", 731 | " [1.00 1.00 1.00 0.00 0.00]\n", 732 | " [1.00 1.00 1.00 0.00 0.00]\n", 733 | " [1.00 1.00 1.00 0.00 0.00]\n", 734 | " [1.00 1.00 1.00 0.00 0.00]\n", 735 | " [1.00 1.00 1.00 0.00 0.00]\n", 736 | " [1.00 1.00 1.00 0.00 0.00]\n", 737 | " [1.00 1.00 1.00 0.00 0.00]\n", 738 | " [1.00 0.44 1.00 -0.18 0.76]\n", 739 | " [1.00 0.46 1.00 -0.18 0.76]\n", 740 | " [1.00 1.00 1.00 0.00 0.00]\n", 741 | " [0.89 1.00 1.00 -0.95 0.78]\n", 742 | " [1.00 1.00 1.00 0.00 0.00]\n", 743 | " [1.00 1.00 1.00 0.00 0.00]\n", 744 | " [1.00 0.44 1.00 0.45 -0.81]\n", 745 | " [1.00 0.20 1.00 -0.64 0.14]\n", 746 | " [1.00 0.19 1.00 -0.64 0.14]\n", 747 | " [1.00 0.21 1.00 -0.64 0.14]\n", 748 | " [1.00 0.57 1.00 0.56 0.78]\n", 749 | " [1.00 1.00 1.00 0.00 0.00]\n", 750 | " [1.00 1.00 1.00 0.00 0.00]\n", 751 | " [1.00 0.92 1.00 0.41 0.77]\n", 752 | " [1.00 1.00 1.00 0.00 0.00]\n", 753 | " [1.00 1.00 1.00 0.00 0.00]\n", 754 | " [1.00 1.00 1.00 0.00 0.00]]\n", 755 | "[1.00 -0.94 -0.25 0.46]\n" 756 | ] 757 | }, 758 | { 759 | "data": { 760 | "text/html": [ 761 | "\n", 762 | "\n", 763 | "\n", 764 | "\n", 765 | " \n", 768 | "\n", 769 | " \n", 772 | "\n", 773 | " \n", 774 | "\n", 775 | " \n", 776 | "\n", 777 | " \n", 778 | "\n", 779 | " \n", 780 | "\n", 781 | " \n", 782 | "\n", 783 | " \n", 784 | "\n", 785 | " \n", 786 | "\n", 787 | " \n", 788 | "\n", 789 | " \n", 790 | "\n", 791 | " \n", 792 | "\n", 793 | " \n", 794 | "\n", 795 | " \n", 796 | "\n", 797 | " \n", 798 | "\n", 799 | " \n", 800 | "\n", 801 | " \n", 802 | "\n", 803 | " \n", 804 | "\n", 805 | " \n", 806 | "\n", 807 | " \n", 808 | "\n", 809 | " \n", 810 | "\n", 811 | " \n", 812 | "\n", 813 | " \n", 814 | "\n", 815 | " \n", 816 | "\n", 817 | " \n", 818 | "\n", 819 | " \n", 820 | "\n", 821 | " \n", 822 | "\n", 823 | " \n", 824 | "\n", 825 | " \n", 826 | "\n", 827 | " \n", 828 | "\n", 829 | " \n", 830 | "\n", 831 | " \n", 832 | "\n", 833 | " \n", 834 | "\n", 835 | " \n", 836 | "\n", 837 | " \n", 840 | "\n", 841 | " \n", 844 | "\n", 845 | " \n", 848 | "\n", 849 | " \n", 852 | "\n", 853 | " \n", 856 | "\n", 857 | " \n", 860 | "\n", 861 | " \n", 864 | "\n", 865 | " \n", 868 | "\n", 869 | " \n", 872 | "\n", 873 | " \n", 876 | "\n", 877 | " \n", 880 | "\n", 881 | " \n", 884 | "\n", 885 | " \n", 888 | "\n", 889 | " \n", 892 | "\n", 893 | " \n", 896 | "\n", 897 | " \n", 900 | "\n", 901 | " \n", 904 | "\n", 905 | " \n", 908 | "\n", 909 | " \n", 912 | "\n", 913 | " \n", 916 | "\n", 917 | " \n", 920 | "\n", 921 | " \n", 924 | "\n", 925 | " \n", 928 | "\n", 929 | " \n", 932 | "\n", 933 | " \n", 936 | "\n", 937 | " \n", 940 | "\n", 941 | " \n", 944 | "\n", 945 | " \n", 948 | "\n", 949 | " \n", 952 | "\n", 953 | " \n", 956 | "\n", 957 | " \n", 960 | "\n", 961 | " \n", 964 | "\n", 965 | " \n", 968 | "\n", 969 | " \n", 972 | "\n", 973 | " \n", 976 | "\n", 977 | " \n", 980 | "\n", 981 | " \n", 984 | "\n", 985 | " \n", 988 | "\n", 989 | " \n", 992 | "\n", 993 | " \n", 996 | "\n", 997 | " \n", 1000 | "\n", 1001 | " \n", 1004 | "\n", 1005 | " \n", 1008 | "\n", 1009 | " \n", 1012 | "\n", 1013 | " \n", 1016 | "\n", 1017 | " \n", 1020 | "\n", 1021 | " \n", 1024 | "\n", 1025 | " \n", 1028 | "\n", 1029 | " \n", 1032 | "\n", 1033 | " \n", 1036 | "\n", 1037 | " \n", 1040 | "\n", 1041 | " \n", 1042 | "\n", 1043 | " nearest wall = 124.1\n", 1044 | "\n", 1045 | " \n", 1046 | "\n", 1047 | " \n", 1048 | "\n", 1049 | " reward = 0.0\n", 1050 | "\n", 1051 | " \n", 1052 | "\n", 1053 | " \n", 1054 | "\n", 1055 | " objects eaten => friend: 192, enemy: 144\n", 1056 | "\n", 1057 | " \n", 1058 | "\n", 1059 | " \n", 1060 | "\n" 1061 | ], 1062 | "text/plain": [ 1063 | "" 1064 | ] 1065 | }, 1066 | "execution_count": 15, 1067 | "metadata": {}, 1068 | "output_type": "execute_result" 1069 | } 1070 | ], 1071 | "source": [ 1072 | "g.__class__ = KarpathyGame\n", 1073 | "np.set_printoptions(formatter={'float': (lambda x: '%.2f' % (x,))})\n", 1074 | "x = g.observe()\n", 1075 | "new_shape = (x[:-4].shape[0]//g.eye_observation_size, g.eye_observation_size)\n", 1076 | "print(x[:-4].reshape(new_shape))\n", 1077 | "print(x[-4:])\n", 1078 | "g.to_html()" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "code", 1083 | "execution_count": null, 1084 | "metadata": { 1085 | "collapsed": true 1086 | }, 1087 | "outputs": [], 1088 | "source": [] 1089 | } 1090 | ], 1091 | "metadata": { 1092 | "kernelspec": { 1093 | "display_name": "Python 3", 1094 | "language": "python", 1095 | "name": "python3" 1096 | }, 1097 | "language_info": { 1098 | "codemirror_mode": { 1099 | "name": "ipython", 1100 | "version": 3 1101 | }, 1102 | "file_extension": ".py", 1103 | "mimetype": "text/x-python", 1104 | "name": "python", 1105 | "nbconvert_exporter": "python", 1106 | "pygments_lexer": "ipython3", 1107 | "version": "3.4.3" 1108 | } 1109 | }, 1110 | "nbformat": 4, 1111 | "nbformat_minor": 0 1112 | } 1113 | -------------------------------------------------------------------------------- /notebooks/pong.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import tensorflow as tf 4 | import time 5 | import os 6 | 7 | from itertools import count 8 | 9 | from tf_rl.models import Layer, LambdaLayer, ConvLayer, SeqLayer 10 | from tf_rl.controller.discrete_deepq import DiscreteDeepQ 11 | 12 | # CRAZY VARIABLES 13 | REAL_TIME = False 14 | RENDER = True 15 | 16 | ENVIRONMENT = "Pong-v0" 17 | MODEL_SAVE_DIR = "./{}-model/".format(ENVIRONMENT) 18 | MODEL_SAVE_EVERY_S = 60 19 | 20 | # SENSIBLE VARIABLES 21 | FPS = 60 22 | MAX_FRAMES = 1000 23 | IMAGE_SHAPE = (210, 160, 3) 24 | OBS_SHAPE = (210, 160, 6) 25 | NUM_ACTIONS = 6 26 | 27 | 28 | def make_model(): 29 | """Create a tensorflow convnet that takes image as input 30 | and outputs a predicted discounted score for every action""" 31 | 32 | with tf.variable_scope('convnet'): 33 | convnet = SeqLayer([ 34 | ConvLayer(3, 3, 6, 32, stride=(1,1), scope='conv1'), # out.shape = (B, 210, 160, 32) 35 | LambdaLayer(tf.nn.sigmoid), 36 | ConvLayer(2, 2, 32, 64, stride=(2,2), scope='conv2'), # out.shape = (B, 105, 80, 64) 37 | LambdaLayer(tf.nn.sigmoid), 38 | ConvLayer(3, 3, 64, 64, stride=(1,1), scope='conv3'), # out.shape = (B, 105, 80, 64) 39 | LambdaLayer(tf.nn.sigmoid), 40 | ConvLayer(2, 2, 64, 128, stride=(2,2), scope='conv4'), # out.shape = (B, 53, 40, 128) 41 | LambdaLayer(tf.nn.sigmoid), 42 | ConvLayer(3, 3, 128, 128, stride=(1,1), scope='conv5'), # out.shape = (B, 53, 40, 128) 43 | LambdaLayer(tf.nn.sigmoid), 44 | ConvLayer(2, 2, 128, 256, stride=(2,2), scope='conv6'), # out.shape = (B, 27, 20, 256) 45 | LambdaLayer(tf.nn.sigmoid), 46 | LambdaLayer(lambda x: tf.reshape(x, [-1, 27 * 20 * 256])), # out.shape = (B, 27 * 20 * 256) 47 | Layer(27 * 20 * 256, 6, scope='proj_actions') # out.shape = (B, 6) 48 | ], scope='convnet') 49 | return convnet 50 | 51 | 52 | def make_controller(): 53 | """Create a deepq controller""" 54 | session = tf.Session() 55 | 56 | model = make_model() 57 | 58 | optimizer = tf.train.AdamOptimizer(learning_rate=0.001) 59 | 60 | return DiscreteDeepQ(OBS_SHAPE, 61 | NUM_ACTIONS, 62 | model, 63 | optimizer, 64 | session, 65 | random_action_probability=0.1, 66 | minibatch_size=8, 67 | discount_rate=0.99, 68 | exploration_period=500000, 69 | max_experience=10000, 70 | target_network_update_rate=0.01, 71 | store_every_nth=4, 72 | train_every_nth=4) 73 | 74 | # TODO(szymon): apparently both DeepMind and Karpathy 75 | # people normalize their frames to sizes 80x80 and grayscale. 76 | # should we do this? 77 | def normalize_frame(o): 78 | """Change from uint in range (0, 255) to float in range (0, 1)""" 79 | return o.astype(np.float32) / 255.0 80 | 81 | def main(): 82 | env = gym.make(ENVIRONMENT) 83 | controller = make_controller() 84 | 85 | # Load existing model. 86 | if os.path.exists(MODEL_SAVE_DIR): 87 | print "loading model... ", 88 | controller.restore(MODEL_SAVE_DIR) 89 | print 'done.' 90 | last_model_save = time.time() 91 | 92 | # For every game 93 | for game_no in count(): 94 | # Reset simulator 95 | frame_tm1 = normalize_frame(env.reset()) 96 | frame_t, _, _, _ = env.step(env.action_space.sample()) 97 | frame_t = normalize_frame(frame_t) 98 | 99 | rewards = [] 100 | for _ in range(MAX_FRAMES): 101 | start_time = time.time() 102 | 103 | # observation consists of two last frames 104 | # this is important so that we can detect speed. 105 | observation_t = np.concatenate([frame_tm1, frame_t], 2) 106 | 107 | # pick an action according to Q-function learned so far. 108 | action = controller.action(observation_t) 109 | 110 | if RENDER: env.render() 111 | 112 | # advance simulator 113 | frame_tp1, reward, done, info = env.step(action) 114 | frame_tp1 = normalize_frame(frame_tp1) 115 | if done: break 116 | 117 | observation_tp1 = np.concatenate([frame_t, frame_tp1], 2) 118 | 119 | # store transitions 120 | controller.store(observation_t, action, reward, observation_tp1) 121 | # run a single iteration of SGD 122 | controller.training_step() 123 | 124 | 125 | frame_tm1, frame_t = frame_t, frame_tp1 126 | rewards.append(reward) 127 | 128 | # if real time visualization is requested throttle down FPS. 129 | if REAL_TIME: 130 | time_passed = time.time() - start_time 131 | time_left = 1.0 / FPS - time_passed 132 | 133 | if time_left > 0: 134 | time.sleep(time_left) 135 | 136 | # save model if time since last save is greater than 137 | # MODEL_SAVE_EVERY_S 138 | if time.time() - last_model_save >= MODEL_SAVE_EVERY_S: 139 | if not os.path.exists(MODEL_SAVE_DIR): 140 | os.makedirs(MODEL_SAVE_DIR) 141 | controller.save(MODEL_SAVE_DIR, debug=True) 142 | last_model_save = time.time() 143 | 144 | # Count scores. This relies on specific score values being 145 | # assigned by openai gym and might break in the future. 146 | points_lost = rewards.count(-1.0) 147 | points_won = rewards.count(1.0) 148 | exploration_done = controller.exploration_completed() 149 | 150 | print "Game no %d is over. Exploration %.1f done. Points lost: %d, points won: %d" % \ 151 | (game_no, 100.0 * exploration_done, points_lost, points_won) 152 | 153 | if __name__ == '__main__': 154 | main() 155 | 156 | -------------------------------------------------------------------------------- /notebooks/tf_rl: -------------------------------------------------------------------------------- 1 | ../tf_rl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | future==0.15.2 2 | euclid==0.1 3 | -------------------------------------------------------------------------------- /saved_models/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "../saved_models/karpathy_game.ckpt" 2 | all_model_checkpoint_paths: "../saved_models/karpathy_game.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/karpathy_game.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siemanko/tensorflow-deepq/149e69e5340984d75df3ff1a374920d870517fb9/saved_models/karpathy_game.ckpt -------------------------------------------------------------------------------- /scripts/make_gif.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # stop script on error and print it 4 | set -e 5 | # inform me of undefined variables 6 | set -u 7 | # handle cascading failures well 8 | set -o pipefail 9 | 10 | SCRIPT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) 11 | 12 | 13 | images_directory=${1:-} 14 | if [[ -z "$images_directory" ]] 15 | then 16 | echo "Usage $0 images_directory" 17 | exit 1 18 | fi 19 | 20 | for img in $images_directory/*.svg 21 | do 22 | if [ ! -f $img.png ]; then 23 | echo "Converting $img." 24 | inkscape -z -e $img.png -b white $img 25 | fi 26 | done 27 | convert -delay 3 -loop 0 $(ls $images_directory/*.png | sort -V) animation.gif 28 | -------------------------------------------------------------------------------- /tf_rl/__init__.py: -------------------------------------------------------------------------------- 1 | from .simulate import simulate 2 | -------------------------------------------------------------------------------- /tf_rl/controller/__init__.py: -------------------------------------------------------------------------------- 1 | from .discrete_deepq import DiscreteDeepQ 2 | from .human_controller import HumanController 3 | -------------------------------------------------------------------------------- /tf_rl/controller/discrete_deepq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import tensorflow as tf 4 | import os 5 | import pickle 6 | import time 7 | 8 | from collections import deque 9 | 10 | class DiscreteDeepQ(object): 11 | def __init__(self, observation_shape, 12 | num_actions, 13 | observation_to_actions, 14 | optimizer, 15 | session, 16 | random_action_probability=0.05, 17 | exploration_period=1000, 18 | store_every_nth=5, 19 | train_every_nth=5, 20 | minibatch_size=32, 21 | discount_rate=0.95, 22 | max_experience=30000, 23 | target_network_update_rate=0.01, 24 | summary_writer=None): 25 | """Initialized the Deepq object. 26 | 27 | Based on: 28 | https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf 29 | 30 | Parameters 31 | ------- 32 | observation_shape : int 33 | length of the vector passed as observation 34 | num_actions : int 35 | number of actions that the model can execute 36 | observation_to_actions: dali model 37 | model that implements activate function 38 | that can take in observation vector or a batch 39 | and returns scores (of unbounded values) for each 40 | action for each observation. 41 | input shape: [batch_size] + observation_shape 42 | output shape: [batch_size, num_actions] 43 | optimizer: tf.solver.* 44 | optimizer for prediction error 45 | session: tf.Session 46 | session on which to execute the computation 47 | random_action_probability: float (0 to 1) 48 | exploration_period: int 49 | probability of choosing a random 50 | action (epsilon form paper) annealed linearly 51 | from 1 to random_action_probability over 52 | exploration_period 53 | store_every_nth: int 54 | to further decorrelate samples do not all 55 | transitions, but rather every nth transition. 56 | For example if store_every_nth is 5, then 57 | only 20% of all the transitions is stored. 58 | train_every_nth: int 59 | normally training_step is invoked every 60 | time action is executed. Depending on the 61 | setup that might be too often. When this 62 | variable is set set to n, then only every 63 | n-th time training_step is called will 64 | the training procedure actually be executed. 65 | minibatch_size: int 66 | number of state,action,reward,newstate 67 | tuples considered during experience reply 68 | dicount_rate: float (0 to 1) 69 | how much we care about future rewards. 70 | max_experience: int 71 | maximum size of the reply buffer 72 | target_network_update_rate: float 73 | how much to update target network after each 74 | iteration. Let's call target_network_update_rate 75 | alpha, target network T, and network N. Every 76 | time N gets updated we execute: 77 | T = (1-alpha)*T + alpha*N 78 | summary_writer: tf.train.SummaryWriter 79 | writer to log metrics 80 | """ 81 | # memorize arguments 82 | self.observation_shape = observation_shape 83 | self.num_actions = num_actions 84 | 85 | self.q_network = observation_to_actions 86 | self.optimizer = optimizer 87 | self.s = session 88 | 89 | self.random_action_probability = random_action_probability 90 | self.exploration_period = exploration_period 91 | self.store_every_nth = store_every_nth 92 | self.train_every_nth = train_every_nth 93 | self.minibatch_size = minibatch_size 94 | self.discount_rate = tf.constant(discount_rate) 95 | self.max_experience = max_experience 96 | self.target_network_update_rate = \ 97 | tf.constant(target_network_update_rate) 98 | 99 | # deepq state 100 | self.actions_executed_so_far = 0 101 | self.experience = deque() 102 | 103 | self.iteration = 0 104 | self.summary_writer = summary_writer 105 | 106 | self.number_of_times_store_called = 0 107 | self.number_of_times_train_called = 0 108 | 109 | self.create_variables() 110 | 111 | self.s.run(tf.initialize_all_variables()) 112 | self.s.run(self.target_network_update) 113 | 114 | self.saver = tf.train.Saver() 115 | 116 | def linear_annealing(self, n, total, p_initial, p_final): 117 | """Linear annealing between p_initial and p_final 118 | over total steps - computes value at step n""" 119 | if n >= total: 120 | return p_final 121 | else: 122 | return p_initial - (n * (p_initial - p_final)) / (total) 123 | 124 | 125 | def observation_batch_shape(self, batch_size): 126 | return tuple([batch_size] + list(self.observation_shape)) 127 | 128 | def create_variables(self): 129 | self.target_q_network = self.q_network.copy(scope="target_network") 130 | 131 | # FOR REGULAR ACTION SCORE COMPUTATION 132 | with tf.name_scope("taking_action"): 133 | self.observation = tf.placeholder(tf.float32, self.observation_batch_shape(None), name="observation") 134 | self.action_scores = tf.identity(self.q_network(self.observation), name="action_scores") 135 | tf.histogram_summary("action_scores", self.action_scores) 136 | self.predicted_actions = tf.argmax(self.action_scores, dimension=1, name="predicted_actions") 137 | 138 | with tf.name_scope("estimating_future_rewards"): 139 | # FOR PREDICTING TARGET FUTURE REWARDS 140 | self.next_observation = tf.placeholder(tf.float32, self.observation_batch_shape(None), name="next_observation") 141 | self.next_observation_mask = tf.placeholder(tf.float32, (None,), name="next_observation_mask") 142 | self.next_action_scores = tf.stop_gradient(self.target_q_network(self.next_observation)) 143 | tf.histogram_summary("target_action_scores", self.next_action_scores) 144 | self.rewards = tf.placeholder(tf.float32, (None,), name="rewards") 145 | target_values = tf.reduce_max(self.next_action_scores, reduction_indices=[1,]) * self.next_observation_mask 146 | self.future_rewards = self.rewards + self.discount_rate * target_values 147 | 148 | with tf.name_scope("q_value_precition"): 149 | # FOR PREDICTION ERROR 150 | self.action_mask = tf.placeholder(tf.float32, (None, self.num_actions), name="action_mask") 151 | self.masked_action_scores = tf.reduce_sum(self.action_scores * self.action_mask, reduction_indices=[1,]) 152 | temp_diff = self.masked_action_scores - self.future_rewards 153 | self.prediction_error = tf.reduce_mean(tf.square(temp_diff)) 154 | gradients = self.optimizer.compute_gradients(self.prediction_error) 155 | for i, (grad, var) in enumerate(gradients): 156 | if grad is not None: 157 | gradients[i] = (tf.clip_by_norm(grad, 5), var) 158 | # Add histograms for gradients. 159 | for grad, var in gradients: 160 | tf.histogram_summary(var.name, var) 161 | if grad is not None: 162 | tf.histogram_summary(var.name + '/gradients', grad) 163 | self.train_op = self.optimizer.apply_gradients(gradients) 164 | 165 | # UPDATE TARGET NETWORK 166 | with tf.name_scope("target_network_update"): 167 | self.target_network_update = [] 168 | for v_source, v_target in zip(self.q_network.variables(), self.target_q_network.variables()): 169 | # this is equivalent to target = (1-alpha) * target + alpha * source 170 | update_op = v_target.assign_sub(self.target_network_update_rate * (v_target - v_source)) 171 | self.target_network_update.append(update_op) 172 | self.target_network_update = tf.group(*self.target_network_update) 173 | 174 | # summaries 175 | tf.scalar_summary("prediction_error", self.prediction_error) 176 | 177 | self.summarize = tf.merge_all_summaries() 178 | self.no_op1 = tf.no_op() 179 | 180 | 181 | def action(self, observation): 182 | """Given observation returns the action that should be chosen using 183 | DeepQ learning strategy. Does not backprop.""" 184 | assert observation.shape == self.observation_shape, \ 185 | "Action is performed based on single observation." 186 | 187 | self.actions_executed_so_far += 1 188 | exploration_p = self.linear_annealing(self.actions_executed_so_far, 189 | self.exploration_period, 190 | 1.0, 191 | self.random_action_probability) 192 | 193 | if random.random() < exploration_p: 194 | return random.randint(0, self.num_actions - 1) 195 | else: 196 | return self.s.run(self.predicted_actions, {self.observation: observation[np.newaxis,:]})[0] 197 | 198 | def exploration_completed(self): 199 | return min(float(self.actions_executed_so_far) / self.exploration_period, 1.0) 200 | 201 | def store(self, observation, action, reward, newobservation): 202 | """Store experience, where starting with observation and 203 | execution action, we arrived at the newobservation and got thetarget_network_update 204 | reward reward 205 | 206 | If newstate is None, the state/action pair is assumed to be terminal 207 | """ 208 | if self.number_of_times_store_called % self.store_every_nth == 0: 209 | self.experience.append((observation, action, reward, newobservation)) 210 | if len(self.experience) > self.max_experience: 211 | self.experience.popleft() 212 | self.number_of_times_store_called += 1 213 | 214 | def training_step(self): 215 | """Pick a self.minibatch_size exeperiences from reply buffer 216 | and backpropage the value function. 217 | """ 218 | if self.number_of_times_train_called % self.train_every_nth == 0: 219 | if len(self.experience) < self.minibatch_size: 220 | return 221 | 222 | # sample experience. 223 | samples = random.sample(range(len(self.experience)), self.minibatch_size) 224 | samples = [self.experience[i] for i in samples] 225 | 226 | # bach states 227 | states = np.empty(self.observation_batch_shape(len(samples))) 228 | newstates = np.empty(self.observation_batch_shape(len(samples))) 229 | action_mask = np.zeros((len(samples), self.num_actions)) 230 | 231 | newstates_mask = np.empty((len(samples),)) 232 | rewards = np.empty((len(samples),)) 233 | 234 | for i, (state, action, reward, newstate) in enumerate(samples): 235 | states[i] = state 236 | action_mask[i] = 0 237 | action_mask[i][action] = 1 238 | rewards[i] = reward 239 | if newstate is not None: 240 | newstates[i] = newstate 241 | newstates_mask[i] = 1 242 | else: 243 | newstates[i] = 0 244 | newstates_mask[i] = 0 245 | 246 | 247 | calculate_summaries = self.iteration % 100 == 0 and \ 248 | self.summary_writer is not None 249 | 250 | cost, _, summary_str = self.s.run([ 251 | self.prediction_error, 252 | self.train_op, 253 | self.summarize if calculate_summaries else self.no_op1, 254 | ], { 255 | self.observation: states, 256 | self.next_observation: newstates, 257 | self.next_observation_mask: newstates_mask, 258 | self.action_mask: action_mask, 259 | self.rewards: rewards, 260 | }) 261 | 262 | self.s.run(self.target_network_update) 263 | 264 | if calculate_summaries: 265 | self.summary_writer.add_summary(summary_str, self.iteration) 266 | 267 | self.iteration += 1 268 | 269 | self.number_of_times_train_called += 1 270 | 271 | def save(self, save_dir, debug=False): 272 | STATE_FILE = os.path.join(save_dir, 'deepq_state') 273 | MODEL_FILE = os.path.join(save_dir, 'model') 274 | 275 | # deepq state 276 | state = { 277 | 'actions_executed_so_far': self.actions_executed_so_far, 278 | 'iteration': self.iteration, 279 | 'number_of_times_store_called': self.number_of_times_store_called, 280 | 'number_of_times_train_called': self.number_of_times_train_called, 281 | } 282 | 283 | if debug: 284 | print 'Saving model... ', 285 | 286 | saving_started = time.time() 287 | 288 | self.saver.save(self.s, MODEL_FILE) 289 | with open(STATE_FILE, "wb") as f: 290 | pickle.dump(state, f) 291 | 292 | print 'done in {} s'.format(time.time() - saving_started) 293 | 294 | def restore(self, save_dir, debug=False): 295 | # deepq state 296 | STATE_FILE = os.path.join(save_dir, 'deepq_state') 297 | MODEL_FILE = os.path.join(save_dir, 'model') 298 | 299 | with open(STATE_FILE, "rb") as f: 300 | state = pickle.load(f) 301 | self.saver.restore(self.s, MODEL_FILE) 302 | 303 | self.actions_executed_so_far = state['actions_executed_so_far'] 304 | self.iteration = state['iteration'] 305 | self.number_of_times_store_called = state['number_of_times_store_called'] 306 | self.number_of_times_train_called = state['number_of_times_train_called'] 307 | 308 | 309 | 310 | -------------------------------------------------------------------------------- /tf_rl/controller/human_controller.py: -------------------------------------------------------------------------------- 1 | from tf_rl.utils.getch import getch 2 | from redis import StrictRedis 3 | 4 | 5 | 6 | class HumanController(object): 7 | def __init__(self, mapping): 8 | self.mapping = mapping 9 | self.r = StrictRedis() 10 | self.experience = [] 11 | 12 | def action(self, o): 13 | return self.mapping[self.r.get("action")] 14 | 15 | def store(self, observation, action, reward, newobservation): 16 | pass 17 | 18 | def training_step(self): 19 | pass 20 | 21 | 22 | 23 | def control_me(): 24 | r = StrictRedis() 25 | while True: 26 | c = getch() 27 | r.set("action", c) 28 | 29 | 30 | if __name__ == '__main__': 31 | control_me() 32 | -------------------------------------------------------------------------------- /tf_rl/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | 4 | from .utils import base_name 5 | 6 | 7 | class Layer(object): 8 | def __init__(self, input_sizes, output_size, scope): 9 | """Cretes a neural network layer.""" 10 | if type(input_sizes) != list: 11 | input_sizes = [input_sizes] 12 | 13 | self.input_sizes = input_sizes 14 | self.output_size = output_size 15 | self.scope = scope or "Layer" 16 | 17 | with tf.variable_scope(self.scope): 18 | self.Ws = [] 19 | for input_idx, input_size in enumerate(input_sizes): 20 | W_name = "W_%d" % (input_idx,) 21 | W_initializer = tf.random_uniform_initializer( 22 | -1.0 / math.sqrt(input_size), 1.0 / math.sqrt(input_size)) 23 | W_var = tf.get_variable(W_name, (input_size, output_size), initializer=W_initializer) 24 | self.Ws.append(W_var) 25 | self.b = tf.get_variable("b", (output_size,), initializer=tf.constant_initializer(0)) 26 | 27 | def __call__(self, xs): 28 | if type(xs) != list: 29 | xs = [xs] 30 | assert len(xs) == len(self.Ws), \ 31 | "Expected %d input vectors, got %d" % (len(self.Ws), len(xs)) 32 | with tf.variable_scope(self.scope): 33 | return sum([tf.matmul(x, W) for x, W in zip(xs, self.Ws)]) + self.b 34 | 35 | def variables(self): 36 | return [self.b] + self.Ws 37 | 38 | def copy(self, scope=None): 39 | scope = scope or self.scope + "_copy" 40 | 41 | with tf.variable_scope(scope) as sc: 42 | for v in self.variables(): 43 | tf.get_variable(base_name(v), v.get_shape(), 44 | initializer=lambda x,dtype=tf.float32,partition_info=None: v.initialized_value()) 45 | sc.reuse_variables() 46 | return Layer(self.input_sizes, self.output_size, scope=sc) 47 | 48 | class MLP(object): 49 | def __init__(self, input_sizes, hiddens, nonlinearities, scope=None, given_layers=None): 50 | self.input_sizes = input_sizes 51 | self.hiddens = hiddens 52 | self.input_nonlinearity, self.layer_nonlinearities = nonlinearities[0], nonlinearities[1:] 53 | self.scope = scope or "MLP" 54 | 55 | assert len(hiddens) == len(nonlinearities), \ 56 | "Number of hiddens must be equal to number of nonlinearities" 57 | 58 | with tf.variable_scope(self.scope): 59 | if given_layers is not None: 60 | self.input_layer = given_layers[0] 61 | self.layers = given_layers[1:] 62 | else: 63 | self.input_layer = Layer(input_sizes, hiddens[0], scope="input_layer") 64 | self.layers = [] 65 | 66 | for l_idx, (h_from, h_to) in enumerate(zip(hiddens[:-1], hiddens[1:])): 67 | self.layers.append(Layer(h_from, h_to, scope="hidden_layer_%d" % (l_idx,))) 68 | 69 | def __call__(self, xs): 70 | if type(xs) != list: 71 | xs = [xs] 72 | with tf.variable_scope(self.scope): 73 | hidden = self.input_nonlinearity(self.input_layer(xs)) 74 | for layer, nonlinearity in zip(self.layers, self.layer_nonlinearities): 75 | hidden = nonlinearity(layer(hidden)) 76 | return hidden 77 | 78 | def variables(self): 79 | res = self.input_layer.variables() 80 | for layer in self.layers: 81 | res.extend(layer.variables()) 82 | return res 83 | 84 | def copy(self, scope=None): 85 | scope = scope or self.scope + "_copy" 86 | nonlinearities = [self.input_nonlinearity] + self.layer_nonlinearities 87 | given_layers = [self.input_layer.copy()] + [layer.copy() for layer in self.layers] 88 | return MLP(self.input_sizes, self.hiddens, nonlinearities, scope=scope, 89 | given_layers=given_layers) 90 | 91 | 92 | class ConvLayer(object): 93 | def __init__(self, filter_H, filter_W, 94 | in_C, out_C, 95 | stride=(1,1), 96 | scope="Convolution"): 97 | self.filter_H, self.filter_W = filter_H, filter_W 98 | self.in_C, self.out_C = in_C, out_C 99 | self.stride = stride 100 | self.scope = scope 101 | 102 | with tf.variable_scope(self.scope): 103 | input_size = filter_H * filter_W * in_C 104 | W_initializer = tf.random_uniform_initializer( 105 | -1.0 / math.sqrt(input_size), 106 | 1.0 / math.sqrt(input_size)) 107 | self.W = tf.get_variable('W', 108 | (filter_H, filter_W, in_C, out_C), 109 | initializer=W_initializer) 110 | self.b = tf.get_variable('b', 111 | (out_C), 112 | initializer=tf.constant_initializer(0)) 113 | 114 | def __call__(self, X): 115 | with tf.variable_scope(self.scope): 116 | return tf.nn.conv2d(X, self.W, 117 | strides=[1] + list(self.stride) + [1], 118 | padding='SAME') + self.b 119 | 120 | def variables(self): 121 | return [self.W, self.b] 122 | 123 | def copy(self, scope=None): 124 | scope = scope or self.scope + "_copy" 125 | 126 | with tf.variable_scope(scope) as sc: 127 | for v in self.variables(): 128 | tf.get_variable(base_name(v), v.get_shape(), 129 | initializer=lambda x,dtype=tf.float32,partition_info=None: v.initialized_value()) 130 | sc.reuse_variables() 131 | return ConvLayer(self.filter_H, self.filter_W, self.in_C, self.out_C, self.stride, scope=sc) 132 | 133 | class SeqLayer(object): 134 | def __init__(self, layers, scope='seq_layer'): 135 | self.scope = scope 136 | self.layers = layers 137 | 138 | def __call__(self, x): 139 | for l in self.layers: 140 | x = l(x) 141 | return x 142 | 143 | def variables(self): 144 | return sum([l.variables() for l in self.layers], []) 145 | 146 | def copy(self, scope=None): 147 | scope = scope or self.scope + "_copy" 148 | with tf.variable_scope(self.scope): 149 | copied_layers = [layer.copy() for layer in self.layers] 150 | return SeqLayer(copied_layers, scope=scope) 151 | 152 | 153 | class LambdaLayer(object): 154 | def __init__(self, f): 155 | self.f = f 156 | 157 | def __call__(self, x): 158 | return self.f(x) 159 | 160 | def variables(self): 161 | return [] 162 | 163 | def copy(self): 164 | return LambdaLayer(self.f) 165 | -------------------------------------------------------------------------------- /tf_rl/simulate.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import time 5 | 6 | import matplotlib.pyplot as plt 7 | from itertools import count 8 | from os.path import join, exists 9 | from os import makedirs 10 | from IPython.display import clear_output, display, HTML 11 | 12 | def simulate(simulation, 13 | controller= None, 14 | fps=60, 15 | visualize_every=1, 16 | action_every=1, 17 | simulation_resolution=None, 18 | wait=False, 19 | disable_training=False, 20 | save_path=None): 21 | """Start the simulation. Performs three tasks 22 | 23 | - visualizes simulation in iPython notebook 24 | - advances simulator state 25 | - reports state to controller and chooses actions 26 | to be performed. 27 | 28 | Parameters 29 | ------- 30 | simulation: tr_lr.simulation 31 | simulation that will be simulated ;-) 32 | controller: tr_lr.controller 33 | controller used 34 | fps: int 35 | frames per seconds 36 | visualize_every: int 37 | visualize every `visualize_every`-th frame. 38 | action_every: int 39 | take action every `action_every`-th frame 40 | simulation_resolution: float 41 | simulate at most 'simulation_resolution' seconds at a time. 42 | If None, the it is set to 1/FPS (default). 43 | wait: boolean 44 | whether to intentionally slow down the simulation 45 | to appear real time. 46 | disable_training: bool 47 | if true training_step is never called. 48 | save_path: str 49 | save svg visualization (only tl_rl.utils.svg 50 | supported for the moment) 51 | """ 52 | 53 | # prepare path to save simulation images 54 | if save_path is not None: 55 | if not exists(save_path): 56 | makedirs(save_path) 57 | last_image = 0 58 | 59 | # calculate simulation times 60 | chunks_per_frame = 1 61 | chunk_length_s = 1.0 / fps 62 | 63 | if simulation_resolution is not None: 64 | frame_length_s = 1.0 / fps 65 | chunks_per_frame = int(math.ceil(frame_length_s / simulation_resolution)) 66 | chunks_per_frame = max(chunks_per_frame, 1) 67 | chunk_length_s = frame_length_s / chunks_per_frame 68 | 69 | # state transition bookkeeping 70 | last_observation = None 71 | last_action = None 72 | 73 | simulation_started_time = time.time() 74 | 75 | # setup rendering handles for reuse 76 | if hasattr(simulation, 'setup_draw'): 77 | simulation.setup_draw() 78 | 79 | for frame_no in count(): 80 | for _ in range(chunks_per_frame): 81 | simulation.step(chunk_length_s) 82 | 83 | if frame_no % action_every == 0: 84 | new_observation = simulation.observe() 85 | reward = simulation.collect_reward() 86 | # store last transition 87 | if last_observation is not None: 88 | controller.store(last_observation, last_action, reward, new_observation) 89 | 90 | # act 91 | new_action = controller.action(new_observation) 92 | simulation.perform_action(new_action) 93 | 94 | #train 95 | if not disable_training: 96 | controller.training_step() 97 | 98 | # update current state as last state. 99 | last_action = new_action 100 | last_observation = new_observation 101 | 102 | # adding 1 to make it less likely to happen at the same time as 103 | # action taking. 104 | if (frame_no + 1) % visualize_every == 0: 105 | fps_estimate = frame_no / (time.time() - simulation_started_time) 106 | 107 | # draw simulated environment all the rendering is handled within the simulation object 108 | stats = ["fps = %.1f" % (fps_estimate, )] 109 | if hasattr(simulation, 'draw'): # render with the draw function 110 | simulation.draw(stats) 111 | elif hasattr(simulation, 'to_html'): # in case some class only support svg rendering 112 | clear_output(wait=True) 113 | svg_html = simulation.to_html(stats) 114 | display(svg_html) 115 | 116 | if save_path is not None: 117 | img_path = join(save_path, "%d.svg" % (last_image,)) 118 | with open(img_path, "w") as f: 119 | svg_html.write_svg(f) 120 | last_image += 1 121 | 122 | time_should_have_passed = frame_no / fps 123 | time_passed = (time.time() - simulation_started_time) 124 | if wait and (time_should_have_passed > time_passed): 125 | time.sleep(time_should_have_passed - time_passed) 126 | -------------------------------------------------------------------------------- /tf_rl/simulation/__init__.py: -------------------------------------------------------------------------------- 1 | from .karpathy_game import KarpathyGame 2 | from .double_pendulum import DoublePendulum 3 | from .discrete_hill import DiscreteHill 4 | -------------------------------------------------------------------------------- /tf_rl/simulation/discrete_hill.py: -------------------------------------------------------------------------------- 1 | from random import randint, gauss 2 | 3 | import numpy as np 4 | 5 | class DiscreteHill(object): 6 | 7 | directions = [(0,1), (0,-1), (1,0), (-1,0)] 8 | 9 | def __init__(self, board=(10,10), variance=4.): 10 | self.variance = variance 11 | self.target = (0,0) 12 | while self.target == (0,0): 13 | self.target = (randint(-board[0], board[0]), randint(-board[1], board[1])) 14 | self.position = (0,0) 15 | 16 | self.shortest_path = self.distance(self.position, self.target) 17 | 18 | @staticmethod 19 | def add(p, q): 20 | return (p[0] + q[0], p[1] + q[1]) 21 | 22 | @staticmethod 23 | def distance(p, q): 24 | return abs(p[0] - q[0]) + abs(p[1] - q[1]) 25 | 26 | def estimate_distance(self, p): 27 | distance = DiscreteHill.distance(self.target, p) - DiscreteHill.distance(self.target, self.position) 28 | return distance + abs(gauss(0, self.variance)) 29 | 30 | def observe(self): 31 | return np.array([self.estimate_distance(DiscreteHill.add(self.position, delta)) 32 | for delta in DiscreteHill.directions]) 33 | 34 | def perform_action(self, action): 35 | self.position = DiscreteHill.add(self.position, DiscreteHill.directions[action]) 36 | 37 | def is_over(self): 38 | return self.position == self.target 39 | 40 | def collect_reward(self, action): 41 | return -DiscreteHill.distance(self.target, DiscreteHill.add(self.position, DiscreteHill.directions[action])) \ 42 | + DiscreteHill.distance(self.target, self.position) - 2 43 | -------------------------------------------------------------------------------- /tf_rl/simulation/double_pendulum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..utils import svg 4 | 5 | class DoublePendulum(object): 6 | def __init__(self, params): 7 | """Double Pendulum simulation, where control is 8 | only applied to joint1. 9 | 10 | state of the system is encoded as the following 11 | four values: 12 | state[0]: 13 | angle of first bar from center 14 | (w.r.t. vertical axis) 15 | state[1]: 16 | angular velocity of state[0] 17 | state[2]: 18 | angle of second bar from center 19 | (w.r.t vertical axis) 20 | state[3]: 21 | angular velocity of state[2] 22 | 23 | Params 24 | ------- 25 | g_ms2 : float 26 | gravity acceleration 27 | l1_m : float 28 | length of the first bar (closer to center) 29 | m1_kg: float 30 | mass of the first joint 31 | l2_m : float 32 | length of the second bar 33 | m2_kg : float 34 | mass of the second joint 35 | max_control_input : float 36 | maximum value of angular force applied 37 | to the first joint 38 | """ 39 | self.state = np.array([0.0, 0.0, 0.0, 0.0]) 40 | self.control_input = 0.0 41 | self.params = params 42 | self.size = (400, 300) 43 | 44 | def external_derivatives(self): 45 | """How state of the world changes 46 | naturally due to gravity and momentum 47 | 48 | Returns a vector of four values 49 | which are derivatives of different 50 | state in the internal state representation.""" 51 | # code below is horrible, if somebody 52 | # wants to clean it up, I will gladly 53 | # accept pull request. 54 | G = self.params['g_ms2'] 55 | L1 = self.params['l1_m'] 56 | L2 = self.params['l2_m'] 57 | M1 = self.params['m1_kg'] 58 | M2 = self.params['m2_kg'] 59 | damping = self.params['damping'] 60 | state = self.state 61 | 62 | dydx = np.zeros_like(state) 63 | dydx[0] = state[1] 64 | 65 | del_ = state[2]-state[0] 66 | den1 = (M1+M2)*L1 - M2*L1*np.cos(del_)*np.cos(del_) 67 | dydx[1] = (M2*L1*state[1]*state[1]*np.sin(del_)*np.cos(del_) 68 | + M2*G*np.sin(state[2])*np.cos(del_) 69 | + M2*L2*state[3]*state[3]*np.sin(del_) 70 | - (M1+M2)*G*np.sin(state[0])) / den1 71 | dydx[1] -= damping * state[1] 72 | 73 | dydx[2] = state[3] 74 | 75 | den2 = (L2/L1)*den1 76 | dydx[3] = (-M2*L2*state[3]*state[3]*np.sin(del_)*np.cos(del_) 77 | + (M1+M2)*G*np.sin(state[0])*np.cos(del_) 78 | - (M1+M2)*L1*state[1]*state[1]*np.sin(del_) 79 | - (M1+M2)*G*np.sin(state[2]))/den2 80 | dydx[3] -= damping * state[3] 81 | 82 | 83 | return np.array(dydx) 84 | 85 | def control_derivative(self): 86 | """Derivative of self.state due to control""" 87 | return np.array([0., 0., 0., 1.]) * self.control_input 88 | 89 | def observe(self): 90 | """Returns an observation.""" 91 | return self.state 92 | 93 | def perform_action(self, action): 94 | """Expects action to be in range [-1, 1]""" 95 | self.control_input = action * self.params['max_control_input'] 96 | 97 | def step(self, dt): 98 | """Advance simulation by dt seconds""" 99 | dstate = self.external_derivatives() + self.control_derivative() 100 | self.state += dt * dstate 101 | 102 | def collect_reward(self): 103 | """Reward corresponds to how high is the first joint.""" 104 | _, joint2 = self.joint_positions() 105 | return -joint2[1] 106 | 107 | def joint_positions(self): 108 | """Returns abosolute positions of both joints in coordinate system 109 | where center of system is the attachement point""" 110 | x1 = self.params['l1_m'] * np.sin(self.state[0]) 111 | y1 = self.params['l1_m'] * np.cos(self.state[0]) 112 | 113 | x2 = self.params['l2_m'] * np.sin(self.state[2]) + x1 114 | y2 = self.params['l2_m'] * np.cos(self.state[2]) + y1 115 | 116 | return (x1, y1), (x2, y2) 117 | 118 | def to_html(self, info=[]): 119 | """Visualize""" 120 | info = info[:] 121 | info.append("Reward = %.1f" % self.collect_reward()) 122 | joint1, joint2 = self.joint_positions() 123 | 124 | total_length = self.params['l1_m'] + self.params['l2_m'] 125 | # 9 / 10 th of half the screen width 126 | total_length_px = (8./10.) * (min(self.size) / 2.) 127 | scaling_ratio = total_length_px / total_length 128 | center = (self.size[0] / 2, self.size[1] / 2) 129 | 130 | def transform(point): 131 | """Transforms from state reference world 132 | to screen and pixels reference world""" 133 | 134 | x = center[0] + scaling_ratio * point[0] 135 | y = center[1] + scaling_ratio * point[1] 136 | return int(x), int(y) 137 | 138 | 139 | scene = svg.Scene((self.size[0] + 20, self.size[1] + 20 + 20 * len(info))) 140 | scene.add(svg.Rectangle((10, 10), self.size)) 141 | 142 | joint1 = transform(joint1) 143 | joint2 = transform(joint2) 144 | scene.add(svg.Line(center, joint1)) 145 | scene.add(svg.Line(joint1, joint2)) 146 | 147 | scene.add(svg.Circle(center, 5, color='red')) 148 | scene.add(svg.Circle(joint1, 3, color='blue')) 149 | scene.add(svg.Circle(joint2, 3, color='green')) 150 | 151 | 152 | offset = self.size[1] + 15 153 | for txt in info: 154 | scene.add(svg.Text((10, offset + 20), txt, 15)) 155 | offset += 20 156 | return scene 157 | -------------------------------------------------------------------------------- /tf_rl/simulation/karpathy_game.py: -------------------------------------------------------------------------------- 1 | import math 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import random 5 | import time 6 | 7 | from collections import defaultdict 8 | from euclid import Circle, Point2, Vector2, LineSegment2 9 | 10 | from ..utils import svg 11 | from IPython.display import clear_output, display, HTML 12 | 13 | class GameObject(object): 14 | def __init__(self, position, speed, obj_type, settings): 15 | """Esentially represents circles of different kinds, which have 16 | position and speed.""" 17 | self.settings = settings 18 | self.radius = self.settings["object_radius"] 19 | 20 | self.obj_type = obj_type 21 | self.position = position 22 | self.speed = speed 23 | self.bounciness = 1.0 24 | 25 | def wall_collisions(self): 26 | """Update speed upon collision with the wall.""" 27 | world_size = self.settings["world_size"] 28 | 29 | for dim in range(2): 30 | if self.position[dim] - self.radius <= 0 and self.speed[dim] < 0: 31 | self.speed[dim] = - self.speed[dim] * self.bounciness 32 | elif self.position[dim] + self.radius + 1 >= world_size[dim] and self.speed[dim] > 0: 33 | self.speed[dim] = - self.speed[dim] * self.bounciness 34 | 35 | def move(self, dt): 36 | """Move as if dt seconds passed""" 37 | self.position += dt * self.speed 38 | self.position = Point2(*self.position) 39 | 40 | def step(self, dt): 41 | """Move and bounce of walls.""" 42 | self.wall_collisions() 43 | self.move(dt) 44 | 45 | def as_circle(self): 46 | return Circle(self.position, float(self.radius)) 47 | 48 | def draw(self): 49 | """Return svg object for this item.""" 50 | color = self.settings["colors"][self.obj_type] 51 | return svg.Circle(self.position + Point2(10, 10), self.radius, color=color) 52 | 53 | class KarpathyGame(object): 54 | def __init__(self, settings): 55 | """Initiallize game simulator with settings""" 56 | self.settings = settings 57 | self.size = self.settings["world_size"] 58 | self.walls = [LineSegment2(Point2(0,0), Point2(0,self.size[1])), 59 | LineSegment2(Point2(0,self.size[1]), Point2(self.size[0], self.size[1])), 60 | LineSegment2(Point2(self.size[0], self.size[1]), Point2(self.size[0], 0)), 61 | LineSegment2(Point2(self.size[0], 0), Point2(0,0))] 62 | 63 | self.hero = GameObject(Point2(*self.settings["hero_initial_position"]), 64 | Vector2(*self.settings["hero_initial_speed"]), 65 | "hero", 66 | self.settings) 67 | if not self.settings["hero_bounces_off_walls"]: 68 | self.hero.bounciness = 0.0 69 | 70 | self.objects = [] 71 | for obj_type, number in settings["num_objects"].items(): 72 | for _ in range(number): 73 | self.spawn_object(obj_type) 74 | 75 | self.observation_lines = self.generate_observation_lines() 76 | 77 | self.object_reward = 0 78 | self.collected_rewards = [] 79 | 80 | # every observation_line sees one of objects or wall and 81 | # two numbers representing speed of the object (if applicable) 82 | self.eye_observation_size = len(self.settings["objects"]) + 3 83 | # additionally there are two numbers representing agents own speed and position. 84 | self.observation_size = self.eye_observation_size * len(self.observation_lines) + 2 + 2 85 | 86 | self.directions = [Vector2(*d) for d in [[1,0], [0,1], [-1,0],[0,-1],[0.0,0.0]]] 87 | self.num_actions = len(self.directions) 88 | 89 | self.objects_eaten = defaultdict(lambda: 0) 90 | 91 | def perform_action(self, action_id): 92 | """Change speed to one of hero vectors""" 93 | assert 0 <= action_id < self.num_actions 94 | self.hero.speed *= 0.5 95 | self.hero.speed += self.directions[action_id] * self.settings["delta_v"] 96 | 97 | def spawn_object(self, obj_type): 98 | """Spawn object of a given type and add it to the objects array""" 99 | radius = self.settings["object_radius"] 100 | position = np.random.uniform([radius, radius], np.array(self.size) - radius) 101 | position = Point2(float(position[0]), float(position[1])) 102 | max_speed = np.array(self.settings["maximum_speed"]) 103 | speed = np.random.uniform(-max_speed, max_speed).astype(float) 104 | speed = Vector2(float(speed[0]), float(speed[1])) 105 | 106 | self.objects.append(GameObject(position, speed, obj_type, self.settings)) 107 | 108 | def step(self, dt): 109 | """Simulate all the objects for a given ammount of time. 110 | 111 | Also resolve collisions with the hero""" 112 | for obj in self.objects + [self.hero] : 113 | obj.step(dt) 114 | self.resolve_collisions() 115 | 116 | def squared_distance(self, p1, p2): 117 | return (p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2 118 | 119 | def resolve_collisions(self): 120 | """If hero touches, hero eats. Also reward gets updated.""" 121 | collision_distance = 2 * self.settings["object_radius"] 122 | collision_distance2 = collision_distance ** 2 123 | to_remove = [] 124 | for obj in self.objects: 125 | if self.squared_distance(self.hero.position, obj.position) < collision_distance2: 126 | to_remove.append(obj) 127 | for obj in to_remove: 128 | self.objects.remove(obj) 129 | self.objects_eaten[obj.obj_type] += 1 130 | self.object_reward += self.settings["object_reward"][obj.obj_type] 131 | self.spawn_object(obj.obj_type) 132 | 133 | def inside_walls(self, point): 134 | """Check if the point is inside the walls""" 135 | EPS = 1e-4 136 | return (EPS <= point[0] < self.size[0] - EPS and 137 | EPS <= point[1] < self.size[1] - EPS) 138 | 139 | def observe(self): 140 | """Return observation vector. For all the observation directions it returns representation 141 | of the closest object to the hero - might be nothing, another object or a wall. 142 | Representation of observation for all the directions will be concatenated. 143 | """ 144 | num_obj_types = len(self.settings["objects"]) + 1 # and wall 145 | max_speed_x, max_speed_y = self.settings["maximum_speed"] 146 | 147 | observable_distance = self.settings["observation_line_length"] 148 | 149 | relevant_objects = [obj for obj in self.objects 150 | if obj.position.distance(self.hero.position) < observable_distance] 151 | # objects sorted from closest to furthest 152 | relevant_objects.sort(key=lambda x: x.position.distance(self.hero.position)) 153 | 154 | observation = np.zeros(self.observation_size) 155 | observation_offset = 0 156 | for i, observation_line in enumerate(self.observation_lines): 157 | # shift to hero position 158 | observation_line = LineSegment2(self.hero.position + Vector2(*observation_line.p1), 159 | self.hero.position + Vector2(*observation_line.p2)) 160 | 161 | observed_object = None 162 | # if end of observation line is outside of walls, we see the wall. 163 | if not self.inside_walls(observation_line.p2): 164 | observed_object = "**wall**" 165 | for obj in relevant_objects: 166 | if observation_line.distance(obj.position) < self.settings["object_radius"]: 167 | observed_object = obj 168 | break 169 | object_type_id = None 170 | speed_x, speed_y = 0, 0 171 | proximity = 0 172 | if observed_object == "**wall**": # wall seen 173 | object_type_id = num_obj_types - 1 174 | # a wall has fairly low speed... 175 | speed_x, speed_y = 0, 0 176 | # best candidate is intersection between 177 | # observation_line and a wall, that's 178 | # closest to the hero 179 | best_candidate = None 180 | for wall in self.walls: 181 | candidate = observation_line.intersect(wall) 182 | if candidate is not None: 183 | if (best_candidate is None or 184 | best_candidate.distance(self.hero.position) > 185 | candidate.distance(self.hero.position)): 186 | best_candidate = candidate 187 | if best_candidate is None: 188 | # assume it is due to rounding errors 189 | # and wall is barely touching observation line 190 | proximity = observable_distance 191 | else: 192 | proximity = best_candidate.distance(self.hero.position) 193 | elif observed_object is not None: # agent seen 194 | object_type_id = self.settings["objects"].index(observed_object.obj_type) 195 | speed_x, speed_y = tuple(observed_object.speed) 196 | intersection_segment = obj.as_circle().intersect(observation_line) 197 | assert intersection_segment is not None 198 | try: 199 | proximity = min(intersection_segment.p1.distance(self.hero.position), 200 | intersection_segment.p2.distance(self.hero.position)) 201 | except AttributeError: 202 | proximity = observable_distance 203 | for object_type_idx_loop in range(num_obj_types): 204 | observation[observation_offset + object_type_idx_loop] = 1.0 205 | if object_type_id is not None: 206 | observation[observation_offset + object_type_id] = proximity / observable_distance 207 | observation[observation_offset + num_obj_types] = speed_x / max_speed_x 208 | observation[observation_offset + num_obj_types + 1] = speed_y / max_speed_y 209 | assert num_obj_types + 2 == self.eye_observation_size 210 | observation_offset += self.eye_observation_size 211 | 212 | observation[observation_offset] = self.hero.speed[0] / max_speed_x 213 | observation[observation_offset + 1] = self.hero.speed[1] / max_speed_y 214 | observation_offset += 2 215 | 216 | # add normalized locaiton of the hero in environment 217 | observation[observation_offset] = self.hero.position[0] / 350.0 - 1.0 218 | observation[observation_offset + 1] = self.hero.position[1] / 250.0 - 1.0 219 | 220 | assert observation_offset + 2 == self.observation_size 221 | 222 | return observation 223 | 224 | def distance_to_walls(self): 225 | """Returns distance of a hero to walls""" 226 | res = float('inf') 227 | for wall in self.walls: 228 | res = min(res, self.hero.position.distance(wall)) 229 | return res - self.settings["object_radius"] 230 | 231 | def collect_reward(self): 232 | """Return accumulated object eating score + current distance to walls score""" 233 | wall_reward = self.settings["wall_distance_penalty"] * \ 234 | np.exp(-self.distance_to_walls() / self.settings["tolerable_distance_to_wall"]) 235 | assert wall_reward < 1e-3, "You are rewarding hero for being close to the wall!" 236 | total_reward = wall_reward + self.object_reward 237 | self.object_reward = 0 238 | self.collected_rewards.append(total_reward) 239 | return total_reward 240 | 241 | def plot_reward(self, smoothing = 30): 242 | """Plot evolution of reward over time.""" 243 | plottable = self.collected_rewards[:] 244 | while len(plottable) > 1000: 245 | for i in range(0, len(plottable) - 1, 2): 246 | plottable[i//2] = (plottable[i] + plottable[i+1]) / 2 247 | plottable = plottable[:(len(plottable) // 2)] 248 | x = [] 249 | for i in range(smoothing, len(plottable)): 250 | chunk = plottable[i-smoothing:i] 251 | x.append(sum(chunk) / len(chunk)) 252 | plt.plot(list(range(len(x))), x) 253 | 254 | def generate_observation_lines(self): 255 | """Generate observation segments in settings["num_observation_lines"] directions""" 256 | result = [] 257 | start = Point2(0.0, 0.0) 258 | end = Point2(self.settings["observation_line_length"], 259 | self.settings["observation_line_length"]) 260 | for angle in np.linspace(0, 2*np.pi, self.settings["num_observation_lines"], endpoint=False): 261 | rotation = Point2(math.cos(angle), math.sin(angle)) 262 | current_start = Point2(start[0] * rotation[0], start[1] * rotation[1]) 263 | current_end = Point2(end[0] * rotation[0], end[1] * rotation[1]) 264 | result.append( LineSegment2(current_start, current_end)) 265 | return result 266 | 267 | def _repr_html_(self): 268 | return self.to_html() 269 | 270 | def to_html(self, stats=[]): 271 | """Return svg representation of the simulator""" 272 | 273 | stats = stats[:] 274 | recent_reward = self.collected_rewards[-100:] + [0] 275 | objects_eaten_str = ', '.join(["%s: %s" % (o,c) for o,c in self.objects_eaten.items()]) 276 | stats.extend([ 277 | "nearest wall = %.1f" % (self.distance_to_walls(),), 278 | "reward = %.1f" % (sum(recent_reward)/len(recent_reward),), 279 | "objects eaten => %s" % (objects_eaten_str,), 280 | ]) 281 | 282 | scene = svg.Scene((self.size[0] + 20, self.size[1] + 20 + 20 * len(stats))) 283 | scene.add(svg.Rectangle((10, 10), self.size)) 284 | 285 | 286 | for line in self.observation_lines: 287 | scene.add(svg.Line(line.p1 + self.hero.position + Point2(10,10), 288 | line.p2 + self.hero.position + Point2(10,10))) 289 | 290 | for obj in self.objects + [self.hero] : 291 | scene.add(obj.draw()) 292 | 293 | offset = self.size[1] + 15 294 | for txt in stats: 295 | scene.add(svg.Text((10, offset + 20), txt, 15)) 296 | offset += 20 297 | 298 | return scene 299 | 300 | def setup_draw(self): 301 | """ 302 | An optional method to be triggered in simulate(...) to initialise 303 | the figure handles for rendering. 304 | simulate(...) will run with/without this method declared in the simulation class 305 | As we are using SVG strings in KarpathyGame, it is not curently used. 306 | """ 307 | pass 308 | 309 | def draw(self, stats=[]): 310 | """ 311 | An optional method to be triggered in simulate(...) to render the simulated environment. 312 | It is repeatedly called in each simulated iteration. 313 | simulate(...) will run with/without this method declared in the simulation class. 314 | """ 315 | clear_output(wait=True) 316 | svg_html = self.to_html(stats) 317 | display(svg_html) 318 | -------------------------------------------------------------------------------- /tf_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def base_name(var): 4 | """Extracts value passed to name= when creating a variable""" 5 | return var.name.split('/')[-1].split(':')[0] 6 | 7 | def copy_variables(variables): 8 | res = {} 9 | for v in variables: 10 | name = base_name(v) 11 | copied_var = tf.Variable(v.initialized_value(), name=name) 12 | res[name] = copied_var 13 | return res 14 | -------------------------------------------------------------------------------- /tf_rl/utils/event_queue.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from queue import PriorityQueue 4 | 5 | class EqItem(object): 6 | """Function and sechduled execution timestamp. 7 | 8 | This class is needed because if 9 | we use tuple instead, Python will ocassionally 10 | complaint that it does not know how to compare 11 | functions""" 12 | def __init__(self, ts, f): 13 | self.ts = ts 14 | self.f = f 15 | 16 | def __lt__(self, other): 17 | return self.ts < other.ts 18 | 19 | def __eq__(self, other): 20 | return self.ts == other.ts 21 | 22 | class EventQueue(object): 23 | def __init__(self): 24 | """Event queue for executing events at 25 | specific timepoints. 26 | 27 | In current form it is NOT thread safe.""" 28 | self.q = PriorityQueue() 29 | 30 | def schedule(self, f, ts): 31 | """Schedule f to be execute at time ts""" 32 | self.q.put(EqItem(ts, f)) 33 | 34 | def schedule_recurring(self, f, interval): 35 | """Schedule f to be run every interval seconds. 36 | 37 | It will be run for the first time interval seconds 38 | from now""" 39 | def recuring_f(): 40 | f() 41 | self.schedule(recuring_f, time.time() + interval) 42 | self.schedule(recuring_f, time.time() + interval) 43 | 44 | 45 | def run(self): 46 | """Execute events in the queue as timely as possible.""" 47 | while True: 48 | event = self.q.get() 49 | now = time.time() 50 | if now < event.ts: 51 | time.sleep(event.ts - now) 52 | event.f() 53 | 54 | -------------------------------------------------------------------------------- /tf_rl/utils/geometry.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module assumes that all geometrical points are 3 | represented as 1D numpy arrays. 4 | 5 | It was designed and tested on 2D points, 6 | but if you try it on 3D points you may 7 | be pleasantly surprised ;-) 8 | """ 9 | import numpy as np 10 | 11 | 12 | def point_distance(x, y): 13 | """Returns euclidean distance between points x and y""" 14 | return np.linalg.norm(x-y) 15 | 16 | def point_projected_on_line(line_s, line_e, point): 17 | """Project point on line that goes through line_s and line_e 18 | 19 | assumes line_e is not equal or close to line_s 20 | """ 21 | line_along = line_e - line_s 22 | 23 | transformed_point = point - line_s 24 | 25 | point_dot_line = np.dot(transformed_point, line_along) 26 | line_along_norm = np.dot(line_along, line_along) 27 | 28 | transformed_projection = (point_dot_line / line_along_norm) * line_along 29 | 30 | return transformed_projection + line_s 31 | 32 | def point_segment_distance(segment_s, segment_e, point): 33 | """Returns distance from point to the closest point on segment 34 | connecting points segment_s and segment_e""" 35 | projected = point_projected_on_line(segment_s, segment_e, point) 36 | if np.isclose(point_distance(segment_s, projected) + point_distance(projected, segment_e), 37 | point_distance(segment_s, segment_e)): 38 | # projected on segment 39 | return point_distance(point, projected) 40 | else: 41 | return min(point_distance(point, segment_s), point_distance(point, segment_e)) 42 | -------------------------------------------------------------------------------- /tf_rl/utils/getch.py: -------------------------------------------------------------------------------- 1 | class _Getch: 2 | """Gets a single character from standard input. Does not echo to the 3 | screen.""" 4 | def __init__(self): 5 | try: 6 | self.impl = _GetchWindows() 7 | except ImportError: 8 | self.impl = _GetchUnix() 9 | 10 | def __call__(self): return self.impl() 11 | 12 | 13 | class _GetchUnix: 14 | def __init__(self): 15 | import tty, sys 16 | 17 | def __call__(self): 18 | import sys, tty, termios 19 | fd = sys.stdin.fileno() 20 | old_settings = termios.tcgetattr(fd) 21 | try: 22 | tty.setraw(sys.stdin.fileno()) 23 | ch = sys.stdin.read(1) 24 | finally: 25 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) 26 | return ch 27 | 28 | 29 | class _GetchWindows: 30 | def __init__(self): 31 | import msvcrt 32 | 33 | def __call__(self): 34 | import msvcrt 35 | return msvcrt.getch() 36 | 37 | 38 | getch = _Getch() 39 | -------------------------------------------------------------------------------- /tf_rl/utils/svg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """\ 3 | SVG.py - Construct/display SVG scenes. 4 | 5 | The following code is a lightweight wrapper around SVG files. The metaphor 6 | is to construct a scene, add objects to it, and then write it to a file 7 | to display it. 8 | 9 | This program uses ImageMagick to display the SVG files. ImageMagick also 10 | does a remarkable job of converting SVG files into other formats. 11 | """ 12 | 13 | import os 14 | 15 | def colorstr(rgb): 16 | if type(rgb) == tuple: 17 | return "#%02x%02x%02x" % rgb 18 | else: 19 | return rgb 20 | 21 | def compute_style(style): 22 | color = style.get("color") 23 | style_str = [] 24 | if color is None: 25 | color="none" 26 | style_str.append('fill:%s;' % (colorstr(color),)) 27 | 28 | style_str = 'style="%s"' % (';'.join(style_str),) 29 | return style_str 30 | 31 | class Scene: 32 | def __init__(self, size=(400,400)): 33 | self.items = [] 34 | self.size = size 35 | 36 | def add(self,item): 37 | self.items.append(item) 38 | 39 | def strarray(self): 40 | var = [ 41 | "\n", 42 | "\n" % (self.size[1],self.size[0]), 43 | " \n" 45 | ] 46 | for item in self.items: var += item.strarray() 47 | var += [" \n\n"] 48 | return var 49 | 50 | def write_svg(self, file): 51 | file.writelines(self.strarray()) 52 | 53 | def _repr_html_(self): 54 | return '\n'.join(self.strarray()) 55 | 56 | class Line: 57 | def __init__(self,start,end): 58 | self.start = start #xy tuple 59 | self.end = end #xy tuple 60 | 61 | def strarray(self): 62 | return [" \n" %\ 63 | (self.start[0],self.start[1],self.end[0],self.end[1])] 64 | 65 | 66 | class Circle: 67 | def __init__(self,center,radius, **style_kwargs): 68 | self.center = center 69 | self.radius = radius 70 | self.style_kwargs = style_kwargs 71 | 72 | def strarray(self): 73 | style_str = compute_style(self.style_kwargs) 74 | 75 | return [ 76 | " \n" % (style_str,) 78 | ] 79 | 80 | class Rectangle: 81 | def __init__(self, origin, size, **style_kwargs): 82 | self.origin = origin 83 | self.size = size 84 | self.style_kwargs = style_kwargs 85 | 86 | def strarray(self): 87 | style_str = compute_style(self.style_kwargs) 88 | 89 | return [ 90 | " \n" % (self.size[0], style_str) 92 | ] 93 | 94 | class Text: 95 | def __init__(self,origin,text,size=24): 96 | self.origin = origin 97 | self.text = text 98 | self.size = size 99 | return 100 | 101 | def strarray(self): 102 | return [" \n" %\ 103 | (self.origin[0],self.origin[1],self.size), 104 | " %s\n" % self.text, 105 | " \n"] 106 | 107 | 108 | 109 | 110 | def test(): 111 | scene = Scene() 112 | scene.add(Rectangle((100,100),(200,200), **{"color":(0,255,255)} )) 113 | scene.add(Line((200,200),(200,300))) 114 | scene.add(Line((200,200),(300,200))) 115 | scene.add(Line((200,200),(100,200))) 116 | scene.add(Line((200,200),(200,100))) 117 | scene.add(Circle((200,200),30, **{"color":(0,0,255)} )) 118 | scene.add(Circle((200,300),30, **{"color":(0,255,0)} )) 119 | scene.add(Circle((300,200),30, **{"color":(255,0,0)} )) 120 | scene.add(Circle((100,200),30, **{"color":(255,255,0)} )) 121 | scene.add(Circle((200,100),30, **{"color":(255,0,255)} )) 122 | scene.add(Text((50,50),"Testing SVG")) 123 | with open("test.svg", "w") as f: 124 | scene.write_svg(f) 125 | 126 | if __name__ == '__main__': 127 | test() 128 | --------------------------------------------------------------------------------