├── .gitignore ├── LICENSE ├── README.md ├── abae.ipynb ├── abae_pytorch ├── __init__.py ├── data.py ├── model.py ├── train.py ├── utils.py └── word2vec.py └── data └── restaurant.train.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Curtis Ogle 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # abae_pytorch 2 | Attention-based aspect extraction 3 | 4 | PyTorch implementation of the model described in [An Unsupervised Neural Attention Model for Aspect Extraction](https://www.aclweb.org/anthology/P17-1036.pdf). 5 | -------------------------------------------------------------------------------- /abae.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "nbAgg\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "%load_ext autoreload\n", 18 | "%autoreload 2\n", 19 | "#%matplotlib inline\n", 20 | "%matplotlib notebook\n", 21 | "import matplotlib\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "print(plt.get_backend())" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": { 30 | "scrolled": false 31 | }, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "loading abae model: \"./data/restaurant.train.txt.prep.abae.pt\"\n", 38 | "n_vocab: 8311 | d_embed: 200 | n_aspects: 15\n" 39 | ] 40 | }, 41 | { 42 | "data": { 43 | "application/javascript": [ 44 | "/* Put everything inside the global mpl namespace */\n", 45 | "window.mpl = {};\n", 46 | "\n", 47 | "\n", 48 | "mpl.get_websocket_type = function() {\n", 49 | " if (typeof(WebSocket) !== 'undefined') {\n", 50 | " return WebSocket;\n", 51 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n", 52 | " return MozWebSocket;\n", 53 | " } else {\n", 54 | " alert('Your browser does not have WebSocket support.' +\n", 55 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", 56 | " 'Firefox 4 and 5 are also supported but you ' +\n", 57 | " 'have to enable WebSockets in about:config.');\n", 58 | " };\n", 59 | "}\n", 60 | "\n", 61 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", 62 | " this.id = figure_id;\n", 63 | "\n", 64 | " this.ws = websocket;\n", 65 | "\n", 66 | " this.supports_binary = (this.ws.binaryType != undefined);\n", 67 | "\n", 68 | " if (!this.supports_binary) {\n", 69 | " var warnings = document.getElementById(\"mpl-warnings\");\n", 70 | " if (warnings) {\n", 71 | " warnings.style.display = 'block';\n", 72 | " warnings.textContent = (\n", 73 | " \"This browser does not support binary websocket messages. \" +\n", 74 | " \"Performance may be slow.\");\n", 75 | " }\n", 76 | " }\n", 77 | "\n", 78 | " this.imageObj = new Image();\n", 79 | "\n", 80 | " this.context = undefined;\n", 81 | " this.message = undefined;\n", 82 | " this.canvas = undefined;\n", 83 | " this.rubberband_canvas = undefined;\n", 84 | " this.rubberband_context = undefined;\n", 85 | " this.format_dropdown = undefined;\n", 86 | "\n", 87 | " this.image_mode = 'full';\n", 88 | "\n", 89 | " this.root = $('
');\n", 90 | " this._root_extra_style(this.root)\n", 91 | " this.root.attr('style', 'display: inline-block');\n", 92 | "\n", 93 | " $(parent_element).append(this.root);\n", 94 | "\n", 95 | " this._init_header(this);\n", 96 | " this._init_canvas(this);\n", 97 | " this._init_toolbar(this);\n", 98 | "\n", 99 | " var fig = this;\n", 100 | "\n", 101 | " this.waiting = false;\n", 102 | "\n", 103 | " this.ws.onopen = function () {\n", 104 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", 105 | " fig.send_message(\"send_image_mode\", {});\n", 106 | " if (mpl.ratio != 1) {\n", 107 | " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", 108 | " }\n", 109 | " fig.send_message(\"refresh\", {});\n", 110 | " }\n", 111 | "\n", 112 | " this.imageObj.onload = function() {\n", 113 | " if (fig.image_mode == 'full') {\n", 114 | " // Full images could contain transparency (where diff images\n", 115 | " // almost always do), so we need to clear the canvas so that\n", 116 | " // there is no ghosting.\n", 117 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", 118 | " }\n", 119 | " fig.context.drawImage(fig.imageObj, 0, 0);\n", 120 | " };\n", 121 | "\n", 122 | " this.imageObj.onunload = function() {\n", 123 | " fig.ws.close();\n", 124 | " }\n", 125 | "\n", 126 | " this.ws.onmessage = this._make_on_message_function(this);\n", 127 | "\n", 128 | " this.ondownload = ondownload;\n", 129 | "}\n", 130 | "\n", 131 | "mpl.figure.prototype._init_header = function() {\n", 132 | " var titlebar = $(\n", 133 | " '
');\n", 135 | " var titletext = $(\n", 136 | " '
');\n", 138 | " titlebar.append(titletext)\n", 139 | " this.root.append(titlebar);\n", 140 | " this.header = titletext[0];\n", 141 | "}\n", 142 | "\n", 143 | "\n", 144 | "\n", 145 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", 146 | "\n", 147 | "}\n", 148 | "\n", 149 | "\n", 150 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", 151 | "\n", 152 | "}\n", 153 | "\n", 154 | "mpl.figure.prototype._init_canvas = function() {\n", 155 | " var fig = this;\n", 156 | "\n", 157 | " var canvas_div = $('
');\n", 158 | "\n", 159 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", 160 | "\n", 161 | " function canvas_keyboard_event(event) {\n", 162 | " return fig.key_event(event, event['data']);\n", 163 | " }\n", 164 | "\n", 165 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n", 166 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n", 167 | " this.canvas_div = canvas_div\n", 168 | " this._canvas_extra_style(canvas_div)\n", 169 | " this.root.append(canvas_div);\n", 170 | "\n", 171 | " var canvas = $('');\n", 172 | " canvas.addClass('mpl-canvas');\n", 173 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", 174 | "\n", 175 | " this.canvas = canvas[0];\n", 176 | " this.context = canvas[0].getContext(\"2d\");\n", 177 | "\n", 178 | " var backingStore = this.context.backingStorePixelRatio ||\n", 179 | "\tthis.context.webkitBackingStorePixelRatio ||\n", 180 | "\tthis.context.mozBackingStorePixelRatio ||\n", 181 | "\tthis.context.msBackingStorePixelRatio ||\n", 182 | "\tthis.context.oBackingStorePixelRatio ||\n", 183 | "\tthis.context.backingStorePixelRatio || 1;\n", 184 | "\n", 185 | " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", 186 | "\n", 187 | " var rubberband = $('');\n", 188 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", 189 | "\n", 190 | " var pass_mouse_events = true;\n", 191 | "\n", 192 | " canvas_div.resizable({\n", 193 | " start: function(event, ui) {\n", 194 | " pass_mouse_events = false;\n", 195 | " },\n", 196 | " resize: function(event, ui) {\n", 197 | " fig.request_resize(ui.size.width, ui.size.height);\n", 198 | " },\n", 199 | " stop: function(event, ui) {\n", 200 | " pass_mouse_events = true;\n", 201 | " fig.request_resize(ui.size.width, ui.size.height);\n", 202 | " },\n", 203 | " });\n", 204 | "\n", 205 | " function mouse_event_fn(event) {\n", 206 | " if (pass_mouse_events)\n", 207 | " return fig.mouse_event(event, event['data']);\n", 208 | " }\n", 209 | "\n", 210 | " rubberband.mousedown('button_press', mouse_event_fn);\n", 211 | " rubberband.mouseup('button_release', mouse_event_fn);\n", 212 | " // Throttle sequential mouse events to 1 every 20ms.\n", 213 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n", 214 | "\n", 215 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", 216 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", 217 | "\n", 218 | " canvas_div.on(\"wheel\", function (event) {\n", 219 | " event = event.originalEvent;\n", 220 | " event['data'] = 'scroll'\n", 221 | " if (event.deltaY < 0) {\n", 222 | " event.step = 1;\n", 223 | " } else {\n", 224 | " event.step = -1;\n", 225 | " }\n", 226 | " mouse_event_fn(event);\n", 227 | " });\n", 228 | "\n", 229 | " canvas_div.append(canvas);\n", 230 | " canvas_div.append(rubberband);\n", 231 | "\n", 232 | " this.rubberband = rubberband;\n", 233 | " this.rubberband_canvas = rubberband[0];\n", 234 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", 235 | " this.rubberband_context.strokeStyle = \"#000000\";\n", 236 | "\n", 237 | " this._resize_canvas = function(width, height) {\n", 238 | " // Keep the size of the canvas, canvas container, and rubber band\n", 239 | " // canvas in synch.\n", 240 | " canvas_div.css('width', width)\n", 241 | " canvas_div.css('height', height)\n", 242 | "\n", 243 | " canvas.attr('width', width * mpl.ratio);\n", 244 | " canvas.attr('height', height * mpl.ratio);\n", 245 | " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", 246 | "\n", 247 | " rubberband.attr('width', width);\n", 248 | " rubberband.attr('height', height);\n", 249 | " }\n", 250 | "\n", 251 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n", 252 | " // upon first draw.\n", 253 | " this._resize_canvas(600, 600);\n", 254 | "\n", 255 | " // Disable right mouse context menu.\n", 256 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", 257 | " return false;\n", 258 | " });\n", 259 | "\n", 260 | " function set_focus () {\n", 261 | " canvas.focus();\n", 262 | " canvas_div.focus();\n", 263 | " }\n", 264 | "\n", 265 | " window.setTimeout(set_focus, 100);\n", 266 | "}\n", 267 | "\n", 268 | "mpl.figure.prototype._init_toolbar = function() {\n", 269 | " var fig = this;\n", 270 | "\n", 271 | " var nav_element = $('
')\n", 272 | " nav_element.attr('style', 'width: 100%');\n", 273 | " this.root.append(nav_element);\n", 274 | "\n", 275 | " // Define a callback function for later on.\n", 276 | " function toolbar_event(event) {\n", 277 | " return fig.toolbar_button_onclick(event['data']);\n", 278 | " }\n", 279 | " function toolbar_mouse_event(event) {\n", 280 | " return fig.toolbar_button_onmouseover(event['data']);\n", 281 | " }\n", 282 | "\n", 283 | " for(var toolbar_ind in mpl.toolbar_items) {\n", 284 | " var name = mpl.toolbar_items[toolbar_ind][0];\n", 285 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", 286 | " var image = mpl.toolbar_items[toolbar_ind][2];\n", 287 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", 288 | "\n", 289 | " if (!name) {\n", 290 | " // put a spacer in here.\n", 291 | " continue;\n", 292 | " }\n", 293 | " var button = $('