├── .gitignore
├── LICENSE
├── README.md
├── abae.ipynb
├── abae_pytorch
├── __init__.py
├── data.py
├── model.py
├── train.py
├── utils.py
└── word2vec.py
└── data
└── restaurant.train.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Curtis Ogle
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # abae_pytorch
2 | Attention-based aspect extraction
3 |
4 | PyTorch implementation of the model described in [An Unsupervised Neural Attention Model for Aspect Extraction](https://www.aclweb.org/anthology/P17-1036.pdf).
5 |
--------------------------------------------------------------------------------
/abae.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "nbAgg\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "%load_ext autoreload\n",
18 | "%autoreload 2\n",
19 | "#%matplotlib inline\n",
20 | "%matplotlib notebook\n",
21 | "import matplotlib\n",
22 | "import matplotlib.pyplot as plt\n",
23 | "print(plt.get_backend())"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "metadata": {
30 | "scrolled": false
31 | },
32 | "outputs": [
33 | {
34 | "name": "stdout",
35 | "output_type": "stream",
36 | "text": [
37 | "loading abae model: \"./data/restaurant.train.txt.prep.abae.pt\"\n",
38 | "n_vocab: 8311 | d_embed: 200 | n_aspects: 15\n"
39 | ]
40 | },
41 | {
42 | "data": {
43 | "application/javascript": [
44 | "/* Put everything inside the global mpl namespace */\n",
45 | "window.mpl = {};\n",
46 | "\n",
47 | "\n",
48 | "mpl.get_websocket_type = function() {\n",
49 | " if (typeof(WebSocket) !== 'undefined') {\n",
50 | " return WebSocket;\n",
51 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n",
52 | " return MozWebSocket;\n",
53 | " } else {\n",
54 | " alert('Your browser does not have WebSocket support.' +\n",
55 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
56 | " 'Firefox 4 and 5 are also supported but you ' +\n",
57 | " 'have to enable WebSockets in about:config.');\n",
58 | " };\n",
59 | "}\n",
60 | "\n",
61 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
62 | " this.id = figure_id;\n",
63 | "\n",
64 | " this.ws = websocket;\n",
65 | "\n",
66 | " this.supports_binary = (this.ws.binaryType != undefined);\n",
67 | "\n",
68 | " if (!this.supports_binary) {\n",
69 | " var warnings = document.getElementById(\"mpl-warnings\");\n",
70 | " if (warnings) {\n",
71 | " warnings.style.display = 'block';\n",
72 | " warnings.textContent = (\n",
73 | " \"This browser does not support binary websocket messages. \" +\n",
74 | " \"Performance may be slow.\");\n",
75 | " }\n",
76 | " }\n",
77 | "\n",
78 | " this.imageObj = new Image();\n",
79 | "\n",
80 | " this.context = undefined;\n",
81 | " this.message = undefined;\n",
82 | " this.canvas = undefined;\n",
83 | " this.rubberband_canvas = undefined;\n",
84 | " this.rubberband_context = undefined;\n",
85 | " this.format_dropdown = undefined;\n",
86 | "\n",
87 | " this.image_mode = 'full';\n",
88 | "\n",
89 | " this.root = $('
');\n",
90 | " this._root_extra_style(this.root)\n",
91 | " this.root.attr('style', 'display: inline-block');\n",
92 | "\n",
93 | " $(parent_element).append(this.root);\n",
94 | "\n",
95 | " this._init_header(this);\n",
96 | " this._init_canvas(this);\n",
97 | " this._init_toolbar(this);\n",
98 | "\n",
99 | " var fig = this;\n",
100 | "\n",
101 | " this.waiting = false;\n",
102 | "\n",
103 | " this.ws.onopen = function () {\n",
104 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
105 | " fig.send_message(\"send_image_mode\", {});\n",
106 | " if (mpl.ratio != 1) {\n",
107 | " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
108 | " }\n",
109 | " fig.send_message(\"refresh\", {});\n",
110 | " }\n",
111 | "\n",
112 | " this.imageObj.onload = function() {\n",
113 | " if (fig.image_mode == 'full') {\n",
114 | " // Full images could contain transparency (where diff images\n",
115 | " // almost always do), so we need to clear the canvas so that\n",
116 | " // there is no ghosting.\n",
117 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
118 | " }\n",
119 | " fig.context.drawImage(fig.imageObj, 0, 0);\n",
120 | " };\n",
121 | "\n",
122 | " this.imageObj.onunload = function() {\n",
123 | " fig.ws.close();\n",
124 | " }\n",
125 | "\n",
126 | " this.ws.onmessage = this._make_on_message_function(this);\n",
127 | "\n",
128 | " this.ondownload = ondownload;\n",
129 | "}\n",
130 | "\n",
131 | "mpl.figure.prototype._init_header = function() {\n",
132 | " var titlebar = $(\n",
133 | " '');\n",
135 | " var titletext = $(\n",
136 | " '');\n",
138 | " titlebar.append(titletext)\n",
139 | " this.root.append(titlebar);\n",
140 | " this.header = titletext[0];\n",
141 | "}\n",
142 | "\n",
143 | "\n",
144 | "\n",
145 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
146 | "\n",
147 | "}\n",
148 | "\n",
149 | "\n",
150 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
151 | "\n",
152 | "}\n",
153 | "\n",
154 | "mpl.figure.prototype._init_canvas = function() {\n",
155 | " var fig = this;\n",
156 | "\n",
157 | " var canvas_div = $('');\n",
158 | "\n",
159 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
160 | "\n",
161 | " function canvas_keyboard_event(event) {\n",
162 | " return fig.key_event(event, event['data']);\n",
163 | " }\n",
164 | "\n",
165 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n",
166 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n",
167 | " this.canvas_div = canvas_div\n",
168 | " this._canvas_extra_style(canvas_div)\n",
169 | " this.root.append(canvas_div);\n",
170 | "\n",
171 | " var canvas = $('');\n",
172 | " canvas.addClass('mpl-canvas');\n",
173 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
174 | "\n",
175 | " this.canvas = canvas[0];\n",
176 | " this.context = canvas[0].getContext(\"2d\");\n",
177 | "\n",
178 | " var backingStore = this.context.backingStorePixelRatio ||\n",
179 | "\tthis.context.webkitBackingStorePixelRatio ||\n",
180 | "\tthis.context.mozBackingStorePixelRatio ||\n",
181 | "\tthis.context.msBackingStorePixelRatio ||\n",
182 | "\tthis.context.oBackingStorePixelRatio ||\n",
183 | "\tthis.context.backingStorePixelRatio || 1;\n",
184 | "\n",
185 | " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
186 | "\n",
187 | " var rubberband = $('');\n",
188 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
189 | "\n",
190 | " var pass_mouse_events = true;\n",
191 | "\n",
192 | " canvas_div.resizable({\n",
193 | " start: function(event, ui) {\n",
194 | " pass_mouse_events = false;\n",
195 | " },\n",
196 | " resize: function(event, ui) {\n",
197 | " fig.request_resize(ui.size.width, ui.size.height);\n",
198 | " },\n",
199 | " stop: function(event, ui) {\n",
200 | " pass_mouse_events = true;\n",
201 | " fig.request_resize(ui.size.width, ui.size.height);\n",
202 | " },\n",
203 | " });\n",
204 | "\n",
205 | " function mouse_event_fn(event) {\n",
206 | " if (pass_mouse_events)\n",
207 | " return fig.mouse_event(event, event['data']);\n",
208 | " }\n",
209 | "\n",
210 | " rubberband.mousedown('button_press', mouse_event_fn);\n",
211 | " rubberband.mouseup('button_release', mouse_event_fn);\n",
212 | " // Throttle sequential mouse events to 1 every 20ms.\n",
213 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n",
214 | "\n",
215 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
216 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
217 | "\n",
218 | " canvas_div.on(\"wheel\", function (event) {\n",
219 | " event = event.originalEvent;\n",
220 | " event['data'] = 'scroll'\n",
221 | " if (event.deltaY < 0) {\n",
222 | " event.step = 1;\n",
223 | " } else {\n",
224 | " event.step = -1;\n",
225 | " }\n",
226 | " mouse_event_fn(event);\n",
227 | " });\n",
228 | "\n",
229 | " canvas_div.append(canvas);\n",
230 | " canvas_div.append(rubberband);\n",
231 | "\n",
232 | " this.rubberband = rubberband;\n",
233 | " this.rubberband_canvas = rubberband[0];\n",
234 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
235 | " this.rubberband_context.strokeStyle = \"#000000\";\n",
236 | "\n",
237 | " this._resize_canvas = function(width, height) {\n",
238 | " // Keep the size of the canvas, canvas container, and rubber band\n",
239 | " // canvas in synch.\n",
240 | " canvas_div.css('width', width)\n",
241 | " canvas_div.css('height', height)\n",
242 | "\n",
243 | " canvas.attr('width', width * mpl.ratio);\n",
244 | " canvas.attr('height', height * mpl.ratio);\n",
245 | " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
246 | "\n",
247 | " rubberband.attr('width', width);\n",
248 | " rubberband.attr('height', height);\n",
249 | " }\n",
250 | "\n",
251 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n",
252 | " // upon first draw.\n",
253 | " this._resize_canvas(600, 600);\n",
254 | "\n",
255 | " // Disable right mouse context menu.\n",
256 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
257 | " return false;\n",
258 | " });\n",
259 | "\n",
260 | " function set_focus () {\n",
261 | " canvas.focus();\n",
262 | " canvas_div.focus();\n",
263 | " }\n",
264 | "\n",
265 | " window.setTimeout(set_focus, 100);\n",
266 | "}\n",
267 | "\n",
268 | "mpl.figure.prototype._init_toolbar = function() {\n",
269 | " var fig = this;\n",
270 | "\n",
271 | " var nav_element = $('')\n",
272 | " nav_element.attr('style', 'width: 100%');\n",
273 | " this.root.append(nav_element);\n",
274 | "\n",
275 | " // Define a callback function for later on.\n",
276 | " function toolbar_event(event) {\n",
277 | " return fig.toolbar_button_onclick(event['data']);\n",
278 | " }\n",
279 | " function toolbar_mouse_event(event) {\n",
280 | " return fig.toolbar_button_onmouseover(event['data']);\n",
281 | " }\n",
282 | "\n",
283 | " for(var toolbar_ind in mpl.toolbar_items) {\n",
284 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
285 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
286 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
287 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
288 | "\n",
289 | " if (!name) {\n",
290 | " // put a spacer in here.\n",
291 | " continue;\n",
292 | " }\n",
293 | " var button = $('');\n",
294 | " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
295 | " 'ui-button-icon-only');\n",
296 | " button.attr('role', 'button');\n",
297 | " button.attr('aria-disabled', 'false');\n",
298 | " button.click(method_name, toolbar_event);\n",
299 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
300 | "\n",
301 | " var icon_img = $('');\n",
302 | " icon_img.addClass('ui-button-icon-primary ui-icon');\n",
303 | " icon_img.addClass(image);\n",
304 | " icon_img.addClass('ui-corner-all');\n",
305 | "\n",
306 | " var tooltip_span = $('');\n",
307 | " tooltip_span.addClass('ui-button-text');\n",
308 | " tooltip_span.html(tooltip);\n",
309 | "\n",
310 | " button.append(icon_img);\n",
311 | " button.append(tooltip_span);\n",
312 | "\n",
313 | " nav_element.append(button);\n",
314 | " }\n",
315 | "\n",
316 | " var fmt_picker_span = $('');\n",
317 | "\n",
318 | " var fmt_picker = $('');\n",
319 | " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
320 | " fmt_picker_span.append(fmt_picker);\n",
321 | " nav_element.append(fmt_picker_span);\n",
322 | " this.format_dropdown = fmt_picker[0];\n",
323 | "\n",
324 | " for (var ind in mpl.extensions) {\n",
325 | " var fmt = mpl.extensions[ind];\n",
326 | " var option = $(\n",
327 | " '', {selected: fmt === mpl.default_extension}).html(fmt);\n",
328 | " fmt_picker.append(option)\n",
329 | " }\n",
330 | "\n",
331 | " // Add hover states to the ui-buttons\n",
332 | " $( \".ui-button\" ).hover(\n",
333 | " function() { $(this).addClass(\"ui-state-hover\");},\n",
334 | " function() { $(this).removeClass(\"ui-state-hover\");}\n",
335 | " );\n",
336 | "\n",
337 | " var status_bar = $('');\n",
338 | " nav_element.append(status_bar);\n",
339 | " this.message = status_bar[0];\n",
340 | "}\n",
341 | "\n",
342 | "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
343 | " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
344 | " // which will in turn request a refresh of the image.\n",
345 | " this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
346 | "}\n",
347 | "\n",
348 | "mpl.figure.prototype.send_message = function(type, properties) {\n",
349 | " properties['type'] = type;\n",
350 | " properties['figure_id'] = this.id;\n",
351 | " this.ws.send(JSON.stringify(properties));\n",
352 | "}\n",
353 | "\n",
354 | "mpl.figure.prototype.send_draw_message = function() {\n",
355 | " if (!this.waiting) {\n",
356 | " this.waiting = true;\n",
357 | " this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
358 | " }\n",
359 | "}\n",
360 | "\n",
361 | "\n",
362 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
363 | " var format_dropdown = fig.format_dropdown;\n",
364 | " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
365 | " fig.ondownload(fig, format);\n",
366 | "}\n",
367 | "\n",
368 | "\n",
369 | "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
370 | " var size = msg['size'];\n",
371 | " if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
372 | " fig._resize_canvas(size[0], size[1]);\n",
373 | " fig.send_message(\"refresh\", {});\n",
374 | " };\n",
375 | "}\n",
376 | "\n",
377 | "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
378 | " var x0 = msg['x0'] / mpl.ratio;\n",
379 | " var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
380 | " var x1 = msg['x1'] / mpl.ratio;\n",
381 | " var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
382 | " x0 = Math.floor(x0) + 0.5;\n",
383 | " y0 = Math.floor(y0) + 0.5;\n",
384 | " x1 = Math.floor(x1) + 0.5;\n",
385 | " y1 = Math.floor(y1) + 0.5;\n",
386 | " var min_x = Math.min(x0, x1);\n",
387 | " var min_y = Math.min(y0, y1);\n",
388 | " var width = Math.abs(x1 - x0);\n",
389 | " var height = Math.abs(y1 - y0);\n",
390 | "\n",
391 | " fig.rubberband_context.clearRect(\n",
392 | " 0, 0, fig.canvas.width, fig.canvas.height);\n",
393 | "\n",
394 | " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
395 | "}\n",
396 | "\n",
397 | "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
398 | " // Updates the figure title.\n",
399 | " fig.header.textContent = msg['label'];\n",
400 | "}\n",
401 | "\n",
402 | "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
403 | " var cursor = msg['cursor'];\n",
404 | " switch(cursor)\n",
405 | " {\n",
406 | " case 0:\n",
407 | " cursor = 'pointer';\n",
408 | " break;\n",
409 | " case 1:\n",
410 | " cursor = 'default';\n",
411 | " break;\n",
412 | " case 2:\n",
413 | " cursor = 'crosshair';\n",
414 | " break;\n",
415 | " case 3:\n",
416 | " cursor = 'move';\n",
417 | " break;\n",
418 | " }\n",
419 | " fig.rubberband_canvas.style.cursor = cursor;\n",
420 | "}\n",
421 | "\n",
422 | "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
423 | " fig.message.textContent = msg['message'];\n",
424 | "}\n",
425 | "\n",
426 | "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
427 | " // Request the server to send over a new figure.\n",
428 | " fig.send_draw_message();\n",
429 | "}\n",
430 | "\n",
431 | "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
432 | " fig.image_mode = msg['mode'];\n",
433 | "}\n",
434 | "\n",
435 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
436 | " // Called whenever the canvas gets updated.\n",
437 | " this.send_message(\"ack\", {});\n",
438 | "}\n",
439 | "\n",
440 | "// A function to construct a web socket function for onmessage handling.\n",
441 | "// Called in the figure constructor.\n",
442 | "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
443 | " return function socket_on_message(evt) {\n",
444 | " if (evt.data instanceof Blob) {\n",
445 | " /* FIXME: We get \"Resource interpreted as Image but\n",
446 | " * transferred with MIME type text/plain:\" errors on\n",
447 | " * Chrome. But how to set the MIME type? It doesn't seem\n",
448 | " * to be part of the websocket stream */\n",
449 | " evt.data.type = \"image/png\";\n",
450 | "\n",
451 | " /* Free the memory for the previous frames */\n",
452 | " if (fig.imageObj.src) {\n",
453 | " (window.URL || window.webkitURL).revokeObjectURL(\n",
454 | " fig.imageObj.src);\n",
455 | " }\n",
456 | "\n",
457 | " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
458 | " evt.data);\n",
459 | " fig.updated_canvas_event();\n",
460 | " fig.waiting = false;\n",
461 | " return;\n",
462 | " }\n",
463 | " else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
464 | " fig.imageObj.src = evt.data;\n",
465 | " fig.updated_canvas_event();\n",
466 | " fig.waiting = false;\n",
467 | " return;\n",
468 | " }\n",
469 | "\n",
470 | " var msg = JSON.parse(evt.data);\n",
471 | " var msg_type = msg['type'];\n",
472 | "\n",
473 | " // Call the \"handle_{type}\" callback, which takes\n",
474 | " // the figure and JSON message as its only arguments.\n",
475 | " try {\n",
476 | " var callback = fig[\"handle_\" + msg_type];\n",
477 | " } catch (e) {\n",
478 | " console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
479 | " return;\n",
480 | " }\n",
481 | "\n",
482 | " if (callback) {\n",
483 | " try {\n",
484 | " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
485 | " callback(fig, msg);\n",
486 | " } catch (e) {\n",
487 | " console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
488 | " }\n",
489 | " }\n",
490 | " };\n",
491 | "}\n",
492 | "\n",
493 | "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
494 | "mpl.findpos = function(e) {\n",
495 | " //this section is from http://www.quirksmode.org/js/events_properties.html\n",
496 | " var targ;\n",
497 | " if (!e)\n",
498 | " e = window.event;\n",
499 | " if (e.target)\n",
500 | " targ = e.target;\n",
501 | " else if (e.srcElement)\n",
502 | " targ = e.srcElement;\n",
503 | " if (targ.nodeType == 3) // defeat Safari bug\n",
504 | " targ = targ.parentNode;\n",
505 | "\n",
506 | " // jQuery normalizes the pageX and pageY\n",
507 | " // pageX,Y are the mouse positions relative to the document\n",
508 | " // offset() returns the position of the element relative to the document\n",
509 | " var x = e.pageX - $(targ).offset().left;\n",
510 | " var y = e.pageY - $(targ).offset().top;\n",
511 | "\n",
512 | " return {\"x\": x, \"y\": y};\n",
513 | "};\n",
514 | "\n",
515 | "/*\n",
516 | " * return a copy of an object with only non-object keys\n",
517 | " * we need this to avoid circular references\n",
518 | " * http://stackoverflow.com/a/24161582/3208463\n",
519 | " */\n",
520 | "function simpleKeys (original) {\n",
521 | " return Object.keys(original).reduce(function (obj, key) {\n",
522 | " if (typeof original[key] !== 'object')\n",
523 | " obj[key] = original[key]\n",
524 | " return obj;\n",
525 | " }, {});\n",
526 | "}\n",
527 | "\n",
528 | "mpl.figure.prototype.mouse_event = function(event, name) {\n",
529 | " var canvas_pos = mpl.findpos(event)\n",
530 | "\n",
531 | " if (name === 'button_press')\n",
532 | " {\n",
533 | " this.canvas.focus();\n",
534 | " this.canvas_div.focus();\n",
535 | " }\n",
536 | "\n",
537 | " var x = canvas_pos.x * mpl.ratio;\n",
538 | " var y = canvas_pos.y * mpl.ratio;\n",
539 | "\n",
540 | " this.send_message(name, {x: x, y: y, button: event.button,\n",
541 | " step: event.step,\n",
542 | " guiEvent: simpleKeys(event)});\n",
543 | "\n",
544 | " /* This prevents the web browser from automatically changing to\n",
545 | " * the text insertion cursor when the button is pressed. We want\n",
546 | " * to control all of the cursor setting manually through the\n",
547 | " * 'cursor' event from matplotlib */\n",
548 | " event.preventDefault();\n",
549 | " return false;\n",
550 | "}\n",
551 | "\n",
552 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
553 | " // Handle any extra behaviour associated with a key event\n",
554 | "}\n",
555 | "\n",
556 | "mpl.figure.prototype.key_event = function(event, name) {\n",
557 | "\n",
558 | " // Prevent repeat events\n",
559 | " if (name == 'key_press')\n",
560 | " {\n",
561 | " if (event.which === this._key)\n",
562 | " return;\n",
563 | " else\n",
564 | " this._key = event.which;\n",
565 | " }\n",
566 | " if (name == 'key_release')\n",
567 | " this._key = null;\n",
568 | "\n",
569 | " var value = '';\n",
570 | " if (event.ctrlKey && event.which != 17)\n",
571 | " value += \"ctrl+\";\n",
572 | " if (event.altKey && event.which != 18)\n",
573 | " value += \"alt+\";\n",
574 | " if (event.shiftKey && event.which != 16)\n",
575 | " value += \"shift+\";\n",
576 | "\n",
577 | " value += 'k';\n",
578 | " value += event.which.toString();\n",
579 | "\n",
580 | " this._key_event_extra(event, name);\n",
581 | "\n",
582 | " this.send_message(name, {key: value,\n",
583 | " guiEvent: simpleKeys(event)});\n",
584 | " return false;\n",
585 | "}\n",
586 | "\n",
587 | "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
588 | " if (name == 'download') {\n",
589 | " this.handle_save(this, null);\n",
590 | " } else {\n",
591 | " this.send_message(\"toolbar_button\", {name: name});\n",
592 | " }\n",
593 | "};\n",
594 | "\n",
595 | "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
596 | " this.message.textContent = tooltip;\n",
597 | "};\n",
598 | "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
599 | "\n",
600 | "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
601 | "\n",
602 | "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
603 | " // Create a \"websocket\"-like object which calls the given IPython comm\n",
604 | " // object with the appropriate methods. Currently this is a non binary\n",
605 | " // socket, so there is still some room for performance tuning.\n",
606 | " var ws = {};\n",
607 | "\n",
608 | " ws.close = function() {\n",
609 | " comm.close()\n",
610 | " };\n",
611 | " ws.send = function(m) {\n",
612 | " //console.log('sending', m);\n",
613 | " comm.send(m);\n",
614 | " };\n",
615 | " // Register the callback with on_msg.\n",
616 | " comm.on_msg(function(msg) {\n",
617 | " //console.log('receiving', msg['content']['data'], msg);\n",
618 | " // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
619 | " ws.onmessage(msg['content']['data'])\n",
620 | " });\n",
621 | " return ws;\n",
622 | "}\n",
623 | "\n",
624 | "mpl.mpl_figure_comm = function(comm, msg) {\n",
625 | " // This is the function which gets called when the mpl process\n",
626 | " // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
627 | "\n",
628 | " var id = msg.content.data.id;\n",
629 | " // Get hold of the div created by the display call when the Comm\n",
630 | " // socket was opened in Python.\n",
631 | " var element = $(\"#\" + id);\n",
632 | " var ws_proxy = comm_websocket_adapter(comm)\n",
633 | "\n",
634 | " function ondownload(figure, format) {\n",
635 | " window.open(figure.imageObj.src);\n",
636 | " }\n",
637 | "\n",
638 | " var fig = new mpl.figure(id, ws_proxy,\n",
639 | " ondownload,\n",
640 | " element.get(0));\n",
641 | "\n",
642 | " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
643 | " // web socket which is closed, not our websocket->open comm proxy.\n",
644 | " ws_proxy.onopen();\n",
645 | "\n",
646 | " fig.parent_element = element.get(0);\n",
647 | " fig.cell_info = mpl.find_output_cell(\"\");\n",
648 | " if (!fig.cell_info) {\n",
649 | " console.error(\"Failed to find cell for figure\", id, fig);\n",
650 | " return;\n",
651 | " }\n",
652 | "\n",
653 | " var output_index = fig.cell_info[2]\n",
654 | " var cell = fig.cell_info[0];\n",
655 | "\n",
656 | "};\n",
657 | "\n",
658 | "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
659 | " var width = fig.canvas.width/mpl.ratio\n",
660 | " fig.root.unbind('remove')\n",
661 | "\n",
662 | " // Update the output cell to use the data from the current canvas.\n",
663 | " fig.push_to_output();\n",
664 | " var dataURL = fig.canvas.toDataURL();\n",
665 | " // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
666 | " // the notebook keyboard shortcuts fail.\n",
667 | " IPython.keyboard_manager.enable()\n",
668 | " $(fig.parent_element).html('
');\n",
669 | " fig.close_ws(fig, msg);\n",
670 | "}\n",
671 | "\n",
672 | "mpl.figure.prototype.close_ws = function(fig, msg){\n",
673 | " fig.send_message('closing', msg);\n",
674 | " // fig.ws.close()\n",
675 | "}\n",
676 | "\n",
677 | "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
678 | " // Turn the data on the canvas into data in the output cell.\n",
679 | " var width = this.canvas.width/mpl.ratio\n",
680 | " var dataURL = this.canvas.toDataURL();\n",
681 | " this.cell_info[1]['text/html'] = '
';\n",
682 | "}\n",
683 | "\n",
684 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
685 | " // Tell IPython that the notebook contents must change.\n",
686 | " IPython.notebook.set_dirty(true);\n",
687 | " this.send_message(\"ack\", {});\n",
688 | " var fig = this;\n",
689 | " // Wait a second, then push the new image to the DOM so\n",
690 | " // that it is saved nicely (might be nice to debounce this).\n",
691 | " setTimeout(function () { fig.push_to_output() }, 1000);\n",
692 | "}\n",
693 | "\n",
694 | "mpl.figure.prototype._init_toolbar = function() {\n",
695 | " var fig = this;\n",
696 | "\n",
697 | " var nav_element = $('')\n",
698 | " nav_element.attr('style', 'width: 100%');\n",
699 | " this.root.append(nav_element);\n",
700 | "\n",
701 | " // Define a callback function for later on.\n",
702 | " function toolbar_event(event) {\n",
703 | " return fig.toolbar_button_onclick(event['data']);\n",
704 | " }\n",
705 | " function toolbar_mouse_event(event) {\n",
706 | " return fig.toolbar_button_onmouseover(event['data']);\n",
707 | " }\n",
708 | "\n",
709 | " for(var toolbar_ind in mpl.toolbar_items){\n",
710 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
711 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
712 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
713 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
714 | "\n",
715 | " if (!name) { continue; };\n",
716 | "\n",
717 | " var button = $('');\n",
718 | " button.click(method_name, toolbar_event);\n",
719 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
720 | " nav_element.append(button);\n",
721 | " }\n",
722 | "\n",
723 | " // Add the status bar.\n",
724 | " var status_bar = $('');\n",
725 | " nav_element.append(status_bar);\n",
726 | " this.message = status_bar[0];\n",
727 | "\n",
728 | " // Add the close button to the window.\n",
729 | " var buttongrp = $('');\n",
730 | " var button = $('');\n",
731 | " button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
732 | " button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
733 | " buttongrp.append(button);\n",
734 | " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
735 | " titlebar.prepend(buttongrp);\n",
736 | "}\n",
737 | "\n",
738 | "mpl.figure.prototype._root_extra_style = function(el){\n",
739 | " var fig = this\n",
740 | " el.on(\"remove\", function(){\n",
741 | "\tfig.close_ws(fig, {});\n",
742 | " });\n",
743 | "}\n",
744 | "\n",
745 | "mpl.figure.prototype._canvas_extra_style = function(el){\n",
746 | " // this is important to make the div 'focusable\n",
747 | " el.attr('tabindex', 0)\n",
748 | " // reach out to IPython and tell the keyboard manager to turn it's self\n",
749 | " // off when our div gets focus\n",
750 | "\n",
751 | " // location in version 3\n",
752 | " if (IPython.notebook.keyboard_manager) {\n",
753 | " IPython.notebook.keyboard_manager.register_events(el);\n",
754 | " }\n",
755 | " else {\n",
756 | " // location in version 2\n",
757 | " IPython.keyboard_manager.register_events(el);\n",
758 | " }\n",
759 | "\n",
760 | "}\n",
761 | "\n",
762 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
763 | " var manager = IPython.notebook.keyboard_manager;\n",
764 | " if (!manager)\n",
765 | " manager = IPython.keyboard_manager;\n",
766 | "\n",
767 | " // Check for shift+enter\n",
768 | " if (event.shiftKey && event.which == 13) {\n",
769 | " this.canvas_div.blur();\n",
770 | " event.shiftKey = false;\n",
771 | " // Send a \"J\" for go to next cell\n",
772 | " event.which = 74;\n",
773 | " event.keyCode = 74;\n",
774 | " manager.command_mode();\n",
775 | " manager.handle_keydown(event);\n",
776 | " }\n",
777 | "}\n",
778 | "\n",
779 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
780 | " fig.ondownload(fig, null);\n",
781 | "}\n",
782 | "\n",
783 | "\n",
784 | "mpl.find_output_cell = function(html_output) {\n",
785 | " // Return the cell and output element which can be found *uniquely* in the notebook.\n",
786 | " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
787 | " // IPython event is triggered only after the cells have been serialised, which for\n",
788 | " // our purposes (turning an active figure into a static one), is too late.\n",
789 | " var cells = IPython.notebook.get_cells();\n",
790 | " var ncells = cells.length;\n",
791 | " for (var i=0; i= 3 moved mimebundle to data attribute of output\n",
798 | " data = data.data;\n",
799 | " }\n",
800 | " if (data['text/html'] == html_output) {\n",
801 | " return [cell, data, j];\n",
802 | " }\n",
803 | " }\n",
804 | " }\n",
805 | " }\n",
806 | "}\n",
807 | "\n",
808 | "// Register the function which deals with the matplotlib target/channel.\n",
809 | "// The kernel may be null if the page has been refreshed.\n",
810 | "if (IPython.notebook.kernel != null) {\n",
811 | " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
812 | "}\n"
813 | ],
814 | "text/plain": [
815 | ""
816 | ]
817 | },
818 | "metadata": {},
819 | "output_type": "display_data"
820 | },
821 | {
822 | "data": {
823 | "text/html": [
824 | "
"
825 | ],
826 | "text/plain": [
827 | ""
828 | ]
829 | },
830 | "metadata": {},
831 | "output_type": "display_data"
832 | },
833 | {
834 | "name": "stderr",
835 | "output_type": "stream",
836 | "text": [
837 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 18.38479: 100%|██████████| 500/500 [00:41<00:00, 11.97it/s]\n"
838 | ]
839 | },
840 | {
841 | "name": "stdout",
842 | "output_type": "stream",
843 | "text": [
844 | "Aspect 0: booth, mirror, dark, wall, art, comfy, picture, suit\n",
845 | "Aspect 1: indian, spark, peter, morton, ribbon, luger, life, nobu\n",
846 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
847 | "Aspect 3: seitan, foreign, concoction, pure, buco, solely, marketing, produced\n",
848 | "Aspect 4: feeling, home, italian, taste, cake, homemade, feel, satisfied\n",
849 | "Aspect 5: wedding, celebrate, saturday, yesterday, colleague, bday, meeting, holiday\n",
850 | "Aspect 6: min, already, busboy, repeatedly, asked, eventually, asking, leaving\n",
851 | "Aspect 7: traditional, treat, fare, southern, ingredient, cooking, cuisine, food\n",
852 | "Aspect 8: coffee, strawberry, espresso, tea, milk, apple, cup, cooky\n",
853 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
854 | "Aspect 10: steamed, quesadilla, veal, vegetable, rubbery, overcooked, clam, bone\n",
855 | "Aspect 11: location, brooklyn, midtown, height, queen, ues, joint, tribeca\n",
856 | "Aspect 12: didnt, matter, might, cause, mind, unless, guess, dont\n",
857 | "Aspect 13: good, outstanding, terrific, fabulous, great, excellent, amazing, fantastic\n",
858 | "Aspect 14: backyard, lighting, ambience, atmosphere, garden, patio, view, weather\n"
859 | ]
860 | },
861 | {
862 | "name": "stderr",
863 | "output_type": "stream",
864 | "text": [
865 | "TRAIN EPOCH: 1 | LR: 0.00010 | MEAN-TRAIN-LOSS: 15.51406: 100%|██████████| 500/500 [03:47<00:00, 2.05it/s] \n",
866 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 17.90987: 100%|██████████| 500/500 [00:41<00:00, 11.12it/s]\n",
867 | " 0%| | 0/500 [00:00, ?it/s]"
868 | ]
869 | },
870 | {
871 | "name": "stdout",
872 | "output_type": "stream",
873 | "text": [
874 | "Aspect 0: booth, dark, mirror, wall, art, comfy, suit, picture\n",
875 | "Aspect 1: spark, indian, peter, morton, ribbon, luger, life, nobu\n",
876 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
877 | "Aspect 3: foreign, painful, seitan, buco, solely, pure, marketing, produced\n",
878 | "Aspect 4: home, italian, feeling, taste, cake, homemade, feel, satisfied\n",
879 | "Aspect 5: wedding, celebrate, saturday, yesterday, bday, holiday, meeting, colleague\n",
880 | "Aspect 6: minute, min, already, busboy, asked, asking, eventually, leaving\n",
881 | "Aspect 7: healthy, treat, fare, southern, ingredient, cooking, cuisine, food\n",
882 | "Aspect 8: coffee, strawberry, espresso, tea, milk, apple, cup, cooky\n",
883 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
884 | "Aspect 10: spiced, veal, quesadilla, vegetable, rubbery, overcooked, clam, bone\n",
885 | "Aspect 11: location, brooklyn, height, midtown, queen, joint, ues, tribeca\n",
886 | "Aspect 12: didnt, matter, might, mind, cause, unless, guess, dont\n",
887 | "Aspect 13: good, outstanding, terrific, fabulous, great, excellent, amazing, fantastic\n",
888 | "Aspect 14: backyard, lighting, ambience, atmosphere, garden, patio, view, weather\n"
889 | ]
890 | },
891 | {
892 | "name": "stderr",
893 | "output_type": "stream",
894 | "text": [
895 | "TRAIN EPOCH: 2 | LR: 0.00009 | MEAN-TRAIN-LOSS: 12.06173: 100%|██████████| 500/500 [03:48<00:00, 2.24it/s] \n",
896 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 14.96455: 100%|██████████| 500/500 [00:42<00:00, 12.49it/s]\n",
897 | " 0%| | 0/500 [00:00, ?it/s]"
898 | ]
899 | },
900 | {
901 | "name": "stdout",
902 | "output_type": "stream",
903 | "text": [
904 | "Aspect 0: wall, curtain, lined, comfy, mirror, art, suit, picture\n",
905 | "Aspect 1: indian, spark, peter, morton, ribbon, luger, life, nobu\n",
906 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
907 | "Aspect 3: buco, solely, substandard, seitan, victim, marketing, produced, pure\n",
908 | "Aspect 4: delight, home, feeling, taste, cake, homemade, feel, satisfied\n",
909 | "Aspect 5: girlfriend, saturday, today, bday, holiday, yesterday, meeting, colleague\n",
910 | "Aspect 6: asked, min, apologized, already, busboy, asking, leaving, eventually\n",
911 | "Aspect 7: healthy, fare, southern, treat, ingredient, food, cooking, cuisine\n",
912 | "Aspect 8: shot, strawberry, espresso, tea, milk, apple, cup, cooky\n",
913 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
914 | "Aspect 10: chunk, vegetable, veal, quesadilla, clam, rubbery, overcooked, bone\n",
915 | "Aspect 11: location, brooklyn, midtown, height, queen, joint, tribeca, ues\n",
916 | "Aspect 12: hate, matter, mind, unless, might, cause, guess, dont\n",
917 | "Aspect 13: terrific, good, outstanding, fabulous, great, excellent, amazing, fantastic\n",
918 | "Aspect 14: backyard, inside, ambience, atmosphere, garden, patio, view, weather\n"
919 | ]
920 | },
921 | {
922 | "name": "stderr",
923 | "output_type": "stream",
924 | "text": [
925 | "TRAIN EPOCH: 3 | LR: 0.00009 | MEAN-TRAIN-LOSS: 3.03627: 100%|██████████| 500/500 [03:47<00:00, 2.25it/s] \n",
926 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 13.23022: 100%|██████████| 500/500 [00:41<00:00, 12.21it/s]\n",
927 | " 0%| | 0/500 [00:00, ?it/s]"
928 | ]
929 | },
930 | {
931 | "name": "stdout",
932 | "output_type": "stream",
933 | "text": [
934 | "Aspect 0: booth, comfy, curtain, lined, mirror, suit, art, picture\n",
935 | "Aspect 1: spark, indian, peter, morton, ribbon, luger, life, nobu\n",
936 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
937 | "Aspect 3: substandard, hungarian, cracker, produced, profit, victim, marketing, pure\n",
938 | "Aspect 4: delight, home, feeling, taste, cake, homemade, feel, satisfied\n",
939 | "Aspect 5: parent, bday, girlfriend, holiday, today, yesterday, meeting, colleague\n",
940 | "Aspect 6: asked, min, already, apologized, asking, busboy, leaving, eventually\n",
941 | "Aspect 7: healthy, fare, southern, ingredient, treat, cooking, cuisine, food\n",
942 | "Aspect 8: shot, strawberry, espresso, milk, tea, apple, cup, cooky\n",
943 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
944 | "Aspect 10: vegetable, chunk, veal, quesadilla, clam, overcooked, rubbery, bone\n",
945 | "Aspect 11: brooklyn, location, height, midtown, tribeca, queen, joint, ues\n",
946 | "Aspect 12: hate, matter, mind, unless, cause, might, guess, dont\n",
947 | "Aspect 13: terrific, good, outstanding, fabulous, great, excellent, amazing, fantastic\n",
948 | "Aspect 14: backyard, inside, ambience, atmosphere, garden, patio, view, weather\n"
949 | ]
950 | },
951 | {
952 | "name": "stderr",
953 | "output_type": "stream",
954 | "text": [
955 | "TRAIN EPOCH: 4 | LR: 0.00008 | MEAN-TRAIN-LOSS: 5.27515: 100%|██████████| 500/500 [03:49<00:00, 2.07it/s] \n",
956 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 13.14703: 100%|██████████| 500/500 [00:41<00:00, 12.05it/s]\n",
957 | " 0%| | 0/500 [00:00, ?it/s]"
958 | ]
959 | },
960 | {
961 | "name": "stdout",
962 | "output_type": "stream",
963 | "text": [
964 | "Aspect 0: lined, comfy, curtain, wall, mirror, suit, art, picture\n",
965 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
966 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
967 | "Aspect 3: substandard, cracker, hungarian, produced, victim, profit, marketing, pure\n",
968 | "Aspect 4: delight, home, feeling, taste, cake, homemade, feel, satisfied\n",
969 | "Aspect 5: client, parent, girlfriend, holiday, today, yesterday, meeting, colleague\n",
970 | "Aspect 6: asked, min, already, apologized, asking, busboy, leaving, eventually\n",
971 | "Aspect 7: healthy, southern, fare, ingredient, treat, cooking, cuisine, food\n",
972 | "Aspect 8: shot, strawberry, espresso, milk, tea, apple, cup, cooky\n",
973 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
974 | "Aspect 10: vegetable, chunk, veal, quesadilla, clam, overcooked, rubbery, bone\n",
975 | "Aspect 11: brooklyn, location, midtown, height, tribeca, queen, joint, ues\n",
976 | "Aspect 12: hate, matter, mind, unless, cause, might, guess, dont\n",
977 | "Aspect 13: terrific, good, outstanding, fabulous, great, excellent, amazing, fantastic\n",
978 | "Aspect 14: inside, backyard, atmosphere, ambience, garden, patio, view, weather\n"
979 | ]
980 | },
981 | {
982 | "name": "stderr",
983 | "output_type": "stream",
984 | "text": [
985 | "TRAIN EPOCH: 5 | LR: 0.00008 | MEAN-TRAIN-LOSS: 19.31808: 100%|██████████| 500/500 [03:40<00:00, 2.48it/s] \n",
986 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 12.25975: 100%|██████████| 500/500 [00:32<00:00, 15.67it/s]\n",
987 | " 0%| | 0/500 [00:00, ?it/s]"
988 | ]
989 | },
990 | {
991 | "name": "stdout",
992 | "output_type": "stream",
993 | "text": [
994 | "Aspect 0: lined, booth, curtain, wall, suit, mirror, art, picture\n",
995 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
996 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
997 | "Aspect 3: cracker, substandard, hungarian, marketing, victim, produced, profit, pure\n",
998 | "Aspect 4: italian, home, feeling, taste, cake, homemade, feel, satisfied\n",
999 | "Aspect 5: bday, parent, girlfriend, holiday, today, yesterday, meeting, colleague\n",
1000 | "Aspect 6: minute, min, apologized, already, asking, busboy, leaving, eventually\n",
1001 | "Aspect 7: healthy, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1002 | "Aspect 8: shot, strawberry, espresso, milk, tea, apple, cup, cooky\n",
1003 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
1004 | "Aspect 10: vegetable, chunk, veal, quesadilla, clam, overcooked, rubbery, bone\n",
1005 | "Aspect 11: location, brooklyn, midtown, height, tribeca, queen, joint, ues\n",
1006 | "Aspect 12: hate, matter, mind, unless, cause, might, guess, dont\n",
1007 | "Aspect 13: terrific, good, outstanding, fabulous, great, excellent, amazing, fantastic\n",
1008 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1009 | ]
1010 | },
1011 | {
1012 | "name": "stderr",
1013 | "output_type": "stream",
1014 | "text": [
1015 | "TRAIN EPOCH: 6 | LR: 0.00007 | MEAN-TRAIN-LOSS: 4.42645: 100%|██████████| 500/500 [03:35<00:00, 2.44it/s] \n",
1016 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 11.49877: 100%|██████████| 500/500 [00:32<00:00, 15.30it/s]\n",
1017 | " 0%| | 0/500 [00:00, ?it/s]"
1018 | ]
1019 | },
1020 | {
1021 | "name": "stdout",
1022 | "output_type": "stream",
1023 | "text": [
1024 | "Aspect 0: lined, booth, curtain, wall, suit, mirror, art, picture\n",
1025 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1026 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1027 | "Aspect 3: cracker, substandard, marketing, hungarian, victim, profit, produced, pure\n",
1028 | "Aspect 4: delight, home, feeling, taste, cake, homemade, feel, satisfied\n",
1029 | "Aspect 5: bday, parent, girlfriend, holiday, today, yesterday, meeting, colleague\n",
1030 | "Aspect 6: minute, min, apologized, already, asking, busboy, leaving, eventually\n",
1031 | "Aspect 7: healthy, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1032 | "Aspect 8: shot, strawberry, espresso, milk, tea, apple, cup, cooky\n",
1033 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
1034 | "Aspect 10: vegetable, chunk, veal, quesadilla, clam, rubbery, overcooked, bone\n",
1035 | "Aspect 11: location, brooklyn, midtown, height, tribeca, queen, joint, ues\n",
1036 | "Aspect 12: cash, matter, mind, unless, cause, might, guess, dont\n",
1037 | "Aspect 13: terrific, good, outstanding, fabulous, great, excellent, amazing, fantastic\n",
1038 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1039 | ]
1040 | },
1041 | {
1042 | "name": "stderr",
1043 | "output_type": "stream",
1044 | "text": [
1045 | "TRAIN EPOCH: 7 | LR: 0.00007 | MEAN-TRAIN-LOSS: 32.07655: 100%|██████████| 500/500 [03:35<00:00, 2.22it/s] \n",
1046 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 11.76541: 100%|██████████| 500/500 [00:32<00:00, 15.68it/s]\n",
1047 | " 0%| | 0/500 [00:00, ?it/s]"
1048 | ]
1049 | },
1050 | {
1051 | "name": "stdout",
1052 | "output_type": "stream",
1053 | "text": [
1054 | "Aspect 0: comfy, curtain, booth, wall, suit, mirror, art, picture\n",
1055 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1056 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1057 | "Aspect 3: painful, substandard, hungarian, marketing, victim, profit, produced, pure\n",
1058 | "Aspect 4: delight, home, feeling, taste, cake, homemade, feel, satisfied\n",
1059 | "Aspect 5: saturday, parent, girlfriend, holiday, today, yesterday, meeting, colleague\n",
1060 | "Aspect 6: minute, apologized, already, min, asking, busboy, leaving, eventually\n",
1061 | "Aspect 7: healthy, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1062 | "Aspect 8: shot, strawberry, espresso, milk, tea, apple, cup, cooky\n",
1063 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
1064 | "Aspect 10: vegetable, chunk, veal, quesadilla, clam, rubbery, overcooked, bone\n",
1065 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, joint, ues\n",
1066 | "Aspect 12: cash, matter, mind, unless, cause, might, guess, dont\n",
1067 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1068 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1069 | ]
1070 | },
1071 | {
1072 | "name": "stderr",
1073 | "output_type": "stream",
1074 | "text": [
1075 | "TRAIN EPOCH: 8 | LR: 0.00006 | MEAN-TRAIN-LOSS: 59.14827: 100%|██████████| 500/500 [03:31<00:00, 2.36it/s] \n",
1076 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 11.67389: 100%|██████████| 500/500 [00:32<00:00, 14.88it/s]\n",
1077 | " 0%| | 0/500 [00:00, ?it/s]"
1078 | ]
1079 | },
1080 | {
1081 | "name": "stdout",
1082 | "output_type": "stream",
1083 | "text": [
1084 | "Aspect 0: comfy, curtain, booth, wall, suit, mirror, art, picture\n",
1085 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1086 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1087 | "Aspect 3: seitan, marketing, substandard, hungarian, victim, profit, produced, pure\n",
1088 | "Aspect 4: delight, home, feeling, taste, cake, homemade, feel, satisfied\n",
1089 | "Aspect 5: client, parent, holiday, girlfriend, today, yesterday, meeting, colleague\n",
1090 | "Aspect 6: minute, already, apologized, min, asking, busboy, leaving, eventually\n",
1091 | "Aspect 7: healthy, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1092 | "Aspect 8: shot, strawberry, espresso, milk, tea, apple, cup, cooky\n",
1093 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
1094 | "Aspect 10: vegetable, chunk, veal, quesadilla, clam, rubbery, overcooked, bone\n",
1095 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1096 | "Aspect 12: cash, matter, mind, unless, cause, might, guess, dont\n",
1097 | "Aspect 13: terrific, good, outstanding, fabulous, great, excellent, amazing, fantastic\n",
1098 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1099 | ]
1100 | },
1101 | {
1102 | "name": "stderr",
1103 | "output_type": "stream",
1104 | "text": [
1105 | "TRAIN EPOCH: 9 | LR: 0.00006 | MEAN-TRAIN-LOSS: 1.41365: 100%|██████████| 500/500 [03:33<00:00, 2.14it/s] \n",
1106 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 10.91851: 100%|██████████| 500/500 [00:32<00:00, 15.04it/s]\n",
1107 | " 0%| | 0/500 [00:00, ?it/s]"
1108 | ]
1109 | },
1110 | {
1111 | "name": "stdout",
1112 | "output_type": "stream",
1113 | "text": [
1114 | "Aspect 0: comfy, curtain, booth, wall, suit, mirror, art, picture\n",
1115 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1116 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1117 | "Aspect 3: marketing, seitan, substandard, hungarian, victim, profit, produced, pure\n",
1118 | "Aspect 4: delight, home, taste, feeling, cake, homemade, feel, satisfied\n",
1119 | "Aspect 5: saturday, parent, holiday, girlfriend, today, yesterday, meeting, colleague\n",
1120 | "Aspect 6: minute, already, apologized, min, asking, busboy, leaving, eventually\n",
1121 | "Aspect 7: healthy, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1122 | "Aspect 8: martini, strawberry, espresso, milk, apple, tea, cup, cooky\n",
1123 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
1124 | "Aspect 10: vegetable, chunk, veal, quesadilla, clam, rubbery, overcooked, bone\n",
1125 | "Aspect 11: location, brooklyn, midtown, height, tribeca, queen, ues, joint\n",
1126 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1127 | "Aspect 13: terrific, good, outstanding, fabulous, great, excellent, amazing, fantastic\n",
1128 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1129 | ]
1130 | },
1131 | {
1132 | "name": "stderr",
1133 | "output_type": "stream",
1134 | "text": [
1135 | "TRAIN EPOCH: 10 | LR: 0.00005 | MEAN-TRAIN-LOSS: 1.20811: 100%|██████████| 500/500 [03:33<00:00, 2.60it/s] \n",
1136 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 10.72387: 100%|██████████| 500/500 [00:32<00:00, 14.68it/s]\n",
1137 | " 0%| | 0/500 [00:00, ?it/s]"
1138 | ]
1139 | },
1140 | {
1141 | "name": "stdout",
1142 | "output_type": "stream",
1143 | "text": [
1144 | "Aspect 0: comfy, curtain, booth, wall, suit, mirror, art, picture\n",
1145 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1146 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1147 | "Aspect 3: painful, seitan, substandard, victim, hungarian, profit, produced, pure\n",
1148 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1149 | "Aspect 5: saturday, parent, holiday, girlfriend, today, yesterday, meeting, colleague\n",
1150 | "Aspect 6: minute, already, apologized, min, asking, busboy, leaving, eventually\n",
1151 | "Aspect 7: healthy, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1152 | "Aspect 8: martini, strawberry, espresso, milk, apple, tea, cup, cooky\n",
1153 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
1154 | "Aspect 10: vegetable, chunk, veal, quesadilla, clam, rubbery, overcooked, bone\n",
1155 | "Aspect 11: location, brooklyn, midtown, height, tribeca, queen, ues, joint\n",
1156 | "Aspect 12: cash, matter, cause, mind, unless, might, guess, dont\n",
1157 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1158 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1159 | ]
1160 | },
1161 | {
1162 | "name": "stderr",
1163 | "output_type": "stream",
1164 | "text": [
1165 | "TRAIN EPOCH: 11 | LR: 0.00005 | MEAN-TRAIN-LOSS: 0.76233: 100%|██████████| 500/500 [03:35<00:00, 2.25it/s] \n",
1166 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 10.91820: 100%|██████████| 500/500 [00:32<00:00, 14.88it/s]\n",
1167 | " 0%| | 0/500 [00:00, ?it/s]"
1168 | ]
1169 | },
1170 | {
1171 | "name": "stdout",
1172 | "output_type": "stream",
1173 | "text": [
1174 | "Aspect 0: comfy, curtain, booth, wall, suit, mirror, art, picture\n",
1175 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1176 | "Aspect 2: manner, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1177 | "Aspect 3: painful, substandard, seitan, victim, hungarian, profit, produced, pure\n",
1178 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1179 | "Aspect 5: client, parent, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1180 | "Aspect 6: minute, already, apologized, min, asking, busboy, leaving, eventually\n",
1181 | "Aspect 7: healthy, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1182 | "Aspect 8: martini, strawberry, espresso, milk, apple, tea, cup, cooky\n",
1183 | "Aspect 9: list, prix, tapa, price, bargain, choice, priced, selection\n",
1184 | "Aspect 10: chunk, vegetable, veal, quesadilla, clam, rubbery, overcooked, bone\n",
1185 | "Aspect 11: location, brooklyn, midtown, height, tribeca, queen, ues, joint\n",
1186 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1187 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1188 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1189 | ]
1190 | },
1191 | {
1192 | "name": "stderr",
1193 | "output_type": "stream",
1194 | "text": [
1195 | "TRAIN EPOCH: 12 | LR: 0.00004 | MEAN-TRAIN-LOSS: 23.23868: 100%|██████████| 500/500 [03:37<00:00, 2.51it/s] \n",
1196 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 10.76833: 100%|██████████| 500/500 [00:32<00:00, 15.36it/s]\n",
1197 | " 0%| | 0/500 [00:00, ?it/s]"
1198 | ]
1199 | },
1200 | {
1201 | "name": "stdout",
1202 | "output_type": "stream",
1203 | "text": [
1204 | "Aspect 0: comfy, curtain, booth, wall, suit, mirror, art, picture\n",
1205 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1206 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1207 | "Aspect 3: painful, substandard, seitan, victim, hungarian, profit, produced, pure\n",
1208 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1209 | "Aspect 5: saturday, parent, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1210 | "Aspect 6: minute, already, apologized, min, asking, busboy, leaving, eventually\n",
1211 | "Aspect 7: soup, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1212 | "Aspect 8: martini, strawberry, espresso, milk, apple, tea, cup, cooky\n",
1213 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1214 | "Aspect 10: chunk, vegetable, veal, quesadilla, clam, rubbery, overcooked, bone\n",
1215 | "Aspect 11: location, brooklyn, midtown, height, tribeca, queen, ues, joint\n",
1216 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1217 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1218 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1219 | ]
1220 | },
1221 | {
1222 | "name": "stderr",
1223 | "output_type": "stream",
1224 | "text": [
1225 | "TRAIN EPOCH: 13 | LR: 0.00004 | MEAN-TRAIN-LOSS: 2.89502: 100%|██████████| 500/500 [03:34<00:00, 2.50it/s] \n",
1226 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 11.53663: 100%|██████████| 500/500 [00:33<00:00, 15.48it/s]\n",
1227 | " 0%| | 0/500 [00:00, ?it/s]"
1228 | ]
1229 | },
1230 | {
1231 | "name": "stdout",
1232 | "output_type": "stream",
1233 | "text": [
1234 | "Aspect 0: comfy, curtain, booth, wall, mirror, suit, art, picture\n",
1235 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1236 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1237 | "Aspect 3: painful, substandard, seitan, victim, hungarian, profit, produced, pure\n",
1238 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1239 | "Aspect 5: parent, saturday, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1240 | "Aspect 6: already, minute, apologized, min, asking, busboy, leaving, eventually\n",
1241 | "Aspect 7: soup, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1242 | "Aspect 8: strawberry, martini, espresso, milk, apple, tea, cup, cooky\n",
1243 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1244 | "Aspect 10: chunk, vegetable, veal, quesadilla, clam, rubbery, overcooked, bone\n",
1245 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1246 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1247 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1248 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1249 | ]
1250 | },
1251 | {
1252 | "name": "stderr",
1253 | "output_type": "stream",
1254 | "text": [
1255 | "TRAIN EPOCH: 14 | LR: 0.00003 | MEAN-TRAIN-LOSS: 1.80120: 100%|██████████| 500/500 [03:33<00:00, 2.24it/s] \n",
1256 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 10.62922: 100%|██████████| 500/500 [00:32<00:00, 15.45it/s]\n",
1257 | " 0%| | 0/500 [00:00, ?it/s]"
1258 | ]
1259 | },
1260 | {
1261 | "name": "stdout",
1262 | "output_type": "stream",
1263 | "text": [
1264 | "Aspect 0: comfy, curtain, booth, wall, mirror, suit, art, picture\n",
1265 | "Aspect 1: spark, peter, indian, morton, ribbon, luger, life, nobu\n",
1266 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1267 | "Aspect 3: painful, substandard, seitan, victim, hungarian, profit, produced, pure\n",
1268 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1269 | "Aspect 5: parent, saturday, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1270 | "Aspect 6: minute, already, apologized, asking, min, busboy, leaving, eventually\n",
1271 | "Aspect 7: soup, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1272 | "Aspect 8: strawberry, martini, espresso, milk, apple, tea, cup, cooky\n",
1273 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1274 | "Aspect 10: chunk, vegetable, veal, quesadilla, rubbery, clam, overcooked, bone\n",
1275 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1276 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1277 | "Aspect 13: terrific, good, outstanding, fabulous, great, amazing, excellent, fantastic\n",
1278 | "Aspect 14: inside, backyard, atmosphere, garden, ambience, patio, view, weather\n"
1279 | ]
1280 | },
1281 | {
1282 | "name": "stderr",
1283 | "output_type": "stream",
1284 | "text": [
1285 | "TRAIN EPOCH: 15 | LR: 0.00003 | MEAN-TRAIN-LOSS: 5.10454: 100%|██████████| 500/500 [03:37<00:00, 2.22it/s] \n",
1286 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 11.15482: 100%|██████████| 500/500 [00:32<00:00, 15.38it/s]\n",
1287 | " 0%| | 0/500 [00:00, ?it/s]"
1288 | ]
1289 | },
1290 | {
1291 | "name": "stdout",
1292 | "output_type": "stream",
1293 | "text": [
1294 | "Aspect 0: comfy, curtain, booth, wall, mirror, suit, art, picture\n",
1295 | "Aspect 1: manhattan, peter, indian, morton, ribbon, luger, life, nobu\n",
1296 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1297 | "Aspect 3: painful, substandard, seitan, victim, hungarian, profit, produced, pure\n",
1298 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1299 | "Aspect 5: parent, saturday, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1300 | "Aspect 6: minute, already, apologized, min, asking, busboy, leaving, eventually\n",
1301 | "Aspect 7: soup, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1302 | "Aspect 8: strawberry, martini, espresso, milk, apple, tea, cup, cooky\n",
1303 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1304 | "Aspect 10: chunk, vegetable, veal, quesadilla, rubbery, clam, overcooked, bone\n",
1305 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1306 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1307 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1308 | "Aspect 14: inside, backyard, atmosphere, garden, patio, ambience, view, weather\n"
1309 | ]
1310 | },
1311 | {
1312 | "name": "stderr",
1313 | "output_type": "stream",
1314 | "text": [
1315 | "TRAIN EPOCH: 16 | LR: 0.00002 | MEAN-TRAIN-LOSS: 1.92227: 100%|██████████| 500/500 [03:35<00:00, 2.18it/s] \n",
1316 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 9.89297: 100%|██████████| 500/500 [00:32<00:00, 16.87it/s] \n",
1317 | " 0%| | 0/500 [00:00, ?it/s]"
1318 | ]
1319 | },
1320 | {
1321 | "name": "stdout",
1322 | "output_type": "stream",
1323 | "text": [
1324 | "Aspect 0: tv, curtain, booth, wall, suit, mirror, art, picture\n",
1325 | "Aspect 1: manhattan, peter, indian, morton, ribbon, luger, life, nobu\n",
1326 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1327 | "Aspect 3: painful, substandard, seitan, victim, hungarian, profit, produced, pure\n",
1328 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1329 | "Aspect 5: parent, saturday, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1330 | "Aspect 6: already, minute, apologized, asking, min, busboy, leaving, eventually\n",
1331 | "Aspect 7: soup, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1332 | "Aspect 8: strawberry, martini, espresso, milk, apple, tea, cup, cooky\n",
1333 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1334 | "Aspect 10: chunk, vegetable, veal, quesadilla, rubbery, clam, overcooked, bone\n",
1335 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1336 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1337 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1338 | "Aspect 14: backyard, inside, atmosphere, garden, patio, ambience, view, weather\n"
1339 | ]
1340 | },
1341 | {
1342 | "name": "stderr",
1343 | "output_type": "stream",
1344 | "text": [
1345 | "TRAIN EPOCH: 17 | LR: 0.00002 | MEAN-TRAIN-LOSS: 4.52363: 100%|██████████| 500/500 [03:40<00:00, 2.36it/s] \n",
1346 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 10.64311: 100%|██████████| 500/500 [00:42<00:00, 11.76it/s]\n",
1347 | " 0%| | 0/500 [00:00, ?it/s]"
1348 | ]
1349 | },
1350 | {
1351 | "name": "stdout",
1352 | "output_type": "stream",
1353 | "text": [
1354 | "Aspect 0: tv, curtain, booth, wall, mirror, suit, art, picture\n",
1355 | "Aspect 1: manhattan, peter, indian, morton, ribbon, luger, life, nobu\n",
1356 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1357 | "Aspect 3: painful, substandard, victim, seitan, hungarian, profit, produced, pure\n",
1358 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1359 | "Aspect 5: parent, saturday, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1360 | "Aspect 6: already, minute, apologized, min, asking, busboy, leaving, eventually\n",
1361 | "Aspect 7: soup, southern, fare, ingredient, treat, cooking, cuisine, food\n",
1362 | "Aspect 8: strawberry, martini, espresso, apple, milk, tea, cup, cooky\n",
1363 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1364 | "Aspect 10: chunk, vegetable, veal, quesadilla, rubbery, clam, overcooked, bone\n",
1365 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1366 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1367 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1368 | "Aspect 14: backyard, inside, atmosphere, garden, patio, ambience, view, weather\n"
1369 | ]
1370 | },
1371 | {
1372 | "name": "stderr",
1373 | "output_type": "stream",
1374 | "text": [
1375 | "TRAIN EPOCH: 18 | LR: 0.00001 | MEAN-TRAIN-LOSS: 59.69238: 100%|██████████| 500/500 [03:45<00:00, 2.40it/s] \n",
1376 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 11.10278: 100%|██████████| 500/500 [00:41<00:00, 11.97it/s]\n",
1377 | " 0%| | 0/500 [00:00, ?it/s]"
1378 | ]
1379 | },
1380 | {
1381 | "name": "stdout",
1382 | "output_type": "stream",
1383 | "text": [
1384 | "Aspect 0: tv, curtain, booth, wall, mirror, suit, art, picture\n",
1385 | "Aspect 1: manhattan, peter, indian, morton, ribbon, luger, life, nobu\n",
1386 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1387 | "Aspect 3: painful, substandard, victim, seitan, hungarian, profit, produced, pure\n",
1388 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1389 | "Aspect 5: parent, saturday, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1390 | "Aspect 6: already, apologized, minute, asking, min, busboy, leaving, eventually\n",
1391 | "Aspect 7: southern, soup, fare, ingredient, treat, cooking, cuisine, food\n",
1392 | "Aspect 8: strawberry, martini, espresso, apple, milk, tea, cup, cooky\n",
1393 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1394 | "Aspect 10: chunk, vegetable, veal, quesadilla, rubbery, clam, overcooked, bone\n",
1395 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1396 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1397 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1398 | "Aspect 14: backyard, inside, atmosphere, garden, patio, ambience, view, weather\n"
1399 | ]
1400 | },
1401 | {
1402 | "name": "stderr",
1403 | "output_type": "stream",
1404 | "text": [
1405 | "TRAIN EPOCH: 19 | LR: 0.00001 | MEAN-TRAIN-LOSS: 0.96658: 100%|██████████| 500/500 [03:46<00:00, 2.31it/s] \n",
1406 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 10.77935: 100%|██████████| 500/500 [00:38<00:00, 13.09it/s]\n",
1407 | " 0%| | 0/500 [00:00, ?it/s]"
1408 | ]
1409 | },
1410 | {
1411 | "name": "stdout",
1412 | "output_type": "stream",
1413 | "text": [
1414 | "Aspect 0: tv, curtain, booth, wall, mirror, suit, art, picture\n",
1415 | "Aspect 1: manhattan, peter, indian, morton, ribbon, luger, life, nobu\n",
1416 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1417 | "Aspect 3: painful, substandard, victim, seitan, hungarian, profit, produced, pure\n",
1418 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1419 | "Aspect 5: parent, saturday, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1420 | "Aspect 6: already, apologized, minute, asking, min, busboy, leaving, eventually\n",
1421 | "Aspect 7: southern, soup, fare, ingredient, treat, cooking, cuisine, food\n",
1422 | "Aspect 8: strawberry, martini, espresso, apple, milk, tea, cup, cooky\n",
1423 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1424 | "Aspect 10: chunk, vegetable, veal, quesadilla, rubbery, clam, overcooked, bone\n",
1425 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1426 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1427 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1428 | "Aspect 14: backyard, inside, atmosphere, garden, patio, ambience, view, weather\n"
1429 | ]
1430 | },
1431 | {
1432 | "name": "stderr",
1433 | "output_type": "stream",
1434 | "text": [
1435 | "TRAIN EPOCH: 20 | LR: 0.00000 | MEAN-TRAIN-LOSS: 3.73918: 100%|██████████| 500/500 [03:35<00:00, 2.43it/s] \n",
1436 | "VAL BATCH: 500 | MEAN-VAL-LOSS: 10.63897: 100%|██████████| 500/500 [00:32<00:00, 16.04it/s]\n"
1437 | ]
1438 | },
1439 | {
1440 | "name": "stdout",
1441 | "output_type": "stream",
1442 | "text": [
1443 | "Aspect 0: tv, curtain, booth, wall, mirror, suit, art, picture\n",
1444 | "Aspect 1: manhattan, peter, indian, morton, ribbon, luger, life, nobu\n",
1445 | "Aspect 2: rude, waiter, waitress, hostess, host, server, staff, waitstaff\n",
1446 | "Aspect 3: painful, substandard, victim, seitan, hungarian, profit, produced, pure\n",
1447 | "Aspect 4: room, home, taste, feeling, cake, homemade, feel, satisfied\n",
1448 | "Aspect 5: parent, saturday, holiday, today, girlfriend, yesterday, meeting, colleague\n",
1449 | "Aspect 6: already, apologized, minute, asking, min, busboy, leaving, eventually\n",
1450 | "Aspect 7: southern, soup, fare, ingredient, treat, cooking, cuisine, food\n",
1451 | "Aspect 8: strawberry, martini, espresso, apple, milk, tea, cup, cooky\n",
1452 | "Aspect 9: list, tapa, prix, price, bargain, choice, priced, selection\n",
1453 | "Aspect 10: chunk, vegetable, veal, quesadilla, rubbery, clam, overcooked, bone\n",
1454 | "Aspect 11: location, brooklyn, height, midtown, tribeca, queen, ues, joint\n",
1455 | "Aspect 12: hate, matter, cause, mind, unless, might, guess, dont\n",
1456 | "Aspect 13: terrific, outstanding, good, fabulous, great, excellent, amazing, fantastic\n",
1457 | "Aspect 14: backyard, inside, atmosphere, garden, patio, ambience, view, weather\n",
1458 | "saving abae model: \"./data/restaurant.train.txt.prep.abae.pt\"\n"
1459 | ]
1460 | }
1461 | ],
1462 | "source": [
1463 | "import os\n",
1464 | "import time\n",
1465 | "from abae_pytorch.data import dataloader, preprocess\n",
1466 | "from abae_pytorch.train import train\n",
1467 | "from abae_pytorch import aspect_model\n",
1468 | "\n",
1469 | "\n",
1470 | "#data = './data/wiki_01'\n",
1471 | "#data = './data/beer.train.txt'\n",
1472 | "data = './data/restaurant.train.txt'\n",
1473 | "prep = data + '.prep'\n",
1474 | "if not os.path.isfile(prep):\n",
1475 | " preprocess(data, prep)\n",
1476 | "\n",
1477 | "\n",
1478 | "min_count = 25\n",
1479 | "d_embed = 200\n",
1480 | "n_aspects = 15\n",
1481 | "w2v = prep + '.w2v'\n",
1482 | "\n",
1483 | "abae_path = prep + '.abae.pt'\n",
1484 | "device = 'cpu'\n",
1485 | "\n",
1486 | "aspector = aspect_model(prep, w2v, min_count, d_embed, n_aspects, device)\n",
1487 | "if os.path.isfile(abae_path):\n",
1488 | " aspector.load_abae(abae_path)\n",
1489 | "\n",
1490 | "\n",
1491 | "x = (aspector.w2v.n_vocab, aspector.w2v.d_embed, aspector.w2v.n_aspects)\n",
1492 | "print('n_vocab: %d | d_embed: %d | n_aspects: %d' % x)\n",
1493 | "\n",
1494 | "\n",
1495 | "split = {'train': 0.9, 'val': 0.05, 'test': 0.05}\n",
1496 | "with dataloader(aspector.w2v.w2i, prep, split=split) as dl:\n",
1497 | "\n",
1498 | " epochs = 20\n",
1499 | " epochsize = 500\n",
1500 | " batchsize = 100\n",
1501 | " negsize = 20\n",
1502 | " ortho_reg = 0.1\n",
1503 | " initial_lr = 0.0001\n",
1504 | "\n",
1505 | " if epochs > 0:\n",
1506 | " train(aspector.ab, dl, \n",
1507 | " device=device,\n",
1508 | " epochs=epochs, \n",
1509 | " epochsize=epochsize,\n",
1510 | " batchsize=batchsize,\n",
1511 | " negsize=negsize,\n",
1512 | " ortho_reg=ortho_reg,\n",
1513 | " initial_lr=initial_lr)\n",
1514 | "\n",
1515 | " aspector.save_abae(abae_path)"
1516 | ]
1517 | },
1518 | {
1519 | "cell_type": "code",
1520 | "execution_count": 4,
1521 | "metadata": {},
1522 | "outputs": [
1523 | {
1524 | "data": {
1525 | "text/plain": [
1526 | "tensor([13, 13])"
1527 | ]
1528 | },
1529 | "execution_count": 4,
1530 | "metadata": {},
1531 | "output_type": "execute_result"
1532 | }
1533 | ],
1534 | "source": [
1535 | "#x = 'like roll tiny order anyway often get order wrong stray menu'\n",
1536 | "#x = 'staff staff rube strawberry banana syrup'\n",
1537 | "#x = 'the strawberries were delicious'\n",
1538 | "#x = 'the strawberry banana pudding'\n",
1539 | "x = 'superb amazing delicious'\n",
1540 | "#x = 'staff manager waiter'\n",
1541 | "#x = \"What don't I like? The rolls are tiny so you have to order more anyway and they will often get your order wrong if you stray from the menu\"\n",
1542 | "#x = \"Go here if you want a little bit of Korean combined with a little bit of Japanese food\"\n",
1543 | "#x = \"Place is casual, not fancy\"\n",
1544 | "aspector.predict(x, x)"
1545 | ]
1546 | },
1547 | {
1548 | "cell_type": "code",
1549 | "execution_count": 5,
1550 | "metadata": {},
1551 | "outputs": [
1552 | {
1553 | "name": "stdout",
1554 | "output_type": "stream",
1555 | "text": [
1556 | "What do I like about Jeollado? I like the 2 for 1 rolls (sometimes 3 for 1) the prices and the variety on the menu\r\n",
1557 | "What don't I like? The rolls are tiny so you have to order more anyway and they will often get your order wrong if you stray from the menu\r\n",
1558 | "For the money, it's a dependable and fun place to get sushi - bring friends and share the 2 for 1 rolls (they have to be 2 of the same\r\n",
1559 | ")\r\n",
1560 | "This place is a great deal for the price and the food they give you\r\n",
1561 | "Crab rolls are made with real crab, not the imitation crab\r\n",
1562 | "They also have a great unagi bim bim bap that you must order\r\n",
1563 | "Go here if you want a little bit of Korean combined with a little bit of Japanese food\r\n",
1564 | "Place is casual, not fancy\r\n",
1565 | "Short of cash, with a big group, starving? This is the place\r\n"
1566 | ]
1567 | }
1568 | ],
1569 | "source": [
1570 | "!head -n 10 ./data/restaurant.train.txt"
1571 | ]
1572 | },
1573 | {
1574 | "cell_type": "markdown",
1575 | "metadata": {},
1576 | "source": [
1577 | "#\n",
1578 | " preprocessing script for some known datasets\n",
1579 | " num tag for preprocessing\n",
1580 | " \n",
1581 | " \n",
1582 | " model wrap class\n",
1583 | "\n",
1584 | " given preprocessed data path\n",
1585 | " train w2v models\n",
1586 | " word embeddings trained on partitions too...\n",
1587 | " optionally use different w2v training corpus\n",
1588 | " initialize aspect matrix\n",
1589 | " inferring n_aspects?\n",
1590 | " downweight specificity?\n",
1591 | "\n",
1592 | " given preprocessed data path, train abae model\n",
1593 | " \n",
1594 | " given sentences, provide aspect predictions\n",
1595 | "\n",
1596 | " save and load combinations of components\n",
1597 | "\n",
1598 | "\n",
1599 | " break into package\n",
1600 | " cli\n",
1601 | " documentation\n",
1602 | " setup.py\n",
1603 | " requirements.txt"
1604 | ]
1605 | }
1606 | ],
1607 | "metadata": {
1608 | "kernelspec": {
1609 | "display_name": "Python 3",
1610 | "language": "python",
1611 | "name": "python3"
1612 | },
1613 | "language_info": {
1614 | "codemirror_mode": {
1615 | "name": "ipython",
1616 | "version": 3
1617 | },
1618 | "file_extension": ".py",
1619 | "mimetype": "text/x-python",
1620 | "name": "python",
1621 | "nbconvert_exporter": "python",
1622 | "pygments_lexer": "ipython3",
1623 | "version": "3.7.1"
1624 | }
1625 | },
1626 | "nbformat": 4,
1627 | "nbformat_minor": 2
1628 | }
1629 |
--------------------------------------------------------------------------------
/abae_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import preprocess_sentence
2 | from .word2vec import word2vec
3 | from .model import abae
4 | import numpy as np
5 | import torch
6 | import os
7 |
8 |
9 | class aspect_model:
10 |
11 | def __init__(self, data_path, w2v_path,
12 | min_count=10, d_embed=100, n_aspects=10, device='cpu'):
13 | self.w2v = word2vec(data_path)
14 | self.w2v.embed(w2v_path, d_embed, min_count=min_count)
15 | self.w2v.aspect(n_aspects)
16 | self.ab = abae(self.w2v.E, self.w2v.T).to(device)
17 |
18 | def save_abae(self, abae_path):
19 | print('saving abae model: "%s"' % abae_path)
20 | torch.save(self.ab.state_dict(), abae_path)
21 |
22 | def load_abae(self, abae_path):
23 | print('loading abae model: "%s"' % abae_path)
24 | self.ab.load_state_dict(torch.load(abae_path))
25 |
26 | def predict(self, *sentences):
27 | w2i = lambda w: self.w2v.w2i[w] if w in self.w2v.w2i else self.w2v.w2i['']
28 | x = [[w2i(w) for w in preprocess_sentence(s)] for s in sentences]
29 | p_t, z_s = self.ab.predict(torch.LongTensor(x))
30 | _, i_t = torch.sort(p_t, dim=1)
31 | return i_t[:, -1]
32 |
33 |
34 |
--------------------------------------------------------------------------------
/abae_pytorch/data.py:
--------------------------------------------------------------------------------
1 | from sklearn.feature_extraction.text import CountVectorizer
2 | from nltk.corpus import stopwords
3 | from nltk.stem.wordnet import WordNetLemmatizer
4 | import numpy as np
5 | import torch
6 | import mmap
7 | import tqdm
8 | import traceback
9 | import json
10 | import os
11 | from abae_pytorch.utils import linecount
12 |
13 |
14 | lmtzr = WordNetLemmatizer().lemmatize
15 | stop = stopwords.words('english')
16 | token = CountVectorizer().build_tokenizer()
17 |
18 |
19 | def preprocess_sentence(x):
20 | return [lmtzr(t) for t in token(x.lower()) if not t in stop]
21 |
22 |
23 | def preprocess(input_path, output_path):
24 | with open(input_path, 'r') as in_f, open(output_path, 'w') as out_f:
25 | lc = linecount(input_path)
26 | desc = 'preprocessing "%s"' % input_path
27 | for j, l in tqdm.tqdm(enumerate(in_f), total=lc, desc=desc):
28 | tokens = preprocess_sentence(l)
29 | n_tokens = len(tokens)
30 | if len(tokens) > 5 and n_tokens < 100:
31 | out_l = ' '.join(tokens)
32 | out_f.write(out_l + '\n')
33 |
34 |
35 | class dataloader:
36 |
37 | def __init__(self, w2i, path, split=None, seed=0):
38 | self.w2i = w2i
39 | self.path = path
40 | self.meta = './.' + os.path.basename(self.path) + '.meta.json'
41 | self.split = split if split else {'train': 1.0}
42 | np.random.seed(seed)
43 |
44 | def __enter__(self):
45 | self.f = open(self.path, 'rb')
46 | self.data = mmap.mmap(self.f.fileno(), 0, access=mmap.ACCESS_COPY)
47 | if os.path.isfile(self.meta):
48 | self.read_meta()
49 | else:
50 | self.offsets = dict((s, []) for s in self.split)
51 | splits, probs = zip(*list(self.split.items()))
52 | desc = 'finding offsets in "%s"' % self.path
53 | i = 0
54 | for j, char in enumerate(tqdm.tqdm(self.data, desc=desc)):
55 | if char == b'\n':
56 | split = splits[np.random.choice(len(probs), p=probs)]
57 | self.offsets[split].append((i, j))
58 | i = j + 1
59 | self.linecounts = dict((s, len(self.offsets[s])) for s in self.split)
60 | self.linecount = sum(self.linecounts[s] for s in self.split)
61 | self.write_meta()
62 | return self
63 |
64 | def __exit__(self, *ags):
65 | if ags[1]:
66 | traceback.print_exception(*ags)
67 | self.f.close()
68 | return True
69 |
70 | def write_meta(self):
71 | meta = {
72 | 'path': self.path,
73 | 'linecount': self.linecount,
74 | 'linecounts': self.linecounts,
75 | 'offsets': self.offsets,
76 | }
77 | with open(self.meta, 'w') as f:
78 | f.write(json.dumps(meta))
79 |
80 | def read_meta(self):
81 | with open(self.meta, 'r') as f:
82 | meta = json.loads(f.read())
83 | assert(self.path == meta['path'])
84 | self.linecount = meta['linecount']
85 | self.linecounts = meta['linecounts']
86 | self.offsets = meta['offsets']
87 |
88 | def b2i(self, batch):
89 | batch = [self.data[u:v].decode('utf').split() for u, v in batch]
90 | lengths = [len(l) for l in batch]
91 | index = np.zeros((len(batch), max(lengths)))
92 | w2i = lambda w: (self.w2i[w] if w in self.w2i else self.w2i[''])
93 | for j, (words, length) in enumerate(zip(batch, lengths)):
94 | index[j, :length] = [w2i(w) for w in words]
95 | return torch.LongTensor(index)
96 |
97 | def batch_generator(self, split='train', device='cpu', batchsize=20, negsize=20):
98 | linecount = self.linecounts[split]
99 | batchcount = (linecount // batchsize)
100 | pos_offsets = self.offsets[split][:]
101 | neg_offsets = self.offsets[split][:]
102 | np.random.shuffle(pos_offsets)
103 | np.random.shuffle(neg_offsets)
104 | batches = 0
105 | while True:
106 | if batches == batchcount:
107 | np.random.shuffle(pos_offsets)
108 | np.random.shuffle(neg_offsets)
109 | batches = 0
110 | pos_batch = pos_offsets[batches * batchsize:(batches + 1) * batchsize]
111 | pos_batch = self.b2i(pos_batch)
112 | neg_batch = np.random.choice(linecount, batchsize * negsize)
113 | neg_batch = self.b2i([neg_offsets[i] for i in neg_batch])
114 | batch = (
115 | pos_batch.to(device),
116 | neg_batch.to(device).view(batchsize, negsize, -1),
117 | )
118 | yield batch
119 | batches += 1
120 |
121 |
122 |
--------------------------------------------------------------------------------
/abae_pytorch/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from torch.nn.functional import normalize, softmax
4 |
5 |
6 | class attention(nn.Module):
7 |
8 | def __init__(self, d_embed):
9 | super(attention, self).__init__()
10 | self.M = nn.Linear(d_embed, d_embed)
11 | self.M.weight.data.uniform_(-0.1, 0.1)
12 |
13 | def forward(self, e_i):
14 | y_s = torch.mean(e_i, dim=-1)
15 | d_i = torch.bmm(e_i.transpose(1, 2), self.M(y_s).unsqueeze(2)).tanh()
16 | a_i = torch.exp(d_i) / torch.sum(torch.exp(d_i))
17 | return a_i.squeeze(1)
18 |
19 |
20 | class abae(nn.Module):
21 |
22 | def __init__(self, E, T):
23 | super(abae, self).__init__()
24 | n_vocab, d_embed = E.shape
25 | n_aspects, d_embed = T.shape
26 | self.E = nn.Embedding(n_vocab, d_embed)
27 | self.T = nn.Embedding(n_aspects, d_embed)
28 | self.attention = attention(d_embed)
29 | self.linear = nn.Linear(d_embed, n_aspects)
30 | #self.E.weight = nn.Parameter(torch.from_numpy(E), requires_grad=False)
31 | self.E.weight = nn.Parameter(torch.from_numpy(E), requires_grad=True)
32 | self.T.weight = nn.Parameter(torch.from_numpy(T), requires_grad=True)
33 |
34 | def forward(self, pos, negs):
35 | p_t, z_s = self.predict(pos)
36 | r_s = normalize(torch.mm(self.T.weight.t(), p_t.t()).t(), dim=-1)
37 | e_n = self.E(negs).transpose(-2, -1)
38 | z_n = normalize(torch.mean(e_n, dim=-1), dim=-1)
39 | return r_s, z_s, z_n
40 |
41 | def predict(self, x):
42 | e_i = self.E(x).transpose(1, 2)
43 | a_i = self.attention(e_i)
44 | z_s = normalize(torch.bmm(e_i, a_i).squeeze(2), dim=-1)
45 | p_t = softmax(self.linear(z_s), dim=1)
46 | return p_t, z_s
47 |
48 | def aspects(self):
49 | E_n = normalize(self.E.weight, dim=1)
50 | T_n = normalize(self.T.weight, dim=1)
51 | projection = torch.mm(E_n, T_n.t()).t()
52 | return projection
53 |
54 |
55 |
--------------------------------------------------------------------------------
/abae_pytorch/train.py:
--------------------------------------------------------------------------------
1 | from torch.nn.functional import normalize
2 | import torch.optim as optim
3 | import torch
4 |
5 | import matplotlib.pyplot as plt
6 | import matplotlib.cm as cm
7 | import collections
8 | import numpy as np
9 | import tqdm
10 |
11 |
12 | def max_margin_loss(r_s, z_s, z_n):
13 | device = r_s.device
14 | pos = torch.bmm(z_s.unsqueeze(1), r_s.unsqueeze(2)).squeeze(2)
15 | negs = torch.bmm(z_n, r_s.unsqueeze(2)).squeeze()
16 | J = torch.ones(negs.shape).to(device) - pos.expand(negs.shape) + negs
17 | return torch.sum(torch.clamp(J, min=0.0))
18 |
19 |
20 | def orthogonal_regularization(T):
21 | T_n = normalize(T, dim=1)
22 | I = torch.eye(T_n.shape[0]).to(T_n.device)
23 | return torch.norm(T_n.mm(T_n.t()) - I)
24 |
25 |
26 | def train(ab, dl, device='cuda', epochs=5, epochsize=100,
27 | initial_lr=0.02, batchsize=100, negsize=20, ortho_reg=0.1):
28 | batches = dl.batch_generator('train', device, batchsize, negsize)
29 | i2w = dict((dl.w2i[w], w) for w in dl.w2i)
30 |
31 | opt = optim.Adam(ab.parameters(), lr=initial_lr)
32 | plot = plotter()
33 |
34 | epoch_losses = collections.defaultdict(list)
35 | ab.eval()
36 | val_loss = validate(ab, dl, device, 'val', epochsize,
37 | batchsize, negsize, ortho_reg)
38 | ab.train()
39 | epoch_losses['Training Loss'].append(float('inf'))
40 | epoch_losses['Validation Loss'].append(val_loss)
41 | sample_aspects(ab.aspects(), i2w)
42 | plot(epoch_losses)
43 |
44 | for e in range(epochs):
45 | train_losses = []
46 | with tqdm.trange(epochsize) as pbar:
47 | for b in pbar:
48 | pos, neg = next(batches)
49 | r_s, z_s, z_n = ab(pos, neg)
50 | J = max_margin_loss(r_s, z_s, z_n)
51 | U = orthogonal_regularization(ab.T.weight)
52 | loss = J + ortho_reg * batchsize * U
53 | opt.zero_grad()
54 | loss.backward()
55 | opt.step()
56 |
57 | train_losses.append(loss.item())
58 | x = (e + 1, opt.param_groups[0]['lr'], train_losses[-1])
59 | d = 'TRAIN EPOCH: %d | LR: %0.5f | MEAN-TRAIN-LOSS: %0.5f' % x
60 | pbar.set_description(d)
61 |
62 | if b * batchsize % 100 == 0:
63 | lr = initial_lr * (1.0 - 1.0 * ((e * epochsize + (b + 1)) / (epochs * epochsize)))
64 | for pg in opt.param_groups:
65 | pg['lr'] = lr
66 |
67 | ab.eval()
68 | val_loss = validate(ab, dl, device, 'val', epochsize,
69 | batchsize, negsize, ortho_reg)
70 | ab.train()
71 | epoch_losses['Training Loss'].append(np.mean(train_losses))
72 | epoch_losses['Validation Loss'].append(val_loss)
73 | sample_aspects(ab.aspects(), i2w)
74 | plot(epoch_losses)
75 |
76 | ab.eval()
77 |
78 |
79 | def validate(ab, dl, device='cuda', split='val',
80 | epochsize=100, batchsize=100, negsize=20, ortho_reg=0.1):
81 | losses = []
82 | batches = dl.batch_generator(split, device, batchsize, negsize)
83 | with tqdm.tqdm(range(epochsize), total=epochsize, desc='validating') as pbar:
84 | for b in pbar:
85 | pos, neg = next(batches)
86 | r_s, z_s, z_n = ab(pos, neg)
87 | J = max_margin_loss(r_s, z_s, z_n).item()
88 | U = orthogonal_regularization(ab.T.weight).item()
89 | losses.append((J + ortho_reg * batchsize * U))
90 | x = (b + 1, np.mean(losses))
91 | pbar.set_description('VAL BATCH: %d | MEAN-VAL-LOSS: %0.5f' % x)
92 | return np.mean(losses)
93 |
94 |
95 | def plotter(figsize=(8, 4)):
96 | f, ax = plt.subplots(1, 1, figsize=figsize)
97 |
98 | def plot_losses(losses):
99 | colors = cm.rainbow(np.linspace(0, 1, len(losses)))
100 | lines = []
101 | for loss, color in zip(losses, colors):
102 | y = losses[loss]
103 | x = list(range(len(y)))
104 | l = ax.plot(x, y, color=color, label=loss, lw=4, marker='o')
105 | lines.append(loss)
106 | ax.legend(lines)
107 | ax.set_yscale('log')
108 | ax.set_title('Losses')
109 | ax.set_xlabel('Epoch')
110 | ax.set_ylabel('Loss')
111 | ax.set_xticks(x)
112 | f.canvas.draw()
113 |
114 | return plot_losses
115 |
116 |
117 | def sample_aspects(projection, i2w, n=8):
118 | projection = torch.sort(projection, dim=1)
119 | for j, (projs, index) in enumerate(zip(*projection)):
120 | index = index[-n:].detach().cpu().numpy()
121 | words = ', '.join([i2w[i] for i in index])
122 | print('Aspect %2d: %s' % (j, words))
123 |
124 |
125 |
--------------------------------------------------------------------------------
/abae_pytorch/utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def linecount(path, chunk=(8192 * 1024)):
4 | lc = 0
5 | with open(path, 'rb') as f:
6 | buf = f.read(chunk)
7 | while buf:
8 | lc += buf.count(b'\n')
9 | buf = f.read(chunk)
10 | return lc
11 |
12 |
13 |
--------------------------------------------------------------------------------
/abae_pytorch/word2vec.py:
--------------------------------------------------------------------------------
1 | from sklearn.cluster import KMeans
2 | import numpy as np
3 | import gensim
4 | import codecs
5 | import tqdm
6 | import os
7 |
8 |
9 | class word2vec:
10 |
11 | def __init__(self, corpus_path):
12 | self.corpus_path = corpus_path
13 | self.n_vocab = 0
14 |
15 | def __iter__(self):
16 | with codecs.open(self.corpus_path, 'r', 'utf-8') as f:
17 | for line in tqdm.tqdm(f, desc='training'):
18 | yield line.split()
19 |
20 | def add_word(self, *words):
21 | for word in words:
22 | if not word in self.w2i:
23 | self.w2i[word] = self.n_vocab
24 | self.i2w[self.w2i[word]] = word
25 | self.n_vocab += 1
26 |
27 | def embed(self, model_path, d_embed,
28 | max_n_vocab=None, window=5, min_count=10, workers=16):
29 | self.d_embed = d_embed
30 | if os.path.isfile(model_path):
31 | model = gensim.models.Word2Vec.load(model_path)
32 | else:
33 | model = gensim.models.Word2Vec(self,
34 | size=d_embed, max_final_vocab=max_n_vocab,
35 | window=window, min_count=min_count, workers=workers)
36 | model.save(model_path)
37 | model = gensim.models.Word2Vec.load(model_path)
38 | self.i2w, self.w2i = {}, {}
39 | self.E = []
40 | n = len(model.wv.vocab)
41 | for word in sorted(model.wv.vocab):
42 | self.add_word(word)
43 | self.E.append(list(model.wv[word]))
44 | else:
45 | self.add_word('')
46 | self.E.append([0] * d_embed)
47 | self.E = np.asarray(self.E).astype(np.float32)
48 | return self
49 |
50 | def aspect(self, n_aspects):
51 | self.n_aspects = n_aspects
52 | km = KMeans(n_clusters=n_aspects, random_state=0)
53 | km.fit(self.E)
54 | self.T = km.cluster_centers_.astype(np.float32)
55 | self.T /= np.linalg.norm(self.T, axis=-1, keepdims=True)
56 | return self
57 |
58 |
59 |
--------------------------------------------------------------------------------