├── .gitignore ├── README.md ├── Stock Market Sentiment with LSTMs and TensorFlow.ipynb ├── checkpoints └── .gitignore ├── data └── StockTwits_SPY_Sentiment_2017.gz ├── dockerfiles ├── Dockerfile.cpu ├── Dockerfile.gpu └── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ipynb_checkpoints/ 3 | .DS_Store 4 | checkpoints/checkpoint 5 | checkpoints/*.ckpt.* 6 | 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction to LSTMs with TensorFlow 2 | 3 | In this tutorial, we will build a Long Short Term Memory (LSTM) Network to predict the stock market sentiment based on a comments about the market from ["StockTwits.com"](https://stocktwits.com). 4 | 5 | This repository contains source code corresponding to our article ["Introduction to LSTMs with TensorFlow"](https://oreilly.com/ideas/introduction-to-lstms-with-tensorflow). 6 | 7 | ## Setup 8 | 9 | ### Download via Git 10 | 11 | 1. Go to your home directory by opening your terminal and entering `cd ~` 12 | 13 | 2. Clone the repository by entering 14 | 15 | ``` 16 | git clone https://github.com/dmonn/lstm-oreilly.git 17 | ``` 18 | 19 | ### Option 1: Dockerfiles (Recommended) 20 | 21 | 3. After cloning the repo to your machine, enter 22 | 23 | ``` 24 | docker build -t lstm_ -f ./dockerfiles/Dockerfile. ./dockerfiles/ 25 | ``` 26 | 27 | where `` is either `gpu` or `cpu`. (Note that, in order to run these files on your GPU, you'll need to have a compatible GPU, with drivers installed and configured properly [as described in TensorFlow's documentation](https://www.tensorflow.org/install/).) 28 | 29 | 4. Run the Docker image by entering 30 | 31 | ``` 32 | docker run -it -p 8888:8888 -v :/root lstm_ 33 | ``` 34 | 35 | where `` is either `gpu` or `cpu`, depending on the image you built in the last step. 36 | 37 | 5. After building, starting, and attaching to the appropriate Docker container, run the provided Jupyter notebooks by entering 38 | 39 | ``` 40 | jupyter notebook --ip 0.0.0.0 --allow-root 41 | ``` 42 | 43 | and navigate to the specified URL `http://0.0.0.0:8888/?token=` in your browser. 44 | 45 | 6. Choose `Stock Market Sentiment with LSTMs and TensorFlow.ipynb` to open the Notebook. 46 | 47 | #### Debugging docker 48 | If you receive an error of the form: 49 | 50 | ``` 51 | WARNING: Error loading config file:/home/rp/.docker/config.json - stat /home/rp/.docker/config.json: permission denied 52 | Got permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock: Get http://%2Fvar%2Frun%2Fdocker.sock/v1.26/images/json: dial unix /var/run/docker.sock: connect: permission denied 53 | ``` 54 | 55 | It's most likely because you installed Docker using sudo permissions with a packet manager such as `brew` or `apt-get`. To solve this `permission denied` simply run docker with `sudo` (ie. run `docker` commands with `sudo docker ` instead of just `docker `). 56 | 57 | ### Option 2: Local setup using Miniconda 58 | 59 | If you don't have or don't want to use Docker, you can follow these steps to setup the notebook. 60 | 61 | 3. Install miniconda using [one of the installers and the miniconda installation instructions](https://conda.io/miniconda.html). Use Python3.6. 62 | 63 | 4. After the installation, create a new virtual environment, using this command. 64 | ``` 65 | $ conda create -n lstm 66 | $ source activate venv 67 | ``` 68 | 69 | 5. You are now in a virtual environment. Next up, [install TensorFlow by following the instructions](https://www.tensorflow.org/install/). 70 | 71 | 6. To install the rest of the dependenies, navigate into your repository and run 72 | 73 | ``` 74 | $ pip install -r dockerfiles/requirements.txt 75 | ``` 76 | 77 | 7. Now you can run 78 | 79 | ``` 80 | jupyter notebook 81 | ``` 82 | 83 | to finally start up the notebook. A browser should open automatically. If not, navigate to [http://127.0.0.1:8888](http://127.0.0.1:8888) in your browser. 84 | 85 | 8. Choose `Stock Market Sentiment with LSTMs and TensorFlow.ipynb` to open the Notebook. 86 | 87 | #### Notes 88 | 89 | The `checkpoints/` directory files with the saved model was not uploaded to github due to size constraints. If you run the code in your docker container or Miniconda virtual environtment, the model will rerun and save at that time. 90 | -------------------------------------------------------------------------------- /Stock Market Sentiment with LSTMs and TensorFlow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Modeling Stock Market Sentiment with LSTMs and TensorFlow" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this tutorial, we will build a Long Short Term Memory (LSTM) Network to predict the stock market sentiment based on a comment about the market." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "We will use the following libraries for our analysis:\n", 29 | "\n", 30 | "* numpy - numerical computing library used to work with our data\n", 31 | "* pandas - data analysis library used to read in our data from csv\n", 32 | "* tensorflow - deep learning framework used for modeling" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "We will also be using the python Counter object for counting our vocabulary items and we have a util module that extracts away a lot of the details of our data processing. Please read through the util.py to get a better understanding of how to preprocess the data for analysis." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import numpy as np\n", 51 | "import pandas as pd\n", 52 | "import tensorflow as tf\n", 53 | "import utils as utl\n", 54 | "from collections import Counter" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "collapsed": true 61 | }, 62 | "source": [ 63 | "## Processing Data" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "We will train the model using messages tagged with SPY, the S&P 500 index fund, from [StockTwits.com](https://www.stocktwits.com). StockTwits is a social media network for traders and investors to share their views about the stock market. When a user posts a message, they tag the relevant stock ticker ($SPY in our case) and have the option to tag the messages with their sentiment – “bullish” if they believe the stock will go up and “bearish” if they believe the stock will go down.\n", 71 | "\n", 72 | "Our dataset consists of approximately 100,000 messages posted in 2017 that are tagged with $SPY where the user indicated their sentiment. Before we get to our LSTM Network we have to perform some processing on our data to get it ready for modeling." 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "#### Read and View Data" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "First we simply read in our data using pandas, pull out our message and sentiment data into numpy arrays. Let's also take a look at a few samples to get familiar with the data set." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 2, 92 | "metadata": { 93 | "collapsed": false 94 | }, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Messages: $SPY crazy day so far!... Sentiment: bearish\n", 101 | "Messages: $SPY Will make a new ATH this week. Watch it!... Sentiment: bullish\n", 102 | "Messages: $SPY $DJIA white elephant in room is $AAPL. Up 14% since election. Strong headwinds w/Trump trade & Strong dollar. How many 7's do you see?... Sentiment: bearish\n", 103 | "Messages: $SPY blocks above. We break above them We should push to double top... Sentiment: bullish\n", 104 | "Messages: $SPY Nothing happening in the market today, guess I'll go to the store and spend some $.... Sentiment: bearish\n", 105 | "Messages: $SPY What an easy call. Good jobs report: good economy, markets go up. Bad jobs report: no more rate hikes, markets go up. Win-win.... Sentiment: bullish\n", 106 | "Messages: $SPY BS market.... Sentiment: bullish\n", 107 | "Messages: $SPY this rally all the cheerleaders were screaming about this morning is pretty weak. I keep adding 2 my short at all spikes... Sentiment: bearish\n", 108 | "Messages: $SPY Dollar ripping higher!... Sentiment: bearish\n", 109 | "Messages: $SPY no reason to go down !... Sentiment: bullish\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "# read data from csv file\n", 115 | "data = pd.read_csv(\"data/StockTwits_SPY_Sentiment_2017.gz\",\n", 116 | " encoding=\"utf-8\",\n", 117 | " compression=\"gzip\",\n", 118 | " index_col=0)\n", 119 | "\n", 120 | "# get messages and sentiment labels\n", 121 | "messages = data.message.values\n", 122 | "labels = data.sentiment.values\n", 123 | "\n", 124 | "# View sample of messages with sentiment\n", 125 | "\n", 126 | "for i in range(10):\n", 127 | " print(\"Messages: {}...\".format(messages[i]),\n", 128 | " \"Sentiment: {}\".format(labels[i]))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "#### Preprocess Messages" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "Working with raw text data often requires preprocessing the text in some fashion to normalize for context. In our case we want to normalize for known unique \"entities\" that appear within messages that carry a similar contextual meaning when analyzing sentiment. This means we want to replace references to specific stock tickers, user names, url links or numbers with a special token identifying the \"entity\". Here we will also make everything lower case and remove punctuation." 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 3, 148 | "metadata": { 149 | "collapsed": true 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "messages = np.array([utl.preprocess_ST_message(message) for message in messages])" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "#### Generate Vocab to Index Mapping" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "To work with raw text we need some encoding from words to numbers for our algorithm to work with the inputs. The first step of doing this is keeping a collection of our full vocabularly and creating a mapping of each word to a unique index. We will use this word to index mapping in a little bit to prep out messages for analysis. \n", 168 | "\n", 169 | "Note that in practice we may want to only include the vocabularly from our training set here to account for the fact that we will likely see new words when our model is out in the wild when we are assessing the results on our validation and test sets. Here, for simplicity and demonstration purposes, we will use our entire data set." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 4, 175 | "metadata": { 176 | "collapsed": true 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "full_lexicon = \" \".join(messages).split()\n", 181 | "vocab_to_int, int_to_vocab = utl.create_lookup_tables(full_lexicon)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "#### Check Message Lengths" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "We will also want to get a sense of the distribution of the length of our inputs. We check for the longest and average messages. We will need to make our input length uniform to feed the data into our model so later we will have some decisions to make about possibly truncating some of the longer messages if they are too long. We also notice that one message has no content remaining after we preprocessed the data, so we will remove this message from our data set." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 5, 201 | "metadata": { 202 | "collapsed": false 203 | }, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "Zero-length messages: 1\n", 210 | "Maximum message length: 244\n", 211 | "Average message length: 78.21856920395598\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "messages_lens = Counter([len(x) for x in messages])\n", 217 | "print(\"Zero-length messages: {}\".format(messages_lens[0]))\n", 218 | "print(\"Maximum message length: {}\".format(max(messages_lens)))\n", 219 | "print(\"Average message length: {}\".format(np.mean([len(x) for x in messages])))" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 6, 225 | "metadata": { 226 | "collapsed": true 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "messages, labels = utl.drop_empty_messages(messages, labels)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "#### Encode Messages and Labels" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "Earlier we mentioned that we need to \"translate\" our text to number for our algorithm to take in as inputs. We call this translation an encoding. We encode our messages to sequences of numbers where each nummber is the word index from the mapping we made earlier. The phrase \"I am bullish\" would now look something like [1, 234, 5345] where each number is the index for the respective word in the message. For our sentiment labels we will simply encode \"bearish\" as 0 and \"bullish\" as 1." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 7, 250 | "metadata": { 251 | "collapsed": true 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "messages = utl.encode_ST_messages(messages, vocab_to_int)\n", 256 | "labels = utl.encode_ST_labels(labels)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "#### Pad Messages" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "The last thing we need to do is make our message inputs the same length. In our case, the longest message is 244 words. LSTMs can usually handle sequence inputs up to 500 items in length so we won't truncate any of the messages here. We need to Zero Pad the rest of the messages that are shorter. We will use a left padding that will pad all of the messages that are shorter than 244 words with 0s at the beginning. So our encoded \"I am bullish\" messages goes from [1, 234, 5345] (length 3) to [0, 0, 0, 0, 0, 0, ... , 0, 0, 1, 234, 5345] (length 244)." 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 8, 276 | "metadata": { 277 | "collapsed": true 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "messages = utl.zero_pad_messages(messages, seq_len=244)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "#### Train, Test, Validation Split" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "The last thing we do is split our data into tranining, validation and test sets and observe the size of each set." 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 9, 301 | "metadata": { 302 | "collapsed": false 303 | }, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "Data Set Size\n", 310 | "Train set: \t\t(77572, 244) \n", 311 | "Validation set: \t(9697, 244) \n", 312 | "Test set: \t\t(9697, 244)\n" 313 | ] 314 | } 315 | ], 316 | "source": [ 317 | "train_x, val_x, test_x, train_y, val_y, test_y = utl.train_val_test_split(messages, labels, split_frac=0.80)\n", 318 | "\n", 319 | "print(\"Data Set Size\")\n", 320 | "print(\"Train set: \\t\\t{}\".format(train_x.shape), \n", 321 | " \"\\nValidation set: \\t{}\".format(val_x.shape),\n", 322 | " \"\\nTest set: \\t\\t{}\".format(test_x.shape))" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "## Building and Training our LSTM Network" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "metadata": {}, 335 | "source": [ 336 | "In this section we will define a number of functions that will construct the items in our network. We will then use these functions to build and train our network." 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "#### Model Inputs" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": {}, 349 | "source": [ 350 | "Here we simply define a function to build TensorFlow Placeholders for our message sequences, our labels and a variable called keep probability associated with drop out (we will talk more about this later). " 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 10, 356 | "metadata": { 357 | "collapsed": true 358 | }, 359 | "outputs": [], 360 | "source": [ 361 | "def model_inputs():\n", 362 | " \"\"\"\n", 363 | " Create the model inputs\n", 364 | " \"\"\"\n", 365 | " inputs_ = tf.placeholder(tf.int32, [None, None], name='inputs')\n", 366 | " labels_ = tf.placeholder(tf.int32, [None, None], name='labels')\n", 367 | " keep_prob_ = tf.placeholder(tf.float32, name='keep_prob')\n", 368 | " \n", 369 | " return inputs_, labels_, keep_prob_" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "metadata": {}, 375 | "source": [ 376 | "#### Embedding Layer" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "In TensorFlow the word embeddings are represented as a vocabulary size x embedding size matrix and will learn these weights during our training process. The embedding lookup is then just a simple lookup from our embedding matrix based on the index of the current word." 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 11, 389 | "metadata": { 390 | "collapsed": true 391 | }, 392 | "outputs": [], 393 | "source": [ 394 | "def build_embedding_layer(inputs_, vocab_size, embed_size):\n", 395 | " \"\"\"\n", 396 | " Create the embedding layer\n", 397 | " \"\"\"\n", 398 | " embedding = tf.Variable(tf.random_uniform((vocab_size, embed_size), -1, 1))\n", 399 | " embed = tf.nn.embedding_lookup(embedding, inputs_)\n", 400 | " \n", 401 | " return embed" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "#### LSTM Layers" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "metadata": {}, 414 | "source": [ 415 | "TensorFlow makes it extremely easy to build LSTM Layers and stack them on top of each other. We represent each LSTM layer as a BasicLSTMCell and keep these in a list to stack them together later. Here we will define a list with our LSTM layer sizes and the number of layers. \n", 416 | "\n", 417 | "We then take each of these LSTM layers and wrap them in a Dropout Layer. Dropout is a regularization technique using in Neural Networks in which any individual node has a probability of “dropping out” of the network during a given iteration of learning. The makes the model more generalizable by ensuring that it is not too dependent on any given nodes. \n", 418 | "\n", 419 | "Finally, we stack these layers using a MultiRNNCell, generate a zero initial state and connect our stacked LSTM layer to our word embedding inputs using dynamic_rnn. Here we track the output and the final state of the LSTM cell, which we will need to pass between mini-batches during training." 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 12, 425 | "metadata": { 426 | "collapsed": true 427 | }, 428 | "outputs": [], 429 | "source": [ 430 | "def build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size):\n", 431 | " \"\"\"\n", 432 | " Create the LSTM layers\n", 433 | " \"\"\"\n", 434 | " lstms = [tf.contrib.rnn.BasicLSTMCell(size) for size in lstm_sizes]\n", 435 | " # Add dropout to the cell\n", 436 | " drops = [tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob_) for lstm in lstms]\n", 437 | " # Stack up multiple LSTM layers, for deep learning\n", 438 | " cell = tf.contrib.rnn.MultiRNNCell(drops)\n", 439 | " # Getting an initial state of all zeros\n", 440 | " initial_state = cell.zero_state(batch_size, tf.float32)\n", 441 | " \n", 442 | " lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, embed, initial_state=initial_state)\n", 443 | " \n", 444 | " return initial_state, lstm_outputs, cell, final_state" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "#### Loss Function and Optimizer" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "First, we get our predictions by passing the final output of the LSTM layers to a sigmoid activation function via a Tensorflow fully connected layer. we only care to use the final output for making predictions so we pull this out using the [: , -1] indexing on our LSTM outputs and pass it through a sigmoid activation function to make the predictions. We pass then pass these predictions to our mean squared error loss function and use the Adadelta Optimizer to minimize the loss." 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 13, 464 | "metadata": { 465 | "collapsed": true 466 | }, 467 | "outputs": [], 468 | "source": [ 469 | "def build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate):\n", 470 | " \"\"\"\n", 471 | " Create the Loss function and Optimizer\n", 472 | " \"\"\"\n", 473 | " predictions = tf.contrib.layers.fully_connected(lstm_outputs[:, -1], 1, activation_fn=tf.sigmoid)\n", 474 | " loss = tf.losses.mean_squared_error(labels_, predictions)\n", 475 | " optimzer = tf.train.AdadeltaOptimizer(learning_rate).minimize(loss)\n", 476 | " \n", 477 | " return predictions, loss, optimzer" 478 | ] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "metadata": {}, 483 | "source": [ 484 | "#### Accuracy" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": {}, 490 | "source": [ 491 | "Finally, we define our accuracy metric for assessing the model performance across our training, validation and test sets. Even though accuracy is just a calculation based on results, everything in TensorFlow is part of a Computation Graph. Therefore, we need to define our loss and accuracy nodes in the context of the rest of our network graph." 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 14, 497 | "metadata": { 498 | "collapsed": true 499 | }, 500 | "outputs": [], 501 | "source": [ 502 | "def build_accuracy(predictions, labels_):\n", 503 | " \"\"\"\n", 504 | " Create accuracy\n", 505 | " \"\"\"\n", 506 | " correct_pred = tf.equal(tf.cast(tf.round(predictions), tf.int32), labels_)\n", 507 | " accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n", 508 | " \n", 509 | " return accuracy" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": {}, 515 | "source": [ 516 | "#### Training" 517 | ] 518 | }, 519 | { 520 | "cell_type": "markdown", 521 | "metadata": {}, 522 | "source": [ 523 | "We are finally ready to build and train our LSTM Network! First, we call each of our each of the functions we have defined to construct the network. Then we define a Saver to be able to write our model to disk to load for future use. Finally, we call a Tensorflow Session to train the model over a predefined number of epochs using mini-batches. At the end of each epoch we will print the loss, training accuracy and validation accuracy to monitor the results as we train." 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 15, 529 | "metadata": { 530 | "collapsed": true 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "def build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size,\n", 535 | " learning_rate, keep_prob, train_x, val_x, train_y, val_y):\n", 536 | " \n", 537 | " inputs_, labels_, keep_prob_ = model_inputs()\n", 538 | " embed = build_embedding_layer(inputs_, vocab_size, embed_size)\n", 539 | " initial_state, lstm_outputs, lstm_cell, final_state = build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size)\n", 540 | " predictions, loss, optimizer = build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate)\n", 541 | " accuracy = build_accuracy(predictions, labels_)\n", 542 | " \n", 543 | " saver = tf.train.Saver()\n", 544 | " \n", 545 | " with tf.Session() as sess:\n", 546 | " \n", 547 | " sess.run(tf.global_variables_initializer())\n", 548 | " n_batches = len(train_x)//batch_size\n", 549 | " for e in range(epochs):\n", 550 | " state = sess.run(initial_state)\n", 551 | " \n", 552 | " train_acc = []\n", 553 | " for ii, (x, y) in enumerate(utl.get_batches(train_x, train_y, batch_size), 1):\n", 554 | " feed = {inputs_: x,\n", 555 | " labels_: y[:, None],\n", 556 | " keep_prob_: keep_prob,\n", 557 | " initial_state: state}\n", 558 | " loss_, state, _, batch_acc = sess.run([loss, final_state, optimizer, accuracy], feed_dict=feed)\n", 559 | " train_acc.append(batch_acc)\n", 560 | " \n", 561 | " if (ii + 1) % n_batches == 0:\n", 562 | " \n", 563 | " val_acc = []\n", 564 | " val_state = sess.run(lstm_cell.zero_state(batch_size, tf.float32))\n", 565 | " for xx, yy in utl.get_batches(val_x, val_y, batch_size):\n", 566 | " feed = {inputs_: xx,\n", 567 | " labels_: yy[:, None],\n", 568 | " keep_prob_: 1,\n", 569 | " initial_state: val_state}\n", 570 | " val_batch_acc, val_state = sess.run([accuracy, final_state], feed_dict=feed)\n", 571 | " val_acc.append(val_batch_acc)\n", 572 | " \n", 573 | " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", 574 | " \"Batch: {}/{}...\".format(ii+1, n_batches),\n", 575 | " \"Train Loss: {:.3f}...\".format(loss_),\n", 576 | " \"Train Accruacy: {:.3f}...\".format(np.mean(train_acc)),\n", 577 | " \"Val Accuracy: {:.3f}\".format(np.mean(val_acc)))\n", 578 | " \n", 579 | " saver.save(sess, \"checkpoints/sentiment.ckpt\")" 580 | ] 581 | }, 582 | { 583 | "cell_type": "markdown", 584 | "metadata": {}, 585 | "source": [ 586 | "Next we define our model hyper parameters. We will build a 2 Layer LSTM Newtork with hidden layer sizes of 128 and 64 respectively. We will use an embedding size of 300 and train over 50 epochs with mini-batches of size 256. We will use an initial learning rate of 0.1, though our Adadelta Optimizer will adapt this over time, and a keep probability of 0.5. " 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": 16, 592 | "metadata": { 593 | "collapsed": true 594 | }, 595 | "outputs": [], 596 | "source": [ 597 | "# Define Inputs and Hyperparameters\n", 598 | "lstm_sizes = [128, 64]\n", 599 | "vocab_size = len(vocab_to_int) + 1 #add one for padding\n", 600 | "embed_size = 300\n", 601 | "epochs = 50\n", 602 | "batch_size = 256\n", 603 | "learning_rate = 0.1\n", 604 | "keep_prob = 0.5" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": {}, 610 | "source": [ 611 | "and now we train!" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 17, 617 | "metadata": { 618 | "collapsed": false, 619 | "scrolled": false 620 | }, 621 | "outputs": [ 622 | { 623 | "name": "stdout", 624 | "output_type": "stream", 625 | "text": [ 626 | "Epoch: 1/50... Batch: 303/303... Train Loss: 0.247... Train Accruacy: 0.562... Val Accuracy: 0.578\n", 627 | "Epoch: 2/50... Batch: 303/303... Train Loss: 0.245... Train Accruacy: 0.583... Val Accuracy: 0.596\n", 628 | "Epoch: 3/50... Batch: 303/303... Train Loss: 0.247... Train Accruacy: 0.597... Val Accuracy: 0.617\n", 629 | "Epoch: 4/50... Batch: 303/303... Train Loss: 0.240... Train Accruacy: 0.610... Val Accuracy: 0.627\n", 630 | "Epoch: 5/50... Batch: 303/303... Train Loss: 0.238... Train Accruacy: 0.620... Val Accuracy: 0.632\n", 631 | "Epoch: 6/50... Batch: 303/303... Train Loss: 0.234... Train Accruacy: 0.632... Val Accuracy: 0.642\n", 632 | "Epoch: 7/50... Batch: 303/303... Train Loss: 0.230... Train Accruacy: 0.636... Val Accuracy: 0.648\n", 633 | "Epoch: 8/50... Batch: 303/303... Train Loss: 0.227... Train Accruacy: 0.641... Val Accuracy: 0.653\n", 634 | "Epoch: 9/50... Batch: 303/303... Train Loss: 0.223... Train Accruacy: 0.646... Val Accuracy: 0.656\n", 635 | "Epoch: 10/50... Batch: 303/303... Train Loss: 0.221... Train Accruacy: 0.652... Val Accuracy: 0.659\n", 636 | "Epoch: 11/50... Batch: 303/303... Train Loss: 0.225... Train Accruacy: 0.656... Val Accuracy: 0.663\n", 637 | "Epoch: 12/50... Batch: 303/303... Train Loss: 0.220... Train Accruacy: 0.661... Val Accuracy: 0.666\n", 638 | "Epoch: 13/50... Batch: 303/303... Train Loss: 0.215... Train Accruacy: 0.665... Val Accuracy: 0.668\n", 639 | "Epoch: 14/50... Batch: 303/303... Train Loss: 0.213... Train Accruacy: 0.668... Val Accuracy: 0.670\n", 640 | "Epoch: 15/50... Batch: 303/303... Train Loss: 0.210... Train Accruacy: 0.669... Val Accuracy: 0.673\n", 641 | "Epoch: 16/50... Batch: 303/303... Train Loss: 0.213... Train Accruacy: 0.673... Val Accuracy: 0.675\n", 642 | "Epoch: 17/50... Batch: 303/303... Train Loss: 0.212... Train Accruacy: 0.675... Val Accuracy: 0.676\n", 643 | "Epoch: 18/50... Batch: 303/303... Train Loss: 0.206... Train Accruacy: 0.681... Val Accuracy: 0.679\n", 644 | "Epoch: 19/50... Batch: 303/303... Train Loss: 0.208... Train Accruacy: 0.683... Val Accuracy: 0.681\n", 645 | "Epoch: 20/50... Batch: 303/303... Train Loss: 0.202... Train Accruacy: 0.684... Val Accuracy: 0.684\n", 646 | "Epoch: 21/50... Batch: 303/303... Train Loss: 0.206... Train Accruacy: 0.685... Val Accuracy: 0.686\n", 647 | "Epoch: 22/50... Batch: 303/303... Train Loss: 0.204... Train Accruacy: 0.689... Val Accuracy: 0.689\n", 648 | "Epoch: 23/50... Batch: 303/303... Train Loss: 0.195... Train Accruacy: 0.690... Val Accuracy: 0.691\n", 649 | "Epoch: 24/50... Batch: 303/303... Train Loss: 0.200... Train Accruacy: 0.695... Val Accuracy: 0.692\n", 650 | "Epoch: 25/50... Batch: 303/303... Train Loss: 0.195... Train Accruacy: 0.696... Val Accuracy: 0.694\n", 651 | "Epoch: 26/50... Batch: 303/303... Train Loss: 0.200... Train Accruacy: 0.698... Val Accuracy: 0.695\n", 652 | "Epoch: 27/50... Batch: 303/303... Train Loss: 0.197... Train Accruacy: 0.701... Val Accuracy: 0.695\n", 653 | "Epoch: 28/50... Batch: 303/303... Train Loss: 0.199... Train Accruacy: 0.703... Val Accuracy: 0.698\n", 654 | "Epoch: 29/50... Batch: 303/303... Train Loss: 0.187... Train Accruacy: 0.704... Val Accuracy: 0.698\n", 655 | "Epoch: 30/50... Batch: 303/303... Train Loss: 0.190... Train Accruacy: 0.708... Val Accuracy: 0.701\n", 656 | "Epoch: 31/50... Batch: 303/303... Train Loss: 0.189... Train Accruacy: 0.708... Val Accuracy: 0.702\n", 657 | "Epoch: 32/50... Batch: 303/303... Train Loss: 0.184... Train Accruacy: 0.710... Val Accuracy: 0.704\n", 658 | "Epoch: 33/50... Batch: 303/303... Train Loss: 0.195... Train Accruacy: 0.714... Val Accuracy: 0.704\n", 659 | "Epoch: 34/50... Batch: 303/303... Train Loss: 0.190... Train Accruacy: 0.715... Val Accuracy: 0.704\n", 660 | "Epoch: 35/50... Batch: 303/303... Train Loss: 0.186... Train Accruacy: 0.714... Val Accuracy: 0.707\n", 661 | "Epoch: 36/50... Batch: 303/303... Train Loss: 0.178... Train Accruacy: 0.717... Val Accuracy: 0.707\n", 662 | "Epoch: 37/50... Batch: 303/303... Train Loss: 0.183... Train Accruacy: 0.722... Val Accuracy: 0.707\n", 663 | "Epoch: 38/50... Batch: 303/303... Train Loss: 0.181... Train Accruacy: 0.721... Val Accuracy: 0.710\n", 664 | "Epoch: 39/50... Batch: 303/303... Train Loss: 0.181... Train Accruacy: 0.723... Val Accuracy: 0.712\n", 665 | "Epoch: 40/50... Batch: 303/303... Train Loss: 0.179... Train Accruacy: 0.726... Val Accuracy: 0.712\n", 666 | "Epoch: 41/50... Batch: 303/303... Train Loss: 0.180... Train Accruacy: 0.726... Val Accuracy: 0.713\n", 667 | "Epoch: 42/50... Batch: 303/303... Train Loss: 0.177... Train Accruacy: 0.729... Val Accuracy: 0.714\n", 668 | "Epoch: 43/50... Batch: 303/303... Train Loss: 0.176... Train Accruacy: 0.731... Val Accuracy: 0.714\n", 669 | "Epoch: 44/50... Batch: 303/303... Train Loss: 0.180... Train Accruacy: 0.732... Val Accuracy: 0.716\n", 670 | "Epoch: 45/50... Batch: 303/303... Train Loss: 0.169... Train Accruacy: 0.734... Val Accuracy: 0.716\n", 671 | "Epoch: 46/50... Batch: 303/303... Train Loss: 0.173... Train Accruacy: 0.735... Val Accuracy: 0.717\n", 672 | "Epoch: 47/50... Batch: 303/303... Train Loss: 0.170... Train Accruacy: 0.736... Val Accuracy: 0.717\n", 673 | "Epoch: 48/50... Batch: 303/303... Train Loss: 0.173... Train Accruacy: 0.739... Val Accuracy: 0.718\n", 674 | "Epoch: 49/50... Batch: 303/303... Train Loss: 0.175... Train Accruacy: 0.740... Val Accuracy: 0.717\n", 675 | "Epoch: 50/50... Batch: 303/303... Train Loss: 0.175... Train Accruacy: 0.745... Val Accuracy: 0.718\n" 676 | ] 677 | } 678 | ], 679 | "source": [ 680 | "with tf.Graph().as_default():\n", 681 | " build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size,\n", 682 | " learning_rate, keep_prob, train_x, val_x, train_y, val_y)" 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "metadata": {}, 688 | "source": [ 689 | "## Testing our Network" 690 | ] 691 | }, 692 | { 693 | "cell_type": "markdown", 694 | "metadata": {}, 695 | "source": [ 696 | "The last thing we want to do is check the model accuracy on our testing data to make sure it is in line with expecations. We build the Computational Graph just like we did before, however, now instead of training we restore our saved model from our checkpoint directory and then run our test data through the model. " 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 18, 702 | "metadata": { 703 | "collapsed": true 704 | }, 705 | "outputs": [], 706 | "source": [ 707 | "def test_network(model_dir, batch_size, test_x, test_y):\n", 708 | " \n", 709 | " inputs_, labels_, keep_prob_ = model_inputs()\n", 710 | " embed = build_embedding_layer(inputs_, vocab_size, embed_size)\n", 711 | " initial_state, lstm_outputs, lstm_cell, final_state = build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size)\n", 712 | " predictions, loss, optimizer = build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate)\n", 713 | " accuracy = build_accuracy(predictions, labels_)\n", 714 | " \n", 715 | " saver = tf.train.Saver()\n", 716 | " \n", 717 | " test_acc = []\n", 718 | " with tf.Session() as sess:\n", 719 | " saver.restore(sess, tf.train.latest_checkpoint(model_dir))\n", 720 | " test_state = sess.run(lstm_cell.zero_state(batch_size, tf.float32))\n", 721 | " for ii, (x, y) in enumerate(utl.get_batches(test_x, test_y, batch_size), 1):\n", 722 | " feed = {inputs_: x,\n", 723 | " labels_: y[:, None],\n", 724 | " keep_prob_: 1,\n", 725 | " initial_state: test_state}\n", 726 | " batch_acc, test_state = sess.run([accuracy, final_state], feed_dict=feed)\n", 727 | " test_acc.append(batch_acc)\n", 728 | " print(\"Test Accuracy: {:.3f}\".format(np.mean(test_acc)))" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": 19, 734 | "metadata": { 735 | "collapsed": false 736 | }, 737 | "outputs": [ 738 | { 739 | "name": "stdout", 740 | "output_type": "stream", 741 | "text": [ 742 | "INFO:tensorflow:Restoring parameters from checkpoints/sentiment.ckpt\n", 743 | "Test Accuracy: 0.717\n" 744 | ] 745 | } 746 | ], 747 | "source": [ 748 | "with tf.Graph().as_default():\n", 749 | " test_network('checkpoints', batch_size, test_x, test_y)" 750 | ] 751 | } 752 | ], 753 | "metadata": { 754 | "anaconda-cloud": {}, 755 | "kernelspec": { 756 | "display_name": "Python [TensorFlow]", 757 | "language": "python", 758 | "name": "Python [TensorFlow]" 759 | }, 760 | "language_info": { 761 | "codemirror_mode": { 762 | "name": "ipython", 763 | "version": 3 764 | }, 765 | "file_extension": ".py", 766 | "mimetype": "text/x-python", 767 | "name": "python", 768 | "nbconvert_exporter": "python", 769 | "pygments_lexer": "ipython3", 770 | "version": "3.6.3" 771 | } 772 | }, 773 | "nbformat": 4, 774 | "nbformat_minor": 1 775 | } 776 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoint 2 | *.ckpt.* 3 | -------------------------------------------------------------------------------- /data/StockTwits_SPY_Sentiment_2017.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/garretthoffman/lstm-oreilly/907bc5a31b5767622c845f31d2dee02fdcf2c971/data/StockTwits_SPY_Sentiment_2017.gz -------------------------------------------------------------------------------- /dockerfiles/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | # Set the base image to Ubuntu 2 | FROM tensorflow/tensorflow:latest-py3-jupyter 3 | 4 | # File Author / Maintainer 5 | MAINTAINER Garrett Hoffman 6 | 7 | # Install git and TF dependencies 8 | RUN apt-get update && \ 9 | apt-get install -y --no-install-recommends libboost-all-dev && \ 10 | apt-get install -y software-properties-common \ 11 | git \ 12 | wget \ 13 | cmake \ 14 | python-zmq \ 15 | python-dev \ 16 | libzmq3-dev \ 17 | libssl-dev \ 18 | libgflags-dev \ 19 | libgoogle-glog-dev \ 20 | liblmdb-dev \ 21 | libatlas-base-dev \ 22 | libblas-dev \ 23 | liblapack-dev \ 24 | libgflags-dev \ 25 | libgoogle-glog-dev \ 26 | liblmdb-dev \ 27 | libprotobuf-dev \ 28 | libleveldb-dev \ 29 | libsnappy-dev \ 30 | libopencv-dev \ 31 | libhdf5-serial-dev \ 32 | protobuf-compiler 33 | 34 | COPY requirements.txt /root/ 35 | 36 | RUN pip install -r /root/requirements.txt 37 | RUN rm /root/requirements.txt 38 | 39 | 40 | WORKDIR /root 41 | 42 | CMD ["/bin/bash"] 43 | -------------------------------------------------------------------------------- /dockerfiles/Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | 2 | # Set the base image to Ubuntu 3 | FROM tensorflow/tensorflow:latest-gpu-py3-jupyter 4 | 5 | # File Author / Maintainer 6 | MAINTAINER Garrett Hoffman 7 | 8 | RUN apt-get update && \ 9 | apt-get install -y --no-install-recommends libboost-all-dev && \ 10 | apt-get install -y software-properties-common \ 11 | git \ 12 | wget \ 13 | cmake \ 14 | python-zmq \ 15 | python-dev \ 16 | libzmq3-dev \ 17 | libssl-dev \ 18 | libgflags-dev \ 19 | libgoogle-glog-dev \ 20 | liblmdb-dev \ 21 | libatlas-base-dev \ 22 | libblas-dev \ 23 | liblapack-dev \ 24 | libgflags-dev \ 25 | libgoogle-glog-dev \ 26 | liblmdb-dev \ 27 | libprotobuf-dev \ 28 | libleveldb-dev \ 29 | libsnappy-dev \ 30 | libopencv-dev \ 31 | libhdf5-serial-dev \ 32 | protobuf-compiler \ 33 | python-tk 34 | 35 | COPY requirements.txt /root/ 36 | 37 | RUN pip install -r /root/requirements.txt 38 | RUN rm /root/requirements.txt 39 | 40 | WORKDIR /root 41 | 42 | CMD ["/bin/bash"] 43 | -------------------------------------------------------------------------------- /dockerfiles/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from collections import Counter 4 | import numpy as np 5 | 6 | def preprocess_ST_message(text): 7 | """ 8 | Preprocesses raw message data for analysis 9 | :param text: String. ST Message 10 | :return: List of Strings. List of processed text tokes 11 | """ 12 | # Define ST Regex Patters 13 | REGEX_PRICE_SIGN = re.compile(r'\$(?!\d*\.?\d+%)\d*\.?\d+|(?!\d*\.?\d+%)\d*\.?\d+\$') 14 | REGEX_PRICE_NOSIGN = re.compile(r'(?!\d*\.?\d+%)(?!\d*\.?\d+k)\d*\.?\d+') 15 | REGEX_TICKER = re.compile('\$[a-zA-Z]+') 16 | REGEX_USER = re.compile('\@\w+') 17 | REGEX_LINK = re.compile('https?:\/\/[^\s]+') 18 | REGEX_HTML_ENTITY = re.compile('\&\w+') 19 | REGEX_NON_ACSII = re.compile('[^\x00-\x7f]') 20 | REGEX_PUNCTUATION = re.compile('[%s]' % re.escape(string.punctuation.replace('<', '')).replace('>', '')) 21 | REGEX_NUMBER = re.compile(r'[-+]?[0-9]+') 22 | 23 | text = text.lower() 24 | 25 | # Replace ST "entitites" with a unique token 26 | text = re.sub(REGEX_TICKER, ' ', text) 27 | text = re.sub(REGEX_USER, ' ', text) 28 | text = re.sub(REGEX_LINK, ' ', text) 29 | text = re.sub(REGEX_PRICE_SIGN, ' ', text) 30 | text = re.sub(REGEX_PRICE_NOSIGN, ' ', text) 31 | text = re.sub(REGEX_NUMBER, ' ', text) 32 | # Remove extraneous text data 33 | text = re.sub(REGEX_HTML_ENTITY, "", text) 34 | text = re.sub(REGEX_NON_ACSII, "", text) 35 | text = re.sub(REGEX_PUNCTUATION, "", text) 36 | # Tokenize and remove < and > that are not in special tokens 37 | words = " ".join(token.replace("<", "").replace(">", "") 38 | if token not in ['', '', '', '', ''] 39 | else token 40 | for token 41 | in text.split()) 42 | 43 | return words 44 | 45 | def create_lookup_tables(words): 46 | """ 47 | Create lookup tables for vocabulary 48 | :param words: Input list of words 49 | :return: A tuple of dicts. The first dict maps a vocab word to and integeter 50 | The second maps an integer back to to the vocab word 51 | """ 52 | word_counts = Counter(words) 53 | sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True) 54 | int_to_vocab = {ii: word for ii, word in enumerate(sorted_vocab, 1)} 55 | vocab_to_int = {word: ii for ii, word in int_to_vocab.items()} 56 | 57 | return vocab_to_int, int_to_vocab 58 | 59 | def encode_ST_messages(messages, vocab_to_int): 60 | """ 61 | Encode ST Sentiment Labels 62 | :param messages: list of list of strings. List of message tokens 63 | :param vocab_to_int: mapping of vocab to idx 64 | :return: list of ints. Lists of encoded messages 65 | """ 66 | messages_encoded = [] 67 | for message in messages: 68 | messages_encoded.append([vocab_to_int[word] for word in message.split()]) 69 | 70 | return np.array(messages_encoded) 71 | 72 | def encode_ST_labels(labels): 73 | """ 74 | Encode ST Sentiment Labels 75 | :param labels: Input list of labels 76 | :return: numpy array. The encoded labels 77 | """ 78 | return np.array([1 if sentiment == 'bullish' else 0 for sentiment in labels]) 79 | 80 | def drop_empty_messages(messages, labels): 81 | """ 82 | Drop messages that are left empty after preprocessing 83 | :param messages: list of encoded messages 84 | :return: tuple of arrays. First array is non-empty messages, second array is non-empty labels 85 | """ 86 | non_zero_idx = [ii for ii, message in enumerate(messages) if len(message) != 0] 87 | messages_non_zero = np.array([messages[ii] for ii in non_zero_idx]) 88 | labels_non_zero = np.array([labels[ii] for ii in non_zero_idx]) 89 | return messages_non_zero, labels_non_zero 90 | 91 | def zero_pad_messages(messages, seq_len): 92 | """ 93 | Zero Pad input messages 94 | :param messages: Input list of encoded messages 95 | :param seq_ken: Input int, maximum sequence input length 96 | :return: numpy array. The encoded labels 97 | """ 98 | messages_padded = np.zeros((len(messages), seq_len), dtype=int) 99 | for i, row in enumerate(messages): 100 | messages_padded[i, -len(row):] = np.array(row)[:seq_len] 101 | 102 | return np.array(messages_padded) 103 | 104 | def train_val_test_split(messages, labels, split_frac, random_seed=None): 105 | """ 106 | Zero Pad input messages 107 | :param messages: Input list of encoded messages 108 | :param labels: Input list of encoded labels 109 | :param split_frac: Input float, training split percentage 110 | :return: tuple of arrays train_x, val_x, test_x, train_y, val_y, test_y 111 | """ 112 | # make sure that number of messages and labels allign 113 | assert len(messages) == len(labels) 114 | # random shuffle data 115 | if random_seed: 116 | np.random.seed(random_seed) 117 | shuf_idx = np.random.permutation(len(messages)) 118 | messages_shuf = np.array(messages)[shuf_idx] 119 | labels_shuf = np.array(labels)[shuf_idx] 120 | 121 | #make splits 122 | split_idx = int(len(messages_shuf)*split_frac) 123 | train_x, val_x = messages_shuf[:split_idx], messages_shuf[split_idx:] 124 | train_y, val_y = labels_shuf[:split_idx], labels_shuf[split_idx:] 125 | 126 | test_idx = int(len(val_x)*0.5) 127 | val_x, test_x = val_x[:test_idx], val_x[test_idx:] 128 | val_y, test_y = val_y[:test_idx], val_y[test_idx:] 129 | 130 | return train_x, val_x, test_x, train_y, val_y, test_y 131 | 132 | def get_batches(x, y, batch_size=100): 133 | """ 134 | Batch Generator for Training 135 | :param x: Input array of x data 136 | :param y: Input array of y data 137 | :param batch_size: Input int, size of batch 138 | :return: generator that returns a tuple of our x batch and y batch 139 | """ 140 | n_batches = len(x)//batch_size 141 | x, y = x[:n_batches*batch_size], y[:n_batches*batch_size] 142 | for ii in range(0, len(x), batch_size): 143 | yield x[ii:ii+batch_size], y[ii:ii+batch_size] --------------------------------------------------------------------------------