├── .gitignore ├── Network.py ├── README.md ├── Transformer.ipynb ├── data.py ├── doc ├── buildfig.png ├── compare.png ├── comparegraph.png ├── hyperparams.png ├── imports,png.PNG ├── init.png ├── lossgraph.png └── training.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | *.py[cod] 4 | .git 5 | data/ 6 | -------------------------------------------------------------------------------- /Network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from utils import * 6 | 7 | 8 | class EncoderLayer(torch.nn.Module): 9 | def __init__(self, dim_val, dim_attn, n_heads = 1): 10 | super(EncoderLayer, self).__init__() 11 | self.attn = MultiHeadAttentionBlock(dim_val, dim_attn , n_heads) 12 | self.fc1 = nn.Linear(dim_val, dim_val) 13 | self.fc2 = nn.Linear(dim_val, dim_val) 14 | 15 | self.norm1 = nn.LayerNorm(dim_val) 16 | self.norm2 = nn.LayerNorm(dim_val) 17 | 18 | def forward(self, x): 19 | a = self.attn(x) 20 | x = self.norm1(x + a) 21 | 22 | a = self.fc1(F.elu(self.fc2(x))) 23 | x = self.norm2(x + a) 24 | 25 | return x 26 | 27 | class DecoderLayer(torch.nn.Module): 28 | def __init__(self, dim_val, dim_attn, n_heads = 1): 29 | super(DecoderLayer, self).__init__() 30 | self.attn1 = MultiHeadAttentionBlock(dim_val, dim_attn, n_heads) 31 | self.attn2 = MultiHeadAttentionBlock(dim_val, dim_attn, n_heads) 32 | self.fc1 = nn.Linear(dim_val, dim_val) 33 | self.fc2 = nn.Linear(dim_val, dim_val) 34 | 35 | self.norm1 = nn.LayerNorm(dim_val) 36 | self.norm2 = nn.LayerNorm(dim_val) 37 | self.norm3 = nn.LayerNorm(dim_val) 38 | 39 | def forward(self, x, enc): 40 | a = self.attn1(x) 41 | x = self.norm1(a + x) 42 | 43 | a = self.attn2(x, kv = enc) 44 | x = self.norm2(a + x) 45 | 46 | a = self.fc1(F.elu(self.fc2(x))) 47 | 48 | x = self.norm3(x + a) 49 | return x 50 | 51 | class Transformer(torch.nn.Module): 52 | def __init__(self, dim_val, dim_attn, input_size, dec_seq_len, out_seq_len, n_decoder_layers = 1, n_encoder_layers = 1, n_heads = 1): 53 | super(Transformer, self).__init__() 54 | self.dec_seq_len = dec_seq_len 55 | 56 | #Initiate encoder and Decoder layers 57 | self.encs = [] 58 | for i in range(n_encoder_layers): 59 | self.encs.append(EncoderLayer(dim_val, dim_attn, n_heads)) 60 | 61 | self.decs = [] 62 | for i in range(n_decoder_layers): 63 | self.decs.append(DecoderLayer(dim_val, dim_attn, n_heads)) 64 | 65 | self.pos = PositionalEncoding(dim_val) 66 | 67 | #Dense layers for managing network inputs and outputs 68 | self.enc_input_fc = nn.Linear(input_size, dim_val) 69 | self.dec_input_fc = nn.Linear(input_size, dim_val) 70 | self.out_fc = nn.Linear(dec_seq_len * dim_val, out_seq_len) 71 | 72 | def forward(self, x): 73 | #encoder 74 | e = self.encs[0](self.pos(self.enc_input_fc(x))) 75 | for enc in self.encs[1:]: 76 | e = enc(e) 77 | 78 | #decoder 79 | d = self.decs[0](self.dec_input_fc(x[:,-self.dec_seq_len:]), e) 80 | for dec in self.decs[1:]: 81 | d = dec(d, e) 82 | 83 | #output 84 | x = self.out_fc(d.flatten(start_dim=1)) 85 | 86 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is an implementation of the Transformer algorithm on time series data in pytorch. In this case the modelling of the sigmoid function is used as a toy problem 2 | 3 | Usage: 4 | First all the necessary imports as well as matplotlib for visualisation. 5 | ![](https://github.com/LiamMaclean216/Pytorch-Transfomer/blob/master/doc/imports%2Cpng.PNG) 6 | Next we need to define some hyperparameters which will vary depending on the task. 7 | ![](https://github.com/LiamMaclean216/Pytorch-Transfomer/blob/master/doc/hyperparams.png) 8 | We initilisise the Network and an optimizier, in this case Adam, as well as an empty list to track losses for visualisation. 9 | ![](https://github.com/LiamMaclean216/Pytorch-Transfomer/blob/master/doc/init.png) 10 | Using matplotlib in jupyter notebook we can graph losses in real time, first lets initialise a figure. 11 | ![](https://github.com/LiamMaclean216/Pytorch-Transfomer/blob/master/doc/buildfig.png) 12 | We can now being training 13 | ![](https://github.com/LiamMaclean216/Pytorch-Transfomer/blob/master/doc/training.png) 14 | You should see a live plot that looks similar to this tracking the ouput error 15 | ![](https://github.com/LiamMaclean216/Pytorch-Transfomer/blob/master/doc/lossgraph.png) 16 | Now that the network is trained, lets give it the first few values of the sigmoid function and see how it approximates the rest. 17 | We create another figure to visualise this. 18 | ![](https://github.com/LiamMaclean216/Pytorch-Transfomer/blob/master/doc/compare.png) 19 | If all went well, the output should look something like this : 20 | ![](https://github.com/LiamMaclean216/Pytorch-Transfomer/blob/master/doc/comparegraph.png) 21 | Note that the network uses past values instead of the x axis for its predictions , so it makes sense that the output is offset. 22 | However it did succesfully captured the shape. 23 | 24 | Resources: 25 | * Attention is all you need : https://arxiv.org/abs/1706.03762 26 | * Deep Transformer Models for Time Series Forecasting : https://arxiv.org/abs/2001.08317 27 | 28 | 29 | -------------------------------------------------------------------------------- /Transformer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import numpy as np\n", 13 | "from utils import *\n", 14 | "from Network import *\n", 15 | "from data import *\n", 16 | "\n", 17 | "%matplotlib notebook\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "\n", 20 | "#hyperparams\n", 21 | "enc_seq_len = 6\n", 22 | "dec_seq_len = 2\n", 23 | "output_sequence_length = 1\n", 24 | "\n", 25 | "dim_val = 10\n", 26 | "dim_attn = 5\n", 27 | "lr = 0.002\n", 28 | "epochs = 20\n", 29 | "\n", 30 | "n_heads = 3 \n", 31 | "\n", 32 | "n_decoder_layers = 3\n", 33 | "n_encoder_layers = 3\n", 34 | "\n", 35 | "batch_size = 15\n", 36 | "\n", 37 | "#init network and optimizer\n", 38 | "t = Transformer(dim_val, dim_attn, 1,dec_seq_len, output_sequence_length, n_decoder_layers, n_encoder_layers, n_heads)\n", 39 | "optimizer = torch.optim.Adam(t.parameters(), lr=lr)\n", 40 | "\n", 41 | "#keep track of loss for graph\n", 42 | "losses = []" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | "D:\\OneDrive\\GitHub\\Pytorch-Chatbot\\data.py:52: FutureWarning: \n", 55 | "Passing list-likes to .loc or [] with any missing label will raise\n", 56 | "KeyError in the future, you can use .reindex() as an alternative.\n", 57 | "\n", 58 | "See the documentation here:\n", 59 | "https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#deprecate-loc-reindex-listlike\n", 60 | " vectorized_targets.append(torch.tensor(word_vecs.loc[b[idx + 1]].values).type(dtype))\n", 61 | "D:\\OneDrive\\GitHub\\Pytorch-Chatbot\\data.py:51: FutureWarning: \n", 62 | "Passing list-likes to .loc or [] with any missing label will raise\n", 63 | "KeyError in the future, you can use .reindex() as an alternative.\n", 64 | "\n", 65 | "See the documentation here:\n", 66 | "https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#deprecate-loc-reindex-listlike\n", 67 | " vectorized_inputs.append(torch.tensor(word_vecs.loc[b[idx]].values).type(dtype))\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "g = gen_data(4, gpu = True)\n", 73 | "vec_ins, vec_targs, ins, targs = next(g)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 2, 79 | "metadata": { 80 | "scrolled": false 81 | }, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "application/javascript": [ 86 | "/* Put everything inside the global mpl namespace */\n", 87 | "window.mpl = {};\n", 88 | "\n", 89 | "\n", 90 | "mpl.get_websocket_type = function() {\n", 91 | " if (typeof(WebSocket) !== 'undefined') {\n", 92 | " return WebSocket;\n", 93 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n", 94 | " return MozWebSocket;\n", 95 | " } else {\n", 96 | " alert('Your browser does not have WebSocket support. ' +\n", 97 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", 98 | " 'Firefox 4 and 5 are also supported but you ' +\n", 99 | " 'have to enable WebSockets in about:config.');\n", 100 | " };\n", 101 | "}\n", 102 | "\n", 103 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", 104 | " this.id = figure_id;\n", 105 | "\n", 106 | " this.ws = websocket;\n", 107 | "\n", 108 | " this.supports_binary = (this.ws.binaryType != undefined);\n", 109 | "\n", 110 | " if (!this.supports_binary) {\n", 111 | " var warnings = document.getElementById(\"mpl-warnings\");\n", 112 | " if (warnings) {\n", 113 | " warnings.style.display = 'block';\n", 114 | " warnings.textContent = (\n", 115 | " \"This browser does not support binary websocket messages. \" +\n", 116 | " \"Performance may be slow.\");\n", 117 | " }\n", 118 | " }\n", 119 | "\n", 120 | " this.imageObj = new Image();\n", 121 | "\n", 122 | " this.context = undefined;\n", 123 | " this.message = undefined;\n", 124 | " this.canvas = undefined;\n", 125 | " this.rubberband_canvas = undefined;\n", 126 | " this.rubberband_context = undefined;\n", 127 | " this.format_dropdown = undefined;\n", 128 | "\n", 129 | " this.image_mode = 'full';\n", 130 | "\n", 131 | " this.root = $('
');\n", 132 | " this._root_extra_style(this.root)\n", 133 | " this.root.attr('style', 'display: inline-block');\n", 134 | "\n", 135 | " $(parent_element).append(this.root);\n", 136 | "\n", 137 | " this._init_header(this);\n", 138 | " this._init_canvas(this);\n", 139 | " this._init_toolbar(this);\n", 140 | "\n", 141 | " var fig = this;\n", 142 | "\n", 143 | " this.waiting = false;\n", 144 | "\n", 145 | " this.ws.onopen = function () {\n", 146 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", 147 | " fig.send_message(\"send_image_mode\", {});\n", 148 | " if (mpl.ratio != 1) {\n", 149 | " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", 150 | " }\n", 151 | " fig.send_message(\"refresh\", {});\n", 152 | " }\n", 153 | "\n", 154 | " this.imageObj.onload = function() {\n", 155 | " if (fig.image_mode == 'full') {\n", 156 | " // Full images could contain transparency (where diff images\n", 157 | " // almost always do), so we need to clear the canvas so that\n", 158 | " // there is no ghosting.\n", 159 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", 160 | " }\n", 161 | " fig.context.drawImage(fig.imageObj, 0, 0);\n", 162 | " };\n", 163 | "\n", 164 | " this.imageObj.onunload = function() {\n", 165 | " fig.ws.close();\n", 166 | " }\n", 167 | "\n", 168 | " this.ws.onmessage = this._make_on_message_function(this);\n", 169 | "\n", 170 | " this.ondownload = ondownload;\n", 171 | "}\n", 172 | "\n", 173 | "mpl.figure.prototype._init_header = function() {\n", 174 | " var titlebar = $(\n", 175 | " '
');\n", 177 | " var titletext = $(\n", 178 | " '
');\n", 180 | " titlebar.append(titletext)\n", 181 | " this.root.append(titlebar);\n", 182 | " this.header = titletext[0];\n", 183 | "}\n", 184 | "\n", 185 | "\n", 186 | "\n", 187 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", 188 | "\n", 189 | "}\n", 190 | "\n", 191 | "\n", 192 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", 193 | "\n", 194 | "}\n", 195 | "\n", 196 | "mpl.figure.prototype._init_canvas = function() {\n", 197 | " var fig = this;\n", 198 | "\n", 199 | " var canvas_div = $('
');\n", 200 | "\n", 201 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", 202 | "\n", 203 | " function canvas_keyboard_event(event) {\n", 204 | " return fig.key_event(event, event['data']);\n", 205 | " }\n", 206 | "\n", 207 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n", 208 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n", 209 | " this.canvas_div = canvas_div\n", 210 | " this._canvas_extra_style(canvas_div)\n", 211 | " this.root.append(canvas_div);\n", 212 | "\n", 213 | " var canvas = $('');\n", 214 | " canvas.addClass('mpl-canvas');\n", 215 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", 216 | "\n", 217 | " this.canvas = canvas[0];\n", 218 | " this.context = canvas[0].getContext(\"2d\");\n", 219 | "\n", 220 | " var backingStore = this.context.backingStorePixelRatio ||\n", 221 | "\tthis.context.webkitBackingStorePixelRatio ||\n", 222 | "\tthis.context.mozBackingStorePixelRatio ||\n", 223 | "\tthis.context.msBackingStorePixelRatio ||\n", 224 | "\tthis.context.oBackingStorePixelRatio ||\n", 225 | "\tthis.context.backingStorePixelRatio || 1;\n", 226 | "\n", 227 | " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", 228 | "\n", 229 | " var rubberband = $('');\n", 230 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", 231 | "\n", 232 | " var pass_mouse_events = true;\n", 233 | "\n", 234 | " canvas_div.resizable({\n", 235 | " start: function(event, ui) {\n", 236 | " pass_mouse_events = false;\n", 237 | " },\n", 238 | " resize: function(event, ui) {\n", 239 | " fig.request_resize(ui.size.width, ui.size.height);\n", 240 | " },\n", 241 | " stop: function(event, ui) {\n", 242 | " pass_mouse_events = true;\n", 243 | " fig.request_resize(ui.size.width, ui.size.height);\n", 244 | " },\n", 245 | " });\n", 246 | "\n", 247 | " function mouse_event_fn(event) {\n", 248 | " if (pass_mouse_events)\n", 249 | " return fig.mouse_event(event, event['data']);\n", 250 | " }\n", 251 | "\n", 252 | " rubberband.mousedown('button_press', mouse_event_fn);\n", 253 | " rubberband.mouseup('button_release', mouse_event_fn);\n", 254 | " // Throttle sequential mouse events to 1 every 20ms.\n", 255 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n", 256 | "\n", 257 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", 258 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", 259 | "\n", 260 | " canvas_div.on(\"wheel\", function (event) {\n", 261 | " event = event.originalEvent;\n", 262 | " event['data'] = 'scroll'\n", 263 | " if (event.deltaY < 0) {\n", 264 | " event.step = 1;\n", 265 | " } else {\n", 266 | " event.step = -1;\n", 267 | " }\n", 268 | " mouse_event_fn(event);\n", 269 | " });\n", 270 | "\n", 271 | " canvas_div.append(canvas);\n", 272 | " canvas_div.append(rubberband);\n", 273 | "\n", 274 | " this.rubberband = rubberband;\n", 275 | " this.rubberband_canvas = rubberband[0];\n", 276 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", 277 | " this.rubberband_context.strokeStyle = \"#000000\";\n", 278 | "\n", 279 | " this._resize_canvas = function(width, height) {\n", 280 | " // Keep the size of the canvas, canvas container, and rubber band\n", 281 | " // canvas in synch.\n", 282 | " canvas_div.css('width', width)\n", 283 | " canvas_div.css('height', height)\n", 284 | "\n", 285 | " canvas.attr('width', width * mpl.ratio);\n", 286 | " canvas.attr('height', height * mpl.ratio);\n", 287 | " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", 288 | "\n", 289 | " rubberband.attr('width', width);\n", 290 | " rubberband.attr('height', height);\n", 291 | " }\n", 292 | "\n", 293 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n", 294 | " // upon first draw.\n", 295 | " this._resize_canvas(600, 600);\n", 296 | "\n", 297 | " // Disable right mouse context menu.\n", 298 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", 299 | " return false;\n", 300 | " });\n", 301 | "\n", 302 | " function set_focus () {\n", 303 | " canvas.focus();\n", 304 | " canvas_div.focus();\n", 305 | " }\n", 306 | "\n", 307 | " window.setTimeout(set_focus, 100);\n", 308 | "}\n", 309 | "\n", 310 | "mpl.figure.prototype._init_toolbar = function() {\n", 311 | " var fig = this;\n", 312 | "\n", 313 | " var nav_element = $('
');\n", 314 | " nav_element.attr('style', 'width: 100%');\n", 315 | " this.root.append(nav_element);\n", 316 | "\n", 317 | " // Define a callback function for later on.\n", 318 | " function toolbar_event(event) {\n", 319 | " return fig.toolbar_button_onclick(event['data']);\n", 320 | " }\n", 321 | " function toolbar_mouse_event(event) {\n", 322 | " return fig.toolbar_button_onmouseover(event['data']);\n", 323 | " }\n", 324 | "\n", 325 | " for(var toolbar_ind in mpl.toolbar_items) {\n", 326 | " var name = mpl.toolbar_items[toolbar_ind][0];\n", 327 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", 328 | " var image = mpl.toolbar_items[toolbar_ind][2];\n", 329 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", 330 | "\n", 331 | " if (!name) {\n", 332 | " // put a spacer in here.\n", 333 | " continue;\n", 334 | " }\n", 335 | " var button = $('');\n", 1598 | " button.click(method_name, toolbar_event);\n", 1599 | " button.mouseover(tooltip, toolbar_mouse_event);\n", 1600 | " nav_element.append(button);\n", 1601 | " }\n", 1602 | "\n", 1603 | " // Add the status bar.\n", 1604 | " var status_bar = $('');\n", 1605 | " nav_element.append(status_bar);\n", 1606 | " this.message = status_bar[0];\n", 1607 | "\n", 1608 | " // Add the close button to the window.\n", 1609 | " var buttongrp = $('
');\n", 1610 | " var button = $('');\n", 1611 | " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", 1612 | " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", 1613 | " buttongrp.append(button);\n", 1614 | " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", 1615 | " titlebar.prepend(buttongrp);\n", 1616 | "}\n", 1617 | "\n", 1618 | "mpl.figure.prototype._root_extra_style = function(el){\n", 1619 | " var fig = this\n", 1620 | " el.on(\"remove\", function(){\n", 1621 | "\tfig.close_ws(fig, {});\n", 1622 | " });\n", 1623 | "}\n", 1624 | "\n", 1625 | "mpl.figure.prototype._canvas_extra_style = function(el){\n", 1626 | " // this is important to make the div 'focusable\n", 1627 | " el.attr('tabindex', 0)\n", 1628 | " // reach out to IPython and tell the keyboard manager to turn it's self\n", 1629 | " // off when our div gets focus\n", 1630 | "\n", 1631 | " // location in version 3\n", 1632 | " if (IPython.notebook.keyboard_manager) {\n", 1633 | " IPython.notebook.keyboard_manager.register_events(el);\n", 1634 | " }\n", 1635 | " else {\n", 1636 | " // location in version 2\n", 1637 | " IPython.keyboard_manager.register_events(el);\n", 1638 | " }\n", 1639 | "\n", 1640 | "}\n", 1641 | "\n", 1642 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n", 1643 | " var manager = IPython.notebook.keyboard_manager;\n", 1644 | " if (!manager)\n", 1645 | " manager = IPython.keyboard_manager;\n", 1646 | "\n", 1647 | " // Check for shift+enter\n", 1648 | " if (event.shiftKey && event.which == 13) {\n", 1649 | " this.canvas_div.blur();\n", 1650 | " event.shiftKey = false;\n", 1651 | " // Send a \"J\" for go to next cell\n", 1652 | " event.which = 74;\n", 1653 | " event.keyCode = 74;\n", 1654 | " manager.command_mode();\n", 1655 | " manager.handle_keydown(event);\n", 1656 | " }\n", 1657 | "}\n", 1658 | "\n", 1659 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n", 1660 | " fig.ondownload(fig, null);\n", 1661 | "}\n", 1662 | "\n", 1663 | "\n", 1664 | "mpl.find_output_cell = function(html_output) {\n", 1665 | " // Return the cell and output element which can be found *uniquely* in the notebook.\n", 1666 | " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", 1667 | " // IPython event is triggered only after the cells have been serialised, which for\n", 1668 | " // our purposes (turning an active figure into a static one), is too late.\n", 1669 | " var cells = IPython.notebook.get_cells();\n", 1670 | " var ncells = cells.length;\n", 1671 | " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", 1678 | " data = data.data;\n", 1679 | " }\n", 1680 | " if (data['text/html'] == html_output) {\n", 1681 | " return [cell, data, j];\n", 1682 | " }\n", 1683 | " }\n", 1684 | " }\n", 1685 | " }\n", 1686 | "}\n", 1687 | "\n", 1688 | "// Register the function which deals with the matplotlib target/channel.\n", 1689 | "// The kernel may be null if the page has been refreshed.\n", 1690 | "if (IPython.notebook.kernel != null) {\n", 1691 | " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", 1692 | "}\n" 1693 | ], 1694 | "text/plain": [ 1695 | "" 1696 | ] 1697 | }, 1698 | "metadata": {}, 1699 | "output_type": "display_data" 1700 | }, 1701 | { 1702 | "data": { 1703 | "text/html": [ 1704 | "" 1705 | ], 1706 | "text/plain": [ 1707 | "" 1708 | ] 1709 | }, 1710 | "metadata": {}, 1711 | "output_type": "display_data" 1712 | }, 1713 | { 1714 | "data": { 1715 | "text/plain": [ 1716 | "" 1717 | ] 1718 | }, 1719 | "execution_count": 5, 1720 | "metadata": {}, 1721 | "output_type": "execute_result" 1722 | } 1723 | ], 1724 | "source": [ 1725 | "fig = plt.figure()\n", 1726 | "ax = fig.add_subplot(111)\n", 1727 | "plt.ion()\n", 1728 | "\n", 1729 | "fig.show()\n", 1730 | "fig.canvas.draw()\n", 1731 | "\n", 1732 | "o = []\n", 1733 | "x = [torch.sigmoid(torch.arange(-10,-1).float()).unsqueeze(-1).numpy().tolist()]\n", 1734 | "\n", 1735 | "#Draw graph comparing to sigmoid\n", 1736 | "for i in range(-10, 10, output_sequence_length):\n", 1737 | " o.append([torch.sigmoid(torch.tensor(i).float())])\n", 1738 | " q = torch.tensor(x).float()\n", 1739 | " \n", 1740 | " if(output_sequence_length == 1):\n", 1741 | " x[0].append([t(q).detach().squeeze().numpy()])\n", 1742 | " else:\n", 1743 | " for a in t(q).detach().squeeze().numpy():\n", 1744 | " x[0].append([a])\n", 1745 | " \n", 1746 | "ax.clear()\n", 1747 | "ax.plot(x[0], label='Network output')\n", 1748 | "ax.plot(o, label='Sigmoid function')\n", 1749 | "ax.set_title(\"\")\n", 1750 | "ax.legend(loc='upper left', frameon=False)\n" 1751 | ] 1752 | }, 1753 | { 1754 | "cell_type": "code", 1755 | "execution_count": null, 1756 | "metadata": {}, 1757 | "outputs": [], 1758 | "source": [] 1759 | } 1760 | ], 1761 | "metadata": { 1762 | "kernelspec": { 1763 | "display_name": "Python 3", 1764 | "language": "python", 1765 | "name": "python3" 1766 | }, 1767 | "language_info": { 1768 | "codemirror_mode": { 1769 | "name": "ipython", 1770 | "version": 3 1771 | }, 1772 | "file_extension": ".py", 1773 | "mimetype": "text/x-python", 1774 | "name": "python", 1775 | "nbconvert_exporter": "python", 1776 | "pygments_lexer": "ipython3", 1777 | "version": "3.7.4" 1778 | } 1779 | }, 1780 | "nbformat": 4, 1781 | "nbformat_minor": 2 1782 | } 1783 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import csv 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | def gen_data(batch_size, gpu = True): 9 | 10 | #https://www.kaggle.com/shashankasubrahmanya/preprocessing-cornell-movie-dialogue-corpus 11 | movie_lines = pd.read_csv('cornell_data/movie_lines.txt', sep = "\+\+\+\$\+\+\+", engine = "python", 12 | index_col = False, names = ["LineID", "Character", "Movie", "Name", "Line"]) 13 | movie_lines = movie_lines[["LineID", "Line"]] 14 | 15 | movie_lines["Line"] = movie_lines['Line'].str.replace('.','') 16 | movie_lines["Line"] = movie_lines['Line'].str.replace('!','') 17 | movie_lines["Line"] = movie_lines['Line'].str.replace('?','') 18 | movie_lines["Line"] = movie_lines['Line'].str.replace(' ',' ') 19 | movie_lines["Line"] = movie_lines['Line'].str.replace('[^\w\s.!?]','') 20 | movie_lines["Line"] = movie_lines["Line"].str.lower() 21 | 22 | movie_lines["LineID"] = movie_lines["LineID"].apply(str.strip) 23 | movie_lines["Line"] = movie_lines["Line"].apply(lambda x : str(x).split(" ")[1:]) 24 | 25 | movie_conversations = pd.read_csv("cornell_data/movie_conversations.txt", sep = "\+\+\+\$\+\+\+", 26 | engine = "python", index_col = False, names = ["Character1", "Character2", "Movie", "Conversation"]) 27 | movie_conversations = movie_conversations["Conversation"] 28 | 29 | #convert from strings of lists to actual lists 30 | movie_conversations = movie_conversations.apply(eval) 31 | 32 | word_vecs = pd.read_table("glove.6B.50d.txt", sep=" ", index_col=0, header=None, quoting=csv.QUOTE_NONE) 33 | if(not gpu): 34 | dtype = torch.FloatTensor 35 | else: 36 | dtype = torch.cuda.FloatTensor 37 | 38 | ba = 0 39 | vectorized_inputs = [] 40 | vectorized_targets = [] 41 | 42 | inputs = [] 43 | targets = [] 44 | while True: 45 | i = random.randint(0, movie_conversations.size - (batch_size + 1)) 46 | batch = movie_conversations.loc[i:i+batch_size].apply(lambda x : movie_lines.loc[(movie_lines['LineID'].isin(x))]) 47 | batch = batch.apply(lambda x : x['Line'].values).values 48 | 49 | for b in batch: 50 | for idx in range(len(b) - 1): 51 | vectorized_inputs.append(torch.tensor(word_vecs.loc[b[idx]].values).type(dtype)) 52 | vectorized_targets.append(torch.tensor(word_vecs.loc[b[idx + 1]].values).type(dtype)) 53 | 54 | inputs.append(b[idx]) 55 | targets.append(b[idx + 1]) 56 | 57 | ba += 1 58 | if ba >= batch_size: 59 | ba = 0 60 | yield (pad_sequence(vectorized_inputs, batch_first = True) 61 | , pad_sequence(vectorized_targets, batch_first = True), inputs, targets) 62 | vectorized_inputs = [] 63 | vectorized_targets = [] 64 | inputs = [] 65 | targets = [] -------------------------------------------------------------------------------- /doc/buildfig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/buildfig.png -------------------------------------------------------------------------------- /doc/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/compare.png -------------------------------------------------------------------------------- /doc/comparegraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/comparegraph.png -------------------------------------------------------------------------------- /doc/hyperparams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/hyperparams.png -------------------------------------------------------------------------------- /doc/imports,png.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/imports,png.PNG -------------------------------------------------------------------------------- /doc/init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/init.png -------------------------------------------------------------------------------- /doc/lossgraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/lossgraph.png -------------------------------------------------------------------------------- /doc/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiamMaclean216/Pytorch-Chatbot/d93639049fc9227bf99be14df5bd272b5f3452d1/doc/training.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | 7 | def a_norm(Q, K): 8 | m = torch.matmul(Q, K.transpose(2,1).float()) 9 | m /= torch.sqrt(torch.tensor(Q.shape[-1]).float()) 10 | 11 | return torch.softmax(m , -1) 12 | 13 | 14 | def attention(Q, K, V): 15 | #Attention(Q, K, V) = norm(QK)V 16 | a = a_norm(Q, K) #(batch_size, dim_attn, seq_length) 17 | 18 | return torch.matmul(a, V) #(batch_size, seq_length, seq_length) 19 | 20 | class AttentionBlock(torch.nn.Module): 21 | def __init__(self, dim_val, dim_attn): 22 | super(AttentionBlock, self).__init__() 23 | self.value = Value(dim_val, dim_val) 24 | self.key = Key(dim_val, dim_attn) 25 | self.query = Query(dim_val, dim_attn) 26 | 27 | def forward(self, x, kv = None): 28 | if(kv is None): 29 | #Attention with x connected to Q,K and V (For encoder) 30 | return attention(self.query(x), self.key(x), self.value(x)) 31 | 32 | #Attention with x as Q, external vector kv as K an V (For decoder) 33 | return attention(self.query(x), self.key(kv), self.value(kv)) 34 | 35 | class MultiHeadAttentionBlock(torch.nn.Module): 36 | def __init__(self, dim_val, dim_attn, n_heads): 37 | super(MultiHeadAttentionBlock, self).__init__() 38 | self.heads = [] 39 | for i in range(n_heads): 40 | self.heads.append(AttentionBlock(dim_val, dim_attn)) 41 | 42 | self.fc = nn.Linear(n_heads * dim_val, dim_val, bias = False) 43 | 44 | 45 | def forward(self, x, kv = None): 46 | a = [] 47 | for h in self.heads: 48 | a.append(h(x, kv = kv)) 49 | 50 | a = torch.stack(a, dim = -1) #combine heads 51 | a = a.flatten(start_dim = 2) #flatten all head outputs 52 | 53 | x = self.fc(a) 54 | 55 | return x 56 | 57 | class Value(torch.nn.Module): 58 | def __init__(self, dim_input, dim_val): 59 | super(Value, self).__init__() 60 | self.dim_val = dim_val 61 | 62 | self.fc1 = nn.Linear(dim_input, dim_val, bias = False) 63 | #self.fc2 = nn.Linear(5, dim_val) 64 | 65 | def forward(self, x): 66 | x = self.fc1(x) 67 | #x = self.fc2(x) 68 | 69 | return x 70 | 71 | class Key(torch.nn.Module): 72 | def __init__(self, dim_input, dim_attn): 73 | super(Key, self).__init__() 74 | self.dim_attn = dim_attn 75 | 76 | self.fc1 = nn.Linear(dim_input, dim_attn, bias = False) 77 | #self.fc2 = nn.Linear(5, dim_attn) 78 | 79 | def forward(self, x): 80 | x = self.fc1(x) 81 | #x = self.fc2(x) 82 | 83 | return x 84 | 85 | class Query(torch.nn.Module): 86 | def __init__(self, dim_input, dim_attn): 87 | super(Query, self).__init__() 88 | self.dim_attn = dim_attn 89 | 90 | self.fc1 = nn.Linear(dim_input, dim_attn, bias = False) 91 | #self.fc2 = nn.Linear(5, dim_attn) 92 | 93 | def forward(self, x): 94 | 95 | x = self.fc1(x) 96 | #print(x.shape) 97 | #x = self.fc2(x) 98 | 99 | return x 100 | 101 | # https://pytorch.org/tutorials/beginner/transformer_tutorial.html 102 | class PositionalEncoding(nn.Module): 103 | def __init__(self, d_model, dropout=0.1, max_len=5000): 104 | super(PositionalEncoding, self).__init__() 105 | 106 | pe = torch.zeros(max_len, d_model) 107 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 108 | 109 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 110 | 111 | pe[:, 0::2] = torch.sin(position * div_term) 112 | pe[:, 1::2] = torch.cos(position * div_term) 113 | 114 | pe = pe.unsqueeze(0).transpose(0, 1) 115 | 116 | self.register_buffer('pe', pe) 117 | 118 | def forward(self, x): 119 | x = x + self.pe[:x.size(1), :]. squeeze(1) 120 | return x 121 | 122 | def get_data(batch_size, input_sequence_length, output_sequence_length): 123 | i = input_sequence_length + output_sequence_length 124 | 125 | t = torch.zeros(batch_size,1).uniform_(0,20 - i).int() 126 | b = torch.arange(-10, -10 + i).unsqueeze(0).repeat(batch_size,1) + t 127 | 128 | s = torch.sigmoid(b.float()) 129 | return s[:, :input_sequence_length].unsqueeze(-1), s[:,-output_sequence_length:] --------------------------------------------------------------------------------