├── .gitignore
├── Network.py
├── README.md
├── Transformer.ipynb
├── data.py
├── doc
├── buildfig.png
├── compare.png
├── comparegraph.png
├── hyperparams.png
├── imports,png.PNG
├── init.png
├── lossgraph.png
└── training.png
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | __pycache__/
3 | *.py[cod]
4 | .git
5 | data/
6 |
--------------------------------------------------------------------------------
/Network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from utils import *
6 |
7 |
8 | class EncoderLayer(torch.nn.Module):
9 | def __init__(self, dim_val, dim_attn, n_heads = 1):
10 | super(EncoderLayer, self).__init__()
11 | self.attn = MultiHeadAttentionBlock(dim_val, dim_attn , n_heads)
12 | self.fc1 = nn.Linear(dim_val, dim_val)
13 | self.fc2 = nn.Linear(dim_val, dim_val)
14 |
15 | self.norm1 = nn.LayerNorm(dim_val)
16 | self.norm2 = nn.LayerNorm(dim_val)
17 |
18 | def forward(self, x):
19 | a = self.attn(x)
20 | x = self.norm1(x + a)
21 |
22 | a = self.fc1(F.elu(self.fc2(x)))
23 | x = self.norm2(x + a)
24 |
25 | return x
26 |
27 | class DecoderLayer(torch.nn.Module):
28 | def __init__(self, dim_val, dim_attn, n_heads = 1):
29 | super(DecoderLayer, self).__init__()
30 | self.attn1 = MultiHeadAttentionBlock(dim_val, dim_attn, n_heads)
31 | self.attn2 = MultiHeadAttentionBlock(dim_val, dim_attn, n_heads)
32 | self.fc1 = nn.Linear(dim_val, dim_val)
33 | self.fc2 = nn.Linear(dim_val, dim_val)
34 |
35 | self.norm1 = nn.LayerNorm(dim_val)
36 | self.norm2 = nn.LayerNorm(dim_val)
37 | self.norm3 = nn.LayerNorm(dim_val)
38 |
39 | def forward(self, x, enc):
40 | a = self.attn1(x)
41 | x = self.norm1(a + x)
42 |
43 | a = self.attn2(x, kv = enc)
44 | x = self.norm2(a + x)
45 |
46 | a = self.fc1(F.elu(self.fc2(x)))
47 |
48 | x = self.norm3(x + a)
49 | return x
50 |
51 | class Transformer(torch.nn.Module):
52 | def __init__(self, dim_val, dim_attn, input_size, dec_seq_len, out_seq_len, n_decoder_layers = 1, n_encoder_layers = 1, n_heads = 1):
53 | super(Transformer, self).__init__()
54 | self.dec_seq_len = dec_seq_len
55 |
56 | #Initiate encoder and Decoder layers
57 | self.encs = []
58 | for i in range(n_encoder_layers):
59 | self.encs.append(EncoderLayer(dim_val, dim_attn, n_heads))
60 |
61 | self.decs = []
62 | for i in range(n_decoder_layers):
63 | self.decs.append(DecoderLayer(dim_val, dim_attn, n_heads))
64 |
65 | self.pos = PositionalEncoding(dim_val)
66 |
67 | #Dense layers for managing network inputs and outputs
68 | self.enc_input_fc = nn.Linear(input_size, dim_val)
69 | self.dec_input_fc = nn.Linear(input_size, dim_val)
70 | self.out_fc = nn.Linear(dec_seq_len * dim_val, out_seq_len)
71 |
72 | def forward(self, x):
73 | #encoder
74 | e = self.encs[0](self.pos(self.enc_input_fc(x)))
75 | for enc in self.encs[1:]:
76 | e = enc(e)
77 |
78 | #decoder
79 | d = self.decs[0](self.dec_input_fc(x[:,-self.dec_seq_len:]), e)
80 | for dec in self.decs[1:]:
81 | d = dec(d, e)
82 |
83 | #output
84 | x = self.out_fc(d.flatten(start_dim=1))
85 |
86 | return x
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is an implementation of the Transformer algorithm on time series data in pytorch. In this case the modelling of the sigmoid function is used as a toy problem
2 |
3 | Usage:
4 | First all the necessary imports as well as matplotlib for visualisation.
5 | 
6 | Next we need to define some hyperparameters which will vary depending on the task.
7 | 
8 | We initilisise the Network and an optimizier, in this case Adam, as well as an empty list to track losses for visualisation.
9 | 
10 | Using matplotlib in jupyter notebook we can graph losses in real time, first lets initialise a figure.
11 | 
12 | We can now being training
13 | 
14 | You should see a live plot that looks similar to this tracking the ouput error
15 | 
16 | Now that the network is trained, lets give it the first few values of the sigmoid function and see how it approximates the rest.
17 | We create another figure to visualise this.
18 | 
19 | If all went well, the output should look something like this :
20 | 
21 | Note that the network uses past values instead of the x axis for its predictions , so it makes sense that the output is offset.
22 | However it did succesfully captured the shape.
23 |
24 | Resources:
25 | * Attention is all you need : https://arxiv.org/abs/1706.03762
26 | * Deep Transformer Models for Time Series Forecasting : https://arxiv.org/abs/2001.08317
27 |
28 |
29 |
--------------------------------------------------------------------------------
/Transformer.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.nn as nn\n",
11 | "import torch.nn.functional as F\n",
12 | "import numpy as np\n",
13 | "from utils import *\n",
14 | "from Network import *\n",
15 | "from data import *\n",
16 | "\n",
17 | "%matplotlib notebook\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "\n",
20 | "#hyperparams\n",
21 | "enc_seq_len = 6\n",
22 | "dec_seq_len = 2\n",
23 | "output_sequence_length = 1\n",
24 | "\n",
25 | "dim_val = 10\n",
26 | "dim_attn = 5\n",
27 | "lr = 0.002\n",
28 | "epochs = 20\n",
29 | "\n",
30 | "n_heads = 3 \n",
31 | "\n",
32 | "n_decoder_layers = 3\n",
33 | "n_encoder_layers = 3\n",
34 | "\n",
35 | "batch_size = 15\n",
36 | "\n",
37 | "#init network and optimizer\n",
38 | "t = Transformer(dim_val, dim_attn, 1,dec_seq_len, output_sequence_length, n_decoder_layers, n_encoder_layers, n_heads)\n",
39 | "optimizer = torch.optim.Adam(t.parameters(), lr=lr)\n",
40 | "\n",
41 | "#keep track of loss for graph\n",
42 | "losses = []"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 3,
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "name": "stderr",
52 | "output_type": "stream",
53 | "text": [
54 | "D:\\OneDrive\\GitHub\\Pytorch-Chatbot\\data.py:52: FutureWarning: \n",
55 | "Passing list-likes to .loc or [] with any missing label will raise\n",
56 | "KeyError in the future, you can use .reindex() as an alternative.\n",
57 | "\n",
58 | "See the documentation here:\n",
59 | "https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#deprecate-loc-reindex-listlike\n",
60 | " vectorized_targets.append(torch.tensor(word_vecs.loc[b[idx + 1]].values).type(dtype))\n",
61 | "D:\\OneDrive\\GitHub\\Pytorch-Chatbot\\data.py:51: FutureWarning: \n",
62 | "Passing list-likes to .loc or [] with any missing label will raise\n",
63 | "KeyError in the future, you can use .reindex() as an alternative.\n",
64 | "\n",
65 | "See the documentation here:\n",
66 | "https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#deprecate-loc-reindex-listlike\n",
67 | " vectorized_inputs.append(torch.tensor(word_vecs.loc[b[idx]].values).type(dtype))\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "g = gen_data(4, gpu = True)\n",
73 | "vec_ins, vec_targs, ins, targs = next(g)"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 2,
79 | "metadata": {
80 | "scrolled": false
81 | },
82 | "outputs": [
83 | {
84 | "data": {
85 | "application/javascript": [
86 | "/* Put everything inside the global mpl namespace */\n",
87 | "window.mpl = {};\n",
88 | "\n",
89 | "\n",
90 | "mpl.get_websocket_type = function() {\n",
91 | " if (typeof(WebSocket) !== 'undefined') {\n",
92 | " return WebSocket;\n",
93 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n",
94 | " return MozWebSocket;\n",
95 | " } else {\n",
96 | " alert('Your browser does not have WebSocket support. ' +\n",
97 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
98 | " 'Firefox 4 and 5 are also supported but you ' +\n",
99 | " 'have to enable WebSockets in about:config.');\n",
100 | " };\n",
101 | "}\n",
102 | "\n",
103 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
104 | " this.id = figure_id;\n",
105 | "\n",
106 | " this.ws = websocket;\n",
107 | "\n",
108 | " this.supports_binary = (this.ws.binaryType != undefined);\n",
109 | "\n",
110 | " if (!this.supports_binary) {\n",
111 | " var warnings = document.getElementById(\"mpl-warnings\");\n",
112 | " if (warnings) {\n",
113 | " warnings.style.display = 'block';\n",
114 | " warnings.textContent = (\n",
115 | " \"This browser does not support binary websocket messages. \" +\n",
116 | " \"Performance may be slow.\");\n",
117 | " }\n",
118 | " }\n",
119 | "\n",
120 | " this.imageObj = new Image();\n",
121 | "\n",
122 | " this.context = undefined;\n",
123 | " this.message = undefined;\n",
124 | " this.canvas = undefined;\n",
125 | " this.rubberband_canvas = undefined;\n",
126 | " this.rubberband_context = undefined;\n",
127 | " this.format_dropdown = undefined;\n",
128 | "\n",
129 | " this.image_mode = 'full';\n",
130 | "\n",
131 | " this.root = $('
');\n",
132 | " this._root_extra_style(this.root)\n",
133 | " this.root.attr('style', 'display: inline-block');\n",
134 | "\n",
135 | " $(parent_element).append(this.root);\n",
136 | "\n",
137 | " this._init_header(this);\n",
138 | " this._init_canvas(this);\n",
139 | " this._init_toolbar(this);\n",
140 | "\n",
141 | " var fig = this;\n",
142 | "\n",
143 | " this.waiting = false;\n",
144 | "\n",
145 | " this.ws.onopen = function () {\n",
146 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
147 | " fig.send_message(\"send_image_mode\", {});\n",
148 | " if (mpl.ratio != 1) {\n",
149 | " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
150 | " }\n",
151 | " fig.send_message(\"refresh\", {});\n",
152 | " }\n",
153 | "\n",
154 | " this.imageObj.onload = function() {\n",
155 | " if (fig.image_mode == 'full') {\n",
156 | " // Full images could contain transparency (where diff images\n",
157 | " // almost always do), so we need to clear the canvas so that\n",
158 | " // there is no ghosting.\n",
159 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
160 | " }\n",
161 | " fig.context.drawImage(fig.imageObj, 0, 0);\n",
162 | " };\n",
163 | "\n",
164 | " this.imageObj.onunload = function() {\n",
165 | " fig.ws.close();\n",
166 | " }\n",
167 | "\n",
168 | " this.ws.onmessage = this._make_on_message_function(this);\n",
169 | "\n",
170 | " this.ondownload = ondownload;\n",
171 | "}\n",
172 | "\n",
173 | "mpl.figure.prototype._init_header = function() {\n",
174 | " var titlebar = $(\n",
175 | " '');\n",
177 | " var titletext = $(\n",
178 | " '');\n",
180 | " titlebar.append(titletext)\n",
181 | " this.root.append(titlebar);\n",
182 | " this.header = titletext[0];\n",
183 | "}\n",
184 | "\n",
185 | "\n",
186 | "\n",
187 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
188 | "\n",
189 | "}\n",
190 | "\n",
191 | "\n",
192 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
193 | "\n",
194 | "}\n",
195 | "\n",
196 | "mpl.figure.prototype._init_canvas = function() {\n",
197 | " var fig = this;\n",
198 | "\n",
199 | " var canvas_div = $('');\n",
200 | "\n",
201 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
202 | "\n",
203 | " function canvas_keyboard_event(event) {\n",
204 | " return fig.key_event(event, event['data']);\n",
205 | " }\n",
206 | "\n",
207 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n",
208 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n",
209 | " this.canvas_div = canvas_div\n",
210 | " this._canvas_extra_style(canvas_div)\n",
211 | " this.root.append(canvas_div);\n",
212 | "\n",
213 | " var canvas = $('');\n",
214 | " canvas.addClass('mpl-canvas');\n",
215 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
216 | "\n",
217 | " this.canvas = canvas[0];\n",
218 | " this.context = canvas[0].getContext(\"2d\");\n",
219 | "\n",
220 | " var backingStore = this.context.backingStorePixelRatio ||\n",
221 | "\tthis.context.webkitBackingStorePixelRatio ||\n",
222 | "\tthis.context.mozBackingStorePixelRatio ||\n",
223 | "\tthis.context.msBackingStorePixelRatio ||\n",
224 | "\tthis.context.oBackingStorePixelRatio ||\n",
225 | "\tthis.context.backingStorePixelRatio || 1;\n",
226 | "\n",
227 | " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
228 | "\n",
229 | " var rubberband = $('');\n",
230 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
231 | "\n",
232 | " var pass_mouse_events = true;\n",
233 | "\n",
234 | " canvas_div.resizable({\n",
235 | " start: function(event, ui) {\n",
236 | " pass_mouse_events = false;\n",
237 | " },\n",
238 | " resize: function(event, ui) {\n",
239 | " fig.request_resize(ui.size.width, ui.size.height);\n",
240 | " },\n",
241 | " stop: function(event, ui) {\n",
242 | " pass_mouse_events = true;\n",
243 | " fig.request_resize(ui.size.width, ui.size.height);\n",
244 | " },\n",
245 | " });\n",
246 | "\n",
247 | " function mouse_event_fn(event) {\n",
248 | " if (pass_mouse_events)\n",
249 | " return fig.mouse_event(event, event['data']);\n",
250 | " }\n",
251 | "\n",
252 | " rubberband.mousedown('button_press', mouse_event_fn);\n",
253 | " rubberband.mouseup('button_release', mouse_event_fn);\n",
254 | " // Throttle sequential mouse events to 1 every 20ms.\n",
255 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n",
256 | "\n",
257 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
258 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
259 | "\n",
260 | " canvas_div.on(\"wheel\", function (event) {\n",
261 | " event = event.originalEvent;\n",
262 | " event['data'] = 'scroll'\n",
263 | " if (event.deltaY < 0) {\n",
264 | " event.step = 1;\n",
265 | " } else {\n",
266 | " event.step = -1;\n",
267 | " }\n",
268 | " mouse_event_fn(event);\n",
269 | " });\n",
270 | "\n",
271 | " canvas_div.append(canvas);\n",
272 | " canvas_div.append(rubberband);\n",
273 | "\n",
274 | " this.rubberband = rubberband;\n",
275 | " this.rubberband_canvas = rubberband[0];\n",
276 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
277 | " this.rubberband_context.strokeStyle = \"#000000\";\n",
278 | "\n",
279 | " this._resize_canvas = function(width, height) {\n",
280 | " // Keep the size of the canvas, canvas container, and rubber band\n",
281 | " // canvas in synch.\n",
282 | " canvas_div.css('width', width)\n",
283 | " canvas_div.css('height', height)\n",
284 | "\n",
285 | " canvas.attr('width', width * mpl.ratio);\n",
286 | " canvas.attr('height', height * mpl.ratio);\n",
287 | " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
288 | "\n",
289 | " rubberband.attr('width', width);\n",
290 | " rubberband.attr('height', height);\n",
291 | " }\n",
292 | "\n",
293 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n",
294 | " // upon first draw.\n",
295 | " this._resize_canvas(600, 600);\n",
296 | "\n",
297 | " // Disable right mouse context menu.\n",
298 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
299 | " return false;\n",
300 | " });\n",
301 | "\n",
302 | " function set_focus () {\n",
303 | " canvas.focus();\n",
304 | " canvas_div.focus();\n",
305 | " }\n",
306 | "\n",
307 | " window.setTimeout(set_focus, 100);\n",
308 | "}\n",
309 | "\n",
310 | "mpl.figure.prototype._init_toolbar = function() {\n",
311 | " var fig = this;\n",
312 | "\n",
313 | " var nav_element = $('');\n",
314 | " nav_element.attr('style', 'width: 100%');\n",
315 | " this.root.append(nav_element);\n",
316 | "\n",
317 | " // Define a callback function for later on.\n",
318 | " function toolbar_event(event) {\n",
319 | " return fig.toolbar_button_onclick(event['data']);\n",
320 | " }\n",
321 | " function toolbar_mouse_event(event) {\n",
322 | " return fig.toolbar_button_onmouseover(event['data']);\n",
323 | " }\n",
324 | "\n",
325 | " for(var toolbar_ind in mpl.toolbar_items) {\n",
326 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
327 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
328 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
329 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
330 | "\n",
331 | " if (!name) {\n",
332 | " // put a spacer in here.\n",
333 | " continue;\n",
334 | " }\n",
335 | " var button = $('');\n",
336 | " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
337 | " 'ui-button-icon-only');\n",
338 | " button.attr('role', 'button');\n",
339 | " button.attr('aria-disabled', 'false');\n",
340 | " button.click(method_name, toolbar_event);\n",
341 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
342 | "\n",
343 | " var icon_img = $('');\n",
344 | " icon_img.addClass('ui-button-icon-primary ui-icon');\n",
345 | " icon_img.addClass(image);\n",
346 | " icon_img.addClass('ui-corner-all');\n",
347 | "\n",
348 | " var tooltip_span = $('');\n",
349 | " tooltip_span.addClass('ui-button-text');\n",
350 | " tooltip_span.html(tooltip);\n",
351 | "\n",
352 | " button.append(icon_img);\n",
353 | " button.append(tooltip_span);\n",
354 | "\n",
355 | " nav_element.append(button);\n",
356 | " }\n",
357 | "\n",
358 | " var fmt_picker_span = $('');\n",
359 | "\n",
360 | " var fmt_picker = $('');\n",
361 | " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
362 | " fmt_picker_span.append(fmt_picker);\n",
363 | " nav_element.append(fmt_picker_span);\n",
364 | " this.format_dropdown = fmt_picker[0];\n",
365 | "\n",
366 | " for (var ind in mpl.extensions) {\n",
367 | " var fmt = mpl.extensions[ind];\n",
368 | " var option = $(\n",
369 | " '', {selected: fmt === mpl.default_extension}).html(fmt);\n",
370 | " fmt_picker.append(option);\n",
371 | " }\n",
372 | "\n",
373 | " // Add hover states to the ui-buttons\n",
374 | " $( \".ui-button\" ).hover(\n",
375 | " function() { $(this).addClass(\"ui-state-hover\");},\n",
376 | " function() { $(this).removeClass(\"ui-state-hover\");}\n",
377 | " );\n",
378 | "\n",
379 | " var status_bar = $('');\n",
380 | " nav_element.append(status_bar);\n",
381 | " this.message = status_bar[0];\n",
382 | "}\n",
383 | "\n",
384 | "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
385 | " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
386 | " // which will in turn request a refresh of the image.\n",
387 | " this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
388 | "}\n",
389 | "\n",
390 | "mpl.figure.prototype.send_message = function(type, properties) {\n",
391 | " properties['type'] = type;\n",
392 | " properties['figure_id'] = this.id;\n",
393 | " this.ws.send(JSON.stringify(properties));\n",
394 | "}\n",
395 | "\n",
396 | "mpl.figure.prototype.send_draw_message = function() {\n",
397 | " if (!this.waiting) {\n",
398 | " this.waiting = true;\n",
399 | " this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
400 | " }\n",
401 | "}\n",
402 | "\n",
403 | "\n",
404 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
405 | " var format_dropdown = fig.format_dropdown;\n",
406 | " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
407 | " fig.ondownload(fig, format);\n",
408 | "}\n",
409 | "\n",
410 | "\n",
411 | "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
412 | " var size = msg['size'];\n",
413 | " if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
414 | " fig._resize_canvas(size[0], size[1]);\n",
415 | " fig.send_message(\"refresh\", {});\n",
416 | " };\n",
417 | "}\n",
418 | "\n",
419 | "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
420 | " var x0 = msg['x0'] / mpl.ratio;\n",
421 | " var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
422 | " var x1 = msg['x1'] / mpl.ratio;\n",
423 | " var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
424 | " x0 = Math.floor(x0) + 0.5;\n",
425 | " y0 = Math.floor(y0) + 0.5;\n",
426 | " x1 = Math.floor(x1) + 0.5;\n",
427 | " y1 = Math.floor(y1) + 0.5;\n",
428 | " var min_x = Math.min(x0, x1);\n",
429 | " var min_y = Math.min(y0, y1);\n",
430 | " var width = Math.abs(x1 - x0);\n",
431 | " var height = Math.abs(y1 - y0);\n",
432 | "\n",
433 | " fig.rubberband_context.clearRect(\n",
434 | " 0, 0, fig.canvas.width / mpl.ratio, fig.canvas.height / mpl.ratio);\n",
435 | "\n",
436 | " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
437 | "}\n",
438 | "\n",
439 | "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
440 | " // Updates the figure title.\n",
441 | " fig.header.textContent = msg['label'];\n",
442 | "}\n",
443 | "\n",
444 | "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
445 | " var cursor = msg['cursor'];\n",
446 | " switch(cursor)\n",
447 | " {\n",
448 | " case 0:\n",
449 | " cursor = 'pointer';\n",
450 | " break;\n",
451 | " case 1:\n",
452 | " cursor = 'default';\n",
453 | " break;\n",
454 | " case 2:\n",
455 | " cursor = 'crosshair';\n",
456 | " break;\n",
457 | " case 3:\n",
458 | " cursor = 'move';\n",
459 | " break;\n",
460 | " }\n",
461 | " fig.rubberband_canvas.style.cursor = cursor;\n",
462 | "}\n",
463 | "\n",
464 | "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
465 | " fig.message.textContent = msg['message'];\n",
466 | "}\n",
467 | "\n",
468 | "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
469 | " // Request the server to send over a new figure.\n",
470 | " fig.send_draw_message();\n",
471 | "}\n",
472 | "\n",
473 | "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
474 | " fig.image_mode = msg['mode'];\n",
475 | "}\n",
476 | "\n",
477 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
478 | " // Called whenever the canvas gets updated.\n",
479 | " this.send_message(\"ack\", {});\n",
480 | "}\n",
481 | "\n",
482 | "// A function to construct a web socket function for onmessage handling.\n",
483 | "// Called in the figure constructor.\n",
484 | "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
485 | " return function socket_on_message(evt) {\n",
486 | " if (evt.data instanceof Blob) {\n",
487 | " /* FIXME: We get \"Resource interpreted as Image but\n",
488 | " * transferred with MIME type text/plain:\" errors on\n",
489 | " * Chrome. But how to set the MIME type? It doesn't seem\n",
490 | " * to be part of the websocket stream */\n",
491 | " evt.data.type = \"image/png\";\n",
492 | "\n",
493 | " /* Free the memory for the previous frames */\n",
494 | " if (fig.imageObj.src) {\n",
495 | " (window.URL || window.webkitURL).revokeObjectURL(\n",
496 | " fig.imageObj.src);\n",
497 | " }\n",
498 | "\n",
499 | " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
500 | " evt.data);\n",
501 | " fig.updated_canvas_event();\n",
502 | " fig.waiting = false;\n",
503 | " return;\n",
504 | " }\n",
505 | " else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
506 | " fig.imageObj.src = evt.data;\n",
507 | " fig.updated_canvas_event();\n",
508 | " fig.waiting = false;\n",
509 | " return;\n",
510 | " }\n",
511 | "\n",
512 | " var msg = JSON.parse(evt.data);\n",
513 | " var msg_type = msg['type'];\n",
514 | "\n",
515 | " // Call the \"handle_{type}\" callback, which takes\n",
516 | " // the figure and JSON message as its only arguments.\n",
517 | " try {\n",
518 | " var callback = fig[\"handle_\" + msg_type];\n",
519 | " } catch (e) {\n",
520 | " console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
521 | " return;\n",
522 | " }\n",
523 | "\n",
524 | " if (callback) {\n",
525 | " try {\n",
526 | " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
527 | " callback(fig, msg);\n",
528 | " } catch (e) {\n",
529 | " console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
530 | " }\n",
531 | " }\n",
532 | " };\n",
533 | "}\n",
534 | "\n",
535 | "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
536 | "mpl.findpos = function(e) {\n",
537 | " //this section is from http://www.quirksmode.org/js/events_properties.html\n",
538 | " var targ;\n",
539 | " if (!e)\n",
540 | " e = window.event;\n",
541 | " if (e.target)\n",
542 | " targ = e.target;\n",
543 | " else if (e.srcElement)\n",
544 | " targ = e.srcElement;\n",
545 | " if (targ.nodeType == 3) // defeat Safari bug\n",
546 | " targ = targ.parentNode;\n",
547 | "\n",
548 | " // jQuery normalizes the pageX and pageY\n",
549 | " // pageX,Y are the mouse positions relative to the document\n",
550 | " // offset() returns the position of the element relative to the document\n",
551 | " var x = e.pageX - $(targ).offset().left;\n",
552 | " var y = e.pageY - $(targ).offset().top;\n",
553 | "\n",
554 | " return {\"x\": x, \"y\": y};\n",
555 | "};\n",
556 | "\n",
557 | "/*\n",
558 | " * return a copy of an object with only non-object keys\n",
559 | " * we need this to avoid circular references\n",
560 | " * http://stackoverflow.com/a/24161582/3208463\n",
561 | " */\n",
562 | "function simpleKeys (original) {\n",
563 | " return Object.keys(original).reduce(function (obj, key) {\n",
564 | " if (typeof original[key] !== 'object')\n",
565 | " obj[key] = original[key]\n",
566 | " return obj;\n",
567 | " }, {});\n",
568 | "}\n",
569 | "\n",
570 | "mpl.figure.prototype.mouse_event = function(event, name) {\n",
571 | " var canvas_pos = mpl.findpos(event)\n",
572 | "\n",
573 | " if (name === 'button_press')\n",
574 | " {\n",
575 | " this.canvas.focus();\n",
576 | " this.canvas_div.focus();\n",
577 | " }\n",
578 | "\n",
579 | " var x = canvas_pos.x * mpl.ratio;\n",
580 | " var y = canvas_pos.y * mpl.ratio;\n",
581 | "\n",
582 | " this.send_message(name, {x: x, y: y, button: event.button,\n",
583 | " step: event.step,\n",
584 | " guiEvent: simpleKeys(event)});\n",
585 | "\n",
586 | " /* This prevents the web browser from automatically changing to\n",
587 | " * the text insertion cursor when the button is pressed. We want\n",
588 | " * to control all of the cursor setting manually through the\n",
589 | " * 'cursor' event from matplotlib */\n",
590 | " event.preventDefault();\n",
591 | " return false;\n",
592 | "}\n",
593 | "\n",
594 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
595 | " // Handle any extra behaviour associated with a key event\n",
596 | "}\n",
597 | "\n",
598 | "mpl.figure.prototype.key_event = function(event, name) {\n",
599 | "\n",
600 | " // Prevent repeat events\n",
601 | " if (name == 'key_press')\n",
602 | " {\n",
603 | " if (event.which === this._key)\n",
604 | " return;\n",
605 | " else\n",
606 | " this._key = event.which;\n",
607 | " }\n",
608 | " if (name == 'key_release')\n",
609 | " this._key = null;\n",
610 | "\n",
611 | " var value = '';\n",
612 | " if (event.ctrlKey && event.which != 17)\n",
613 | " value += \"ctrl+\";\n",
614 | " if (event.altKey && event.which != 18)\n",
615 | " value += \"alt+\";\n",
616 | " if (event.shiftKey && event.which != 16)\n",
617 | " value += \"shift+\";\n",
618 | "\n",
619 | " value += 'k';\n",
620 | " value += event.which.toString();\n",
621 | "\n",
622 | " this._key_event_extra(event, name);\n",
623 | "\n",
624 | " this.send_message(name, {key: value,\n",
625 | " guiEvent: simpleKeys(event)});\n",
626 | " return false;\n",
627 | "}\n",
628 | "\n",
629 | "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
630 | " if (name == 'download') {\n",
631 | " this.handle_save(this, null);\n",
632 | " } else {\n",
633 | " this.send_message(\"toolbar_button\", {name: name});\n",
634 | " }\n",
635 | "};\n",
636 | "\n",
637 | "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
638 | " this.message.textContent = tooltip;\n",
639 | "};\n",
640 | "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
641 | "\n",
642 | "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
643 | "\n",
644 | "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
645 | " // Create a \"websocket\"-like object which calls the given IPython comm\n",
646 | " // object with the appropriate methods. Currently this is a non binary\n",
647 | " // socket, so there is still some room for performance tuning.\n",
648 | " var ws = {};\n",
649 | "\n",
650 | " ws.close = function() {\n",
651 | " comm.close()\n",
652 | " };\n",
653 | " ws.send = function(m) {\n",
654 | " //console.log('sending', m);\n",
655 | " comm.send(m);\n",
656 | " };\n",
657 | " // Register the callback with on_msg.\n",
658 | " comm.on_msg(function(msg) {\n",
659 | " //console.log('receiving', msg['content']['data'], msg);\n",
660 | " // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
661 | " ws.onmessage(msg['content']['data'])\n",
662 | " });\n",
663 | " return ws;\n",
664 | "}\n",
665 | "\n",
666 | "mpl.mpl_figure_comm = function(comm, msg) {\n",
667 | " // This is the function which gets called when the mpl process\n",
668 | " // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
669 | "\n",
670 | " var id = msg.content.data.id;\n",
671 | " // Get hold of the div created by the display call when the Comm\n",
672 | " // socket was opened in Python.\n",
673 | " var element = $(\"#\" + id);\n",
674 | " var ws_proxy = comm_websocket_adapter(comm)\n",
675 | "\n",
676 | " function ondownload(figure, format) {\n",
677 | " window.open(figure.imageObj.src);\n",
678 | " }\n",
679 | "\n",
680 | " var fig = new mpl.figure(id, ws_proxy,\n",
681 | " ondownload,\n",
682 | " element.get(0));\n",
683 | "\n",
684 | " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
685 | " // web socket which is closed, not our websocket->open comm proxy.\n",
686 | " ws_proxy.onopen();\n",
687 | "\n",
688 | " fig.parent_element = element.get(0);\n",
689 | " fig.cell_info = mpl.find_output_cell(\"\");\n",
690 | " if (!fig.cell_info) {\n",
691 | " console.error(\"Failed to find cell for figure\", id, fig);\n",
692 | " return;\n",
693 | " }\n",
694 | "\n",
695 | " var output_index = fig.cell_info[2]\n",
696 | " var cell = fig.cell_info[0];\n",
697 | "\n",
698 | "};\n",
699 | "\n",
700 | "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
701 | " var width = fig.canvas.width/mpl.ratio\n",
702 | " fig.root.unbind('remove')\n",
703 | "\n",
704 | " // Update the output cell to use the data from the current canvas.\n",
705 | " fig.push_to_output();\n",
706 | " var dataURL = fig.canvas.toDataURL();\n",
707 | " // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
708 | " // the notebook keyboard shortcuts fail.\n",
709 | " IPython.keyboard_manager.enable()\n",
710 | " $(fig.parent_element).html('
');\n",
711 | " fig.close_ws(fig, msg);\n",
712 | "}\n",
713 | "\n",
714 | "mpl.figure.prototype.close_ws = function(fig, msg){\n",
715 | " fig.send_message('closing', msg);\n",
716 | " // fig.ws.close()\n",
717 | "}\n",
718 | "\n",
719 | "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
720 | " // Turn the data on the canvas into data in the output cell.\n",
721 | " var width = this.canvas.width/mpl.ratio\n",
722 | " var dataURL = this.canvas.toDataURL();\n",
723 | " this.cell_info[1]['text/html'] = '
';\n",
724 | "}\n",
725 | "\n",
726 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
727 | " // Tell IPython that the notebook contents must change.\n",
728 | " IPython.notebook.set_dirty(true);\n",
729 | " this.send_message(\"ack\", {});\n",
730 | " var fig = this;\n",
731 | " // Wait a second, then push the new image to the DOM so\n",
732 | " // that it is saved nicely (might be nice to debounce this).\n",
733 | " setTimeout(function () { fig.push_to_output() }, 1000);\n",
734 | "}\n",
735 | "\n",
736 | "mpl.figure.prototype._init_toolbar = function() {\n",
737 | " var fig = this;\n",
738 | "\n",
739 | " var nav_element = $('');\n",
740 | " nav_element.attr('style', 'width: 100%');\n",
741 | " this.root.append(nav_element);\n",
742 | "\n",
743 | " // Define a callback function for later on.\n",
744 | " function toolbar_event(event) {\n",
745 | " return fig.toolbar_button_onclick(event['data']);\n",
746 | " }\n",
747 | " function toolbar_mouse_event(event) {\n",
748 | " return fig.toolbar_button_onmouseover(event['data']);\n",
749 | " }\n",
750 | "\n",
751 | " for(var toolbar_ind in mpl.toolbar_items){\n",
752 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
753 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
754 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
755 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
756 | "\n",
757 | " if (!name) { continue; };\n",
758 | "\n",
759 | " var button = $('');\n",
760 | " button.click(method_name, toolbar_event);\n",
761 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
762 | " nav_element.append(button);\n",
763 | " }\n",
764 | "\n",
765 | " // Add the status bar.\n",
766 | " var status_bar = $('');\n",
767 | " nav_element.append(status_bar);\n",
768 | " this.message = status_bar[0];\n",
769 | "\n",
770 | " // Add the close button to the window.\n",
771 | " var buttongrp = $('');\n",
772 | " var button = $('');\n",
773 | " button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
774 | " button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
775 | " buttongrp.append(button);\n",
776 | " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
777 | " titlebar.prepend(buttongrp);\n",
778 | "}\n",
779 | "\n",
780 | "mpl.figure.prototype._root_extra_style = function(el){\n",
781 | " var fig = this\n",
782 | " el.on(\"remove\", function(){\n",
783 | "\tfig.close_ws(fig, {});\n",
784 | " });\n",
785 | "}\n",
786 | "\n",
787 | "mpl.figure.prototype._canvas_extra_style = function(el){\n",
788 | " // this is important to make the div 'focusable\n",
789 | " el.attr('tabindex', 0)\n",
790 | " // reach out to IPython and tell the keyboard manager to turn it's self\n",
791 | " // off when our div gets focus\n",
792 | "\n",
793 | " // location in version 3\n",
794 | " if (IPython.notebook.keyboard_manager) {\n",
795 | " IPython.notebook.keyboard_manager.register_events(el);\n",
796 | " }\n",
797 | " else {\n",
798 | " // location in version 2\n",
799 | " IPython.keyboard_manager.register_events(el);\n",
800 | " }\n",
801 | "\n",
802 | "}\n",
803 | "\n",
804 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
805 | " var manager = IPython.notebook.keyboard_manager;\n",
806 | " if (!manager)\n",
807 | " manager = IPython.keyboard_manager;\n",
808 | "\n",
809 | " // Check for shift+enter\n",
810 | " if (event.shiftKey && event.which == 13) {\n",
811 | " this.canvas_div.blur();\n",
812 | " event.shiftKey = false;\n",
813 | " // Send a \"J\" for go to next cell\n",
814 | " event.which = 74;\n",
815 | " event.keyCode = 74;\n",
816 | " manager.command_mode();\n",
817 | " manager.handle_keydown(event);\n",
818 | " }\n",
819 | "}\n",
820 | "\n",
821 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
822 | " fig.ondownload(fig, null);\n",
823 | "}\n",
824 | "\n",
825 | "\n",
826 | "mpl.find_output_cell = function(html_output) {\n",
827 | " // Return the cell and output element which can be found *uniquely* in the notebook.\n",
828 | " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
829 | " // IPython event is triggered only after the cells have been serialised, which for\n",
830 | " // our purposes (turning an active figure into a static one), is too late.\n",
831 | " var cells = IPython.notebook.get_cells();\n",
832 | " var ncells = cells.length;\n",
833 | " for (var i=0; i= 3 moved mimebundle to data attribute of output\n",
840 | " data = data.data;\n",
841 | " }\n",
842 | " if (data['text/html'] == html_output) {\n",
843 | " return [cell, data, j];\n",
844 | " }\n",
845 | " }\n",
846 | " }\n",
847 | " }\n",
848 | "}\n",
849 | "\n",
850 | "// Register the function which deals with the matplotlib target/channel.\n",
851 | "// The kernel may be null if the page has been refreshed.\n",
852 | "if (IPython.notebook.kernel != null) {\n",
853 | " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
854 | "}\n"
855 | ],
856 | "text/plain": [
857 | ""
858 | ]
859 | },
860 | "metadata": {},
861 | "output_type": "display_data"
862 | },
863 | {
864 | "data": {
865 | "text/html": [
866 | "
"
867 | ],
868 | "text/plain": [
869 | ""
870 | ]
871 | },
872 | "metadata": {},
873 | "output_type": "display_data"
874 | }
875 | ],
876 | "source": [
877 | "#build live matplotlib fig\n",
878 | "fig = plt.figure()\n",
879 | "\n",
880 | "ax = fig.add_subplot(111)\n",
881 | "plt.ion()\n",
882 | "\n",
883 | "fig.show()\n",
884 | "fig.canvas.draw()\n",
885 | "\n",
886 | " \n",
887 | "for e in range(epochs):\n",
888 | " out = []\n",
889 | " \n",
890 | " for b in range(-10- enc_seq_len, 10 - enc_seq_len):\n",
891 | " optimizer.zero_grad()\n",
892 | " X, Y = get_data(batch_size, enc_seq_len, output_sequence_length)\n",
893 | " \n",
894 | " #Forward pass and calculate loss\n",
895 | " net_out = t(X)\n",
896 | " #print(net_out.shape,Y.shape)\n",
897 | " loss = torch.mean((net_out - Y) ** 2)\n",
898 | "\n",
899 | " #backwards pass\n",
900 | " loss.backward()\n",
901 | " optimizer.step()\n",
902 | "\n",
903 | " #Track losses and draw rgaph\n",
904 | " out.append([net_out.detach().numpy(), Y])\n",
905 | " losses.append(loss)\n",
906 | "\n",
907 | " ax.clear()\n",
908 | " ax.plot(losses)\n",
909 | " ax.set_title(\"Mean Squared Error\")\n",
910 | " fig.canvas.draw()\n",
911 | "\n"
912 | ]
913 | },
914 | {
915 | "cell_type": "code",
916 | "execution_count": 5,
917 | "metadata": {
918 | "scrolled": false
919 | },
920 | "outputs": [
921 | {
922 | "data": {
923 | "application/javascript": [
924 | "/* Put everything inside the global mpl namespace */\n",
925 | "window.mpl = {};\n",
926 | "\n",
927 | "\n",
928 | "mpl.get_websocket_type = function() {\n",
929 | " if (typeof(WebSocket) !== 'undefined') {\n",
930 | " return WebSocket;\n",
931 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n",
932 | " return MozWebSocket;\n",
933 | " } else {\n",
934 | " alert('Your browser does not have WebSocket support. ' +\n",
935 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
936 | " 'Firefox 4 and 5 are also supported but you ' +\n",
937 | " 'have to enable WebSockets in about:config.');\n",
938 | " };\n",
939 | "}\n",
940 | "\n",
941 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
942 | " this.id = figure_id;\n",
943 | "\n",
944 | " this.ws = websocket;\n",
945 | "\n",
946 | " this.supports_binary = (this.ws.binaryType != undefined);\n",
947 | "\n",
948 | " if (!this.supports_binary) {\n",
949 | " var warnings = document.getElementById(\"mpl-warnings\");\n",
950 | " if (warnings) {\n",
951 | " warnings.style.display = 'block';\n",
952 | " warnings.textContent = (\n",
953 | " \"This browser does not support binary websocket messages. \" +\n",
954 | " \"Performance may be slow.\");\n",
955 | " }\n",
956 | " }\n",
957 | "\n",
958 | " this.imageObj = new Image();\n",
959 | "\n",
960 | " this.context = undefined;\n",
961 | " this.message = undefined;\n",
962 | " this.canvas = undefined;\n",
963 | " this.rubberband_canvas = undefined;\n",
964 | " this.rubberband_context = undefined;\n",
965 | " this.format_dropdown = undefined;\n",
966 | "\n",
967 | " this.image_mode = 'full';\n",
968 | "\n",
969 | " this.root = $('');\n",
970 | " this._root_extra_style(this.root)\n",
971 | " this.root.attr('style', 'display: inline-block');\n",
972 | "\n",
973 | " $(parent_element).append(this.root);\n",
974 | "\n",
975 | " this._init_header(this);\n",
976 | " this._init_canvas(this);\n",
977 | " this._init_toolbar(this);\n",
978 | "\n",
979 | " var fig = this;\n",
980 | "\n",
981 | " this.waiting = false;\n",
982 | "\n",
983 | " this.ws.onopen = function () {\n",
984 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
985 | " fig.send_message(\"send_image_mode\", {});\n",
986 | " if (mpl.ratio != 1) {\n",
987 | " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
988 | " }\n",
989 | " fig.send_message(\"refresh\", {});\n",
990 | " }\n",
991 | "\n",
992 | " this.imageObj.onload = function() {\n",
993 | " if (fig.image_mode == 'full') {\n",
994 | " // Full images could contain transparency (where diff images\n",
995 | " // almost always do), so we need to clear the canvas so that\n",
996 | " // there is no ghosting.\n",
997 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
998 | " }\n",
999 | " fig.context.drawImage(fig.imageObj, 0, 0);\n",
1000 | " };\n",
1001 | "\n",
1002 | " this.imageObj.onunload = function() {\n",
1003 | " fig.ws.close();\n",
1004 | " }\n",
1005 | "\n",
1006 | " this.ws.onmessage = this._make_on_message_function(this);\n",
1007 | "\n",
1008 | " this.ondownload = ondownload;\n",
1009 | "}\n",
1010 | "\n",
1011 | "mpl.figure.prototype._init_header = function() {\n",
1012 | " var titlebar = $(\n",
1013 | " '');\n",
1015 | " var titletext = $(\n",
1016 | " '');\n",
1018 | " titlebar.append(titletext)\n",
1019 | " this.root.append(titlebar);\n",
1020 | " this.header = titletext[0];\n",
1021 | "}\n",
1022 | "\n",
1023 | "\n",
1024 | "\n",
1025 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
1026 | "\n",
1027 | "}\n",
1028 | "\n",
1029 | "\n",
1030 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
1031 | "\n",
1032 | "}\n",
1033 | "\n",
1034 | "mpl.figure.prototype._init_canvas = function() {\n",
1035 | " var fig = this;\n",
1036 | "\n",
1037 | " var canvas_div = $('');\n",
1038 | "\n",
1039 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
1040 | "\n",
1041 | " function canvas_keyboard_event(event) {\n",
1042 | " return fig.key_event(event, event['data']);\n",
1043 | " }\n",
1044 | "\n",
1045 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n",
1046 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n",
1047 | " this.canvas_div = canvas_div\n",
1048 | " this._canvas_extra_style(canvas_div)\n",
1049 | " this.root.append(canvas_div);\n",
1050 | "\n",
1051 | " var canvas = $('');\n",
1052 | " canvas.addClass('mpl-canvas');\n",
1053 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
1054 | "\n",
1055 | " this.canvas = canvas[0];\n",
1056 | " this.context = canvas[0].getContext(\"2d\");\n",
1057 | "\n",
1058 | " var backingStore = this.context.backingStorePixelRatio ||\n",
1059 | "\tthis.context.webkitBackingStorePixelRatio ||\n",
1060 | "\tthis.context.mozBackingStorePixelRatio ||\n",
1061 | "\tthis.context.msBackingStorePixelRatio ||\n",
1062 | "\tthis.context.oBackingStorePixelRatio ||\n",
1063 | "\tthis.context.backingStorePixelRatio || 1;\n",
1064 | "\n",
1065 | " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
1066 | "\n",
1067 | " var rubberband = $('');\n",
1068 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
1069 | "\n",
1070 | " var pass_mouse_events = true;\n",
1071 | "\n",
1072 | " canvas_div.resizable({\n",
1073 | " start: function(event, ui) {\n",
1074 | " pass_mouse_events = false;\n",
1075 | " },\n",
1076 | " resize: function(event, ui) {\n",
1077 | " fig.request_resize(ui.size.width, ui.size.height);\n",
1078 | " },\n",
1079 | " stop: function(event, ui) {\n",
1080 | " pass_mouse_events = true;\n",
1081 | " fig.request_resize(ui.size.width, ui.size.height);\n",
1082 | " },\n",
1083 | " });\n",
1084 | "\n",
1085 | " function mouse_event_fn(event) {\n",
1086 | " if (pass_mouse_events)\n",
1087 | " return fig.mouse_event(event, event['data']);\n",
1088 | " }\n",
1089 | "\n",
1090 | " rubberband.mousedown('button_press', mouse_event_fn);\n",
1091 | " rubberband.mouseup('button_release', mouse_event_fn);\n",
1092 | " // Throttle sequential mouse events to 1 every 20ms.\n",
1093 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n",
1094 | "\n",
1095 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
1096 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
1097 | "\n",
1098 | " canvas_div.on(\"wheel\", function (event) {\n",
1099 | " event = event.originalEvent;\n",
1100 | " event['data'] = 'scroll'\n",
1101 | " if (event.deltaY < 0) {\n",
1102 | " event.step = 1;\n",
1103 | " } else {\n",
1104 | " event.step = -1;\n",
1105 | " }\n",
1106 | " mouse_event_fn(event);\n",
1107 | " });\n",
1108 | "\n",
1109 | " canvas_div.append(canvas);\n",
1110 | " canvas_div.append(rubberband);\n",
1111 | "\n",
1112 | " this.rubberband = rubberband;\n",
1113 | " this.rubberband_canvas = rubberband[0];\n",
1114 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
1115 | " this.rubberband_context.strokeStyle = \"#000000\";\n",
1116 | "\n",
1117 | " this._resize_canvas = function(width, height) {\n",
1118 | " // Keep the size of the canvas, canvas container, and rubber band\n",
1119 | " // canvas in synch.\n",
1120 | " canvas_div.css('width', width)\n",
1121 | " canvas_div.css('height', height)\n",
1122 | "\n",
1123 | " canvas.attr('width', width * mpl.ratio);\n",
1124 | " canvas.attr('height', height * mpl.ratio);\n",
1125 | " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
1126 | "\n",
1127 | " rubberband.attr('width', width);\n",
1128 | " rubberband.attr('height', height);\n",
1129 | " }\n",
1130 | "\n",
1131 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n",
1132 | " // upon first draw.\n",
1133 | " this._resize_canvas(600, 600);\n",
1134 | "\n",
1135 | " // Disable right mouse context menu.\n",
1136 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
1137 | " return false;\n",
1138 | " });\n",
1139 | "\n",
1140 | " function set_focus () {\n",
1141 | " canvas.focus();\n",
1142 | " canvas_div.focus();\n",
1143 | " }\n",
1144 | "\n",
1145 | " window.setTimeout(set_focus, 100);\n",
1146 | "}\n",
1147 | "\n",
1148 | "mpl.figure.prototype._init_toolbar = function() {\n",
1149 | " var fig = this;\n",
1150 | "\n",
1151 | " var nav_element = $('');\n",
1152 | " nav_element.attr('style', 'width: 100%');\n",
1153 | " this.root.append(nav_element);\n",
1154 | "\n",
1155 | " // Define a callback function for later on.\n",
1156 | " function toolbar_event(event) {\n",
1157 | " return fig.toolbar_button_onclick(event['data']);\n",
1158 | " }\n",
1159 | " function toolbar_mouse_event(event) {\n",
1160 | " return fig.toolbar_button_onmouseover(event['data']);\n",
1161 | " }\n",
1162 | "\n",
1163 | " for(var toolbar_ind in mpl.toolbar_items) {\n",
1164 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
1165 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
1166 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
1167 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
1168 | "\n",
1169 | " if (!name) {\n",
1170 | " // put a spacer in here.\n",
1171 | " continue;\n",
1172 | " }\n",
1173 | " var button = $('');\n",
1174 | " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
1175 | " 'ui-button-icon-only');\n",
1176 | " button.attr('role', 'button');\n",
1177 | " button.attr('aria-disabled', 'false');\n",
1178 | " button.click(method_name, toolbar_event);\n",
1179 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
1180 | "\n",
1181 | " var icon_img = $('');\n",
1182 | " icon_img.addClass('ui-button-icon-primary ui-icon');\n",
1183 | " icon_img.addClass(image);\n",
1184 | " icon_img.addClass('ui-corner-all');\n",
1185 | "\n",
1186 | " var tooltip_span = $('');\n",
1187 | " tooltip_span.addClass('ui-button-text');\n",
1188 | " tooltip_span.html(tooltip);\n",
1189 | "\n",
1190 | " button.append(icon_img);\n",
1191 | " button.append(tooltip_span);\n",
1192 | "\n",
1193 | " nav_element.append(button);\n",
1194 | " }\n",
1195 | "\n",
1196 | " var fmt_picker_span = $('');\n",
1197 | "\n",
1198 | " var fmt_picker = $('');\n",
1199 | " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
1200 | " fmt_picker_span.append(fmt_picker);\n",
1201 | " nav_element.append(fmt_picker_span);\n",
1202 | " this.format_dropdown = fmt_picker[0];\n",
1203 | "\n",
1204 | " for (var ind in mpl.extensions) {\n",
1205 | " var fmt = mpl.extensions[ind];\n",
1206 | " var option = $(\n",
1207 | " '', {selected: fmt === mpl.default_extension}).html(fmt);\n",
1208 | " fmt_picker.append(option);\n",
1209 | " }\n",
1210 | "\n",
1211 | " // Add hover states to the ui-buttons\n",
1212 | " $( \".ui-button\" ).hover(\n",
1213 | " function() { $(this).addClass(\"ui-state-hover\");},\n",
1214 | " function() { $(this).removeClass(\"ui-state-hover\");}\n",
1215 | " );\n",
1216 | "\n",
1217 | " var status_bar = $('');\n",
1218 | " nav_element.append(status_bar);\n",
1219 | " this.message = status_bar[0];\n",
1220 | "}\n",
1221 | "\n",
1222 | "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
1223 | " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
1224 | " // which will in turn request a refresh of the image.\n",
1225 | " this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
1226 | "}\n",
1227 | "\n",
1228 | "mpl.figure.prototype.send_message = function(type, properties) {\n",
1229 | " properties['type'] = type;\n",
1230 | " properties['figure_id'] = this.id;\n",
1231 | " this.ws.send(JSON.stringify(properties));\n",
1232 | "}\n",
1233 | "\n",
1234 | "mpl.figure.prototype.send_draw_message = function() {\n",
1235 | " if (!this.waiting) {\n",
1236 | " this.waiting = true;\n",
1237 | " this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
1238 | " }\n",
1239 | "}\n",
1240 | "\n",
1241 | "\n",
1242 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
1243 | " var format_dropdown = fig.format_dropdown;\n",
1244 | " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
1245 | " fig.ondownload(fig, format);\n",
1246 | "}\n",
1247 | "\n",
1248 | "\n",
1249 | "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
1250 | " var size = msg['size'];\n",
1251 | " if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
1252 | " fig._resize_canvas(size[0], size[1]);\n",
1253 | " fig.send_message(\"refresh\", {});\n",
1254 | " };\n",
1255 | "}\n",
1256 | "\n",
1257 | "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
1258 | " var x0 = msg['x0'] / mpl.ratio;\n",
1259 | " var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
1260 | " var x1 = msg['x1'] / mpl.ratio;\n",
1261 | " var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
1262 | " x0 = Math.floor(x0) + 0.5;\n",
1263 | " y0 = Math.floor(y0) + 0.5;\n",
1264 | " x1 = Math.floor(x1) + 0.5;\n",
1265 | " y1 = Math.floor(y1) + 0.5;\n",
1266 | " var min_x = Math.min(x0, x1);\n",
1267 | " var min_y = Math.min(y0, y1);\n",
1268 | " var width = Math.abs(x1 - x0);\n",
1269 | " var height = Math.abs(y1 - y0);\n",
1270 | "\n",
1271 | " fig.rubberband_context.clearRect(\n",
1272 | " 0, 0, fig.canvas.width / mpl.ratio, fig.canvas.height / mpl.ratio);\n",
1273 | "\n",
1274 | " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
1275 | "}\n",
1276 | "\n",
1277 | "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
1278 | " // Updates the figure title.\n",
1279 | " fig.header.textContent = msg['label'];\n",
1280 | "}\n",
1281 | "\n",
1282 | "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
1283 | " var cursor = msg['cursor'];\n",
1284 | " switch(cursor)\n",
1285 | " {\n",
1286 | " case 0:\n",
1287 | " cursor = 'pointer';\n",
1288 | " break;\n",
1289 | " case 1:\n",
1290 | " cursor = 'default';\n",
1291 | " break;\n",
1292 | " case 2:\n",
1293 | " cursor = 'crosshair';\n",
1294 | " break;\n",
1295 | " case 3:\n",
1296 | " cursor = 'move';\n",
1297 | " break;\n",
1298 | " }\n",
1299 | " fig.rubberband_canvas.style.cursor = cursor;\n",
1300 | "}\n",
1301 | "\n",
1302 | "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
1303 | " fig.message.textContent = msg['message'];\n",
1304 | "}\n",
1305 | "\n",
1306 | "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
1307 | " // Request the server to send over a new figure.\n",
1308 | " fig.send_draw_message();\n",
1309 | "}\n",
1310 | "\n",
1311 | "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
1312 | " fig.image_mode = msg['mode'];\n",
1313 | "}\n",
1314 | "\n",
1315 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
1316 | " // Called whenever the canvas gets updated.\n",
1317 | " this.send_message(\"ack\", {});\n",
1318 | "}\n",
1319 | "\n",
1320 | "// A function to construct a web socket function for onmessage handling.\n",
1321 | "// Called in the figure constructor.\n",
1322 | "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
1323 | " return function socket_on_message(evt) {\n",
1324 | " if (evt.data instanceof Blob) {\n",
1325 | " /* FIXME: We get \"Resource interpreted as Image but\n",
1326 | " * transferred with MIME type text/plain:\" errors on\n",
1327 | " * Chrome. But how to set the MIME type? It doesn't seem\n",
1328 | " * to be part of the websocket stream */\n",
1329 | " evt.data.type = \"image/png\";\n",
1330 | "\n",
1331 | " /* Free the memory for the previous frames */\n",
1332 | " if (fig.imageObj.src) {\n",
1333 | " (window.URL || window.webkitURL).revokeObjectURL(\n",
1334 | " fig.imageObj.src);\n",
1335 | " }\n",
1336 | "\n",
1337 | " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
1338 | " evt.data);\n",
1339 | " fig.updated_canvas_event();\n",
1340 | " fig.waiting = false;\n",
1341 | " return;\n",
1342 | " }\n",
1343 | " else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
1344 | " fig.imageObj.src = evt.data;\n",
1345 | " fig.updated_canvas_event();\n",
1346 | " fig.waiting = false;\n",
1347 | " return;\n",
1348 | " }\n",
1349 | "\n",
1350 | " var msg = JSON.parse(evt.data);\n",
1351 | " var msg_type = msg['type'];\n",
1352 | "\n",
1353 | " // Call the \"handle_{type}\" callback, which takes\n",
1354 | " // the figure and JSON message as its only arguments.\n",
1355 | " try {\n",
1356 | " var callback = fig[\"handle_\" + msg_type];\n",
1357 | " } catch (e) {\n",
1358 | " console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
1359 | " return;\n",
1360 | " }\n",
1361 | "\n",
1362 | " if (callback) {\n",
1363 | " try {\n",
1364 | " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
1365 | " callback(fig, msg);\n",
1366 | " } catch (e) {\n",
1367 | " console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
1368 | " }\n",
1369 | " }\n",
1370 | " };\n",
1371 | "}\n",
1372 | "\n",
1373 | "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
1374 | "mpl.findpos = function(e) {\n",
1375 | " //this section is from http://www.quirksmode.org/js/events_properties.html\n",
1376 | " var targ;\n",
1377 | " if (!e)\n",
1378 | " e = window.event;\n",
1379 | " if (e.target)\n",
1380 | " targ = e.target;\n",
1381 | " else if (e.srcElement)\n",
1382 | " targ = e.srcElement;\n",
1383 | " if (targ.nodeType == 3) // defeat Safari bug\n",
1384 | " targ = targ.parentNode;\n",
1385 | "\n",
1386 | " // jQuery normalizes the pageX and pageY\n",
1387 | " // pageX,Y are the mouse positions relative to the document\n",
1388 | " // offset() returns the position of the element relative to the document\n",
1389 | " var x = e.pageX - $(targ).offset().left;\n",
1390 | " var y = e.pageY - $(targ).offset().top;\n",
1391 | "\n",
1392 | " return {\"x\": x, \"y\": y};\n",
1393 | "};\n",
1394 | "\n",
1395 | "/*\n",
1396 | " * return a copy of an object with only non-object keys\n",
1397 | " * we need this to avoid circular references\n",
1398 | " * http://stackoverflow.com/a/24161582/3208463\n",
1399 | " */\n",
1400 | "function simpleKeys (original) {\n",
1401 | " return Object.keys(original).reduce(function (obj, key) {\n",
1402 | " if (typeof original[key] !== 'object')\n",
1403 | " obj[key] = original[key]\n",
1404 | " return obj;\n",
1405 | " }, {});\n",
1406 | "}\n",
1407 | "\n",
1408 | "mpl.figure.prototype.mouse_event = function(event, name) {\n",
1409 | " var canvas_pos = mpl.findpos(event)\n",
1410 | "\n",
1411 | " if (name === 'button_press')\n",
1412 | " {\n",
1413 | " this.canvas.focus();\n",
1414 | " this.canvas_div.focus();\n",
1415 | " }\n",
1416 | "\n",
1417 | " var x = canvas_pos.x * mpl.ratio;\n",
1418 | " var y = canvas_pos.y * mpl.ratio;\n",
1419 | "\n",
1420 | " this.send_message(name, {x: x, y: y, button: event.button,\n",
1421 | " step: event.step,\n",
1422 | " guiEvent: simpleKeys(event)});\n",
1423 | "\n",
1424 | " /* This prevents the web browser from automatically changing to\n",
1425 | " * the text insertion cursor when the button is pressed. We want\n",
1426 | " * to control all of the cursor setting manually through the\n",
1427 | " * 'cursor' event from matplotlib */\n",
1428 | " event.preventDefault();\n",
1429 | " return false;\n",
1430 | "}\n",
1431 | "\n",
1432 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
1433 | " // Handle any extra behaviour associated with a key event\n",
1434 | "}\n",
1435 | "\n",
1436 | "mpl.figure.prototype.key_event = function(event, name) {\n",
1437 | "\n",
1438 | " // Prevent repeat events\n",
1439 | " if (name == 'key_press')\n",
1440 | " {\n",
1441 | " if (event.which === this._key)\n",
1442 | " return;\n",
1443 | " else\n",
1444 | " this._key = event.which;\n",
1445 | " }\n",
1446 | " if (name == 'key_release')\n",
1447 | " this._key = null;\n",
1448 | "\n",
1449 | " var value = '';\n",
1450 | " if (event.ctrlKey && event.which != 17)\n",
1451 | " value += \"ctrl+\";\n",
1452 | " if (event.altKey && event.which != 18)\n",
1453 | " value += \"alt+\";\n",
1454 | " if (event.shiftKey && event.which != 16)\n",
1455 | " value += \"shift+\";\n",
1456 | "\n",
1457 | " value += 'k';\n",
1458 | " value += event.which.toString();\n",
1459 | "\n",
1460 | " this._key_event_extra(event, name);\n",
1461 | "\n",
1462 | " this.send_message(name, {key: value,\n",
1463 | " guiEvent: simpleKeys(event)});\n",
1464 | " return false;\n",
1465 | "}\n",
1466 | "\n",
1467 | "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
1468 | " if (name == 'download') {\n",
1469 | " this.handle_save(this, null);\n",
1470 | " } else {\n",
1471 | " this.send_message(\"toolbar_button\", {name: name});\n",
1472 | " }\n",
1473 | "};\n",
1474 | "\n",
1475 | "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
1476 | " this.message.textContent = tooltip;\n",
1477 | "};\n",
1478 | "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
1479 | "\n",
1480 | "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
1481 | "\n",
1482 | "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
1483 | " // Create a \"websocket\"-like object which calls the given IPython comm\n",
1484 | " // object with the appropriate methods. Currently this is a non binary\n",
1485 | " // socket, so there is still some room for performance tuning.\n",
1486 | " var ws = {};\n",
1487 | "\n",
1488 | " ws.close = function() {\n",
1489 | " comm.close()\n",
1490 | " };\n",
1491 | " ws.send = function(m) {\n",
1492 | " //console.log('sending', m);\n",
1493 | " comm.send(m);\n",
1494 | " };\n",
1495 | " // Register the callback with on_msg.\n",
1496 | " comm.on_msg(function(msg) {\n",
1497 | " //console.log('receiving', msg['content']['data'], msg);\n",
1498 | " // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
1499 | " ws.onmessage(msg['content']['data'])\n",
1500 | " });\n",
1501 | " return ws;\n",
1502 | "}\n",
1503 | "\n",
1504 | "mpl.mpl_figure_comm = function(comm, msg) {\n",
1505 | " // This is the function which gets called when the mpl process\n",
1506 | " // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
1507 | "\n",
1508 | " var id = msg.content.data.id;\n",
1509 | " // Get hold of the div created by the display call when the Comm\n",
1510 | " // socket was opened in Python.\n",
1511 | " var element = $(\"#\" + id);\n",
1512 | " var ws_proxy = comm_websocket_adapter(comm)\n",
1513 | "\n",
1514 | " function ondownload(figure, format) {\n",
1515 | " window.open(figure.imageObj.src);\n",
1516 | " }\n",
1517 | "\n",
1518 | " var fig = new mpl.figure(id, ws_proxy,\n",
1519 | " ondownload,\n",
1520 | " element.get(0));\n",
1521 | "\n",
1522 | " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
1523 | " // web socket which is closed, not our websocket->open comm proxy.\n",
1524 | " ws_proxy.onopen();\n",
1525 | "\n",
1526 | " fig.parent_element = element.get(0);\n",
1527 | " fig.cell_info = mpl.find_output_cell(\"\");\n",
1528 | " if (!fig.cell_info) {\n",
1529 | " console.error(\"Failed to find cell for figure\", id, fig);\n",
1530 | " return;\n",
1531 | " }\n",
1532 | "\n",
1533 | " var output_index = fig.cell_info[2]\n",
1534 | " var cell = fig.cell_info[0];\n",
1535 | "\n",
1536 | "};\n",
1537 | "\n",
1538 | "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
1539 | " var width = fig.canvas.width/mpl.ratio\n",
1540 | " fig.root.unbind('remove')\n",
1541 | "\n",
1542 | " // Update the output cell to use the data from the current canvas.\n",
1543 | " fig.push_to_output();\n",
1544 | " var dataURL = fig.canvas.toDataURL();\n",
1545 | " // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
1546 | " // the notebook keyboard shortcuts fail.\n",
1547 | " IPython.keyboard_manager.enable()\n",
1548 | " $(fig.parent_element).html('
');\n",
1549 | " fig.close_ws(fig, msg);\n",
1550 | "}\n",
1551 | "\n",
1552 | "mpl.figure.prototype.close_ws = function(fig, msg){\n",
1553 | " fig.send_message('closing', msg);\n",
1554 | " // fig.ws.close()\n",
1555 | "}\n",
1556 | "\n",
1557 | "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
1558 | " // Turn the data on the canvas into data in the output cell.\n",
1559 | " var width = this.canvas.width/mpl.ratio\n",
1560 | " var dataURL = this.canvas.toDataURL();\n",
1561 | " this.cell_info[1]['text/html'] = '
';\n",
1562 | "}\n",
1563 | "\n",
1564 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
1565 | " // Tell IPython that the notebook contents must change.\n",
1566 | " IPython.notebook.set_dirty(true);\n",
1567 | " this.send_message(\"ack\", {});\n",
1568 | " var fig = this;\n",
1569 | " // Wait a second, then push the new image to the DOM so\n",
1570 | " // that it is saved nicely (might be nice to debounce this).\n",
1571 | " setTimeout(function () { fig.push_to_output() }, 1000);\n",
1572 | "}\n",
1573 | "\n",
1574 | "mpl.figure.prototype._init_toolbar = function() {\n",
1575 | " var fig = this;\n",
1576 | "\n",
1577 | " var nav_element = $('');\n",
1578 | " nav_element.attr('style', 'width: 100%');\n",
1579 | " this.root.append(nav_element);\n",
1580 | "\n",
1581 | " // Define a callback function for later on.\n",
1582 | " function toolbar_event(event) {\n",
1583 | " return fig.toolbar_button_onclick(event['data']);\n",
1584 | " }\n",
1585 | " function toolbar_mouse_event(event) {\n",
1586 | " return fig.toolbar_button_onmouseover(event['data']);\n",
1587 | " }\n",
1588 | "\n",
1589 | " for(var toolbar_ind in mpl.toolbar_items){\n",
1590 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
1591 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
1592 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
1593 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
1594 | "\n",
1595 | " if (!name) { continue; };\n",
1596 | "\n",
1597 | " var button = $('');\n",
1598 | " button.click(method_name, toolbar_event);\n",
1599 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
1600 | " nav_element.append(button);\n",
1601 | " }\n",
1602 | "\n",
1603 | " // Add the status bar.\n",
1604 | " var status_bar = $('');\n",
1605 | " nav_element.append(status_bar);\n",
1606 | " this.message = status_bar[0];\n",
1607 | "\n",
1608 | " // Add the close button to the window.\n",
1609 | " var buttongrp = $('');\n",
1610 | " var button = $('');\n",
1611 | " button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
1612 | " button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
1613 | " buttongrp.append(button);\n",
1614 | " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
1615 | " titlebar.prepend(buttongrp);\n",
1616 | "}\n",
1617 | "\n",
1618 | "mpl.figure.prototype._root_extra_style = function(el){\n",
1619 | " var fig = this\n",
1620 | " el.on(\"remove\", function(){\n",
1621 | "\tfig.close_ws(fig, {});\n",
1622 | " });\n",
1623 | "}\n",
1624 | "\n",
1625 | "mpl.figure.prototype._canvas_extra_style = function(el){\n",
1626 | " // this is important to make the div 'focusable\n",
1627 | " el.attr('tabindex', 0)\n",
1628 | " // reach out to IPython and tell the keyboard manager to turn it's self\n",
1629 | " // off when our div gets focus\n",
1630 | "\n",
1631 | " // location in version 3\n",
1632 | " if (IPython.notebook.keyboard_manager) {\n",
1633 | " IPython.notebook.keyboard_manager.register_events(el);\n",
1634 | " }\n",
1635 | " else {\n",
1636 | " // location in version 2\n",
1637 | " IPython.keyboard_manager.register_events(el);\n",
1638 | " }\n",
1639 | "\n",
1640 | "}\n",
1641 | "\n",
1642 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
1643 | " var manager = IPython.notebook.keyboard_manager;\n",
1644 | " if (!manager)\n",
1645 | " manager = IPython.keyboard_manager;\n",
1646 | "\n",
1647 | " // Check for shift+enter\n",
1648 | " if (event.shiftKey && event.which == 13) {\n",
1649 | " this.canvas_div.blur();\n",
1650 | " event.shiftKey = false;\n",
1651 | " // Send a \"J\" for go to next cell\n",
1652 | " event.which = 74;\n",
1653 | " event.keyCode = 74;\n",
1654 | " manager.command_mode();\n",
1655 | " manager.handle_keydown(event);\n",
1656 | " }\n",
1657 | "}\n",
1658 | "\n",
1659 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
1660 | " fig.ondownload(fig, null);\n",
1661 | "}\n",
1662 | "\n",
1663 | "\n",
1664 | "mpl.find_output_cell = function(html_output) {\n",
1665 | " // Return the cell and output element which can be found *uniquely* in the notebook.\n",
1666 | " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
1667 | " // IPython event is triggered only after the cells have been serialised, which for\n",
1668 | " // our purposes (turning an active figure into a static one), is too late.\n",
1669 | " var cells = IPython.notebook.get_cells();\n",
1670 | " var ncells = cells.length;\n",
1671 | " for (var i=0; i= 3 moved mimebundle to data attribute of output\n",
1678 | " data = data.data;\n",
1679 | " }\n",
1680 | " if (data['text/html'] == html_output) {\n",
1681 | " return [cell, data, j];\n",
1682 | " }\n",
1683 | " }\n",
1684 | " }\n",
1685 | " }\n",
1686 | "}\n",
1687 | "\n",
1688 | "// Register the function which deals with the matplotlib target/channel.\n",
1689 | "// The kernel may be null if the page has been refreshed.\n",
1690 | "if (IPython.notebook.kernel != null) {\n",
1691 | " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
1692 | "}\n"
1693 | ],
1694 | "text/plain": [
1695 | ""
1696 | ]
1697 | },
1698 | "metadata": {},
1699 | "output_type": "display_data"
1700 | },
1701 | {
1702 | "data": {
1703 | "text/html": [
1704 | "
"
1705 | ],
1706 | "text/plain": [
1707 | ""
1708 | ]
1709 | },
1710 | "metadata": {},
1711 | "output_type": "display_data"
1712 | },
1713 | {
1714 | "data": {
1715 | "text/plain": [
1716 | ""
1717 | ]
1718 | },
1719 | "execution_count": 5,
1720 | "metadata": {},
1721 | "output_type": "execute_result"
1722 | }
1723 | ],
1724 | "source": [
1725 | "fig = plt.figure()\n",
1726 | "ax = fig.add_subplot(111)\n",
1727 | "plt.ion()\n",
1728 | "\n",
1729 | "fig.show()\n",
1730 | "fig.canvas.draw()\n",
1731 | "\n",
1732 | "o = []\n",
1733 | "x = [torch.sigmoid(torch.arange(-10,-1).float()).unsqueeze(-1).numpy().tolist()]\n",
1734 | "\n",
1735 | "#Draw graph comparing to sigmoid\n",
1736 | "for i in range(-10, 10, output_sequence_length):\n",
1737 | " o.append([torch.sigmoid(torch.tensor(i).float())])\n",
1738 | " q = torch.tensor(x).float()\n",
1739 | " \n",
1740 | " if(output_sequence_length == 1):\n",
1741 | " x[0].append([t(q).detach().squeeze().numpy()])\n",
1742 | " else:\n",
1743 | " for a in t(q).detach().squeeze().numpy():\n",
1744 | " x[0].append([a])\n",
1745 | " \n",
1746 | "ax.clear()\n",
1747 | "ax.plot(x[0], label='Network output')\n",
1748 | "ax.plot(o, label='Sigmoid function')\n",
1749 | "ax.set_title(\"\")\n",
1750 | "ax.legend(loc='upper left', frameon=False)\n"
1751 | ]
1752 | },
1753 | {
1754 | "cell_type": "code",
1755 | "execution_count": null,
1756 | "metadata": {},
1757 | "outputs": [],
1758 | "source": []
1759 | }
1760 | ],
1761 | "metadata": {
1762 | "kernelspec": {
1763 | "display_name": "Python 3",
1764 | "language": "python",
1765 | "name": "python3"
1766 | },
1767 | "language_info": {
1768 | "codemirror_mode": {
1769 | "name": "ipython",
1770 | "version": 3
1771 | },
1772 | "file_extension": ".py",
1773 | "mimetype": "text/x-python",
1774 | "name": "python",
1775 | "nbconvert_exporter": "python",
1776 | "pygments_lexer": "ipython3",
1777 | "version": "3.7.4"
1778 | }
1779 | },
1780 | "nbformat": 4,
1781 | "nbformat_minor": 2
1782 | }
1783 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import csv
3 | import torch
4 | import random
5 | import numpy as np
6 | from torch.nn.utils.rnn import pad_sequence
7 |
8 | def gen_data(batch_size, gpu = True):
9 |
10 | #https://www.kaggle.com/shashankasubrahmanya/preprocessing-cornell-movie-dialogue-corpus
11 | movie_lines = pd.read_csv('cornell_data/movie_lines.txt', sep = "\+\+\+\$\+\+\+", engine = "python",
12 | index_col = False, names = ["LineID", "Character", "Movie", "Name", "Line"])
13 | movie_lines = movie_lines[["LineID", "Line"]]
14 |
15 | movie_lines["Line"] = movie_lines['Line'].str.replace('.','')
16 | movie_lines["Line"] = movie_lines['Line'].str.replace('!','')
17 | movie_lines["Line"] = movie_lines['Line'].str.replace('?','')
18 | movie_lines["Line"] = movie_lines['Line'].str.replace(' ',' ')
19 | movie_lines["Line"] = movie_lines['Line'].str.replace('[^\w\s.!?]','')
20 | movie_lines["Line"] = movie_lines["Line"].str.lower()
21 |
22 | movie_lines["LineID"] = movie_lines["LineID"].apply(str.strip)
23 | movie_lines["Line"] = movie_lines["Line"].apply(lambda x : str(x).split(" ")[1:])
24 |
25 | movie_conversations = pd.read_csv("cornell_data/movie_conversations.txt", sep = "\+\+\+\$\+\+\+",
26 | engine = "python", index_col = False, names = ["Character1", "Character2", "Movie", "Conversation"])
27 | movie_conversations = movie_conversations["Conversation"]
28 |
29 | #convert from strings of lists to actual lists
30 | movie_conversations = movie_conversations.apply(eval)
31 |
32 | word_vecs = pd.read_table("glove.6B.50d.txt", sep=" ", index_col=0, header=None, quoting=csv.QUOTE_NONE)
33 | if(not gpu):
34 | dtype = torch.FloatTensor
35 | else:
36 | dtype = torch.cuda.FloatTensor
37 |
38 | ba = 0
39 | vectorized_inputs = []
40 | vectorized_targets = []
41 |
42 | inputs = []
43 | targets = []
44 | while True:
45 | i = random.randint(0, movie_conversations.size - (batch_size + 1))
46 | batch = movie_conversations.loc[i:i+batch_size].apply(lambda x : movie_lines.loc[(movie_lines['LineID'].isin(x))])
47 | batch = batch.apply(lambda x : x['Line'].values).values
48 |
49 | for b in batch:
50 | for idx in range(len(b) - 1):
51 | vectorized_inputs.append(torch.tensor(word_vecs.loc[b[idx]].values).type(dtype))
52 | vectorized_targets.append(torch.tensor(word_vecs.loc[b[idx + 1]].values).type(dtype))
53 |
54 | inputs.append(b[idx])
55 | targets.append(b[idx + 1])
56 |
57 | ba += 1
58 | if ba >= batch_size:
59 | ba = 0
60 | yield (pad_sequence(vectorized_inputs, batch_first = True)
61 | , pad_sequence(vectorized_targets, batch_first = True), inputs, targets)
62 | vectorized_inputs = []
63 | vectorized_targets = []
64 | inputs = []
65 | targets = []
--------------------------------------------------------------------------------
/doc/buildfig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/buildfig.png
--------------------------------------------------------------------------------
/doc/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/compare.png
--------------------------------------------------------------------------------
/doc/comparegraph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/comparegraph.png
--------------------------------------------------------------------------------
/doc/hyperparams.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/hyperparams.png
--------------------------------------------------------------------------------
/doc/imports,png.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/imports,png.PNG
--------------------------------------------------------------------------------
/doc/init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/init.png
--------------------------------------------------------------------------------
/doc/lossgraph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/lossgraph.png
--------------------------------------------------------------------------------
/doc/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/training.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import math
6 |
7 | def a_norm(Q, K):
8 | m = torch.matmul(Q, K.transpose(2,1).float())
9 | m /= torch.sqrt(torch.tensor(Q.shape[-1]).float())
10 |
11 | return torch.softmax(m , -1)
12 |
13 |
14 | def attention(Q, K, V):
15 | #Attention(Q, K, V) = norm(QK)V
16 | a = a_norm(Q, K) #(batch_size, dim_attn, seq_length)
17 |
18 | return torch.matmul(a, V) #(batch_size, seq_length, seq_length)
19 |
20 | class AttentionBlock(torch.nn.Module):
21 | def __init__(self, dim_val, dim_attn):
22 | super(AttentionBlock, self).__init__()
23 | self.value = Value(dim_val, dim_val)
24 | self.key = Key(dim_val, dim_attn)
25 | self.query = Query(dim_val, dim_attn)
26 |
27 | def forward(self, x, kv = None):
28 | if(kv is None):
29 | #Attention with x connected to Q,K and V (For encoder)
30 | return attention(self.query(x), self.key(x), self.value(x))
31 |
32 | #Attention with x as Q, external vector kv as K an V (For decoder)
33 | return attention(self.query(x), self.key(kv), self.value(kv))
34 |
35 | class MultiHeadAttentionBlock(torch.nn.Module):
36 | def __init__(self, dim_val, dim_attn, n_heads):
37 | super(MultiHeadAttentionBlock, self).__init__()
38 | self.heads = []
39 | for i in range(n_heads):
40 | self.heads.append(AttentionBlock(dim_val, dim_attn))
41 |
42 | self.fc = nn.Linear(n_heads * dim_val, dim_val, bias = False)
43 |
44 |
45 | def forward(self, x, kv = None):
46 | a = []
47 | for h in self.heads:
48 | a.append(h(x, kv = kv))
49 |
50 | a = torch.stack(a, dim = -1) #combine heads
51 | a = a.flatten(start_dim = 2) #flatten all head outputs
52 |
53 | x = self.fc(a)
54 |
55 | return x
56 |
57 | class Value(torch.nn.Module):
58 | def __init__(self, dim_input, dim_val):
59 | super(Value, self).__init__()
60 | self.dim_val = dim_val
61 |
62 | self.fc1 = nn.Linear(dim_input, dim_val, bias = False)
63 | #self.fc2 = nn.Linear(5, dim_val)
64 |
65 | def forward(self, x):
66 | x = self.fc1(x)
67 | #x = self.fc2(x)
68 |
69 | return x
70 |
71 | class Key(torch.nn.Module):
72 | def __init__(self, dim_input, dim_attn):
73 | super(Key, self).__init__()
74 | self.dim_attn = dim_attn
75 |
76 | self.fc1 = nn.Linear(dim_input, dim_attn, bias = False)
77 | #self.fc2 = nn.Linear(5, dim_attn)
78 |
79 | def forward(self, x):
80 | x = self.fc1(x)
81 | #x = self.fc2(x)
82 |
83 | return x
84 |
85 | class Query(torch.nn.Module):
86 | def __init__(self, dim_input, dim_attn):
87 | super(Query, self).__init__()
88 | self.dim_attn = dim_attn
89 |
90 | self.fc1 = nn.Linear(dim_input, dim_attn, bias = False)
91 | #self.fc2 = nn.Linear(5, dim_attn)
92 |
93 | def forward(self, x):
94 |
95 | x = self.fc1(x)
96 | #print(x.shape)
97 | #x = self.fc2(x)
98 |
99 | return x
100 |
101 | # https://pytorch.org/tutorials/beginner/transformer_tutorial.html
102 | class PositionalEncoding(nn.Module):
103 | def __init__(self, d_model, dropout=0.1, max_len=5000):
104 | super(PositionalEncoding, self).__init__()
105 |
106 | pe = torch.zeros(max_len, d_model)
107 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
108 |
109 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
110 |
111 | pe[:, 0::2] = torch.sin(position * div_term)
112 | pe[:, 1::2] = torch.cos(position * div_term)
113 |
114 | pe = pe.unsqueeze(0).transpose(0, 1)
115 |
116 | self.register_buffer('pe', pe)
117 |
118 | def forward(self, x):
119 | x = x + self.pe[:x.size(1), :]. squeeze(1)
120 | return x
121 |
122 | def get_data(batch_size, input_sequence_length, output_sequence_length):
123 | i = input_sequence_length + output_sequence_length
124 |
125 | t = torch.zeros(batch_size,1).uniform_(0,20 - i).int()
126 | b = torch.arange(-10, -10 + i).unsqueeze(0).repeat(batch_size,1) + t
127 |
128 | s = torch.sigmoid(b.float())
129 | return s[:, :input_sequence_length].unsqueeze(-1), s[:,-output_sequence_length:]
--------------------------------------------------------------------------------