├── .DS_Store
├── .gitattributes
├── ckpt
├── model-3700000
├── model-3800000
└── model-3808836
├── docs
└── images
│ └── q-learning.png
├── .ipynb_checkpoints
├── tmp-checkpoint.ipynb
├── main-shallow-network-checkpoint.ipynb
├── dqn-named-tuple-checkpoint.ipynb
├── dqn-checkpoint.ipynb
└── main-checkpoint.ipynb
├── README.md
└── dqn.ipynb
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidreiman/pytorch-atari-dqn/HEAD/.DS_Store
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/ckpt/model-3700000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidreiman/pytorch-atari-dqn/HEAD/ckpt/model-3700000
--------------------------------------------------------------------------------
/ckpt/model-3800000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidreiman/pytorch-atari-dqn/HEAD/ckpt/model-3800000
--------------------------------------------------------------------------------
/ckpt/model-3808836:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidreiman/pytorch-atari-dqn/HEAD/ckpt/model-3808836
--------------------------------------------------------------------------------
/docs/images/q-learning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidreiman/pytorch-atari-dqn/HEAD/docs/images/q-learning.png
--------------------------------------------------------------------------------
/.ipynb_checkpoints/tmp-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 2
6 | }
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ------------
6 |
7 | ## Built With
8 |
9 | * [PyTorch](https://github.com/pytorch/pytorch) - Tensors and dynamic neural networks in Python with strong GPU acceleration
10 | * [OpenAI Gym](https://github.com/openai/gym) - A toolkit for developing and comparing reinforcement learning algorithms
11 |
12 | ## Authors
13 |
14 | * **David Reiman** - [davidreiman](https://github.com/davidreiman)
15 |
16 | ## References
17 |
18 | * [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/pdf/1312.5602v1.pdf)
19 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/main-shallow-network-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%matplotlib inline\n",
12 | "\n",
13 | "import os \n",
14 | "import gym\n",
15 | "import time\n",
16 | "import copy\n",
17 | "import random\n",
18 | "import numpy as np\n",
19 | "\n",
20 | "import torch\n",
21 | "import torchvision\n",
22 | "import torch.nn as nn\n",
23 | "\n",
24 | "from IPython import display\n",
25 | "from collections import deque\n",
26 | "from skimage.color import rgb2grey\n",
27 | "from skimage.transform import rescale\n",
28 | "from matplotlib import pyplot as plt\n",
29 | "from tqdm import tqdm_notebook as tqdm"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {
36 | "collapsed": true
37 | },
38 | "outputs": [],
39 | "source": [
40 | "plt.style.use('seaborn')"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {
47 | "collapsed": true
48 | },
49 | "outputs": [],
50 | "source": [
51 | "class DeepQNetwork(nn.Module):\n",
52 | " def __init__(self, num_frames, num_actions):\n",
53 | " super(DeepQNetwork, self).__init__()\n",
54 | " self.num_frames = num_frames\n",
55 | " self.num_actions = num_actions\n",
56 | " \n",
57 | " # Layers\n",
58 | " self.conv1 = nn.Conv2d(\n",
59 | " in_channels=num_frames,\n",
60 | " out_channels=16,\n",
61 | " kernel_size=8,\n",
62 | " stride=4,\n",
63 | " padding=2\n",
64 | " )\n",
65 | " self.conv2 = nn.Conv2d(\n",
66 | " in_channels=16,\n",
67 | " out_channels=32,\n",
68 | " kernel_size=4,\n",
69 | " stride=2,\n",
70 | " padding=1\n",
71 | " )\n",
72 | " self.fc1 = nn.Linear(\n",
73 | " in_features=3200,\n",
74 | " out_features=256,\n",
75 | " )\n",
76 | " self.fc2 = nn.Linear(\n",
77 | " in_features=256,\n",
78 | " out_features=num_actions,\n",
79 | " )\n",
80 | " \n",
81 | " # Activations\n",
82 | " self.relu = nn.ReLU()\n",
83 | " \n",
84 | " def flatten(self, x):\n",
85 | " batch_size = x.size()[0]\n",
86 | " x = x.view(batch_size, -1)\n",
87 | " return x\n",
88 | " \n",
89 | " def forward(self, x):\n",
90 | " \n",
91 | " # Forward pass\n",
92 | " x = self.relu(self.conv1(x)) # In: (80, 80, 4), Out: (20, 20, 16)\n",
93 | " x = self.relu(self.conv2(x)) # In: (20, 20, 16), Out: (10, 10, 32)\n",
94 | " x = self.flatten(x) # In: (10, 10, 32), Out: (3200,)\n",
95 | " x = self.relu(self.fc1(x)) # In: (3200,), Out: (256,)\n",
96 | " x = self.fc2(x) # In: (256,), Out: (4,)\n",
97 | " \n",
98 | " return x"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {
105 | "collapsed": true
106 | },
107 | "outputs": [],
108 | "source": [
109 | "def output_size(w, k, s, p):\n",
110 | " return ((w - k + 2*p)/s) + 1"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {
117 | "collapsed": true
118 | },
119 | "outputs": [],
120 | "source": [
121 | "class Agent:\n",
122 | " def __init__(self, model, memory_depth, gamma, epsilon_i, epsilon_f, anneal_time):\n",
123 | " \n",
124 | " self.cuda = True if torch.cuda.is_available() else False\n",
125 | " self.to_tensor = torch.cuda.FloatTensor if self.cuda else torch.FloatTensor\n",
126 | " self.to_byte_tensor = torch.cuda.ByteTensor if self.cuda else torch.ByteTensor\n",
127 | " \n",
128 | " self.model = model\n",
129 | " self.memory_depth = memory_depth\n",
130 | " self.gamma = self.to_tensor([gamma])\n",
131 | " self.e_i = epsilon_i\n",
132 | " self.e_f = epsilon_f\n",
133 | " self.anneal_time = anneal_time\n",
134 | " \n",
135 | " self.memory = deque(maxlen=memory_depth)\n",
136 | " self.clone()\n",
137 | " \n",
138 | " self.loss = nn.MSELoss()\n",
139 | " self.opt = torch.optim.RMSprop(self.model.parameters(), lr=2.5e-4)\n",
140 | " \n",
141 | " def clone(self):\n",
142 | " self.clone_model = copy.deepcopy(self.model)\n",
143 | " \n",
144 | " for p in self.clone_model.parameters():\n",
145 | " p.requires_grad = False\n",
146 | " \n",
147 | " def remember(self, state, action, reward, terminal, next_state):\n",
148 | " state, next_state = state.data.numpy(), next_state.data.numpy()\n",
149 | " state, next_state = 255.*state, 255.*next_state\n",
150 | " state, next_state = state.astype(np.uint8), next_state.astype(np.uint8)\n",
151 | " self.memory.append([state, action, reward, terminal, next_state])\n",
152 | " \n",
153 | " def retrieve(self, batch_size):\n",
154 | " # Note: Use lists for inhomogenous data!\n",
155 | " \n",
156 | " if batch_size > self.memories:\n",
157 | " batch_size = self.memories\n",
158 | " \n",
159 | " batch = random.sample(self.memory, batch_size)\n",
160 | " \n",
161 | " state = np.concatenate([batch[i][0] for i in range(batch_size)]).astype(np.int64)\n",
162 | " action = np.array([batch[i][1] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
163 | " reward = np.array([batch[i][2] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
164 | " terminal = np.array([batch[i][3] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
165 | " next_state = np.concatenate([batch[i][4] for i in range(batch_size)]).astype(np.int64)\n",
166 | " \n",
167 | " state = self.to_tensor(state/255.)\n",
168 | " next_state = self.to_tensor(state/255.)\n",
169 | " reward = self.to_tensor(reward)\n",
170 | " terminal = self.to_byte_tensor(terminal)\n",
171 | "\n",
172 | " return state, action, reward, terminal, next_state\n",
173 | " \n",
174 | " @property\n",
175 | " def memories(self):\n",
176 | " return len(self.memory)\n",
177 | " \n",
178 | " def act(self, state):\n",
179 | " q_values = self.model(state).detach()\n",
180 | " action = np.argmax(q_values.numpy())\n",
181 | " return action\n",
182 | " \n",
183 | " def process(self, state):\n",
184 | " state = rgb2grey(state[35:195, :, :])\n",
185 | " state = rescale(state, scale=0.5)\n",
186 | " state = state[np.newaxis, np.newaxis, :, :]\n",
187 | " return self.to_tensor(state)\n",
188 | " \n",
189 | " def exploration_rate(self, t):\n",
190 | " if t < self.anneal_time:\n",
191 | " return self.e_i - t*(self.e_i - self.e_f)/self.anneal_time\n",
192 | " elif t >= self.anneal_time:\n",
193 | " return self.e_f\n",
194 | " \n",
195 | " def huber_loss(self, x, y):\n",
196 | " error = x - y\n",
197 | " quadratic = 0.5 * error**2\n",
198 | " linear = np.absolute(error) - 0.5\n",
199 | " \n",
200 | " is_quadratic = (np.absolute(error) <= 1)\n",
201 | " \n",
202 | " return is_quadratic*quadratic + ~is_quadratic*linear\n",
203 | " \n",
204 | " def save(self, t, savedir=\"\"):\n",
205 | " save_path = os.path.join(savedir, 'model-{}'.format(t))\n",
206 | " torch.save(self.model.state_dict(), save_path)\n",
207 | " \n",
208 | " def update(self, batch_size, verbose=False):\n",
209 | " \n",
210 | " start = time.time()\n",
211 | " state, action, reward, terminal, next_state = self.retrieve(batch_size)\n",
212 | " \n",
213 | " if verbose:\n",
214 | " print(\"Sampled memory in {:0.2f} seconds.\".format(time.time() - start))\n",
215 | " \n",
216 | " start = time.time()\n",
217 | " \n",
218 | " q = self.model(state)[range(batch_size), action.flatten()][:, None]\n",
219 | " qmax = self.clone_model(next_state).max(dim=1)[0][:, None]\n",
220 | " \n",
221 | " nonterminal_target = reward + self.gamma*qmax\n",
222 | " terminal_target = reward\n",
223 | " \n",
224 | " target = terminal.float()*terminal_target + (~terminal).float()*nonterminal_target\n",
225 | " \n",
226 | " loss = self.loss(q, target)\n",
227 | " \n",
228 | " loss.backward()\n",
229 | " self.opt.step()\n",
230 | " \n",
231 | " if verbose:\n",
232 | " print(\"Updated parameters in {:0.2f} seconds.\".format(time.time() - start))"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {
239 | "collapsed": true
240 | },
241 | "outputs": [],
242 | "source": [
243 | "# Hyperparameters\n",
244 | "\n",
245 | "batch_size = 32\n",
246 | "update_interval = 32\n",
247 | "clone_interval = 128\n",
248 | "save_interval = int(1e5)\n",
249 | "frame_skip = 4\n",
250 | "num_frames = 4\n",
251 | "num_actions = 4\n",
252 | "episodes = 10000\n",
253 | "memory_depth = int(1e5)\n",
254 | "epsilon_i = 1.0\n",
255 | "epsilon_f = 0.1\n",
256 | "anneal_time = 1000000\n",
257 | "burn_in = 50000\n",
258 | "gamma = 0.99"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {
265 | "collapsed": true
266 | },
267 | "outputs": [],
268 | "source": [
269 | "model = DeepQNetwork(num_frames, num_actions)"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": null,
275 | "metadata": {
276 | "collapsed": true
277 | },
278 | "outputs": [],
279 | "source": [
280 | "agent = Agent(model, memory_depth, gamma, epsilon_i, epsilon_f, anneal_time)"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "metadata": {
287 | "collapsed": true
288 | },
289 | "outputs": [],
290 | "source": [
291 | "env = gym.make('Breakout-v0')"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": null,
297 | "metadata": {
298 | "collapsed": true
299 | },
300 | "outputs": [],
301 | "source": [
302 | "def q_iteration(episodes, plot=True, render=True, verbose=False):\n",
303 | " \n",
304 | " t = 0\n",
305 | " metadata = dict(episode=[], reward=[])\n",
306 | " \n",
307 | " progress_bar = tqdm(range(episodes))\n",
308 | " \n",
309 | " for episode in progress_bar:\n",
310 | " \n",
311 | " state = env.reset()\n",
312 | " state = agent.process(state)\n",
313 | " \n",
314 | " done = False\n",
315 | " total_reward = 0\n",
316 | "\n",
317 | " while not done:\n",
318 | " \n",
319 | " if render:\n",
320 | " env.render()\n",
321 | " \n",
322 | " while state.size()[1] < num_frames:\n",
323 | " action = np.random.choice(num_actions)\n",
324 | " \n",
325 | " new_frame, reward, done, info = env.step(action)\n",
326 | " new_frame = agent.process(new_frame)\n",
327 | " \n",
328 | " state = torch.cat([state, new_frame], 1)\n",
329 | " \n",
330 | " if np.random.uniform() < agent.exploration_rate(t) or t < burn_in:\n",
331 | " action = np.random.choice(num_actions)\n",
332 | "\n",
333 | " else:\n",
334 | " action = agent.act(state)\n",
335 | " \n",
336 | " new_frame, reward, done, info = env.step(action)\n",
337 | " new_frame = agent.process(new_frame)\n",
338 | " \n",
339 | " new_state = torch.cat([state, new_frame], 1)\n",
340 | " new_state = new_state[:, 1:, :, :]\n",
341 | "\n",
342 | " agent.remember(state, action, reward, done, new_state)\n",
343 | "\n",
344 | " state = new_state\n",
345 | " total_reward += reward\n",
346 | " t += 1\n",
347 | " \n",
348 | " if t % update_interval == 0 and t > burn_in:\n",
349 | " agent.update(batch_size, verbose=verbose)\n",
350 | " \n",
351 | " if t % clone_interval == 0 and t > burn_in:\n",
352 | " agent.clone()\n",
353 | " \n",
354 | " if t % save_interval == 0 and t > burn_in:\n",
355 | " agent.save(t)\n",
356 | " \n",
357 | " if t % 1000 == 0:\n",
358 | " progress_bar.set_description(\"t = {}\".format(t))\n",
359 | " \n",
360 | " metadata['episode'].append(episode)\n",
361 | " metadata['reward'].append(total_reward)\n",
362 | " \n",
363 | " if plot:\n",
364 | " plt.scatter(episode, total_reward)\n",
365 | " plt.xlim(0, episodes)\n",
366 | " plt.xlabel(\"Episode\")\n",
367 | " plt.ylabel(\"Return\")\n",
368 | " display.clear_output(wait=True)\n",
369 | " display.display(plt.gcf())\n",
370 | " \n",
371 | " return metadata"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "metadata": {
378 | "scrolled": false
379 | },
380 | "outputs": [],
381 | "source": [
382 | "metadata = q_iteration(episodes, plot=True, render=True)"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": null,
388 | "metadata": {
389 | "collapsed": true
390 | },
391 | "outputs": [],
392 | "source": []
393 | }
394 | ],
395 | "metadata": {
396 | "kernelspec": {
397 | "display_name": "Python 3",
398 | "language": "python",
399 | "name": "python3"
400 | },
401 | "language_info": {
402 | "codemirror_mode": {
403 | "name": "ipython",
404 | "version": 3
405 | },
406 | "file_extension": ".py",
407 | "mimetype": "text/x-python",
408 | "name": "python",
409 | "nbconvert_exporter": "python",
410 | "pygments_lexer": "ipython3",
411 | "version": "3.6.1"
412 | }
413 | },
414 | "nbformat": 4,
415 | "nbformat_minor": 2
416 | }
417 |
--------------------------------------------------------------------------------
/dqn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "\n",
11 | "import os\n",
12 | "import re\n",
13 | "import gym\n",
14 | "import time\n",
15 | "import copy\n",
16 | "import random\n",
17 | "import warnings\n",
18 | "import numpy as np\n",
19 | "\n",
20 | "import torch\n",
21 | "import torchvision\n",
22 | "import torch.nn as nn\n",
23 | "\n",
24 | "from IPython import display\n",
25 | "from skimage.color import rgb2grey\n",
26 | "from skimage.transform import rescale\n",
27 | "from matplotlib import pyplot as plt\n",
28 | "from tqdm import tqdm_notebook as tqdm\n",
29 | "from collections import deque, namedtuple"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "plt.style.use('seaborn')\n",
39 | "warnings.filterwarnings('ignore')"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "class DeepQNetwork(nn.Module):\n",
49 | " def __init__(self, num_frames, num_actions):\n",
50 | " super(DeepQNetwork, self).__init__()\n",
51 | " self.num_frames = num_frames\n",
52 | " self.num_actions = num_actions\n",
53 | " \n",
54 | " # Layers\n",
55 | " self.conv1 = nn.Conv2d(\n",
56 | " in_channels=num_frames,\n",
57 | " out_channels=16,\n",
58 | " kernel_size=8,\n",
59 | " stride=4,\n",
60 | " padding=2\n",
61 | " )\n",
62 | " self.conv2 = nn.Conv2d(\n",
63 | " in_channels=16,\n",
64 | " out_channels=32,\n",
65 | " kernel_size=4,\n",
66 | " stride=2,\n",
67 | " padding=1\n",
68 | " )\n",
69 | " self.fc1 = nn.Linear(\n",
70 | " in_features=3200,\n",
71 | " out_features=256,\n",
72 | " )\n",
73 | " self.fc2 = nn.Linear(\n",
74 | " in_features=256,\n",
75 | " out_features=num_actions,\n",
76 | " )\n",
77 | " \n",
78 | " # Activation Functions\n",
79 | " self.relu = nn.ReLU()\n",
80 | " \n",
81 | " def flatten(self, x):\n",
82 | " batch_size = x.size()[0]\n",
83 | " x = x.view(batch_size, -1)\n",
84 | " return x\n",
85 | " \n",
86 | " def forward(self, x):\n",
87 | " \n",
88 | " # Forward pass\n",
89 | " x = self.relu(self.conv1(x)) # In: (80, 80, 4) Out: (20, 20, 16)\n",
90 | " x = self.relu(self.conv2(x)) # In: (20, 20, 16) Out: (10, 10, 32)\n",
91 | " x = self.flatten(x) # In: (10, 10, 32) Out: (3200,)\n",
92 | " x = self.relu(self.fc1(x)) # In: (3200,) Out: (256,)\n",
93 | " x = self.fc2(x) # In: (256,) Out: (4,)\n",
94 | " \n",
95 | " return x"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 4,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "Transition = namedtuple('Transition', ['state', 'action', 'reward', 'terminal', 'next_state'])"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 5,
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "class Agent:\n",
114 | " def __init__(self, model, memory_depth, lr, gamma, epsilon_i, epsilon_f, anneal_time, ckptdir):\n",
115 | " \n",
116 | " self.cuda = True if torch.cuda.is_available() else False\n",
117 | " \n",
118 | " self.model = model\n",
119 | " self.device = torch.device(\"cuda\" if self.cuda else \"cpu\")\n",
120 | " \n",
121 | " if self.cuda:\n",
122 | " self.model = self.model.cuda()\n",
123 | " \n",
124 | " self.memory_depth = memory_depth\n",
125 | " self.gamma = torch.tensor([gamma], device=self.device)\n",
126 | " self.e_i = epsilon_i\n",
127 | " self.e_f = epsilon_f\n",
128 | " self.anneal_time = anneal_time\n",
129 | " self.ckptdir = ckptdir\n",
130 | " \n",
131 | " if not os.path.isdir(ckptdir):\n",
132 | " os.makedirs(ckptdir)\n",
133 | " \n",
134 | " self.memory = deque(maxlen=memory_depth)\n",
135 | " self.clone()\n",
136 | " \n",
137 | " self.loss = nn.SmoothL1Loss()\n",
138 | " self.opt = torch.optim.RMSprop(self.model.parameters(), lr=lr, alpha=0.95, eps=0.01)\n",
139 | " \n",
140 | " def clone(self):\n",
141 | " try:\n",
142 | " del self.clone_model\n",
143 | " except:\n",
144 | " pass\n",
145 | " \n",
146 | " self.clone_model = copy.deepcopy(self.model)\n",
147 | " \n",
148 | " for p in self.clone_model.parameters():\n",
149 | " p.requires_grad = False\n",
150 | " \n",
151 | " if self.cuda:\n",
152 | " self.clone_model = self.clone_model.cuda()\n",
153 | " \n",
154 | " def remember(self, *args):\n",
155 | " self.memory.append(Transition(*args))\n",
156 | " \n",
157 | " def retrieve(self, batch_size):\n",
158 | " transitions = random.sample(self.memory, batch_size)\n",
159 | " batch = Transition(*zip(*transitions))\n",
160 | " state, action, reward, terminal, next_state = map(torch.cat, [*batch])\n",
161 | " return state, action, reward, terminal, next_state\n",
162 | " \n",
163 | " @property\n",
164 | " def memories(self):\n",
165 | " return len(self.memory)\n",
166 | " \n",
167 | " def act(self, state):\n",
168 | " q_values = self.model(state).detach()\n",
169 | " action = torch.argmax(q_values)\n",
170 | " return action.item()\n",
171 | " \n",
172 | " def process(self, state):\n",
173 | " state = rgb2grey(state[35:195, :, :])\n",
174 | " state = rescale(state, scale=0.5)\n",
175 | " state = state[np.newaxis, np.newaxis, :, :]\n",
176 | " return torch.tensor(state, device=self.device, dtype=torch.float)\n",
177 | " \n",
178 | " def exploration_rate(self, t):\n",
179 | " if 0 <= t < self.anneal_time:\n",
180 | " return self.e_i - t*(self.e_i - self.e_f)/self.anneal_time\n",
181 | " elif t >= self.anneal_time:\n",
182 | " return self.e_f\n",
183 | " elif t < 0:\n",
184 | " return self.e_i\n",
185 | " \n",
186 | " def save(self, t):\n",
187 | " save_path = os.path.join(self.ckptdir, 'model-{}'.format(t))\n",
188 | " torch.save(self.model.state_dict(), save_path)\n",
189 | " \n",
190 | " def load(self):\n",
191 | " ckpts = [file for file in os.listdir(self.ckptdir) if 'model' in file]\n",
192 | " steps = [int(re.search('\\d+', file).group(0)) for file in ckpts]\n",
193 | " \n",
194 | " latest_ckpt = ckpts[np.argmax(steps)]\n",
195 | " self.t = np.max(steps)\n",
196 | " \n",
197 | " print(\"Loading checkpoint: {}\".format(latest_ckpt))\n",
198 | " \n",
199 | " self.model.load_state_dict(torch.load(os.path.join(self.ckptdir, latest_ckpt)))\n",
200 | " \n",
201 | " def update(self, batch_size):\n",
202 | " self.model.zero_grad()\n",
203 | "\n",
204 | " state, action, reward, terminal, next_state = self.retrieve(batch_size)\n",
205 | " q = self.model(state).gather(1, action.view(batch_size, 1))\n",
206 | " qmax = self.clone_model(next_state).max(dim=1)[0]\n",
207 | " \n",
208 | " nonterminal_target = reward + self.gamma*qmax\n",
209 | " terminal_target = reward\n",
210 | " \n",
211 | " target = terminal.float()*terminal_target + (~terminal).float()*nonterminal_target\n",
212 | " \n",
213 | " loss = self.loss(q.view(-1), target)\n",
214 | " loss.backward()\n",
215 | " self.opt.step()\n",
216 | "\n",
217 | " def play(self, episodes, train=False, load=False, plot=False, render=False, verbose=False):\n",
218 | " \n",
219 | " self.t = 0\n",
220 | " metadata = dict(episode=[], reward=[])\n",
221 | " \n",
222 | " if load:\n",
223 | " self.load()\n",
224 | "\n",
225 | " try:\n",
226 | " progress_bar = tqdm(range(episodes), unit='episode')\n",
227 | " \n",
228 | " i = 0\n",
229 | " for episode in progress_bar:\n",
230 | "\n",
231 | " state = env.reset()\n",
232 | " state = self.process(state)\n",
233 | " \n",
234 | " done = False\n",
235 | " total_reward = 0\n",
236 | "\n",
237 | " while not done:\n",
238 | "\n",
239 | " if render:\n",
240 | " env.render()\n",
241 | "\n",
242 | " while state.size()[1] < num_frames:\n",
243 | " action = 1 # Fire\n",
244 | "\n",
245 | " new_frame, reward, done, info = env.step(action)\n",
246 | " new_frame = self.process(new_frame)\n",
247 | "\n",
248 | " state = torch.cat([state, new_frame], 1)\n",
249 | " \n",
250 | " if train and np.random.uniform() < self.exploration_rate(self.t-burn_in):\n",
251 | " action = np.random.choice(num_actions)\n",
252 | "\n",
253 | " else:\n",
254 | " action = self.act(state)\n",
255 | "\n",
256 | " new_frame, reward, done, info = env.step(action)\n",
257 | " new_frame = self.process(new_frame)\n",
258 | "\n",
259 | " new_state = torch.cat([state, new_frame], 1)\n",
260 | " new_state = new_state[:, 1:, :, :]\n",
261 | " \n",
262 | " if train:\n",
263 | " reward = torch.tensor([reward], device=self.device, dtype=torch.float)\n",
264 | " action = torch.tensor([action], device=self.device, dtype=torch.long)\n",
265 | " done = torch.tensor([done], device=self.device, dtype=torch.uint8)\n",
266 | " \n",
267 | " self.remember(state, action, reward, done, new_state)\n",
268 | "\n",
269 | " state = new_state\n",
270 | " total_reward += reward\n",
271 | " self.t += 1\n",
272 | " i += 1\n",
273 | " \n",
274 | " if not train:\n",
275 | " time.sleep(0.1)\n",
276 | " \n",
277 | " if train and self.t > burn_in and i > batch_size:\n",
278 | "\n",
279 | " if self.t % update_interval == 0:\n",
280 | " self.update(batch_size)\n",
281 | "\n",
282 | " if self.t % clone_interval == 0:\n",
283 | " self.clone()\n",
284 | "\n",
285 | " if self.t % save_interval == 0:\n",
286 | " self.save(self.t)\n",
287 | "\n",
288 | " if self.t % 1000 == 0:\n",
289 | " progress_bar.set_description(\"t = {}\".format(self.t))\n",
290 | "\n",
291 | " metadata['episode'].append(episode)\n",
292 | " metadata['reward'].append(total_reward)\n",
293 | "\n",
294 | " if episode % 100 == 0 and episode != 0:\n",
295 | " avg_return = np.mean(metadata['reward'][-100:])\n",
296 | " print(\"Average return (last 100 episodes): {:.2f}\".format(avg_return))\n",
297 | "\n",
298 | " if plot:\n",
299 | " plt.scatter(metadata['episode'], metadata['reward'])\n",
300 | " plt.xlim(0, episodes)\n",
301 | " plt.xlabel(\"Episode\")\n",
302 | " plt.ylabel(\"Return\")\n",
303 | " display.clear_output(wait=True)\n",
304 | " display.display(plt.gcf())\n",
305 | " \n",
306 | " env.close()\n",
307 | " return metadata\n",
308 | "\n",
309 | " except KeyboardInterrupt:\n",
310 | " if train:\n",
311 | " print(\"Saving model before quitting...\")\n",
312 | " self.save(self.t)\n",
313 | " \n",
314 | " env.close()\n",
315 | " return metadata"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 6,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "# Hyperparameters\n",
325 | "\n",
326 | "batch_size = 32\n",
327 | "update_interval = 4\n",
328 | "clone_interval = int(1e4)\n",
329 | "save_interval = int(1e5)\n",
330 | "frame_skip = None\n",
331 | "num_frames = 4\n",
332 | "num_actions = 4\n",
333 | "episodes = int(1e5)\n",
334 | "memory_depth = int(1e5)\n",
335 | "epsilon_i = 1.0\n",
336 | "epsilon_f = 0.1\n",
337 | "anneal_time = int(1e6)\n",
338 | "burn_in = int(5e4)\n",
339 | "gamma = 0.99\n",
340 | "learning_rate = 2.5e-4"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 7,
346 | "metadata": {},
347 | "outputs": [],
348 | "source": [
349 | "model = DeepQNetwork(num_frames, num_actions)"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": 8,
355 | "metadata": {},
356 | "outputs": [],
357 | "source": [
358 | "agent = Agent(model, memory_depth, learning_rate, gamma, epsilon_i, epsilon_f, anneal_time, 'ckpt')"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 9,
364 | "metadata": {},
365 | "outputs": [],
366 | "source": [
367 | "env = gym.make('BreakoutDeterministic-v4')"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": 10,
373 | "metadata": {
374 | "scrolled": false
375 | },
376 | "outputs": [
377 | {
378 | "name": "stdout",
379 | "output_type": "stream",
380 | "text": [
381 | "Loading checkpoint: model-3808836\n"
382 | ]
383 | },
384 | {
385 | "data": {
386 | "application/vnd.jupyter.widget-view+json": {
387 | "model_id": "9187fc04563b48c68e1f9f3d4faacc08",
388 | "version_major": 2,
389 | "version_minor": 0
390 | },
391 | "text/plain": [
392 | "HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))"
393 | ]
394 | },
395 | "metadata": {},
396 | "output_type": "display_data"
397 | },
398 | {
399 | "name": "stdout",
400 | "output_type": "stream",
401 | "text": [
402 | "Saving model before quitting...\n"
403 | ]
404 | }
405 | ],
406 | "source": [
407 | "metadata = agent.play(episodes, train=True, load=True)"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": null,
413 | "metadata": {},
414 | "outputs": [],
415 | "source": []
416 | }
417 | ],
418 | "metadata": {
419 | "kernelspec": {
420 | "display_name": "Python 3",
421 | "language": "python",
422 | "name": "python3"
423 | },
424 | "language_info": {
425 | "codemirror_mode": {
426 | "name": "ipython",
427 | "version": 3
428 | },
429 | "file_extension": ".py",
430 | "mimetype": "text/x-python",
431 | "name": "python",
432 | "nbconvert_exporter": "python",
433 | "pygments_lexer": "ipython3",
434 | "version": "3.6.5"
435 | }
436 | },
437 | "nbformat": 4,
438 | "nbformat_minor": 2
439 | }
440 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/dqn-named-tuple-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "\n",
11 | "import os\n",
12 | "import re\n",
13 | "import gym\n",
14 | "import time\n",
15 | "import copy\n",
16 | "import random\n",
17 | "import warnings\n",
18 | "import numpy as np\n",
19 | "\n",
20 | "import torch\n",
21 | "import torchvision\n",
22 | "import torch.nn as nn\n",
23 | "\n",
24 | "from IPython import display\n",
25 | "from skimage.color import rgb2grey\n",
26 | "from skimage.transform import rescale\n",
27 | "from matplotlib import pyplot as plt\n",
28 | "from tqdm import tqdm_notebook as tqdm\n",
29 | "from collections import deque, namedtuple"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "plt.style.use('seaborn')\n",
39 | "warnings.filterwarnings('ignore')"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "class DeepQNetwork(nn.Module):\n",
49 | " def __init__(self, num_frames, num_actions):\n",
50 | " super(DeepQNetwork, self).__init__()\n",
51 | " self.num_frames = num_frames\n",
52 | " self.num_actions = num_actions\n",
53 | " \n",
54 | " # Layers\n",
55 | " self.conv1 = nn.Conv2d(\n",
56 | " in_channels=num_frames,\n",
57 | " out_channels=16,\n",
58 | " kernel_size=8,\n",
59 | " stride=4,\n",
60 | " padding=2\n",
61 | " )\n",
62 | " self.conv2 = nn.Conv2d(\n",
63 | " in_channels=16,\n",
64 | " out_channels=32,\n",
65 | " kernel_size=4,\n",
66 | " stride=2,\n",
67 | " padding=1\n",
68 | " )\n",
69 | " self.fc1 = nn.Linear(\n",
70 | " in_features=3200,\n",
71 | " out_features=256,\n",
72 | " )\n",
73 | " self.fc2 = nn.Linear(\n",
74 | " in_features=256,\n",
75 | " out_features=num_actions,\n",
76 | " )\n",
77 | " \n",
78 | " # Activation Functions\n",
79 | " self.relu = nn.ReLU()\n",
80 | " \n",
81 | " def flatten(self, x):\n",
82 | " batch_size = x.size()[0]\n",
83 | " x = x.view(batch_size, -1)\n",
84 | " return x\n",
85 | " \n",
86 | " def forward(self, x):\n",
87 | " \n",
88 | " # Forward pass\n",
89 | " x = self.relu(self.conv1(x)) # In: (80, 80, 4) Out: (20, 20, 16)\n",
90 | " x = self.relu(self.conv2(x)) # In: (20, 20, 16) Out: (10, 10, 32)\n",
91 | " x = self.flatten(x) # In: (10, 10, 32) Out: (3200,)\n",
92 | " x = self.relu(self.fc1(x)) # In: (3200,) Out: (256,)\n",
93 | " x = self.fc2(x) # In: (256,) Out: (4,)\n",
94 | " \n",
95 | " return x"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 4,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "Transition = namedtuple('Transition', ['state', 'action', 'reward', 'terminal', 'next_state'])"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 5,
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "class Agent:\n",
114 | " def __init__(self, model, memory_depth, lr, gamma, epsilon_i, epsilon_f, anneal_time, ckptdir):\n",
115 | " \n",
116 | " self.cuda = True if torch.cuda.is_available() else False\n",
117 | " \n",
118 | " self.model = model\n",
119 | " self.device = torch.device(\"cuda\" if self.cuda else \"cpu\")\n",
120 | " \n",
121 | " if self.cuda:\n",
122 | " self.model = self.model.cuda()\n",
123 | " \n",
124 | " self.memory_depth = memory_depth\n",
125 | " self.gamma = torch.tensor([gamma], device=self.device)\n",
126 | " self.e_i = epsilon_i\n",
127 | " self.e_f = epsilon_f\n",
128 | " self.anneal_time = anneal_time\n",
129 | " self.ckptdir = ckptdir\n",
130 | " \n",
131 | " if not os.path.isdir(ckptdir):\n",
132 | " os.makedirs(ckptdir)\n",
133 | " \n",
134 | " self.memory = deque(maxlen=memory_depth)\n",
135 | " self.clone()\n",
136 | " \n",
137 | " self.loss = nn.SmoothL1Loss()\n",
138 | " self.opt = torch.optim.RMSprop(self.model.parameters(), lr=lr, alpha=0.95, eps=0.01)\n",
139 | " \n",
140 | " def clone(self):\n",
141 | " try:\n",
142 | " del self.clone_model\n",
143 | " except:\n",
144 | " pass\n",
145 | " \n",
146 | " self.clone_model = copy.deepcopy(self.model)\n",
147 | " \n",
148 | " for p in self.clone_model.parameters():\n",
149 | " p.requires_grad = False\n",
150 | " \n",
151 | " if self.cuda:\n",
152 | " self.clone_model = self.clone_model.cuda()\n",
153 | " \n",
154 | " def remember(self, *args):\n",
155 | " self.memory.append(Transition(*args))\n",
156 | " \n",
157 | " def retrieve(self, batch_size):\n",
158 | " transitions = random.sample(self.memory, batch_size)\n",
159 | " batch = Transition(*zip(*transitions))\n",
160 | " state, action, reward, terminal, next_state = map(torch.cat, [*batch])\n",
161 | " return state, action, reward, terminal, next_state\n",
162 | " \n",
163 | " @property\n",
164 | " def memories(self):\n",
165 | " return len(self.memory)\n",
166 | " \n",
167 | " def act(self, state):\n",
168 | " q_values = self.model(state).detach()\n",
169 | " action = torch.argmax(q_values)\n",
170 | " return action.item()\n",
171 | " \n",
172 | " def process(self, state):\n",
173 | " state = rgb2grey(state[35:195, :, :])\n",
174 | " state = rescale(state, scale=0.5)\n",
175 | " state = state[np.newaxis, np.newaxis, :, :]\n",
176 | " return torch.tensor(state, device=self.device, dtype=torch.float)\n",
177 | " \n",
178 | " def exploration_rate(self, t):\n",
179 | " if 0 <= t < self.anneal_time:\n",
180 | " return self.e_i - t*(self.e_i - self.e_f)/self.anneal_time\n",
181 | " elif t >= self.anneal_time:\n",
182 | " return self.e_f\n",
183 | " elif t < 0:\n",
184 | " return self.e_i\n",
185 | " \n",
186 | " def save(self, t):\n",
187 | " save_path = os.path.join(self.ckptdir, 'model-{}'.format(t))\n",
188 | " torch.save(self.model.state_dict(), save_path)\n",
189 | " \n",
190 | " def load(self):\n",
191 | " ckpts = [file for file in os.listdir(self.ckptdir) if 'model' in file]\n",
192 | " steps = [int(re.search('\\d+', file).group(0)) for file in ckpts]\n",
193 | " \n",
194 | " latest_ckpt = ckpts[np.argmax(steps)]\n",
195 | " self.t = np.max(steps)\n",
196 | " \n",
197 | " print(\"Loading checkpoint: {}\".format(latest_ckpt))\n",
198 | " \n",
199 | " self.model.load_state_dict(torch.load(os.path.join(self.ckptdir, latest_ckpt)))\n",
200 | " \n",
201 | " def update(self, batch_size):\n",
202 | " self.model.zero_grad()\n",
203 | "\n",
204 | " state, action, reward, terminal, next_state = self.retrieve(batch_size)\n",
205 | " q = self.model(state).gather(1, action.view(batch_size, 1))\n",
206 | " qmax = self.clone_model(next_state).max(dim=1)[0]\n",
207 | " \n",
208 | " nonterminal_target = reward + self.gamma*qmax\n",
209 | " terminal_target = reward\n",
210 | " \n",
211 | " target = terminal.float()*terminal_target + (~terminal).float()*nonterminal_target\n",
212 | " \n",
213 | " loss = self.loss(q.view(-1), target)\n",
214 | " loss.backward()\n",
215 | " self.opt.step()\n",
216 | "\n",
217 | " def play(self, episodes, train=False, load=False, plot=False, render=False, verbose=False):\n",
218 | " \n",
219 | " self.t = 0\n",
220 | " metadata = dict(episode=[], reward=[])\n",
221 | " \n",
222 | " if load:\n",
223 | " self.load()\n",
224 | "\n",
225 | " try:\n",
226 | " progress_bar = tqdm(range(episodes), unit='episode')\n",
227 | " \n",
228 | " i = 0\n",
229 | " for episode in progress_bar:\n",
230 | "\n",
231 | " state = env.reset()\n",
232 | " state = self.process(state)\n",
233 | " \n",
234 | " done = False\n",
235 | " total_reward = 0\n",
236 | "\n",
237 | " while not done:\n",
238 | "\n",
239 | " if render:\n",
240 | " env.render()\n",
241 | "\n",
242 | " while state.size()[1] < num_frames:\n",
243 | " action = 1 # Fire\n",
244 | "\n",
245 | " new_frame, reward, done, info = env.step(action)\n",
246 | " new_frame = self.process(new_frame)\n",
247 | "\n",
248 | " state = torch.cat([state, new_frame], 1)\n",
249 | " \n",
250 | " if train and np.random.uniform() < self.exploration_rate(self.t-burn_in):\n",
251 | " action = np.random.choice(num_actions)\n",
252 | "\n",
253 | " else:\n",
254 | " action = self.act(state)\n",
255 | "\n",
256 | " new_frame, reward, done, info = env.step(action)\n",
257 | " new_frame = self.process(new_frame)\n",
258 | "\n",
259 | " new_state = torch.cat([state, new_frame], 1)\n",
260 | " new_state = new_state[:, 1:, :, :]\n",
261 | " \n",
262 | " if train:\n",
263 | " reward = torch.tensor([reward], device=self.device, dtype=torch.float)\n",
264 | " action = torch.tensor([action], device=self.device, dtype=torch.long)\n",
265 | " done = torch.tensor([done], device=self.device, dtype=torch.uint8)\n",
266 | " \n",
267 | " self.remember(state, action, reward, done, new_state)\n",
268 | "\n",
269 | " state = new_state\n",
270 | " total_reward += reward\n",
271 | " self.t += 1\n",
272 | " i += 1\n",
273 | " \n",
274 | " if not train:\n",
275 | " time.sleep(0.1)\n",
276 | " \n",
277 | " if train and self.t > burn_in and i > batch_size:\n",
278 | "\n",
279 | " if self.t % update_interval == 0:\n",
280 | " self.update(batch_size)\n",
281 | "\n",
282 | " if self.t % clone_interval == 0:\n",
283 | " self.clone()\n",
284 | "\n",
285 | " if self.t % save_interval == 0:\n",
286 | " self.save(self.t)\n",
287 | "\n",
288 | " if self.t % 1000 == 0:\n",
289 | " progress_bar.set_description(\"t = {}\".format(self.t))\n",
290 | "\n",
291 | " metadata['episode'].append(episode)\n",
292 | " metadata['reward'].append(total_reward)\n",
293 | "\n",
294 | " if episode % 100 == 0 and episode != 0:\n",
295 | " avg_return = np.mean(metadata['reward'][-100:])\n",
296 | " print(\"Average return (last 100 episodes): {:.2f}\".format(avg_return))\n",
297 | "\n",
298 | " if plot:\n",
299 | " plt.scatter(metadata['episode'], metadata['reward'])\n",
300 | " plt.xlim(0, episodes)\n",
301 | " plt.xlabel(\"Episode\")\n",
302 | " plt.ylabel(\"Return\")\n",
303 | " display.clear_output(wait=True)\n",
304 | " display.display(plt.gcf())\n",
305 | " \n",
306 | " env.close()\n",
307 | " return metadata\n",
308 | "\n",
309 | " except KeyboardInterrupt:\n",
310 | " if train:\n",
311 | " print(\"Saving model before quitting...\")\n",
312 | " self.save(self.t)\n",
313 | " \n",
314 | " env.close()\n",
315 | " return metadata"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 6,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "# Hyperparameters\n",
325 | "\n",
326 | "batch_size = 32\n",
327 | "update_interval = 4\n",
328 | "clone_interval = int(1e4)\n",
329 | "save_interval = int(1e5)\n",
330 | "frame_skip = None\n",
331 | "num_frames = 4\n",
332 | "num_actions = 4\n",
333 | "episodes = int(1e5)\n",
334 | "memory_depth = int(1e5)\n",
335 | "epsilon_i = 1.0\n",
336 | "epsilon_f = 0.1\n",
337 | "anneal_time = int(1e6)\n",
338 | "burn_in = int(5e4)\n",
339 | "gamma = 0.99\n",
340 | "learning_rate = 2.5e-4"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 7,
346 | "metadata": {},
347 | "outputs": [],
348 | "source": [
349 | "model = DeepQNetwork(num_frames, num_actions)"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": 8,
355 | "metadata": {},
356 | "outputs": [],
357 | "source": [
358 | "agent = Agent(model, memory_depth, learning_rate, gamma, epsilon_i, epsilon_f, anneal_time, 'ckpt')"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 9,
364 | "metadata": {},
365 | "outputs": [],
366 | "source": [
367 | "env = gym.make('BreakoutDeterministic-v4')"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": 10,
373 | "metadata": {
374 | "scrolled": false
375 | },
376 | "outputs": [
377 | {
378 | "name": "stdout",
379 | "output_type": "stream",
380 | "text": [
381 | "Loading checkpoint: model-3808836\n"
382 | ]
383 | },
384 | {
385 | "data": {
386 | "application/vnd.jupyter.widget-view+json": {
387 | "model_id": "9187fc04563b48c68e1f9f3d4faacc08",
388 | "version_major": 2,
389 | "version_minor": 0
390 | },
391 | "text/plain": [
392 | "HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))"
393 | ]
394 | },
395 | "metadata": {},
396 | "output_type": "display_data"
397 | },
398 | {
399 | "name": "stdout",
400 | "output_type": "stream",
401 | "text": [
402 | "Saving model before quitting...\n"
403 | ]
404 | }
405 | ],
406 | "source": [
407 | "metadata = agent.play(episodes, train=True, load=True)"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": null,
413 | "metadata": {},
414 | "outputs": [],
415 | "source": []
416 | }
417 | ],
418 | "metadata": {
419 | "kernelspec": {
420 | "display_name": "Python 3",
421 | "language": "python",
422 | "name": "python3"
423 | },
424 | "language_info": {
425 | "codemirror_mode": {
426 | "name": "ipython",
427 | "version": 3
428 | },
429 | "file_extension": ".py",
430 | "mimetype": "text/x-python",
431 | "name": "python",
432 | "nbconvert_exporter": "python",
433 | "pygments_lexer": "ipython3",
434 | "version": "3.6.5"
435 | }
436 | },
437 | "nbformat": 4,
438 | "nbformat_minor": 2
439 | }
440 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/dqn-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "\n",
11 | "import os\n",
12 | "import re\n",
13 | "import gym\n",
14 | "import time\n",
15 | "import copy\n",
16 | "import random\n",
17 | "import warnings\n",
18 | "import numpy as np\n",
19 | "\n",
20 | "import torch\n",
21 | "import torchvision\n",
22 | "import torch.nn as nn\n",
23 | "\n",
24 | "from IPython import display\n",
25 | "from collections import deque\n",
26 | "from skimage.color import rgb2grey\n",
27 | "from skimage.transform import rescale\n",
28 | "from matplotlib import pyplot as plt\n",
29 | "from tqdm import tqdm_notebook as tqdm"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "plt.style.use('seaborn')\n",
39 | "warnings.filterwarnings('ignore')"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "class DeepQNetwork(nn.Module):\n",
49 | " def __init__(self, num_frames, num_actions):\n",
50 | " super(DeepQNetwork, self).__init__()\n",
51 | " self.num_frames = num_frames\n",
52 | " self.num_actions = num_actions\n",
53 | " \n",
54 | " # Layers\n",
55 | " self.conv1 = nn.Conv2d(\n",
56 | " in_channels=num_frames,\n",
57 | " out_channels=16,\n",
58 | " kernel_size=8,\n",
59 | " stride=4,\n",
60 | " padding=2\n",
61 | " )\n",
62 | " self.conv2 = nn.Conv2d(\n",
63 | " in_channels=16,\n",
64 | " out_channels=32,\n",
65 | " kernel_size=4,\n",
66 | " stride=2,\n",
67 | " padding=1\n",
68 | " )\n",
69 | " self.fc1 = nn.Linear(\n",
70 | " in_features=3200,\n",
71 | " out_features=256,\n",
72 | " )\n",
73 | " self.fc2 = nn.Linear(\n",
74 | " in_features=256,\n",
75 | " out_features=num_actions,\n",
76 | " )\n",
77 | " \n",
78 | " # Activations\n",
79 | " self.relu = nn.ReLU()\n",
80 | " \n",
81 | " def flatten(self, x):\n",
82 | " batch_size = x.size()[0]\n",
83 | " x = x.view(batch_size, -1)\n",
84 | " return x\n",
85 | " \n",
86 | " def forward(self, x):\n",
87 | " \n",
88 | " # Forward pass\n",
89 | " x = self.relu(self.conv1(x)) # In: (80, 80, 4), Out: (20, 20, 16)\n",
90 | " x = self.relu(self.conv2(x)) # In: (20, 20, 16), Out: (10, 10, 32)\n",
91 | " x = self.flatten(x) # In: (10, 10, 32), Out: (3200,)\n",
92 | " x = self.relu(self.fc1(x)) # In: (3200,), Out: (256,)\n",
93 | " x = self.fc2(x) # In: (256,), Out: (4,)\n",
94 | " \n",
95 | " return x"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 4,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "def output_size(w, k, s, p):\n",
105 | " return ((w - k + 2*p)/s) + 1"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 5,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "class Agent:\n",
115 | " def __init__(self, model, memory_depth, lr, gamma, epsilon_i, epsilon_f, anneal_time, ckptdir):\n",
116 | " \n",
117 | " self.cuda = True if torch.cuda.is_available() else False\n",
118 | " self.to_tensor = torch.cuda.FloatTensor if self.cuda else torch.FloatTensor\n",
119 | " self.to_byte_tensor = torch.cuda.ByteTensor if self.cuda else torch.ByteTensor\n",
120 | " \n",
121 | " self.model = model\n",
122 | " \n",
123 | " if self.cuda:\n",
124 | " self.model = self.model.cuda()\n",
125 | " \n",
126 | " self.memory_depth = memory_depth\n",
127 | " self.gamma = self.to_tensor([gamma])\n",
128 | " self.e_i = epsilon_i\n",
129 | " self.e_f = epsilon_f\n",
130 | " self.anneal_time = anneal_time\n",
131 | " self.ckptdir = ckptdir\n",
132 | " \n",
133 | " if not os.path.isdir(ckptdir):\n",
134 | " os.makedirs(ckptdir)\n",
135 | " \n",
136 | " self.memory = deque(maxlen=memory_depth)\n",
137 | " self.clone()\n",
138 | " \n",
139 | " self.loss = nn.SmoothL1Loss()\n",
140 | " self.opt = torch.optim.RMSprop(self.model.parameters(), lr=lr, alpha=0.95, eps=0.01)\n",
141 | " \n",
142 | " def clone(self):\n",
143 | " try:\n",
144 | " del self.clone_model\n",
145 | " except:\n",
146 | " pass\n",
147 | " \n",
148 | " self.clone_model = copy.deepcopy(self.model)\n",
149 | " \n",
150 | " for p in self.clone_model.parameters():\n",
151 | " p.requires_grad = False\n",
152 | " \n",
153 | " if self.cuda:\n",
154 | " self.clone_model = self.clone_model.cuda()\n",
155 | " \n",
156 | " def remember(self, state, action, reward, terminal, next_state):\n",
157 | " \n",
158 | " if self.cuda:\n",
159 | " state, next_state = state.cpu(), next_state.cpu()\n",
160 | " \n",
161 | " state, next_state = state.data.numpy(), next_state.data.numpy()\n",
162 | " state, next_state = 255.*state, 255.*next_state\n",
163 | " state, next_state = state.astype(np.uint8), next_state.astype(np.uint8)\n",
164 | " \n",
165 | " self.memory.append([state, action, reward, terminal, next_state])\n",
166 | " \n",
167 | " def retrieve(self, batch_size):\n",
168 | " \n",
169 | " if batch_size > self.memories:\n",
170 | " batch_size = self.memories\n",
171 | " \n",
172 | " batch = random.sample(self.memory, batch_size)\n",
173 | " \n",
174 | " state = np.concatenate([batch[i][0] for i in range(batch_size)]).astype(np.int64)\n",
175 | " action = np.array([batch[i][1] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
176 | " reward = np.array([batch[i][2] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
177 | " terminal = np.array([batch[i][3] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
178 | " next_state = np.concatenate([batch[i][4] for i in range(batch_size)]).astype(np.int64)\n",
179 | " \n",
180 | " state = self.to_tensor(state/255.)\n",
181 | " next_state = self.to_tensor(next_state/255.)\n",
182 | " reward = self.to_tensor(reward)\n",
183 | " terminal = self.to_byte_tensor(terminal)\n",
184 | "\n",
185 | " return state, action, reward, terminal, next_state\n",
186 | " \n",
187 | " @property\n",
188 | " def memories(self):\n",
189 | " return len(self.memory)\n",
190 | " \n",
191 | " def act(self, state):\n",
192 | " q_values = self.model(state).detach()\n",
193 | " action = torch.argmax(q_values)\n",
194 | " return action.item()\n",
195 | " \n",
196 | " def process(self, state):\n",
197 | " state = rgb2grey(state[35:195, :, :])\n",
198 | " state = rescale(state, scale=0.5)\n",
199 | " state = state[np.newaxis, np.newaxis, :, :]\n",
200 | " return self.to_tensor(state)\n",
201 | " \n",
202 | " def exploration_rate(self, t):\n",
203 | " if 0 <= t < self.anneal_time:\n",
204 | " return self.e_i - t*(self.e_i - self.e_f)/self.anneal_time\n",
205 | " elif t >= self.anneal_time:\n",
206 | " return self.e_f\n",
207 | " elif t < 0:\n",
208 | " return self.e_i\n",
209 | " \n",
210 | " def save(self, t):\n",
211 | " save_path = os.path.join(self.ckptdir, 'model-{}'.format(t))\n",
212 | " torch.save(self.model.state_dict(), save_path)\n",
213 | " \n",
214 | " def load(self):\n",
215 | " ckpts = [file for file in os.listdir(self.ckptdir) if 'model' in file]\n",
216 | " steps = [int(re.search('\\d+', file).group(0)) for file in ckpts]\n",
217 | " \n",
218 | " latest_ckpt = ckpts[np.argmax(steps)]\n",
219 | " self.t = np.max(steps)\n",
220 | " \n",
221 | " print(\"Loading checkpoint: {}\".format(latest_ckpt))\n",
222 | " \n",
223 | " self.model.load_state_dict(torch.load(os.path.join(self.ckptdir, latest_ckpt)))\n",
224 | " \n",
225 | " def update(self, batch_size, verbose=False):\n",
226 | " \n",
227 | " self.model.zero_grad()\n",
228 | " \n",
229 | " start = time.time()\n",
230 | " state, action, reward, terminal, next_state = self.retrieve(batch_size)\n",
231 | " \n",
232 | " if verbose:\n",
233 | " print(\"Sampled memory in {:0.2f} seconds.\".format(time.time() - start))\n",
234 | " \n",
235 | " start = time.time()\n",
236 | " \n",
237 | " q = self.model(state)[range(batch_size), action.flatten()][:, None]\n",
238 | " qmax = self.clone_model(next_state).max(dim=1)[0][:, None]\n",
239 | " \n",
240 | " nonterminal_target = reward + self.gamma*qmax\n",
241 | " terminal_target = reward\n",
242 | " \n",
243 | " target = terminal.float()*terminal_target + (~terminal).float()*nonterminal_target\n",
244 | " \n",
245 | " loss = self.loss(q, target)\n",
246 | " \n",
247 | " loss.backward()\n",
248 | " self.opt.step()\n",
249 | " \n",
250 | " if verbose:\n",
251 | " print(\"Updated parameters in {:0.2f} seconds.\".format(time.time() - start))\n",
252 | "\n",
253 | " def play(self, episodes, train=False, load=False, plot=False, render=False, verbose=False):\n",
254 | " \n",
255 | " self.t = 0\n",
256 | " metadata = dict(episode=[], reward=[])\n",
257 | " \n",
258 | " if load:\n",
259 | " self.load()\n",
260 | "\n",
261 | " try:\n",
262 | " progress_bar = tqdm(range(episodes), unit='episode')\n",
263 | " \n",
264 | " i = 0\n",
265 | " for episode in progress_bar:\n",
266 | "\n",
267 | " state = env.reset()\n",
268 | " state = self.process(state)\n",
269 | " \n",
270 | " done = False\n",
271 | " total_reward = 0\n",
272 | "\n",
273 | " while not done:\n",
274 | "\n",
275 | " if render:\n",
276 | " env.render()\n",
277 | "\n",
278 | " while state.size()[1] < num_frames:\n",
279 | " action = 1 # Fire\n",
280 | "\n",
281 | " new_frame, reward, done, info = env.step(action)\n",
282 | " new_frame = self.process(new_frame)\n",
283 | "\n",
284 | " state = torch.cat([state, new_frame], 1)\n",
285 | " \n",
286 | " if train and np.random.uniform() < self.exploration_rate(self.t-burn_in):\n",
287 | " action = np.random.choice(num_actions)\n",
288 | "\n",
289 | " else:\n",
290 | " action = self.act(state)\n",
291 | "\n",
292 | " new_frame, reward, done, info = env.step(action)\n",
293 | " new_frame = self.process(new_frame)\n",
294 | "\n",
295 | " new_state = torch.cat([state, new_frame], 1)\n",
296 | " new_state = new_state[:, 1:, :, :]\n",
297 | " \n",
298 | " if train:\n",
299 | " self.remember(state, action, reward, done, new_state)\n",
300 | "\n",
301 | " state = new_state\n",
302 | " total_reward += reward\n",
303 | " self.t += 1\n",
304 | " i += 1\n",
305 | " \n",
306 | " if not train:\n",
307 | " time.sleep(0.1)\n",
308 | " \n",
309 | " if train and i > batch_size:\n",
310 | "\n",
311 | " if self.t % update_interval == 0:\n",
312 | " self.update(batch_size, verbose=verbose)\n",
313 | "\n",
314 | " if self.t % clone_interval == 0:\n",
315 | " self.clone()\n",
316 | "\n",
317 | " if self.t % save_interval == 0:\n",
318 | " self.save(self.t)\n",
319 | "\n",
320 | " if self.t % 1000 == 0:\n",
321 | " progress_bar.set_description(\"t = {}\".format(self.t))\n",
322 | "\n",
323 | " metadata['episode'].append(episode)\n",
324 | " metadata['reward'].append(total_reward)\n",
325 | "\n",
326 | " if episode % 100 == 0 and episode != 0:\n",
327 | " avg_return = np.mean(metadata['reward'][-100:])\n",
328 | " print(\"Average return (last 100 episodes): {}\".format(avg_return))\n",
329 | "\n",
330 | " if plot:\n",
331 | " plt.scatter(metadata['episode'], metadata['reward'])\n",
332 | " plt.xlim(0, episodes)\n",
333 | " plt.xlabel(\"Episode\")\n",
334 | " plt.ylabel(\"Return\")\n",
335 | " display.clear_output(wait=True)\n",
336 | " display.display(plt.gcf())\n",
337 | " \n",
338 | " env.close()\n",
339 | " return metadata\n",
340 | "\n",
341 | " except KeyboardInterrupt:\n",
342 | " if train:\n",
343 | " print(\"Saving model before quitting...\")\n",
344 | " self.save(self.t)\n",
345 | " \n",
346 | " env.close()\n",
347 | " return metadata"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": 6,
353 | "metadata": {},
354 | "outputs": [],
355 | "source": [
356 | "# Hyperparameters\n",
357 | "\n",
358 | "batch_size = 32\n",
359 | "update_interval = 4\n",
360 | "clone_interval = int(1e4)\n",
361 | "save_interval = int(1e5)\n",
362 | "frame_skip = None\n",
363 | "num_frames = 4\n",
364 | "num_actions = 4\n",
365 | "episodes = int(1e5)\n",
366 | "memory_depth = int(1e5)\n",
367 | "epsilon_i = 1.0\n",
368 | "epsilon_f = 0.1\n",
369 | "anneal_time = int(1e6)\n",
370 | "burn_in = int(5e4)\n",
371 | "gamma = 0.99\n",
372 | "learning_rate = 2.5e-4"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": 7,
378 | "metadata": {},
379 | "outputs": [],
380 | "source": [
381 | "model = DeepQNetwork(num_frames, num_actions)"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": 8,
387 | "metadata": {},
388 | "outputs": [],
389 | "source": [
390 | "agent = Agent(model, memory_depth, learning_rate, gamma, epsilon_i, epsilon_f, anneal_time, 'ckpt')"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": 9,
396 | "metadata": {},
397 | "outputs": [],
398 | "source": [
399 | "env = gym.make('BreakoutDeterministic-v4')"
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "execution_count": 10,
405 | "metadata": {
406 | "scrolled": false
407 | },
408 | "outputs": [
409 | {
410 | "name": "stdout",
411 | "output_type": "stream",
412 | "text": [
413 | "Loading checkpoint: model-2985892\n"
414 | ]
415 | },
416 | {
417 | "data": {
418 | "application/vnd.jupyter.widget-view+json": {
419 | "model_id": "4728272ffe5c4a82b000f8f94f94b086",
420 | "version_major": 2,
421 | "version_minor": 0
422 | },
423 | "text/plain": [
424 | "HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))"
425 | ]
426 | },
427 | "metadata": {},
428 | "output_type": "display_data"
429 | },
430 | {
431 | "name": "stdout",
432 | "output_type": "stream",
433 | "text": [
434 | "Average return (last 100 episodes): 4.43\n",
435 | "Average return (last 100 episodes): 4.67\n",
436 | "Average return (last 100 episodes): 3.88\n",
437 | "Average return (last 100 episodes): 2.97\n",
438 | "Average return (last 100 episodes): 3.05\n",
439 | "Average return (last 100 episodes): 2.68\n",
440 | "Average return (last 100 episodes): 3.1\n",
441 | "Average return (last 100 episodes): 4.96\n",
442 | "Average return (last 100 episodes): 6.26\n",
443 | "Average return (last 100 episodes): 6.2\n",
444 | "Average return (last 100 episodes): 5.65\n",
445 | "Average return (last 100 episodes): 6.75\n",
446 | "Average return (last 100 episodes): 7.27\n",
447 | "Average return (last 100 episodes): 6.72\n",
448 | "Average return (last 100 episodes): 5.78\n",
449 | "Average return (last 100 episodes): 6.35\n",
450 | "Average return (last 100 episodes): 6.86\n",
451 | "Average return (last 100 episodes): 6.94\n",
452 | "Average return (last 100 episodes): 7.54\n",
453 | "Average return (last 100 episodes): 8.38\n",
454 | "Average return (last 100 episodes): 9.32\n",
455 | "Average return (last 100 episodes): 10.21\n",
456 | "Average return (last 100 episodes): 10.75\n",
457 | "Saving model before quitting...\n"
458 | ]
459 | }
460 | ],
461 | "source": [
462 | "metadata = agent.play(episodes, train=True, load=True)"
463 | ]
464 | },
465 | {
466 | "cell_type": "code",
467 | "execution_count": null,
468 | "metadata": {},
469 | "outputs": [],
470 | "source": []
471 | }
472 | ],
473 | "metadata": {
474 | "kernelspec": {
475 | "display_name": "Python 3",
476 | "language": "python",
477 | "name": "python3"
478 | },
479 | "language_info": {
480 | "codemirror_mode": {
481 | "name": "ipython",
482 | "version": 3
483 | },
484 | "file_extension": ".py",
485 | "mimetype": "text/x-python",
486 | "name": "python",
487 | "nbconvert_exporter": "python",
488 | "pygments_lexer": "ipython3",
489 | "version": "3.6.5"
490 | }
491 | },
492 | "nbformat": 4,
493 | "nbformat_minor": 2
494 | }
495 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/main-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%matplotlib inline\n",
12 | "\n",
13 | "import os \n",
14 | "import gym\n",
15 | "import time\n",
16 | "import copy\n",
17 | "import random\n",
18 | "import numpy as np\n",
19 | "\n",
20 | "import torch\n",
21 | "import torchvision\n",
22 | "import torch.nn as nn\n",
23 | "\n",
24 | "from IPython import display\n",
25 | "from collections import deque\n",
26 | "from skimage.color import rgb2grey\n",
27 | "from matplotlib import pyplot as plt\n",
28 | "from tqdm import tqdm_notebook as tqdm"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {
35 | "collapsed": true
36 | },
37 | "outputs": [],
38 | "source": [
39 | "plt.style.use('seaborn')"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {
46 | "collapsed": true
47 | },
48 | "outputs": [],
49 | "source": [
50 | "class DeepQNetwork(nn.Module):\n",
51 | " def __init__(self, num_frames, num_actions):\n",
52 | " super(DeepQNetwork, self).__init__()\n",
53 | " self.num_frames = num_frames\n",
54 | " self.num_actions = num_actions\n",
55 | " \n",
56 | " # Layers\n",
57 | " self.conv1 = nn.Conv2d(\n",
58 | " in_channels=num_frames,\n",
59 | " out_channels=32,\n",
60 | " kernel_size=3,\n",
61 | " stride=2,\n",
62 | " padding=1\n",
63 | " )\n",
64 | " self.conv2 = nn.Conv2d(\n",
65 | " in_channels=32,\n",
66 | " out_channels=64,\n",
67 | " kernel_size=3,\n",
68 | " stride=2,\n",
69 | " padding=1\n",
70 | " )\n",
71 | " self.conv3 = nn.Conv2d(\n",
72 | " in_channels=64,\n",
73 | " out_channels=128,\n",
74 | " kernel_size=3,\n",
75 | " stride=2,\n",
76 | " padding=1\n",
77 | " )\n",
78 | " self.conv4 = nn.Conv2d(\n",
79 | " in_channels=128,\n",
80 | " out_channels=256,\n",
81 | " kernel_size=3,\n",
82 | " stride=2,\n",
83 | " padding=1\n",
84 | " )\n",
85 | " self.fc1 = nn.Linear(\n",
86 | " in_features=25600,\n",
87 | " out_features=512,\n",
88 | " )\n",
89 | " self.fc2 = nn.Linear(\n",
90 | " in_features=512,\n",
91 | " out_features=num_actions\n",
92 | " )\n",
93 | " \n",
94 | " # Activations\n",
95 | " self.relu = nn.ReLU()\n",
96 | " \n",
97 | " def flatten(self, x):\n",
98 | " batch_size = x.size()[0]\n",
99 | " x = x.view(batch_size, -1)\n",
100 | " return x\n",
101 | " \n",
102 | " def forward(self, x):\n",
103 | " \n",
104 | " # Forward pass\n",
105 | " x = self.relu(self.conv1(x))\n",
106 | " x = self.relu(self.conv2(x))\n",
107 | " x = self.relu(self.conv3(x))\n",
108 | " x = self.relu(self.conv4(x))\n",
109 | " x = self.flatten(x)\n",
110 | " x = self.relu(self.fc1(x))\n",
111 | " x = self.fc2(x)\n",
112 | " \n",
113 | " return x"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 4,
119 | "metadata": {
120 | "collapsed": true
121 | },
122 | "outputs": [],
123 | "source": [
124 | "class Agent:\n",
125 | " def __init__(self, model, memory_depth, gamma, epsilon_i, epsilon_f, anneal_time):\n",
126 | " \n",
127 | " self.cuda = True if torch.cuda.is_available() else False\n",
128 | " self.to_tensor = torch.cuda.FloatTensor if self.cuda else torch.FloatTensor\n",
129 | " self.to_byte_tensor = torch.cuda.ByteTensor if self.cuda else torch.ByteTensor\n",
130 | " \n",
131 | " self.model = model\n",
132 | " self.memory_depth = memory_depth\n",
133 | " self.gamma = self.to_tensor([gamma])\n",
134 | " self.e_i = epsilon_i\n",
135 | " self.e_f = epsilon_f\n",
136 | " self.anneal_time = anneal_time\n",
137 | " \n",
138 | " self.memory = deque(maxlen=memory_depth)\n",
139 | " self.clone()\n",
140 | " \n",
141 | " self.loss = nn.MSELoss()\n",
142 | " self.opt = torch.optim.RMSprop(self.model.parameters(), lr=2.5e-4)\n",
143 | " \n",
144 | " def clone(self):\n",
145 | " self.clone_model = copy.deepcopy(self.model)\n",
146 | " \n",
147 | " for p in self.clone_model.parameters():\n",
148 | " p.requires_grad = False\n",
149 | " \n",
150 | " def remember(self, state, action, reward, terminal, next_state):\n",
151 | " state, next_state = state.data.numpy(), next_state.data.numpy()\n",
152 | " state, next_state = 255.*state, 255.*next_state\n",
153 | " state, next_state = state.astype(np.uint8), next_state.astype(np.uint8)\n",
154 | " self.memory.append([state, action, reward, terminal, next_state])\n",
155 | " \n",
156 | " def retrieve(self, batch_size):\n",
157 | " # Note: Use lists for inhomogenous data!\n",
158 | " \n",
159 | " if batch_size > self.memories:\n",
160 | " batch_size = self.memories\n",
161 | " \n",
162 | " batch = random.sample(self.memory, batch_size)\n",
163 | " \n",
164 | " state = np.concatenate([batch[i][0] for i in range(batch_size)]).astype(np.int64)\n",
165 | " action = np.array([batch[i][1] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
166 | " reward = np.array([batch[i][2] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
167 | " terminal = np.array([batch[i][3] for i in range(batch_size)], dtype=np.int64)[:, None]\n",
168 | " next_state = np.concatenate([batch[i][4] for i in range(batch_size)]).astype(np.int64)\n",
169 | " \n",
170 | " state = self.to_tensor(state/255.)\n",
171 | " next_state = self.to_tensor(state/255.)\n",
172 | " reward = self.to_tensor(reward)\n",
173 | " terminal = self.to_byte_tensor(terminal)\n",
174 | "\n",
175 | " return state, action, reward, terminal, next_state\n",
176 | " \n",
177 | " @property\n",
178 | " def memories(self):\n",
179 | " return len(self.memory)\n",
180 | " \n",
181 | " def act(self, state):\n",
182 | " q_values = self.model(state).detach()\n",
183 | " action = np.argmax(q_values.numpy())\n",
184 | " return action\n",
185 | " \n",
186 | " def process(self, state):\n",
187 | " state = rgb2grey(state[35:195, :, :])\n",
188 | " state = state[np.newaxis, np.newaxis, :, :]\n",
189 | " return self.to_tensor(state)\n",
190 | " \n",
191 | " def exploration_rate(self, t):\n",
192 | " if t < self.anneal_time:\n",
193 | " return self.e_i - t*(self.e_i - self.e_f)/self.anneal_time\n",
194 | " elif t >= self.anneal_time:\n",
195 | " return self.e_f\n",
196 | " \n",
197 | " def huber_loss(self, x, y):\n",
198 | " error = x - y\n",
199 | " quadratic = 0.5 * error**2\n",
200 | " linear = np.absolute(error) - 0.5\n",
201 | " \n",
202 | " is_quadratic = (np.absolute(error) <= 1)\n",
203 | " \n",
204 | " return is_quadratic*quadratic + ~is_quadratic*linear\n",
205 | " \n",
206 | " def save(self, t, savedir=\"\"):\n",
207 | " save_path = os.path.join(savedir, 'model-{}'.format(t))\n",
208 | " self.model.save_state_dict(save_path)\n",
209 | " \n",
210 | " def update(self, batch_size, verbose=False):\n",
211 | " \n",
212 | " start = time.time()\n",
213 | " state, action, reward, terminal, next_state = self.retrieve(batch_size)\n",
214 | " \n",
215 | " if verbose:\n",
216 | " print(\"Sampled memory in {:0.2f} seconds.\".format(time.time() - start))\n",
217 | " \n",
218 | " start = time.time()\n",
219 | " \n",
220 | " q = self.model(state)[range(batch_size), action.flatten()][:, None]\n",
221 | " qmax = self.clone_model(next_state).max(dim=1)[0][:, None]\n",
222 | " \n",
223 | " nonterminal_target = reward + self.gamma*qmax\n",
224 | " terminal_target = reward\n",
225 | " \n",
226 | " target = terminal.float()*terminal_target + (~terminal).float()*nonterminal_target\n",
227 | " \n",
228 | " loss = self.loss(q, target)\n",
229 | " \n",
230 | " loss.backward()\n",
231 | " self.opt.step()\n",
232 | " \n",
233 | " if verbose:\n",
234 | " print(\"Updated parameters in {:0.2f} seconds.\".format(time.time() - start))"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 5,
240 | "metadata": {
241 | "collapsed": true
242 | },
243 | "outputs": [],
244 | "source": [
245 | "# Hyperparameters\n",
246 | "\n",
247 | "batch_size = 32\n",
248 | "update_interval = 32\n",
249 | "clone_interval = 128\n",
250 | "frame_skip = 4\n",
251 | "num_frames = 4\n",
252 | "num_actions = 4\n",
253 | "episodes = 10000\n",
254 | "memory_depth = int(1e5)\n",
255 | "epsilon_i = 1.0\n",
256 | "epsilon_f = 0.1\n",
257 | "anneal_time = 100000\n",
258 | "burn_in = 50000\n",
259 | "gamma = 0.99"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 6,
265 | "metadata": {
266 | "collapsed": true
267 | },
268 | "outputs": [],
269 | "source": [
270 | "model = DeepQNetwork(num_frames, num_actions)"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 7,
276 | "metadata": {
277 | "collapsed": true
278 | },
279 | "outputs": [],
280 | "source": [
281 | "agent = Agent(model, memory_depth, gamma, epsilon_i, epsilon_f, anneal_time)"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": 8,
287 | "metadata": {
288 | "collapsed": true
289 | },
290 | "outputs": [],
291 | "source": [
292 | "env = gym.make('Breakout-v0')"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": 9,
298 | "metadata": {
299 | "collapsed": true
300 | },
301 | "outputs": [],
302 | "source": [
303 | "def q_iteration(episodes, plot=True, render=True):\n",
304 | " \n",
305 | " t = 0\n",
306 | " metadata = dict(episode=[], reward=[])\n",
307 | " \n",
308 | " for episode in tqdm(range(episodes)):\n",
309 | " \n",
310 | " state = env.reset()\n",
311 | " state = agent.process(state)\n",
312 | " \n",
313 | " done = False\n",
314 | " total_reward = 0\n",
315 | "\n",
316 | " while not done:\n",
317 | " \n",
318 | " if render:\n",
319 | " env.render()\n",
320 | " \n",
321 | " while state.size()[1] < num_frames:\n",
322 | " action = np.random.choice(num_actions)\n",
323 | " \n",
324 | " new_frame, reward, done, info = env.step(action)\n",
325 | " new_frame = agent.process(new_frame)\n",
326 | " \n",
327 | " state = torch.cat([state, new_frame], 1)\n",
328 | " \n",
329 | " if np.random.uniform() < agent.exploration_rate(t) or t < burn_in:\n",
330 | " action = np.random.choice(num_actions)\n",
331 | "\n",
332 | " else:\n",
333 | " action = agent.act(state)\n",
334 | " \n",
335 | " new_frame, reward, done, info = env.step(action)\n",
336 | " new_frame = agent.process(new_frame)\n",
337 | " \n",
338 | " new_state = torch.cat([state, new_frame], 1)\n",
339 | " new_state = new_state[:, 1:, :, :]\n",
340 | "\n",
341 | " agent.remember(state, action, reward, done, new_state)\n",
342 | "\n",
343 | " state = new_state\n",
344 | " total_reward += reward\n",
345 | " t += 1\n",
346 | " \n",
347 | " if t % update_interval == 0 and t > burn_in:\n",
348 | " agent.update(batch_size, verbose=False)\n",
349 | " \n",
350 | " if t % clone_interval == 0 and t > burn_in:\n",
351 | " agent.clone()\n",
352 | " \n",
353 | " if t % save_interval == 0 and t > burn_in:\n",
354 | " agent.save(t)\n",
355 | " \n",
356 | " metadata['episode'].append(episode)\n",
357 | " metadata['reward'].append(total_reward)\n",
358 | " \n",
359 | " if plot:\n",
360 | " plt.scatter(episode, total_reward)\n",
361 | " plt.xlim(0, episodes)\n",
362 | " plt.xlabel(\"Episode\")\n",
363 | " plt.ylabel(\"Return\")\n",
364 | " display.clear_output(wait=True)\n",
365 | " display.display(plt.gcf())\n",
366 | " \n",
367 | " return metadata"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": 10,
373 | "metadata": {
374 | "scrolled": false
375 | },
376 | "outputs": [
377 | {
378 | "data": {
379 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfgAAAFXCAYAAABOYlxEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xm4nHV9///nfd8zc896zpw1+0oSCJsQFrEtWPTbghV/\nisWi9MKtV78u+BW0pVArNgqWUqxWaWtdqlZqC1hRS60boCyCoBACgSQkISQ52c5+zqz3/vsjYZJj\nTjY4M0nuvB7XxXWdmXv5vO/3OeE19+eemduIoihCREREYsU80gWIiIjI1FPAi4iIxJACXkREJIYU\n8CIiIjGkgBcREYkhBbyIiEgMJY50AQfi+wEjI9UjXUbsdXRk1ecmU4+bTz1uDfW5+Xp6ClOyn6P6\nDD6RsI50CccF9bn51OPmU49bQ30+dhzVAS8iIiIvjwJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp\n4EVERGJIAS8iIhJDCngREZEYUsCLiIjEkAL+ELhByFDdxQ3CI12KiIjIITmqv4v+SAuiiB9uHmT1\naJlR16eYSrC0mOcNc7uxDONIlyciIrJfCvgD+OHmQR7pH208HnH9xuNL5vUcqbJEREQOSlP0++EG\nIatHy5MuWz1a1nS9iIgc1RTw+1HyfEZdf9Jlo65PyZt8mYiIyNFAAb8fhWSCYmryKxjFVIJCUlc3\nRETk6KWA34+UZbK0mJ902dJinpSl1omIyNFLp6EH8Ia53QCTvoteRETkaKaAPwDLMLhkXg+/P7uL\nkudTSCZ05i4iIscEBfwhSFkmXVbqSJchIiJyyHQ6KiIiEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAi\nIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBERkRhSwIuIiMRQUwN+5cqVXHnllQCsXr2a\nK664giuvvJI/+ZM/YXBwsJlDi4iIHNeaFvBf+cpX+PjHP47jOAB8+tOf5oYbbuD222/n937v9/jK\nV77SrKFFRESOe00L+Llz53Lbbbc1Hn/2s59l6dKlAARBgG3bzRpaRETkuNe028VedNFF9PX1NR73\n9vYC8OSTT/Lv//7vfOtb3zqk/fT0FJpSn0ykPjefetx86nFrqM/HhpbeD/5///d/+eIXv8iXv/xl\nOjs7D2mbgYFSk6uSnp6C+txk6nHzqcetoT4331S9gGpZwH//+9/nzjvv5Pbbb6dYLLZqWBERkeNS\nSwI+CAI+/elPM2PGDP7f//t/AJxzzjl8+MMfbsXwIiIix52mBvzs2bO56667AHj88cebOZSIiIjs\nRV90IyIiEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIi\nEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBER\nkRhSwIuIiMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBERkRhSwIuI\niMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkNNDfiVK1dy5ZVXArBp0ybe\n8Y53cMUVV/DXf/3XhGHYzKFFRESOa00L+K985St8/OMfx3EcAG6++WauueYa/uM//oMoirjvvvua\nNbSIiMhxr2kBP3fuXG677bbG42effZZzzz0XgAsuuIBHHnmkWUOLiIgc9xLN2vFFF11EX19f43EU\nRRiGAUAul6NUKh3Sfnp6Ck2pTyZSn5tPPW4+9bg11OdjQ9MC/jeZ5p7JgkqlQltb2yFtNzBwaC8E\n5OXr6Smoz02mHjefetwa6nPzTdULqJa9i/7kk0/mscceA+DBBx/k7LPPbtXQIiIix52WBfx1113H\nbbfdxuWXX47neVx00UWtGlpEROS4Y0RRFB3pIg5EU0HNpym35lOPm089bg31ufmOuSl6ERERaR0F\nvIiISAwp4EVERGJIAS8iIhJDCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp\n4EVERGJIAS8iIhJDCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp4EVERGJI\nAS8iIhJDCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp4EVERGJIAS8iIhJD\nCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiaFEKwfzPI/rr7+erVu3YpomN954IyeccEIr\nSxARETkutPQM/oEHHsD3fe644w6uuuoq/uEf/qGVw4uIiBw3WhrwCxYsIAgCwjCkXC6TSLR0AkFE\nROS4YURRFLVqsO3bt/PBD36QarXKyMgI//Iv/8KyZctaNbyIiMhxo6UBf/PNN5NKpfizP/sztm/f\nzrve9S7uuecebNve7zYDA6VWlXfc6ukpqM9Nph43n3rcGupz8/X0FKZkPy2dI29rayOZTALQ3t6O\n7/sEQdDKEkRERI4LLQ34d7/73XzsYx/jiiuuwPM8PvKRj5DNZltZgoiIyHGhpQGfy+X4/Oc/38oh\nRUREjkv6ohsREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp4EVERGJIAS8iIhJD\nCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiaFDupvc+Pg499xzD6Ojo0RR1Hj+Qx/6UNMK\nExERkZfvkAL+6quvplAosHjxYgzDaHZNIiIi8godUsAPDg7y9a9/vdm1iIiIyBQ5pGvwS5cuZc2a\nNc2uRURERKbIIZ3Br1u3jksvvZSuri5s2yaKIgzD4L777mt2fSIiIvIyHFLAf/KTn6Srq6vZtYiI\niMgUOaSAv+666/jhD3/Y7FpERERkihxSwJ900kl873vf4/TTTyedTjeenzlzZtMKExERkZfvkAJ+\n5cqVrFy5csJzugYvIiJy9DqkgL///vubXYeIiIhMoUMK+L/8y7+c9Pmbb755SosRERGRqXFIAX/u\nuec2fvZ9n/vuu4+FCxc2rSgRERF5ZQ4p4C+99NIJjy+77DLe8Y53NKUgEREReeVe1t3kNmzYQH9/\n/1TXIiIiIlPkkD8m99JNZqIoorOzk49+9KNNLUxERERevkMK+Mm+h9513SkvRkRERKbGIU3RX375\n5RMeh2HIH/7hHzalIBEREXnlDngG/853vpPHH38c2DVN39gokeB1r3tdcysTERGRl+2AAf/Nb34T\ngJtuuomPf/zjLSlIREREXrlDmqL/+Mc/zj333MPnPvc5arUa3/ve95pdl4iIiLwChxTwn/nMZ3jg\ngQf4yU9+gu/7fOc73+Fv//ZvX9aAX/rSl7j88st561vfyre//e2XtQ8RERE5sEMK+Icffphbb70V\n27YpFAp8/etf58EHHzzswR577DFWrFjBf/7nf3L77bezY8eOw96HiIiIHNwhfUzONHe9Dnjps/Cu\n6zaeOxwPP/wwS5Ys4aqrrqJcLvMXf/EXh70PERERObhDCviLL76Ya665hrGxMb7xjW/w/e9/n0su\nueSwBxsZGWHbtm38y7/8C319fXzgAx/gRz/6UeOFg4iIiEyNgwb8Cy+8wJvf/GaWLl3KzJkz2bFj\nB+9+97v59a9/fdiDFYtFFi5cSCqVYuHChdi2zfDwMF1dXfvdpqencNjjyOFTn5tPPW4+9bg11Odj\nwwED/rbbbuNrX/saAP/4j//In//5n/Ov//qvfPKTn+TMM8887MHOOussvvnNb/Ke97yH/v5+arUa\nxWLxgNsMDJQOexw5PD09BfW5ydTj5lOPW0N9br6pegF1wID/3ve+x49//GP6+/v5whe+wFe/+lUG\nBwf5/Oc/z/nnn3/Yg1144YX86le/4rLLLiOKIj7xiU9gWdbLLl5EREQmd8CAz+Vy9Pb20tvby9NP\nP81b3vIWvvrVr76iUNYb60RERJrvgAG/9zvlOzo6uP7665tekIiIiLxyB/ys297vbk+n000vRkRE\nRKbGAc/g161bx+tf/3oAdu7c2fg5iiIMw+C+++5rfoUiIiJy2A4Y8D/+8Y9bVYeIiIhMoQMG/KxZ\ns1pVh4iIiEyhw/++WRERETnqKeBFRERiSAEvIiISQwp4ERGRGFLAi4iIxJACXkREJIYU8CIiIjF0\nTAS8G4QM1V3cIDzSpTRN6Di4/f2EjvOytncDl4HqEG7gTnFlIiJyLDrgF90caUEY8T+bBlg9WmbU\n9SmmEiwt5nnD3G6svb4n/1gWBQED376D8oon8YeHSXR2kj9zGT1vezvGIdy1LwgD7l7/A54eeJYR\nZ5QOu8jpPafw1kVvxDJ1K14RkePVUR3w317TxyP9o43HI67feHzJvJ4jVdaUGvj2HYze+9PGY39o\nqPG49+1/fNDt717/A37e93Dj8bAz0nj8tiX/3xRXKyIix4qjeor+qZ2jkz6/erQci+n60HEor3hy\n0mXlFSsOOl3vBi5PDzw76bJnBp/VdL2IyHHsqA744Zo36fOjrk/J81tczdTzx8bwh4cnXzYyjD82\ndsDtx5wSI87kL4KG66OMOaVXXKOIiBybjuqA78wkJ32+mEpQSB7VVxcOSaK9nURn5+TLOjpJtLcf\ncPt2u0CHXZx0WWe6SLtdeMU1iojIsemoDvgzpk0eXkuLeVLWUV36ITFtm/yZyyZdlj/zTEzbPuD2\nKSvF6T2nTLrstO5TSFmpV1yjiIgcm47q0+C3nTSbWtWb9F30cdHztrcDu665+yPDJDo6yZ95ZuP5\ng3nrojcCu665D9dH6UwXOa37lMbzIiJyfDKiKIqOdBEHMjBQwg1CSp5PIZmIxZn7ZELHwR8bI9He\nftAz98m4gcuYU6LdLhz2mXtPT4GBAV2vbyb1uPnU49ZQn5uvp2dqLq8e1WfwL0lZJl0xn242bZtU\nb+/L3j5lpejJdk1hRSIiciyL5+mwiIjIcU4BLyIiEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAiIiIx\npIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBERkRg6IgE/NDTEa1/7WjZs2HAkhj+o0HFw+/sJ\nHQcAxwvoH6nieMGEn/fH8wLGRmp4B1jnpf3u3DlKedv2xljNsvcxHWp9IiJy7Gr5zWY8z+MTn/gE\n6XS61UMfVBQEDHz7DsornsQfHsbs7OKBuRewxuhiaNwhnTIBA8cN6GyzOXNJD5e/bhGWuet1UhiG\nPHL/BjY+P0h53CHfZrNgSTe/9boTMM09r6WCMOTOe5/H/tn/MGdkI21+BTfbRu955zLt8ndgWFZT\njskdHmHDzN9mID+PWpBo1PfmPzpjysYTEZGjQ8vP4G+55Rbe/va30/sK7pzWLAPfvoPRe3+KPzQE\nUcRPmM8vSnmGxnedXdfdkLobEAFD4w73/rqPO+9f39j+kfs38Myvt1LevX553OGZX2/lkfsnzlTc\nef96oh9/j9MGVlH0K5hAujrO+P33MvDtO5p2TOs6z2Jz5gRqQWJCfT+557kpHVNERI68lgb83Xff\nTWdnJ+eff34rhz0koeNQXvFk47FnWDyfn3PQ7VY8P4jjBXhewMbnBydd58XnBxvT4Y4X8PSa7Syu\nbJl03dKKJ6dsun7vYwoMi8H83EnXe37VDk3Xi4jETEun6L/zne9gGAaPPvooq1ev5rrrruOLX/wi\nPT09+91mqm58fzC17RX84eHG47KVYTyRO+h2I6U6ViqJDZRLkwdzueSQTiXp7M6xfbCCNzJKm1+Z\ndN1gZIQ2yyfT0/2yjmNvex+TY2Wp7+d4xkZrjfqkeVr1t3w8U49bQ30+NrQ04L/1rW81fr7yyitZ\nvnz5AcMdYGCg1OyyAAiDBInOzl3T80A+qNHmVxhPHvgPuaOQJnA96kC+YDem5/eWL9jUXY+BgRKB\nF5DsKDK+NUdxkpC3OjoYDxKUp+C49z4mO6iS9ivUJzme9mKmUZ80R09PQf1tMvW4NdTn5puqF1D6\nmNxupm2TP3NZ43EyClhSnnwafW9nLunGTlokkxYLlkx+1j1/STfJ5K43ztlJi9NPmsG63OTT/4Uz\nl2Ha9ss4gn3tfUxWFNBd3jzpektOnd6oT0RE4sFavnz58iMx8Fvf+lY6OzsPul616ragml2yS08h\nrNfwx8YJnTqLMi5h7wyq6QJ1JyCdskhYJmEY0dmW5rdPm87lr1uEaRgAzJ7fgev41MounhtQaLM5\n8bTp/NbrTsDYvQ7AyfM7eNbsYah/FNutkgpdnFw7neefT+8fvQPDnLrXXXsfU3FkI2G6gJcu4GM2\n6nvjW0+jVvOmbEzZVy5nt/Rv+XikHreG+tx8udzUnOQZURRFU7KnJjkSU0Gh4+CPjZFob8e0bRwv\nYKzs0J7f1fSXfrb3c9breQHVsks2nzrgmbHjBYwOl8gFNbJdnVN25j6ZvY8pMBMT6tOUW/Opx82n\nHreG+tx8UzVF3/LPwR8LTNsmtdfH+OykRW9HtvF4758nk0xatHdkDjqOnbSYNq0IFF92rYdq72My\n4ZDqExGRY5euwYuIiMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBER\nkRhSwIuIiMSQAl5ERCSGjvqvqg1Dj8ArYSULmGbysLZ1vIDR8TIp08EJ0+TzGYZrVUbGRumx22gv\npqn4Y6SCiLqfo6Mt1/h++brvMFYdpz3bRjpx8O+If+m73utZm51+lRnZDvKpNABuEFLyfArJBCnL\nbOx/sDQECYPebBcpK7XPcY+VBtlecZhe7CEyk43tPS9gbLxOPRFgWg5tdgEnMBvLy+Nj7Njcx/S5\ns8m3te8z/m9+176IiMTPUR3wW9Z8n6HtzxB4Y1jJdjLFE+mY9fsYxoEnHoIw5M77nydTf4Qtw2k2\nDhfxZ3cTWiuYNpSlMN7ByLQXedXcAbZuXsjGgR7G6jYduYhlJ83CnllmXTVFiTQFtrE46/KWpctI\nmPveOCYKAga+fQejK1dw57m99LdXCalimXlm5ZZwUs+FPD9WZdT1KaYSnNSWoVbZxBMjq6j4W4mi\nMraZ5bwZZ/CHi9+EaRjs3Pwjvvj0cwwmxrELZ5BMzMcy87QnLGa8WGFHMmBHbiWusQU7eQrJ5AJM\nM0fRgJk/ewzHKeJaOZLhC5ROacedP5txL6AjYXL+r35Gz4bV+MPDJDo7yZ+5jO4P/mmzfoUiInKE\nHNUB37/54cbPgTdGeeBxADpnX3zA7e68fz1W6WH6azZPbZtOYXE7YeIJpg+n6B6cw7Y5z3HGgh0M\nvngiT22Z1dhupGKw09rOYHXPvdpLZHmymoXVT3LZKefsM9bAt+9g9N6fcucbT2ZH++CeesMy/U6K\nsYHxPft3fTYOrmdj/QVcb23jeSes8sDWRzAMk9dnU3zx6WcYzA2RTp2HbZ+2Z7DnhunLWAwUV+F6\nz5FOTlzec/8jlMJ5jd9q/0mzKM8qgBcAcMIDP6Jj1a/wd6/vDw0xeu9P2ZhOUXjL2w7YUxERObYc\nc9fga6PPE4b7v3e54wU8s34HJ3SNsKa/C0yDZEeCKLGdtpFphGZAtWMn863kruV7SSZCqrnJ71G/\nrpqi7jsTngsdh/KKJ6nZCfrbqr+xhUUiMe83nvEZi9J4/ouTjrFq4BkGBlczmCjvs70RhKSHa1S7\nkru3n7jcqlcJ3T21h6ZBrXvPHeMsz2Pui89POu7w478idJxJl4mIyLHpmAv4wBsj8PZ/L+KxsoPn\nlDCNiLG6jWWbREaNFBFJN42XrJPO1DG8NGP1idef8/mQqjn5bVTLpBmrjk94zh8bwx8eZqgjT2hM\nDHjTyGKa+QnPZahTjiKiqLyfYyuxs1wnSjn7bG86IUZk4NkOUVTeZ3lmbAjXyjUeh7ZJkN5zSSFb\nLZMrj006rjM4iD82+TIRETk2HXMBbyXbsZKF/S5vz9sk7QJhZNCedgicECPK4GLgpeokvTT1Wpoo\nWac9PfGstVw2yYa1Sfebp057tm3Cc4n2dhKdnXSNlDGjifeID6MqYTgxyGukyRsGhjEx+PccW4Fp\n+TSGa++zfWibREZE0rExjPw+y2vtXaSCSuOx6YRY9aDxuJrNU8m3Tzqu3d1Non3yZSIicmw65gI+\nU1xywHfT20mL0xZNZ8NQByf1DkEY4Y34GP4Mxjt2YoYW2ZFpvBh4u5bvxfNNspXhSfe7OOvu8256\n07bJn7mMjOPTO579jS0CfH/TbzyToN2ok0zMn3SMU3tOo6d7Kd1+fp/tI8uk3pkhO+Tt3n7i8iCd\nxUztqd0MIzKDe16sBMkkm+cvmXTcznPP0bvpRURixlq+fPnyI13E/gRejXp1nCh0sJJFcl2n734X\nvXHA7U6e38Ez23IUrc1kE3XGdgZE+YWUslsxjBrF4em86PksnNtHDou6a+P4Jh25iMVdM5nTOU7Z\nC/GwKFDj1GyFtyxdhjnJu/ezS08hrNdY9OQLbOruoZqGCA/LLDAjU+SMnpOo+gFOENKRSrC4YxrT\nrQRDbogXOYCLbWb5nZnn8IeLLyHbtoil6RLP9VUZ9zeBZWIYWUwjSWZajpljAYbXQz0V4obrIYow\nzRyGkSRcOIfuTc+BZxIYSdoGhkikPVJdRZwworZwMb1GSKFeJXTqJDq7aPvt32bJ/30v1dr+39cg\nr1wuZ1Otuke6jFhTj1tDfW6+XG5qTriMKIqiKdlTk+zcOazPwTf5c/A9PQUGBvb/vgZ55dTj5lOP\nW0N9br6env1fhj4cR33A6w+p+fQPtvnU4+ZTj1tDfW6+qQr4Y+4avIiIiBycAl5ERCSGFPAiIiIx\npIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAiIiIxlDjSBRyM\n5wVUyy4Jw6c2OIbdk6eeDMgkcziBSQYDr+aTzaeITGOfm7oc7lij42X8pENXrkjKSjXGz+ZTJHff\niOalm73UTKBawaoPMZboZHZPgbRVxwhT1McqjCXTECVoyyYZr1WJrAqFhEU1zJF0Tdrb01hW2LiZ\nDkCtPkw5iuhIdzZuQON4AaPDJWyngpvMERXSFG0D3ykxXoKUW6FuGXT29JKx99zgZqhexXMrZIMM\n7e05kklr0hvfiIhI/LQ04D3P42Mf+xhbt27FdV0+8IEP8PrXv36/6//oe6t4dsUWkv1bqCbzvLhw\nM+WOnZhtZ2Bb82nfEJAbqmPWAyond1DvyVK3oJhKsLSY5w1zu7EOcmtZgDAMefj+dTww8nMGc9vw\n7RrZKM/JI68mMdBGedwh32Yzf3EXIfBYrcrWoTKX2Q/yU38ZO4dTvG7eL3B7Bsk8uZ1Hi6/mKX82\n46MBkRuQnLOG/zN3kHJqGSPrukkM+iSdgJNP2cjs6cMkrSqhkeRn1SrPuy7jYUQxYXN679n4m5Zg\n//wHYLSz4YzTqEzLc1ZmFX3PB3Q/s5GBWTPZEExjzEnTbq/i1PkJus44jV9uu5fcOoO24U6SbgYz\nE2Ism8NYW5Ix15/QIxERiZ+WBvx///d/UywWufXWWxkdHeUtb3nLAQP+8Yc2UqxuZzQ7g21znmN4\nxibSqfOw7dNoXztKW18FgJHFbZSnZxvbjbg+j/SPAnDJvJ6D1vXI/Rv46c77GJ7xYuO5ts1zqO+0\nAQeA8rjDqie2MbK4jW01nz+2f86Po7Pp25biohNf4DXzt+E9NMgjuXP5lbmIan8ZgMSctfz+gh24\nqVfTv2Z6o+alJ25gwaxtjfHur5R4wvEbj0d9h3sfHeR3V63GThR45tyzKM8t8BrjCbY971FcsZXh\nedN5sjq/sc2Yk+bpsB2z76d0rg/o3rmgsWx4dgfltAGuv0+P3tPbdtAeiYjIsaWlc7QXX3wxV199\nNQBRFGFZ1gHXN0KfWrJAaAaMd+wELBKJeRhBSGawBkBoGtS6M5Nuv3q0jBuEBxzD8wLWr9uxe/+7\nxw1M2kam7bNuaBpUOtIwOErb9IhtA2mSZsBJvUNEXoiz2WXTnMXUd9eGGWB37GRBMsVGf1ajZtMM\nmN47tKeGKGKdF0wYKwpMjKEuTqhuo79tHrXuDAl8ZkdbWb+jnfnuNp4PZkws0DRIdSXxnc0T6j9Y\nj5yD9EhERI49LT2Dz+VyAJTLZT784Q9zzTXXHHD9VFDDSeTwkjV8u4ZpFDDNPGYtJFHfFUqhbRKk\nJ3+hMOb6JAtpenL2fscYHqww6ozj27XGc0kvTdJN77NuaJt4JvQEowwlivj1kI6MS3vaISoFVMM0\n46k8YX3XWbqRdChkHBJmO7WqTbE+DkDadkmnncZ+y2HEeBhNGCvybPK1CDuKqOYKBGmLNspETkBQ\nCbHSJmPOxBot28S0XRJVf0L9B+vRWN2jd4ruPyz7N1X3eJb9U49bQ30+NrT8TXbbt2/nqquu4oor\nruBNb3rTAdd1rQy2XyX0MiScDH66ShiWCe08ftokWQ8xnRCrHhBk9z2U9lQCr1RnoOrudwzPCyja\nbbv3vyvkvWQdL1Un5WYnrGs6IckQBqwiXf5aEmmTspNirG5TzIZkzTptbpnhtEVYD4g8m1LNxs9X\nyNhOo+a6k6Jet8lmd4V83jRoM40JIW8kHcoZA8cwyFZKWPVeqtkMhm1h5UyC/pB2u86Ys+fMPHBC\nQieFn0pMqP9gPWpPJxkYKB3wdyGvTE9PQT1uMvW4NdTn5puqF1AtnaIfHBzkve99L9deey2XXXbZ\nQdePzAQZr4QZWrunnAN8fxORZTamnM0wakx9/6alxfxB3ymeTFosWjx9wpR2ZIUTpuxfYoYRuZE6\ndBcZ32Ews6eOF1qs6e/CSJrYc1PM27KO9EvT4aGFMzKNjZ7LgsTWRs1haLGjv2tPDYbB4uTEM2zD\nCom6htiQnUnv+CYygzV8EvQZs1g0fYwXUzNZYm2fWGAY4Q55JOy5E+o/WI9svZteRCR2rOXLly9v\n1WCf/exnefbZZ1m/fj3f/e53+e53v8sf/MEfkEhMPpFQq3nsLEF2dCu58R6cpIFjbiBMmPhdnRhB\ngoQXkun7V8/aAAAaWklEQVSvY2USGJkkgQkdqQTLutp4w9xuzEN4F/3s+R0kB9oZGh7HMeuElk/U\nVmdGZjrpIIfnBhTabJacOo05doqKFfJIaSaXJH7JYK6XVTu6SeLTtshi/ua1hHaW0UwR14sIRzp5\nMaqzKPsCyd48Jb8AbsRwf5FkOiSfC7BMj/mpDG4UUQ5D3AiKiTSvOWkRYfdvw6bn6No+TDVd4MXU\nbJZMG6I/VaBrfR/tXQFV08YNLIq2w2kdDkuXvIaNqRfx3RKWm8AKLLL1GsWeNsxsEicIJ/Qon7Op\nHmCWQ165nHrcdOpxa6jPzZc7wGXlw2FEURQdfLUjZ9u2UX0Ovsmfg9eUW/Opx82nHreG+tx8UzVF\nf9QHvP6Qmk//YJtPPW4+9bg11OfmOyavwYuIiEhrKOBFRERiSAEvIiISQwp4ERGRGFLAi4iIxJAC\nXkREJIYU8CIiIjGkgBcREYkhBbyIiEgMKeBFRERiqOW3iz0cgePg9veTaG/HTxgMVIZwnJCEl6Qj\ncjByULYMLN9kaNykK6jhJwNCO0F7AKEVMFqqYaWnEeVtTMuhO9NOykpR9x0GxkYhsuko5Cj5DmEw\nTkfCwrLaqVcjgrBG//aNTJvRTSHbhTtSxrEyZPNJKI3gpUMqySQd6U4AhspjJD2b9vYclhVO+F75\nhB/hj40R5jOUcGm3C43vuh8YrjDqh8xsS1Otuoz6ITM6MoSmge261HYOku1KYbcVcepVxktQiDyy\nXZ0EZoJq2SWdijAqZRLt7Zi2TRh61Kpjje+qLxk+Zj5Jm5Wk5KYw7YCe3d+3LyIi8XNUB/yKD11N\nbXCQX5zXTfmk83lxfZGzn32EM04e5t45Jp3pFE+sP4+Tnn2aE860eKh3jO4nh1nSm2btdoeR8qms\nO+1V9Hf/CpctRJTJJNo4MXsez63JUhp0yc7JYXY9w/l2H4uskJXrF7F9R5G20oOcujAgPSvDwL0l\nHh46icHkTOaMPcfs00f4xRyT9UbIeAQpLHo2LyU33E3KS3PSyevZ0LGVDX6dkh/y+qdqLNwe8PgJ\nFi/MzTCeMei0O5i/bRnPjRYYK7vMcENKxRTlmk9uboFMl83Jv17NggVD9EwfpTTqcu/aeXQ83ce8\n8W3k/SrPT/stRtvnMnvnU/RWt2B7ZZLdXST+zxx+3D+NtpUb6J81g3XTSvzOjBG2b1nA+sQYfmEI\n065jk+e82a/i/V3vONK/ahERmWJH9RS90z/AQ2dkKZ3826x7YTpnr3yE807q57FFCdrTCZ7a+Fuc\ntPJpTjrT5Lk5I7T9eojTu9L0DzpUhk9m7RmvZueM9TisJqIMwOzUqaxc08ZYX53cnAJW72rOT73A\nOWnYtn4RmzbPpm38QZbN9cmd2k64apy125fQl1vKnLE1LDh1gEcWJXiSXeEO0LlpCR075pJysyxd\n8gIbO19khVtjPIw4/8kyp6wp8/gJJk+dlGU8a4ABqXXTWTWcY2S4ziw3pFxMMT7qkptbIDe3jaW/\nWsPcE8aYt7CfbNblp+sWkFuxnVcNr6PoV9jQdTbbC4uZvfMp5o6tJu2VMYDoxIgf9neRfWITw7Om\n80xXndfOGWSkbxHPUiHs2YqZroMBjlHmga2/4PaV3zlyv2QREWmKozrgPQtenJ1nnDn4/SUWO1sI\n52fZ6AXMNtP070ywyNlGarbPCzWfhTtcgtlpcptc+tvmUe1K4vkvNvaXwGKcOdQHa2AapLqSRP6L\nLE5ZBIHJzv4uAurMHx3GWpAj8kLcjQ6D+bmYoU/37vHXeUFjn0Zg0jYyDQDTDOjqGWwsT/gRC7c6\neBZsmJWesE1udDqlmo8JFAyDUs0H08DuzpCou4SezfTeIQDcwGT9jnYWV7YAEBjWnpp2P7drQINw\nboH1O9qZ725jbdiL3bGT+VaS1QNFrI6dk/b58S0rcQPd31lEJE6O6oCvZCyiQp6ykyZbKdOecqjk\nLELA8dpJVyqkswZ+yiWs+LRhUDMg6aSp5gp4tkMUlRv7y1k5yk6asB5g2Sam7ZIzKrSZJnUnRa2e\nBmOUNhOMQoKoGuDUUtQTOeygSiZVp5KzGA/33GE36aVJurvCO227+Kl6Y3muFlCohFQyFqWcOWGb\ngAyhE5IEopRJ6IRYtomVTpAbK2FmTdJpB4CykyKohLT5FQAcK9uoKe3vOT4ja1FOZQkqIVbapBQZ\nFDIOhpemFIJh1yft81BthDFHt38UEYmTozrgc7UAo1Qmb9ep5vKMuTa5SoAJ2Mkx6rkc9WpEwk1h\n5hKME5GJwLPrZCslko6NYeQb+6sEFfJ2HTNtETghoZOiEuUYD0PStksmXYeoyHgIUcnHyFrYGZe0\nX8GxstTcNLlKQJtpNPbpJet4qV3BWXdSJNx0Y/lLwf5S0O+9jUUN0zbxAMMNMW2TwAkJ6j6V9gJh\nNaRetwHI2y5WzmQ8kQPYHeyV3UG/5/iiakDerWLlTIJ6SMGIKNVsomSdggmRs2cWYW9dmQ7a7am5\n/7CIiBwdjuqATwYwv69MG1tI9BZYZ8/BfLHKgqRFX1ind5rPensmbl+ChZkEL0xPYfXVqcxL0Tu+\nieyQRzIxv7E/n4A2tpDuzkAY4Q55GIn5rHMDLCtkWu8QFmleLHYSbKxgJE1SC2y6y5sJzQSDu8df\nnLQa+4yskPHdU99haDE00N1Y7icMXphlkwzghK31CdtUijsoZBKEQCmKKGQSEEY4gzX8dAoz6bCj\nvwuAlBWyaPoY63JzALCiYE9Nu5/bNWCEubnEouljvJiayYlmP87INF4MPJb2jBLsvpTwm86d8yq9\nm15EJGas5cuXLz/SRexP//0/Z+bGcXbag7Qt7uK51GLC1cOcbVXYkItY0LONZ1JnwdM7ODGbZcti\nk9ILJeb2pHGDHeQ2ezjZk3GyBqFRB1zq0RinzOih7GYo76iRSMxki+2RYJx53UOYQYIh/0SG+7fQ\n4ddInJCns7odbwx2ZBcQbqlySrKK325SMcAB3OIIZpDCcpOMDHSzuBBh5+pUI5/105LkfZNTN4cE\nYUA1m8BNGNi9AUv8bspGjoEgpKviYxRTVIZ2vRAYXTSd1LMlDExStseS3iHWJGczPp4k7blMq2zC\nTebo7zqRKAiwwzpW4GHVspx4WpI17bMorttCPtHBk0GCE2dtIzPey0g9RWB5GJaPTZ7fmX0Of3LO\nH1GreUf0dx13uZxNtar3OTSTetwa6nPz5XL2lOzHiKIoOvhqR0bgOOxY36fPwTf5c/A9PQUGBnQN\nvpnU4+ZTj1tDfW6+np6puWR6VAc8oD+kFtA/2OZTj5tPPW4N9bn5pirgj+pr8CIiIvLyKOBFRERi\nSAEvIiISQwp4ERGRGFLAi4iIxJACXkREJIYU8CIiIjGkgBcREYkhBbyIiEgMKeBFRERiSAEvIiIS\nQ4kjXcCBbNo5wkh5mPZUgWotoj1vYyctPC9gdKSG67m0Gw7Zrk78hMGYUyJjZhkf9yECK5ekI5Mi\nZZmEjkNtcIA6Abm2DsxaHTOTwS1VGjeQ8UolxpJpiBJ0t6ex97otrBu4jNaGMQKP4dBgVq6bfCqN\n5wVUyy7ZfIpk0sLxAsbKTqPW+vgY4zu20jZ9FlYmv0/dpm039j/mlGi3C0Q1n9KOYQrTO7Hz2X2W\nA42f977Nq+vU2TkwwHiUYV5PO/l0ckI/967VskICr4SV1H3gRUTiqKUBH4Yhy5cvZ+3ataRSKW66\n6SbmzZu33/U/+sMbCXbOJRydQeTadBVsTkyn2DJa4aRtj3FCZQtjQYUHzulg87wcQ1sXMmNoDtbC\nLrzuDEHGIuuHvP6ZB0l1lshNd0g8uYPBFyqElZD13ecwmJ3FrLE1bD1lPk8zg9JYSOiEZLIJfmvp\ndN524QK+v+F/sUefYYWbZ9AfJYyqWORZuP0susY6KY87ZAspBjIJdtY9hscduvMml85aRbG3TCrl\n8uiji/n19k5OHljBCZUtlPwKbraN7vPO4RdnFXh6cDWj1VHOWbEY35uOY2ZJh3VmFn3GXh/yzNBq\nhp0RbNMGY1fgd9hFTu85hTcvuJjnHr+bu1/oYMeQjV8PSdgmc+cV+Ys3n0bCMHjk/g1sfH6QSqnO\naaduYvq0QZJWFSvZjjtyGnbn72IYmtAREYmLlgb8vffei+u63HnnnTz11FP87d/+LV/84hf3u37Q\nP5egf0Hjcbbksr7kctrA45w1tgaAB5blWLUoibtpATP752MvLlKeu+es9JTH7yM1vULPojreQ4ME\nz4xjAOu7zqGveDKLBx5j+8mzeCK1gGpfubFdrepz3xN9bDQeZWl2Hc+E3fR72xrLezbNwd6Zo4wD\nwOpSnf69brB0Se/TzJw/CsCzqxfy6PZpE+oGSFfH+cHoL3lqaw6As544gQqLYPfEQd3K8nDxOYa3\nvdjYxgmdxs/Dzgg/73uYGX07uHf7XPq2JoEQAN8JeeH5YW79/ipe25blmV9vBWDpiRuYM3PPcQTe\nGP2bHyZfc+mcffF+fxciInJsaekp2xNPPMH5558PwBlnnMGqVasOuH4wMq3xswm0AeXQZ3FlCwCe\nBRtmpYkCk2h4Gu2mSa0709jG8jxmbdlAbk5A5IUEG6u79mtYDObnYoY+xfp2Ns1ZTH2wtm8BZsCA\n/wJzEkl2eKONp43ApG2v2gIiRvfaLINDz6zKrmWBybadXRPqfslL9QMkXYPInzlheWgGjHfsPGCP\nkoFJNjnOtoH0pMs3bx7hhecHdx2OGTC9d2jS9WqjzxOG3gHHEhGRY0dLz+DL5TL5fL7x2LIsfN8n\nkdhPGe6e0EoCBpAKarT5u8KzkrEo5Uwi1ybppUlkTIL0nuvm2WoZ2wpIpz2iUgAlHwDHylJP5Mj4\nZSIbxlN5wnpln+GNpEPOruPQSRjtCcaklya5V20e4O613cxkhXR6V1jWnRTjjk0qKDXqfslL9QMU\nSklcKzdhuZes49uTvPDYSzFMU47S+PVw0uVmPaRS33XWn7Zd0mln0vUCb4xiIcTO6pp8s0zVPZ5l\n/9Tj1lCfjw0tDfh8Pk+lsifkwjDcf7gDpOrg7nqTmQdEgGtlGE/kKPoVcrWAQiVkPOvgJuv4TgKr\nHhBkd+2zms3jBBb1epJMNoRCAko+dlAl7VdwrCyGA21umeG0RVgPJgwfeTYVJ42dq2MaWcJo1wyA\nl6zjpeqkdteWBFLsCfltXo56PUk265G2XdpsBzfcU/dLXqq/VLAoFTxSQQU3secfTtJLk3Ay+On9\nh/yoWSdv1EmkzUlDPkyb5FI2lXGHupOiXrfJZvcNeSvZzmjJxKyU9lkmr1xPT4GBAfW2mdTj1lCf\nm2+qXkC1dIp+2bJlPPjggwA89dRTLFmy5IDrW3tNT4fAOJA3E6zLzQEgGcAJW+sYVojRuZOxMCSz\n11R7kEyydc4JVLZYGEkTa8GuQLaigO7yZkIzwWh6BvO2rCO919T+nkEtehIL2eJ7TE8WG09HVjhh\n6tzCoLjXZjVsBnZfV7eskJnThibU/ZKX6gfwUhFGYtuE5WZoTbgUMBnPCql6bczsqU+6fO7cDhYu\n6d51OKHFjv6uSdfLFJdgmslJl4mIyLHHWr58+fJWDbZw4UIeeughvvSlL/HQQw+xfPlyOjs797v+\nt9fdBUTgZSCwSBZslrRn2JjsxfMcskGdE7ZVqGVsvDkV+kODzLYkKdMiSlpElsHozAXM27wTpwbW\n4jRW6BFWfDrGtuKbKXYUFjKjbyP5LhjLteP5EAUR2VyCC06fyf+98ALW1KvM9Poo0UUtConwqBXr\ntJndFCnguQEzCzbZYprQMnDcgD5vBrPDElYqYFrvIKm6zdPRQvzAJxvUSYYubradMxadi3XiIkpe\nmY3d25m/08D0UgRGgnRY4xTyTH/VbMpemZpfJ23ZWGaCMArpSnfw6uln8/pXXUZXZQVbgxRVL0Ho\nQyJtMn9hB3/x5tOYt7AT1/GplV2272gjk4FM1scyfaxkkZ7ZZ5Of9noMw2jVn8JxJ5ezqVbdg68o\nL5t63Brqc/PlcvaU7MeIoiiakj01gT4H35rPwU+b1qkptybTtGbzqcetoT4331RN0R/VAQ/oD6kF\n9A+2+dTj5lOPW0N9br5j8hq8iIiItIYCXkREJIYU8CIiIjGkgBcREYkhBbyIiEgMKeBFRERiSAEv\nIiISQwp4ERGRGFLAi4iIxJACXkREJIYU8CIiIjF01H8XvYiIiBw+ncGLiIjEkAJeREQkhhTwIiIi\nMaSAFxERiSEFvIiISAwp4EVERGIocaQLmEwYhixfvpy1a9eSSqW46aabmDdv3pEu65jleR4f+9jH\n2Lp1K67r8oEPfIBFixZx/fXXYxgGixcv5q//+q8xTZO77rqLO+64g0QiwQc+8AEuvPDCI13+MWVo\naIi3vvWtfO1rXyORSKjHU+xLX/oS999/P57n8Y53vINzzz1XPZ5inudx/fXXs3XrVkzT5MYbb9Tf\n8hRauXIln/nMZ7j99tvZtGnTIfe1Xq9z7bXXMjQ0RC6X45ZbbqGzs/PAg0VHoR//+MfRddddF0VR\nFK1YsSJ6//vff4QrOrb913/9V3TTTTdFURRFIyMj0Wtf+9rofe97X/TLX/4yiqIouuGGG6Kf/OQn\nUX9/f3TJJZdEjuNE4+PjjZ/l0LiuG33wgx+Mfv/3fz9av369ejzFfvnLX0bve9/7oiAIonK5HH3h\nC19Qj5vgpz/9afThD384iqIoevjhh6MPfehD6vMU+fKXvxxdcskl0dve9rYoiqLD6uvXvva16Atf\n+EIURVH0P//zP9GNN9540PGOyin6J554gvPPPx+AM844g1WrVh3hio5tF198MVdffTUAURRhWRbP\nPvss5557LgAXXHABjzzyCE8//TRnnnkmqVSKQqHA3LlzWbNmzZEs/Zhyyy238Pa3v53e3l4A9XiK\nPfzwwyxZsoSrrrqK97///fzu7/6uetwECxYsIAgCwjCkXC6TSCTU5ykyd+5cbrvttsbjw+nr3rl4\nwQUX8Oijjx50vKMy4MvlMvl8vvHYsix83z+CFR3bcrkc+XyecrnMhz/8Ya655hqiKMIwjMbyUqlE\nuVymUChM2K5cLh+pso8pd999N52dnY1/gIB6PMVGRkZYtWoVn//85/nkJz/Jn//5n6vHTZDNZtm6\ndStveMMbuOGGG7jyyivV5yly0UUXkUjsuTJ+OH3d+/mX1j2Yo/IafD6fp1KpNB6HYTihKXL4tm/f\nzlVXXcUVV1zBm970Jm699dbGskqlQltb2z59r1QqE/7QZP++853vYBgGjz76KKtXr+a6665jeHi4\nsVw9fuWKxSILFy4klUqxcOFCbNtmx44djeXq8dT4xje+we/8zu/wZ3/2Z2zfvp13vetdeJ7XWK4+\nTx3T3HOOfbC+7v38S+sedP9TX/Irt2zZMh588EEAnnrqKZYsWXKEKzq2DQ4O8t73vpdrr72Wyy67\nDICTTz6Zxx57DIAHH3yQs88+m9NPP50nnngCx3EolUps2LBBvT9E3/rWt/j3f/93br/9dpYuXcot\nt9zCBRdcoB5PobPOOouHHnqIKIrYuXMntVqN17zmNerxFGtra2sEdXt7O77v6/8XTXI4fV22bBkP\nPPBAY92zzjrroPs/Km8289K76J9//nmiKOJv/uZvOOGEE450Wcesm266iR/+8IcsXLiw8dxf/dVf\ncdNNN+F5HgsXLuSmm27Csizuuusu7rzzTqIo4n3vex8XXXTREaz82HTllVeyfPlyTNPkhhtuUI+n\n0N/93d/x2GOPEUURH/nIR5g9e7Z6PMUqlQof+9jHGBgYwPM83vnOd3Lqqaeqz1Okr6+Pj370o9x1\n111s3LjxkPtaq9W47rrrGBgYIJlM8vd///f09PQccKyjMuBFRETklTkqp+hFRETklVHAi4iIxJAC\nXkREJIYU8CIiIjGkgBcREYkhBbxIjPT19XHqqafy5je/ecJ/3/rWt/a7zZ/+6Z+yc+fOVzTuY489\nxpVXXvmK9iEiU0tfDycSM729vXz/+98/5PW/8pWvNLEaETlSFPAix4nzzjuPCy+8kFWrVpHL5fjM\nZz7D7Nmzed3rXsc3v/lNyuUyn/jEJ/B9H9u2ufnmm5k/fz4/+9nP+Id/+AfCMGTOnDl86lOforu7\nm4cffpibb74Z27ZZsGBBY5xNmzaxfPlyRkdHSafT3HDDDZx88slH8MhFjk+aoheJmf7+/n2m6Neu\nXcvIyAjnnnsu99xzD2984xu56aabJmz3b//2b7znPe/h7rvv5sorr+Spp55iaGiIT3ziE/zTP/0T\n99xzD8uWLeNTn/oUruty/fXX84UvfIG7776bdDrd2M91113Htddey3e/+11uvPFGPvKRj7S6BSKC\nzuBFYmd/U/S2bfOWt7wFgEsvvZTPfvazE5a/9rWv5VOf+hQPPfQQF154IRdddBEPPvggp59+OrNn\nzwbg8ssv58tf/jJr166lt7e38RXSl156KZ///OepVCqsWrWKv/zLv2zst1qtMjIyQkdHR7MOWUQm\noYAXOU6Yptm4NWUYhliWNWH5xRdfzJlnnsnPfvYz/u3f/o0HHniACy+8cMI6URTh+z6GYRCGYeP5\nl/YVhiGpVGrCC4wdO3ZQLBabdVgish+aohc5TtRqNe6//35g1/3rL7jgggnLr7nmGp5++mne/va3\nc/XVV/Pcc8/xqle9ipUrV9LX1wfAnXfeyatf/WpOPPFEhoaGWLNmDQA/+MEPACgUCsyfP78R8L/4\nxS/44z/+41YdoojsRWfwIjHz0jX4vZ1zzjkA/OhHP+Jzn/scvb293HLLLRPWef/7389f/dVf8c//\n/M9YlsX1119Pd3c3n/rUp/jQhz6E53nMnDmTT3/60ySTST772c9y7bXXkkgkJryJ7tZbb2X58uV8\n9atfJZlM8rnPfa4xcyAiraO7yYkcJ0488UTWrl17pMsQkRbRFL2IiEgM6QxeREQkhnQGLyIiEkMK\neBERkRhSwIuIiMSQAl5ERCSGFPAiIiIxpIAXERGJof8fOg9FfyxpJFAAAAAASUVORK5CYII=\n",
380 | "text/plain": [
381 | ""
382 | ]
383 | },
384 | "metadata": {},
385 | "output_type": "display_data"
386 | },
387 | {
388 | "name": "stdout",
389 | "output_type": "stream",
390 | "text": [
391 | "Sampled memory in 0.04 seconds.\n",
392 | "Updated parameters in 0.81 seconds.\n",
393 | "Sampled memory in 0.03 seconds.\n",
394 | "Updated parameters in 0.89 seconds.\n",
395 | "Sampled memory in 0.04 seconds.\n",
396 | "Updated parameters in 0.81 seconds.\n",
397 | "Sampled memory in 0.03 seconds.\n",
398 | "Updated parameters in 0.81 seconds.\n",
399 | "Sampled memory in 0.04 seconds.\n",
400 | "Updated parameters in 0.93 seconds.\n",
401 | "Sampled memory in 0.04 seconds.\n",
402 | "Updated parameters in 0.82 seconds.\n",
403 | "Sampled memory in 0.05 seconds.\n"
404 | ]
405 | },
406 | {
407 | "ename": "KeyboardInterrupt",
408 | "evalue": "",
409 | "output_type": "error",
410 | "traceback": [
411 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
412 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
413 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mq_iteration\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepisodes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
414 | "\u001b[0;32m\u001b[0m in \u001b[0;36mq_iteration\u001b[0;34m(episodes, render)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mupdate_interval\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mt\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mburn_in\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mclone_interval\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mt\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mburn_in\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
415 | "\u001b[0;32m\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch_size, verbose)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0mq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0mqmax\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
416 | "\u001b[0;32m~/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 477\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 478\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
417 | "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# Forward pass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv3\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv4\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
418 | "\u001b[0;32m~/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 477\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 478\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
419 | "\u001b[0;32m~/anaconda/lib/python3.6/site-packages/torch/nn/modules/activation.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthreshold\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthreshold\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
420 | "\u001b[0;32m~/anaconda/lib/python3.6/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mthreshold\u001b[0;34m(input, threshold, value, inplace)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthreshold_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthreshold\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 625\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthreshold\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthreshold\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 626\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
421 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
422 | ]
423 | },
424 | {
425 | "data": {
426 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfgAAAFXCAYAAABOYlxEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xm4nHV9///nfd8zc896zpw1+0oSCJsQFrEtWPTbghV/\nisWi9MKtV78u+BW0pVArNgqWUqxWaWtdqlZqC1hRS60boCyCoBACgSQkISQ52c5+zqz3/vsjYZJj\nTjY4M0nuvB7XxXWdmXv5vO/3OeE19+eemduIoihCREREYsU80gWIiIjI1FPAi4iIxJACXkREJIYU\n8CIiIjGkgBcREYkhBbyIiEgMJY50AQfi+wEjI9UjXUbsdXRk1ecmU4+bTz1uDfW5+Xp6ClOyn6P6\nDD6RsI50CccF9bn51OPmU49bQ30+dhzVAS8iIiIvjwJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp\n4EVERGJIAS8iIhJDCngREZEYUsCLiIjEkAL+ELhByFDdxQ3CI12KiIjIITmqv4v+SAuiiB9uHmT1\naJlR16eYSrC0mOcNc7uxDONIlyciIrJfCvgD+OHmQR7pH208HnH9xuNL5vUcqbJEREQOSlP0++EG\nIatHy5MuWz1a1nS9iIgc1RTw+1HyfEZdf9Jlo65PyZt8mYiIyNFAAb8fhWSCYmryKxjFVIJCUlc3\nRETk6KWA34+UZbK0mJ902dJinpSl1omIyNFLp6EH8Ia53QCTvoteRETkaKaAPwDLMLhkXg+/P7uL\nkudTSCZ05i4iIscEBfwhSFkmXVbqSJchIiJyyHQ6KiIiEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAi\nIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBERkRhSwIuIiMRQUwN+5cqVXHnllQCsXr2a\nK664giuvvJI/+ZM/YXBwsJlDi4iIHNeaFvBf+cpX+PjHP47jOAB8+tOf5oYbbuD222/n937v9/jK\nV77SrKFFRESOe00L+Llz53Lbbbc1Hn/2s59l6dKlAARBgG3bzRpaRETkuNe028VedNFF9PX1NR73\n9vYC8OSTT/Lv//7vfOtb3zqk/fT0FJpSn0ykPjefetx86nFrqM/HhpbeD/5///d/+eIXv8iXv/xl\nOjs7D2mbgYFSk6uSnp6C+txk6nHzqcetoT4331S9gGpZwH//+9/nzjvv5Pbbb6dYLLZqWBERkeNS\nSwI+CAI+/elPM2PGDP7f//t/AJxzzjl8+MMfbsXwIiIix52mBvzs2bO56667AHj88cebOZSIiIjs\nRV90IyIiEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIi\nEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBER\nkRhSwIuIiMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBERkRhSwIuI\niMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkNNDfiVK1dy5ZVXArBp0ybe\n8Y53cMUVV/DXf/3XhGHYzKFFRESOa00L+K985St8/OMfx3EcAG6++WauueYa/uM//oMoirjvvvua\nNbSIiMhxr2kBP3fuXG677bbG42effZZzzz0XgAsuuIBHHnmkWUOLiIgc9xLN2vFFF11EX19f43EU\nRRiGAUAul6NUKh3Sfnp6Ck2pTyZSn5tPPW4+9bg11OdjQ9MC/jeZ5p7JgkqlQltb2yFtNzBwaC8E\n5OXr6Smoz02mHjefetwa6nPzTdULqJa9i/7kk0/mscceA+DBBx/k7LPPbtXQIiIix52WBfx1113H\nbbfdxuWXX47neVx00UWtGlpEROS4Y0RRFB3pIg5EU0HNpym35lOPm089bg31ufmOuSl6ERERaR0F\nvIiISAwp4EVERGJIAS8iIhJDCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp\n4EVERGJIAS8iIhJDCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp4EVERGJI\nAS8iIhJDCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp4EVERGJIAS8iIhJD\nCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiaFEKwfzPI/rr7+erVu3YpomN954IyeccEIr\nSxARETkutPQM/oEHHsD3fe644w6uuuoq/uEf/qGVw4uIiBw3WhrwCxYsIAgCwjCkXC6TSLR0AkFE\nROS4YURRFLVqsO3bt/PBD36QarXKyMgI//Iv/8KyZctaNbyIiMhxo6UBf/PNN5NKpfizP/sztm/f\nzrve9S7uuecebNve7zYDA6VWlXfc6ukpqM9Nph43n3rcGupz8/X0FKZkPy2dI29rayOZTALQ3t6O\n7/sEQdDKEkRERI4LLQ34d7/73XzsYx/jiiuuwPM8PvKRj5DNZltZgoiIyHGhpQGfy+X4/Oc/38oh\nRUREjkv6ohsREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiSEFvIiISAwp4EVERGJIAS8iIhJD\nCngREZEYUsCLiIjEkAJeREQkhhTwIiIiMaSAFxERiaFDupvc+Pg499xzD6Ojo0RR1Hj+Qx/6UNMK\nExERkZfvkAL+6quvplAosHjxYgzDaHZNIiIi8godUsAPDg7y9a9/vdm1iIiIyBQ5pGvwS5cuZc2a\nNc2uRURERKbIIZ3Br1u3jksvvZSuri5s2yaKIgzD4L777mt2fSIiIvIyHFLAf/KTn6Srq6vZtYiI\niMgUOaSAv+666/jhD3/Y7FpERERkihxSwJ900kl873vf4/TTTyedTjeenzlzZtMKExERkZfvkAJ+\n5cqVrFy5csJzugYvIiJy9DqkgL///vubXYeIiIhMoUMK+L/8y7+c9Pmbb755SosRERGRqXFIAX/u\nuec2fvZ9n/vuu4+FCxc2rSgRERF5ZQ4p4C+99NIJjy+77DLe8Y53NKUgEREReeVe1t3kNmzYQH9/\n/1TXIiIiIlPkkD8m99JNZqIoorOzk49+9KNNLUxERERevkMK+Mm+h9513SkvRkRERKbGIU3RX375\n5RMeh2HIH/7hHzalIBEREXnlDngG/853vpPHH38c2DVN39gokeB1r3tdcysTERGRl+2AAf/Nb34T\ngJtuuomPf/zjLSlIREREXrlDmqL/+Mc/zj333MPnPvc5arUa3/ve95pdl4iIiLwChxTwn/nMZ3jg\ngQf4yU9+gu/7fOc73+Fv//ZvX9aAX/rSl7j88st561vfyre//e2XtQ8RERE5sEMK+Icffphbb70V\n27YpFAp8/etf58EHHzzswR577DFWrFjBf/7nf3L77bezY8eOw96HiIiIHNwhfUzONHe9Dnjps/Cu\n6zaeOxwPP/wwS5Ys4aqrrqJcLvMXf/EXh70PERERObhDCviLL76Ya665hrGxMb7xjW/w/e9/n0su\nueSwBxsZGWHbtm38y7/8C319fXzgAx/gRz/6UeOFg4iIiEyNgwb8Cy+8wJvf/GaWLl3KzJkz2bFj\nB+9+97v59a9/fdiDFYtFFi5cSCqVYuHChdi2zfDwMF1dXfvdpqencNjjyOFTn5tPPW4+9bg11Odj\nwwED/rbbbuNrX/saAP/4j//In//5n/Ov//qvfPKTn+TMM8887MHOOussvvnNb/Ke97yH/v5+arUa\nxWLxgNsMDJQOexw5PD09BfW5ydTj5lOPW0N9br6pegF1wID/3ve+x49//GP6+/v5whe+wFe/+lUG\nBwf5/Oc/z/nnn3/Yg1144YX86le/4rLLLiOKIj7xiU9gWdbLLl5EREQmd8CAz+Vy9Pb20tvby9NP\nP81b3vIWvvrVr76iUNYb60RERJrvgAG/9zvlOzo6uP7665tekIiIiLxyB/ys297vbk+n000vRkRE\nRKbGAc/g161bx+tf/3oAdu7c2fg5iiIMw+C+++5rfoUiIiJy2A4Y8D/+8Y9bVYeIiIhMoQMG/KxZ\ns1pVh4iIiEyhw/++WRERETnqKeBFRERiSAEvIiISQwp4ERGRGFLAi4iIxJACXkREJIYU8CIiIjF0\nTAS8G4QM1V3cIDzSpTRN6Di4/f2EjvOytncDl4HqEG7gTnFlIiJyLDrgF90caUEY8T+bBlg9WmbU\n9SmmEiwt5nnD3G6svb4n/1gWBQED376D8oon8YeHSXR2kj9zGT1vezvGIdy1LwgD7l7/A54eeJYR\nZ5QOu8jpPafw1kVvxDJ1K14RkePVUR3w317TxyP9o43HI67feHzJvJ4jVdaUGvj2HYze+9PGY39o\nqPG49+1/fNDt717/A37e93Dj8bAz0nj8tiX/3xRXKyIix4qjeor+qZ2jkz6/erQci+n60HEor3hy\n0mXlFSsOOl3vBi5PDzw76bJnBp/VdL2IyHHsqA744Zo36fOjrk/J81tczdTzx8bwh4cnXzYyjD82\ndsDtx5wSI87kL4KG66OMOaVXXKOIiBybjuqA78wkJ32+mEpQSB7VVxcOSaK9nURn5+TLOjpJtLcf\ncPt2u0CHXZx0WWe6SLtdeMU1iojIsemoDvgzpk0eXkuLeVLWUV36ITFtm/yZyyZdlj/zTEzbPuD2\nKSvF6T2nTLrstO5TSFmpV1yjiIgcm47q0+C3nTSbWtWb9F30cdHztrcDu665+yPDJDo6yZ95ZuP5\ng3nrojcCu665D9dH6UwXOa37lMbzIiJyfDKiKIqOdBEHMjBQwg1CSp5PIZmIxZn7ZELHwR8bI9He\nftAz98m4gcuYU6LdLhz2mXtPT4GBAV2vbyb1uPnU49ZQn5uvp2dqLq8e1WfwL0lZJl0xn242bZtU\nb+/L3j5lpejJdk1hRSIiciyL5+mwiIjIcU4BLyIiEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAiIiIx\npIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBERkRg6IgE/NDTEa1/7WjZs2HAkhj+o0HFw+/sJ\nHQcAxwvoH6nieMGEn/fH8wLGRmp4B1jnpf3u3DlKedv2xljNsvcxHWp9IiJy7Gr5zWY8z+MTn/gE\n6XS61UMfVBQEDHz7DsornsQfHsbs7OKBuRewxuhiaNwhnTIBA8cN6GyzOXNJD5e/bhGWuet1UhiG\nPHL/BjY+P0h53CHfZrNgSTe/9boTMM09r6WCMOTOe5/H/tn/MGdkI21+BTfbRu955zLt8ndgWFZT\njskdHmHDzN9mID+PWpBo1PfmPzpjysYTEZGjQ8vP4G+55Rbe/va30/sK7pzWLAPfvoPRe3+KPzQE\nUcRPmM8vSnmGxnedXdfdkLobEAFD4w73/rqPO+9f39j+kfs38Myvt1LevX553OGZX2/lkfsnzlTc\nef96oh9/j9MGVlH0K5hAujrO+P33MvDtO5p2TOs6z2Jz5gRqQWJCfT+557kpHVNERI68lgb83Xff\nTWdnJ+eff34rhz0koeNQXvFk47FnWDyfn3PQ7VY8P4jjBXhewMbnBydd58XnBxvT4Y4X8PSa7Syu\nbJl03dKKJ6dsun7vYwoMi8H83EnXe37VDk3Xi4jETEun6L/zne9gGAaPPvooq1ev5rrrruOLX/wi\nPT09+91mqm58fzC17RX84eHG47KVYTyRO+h2I6U6ViqJDZRLkwdzueSQTiXp7M6xfbCCNzJKm1+Z\ndN1gZIQ2yyfT0/2yjmNvex+TY2Wp7+d4xkZrjfqkeVr1t3w8U49bQ30+NrQ04L/1rW81fr7yyitZ\nvnz5AcMdYGCg1OyyAAiDBInOzl3T80A+qNHmVxhPHvgPuaOQJnA96kC+YDem5/eWL9jUXY+BgRKB\nF5DsKDK+NUdxkpC3OjoYDxKUp+C49z4mO6iS9ivUJzme9mKmUZ80R09PQf1tMvW4NdTn5puqF1D6\nmNxupm2TP3NZ43EyClhSnnwafW9nLunGTlokkxYLlkx+1j1/STfJ5K43ztlJi9NPmsG63OTT/4Uz\nl2Ha9ss4gn3tfUxWFNBd3jzpektOnd6oT0RE4sFavnz58iMx8Fvf+lY6OzsPul616ragml2yS08h\nrNfwx8YJnTqLMi5h7wyq6QJ1JyCdskhYJmEY0dmW5rdPm87lr1uEaRgAzJ7fgev41MounhtQaLM5\n8bTp/NbrTsDYvQ7AyfM7eNbsYah/FNutkgpdnFw7neefT+8fvQPDnLrXXXsfU3FkI2G6gJcu4GM2\n6nvjW0+jVvOmbEzZVy5nt/Rv+XikHreG+tx8udzUnOQZURRFU7KnJjkSU0Gh4+CPjZFob8e0bRwv\nYKzs0J7f1fSXfrb3c9breQHVsks2nzrgmbHjBYwOl8gFNbJdnVN25j6ZvY8pMBMT6tOUW/Opx82n\nHreG+tx8UzVF3/LPwR8LTNsmtdfH+OykRW9HtvF4758nk0xatHdkDjqOnbSYNq0IFF92rYdq72My\n4ZDqExGRY5euwYuIiMSQAl5ERCSGFPAiIiIxpIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBER\nkRhSwIuIiMSQAl5ERCSGjvqvqg1Dj8ArYSULmGbysLZ1vIDR8TIp08EJ0+TzGYZrVUbGRumx22gv\npqn4Y6SCiLqfo6Mt1/h++brvMFYdpz3bRjpx8O+If+m73utZm51+lRnZDvKpNABuEFLyfArJBCnL\nbOx/sDQECYPebBcpK7XPcY+VBtlecZhe7CEyk43tPS9gbLxOPRFgWg5tdgEnMBvLy+Nj7Njcx/S5\ns8m3te8z/m9+176IiMTPUR3wW9Z8n6HtzxB4Y1jJdjLFE+mY9fsYxoEnHoIw5M77nydTf4Qtw2k2\nDhfxZ3cTWiuYNpSlMN7ByLQXedXcAbZuXsjGgR7G6jYduYhlJ83CnllmXTVFiTQFtrE46/KWpctI\nmPveOCYKAga+fQejK1dw57m99LdXCalimXlm5ZZwUs+FPD9WZdT1KaYSnNSWoVbZxBMjq6j4W4mi\nMraZ5bwZZ/CHi9+EaRjs3Pwjvvj0cwwmxrELZ5BMzMcy87QnLGa8WGFHMmBHbiWusQU7eQrJ5AJM\nM0fRgJk/ewzHKeJaOZLhC5ROacedP5txL6AjYXL+r35Gz4bV+MPDJDo7yZ+5jO4P/mmzfoUiInKE\nHNUB37/54cbPgTdGeeBxADpnX3zA7e68fz1W6WH6azZPbZtOYXE7YeIJpg+n6B6cw7Y5z3HGgh0M\nvngiT22Z1dhupGKw09rOYHXPvdpLZHmymoXVT3LZKefsM9bAt+9g9N6fcucbT2ZH++CeesMy/U6K\nsYHxPft3fTYOrmdj/QVcb23jeSes8sDWRzAMk9dnU3zx6WcYzA2RTp2HbZ+2Z7DnhunLWAwUV+F6\nz5FOTlzec/8jlMJ5jd9q/0mzKM8qgBcAcMIDP6Jj1a/wd6/vDw0xeu9P2ZhOUXjL2w7YUxERObYc\nc9fga6PPE4b7v3e54wU8s34HJ3SNsKa/C0yDZEeCKLGdtpFphGZAtWMn863kruV7SSZCqrnJ71G/\nrpqi7jsTngsdh/KKJ6nZCfrbqr+xhUUiMe83nvEZi9J4/ouTjrFq4BkGBlczmCjvs70RhKSHa1S7\nkru3n7jcqlcJ3T21h6ZBrXvPHeMsz2Pui89POu7w478idJxJl4mIyLHpmAv4wBsj8PZ/L+KxsoPn\nlDCNiLG6jWWbREaNFBFJN42XrJPO1DG8NGP1idef8/mQqjn5bVTLpBmrjk94zh8bwx8eZqgjT2hM\nDHjTyGKa+QnPZahTjiKiqLyfYyuxs1wnSjn7bG86IUZk4NkOUVTeZ3lmbAjXyjUeh7ZJkN5zSSFb\nLZMrj006rjM4iD82+TIRETk2HXMBbyXbsZKF/S5vz9sk7QJhZNCedgicECPK4GLgpeokvTT1Wpoo\nWac9PfGstVw2yYa1Sfebp057tm3Cc4n2dhKdnXSNlDGjifeID6MqYTgxyGukyRsGhjEx+PccW4Fp\n+TSGa++zfWibREZE0rExjPw+y2vtXaSCSuOx6YRY9aDxuJrNU8m3Tzqu3d1Non3yZSIicmw65gI+\nU1xywHfT20mL0xZNZ8NQByf1DkEY4Y34GP4Mxjt2YoYW2ZFpvBh4u5bvxfNNspXhSfe7OOvu8256\n07bJn7mMjOPTO579jS0CfH/TbzyToN2ok0zMn3SMU3tOo6d7Kd1+fp/tI8uk3pkhO+Tt3n7i8iCd\nxUztqd0MIzKDe16sBMkkm+cvmXTcznPP0bvpRURixlq+fPnyI13E/gRejXp1nCh0sJJFcl2n734X\nvXHA7U6e38Ez23IUrc1kE3XGdgZE+YWUslsxjBrF4em86PksnNtHDou6a+P4Jh25iMVdM5nTOU7Z\nC/GwKFDj1GyFtyxdhjnJu/ezS08hrNdY9OQLbOruoZqGCA/LLDAjU+SMnpOo+gFOENKRSrC4YxrT\nrQRDbogXOYCLbWb5nZnn8IeLLyHbtoil6RLP9VUZ9zeBZWIYWUwjSWZajpljAYbXQz0V4obrIYow\nzRyGkSRcOIfuTc+BZxIYSdoGhkikPVJdRZwworZwMb1GSKFeJXTqJDq7aPvt32bJ/30v1dr+39cg\nr1wuZ1Otuke6jFhTj1tDfW6+XG5qTriMKIqiKdlTk+zcOazPwTf5c/A9PQUGBvb/vgZ55dTj5lOP\nW0N9br6env1fhj4cR33A6w+p+fQPtvnU4+ZTj1tDfW6+qQr4Y+4avIiIiBycAl5ERCSGFPAiIiIx\npIAXERGJIQW8iIhIDCngRUREYkgBLyIiEkMKeBERkRhSwIuIiMSQAl5ERCSGFPAiIiIxlDjSBRyM\n5wVUyy4Jw6c2OIbdk6eeDMgkcziBSQYDr+aTzaeITGOfm7oc7lij42X8pENXrkjKSjXGz+ZTJHff\niOalm73UTKBawaoPMZboZHZPgbRVxwhT1McqjCXTECVoyyYZr1WJrAqFhEU1zJF0Tdrb01hW2LiZ\nDkCtPkw5iuhIdzZuQON4AaPDJWyngpvMERXSFG0D3ykxXoKUW6FuGXT29JKx99zgZqhexXMrZIMM\n7e05kklr0hvfiIhI/LQ04D3P42Mf+xhbt27FdV0+8IEP8PrXv36/6//oe6t4dsUWkv1bqCbzvLhw\nM+WOnZhtZ2Bb82nfEJAbqmPWAyond1DvyVK3oJhKsLSY5w1zu7EOcmtZgDAMefj+dTww8nMGc9vw\n7RrZKM/JI68mMdBGedwh32Yzf3EXIfBYrcrWoTKX2Q/yU38ZO4dTvG7eL3B7Bsk8uZ1Hi6/mKX82\n46MBkRuQnLOG/zN3kHJqGSPrukkM+iSdgJNP2cjs6cMkrSqhkeRn1SrPuy7jYUQxYXN679n4m5Zg\n//wHYLSz4YzTqEzLc1ZmFX3PB3Q/s5GBWTPZEExjzEnTbq/i1PkJus44jV9uu5fcOoO24U6SbgYz\nE2Ism8NYW5Ix15/QIxERiZ+WBvx///d/UywWufXWWxkdHeUtb3nLAQP+8Yc2UqxuZzQ7g21znmN4\nxibSqfOw7dNoXztKW18FgJHFbZSnZxvbjbg+j/SPAnDJvJ6D1vXI/Rv46c77GJ7xYuO5ts1zqO+0\nAQeA8rjDqie2MbK4jW01nz+2f86Po7Pp25biohNf4DXzt+E9NMgjuXP5lbmIan8ZgMSctfz+gh24\nqVfTv2Z6o+alJ25gwaxtjfHur5R4wvEbj0d9h3sfHeR3V63GThR45tyzKM8t8BrjCbY971FcsZXh\nedN5sjq/sc2Yk+bpsB2z76d0rg/o3rmgsWx4dgfltAGuv0+P3tPbdtAeiYjIsaWlc7QXX3wxV199\nNQBRFGFZ1gHXN0KfWrJAaAaMd+wELBKJeRhBSGawBkBoGtS6M5Nuv3q0jBuEBxzD8wLWr9uxe/+7\nxw1M2kam7bNuaBpUOtIwOErb9IhtA2mSZsBJvUNEXoiz2WXTnMXUd9eGGWB37GRBMsVGf1ajZtMM\nmN47tKeGKGKdF0wYKwpMjKEuTqhuo79tHrXuDAl8ZkdbWb+jnfnuNp4PZkws0DRIdSXxnc0T6j9Y\nj5yD9EhERI49LT2Dz+VyAJTLZT784Q9zzTXXHHD9VFDDSeTwkjV8u4ZpFDDNPGYtJFHfFUqhbRKk\nJ3+hMOb6JAtpenL2fscYHqww6ozj27XGc0kvTdJN77NuaJt4JvQEowwlivj1kI6MS3vaISoFVMM0\n46k8YX3XWbqRdChkHBJmO7WqTbE+DkDadkmnncZ+y2HEeBhNGCvybPK1CDuKqOYKBGmLNspETkBQ\nCbHSJmPOxBot28S0XRJVf0L9B+vRWN2jd4ruPyz7N1X3eJb9U49bQ30+NrT8TXbbt2/nqquu4oor\nruBNb3rTAdd1rQy2XyX0MiScDH66ShiWCe08ftokWQ8xnRCrHhBk9z2U9lQCr1RnoOrudwzPCyja\nbbv3vyvkvWQdL1Un5WYnrGs6IckQBqwiXf5aEmmTspNirG5TzIZkzTptbpnhtEVYD4g8m1LNxs9X\nyNhOo+a6k6Jet8lmd4V83jRoM40JIW8kHcoZA8cwyFZKWPVeqtkMhm1h5UyC/pB2u86Ys+fMPHBC\nQieFn0pMqP9gPWpPJxkYKB3wdyGvTE9PQT1uMvW4NdTn5puqF1AtnaIfHBzkve99L9deey2XXXbZ\nQdePzAQZr4QZWrunnAN8fxORZTamnM0wakx9/6alxfxB3ymeTFosWjx9wpR2ZIUTpuxfYoYRuZE6\ndBcZ32Ews6eOF1qs6e/CSJrYc1PM27KO9EvT4aGFMzKNjZ7LgsTWRs1haLGjv2tPDYbB4uTEM2zD\nCom6htiQnUnv+CYygzV8EvQZs1g0fYwXUzNZYm2fWGAY4Q55JOy5E+o/WI9svZteRCR2rOXLly9v\n1WCf/exnefbZZ1m/fj3f/e53+e53v8sf/MEfkEhMPpFQq3nsLEF2dCu58R6cpIFjbiBMmPhdnRhB\ngoQXkun7V8/aAAAaWklEQVSvY2USGJkkgQkdqQTLutp4w9xuzEN4F/3s+R0kB9oZGh7HMeuElk/U\nVmdGZjrpIIfnBhTabJacOo05doqKFfJIaSaXJH7JYK6XVTu6SeLTtshi/ua1hHaW0UwR14sIRzp5\nMaqzKPsCyd48Jb8AbsRwf5FkOiSfC7BMj/mpDG4UUQ5D3AiKiTSvOWkRYfdvw6bn6No+TDVd4MXU\nbJZMG6I/VaBrfR/tXQFV08YNLIq2w2kdDkuXvIaNqRfx3RKWm8AKLLL1GsWeNsxsEicIJ/Qon7Op\nHmCWQ165nHrcdOpxa6jPzZc7wGXlw2FEURQdfLUjZ9u2UX0Ovsmfg9eUW/Opx82nHreG+tx8UzVF\nf9QHvP6Qmk//YJtPPW4+9bg11OfmOyavwYuIiEhrKOBFRERiSAEvIiISQwp4ERGRGFLAi4iIxJAC\nXkREJIYU8CIiIjGkgBcREYkhBbyIiEgMKeBFRERiqOW3iz0cgePg9veTaG/HTxgMVIZwnJCEl6Qj\ncjByULYMLN9kaNykK6jhJwNCO0F7AKEVMFqqYaWnEeVtTMuhO9NOykpR9x0GxkYhsuko5Cj5DmEw\nTkfCwrLaqVcjgrBG//aNTJvRTSHbhTtSxrEyZPNJKI3gpUMqySQd6U4AhspjJD2b9vYclhVO+F75\nhB/hj40R5jOUcGm3C43vuh8YrjDqh8xsS1Otuoz6ITM6MoSmge261HYOku1KYbcVcepVxktQiDyy\nXZ0EZoJq2SWdijAqZRLt7Zi2TRh61Kpjje+qLxk+Zj5Jm5Wk5KYw7YCe3d+3LyIi8XNUB/yKD11N\nbXCQX5zXTfmk83lxfZGzn32EM04e5t45Jp3pFE+sP4+Tnn2aE860eKh3jO4nh1nSm2btdoeR8qms\nO+1V9Hf/CpctRJTJJNo4MXsez63JUhp0yc7JYXY9w/l2H4uskJXrF7F9R5G20oOcujAgPSvDwL0l\nHh46icHkTOaMPcfs00f4xRyT9UbIeAQpLHo2LyU33E3KS3PSyevZ0LGVDX6dkh/y+qdqLNwe8PgJ\nFi/MzTCeMei0O5i/bRnPjRYYK7vMcENKxRTlmk9uboFMl83Jv17NggVD9EwfpTTqcu/aeXQ83ce8\n8W3k/SrPT/stRtvnMnvnU/RWt2B7ZZLdXST+zxx+3D+NtpUb6J81g3XTSvzOjBG2b1nA+sQYfmEI\n065jk+e82a/i/V3vONK/ahERmWJH9RS90z/AQ2dkKZ3826x7YTpnr3yE807q57FFCdrTCZ7a+Fuc\ntPJpTjrT5Lk5I7T9eojTu9L0DzpUhk9m7RmvZueM9TisJqIMwOzUqaxc08ZYX53cnAJW72rOT73A\nOWnYtn4RmzbPpm38QZbN9cmd2k64apy125fQl1vKnLE1LDh1gEcWJXiSXeEO0LlpCR075pJysyxd\n8gIbO19khVtjPIw4/8kyp6wp8/gJJk+dlGU8a4ABqXXTWTWcY2S4ziw3pFxMMT7qkptbIDe3jaW/\nWsPcE8aYt7CfbNblp+sWkFuxnVcNr6PoV9jQdTbbC4uZvfMp5o6tJu2VMYDoxIgf9neRfWITw7Om\n80xXndfOGWSkbxHPUiHs2YqZroMBjlHmga2/4PaV3zlyv2QREWmKozrgPQtenJ1nnDn4/SUWO1sI\n52fZ6AXMNtP070ywyNlGarbPCzWfhTtcgtlpcptc+tvmUe1K4vkvNvaXwGKcOdQHa2AapLqSRP6L\nLE5ZBIHJzv4uAurMHx3GWpAj8kLcjQ6D+bmYoU/37vHXeUFjn0Zg0jYyDQDTDOjqGWwsT/gRC7c6\neBZsmJWesE1udDqlmo8JFAyDUs0H08DuzpCou4SezfTeIQDcwGT9jnYWV7YAEBjWnpp2P7drQINw\nboH1O9qZ725jbdiL3bGT+VaS1QNFrI6dk/b58S0rcQPd31lEJE6O6oCvZCyiQp6ykyZbKdOecqjk\nLELA8dpJVyqkswZ+yiWs+LRhUDMg6aSp5gp4tkMUlRv7y1k5yk6asB5g2Sam7ZIzKrSZJnUnRa2e\nBmOUNhOMQoKoGuDUUtQTOeygSiZVp5KzGA/33GE36aVJurvCO227+Kl6Y3muFlCohFQyFqWcOWGb\ngAyhE5IEopRJ6IRYtomVTpAbK2FmTdJpB4CykyKohLT5FQAcK9uoKe3vOT4ja1FOZQkqIVbapBQZ\nFDIOhpemFIJh1yft81BthDFHt38UEYmTozrgc7UAo1Qmb9ep5vKMuTa5SoAJ2Mkx6rkc9WpEwk1h\n5hKME5GJwLPrZCslko6NYeQb+6sEFfJ2HTNtETghoZOiEuUYD0PStksmXYeoyHgIUcnHyFrYGZe0\nX8GxstTcNLlKQJtpNPbpJet4qV3BWXdSJNx0Y/lLwf5S0O+9jUUN0zbxAMMNMW2TwAkJ6j6V9gJh\nNaRetwHI2y5WzmQ8kQPYHeyV3UG/5/iiakDerWLlTIJ6SMGIKNVsomSdggmRs2cWYW9dmQ7a7am5\n/7CIiBwdjuqATwYwv69MG1tI9BZYZ8/BfLHKgqRFX1ind5rPensmbl+ChZkEL0xPYfXVqcxL0Tu+\nieyQRzIxv7E/n4A2tpDuzkAY4Q55GIn5rHMDLCtkWu8QFmleLHYSbKxgJE1SC2y6y5sJzQSDu8df\nnLQa+4yskPHdU99haDE00N1Y7icMXphlkwzghK31CdtUijsoZBKEQCmKKGQSEEY4gzX8dAoz6bCj\nvwuAlBWyaPoY63JzALCiYE9Nu5/bNWCEubnEouljvJiayYlmP87INF4MPJb2jBLsvpTwm86d8yq9\nm15EJGas5cuXLz/SRexP//0/Z+bGcXbag7Qt7uK51GLC1cOcbVXYkItY0LONZ1JnwdM7ODGbZcti\nk9ILJeb2pHGDHeQ2ezjZk3GyBqFRB1zq0RinzOih7GYo76iRSMxki+2RYJx53UOYQYIh/0SG+7fQ\n4ddInJCns7odbwx2ZBcQbqlySrKK325SMcAB3OIIZpDCcpOMDHSzuBBh5+pUI5/105LkfZNTN4cE\nYUA1m8BNGNi9AUv8bspGjoEgpKviYxRTVIZ2vRAYXTSd1LMlDExStseS3iHWJGczPp4k7blMq2zC\nTebo7zqRKAiwwzpW4GHVspx4WpI17bMorttCPtHBk0GCE2dtIzPey0g9RWB5GJaPTZ7fmX0Of3LO\nH1GreUf0dx13uZxNtar3OTSTetwa6nPz5XL2lOzHiKIoOvhqR0bgOOxY36fPwTf5c/A9PQUGBnQN\nvpnU4+ZTj1tDfW6+np6puWR6VAc8oD+kFtA/2OZTj5tPPW4N9bn5pirgj+pr8CIiIvLyKOBFRERi\nSAEvIiISQwp4ERGRGFLAi4iIxJACXkREJIYU8CIiIjGkgBcREYkhBbyIiEgMKeBFRERiSAEvIiIS\nQ4kjXcCBbNo5wkh5mPZUgWotoj1vYyctPC9gdKSG67m0Gw7Zrk78hMGYUyJjZhkf9yECK5ekI5Mi\nZZmEjkNtcIA6Abm2DsxaHTOTwS1VGjeQ8UolxpJpiBJ0t6ex97otrBu4jNaGMQKP4dBgVq6bfCqN\n5wVUyy7ZfIpk0sLxAsbKTqPW+vgY4zu20jZ9FlYmv0/dpm039j/mlGi3C0Q1n9KOYQrTO7Hz2X2W\nA42f977Nq+vU2TkwwHiUYV5PO/l0ckI/967VskICr4SV1H3gRUTiqKUBH4Yhy5cvZ+3ataRSKW66\n6SbmzZu33/U/+sMbCXbOJRydQeTadBVsTkyn2DJa4aRtj3FCZQtjQYUHzulg87wcQ1sXMmNoDtbC\nLrzuDEHGIuuHvP6ZB0l1lshNd0g8uYPBFyqElZD13ecwmJ3FrLE1bD1lPk8zg9JYSOiEZLIJfmvp\ndN524QK+v+F/sUefYYWbZ9AfJYyqWORZuP0susY6KY87ZAspBjIJdtY9hscduvMml85aRbG3TCrl\n8uiji/n19k5OHljBCZUtlPwKbraN7vPO4RdnFXh6cDWj1VHOWbEY35uOY2ZJh3VmFn3GXh/yzNBq\nhp0RbNMGY1fgd9hFTu85hTcvuJjnHr+bu1/oYMeQjV8PSdgmc+cV+Ys3n0bCMHjk/g1sfH6QSqnO\naaduYvq0QZJWFSvZjjtyGnbn72IYmtAREYmLlgb8vffei+u63HnnnTz11FP87d/+LV/84hf3u37Q\nP5egf0Hjcbbksr7kctrA45w1tgaAB5blWLUoibtpATP752MvLlKeu+es9JTH7yM1vULPojreQ4ME\nz4xjAOu7zqGveDKLBx5j+8mzeCK1gGpfubFdrepz3xN9bDQeZWl2Hc+E3fR72xrLezbNwd6Zo4wD\nwOpSnf69brB0Se/TzJw/CsCzqxfy6PZpE+oGSFfH+cHoL3lqaw6As544gQqLYPfEQd3K8nDxOYa3\nvdjYxgmdxs/Dzgg/73uYGX07uHf7XPq2JoEQAN8JeeH5YW79/ipe25blmV9vBWDpiRuYM3PPcQTe\nGP2bHyZfc+mcffF+fxciInJsaekp2xNPPMH5558PwBlnnMGqVasOuH4wMq3xswm0AeXQZ3FlCwCe\nBRtmpYkCk2h4Gu2mSa0709jG8jxmbdlAbk5A5IUEG6u79mtYDObnYoY+xfp2Ns1ZTH2wtm8BZsCA\n/wJzEkl2eKONp43ApG2v2gIiRvfaLINDz6zKrmWBybadXRPqfslL9QMkXYPInzlheWgGjHfsPGCP\nkoFJNjnOtoH0pMs3bx7hhecHdx2OGTC9d2jS9WqjzxOG3gHHEhGRY0dLz+DL5TL5fL7x2LIsfN8n\nkdhPGe6e0EoCBpAKarT5u8KzkrEo5Uwi1ybppUlkTIL0nuvm2WoZ2wpIpz2iUgAlHwDHylJP5Mj4\nZSIbxlN5wnpln+GNpEPOruPQSRjtCcaklya5V20e4O613cxkhXR6V1jWnRTjjk0qKDXqfslL9QMU\nSklcKzdhuZes49uTvPDYSzFMU47S+PVw0uVmPaRS33XWn7Zd0mln0vUCb4xiIcTO6pp8s0zVPZ5l\n/9Tj1lCfjw0tDfh8Pk+lsifkwjDcf7gDpOrg7nqTmQdEgGtlGE/kKPoVcrWAQiVkPOvgJuv4TgKr\nHhBkd+2zms3jBBb1epJMNoRCAko+dlAl7VdwrCyGA21umeG0RVgPJgwfeTYVJ42dq2MaWcJo1wyA\nl6zjpeqkdteWBFLsCfltXo56PUk265G2XdpsBzfcU/dLXqq/VLAoFTxSQQU3secfTtJLk3Ay+On9\nh/yoWSdv1EmkzUlDPkyb5FI2lXGHupOiXrfJZvcNeSvZzmjJxKyU9lkmr1xPT4GBAfW2mdTj1lCf\nm2+qXkC1dIp+2bJlPPjggwA89dRTLFmy5IDrW3tNT4fAOJA3E6zLzQEgGcAJW+sYVojRuZOxMCSz\n11R7kEyydc4JVLZYGEkTa8GuQLaigO7yZkIzwWh6BvO2rCO919T+nkEtehIL2eJ7TE8WG09HVjhh\n6tzCoLjXZjVsBnZfV7eskJnThibU/ZKX6gfwUhFGYtuE5WZoTbgUMBnPCql6bczsqU+6fO7cDhYu\n6d51OKHFjv6uSdfLFJdgmslJl4mIyLHHWr58+fJWDbZw4UIeeughvvSlL/HQQw+xfPlyOjs797v+\nt9fdBUTgZSCwSBZslrRn2JjsxfMcskGdE7ZVqGVsvDkV+kODzLYkKdMiSlpElsHozAXM27wTpwbW\n4jRW6BFWfDrGtuKbKXYUFjKjbyP5LhjLteP5EAUR2VyCC06fyf+98ALW1KvM9Poo0UUtConwqBXr\ntJndFCnguQEzCzbZYprQMnDcgD5vBrPDElYqYFrvIKm6zdPRQvzAJxvUSYYubradMxadi3XiIkpe\nmY3d25m/08D0UgRGgnRY4xTyTH/VbMpemZpfJ23ZWGaCMArpSnfw6uln8/pXXUZXZQVbgxRVL0Ho\nQyJtMn9hB3/x5tOYt7AT1/GplV2272gjk4FM1scyfaxkkZ7ZZ5Of9noMw2jVn8JxJ5ezqVbdg68o\nL5t63Brqc/PlcvaU7MeIoiiakj01gT4H35rPwU+b1qkptybTtGbzqcetoT4331RN0R/VAQ/oD6kF\n9A+2+dTj5lOPW0N9br5j8hq8iIiItIYCXkREJIYU8CIiIjGkgBcREYkhBbyIiEgMKeBFRERiSAEv\nIiISQwp4ERGRGFLAi4iIxJACXkREJIYU8CIiIjF01H8XvYiIiBw+ncGLiIjEkAJeREQkhhTwIiIi\nMaSAFxERiSEFvIiISAwp4EVERGIocaQLmEwYhixfvpy1a9eSSqW46aabmDdv3pEu65jleR4f+9jH\n2Lp1K67r8oEPfIBFixZx/fXXYxgGixcv5q//+q8xTZO77rqLO+64g0QiwQc+8AEuvPDCI13+MWVo\naIi3vvWtfO1rXyORSKjHU+xLX/oS999/P57n8Y53vINzzz1XPZ5inudx/fXXs3XrVkzT5MYbb9Tf\n8hRauXIln/nMZ7j99tvZtGnTIfe1Xq9z7bXXMjQ0RC6X45ZbbqGzs/PAg0VHoR//+MfRddddF0VR\nFK1YsSJ6//vff4QrOrb913/9V3TTTTdFURRFIyMj0Wtf+9rofe97X/TLX/4yiqIouuGGG6Kf/OQn\nUX9/f3TJJZdEjuNE4+PjjZ/l0LiuG33wgx+Mfv/3fz9av369ejzFfvnLX0bve9/7oiAIonK5HH3h\nC19Qj5vgpz/9afThD384iqIoevjhh6MPfehD6vMU+fKXvxxdcskl0dve9rYoiqLD6uvXvva16Atf\n+EIURVH0P//zP9GNN9540PGOyin6J554gvPPPx+AM844g1WrVh3hio5tF198MVdffTUAURRhWRbP\nPvss5557LgAXXHABjzzyCE8//TRnnnkmqVSKQqHA3LlzWbNmzZEs/Zhyyy238Pa3v53e3l4A9XiK\nPfzwwyxZsoSrrrqK97///fzu7/6uetwECxYsIAgCwjCkXC6TSCTU5ykyd+5cbrvttsbjw+nr3rl4\nwQUX8Oijjx50vKMy4MvlMvl8vvHYsix83z+CFR3bcrkc+XyecrnMhz/8Ya655hqiKMIwjMbyUqlE\nuVymUChM2K5cLh+pso8pd999N52dnY1/gIB6PMVGRkZYtWoVn//85/nkJz/Jn//5n6vHTZDNZtm6\ndStveMMbuOGGG7jyyivV5yly0UUXkUjsuTJ+OH3d+/mX1j2Yo/IafD6fp1KpNB6HYTihKXL4tm/f\nzlVXXcUVV1zBm970Jm699dbGskqlQltb2z59r1QqE/7QZP++853vYBgGjz76KKtXr+a6665jeHi4\nsVw9fuWKxSILFy4klUqxcOFCbNtmx44djeXq8dT4xje+we/8zu/wZ3/2Z2zfvp13vetdeJ7XWK4+\nTx3T3HOOfbC+7v38S+sedP9TX/Irt2zZMh588EEAnnrqKZYsWXKEKzq2DQ4O8t73vpdrr72Wyy67\nDICTTz6Zxx57DIAHH3yQs88+m9NPP50nnngCx3EolUps2LBBvT9E3/rWt/j3f/93br/9dpYuXcot\nt9zCBRdcoB5PobPOOouHHnqIKIrYuXMntVqN17zmNerxFGtra2sEdXt7O77v6/8XTXI4fV22bBkP\nPPBAY92zzjrroPs/Km8289K76J9//nmiKOJv/uZvOOGEE450Wcesm266iR/+8IcsXLiw8dxf/dVf\ncdNNN+F5HgsXLuSmm27Csizuuusu7rzzTqIo4n3vex8XXXTREaz82HTllVeyfPlyTNPkhhtuUI+n\n0N/93d/x2GOPEUURH/nIR5g9e7Z6PMUqlQof+9jHGBgYwPM83vnOd3Lqqaeqz1Okr6+Pj370o9x1\n111s3LjxkPtaq9W47rrrGBgYIJlM8vd///f09PQccKyjMuBFRETklTkqp+hFRETklVHAi4iIxJAC\nXkREJIYU8CIiIjGkgBcREYkhBbxIjPT19XHqqafy5je/ecJ/3/rWt/a7zZ/+6Z+yc+fOVzTuY489\nxpVXXvmK9iEiU0tfDycSM729vXz/+98/5PW/8pWvNLEaETlSFPAix4nzzjuPCy+8kFWrVpHL5fjM\nZz7D7Nmzed3rXsc3v/lNyuUyn/jEJ/B9H9u2ufnmm5k/fz4/+9nP+Id/+AfCMGTOnDl86lOforu7\nm4cffpibb74Z27ZZsGBBY5xNmzaxfPlyRkdHSafT3HDDDZx88slH8MhFjk+aoheJmf7+/n2m6Neu\nXcvIyAjnnnsu99xzD2984xu56aabJmz3b//2b7znPe/h7rvv5sorr+Spp55iaGiIT3ziE/zTP/0T\n99xzD8uWLeNTn/oUruty/fXX84UvfIG7776bdDrd2M91113Htddey3e/+11uvPFGPvKRj7S6BSKC\nzuBFYmd/U/S2bfOWt7wFgEsvvZTPfvazE5a/9rWv5VOf+hQPPfQQF154IRdddBEPPvggp59+OrNn\nzwbg8ssv58tf/jJr166lt7e38RXSl156KZ///OepVCqsWrWKv/zLv2zst1qtMjIyQkdHR7MOWUQm\noYAXOU6Yptm4NWUYhliWNWH5xRdfzJlnnsnPfvYz/u3f/o0HHniACy+8cMI6URTh+z6GYRCGYeP5\nl/YVhiGpVGrCC4wdO3ZQLBabdVgish+aohc5TtRqNe6//35g1/3rL7jgggnLr7nmGp5++mne/va3\nc/XVV/Pcc8/xqle9ipUrV9LX1wfAnXfeyatf/WpOPPFEhoaGWLNmDQA/+MEPACgUCsyfP78R8L/4\nxS/44z/+41YdoojsRWfwIjHz0jX4vZ1zzjkA/OhHP+Jzn/scvb293HLLLRPWef/7389f/dVf8c//\n/M9YlsX1119Pd3c3n/rUp/jQhz6E53nMnDmTT3/60ySTST772c9y7bXXkkgkJryJ7tZbb2X58uV8\n9atfJZlM8rnPfa4xcyAiraO7yYkcJ0488UTWrl17pMsQkRbRFL2IiEgM6QxeREQkhnQGLyIiEkMK\neBERkRhSwIuIiMSQAl5ERCSGFPAiIiIxpIAXERGJof8fOg9FfyxpJFAAAAAASUVORK5CYII=\n",
427 | "text/plain": [
428 | ""
429 | ]
430 | },
431 | "metadata": {},
432 | "output_type": "display_data"
433 | }
434 | ],
435 | "source": [
436 | "metadata = q_iteration(episodes, render=False)"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": null,
442 | "metadata": {
443 | "collapsed": true
444 | },
445 | "outputs": [],
446 | "source": []
447 | }
448 | ],
449 | "metadata": {
450 | "kernelspec": {
451 | "display_name": "Python 3",
452 | "language": "python",
453 | "name": "python3"
454 | },
455 | "language_info": {
456 | "codemirror_mode": {
457 | "name": "ipython",
458 | "version": 3
459 | },
460 | "file_extension": ".py",
461 | "mimetype": "text/x-python",
462 | "name": "python",
463 | "nbconvert_exporter": "python",
464 | "pygments_lexer": "ipython3",
465 | "version": "3.6.1"
466 | }
467 | },
468 | "nbformat": 4,
469 | "nbformat_minor": 2
470 | }
471 |
--------------------------------------------------------------------------------