├── .gitignore
├── LICENSE.md
├── LICENSE2.md
├── README.md
├── TODO.txt
├── data
└── example.gif
├── notebooks
├── DoublePendulum.ipynb
├── MLP.ipynb
├── MLP_copy.ipynb
├── XOR Network.ipynb
├── game.ipynb
├── game_memory.ipynb
├── karpathy_game.ipynb
├── pong.py
└── tf_rl
├── requirements.txt
├── saved_models
├── checkpoint
└── karpathy_game.ckpt
├── scripts
└── make_gif.sh
└── tf_rl
├── __init__.py
├── controller
├── __init__.py
├── discrete_deepq.py
└── human_controller.py
├── models.py
├── simulate.py
├── simulation
├── __init__.py
├── discrete_hill.py
├── double_pendulum.py
└── karpathy_game.py
└── utils
├── __init__.py
├── event_queue.py
├── geometry.py
├── getch.py
└── svg.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | __pycache__/
3 | *.pyc
4 | *.pyo
5 | notebooks/Goofiness.ipynb
6 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2015 Szymon Sidor
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE2.md:
--------------------------------------------------------------------------------
1 | Copyright 2016 Szymon Sidor. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright 2015, The TensorFlow Authors.
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # This reposity is now obsolte!
2 |
3 | Check out the new simpler, better performing and more complete implementation that we released at OpenAI:
4 |
5 | https://github.com/openai/baselines
6 |
7 |
8 | (scroll for docs of the obsolete version)
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | ### Reinforcement Learning using Tensor Flow
43 |
44 |
45 | #### Quick start
46 |
47 | Check out Karpathy game in `notebooks` folder.
48 |
49 |
50 |
51 | *The image above depicts a strategy learned by the DeepQ controller. Available actions are accelerating top, bottom, left or right. The reward signal is +1 for the green fellas, -1 for red and -5 for orange.*
52 |
53 | #### Requirements
54 |
55 | - `future==0.15.2`
56 | - `euclid==0.1`
57 | - `inkscape` (for animation gif creation)
58 |
59 | #### How does this all fit together.
60 |
61 | `tf_rl` has controllers and simulators which can be pieced together using simulate function.
62 |
63 | #### Using human controller.
64 | Want to have some fun controlling the simulation by yourself? You got it!
65 | Use `tf_rl.controller.HumanController` in your simulation.
66 |
67 | To issue commands run in terminal
68 | ```python3
69 | python3 tf_rl/controller/human_controller.py
70 | ```
71 | For it to work you also need to have a redis server running locally.
72 |
73 | #### Writing your own controller
74 | To write your own controller define a controller class with 3 functions:
75 | - `action(self, observation)` given an observation (usually a tensor of numbers) representing an observation returns action to perform.
76 | - `store(self, observation, action, reward, newobservation)` called each time a transition is observed from `observation` to `newobservation`. Transition is a consequence of `action` and has associated `reward`
77 | - `training_step(self)` if your controller requires training that is the place to do it, should not take to long, because it will be called roughly every action execution.
78 |
79 | #### Writing your own simulation
80 | To write your own simulation define a simulation class with 4 functions:
81 | - `observe(self)` returns a current observation
82 | - `collect_reward(self)` returns the reward accumulated since the last time function was called.
83 | - `perform_action(self, action)` updates internal state to reflect the fact that `aciton` was executed
84 | - `step(self, dt)` update internal state as if `dt` of simulation time has passed.
85 | - `to_html(self, info=[])` generate an html visualization of the game. `info` can be optionally passed an has a list of strings that should be displayed along with the visualization
86 |
87 |
88 |
89 | #### Creating GIFs based on simulation
90 | The `simulate` method accepts `save_path` argument which is a folder where all the consecutive images will be stored.
91 | To make them into a GIF use `scripts/make_gif.sh PATH` where path is the same as the path you passed to `save_path` argument
92 |
--------------------------------------------------------------------------------
/TODO.txt:
--------------------------------------------------------------------------------
1 | -> Variable naming
2 | -> Documentation
3 | -> move to separate files
4 | -> more interesting examples
5 |
--------------------------------------------------------------------------------
/data/example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siemanko/tensorflow-deepq/149e69e5340984d75df3ff1a374920d870517fb9/data/example.gif
--------------------------------------------------------------------------------
/notebooks/DoublePendulum.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%load_ext autoreload\n",
12 | "%autoreload 2"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {
19 | "collapsed": true
20 | },
21 | "outputs": [],
22 | "source": [
23 | "import math\n",
24 | "import random\n",
25 | "import time\n",
26 | "\n",
27 | "from collections import defaultdict\n",
28 | "\n"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 8,
34 | "metadata": {
35 | "collapsed": true
36 | },
37 | "outputs": [],
38 | "source": [
39 | "from tf_rl.simulation import DoublePendulum\n",
40 | "from tf_rl import simulate"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 9,
46 | "metadata": {
47 | "collapsed": false
48 | },
49 | "outputs": [],
50 | "source": [
51 | "DOUBLE_PENDULUM_PARAMS = {\n",
52 | " 'g_ms2': 9.8, # acceleration due to gravity, in m/s^2\n",
53 | " 'l1_m': 1.0, # length of pendulum 1 in m\n",
54 | " 'l2_m': 2.0, # length of pendulum 2 in m\n",
55 | " 'm1_kg': 1.0, # mass of pendulum 1 in kg\n",
56 | " 'm2_kg': 1.0, # mass of pendulum 2 in kg\n",
57 | " 'damping': 0.4,\n",
58 | " 'max_control_input': 20.0\n",
59 | "}"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 10,
65 | "metadata": {
66 | "collapsed": false
67 | },
68 | "outputs": [],
69 | "source": [
70 | "d = DoublePendulum(DOUBLE_PENDULUM_PARAMS)"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 11,
76 | "metadata": {
77 | "collapsed": true
78 | },
79 | "outputs": [],
80 | "source": [
81 | "d.perform_action(0.2)"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 13,
87 | "metadata": {
88 | "collapsed": false
89 | },
90 | "outputs": [
91 | {
92 | "data": {
93 | "text/html": [
94 | "\n",
95 | "\n",
96 | "\n"
130 | ],
131 | "text/plain": [
132 | ""
133 | ]
134 | },
135 | "metadata": {},
136 | "output_type": "display_data"
137 | },
138 | {
139 | "name": "stdout",
140 | "output_type": "stream",
141 | "text": [
142 | "Interrupted\n"
143 | ]
144 | }
145 | ],
146 | "source": [
147 | "try:\n",
148 | " simulate(d, fps=30, actions_per_simulation_second=1, speed=1.0, simulation_resultion=0.01)\n",
149 | "except KeyboardInterrupt:\n",
150 | " print(\"Interrupted\")"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {
157 | "collapsed": true
158 | },
159 | "outputs": [],
160 | "source": []
161 | }
162 | ],
163 | "metadata": {
164 | "kernelspec": {
165 | "display_name": "Python 2",
166 | "language": "python",
167 | "name": "python2"
168 | },
169 | "language_info": {
170 | "codemirror_mode": {
171 | "name": "ipython",
172 | "version": 2
173 | },
174 | "file_extension": ".py",
175 | "mimetype": "text/x-python",
176 | "name": "python",
177 | "nbconvert_exporter": "python",
178 | "pygments_lexer": "ipython2",
179 | "version": "2.7.8"
180 | }
181 | },
182 | "nbformat": 4,
183 | "nbformat_minor": 0
184 | }
185 |
--------------------------------------------------------------------------------
/notebooks/MLP.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import tensorflow as tf\n",
13 | "\n",
14 | "from __future__ import print_function"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {
21 | "collapsed": true
22 | },
23 | "outputs": [],
24 | "source": [
25 | "def create_examples(N, batch_size):\n",
26 | " A = np.random.binomial(n=1, p=0.5, size=(batch_size, N))\n",
27 | " B = np.random.binomial(n=1, p=0.5, size=(batch_size, N,))\n",
28 | "\n",
29 | " X = np.zeros((batch_size, 2 *N,), dtype=np.float32)\n",
30 | " X[:,:N], X[:,N:] = A, B\n",
31 | "\n",
32 | " Y = (A ^ B).astype(np.float32)\n",
33 | " return X,Y"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 3,
39 | "metadata": {
40 | "collapsed": true
41 | },
42 | "outputs": [],
43 | "source": [
44 | "import math\n",
45 | "\n",
46 | "class Layer(object):\n",
47 | " def __init__(self, input_sizes, output_size):\n",
48 | " \"\"\"Cretes a neural network layer.\"\"\"\n",
49 | " if type(input_sizes) != list:\n",
50 | " input_sizes = [input_sizes]\n",
51 | " \n",
52 | " self.input_sizes = input_sizes\n",
53 | " self.output_size = output_size\n",
54 | " \n",
55 | " self.Ws = []\n",
56 | " for input_size in input_sizes:\n",
57 | " tensor_W = tf.random_uniform((input_size, output_size),\n",
58 | " -1.0 / math.sqrt(input_size),\n",
59 | " 1.0 / math.sqrt(input_size))\n",
60 | " self.Ws.append(tf.Variable(tensor_W))\n",
61 | "\n",
62 | " tensor_b = tf.zeros((output_size,))\n",
63 | " self.b = tf.Variable(tensor_b)\n",
64 | " \n",
65 | " def __call__(self, xs):\n",
66 | " if type(xs) != list:\n",
67 | " xs = [xs]\n",
68 | " assert len(xs) == len(self.Ws), \\\n",
69 | " \"Expected %d input vectors, got %d\" % (len(self.Ws), len(x))\n",
70 | " return sum([tf.matmul(x, W) for x, W in zip(xs, self.Ws)]) + self.b\n",
71 | "\n",
72 | " \n",
73 | "class MLP(object):\n",
74 | " def __init__(self, input_sizes, hiddens, nonlinearities):\n",
75 | " self.input_sizes = input_sizes\n",
76 | " self.hiddens = hiddens\n",
77 | " self.input_nonlinearity, self.layer_nonlinearities = nonlinearities[0], nonlinearities[1:]\n",
78 | "\n",
79 | " assert len(hiddens) == len(nonlinearities), \\\n",
80 | " \"Number of hiddens must be equal to number of nonlinearities\"\n",
81 | " \n",
82 | " self.input_layer = Layer(input_sizes, hiddens[0])\n",
83 | " self.layers = [Layer(h_from, h_to) for h_from, h_to in zip(hiddens[:-1], hiddens[1:])]\n",
84 | "\n",
85 | " def __call__(self, xs):\n",
86 | " if type(xs) != list:\n",
87 | " xs = [xs]\n",
88 | " hidden = self.input_nonlinearity(self.input_layer(xs))\n",
89 | " for layer, nonlinearity in zip(self.layers, self.layer_nonlinearities):\n",
90 | " hidden = nonlinearity(layer(hidden))\n",
91 | " return hidden"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 19,
97 | "metadata": {
98 | "collapsed": false
99 | },
100 | "outputs": [
101 | {
102 | "name": "stderr",
103 | "output_type": "stream",
104 | "text": [
105 | "Exception AssertionError: AssertionError() in > ignored\n"
106 | ]
107 | }
108 | ],
109 | "source": [
110 | "tf.ops.reset_default_graph()\n",
111 | "sess = tf.InteractiveSession()"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 20,
117 | "metadata": {
118 | "collapsed": false
119 | },
120 | "outputs": [],
121 | "source": [
122 | "N = 5\n",
123 | "# we add a single hidden layer of size 12\n",
124 | "# otherwise code is similar to above\n",
125 | "HIDDEN_SIZE = 12\n",
126 | "\n",
127 | "x = tf.placeholder(tf.float32, (None, 2 * N), name=\"x\")\n",
128 | "y_golden = tf.placeholder(tf.float32, (None, N), name=\"y\")\n",
129 | "\n",
130 | "mlp = MLP(2 * N, [HIDDEN_SIZE, N], [tf.tanh, tf.sigmoid])\n",
131 | "y = mlp(x)\n",
132 | "\n",
133 | "cost = tf.reduce_mean(tf.square(y - y_golden))\n",
134 | "\n",
135 | "optimizer = tf.train.AdagradOptimizer(learning_rate=0.3)\n",
136 | "train_op = optimizer.minimize(cost)\n",
137 | "sess.run(tf.initialize_all_variables())"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 21,
143 | "metadata": {
144 | "collapsed": false
145 | },
146 | "outputs": [
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "0.241206\n",
152 | "0.246315\n",
153 | "0.193208\n",
154 | "0.107224\n",
155 | "0.100235\n",
156 | "0.0613462\n",
157 | "0.0480775\n",
158 | "0.0498072\n",
159 | "0.0403215\n",
160 | "0.0474323\n"
161 | ]
162 | }
163 | ],
164 | "source": [
165 | "for t in range(5000):\n",
166 | " example_x, example_y = create_examples(N, 10)\n",
167 | " cost_t, _ = sess.run([cost, train_op], {x: example_x, y_golden: example_y})\n",
168 | " if t % 500 == 0: \n",
169 | " print(cost_t.mean())"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 22,
175 | "metadata": {
176 | "collapsed": false
177 | },
178 | "outputs": [
179 | {
180 | "name": "stdout",
181 | "output_type": "stream",
182 | "text": [
183 | "Accuracy over 1000 examples: 98 %\n"
184 | ]
185 | }
186 | ],
187 | "source": [
188 | "N_EXAMPLES = 1000\n",
189 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n",
190 | "is_correct = tf.less_equal(tf.abs(y - y_golden), tf.constant(0.5))\n",
191 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n",
192 | "\n",
193 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n",
194 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {
201 | "collapsed": true
202 | },
203 | "outputs": [],
204 | "source": []
205 | }
206 | ],
207 | "metadata": {
208 | "kernelspec": {
209 | "display_name": "Python 2",
210 | "language": "python",
211 | "name": "python2"
212 | },
213 | "language_info": {
214 | "codemirror_mode": {
215 | "name": "ipython",
216 | "version": 2
217 | },
218 | "file_extension": ".py",
219 | "mimetype": "text/x-python",
220 | "name": "python",
221 | "nbconvert_exporter": "python",
222 | "pygments_lexer": "ipython2",
223 | "version": "2.7.8"
224 | }
225 | },
226 | "nbformat": 4,
227 | "nbformat_minor": 0
228 | }
229 |
--------------------------------------------------------------------------------
/notebooks/MLP_copy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%load_ext autoreload\n",
12 | "%autoreload 2"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {
19 | "collapsed": true
20 | },
21 | "outputs": [],
22 | "source": [
23 | "import numpy as np\n",
24 | "import tensorflow as tf\n",
25 | "\n",
26 | "from __future__ import print_function"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 3,
32 | "metadata": {
33 | "collapsed": true
34 | },
35 | "outputs": [],
36 | "source": [
37 | "def create_examples(N, batch_size):\n",
38 | " A = np.random.binomial(n=1, p=0.5, size=(batch_size, N))\n",
39 | " B = np.random.binomial(n=1, p=0.5, size=(batch_size, N,))\n",
40 | "\n",
41 | " X = np.zeros((batch_size, 2 *N,), dtype=np.float32)\n",
42 | " X[:,:N], X[:,N:] = A, B\n",
43 | "\n",
44 | " Y = (A ^ B).astype(np.float32)\n",
45 | " return X,Y"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 4,
51 | "metadata": {
52 | "collapsed": true
53 | },
54 | "outputs": [],
55 | "source": [
56 | "from tf_rl.models import MLP"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 5,
62 | "metadata": {
63 | "collapsed": false
64 | },
65 | "outputs": [],
66 | "source": [
67 | "tf.ops.reset_default_graph()\n",
68 | "sess = tf.InteractiveSession()"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 6,
74 | "metadata": {
75 | "collapsed": false
76 | },
77 | "outputs": [],
78 | "source": [
79 | "N = 5\n",
80 | "# we add a single hidden layer of size 12\n",
81 | "# otherwise code is similar to above\n",
82 | "HIDDEN_SIZE = 12\n",
83 | "\n",
84 | "x = tf.placeholder(tf.float32, (None, 2 * N), name=\"x\")\n",
85 | "y_golden = tf.placeholder(tf.float32, (None, N), name=\"y\")\n",
86 | "\n",
87 | "mlp = MLP(2 * N, [HIDDEN_SIZE, N], [tf.tanh, tf.sigmoid])\n",
88 | "y = mlp(x)\n",
89 | "\n",
90 | "cost = tf.reduce_mean(tf.square(y - y_golden))\n",
91 | "\n",
92 | "optimizer = tf.train.AdagradOptimizer(learning_rate=0.3)\n",
93 | "train_op = optimizer.minimize(cost)\n",
94 | "sess.run(tf.initialize_all_variables())"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 7,
100 | "metadata": {
101 | "collapsed": false
102 | },
103 | "outputs": [],
104 | "source": [
105 | "mlp2 = mlp.copy(sess)"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 8,
111 | "metadata": {
112 | "collapsed": false
113 | },
114 | "outputs": [
115 | {
116 | "name": "stdout",
117 | "output_type": "stream",
118 | "text": [
119 | "0.25368\n",
120 | "0.24415\n",
121 | "0.170396\n",
122 | "0.0975643\n",
123 | "0.0823987\n",
124 | "0.0314237\n",
125 | "0.0417636\n",
126 | "0.0364044\n",
127 | "0.039408\n",
128 | "0.026788\n"
129 | ]
130 | }
131 | ],
132 | "source": [
133 | "for t in range(5000):\n",
134 | " example_x, example_y = create_examples(N, 10)\n",
135 | " cost_t, _ = sess.run([cost, train_op], {x: example_x, y_golden: example_y})\n",
136 | " if t % 500 == 0: \n",
137 | " print(cost_t.mean())"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 9,
143 | "metadata": {
144 | "collapsed": false
145 | },
146 | "outputs": [
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "Accuracy over 1000 examples: 99 %\n"
152 | ]
153 | }
154 | ],
155 | "source": [
156 | "N_EXAMPLES = 1000\n",
157 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n",
158 | "is_correct = tf.less_equal(tf.abs(y - y_golden), tf.constant(0.5))\n",
159 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n",
160 | "\n",
161 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n",
162 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 10,
168 | "metadata": {
169 | "collapsed": false
170 | },
171 | "outputs": [
172 | {
173 | "name": "stdout",
174 | "output_type": "stream",
175 | "text": [
176 | "Accuracy over 1000 examples: 52 %\n"
177 | ]
178 | }
179 | ],
180 | "source": [
181 | "# If copy works accuracy should be around 50% for this one\n",
182 | "N_EXAMPLES = 1000\n",
183 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n",
184 | "is_correct = tf.less_equal(tf.abs(mlp2(x) - y_golden), tf.constant(0.5))\n",
185 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n",
186 | "\n",
187 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n",
188 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {
195 | "collapsed": true
196 | },
197 | "outputs": [],
198 | "source": []
199 | }
200 | ],
201 | "metadata": {
202 | "kernelspec": {
203 | "display_name": "Python 2",
204 | "language": "python",
205 | "name": "python2"
206 | },
207 | "language_info": {
208 | "codemirror_mode": {
209 | "name": "ipython",
210 | "version": 2
211 | },
212 | "file_extension": ".py",
213 | "mimetype": "text/x-python",
214 | "name": "python",
215 | "nbconvert_exporter": "python",
216 | "pygments_lexer": "ipython2",
217 | "version": "2.7.8"
218 | }
219 | },
220 | "nbformat": 4,
221 | "nbformat_minor": 0
222 | }
223 |
--------------------------------------------------------------------------------
/notebooks/XOR Network.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import tensorflow as tf\n",
13 | "\n",
14 | "from __future__ import print_function"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "# XOR Network"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "### Data generation"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {
35 | "collapsed": true
36 | },
37 | "outputs": [],
38 | "source": [
39 | "def create_examples(N, batch_size):\n",
40 | " A = np.random.binomial(n=1, p=0.5, size=(batch_size, N))\n",
41 | " B = np.random.binomial(n=1, p=0.5, size=(batch_size, N,))\n",
42 | "\n",
43 | " X = np.zeros((batch_size, 2 *N,), dtype=np.float32)\n",
44 | " X[:,:N], X[:,N:] = A, B\n",
45 | "\n",
46 | " Y = (A ^ B).astype(np.float32)\n",
47 | " return X,Y"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 3,
53 | "metadata": {
54 | "collapsed": false
55 | },
56 | "outputs": [
57 | {
58 | "name": "stdout",
59 | "output_type": "stream",
60 | "text": [
61 | "[ 0. 1. 0.] xor [ 1. 1. 1.] equals [ 1. 0. 1.]\n",
62 | "[ 0. 0. 1.] xor [ 1. 1. 0.] equals [ 1. 1. 1.]\n"
63 | ]
64 | }
65 | ],
66 | "source": [
67 | "X, Y = create_examples(3, 2)\n",
68 | "print(X[0,:3], \"xor\", X[0,3:],\"equals\", Y[0])\n",
69 | "print(X[1,:3], \"xor\", X[1,3:],\"equals\", Y[1])\n"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "### Xor cannot be solved with single layer of neural network"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 22,
82 | "metadata": {
83 | "collapsed": false
84 | },
85 | "outputs": [],
86 | "source": [
87 | "import math\n",
88 | "\n",
89 | "class Layer(object):\n",
90 | " def __init__(self, input_size, output_size):\n",
91 | " tensor_b = tf.zeros((output_size,))\n",
92 | " self.b = tf.Variable(tensor_b)\n",
93 | " tensor_W = tf.random_uniform((input_size, output_size),\n",
94 | " -1.0 / math.sqrt(input_size),\n",
95 | " 1.0 / math.sqrt(input_size))\n",
96 | " self.W = tf.Variable(tensor_W)\n",
97 | "\n",
98 | " def __call__(self, x):\n",
99 | " return tf.matmul(x, self.W) + self.b"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 105,
105 | "metadata": {
106 | "collapsed": false
107 | },
108 | "outputs": [],
109 | "source": [
110 | "tf.ops.reset_default_graph()\n",
111 | "sess = tf.InteractiveSession()"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 106,
117 | "metadata": {
118 | "collapsed": false
119 | },
120 | "outputs": [],
121 | "source": [
122 | "N = 5\n",
123 | "# x represents input data\n",
124 | "x = tf.placeholder(tf.float32, (None, 2 * N), name=\"x\")\n",
125 | "# y_golden is a reference output data.\n",
126 | "y_golden = tf.placeholder(tf.float32, (None, N), name=\"y\")\n",
127 | "\n",
128 | "layer1 = Layer(2 * N, N)\n",
129 | "# y is a linear projection of x with nonlinearity applied to the result.\n",
130 | "y = tf.nn.sigmoid(layer1(x))\n",
131 | "\n",
132 | "# mean squared error over all examples and all N output dimensions.\n",
133 | "cost = tf.reduce_mean(tf.square(y - y_golden))\n",
134 | "\n",
135 | "# create a function that will optimize the neural network\n",
136 | "optimizer = tf.train.AdagradOptimizer(learning_rate=0.3)\n",
137 | "train_op = optimizer.minimize(cost)\n",
138 | "\n",
139 | "# initialize the variables\n",
140 | "sess.run(tf.initialize_all_variables())"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 107,
146 | "metadata": {
147 | "collapsed": false
148 | },
149 | "outputs": [
150 | {
151 | "name": "stdout",
152 | "output_type": "stream",
153 | "text": [
154 | "0.262958\n",
155 | "0.249229\n",
156 | "0.259427\n",
157 | "0.245061\n",
158 | "0.252946\n",
159 | "0.24782\n",
160 | "0.250937\n",
161 | "0.246418\n",
162 | "0.246755\n",
163 | "0.244774\n"
164 | ]
165 | }
166 | ],
167 | "source": [
168 | "for t in range(5000):\n",
169 | " example_x, example_y = create_examples(N, 10)\n",
170 | " cost_t, _ = sess.run([cost, train_op], {x: example_x, y_golden: example_y})\n",
171 | " if t % 500 == 0: \n",
172 | " print(cost_t.mean())"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {},
178 | "source": [
179 | "### Notice that the error is far from zero.\n",
180 | "\n",
181 | "Actually network always predicts approximately $0.5$, regardless of input data. That yields error of about $0.25$, because we use mean squared error ($0.5^2 = 0.25$). "
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 109,
187 | "metadata": {
188 | "collapsed": false
189 | },
190 | "outputs": [
191 | {
192 | "name": "stdout",
193 | "output_type": "stream",
194 | "text": [
195 | "[[ 1. 0. 1. 1. 1. 1. 0. 0. 1. 1.]\n",
196 | " [ 1. 0. 1. 1. 0. 1. 1. 1. 1. 1.]\n",
197 | " [ 0. 0. 1. 0. 1. 0. 0. 1. 1. 1.]]\n",
198 | "[array([[ 0.56099683, 0.54470569, 0.4940519 , 0.49518651, 0.54470527],\n",
199 | " [ 0.56658453, 0.52068532, 0.48442408, 0.4748241 , 0.5073036 ],\n",
200 | " [ 0.53004831, 0.52866411, 0.48705727, 0.48926324, 0.53761232]], dtype=float32)]\n"
201 | ]
202 | }
203 | ],
204 | "source": [
205 | "X, _ = create_examples(N, 3)\n",
206 | "prediction = sess.run([y], {x: X})\n",
207 | "print(X)\n",
208 | "print(prediction)"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "metadata": {},
214 | "source": [
215 | "### Accuracy is not that hard to predict..."
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 113,
221 | "metadata": {
222 | "collapsed": false
223 | },
224 | "outputs": [
225 | {
226 | "name": "stdout",
227 | "output_type": "stream",
228 | "text": [
229 | "Accuracy over 1000 examples: 48 %\n"
230 | ]
231 | }
232 | ],
233 | "source": [
234 | "N_EXAMPLES = 1000\n",
235 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n",
236 | "# one day I need to write a wrapper which will turn the expression\n",
237 | "# below to:\n",
238 | "# tf.abs(y - y_golden) < 0.5\n",
239 | "is_correct = tf.less_equal(tf.abs(y - y_golden), tf.constant(0.5))\n",
240 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n",
241 | "\n",
242 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n",
243 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))"
244 | ]
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "metadata": {},
249 | "source": [
250 | "### Xor Network with 2 layers"
251 | ]
252 | },
253 | {
254 | "cell_type": "code",
255 | "execution_count": 149,
256 | "metadata": {
257 | "collapsed": false
258 | },
259 | "outputs": [
260 | {
261 | "name": "stderr",
262 | "output_type": "stream",
263 | "text": [
264 | "Exception AssertionError: AssertionError() in > ignored\n"
265 | ]
266 | }
267 | ],
268 | "source": [
269 | "tf.ops.reset_default_graph()\n",
270 | "sess = tf.InteractiveSession()"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 150,
276 | "metadata": {
277 | "collapsed": false
278 | },
279 | "outputs": [],
280 | "source": [
281 | "N = 5\n",
282 | "# we add a single hidden layer of size 12\n",
283 | "# otherwise code is similar to above\n",
284 | "HIDDEN_SIZE = 12\n",
285 | "\n",
286 | "x = tf.placeholder(tf.float32, (None, 2 * N), name=\"x\")\n",
287 | "y_golden = tf.placeholder(tf.float32, (None, N), name=\"y\")\n",
288 | "\n",
289 | "layer1 = Layer(2 * N, HIDDEN_SIZE)\n",
290 | "layer2 = Layer(HIDDEN_SIZE, N) # <------- HERE IT IS!\n",
291 | "\n",
292 | "hidden_repr = tf.nn.tanh(layer1(x))\n",
293 | "y = tf.nn.sigmoid(layer2(hidden_repr))\n",
294 | "\n",
295 | "cost = tf.reduce_mean(tf.square(y - y_golden))\n",
296 | "\n",
297 | "optimizer = tf.train.AdagradOptimizer(learning_rate=0.3)\n",
298 | "train_op = optimizer.minimize(cost)\n",
299 | "sess.run(tf.initialize_all_variables())"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": 151,
305 | "metadata": {
306 | "collapsed": false
307 | },
308 | "outputs": [
309 | {
310 | "name": "stdout",
311 | "output_type": "stream",
312 | "text": [
313 | "0.241089\n",
314 | "0.240045\n",
315 | "0.1631\n",
316 | "0.0709099\n",
317 | "0.0326128\n",
318 | "0.0087687\n",
319 | "0.00526247\n",
320 | "0.00518266\n",
321 | "0.00272845\n",
322 | "0.00213744\n"
323 | ]
324 | }
325 | ],
326 | "source": [
327 | "for t in range(5000):\n",
328 | " example_x, example_y = create_examples(N, 10)\n",
329 | " cost_t, _ = sess.run([cost, train_op], {x: example_x, y_golden: example_y})\n",
330 | " if t % 500 == 0: \n",
331 | " print(cost_t.mean())"
332 | ]
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "metadata": {},
337 | "source": [
338 | "### This time the network works a tad better"
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": 156,
344 | "metadata": {
345 | "collapsed": false
346 | },
347 | "outputs": [
348 | {
349 | "name": "stdout",
350 | "output_type": "stream",
351 | "text": [
352 | "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
353 | " [ 1. 0. 0. 1. 1. 1. 0. 1. 1. 1.]\n",
354 | " [ 0. 1. 1. 1. 0. 0. 0. 0. 1. 0.]]\n",
355 | "[[ 0. 0. 0. 0. 1.]\n",
356 | " [ 0. 0. 1. 0. 0.]\n",
357 | " [ 0. 1. 1. 0. 0.]]\n",
358 | "[array([[ 0.10384335, 0.04389301, 0.05774897, 0.04509954, 0.9374879 ],\n",
359 | " [ 0.05130127, 0.02655722, 0.97246277, 0.03545236, 0.04168396],\n",
360 | " [ 0.03924223, 0.96327722, 0.96935028, 0.03265698, 0.0310236 ]], dtype=float32)]\n"
361 | ]
362 | }
363 | ],
364 | "source": [
365 | "X, Y = create_examples(N, 3)\n",
366 | "prediction = sess.run([y], {x: X})\n",
367 | "print(X)\n",
368 | "print(Y)\n",
369 | "print(prediction)"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": 152,
375 | "metadata": {
376 | "collapsed": false
377 | },
378 | "outputs": [
379 | {
380 | "name": "stdout",
381 | "output_type": "stream",
382 | "text": [
383 | "Accuracy over 1000 examples: 100 %\n"
384 | ]
385 | }
386 | ],
387 | "source": [
388 | "N_EXAMPLES = 1000\n",
389 | "example_x, example_y = create_examples(N, N_EXAMPLES)\n",
390 | "is_correct = tf.less_equal(tf.abs(y - y_golden), tf.constant(0.5))\n",
391 | "accuracy = tf.reduce_mean(tf.cast(is_correct, \"float\"))\n",
392 | "\n",
393 | "acc_result = sess.run(accuracy, {x: example_x, y_golden: example_y})\n",
394 | "print(\"Accuracy over %d examples: %.0f %%\" % (N_EXAMPLES, 100.0 * acc_result))"
395 | ]
396 | }
397 | ],
398 | "metadata": {
399 | "kernelspec": {
400 | "display_name": "Python 2",
401 | "language": "python",
402 | "name": "python2"
403 | },
404 | "language_info": {
405 | "codemirror_mode": {
406 | "name": "ipython",
407 | "version": 2
408 | },
409 | "file_extension": ".py",
410 | "mimetype": "text/x-python",
411 | "name": "python",
412 | "nbconvert_exporter": "python",
413 | "pygments_lexer": "ipython2",
414 | "version": "2.7.8"
415 | }
416 | },
417 | "nbformat": 4,
418 | "nbformat_minor": 0
419 | }
420 |
--------------------------------------------------------------------------------
/notebooks/game.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 4,
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "outputs": [
10 | {
11 | "name": "stdout",
12 | "output_type": "stream",
13 | "text": [
14 | "The autoreload extension is already loaded. To reload it, use:\n",
15 | " %reload_ext autoreload\n"
16 | ]
17 | }
18 | ],
19 | "source": [
20 | "%load_ext autoreload\n",
21 | "%autoreload 2"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 5,
27 | "metadata": {
28 | "collapsed": true
29 | },
30 | "outputs": [],
31 | "source": [
32 | "import matplotlib.pyplot as plt\n",
33 | "%matplotlib inline"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 6,
39 | "metadata": {
40 | "collapsed": false
41 | },
42 | "outputs": [],
43 | "source": [
44 | "import numpy as np\n",
45 | "import tempfile\n",
46 | "import tensorflow as tf\n",
47 | "\n",
48 | "from tf_rl.controller import DiscreteDeepQ, HumanController\n",
49 | "from tf_rl.simulation import KarpathyGame\n",
50 | "from tf_rl import simulate\n",
51 | "from tf_rl.models import MLP"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 7,
57 | "metadata": {
58 | "collapsed": false
59 | },
60 | "outputs": [
61 | {
62 | "name": "stdout",
63 | "output_type": "stream",
64 | "text": [
65 | "/tmp/tmpdzxofD\n"
66 | ]
67 | }
68 | ],
69 | "source": [
70 | "LOG_DIR = tempfile.mkdtemp()\n",
71 | "print(LOG_DIR)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 8,
77 | "metadata": {
78 | "collapsed": false
79 | },
80 | "outputs": [],
81 | "source": [
82 | "from tf_rl.simulation import DiscreteHill"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 9,
88 | "metadata": {
89 | "collapsed": false
90 | },
91 | "outputs": [],
92 | "source": [
93 | "# Tensorflow business - it is always good to reset a graph before creating a new controller.\n",
94 | "tf.ops.reset_default_graph()\n",
95 | "session = tf.InteractiveSession()\n",
96 | "\n",
97 | "# This little guy will let us run tensorboard\n",
98 | "# tensorboard --logdir [LOG_DIR]\n",
99 | "journalist = tf.train.SummaryWriter(LOG_DIR)\n",
100 | "\n",
101 | "# Brain maps from observation to Q values for different actions.\n",
102 | "# Here it is a done using a multi layer perceptron with 2 hidden\n",
103 | "# layers\n",
104 | "brain = MLP([4,], [10, 4], \n",
105 | " [tf.tanh, tf.identity])\n",
106 | "\n",
107 | "# The optimizer to use. Here we use RMSProp as recommended\n",
108 | "# by the publication\n",
109 | "optimizer = tf.train.RMSPropOptimizer(learning_rate= 0.001, decay=0.9)\n",
110 | "\n",
111 | "# DiscreteDeepQ object\n",
112 | "current_controller = DiscreteDeepQ((4,), 4, brain, optimizer, session,\n",
113 | " discount_rate=0.9, exploration_period=100, max_experience=10000, \n",
114 | " store_every_nth=1, train_every_nth=4, target_network_update_rate=0.1,\n",
115 | " summary_writer=journalist)\n",
116 | "\n",
117 | "session.run(tf.initialize_all_variables())\n",
118 | "session.run(current_controller.target_network_update)\n",
119 | "# graph was not available when journalist was created \n",
120 | "journalist.add_graph(session.graph_def)"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 10,
126 | "metadata": {
127 | "collapsed": false
128 | },
129 | "outputs": [
130 | {
131 | "name": "stdout",
132 | "output_type": "stream",
133 | "text": [
134 | "Game 9900: iterations before success 12. Pos: (-3, 7), Target: (-3, 7)\n"
135 | ]
136 | }
137 | ],
138 | "source": [
139 | "performances = []\n",
140 | "\n",
141 | "try:\n",
142 | " for game_idx in range(10000):\n",
143 | " game = DiscreteHill()\n",
144 | " game_iterations = 0\n",
145 | "\n",
146 | " observation = game.observe()\n",
147 | "\n",
148 | " while game_iterations < 50 and not game.is_over():\n",
149 | " action = current_controller.action(observation)\n",
150 | " reward = game.collect_reward(action)\n",
151 | " game.perform_action(action)\n",
152 | " new_observation = game.observe()\n",
153 | " current_controller.store(observation, action, reward, new_observation)\n",
154 | " current_controller.training_step()\n",
155 | " observation = new_observation\n",
156 | " game_iterations += 1\n",
157 | " performance = float(game_iterations - (game.shortest_path)) / game.shortest_path\n",
158 | " performances.append(performance)\n",
159 | " if game_idx % 100 == 0:\n",
160 | " print \"\\rGame %d: iterations before success %d.\" % (game_idx, game_iterations),\n",
161 | " print \"Pos: %s, Target: %s\" % (game.position, game.target),\n",
162 | "except KeyboardInterrupt:\n",
163 | " print \"Interrupted\""
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 11,
169 | "metadata": {
170 | "collapsed": false
171 | },
172 | "outputs": [
173 | {
174 | "data": {
175 | "text/plain": [
176 | "[]"
177 | ]
178 | },
179 | "execution_count": 11,
180 | "metadata": {},
181 | "output_type": "execute_result"
182 | },
183 | {
184 | "data": {
185 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEACAYAAABbMHZzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XuUVNWZ9/HvIyCKN0QUubRCBFQUBVEkxkiZMV4YhpmM\nGp1JIJoJuhwTL1HH6wyYiTHrXXFwjEYxGkOMUTJoFBW8UxqNooaLCCKiEAG5CrZyE7r7ef/YVVR1\ndXV3dXdVn7r8PmvVqlPn7D77qd3dzzm1z6m9zd0REZHKslvUAYiISPtT8hcRqUBK/iIiFUjJX0Sk\nAin5i4hUICV/EZEK1DGXQma2HPgcqAV2uvvwjO0x4Ango8SqR939p/kLU0RE8imn5A84EHP3jU2U\nedndx+QhJhERKbCWdPtYG7eLiEiRyDX5O/CCmb1tZuMb2X6Smc03sxlmNih/IYqISL7l2u3zNXdf\nbWYHAs+b2WJ3/3Pa9jlAlbtvNbOzgMeBgfkOVkRE8sNaOraPmU0ANrv7bU2UWQYMS79GYGYaREhE\npBXcPe/d6s12+5hZFzPbJ7G8F3A6sCCjTA8zs8TycMJBpcHFYXfXw50JEyZEHkOxPNQWagu1RdOP\nQsml26cH8KdEbu8IPOTuz5nZxYmEPhk4B7jEzGqArcD5BYpXRETyoNnk7+7LgCFZ1k9OW74LuCu/\noYmISKHoG74RiMViUYdQNNQWKWqLFLVF4bX4gm+rKzLz9qpLRKRcmBkexQVfEREpP0r+IiIVSMlf\nRKQCKfmLiFQgJX8RkQqk5C8iUoGU/EVEKpCSv4hIBVLyFxGpQEr+IiIVKPLkv3p11BGIiFSeSJP/\n8uXQqxesXx9lFCIilSfS5L9uXXhesiTKKEREKk8kyf+aa2DLFqirC6/HZ5sSXkRECianIZ3NbDnw\nOVAL7HT34VnK3AGcRZjJ6wJ3n5ux3d2dTZugWzc48EDYsQOqq8P2rVuhY0fo1Kmtb0lEpHwUakjn\nXKZxBHAg5lnm5QUws1FAf3cfYGYnAncDI7KV3ZjYQ2Y/f5cu4Xn+fDjmmByjEhGRVmlJt09TR54x\nwBQAd58NdDWzHtkK9u/fdCXHHdeCiEREpFVyTf4OvGBmb5tZth763sCKtNcrgT6ZhXbubL6iu+/O\nMSIREWm1XLt9vubuq83sQOB5M1vs7n/OKJP5yaDBxYTdd5+4a/mWW2IMHRrjlVfg//4vnPHvuy9Y\n3nu2RERKRzweJx6PF7yeFs/ha2YTgM3uflvaunuAuLs/kni9GBjp7mvTynjyePD22zBsWMN9X3op\nDBoUnkVEJMI5fM2si5ntk1jeCzgdWJBRbDowLlFmBPBZeuJPuvlmcM+e+AH22AO2bWtR/CIi0grN\nnvmbWT/gT4mXHYGH3P1WM7sYwN0nJ8rdCZwJbAEudPc5Gfvx1audgw9uvK4TT4Q33wwHCBERKdyZ\nf4u7fVpdkZlXVzv77tt4md/+Fi68UMlfRCSpLJL/zp1OxyYuMW/bFu7337ABDjigXcISESlqkfX5\n51NTiR9gzz3D84oVTZcTEZG2iXxI52yGDoWRI1NDP4iISH4VZfIHeOUVeOedqKMQESlPRZf8p0xJ\nLX/5ZXRxiIiUs6JL/uPGpZaV/EVECqPokj9A377heceOSMMQESlbRZn8J0wIzzU10cYhIlKu2vU+\n/5bUZQYDB8L77xcwKBGRIhf1ZC7t7vDDYfDgqKMQESlPRX3mD2GeXw3zLCKVqiy+4dsSP/pReN5r\nrzC/r4iI5E/RJv+rrgrP27bB9OnRxiIiUm6Kttsn/ExqWSN9ikglqrhuH4Brr00tb94cXRwiIuWm\nqJP/rbfCs8+G5SefjDYWEZFyklPyN7MOZjbXzBqkYDOLmVl1YvtcM7spX8GZwemnw+jRuuNHRCSf\ncj3zvxxYRHIG9oZedvehicdP8xNayqBBYXpHERHJj1wmcO8DjALuAxo7/y7oefnxx8OkSbroKyKS\nL7mc+U8CrgHqGtnuwElmNt/MZpjZoLxFl/DNb4ZnTe4iIpIfTQ7vYGajgXXuPtfMYo0UmwNUuftW\nMzsLeBwYmK3gxIkTdy3HYjFiscZ2WV/XrtCvH6xfH5ZFRMpVPB4nHo8XvJ4m7/M3s58BY4EaYA9g\nX+BRdx/XxM8sA4a5+8aM9S2+z7/+z0OvXrBqVat3ISJSciK5z9/db3D3KnfvB5wPvJSZ+M2sh1m4\nF8fMhhMOKBuz7K7NOncuxF5FRCpPS+/zdwAzu9jMLk6sOwdYYGbzgNsJB4m8mzULevcuxJ5FRCpP\nUQ/vkG7JEhg1CpYuzWNQIiJFriKHd0jXqRPs3Bl1FCIi5aFkkr8ZfPxxGOVTRETapmSSf/Ksf/Xq\naOMQESkHJZP8998/PI9r9CZTERHJVckk/+7dw62e6vcXEWm7kkn+AL/7nZK/iEg+lFTyHzZMt3qK\niORDSSX/vn3hiy9g0aKoIxERKW0llfw7dIATToAFC6KORESktJVU8gcYOhQ2bYo6ChGR0lZyyX+f\nfULXj4iItF7JJf+994bNm6OOQkSktJVc8t9nH/j886ijEBEpbSWX/Hv2hE8+iToKEZHSVnLJ/5BD\nYMWKqKMQESltJZf8q6rC6J4iItJ6OSV/M+tgZnPN7MlGtt9hZh+Y2XwzG5rfEOvr1QvWrdMwDyIi\nbZHrmf/lwCIS0zimM7NRQH93HwBcBNydv/Aa6tQJevRQv7+ISFs0m/zNrA8wCrgPyDaV2BhgCoC7\nzwa6mlmPfAaZSV0/IiJtk8uZ/yTgGqCuke29gfRLsCuBPm2Mq0m9eunMX0SkLTo2tdHMRgPr3H2u\nmcWaKprxOutM7RMnTty1HIvFiMWa2mXj9t1XX/QSkfIUj8eJx+MFr8fcs+bpsNHsZ8BYoAbYA9gX\neNTdx6WVuQeIu/sjideLgZHuvjZjX95UXS1x2WXQv394FhEpZ2aGu2frcm+TJrt93P0Gd69y937A\n+cBL6Yk/YTowLhHkCOCzzMSfb927w9qC1iAiUt5aep+/A5jZxWZ2MYC7zwA+MrOlwGTg3/MbYkM9\nesCGDYWuRUSkfDXZ55/O3V8GXk4sT87Y9sM8x9Wk5OBur7wCp5zSnjWLiJSHkvuGL4TB3WbPhpEj\nIU+XEUREKkpJJv8DD4QPPwzLut9fRKTlSjL59+yZWl6zJro4RERKVcknf33ZS0Sk5Uoy+XfuDA8+\nGJa3bIk2FhGRUlSSyR9g9OjwvHhxtHGIiJSikk3+XbvC1VdraGcRkdYo2eQPcNhhsHBh1FGIiJSe\nkk7+778PTz8ddRQiIqWnpJP/GWeE59raaOMQESk1TY7qmdeK8jiqZ/39wrJl0Ldv3nctIhK5SEb1\nLAVdu0J1ddRRiIiUlpJP/oMGwRdfRB2FiEhpKfnkv2wZvP561FGIiJSWkk/+I0bAp59GHYWISGlp\nNvmb2R5mNtvM5pnZIjO7NUuZmJlVm9ncxOOmwoTb0FFHQZcu7VWbiEh5aHYyF3ffbmanuvtWM+sI\nvGpmJ7v7qxlFX3b3MYUJs3GdO8O2be1dq4hIacup28fdtyYWdwc6ABuzFMv7rUi56NwZduyIomYR\nkdKVU/I3s93MbB6wFpjl7osyijhwkpnNN7MZZjYo34E2pnNn+PLL9qpNRKQ85HrmX+fuQ4A+wClm\nFssoMgeocvdjgV8Cj+c1yiYo+YuItFzOE7gDuHu1mT0NHA/E09Z/kbY808x+ZWbd3L1e99DEiRN3\nLcdiMWKxWOuiTqPkLyLlJB6PE4/HC15Ps8M7mFl3oMbdPzOzPYFngZvd/cW0Mj2Ade7uZjYc+KO7\n983YT0GGd5g2Df7wB3jssbzvWkQkcoUa3iGXM/+ewBQz243QTfSgu79oZhcDuPtk4BzgEjOrAbYC\n5+c70Mb06QMrV7ZXbSIi5aHkB3ZbuRKGD9dcviJSngp15l/yyb+mJvT7a1hnESlHGtWzER07Ql0d\nbN8edSQiIqWj5JN/ki74iojkriyS//e+p9s9RURaouT7/AGOPhqWLNEwDyJSftTn34QTToCdO6OO\nQkSkdJRF8r/0UhgyJOooRERKR1l0+6xbBz16hLt+LJKxRUVECkPdPk046KDwrBm9RERyUxbJH+CY\nY+Djj6OOQkSkNJRN8j/0UPjb36KOQkSkNJRN8j/kECV/EZFclU3yP/RQdfuIiOSqrJK/zvxFRHJT\nNslf3T4iIrkri/v8IXWvf00NdOhQsGpERNpVJPf5m9keZjbbzOaZ2SIzu7WRcneY2QdmNt/MhuY7\nyFwcdBAccACsWhVF7SIipaXJ5O/u24FT3X0IcAxwqpmdnF7GzEYB/d19AHARcHehgm3OkUfCsmVR\n1S4iUjqa7fN3962Jxd2BDsDGjCJjgCmJsrOBrokJ3dvdV74CH34YRc0iIqWl2eRvZruZ2TxgLTDL\n3RdlFOkNrEh7vRLok78QczdwILz/fhQ1i4iUllzO/OsS3T59gFPMLJalWObFiPa5ipzhkEN0r7+I\nSC465lrQ3avN7GngeCCetmkVUJX2uk9iXQMTJ07ctRyLxYjFYrlHmoMBA8KkLiIipSoejxOPxwte\nT5O3eppZd6DG3T8zsz2BZ4Gb3f3FtDKjgB+6+ygzGwHc7u4jsuyroLd6AmzcCH37QnW1hnYWkfJQ\nqFs9mzvz7wlMMbPdCF1ED7r7i2Z2MYC7T3b3GWY2ysyWAluAC/MdZK66dYMvvoBFi+Coo6KKQkSk\n+JXNl7xS9cC998L48QWvSkSk4Ap15l92yX/UKJg5E9rpbYmIFJSSf44WLw5f9qqthd3KZuQiEalU\nmsYxR0ccEYZ52LAh6khERIpX2SV/gJ49YfXqqKMQESleZZn8e/VS8hcRaUrZ9fmHusKzLvqKSKlT\nn38LHH101BGIiBS3skz+DzwAQyOZVUBEpDSUZfI/9FBYvjzqKEREildZJv/u3cN0jps2RR2JiEhx\nKsvkbwaHHQZLl0YdiYhIcSrL5A/QuTMMHx51FCIixalsk/9110FVVfPlREQqUdkm/6qq0PcvIiIN\nlW3y33132LEj6ihERIpT2Sb/zp3hyy+jjkJEpDg1m/zNrMrMZpnZQjN718wuy1ImZmbVZjY38bip\nMOHmTmf+IiKNy2UC953Ale4+z8z2Bv5qZs+7+3sZ5V529zH5D7F1unSBLVuijkJEpDg1e+bv7mvc\nfV5ieTPwHtArS9GimjL9wAPDmf+6dVFHIiJSfFrU529mfYGhwOyMTQ6cZGbzzWyGmQ3KT3itZwbH\nHgvz50cdiYhI8cml2weARJfPNODyxCeAdHOAKnffamZnAY8DAzP3MXHixF3LsViMWCzWipBz16MH\nVFcXtAoRkbyKx+PE4/GC15PTeP5m1gl4Cpjp7rfnUH4ZMMzdN6ata7fx/JN+8AM48UQYP75dqxUR\nyZvIxvM3MwPuBxY1lvjNrEeiHGY2nHBQ2ZitbHvq3RtWrow6ChGR4pNLt8/XgO8C75jZ3MS6G4BD\nANx9MnAOcImZ1QBbgfMLEGuL9esHL70UdRQiIsWn2eTv7q/SzCcEd78LuCtfQeVLv36wbFnUUYiI\nFJ+ynMM3aeXKMMbPmjXh4q+ISKnRHL6t0KdPeH7uuWjjEBEpNmWd/AHOPhs6dYo6ChGR4lL2yb+q\nClatijoKEZHiUvbJf//9NZeviEgmJX8RkQpU9sm/e3f49NOooxARKS5ln/z33RemTo06ChGR4lL2\nyb9///C8fXu0cYiIFJOyT/6HHw49e6rrR0QkXdknfwj9/prURUQkpSKSf00NTJkSdRQiIsWjrMf2\nSdUdniOqXkSk1TS2TxsUeMIwEZGSUxHJf9q08HzPPdHGISJSLCoi+R9wAFx6KSxfHnUkIiLFIZdp\nHKvMbJaZLTSzd83sskbK3WFmH5jZfDMbmv9Q22b0aHjhhaijEBEpDrmc+e8ErnT3o4ARwKVmdmR6\nATMbBfR39wHARcDdeY+0jQ47DP76V9iyJepIRESi12zyd/c17j4vsbwZeA/olVFsDDAlUWY20NXM\nimrurAEDwvPChdHGISJSDFrU529mfYGhwOyMTb2BFWmvVwJ92hJYIYwcCc88E3UUIiLRa3YC9yQz\n2xuYBlye+ATQoEjG6wZ31U+cOHHXciwWI9bO92COHw/Tp7drlSIiLRKPx4nH4wWvJ6cveZlZJ+Ap\nYKa7355l+z1A3N0fSbxeDIx097VpZSL7klfSs8/CbbdpTl8RKR2RfcnLzAy4H1iULfEnTAfGJcqP\nAD5LT/zFYq+94LXXoo5CRCR6zZ75m9nJwCvAO6S6cm4ADgFw98mJcncCZwJbgAvdfU7GfiI/86+u\nhq5d4bPPYL/9Ig1FRBqxfTvssUfUURSPQp35V8TYPvXjgEmT4Ioroo5EpDJ88UWYVGnrVthzz+xl\n3GG3tH6IurrUmFyVTmP75NGVV4Y/rN/+NupIRMrXlCmwc2dI/ADf/nbjZWfOrP/6qqsKF5cEFZf8\n0yd1ufDC6OIQKVfjx4eTqwsugEcegQ4d4M474amnGh9Z95NP4OSTU9snTWq3cCtWxSX/bt1C3//r\nr8PAgVFHI1Je3n4b7rsv9XrcODj66NSJ1qZN0K9f/TJjx4YDxquvhtfJLtnq6vaJuVJVXPKH8DF0\nwABYsgQ2bIg6GpHyMXZseB4yJNWHP38+dOkChx4KK1aEARZffDFse+IJ+P3vw3Jy9N3k14E0Cm9h\nVdwF36TkBaZZszTev0hbrF8P550X/pcg3FK9eTNs2xaS/rPPwumnN7yAe+ONMGMGzJ0b9tG9e2pb\nsuwDD4RPDZs3h/0Wo/feC/OEd+1amP3rbp8CGDcOTj1Vff8irbFxIxxySMPBEt98E044oWH5xu7e\nuffe0O2TbuvWhsm+yNLHLmYhlxRqqljd7VMA/fvD0qVRRyFSOmprQ7IzC/NkJBP/P/4jzJkTzuCz\nJX6ABQtCf797mFc7Kdv/YJcuDZP/li3hLqBzz4V2GP2gWb/8ZeqAVlcXbSytUdFn/r//PTz9NDz8\ncNSRSCVbvx4OPDDqKHIzdSqcf379dRs2hANBS/3kJzBhQsMun6SNG8Nt2ZMmZd9/FOnk3XfhoovC\nbeKHH15/W21t/e8q5IvO/Atgr73CtwmltHzwQWH3v2ULHHxww/V//Svccgs8+ij8wz+E2xfTtebs\nb8YMOOigcAHULNSRrczNN7d83/nkHrpzpk0Lt24eeWS4NbO2tnWJH+C//ivsN1vih3Bn3pQp4fmf\n/imsSz/wvPlm6+ptiW98I/VJxwwGDw53CiYT/wMPwGOPheWVK+HDD0OblAR3b5dHqKq4PP64+5gx\nUUchLbH77u4hZbg//LD7GWe4X3GF+8SJ9ctNmeL+xRetq2PRorD/hx4Kr5P1DRyYWk4+3nknlJk+\nPbxetKhldWXuD9x37Ehtr61Nra+tbd37yWbbNvcHH2y6zJYt7qtXh+U77kjFscceIZZ8xpOL2bPD\n86ZN9dtr+/bC1Hfuudl/P7vtllpOyizz85+7b9zoXl3d9jgSuTP/ObkQO81aUREm/1tvrf8LlOLz\nt7+533KL+5NPuq9fX/8fbMCAhv90H3zgftttqdd1dS2vMz3RLVyYPQEsX55a/s536m+rq3P/7LPc\n6gL3vn3db7jBfezY1D4uv9x9/PiG7y3dnXe6v/9+83Xs3JlqhzfeaPheLrigfvlly9xXrEhtnzcv\n/A6Sr4cMye29FdJPftKwzdsivY3eeiu0f/q+H3ggJPTkwWbHjvqJ/YYbsv+dVFW1LS53Jf+C+MEP\nQgts2RJ1JKVl0SL3PffMzx92UzZsyP4PdeWV2ddne3z1q03XsXVrarmuzv3zz927dnUfOrT+fl54\nwRuclY8c2XTdzXn1VfcjjnB/6qnUumSCaWqf6Z8GTj65+XqSZbt0aXy/3/mO+69+Vf+glvn4xS+a\nr6u91NS4v/uu+4cfhtjacoa9alXj73nTptz2kTxYrljh/tprqZ+/667Wx5Wk5F8AyY/qP/pR1JEU\nr2wfqfv3T/1xX3FF4eq+//7s/5AffOA+a5bvOgtzD4l74sRUmQUL3P/u78Ly559n3/977/muM1v3\n+nWce244e4fwz9yYjz5K/czOnfX3kewy2bYtJJGnngpxJyXLvfRS/X1efHHDhH/ZZWE5W5s0Zd68\nhuWrq90feyzE9OWXTR9s0n/+mmuarisqnTuH32VrZTuIH364+403tm5/NTWp/axc2fq4kpT8C+Sm\nm9x//OOooyhOffqk/oivuy6se/LJ8PrNN1PbTjklbLvppuaTUa527nQfPNj9qqta/rPJj++1te49\ne7q/+GLoIknGm7wW8Jvf1E+g6f/8X36Ze321teEf3j0kkm99y71jR/c//SmsGzy4/r6ffTZV3+jR\nDfvO6+rCCUlyn0np+3jwQfc5c+ofeJJeein8TpLdmslHv37Zu0fWrg3JLjMBnnBC2J781NPUQTBK\nENq7tQ491H3SpLC8ZYv7PffkJay8UfIvkJtvDq1wxhlRR1J46QmiOelntNkedXX1E2rmY+HCtnWn\nHXxw2M+117Z+H+7uZ5/dMPkmDwCjR6deP/powyTaFsn9/uUv2dunX7/wnN7t1JwtW9w7dQpJ3919\n6dLU/tLbOrOueDy3/a9Z437eeeHaykMP1f+UUsx+9rPwPj/6qOU/m+xCe+SR/MeVL0r+BbJmTeqf\npFev8JzLRbRSkt6Pm4u6uvrlb7yxfjLZti2sT2+7xhJc0mefhcTS2B0is2aFRL9xo3vv3ql9PPFE\nq9+2u7sfd1xqX9Omhf1BqmslfXs+/0TPO6/+fh97LKz//PP81VdX53766WE/n3xS/32B+5lnhusm\n5S69m6W5O5gyffhh+IRbzCJL/sBvgLXAgka2x4BqYG7icVMj5QrXOm20fn34iFuIJFAMxoxJva/0\nM9u333Y/9VT3556rf3fK//xPKHvLLeH1xx+nfv7UU7PXsWZNavnVV1PlZ8xw/+d/Tr2++273005z\nnzq1/s8PG9bw4JE8yLTFmjXu++wT7nJJSq+jtjZ1cLzhhrbXl/TMM2Gf3buHC4qZZs1ynzkzP3WB\n+/nn139f+foEUyomT67//u+8M6z/3e/qX5eqqwsJ/7XX3H/5y1B20KBoYs5VlMn/68DQZpL/9Bz2\nU6CmyY/Nm92///1UN9D69WH9zJnuZu7/+Z8hSZQicD/66PoHtqbucMh2AGzpQXHHjub3X1eX/c6d\n//7vcCZbKBdcEOpJv56QfhdPqUm2W/onpkqU+Xd00kmp5bVrc/9bLzaRdvsAfZtJ/k/msI/CtEye\npXd5ZD7OOSfq6JpXV9ewrx3CGXAy0cZiqffUrVvDLgpo+AWpefNa3qf61lup/a1Y4X7MMU3/A7bX\n2er27eHTQLl0iSS/H3DGGfn7NFGKXnnF/ac/DXczNXfikXy8+mq4ZbSYFXPyHwl8CswHZgCDGilX\nsMbJt+RtdcnH1VeHs4jdd3d//fWoo2tc+v3fEBLctdf6ru6N5O2RyUd6N8f27e4nnpj6xJMvdXWp\nfv733gv97rW14UszELqE5s0r7nYtdjt2hGsK5XIwy4elS92XLKn/CXS//dwXL07dgrtgQbQx5qpQ\nyT+ngd3MrG/i7H5wlm37ALXuvtXMzgL+190bzJFlZj5hwoRdr2OxGLEiHUh/2zZYty6M8dKvX5h0\n+umnYfToVJk774RLLinMQE65+PWvQ93/9m+pdSNHwiuvZC+f/DW/9FIYlGrKFE2QLZXjj3+Eo44K\nj2IXj8eJpw1bevPNN+NRjeffVPLPUnYZMMzdN2as91zqKmbLl8O//Au88UZq3de/Hiar2HPPMLBX\nXR107FjYOGpqoFOnsJwc9bBXrzAP6n33wYgRYXKJqVNhv/3CYFhRHaREpG0incylmTP/HsA6d3cz\nGw780d37ZilX8sk/qbo6jESYHJP8lFNCEv7Wt8LrqVPh298uXP0PPQTf/W4YN/2tt+pvK5MmFpGE\nyJK/mT1M6NfvTrjlcwLQCcDdJ5vZpcAlQA2wFfixu7+RZT9lk/yT0s/A033lK2Fo13wzg2OPDXOi\nptu+PcyP+vLLDccYF5HSpmkci9SVV8Ltt4flSy8Nk08fdlgYo33p0jA/6ZAhDX9u06Zw4Nh779zq\nca/fdfOLX8DVV8PMmXDmmW1+GyJSpJT8i5x76gJq5oXUTz8NE1IkrVsHPXqkfu6hh2D//WHUqMb3\nn9zn88+HCecLfV1BRIqDZvIqcukJ//77w10Fye6ZzJmOqqtTyz17hv77v//7+p8Qrr8+7HPlytS6\nbt3gtNOU+EWk7ZT8C+D73w9zfQ4eDNdcE9YtXBiS+Z//DDt3whFHhPVr1oTnsWPDweLhh8P2n/88\nrK+qChNFQ3FMWi0i5UHdPu2gsfvpd+4M9x+ffTZ07tyw3Ne/Hg4WSXV1ujdfpNKo26eEJe8IOuec\n1Lrjjw/dN//6ryHxA8yZk9p+5ZXhC1vJ20nvvVeJX0TyR2f+7aCmJnxbeJ99cvuy1WuvwUknKdmL\niO72ERGpSOr2ERGRvFHyFxGpQEr+IiIVSMlfRKQCKfmLiFQgJX8RkQqk5C8iUoGU/EVEKlCzyd/M\nfmNma81sQRNl7jCzD8xsvpkNzW+IIiKSb7mc+T8ANDpdiJmNAvq7+wDgIuDuPMVWtuIannMXtUWK\n2iJFbVF4zSZ/d/8zsKmJImOAKYmys4GuiXl9pRH6w05RW6SoLVLUFoWXjz7/3sCKtNcrgT552K+I\niBRIvi74Zg46pBHcRESKWE6jeppZX+BJdx+cZds9QNzdH0m8XgyMdPe1GeV0QBARaYVCjOqZj9lg\npwM/BB4xsxHAZ5mJHwoTvIiItE6zyd/MHgZGAt3NbAUwAegE4O6T3X2GmY0ys6XAFuDCQgYsIiJt\n126TuYiISPFol2/4mtmZZrY48UWwa9ujzvZkZlVmNsvMFprZu2Z2WWJ9NzN73syWmNlzZtY17Weu\nT7THYjM7PW39MDNbkNj2v1G8n3wwsw5mNtfMnky8rsi2MLOuZjbNzN4zs0VmdmIFt8X1if+RBWb2\nBzPrXClPzWGEAAADYUlEQVRtke3Lsvl874m2nJpY/4aZHdpsUO5e0AfQAVgK9CV0F80Djix0ve35\nAA4GhiSW9wbeB44E/h/wH4n11wI/TywPSrRDp0S7LCX1KexNYHhieQZwZtTvr5Vt8mPgIWB64nVF\ntgXhOzDfTyx3BParxLZIvJ+PgM6J11OB71VKWwBfB4YCC9LW5e29A/8O/CqxfB7wSLMxtcOb/irw\nTNrr64Drov5lFPg9Pw6cBiwGeiTWHQwsTixfD1ybVv4ZYATQE3gvbf35wD1Rv59WvP8+wAvAqYS7\nxKjEtkgk+o+yrK/EtuhGOCnan3AQfBL4ZiW1RSKRpyf/vL33RJkTE8sdgfXNxdMe3T7ZvgTWux3q\njUTittihwGzCLzZ559NaIPnN516EdkhKtknm+lWUZltNAq4B6tLWVWJb9APWm9kDZjbHzH5tZntR\ngW3h7huB24CPgU8IdwU+TwW2RZp8vvddedbda4BqM+vWVOXtkfwr5oqyme0NPApc7u5fpG/zcEgu\n+7Yws9HAOnefS8Mv/wGV0xaEM7DjCB/HjyPcDXddeoFKaQszOwy4gnD22wvY28y+m16mUtoimyje\ne3sk/1VAVdrrKuofvcqCmXUiJP4H3f3xxOq1ZnZwYntPYF1ifWab9CG0ySrqD43RJ7GulJwEjDGz\nZcDDwDfM7EEqsy1WAivd/a3E62mEg8GaCmyL44G/uPuniTPTxwhdwpXYFkn5+J9YmfYzhyT21RHY\nL/Fpq1HtkfzfBgaYWV8z251wMWJ6O9TbbszMgPuBRe5+e9qm6YSLWiSeH09bf76Z7W5m/YABwJvu\nvgb4PHFHiAFj036mJLj7De5e5e79CH2SL7n7WCqzLdYAK8xsYGLVacBCQn93RbUFoX97hJntmXgP\npwGLqMy2SMrH/8QTWfZ1DvBis7W304WOswgXe5YC10d94aUA7+9kQv/2PGBu4nEm4SLXC8AS4Dmg\na9rP3JBoj8XAGWnrhwELEtvuiPq9tbFdRpK626ci2wI4FngLmE84292vgtviPwgHvwWEu6A6VUpb\nED4FfwLsIPTNX5jP9w50Bv4IfAC8AfRtLiZ9yUtEpAJpGkcRkQqk5C8iUoGU/EVEKpCSv4hIBVLy\nFxGpQEr+IiIVSMlfRKQCKfmLiFSg/w+W9pFU+SNQPAAAAABJRU5ErkJggg==\n",
186 | "text/plain": [
187 | ""
188 | ]
189 | },
190 | "metadata": {},
191 | "output_type": "display_data"
192 | }
193 | ],
194 | "source": [
195 | "N = 500\n",
196 | "smooth_performances = [float(sum(performances[i:i+N])) / N for i in range(0, len(performances) - N)]\n",
197 | "\n",
198 | "plt.plot(range(len(smooth_performances)), smooth_performances)"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 19,
204 | "metadata": {
205 | "collapsed": false
206 | },
207 | "outputs": [
208 | {
209 | "data": {
210 | "text/plain": [
211 | "1.6869445151941282"
212 | ]
213 | },
214 | "execution_count": 19,
215 | "metadata": {},
216 | "output_type": "execute_result"
217 | }
218 | ],
219 | "source": [
220 | "np.average(performances[-1000:])"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 231,
226 | "metadata": {
227 | "collapsed": false
228 | },
229 | "outputs": [
230 | {
231 | "data": {
232 | "text/plain": [
233 | ""
234 | ]
235 | },
236 | "execution_count": 231,
237 | "metadata": {},
238 | "output_type": "execute_result"
239 | },
240 | {
241 | "data": {
242 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAJBCAYAAABRfpDFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGgJJREFUeJzt3X+wZGV95/HPhzsgDojEUAGBSU3WRYNGUYpfhUm8GLAm\nxJ8pE0MkGkxZVrJjTKqSqCQxQ2o3WZO4IVk2LiKiBCOpoBgMjmRivCpRR0aZAWeGBTRUZoZlQFkw\ngETuzHf/6AYu95zb/fTTp/ucnuf9quqy+/bpc773yvnO5zzn6XMcEQKAYQ5quwAAs4FmASAJzQJA\nEpoFgCQ0CwBJaBYAktAsgBlk+1Dbm21vtb3D9h/XLPNG29ts32L7X2y/aJxtrhrnwwDaERGP2j4r\nIh6xvUrSjbZ/PCJuXLLYtyT9ZEQ8aHudpA9IOiN3mzQLYEZFxCP9p4dImpN0/7L3v7zk5WZJx4+z\nPQ5DgBll+yDbWyXtlfS5iNgxYPFfkfTpcbZHswBmVETsj4gXq5cYftL2fN1yts+S9BZJ7xxnexyG\nABlsT/1LVRHhFX7+oO3rJZ0iaWHpe/1BzcskrYuI/zfO9mkWQKYNLW7L9lGSFiPiAdtPl3SOpIuW\nLfPDkj4h6fyIuHPcGmgWwGx6tqSP2D5IveGEv46Iz9p+myRFxKWS3iPpByS937YkPRYRp+Vu0HxF\nHRid7fivU9ze72nlw5BpYYATQBIOQ4BMB7ddwJSRLAAkIVkAmUrbeUgWAJKU1hyBxjBmAQA1SBZA\nptJ2HpIFgCSlNUegMYxZAEANmgWAJByGAJlK23lIFgCSlNYcgcYwwAkANUgWQKbSdh6SBYAkpTVH\noDGMWQBADZIFkIlkAQA1SBZAptJ2HpIFgCQ0CwBJSktSQGMY4ASAGiQLIFNpOw/JAkCS0poj0BjG\nLACgBskCyFTazkOyAJCktOYINIYxCwCoQbMAkITDECBTaTsPyQJAkk42C9vrbN9m+w7b72y7nmFs\nf8j2Xtu3tl1LKttrbH/O9nbb37D9623XNIjtQ21vtr3V9g7bf9x2TQdP8dEFnWsWtuckXSJpnaTn\nSzrP9ontVjXUFerVO0sek/SbEfECSWdI+i9d/jtHxKOSzoqIF0t6kaSzbP94y2UVpYuHXadJujMi\n7pIk21dLeo2knW0WNUhEfNH22rbrGEVE3CPpnv7zh2zvlHSsuv13fqT/9BBJc5Lub7GcTu48k9S5\nZCHpOEm7lrze3f8ZJqTf6F4iaXO7lQxm+yDbWyXtlfS5iNjRdk0l6WJzjLYLKIntwyVdI+kdEfFQ\n2/UMEhH7Jb3Y9jMl3WB7PiIW2qqnK2MJ09LFZLFH0polr9eoly7QMNsHS/q4pKsi4pNt15MqIh6U\ndL2kU9qupSRdTBZbJJ3Qj8Z3S3qDpPPaLOhAZNuSLpe0IyIubrueYWwfJWkxIh6w/XRJ50i6qM2a\nSBYti4hFSesl3SBph6S/jYjODrpJku2PSfqSpOfa3mX7grZrSvBSSeerd1bh5v6jy2d0ni3pn/tj\nFpslfSoiPttyTUVxBEMEwKhsxz1T3N4xkiLCU9xkReeSBYBuolkASNLFAU5gJhw8zb1ncYrbWgHJ\nAkASkgWQaVVhyWLsX9c2p1NwwGj7jEOXNdQb/6CZ1VQsSJqfwHrXTmCdkvRJSa+d0LpPntB63y/p\nVye07r0TWu+Vkt40gfW+YqSlD56bQAkdxpgFgCSMWQCZpjpm0QEdTxZr2y5gRD/adgEZZvG7WCe1\nXUCROt4b17ZdwIhmsVmc2nYBGbrRLKY6z6IDOp4sAHQFzQJAksKCFNAgTp0CQBXJAshV2N5DsgCQ\npLDeCDSosL2HZAHMoFFuP2n7VNuLtn92nG0W1huBBrW79zx++8mt/Xu/fM32puUXt+7fDvS9kj4j\naaxv1JIsgBkUEfdExNb+84fUu+3ksTWLvl29m0jdN+42SRZAro7Ms1jp9pO2j1PvPsEvV29e/1jX\nnhnaLPr3krhYvT/NByPiveNsEMBwC49IC98bvtyQ209eLOldERH9m0qNdRgy8L4h/eOd/yPpbPVu\nK3iTpPOWHhf1rpQ1qYvfTMratgvIMKmL30zSpC5+MymvSL5Slu2IEyddz5Lt7axexat/+8l/kLSx\n7q5ytr+lJxvEUZIekfTWiLgup4ZhyeI0SXdGxF39jV+tXqzp9B3CgANdyu0nI+I/LVn+CvXu4pbV\nKKThzeI4SbuWvN4t6fTcjQFozOO3n7zF9s39n10o6YclKSIubXqDw5oFF+MFVtLi6YGIuFEjnM2M\niLHvvzvs190jac2S12vUSxfLLCx5vlazOSaA8mzrP5BiWLPYIumE/qmZuyW9QdJ51cXmm60KmIqT\n9NSrbl012sc7cup0WgY2i4hYtL1e0g3q/WkuXz5DDEAZhh51RcRGSRunUAswWwqb0sh0bwBJCuuN\nQIMK23tIFgCSFNYbgQYVdjaEZAEgCc0CQBIOQ4Bche09JAsASQrrjUCDCtt7SBYAkhTWG4EGFbb3\nkCwAJCmsNwINYlIWAFSRLIBche09JAsASQrrjUCDCtt7SBYAkjTSG4+PNzaxmqnZ/boT2i5hZD99\n7SfaLmFkP6jvtF3CSK4a6+Z+B77CghTQIE6dAkAVyQLIVdjeQ7IAkKSw3gg0qLC9h2QBIElhvRFo\nEGdDAKCKZAHkKmzvIVkASFJYbwQaVNjeQ7IAkIRmASBJYUEKaFBhew/JAkCSwnoj0CAmZQFAFckC\nyFXY3kOyAJCksN4INKiwvYdkASBJYb0RaBBnQwCgimYBIAmHIUCuwvYekgWAJIX1RqBBhe09JAsA\nSQrrjUCDCtt7SBYAkhTWG4EGMSkLAKpIFkCuwvYekgWAJIX1RqBBhe09JAsASWgWAJIUFqSABnHq\nFACqSBZArsL2HpIFgCQ0CyDXqik+lrH9Idt7bd+6Unm2523fbPsbthfG/XVpFsBsukLSupXetH2k\npP8l6VUR8WOSXj/uBhs56nqP/rCJ1UzN3LX72i5hZG954cfaLmFkD2+ZrX+Lrhr1Ay2eDYmIL9pe\nO2CRX5T08YjY3V/+2+Nuc7b+3wSQ6gRJz7L9OdtbbP/SuCssbDwXaFC3956DJZ0s6ackrZb0Zdtf\niYg7clfY7V8XKNTCzt5jDLskfTsivifpe7a/IOkkSTQL4EAyf2Lv8biLrh15FX8v6RLbc5KeJul0\nSf9jnJpoFkCuFvce2x+T9DJJR9neJekP1Dv0UERcGhG32f6MpFsk7Zd0WUTsGGebNAtgBkXEeQnL\n/JmkP2tqmzQLIFdhew+nTgEkKaw3Ag3iK+oAUEWyAHIVtveQLAAkKaw3Ag0qbO8hWQBIUlhvBBrE\n2RAAqKJZAEjCYQiQq7C9Z2iysL2mf7Wd7f0Lf/76NAoD0C0pvfExSb8ZEVttHy7pa7Y3RcR4l+YA\nZh3J4qki4p6I2Np//pCknZKOnXRhALplpN7Yv5rwSyRtnkQxwEwp7NRpcrPoH4JcI+kd/YTxhOs2\nbHvi+fPmj9bz5o9prEBgUr7w+dAXvxBtlzEzHDH8j2X7YEn/IGljRFy87L34QJw/ofImY07cN2Qa\nZu2+IYcful8R4ZRlbUd8ZdIVLdneGUqubVJSzoZY0uWSdixvFADKkXIY8lJJ50u6xfbN/Z+9OyI+\nM7mygBlQ2NmQob9uRNwoZnoCxaMJAEhSWJACGlTY3kOyAJCksN4INKiwSVkkCwBJSBZArsL2HpIF\ngCSF9UagQYXtPSQLAEkK641AgzgbAgBVJAsgV2F7D8kCQBKaBYAkhQUpoEGF7T0kCwBJCuuNQIMK\n23sa+XX3+KomVjM1G17edgUZrmu7gNEdtnV/2yWgQYX1RqA5waQsAKgiWQCZ9hW295AsACQprDcC\nzSFZAECNwnoj0JzFuWn+W9v+aWiSBYAkNAsASTgMATLtWzXN3ef7U9xWPZIFgCQkCyDTvrmy5nuT\nLAAkIVkAmfYVdnlvkgWAJCQLINMiyQIAqkgWQKZ9he0+JAsASWgWAJKUlaOABnHqFABqkCyATCQL\nAKhBsgAykSwAoAbNAsi0qLmpPerYXmf7Ntt32H5nzftH2f6M7a22v2H7l8f5fWkWwAyyPSfpEknr\nJD1f0nm2T1y22HpJN0fEiyXNS3qf7eyhB8YsgEwtT/c+TdKdEXGXJNm+WtJrJO1cssz/lfSi/vMj\nJH0nIhZzN0izAGbTcZJ2LXm9W9Lpy5a5TNI/275b0jMk/fw4G6RZAJkmeTbkpoVHtGXhkUGLRMJq\nLpS0NSLmbT9H0ibbJ0XEv+fURLMAOujU+dU6dX71E6//90X3L19kj6Q1S16vUS9dLHWmpP8mSRHx\nTdv/Kul5krbk1MQAJzCbtkg6wfZa24dIeoOk65Ytc5uksyXJ9tHqNYpv5W6QZAFkanNSVkQs2l4v\n6QZJc5Iuj4idtt/Wf/9SSX8k6Qrb29QLBr8TEZWIkopmAcyoiNgoaeOyn1265Pm3Jb2qqe3RLIBM\nXIMTAGqQLIBMXIMTAGqU1RqBBvEVdQCoQbIAMpEsAKAGzQJAkkYOQz4cO4cv1CGffco3e2fDjZef\n03YJozu87QJG5ZGW5jAEAGowwAlkYro3ANQgWQCZmO4NADXKao1AgzgbAgA1SBZAJpIFANQgWQCZ\nmGcBADVoFgCScBgCZGJSFgDUKKs1Ag3i1CkA1CBZAJlIFjVsz9m+2fanJl0QgG5KTRbvkLRD0jMm\nWAswU5iUtYzt4yWdK+mDGvUihQAOGCnJ4s8l/bakIyZcCzBTSptnMfC3tf1KSfdGxM2251da7oEN\nlzzx/ND503To/GmNFQhMzPYFacdC21XMjGGt8UxJr7Z9rqRDJR1h+8qIeNPShY7csH5S9QGT84L5\n3uNxH7+orUpmwsBmEREXSrpQkmy/TNJvLW8UQKk4dTpYTKQKAJ2XPEITEZ+X9PkJ1gLMFJIFANQo\n69wP0CCSBQDUIFkAmZjuDQA1SBZAptKme5MsACQpqzUCDeJsCADUoFkASMJhCJCJwxAAqEGyADIx\nKQsAapAsgExMygKAGmW1RqBBnA0BgBqNJIu7PnxiE6uZmofPm70eedgr9rddwuiuaLuA0Yx6By2S\nBQDUYMwCyESyAIAaNAsASTgMATIx3RsAapAsgExM9waAGmW1RqBBnDoFMBNsr7N9m+07bL9zhWX+\nsv/+NtsvGWd7JAsgU5vJwvacpEsknS1pj6SbbF8XETuXLHOupP8cESfYPl3S+yWdkbtNkgUwm06T\ndGdE3BURj0m6WtJrli3zakkfkaSI2CzpSNtH526QZAFkanmexXGSdi15vVvS6QnLHC9pb84GSRbA\nbIrE5ZZ/mTb1cxUkC6CD7l64Q3cv3DlokT2S1ix5vUa95DBomeP7P8tCswAyTXJS1tHzJ+ro+Sev\nE/P1i25YvsgWSSfYXivpbklvkHTesmWuk7Re0tW2z5D0QERkHYJINAtgJkXEou31km6QNCfp8ojY\naftt/fcvjYhP2z7X9p2SHpZ0wTjbpFkAmdqelBURGyVtXPazS5e9Xt/U9hjgBJCEZAFkajtZTBvJ\nAkASkgWQiWQBADVIFkAmLqsHADVIFkAmLqsHADVoFgCSlJWjgAZx6hQAapAsgEwkCwCoQbIAMjEp\nCwBqkCyATEzKAoAaZbVGoEGcDQGAGjQLAEk4DAEycRgCADVIFkAmJmUBQI1mksUZjaxlav79ac9o\nu4SRHfbog22XMLq/a7uAyWJSFgDUKKs1Ag3ibAgA1CBZAJlIFgBQg2QBZCJZAEANmgWAJByGAJmY\n7g0ANUgWQCamewNAjbJaI9AgTp0CQA2SBZCJZAEANUgWQCbmWSxj+0jb19jeaXuH7Rm7LhaAJqQk\ni7+Q9OmIeL3tVZIOm3BNADpoYLOw/UxJPxERb5akiFiUNIMXgwSax6Ssp/oRSffZvsL2121fZnv1\nNAoD0C3DmsUqSSdL+quIOFnSw5LeNfGqgBmwT3NTe3TBsBy1W9LuiLip//oa1TSLDZc8+Xz+tN4D\n6LqFh6SFh9uuYnYMbBYRcY/tXbafGxG3Szpb0vbly21YP6nygMmZP7z3eNxF9432+a78iz8tKSM0\nb5f0UduHSPqmpAsmWxKALhraLCJim6RTp1ALMFP27S8rWTDdG0CSsk4UAw1aXCRZAEAFyQLItG+x\nrN2HZAEgCc0CQJKychTQoH0McAJAFckCyESyAIAaJAsg0+JjJAsAqCBZAJn27ytr9yFZAAcY28+y\nvcn27bb/0faRKyz3btvbbd9q+29sP23QemkWQK7Fuek9RvMuSZsi4rmSPquaq9vZXivprZJOjogX\nSpqT9AuDVkqzAA48r5b0kf7zj0h6bc0y35X0mKTV/Vt8rJa0Z9BKyzroAprU3XkWR0fE3v7zvZKO\nXr5ARNxv+32S/k3S9yTdEBH/NGilNAtgBtneJOmYmrd+d+mLiAjbUfP550j6DUlr1bsX0N/ZfmNE\nfHSlbTbSLOKlTaxlenz/7N0n6Y62C8gwMNNisK8sSJsXVnw7Is5Z6T3be20f07/g9rMl3Vuz2CmS\nvhQR3+l/5hOSzpQ02WYBFGnRk1v3KWf1Ho/7yz8c5dPXSXqzpPf2//eTNcvcJun3bT9d0qPqXbn/\nq4NWygAncOD575LOsX27pJf3X8v2sbavl564EPeVkrZIuqX/uQ8MWqkjKoczI7Ed+5811iqm7t77\n265gdN9tu4AMs3YYcpakiEiKC7ZD28fbd0byAifXNikkCwBJGLMAci22XcB0kSwAJCFZALlIFgBQ\nRbIAcj3WdgHTRbIAkIRmASAJhyFArn1tFzBdJAsASUgWQC5OnQJAFckCyEWyAIAqkgWQi2QBAFUk\nCyAXyQIAqkgWQC6SBQBU0SwAJOEwBMjFYQgAVJEsgFxcKQsAqkgWQC4ufgMAVSQLIBdnQwCgimQB\n5CJZAEAVzQJAEg5DgFwchgBAFckCyEWyAIAqkgWQi2QBAFUkCyAXyQIAqkgWQK7CLn7TSLM46P5o\nYjXTc0rbBWR4bdsFZLin7QJGdInbrqDTSBZALi5+AwBVNAsASTgMAXJx6hQAqkgWQC6SBQBUkSyA\nXCQLAKgiWQC5CpvuTbIAkIRkAeRiujcAVNEsACThMATIxalTAKgamixsv1vS+ZL2S7pV0gUR8R+T\nLgzoPJLFk2yvlfRWSSdHxAslzUn6hcmXBaBrhiWL76o39WS17X2SVkvaM/GqgFnApKwnRcT9kt4n\n6d8k3S3pgYj4p2kUBqBbhh2GPEfSb0haK+lYSYfbfuMU6gK6b98UHx0w7DDkFElfiojvSJLtT0g6\nU9JHn7rYhiXP5/sPoON2L0h7FtquYmYMaxa3Sfp920+X9KiksyV9tbrYhqbrAibv+Pne43E3XTTa\n5zkb8qSI2CbpSklbJN3S//EHJl0UgHy2f872dtv7bJ88YLkjbV9je6ftHbbPGLTeofMsIuJPJP1J\nRs3Aga27yeJWSa+TdOmQ5f5C0qcj4vW2V0k6bNDCTPcGDjARcZsk2SvfYc32MyX9RES8uf+ZRUkP\nDlov072BMv2IpPtsX2H767Yvs7160AdIFkCuSU7KundBum9hxbdtb5J0TM1bF0bEpxK2sErSyZLW\nR8RNti+W9C5J7xn0AQBd80Pzvcfjdj71TE1EnDPmFnZL2h0RN/VfX6Nes1gRzQLI1ZHJUkPUDlxE\nxD22d9l+bkTcrt60iO2DVsSYBXCAsf0627sknSHpetsb+z8/1vb1SxZ9u6SP2t4m6UWS/mjQekkW\nQK6OnjqNiGslXVvz87sl/cyS19sknZq6XpIFgCQkCyBXR5PFpJAsACQhWQC5uPgNAFTRLAAk4TAE\nyDUbk7IaQ7IAkIRkAeTi1CkAVJEsgFyFJYuGmsXmZlYzLVtmsEduqblOcuf9UNsFoEEzuNcAHcGk\nLACoIlkAuZhnAQBVJAsgV2FnQ0gWAJLQLAAk4TAEyMVhCABUkSyAXEzKAoAqkgWQi0lZAFBFsgBy\ncTYEAKpIFkAukgUAVJEsgFzMswCAKpoFgCQchgC5mJQFAFUkCyAXp04BoIpkAeQiWQBAFckCyMWk\nLACoIlkAuZhnAQBVNAsASTgMAXJF2wVMF8kCQBKaBYAkNAsASWgWAJLQLAAkoVkASEKzAJCEeRZA\ntrK+SdbxZPG1tgsY0Za2C8hwe9sFZPhG2wUUqePN4uttFzCiWWtu0mw2i+1tF9C3OMVH+zreLAB0\nBc0CQBJHjPdtGNuFfZ0GB7KIcMpyvf/uH5x0OUs8M7m2SRn7bEjbvwCA6eDUKZCtGwOP08KYBYAk\nJAsgG5OyAKCCZAFkI1kAQAXJAsjWzbMhtv9U0islfV/SNyVdEBG1k0Jsz6n3pabdEfGqQeslWQAH\nnn+U9IKIOEm9L/+8e8Cy75C0QwnXKqdZAAeYiNgUEfv7LzdLOr5uOdvHSzpX0gclDZ1cyWEIkG0m\nBjjfIuljK7z355J+W9IRKSuiWQAzyPYmScfUvHVhRHyqv8zvSvp+RPxNzedfKeneiLjZ9nzKNmkW\nQLZJDnB+tf+oFxHnDPq07V9W7xDjp1ZY5ExJr7Z9rqRDJR1h+8qIeNOK6xz3W6dAiXrfOt0xxS0+\nf5RvxK6T9D5JL4uIbycs/zJJvzXsbAjJAsjW2TGL/ynpEEmbbEvSlyPi12wfK+myiPiZms8MTQ0k\nCyBDL1lsm+IWT2r9chAkCyBbNydlTQrzLAAkIVkA2To7ZjERJAsASUgWQDbGLACggmYBIAmHIUA2\nBjgBoIJkAWRjgBMAKkgWQDbGLACggmQBZGPMAgAqSBZANsYsAKCCZgEgCYchQDYOQwCggmQBZOPU\nKQBUkCyAbIxZAEAFyQLIxpgFAFSQLIBsjFkAQAXJAsjGmAUAVNAsACThMATIxgAnAFSQLIBsDHAC\nQAXJAsjGmAUAVJAsgGyMWQBAhSOi7RqAmWN76jtORHja21yKZgEgCYchAJLQLAAkoVkASEKzAJCE\nZgEgyf8HmzwGuPTq2GoAAAAASUVORK5CYII=\n",
243 | "text/plain": [
244 | ""
245 | ]
246 | },
247 | "metadata": {},
248 | "output_type": "display_data"
249 | }
250 | ],
251 | "source": [
252 | "x = brain.layers[0].Ws[0].eval()\n",
253 | "import matplotlib.pyplot as plt\n",
254 | "%matplotlib inline\n",
255 | "plt.matshow(x)\n",
256 | "plt.colorbar()"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 138,
262 | "metadata": {
263 | "collapsed": false
264 | },
265 | "outputs": [
266 | {
267 | "data": {
268 | "text/plain": [
269 | "array([ 11.01352692, 11.28201485, 12.03692055, 12.26954937], dtype=float32)"
270 | ]
271 | },
272 | "execution_count": 138,
273 | "metadata": {},
274 | "output_type": "execute_result"
275 | }
276 | ],
277 | "source": [
278 | "brain.input_layer.b.eval()"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 88,
284 | "metadata": {
285 | "collapsed": false
286 | },
287 | "outputs": [
288 | {
289 | "data": {
290 | "text/plain": [
291 | "-2.0"
292 | ]
293 | },
294 | "execution_count": 88,
295 | "metadata": {},
296 | "output_type": "execute_result"
297 | }
298 | ],
299 | "source": [
300 | "game.collect_reward(0)"
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": 7,
306 | "metadata": {
307 | "collapsed": false
308 | },
309 | "outputs": [],
310 | "source": [
311 | "x = tf.Variable(tf.zeros((5,5)))"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 8,
317 | "metadata": {
318 | "collapsed": false
319 | },
320 | "outputs": [
321 | {
322 | "data": {
323 | "text/plain": [
324 | ""
325 | ]
326 | },
327 | "execution_count": 8,
328 | "metadata": {},
329 | "output_type": "execute_result"
330 | }
331 | ],
332 | "source": [
333 | "tf.clip_by_norm(x, 5)"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": null,
339 | "metadata": {
340 | "collapsed": true
341 | },
342 | "outputs": [],
343 | "source": []
344 | }
345 | ],
346 | "metadata": {
347 | "kernelspec": {
348 | "display_name": "Python 2",
349 | "language": "python",
350 | "name": "python2"
351 | },
352 | "language_info": {
353 | "codemirror_mode": {
354 | "name": "ipython",
355 | "version": 2
356 | },
357 | "file_extension": ".py",
358 | "mimetype": "text/x-python",
359 | "name": "python",
360 | "nbconvert_exporter": "python",
361 | "pygments_lexer": "ipython2",
362 | "version": "2.7.8"
363 | }
364 | },
365 | "nbformat": 4,
366 | "nbformat_minor": 0
367 | }
368 |
--------------------------------------------------------------------------------
/notebooks/game_memory.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import tempfile\n",
13 | "import tensorflow as tf\n",
14 | "\n",
15 | "from tf_rl.controller import DiscreteDeepQ, HumanController\n",
16 | "from tf_rl.simulation import KarpathyGame\n",
17 | "from tf_rl import simulate\n",
18 | "from tf_rl.models import MLP"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 20,
24 | "metadata": {
25 | "collapsed": false
26 | },
27 | "outputs": [
28 | {
29 | "name": "stdout",
30 | "output_type": "stream",
31 | "text": [
32 | "/tmp/tmp3rSEdd\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "LOG_DIR = tempfile.mkdtemp()\n",
38 | "print(LOG_DIR)"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 316,
44 | "metadata": {
45 | "collapsed": false
46 | },
47 | "outputs": [],
48 | "source": [
49 | "from random import randint, gauss\n",
50 | "\n",
51 | "import numpy as np\n",
52 | "\n",
53 | "class DiscreteHill(object):\n",
54 | " \n",
55 | " directions = [(0,1), (0,-1), (1,0), (-1,0)]\n",
56 | " \n",
57 | " def __init__(self, board=(10,10), variance=4.):\n",
58 | " self.variance = variance\n",
59 | " self.target = (0, 0)\n",
60 | " while self.target == (0, 0):\n",
61 | " self.target = (randint(-board[0], board[0]), randint(-board[1], board[1]))\n",
62 | " self.position = (0, 0)\n",
63 | " \n",
64 | " @staticmethod\n",
65 | " def add(p, q):\n",
66 | " return (p[0] + q[0], p[1] + q[1])\n",
67 | " \n",
68 | " @staticmethod\n",
69 | " def distance(p, q):\n",
70 | " return abs(p[0] - q[0]) + abs(p[1] - q[1])\n",
71 | " \n",
72 | " def estimate_distance(self, p):\n",
73 | " distance = DiscreteHill.distance(self.target, p) - DiscreteHill.distance(self.target, self.position)\n",
74 | " return distance + abs(gauss(0, self.variance))\n",
75 | " \n",
76 | " def observe(self): \n",
77 | " return np.array([self.estimate_distance(DiscreteHill.add(self.position, delta)) \n",
78 | " for delta in DiscreteHill.directions])\n",
79 | " \n",
80 | " def perform_action(self, action):\n",
81 | " self.position = DiscreteHill.add(self.position, DiscreteHill.directions[action])\n",
82 | " \n",
83 | " def is_over(self):\n",
84 | " return self.position == self.target\n",
85 | " \n",
86 | " def collect_reward(self, action):\n",
87 | " return -DiscreteHill.distance(self.target, DiscreteHill.add(self.position, DiscreteHill.directions[action])) \\\n",
88 | " + DiscreteHill.distance(self.target, self.position) - 2"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 329,
94 | "metadata": {
95 | "collapsed": false
96 | },
97 | "outputs": [
98 | {
99 | "name": "stderr",
100 | "output_type": "stream",
101 | "text": [
102 | "Exception AssertionError: AssertionError() in > ignored\n"
103 | ]
104 | }
105 | ],
106 | "source": [
107 | "n_prev_frames = 3\n",
108 | "\n",
109 | "# Tensorflow business - it is always good to reset a graph before creating a new controller.\n",
110 | "tf.ops.reset_default_graph()\n",
111 | "session = tf.InteractiveSession()\n",
112 | "\n",
113 | "# This little guy will let us run tensorboard\n",
114 | "# tensorboard --logdir [LOG_DIR]\n",
115 | "journalist = tf.train.SummaryWriter(LOG_DIR)\n",
116 | "\n",
117 | "# Brain maps from observation to Q values for different actions.\n",
118 | "# Here it is a done using a multi layer perceptron with 2 hidden\n",
119 | "# layers\n",
120 | "brain = MLP([n_prev_frames * 4 + n_prev_frames - 1,], [4], \n",
121 | " [tf.identity])\n",
122 | "\n",
123 | "# The optimizer to use. Here we use RMSProp as recommended\n",
124 | "# by the publication\n",
125 | "optimizer = tf.train.RMSPropOptimizer(learning_rate= 0.001, decay=0.9)\n",
126 | "\n",
127 | "# DiscreteDeepQ object\n",
128 | "current_controller = DiscreteDeepQ(n_prev_frames * 4 + n_prev_frames - 1, 4, brain, optimizer, session,\n",
129 | " discount_rate=0.9, exploration_period=100, max_experience=10000, \n",
130 | " store_every_nth=1, train_every_nth=4, target_network_update_rate=0.1,\n",
131 | " summary_writer=journalist)\n",
132 | "\n",
133 | "session.run(tf.initialize_all_variables())\n",
134 | "session.run(current_controller.target_network_update)\n",
135 | "# graph was not available when journalist was created \n",
136 | "journalist.add_graph(session.graph_def)"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 330,
142 | "metadata": {
143 | "collapsed": false
144 | },
145 | "outputs": [
146 | {
147 | "name": "stdout",
148 | "output_type": "stream",
149 | "text": [
150 | "Game 9900: iterations before success 16. Pos: (6, 6), Target: (6, 6)\n"
151 | ]
152 | }
153 | ],
154 | "source": [
155 | "performances = []\n",
156 | "\n",
157 | "try:\n",
158 | " for game_idx in range(10000):\n",
159 | " game = DiscreteHill()\n",
160 | " game_iterations = 0\n",
161 | "\n",
162 | " observation = game.observe()\n",
163 | " \n",
164 | " prev_frames = [(observation, -1)] * (n_prev_frames - 1)\n",
165 | " memory = np.concatenate([np.concatenate([observation, np.array([-1])])] * (n_prev_frames - 1) + [observation])\n",
166 | " \n",
167 | " while game_iterations < 50 and not game.is_over():\n",
168 | " action = current_controller.action(memory)\n",
169 | " if n_prev_frames > 1:\n",
170 | " prev_frames = prev_frames[1:] + [(observation, action)]\n",
171 | " reward = game.collect_reward(action)\n",
172 | " game.perform_action(action)\n",
173 | " observation = game.observe()\n",
174 | " new_memory = np.concatenate([np.concatenate([a, np.array([b])]) for (a, b) in prev_frames] + [observation])\n",
175 | " current_controller.store(memory, action, reward, new_memory)\n",
176 | " current_controller.training_step()\n",
177 | " memory = new_memory\n",
178 | " game_iterations += 1\n",
179 | " cost = abs(game.target[0]) + abs(game.target[1])\n",
180 | " performances.append((game_iterations - cost) / float(cost))\n",
181 | " if game_idx % 100 == 0:\n",
182 | " print \"\\rGame %d: iterations before success %d.\" % (game_idx, game_iterations),\n",
183 | " print \"Pos: %s, Target: %s\" % (game.position, game.target),\n",
184 | "except KeyboardInterrupt:\n",
185 | " print \"Interrupted\""
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 327,
191 | "metadata": {
192 | "collapsed": false
193 | },
194 | "outputs": [
195 | {
196 | "data": {
197 | "text/plain": [
198 | "[]"
199 | ]
200 | },
201 | "execution_count": 327,
202 | "metadata": {},
203 | "output_type": "execute_result"
204 | },
205 | {
206 | "data": {
207 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAEACAYAAABMEua6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGzdJREFUeJzt3XuUFPWd9/H3F2aG4X4xgiA35RijBCFuRPGymcQbxLia\nk+yqJ0qSx71kicTk8WxQkydgTtYVJXHDusnj8bau8QmwrCis4CWRUdYrLiLITVQMcr8Ml5GLzDDf\n549ftz0DM9M9M91d3dWf1zl9qrqquupbv+n59q9+9asqc3dERCQeOkUdgIiIZI+SuohIjCipi4jE\niJK6iEiMKKmLiMSIkrqISIykTepm1sfM5prZGjNbbWbn5SMwERFpu7IMlvk1sNDdv2lmZUD3HMck\nIiLtZK1dfGRmvYG33P3U/IUkIiLtla755RRgp5k9YmbLzOwBM+uWj8BERKTt0iX1MuBs4DfufjZw\nALg151GJiEi7pGtT3wRscvelifdzOSapm5luHiMi0g7ubtleZ6s1dXffBnxkZp9NTLoEWNXMcnq5\nM3Xq1MhjKJSXykJlobJo/ZUrmfR+mQw8bmYVwPvAd3MWjYiIdEjapO7ubwPn5CEWERHpIF1RmkVV\nVVVRh1AwVBYpKosUlUXutdpPPaMVmHku24dEROLIzPB8nygVEZHikvWk/tprsHBhttcqIiKZyHrz\niyUOJtQiIyLSMjW/iIhIWllN6ps3Z3NtIiLSVplcfJSRK66AbdtS72tqoF+/bK1dREQykZWa+qFD\n4eTosmWpaQsWZGPNIiLSFllJ6vPmNX0/ZoyaYkREopCVpF5fDxUVYXzBArj4Yigvz8aaRUSkLbKS\n1L/9bThyJIx37gzdu8Ojj8LOndlYu4iIZCrrXRo//hhWrIBVq+C667K9dhERaU1WLj4Cp7ISZs+G\nr34VeveGgwdh+HDYsCE7gYqIxEmuLj7KWlL/8EMYNixM27EjnCzdulVXloqINKfgryhNJnSA/v1h\nzpww/uGH2dqCiIikk5WkvmPH8dOGDg3DNWuysQUREclEVpL6iScePy2Z1B99NBtbEBGRTOT0IRnJ\nOzbu2we9enVoMyIisVLwberNmTEjDH/3u1xuRUREkrJ2Q6/m3HILPP88NDTkcisiIpKU82eU6qEZ\nIiLHK8rml8aefTZfWxIRKV05T+p33BGG48eH55eKiEju5Dyp/+xnoW0d4PBh2LVLTTEiIrmSl+aX\ne+4J94Gprw992nv0yMdWRURKT16SuhmMHAmXXhreHzyYj62KiJSevJ0oPbZb47p1+dqyiEjpyKif\nupl9COwHjgJ17j62rRtatCgM3UPNXbV1EZHsy/TiIweq3L0mGxs96yzo1OgYYdOmcCuBkSOzsXYR\nkdLVluaXDnWSf+EFePXVML5iBXz/+/DRR+H9NdfA5z/fkbWLiAhkeEWpmX0A7CM0v9zv7g80mtfq\nFaXNry81fvRoeK4pqKujiJSOqK8ovcDdvwBMAL5vZhd1ZKPz5qXGDx3qyJpERKSxjNrU3X1rYrjT\nzOYBY4ElyfnTpk37dNmqqiqqqqpaXd/VV8OECfDyy6mHaIwb17bARUSKSXV1NdXV1TnfTtrmFzPr\nBnR291oz6w48B9zh7s8l5re5+QWgrg4qKmD0aHj77TDtmWfg8svbvCoRkaITZfPLAGCJmS0HXgf+\nK5nQO6K8PAzffhsmTYKuXdXNUUSko9I2v7j7BmBMLoPo3Ts0x+i+6yIiHZO3K0qbc999YThiROi3\nPn16at7994eXiIhkLqdPPkqnoiIMKyth7twwbhZq7N/7Xnj/d38XTWwiIsUo0pp6bW0Ydu0Ka9em\npr/ySjTxiIgUu0iTerI747vvwumnQ03iJgQXXgg9e4bx996LJjYRkWKU82eUtmbLFrjoonDbgO7d\nk+truswrr6gPu4jET666NEaa1JtfX2q8V6/wGLzZs7O2ehGRghD1bQLy5q/+KjX+uc/BnDnRxSIi\nUmwKrqYO8PDDcOONsHcv9OkT7g9TWZnVTYiIRKpkml+OX38YNjQc394uIlKsSqb55Vh9+4bhm29G\nG4eISDEo+KR+zz1huHp1tHGIiBSDgm9+AZg8OdxSYOdO+MxncropEZG8KNnmF4Dt28Pw/fejjUNE\npNAVRVKfNQsuuQQ2bIg6EhGRwlYUSb1TJ+jfH371q6gjEREpbEWR1AFuuAGWLo06ChGRwlY0Sf3c\nc8PDNEREpGVFk9T79AF32L076khERApX0SR1Mxg1CpYtizoSEZHCVRT91FPbCsM8bU5EJGdKup96\n0s9+Fu6/LiIizSuqpD5mTOpeMCIicryiSuqdOsH8+VBfH3UkIiKFqaiS+vjxYZh8tqmIiDRVVEm9\nSxcYOhQeeCDqSEREClNRJXWA889XTV1EpCVFl9S/8Q3o2TPqKEREClPRJfVhw2DjxqijEBEpTEV1\n8RGEe6ufdJKeWSoixS3Si4/MrLOZvWVmC7IdQFsNGAAVFfCnP0UdiYhI4cm0+eVmYDVQEBfon302\nLFoUdRQiIoUnbVI3s8HAV4EHgYJo8Ni8GSZNijoKEZHCk0lN/V7gH4CGHMeSsd/9Ds45J+ooREQK\nT1lrM83sa8AOd3/LzKpaWm7atGmfjldVVVFV1eKiWXHmmbByZU43ISKSVdXV1VRXV+d8O632fjGz\nO4EbgHqgEugF/Ke7T2y0TF57v0C49W6nTqEZZtCgvG5aRCQrIun94u63u/sQdz8FuBZ4oXFCj4oZ\nDB8O77wTdSQiIoWlrRcfFUTvF4CdO2HWrKijEBEpLBkndXd/0d3/IpfBtMXDD0NNTdRRiIgUlqK7\nTUDSF7+o55WKiByr6G4TkFRXB927w4EDUF6e982LiHSInlF6jPLykNiXL486EhGRwlG0ST1p8uSo\nIxARKRxFndQnToTXX486ChGRwlG0beoAe/dC377hQdSdO0cSgohIu6hNvRl9+oRb8W7bFnUkIiKF\noaiTOsC+fVBbG3UUIiKFoeiTerdusH591FGIiBSGok/ql10G+/dHHYWISGEo+qTevz/s2BF1FCIi\nhaHok/rJJ4db8IqIiJK6iEisFH1SHzwYNm2KOgoRkcJQ9EldNXURkZSivqIU4NChcBHSoUPhEXci\nIsVAV5S2oGtXqKjQBUgiIhCDpA7QpQt88knUUYiIRE9JXUQkRmKR1CsrldRFRCBGSf3QoaijEBGJ\nXiyS+sknq6+6iAjEJKn37RtuwSsiUupikdR79dKdGkVEIEZJfe/eqKMQEYleLJL65z8Py5dHHYWI\nSPRikdT791ebuogIxCSpq0ujiEiQNqmbWaWZvW5my81stZn9Uz4Ca4vKSjh8OOooRESiV5ZuAXc/\nbGZfdveDZlYG/LeZXeju/52H+DLSrRscPBh1FCIi0cuo+cXdkymzAugM1OQsonbo3x+2b486ChGR\n6GWU1M2sk5ktB7YDi919dW7DapsBA2D3bqirizoSEZFopW1+AXD3BmCMmfUGnjWzKnevTs6fNm3a\np8tWVVVRVVWV3SjTKCsLtfWtW2Ho0LxuWkQkI9XV1VRXV+d8O21+8pGZ/R/gkLvPSLyP9MlHSePG\nwYwZcMEFUUciIpJeZE8+MrPPmFmfxHhX4FLgrWwH0lF6ALWISGbNLwOBR82sE+FH4DF3/2Nuw2q7\nYcNgw4aooxARiVYmXRpXAmfnIZYOOessWLQo6ihERKIViytKAcaMgbcKrlFIRCS/2nyi9LgVFMiJ\n0ro66N0bduyAHj2ijkZEpHWRnSgtFuXlcPrpsHZt1JGIiEQnNkkd4LTTYP36qKMQEYlOrJL66afD\nunVRRyEiEp1YJfXdu+GOO6KOQkQkOrFK6mPHRh2BiEi0YtP7BWDPHujXD2pqoG/fqKMREWmZer9k\nIJnI778/2jhERKISq6QOcNttcOBA1FGIiEQjdkl93Dh49dWooxARiUbskvro0boASURKV6xOlEJ4\nAHXXrrB/P/TsGXU0IiLN04nSDFVWhqFuwysipSh2SR3g0kthy5aooxARyb9YJvXu3WHJkqijEBHJ\nv1gm9TPOgJ07o45CRCT/YpnUhw+HAjp3KyKSN7FM6pWVoReMiEipiWVS79Il3AdGRKTUxDKpNzTA\n00/rmaUiUnpid/ERhPb0Tp3g2mvh97+POhoRkePp4qM2MIM774Q+faKOREQkv2KZ1AGGDYO9e6OO\nQkQkv2Kb1Lt1gzfeiDoKEZH8im1SHzwYPvgg6ihERPIrlidKIfRT7907DC3rpyJERDpGJ0rbqLIy\nJHNdhCQipSRtUjezIWa22MxWmdk7ZvaDfASWDb16hfuqi4iUikxq6nXAj9x9JHAe8H0zOyO3YWVH\nz55QWxt1FCIi+ZM2qbv7Nndfnhj/GFgDDMp1YNlQUaFujSJSWtrUpm5mw4EvAK/nIphsq6+HFSui\njkJEJH8y7v1iZj2AauAX7v5ko+k+derUT5erqqqiqqoqu1G2U+fO4T4wBdg5R0RKTHV1NdXV1Z++\nv+OOO3LS+yWjpG5m5cB/AYvc/Z+PmVeQXRoB7r8fvvc9ePddOO20qKMREUmJrEujmRnwELD62IRe\n6P76r8PwpZeijUNEJF/S1tTN7ELgJWAFkFz4Nnd/JjG/YGvqAP36hXurF3CIIlKCclVTj+0VpUn7\n94crS5cvh9Gjo45GRCTQFaXt1KsXfP3roV1dRCTuYp/UIdzca8uWqKMQEcm9kkjqAwfC5s1RRyEi\nknslkdQ/9zn1gBGR0hD7E6UQ7qs+YgRs3w79+0cdjYiITpR2yKmnwoQJ8MADUUciIpJbJZHUAa6+\nWj1gRCT+Siapn3mmkrqIxF9JtKkD7NgBAwbArl1wwglRRyMipU5t6h3Uvz/07Qvz5kUdiYhI7pRM\nUodw18Y5c6KOQkQkd0qm+QXgwAHo0QOWLYPPfha6d486IhEpVbqhV5ZYoyIsorBFJGbUpp4lf/mX\nqfE//CG6OEREcqHkauoQHkb9ta/B2rWhV0ynkvtpE5Goqfkly5YuhbFjw3hDQ9NmGRGRXFPzS5ad\ncw68+WYYf+21kNSTr6VLo41NRKS9SramntSzJ3z88fHT6+qgrCz/8YhIaVBNPUeWLAnDe++F2lpY\nvbrpdBGRYlLyNfXm3HcfTJ4Mt9wCM2ZEHY2IxJFOlOZRfT2MHBluABazXRORAqHmlzwqK4OVK8P4\nyy9HG4uISFsoqbegogJuvBEuvTTU3EVEioGaX1qxbx/06RPuF1NbG3U0IhInan6JQO/ecOutocvj\nc8+FaUruIlLIVFNP4+jR5vurb9wIQ4bkPx4RiQfV1CPSuXO4jcDJJzedfu+90cQjItIaJfUMmMGm\nTaF7o3u4u+O994bpV14ZdXQiIilpm1/M7GHgCmCHu49qZn6sm19a0twNwHRjMBHJVJTNL48A47O9\n4WKXrLUvWpSa1qkTXHUVvPNOdHFJ9tXV5fYitK1bYf/+MF5fDy+8ECoIIu2RNqm7+xJgTx5iKUrj\nx4d/+HHjwvv582HUqFBjX7gQnn8e7rwzPEpPcsc93Er5O98JTWUnnBD+BsmLyFrS0ABHjsD//E9Y\nfvp0OPHEMP7b38LBg+GahU6dYMUKuPnmsO7rrgv35W/rNQy1tamE/cYbcNddMGhQ6Gn1538O5eVw\n8cXhXM6kSXDPPTBlCtx+ezhpL5KWu6d9AcOBlS3Mc3Gvq3PfsycMn3zSvWfPZF2+6etv/iYsE6WG\nBvf33w9xbtjQ8nI1Ne6XXea+fn1u4zl61P3RR90nTgxl9MknqXnf+laYVlvrPmGCe9eu7nv3pubP\nn+8+b577r37VfHmD+3vvhWU//ND9wIGw/y++6P7gg80vX17e8rqaez3xhPuYMe5TpqTf10zW9+tf\ntzyvcdlIcUvkzoxycFteSuo59NFH7vv2ua9d2/w/6B//GBLamjXuf/pTx7d33XXuZWXuu3alpu3c\nmdreTTe5z5rVfCwff5z6TEND88usWRPmv/CCe3V1x+N1dz98uPltXXRRiKktyXXgwLDvTzwR9mHX\nrtS87t1b/tyYMe5nnOHeo0d4n9y3F190nzQpxOju/tJL4YewocF99Wr3QYOaT/DN2b8/tcy554bh\n5Mnuhw6llqmpCT88yXJZvjzM37XL/ec/T31+5MjslH22LFuWWUWlutp99+7wPaqpyX1cha6gk/rU\nqVM/fS1evDinBVHsFi8OpX7mmccnhLvuanstvqEh1N5efbXpuv7xH1tPgA8+6L5li/uzz6amLVoU\n1nn99eF9nz6hRv+jHzW/jl/+8vhYDh92v/nmMP/ll8P0l192X7Uq/ICdf37z65o71/3gwfBDeNNN\nTedt2eJ+yinhx+TgQfdTT00lx3/5F/d169yHDGm+prx7d2o9v/iF+0MPuT/+ePgRfegh94ULmy6/\ncqV7fX3m5V9b6z56tPu//7t7r17+aU37yBH3GTNCmQ4dmorhxz/OfN3Huuee1HoKxSWXpGJ68cWm\n8xoa3KdNa/k72KOH+9atHY9h507373zH/cor3f/+71Pr//rX3Z96qumRXZQWL17cJFcWdFKX9kl+\n+WbOdD/ppNT7zp3dN20KyzQ0hET/29+GeU8+GZLcsmXuTz/d9J/km98MCenYf553302ta+fO4+PY\ns6fp9iE0aRzrscfCvFtucZ86NYx/61uhRplJTXrs2OanJ2vCjU2f7n7BBU1rsoXu6NHW9z8bR2NH\nj7pXVobvSLbW2R6bN6f265ZbUuMTJ4bvSfL7Cu5/+7fuZuGoZ+ZM94svTs077bTwA9ge69eHWn8m\n3z0IR7L/9m/hyOzxx7NbHu0RWVIHfg9sAT4BPgK+e8z8nO98XDU0pL7Q9fWhptP4SzhiROZf2K98\nJfzDu4cmn2Stuy2uuCKsa8KEzJafPv34OKZMcb/qqjD/yivDtFGjUvNXr257XMUmua//8R+hLHfv\nzu76Z85sWuannx6O1JJHGPlod//GN8K2k812+/Yd/12YOLFpU+CxfvCD1LLXXOM+e3aIvaEhNHXN\nnu2+YEGovKxbF5b74IOwzEsvNd3W3LnHr7+hIfz4jBqVOoo69vX++2G5KERaU291BUrqWVVfHw7p\nk1+6QYNC+/KhQ+HLt2FDaDa44gr3f/3XqKMNRxTJtvbWHD3qvm1b7uMpBAsXtr/22RaHD7vfcEPz\nyerqq8OwS5fQzj9njvvGjaFtv7kk9tRToSad/PxPfxqOBhtLJvBk+35ztd2jR8PnHnkks2R5+HDr\n5ztae/XsGWremVq1yv3558N446PcQYNCs16+5Sqp694vIkWupgZ++lN48MHQpz6dgQNh2bLQdbNz\n5/D5E05oftkvfxl++MNwY7s1a1LTH3sMrr8+O/FDiHvJEnj6aRg+HM46K4wPGBC6k3btGm6FXVEB\nffvCnj1QXQ1f+lL7tzlzZuiimrRsGcyaBXffnZr2k5/Az38etveTn4RrCH7zm1AWN94I3bq1f/t6\n8pGIZMQ9dWVzbS188gk8/DD07x+upxg7NnWx0333wU03wamnhttfDB8ePl9fD0uXwoUXptY7eHC4\nBuDEE2HHjrzvVs5s3AjDhrX9c/X14UexvZTURSQrDh+GZ5+Fq69OTduyJdTgj3XkCOzaFS6QirO6\nunAl+LBh0K9fanp9PVx+eThauPvucAQxfTqMGAHXXNOxbSqpi0jWvfIKnHtux2qc0j5K6iIiMaL7\nqYuISFpK6iIiMaKkLiISI0rqIiIxoqQuIhIjSuoiIjGipC4iEiNK6iIiMaKkLiISI0rqIiIxoqQu\nIhIjSuoiIjGipC4iEiNK6iIiMaKkLiISI0rqIiIxoqQuIhIjSuoiIjGipC4iEiNK6iIiMaKkLiIS\nI0rqIiIxoqQuIhIjaZO6mY03s7Vmtt7MpuQjKBERaZ9Wk7qZdQbuA8YDZwLXmdkZ+QisGFVXV0cd\nQsFQWaSoLFJUFrmXrqY+FnjP3T909zpgFnBV7sMqTvrCpqgsUlQWKSqL3EuX1E8GPmr0flNimoiI\nFKB0Sd3zEoWIiGSFubect83sPGCau49PvL8NaHD36Y2WUeIXEWkHd7dsrzNdUi8D1gEXA1uAN4Dr\n3H1NtgMREZGOK2ttprvXm9lNwLNAZ+AhJXQRkcLVak1dRESKS4euKI37hUlmNsTMFpvZKjN7x8x+\nkJjez8yeN7N3zew5M+vT6DO3JcpjrZld1mj6n5nZysS8X0exP9lgZp3N7C0zW5B4X5JlYWZ9zGyu\nma0xs9Vmdm4Jl8Vtif+RlWb2/8ysS6mUhZk9bGbbzWxlo2lZ2/dEWc5OTH/NzIalDcrd2/UiNMe8\nBwwHyoHlwBntXV8hvoCTgDGJ8R6E8wtnAHcDP05MnwLclRg/M1EO5YlyeY/U0dAbwNjE+EJgfNT7\n184y+d/A48D8xPuSLAvgUeB/JcbLgN6lWBaJ/fkA6JJ4Pxv4dqmUBXAR8AVgZaNpWdt3YBLwm8T4\nNcCstDF1YGfGAc80en8rcGvUhZzjP+CTwCXAWmBAYtpJwNrE+G3AlEbLPwOcBwwE1jSafi3wf6Pe\nn3bs/2DgD8CXgQWJaSVXFokE/kEz00uxLPoRKjt9CT9uC4BLS6ksEgm6cVLP2r4nljk3MV4G7EwX\nT0eaX0rqwiQzG074RX6d8Afbnpi1HRiQGB9EKIekZJkcO30zxVlW9wL/ADQ0mlaKZXEKsNPMHjGz\nZWb2gJl1pwTLwt1rgF8CGwk95Pa6+/OUYFk0ks19/zTPuns9sM/M+rW28Y4k9ZI5w2pmPYD/BG52\n99rG8zz8hMa+LMzsa8AOd38LaLZvbamUBaHGdDbhsPhs4ADhSPVTpVIWZjYC+CGhtjoI6GFm1zde\nplTKojlR7HtHkvpmYEij90No+msTC2ZWTkjoj7n7k4nJ283spMT8gcCOxPRjy2QwoUw2J8YbT9+c\ny7hz4HzgL8xsA/B74Ctm9hilWRabgE3uvjTxfi4hyW8rwbL4IvCKu+9O1CSfIDTNlmJZJGXjf2JT\no88MTayrDOidODpqUUeS+pvAaWY23MwqCI348zuwvoJjZgY8BKx2939uNGs+4WQQieGTjaZfa2YV\nZnYKcBrwhrtvA/YnekgYcEOjzxQFd7/d3Ye4+ymENr8X3P0GSrMstgEfmdlnE5MuAVYR2pNLqiwI\n7cfnmVnXxD5cAqymNMsiKRv/E081s65vAn9Mu/UOniCYQDhJ8h5wW9QnLHJwAuRCQvvxcuCtxGs8\n4eTQH4B3geeAPo0+c3uiPNYClzea/mfAysS8mVHvWwfL5Uuker+UZFkAo4GlwNuE2mnvEi6LHxN+\n1FYSegWVl0pZEI5atwBHCG3f383mvgNdgDnAeuA1YHi6mHTxkYhIjOhxdiIiMaKkLiISI0rqIiIx\noqQuIhIjSuoiIjGipC4iEiNK6iIiMaKkLiISI/8fTk8JmejJIq0AAAAASUVORK5CYII=\n",
208 | "text/plain": [
209 | ""
210 | ]
211 | },
212 | "metadata": {},
213 | "output_type": "display_data"
214 | }
215 | ],
216 | "source": [
217 | "N = 500\n",
218 | "smooth_performances = [float(sum(performances[i:i+N])) / N for i in range(0, len(performances) - N)]\n",
219 | "\n",
220 | "plt.plot(range(len(smooth_performances)), smooth_performances)"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 231,
226 | "metadata": {
227 | "collapsed": false
228 | },
229 | "outputs": [
230 | {
231 | "data": {
232 | "text/plain": [
233 | ""
234 | ]
235 | },
236 | "execution_count": 231,
237 | "metadata": {},
238 | "output_type": "execute_result"
239 | },
240 | {
241 | "data": {
242 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAJBCAYAAABRfpDFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGgJJREFUeJzt3X+wZGV95/HPhzsgDojEUAGBSU3WRYNGUYpfhUm8GLAm\nxJ8pE0MkGkxZVrJjTKqSqCQxQ2o3WZO4IVk2LiKiBCOpoBgMjmRivCpRR0aZAWeGBTRUZoZlQFkw\ngETuzHf/6AYu95zb/fTTp/ucnuf9quqy+/bpc773yvnO5zzn6XMcEQKAYQ5quwAAs4FmASAJzQJA\nEpoFgCQ0CwBJaBYAktAsgBlk+1Dbm21vtb3D9h/XLPNG29ts32L7X2y/aJxtrhrnwwDaERGP2j4r\nIh6xvUrSjbZ/PCJuXLLYtyT9ZEQ8aHudpA9IOiN3mzQLYEZFxCP9p4dImpN0/7L3v7zk5WZJx4+z\nPQ5DgBll+yDbWyXtlfS5iNgxYPFfkfTpcbZHswBmVETsj4gXq5cYftL2fN1yts+S9BZJ7xxnexyG\nABlsT/1LVRHhFX7+oO3rJZ0iaWHpe/1BzcskrYuI/zfO9mkWQKYNLW7L9lGSFiPiAdtPl3SOpIuW\nLfPDkj4h6fyIuHPcGmgWwGx6tqSP2D5IveGEv46Iz9p+myRFxKWS3iPpByS937YkPRYRp+Vu0HxF\nHRid7fivU9ze72nlw5BpYYATQBIOQ4BMB7ddwJSRLAAkIVkAmUrbeUgWAJKU1hyBxjBmAQA1SBZA\nptJ2HpIFgCSlNUegMYxZAEANmgWAJByGAJlK23lIFgCSlNYcgcYwwAkANUgWQKbSdh6SBYAkpTVH\noDGMWQBADZIFkIlkAQA1SBZAptJ2HpIFgCQ0CwBJSktSQGMY4ASAGiQLIFNpOw/JAkCS0poj0BjG\nLACgBskCyFTazkOyAJCktOYINIYxCwCoQbMAkITDECBTaTsPyQJAkk42C9vrbN9m+w7b72y7nmFs\nf8j2Xtu3tl1LKttrbH/O9nbb37D9623XNIjtQ21vtr3V9g7bf9x2TQdP8dEFnWsWtuckXSJpnaTn\nSzrP9ontVjXUFerVO0sek/SbEfECSWdI+i9d/jtHxKOSzoqIF0t6kaSzbP94y2UVpYuHXadJujMi\n7pIk21dLeo2knW0WNUhEfNH22rbrGEVE3CPpnv7zh2zvlHSsuv13fqT/9BBJc5Lub7GcTu48k9S5\nZCHpOEm7lrze3f8ZJqTf6F4iaXO7lQxm+yDbWyXtlfS5iNjRdk0l6WJzjLYLKIntwyVdI+kdEfFQ\n2/UMEhH7Jb3Y9jMl3WB7PiIW2qqnK2MJ09LFZLFH0polr9eoly7QMNsHS/q4pKsi4pNt15MqIh6U\ndL2kU9qupSRdTBZbJJ3Qj8Z3S3qDpPPaLOhAZNuSLpe0IyIubrueYWwfJWkxIh6w/XRJ50i6qM2a\nSBYti4hFSesl3SBph6S/jYjODrpJku2PSfqSpOfa3mX7grZrSvBSSeerd1bh5v6jy2d0ni3pn/tj\nFpslfSoiPttyTUVxBEMEwKhsxz1T3N4xkiLCU9xkReeSBYBuolkASNLFAU5gJhw8zb1ncYrbWgHJ\nAkASkgWQaVVhyWLsX9c2p1NwwGj7jEOXNdQb/6CZ1VQsSJqfwHrXTmCdkvRJSa+d0LpPntB63y/p\nVye07r0TWu+Vkt40gfW+YqSlD56bQAkdxpgFgCSMWQCZpjpm0QEdTxZr2y5gRD/adgEZZvG7WCe1\nXUCROt4b17ZdwIhmsVmc2nYBGbrRLKY6z6IDOp4sAHQFzQJAksKCFNAgTp0CQBXJAshV2N5DsgCQ\npLDeCDSosL2HZAHMoFFuP2n7VNuLtn92nG0W1huBBrW79zx++8mt/Xu/fM32puUXt+7fDvS9kj4j\naaxv1JIsgBkUEfdExNb+84fUu+3ksTWLvl29m0jdN+42SRZAro7Ms1jp9pO2j1PvPsEvV29e/1jX\nnhnaLPr3krhYvT/NByPiveNsEMBwC49IC98bvtyQ209eLOldERH9m0qNdRgy8L4h/eOd/yPpbPVu\nK3iTpPOWHhf1rpQ1qYvfTMratgvIMKmL30zSpC5+MymvSL5Slu2IEyddz5Lt7axexat/+8l/kLSx\n7q5ytr+lJxvEUZIekfTWiLgup4ZhyeI0SXdGxF39jV+tXqzp9B3CgANdyu0nI+I/LVn+CvXu4pbV\nKKThzeI4SbuWvN4t6fTcjQFozOO3n7zF9s39n10o6YclKSIubXqDw5oFF+MFVtLi6YGIuFEjnM2M\niLHvvzvs190jac2S12vUSxfLLCx5vlazOSaA8mzrP5BiWLPYIumE/qmZuyW9QdJ51cXmm60KmIqT\n9NSrbl012sc7cup0WgY2i4hYtL1e0g3q/WkuXz5DDEAZhh51RcRGSRunUAswWwqb0sh0bwBJCuuN\nQIMK23tIFgCSFNYbgQYVdjaEZAEgCc0CQBIOQ4Bche09JAsASQrrjUCDCtt7SBYAkhTWG4EGFbb3\nkCwAJCmsNwINYlIWAFSRLIBche09JAsASQrrjUCDCtt7SBYAkjTSG4+PNzaxmqnZ/boT2i5hZD99\n7SfaLmFkP6jvtF3CSK4a6+Z+B77CghTQIE6dAkAVyQLIVdjeQ7IAkKSw3gg0qLC9h2QBIElhvRFo\nEGdDAKCKZAHkKmzvIVkASFJYbwQaVNjeQ7IAkIRmASBJYUEKaFBhew/JAkCSwnoj0CAmZQFAFckC\nyFXY3kOyAJCksN4INKiwvYdkASBJYb0RaBBnQwCgimYBIAmHIUCuwvYekgWAJIX1RqBBhe09JAsA\nSQrrjUCDCtt7SBYAkhTWG4EGMSkLAKpIFkCuwvYekgWAJIX1RqBBhe09JAsASWgWAJIUFqSABnHq\nFACqSBZArsL2HpIFgCQ0CyDXqik+lrH9Idt7bd+6Unm2523fbPsbthfG/XVpFsBsukLSupXetH2k\npP8l6VUR8WOSXj/uBhs56nqP/rCJ1UzN3LX72i5hZG954cfaLmFkD2+ZrX+Lrhr1Ay2eDYmIL9pe\nO2CRX5T08YjY3V/+2+Nuc7b+3wSQ6gRJz7L9OdtbbP/SuCssbDwXaFC3956DJZ0s6ackrZb0Zdtf\niYg7clfY7V8XKNTCzt5jDLskfTsivifpe7a/IOkkSTQL4EAyf2Lv8biLrh15FX8v6RLbc5KeJul0\nSf9jnJpoFkCuFvce2x+T9DJJR9neJekP1Dv0UERcGhG32f6MpFsk7Zd0WUTsGGebNAtgBkXEeQnL\n/JmkP2tqmzQLIFdhew+nTgEkKaw3Ag3iK+oAUEWyAHIVtveQLAAkKaw3Ag0qbO8hWQBIUlhvBBrE\n2RAAqKJZAEjCYQiQq7C9Z2iysL2mf7Wd7f0Lf/76NAoD0C0pvfExSb8ZEVttHy7pa7Y3RcR4l+YA\nZh3J4qki4p6I2Np//pCknZKOnXRhALplpN7Yv5rwSyRtnkQxwEwp7NRpcrPoH4JcI+kd/YTxhOs2\nbHvi+fPmj9bz5o9prEBgUr7w+dAXvxBtlzEzHDH8j2X7YEn/IGljRFy87L34QJw/ofImY07cN2Qa\nZu2+IYcful8R4ZRlbUd8ZdIVLdneGUqubVJSzoZY0uWSdixvFADKkXIY8lJJ50u6xfbN/Z+9OyI+\nM7mygBlQ2NmQob9uRNwoZnoCxaMJAEhSWJACGlTY3kOyAJCksN4INKiwSVkkCwBJSBZArsL2HpIF\ngCSF9UagQYXtPSQLAEkK641AgzgbAgBVJAsgV2F7D8kCQBKaBYAkhQUpoEGF7T0kCwBJCuuNQIMK\n23sa+XX3+KomVjM1G17edgUZrmu7gNEdtnV/2yWgQYX1RqA5waQsAKgiWQCZ9hW295AsACQprDcC\nzSFZAECNwnoj0JzFuWn+W9v+aWiSBYAkNAsASTgMATLtWzXN3ef7U9xWPZIFgCQkCyDTvrmy5nuT\nLAAkIVkAmfYVdnlvkgWAJCQLINMiyQIAqkgWQKZ9he0+JAsASWgWAJKUlaOABnHqFABqkCyATCQL\nAKhBsgAykSwAoAbNAsi0qLmpPerYXmf7Ntt32H5nzftH2f6M7a22v2H7l8f5fWkWwAyyPSfpEknr\nJD1f0nm2T1y22HpJN0fEiyXNS3qf7eyhB8YsgEwtT/c+TdKdEXGXJNm+WtJrJO1cssz/lfSi/vMj\nJH0nIhZzN0izAGbTcZJ2LXm9W9Lpy5a5TNI/275b0jMk/fw4G6RZAJkmeTbkpoVHtGXhkUGLRMJq\nLpS0NSLmbT9H0ibbJ0XEv+fURLMAOujU+dU6dX71E6//90X3L19kj6Q1S16vUS9dLHWmpP8mSRHx\nTdv/Kul5krbk1MQAJzCbtkg6wfZa24dIeoOk65Ytc5uksyXJ9tHqNYpv5W6QZAFkanNSVkQs2l4v\n6QZJc5Iuj4idtt/Wf/9SSX8k6Qrb29QLBr8TEZWIkopmAcyoiNgoaeOyn1265Pm3Jb2qqe3RLIBM\nXIMTAGqQLIBMXIMTAGqU1RqBBvEVdQCoQbIAMpEsAKAGzQJAkkYOQz4cO4cv1CGffco3e2fDjZef\n03YJozu87QJG5ZGW5jAEAGowwAlkYro3ANQgWQCZmO4NADXKao1AgzgbAgA1SBZAJpIFANQgWQCZ\nmGcBADVoFgCScBgCZGJSFgDUKKs1Ag3i1CkA1CBZAJlIFjVsz9m+2fanJl0QgG5KTRbvkLRD0jMm\nWAswU5iUtYzt4yWdK+mDGvUihQAOGCnJ4s8l/bakIyZcCzBTSptnMfC3tf1KSfdGxM2251da7oEN\nlzzx/ND503To/GmNFQhMzPYFacdC21XMjGGt8UxJr7Z9rqRDJR1h+8qIeNPShY7csH5S9QGT84L5\n3uNxH7+orUpmwsBmEREXSrpQkmy/TNJvLW8UQKk4dTpYTKQKAJ2XPEITEZ+X9PkJ1gLMFJIFANQo\n69wP0CCSBQDUIFkAmZjuDQA1SBZAptKme5MsACQpqzUCDeJsCADUoFkASMJhCJCJwxAAqEGyADIx\nKQsAapAsgExMygKAGmW1RqBBnA0BgBqNJIu7PnxiE6uZmofPm70eedgr9rddwuiuaLuA0Yx6By2S\nBQDUYMwCyESyAIAaNAsASTgMATIx3RsAapAsgExM9waAGmW1RqBBnDoFMBNsr7N9m+07bL9zhWX+\nsv/+NtsvGWd7JAsgU5vJwvacpEsknS1pj6SbbF8XETuXLHOupP8cESfYPl3S+yWdkbtNkgUwm06T\ndGdE3BURj0m6WtJrli3zakkfkaSI2CzpSNtH526QZAFkanmexXGSdi15vVvS6QnLHC9pb84GSRbA\nbIrE5ZZ/mTb1cxUkC6CD7l64Q3cv3DlokT2S1ix5vUa95DBomeP7P8tCswAyTXJS1tHzJ+ro+Sev\nE/P1i25YvsgWSSfYXivpbklvkHTesmWuk7Re0tW2z5D0QERkHYJINAtgJkXEou31km6QNCfp8ojY\naftt/fcvjYhP2z7X9p2SHpZ0wTjbpFkAmdqelBURGyVtXPazS5e9Xt/U9hjgBJCEZAFkajtZTBvJ\nAkASkgWQiWQBADVIFkAmLqsHADVIFkAmLqsHADVoFgCSlJWjgAZx6hQAapAsgEwkCwCoQbIAMjEp\nCwBqkCyATEzKAoAaZbVGoEGcDQGAGjQLAEk4DAEycRgCADVIFkAmJmUBQI1mksUZjaxlav79ac9o\nu4SRHfbog22XMLq/a7uAyWJSFgDUKKs1Ag3ibAgA1CBZAJlIFgBQg2QBZCJZAEANmgWAJByGAJmY\n7g0ANUgWQCamewNAjbJaI9AgTp0CQA2SBZCJZAEANUgWQCbmWSxj+0jb19jeaXuH7Rm7LhaAJqQk\ni7+Q9OmIeL3tVZIOm3BNADpoYLOw/UxJPxERb5akiFiUNIMXgwSax6Ssp/oRSffZvsL2121fZnv1\nNAoD0C3DmsUqSSdL+quIOFnSw5LeNfGqgBmwT3NTe3TBsBy1W9LuiLip//oa1TSLDZc8+Xz+tN4D\n6LqFh6SFh9uuYnYMbBYRcY/tXbafGxG3Szpb0vbly21YP6nygMmZP7z3eNxF9432+a78iz8tKSM0\nb5f0UduHSPqmpAsmWxKALhraLCJim6RTp1ALMFP27S8rWTDdG0CSsk4UAw1aXCRZAEAFyQLItG+x\nrN2HZAEgCc0CQJKychTQoH0McAJAFckCyESyAIAaJAsg0+JjJAsAqCBZAJn27ytr9yFZAAcY28+y\nvcn27bb/0faRKyz3btvbbd9q+29sP23QemkWQK7Fuek9RvMuSZsi4rmSPquaq9vZXivprZJOjogX\nSpqT9AuDVkqzAA48r5b0kf7zj0h6bc0y35X0mKTV/Vt8rJa0Z9BKyzroAprU3XkWR0fE3v7zvZKO\nXr5ARNxv+32S/k3S9yTdEBH/NGilNAtgBtneJOmYmrd+d+mLiAjbUfP550j6DUlr1bsX0N/ZfmNE\nfHSlbTbSLOKlTaxlenz/7N0n6Y62C8gwMNNisK8sSJsXVnw7Is5Z6T3be20f07/g9rMl3Vuz2CmS\nvhQR3+l/5hOSzpQ02WYBFGnRk1v3KWf1Ho/7yz8c5dPXSXqzpPf2//eTNcvcJun3bT9d0qPqXbn/\nq4NWygAncOD575LOsX27pJf3X8v2sbavl564EPeVkrZIuqX/uQ8MWqkjKoczI7Ed+5811iqm7t77\n265gdN9tu4AMs3YYcpakiEiKC7ZD28fbd0byAifXNikkCwBJGLMAci22XcB0kSwAJCFZALlIFgBQ\nRbIAcj3WdgHTRbIAkIRmASAJhyFArn1tFzBdJAsASUgWQC5OnQJAFckCyEWyAIAqkgWQi2QBAFUk\nCyAXyQIAqkgWQC6SBQBU0SwAJOEwBMjFYQgAVJEsgFxcKQsAqkgWQC4ufgMAVSQLIBdnQwCgimQB\n5CJZAEAVzQJAEg5DgFwchgBAFckCyEWyAIAqkgWQi2QBAFUkCyAXyQIAqkgWQK7CLn7TSLM46P5o\nYjXTc0rbBWR4bdsFZLin7QJGdInbrqDTSBZALi5+AwBVNAsASTgMAXJx6hQAqkgWQC6SBQBUkSyA\nXCQLAKgiWQC5CpvuTbIAkIRkAeRiujcAVNEsACThMATIxalTAKgamixsv1vS+ZL2S7pV0gUR8R+T\nLgzoPJLFk2yvlfRWSSdHxAslzUn6hcmXBaBrhiWL76o39WS17X2SVkvaM/GqgFnApKwnRcT9kt4n\n6d8k3S3pgYj4p2kUBqBbhh2GPEfSb0haK+lYSYfbfuMU6gK6b98UHx0w7DDkFElfiojvSJLtT0g6\nU9JHn7rYhiXP5/sPoON2L0h7FtquYmYMaxa3Sfp920+X9KiksyV9tbrYhqbrAibv+Pne43E3XTTa\n5zkb8qSI2CbpSklbJN3S//EHJl0UgHy2f872dtv7bJ88YLkjbV9je6ftHbbPGLTeofMsIuJPJP1J\nRs3Aga27yeJWSa+TdOmQ5f5C0qcj4vW2V0k6bNDCTPcGDjARcZsk2SvfYc32MyX9RES8uf+ZRUkP\nDlov072BMv2IpPtsX2H767Yvs7160AdIFkCuSU7KundBum9hxbdtb5J0TM1bF0bEpxK2sErSyZLW\nR8RNti+W9C5J7xn0AQBd80Pzvcfjdj71TE1EnDPmFnZL2h0RN/VfX6Nes1gRzQLI1ZHJUkPUDlxE\nxD22d9l+bkTcrt60iO2DVsSYBXCAsf0627sknSHpetsb+z8/1vb1SxZ9u6SP2t4m6UWS/mjQekkW\nQK6OnjqNiGslXVvz87sl/cyS19sknZq6XpIFgCQkCyBXR5PFpJAsACQhWQC5uPgNAFTRLAAk4TAE\nyDUbk7IaQ7IAkIRkAeTi1CkAVJEsgFyFJYuGmsXmZlYzLVtmsEduqblOcuf9UNsFoEEzuNcAHcGk\nLACoIlkAuZhnAQBVJAsgV2FnQ0gWAJLQLAAk4TAEyMVhCABUkSyAXEzKAoAqkgWQi0lZAFBFsgBy\ncTYEAKpIFkAukgUAVJEsgFzMswCAKpoFgCQchgC5mJQFAFUkCyAXp04BoIpkAeQiWQBAFckCyMWk\nLACoIlkAuZhnAQBVNAsASTgMAXJF2wVMF8kCQBKaBYAkNAsASWgWAJLQLAAkoVkASEKzAJCEeRZA\ntrK+SdbxZPG1tgsY0Za2C8hwe9sFZPhG2wUUqePN4uttFzCiWWtu0mw2i+1tF9C3OMVH+zreLAB0\nBc0CQBJHjPdtGNuFfZ0GB7KIcMpyvf/uH5x0OUs8M7m2SRn7bEjbvwCA6eDUKZCtGwOP08KYBYAk\nJAsgG5OyAKCCZAFkI1kAQAXJAsjWzbMhtv9U0islfV/SNyVdEBG1k0Jsz6n3pabdEfGqQeslWQAH\nnn+U9IKIOEm9L/+8e8Cy75C0QwnXKqdZAAeYiNgUEfv7LzdLOr5uOdvHSzpX0gclDZ1cyWEIkG0m\nBjjfIuljK7z355J+W9IRKSuiWQAzyPYmScfUvHVhRHyqv8zvSvp+RPxNzedfKeneiLjZ9nzKNmkW\nQLZJDnB+tf+oFxHnDPq07V9W7xDjp1ZY5ExJr7Z9rqRDJR1h+8qIeNOK6xz3W6dAiXrfOt0xxS0+\nf5RvxK6T9D5JL4uIbycs/zJJvzXsbAjJAsjW2TGL/ynpEEmbbEvSlyPi12wfK+myiPiZms8MTQ0k\nCyBDL1lsm+IWT2r9chAkCyBbNydlTQrzLAAkIVkA2To7ZjERJAsASUgWQDbGLACggmYBIAmHIUA2\nBjgBoIJkAWRjgBMAKkgWQDbGLACggmQBZGPMAgAqSBZANsYsAKCCZgEgCYchQDYOQwCggmQBZOPU\nKQBUkCyAbIxZAEAFyQLIxpgFAFSQLIBsjFkAQAXJAsjGmAUAVNAsACThMATIxgAnAFSQLIBsDHAC\nQAXJAsjGmAUAVJAsgGyMWQBAhSOi7RqAmWN76jtORHja21yKZgEgCYchAJLQLAAkoVkASEKzAJCE\nZgEgyf8HmzwGuPTq2GoAAAAASUVORK5CYII=\n",
243 | "text/plain": [
244 | ""
245 | ]
246 | },
247 | "metadata": {},
248 | "output_type": "display_data"
249 | }
250 | ],
251 | "source": [
252 | "x = brain.layers[0].Ws[0].eval()\n",
253 | "import matplotlib.pyplot as plt\n",
254 | "%matplotlib inline\n",
255 | "plt.matshow(x)\n",
256 | "plt.colorbar()"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 138,
262 | "metadata": {
263 | "collapsed": false
264 | },
265 | "outputs": [
266 | {
267 | "data": {
268 | "text/plain": [
269 | "array([ 11.01352692, 11.28201485, 12.03692055, 12.26954937], dtype=float32)"
270 | ]
271 | },
272 | "execution_count": 138,
273 | "metadata": {},
274 | "output_type": "execute_result"
275 | }
276 | ],
277 | "source": [
278 | "brain.input_layer.b.eval()"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 88,
284 | "metadata": {
285 | "collapsed": false
286 | },
287 | "outputs": [
288 | {
289 | "data": {
290 | "text/plain": [
291 | "-2.0"
292 | ]
293 | },
294 | "execution_count": 88,
295 | "metadata": {},
296 | "output_type": "execute_result"
297 | }
298 | ],
299 | "source": [
300 | "game.collect_reward(0)"
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": 269,
306 | "metadata": {
307 | "collapsed": false
308 | },
309 | "outputs": [
310 | {
311 | "data": {
312 | "text/plain": [
313 | "array([ 1.39407934, 2.791605 , 0.92194436, 1.73143704, -1. ])"
314 | ]
315 | },
316 | "execution_count": 269,
317 | "metadata": {},
318 | "output_type": "execute_result"
319 | }
320 | ],
321 | "source": [
322 | "np.concatenate([observation, np.array([-1])])"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 278,
328 | "metadata": {
329 | "collapsed": false
330 | },
331 | "outputs": [
332 | {
333 | "data": {
334 | "text/plain": [
335 | "1"
336 | ]
337 | },
338 | "execution_count": 278,
339 | "metadata": {},
340 | "output_type": "execute_result"
341 | }
342 | ],
343 | "source": [
344 | "n_prev_frames"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "execution_count": 285,
350 | "metadata": {
351 | "collapsed": false
352 | },
353 | "outputs": [
354 | {
355 | "data": {
356 | "text/plain": [
357 | "1"
358 | ]
359 | },
360 | "execution_count": 285,
361 | "metadata": {},
362 | "output_type": "execute_result"
363 | }
364 | ],
365 | "source": [
366 | "action"
367 | ]
368 | },
369 | {
370 | "cell_type": "code",
371 | "execution_count": 294,
372 | "metadata": {
373 | "collapsed": false
374 | },
375 | "outputs": [
376 | {
377 | "data": {
378 | "text/plain": [
379 | "[]"
380 | ]
381 | },
382 | "execution_count": 294,
383 | "metadata": {},
384 | "output_type": "execute_result"
385 | }
386 | ],
387 | "source": [
388 | "performances[:10]"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": 306,
394 | "metadata": {
395 | "collapsed": true
396 | },
397 | "outputs": [],
398 | "source": [
399 | "performances_1 = performances[:]"
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "execution_count": 328,
405 | "metadata": {
406 | "collapsed": false
407 | },
408 | "outputs": [
409 | {
410 | "data": {
411 | "text/plain": [
412 | "0.89357223079637949"
413 | ]
414 | },
415 | "execution_count": 328,
416 | "metadata": {},
417 | "output_type": "execute_result"
418 | }
419 | ],
420 | "source": [
421 | "np.average(performances[-1000:])"
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": null,
427 | "metadata": {
428 | "collapsed": true
429 | },
430 | "outputs": [],
431 | "source": [
432 | "np"
433 | ]
434 | }
435 | ],
436 | "metadata": {
437 | "kernelspec": {
438 | "display_name": "Python 2",
439 | "language": "python",
440 | "name": "python2"
441 | },
442 | "language_info": {
443 | "codemirror_mode": {
444 | "name": "ipython",
445 | "version": 2
446 | },
447 | "file_extension": ".py",
448 | "mimetype": "text/x-python",
449 | "name": "python",
450 | "nbconvert_exporter": "python",
451 | "pygments_lexer": "ipython2",
452 | "version": "2.7.6"
453 | }
454 | },
455 | "nbformat": 4,
456 | "nbformat_minor": 0
457 | }
458 |
--------------------------------------------------------------------------------
/notebooks/karpathy_game.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%load_ext autoreload\n",
12 | "%autoreload 2\n",
13 | "%matplotlib inline"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {
20 | "collapsed": false
21 | },
22 | "outputs": [],
23 | "source": [
24 | "from __future__ import print_function\n",
25 | "\n",
26 | "import numpy as np\n",
27 | "import tempfile\n",
28 | "import tensorflow as tf\n",
29 | "\n",
30 | "from tf_rl.controller import DiscreteDeepQ, HumanController\n",
31 | "from tf_rl.simulation import KarpathyGame\n",
32 | "from tf_rl import simulate\n",
33 | "from tf_rl.models import MLP\n"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 3,
39 | "metadata": {
40 | "collapsed": false
41 | },
42 | "outputs": [
43 | {
44 | "name": "stdout",
45 | "output_type": "stream",
46 | "text": [
47 | "/tmp/tmp7dz3utz0\n"
48 | ]
49 | }
50 | ],
51 | "source": [
52 | "LOG_DIR = tempfile.mkdtemp()\n",
53 | "print(LOG_DIR)"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 4,
59 | "metadata": {
60 | "collapsed": false
61 | },
62 | "outputs": [],
63 | "source": [
64 | "current_settings = {\n",
65 | " 'objects': [\n",
66 | " 'friend',\n",
67 | " 'enemy',\n",
68 | " ],\n",
69 | " 'colors': {\n",
70 | " 'hero': 'yellow',\n",
71 | " 'friend': 'green',\n",
72 | " 'enemy': 'red',\n",
73 | " },\n",
74 | " 'object_reward': {\n",
75 | " 'friend': 0.1,\n",
76 | " 'enemy': -0.1,\n",
77 | " },\n",
78 | " 'hero_bounces_off_walls': False,\n",
79 | " 'world_size': (700,500),\n",
80 | " 'hero_initial_position': [400, 300],\n",
81 | " 'hero_initial_speed': [0, 0],\n",
82 | " \"maximum_speed\": [50, 50],\n",
83 | " \"object_radius\": 10.0,\n",
84 | " \"num_objects\": {\n",
85 | " \"friend\" : 25,\n",
86 | " \"enemy\" : 25,\n",
87 | " },\n",
88 | " \"num_observation_lines\" : 32,\n",
89 | " \"observation_line_length\": 120.,\n",
90 | " \"tolerable_distance_to_wall\": 50,\n",
91 | " \"wall_distance_penalty\": -0.0,\n",
92 | " \"delta_v\": 50\n",
93 | "}"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 5,
99 | "metadata": {
100 | "collapsed": false
101 | },
102 | "outputs": [],
103 | "source": [
104 | "# create the game simulator\n",
105 | "g = KarpathyGame(current_settings)"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 6,
111 | "metadata": {
112 | "collapsed": false
113 | },
114 | "outputs": [],
115 | "source": [
116 | "human_control = False\n",
117 | "\n",
118 | "if human_control:\n",
119 | " # WSAD CONTROL (requires extra setup - check out README)\n",
120 | " current_controller = HumanController({b\"w\": 3, b\"d\": 0, b\"s\": 1,b\"a\": 2,}) \n",
121 | "else:\n",
122 | " # Tensorflow business - it is always good to reset a graph before creating a new controller.\n",
123 | " tf.reset_default_graph()\n",
124 | " session = tf.InteractiveSession()\n",
125 | "\n",
126 | " # This little guy will let us run tensorboard\n",
127 | " # tensorboard --logdir [LOG_DIR]\n",
128 | " journalist = tf.train.SummaryWriter(LOG_DIR)\n",
129 | "\n",
130 | " # Brain maps from observation to Q values for different actions.\n",
131 | " # Here it is a done using a multi layer perceptron with 2 hidden\n",
132 | " # layers\n",
133 | " brain = MLP([g.observation_size,], [200, 200, g.num_actions], \n",
134 | " [tf.tanh, tf.tanh, tf.identity])\n",
135 | " \n",
136 | " # The optimizer to use. Here we use RMSProp as recommended\n",
137 | " # by the publication\n",
138 | " optimizer = tf.train.RMSPropOptimizer(learning_rate= 0.001, decay=0.9)\n",
139 | "\n",
140 | " # DiscreteDeepQ object\n",
141 | " current_controller = DiscreteDeepQ((g.observation_size,), g.num_actions, brain, optimizer, session,\n",
142 | " discount_rate=0.99, exploration_period=5000, max_experience=10000, \n",
143 | " store_every_nth=4, train_every_nth=4,\n",
144 | " summary_writer=journalist)\n",
145 | " \n",
146 | " session.run(tf.initialize_all_variables())\n",
147 | " session.run(current_controller.target_network_update)\n",
148 | " # graph was not available when journalist was created \n",
149 | " journalist.add_graph(session.graph)"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 7,
155 | "metadata": {
156 | "collapsed": false,
157 | "scrolled": false
158 | },
159 | "outputs": [
160 | {
161 | "data": {
162 | "text/html": [
163 | "\n",
164 | "\n",
165 | "\n"
469 | ],
470 | "text/plain": [
471 | ""
472 | ]
473 | },
474 | "metadata": {},
475 | "output_type": "display_data"
476 | },
477 | {
478 | "name": "stdout",
479 | "output_type": "stream",
480 | "text": [
481 | "Interrupted\n"
482 | ]
483 | }
484 | ],
485 | "source": [
486 | "FPS = 30\n",
487 | "ACTION_EVERY = 3\n",
488 | " \n",
489 | "fast_mode = True\n",
490 | "if fast_mode:\n",
491 | " WAIT, VISUALIZE_EVERY = False, 50\n",
492 | "else:\n",
493 | " WAIT, VISUALIZE_EVERY = True, 1\n",
494 | "\n",
495 | " \n",
496 | "try:\n",
497 | " with tf.device(\"/cpu:0\"):\n",
498 | " simulate(simulation=g,\n",
499 | " controller=current_controller,\n",
500 | " fps=FPS,\n",
501 | " visualize_every=VISUALIZE_EVERY,\n",
502 | " action_every=ACTION_EVERY,\n",
503 | " wait=WAIT,\n",
504 | " disable_training=False,\n",
505 | " simulation_resolution=0.001,\n",
506 | " save_path=None)\n",
507 | "except KeyboardInterrupt:\n",
508 | " print(\"Interrupted\")"
509 | ]
510 | },
511 | {
512 | "cell_type": "code",
513 | "execution_count": 8,
514 | "metadata": {
515 | "collapsed": true
516 | },
517 | "outputs": [],
518 | "source": [
519 | "session.run(current_controller.target_network_update)"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": 9,
525 | "metadata": {
526 | "collapsed": false
527 | },
528 | "outputs": [
529 | {
530 | "data": {
531 | "text/plain": [
532 | "array([[ 0.00824914, -0.04792793, 0.08260646, ..., -0.00135659,\n",
533 | " -0.0149605 , -0.00065048],\n",
534 | " [ 0.04895933, -0.01720949, 0.03015076, ..., 0.04350275,\n",
535 | " -0.00071916, -0.00507376],\n",
536 | " [-0.03408033, -0.00734746, -0.07286905, ..., 0.06636748,\n",
537 | " 0.0507561 , 0.04723936],\n",
538 | " ..., \n",
539 | " [-0.01454929, -0.00313209, 0.02152171, ..., 0.01723659,\n",
540 | " -0.01757577, -0.02262094],\n",
541 | " [ 0.03226471, -0.09545884, 0.01721121, ..., 0.0179732 ,\n",
542 | " -0.01188065, 0.03430547],\n",
543 | " [ 0.02971489, 0.06272104, -0.05087573, ..., -0.00265156,\n",
544 | " -0.00139228, 0.01183042]], dtype=float32)"
545 | ]
546 | },
547 | "execution_count": 9,
548 | "metadata": {},
549 | "output_type": "execute_result"
550 | }
551 | ],
552 | "source": [
553 | "current_controller.q_network.input_layer.Ws[0].eval()"
554 | ]
555 | },
556 | {
557 | "cell_type": "code",
558 | "execution_count": 10,
559 | "metadata": {
560 | "collapsed": false
561 | },
562 | "outputs": [
563 | {
564 | "data": {
565 | "text/plain": [
566 | "array([[ 0.00792158, -0.04618821, 0.08227044, ..., -0.00076906,\n",
567 | " -0.01525626, -0.00064818],\n",
568 | " [ 0.04957703, -0.01709567, 0.03077196, ..., 0.04223152,\n",
569 | " -0.00179744, -0.00452804],\n",
570 | " [-0.03365123, -0.00761478, -0.07296818, ..., 0.0649555 ,\n",
571 | " 0.0494989 , 0.04785381],\n",
572 | " ..., \n",
573 | " [-0.01843384, -0.00160771, 0.02573192, ..., 0.01739593,\n",
574 | " -0.01991378, -0.01864818],\n",
575 | " [ 0.03407663, -0.09417476, 0.01543032, ..., 0.01674915,\n",
576 | " -0.00885472, 0.03387342],\n",
577 | " [ 0.02986682, 0.06386744, -0.05153623, ..., -0.00791921,\n",
578 | " -0.00157495, 0.01063347]], dtype=float32)"
579 | ]
580 | },
581 | "execution_count": 10,
582 | "metadata": {},
583 | "output_type": "execute_result"
584 | }
585 | ],
586 | "source": [
587 | "current_controller.target_q_network.input_layer.Ws[0].eval()"
588 | ]
589 | },
590 | {
591 | "cell_type": "markdown",
592 | "metadata": {},
593 | "source": [
594 | "# Average Reward over time"
595 | ]
596 | },
597 | {
598 | "cell_type": "code",
599 | "execution_count": 11,
600 | "metadata": {
601 | "collapsed": false
602 | },
603 | "outputs": [
604 | {
605 | "data": {
606 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAEACAYAAAB27puMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXmYHFW5h9/JZCZMJglZZjo7SQwgJIAE7g1BEIY9IKso\n6FVEUAE1rIqAiCQ8XkUu94KCyKpGkMAFBFGDEJZBvAgIYc9CEjLZk54sE8JM9vT946uyqnu6p6un\n1+r+vc9TTy19TtVXPT31q+9853wHhBBCCCGEEEIIIYQQQgghhBBCCCGEEEIIIYQQAoDJwDxgAXBV\nijK/cD5/G5gQoO5/AXOd8n8Advd9do1Tfh5wfPbmCyGEKDbVwEJgNFADvAXsm1DmJGCms30I8EqA\nuscBPZztG50FYJxTrsapt9BXTgghRJHI9kE8EXugtwDbgYeA0xLKnApMd7ZfBfoDQ9LUnQXs8tUZ\n4WyfBsxwyrc49SdmeQ9CCCGyJFsxGQ4s8+0vd44FKTMsQF2A8/E8m2FOuXR1hBBCFJBsxSQWsFxV\nN89/LbANeDAHNgghhMgTPbOsvwIY6dsfSbznkKzMCKdMTZq6X8PiLcekOdeKRKPGjh0bW7RoUaAb\nEEII8S8WAXsW48I9nYuPBmpJH4CfhBeA76ruZOB9oCHhXG4AvhYY49RP5vXEwsz1119fbBO6TZht\nj8Vkf7GR/cWFLFp6svVMdgBTgKex3ln3YV16L3Q+vwsTkpOwYHk7cF6augC3YYIxy9n/B/BtYA7w\nv856h3NMzVxCCFFkshUTgKecxc9dCftTMqgLsFcX1/uJswghhCgRNEajBGlqaiq2Cd0mzLaD7C82\nsj+8dLeXVanjNP8JIYQISlVVFXRTF+SZCCGEyBqJiRBCiKyRmAghhMgaiYkQQoiskZgIIdKyYwec\ncAI891yxLRGlisRECJGWaBSeeQZefLHYlohSRWIihEhLNBq/FiIRiYkQIi0SE5EOiYkQIi3RKIwY\nITERqZGYCCHS0toK++1nYhKLwcKFMH8+7NqVvq6oDCQmQoi0RKMwbhysXQtz5piw/Nu/qXeX8MhF\n1mAhRJkTjcKECdDeDitXwuGHw7Bhti0EyDMRQgQgGoU99rDxJitXQiRii2IowkViIoRISzRq4lFf\nDy0tEhPRGYmJECItfjFZvBgaGyUmIh6JiRCiS37yE2vaamz0xKTUPJM774THHy+2FZWNAvBCiC6Z\nOhXuvRf69i1dMfnWt6y32RlnFNuSykWeiRAiJdu321iSc86x/fp6WLbME5PW1uLa56dv32JbUNlI\nTIQQKWlvNwGpciZyra+3dSRizV7uIMZi0t5u6x56mhUVff1CiJS4YuLiF5O6OqipgU2bimObi+sd\nuaIiioPERAiRkkQxcT2UPn1sXQpxk2gUdt9dYlJsJCZCiJQkisnw4bZ2RSUSgenTbX/Fitxcc7fd\n4NFHg5ePRmHMGIlJsZGYCCFSkigmt98eHyOJRODll217yZLcXHPrVnj11eDlJSalgcRECJGSRDFJ\nJBKBd9+17WI1d/nFpNidASoZiYkQIiVBxMQNgOdCTNyU9pmkto9GLelkz57m1YjiIDERQqSkvR16\n9079eSRi6332yY2YbNhg67a24HX8qV7U1FU8JCZCiJTMnBlMTPbbD55/3mvySqSjA155xdtftszi\nL/Pm2f6uXdDcDH/5i+37pwl+773U19+0CZ56yuzYtg3+7/8C3ZbIAxITIURKHnoITj459eef+Qxc\nfjlceSVUV8OttyYvd8stcOih3v6DD8L3vgd33GH7//gHHHUU/Pa3Nm+KKybnnAP775/6+i+9ZEJ1\n8MFw0klw//0Z3Z7IIRITIURSNm+G2lo49dTUZUaMgP/5H5g4ES6+OHV6lS1b4vejURg/3hMNN9bR\n2grXXusd//jjrm2MRuGss6ChAS68ENatS39fIj9ITIQQSWltteYjd0xJOroawJh4jtZWExNXfNav\nt/W8efEik+7abrwk3fVF/pGYCCGSEo1a/q2guLm6kuEGxt11NGpxFn9sBGwmxzFjrItvkGC6xKR0\nkJgIIZLif1AHoauHeWL34cRmLn+9Xr28c+3cacfcdVc2DhpkvcFSlRX5RWIihEhKa2tmnkmfPvYg\n7+iw/dWrLeZywAHexFUnnGBpV6JRm39k40Y4+ujOIjRsmB2fPdv2b7ml8/VaWizgPmyY7ffsaZ7N\nf/93RrcpcoQmxxJCJGXTJujXL3j5qiqbU2TTJutOPH++zYdSUwOvvw4DB8Jxx8HSpSYeQ4fCggUw\nahT0728TcJ1yip3rySdNzGprYcYMmDOn8/WWL4exY+GYY7xjN99s1xKFJxeeyWRgHrAAuCpFmV84\nn78NTAhQ9wvA+8BO4CDf8dHAZuBNZ7kja+uFEElJN/o9Gf6Bg66H0tAAn/ykeTmDBtkUwHV1ltBx\n5EjbbmmBPff0mqwaGmDffU0sxo71zuVn61bYY4/4eUwOOKC0JuyqJLL1TKqB24FjgRXAP4Engbm+\nMicBewJ7AYcAvwImpan7LnAGcFeSay4kXpCEEHkgWzFxH+r+Hln+aX+7OpbqnH62bDFB8qMgfPHI\n1jOZiD3cW4DtwEPAaQllTgWmO9uvAv2BIWnqzgM+yNI2IUQWZCsmybr3phKTtrbMxWTrVgvW+5GY\nFI9sxWQ4sMy3v9w5FqTMsAB1kzEGa+JqBg7PzFwhRBCWLOm+mCxdCm++mfyhnkpMqqthwIDU5wwq\nJg0NNnAxk0SRIjdk28wVNOFzwGFPaVkJjAQ2YLGUJ4DxQKeJQ6dOnfqv7aamJpqamnJkghDlz+jR\n9hD/9Kczq1dfD9ddBx98AOeeaz2tvve9+M8XL4Yjjog/1tiYeg73TJq5amqs/MaNqcVJeDQ3N9Pc\n3JyTc2UrJiuwh7vLSMzD6KrMCKdMTYC6iWxzFoDZwCIsFjM7saBfTIQQmdNdz6SlxbZXrIC7747v\nbdW7t4nJ5z8ff6yr8SyZeCb+8hKT9CS+aE+bNq3b58q2met17GE+GqgFzsaC6H6eBL7qbE8C2oA1\nAetCvFfTgAXuAT7h1P8wu1sQQqSiq4zBqcpvc173kgXV6+tNBBKbufIhJqKwZOuZ7ACmAE9jD/n7\nsN5YFzqf3wXMxHp0LQTagfPS1AXryfULTDz+gsVITgSOBKZhAftdznUymPlACJEJ3fFMXJYuTS4m\n0FlMdt899Tl795aYhIFcDFp8yln8JHbpnZJBXYDHnSWRx5xFCJEn/MHr7opJz542YDFxBH13xWTz\nZrPLH1dJFjNxzycxKTxKpyKEiMM/QLBPn8zq9u1r6zFjrG5iM5n7uV9M+vbtupmruto8kM2bvWOt\nrXD99fJMSgmlUxFCxNHebm/8jz0Ge+2VWd1vftMmwbruuuQJF08+Gf78Zxvd7nLVVck9DD91deaJ\nuJ7Nyy/bWmJSOsgzEULE0d4OgwfbzIWpuuumIhKxBI2pgup9+8JnPxt/3mHDLG9XV/Tq5U2gBZ7H\nIzEpHeSZCCHi6E6X4ETq682byBW77RY/W+OOHbb2N335ry0xKTwSEyFEHKUoJomeiRvX2bgx+bUl\nJoVHzVxCVBC7dnlv9anIlZhkMrFWOhLFxBWLVGKyaZMmySo0EhMhKojvftfiIV2RCzE54ACYkMPc\n3qnE5OijO5ft2xduuMFiPqJwqJlLiAritddg/fquy+RCTL7znezqJ5IYM2lvhyuugDPP7FzWHdvy\n/PO5tUF0jTwTISqIWIDUrLkQk1yTzDNJZWMum9dEcCQmQog4wi4mmcxbL3KHxESICsL1TLryUEpR\nTJI1c8kzKS0kJkJUCMuWwSuv2PasWanLlaKYuJ7Jzp0wfTo8/nh6z2THDpgzp3A2VjoSEyEqhGee\nsUmvzjgDZsxIXa6UxeSDD+BrX4NVq+D445OX7d3bUsFMnAiPPlpQMysa9eYSokKIRuHss+Gww+Cu\nxLzePkpZTDY5c6pOmADDu5jk+3Ofs8m55s8vjH1CnokQFUM0avGESCT5/OwupSgmbsykK7sTSXef\nIrdITISoEMIsJq5n4todpIuzxKSwSEyEqBASxSTVAzkMYhIEiUlhkZgIUSG4YlJfD1VVqZMhfvxx\n5pNi5ZtBg2DaNFi50vaPPDJ9nUjEJtEShUFiIkSF0NrqjcHo6q197VpoaCicXUG4+GLLuTV3Lvzu\nd3DrrenrDBwIbW3pE1uK3CAxEaIC2LUrXiS6EhPXgyklqqpgjz3gvfeC21ZdDQMGwLp1+bVNGBIT\nISqAtjZruqqttf1UYrJ5s/Wa2n33wtoXhEgEVq/OTOgUNykcEhMhKoBEbyNVPMFtCquqKpxtQfE3\n0WVSR2JSGCQmQlQAiWLS2GgpVZYv77pcKeGmSckkniMxKRwSEyEqgPXrLX7gcvLJNjr8zjvjy5Wy\nmJx+Olx2mXUTDorEpHAonYoQFUB7e3x338MPh4sugjfeiC9XymJyzDG2ZILEpHDIMxGiAkg2EDES\ngTVr4o+Vsph0B4lJ4ZCYCFEBJBOTwYM7P2glJqK7SEyEqABSeSaJD1r/wMZyQGJSOCQmQlQAqTyT\nZctg0iR48UU7Fo2W17S3EpPCITERogJIJiZ9+8Ls2TBqFLz1lh376CPo16/w9uULiUnhkJgIUQG0\nt9sMhInstx984hNe0sdSzBicDX37wvbtNrJf5BeJiRAVQFciUV9fvmJSVWXNdsoenH8kJkJUAJUq\nJqCmrkIhMRGiDPnoI5tMykViYttr1wabpVFkjsREiDKkf3847zxvP6iYdHSUt5g0NsIDDxTXnnIl\nF2IyGZgHLACuSlHmF87nbwMTAtT9AvA+sBM4KOFc1zjl5wHHZ2m7EGVJLAYLFnj7QcRk2zbbd9PU\nlwuJzVyrVxfPlnImWzGpBm7HRGEc8CVg34QyJwF7AnsBFwC/ClD3XeAM4G8J5xoHnO2sJwN35OAe\nhChL/M05QcSkHJu4QDGTQpHtg3gisBBoAbYDDwGnJZQ5FZjubL8K9AeGpKk7D/ggyfVOA2Y45Vuc\n+hOzvAchyh6JSbGtKH+yFZPhwDLf/nLnWJAywwLUTWSYUy6TOkJUPBKTYltR/mQrJkH7ReRz3jb1\nzRCiC267zQLryQYtgicmM2YU1q5C4YqJ2+yn3lz5Idv5TFYAI337I4n3HJKVGeGUqQlQN931RjjH\nOjF16tR/bTc1NdHU1JTm1EKUJ5dcYuseKV4dBwywybOeey6+B1i54IqJOwpeo+E9mpubaW5uLrYZ\ngInRImA0UAu8RfIA/ExnexLwSgZ1XwAO9u2Pc8rVAmOc+sm8npgQlQzEYgcd5G139S+xc2cs1rNn\nLLbXXrHYO+8Uxr5CsnlzLFZTE4utWWPfw5VXFtui0oUsWnqy9Ux2AFOAp7HeWfcBc4ELnc/vwoTk\nJCxY3g6cl6YuWE+uXwANwF+AN4ETgTnA/zrrHcC3UTOXEFnRo4fNq75gQXmln3fZbTeoq4OVK23f\nHVMjcksupu19yln83JWwPyWDugCPO0syfuIsQogk7Npl60xiA4MG2fiLQYPyY1OxiUSgpcW2JSb5\nQWM0hCgDOjpMRLZts/gHWEqVoFQ5jcU9c/F6WYI0NsIHzmCDbMQkFvMGd4p4JCZChJyHH7YeWVdf\nbc1Ve+wBQ4fCokXwxBPBzjFxoqWiL1f22w+mTrWU9Js2df88P/0p9OqlHmHJkJgIEXJWr7Zuv3Pm\nWCqUjg6LD3zjG7BmjTVdpUvBft99Jj7lyt132/fy3HOW7LG7uE1lmXh9lYLERIiQ097uxQT8AXR3\n/MiWLfY2LbIfwOjW1SDIzkhMhAg5HR32kFy8OLmYbN0qMXFpbIwfwJgpEpPUSEyECDmuZ+KKikt9\nvcUHdu6Empri2VdK9O5t30V34ybRKIwbJzFJhsREiBAzaxZs2OCJSKKYbNhgXklVPhMahYxIxIsh\nffAB3HSTxZt++UtYssQrt24d/PznFkt64w07Fo1aMF9i0pky7QgoRGVwvDOjz9VX2zpRTNavVxNX\nIgMGmMgC3H8/3HwzvPkmPPSQNQlecYV99re/wWWXwQ9+YF5fR4d9Pno0tLUVzfySRZ6JEGVAKs9k\n3TobAS48/DNLRqMwfjy8/77t+3u9+WefdD+LROLrCw+JiRAhxZ+wMJWYyDPpjF8MWlut2eq992zf\n33yVKBjRqMSkKyQmQoQU/1t0V2IizySeZJ5JLAb9+0tMskFiIkRI8T/4GhpsnayZS55JPL17x4vJ\nfvvZdmJgvb0d+vTx9leulJh0hQLwQoSMhQvhhBPg44+9Y4MGWQqVIUO8YwMH2oDFfv0Kb2MpU19v\n87YsXGjisf/+dvyQQ6z31qhRtr90KRx0EMyebfvf/CZ8//sSk1TIMxEiZHzwgeXfeu01S+uxapXt\nf/hhvHDsvTesWAEzZ6Y+VyXiTk38179acH34cEs7c+ONsGyZ9eK65x4rc8459v3uuaft33CDxCQV\n8kyECBnRqImH+wbdt6+tk8VGhg0rnF1hwR3AuWqVjYivqvKaB13Pzv0u+/SxY+44nV69JCapkGci\nRMhwA8Gie7hdfd0YSDLceV3q6mztT78iMUmOxESIkCExyQ6/EKT6Ht15XbZs6fyZxCQ5EhMhQobE\nJDv8QuDGT4KU9deRmHRGMRMhQobEJDsuuQT23deC7ocdlrrclClw8sm2/fOfe3m7JCbJKdf0b7GY\npkITZcrBB9tkTwcfXGxLKpNt20xQtm8vtiW5p8p6GnRLF9TMJUTIiEatF5IoDrW1ttZc8PFITIQI\nEbGYxKQUUFNXZyQmQoSIjz6ysQ5ul1VRHCQmnZGYCFHCxGJw3HG2/t3v4NRTYfDgYlsl6utttPwp\np8BJJ1lqlkpHYiJECfPRR/DsszYZ05//DIcfDo8/XmyrRH29zb744Yfmobj5uyoZiYkQJYybxTYa\nteXYY70st6J41NfD4sWWs0vT+BoSEyFKmEQx0fiS0sAVk0jEFomJxESIksZ9SK1ZY4vEpDSQmHRG\nYiJECePOpvjuuxY/cRMQiuKSKCarV9v8J6nGSm/ebPPP7Npl++3t5dcbTGIiRAmzbp313vr97+Go\no6CH/mNLggkTbKzPpz5lYvLHP9qUAO++m7x8v342VcCNN9p+nz62vPFG4WzON/ppClHCbNkC3/qW\ndT195pliWyNcrrgCFiyApqb4AaQrViQvv2OHrefNiz++YUNezCsKEhMhSpgtW5JPeiVKB38cy22W\n7IqdO71tNzVLOSAxEaKE2brVRryL0qV/f287SCB+3Tpv2531sRyQmAhRwkhMSh83jjVgQHoxaW+P\nL+P3UsKOxESEhrY2uP9+WL4crroKNm0qtkX54Te/sZ4/YM1cEpNwMHx4ejH55z/hhhu8/XJKYy8x\nEaHhoYfgq1+Fv/0Nbropdc+ZsHP++TBzpm1v3aqYSRh48UX40Y9Si8m4cfDjH8OVV8Khh0JDgx2X\nmMQzGZgHLACuSlHmF87nbwMTAtQdCMwCPgCeAdxWydHAZuBNZ7kjB/aLkOBOseofFV6uuPeqZq5w\ncMQRMHp06t/kzp1w5plw8cVw+eVw5JF2vJzmRMlWTKqB2zFRGAd8Cdg3ocxJwJ7AXsAFwK8C1L0a\nE5O9geecfZeFmCBNAL6dpf0iRLgP1ZUrbV2OYuIOZKty5rpTM1d46GokfKq/ozwTj4nYw70F2A48\nBJyWUOZUYLqz/SrmZQxJU9dfZzpwepZ2ijJg82Zbz51rffvLUUzce3JFRc1c4cH9TSYbBZ/q7yjP\nxGM4sMy3v9w5FqTMsC7qDgbWONtrnH2XMVgTVzNwePdNF2HDfcC+9175ZmpNJibyTMJB797W1TdZ\nx5BUf0d5Jh4pMtF0IsgE9VUpzhfzHV8JjMSauK4AHgT6BrRBlDCf/az9I9bUWOqJ9eu9zw47DKZO\n9R6wS5bAMcfAbbd5TV7lgjvo7Xe/s+/i73+XmISJsWNh4EDvtzxihHkqiWLymc/Yupw8k55Z1l+B\nPdxdRmIeRldlRjhlapIcd5MRrMGawlYDQwH3HXSbswDMBhZhsZhOU9NMnTr1X9tNTU00NTUFuiFR\nHObOtd5ZY8fCgQdaWoqBA+2zl1+2tBNnnQU//CFcfz307GmTRC1dCsOGFdf2XOJ2CZ49G4YOtdn8\nJCbhYfbs+LEjAweap5IYM7n0Upgzp/ieSXNzM83NzTk5V7Zi8jr2MB+NeQ1nY4F0P08CU7CYyCSg\nDROLdV3UfRI4F/iZs37COd4AbAB2Ap9w6n+YzDC/mIjSJxq1fvo1NZbYMLEJy82y2tBgQgL2sC23\npi7X+9q40TyyZcsUMwkTPXrEJ+OMRMx7rqryfrcuNTXF90wSX7SnTZvW7XNl28y1AxOKp4E5wMPA\nXOBCZwGYiT3wFwJ34fXASlUX4EbgOKxr8NHOPsARWPfiN4FHnGu0ZXkPosh0dNjbXJ8+tp+sV4wr\nJm6X2VTlwo4/Lbk7o6I8k/ASiaT2Lmtri++Z5JJsPROAp5zFz10J+1MyqAuwHjg2yfE/OIsoI1pb\nrSeM2x02WU+tjo7OYlKOPbra2y1V+aZNMH68HZOYhJfGxtRiUgqeSS7RCHhRNHbsgOeft/Qh/syr\nkQg8+6ylXXfZvBnef7+zZ/LCC3D33VY2FoMHHrCJisJKe7v3XezrjLoqp2SAlUYkAn/6U/KmynLz\nTCQmomg8/bT1yrrnHvjOd7zjJ59svbluv932hwyByZPh4INh4kSv3AknwJgxMH26lV26FM45Bx5+\nuLD3kUva2+H44+37GDfOerENGFBsq0R3+eIXzTu57LLOn5WbZ5KLZi4huoXr+n/603Deed7xCRNs\nQqinnAbQXbvg17+2gLuf8ePNK3ngASvrNnkFmVOiVGlvh4MOgosusv3rry+uPSI7jj3WlmTU1lrz\nbbkgz0QUDXdiIDfw7scfXE+MlaQqWw45u9Ldqygfys0zkZiIorF1q62rkgxpdQUiFrO3t6Bi0rev\nxESEg5oaxUyEyAmumCQjEoE1ayzwXlsL1dVdl3XFJOxpViQmlUNtrTwTIXKCKyZjx3b+rLHRRr3/\n+7/D7rt3fZ6GBli71sRn/HiJiQgH5eaZKAAvisaWLZaT6+qrO39WUwPz5tl4C/8c28morbXmrfnz\nbV6JJ5/Mj72FQGJSOZxxhvXcKxckJqJobN1qTVSpmrDGjAl+rkjEsgmfe278KPKwITGpHAYO9PLP\nlQNq5hJFI5fp1SMRG2cyerQF7JPNKREGJCYirEhMRNHYsiV3SQzdUeNDh1qzlzuRVtiQmIiwIjER\nRSOXnkljo7eurw9fU1csZmn3JSYirEhMRNHIpZgcdRSceKJ5OmEUk4cftomUqquVi0uEE4mJKBqJ\nEwZlw1lnwcyZth1GMXFTwMgrEWFFYiKKxtat+Zn4KYxismuXrXv3Lq4dQnQXiYkoGrls5vITRjHZ\nuNHW5TSITVQWEhNRNHLZzOUnjGLiT2opRBiRmIikbN8O3/ymbW/YAP/xH/GTVXWHF1+Es8+25Yor\n4Lnn8tPM1acP3HgjXH652e2fLCsWg69/Pdg4lK9/3eZHKcSYFVdMFHwXYUViIpKyahXce689SOfM\ngRkz4B//yO6cM2faw3LXLrjlFjjzTDjllNzY6+eGG6BnT7j1Vnj8cUvL4rJpk82N4jYrdcWvf21z\npQQpmy0bN8JPfwrNzfm/lhD5QGIikrJ+va23bs3dPCHRKBx9NBx5pO2ffHL6JI7dYZ99bFZGsJQs\n/qajoPfi90YKkTiyo8MmCTvwwPxfS4h8IDERSfG34edSTCIRb7S6f973XOMOYuyumPhTgxdCTDRY\nUYQdJXoUSUkUk6FDcycm7lSl+RQTt4ttY2P3xGTLls518onERIQdiYlISkuLrVetsgSK2Uw61dFh\nMYwlS+zh/tFHdjyfGVPdZip/z672dvjwQ9v238u2bZ5NYHEd/8Rd7neRTyQmIuxITEQnduyA666z\n7UmTbPKpyy+3YHZ38D8khwyxuUeg69kTs+WQQ2DQoHgx2WcfE7Z+/bwR5wBf/jI8/bQ3J31bW3wg\n/Lvftd5n+URiIsKOYiaiE2vXWhOUGxxvbbUutt3xTPyB7P/8T6irM3HKd3fbSZPsPlwx2bHDughH\no3DVVfFNX0uXwjPPWPm1a212x0WLvM979ICdO/Nrr8REhB2JieiEG9vwP3AbG+14piLgP0c+YySp\ncMVk3ToYMMC8ocRBje79+u1ctszb793bi/Pkg+3brbu06xkJEUYkJqIT7sN1xw6vKaq+3rY//jjz\nc7kUU0z8ghFETJYu9fZ7987vyHTXK6mqyt81hMg3EhPRicQHr0skknlTVzRqwXvwuusWknRi0t5u\n3lbiffo9k3ynZ1ETlygHFIAPSDQKs2bBHnvAZz5TbGvyi//B26+fdzwSgZ/9DIYNs/1jjun6u1iz\nBn75Sxvr8d57xXlgphKTF16AqVPN/kgk3iuIROCvf/X26+rg+uvh+OPhq1/NvY0SE1EOyDMJyLXX\nwle+AkccUWxL8o/74H3jDXvoulx7rSckc+bAHXd0fZ6XXrJzXHYZPPIIjB+fP5tTMWCA9c5KFJPW\nVrjpJksZk9j8dtpp3vbbb1u85fe/h6uvzo+NbW3Qv39+zi1EoZBnEpB89+YpJaJR61p70EHxx085\nxcul9eyzlkuqK9rbLa3J0Ufnx84guB0HkjXdjR8P774Lo0fH13H3Bw+GAw6IbxLLB9GoXUuIMCPP\nRHQiMSCdjCDxk1JovnHtTCYm++0H77+f+l7b2mztdjpw4yu5Zs2a4nROECKXSEwC4s6EVwnkSkw6\nOoovJoMGWQr9VauSiwkkv9fq6vhR8L162XgTf86uXBHk+xai1JGYBMTNogvlP4FRa2v6h1tDg30n\nXTX/lYJn0rOnxSPmzs1MTBKzGdfX25KP8SYSE1EOSEwC4n8LHznSHqZ77AGjRtl2Y2N+JnrKN+ef\nb/YPHWpv39//vt1rum68PXvaA3fECNh7b/PcolEYPtzO19BgvaWKLSZgf6fFi81WsMmzAMaNs/XI\nkZ3rfPaa5EJZAAAQJ0lEQVSznth86lMweXJ+ughPm2YdGZLZIESYUAA+INEozJ9vDyT37dQVkLlz\nrftrGHn9dXjsMeuptXq1zYa4fbuXP6sr+vSx5I0bN8LmzZYQMRKxLtRz5ti8JaUgJn//uzVPud5G\nfb03IHPjxuT3On26Fx+ZPdu6Du+zT+7F5M034e674fOfz+15hSg0ufBMJgPzgAXAVSnK/ML5/G1g\nQoC6A4FZwAfAM4C/4+Q1Tvl5wPHZmx+M1lZLUti7t/fmDdaOPmpUoazIPdGoeRYDBtj+smWdx12k\nwi3jH8sxdKh9N3vv7X1WbOrqOjdbuSP7+/VLfq9VVfa3BVtXVeXHM4lGYexYjX4X4SdbMakGbsdE\nYRzwJWDfhDInAXsCewEXAL8KUPdqTEz2Bp5z9nHKne2sJwN35OAe0tLR0fXbelgfBLt22RgKVxgh\nPlAdlGQDAwcNsnVdXW5sLQXyISZB4lNChIFsH8QTgYVAC7AdeAg4LaHMqcB0Z/tVzMsYkqauv850\n4HRn+zRghlO+xak/Mct7SIv7Dx9ENPKdDTeXrF9vb+Y1NfHHM324ubmr/GLintPfIyrs5MszkZiI\nciBbMRkO+LIYsdw5FqTMsC7qDgbWONtrnH2cOsvTXC/nPPts8LnK/TP0pWPzZhsZ/vLL8Mc/Bu9+\nvH49/OlP2QvXCy94wWg/QeIlftyH7MyZnR+M+cy2W2jq6+G55yxN/SOPwB/+YLGX7rJli/0Ggv62\nhChlsg3AB32cBWkIqkpxvlia6yT9bOrUqf/abmpqoqmpKYAJyfnlL5OP4n70Ua8H1223wcUX20M1\naNPO88/bPCHuA2nOHNg3sZEwCQ8+aNdassR6KnUX/31NnWo9llpa4LDDgtW/915rFrv3Xpg3z9Kn\n3Hqr9/m118Kpp3bfvlLjc5+DH/zAPNVZs0woX3gBDjywe+fLxOMVIh80NzfT7J8JLguyFZMVgL9T\n40jiPYdkZUY4ZWqSHF/hbK/BmsJWA0MBt2NusnOtIAl+McmWjz6CKVM6Hz/zTG97yhTL9dTeHh+D\n6Ipo1LqnvvOOtx9ETPzzmGcjJm1tcMkltn3wwbZkwjHH2HrGDJsO96CD4lOw/PjH3betFPnyl+Ef\n/7Dl2GNNQDPxRBNRE5coNokv2tOmTev2ubJt5nodC6yPBmqx4PiTCWWeBNxcq5OANkwsuqr7JHCu\ns30u8ITv+Bed8mOc+q9leQ9pCfpPn2mbuj89u7sftF4m5bs6Ty4eZvX1No6jEh6MkYhlQI5EzCvN\nJiYkMRHlRLaeyQ5gCvA01jvrPmAucKHz+V3ATKxH10KgHTgvTV2AG4H/Bb6OBdrPco7PcY7Pcep/\nm+BNbd1iyxZ7YPhTsaeiO2LiZtLt3TszMamvz05Mdu2yZpZczDFSX2/T3IZ1rE0mRCI2ZiUSsUGe\nEhMhjFwMWnzKWfzclbCfpJEoZV2A9cCxKer8xFkKgvvADdKu3R0x+dSnbHvEiMzEZPz47MSkrc3s\nzcVUsa5nMjHv/eqKjyu+jY0mJmrmEsLQCPg03HRT8IB6d8TEfZh88pMWvK6rg+9+1+bU2LTJelb1\n6GHjNo49Fu65x0ZNn3OOBf3HjIEvfCH4NTs67Dy1tZ27BHeXfv1gxQob1FnuuPc4ZEh8M9evf21/\nr0svDX4uiYkoJyQmaXjtNbjuumBluysmy5ZZUPdPf4JrrrFJuF5/3VJ6nHSSV37WLPNknn/eAvd9\n+lhvokzExL0WmBjlgiuvhBNP7H6vpjAxaRK88op1NHjgAc8zufRSS1WfqZgUY8IwIfKBEj2mYcMG\nmygqCN0VkxEj4idHchMmnnBC5/Ljx5s9ffvaujtzsrscemhmdVMxYIBNf5vp+JQwUl1t33tNjWIm\nQviRmKQhk3/4TMQkFosPgPuv4V6zR8JfZ+fO+HJB5hRJxF9eD7LsUG8uITwkJl2wdavFGILOz52J\nmLS1WQ+uXr1sP5mYJCMXYuKmYs9FT65KRgF4ITzKPmbS0QELF1rKim3bLMNv0B5M//xn8J5cYGKy\nYoV1u3Vn5VuyBPbaK75cLGaxGP+DxC9Y77wTXExWrfIGPQ4ebFPA+hk50ssIvHixjZEYORKWLw/n\n/CulRK9eNljznXe8LAYtLfZ9b9pkLxZjx9rcL36WLLHvfvlyCboQpU7M5cc/jsUGD47F7BEei91z\nTywQixbFYr16xWLnnx+sfCwWi/3Xf9k17rzT9u+7z/ZXrYov9847sVhdXSz2ne/EHx81KhY75JBY\nbP/9Y7FHHrFjX/1qLLbXXnZ89OhY7MMPvfI7d8ZiRx5p5QcNsmuNGmX7++9v22efbWW3bo3Famvt\n+D33xGKf/GTw+xLJmTrVvvN99onFDj3Ulrq6WOzzn/d+b48/3rme+1m/foW3WYiuIItxe2XvmaxY\nAT/8oeWyAli5Mni9gw+G++4Lfi137o4lS2zd1mbr1avju82uXAmHHw633x5fv6Wl8zmnT+98zKVH\nD3DT6lx3naUv+e1vwc2O0NwM119v22vXwsCBnhfzjW8EuiXRBW4T5RtvWJMlmAf84YdemcTfmz+Z\n5y235Nc+IQpJ2cdMotHOPaW6Uy8IiRNBufGTxGvmo63cPV+qmIra53OPG3x3hQTsO16wwNtP/Nv7\nsyjr7yHKiYoQk8TgdnfqBcEdBLh9u60lJuXNxo2dj0UiFi9xSfzb+zto6O8hygmJScB6QXDfOteu\ntXV7uzVFFUJM3AD+wIHesYEDLePx9u0Sk3zw0UedjyUG1JOJidvlW38PUU6Ubczkhhust8ySJd4/\neK9e8O67cOGFNi/Fn/8MZ53lff6jH1l8A2yU+AUXZHZNtz38hRes7sMPW7qTBx6wdOUuL70E3/te\ndveXiNt+7x+b0qOHDbJ78EG4887ggy9FMJJNTuZPCDpkiP2OfvMbOM9Jb9rebr+JRYvUk0uUF+U6\nLU8MYhx3nP0Tf/GL1j24rs5mTbzzTgvIf+Urlg/r0kvtwdCzp00Y5T6QTzsts7jJ1q3w979bF9x3\n3rF0JTff3HlkeFUVnHFG8HlPgrBrl83YePjh8ccvuQSWLrWZHOfPh733zt01K50NG6yjhn8agXnz\nLGYyZox5hrffbi8PL71kn7/yiv3ebr7ZsgYIUUpU2TiIbulC2XomYLMIfulLtu2O9fja1+Cttzo3\nP2zdam/xF13U/ev16uVNGPXqqyYme+8Np5zS/XMGpUePzkICcPrp8O1vW5OKhCS3DBjgjeFx2Wcf\nW1xOP91yqrm0t1tHDQmJKDfKOmaSqk26sdHeKMFSlID3T57ra+fynN21Y/58tc8Xi8SUK7n+nQlR\nKlSkmEQiNt86xAfLy1VM/GtRWBKTQUpMRLlS1mKS6p82ErFAPFjAPRbL/T+5e65kQdpCMmiQrf29\nvEThSMzfJTER5UpZx0z22CP58U98wjySxka4/34LvF9wQX7+yYvtEVRX2zoxP5goDImeyccf2zw0\nQpQbZSsmXXkE++8Pmzfb9rPPwk9/mp83xmJ7JS6lYkclkhgz8U87IEQ5UdbNXEFwR4l3dKj5QeSe\nRM9Eg0dFuSIxccSkvV3NDyL3uDET1zuUmIhypeLFpKEB1q+3fEryTESuqa62xZ3vRGIiypWKF5Oe\nPS1t+GOPSUxEfvA3dUWjipmI8qTixQRg2jTr7XT22cW2RJQj/u7BGzZ0HjUvRDlQtrm5YurCJEqE\n4cNtmubhw01YNm7UlMmiNMkmN5c8EyHyjNvMtWOHLW6GZyHKCYmJEHnGFRN3LFNVubYHiIpGYiJE\nntltN4uZKJWKKGfKdgS8EKXCbrvBscfCEUdITET5Is9EiDwzaJCNZXriCYmJKF8kJkLkGf+MmhIT\nUa5ITITIMzU13rbERJQrEhMhCkhdXbEtECI/SEyEKCDu1AdClBvqzSVEnrnoIpvpsqMDDjyw2NYI\nkR+yGT41EHgYGAW0AGcBbUnKTQZuBaqBe4GfBah/DXA+sBO4BHjGOd4MDAHc97vjgLVJrql0KkII\nkSHFSqdyNTAL2Bt4ztlPpBq4HROUccCXgH3T1B8HnO2sJwN34N1cDPgPYIKzJBOS0NPc3FxsE7pN\nmG0H2V9sZH94yUZMTgWmO9vTgdOTlJkILMQ8j+3AQ8BpaeqfBsxwyrc49Q/xnbPsk1GE+QcZZttB\n9hcb2R9eshGTwcAaZ3uNs5/IcGCZb3+5c6yr+sOccv46w3z704E3gR9213AhhBC5JV0AfhYWo0jk\n2oT9mLMkknisqotyQYIcXwZWAn2Ax4BzgPsD1BNCCFGizMMTmqHOfiKTgL/69q8BrkpT/2ri4y9/\nJb6Zy+Vc4LYUti3EEygtWrRo0RJsWUgRuAlPGK4GbkxSpiewCBgN1AJv4QXgU9Uf55SrBcY49auw\nYL6bmKIGeBS4ICd3IoQQomgMBJ4FPsC67vZ3jg8D/uIrdyIwH1O8awLUB/iBU34ecIJzrB54HXgb\neA+4hQoIxgshhBBCCCFCyGTMm1mA14RWavwa6732ru/YQKyzQzIv7RrsfuYBxxfIxq4YCbwAvI95\niJc4x8NwD7sBr2LNqHOAnzrHw2C7n2qsR+OfnP0w2d8CvIPZ/5pzLEz298ea2Odiv6FDCI/9n8S+\nd3fZiP3/hsX+glGNNY2NxmIq/vhMKfEZbMClX0xuAr7vbF9F5/hRDXZfCyl+PrUhgJsUpA/WhLkv\n4bmH3s66J/AKcDjhsd3lCuD3wJPOfpjsX4w9vPyEyf7pWHYOsN/Q7oTLfpcewCrs5TCM9ueVQ4nv\nOZbYK6yUGE28mMzDG2czBK9nm7/3G9j9Tcq3cRnyBHAs4buH3sA/gfGEy/YRWKzxKDzPJEz2LwYG\nJRwLi/27Ax8mOR4W+/0cD7zkbOfE/nJSma4GSJY6mQzgLKV7Go15Wa8Snnvogb1trcFrrguL7WAd\nT64EdvmOhcn+GCaGrwPfdI6Fxf4xQCvwG2A2cA/WMSgs9vv5IpZpBHJkfzmJSazYBuQIt793V5+X\nAu7A0UuBTQmflfI97MKa6UYAR2Bv+H5K2faTgSjW3p2qJ2Mp2w9wGPYCciLwHazZ108p298TOAjL\nF3gQ0E7n1o9Stt+lFjgFeCTJZ922v5zEZAXW/ucyknhVLWXWED+AM+psJ97TCOdYsanBhOR+rJkL\nwncPG7Eu7AcTHts/jeW0W4y9VR6N/Q3CYj9YOz3YG/7jWP6+sNi/3Fn+6ew/ionKasJhv8uJwBvY\n3wDC8/0XjK4GSJYao+kcgM9kAGcxqQJ+hzW3+AnDPTTg9VSpA/4GHEM4bE/kSLyYSVjs7w30dbbr\ngf/D2u7DYj/Yb2ZvZ3sqZnuY7AdLuHuubz9s9heEVAMkS4kZWH6xbViM5zwyH8BZTA7Hmorewuti\nOJlw3MP+WFv3W1j31Cud42GwPZEj8XpzhcX+Mdh3/xbWrdz9Hw2L/QCfwjyTt4E/YEH5MNlfj03d\n0dd3LEz2CyGEEEIIIYQQQgghhBBCCCGEEEIIIYQQQgghhBBCCCGEEKLc+X+LTdyDrkqycQAAAABJ\nRU5ErkJggg==\n",
607 | "text/plain": [
608 | ""
609 | ]
610 | },
611 | "metadata": {},
612 | "output_type": "display_data"
613 | }
614 | ],
615 | "source": [
616 | "g.plot_reward(smoothing=100)"
617 | ]
618 | },
619 | {
620 | "cell_type": "code",
621 | "execution_count": 12,
622 | "metadata": {
623 | "collapsed": true
624 | },
625 | "outputs": [],
626 | "source": [
627 | "session.run(current_controller.target_network_update)"
628 | ]
629 | },
630 | {
631 | "cell_type": "code",
632 | "execution_count": 13,
633 | "metadata": {
634 | "collapsed": false
635 | },
636 | "outputs": [
637 | {
638 | "data": {
639 | "text/plain": [
640 | "array([[ 0.00824914, -0.04792793, 0.08260646, ..., -0.00135659,\n",
641 | " -0.0149605 , -0.00065048],\n",
642 | " [ 0.04895933, -0.01720949, 0.03015076, ..., 0.04350275,\n",
643 | " -0.00071916, -0.00507376],\n",
644 | " [-0.03408033, -0.00734746, -0.07286905, ..., 0.06636748,\n",
645 | " 0.0507561 , 0.04723936],\n",
646 | " ..., \n",
647 | " [-0.01454929, -0.00313209, 0.02152171, ..., 0.01723659,\n",
648 | " -0.01757577, -0.02262094],\n",
649 | " [ 0.03226471, -0.09545884, 0.01721121, ..., 0.0179732 ,\n",
650 | " -0.01188065, 0.03430547],\n",
651 | " [ 0.02971489, 0.06272104, -0.05087573, ..., -0.00265156,\n",
652 | " -0.00139228, 0.01183042]], dtype=float32)"
653 | ]
654 | },
655 | "execution_count": 13,
656 | "metadata": {},
657 | "output_type": "execute_result"
658 | }
659 | ],
660 | "source": [
661 | "current_controller.q_network.input_layer.Ws[0].eval()"
662 | ]
663 | },
664 | {
665 | "cell_type": "code",
666 | "execution_count": 14,
667 | "metadata": {
668 | "collapsed": false
669 | },
670 | "outputs": [
671 | {
672 | "data": {
673 | "text/plain": [
674 | "array([[ 0.00792486, -0.04620561, 0.0822738 , ..., -0.00077494,\n",
675 | " -0.0152533 , -0.0006482 ],\n",
676 | " [ 0.04957085, -0.01709681, 0.03076575, ..., 0.04224423,\n",
677 | " -0.00178666, -0.0045335 ],\n",
678 | " [-0.03365552, -0.00761211, -0.07296719, ..., 0.06496961,\n",
679 | " 0.04951147, 0.04784767],\n",
680 | " ..., \n",
681 | " [-0.01839499, -0.00162295, 0.02568982, ..., 0.01739434,\n",
682 | " -0.0198904 , -0.01868791],\n",
683 | " [ 0.03405851, -0.09418759, 0.01544813, ..., 0.01676139,\n",
684 | " -0.00888498, 0.03387775],\n",
685 | " [ 0.0298653 , 0.06385598, -0.05152962, ..., -0.00786654,\n",
686 | " -0.00157313, 0.01064544]], dtype=float32)"
687 | ]
688 | },
689 | "execution_count": 14,
690 | "metadata": {},
691 | "output_type": "execute_result"
692 | }
693 | ],
694 | "source": [
695 | "current_controller.target_q_network.input_layer.Ws[0].eval()"
696 | ]
697 | },
698 | {
699 | "cell_type": "markdown",
700 | "metadata": {},
701 | "source": [
702 | "# Visualizing what the agent is seeing\n",
703 | "\n",
704 | "Starting with the ray pointing all the way right, we have one row per ray in clockwise order.\n",
705 | "The numbers for each ray are the following:\n",
706 | "- first three numbers are normalized distances to the closest visible (intersecting with the ray) object. If no object is visible then all of them are $1$. If there's many objects in sight, then only the closest one is visible. The numbers represent distance to friend, enemy and wall in order.\n",
707 | "- the last two numbers represent the speed of moving object (x and y components). Speed of wall is ... zero.\n",
708 | "\n",
709 | "Finally the last two numbers in the representation correspond to speed of the hero."
710 | ]
711 | },
712 | {
713 | "cell_type": "code",
714 | "execution_count": 15,
715 | "metadata": {
716 | "collapsed": false
717 | },
718 | "outputs": [
719 | {
720 | "name": "stdout",
721 | "output_type": "stream",
722 | "text": [
723 | "[[1.00 0.55 1.00 -0.42 -0.40]\n",
724 | " [1.00 1.00 1.00 0.00 0.00]\n",
725 | " [1.00 1.00 1.00 0.00 0.00]\n",
726 | " [1.00 0.83 1.00 0.42 0.63]\n",
727 | " [1.00 1.00 1.00 0.00 0.00]\n",
728 | " [1.00 1.00 1.00 0.00 0.00]\n",
729 | " [1.00 1.00 1.00 0.00 0.00]\n",
730 | " [1.00 1.00 1.00 0.00 0.00]\n",
731 | " [1.00 1.00 1.00 0.00 0.00]\n",
732 | " [1.00 1.00 1.00 0.00 0.00]\n",
733 | " [1.00 1.00 1.00 0.00 0.00]\n",
734 | " [1.00 1.00 1.00 0.00 0.00]\n",
735 | " [1.00 1.00 1.00 0.00 0.00]\n",
736 | " [1.00 1.00 1.00 0.00 0.00]\n",
737 | " [1.00 1.00 1.00 0.00 0.00]\n",
738 | " [1.00 0.44 1.00 -0.18 0.76]\n",
739 | " [1.00 0.46 1.00 -0.18 0.76]\n",
740 | " [1.00 1.00 1.00 0.00 0.00]\n",
741 | " [0.89 1.00 1.00 -0.95 0.78]\n",
742 | " [1.00 1.00 1.00 0.00 0.00]\n",
743 | " [1.00 1.00 1.00 0.00 0.00]\n",
744 | " [1.00 0.44 1.00 0.45 -0.81]\n",
745 | " [1.00 0.20 1.00 -0.64 0.14]\n",
746 | " [1.00 0.19 1.00 -0.64 0.14]\n",
747 | " [1.00 0.21 1.00 -0.64 0.14]\n",
748 | " [1.00 0.57 1.00 0.56 0.78]\n",
749 | " [1.00 1.00 1.00 0.00 0.00]\n",
750 | " [1.00 1.00 1.00 0.00 0.00]\n",
751 | " [1.00 0.92 1.00 0.41 0.77]\n",
752 | " [1.00 1.00 1.00 0.00 0.00]\n",
753 | " [1.00 1.00 1.00 0.00 0.00]\n",
754 | " [1.00 1.00 1.00 0.00 0.00]]\n",
755 | "[1.00 -0.94 -0.25 0.46]\n"
756 | ]
757 | },
758 | {
759 | "data": {
760 | "text/html": [
761 | "\n",
762 | "\n",
763 | "\n"
1061 | ],
1062 | "text/plain": [
1063 | ""
1064 | ]
1065 | },
1066 | "execution_count": 15,
1067 | "metadata": {},
1068 | "output_type": "execute_result"
1069 | }
1070 | ],
1071 | "source": [
1072 | "g.__class__ = KarpathyGame\n",
1073 | "np.set_printoptions(formatter={'float': (lambda x: '%.2f' % (x,))})\n",
1074 | "x = g.observe()\n",
1075 | "new_shape = (x[:-4].shape[0]//g.eye_observation_size, g.eye_observation_size)\n",
1076 | "print(x[:-4].reshape(new_shape))\n",
1077 | "print(x[-4:])\n",
1078 | "g.to_html()"
1079 | ]
1080 | },
1081 | {
1082 | "cell_type": "code",
1083 | "execution_count": null,
1084 | "metadata": {
1085 | "collapsed": true
1086 | },
1087 | "outputs": [],
1088 | "source": []
1089 | }
1090 | ],
1091 | "metadata": {
1092 | "kernelspec": {
1093 | "display_name": "Python 3",
1094 | "language": "python",
1095 | "name": "python3"
1096 | },
1097 | "language_info": {
1098 | "codemirror_mode": {
1099 | "name": "ipython",
1100 | "version": 3
1101 | },
1102 | "file_extension": ".py",
1103 | "mimetype": "text/x-python",
1104 | "name": "python",
1105 | "nbconvert_exporter": "python",
1106 | "pygments_lexer": "ipython3",
1107 | "version": "3.4.3"
1108 | }
1109 | },
1110 | "nbformat": 4,
1111 | "nbformat_minor": 0
1112 | }
1113 |
--------------------------------------------------------------------------------
/notebooks/pong.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import tensorflow as tf
4 | import time
5 | import os
6 |
7 | from itertools import count
8 |
9 | from tf_rl.models import Layer, LambdaLayer, ConvLayer, SeqLayer
10 | from tf_rl.controller.discrete_deepq import DiscreteDeepQ
11 |
12 | # CRAZY VARIABLES
13 | REAL_TIME = False
14 | RENDER = True
15 |
16 | ENVIRONMENT = "Pong-v0"
17 | MODEL_SAVE_DIR = "./{}-model/".format(ENVIRONMENT)
18 | MODEL_SAVE_EVERY_S = 60
19 |
20 | # SENSIBLE VARIABLES
21 | FPS = 60
22 | MAX_FRAMES = 1000
23 | IMAGE_SHAPE = (210, 160, 3)
24 | OBS_SHAPE = (210, 160, 6)
25 | NUM_ACTIONS = 6
26 |
27 |
28 | def make_model():
29 | """Create a tensorflow convnet that takes image as input
30 | and outputs a predicted discounted score for every action"""
31 |
32 | with tf.variable_scope('convnet'):
33 | convnet = SeqLayer([
34 | ConvLayer(3, 3, 6, 32, stride=(1,1), scope='conv1'), # out.shape = (B, 210, 160, 32)
35 | LambdaLayer(tf.nn.sigmoid),
36 | ConvLayer(2, 2, 32, 64, stride=(2,2), scope='conv2'), # out.shape = (B, 105, 80, 64)
37 | LambdaLayer(tf.nn.sigmoid),
38 | ConvLayer(3, 3, 64, 64, stride=(1,1), scope='conv3'), # out.shape = (B, 105, 80, 64)
39 | LambdaLayer(tf.nn.sigmoid),
40 | ConvLayer(2, 2, 64, 128, stride=(2,2), scope='conv4'), # out.shape = (B, 53, 40, 128)
41 | LambdaLayer(tf.nn.sigmoid),
42 | ConvLayer(3, 3, 128, 128, stride=(1,1), scope='conv5'), # out.shape = (B, 53, 40, 128)
43 | LambdaLayer(tf.nn.sigmoid),
44 | ConvLayer(2, 2, 128, 256, stride=(2,2), scope='conv6'), # out.shape = (B, 27, 20, 256)
45 | LambdaLayer(tf.nn.sigmoid),
46 | LambdaLayer(lambda x: tf.reshape(x, [-1, 27 * 20 * 256])), # out.shape = (B, 27 * 20 * 256)
47 | Layer(27 * 20 * 256, 6, scope='proj_actions') # out.shape = (B, 6)
48 | ], scope='convnet')
49 | return convnet
50 |
51 |
52 | def make_controller():
53 | """Create a deepq controller"""
54 | session = tf.Session()
55 |
56 | model = make_model()
57 |
58 | optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
59 |
60 | return DiscreteDeepQ(OBS_SHAPE,
61 | NUM_ACTIONS,
62 | model,
63 | optimizer,
64 | session,
65 | random_action_probability=0.1,
66 | minibatch_size=8,
67 | discount_rate=0.99,
68 | exploration_period=500000,
69 | max_experience=10000,
70 | target_network_update_rate=0.01,
71 | store_every_nth=4,
72 | train_every_nth=4)
73 |
74 | # TODO(szymon): apparently both DeepMind and Karpathy
75 | # people normalize their frames to sizes 80x80 and grayscale.
76 | # should we do this?
77 | def normalize_frame(o):
78 | """Change from uint in range (0, 255) to float in range (0, 1)"""
79 | return o.astype(np.float32) / 255.0
80 |
81 | def main():
82 | env = gym.make(ENVIRONMENT)
83 | controller = make_controller()
84 |
85 | # Load existing model.
86 | if os.path.exists(MODEL_SAVE_DIR):
87 | print "loading model... ",
88 | controller.restore(MODEL_SAVE_DIR)
89 | print 'done.'
90 | last_model_save = time.time()
91 |
92 | # For every game
93 | for game_no in count():
94 | # Reset simulator
95 | frame_tm1 = normalize_frame(env.reset())
96 | frame_t, _, _, _ = env.step(env.action_space.sample())
97 | frame_t = normalize_frame(frame_t)
98 |
99 | rewards = []
100 | for _ in range(MAX_FRAMES):
101 | start_time = time.time()
102 |
103 | # observation consists of two last frames
104 | # this is important so that we can detect speed.
105 | observation_t = np.concatenate([frame_tm1, frame_t], 2)
106 |
107 | # pick an action according to Q-function learned so far.
108 | action = controller.action(observation_t)
109 |
110 | if RENDER: env.render()
111 |
112 | # advance simulator
113 | frame_tp1, reward, done, info = env.step(action)
114 | frame_tp1 = normalize_frame(frame_tp1)
115 | if done: break
116 |
117 | observation_tp1 = np.concatenate([frame_t, frame_tp1], 2)
118 |
119 | # store transitions
120 | controller.store(observation_t, action, reward, observation_tp1)
121 | # run a single iteration of SGD
122 | controller.training_step()
123 |
124 |
125 | frame_tm1, frame_t = frame_t, frame_tp1
126 | rewards.append(reward)
127 |
128 | # if real time visualization is requested throttle down FPS.
129 | if REAL_TIME:
130 | time_passed = time.time() - start_time
131 | time_left = 1.0 / FPS - time_passed
132 |
133 | if time_left > 0:
134 | time.sleep(time_left)
135 |
136 | # save model if time since last save is greater than
137 | # MODEL_SAVE_EVERY_S
138 | if time.time() - last_model_save >= MODEL_SAVE_EVERY_S:
139 | if not os.path.exists(MODEL_SAVE_DIR):
140 | os.makedirs(MODEL_SAVE_DIR)
141 | controller.save(MODEL_SAVE_DIR, debug=True)
142 | last_model_save = time.time()
143 |
144 | # Count scores. This relies on specific score values being
145 | # assigned by openai gym and might break in the future.
146 | points_lost = rewards.count(-1.0)
147 | points_won = rewards.count(1.0)
148 | exploration_done = controller.exploration_completed()
149 |
150 | print "Game no %d is over. Exploration %.1f done. Points lost: %d, points won: %d" % \
151 | (game_no, 100.0 * exploration_done, points_lost, points_won)
152 |
153 | if __name__ == '__main__':
154 | main()
155 |
156 |
--------------------------------------------------------------------------------
/notebooks/tf_rl:
--------------------------------------------------------------------------------
1 | ../tf_rl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | future==0.15.2
2 | euclid==0.1
3 |
--------------------------------------------------------------------------------
/saved_models/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "../saved_models/karpathy_game.ckpt"
2 | all_model_checkpoint_paths: "../saved_models/karpathy_game.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/karpathy_game.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siemanko/tensorflow-deepq/149e69e5340984d75df3ff1a374920d870517fb9/saved_models/karpathy_game.ckpt
--------------------------------------------------------------------------------
/scripts/make_gif.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # stop script on error and print it
4 | set -e
5 | # inform me of undefined variables
6 | set -u
7 | # handle cascading failures well
8 | set -o pipefail
9 |
10 | SCRIPT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )
11 |
12 |
13 | images_directory=${1:-}
14 | if [[ -z "$images_directory" ]]
15 | then
16 | echo "Usage $0 images_directory"
17 | exit 1
18 | fi
19 |
20 | for img in $images_directory/*.svg
21 | do
22 | if [ ! -f $img.png ]; then
23 | echo "Converting $img."
24 | inkscape -z -e $img.png -b white $img
25 | fi
26 | done
27 | convert -delay 3 -loop 0 $(ls $images_directory/*.png | sort -V) animation.gif
28 |
--------------------------------------------------------------------------------
/tf_rl/__init__.py:
--------------------------------------------------------------------------------
1 | from .simulate import simulate
2 |
--------------------------------------------------------------------------------
/tf_rl/controller/__init__.py:
--------------------------------------------------------------------------------
1 | from .discrete_deepq import DiscreteDeepQ
2 | from .human_controller import HumanController
3 |
--------------------------------------------------------------------------------
/tf_rl/controller/discrete_deepq.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import tensorflow as tf
4 | import os
5 | import pickle
6 | import time
7 |
8 | from collections import deque
9 |
10 | class DiscreteDeepQ(object):
11 | def __init__(self, observation_shape,
12 | num_actions,
13 | observation_to_actions,
14 | optimizer,
15 | session,
16 | random_action_probability=0.05,
17 | exploration_period=1000,
18 | store_every_nth=5,
19 | train_every_nth=5,
20 | minibatch_size=32,
21 | discount_rate=0.95,
22 | max_experience=30000,
23 | target_network_update_rate=0.01,
24 | summary_writer=None):
25 | """Initialized the Deepq object.
26 |
27 | Based on:
28 | https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
29 |
30 | Parameters
31 | -------
32 | observation_shape : int
33 | length of the vector passed as observation
34 | num_actions : int
35 | number of actions that the model can execute
36 | observation_to_actions: dali model
37 | model that implements activate function
38 | that can take in observation vector or a batch
39 | and returns scores (of unbounded values) for each
40 | action for each observation.
41 | input shape: [batch_size] + observation_shape
42 | output shape: [batch_size, num_actions]
43 | optimizer: tf.solver.*
44 | optimizer for prediction error
45 | session: tf.Session
46 | session on which to execute the computation
47 | random_action_probability: float (0 to 1)
48 | exploration_period: int
49 | probability of choosing a random
50 | action (epsilon form paper) annealed linearly
51 | from 1 to random_action_probability over
52 | exploration_period
53 | store_every_nth: int
54 | to further decorrelate samples do not all
55 | transitions, but rather every nth transition.
56 | For example if store_every_nth is 5, then
57 | only 20% of all the transitions is stored.
58 | train_every_nth: int
59 | normally training_step is invoked every
60 | time action is executed. Depending on the
61 | setup that might be too often. When this
62 | variable is set set to n, then only every
63 | n-th time training_step is called will
64 | the training procedure actually be executed.
65 | minibatch_size: int
66 | number of state,action,reward,newstate
67 | tuples considered during experience reply
68 | dicount_rate: float (0 to 1)
69 | how much we care about future rewards.
70 | max_experience: int
71 | maximum size of the reply buffer
72 | target_network_update_rate: float
73 | how much to update target network after each
74 | iteration. Let's call target_network_update_rate
75 | alpha, target network T, and network N. Every
76 | time N gets updated we execute:
77 | T = (1-alpha)*T + alpha*N
78 | summary_writer: tf.train.SummaryWriter
79 | writer to log metrics
80 | """
81 | # memorize arguments
82 | self.observation_shape = observation_shape
83 | self.num_actions = num_actions
84 |
85 | self.q_network = observation_to_actions
86 | self.optimizer = optimizer
87 | self.s = session
88 |
89 | self.random_action_probability = random_action_probability
90 | self.exploration_period = exploration_period
91 | self.store_every_nth = store_every_nth
92 | self.train_every_nth = train_every_nth
93 | self.minibatch_size = minibatch_size
94 | self.discount_rate = tf.constant(discount_rate)
95 | self.max_experience = max_experience
96 | self.target_network_update_rate = \
97 | tf.constant(target_network_update_rate)
98 |
99 | # deepq state
100 | self.actions_executed_so_far = 0
101 | self.experience = deque()
102 |
103 | self.iteration = 0
104 | self.summary_writer = summary_writer
105 |
106 | self.number_of_times_store_called = 0
107 | self.number_of_times_train_called = 0
108 |
109 | self.create_variables()
110 |
111 | self.s.run(tf.initialize_all_variables())
112 | self.s.run(self.target_network_update)
113 |
114 | self.saver = tf.train.Saver()
115 |
116 | def linear_annealing(self, n, total, p_initial, p_final):
117 | """Linear annealing between p_initial and p_final
118 | over total steps - computes value at step n"""
119 | if n >= total:
120 | return p_final
121 | else:
122 | return p_initial - (n * (p_initial - p_final)) / (total)
123 |
124 |
125 | def observation_batch_shape(self, batch_size):
126 | return tuple([batch_size] + list(self.observation_shape))
127 |
128 | def create_variables(self):
129 | self.target_q_network = self.q_network.copy(scope="target_network")
130 |
131 | # FOR REGULAR ACTION SCORE COMPUTATION
132 | with tf.name_scope("taking_action"):
133 | self.observation = tf.placeholder(tf.float32, self.observation_batch_shape(None), name="observation")
134 | self.action_scores = tf.identity(self.q_network(self.observation), name="action_scores")
135 | tf.histogram_summary("action_scores", self.action_scores)
136 | self.predicted_actions = tf.argmax(self.action_scores, dimension=1, name="predicted_actions")
137 |
138 | with tf.name_scope("estimating_future_rewards"):
139 | # FOR PREDICTING TARGET FUTURE REWARDS
140 | self.next_observation = tf.placeholder(tf.float32, self.observation_batch_shape(None), name="next_observation")
141 | self.next_observation_mask = tf.placeholder(tf.float32, (None,), name="next_observation_mask")
142 | self.next_action_scores = tf.stop_gradient(self.target_q_network(self.next_observation))
143 | tf.histogram_summary("target_action_scores", self.next_action_scores)
144 | self.rewards = tf.placeholder(tf.float32, (None,), name="rewards")
145 | target_values = tf.reduce_max(self.next_action_scores, reduction_indices=[1,]) * self.next_observation_mask
146 | self.future_rewards = self.rewards + self.discount_rate * target_values
147 |
148 | with tf.name_scope("q_value_precition"):
149 | # FOR PREDICTION ERROR
150 | self.action_mask = tf.placeholder(tf.float32, (None, self.num_actions), name="action_mask")
151 | self.masked_action_scores = tf.reduce_sum(self.action_scores * self.action_mask, reduction_indices=[1,])
152 | temp_diff = self.masked_action_scores - self.future_rewards
153 | self.prediction_error = tf.reduce_mean(tf.square(temp_diff))
154 | gradients = self.optimizer.compute_gradients(self.prediction_error)
155 | for i, (grad, var) in enumerate(gradients):
156 | if grad is not None:
157 | gradients[i] = (tf.clip_by_norm(grad, 5), var)
158 | # Add histograms for gradients.
159 | for grad, var in gradients:
160 | tf.histogram_summary(var.name, var)
161 | if grad is not None:
162 | tf.histogram_summary(var.name + '/gradients', grad)
163 | self.train_op = self.optimizer.apply_gradients(gradients)
164 |
165 | # UPDATE TARGET NETWORK
166 | with tf.name_scope("target_network_update"):
167 | self.target_network_update = []
168 | for v_source, v_target in zip(self.q_network.variables(), self.target_q_network.variables()):
169 | # this is equivalent to target = (1-alpha) * target + alpha * source
170 | update_op = v_target.assign_sub(self.target_network_update_rate * (v_target - v_source))
171 | self.target_network_update.append(update_op)
172 | self.target_network_update = tf.group(*self.target_network_update)
173 |
174 | # summaries
175 | tf.scalar_summary("prediction_error", self.prediction_error)
176 |
177 | self.summarize = tf.merge_all_summaries()
178 | self.no_op1 = tf.no_op()
179 |
180 |
181 | def action(self, observation):
182 | """Given observation returns the action that should be chosen using
183 | DeepQ learning strategy. Does not backprop."""
184 | assert observation.shape == self.observation_shape, \
185 | "Action is performed based on single observation."
186 |
187 | self.actions_executed_so_far += 1
188 | exploration_p = self.linear_annealing(self.actions_executed_so_far,
189 | self.exploration_period,
190 | 1.0,
191 | self.random_action_probability)
192 |
193 | if random.random() < exploration_p:
194 | return random.randint(0, self.num_actions - 1)
195 | else:
196 | return self.s.run(self.predicted_actions, {self.observation: observation[np.newaxis,:]})[0]
197 |
198 | def exploration_completed(self):
199 | return min(float(self.actions_executed_so_far) / self.exploration_period, 1.0)
200 |
201 | def store(self, observation, action, reward, newobservation):
202 | """Store experience, where starting with observation and
203 | execution action, we arrived at the newobservation and got thetarget_network_update
204 | reward reward
205 |
206 | If newstate is None, the state/action pair is assumed to be terminal
207 | """
208 | if self.number_of_times_store_called % self.store_every_nth == 0:
209 | self.experience.append((observation, action, reward, newobservation))
210 | if len(self.experience) > self.max_experience:
211 | self.experience.popleft()
212 | self.number_of_times_store_called += 1
213 |
214 | def training_step(self):
215 | """Pick a self.minibatch_size exeperiences from reply buffer
216 | and backpropage the value function.
217 | """
218 | if self.number_of_times_train_called % self.train_every_nth == 0:
219 | if len(self.experience) < self.minibatch_size:
220 | return
221 |
222 | # sample experience.
223 | samples = random.sample(range(len(self.experience)), self.minibatch_size)
224 | samples = [self.experience[i] for i in samples]
225 |
226 | # bach states
227 | states = np.empty(self.observation_batch_shape(len(samples)))
228 | newstates = np.empty(self.observation_batch_shape(len(samples)))
229 | action_mask = np.zeros((len(samples), self.num_actions))
230 |
231 | newstates_mask = np.empty((len(samples),))
232 | rewards = np.empty((len(samples),))
233 |
234 | for i, (state, action, reward, newstate) in enumerate(samples):
235 | states[i] = state
236 | action_mask[i] = 0
237 | action_mask[i][action] = 1
238 | rewards[i] = reward
239 | if newstate is not None:
240 | newstates[i] = newstate
241 | newstates_mask[i] = 1
242 | else:
243 | newstates[i] = 0
244 | newstates_mask[i] = 0
245 |
246 |
247 | calculate_summaries = self.iteration % 100 == 0 and \
248 | self.summary_writer is not None
249 |
250 | cost, _, summary_str = self.s.run([
251 | self.prediction_error,
252 | self.train_op,
253 | self.summarize if calculate_summaries else self.no_op1,
254 | ], {
255 | self.observation: states,
256 | self.next_observation: newstates,
257 | self.next_observation_mask: newstates_mask,
258 | self.action_mask: action_mask,
259 | self.rewards: rewards,
260 | })
261 |
262 | self.s.run(self.target_network_update)
263 |
264 | if calculate_summaries:
265 | self.summary_writer.add_summary(summary_str, self.iteration)
266 |
267 | self.iteration += 1
268 |
269 | self.number_of_times_train_called += 1
270 |
271 | def save(self, save_dir, debug=False):
272 | STATE_FILE = os.path.join(save_dir, 'deepq_state')
273 | MODEL_FILE = os.path.join(save_dir, 'model')
274 |
275 | # deepq state
276 | state = {
277 | 'actions_executed_so_far': self.actions_executed_so_far,
278 | 'iteration': self.iteration,
279 | 'number_of_times_store_called': self.number_of_times_store_called,
280 | 'number_of_times_train_called': self.number_of_times_train_called,
281 | }
282 |
283 | if debug:
284 | print 'Saving model... ',
285 |
286 | saving_started = time.time()
287 |
288 | self.saver.save(self.s, MODEL_FILE)
289 | with open(STATE_FILE, "wb") as f:
290 | pickle.dump(state, f)
291 |
292 | print 'done in {} s'.format(time.time() - saving_started)
293 |
294 | def restore(self, save_dir, debug=False):
295 | # deepq state
296 | STATE_FILE = os.path.join(save_dir, 'deepq_state')
297 | MODEL_FILE = os.path.join(save_dir, 'model')
298 |
299 | with open(STATE_FILE, "rb") as f:
300 | state = pickle.load(f)
301 | self.saver.restore(self.s, MODEL_FILE)
302 |
303 | self.actions_executed_so_far = state['actions_executed_so_far']
304 | self.iteration = state['iteration']
305 | self.number_of_times_store_called = state['number_of_times_store_called']
306 | self.number_of_times_train_called = state['number_of_times_train_called']
307 |
308 |
309 |
310 |
--------------------------------------------------------------------------------
/tf_rl/controller/human_controller.py:
--------------------------------------------------------------------------------
1 | from tf_rl.utils.getch import getch
2 | from redis import StrictRedis
3 |
4 |
5 |
6 | class HumanController(object):
7 | def __init__(self, mapping):
8 | self.mapping = mapping
9 | self.r = StrictRedis()
10 | self.experience = []
11 |
12 | def action(self, o):
13 | return self.mapping[self.r.get("action")]
14 |
15 | def store(self, observation, action, reward, newobservation):
16 | pass
17 |
18 | def training_step(self):
19 | pass
20 |
21 |
22 |
23 | def control_me():
24 | r = StrictRedis()
25 | while True:
26 | c = getch()
27 | r.set("action", c)
28 |
29 |
30 | if __name__ == '__main__':
31 | control_me()
32 |
--------------------------------------------------------------------------------
/tf_rl/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import tensorflow as tf
3 |
4 | from .utils import base_name
5 |
6 |
7 | class Layer(object):
8 | def __init__(self, input_sizes, output_size, scope):
9 | """Cretes a neural network layer."""
10 | if type(input_sizes) != list:
11 | input_sizes = [input_sizes]
12 |
13 | self.input_sizes = input_sizes
14 | self.output_size = output_size
15 | self.scope = scope or "Layer"
16 |
17 | with tf.variable_scope(self.scope):
18 | self.Ws = []
19 | for input_idx, input_size in enumerate(input_sizes):
20 | W_name = "W_%d" % (input_idx,)
21 | W_initializer = tf.random_uniform_initializer(
22 | -1.0 / math.sqrt(input_size), 1.0 / math.sqrt(input_size))
23 | W_var = tf.get_variable(W_name, (input_size, output_size), initializer=W_initializer)
24 | self.Ws.append(W_var)
25 | self.b = tf.get_variable("b", (output_size,), initializer=tf.constant_initializer(0))
26 |
27 | def __call__(self, xs):
28 | if type(xs) != list:
29 | xs = [xs]
30 | assert len(xs) == len(self.Ws), \
31 | "Expected %d input vectors, got %d" % (len(self.Ws), len(xs))
32 | with tf.variable_scope(self.scope):
33 | return sum([tf.matmul(x, W) for x, W in zip(xs, self.Ws)]) + self.b
34 |
35 | def variables(self):
36 | return [self.b] + self.Ws
37 |
38 | def copy(self, scope=None):
39 | scope = scope or self.scope + "_copy"
40 |
41 | with tf.variable_scope(scope) as sc:
42 | for v in self.variables():
43 | tf.get_variable(base_name(v), v.get_shape(),
44 | initializer=lambda x,dtype=tf.float32,partition_info=None: v.initialized_value())
45 | sc.reuse_variables()
46 | return Layer(self.input_sizes, self.output_size, scope=sc)
47 |
48 | class MLP(object):
49 | def __init__(self, input_sizes, hiddens, nonlinearities, scope=None, given_layers=None):
50 | self.input_sizes = input_sizes
51 | self.hiddens = hiddens
52 | self.input_nonlinearity, self.layer_nonlinearities = nonlinearities[0], nonlinearities[1:]
53 | self.scope = scope or "MLP"
54 |
55 | assert len(hiddens) == len(nonlinearities), \
56 | "Number of hiddens must be equal to number of nonlinearities"
57 |
58 | with tf.variable_scope(self.scope):
59 | if given_layers is not None:
60 | self.input_layer = given_layers[0]
61 | self.layers = given_layers[1:]
62 | else:
63 | self.input_layer = Layer(input_sizes, hiddens[0], scope="input_layer")
64 | self.layers = []
65 |
66 | for l_idx, (h_from, h_to) in enumerate(zip(hiddens[:-1], hiddens[1:])):
67 | self.layers.append(Layer(h_from, h_to, scope="hidden_layer_%d" % (l_idx,)))
68 |
69 | def __call__(self, xs):
70 | if type(xs) != list:
71 | xs = [xs]
72 | with tf.variable_scope(self.scope):
73 | hidden = self.input_nonlinearity(self.input_layer(xs))
74 | for layer, nonlinearity in zip(self.layers, self.layer_nonlinearities):
75 | hidden = nonlinearity(layer(hidden))
76 | return hidden
77 |
78 | def variables(self):
79 | res = self.input_layer.variables()
80 | for layer in self.layers:
81 | res.extend(layer.variables())
82 | return res
83 |
84 | def copy(self, scope=None):
85 | scope = scope or self.scope + "_copy"
86 | nonlinearities = [self.input_nonlinearity] + self.layer_nonlinearities
87 | given_layers = [self.input_layer.copy()] + [layer.copy() for layer in self.layers]
88 | return MLP(self.input_sizes, self.hiddens, nonlinearities, scope=scope,
89 | given_layers=given_layers)
90 |
91 |
92 | class ConvLayer(object):
93 | def __init__(self, filter_H, filter_W,
94 | in_C, out_C,
95 | stride=(1,1),
96 | scope="Convolution"):
97 | self.filter_H, self.filter_W = filter_H, filter_W
98 | self.in_C, self.out_C = in_C, out_C
99 | self.stride = stride
100 | self.scope = scope
101 |
102 | with tf.variable_scope(self.scope):
103 | input_size = filter_H * filter_W * in_C
104 | W_initializer = tf.random_uniform_initializer(
105 | -1.0 / math.sqrt(input_size),
106 | 1.0 / math.sqrt(input_size))
107 | self.W = tf.get_variable('W',
108 | (filter_H, filter_W, in_C, out_C),
109 | initializer=W_initializer)
110 | self.b = tf.get_variable('b',
111 | (out_C),
112 | initializer=tf.constant_initializer(0))
113 |
114 | def __call__(self, X):
115 | with tf.variable_scope(self.scope):
116 | return tf.nn.conv2d(X, self.W,
117 | strides=[1] + list(self.stride) + [1],
118 | padding='SAME') + self.b
119 |
120 | def variables(self):
121 | return [self.W, self.b]
122 |
123 | def copy(self, scope=None):
124 | scope = scope or self.scope + "_copy"
125 |
126 | with tf.variable_scope(scope) as sc:
127 | for v in self.variables():
128 | tf.get_variable(base_name(v), v.get_shape(),
129 | initializer=lambda x,dtype=tf.float32,partition_info=None: v.initialized_value())
130 | sc.reuse_variables()
131 | return ConvLayer(self.filter_H, self.filter_W, self.in_C, self.out_C, self.stride, scope=sc)
132 |
133 | class SeqLayer(object):
134 | def __init__(self, layers, scope='seq_layer'):
135 | self.scope = scope
136 | self.layers = layers
137 |
138 | def __call__(self, x):
139 | for l in self.layers:
140 | x = l(x)
141 | return x
142 |
143 | def variables(self):
144 | return sum([l.variables() for l in self.layers], [])
145 |
146 | def copy(self, scope=None):
147 | scope = scope or self.scope + "_copy"
148 | with tf.variable_scope(self.scope):
149 | copied_layers = [layer.copy() for layer in self.layers]
150 | return SeqLayer(copied_layers, scope=scope)
151 |
152 |
153 | class LambdaLayer(object):
154 | def __init__(self, f):
155 | self.f = f
156 |
157 | def __call__(self, x):
158 | return self.f(x)
159 |
160 | def variables(self):
161 | return []
162 |
163 | def copy(self):
164 | return LambdaLayer(self.f)
165 |
--------------------------------------------------------------------------------
/tf_rl/simulate.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import math
4 | import time
5 |
6 | import matplotlib.pyplot as plt
7 | from itertools import count
8 | from os.path import join, exists
9 | from os import makedirs
10 | from IPython.display import clear_output, display, HTML
11 |
12 | def simulate(simulation,
13 | controller= None,
14 | fps=60,
15 | visualize_every=1,
16 | action_every=1,
17 | simulation_resolution=None,
18 | wait=False,
19 | disable_training=False,
20 | save_path=None):
21 | """Start the simulation. Performs three tasks
22 |
23 | - visualizes simulation in iPython notebook
24 | - advances simulator state
25 | - reports state to controller and chooses actions
26 | to be performed.
27 |
28 | Parameters
29 | -------
30 | simulation: tr_lr.simulation
31 | simulation that will be simulated ;-)
32 | controller: tr_lr.controller
33 | controller used
34 | fps: int
35 | frames per seconds
36 | visualize_every: int
37 | visualize every `visualize_every`-th frame.
38 | action_every: int
39 | take action every `action_every`-th frame
40 | simulation_resolution: float
41 | simulate at most 'simulation_resolution' seconds at a time.
42 | If None, the it is set to 1/FPS (default).
43 | wait: boolean
44 | whether to intentionally slow down the simulation
45 | to appear real time.
46 | disable_training: bool
47 | if true training_step is never called.
48 | save_path: str
49 | save svg visualization (only tl_rl.utils.svg
50 | supported for the moment)
51 | """
52 |
53 | # prepare path to save simulation images
54 | if save_path is not None:
55 | if not exists(save_path):
56 | makedirs(save_path)
57 | last_image = 0
58 |
59 | # calculate simulation times
60 | chunks_per_frame = 1
61 | chunk_length_s = 1.0 / fps
62 |
63 | if simulation_resolution is not None:
64 | frame_length_s = 1.0 / fps
65 | chunks_per_frame = int(math.ceil(frame_length_s / simulation_resolution))
66 | chunks_per_frame = max(chunks_per_frame, 1)
67 | chunk_length_s = frame_length_s / chunks_per_frame
68 |
69 | # state transition bookkeeping
70 | last_observation = None
71 | last_action = None
72 |
73 | simulation_started_time = time.time()
74 |
75 | # setup rendering handles for reuse
76 | if hasattr(simulation, 'setup_draw'):
77 | simulation.setup_draw()
78 |
79 | for frame_no in count():
80 | for _ in range(chunks_per_frame):
81 | simulation.step(chunk_length_s)
82 |
83 | if frame_no % action_every == 0:
84 | new_observation = simulation.observe()
85 | reward = simulation.collect_reward()
86 | # store last transition
87 | if last_observation is not None:
88 | controller.store(last_observation, last_action, reward, new_observation)
89 |
90 | # act
91 | new_action = controller.action(new_observation)
92 | simulation.perform_action(new_action)
93 |
94 | #train
95 | if not disable_training:
96 | controller.training_step()
97 |
98 | # update current state as last state.
99 | last_action = new_action
100 | last_observation = new_observation
101 |
102 | # adding 1 to make it less likely to happen at the same time as
103 | # action taking.
104 | if (frame_no + 1) % visualize_every == 0:
105 | fps_estimate = frame_no / (time.time() - simulation_started_time)
106 |
107 | # draw simulated environment all the rendering is handled within the simulation object
108 | stats = ["fps = %.1f" % (fps_estimate, )]
109 | if hasattr(simulation, 'draw'): # render with the draw function
110 | simulation.draw(stats)
111 | elif hasattr(simulation, 'to_html'): # in case some class only support svg rendering
112 | clear_output(wait=True)
113 | svg_html = simulation.to_html(stats)
114 | display(svg_html)
115 |
116 | if save_path is not None:
117 | img_path = join(save_path, "%d.svg" % (last_image,))
118 | with open(img_path, "w") as f:
119 | svg_html.write_svg(f)
120 | last_image += 1
121 |
122 | time_should_have_passed = frame_no / fps
123 | time_passed = (time.time() - simulation_started_time)
124 | if wait and (time_should_have_passed > time_passed):
125 | time.sleep(time_should_have_passed - time_passed)
126 |
--------------------------------------------------------------------------------
/tf_rl/simulation/__init__.py:
--------------------------------------------------------------------------------
1 | from .karpathy_game import KarpathyGame
2 | from .double_pendulum import DoublePendulum
3 | from .discrete_hill import DiscreteHill
4 |
--------------------------------------------------------------------------------
/tf_rl/simulation/discrete_hill.py:
--------------------------------------------------------------------------------
1 | from random import randint, gauss
2 |
3 | import numpy as np
4 |
5 | class DiscreteHill(object):
6 |
7 | directions = [(0,1), (0,-1), (1,0), (-1,0)]
8 |
9 | def __init__(self, board=(10,10), variance=4.):
10 | self.variance = variance
11 | self.target = (0,0)
12 | while self.target == (0,0):
13 | self.target = (randint(-board[0], board[0]), randint(-board[1], board[1]))
14 | self.position = (0,0)
15 |
16 | self.shortest_path = self.distance(self.position, self.target)
17 |
18 | @staticmethod
19 | def add(p, q):
20 | return (p[0] + q[0], p[1] + q[1])
21 |
22 | @staticmethod
23 | def distance(p, q):
24 | return abs(p[0] - q[0]) + abs(p[1] - q[1])
25 |
26 | def estimate_distance(self, p):
27 | distance = DiscreteHill.distance(self.target, p) - DiscreteHill.distance(self.target, self.position)
28 | return distance + abs(gauss(0, self.variance))
29 |
30 | def observe(self):
31 | return np.array([self.estimate_distance(DiscreteHill.add(self.position, delta))
32 | for delta in DiscreteHill.directions])
33 |
34 | def perform_action(self, action):
35 | self.position = DiscreteHill.add(self.position, DiscreteHill.directions[action])
36 |
37 | def is_over(self):
38 | return self.position == self.target
39 |
40 | def collect_reward(self, action):
41 | return -DiscreteHill.distance(self.target, DiscreteHill.add(self.position, DiscreteHill.directions[action])) \
42 | + DiscreteHill.distance(self.target, self.position) - 2
43 |
--------------------------------------------------------------------------------
/tf_rl/simulation/double_pendulum.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from ..utils import svg
4 |
5 | class DoublePendulum(object):
6 | def __init__(self, params):
7 | """Double Pendulum simulation, where control is
8 | only applied to joint1.
9 |
10 | state of the system is encoded as the following
11 | four values:
12 | state[0]:
13 | angle of first bar from center
14 | (w.r.t. vertical axis)
15 | state[1]:
16 | angular velocity of state[0]
17 | state[2]:
18 | angle of second bar from center
19 | (w.r.t vertical axis)
20 | state[3]:
21 | angular velocity of state[2]
22 |
23 | Params
24 | -------
25 | g_ms2 : float
26 | gravity acceleration
27 | l1_m : float
28 | length of the first bar (closer to center)
29 | m1_kg: float
30 | mass of the first joint
31 | l2_m : float
32 | length of the second bar
33 | m2_kg : float
34 | mass of the second joint
35 | max_control_input : float
36 | maximum value of angular force applied
37 | to the first joint
38 | """
39 | self.state = np.array([0.0, 0.0, 0.0, 0.0])
40 | self.control_input = 0.0
41 | self.params = params
42 | self.size = (400, 300)
43 |
44 | def external_derivatives(self):
45 | """How state of the world changes
46 | naturally due to gravity and momentum
47 |
48 | Returns a vector of four values
49 | which are derivatives of different
50 | state in the internal state representation."""
51 | # code below is horrible, if somebody
52 | # wants to clean it up, I will gladly
53 | # accept pull request.
54 | G = self.params['g_ms2']
55 | L1 = self.params['l1_m']
56 | L2 = self.params['l2_m']
57 | M1 = self.params['m1_kg']
58 | M2 = self.params['m2_kg']
59 | damping = self.params['damping']
60 | state = self.state
61 |
62 | dydx = np.zeros_like(state)
63 | dydx[0] = state[1]
64 |
65 | del_ = state[2]-state[0]
66 | den1 = (M1+M2)*L1 - M2*L1*np.cos(del_)*np.cos(del_)
67 | dydx[1] = (M2*L1*state[1]*state[1]*np.sin(del_)*np.cos(del_)
68 | + M2*G*np.sin(state[2])*np.cos(del_)
69 | + M2*L2*state[3]*state[3]*np.sin(del_)
70 | - (M1+M2)*G*np.sin(state[0])) / den1
71 | dydx[1] -= damping * state[1]
72 |
73 | dydx[2] = state[3]
74 |
75 | den2 = (L2/L1)*den1
76 | dydx[3] = (-M2*L2*state[3]*state[3]*np.sin(del_)*np.cos(del_)
77 | + (M1+M2)*G*np.sin(state[0])*np.cos(del_)
78 | - (M1+M2)*L1*state[1]*state[1]*np.sin(del_)
79 | - (M1+M2)*G*np.sin(state[2]))/den2
80 | dydx[3] -= damping * state[3]
81 |
82 |
83 | return np.array(dydx)
84 |
85 | def control_derivative(self):
86 | """Derivative of self.state due to control"""
87 | return np.array([0., 0., 0., 1.]) * self.control_input
88 |
89 | def observe(self):
90 | """Returns an observation."""
91 | return self.state
92 |
93 | def perform_action(self, action):
94 | """Expects action to be in range [-1, 1]"""
95 | self.control_input = action * self.params['max_control_input']
96 |
97 | def step(self, dt):
98 | """Advance simulation by dt seconds"""
99 | dstate = self.external_derivatives() + self.control_derivative()
100 | self.state += dt * dstate
101 |
102 | def collect_reward(self):
103 | """Reward corresponds to how high is the first joint."""
104 | _, joint2 = self.joint_positions()
105 | return -joint2[1]
106 |
107 | def joint_positions(self):
108 | """Returns abosolute positions of both joints in coordinate system
109 | where center of system is the attachement point"""
110 | x1 = self.params['l1_m'] * np.sin(self.state[0])
111 | y1 = self.params['l1_m'] * np.cos(self.state[0])
112 |
113 | x2 = self.params['l2_m'] * np.sin(self.state[2]) + x1
114 | y2 = self.params['l2_m'] * np.cos(self.state[2]) + y1
115 |
116 | return (x1, y1), (x2, y2)
117 |
118 | def to_html(self, info=[]):
119 | """Visualize"""
120 | info = info[:]
121 | info.append("Reward = %.1f" % self.collect_reward())
122 | joint1, joint2 = self.joint_positions()
123 |
124 | total_length = self.params['l1_m'] + self.params['l2_m']
125 | # 9 / 10 th of half the screen width
126 | total_length_px = (8./10.) * (min(self.size) / 2.)
127 | scaling_ratio = total_length_px / total_length
128 | center = (self.size[0] / 2, self.size[1] / 2)
129 |
130 | def transform(point):
131 | """Transforms from state reference world
132 | to screen and pixels reference world"""
133 |
134 | x = center[0] + scaling_ratio * point[0]
135 | y = center[1] + scaling_ratio * point[1]
136 | return int(x), int(y)
137 |
138 |
139 | scene = svg.Scene((self.size[0] + 20, self.size[1] + 20 + 20 * len(info)))
140 | scene.add(svg.Rectangle((10, 10), self.size))
141 |
142 | joint1 = transform(joint1)
143 | joint2 = transform(joint2)
144 | scene.add(svg.Line(center, joint1))
145 | scene.add(svg.Line(joint1, joint2))
146 |
147 | scene.add(svg.Circle(center, 5, color='red'))
148 | scene.add(svg.Circle(joint1, 3, color='blue'))
149 | scene.add(svg.Circle(joint2, 3, color='green'))
150 |
151 |
152 | offset = self.size[1] + 15
153 | for txt in info:
154 | scene.add(svg.Text((10, offset + 20), txt, 15))
155 | offset += 20
156 | return scene
157 |
--------------------------------------------------------------------------------
/tf_rl/simulation/karpathy_game.py:
--------------------------------------------------------------------------------
1 | import math
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import random
5 | import time
6 |
7 | from collections import defaultdict
8 | from euclid import Circle, Point2, Vector2, LineSegment2
9 |
10 | from ..utils import svg
11 | from IPython.display import clear_output, display, HTML
12 |
13 | class GameObject(object):
14 | def __init__(self, position, speed, obj_type, settings):
15 | """Esentially represents circles of different kinds, which have
16 | position and speed."""
17 | self.settings = settings
18 | self.radius = self.settings["object_radius"]
19 |
20 | self.obj_type = obj_type
21 | self.position = position
22 | self.speed = speed
23 | self.bounciness = 1.0
24 |
25 | def wall_collisions(self):
26 | """Update speed upon collision with the wall."""
27 | world_size = self.settings["world_size"]
28 |
29 | for dim in range(2):
30 | if self.position[dim] - self.radius <= 0 and self.speed[dim] < 0:
31 | self.speed[dim] = - self.speed[dim] * self.bounciness
32 | elif self.position[dim] + self.radius + 1 >= world_size[dim] and self.speed[dim] > 0:
33 | self.speed[dim] = - self.speed[dim] * self.bounciness
34 |
35 | def move(self, dt):
36 | """Move as if dt seconds passed"""
37 | self.position += dt * self.speed
38 | self.position = Point2(*self.position)
39 |
40 | def step(self, dt):
41 | """Move and bounce of walls."""
42 | self.wall_collisions()
43 | self.move(dt)
44 |
45 | def as_circle(self):
46 | return Circle(self.position, float(self.radius))
47 |
48 | def draw(self):
49 | """Return svg object for this item."""
50 | color = self.settings["colors"][self.obj_type]
51 | return svg.Circle(self.position + Point2(10, 10), self.radius, color=color)
52 |
53 | class KarpathyGame(object):
54 | def __init__(self, settings):
55 | """Initiallize game simulator with settings"""
56 | self.settings = settings
57 | self.size = self.settings["world_size"]
58 | self.walls = [LineSegment2(Point2(0,0), Point2(0,self.size[1])),
59 | LineSegment2(Point2(0,self.size[1]), Point2(self.size[0], self.size[1])),
60 | LineSegment2(Point2(self.size[0], self.size[1]), Point2(self.size[0], 0)),
61 | LineSegment2(Point2(self.size[0], 0), Point2(0,0))]
62 |
63 | self.hero = GameObject(Point2(*self.settings["hero_initial_position"]),
64 | Vector2(*self.settings["hero_initial_speed"]),
65 | "hero",
66 | self.settings)
67 | if not self.settings["hero_bounces_off_walls"]:
68 | self.hero.bounciness = 0.0
69 |
70 | self.objects = []
71 | for obj_type, number in settings["num_objects"].items():
72 | for _ in range(number):
73 | self.spawn_object(obj_type)
74 |
75 | self.observation_lines = self.generate_observation_lines()
76 |
77 | self.object_reward = 0
78 | self.collected_rewards = []
79 |
80 | # every observation_line sees one of objects or wall and
81 | # two numbers representing speed of the object (if applicable)
82 | self.eye_observation_size = len(self.settings["objects"]) + 3
83 | # additionally there are two numbers representing agents own speed and position.
84 | self.observation_size = self.eye_observation_size * len(self.observation_lines) + 2 + 2
85 |
86 | self.directions = [Vector2(*d) for d in [[1,0], [0,1], [-1,0],[0,-1],[0.0,0.0]]]
87 | self.num_actions = len(self.directions)
88 |
89 | self.objects_eaten = defaultdict(lambda: 0)
90 |
91 | def perform_action(self, action_id):
92 | """Change speed to one of hero vectors"""
93 | assert 0 <= action_id < self.num_actions
94 | self.hero.speed *= 0.5
95 | self.hero.speed += self.directions[action_id] * self.settings["delta_v"]
96 |
97 | def spawn_object(self, obj_type):
98 | """Spawn object of a given type and add it to the objects array"""
99 | radius = self.settings["object_radius"]
100 | position = np.random.uniform([radius, radius], np.array(self.size) - radius)
101 | position = Point2(float(position[0]), float(position[1]))
102 | max_speed = np.array(self.settings["maximum_speed"])
103 | speed = np.random.uniform(-max_speed, max_speed).astype(float)
104 | speed = Vector2(float(speed[0]), float(speed[1]))
105 |
106 | self.objects.append(GameObject(position, speed, obj_type, self.settings))
107 |
108 | def step(self, dt):
109 | """Simulate all the objects for a given ammount of time.
110 |
111 | Also resolve collisions with the hero"""
112 | for obj in self.objects + [self.hero] :
113 | obj.step(dt)
114 | self.resolve_collisions()
115 |
116 | def squared_distance(self, p1, p2):
117 | return (p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2
118 |
119 | def resolve_collisions(self):
120 | """If hero touches, hero eats. Also reward gets updated."""
121 | collision_distance = 2 * self.settings["object_radius"]
122 | collision_distance2 = collision_distance ** 2
123 | to_remove = []
124 | for obj in self.objects:
125 | if self.squared_distance(self.hero.position, obj.position) < collision_distance2:
126 | to_remove.append(obj)
127 | for obj in to_remove:
128 | self.objects.remove(obj)
129 | self.objects_eaten[obj.obj_type] += 1
130 | self.object_reward += self.settings["object_reward"][obj.obj_type]
131 | self.spawn_object(obj.obj_type)
132 |
133 | def inside_walls(self, point):
134 | """Check if the point is inside the walls"""
135 | EPS = 1e-4
136 | return (EPS <= point[0] < self.size[0] - EPS and
137 | EPS <= point[1] < self.size[1] - EPS)
138 |
139 | def observe(self):
140 | """Return observation vector. For all the observation directions it returns representation
141 | of the closest object to the hero - might be nothing, another object or a wall.
142 | Representation of observation for all the directions will be concatenated.
143 | """
144 | num_obj_types = len(self.settings["objects"]) + 1 # and wall
145 | max_speed_x, max_speed_y = self.settings["maximum_speed"]
146 |
147 | observable_distance = self.settings["observation_line_length"]
148 |
149 | relevant_objects = [obj for obj in self.objects
150 | if obj.position.distance(self.hero.position) < observable_distance]
151 | # objects sorted from closest to furthest
152 | relevant_objects.sort(key=lambda x: x.position.distance(self.hero.position))
153 |
154 | observation = np.zeros(self.observation_size)
155 | observation_offset = 0
156 | for i, observation_line in enumerate(self.observation_lines):
157 | # shift to hero position
158 | observation_line = LineSegment2(self.hero.position + Vector2(*observation_line.p1),
159 | self.hero.position + Vector2(*observation_line.p2))
160 |
161 | observed_object = None
162 | # if end of observation line is outside of walls, we see the wall.
163 | if not self.inside_walls(observation_line.p2):
164 | observed_object = "**wall**"
165 | for obj in relevant_objects:
166 | if observation_line.distance(obj.position) < self.settings["object_radius"]:
167 | observed_object = obj
168 | break
169 | object_type_id = None
170 | speed_x, speed_y = 0, 0
171 | proximity = 0
172 | if observed_object == "**wall**": # wall seen
173 | object_type_id = num_obj_types - 1
174 | # a wall has fairly low speed...
175 | speed_x, speed_y = 0, 0
176 | # best candidate is intersection between
177 | # observation_line and a wall, that's
178 | # closest to the hero
179 | best_candidate = None
180 | for wall in self.walls:
181 | candidate = observation_line.intersect(wall)
182 | if candidate is not None:
183 | if (best_candidate is None or
184 | best_candidate.distance(self.hero.position) >
185 | candidate.distance(self.hero.position)):
186 | best_candidate = candidate
187 | if best_candidate is None:
188 | # assume it is due to rounding errors
189 | # and wall is barely touching observation line
190 | proximity = observable_distance
191 | else:
192 | proximity = best_candidate.distance(self.hero.position)
193 | elif observed_object is not None: # agent seen
194 | object_type_id = self.settings["objects"].index(observed_object.obj_type)
195 | speed_x, speed_y = tuple(observed_object.speed)
196 | intersection_segment = obj.as_circle().intersect(observation_line)
197 | assert intersection_segment is not None
198 | try:
199 | proximity = min(intersection_segment.p1.distance(self.hero.position),
200 | intersection_segment.p2.distance(self.hero.position))
201 | except AttributeError:
202 | proximity = observable_distance
203 | for object_type_idx_loop in range(num_obj_types):
204 | observation[observation_offset + object_type_idx_loop] = 1.0
205 | if object_type_id is not None:
206 | observation[observation_offset + object_type_id] = proximity / observable_distance
207 | observation[observation_offset + num_obj_types] = speed_x / max_speed_x
208 | observation[observation_offset + num_obj_types + 1] = speed_y / max_speed_y
209 | assert num_obj_types + 2 == self.eye_observation_size
210 | observation_offset += self.eye_observation_size
211 |
212 | observation[observation_offset] = self.hero.speed[0] / max_speed_x
213 | observation[observation_offset + 1] = self.hero.speed[1] / max_speed_y
214 | observation_offset += 2
215 |
216 | # add normalized locaiton of the hero in environment
217 | observation[observation_offset] = self.hero.position[0] / 350.0 - 1.0
218 | observation[observation_offset + 1] = self.hero.position[1] / 250.0 - 1.0
219 |
220 | assert observation_offset + 2 == self.observation_size
221 |
222 | return observation
223 |
224 | def distance_to_walls(self):
225 | """Returns distance of a hero to walls"""
226 | res = float('inf')
227 | for wall in self.walls:
228 | res = min(res, self.hero.position.distance(wall))
229 | return res - self.settings["object_radius"]
230 |
231 | def collect_reward(self):
232 | """Return accumulated object eating score + current distance to walls score"""
233 | wall_reward = self.settings["wall_distance_penalty"] * \
234 | np.exp(-self.distance_to_walls() / self.settings["tolerable_distance_to_wall"])
235 | assert wall_reward < 1e-3, "You are rewarding hero for being close to the wall!"
236 | total_reward = wall_reward + self.object_reward
237 | self.object_reward = 0
238 | self.collected_rewards.append(total_reward)
239 | return total_reward
240 |
241 | def plot_reward(self, smoothing = 30):
242 | """Plot evolution of reward over time."""
243 | plottable = self.collected_rewards[:]
244 | while len(plottable) > 1000:
245 | for i in range(0, len(plottable) - 1, 2):
246 | plottable[i//2] = (plottable[i] + plottable[i+1]) / 2
247 | plottable = plottable[:(len(plottable) // 2)]
248 | x = []
249 | for i in range(smoothing, len(plottable)):
250 | chunk = plottable[i-smoothing:i]
251 | x.append(sum(chunk) / len(chunk))
252 | plt.plot(list(range(len(x))), x)
253 |
254 | def generate_observation_lines(self):
255 | """Generate observation segments in settings["num_observation_lines"] directions"""
256 | result = []
257 | start = Point2(0.0, 0.0)
258 | end = Point2(self.settings["observation_line_length"],
259 | self.settings["observation_line_length"])
260 | for angle in np.linspace(0, 2*np.pi, self.settings["num_observation_lines"], endpoint=False):
261 | rotation = Point2(math.cos(angle), math.sin(angle))
262 | current_start = Point2(start[0] * rotation[0], start[1] * rotation[1])
263 | current_end = Point2(end[0] * rotation[0], end[1] * rotation[1])
264 | result.append( LineSegment2(current_start, current_end))
265 | return result
266 |
267 | def _repr_html_(self):
268 | return self.to_html()
269 |
270 | def to_html(self, stats=[]):
271 | """Return svg representation of the simulator"""
272 |
273 | stats = stats[:]
274 | recent_reward = self.collected_rewards[-100:] + [0]
275 | objects_eaten_str = ', '.join(["%s: %s" % (o,c) for o,c in self.objects_eaten.items()])
276 | stats.extend([
277 | "nearest wall = %.1f" % (self.distance_to_walls(),),
278 | "reward = %.1f" % (sum(recent_reward)/len(recent_reward),),
279 | "objects eaten => %s" % (objects_eaten_str,),
280 | ])
281 |
282 | scene = svg.Scene((self.size[0] + 20, self.size[1] + 20 + 20 * len(stats)))
283 | scene.add(svg.Rectangle((10, 10), self.size))
284 |
285 |
286 | for line in self.observation_lines:
287 | scene.add(svg.Line(line.p1 + self.hero.position + Point2(10,10),
288 | line.p2 + self.hero.position + Point2(10,10)))
289 |
290 | for obj in self.objects + [self.hero] :
291 | scene.add(obj.draw())
292 |
293 | offset = self.size[1] + 15
294 | for txt in stats:
295 | scene.add(svg.Text((10, offset + 20), txt, 15))
296 | offset += 20
297 |
298 | return scene
299 |
300 | def setup_draw(self):
301 | """
302 | An optional method to be triggered in simulate(...) to initialise
303 | the figure handles for rendering.
304 | simulate(...) will run with/without this method declared in the simulation class
305 | As we are using SVG strings in KarpathyGame, it is not curently used.
306 | """
307 | pass
308 |
309 | def draw(self, stats=[]):
310 | """
311 | An optional method to be triggered in simulate(...) to render the simulated environment.
312 | It is repeatedly called in each simulated iteration.
313 | simulate(...) will run with/without this method declared in the simulation class.
314 | """
315 | clear_output(wait=True)
316 | svg_html = self.to_html(stats)
317 | display(svg_html)
318 |
--------------------------------------------------------------------------------
/tf_rl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def base_name(var):
4 | """Extracts value passed to name= when creating a variable"""
5 | return var.name.split('/')[-1].split(':')[0]
6 |
7 | def copy_variables(variables):
8 | res = {}
9 | for v in variables:
10 | name = base_name(v)
11 | copied_var = tf.Variable(v.initialized_value(), name=name)
12 | res[name] = copied_var
13 | return res
14 |
--------------------------------------------------------------------------------
/tf_rl/utils/event_queue.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from queue import PriorityQueue
4 |
5 | class EqItem(object):
6 | """Function and sechduled execution timestamp.
7 |
8 | This class is needed because if
9 | we use tuple instead, Python will ocassionally
10 | complaint that it does not know how to compare
11 | functions"""
12 | def __init__(self, ts, f):
13 | self.ts = ts
14 | self.f = f
15 |
16 | def __lt__(self, other):
17 | return self.ts < other.ts
18 |
19 | def __eq__(self, other):
20 | return self.ts == other.ts
21 |
22 | class EventQueue(object):
23 | def __init__(self):
24 | """Event queue for executing events at
25 | specific timepoints.
26 |
27 | In current form it is NOT thread safe."""
28 | self.q = PriorityQueue()
29 |
30 | def schedule(self, f, ts):
31 | """Schedule f to be execute at time ts"""
32 | self.q.put(EqItem(ts, f))
33 |
34 | def schedule_recurring(self, f, interval):
35 | """Schedule f to be run every interval seconds.
36 |
37 | It will be run for the first time interval seconds
38 | from now"""
39 | def recuring_f():
40 | f()
41 | self.schedule(recuring_f, time.time() + interval)
42 | self.schedule(recuring_f, time.time() + interval)
43 |
44 |
45 | def run(self):
46 | """Execute events in the queue as timely as possible."""
47 | while True:
48 | event = self.q.get()
49 | now = time.time()
50 | if now < event.ts:
51 | time.sleep(event.ts - now)
52 | event.f()
53 |
54 |
--------------------------------------------------------------------------------
/tf_rl/utils/geometry.py:
--------------------------------------------------------------------------------
1 | """
2 | This module assumes that all geometrical points are
3 | represented as 1D numpy arrays.
4 |
5 | It was designed and tested on 2D points,
6 | but if you try it on 3D points you may
7 | be pleasantly surprised ;-)
8 | """
9 | import numpy as np
10 |
11 |
12 | def point_distance(x, y):
13 | """Returns euclidean distance between points x and y"""
14 | return np.linalg.norm(x-y)
15 |
16 | def point_projected_on_line(line_s, line_e, point):
17 | """Project point on line that goes through line_s and line_e
18 |
19 | assumes line_e is not equal or close to line_s
20 | """
21 | line_along = line_e - line_s
22 |
23 | transformed_point = point - line_s
24 |
25 | point_dot_line = np.dot(transformed_point, line_along)
26 | line_along_norm = np.dot(line_along, line_along)
27 |
28 | transformed_projection = (point_dot_line / line_along_norm) * line_along
29 |
30 | return transformed_projection + line_s
31 |
32 | def point_segment_distance(segment_s, segment_e, point):
33 | """Returns distance from point to the closest point on segment
34 | connecting points segment_s and segment_e"""
35 | projected = point_projected_on_line(segment_s, segment_e, point)
36 | if np.isclose(point_distance(segment_s, projected) + point_distance(projected, segment_e),
37 | point_distance(segment_s, segment_e)):
38 | # projected on segment
39 | return point_distance(point, projected)
40 | else:
41 | return min(point_distance(point, segment_s), point_distance(point, segment_e))
42 |
--------------------------------------------------------------------------------
/tf_rl/utils/getch.py:
--------------------------------------------------------------------------------
1 | class _Getch:
2 | """Gets a single character from standard input. Does not echo to the
3 | screen."""
4 | def __init__(self):
5 | try:
6 | self.impl = _GetchWindows()
7 | except ImportError:
8 | self.impl = _GetchUnix()
9 |
10 | def __call__(self): return self.impl()
11 |
12 |
13 | class _GetchUnix:
14 | def __init__(self):
15 | import tty, sys
16 |
17 | def __call__(self):
18 | import sys, tty, termios
19 | fd = sys.stdin.fileno()
20 | old_settings = termios.tcgetattr(fd)
21 | try:
22 | tty.setraw(sys.stdin.fileno())
23 | ch = sys.stdin.read(1)
24 | finally:
25 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
26 | return ch
27 |
28 |
29 | class _GetchWindows:
30 | def __init__(self):
31 | import msvcrt
32 |
33 | def __call__(self):
34 | import msvcrt
35 | return msvcrt.getch()
36 |
37 |
38 | getch = _Getch()
39 |
--------------------------------------------------------------------------------
/tf_rl/utils/svg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """\
3 | SVG.py - Construct/display SVG scenes.
4 |
5 | The following code is a lightweight wrapper around SVG files. The metaphor
6 | is to construct a scene, add objects to it, and then write it to a file
7 | to display it.
8 |
9 | This program uses ImageMagick to display the SVG files. ImageMagick also
10 | does a remarkable job of converting SVG files into other formats.
11 | """
12 |
13 | import os
14 |
15 | def colorstr(rgb):
16 | if type(rgb) == tuple:
17 | return "#%02x%02x%02x" % rgb
18 | else:
19 | return rgb
20 |
21 | def compute_style(style):
22 | color = style.get("color")
23 | style_str = []
24 | if color is None:
25 | color="none"
26 | style_str.append('fill:%s;' % (colorstr(color),))
27 |
28 | style_str = 'style="%s"' % (';'.join(style_str),)
29 | return style_str
30 |
31 | class Scene:
32 | def __init__(self, size=(400,400)):
33 | self.items = []
34 | self.size = size
35 |
36 | def add(self,item):
37 | self.items.append(item)
38 |
39 | def strarray(self):
40 | var = [
41 | "\n",
42 | "\n"]
48 | return var
49 |
50 | def write_svg(self, file):
51 | file.writelines(self.strarray())
52 |
53 | def _repr_html_(self):
54 | return '\n'.join(self.strarray())
55 |
56 | class Line:
57 | def __init__(self,start,end):
58 | self.start = start #xy tuple
59 | self.end = end #xy tuple
60 |
61 | def strarray(self):
62 | return [" \n" %\
63 | (self.start[0],self.start[1],self.end[0],self.end[1])]
64 |
65 |
66 | class Circle:
67 | def __init__(self,center,radius, **style_kwargs):
68 | self.center = center
69 | self.radius = radius
70 | self.style_kwargs = style_kwargs
71 |
72 | def strarray(self):
73 | style_str = compute_style(self.style_kwargs)
74 |
75 | return [
76 | " \n" % (style_str,)
78 | ]
79 |
80 | class Rectangle:
81 | def __init__(self, origin, size, **style_kwargs):
82 | self.origin = origin
83 | self.size = size
84 | self.style_kwargs = style_kwargs
85 |
86 | def strarray(self):
87 | style_str = compute_style(self.style_kwargs)
88 |
89 | return [
90 | " \n" % (self.size[0], style_str)
92 | ]
93 |
94 | class Text:
95 | def __init__(self,origin,text,size=24):
96 | self.origin = origin
97 | self.text = text
98 | self.size = size
99 | return
100 |
101 | def strarray(self):
102 | return [" \n" %\
103 | (self.origin[0],self.origin[1],self.size),
104 | " %s\n" % self.text,
105 | " \n"]
106 |
107 |
108 |
109 |
110 | def test():
111 | scene = Scene()
112 | scene.add(Rectangle((100,100),(200,200), **{"color":(0,255,255)} ))
113 | scene.add(Line((200,200),(200,300)))
114 | scene.add(Line((200,200),(300,200)))
115 | scene.add(Line((200,200),(100,200)))
116 | scene.add(Line((200,200),(200,100)))
117 | scene.add(Circle((200,200),30, **{"color":(0,0,255)} ))
118 | scene.add(Circle((200,300),30, **{"color":(0,255,0)} ))
119 | scene.add(Circle((300,200),30, **{"color":(255,0,0)} ))
120 | scene.add(Circle((100,200),30, **{"color":(255,255,0)} ))
121 | scene.add(Circle((200,100),30, **{"color":(255,0,255)} ))
122 | scene.add(Text((50,50),"Testing SVG"))
123 | with open("test.svg", "w") as f:
124 | scene.write_svg(f)
125 |
126 | if __name__ == '__main__':
127 | test()
128 |
--------------------------------------------------------------------------------