├── README.md ├── strata-2019-dl-for-nlp ├── dockerfiles │ ├── requirements.txt │ ├── Dockerfile.cpu │ └── Dockerfile.gpu ├── Deep Learning for NLP Strata 2019.pdf ├── data │ └── StockTwits_SPY_Sentiment_2017.gz ├── README.md ├── utils.py ├── 03_CNN_Glove_Keras.ipynb ├── 01_LSTM_N2N_TF.ipynb ├── 02_BiLSTM_N2N_TF.ipynb └── 04_ULMFiT_fastai.ipynb └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | # talks-and-tutorials 2 | Repository of my conference / meetup talks and tutorials 3 | -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/dockerfiles/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | sklearn 4 | matplotlib 5 | keras==2.0.8 6 | -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/Deep Learning for NLP Strata 2019.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/garretthoffman/talks-and-tutorials/HEAD/strata-2019-dl-for-nlp/Deep Learning for NLP Strata 2019.pdf -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/data/StockTwits_SPY_Sentiment_2017.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/garretthoffman/talks-and-tutorials/HEAD/strata-2019-dl-for-nlp/data/StockTwits_SPY_Sentiment_2017.gz -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/dockerfiles/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | # Set the base image to Ubuntu 2 | FROM tensorflow/tensorflow:latest-py3-jupyter 3 | 4 | # File Author / Maintainer 5 | MAINTAINER Garrett Hoffman 6 | 7 | COPY requirements.txt /root/ 8 | 9 | RUN pip install -r /root/requirements.txt 10 | RUN rm /root/requirements.txt 11 | 12 | WORKDIR /root 13 | 14 | CMD ["/bin/bash"] 15 | -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/dockerfiles/Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | 2 | # Set the base image to Ubuntu 3 | FROM tensorflow/tensorflow:latest-gpu-py3-jupyter 4 | 5 | # File Author / Maintainer 6 | MAINTAINER Garrett Hoffman 7 | 8 | COPY requirements.txt /root/ 9 | 10 | RUN pip install -r /root/requirements.txt 11 | RUN rm /root/requirements.txt 12 | 13 | WORKDIR /root 14 | 15 | CMD ["/bin/bash"] 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | .ipython 29 | .keras 30 | .local 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | .DS_Store 110 | 111 | #vscode 112 | .vscode/ 113 | 114 | #extras 115 | .ipython 116 | .keras 117 | .local 118 | 119 | checkpoints/ 120 | 121 | .bash_history 122 | -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/README.md: -------------------------------------------------------------------------------- 1 | # Deepl Learning Methodologies for NLP - Strata 2019 2 | Slides and Code Tutorials for 2019 Strata Tutorial on Deep Learning Methodologies for Natural Language Processing 3 | 4 | ## Notes 5 | 6 | You can access the ULMFiT Notebook on [Google Colab](https://colab.research.google.com/drive/1Q5lUfTt3WIj4K9VNiMyEFk82gGerZk05) 7 | 8 | For a code sample of RNNs with Attention check out [Taming Recurrent Neural Networks for Better Summarization](http://www.abigailsee.com/2017/04/16/taming-rnns-for-better-summarization.html) with links to the accompanying tensorflow implementation. 9 | 10 | ## Setup 11 | 12 | ### Download via Git 13 | 14 | 1. Go to your desktop by opening your terminal and entering `cd Desktop`. 15 | 16 | 2. Clone the repository by entering 17 | 18 | ``` 19 | git clone https://github.com/GarrettHoffman/talks-and-tutorials.git 20 | ``` 21 | 22 | 3. Move into this directory with `cd talks-and-tutorials/strata-2019-dl-for-nlp`. 23 | 24 | ### Download Twitter GloVe Vectors 25 | 26 | Download the pre-trained Twitter GloVe word vectors from [here](https://nlp.stanford.edu/projects/glove/) and place the file `glove.twitter.27B.50d.txt` in the `data` directory. 27 | 28 | ### Setup Virtual Environment 29 | 30 | #### Option 1: Dockerfiles (Recommended) 31 | 32 | 3. After cloning the repo to your machine, navigate into the repo and enter 33 | 34 | ``` 35 | docker build -t dl_for_nlp_ -f ./dockerfiles/Dockerfile. ./dockerfiles/ 36 | ``` 37 | 38 | where `` is either `gpu` or `cpu`. (Note that, in order to run these files on your GPU, you'll need to have a compatible GPU, with drivers installed and configured properly [as described in TensorFlow's documentation](https://www.tensorflow.org/install/).) 39 | 40 | 4. Run the Docker image by entering 41 | 42 | ``` 43 | docker run -it -p 8888:8888 -v ~/Desktop/talks-and-tutorials/strata-2019-dl-for-nlp:/root dl_for_nlp_ 44 | ``` 45 | 46 | where `` is either `gpu` or `cpu`, depending on the image you built in the last step. 47 | 48 | 5. After building, starting, and attaching to the appropriate Docker container, run the provided Jupyter notebooks by entering 49 | 50 | ``` 51 | jupyter notebook --ip 0.0.0.0 --allow-root 52 | ``` 53 | 54 | and navigate to the specified URL `http://0.0.0.0:8888/?token=` in your browser. 55 | 56 | 6. Choose `0X_Notebook_Title.ipynb` to open the applicable Notebook. Note: The ULMFiT Notebook must be run on Google Colab, see link above. 57 | 58 | ###### Debugging docker 59 | If you receive an error of the form: 60 | 61 | ``` 62 | WARNING: Error loading config file:/home/rp/.docker/config.json - stat /home/rp/.docker/config.json: permission denied 63 | Got permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock: Get http://%2Fvar%2Frun%2Fdocker.sock/v1.26/images/json: dial unix /var/run/docker.sock: connect: permission denied 64 | ``` 65 | 66 | It's most likely because you installed Docker using sudo permissions with a packet manager such as `brew` or `apt-get`. To solve this `permission denied` simply run docker with `sudo` (ie. run `docker` commands with `sudo docker ` instead of just `docker `). 67 | 68 | #### Option 2: Local setup using Miniconda 69 | 70 | If you don't have or don't want to use Docker, you can follow these steps to setup the notebook. 71 | 72 | 3. Install miniconda using [one of the installers and the miniconda installation instructions](https://conda.io/miniconda.html). Use Python3.6. 73 | 74 | 4. After the installation, create a new virtual environment, using this command. 75 | ``` 76 | $ conda create -n strata_nlp 77 | $ source activate venv 78 | ``` 79 | 80 | 5. You are now in a virtual environment. Next up, [install TensorFlow by following the instructions](https://www.tensorflow.org/install/). 81 | 82 | 6. To install the rest of the dependenies, navigate into your repository and run 83 | 84 | ``` 85 | $ pip install -r dockerfiles/requirements.txt 86 | ``` 87 | 88 | 7. Now you can run 89 | 90 | ``` 91 | jupyter notebook 92 | ``` 93 | 94 | to finally start up the notebook. A browser should open automatically. If not, navigate to [http://127.0.0.1:8888](http://127.0.0.1:8888) in your browser. 95 | 96 | 8. Choose `0X_Notebook_Title.ipynb` to open the applicable Notebook. Note: The ULMFiT Notebook must be run on Google Colab, see link above. 97 | -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from collections import Counter 4 | import numpy as np 5 | 6 | def preprocess_ST_message(text): 7 | """ 8 | Preprocesses raw message data for analysis 9 | :param text: String. ST Message 10 | :return: List of Strings. List of processed text tokes 11 | """ 12 | # Define ST Regex Patters 13 | REGEX_PRICE_SIGN = re.compile(r'\$(?!\d*\.?\d+%)\d*\.?\d+|(?!\d*\.?\d+%)\d*\.?\d+\$') 14 | REGEX_PRICE_NOSIGN = re.compile(r'(?!\d*\.?\d+%)(?!\d*\.?\d+k)\d*\.?\d+') 15 | REGEX_TICKER = re.compile('\$[a-zA-Z]+') 16 | REGEX_USER = re.compile('\@\w+') 17 | REGEX_LINK = re.compile('https?:\/\/[^\s]+') 18 | REGEX_HTML_ENTITY = re.compile('\&\w+') 19 | REGEX_NON_ACSII = re.compile('[^\x00-\x7f]') 20 | REGEX_PUNCTUATION = re.compile('[%s]' % re.escape(string.punctuation.replace('<', '')).replace('>', '')) 21 | REGEX_NUMBER = re.compile(r'[-+]?[0-9]+') 22 | 23 | text = text.lower() 24 | 25 | # Replace ST "entitites" with a unique token 26 | text = re.sub(REGEX_TICKER, ' ', text) 27 | text = re.sub(REGEX_USER, ' ', text) 28 | text = re.sub(REGEX_LINK, ' ', text) 29 | text = re.sub(REGEX_PRICE_SIGN, ' ', text) 30 | text = re.sub(REGEX_PRICE_NOSIGN, ' ', text) 31 | text = re.sub(REGEX_NUMBER, ' ', text) 32 | # Remove extraneous text data 33 | text = re.sub(REGEX_HTML_ENTITY, "", text) 34 | text = re.sub(REGEX_NON_ACSII, "", text) 35 | text = re.sub(REGEX_PUNCTUATION, "", text) 36 | # Tokenize and remove < and > that are not in special tokens 37 | words = " ".join(token.replace("<", "").replace(">", "") 38 | if token not in ['', '', '', '', ''] 39 | else token 40 | for token 41 | in text.split()) 42 | 43 | return words 44 | 45 | def create_lookup_tables_w2v(words): 46 | """ 47 | Create lookup tables for vocabulary 48 | :param words: Input list of words 49 | :return: A tuple of dicts. The first dict maps a vocab word to and integeter 50 | The second maps an integer back to to the vocab word 51 | """ 52 | word_counts = Counter(words) 53 | sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True) 54 | int_to_vocab = {ii: word for ii, word in enumerate(sorted_vocab)} 55 | vocab_to_int = {word: ii for ii, word in int_to_vocab.items()} 56 | 57 | return vocab_to_int, int_to_vocab 58 | 59 | def create_lookup_tables(words): 60 | """ 61 | Create lookup tables for vocabulary 62 | :param words: Input list of words 63 | :return: A tuple of dicts. The first dict maps a vocab word to and integeter 64 | The second maps an integer back to to the vocab word 65 | """ 66 | word_counts = Counter(words) 67 | sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True) 68 | int_to_vocab = {ii: word for ii, word in enumerate(sorted_vocab, 1)} 69 | vocab_to_int = {word: ii for ii, word in int_to_vocab.items()} 70 | 71 | return vocab_to_int, int_to_vocab 72 | 73 | def encode_ST_messages(messages, vocab_to_int): 74 | """ 75 | Encode ST Sentiment Labels 76 | :param messages: list of list of strings. List of message tokens 77 | :param vocab_to_int: mapping of vocab to idx 78 | :return: list of ints. Lists of encoded messages 79 | """ 80 | messages_encoded = [] 81 | for message in messages: 82 | messages_encoded.append([vocab_to_int[word] for word in message.split()]) 83 | 84 | return np.array(messages_encoded) 85 | 86 | def encode_ST_labels(labels): 87 | """ 88 | Encode ST Sentiment Labels 89 | :param labels: Input list of labels 90 | :return: numpy array. The encoded labels 91 | """ 92 | return np.array([1 if sentiment == 'bullish' else 0 for sentiment in labels]) 93 | 94 | def drop_empty_messages(messages, labels): 95 | """ 96 | Drop messages that are left empty after preprocessing 97 | :param messages: list of encoded messages 98 | :return: tuple of arrays. First array is non-empty messages, second array is non-empty labels 99 | """ 100 | non_zero_idx = [ii for ii, message in enumerate(messages) if len(message) != 0] 101 | messages_non_zero = np.array([messages[ii] for ii in non_zero_idx]) 102 | labels_non_zero = np.array([labels[ii] for ii in non_zero_idx]) 103 | return messages_non_zero, labels_non_zero 104 | 105 | def zero_pad_messages(messages, seq_len): 106 | """ 107 | Zero Pad input messages 108 | :param messages: Input list of encoded messages 109 | :param seq_ken: Input int, maximum sequence input length 110 | :return: numpy array. The encoded labels 111 | """ 112 | messages_padded = np.zeros((len(messages), seq_len), dtype=int) 113 | for i, row in enumerate(messages): 114 | messages_padded[i, -len(row):] = np.array(row)[:seq_len] 115 | 116 | return np.array(messages_padded) 117 | 118 | def train_val_test_split(messages, labels, split_frac, random_seed=None): 119 | """ 120 | Zero Pad input messages 121 | :param messages: Input list of encoded messages 122 | :param labels: Input list of encoded labels 123 | :param split_frac: Input float, training split percentage 124 | :return: tuple of arrays train_x, val_x, test_x, train_y, val_y, test_y 125 | """ 126 | # make sure that number of messages and labels allign 127 | assert len(messages) == len(labels) 128 | # random shuffle data 129 | if random_seed: 130 | np.random.seed(random_seed) 131 | shuf_idx = np.random.permutation(len(messages)) 132 | messages_shuf = np.array(messages)[shuf_idx] 133 | labels_shuf = np.array(labels)[shuf_idx] 134 | 135 | #make splits 136 | split_idx = int(len(messages_shuf)*split_frac) 137 | train_x, val_x = messages_shuf[:split_idx], messages_shuf[split_idx:] 138 | train_y, val_y = labels_shuf[:split_idx], labels_shuf[split_idx:] 139 | 140 | test_idx = int(len(val_x)*0.5) 141 | val_x, test_x = val_x[:test_idx], val_x[test_idx:] 142 | val_y, test_y = val_y[:test_idx], val_y[test_idx:] 143 | 144 | return train_x, val_x, test_x, train_y, val_y, test_y 145 | 146 | def get_batches(x, y, batch_size=100): 147 | """ 148 | Batch Generator for Training 149 | :param x: Input array of x data 150 | :param y: Input array of y data 151 | :param batch_size: Input int, size of batch 152 | :return: generator that returns a tuple of our x batch and y batch 153 | """ 154 | n_batches = len(x)//batch_size 155 | x, y = x[:n_batches*batch_size], y[:n_batches*batch_size] 156 | for ii in range(0, len(x), batch_size): 157 | yield x[ii:ii+batch_size], y[ii:ii+batch_size] -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/03_CNN_Glove_Keras.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Modeling Stock Market Sentiment with CNNs and Keras" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this tutorial, we will build a CNN Network to predict the stock market sentiment based on a comment about the market." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "We will use the following libraries for our analysis:\n", 29 | "\n", 30 | "* numpy - numerical computing library used to work with our data\n", 31 | "* pandas - data analysis library used to read in our data from csv\n", 32 | "* tensorflow - a lower level deep learning framework used for modeling\n", 33 | "* keras - a higher level deep learning library that absracts away a lot of DL details. Keras will use Tensforflow in the background" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "We will also be using the python Counter object for counting our vocabulary items and we have a util module that extracts away a lot of the details of our data processing. Please read through the util.py to get a better understanding of how to preprocess the data for analysis." 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 1, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stderr", 50 | "output_type": "stream", 51 | "text": [ 52 | "Using TensorFlow backend.\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "import numpy as np\n", 58 | "import pandas as pd\n", 59 | "import tensorflow as tf\n", 60 | "import utils as utl\n", 61 | "from collections import Counter\n", 62 | "\n", 63 | "from keras.preprocessing.text import Tokenizer\n", 64 | "from keras.preprocessing.sequence import pad_sequences\n", 65 | "from keras.utils import to_categorical\n", 66 | "\n", 67 | "from keras.layers import Dense, Input, Flatten, Dropout, Merge\n", 68 | "from keras.layers import Conv1D, MaxPooling1D, Embedding\n", 69 | "from keras.models import Model" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "collapsed": true 76 | }, 77 | "source": [ 78 | "## Processing Data" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "We will train the model using messages tagged with SPY, the S&P 500 index fund, from [StockTwits.com](https://www.stocktwits.com). StockTwits is a social media network for traders and investors to share their views about the stock market. When a user posts a message, they tag the relevant stock ticker ($SPY in our case) and have the option to tag the messages with their sentiment – “bullish” if they believe the stock will go up and “bearish” if they believe the stock will go down.\n", 86 | "\n", 87 | "Our dataset consists of approximately 100,000 messages posted in 2017 that are tagged with $SPY where the user indicated their sentiment. Before we get to our CNN Network we have to perform some processing on our data to get it ready for modeling." 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "#### Read and View Data" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "First we simply read in our data using pandas, pull out our message and sentiment data into numpy arrays. Let's also take a look at a few samples to get familiar with the data set." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 2, 107 | "metadata": { 108 | "scrolled": true 109 | }, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "Messages: $SPY crazy day so far!... Sentiment: bearish\n", 116 | "Messages: $SPY Will make a new ATH this week. Watch it!... Sentiment: bullish\n", 117 | "Messages: $SPY $DJIA white elephant in room is $AAPL. Up 14% since election. Strong headwinds w/Trump trade & Strong dollar. How many 7's do you see?... Sentiment: bearish\n", 118 | "Messages: $SPY blocks above. We break above them We should push to double top... Sentiment: bullish\n", 119 | "Messages: $SPY Nothing happening in the market today, guess I'll go to the store and spend some $.... Sentiment: bearish\n", 120 | "Messages: $SPY What an easy call. Good jobs report: good economy, markets go up. Bad jobs report: no more rate hikes, markets go up. Win-win.... Sentiment: bullish\n", 121 | "Messages: $SPY BS market.... Sentiment: bullish\n", 122 | "Messages: $SPY this rally all the cheerleaders were screaming about this morning is pretty weak. I keep adding 2 my short at all spikes... Sentiment: bearish\n", 123 | "Messages: $SPY Dollar ripping higher!... Sentiment: bearish\n", 124 | "Messages: $SPY no reason to go down !... Sentiment: bullish\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "# read data from csv file\n", 130 | "data = pd.read_csv(\"data/StockTwits_SPY_Sentiment_2017.gz\",\n", 131 | " encoding=\"utf-8\",\n", 132 | " compression=\"gzip\",\n", 133 | " index_col=0)\n", 134 | "\n", 135 | "# get messages and sentiment labels\n", 136 | "messages = data.message.values\n", 137 | "labels = data.sentiment.values\n", 138 | "\n", 139 | "# View sample of messages with sentiment\n", 140 | "\n", 141 | "for i in range(10):\n", 142 | " print(\"Messages: {}...\".format(messages[i]),\n", 143 | " \"Sentiment: {}\".format(labels[i]))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "#### Check Message Lengths" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "We will also want to get a sense of the distribution of the length of our inputs. We check for the longest and average messages. We will need to make our input length uniform to feed the data into our model so later we will have some decisions to make about possibly truncating some of the longer messages if they are too long. We also notice that one message has no content remaining after we preprocessed the data, so we will remove this message from our data set." 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 3, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "Zero-length messages: 0\n", 170 | "Maximum message length: 635\n", 171 | "Average message length: 75.64462136603174\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "messages_lens = Counter([len(x) for x in messages])\n", 177 | "print(\"Zero-length messages: {}\".format(messages_lens[0]))\n", 178 | "print(\"Maximum message length: {}\".format(max(messages_lens)))\n", 179 | "print(\"Average message length: {}\".format(np.mean([len(x) for x in messages])))" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 4, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "messages, labels = utl.drop_empty_messages(messages, labels)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "#### Preprocess Messages" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "Working with raw text data often requires preprocessing the text in some fashion to normalize for context. In our case we want to normalize for known unique \"entities\" that appear within messages that carry a similar contextual meaning when analyzing sentiment. This means we want to replace references to specific stock tickers, user names, url links or numbers with a special token identifying the \"entity\". Here we will also make everything lower case and remove punctuation." 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 5, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "messages = np.array([utl.preprocess_ST_message(message) for message in messages])" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "#### Generate Vocab to Index Mapping" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "We will use a Keras `Tokenizer` in order to generate our word index. The tockenizer takes our vocabulary and assigns each word a unique index from 1 to *VOCAB_SIZE*. Zero is reserved for padding which we will get to in a bit" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 6, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "VOCAB = sorted(list(set(messages)))\n", 235 | "VOCAB_SIZE = len(VOCAB)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 7, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "tokenizer = Tokenizer(num_words=VOCAB_SIZE)\n", 245 | "tokenizer.fit_on_texts(messages)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 8, 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "Found 31975 unique tokens.\n" 258 | ] 259 | } 260 | ], 261 | "source": [ 262 | "word_index = tokenizer.word_index\n", 263 | "print('Found %s unique tokens.' % len(word_index))" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "#### Encode Messages and Labels" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "We need to \"translate\" our text to number for our algorithm to take in as inputs. We call this translation an encoding. We encode our messages to sequences of numbers where each nummber is the word index from the mapping we made earlier. The phrase \"I am bullish\" would now look something like [1, 234, 5345] where each number is the index for the respective word in the message. We can do this very easily with our tokenizer by calling the `text_to_sequences` method. For our sentiment labels we will simply encode \"bearish\" as 0 and \"bullish\" as 1." 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 9, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "sequences = tokenizer.texts_to_sequences(messages)" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 10, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "labels = to_categorical(utl.encode_ST_labels(labels))" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "#### Pad Messages" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": {}, 308 | "source": [ 309 | "The last thing we need to do is make our message inputs the same length. In our case, the average message length is 78 words so we will use a max length of around this amount. We need to Zero Pad the rest of the messages that are shorter. We will use a left padding that will pad all of the messages that are shorter than 244 words with 0s at the beginning. So our encoded \"I am bullish\" messages goes from [1, 234, 5345] (length 3) to [0, 0, 0, 0, 0, 0, ... , 0, 0, 1, 234, 5345] (length 80). Keras has a build in processing function called `pad_sequences` to do this for us." 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 11, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "MAX_SEQUENCE_LENGTH = 80\n", 319 | "cnn_data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "#### Train and Validation Split" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "The last thing we do is split our data into tranining and validation sets. Typically we will want a test set as well but we will skip this for demonstration purposes." 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 12, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "VALIDATION_SPLIT = .2\n", 343 | "num_validation_samples = int(VALIDATION_SPLIT * cnn_data.shape[0])" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 13, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "indices = np.arange(cnn_data.shape[0])\n", 353 | "np.random.shuffle(indices)\n", 354 | "cnn_data = cnn_data[indices]\n", 355 | "labels = labels[indices]" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 14, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "x_train = cnn_data[:-num_validation_samples]\n", 365 | "y_train = labels[:-num_validation_samples]\n", 366 | "x_val = cnn_data[-num_validation_samples:]\n", 367 | "y_val = labels[-num_validation_samples:]" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": {}, 373 | "source": [ 374 | "## Building and Training our CNN Network" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": {}, 380 | "source": [ 381 | "In this section we will load our pretrained word embeddings and build out CNN Model." 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": {}, 387 | "source": [ 388 | "#### Glove Embeddings" 389 | ] 390 | }, 391 | { 392 | "cell_type": "markdown", 393 | "metadata": {}, 394 | "source": [ 395 | "For this example we will use the Twitter GloVe embeddings that can be found here https://nlp.stanford.edu/projects/glove/. We have our embeddings saved in a text file in our data directory so first we load parse these and load them in to a dictionary.\n", 396 | "\n", 397 | "Next we get the mean and standard devation of all embedding values. The pretrained GloVe embeddings won't contain all of the words in our vocabularly so we will seed our embedding matrix for our vocabularly with random draws from a normal distribution with mean *emb_mean* and standard deviation *emb_std*. Next we iterate through all of the words in our vocabulary and set our word embeddings to the GloVe vectors where they are available." 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 26, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "EMBEDDING_DIM = 50\n", 407 | "EMBEDDING_FILE = 'data/glove.twitter.27B.50d.txt'\n", 408 | "\n", 409 | "def get_embed_coefs(word, *arr): \n", 410 | " return word, np.asarray(arr, dtype='float32')" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 18, 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "embeddings_index = dict(get_embed_coefs(*o.rstrip().rsplit(' ')) for o in open(EMBEDDING_FILE))" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 19, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "embeddings_values=list(embeddings_index.values())\n", 429 | "all_embs = np.stack(embeddings_values)\n", 430 | "emb_mean,emb_std = all_embs.mean(), all_embs.std()" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 20, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "embedding_matrix = np.random.normal(emb_mean, emb_std, (len(word_index) + 1, EMBEDDING_DIM))" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 21, 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [ 448 | "for word, i in word_index.items():\n", 449 | " embedding_vector = embeddings_index.get(word)\n", 450 | " if embedding_vector is not None: \n", 451 | " embedding_matrix[i] = embedding_vector" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 22, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "data": { 461 | "text/plain": [ 462 | "(31976, 50)" 463 | ] 464 | }, 465 | "execution_count": 22, 466 | "metadata": {}, 467 | "output_type": "execute_result" 468 | } 469 | ], 470 | "source": [ 471 | "embedding_matrix.shape" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "#### Model and Training" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "Here we build our Convoluational Neural Network using Keras. We will use the architecture from the Yoon Kim model (https://arxiv.org/abs/1408.5882) with some adjustments. We first start with our embeddings layer and then have 3 parallel 1D convulational layers with 128 filters and sizes [3, 4, 5] respectively. We concatenate these results, pass to a dropout layer for regularization, and then to a fully connected layer with relu activation and finally a softmax layer to make out predictions.\n", 486 | "\n", 487 | "Here we also define that we will use a categorical crossentropy loss funcation and an Adam optimizer." 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 27, 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "def ConvNet(embeddings, max_sequence_length, num_words, embedding_dim, labels_index, trainable=False, extra_conv=True):\n", 497 | " \n", 498 | " embedding_layer = Embedding(num_words,\n", 499 | " embedding_dim,\n", 500 | " weights=[embeddings],\n", 501 | " input_length=max_sequence_length,\n", 502 | " trainable=trainable)\n", 503 | "\n", 504 | " sequence_input = Input(shape=(max_sequence_length,), dtype='int32')\n", 505 | " embedded_sequences = embedding_layer(sequence_input)\n", 506 | "\n", 507 | " # Yoon Kim model (https://arxiv.org/abs/1408.5882)\n", 508 | " convs = []\n", 509 | " filter_sizes = [3,4,5]\n", 510 | "\n", 511 | " for filter_size in filter_sizes:\n", 512 | " l_conv = Conv1D(filters=128, kernel_size=filter_size, activation='relu')(embedded_sequences)\n", 513 | " l_pool = MaxPooling1D(pool_size=3)(l_conv)\n", 514 | " convs.append(l_pool)\n", 515 | "\n", 516 | " l_merge = Merge(mode='concat', concat_axis=1)(convs)\n", 517 | "\n", 518 | " # add a 1D convnet with global maxpooling, instead of Yoon Kim model\n", 519 | " conv = Conv1D(filters=128, kernel_size=3, activation='relu')(embedded_sequences)\n", 520 | " pool = MaxPooling1D(pool_size=3)(conv)\n", 521 | "\n", 522 | " if extra_conv==True:\n", 523 | " x = Dropout(0.5)(l_merge) \n", 524 | " else:\n", 525 | " # Original Yoon Kim model\n", 526 | " x = Dropout(0.5)(pool)\n", 527 | " x = Flatten()(x)\n", 528 | " x = Dense(128, activation='relu')(x)\n", 529 | " #x = Dropout(0.5)(x)\n", 530 | "\n", 531 | " preds = Dense(labels_index, activation='softmax')(x)\n", 532 | "\n", 533 | " model = Model(sequence_input, preds)\n", 534 | " model.compile(loss='categorical_crossentropy',\n", 535 | " optimizer='adam',\n", 536 | " metrics=['acc'])\n", 537 | "\n", 538 | " return model" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 24, 544 | "metadata": {}, 545 | "outputs": [ 546 | { 547 | "name": "stderr", 548 | "output_type": "stream", 549 | "text": [ 550 | "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:21: UserWarning: The `Merge` layer is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n" 551 | ] 552 | } 553 | ], 554 | "source": [ 555 | "model = ConvNet(embedding_matrix, MAX_SEQUENCE_LENGTH, len(word_index)+1, EMBEDDING_DIM, \n", 556 | " len(labels[0]), False)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "markdown", 561 | "metadata": {}, 562 | "source": [ 563 | "and now we train!" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": 25, 569 | "metadata": {}, 570 | "outputs": [ 571 | { 572 | "name": "stdout", 573 | "output_type": "stream", 574 | "text": [ 575 | "Train on 77574 samples, validate on 19393 samples\n", 576 | "Epoch 1/50\n", 577 | "77574/77574 [==============================] - 12s - loss: 0.6753 - acc: 0.5859 - val_loss: 0.6544 - val_acc: 0.6047\n", 578 | "Epoch 2/50\n", 579 | "77574/77574 [==============================] - 9s - loss: 0.6345 - acc: 0.6348 - val_loss: 0.6146 - val_acc: 0.6532\n", 580 | "Epoch 3/50\n", 581 | "77574/77574 [==============================] - 9s - loss: 0.6028 - acc: 0.6689 - val_loss: 0.5894 - val_acc: 0.6806\n", 582 | "Epoch 4/50\n", 583 | "77574/77574 [==============================] - 9s - loss: 0.5756 - acc: 0.6942 - val_loss: 0.5797 - val_acc: 0.6917\n", 584 | "Epoch 5/50\n", 585 | "77574/77574 [==============================] - 9s - loss: 0.5539 - acc: 0.7101 - val_loss: 0.5636 - val_acc: 0.6999\n", 586 | "Epoch 6/50\n", 587 | "77574/77574 [==============================] - 9s - loss: 0.5345 - acc: 0.7219 - val_loss: 0.5570 - val_acc: 0.7073\n", 588 | "Epoch 7/50\n", 589 | "77574/77574 [==============================] - 9s - loss: 0.5184 - acc: 0.7345 - val_loss: 0.5548 - val_acc: 0.7085\n", 590 | "Epoch 8/50\n", 591 | "77574/77574 [==============================] - 9s - loss: 0.5021 - acc: 0.7469 - val_loss: 0.5630 - val_acc: 0.7026\n", 592 | "Epoch 9/50\n", 593 | "77574/77574 [==============================] - 9s - loss: 0.4874 - acc: 0.7565 - val_loss: 0.5496 - val_acc: 0.7124\n", 594 | "Epoch 10/50\n", 595 | "77574/77574 [==============================] - 9s - loss: 0.4739 - acc: 0.7652 - val_loss: 0.5515 - val_acc: 0.7129\n", 596 | "Epoch 11/50\n", 597 | "77574/77574 [==============================] - 9s - loss: 0.4613 - acc: 0.7736 - val_loss: 0.5528 - val_acc: 0.7099\n", 598 | "Epoch 12/50\n", 599 | "77574/77574 [==============================] - 9s - loss: 0.4490 - acc: 0.7805 - val_loss: 0.5545 - val_acc: 0.7110\n", 600 | "Epoch 13/50\n", 601 | "77574/77574 [==============================] - 9s - loss: 0.4388 - acc: 0.7867 - val_loss: 0.5598 - val_acc: 0.7110\n", 602 | "Epoch 14/50\n", 603 | "77574/77574 [==============================] - 9s - loss: 0.4262 - acc: 0.7957 - val_loss: 0.5577 - val_acc: 0.7103\n", 604 | "Epoch 15/50\n", 605 | "77574/77574 [==============================] - 9s - loss: 0.4147 - acc: 0.8016 - val_loss: 0.5626 - val_acc: 0.7091\n", 606 | "Epoch 16/50\n", 607 | "77574/77574 [==============================] - 9s - loss: 0.4066 - acc: 0.8061 - val_loss: 0.5721 - val_acc: 0.7082\n", 608 | "Epoch 17/50\n", 609 | "77574/77574 [==============================] - 9s - loss: 0.3935 - acc: 0.8117 - val_loss: 0.5746 - val_acc: 0.7129\n", 610 | "Epoch 18/50\n", 611 | "77574/77574 [==============================] - 9s - loss: 0.3832 - acc: 0.8180 - val_loss: 0.5828 - val_acc: 0.7137\n", 612 | "Epoch 19/50\n", 613 | "77574/77574 [==============================] - 9s - loss: 0.3772 - acc: 0.8221 - val_loss: 0.5809 - val_acc: 0.7141\n", 614 | "Epoch 20/50\n", 615 | "77574/77574 [==============================] - 9s - loss: 0.3720 - acc: 0.8254 - val_loss: 0.5847 - val_acc: 0.7085\n", 616 | "Epoch 21/50\n", 617 | "77574/77574 [==============================] - 9s - loss: 0.3646 - acc: 0.8298 - val_loss: 0.5838 - val_acc: 0.7073\n", 618 | "Epoch 22/50\n", 619 | "77574/77574 [==============================] - 9s - loss: 0.3575 - acc: 0.8341 - val_loss: 0.5841 - val_acc: 0.7103\n", 620 | "Epoch 23/50\n", 621 | "77574/77574 [==============================] - 9s - loss: 0.3514 - acc: 0.8380 - val_loss: 0.5921 - val_acc: 0.7118\n", 622 | "Epoch 24/50\n", 623 | "77574/77574 [==============================] - 9s - loss: 0.3466 - acc: 0.8396 - val_loss: 0.5975 - val_acc: 0.7066\n", 624 | "Epoch 25/50\n", 625 | "77574/77574 [==============================] - 9s - loss: 0.3424 - acc: 0.8423 - val_loss: 0.5915 - val_acc: 0.7112\n", 626 | "Epoch 26/50\n", 627 | "77574/77574 [==============================] - 9s - loss: 0.3368 - acc: 0.8449 - val_loss: 0.6010 - val_acc: 0.7091\n", 628 | "Epoch 27/50\n", 629 | "77574/77574 [==============================] - 9s - loss: 0.3331 - acc: 0.8474 - val_loss: 0.6044 - val_acc: 0.7111\n", 630 | "Epoch 28/50\n", 631 | "77574/77574 [==============================] - 9s - loss: 0.3263 - acc: 0.8495 - val_loss: 0.6139 - val_acc: 0.7096\n", 632 | "Epoch 29/50\n", 633 | "77574/77574 [==============================] - 9s - loss: 0.3217 - acc: 0.8543 - val_loss: 0.6137 - val_acc: 0.7047\n", 634 | "Epoch 30/50\n", 635 | "77574/77574 [==============================] - 9s - loss: 0.3163 - acc: 0.8550 - val_loss: 0.6122 - val_acc: 0.7059\n", 636 | "Epoch 31/50\n", 637 | "77574/77574 [==============================] - 9s - loss: 0.3172 - acc: 0.8552 - val_loss: 0.6215 - val_acc: 0.7074\n", 638 | "Epoch 32/50\n", 639 | "77574/77574 [==============================] - 9s - loss: 0.3096 - acc: 0.8585 - val_loss: 0.6237 - val_acc: 0.7090\n", 640 | "Epoch 33/50\n", 641 | "77574/77574 [==============================] - 9s - loss: 0.3030 - acc: 0.8628 - val_loss: 0.6231 - val_acc: 0.7080\n", 642 | "Epoch 34/50\n", 643 | "77574/77574 [==============================] - 9s - loss: 0.3026 - acc: 0.8621 - val_loss: 0.6233 - val_acc: 0.7102\n", 644 | "Epoch 35/50\n", 645 | "77574/77574 [==============================] - 9s - loss: 0.2999 - acc: 0.8643 - val_loss: 0.6312 - val_acc: 0.7129\n", 646 | "Epoch 36/50\n", 647 | "77574/77574 [==============================] - 9s - loss: 0.2992 - acc: 0.8646 - val_loss: 0.6212 - val_acc: 0.7113\n", 648 | "Epoch 37/50\n", 649 | "77574/77574 [==============================] - 9s - loss: 0.2923 - acc: 0.8695 - val_loss: 0.6390 - val_acc: 0.7023\n", 650 | "Epoch 38/50\n", 651 | "77574/77574 [==============================] - 9s - loss: 0.2921 - acc: 0.8683 - val_loss: 0.6325 - val_acc: 0.7132\n", 652 | "Epoch 39/50\n", 653 | "77574/77574 [==============================] - 9s - loss: 0.2904 - acc: 0.8700 - val_loss: 0.6335 - val_acc: 0.7070\n", 654 | "Epoch 40/50\n", 655 | "77574/77574 [==============================] - 9s - loss: 0.2864 - acc: 0.8711 - val_loss: 0.6445 - val_acc: 0.7093\n", 656 | "Epoch 41/50\n", 657 | "77574/77574 [==============================] - 9s - loss: 0.2833 - acc: 0.8720 - val_loss: 0.6386 - val_acc: 0.7074\n", 658 | "Epoch 42/50\n", 659 | "77574/77574 [==============================] - 9s - loss: 0.2815 - acc: 0.8721 - val_loss: 0.6494 - val_acc: 0.7045\n", 660 | "Epoch 43/50\n", 661 | "77574/77574 [==============================] - 9s - loss: 0.2768 - acc: 0.8752 - val_loss: 0.6433 - val_acc: 0.7026\n", 662 | "Epoch 44/50\n", 663 | "77574/77574 [==============================] - 9s - loss: 0.2790 - acc: 0.8740 - val_loss: 0.6509 - val_acc: 0.7007\n", 664 | "Epoch 45/50\n", 665 | "77574/77574 [==============================] - 9s - loss: 0.2750 - acc: 0.8764 - val_loss: 0.6492 - val_acc: 0.7077\n", 666 | "Epoch 46/50\n", 667 | "77574/77574 [==============================] - 9s - loss: 0.2709 - acc: 0.8787 - val_loss: 0.6567 - val_acc: 0.7113\n", 668 | "Epoch 47/50\n", 669 | "77574/77574 [==============================] - 9s - loss: 0.2699 - acc: 0.8785 - val_loss: 0.6500 - val_acc: 0.7094\n", 670 | "Epoch 48/50\n", 671 | "77574/77574 [==============================] - 9s - loss: 0.2679 - acc: 0.8791 - val_loss: 0.6596 - val_acc: 0.7059\n", 672 | "Epoch 49/50\n", 673 | "77574/77574 [==============================] - 9s - loss: 0.2660 - acc: 0.8801 - val_loss: 0.6665 - val_acc: 0.7064\n", 674 | "Epoch 50/50\n", 675 | "77574/77574 [==============================] - 9s - loss: 0.2662 - acc: 0.8799 - val_loss: 0.6540 - val_acc: 0.7039\n" 676 | ] 677 | }, 678 | { 679 | "data": { 680 | "text/plain": [ 681 | "" 682 | ] 683 | }, 684 | "execution_count": 25, 685 | "metadata": {}, 686 | "output_type": "execute_result" 687 | } 688 | ], 689 | "source": [ 690 | "model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=50, batch_size=128)" 691 | ] 692 | } 693 | ], 694 | "metadata": { 695 | "anaconda-cloud": {}, 696 | "kernelspec": { 697 | "display_name": "Python 3", 698 | "language": "python", 699 | "name": "python3" 700 | }, 701 | "language_info": { 702 | "codemirror_mode": { 703 | "name": "ipython", 704 | "version": 3 705 | }, 706 | "file_extension": ".py", 707 | "mimetype": "text/x-python", 708 | "name": "python", 709 | "nbconvert_exporter": "python", 710 | "pygments_lexer": "ipython3", 711 | "version": "3.5.2" 712 | } 713 | }, 714 | "nbformat": 4, 715 | "nbformat_minor": 1 716 | } 717 | -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/01_LSTM_N2N_TF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Modeling Stock Market Sentiment with LSTMs and TensorFlow" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this tutorial, we will build a Long Short Term Memory (LSTM) Network to predict the stock market sentiment based on a comment about the market." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "We will use the following libraries for our analysis:\n", 29 | "\n", 30 | "* numpy - numerical computing library used to work with our data\n", 31 | "* pandas - data analysis library used to read in our data from csv\n", 32 | "* tensorflow - deep learning framework used for modeling" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "We will also be using the python Counter object for counting our vocabulary items and we have a util module that extracts away a lot of the details of our data processing. Please read through the util.py to get a better understanding of how to preprocess the data for analysis." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import numpy as np\n", 51 | "import pandas as pd\n", 52 | "import tensorflow as tf\n", 53 | "import utils as utl\n", 54 | "from collections import Counter" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "collapsed": true 61 | }, 62 | "source": [ 63 | "## Processing Data" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "We will train the model using messages tagged with SPY, the S&P 500 index fund, from [StockTwits.com](https://www.stocktwits.com). StockTwits is a social media network for traders and investors to share their views about the stock market. When a user posts a message, they tag the relevant stock ticker ($SPY in our case) and have the option to tag the messages with their sentiment – “bullish” if they believe the stock will go up and “bearish” if they believe the stock will go down.\n", 71 | "\n", 72 | "Our dataset consists of approximately 100,000 messages posted in 2017 that are tagged with $SPY where the user indicated their sentiment. Before we get to our LSTM Network we have to perform some processing on our data to get it ready for modeling." 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "#### Read and View Data" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "First we simply read in our data using pandas, pull out our message and sentiment data into numpy arrays. Let's also take a look at a few samples to get familiar with the data set." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 2, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "Messages: $SPY crazy day so far!... Sentiment: bearish\n", 99 | "Messages: $SPY Will make a new ATH this week. Watch it!... Sentiment: bullish\n", 100 | "Messages: $SPY $DJIA white elephant in room is $AAPL. Up 14% since election. Strong headwinds w/Trump trade & Strong dollar. How many 7's do you see?... Sentiment: bearish\n", 101 | "Messages: $SPY blocks above. We break above them We should push to double top... Sentiment: bullish\n", 102 | "Messages: $SPY Nothing happening in the market today, guess I'll go to the store and spend some $.... Sentiment: bearish\n", 103 | "Messages: $SPY What an easy call. Good jobs report: good economy, markets go up. Bad jobs report: no more rate hikes, markets go up. Win-win.... Sentiment: bullish\n", 104 | "Messages: $SPY BS market.... Sentiment: bullish\n", 105 | "Messages: $SPY this rally all the cheerleaders were screaming about this morning is pretty weak. I keep adding 2 my short at all spikes... Sentiment: bearish\n", 106 | "Messages: $SPY Dollar ripping higher!... Sentiment: bearish\n", 107 | "Messages: $SPY no reason to go down !... Sentiment: bullish\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "# read data from csv file\n", 113 | "data = pd.read_csv(\"data/StockTwits_SPY_Sentiment_2017.gz\",\n", 114 | " encoding=\"utf-8\",\n", 115 | " compression=\"gzip\",\n", 116 | " index_col=0)\n", 117 | "\n", 118 | "# get messages and sentiment labels\n", 119 | "messages = data.message.values\n", 120 | "labels = data.sentiment.values\n", 121 | "\n", 122 | "# View sample of messages with sentiment\n", 123 | "\n", 124 | "for i in range(10):\n", 125 | " print(\"Messages: {}...\".format(messages[i]),\n", 126 | " \"Sentiment: {}\".format(labels[i]))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "#### Preprocess Messages" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "Working with raw text data often requires preprocessing the text in some fashion to normalize for context. In our case we want to normalize for known unique \"entities\" that appear within messages that carry a similar contextual meaning when analyzing sentiment. This means we want to replace references to specific stock tickers, user names, url links or numbers with a special token identifying the \"entity\". Here we will also make everything lower case and remove punctuation." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 3, 146 | "metadata": { 147 | "collapsed": true 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "messages = np.array([utl.preprocess_ST_message(message) for message in messages])" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "#### Generate Vocab to Index Mapping" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "To work with raw text we need some encoding from words to numbers for our algorithm to work with the inputs. The first step of doing this is keeping a collection of our full vocabularly and creating a mapping of each word to a unique index. We will use this word to index mapping in a little bit to prep out messages for analysis. \n", 166 | "\n", 167 | "Note that in practice we may want to only include the vocabularly from our training set here to account for the fact that we will likely see new words when our model is out in the wild when we are assessing the results on our validation and test sets. Here, for simplicity and demonstration purposes, we will use our entire data set." 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 4, 173 | "metadata": { 174 | "collapsed": true 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "full_lexicon = \" \".join(messages).split()\n", 179 | "vocab_to_int, int_to_vocab = utl.create_lookup_tables(full_lexicon)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "#### Check Message Lengths" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "We will also want to get a sense of the distribution of the length of our inputs. We check for the longest and average messages. We will need to make our input length uniform to feed the data into our model so later we will have some decisions to make about possibly truncating some of the longer messages if they are too long. We also notice that one message has no content remaining after we preprocessed the data, so we will remove this message from our data set." 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 5, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "Zero-length messages: 1\n", 206 | "Maximum message length: 244\n", 207 | "Average message length: 78.21856920395598\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "messages_lens = Counter([len(x) for x in messages])\n", 213 | "print(\"Zero-length messages: {}\".format(messages_lens[0]))\n", 214 | "print(\"Maximum message length: {}\".format(max(messages_lens)))\n", 215 | "print(\"Average message length: {}\".format(np.mean([len(x) for x in messages])))" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 6, 221 | "metadata": { 222 | "collapsed": true 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "messages, labels = utl.drop_empty_messages(messages, labels)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "#### Encode Messages and Labels" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "Earlier we mentioned that we need to \"translate\" our text to number for our algorithm to take in as inputs. We call this translation an encoding. We encode our messages to sequences of numbers where each nummber is the word index from the mapping we made earlier. The phrase \"I am bullish\" would now look something like [1, 234, 5345] where each number is the index for the respective word in the message. For our sentiment labels we will simply encode \"bearish\" as 0 and \"bullish\" as 1." 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 7, 246 | "metadata": { 247 | "collapsed": true 248 | }, 249 | "outputs": [], 250 | "source": [ 251 | "messages = utl.encode_ST_messages(messages, vocab_to_int)\n", 252 | "labels = utl.encode_ST_labels(labels)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "#### Pad Messages" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": {}, 265 | "source": [ 266 | "The last thing we need to do is make our message inputs the same length. In our case, the longest message is 244 words. LSTMs can usually handle sequence inputs up to 500 items in length so we won't truncate any of the messages here. We need to Zero Pad the rest of the messages that are shorter. We will use a left padding that will pad all of the messages that are shorter than 244 words with 0s at the beginning. So our encoded \"I am bullish\" messages goes from [1, 234, 5345] (length 3) to [0, 0, 0, 0, 0, 0, ... , 0, 0, 1, 234, 5345] (length 244)." 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 8, 272 | "metadata": { 273 | "collapsed": true 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "messages = utl.zero_pad_messages(messages, seq_len=244)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "#### Train, Test, Validation Split" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "The last thing we do is split our data into tranining, validation and test sets and observe the size of each set." 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 9, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "Data Set Size\n", 304 | "Train set: \t\t(77572, 244) \n", 305 | "Validation set: \t(9697, 244) \n", 306 | "Test set: \t\t(9697, 244)\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "train_x, val_x, test_x, train_y, val_y, test_y = utl.train_val_test_split(messages, labels, split_frac=0.80)\n", 312 | "\n", 313 | "print(\"Data Set Size\")\n", 314 | "print(\"Train set: \\t\\t{}\".format(train_x.shape), \n", 315 | " \"\\nValidation set: \\t{}\".format(val_x.shape),\n", 316 | " \"\\nTest set: \\t\\t{}\".format(test_x.shape))" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "## Building and Training our LSTM Network" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "In this section we will define a number of functions that will construct the items in our network. We will then use these functions to build and train our network." 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "#### Model Inputs" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "metadata": {}, 343 | "source": [ 344 | "Here we simply define a function to build TensorFlow Placeholders for our message sequences, our labels and a variable called keep probability associated with drop out (we will talk more about this later). " 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 10, 350 | "metadata": { 351 | "collapsed": true 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "def model_inputs():\n", 356 | " \"\"\"\n", 357 | " Create the model inputs\n", 358 | " \"\"\"\n", 359 | " inputs_ = tf.placeholder(tf.int32, [None, None], name='inputs')\n", 360 | " labels_ = tf.placeholder(tf.int32, [None, None], name='labels')\n", 361 | " keep_prob_ = tf.placeholder(tf.float32, name='keep_prob')\n", 362 | " \n", 363 | " return inputs_, labels_, keep_prob_" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": {}, 369 | "source": [ 370 | "#### Embedding Layer" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "metadata": {}, 376 | "source": [ 377 | "In TensorFlow the word embeddings are represented as a vocabulary size x embedding size matrix and will learn these weights during our training process. The embedding lookup is then just a simple lookup from our embedding matrix based on the index of the current word." 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 11, 383 | "metadata": { 384 | "collapsed": true 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "def build_embedding_layer(inputs_, vocab_size, embed_size):\n", 389 | " \"\"\"\n", 390 | " Create the embedding layer\n", 391 | " \"\"\"\n", 392 | " embedding = tf.Variable(tf.random_uniform((vocab_size, embed_size), -1, 1))\n", 393 | " embed = tf.nn.embedding_lookup(embedding, inputs_)\n", 394 | " \n", 395 | " return embed" 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": {}, 401 | "source": [ 402 | "#### LSTM Layers" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "TensorFlow makes it extremely easy to build LSTM Layers and stack them on top of each other. We represent each LSTM layer as a BasicLSTMCell and keep these in a list to stack them together later. Here we will define a list with our LSTM layer sizes and the number of layers. \n", 410 | "\n", 411 | "We then take each of these LSTM layers and wrap them in a Dropout Layer. Dropout is a regularization technique using in Neural Networks in which any individual node has a probability of “dropping out” of the network during a given iteration of learning. The makes the model more generalizable by ensuring that it is not too dependent on any given nodes. \n", 412 | "\n", 413 | "Finally, we stack these layers using a MultiRNNCell, generate a zero initial state and connect our stacked LSTM layer to our word embedding inputs using dynamic_rnn. Here we track the output and the final state of the LSTM cell, which we will need to pass between mini-batches during training." 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 12, 419 | "metadata": { 420 | "collapsed": true 421 | }, 422 | "outputs": [], 423 | "source": [ 424 | "def build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size):\n", 425 | " \"\"\"\n", 426 | " Create the LSTM layers\n", 427 | " \"\"\"\n", 428 | " lstms = [tf.contrib.rnn.BasicLSTMCell(size) for size in lstm_sizes]\n", 429 | " # Add dropout to the cell\n", 430 | " drops = [tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob_) for lstm in lstms]\n", 431 | " # Stack up multiple LSTM layers, for deep learning\n", 432 | " cell = tf.contrib.rnn.MultiRNNCell(drops)\n", 433 | " # Getting an initial state of all zeros\n", 434 | " initial_state = cell.zero_state(batch_size, tf.float32)\n", 435 | " \n", 436 | " lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, embed, initial_state=initial_state)\n", 437 | " \n", 438 | " return initial_state, lstm_outputs, cell, final_state" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "#### Loss Function and Optimizer" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "metadata": {}, 451 | "source": [ 452 | "First, we get our predictions by passing the final output of the LSTM layers to a sigmoid activation function via a Tensorflow fully connected layer. we only care to use the final output for making predictions so we pull this out using the [: , -1] indexing on our LSTM outputs and pass it through a sigmoid activation function to make the predictions. We pass then pass these predictions to our mean squared error loss function and use the Adadelta Optimizer to minimize the loss." 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 13, 458 | "metadata": { 459 | "collapsed": true 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "def build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate):\n", 464 | " \"\"\"\n", 465 | " Create the Loss function and Optimizer\n", 466 | " \"\"\"\n", 467 | " predictions = tf.contrib.layers.fully_connected(lstm_outputs[:, -1], 1, activation_fn=tf.sigmoid)\n", 468 | " loss = tf.losses.mean_squared_error(labels_, predictions)\n", 469 | " optimzer = tf.train.AdadeltaOptimizer(learning_rate).minimize(loss)\n", 470 | " \n", 471 | " return predictions, loss, optimzer" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "#### Accuracy" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "Finally, we define our accuracy metric for assessing the model performance across our training, validation and test sets. Even though accuracy is just a calculation based on results, everything in TensorFlow is part of a Computation Graph. Therefore, we need to define our loss and accuracy nodes in the context of the rest of our network graph." 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 14, 491 | "metadata": { 492 | "collapsed": true 493 | }, 494 | "outputs": [], 495 | "source": [ 496 | "def build_accuracy(predictions, labels_):\n", 497 | " \"\"\"\n", 498 | " Create accuracy\n", 499 | " \"\"\"\n", 500 | " correct_pred = tf.equal(tf.cast(tf.round(predictions), tf.int32), labels_)\n", 501 | " accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n", 502 | " \n", 503 | " return accuracy" 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": {}, 509 | "source": [ 510 | "#### Training" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "We are finally ready to build and train our LSTM Network! First, we call each of our each of the functions we have defined to construct the network. Then we define a Saver to be able to write our model to disk to load for future use. Finally, we call a Tensorflow Session to train the model over a predefined number of epochs using mini-batches. At the end of each epoch we will print the loss, training accuracy and validation accuracy to monitor the results as we train." 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 15, 523 | "metadata": { 524 | "collapsed": true 525 | }, 526 | "outputs": [], 527 | "source": [ 528 | "def build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size,\n", 529 | " learning_rate, keep_prob, train_x, val_x, train_y, val_y):\n", 530 | " \n", 531 | " inputs_, labels_, keep_prob_ = model_inputs()\n", 532 | " embed = build_embedding_layer(inputs_, vocab_size, embed_size)\n", 533 | " initial_state, lstm_outputs, lstm_cell, final_state = build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size)\n", 534 | " predictions, loss, optimizer = build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate)\n", 535 | " accuracy = build_accuracy(predictions, labels_)\n", 536 | " \n", 537 | " saver = tf.train.Saver()\n", 538 | " \n", 539 | " with tf.Session() as sess:\n", 540 | " \n", 541 | " sess.run(tf.global_variables_initializer())\n", 542 | " n_batches = len(train_x)//batch_size\n", 543 | " for e in range(epochs):\n", 544 | " state = sess.run(initial_state)\n", 545 | " \n", 546 | " train_acc = []\n", 547 | " for ii, (x, y) in enumerate(utl.get_batches(train_x, train_y, batch_size), 1):\n", 548 | " feed = {inputs_: x,\n", 549 | " labels_: y[:, None],\n", 550 | " keep_prob_: keep_prob,\n", 551 | " initial_state: state}\n", 552 | " loss_, state, _, batch_acc = sess.run([loss, final_state, optimizer, accuracy], feed_dict=feed)\n", 553 | " train_acc.append(batch_acc)\n", 554 | " \n", 555 | " if (ii + 1) % n_batches == 0:\n", 556 | " \n", 557 | " val_acc = []\n", 558 | " val_state = sess.run(lstm_cell.zero_state(batch_size, tf.float32))\n", 559 | " for xx, yy in utl.get_batches(val_x, val_y, batch_size):\n", 560 | " feed = {inputs_: xx,\n", 561 | " labels_: yy[:, None],\n", 562 | " keep_prob_: 1,\n", 563 | " initial_state: val_state}\n", 564 | " val_batch_acc, val_state = sess.run([accuracy, final_state], feed_dict=feed)\n", 565 | " val_acc.append(val_batch_acc)\n", 566 | " \n", 567 | " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", 568 | " \"Batch: {}/{}...\".format(ii+1, n_batches),\n", 569 | " \"Train Loss: {:.3f}...\".format(loss_),\n", 570 | " \"Train Accruacy: {:.3f}...\".format(np.mean(train_acc)),\n", 571 | " \"Val Accuracy: {:.3f}\".format(np.mean(val_acc)))\n", 572 | " \n", 573 | " saver.save(sess, \"checkpoints/sentiment.ckpt\")" 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": {}, 579 | "source": [ 580 | "Next we define our model hyper parameters. We will build a 2 Layer LSTM Newtork with hidden layer sizes of 128 and 64 respectively. We will use an embedding size of 300 and train over 50 epochs with mini-batches of size 256. We will use an initial learning rate of 0.1, though our Adadelta Optimizer will adapt this over time, and a keep probability of 0.5. " 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 16, 586 | "metadata": { 587 | "collapsed": true 588 | }, 589 | "outputs": [], 590 | "source": [ 591 | "# Define Inputs and Hyperparameters\n", 592 | "lstm_sizes = [128, 64]\n", 593 | "vocab_size = len(vocab_to_int) + 1 #add one for padding\n", 594 | "embed_size = 300\n", 595 | "epochs = 50\n", 596 | "batch_size = 256\n", 597 | "learning_rate = 0.1\n", 598 | "keep_prob = 0.5" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "and now we train!" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": 17, 611 | "metadata": { 612 | "scrolled": false 613 | }, 614 | "outputs": [ 615 | { 616 | "name": "stdout", 617 | "output_type": "stream", 618 | "text": [ 619 | "Epoch: 1/50... Batch: 303/303... Train Loss: 0.247... Train Accruacy: 0.562... Val Accuracy: 0.578\n", 620 | "Epoch: 2/50... Batch: 303/303... Train Loss: 0.245... Train Accruacy: 0.583... Val Accuracy: 0.596\n", 621 | "Epoch: 3/50... Batch: 303/303... Train Loss: 0.247... Train Accruacy: 0.597... Val Accuracy: 0.617\n", 622 | "Epoch: 4/50... Batch: 303/303... Train Loss: 0.240... Train Accruacy: 0.610... Val Accuracy: 0.627\n", 623 | "Epoch: 5/50... Batch: 303/303... Train Loss: 0.238... Train Accruacy: 0.620... Val Accuracy: 0.632\n", 624 | "Epoch: 6/50... Batch: 303/303... Train Loss: 0.234... Train Accruacy: 0.632... Val Accuracy: 0.642\n", 625 | "Epoch: 7/50... Batch: 303/303... Train Loss: 0.230... Train Accruacy: 0.636... Val Accuracy: 0.648\n", 626 | "Epoch: 8/50... Batch: 303/303... Train Loss: 0.227... Train Accruacy: 0.641... Val Accuracy: 0.653\n", 627 | "Epoch: 9/50... Batch: 303/303... Train Loss: 0.223... Train Accruacy: 0.646... Val Accuracy: 0.656\n", 628 | "Epoch: 10/50... Batch: 303/303... Train Loss: 0.221... Train Accruacy: 0.652... Val Accuracy: 0.659\n", 629 | "Epoch: 11/50... Batch: 303/303... Train Loss: 0.225... Train Accruacy: 0.656... Val Accuracy: 0.663\n", 630 | "Epoch: 12/50... Batch: 303/303... Train Loss: 0.220... Train Accruacy: 0.661... Val Accuracy: 0.666\n", 631 | "Epoch: 13/50... Batch: 303/303... Train Loss: 0.215... Train Accruacy: 0.665... Val Accuracy: 0.668\n", 632 | "Epoch: 14/50... Batch: 303/303... Train Loss: 0.213... Train Accruacy: 0.668... Val Accuracy: 0.670\n", 633 | "Epoch: 15/50... Batch: 303/303... Train Loss: 0.210... Train Accruacy: 0.669... Val Accuracy: 0.673\n", 634 | "Epoch: 16/50... Batch: 303/303... Train Loss: 0.213... Train Accruacy: 0.673... Val Accuracy: 0.675\n", 635 | "Epoch: 17/50... Batch: 303/303... Train Loss: 0.212... Train Accruacy: 0.675... Val Accuracy: 0.676\n", 636 | "Epoch: 18/50... Batch: 303/303... Train Loss: 0.206... Train Accruacy: 0.681... Val Accuracy: 0.679\n", 637 | "Epoch: 19/50... Batch: 303/303... Train Loss: 0.208... Train Accruacy: 0.683... Val Accuracy: 0.681\n", 638 | "Epoch: 20/50... Batch: 303/303... Train Loss: 0.202... Train Accruacy: 0.684... Val Accuracy: 0.684\n", 639 | "Epoch: 21/50... Batch: 303/303... Train Loss: 0.206... Train Accruacy: 0.685... Val Accuracy: 0.686\n", 640 | "Epoch: 22/50... Batch: 303/303... Train Loss: 0.204... Train Accruacy: 0.689... Val Accuracy: 0.689\n", 641 | "Epoch: 23/50... Batch: 303/303... Train Loss: 0.195... Train Accruacy: 0.690... Val Accuracy: 0.691\n", 642 | "Epoch: 24/50... Batch: 303/303... Train Loss: 0.200... Train Accruacy: 0.695... Val Accuracy: 0.692\n", 643 | "Epoch: 25/50... Batch: 303/303... Train Loss: 0.195... Train Accruacy: 0.696... Val Accuracy: 0.694\n", 644 | "Epoch: 26/50... Batch: 303/303... Train Loss: 0.200... Train Accruacy: 0.698... Val Accuracy: 0.695\n", 645 | "Epoch: 27/50... Batch: 303/303... Train Loss: 0.197... Train Accruacy: 0.701... Val Accuracy: 0.695\n", 646 | "Epoch: 28/50... Batch: 303/303... Train Loss: 0.199... Train Accruacy: 0.703... Val Accuracy: 0.698\n", 647 | "Epoch: 29/50... Batch: 303/303... Train Loss: 0.187... Train Accruacy: 0.704... Val Accuracy: 0.698\n", 648 | "Epoch: 30/50... Batch: 303/303... Train Loss: 0.190... Train Accruacy: 0.708... Val Accuracy: 0.701\n", 649 | "Epoch: 31/50... Batch: 303/303... Train Loss: 0.189... Train Accruacy: 0.708... Val Accuracy: 0.702\n", 650 | "Epoch: 32/50... Batch: 303/303... Train Loss: 0.184... Train Accruacy: 0.710... Val Accuracy: 0.704\n", 651 | "Epoch: 33/50... Batch: 303/303... Train Loss: 0.195... Train Accruacy: 0.714... Val Accuracy: 0.704\n", 652 | "Epoch: 34/50... Batch: 303/303... Train Loss: 0.190... Train Accruacy: 0.715... Val Accuracy: 0.704\n", 653 | "Epoch: 35/50... Batch: 303/303... Train Loss: 0.186... Train Accruacy: 0.714... Val Accuracy: 0.707\n", 654 | "Epoch: 36/50... Batch: 303/303... Train Loss: 0.178... Train Accruacy: 0.717... Val Accuracy: 0.707\n", 655 | "Epoch: 37/50... Batch: 303/303... Train Loss: 0.183... Train Accruacy: 0.722... Val Accuracy: 0.707\n", 656 | "Epoch: 38/50... Batch: 303/303... Train Loss: 0.181... Train Accruacy: 0.721... Val Accuracy: 0.710\n", 657 | "Epoch: 39/50... Batch: 303/303... Train Loss: 0.181... Train Accruacy: 0.723... Val Accuracy: 0.712\n", 658 | "Epoch: 40/50... Batch: 303/303... Train Loss: 0.179... Train Accruacy: 0.726... Val Accuracy: 0.712\n", 659 | "Epoch: 41/50... Batch: 303/303... Train Loss: 0.180... Train Accruacy: 0.726... Val Accuracy: 0.713\n", 660 | "Epoch: 42/50... Batch: 303/303... Train Loss: 0.177... Train Accruacy: 0.729... Val Accuracy: 0.714\n", 661 | "Epoch: 43/50... Batch: 303/303... Train Loss: 0.176... Train Accruacy: 0.731... Val Accuracy: 0.714\n", 662 | "Epoch: 44/50... Batch: 303/303... Train Loss: 0.180... Train Accruacy: 0.732... Val Accuracy: 0.716\n", 663 | "Epoch: 45/50... Batch: 303/303... Train Loss: 0.169... Train Accruacy: 0.734... Val Accuracy: 0.716\n", 664 | "Epoch: 46/50... Batch: 303/303... Train Loss: 0.173... Train Accruacy: 0.735... Val Accuracy: 0.717\n", 665 | "Epoch: 47/50... Batch: 303/303... Train Loss: 0.170... Train Accruacy: 0.736... Val Accuracy: 0.717\n", 666 | "Epoch: 48/50... Batch: 303/303... Train Loss: 0.173... Train Accruacy: 0.739... Val Accuracy: 0.718\n", 667 | "Epoch: 49/50... Batch: 303/303... Train Loss: 0.175... Train Accruacy: 0.740... Val Accuracy: 0.717\n", 668 | "Epoch: 50/50... Batch: 303/303... Train Loss: 0.175... Train Accruacy: 0.745... Val Accuracy: 0.718\n" 669 | ] 670 | } 671 | ], 672 | "source": [ 673 | "with tf.Graph().as_default():\n", 674 | " build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size,\n", 675 | " learning_rate, keep_prob, train_x, val_x, train_y, val_y)" 676 | ] 677 | }, 678 | { 679 | "cell_type": "markdown", 680 | "metadata": {}, 681 | "source": [ 682 | "## Testing our Network" 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "metadata": {}, 688 | "source": [ 689 | "The last thing we want to do is check the model accuracy on our testing data to make sure it is in line with expecations. We build the Computational Graph just like we did before, however, now instead of training we restore our saved model from our checkpoint directory and then run our test data through the model. " 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 18, 695 | "metadata": { 696 | "collapsed": true 697 | }, 698 | "outputs": [], 699 | "source": [ 700 | "def test_network(model_dir, batch_size, test_x, test_y):\n", 701 | " \n", 702 | " inputs_, labels_, keep_prob_ = model_inputs()\n", 703 | " embed = build_embedding_layer(inputs_, vocab_size, embed_size)\n", 704 | " initial_state, lstm_outputs, lstm_cell, final_state = build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size)\n", 705 | " predictions, loss, optimizer = build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate)\n", 706 | " accuracy = build_accuracy(predictions, labels_)\n", 707 | " \n", 708 | " saver = tf.train.Saver()\n", 709 | " \n", 710 | " test_acc = []\n", 711 | " with tf.Session() as sess:\n", 712 | " saver.restore(sess, tf.train.latest_checkpoint(model_dir))\n", 713 | " test_state = sess.run(lstm_cell.zero_state(batch_size, tf.float32))\n", 714 | " for ii, (x, y) in enumerate(utl.get_batches(test_x, test_y, batch_size), 1):\n", 715 | " feed = {inputs_: x,\n", 716 | " labels_: y[:, None],\n", 717 | " keep_prob_: 1,\n", 718 | " initial_state: test_state}\n", 719 | " batch_acc, test_state = sess.run([accuracy, final_state], feed_dict=feed)\n", 720 | " test_acc.append(batch_acc)\n", 721 | " print(\"Test Accuracy: {:.3f}\".format(np.mean(test_acc)))" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 19, 727 | "metadata": {}, 728 | "outputs": [ 729 | { 730 | "name": "stdout", 731 | "output_type": "stream", 732 | "text": [ 733 | "INFO:tensorflow:Restoring parameters from checkpoints/sentiment.ckpt\n", 734 | "Test Accuracy: 0.717\n" 735 | ] 736 | } 737 | ], 738 | "source": [ 739 | "with tf.Graph().as_default():\n", 740 | " test_network('checkpoints', batch_size, test_x, test_y)" 741 | ] 742 | } 743 | ], 744 | "metadata": { 745 | "anaconda-cloud": {}, 746 | "kernelspec": { 747 | "display_name": "Python 3", 748 | "language": "python", 749 | "name": "python3" 750 | }, 751 | "language_info": { 752 | "codemirror_mode": { 753 | "name": "ipython", 754 | "version": 3 755 | }, 756 | "file_extension": ".py", 757 | "mimetype": "text/x-python", 758 | "name": "python", 759 | "nbconvert_exporter": "python", 760 | "pygments_lexer": "ipython3", 761 | "version": "3.6.8" 762 | } 763 | }, 764 | "nbformat": 4, 765 | "nbformat_minor": 1 766 | } 767 | -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/02_BiLSTM_N2N_TF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Modeling Stock Market Sentiment with BiLSTMs and TensorFlow" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this tutorial, we will build a Bidirectional Long Short Term Memory (BiLSTM) Network to predict the stock market sentiment based on a comment about the market." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "We will use the following libraries for our analysis:\n", 29 | "\n", 30 | "* numpy - numerical computing library used to work with our data\n", 31 | "* pandas - data analysis library used to read in our data from csv\n", 32 | "* tensorflow - deep learning framework used for modeling" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "We will also be using the python Counter object for counting our vocabulary items and we have a util module that extracts away a lot of the details of our data processing. Please read through the util.py to get a better understanding of how to preprocess the data for analysis." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "import numpy as np\n", 49 | "import pandas as pd\n", 50 | "import tensorflow as tf\n", 51 | "import utils as utl\n", 52 | "from collections import Counter" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": { 58 | "collapsed": true 59 | }, 60 | "source": [ 61 | "## Processing Data" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "We will train the model using messages tagged with SPY, the S&P 500 index fund, from [StockTwits.com](https://www.stocktwits.com). StockTwits is a social media network for traders and investors to share their views about the stock market. When a user posts a message, they tag the relevant stock ticker ($SPY in our case) and have the option to tag the messages with their sentiment – “bullish” if they believe the stock will go up and “bearish” if they believe the stock will go down.\n", 69 | "\n", 70 | "Our dataset consists of approximately 100,000 messages posted in 2017 that are tagged with $SPY where the user indicated their sentiment. Before we get to our LSTM Network we have to perform some processing on our data to get it ready for modeling." 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "#### Read and View Data" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "First we simply read in our data using pandas, pull out our message and sentiment data into numpy arrays. Let's also take a look at a few samples to get familiar with the data set." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 2, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Messages: $SPY crazy day so far!... Sentiment: bearish\n", 97 | "Messages: $SPY Will make a new ATH this week. Watch it!... Sentiment: bullish\n", 98 | "Messages: $SPY $DJIA white elephant in room is $AAPL. Up 14% since election. Strong headwinds w/Trump trade & Strong dollar. How many 7's do you see?... Sentiment: bearish\n", 99 | "Messages: $SPY blocks above. We break above them We should push to double top... Sentiment: bullish\n", 100 | "Messages: $SPY Nothing happening in the market today, guess I'll go to the store and spend some $.... Sentiment: bearish\n", 101 | "Messages: $SPY What an easy call. Good jobs report: good economy, markets go up. Bad jobs report: no more rate hikes, markets go up. Win-win.... Sentiment: bullish\n", 102 | "Messages: $SPY BS market.... Sentiment: bullish\n", 103 | "Messages: $SPY this rally all the cheerleaders were screaming about this morning is pretty weak. I keep adding 2 my short at all spikes... Sentiment: bearish\n", 104 | "Messages: $SPY Dollar ripping higher!... Sentiment: bearish\n", 105 | "Messages: $SPY no reason to go down !... Sentiment: bullish\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "# read data from csv file\n", 111 | "data = pd.read_csv(\"data/StockTwits_SPY_Sentiment_2017.gz\",\n", 112 | " encoding=\"utf-8\",\n", 113 | " compression=\"gzip\",\n", 114 | " index_col=0)\n", 115 | "\n", 116 | "# get messages and sentiment labels\n", 117 | "messages = data.message.values\n", 118 | "labels = data.sentiment.values\n", 119 | "\n", 120 | "# View sample of messages with sentiment\n", 121 | "\n", 122 | "for i in range(10):\n", 123 | " print(\"Messages: {}...\".format(messages[i]),\n", 124 | " \"Sentiment: {}\".format(labels[i]))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "#### Preprocess Messages" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "Working with raw text data often requires preprocessing the text in some fashion to normalize for context. In our case we want to normalize for known unique \"entities\" that appear within messages that carry a similar contextual meaning when analyzing sentiment. This means we want to replace references to specific stock tickers, user names, url links or numbers with a special token identifying the \"entity\". Here we will also make everything lower case and remove punctuation." 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 3, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "messages = np.array([utl.preprocess_ST_message(message) for message in messages])" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "#### Generate Vocab to Index Mapping" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "To work with raw text we need some encoding from words to numbers for our algorithm to work with the inputs. The first step of doing this is keeping a collection of our full vocabularly and creating a mapping of each word to a unique index. We will use this word to index mapping in a little bit to prep out messages for analysis. \n", 162 | "\n", 163 | "Note that in practice we may want to only include the vocabularly from our training set here to account for the fact that we will likely see new words when our model is out in the wild when we are assessing the results on our validation and test sets. Here, for simplicity and demonstration purposes, we will use our entire data set." 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 4, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "full_lexicon = \" \".join(messages).split()\n", 173 | "vocab_to_int, int_to_vocab = utl.create_lookup_tables(full_lexicon)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "#### Check Message Lengths" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "We will also want to get a sense of the distribution of the length of our inputs. We check for the longest and average messages. We will need to make our input length uniform to feed the data into our model so later we will have some decisions to make about possibly truncating some of the longer messages if they are too long. We also notice that one message has no content remaining after we preprocessed the data, so we will remove this message from our data set." 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 5, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "Zero-length messages: 1\n", 200 | "Maximum message length: 244\n", 201 | "Average message length: 78.21856920395598\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "messages_lens = Counter([len(x) for x in messages])\n", 207 | "print(\"Zero-length messages: {}\".format(messages_lens[0]))\n", 208 | "print(\"Maximum message length: {}\".format(max(messages_lens)))\n", 209 | "print(\"Average message length: {}\".format(np.mean([len(x) for x in messages])))" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 6, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "messages, labels = utl.drop_empty_messages(messages, labels)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "#### Encode Messages and Labels" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "Earlier we mentioned that we need to \"translate\" our text to number for our algorithm to take in as inputs. We call this translation an encoding. We encode our messages to sequences of numbers where each nummber is the word index from the mapping we made earlier. The phrase \"I am bullish\" would now look something like [1, 234, 5345] where each number is the index for the respective word in the message. For our sentiment labels we will simply encode \"bearish\" as 0 and \"bullish\" as 1." 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 7, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "messages = utl.encode_ST_messages(messages, vocab_to_int)\n", 242 | "labels = utl.encode_ST_labels(labels)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "#### Pad Messages" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "The last thing we need to do is make our message inputs the same length. In our case, the longest message is 244 words. LSTMs can usually handle sequence inputs up to 500 items in length so we won't truncate any of the messages here. We need to Zero Pad the rest of the messages that are shorter. We will use a left padding that will pad all of the messages that are shorter than 244 words with 0s at the beginning. So our encoded \"I am bullish\" messages goes from [1, 234, 5345] (length 3) to [0, 0, 0, 0, 0, 0, ... , 0, 0, 1, 234, 5345] (length 244)." 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 8, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "messages = utl.zero_pad_messages(messages, seq_len=244)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "#### Train, Test, Validation Split" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "The last thing we do is split our data into tranining, validation and test sets and observe the size of each set." 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 9, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "Data Set Size\n", 292 | "Train set: \t\t(77572, 244) \n", 293 | "Validation set: \t(9697, 244) \n", 294 | "Test set: \t\t(9697, 244)\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "train_x, val_x, test_x, train_y, val_y, test_y = utl.train_val_test_split(messages, labels, split_frac=0.80)\n", 300 | "\n", 301 | "print(\"Data Set Size\")\n", 302 | "print(\"Train set: \\t\\t{}\".format(train_x.shape), \n", 303 | " \"\\nValidation set: \\t{}\".format(val_x.shape),\n", 304 | " \"\\nTest set: \\t\\t{}\".format(test_x.shape))" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "## Building and Training our BiLSTM Network" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "metadata": {}, 317 | "source": [ 318 | "In this section we will define a number of functions that will construct the items in our network. We will then use these functions to build and train our network." 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": {}, 324 | "source": [ 325 | "#### Model Inputs" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "Here we simply define a function to build TensorFlow Placeholders for our message sequences, our labels and a variable called keep probability associated with drop out (we will talk more about this later). " 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 10, 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "def model_inputs():\n", 342 | " \"\"\"\n", 343 | " Create the model inputs\n", 344 | " \"\"\"\n", 345 | " inputs_ = tf.placeholder(tf.int32, [None, None], name='inputs')\n", 346 | " labels_ = tf.placeholder(tf.int32, [None, None], name='labels')\n", 347 | " keep_prob_ = tf.placeholder(tf.float32, name='keep_prob')\n", 348 | " \n", 349 | " return inputs_, labels_, keep_prob_" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "#### Embedding Layer" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "In TensorFlow the word embeddings are represented as a vocabulary size x embedding size matrix and will learn these weights during our training process. The embedding lookup is then just a simple lookup from our embedding matrix based on the index of the current word." 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 11, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "def build_embedding_layer(inputs_, vocab_size, embed_size):\n", 373 | " \"\"\"\n", 374 | " Create the embedding layer\n", 375 | " \"\"\"\n", 376 | " embedding = tf.Variable(tf.random_uniform((vocab_size, embed_size), -1, 1))\n", 377 | " embed = tf.nn.embedding_lookup(embedding, inputs_)\n", 378 | " \n", 379 | " return embed" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "#### BiLSTM Layers" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "TensorFlow makes it extremely easy to build BiLSTM Layers. We represent each LSTM (foward and backward) as a BasicLSTMCell.\n", 394 | "\n", 395 | "We then take each of these LSTM layers and wrap them in a Dropout Layer. Dropout is a regularization technique using in Neural Networks in which any individual node has a probability of “dropping out” of the network during a given iteration of learning. The makes the model more generalizable by ensuring that it is not too dependent on any given nodes. \n", 396 | "\n", 397 | "Finally, we generate a zero initial state and connect our BiLSTM layer to our word embedding inputs using **bidirectional_dynamic_rnn**. Here we track the output and the final state of the LSTM cell, which we will need to pass between mini-batches during training. Finally we concatenate the output of the forward and backward LSTMs layers to get our final output." 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 12, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "def build_bilstm_layer(lstm_size, embed, keep_prob_, batch_size):\n", 407 | " \"\"\"\n", 408 | " Create the LSTM layers\n", 409 | " \"\"\"\n", 410 | " lstms_fwd = tf.contrib.rnn.BasicLSTMCell(lstm_size)\n", 411 | " lstms_bwd = tf.contrib.rnn.BasicLSTMCell(lstm_size)\n", 412 | " # Add dropout to the cell\n", 413 | " cell_fwd = tf.contrib.rnn.DropoutWrapper(lstms_fwd, output_keep_prob=keep_prob_)\n", 414 | " cell_bwd = tf.contrib.rnn.DropoutWrapper(lstms_bwd, output_keep_prob=keep_prob_)\n", 415 | " \n", 416 | " initial_state_fwd = cell_fwd.zero_state(batch_size, tf.float32)\n", 417 | " initial_state_bwd = cell_bwd.zero_state(batch_size, tf.float32)\n", 418 | "\n", 419 | " (output_fw, output_bw), (final_state_fwd, final_state_bwd) = tf.nn.bidirectional_dynamic_rnn(cell_fwd, cell_bwd, embed,\n", 420 | " initial_state_fw=initial_state_fwd,\n", 421 | " initial_state_bw=initial_state_bwd)\n", 422 | "\n", 423 | " bi_lstm_output = tf.concat([output_fw, output_bw], axis=2)\n", 424 | "\n", 425 | " return initial_state_fwd, initial_state_bwd, bi_lstm_output, cell_fwd, cell_bwd, final_state_fwd, final_state_bwd" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": {}, 431 | "source": [ 432 | "#### Loss Function and Optimizer" 433 | ] 434 | }, 435 | { 436 | "cell_type": "markdown", 437 | "metadata": {}, 438 | "source": [ 439 | "First, we get our predictions by passing the final output of the LSTM layers to a sigmoid activation function via a Tensorflow fully connected layer. we only care to use the final output for making predictions so we pull this out using the [: , -1] indexing on our LSTM outputs and pass it through a sigmoid activation function to make the predictions. We pass then pass these predictions to our mean squared error loss function and use the Adadelta Optimizer to minimize the loss." 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 13, 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [ 448 | "def build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate):\n", 449 | " \"\"\"\n", 450 | " Create the Loss function and Optimizer\n", 451 | " \"\"\"\n", 452 | " predictions = tf.contrib.layers.fully_connected(lstm_outputs[:, -1], 1, activation_fn=tf.sigmoid)\n", 453 | " loss = tf.losses.mean_squared_error(labels_, predictions)\n", 454 | " optimzer = tf.train.AdadeltaOptimizer(learning_rate).minimize(loss)\n", 455 | " \n", 456 | " return predictions, loss, optimzer" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": {}, 462 | "source": [ 463 | "#### Accuracy" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "Finally, we define our accuracy metric for assessing the model performance across our training, validation and test sets. Even though accuracy is just a calculation based on results, everything in TensorFlow is part of a Computation Graph. Therefore, we need to define our loss and accuracy nodes in the context of the rest of our network graph." 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 14, 476 | "metadata": {}, 477 | "outputs": [], 478 | "source": [ 479 | "def build_accuracy(predictions, labels_):\n", 480 | " \"\"\"\n", 481 | " Create accuracy\n", 482 | " \"\"\"\n", 483 | " correct_pred = tf.equal(tf.cast(tf.round(predictions), tf.int32), labels_)\n", 484 | " accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n", 485 | " \n", 486 | " return accuracy" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": {}, 492 | "source": [ 493 | "#### Training" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": {}, 499 | "source": [ 500 | "We are finally ready to build and train our LSTM Network! First, we call each of our each of the functions we have defined to construct the network. Then we define a Saver to be able to write our model to disk to load for future use. Finally, we call a Tensorflow Session to train the model over a predefined number of epochs using mini-batches. At the end of each epoch we will print the loss, training accuracy and validation accuracy to monitor the results as we train." 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 15, 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "def build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size,\n", 510 | " learning_rate, keep_prob, train_x, val_x, train_y, val_y):\n", 511 | " \n", 512 | " inputs_, labels_, keep_prob_ = model_inputs()\n", 513 | " embed = build_embedding_layer(inputs_, vocab_size, embed_size)\n", 514 | " initial_state_fwd, initial_state_bwd, bi_lstm_outputs, cell_fwd, cell_bwd, final_state_fwd, final_state_bwd = build_bilstm_layer(lstm_sizes, embed, keep_prob_, batch_size)\n", 515 | " predictions, loss, optimizer = build_cost_fn_and_opt(bi_lstm_outputs, labels_, learning_rate)\n", 516 | " accuracy = build_accuracy(predictions, labels_)\n", 517 | " \n", 518 | " saver = tf.train.Saver()\n", 519 | " \n", 520 | " with tf.Session() as sess:\n", 521 | " \n", 522 | " sess.run(tf.global_variables_initializer())\n", 523 | " n_batches = len(train_x)//batch_size\n", 524 | " for e in range(epochs):\n", 525 | " state_fwd = sess.run(initial_state_fwd)\n", 526 | " state_bwd = sess.run(initial_state_bwd)\n", 527 | " \n", 528 | " train_acc = []\n", 529 | " for ii, (x, y) in enumerate(utl.get_batches(train_x, train_y, batch_size), 1):\n", 530 | " feed = {inputs_: x,\n", 531 | " labels_: y[:, None],\n", 532 | " keep_prob_: keep_prob,\n", 533 | " initial_state_fwd: state_fwd,\n", 534 | " initial_state_bwd: state_bwd}\n", 535 | " loss_, state_fwd, state_bwd, _, batch_acc = sess.run([loss, final_state_fwd, final_state_bwd, optimizer, accuracy], feed_dict=feed)\n", 536 | " train_acc.append(batch_acc)\n", 537 | " \n", 538 | " if (ii + 1) % n_batches == 0:\n", 539 | " \n", 540 | " val_acc = []\n", 541 | " val_state_fwd = sess.run(cell_fwd.zero_state(batch_size, tf.float32))\n", 542 | " val_state_bwd = sess.run(cell_bwd.zero_state(batch_size, tf.float32))\n", 543 | " for xx, yy in utl.get_batches(val_x, val_y, batch_size):\n", 544 | " feed = {inputs_: xx,\n", 545 | " labels_: yy[:, None],\n", 546 | " keep_prob_: 1,\n", 547 | " initial_state_fwd: val_state_fwd,\n", 548 | " initial_state_bwd: val_state_bwd}\n", 549 | " val_batch_acc, val_state_fwd, val_state_bwd = sess.run([accuracy, final_state_fwd, final_state_bwd], feed_dict=feed)\n", 550 | " val_acc.append(val_batch_acc)\n", 551 | " \n", 552 | " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", 553 | " \"Batch: {}/{}...\".format(ii+1, n_batches),\n", 554 | " \"Train Loss: {:.3f}...\".format(loss_),\n", 555 | " \"Train Accruacy: {:.3f}...\".format(np.mean(train_acc)),\n", 556 | " \"Val Accuracy: {:.3f}\".format(np.mean(val_acc)))\n", 557 | " \n", 558 | " saver.save(sess, \"checkpoints/sentiment.ckpt\")" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "Next we define our model hyper parameters. We will build a single layer BiLSTM Newtork with hidden layer size of 128. We will use an embedding size of 300 and train over 50 epochs with mini-batches of size 256. We will use an initial learning rate of 0.1, though our Adadelta Optimizer will adapt this over time, and a keep probability of 0.5. " 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 16, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "# Define Inputs and Hyperparameters\n", 575 | "lstm_sizes = 128\n", 576 | "vocab_size = len(vocab_to_int) + 1 #add one for padding\n", 577 | "embed_size = 300\n", 578 | "epochs = 50\n", 579 | "batch_size = 256\n", 580 | "learning_rate = 0.1\n", 581 | "keep_prob = 0.5" 582 | ] 583 | }, 584 | { 585 | "cell_type": "markdown", 586 | "metadata": {}, 587 | "source": [ 588 | "and now we train!" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": 17, 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [ 597 | "tf.reset_default_graph()" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 18, 603 | "metadata": { 604 | "scrolled": false 605 | }, 606 | "outputs": [ 607 | { 608 | "name": "stdout", 609 | "output_type": "stream", 610 | "text": [ 611 | "Epoch: 1/50... Batch: 303/303... Train Loss: 0.254... Train Accruacy: 0.549... Val Accuracy: 0.587\n", 612 | "Epoch: 2/50... Batch: 303/303... Train Loss: 0.249... Train Accruacy: 0.575... Val Accuracy: 0.612\n", 613 | "Epoch: 3/50... Batch: 303/303... Train Loss: 0.244... Train Accruacy: 0.596... Val Accuracy: 0.624\n", 614 | "Epoch: 4/50... Batch: 303/303... Train Loss: 0.248... Train Accruacy: 0.614... Val Accuracy: 0.636\n", 615 | "Epoch: 5/50... Batch: 303/303... Train Loss: 0.238... Train Accruacy: 0.630... Val Accuracy: 0.643\n", 616 | "Epoch: 6/50... Batch: 303/303... Train Loss: 0.234... Train Accruacy: 0.639... Val Accuracy: 0.651\n", 617 | "Epoch: 7/50... Batch: 303/303... Train Loss: 0.227... Train Accruacy: 0.651... Val Accuracy: 0.656\n", 618 | "Epoch: 8/50... Batch: 303/303... Train Loss: 0.224... Train Accruacy: 0.659... Val Accuracy: 0.658\n", 619 | "Epoch: 9/50... Batch: 303/303... Train Loss: 0.220... Train Accruacy: 0.665... Val Accuracy: 0.663\n", 620 | "Epoch: 10/50... Batch: 303/303... Train Loss: 0.219... Train Accruacy: 0.669... Val Accuracy: 0.665\n", 621 | "Epoch: 11/50... Batch: 303/303... Train Loss: 0.221... Train Accruacy: 0.672... Val Accuracy: 0.667\n", 622 | "Epoch: 12/50... Batch: 303/303... Train Loss: 0.220... Train Accruacy: 0.675... Val Accuracy: 0.671\n", 623 | "Epoch: 13/50... Batch: 303/303... Train Loss: 0.224... Train Accruacy: 0.680... Val Accuracy: 0.677\n", 624 | "Epoch: 14/50... Batch: 303/303... Train Loss: 0.226... Train Accruacy: 0.683... Val Accuracy: 0.680\n", 625 | "Epoch: 15/50... Batch: 303/303... Train Loss: 0.214... Train Accruacy: 0.685... Val Accuracy: 0.680\n", 626 | "Epoch: 16/50... Batch: 303/303... Train Loss: 0.219... Train Accruacy: 0.689... Val Accuracy: 0.683\n", 627 | "Epoch: 17/50... Batch: 303/303... Train Loss: 0.216... Train Accruacy: 0.690... Val Accuracy: 0.684\n", 628 | "Epoch: 18/50... Batch: 303/303... Train Loss: 0.216... Train Accruacy: 0.693... Val Accuracy: 0.687\n", 629 | "Epoch: 19/50... Batch: 303/303... Train Loss: 0.211... Train Accruacy: 0.695... Val Accuracy: 0.689\n", 630 | "Epoch: 20/50... Batch: 303/303... Train Loss: 0.213... Train Accruacy: 0.696... Val Accuracy: 0.689\n", 631 | "Epoch: 21/50... Batch: 303/303... Train Loss: 0.218... Train Accruacy: 0.700... Val Accuracy: 0.692\n", 632 | "Epoch: 22/50... Batch: 303/303... Train Loss: 0.214... Train Accruacy: 0.700... Val Accuracy: 0.693\n", 633 | "Epoch: 23/50... Batch: 303/303... Train Loss: 0.211... Train Accruacy: 0.703... Val Accuracy: 0.695\n", 634 | "Epoch: 24/50... Batch: 303/303... Train Loss: 0.207... Train Accruacy: 0.704... Val Accuracy: 0.695\n", 635 | "Epoch: 25/50... Batch: 303/303... Train Loss: 0.210... Train Accruacy: 0.706... Val Accuracy: 0.696\n", 636 | "Epoch: 26/50... Batch: 303/303... Train Loss: 0.208... Train Accruacy: 0.707... Val Accuracy: 0.698\n", 637 | "Epoch: 27/50... Batch: 303/303... Train Loss: 0.213... Train Accruacy: 0.709... Val Accuracy: 0.699\n", 638 | "Epoch: 28/50... Batch: 303/303... Train Loss: 0.202... Train Accruacy: 0.710... Val Accuracy: 0.701\n", 639 | "Epoch: 29/50... Batch: 303/303... Train Loss: 0.201... Train Accruacy: 0.711... Val Accuracy: 0.701\n", 640 | "Epoch: 30/50... Batch: 303/303... Train Loss: 0.207... Train Accruacy: 0.714... Val Accuracy: 0.704\n", 641 | "Epoch: 31/50... Batch: 303/303... Train Loss: 0.202... Train Accruacy: 0.716... Val Accuracy: 0.704\n", 642 | "Epoch: 32/50... Batch: 303/303... Train Loss: 0.204... Train Accruacy: 0.717... Val Accuracy: 0.706\n", 643 | "Epoch: 33/50... Batch: 303/303... Train Loss: 0.198... Train Accruacy: 0.719... Val Accuracy: 0.707\n", 644 | "Epoch: 34/50... Batch: 303/303... Train Loss: 0.202... Train Accruacy: 0.721... Val Accuracy: 0.707\n", 645 | "Epoch: 35/50... Batch: 303/303... Train Loss: 0.197... Train Accruacy: 0.723... Val Accuracy: 0.709\n", 646 | "Epoch: 36/50... Batch: 303/303... Train Loss: 0.193... Train Accruacy: 0.723... Val Accuracy: 0.710\n", 647 | "Epoch: 37/50... Batch: 303/303... Train Loss: 0.197... Train Accruacy: 0.725... Val Accuracy: 0.711\n", 648 | "Epoch: 38/50... Batch: 303/303... Train Loss: 0.200... Train Accruacy: 0.726... Val Accuracy: 0.712\n", 649 | "Epoch: 39/50... Batch: 303/303... Train Loss: 0.197... Train Accruacy: 0.730... Val Accuracy: 0.713\n", 650 | "Epoch: 40/50... Batch: 303/303... Train Loss: 0.200... Train Accruacy: 0.730... Val Accuracy: 0.714\n", 651 | "Epoch: 41/50... Batch: 303/303... Train Loss: 0.195... Train Accruacy: 0.731... Val Accuracy: 0.714\n", 652 | "Epoch: 42/50... Batch: 303/303... Train Loss: 0.196... Train Accruacy: 0.733... Val Accuracy: 0.715\n", 653 | "Epoch: 43/50... Batch: 303/303... Train Loss: 0.196... Train Accruacy: 0.734... Val Accuracy: 0.716\n", 654 | "Epoch: 44/50... Batch: 303/303... Train Loss: 0.194... Train Accruacy: 0.737... Val Accuracy: 0.716\n", 655 | "Epoch: 45/50... Batch: 303/303... Train Loss: 0.199... Train Accruacy: 0.738... Val Accuracy: 0.719\n", 656 | "Epoch: 46/50... Batch: 303/303... Train Loss: 0.192... Train Accruacy: 0.738... Val Accuracy: 0.719\n", 657 | "Epoch: 47/50... Batch: 303/303... Train Loss: 0.195... Train Accruacy: 0.741... Val Accuracy: 0.720\n", 658 | "Epoch: 48/50... Batch: 303/303... Train Loss: 0.185... Train Accruacy: 0.742... Val Accuracy: 0.721\n", 659 | "Epoch: 49/50... Batch: 303/303... Train Loss: 0.181... Train Accruacy: 0.745... Val Accuracy: 0.722\n", 660 | "Epoch: 50/50... Batch: 303/303... Train Loss: 0.191... Train Accruacy: 0.744... Val Accuracy: 0.723\n" 661 | ] 662 | } 663 | ], 664 | "source": [ 665 | "with tf.Graph().as_default():\n", 666 | " build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size,\n", 667 | " learning_rate, keep_prob, train_x, val_x, train_y, val_y)" 668 | ] 669 | }, 670 | { 671 | "cell_type": "markdown", 672 | "metadata": {}, 673 | "source": [ 674 | "## Testing our Network" 675 | ] 676 | }, 677 | { 678 | "cell_type": "markdown", 679 | "metadata": {}, 680 | "source": [ 681 | "The last thing we want to do is check the model accuracy on our testing data to make sure it is in line with expecations. We build the Computational Graph just like we did before, however, now instead of training we restore our saved model from our checkpoint directory and then run our test data through the model. " 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": 21, 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "def test_network(model_dir, batch_size, test_x, test_y):\n", 691 | " \n", 692 | " inputs_, labels_, keep_prob_ = model_inputs()\n", 693 | " embed = build_embedding_layer(inputs_, vocab_size, embed_size)\n", 694 | " initial_state_fwd, initial_state_bwd, bi_lstm_outputs, cell_fwd, cell_bwd, final_state_fwd, final_state_bwd = build_bilstm_layer(lstm_sizes, embed, keep_prob_, batch_size)\n", 695 | " predictions, loss, optimizer = build_cost_fn_and_opt(bi_lstm_outputs, labels_, learning_rate)\n", 696 | " accuracy = build_accuracy(predictions, labels_)\n", 697 | " \n", 698 | " saver = tf.train.Saver()\n", 699 | " \n", 700 | " test_acc = []\n", 701 | " with tf.Session() as sess:\n", 702 | " saver.restore(sess, tf.train.latest_checkpoint(model_dir))\n", 703 | " test_state_fwd = sess.run(cell_fwd.zero_state(batch_size, tf.float32))\n", 704 | " test_state_bwd = sess.run(cell_bwd.zero_state(batch_size, tf.float32))\n", 705 | " for ii, (x, y) in enumerate(utl.get_batches(test_x, test_y, batch_size), 1):\n", 706 | " feed = {inputs_: x,\n", 707 | " labels_: y[:, None],\n", 708 | " keep_prob_: 1,\n", 709 | " initial_state_fwd: test_state_fwd,\n", 710 | " initial_state_bwd: test_state_bwd}\n", 711 | " batch_acc, test_state_fwd, test_state_bwd = sess.run([accuracy, final_state_fwd, final_state_bwd], feed_dict=feed)\n", 712 | " test_acc.append(batch_acc)\n", 713 | " print(\"Test Accuracy: {:.3f}\".format(np.mean(test_acc)))" 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": 22, 719 | "metadata": {}, 720 | "outputs": [ 721 | { 722 | "name": "stdout", 723 | "output_type": "stream", 724 | "text": [ 725 | "INFO:tensorflow:Restoring parameters from checkpoints/sentiment.ckpt\n", 726 | "Test Accuracy: 0.725\n" 727 | ] 728 | } 729 | ], 730 | "source": [ 731 | "with tf.Graph().as_default():\n", 732 | " test_network('checkpoints', batch_size, test_x, test_y)" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": null, 738 | "metadata": {}, 739 | "outputs": [], 740 | "source": [] 741 | } 742 | ], 743 | "metadata": { 744 | "anaconda-cloud": {}, 745 | "kernelspec": { 746 | "display_name": "Python 3", 747 | "language": "python", 748 | "name": "python3" 749 | }, 750 | "language_info": { 751 | "codemirror_mode": { 752 | "name": "ipython", 753 | "version": 3 754 | }, 755 | "file_extension": ".py", 756 | "mimetype": "text/x-python", 757 | "name": "python", 758 | "nbconvert_exporter": "python", 759 | "pygments_lexer": "ipython3", 760 | "version": "3.5.2" 761 | } 762 | }, 763 | "nbformat": 4, 764 | "nbformat_minor": 1 765 | } 766 | -------------------------------------------------------------------------------- /strata-2019-dl-for-nlp/04_ULMFiT_fastai.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "StockTwits Sentiment with ULMFiT", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "metadata": { 20 | "id": "vwmMXgxIz2dZ", 21 | "colab_type": "text" 22 | }, 23 | "cell_type": "markdown", 24 | "source": [ 25 | "\n", 26 | "# Modeling Stock Market Sentiment with ULMFiT\n", 27 | "\n", 28 | "In this tutorial, we will fine-tune a ULMFiT model to predict the stock market sentiment based on a comment about the market. \n" 29 | ] 30 | }, 31 | { 32 | "metadata": { 33 | "id": "c_G95Xix0iTS", 34 | "colab_type": "text" 35 | }, 36 | "cell_type": "markdown", 37 | "source": [ 38 | "## Setup\n", 39 | "\n", 40 | "First we import the necesary libraries for our modeling.\n" 41 | ] 42 | }, 43 | { 44 | "metadata": { 45 | "id": "yA8zetPccyoM", 46 | "colab_type": "code", 47 | "outputId": "7b602f98-d015-4ed1-af77-da2ea712b6ba", 48 | "colab": { 49 | "base_uri": "https://localhost:8080/", 50 | "height": 153 51 | } 52 | }, 53 | "cell_type": "code", 54 | "source": [ 55 | "# Install PyTorch 1.0 that works with Colab\n", 56 | "!pip3 install https://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl" 57 | ], 58 | "execution_count": 0, 59 | "outputs": [ 60 | { 61 | "output_type": "stream", 62 | "text": [ 63 | "Collecting torch==1.0.0 from https://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl\n", 64 | "\u001b[?25l Downloading https://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl (532.5MB)\n", 65 | "\u001b[K 100% |████████████████████████████████| 532.5MB 32kB/s \n", 66 | "\u001b[?25hInstalling collected packages: torch\n", 67 | " Found existing installation: torch 1.0.1.post2\n", 68 | " Uninstalling torch-1.0.1.post2:\n", 69 | " Successfully uninstalled torch-1.0.1.post2\n", 70 | "Successfully installed torch-1.0.0\n" 71 | ], 72 | "name": "stdout" 73 | } 74 | ] 75 | }, 76 | { 77 | "metadata": { 78 | "id": "5UdbcNQsz0AH", 79 | "colab_type": "code", 80 | "colab": {} 81 | }, 82 | "cell_type": "code", 83 | "source": [ 84 | "# import necesary libaries, classes and functions\n", 85 | "\n", 86 | "import pandas as pd\n", 87 | "import numpy as np\n", 88 | "from fastai.text.data import TextLMDataBunch, TextClasDataBunch\n", 89 | "from fastai.text.learner import language_model_learner, text_classifier_learner\n", 90 | "from fastai.text.models import AWD_LSTM\n", 91 | "from sklearn.model_selection import train_test_split" 92 | ], 93 | "execution_count": 0, 94 | "outputs": [] 95 | }, 96 | { 97 | "metadata": { 98 | "id": "ejt1Si7ZynFS", 99 | "colab_type": "text" 100 | }, 101 | "cell_type": "markdown", 102 | "source": [ 103 | "## Processing Data\n", 104 | "\n", 105 | "We will train the model using messages tagged with SPY, the S&P 500 index fund, from StockTwits.com. StockTwits is a social media network for traders and investors to share their views about the stock market. When a user posts a message, they tag the relevant stock ticker ($SPY in our case) and have the option to tag the messages with their sentiment – “bullish” if they believe the stock will go up and “bearish” if they believe the stock will go down.\n", 106 | "\n", 107 | "Our dataset consists of approximately 100,000 messages posted in 2017 that are tagged with $SPY where the user indicated their sentiment. Before we get to our LSTM Network we have to perform some processing on our data to get it ready for modeling.\n" 108 | ] 109 | }, 110 | { 111 | "metadata": { 112 | "id": "xA5mgkVL_9w-", 113 | "colab_type": "text" 114 | }, 115 | "cell_type": "markdown", 116 | "source": [ 117 | "#### Read and View Data\n", 118 | "\n", 119 | "First we simply read in our data using pandas, pull out our message and sentiment data into numpy arrays. Let's also take a look at a few samples to get familiar with the data set." 120 | ] 121 | }, 122 | { 123 | "metadata": { 124 | "id": "Jijw947ruc49", 125 | "colab_type": "code", 126 | "outputId": "2bb13e02-e5fe-49db-ba7d-90fae927bf74", 127 | "colab": { 128 | "base_uri": "https://localhost:8080/", 129 | "height": 207 130 | } 131 | }, 132 | "cell_type": "code", 133 | "source": [ 134 | "# read our data directly from github\n", 135 | "\n", 136 | "data_url = 'https://github.com/GarrettHoffman/AI_Conf_2019_DL_4_NLP/blob/master/data/StockTwits_SPY_Sentiment_2017.gz?raw=true'\n", 137 | "data = df = pd.read_csv(data_url,\n", 138 | " encoding=\"utf-8\",\n", 139 | " compression='gzip', \n", 140 | " index_col=0)\n", 141 | "\n", 142 | "# get messages and sentiment labels\n", 143 | "messages = data.message.values\n", 144 | "labels = data.sentiment.values\n", 145 | "\n", 146 | "# View sample of messages with sentiment\n", 147 | "\n", 148 | "for i in range(10):\n", 149 | " print(\"Messages: {}...\".format(messages[i]),\n", 150 | " \"Sentiment: {}\".format(labels[i]))" 151 | ], 152 | "execution_count": 0, 153 | "outputs": [ 154 | { 155 | "output_type": "stream", 156 | "text": [ 157 | "Messages: $SPY crazy day so far!... Sentiment: bearish\n", 158 | "Messages: $SPY Will make a new ATH this week. Watch it!... Sentiment: bullish\n", 159 | "Messages: $SPY $DJIA white elephant in room is $AAPL. Up 14% since election. Strong headwinds w/Trump trade & Strong dollar. How many 7's do you see?... Sentiment: bearish\n", 160 | "Messages: $SPY blocks above. We break above them We should push to double top... Sentiment: bullish\n", 161 | "Messages: $SPY Nothing happening in the market today, guess I'll go to the store and spend some $.... Sentiment: bearish\n", 162 | "Messages: $SPY What an easy call. Good jobs report: good economy, markets go up. Bad jobs report: no more rate hikes, markets go up. Win-win.... Sentiment: bullish\n", 163 | "Messages: $SPY BS market.... Sentiment: bullish\n", 164 | "Messages: $SPY this rally all the cheerleaders were screaming about this morning is pretty weak. I keep adding 2 my short at all spikes... Sentiment: bearish\n", 165 | "Messages: $SPY Dollar ripping higher!... Sentiment: bearish\n", 166 | "Messages: $SPY no reason to go down !... Sentiment: bullish\n" 167 | ], 168 | "name": "stdout" 169 | } 170 | ] 171 | }, 172 | { 173 | "metadata": { 174 | "id": "BsgYP2e0_E1q", 175 | "colab_type": "text" 176 | }, 177 | "cell_type": "markdown", 178 | "source": [ 179 | "#### Create Data DF\n", 180 | "\n", 181 | "The fast.ai library allows us to create our input data from a pandas DataFrame. By default the factory methods that create these objects assume by default that a single label column comes before a single text column, so we will make sure these labels come first." 182 | ] 183 | }, 184 | { 185 | "metadata": { 186 | "id": "njaoOeg26RQh", 187 | "colab_type": "code", 188 | "colab": {} 189 | }, 190 | "cell_type": "code", 191 | "source": [ 192 | "df = pd.DataFrame({'label': labels, 'text': messages })" 193 | ], 194 | "execution_count": 0, 195 | "outputs": [] 196 | }, 197 | { 198 | "metadata": { 199 | "id": "BHlfJOCnAF-j", 200 | "colab_type": "text" 201 | }, 202 | "cell_type": "markdown", 203 | "source": [ 204 | "#### Train/Val Split\n", 205 | "\n", 206 | "The last thing we do is split our data into tranining and validation sets." 207 | ] 208 | }, 209 | { 210 | "metadata": { 211 | "id": "AwyX2Clc6RVF", 212 | "colab_type": "code", 213 | "colab": {} 214 | }, 215 | "cell_type": "code", 216 | "source": [ 217 | "df_trn, df_val = train_test_split(df, stratify = df['label'], test_size=0.2, random_state=42)" 218 | ], 219 | "execution_count": 0, 220 | "outputs": [] 221 | }, 222 | { 223 | "metadata": { 224 | "id": "J0A-Z1ZM6RXc", 225 | "colab_type": "code", 226 | "outputId": "8f7e991a-1f31-4852-b98c-c3b911b09dda", 227 | "colab": { 228 | "base_uri": "https://localhost:8080/", 229 | "height": 68 230 | } 231 | }, 232 | "cell_type": "code", 233 | "source": [ 234 | "print(\"Data Set Size\")\n", 235 | "print(\"Train set: \\t\\t{}\".format(df_trn.shape), \n", 236 | " \"\\nValidation set: \\t{}\".format(df_val.shape))" 237 | ], 238 | "execution_count": 0, 239 | "outputs": [ 240 | { 241 | "output_type": "stream", 242 | "text": [ 243 | "Data Set Size\n", 244 | "Train set: \t\t(77573, 2) \n", 245 | "Validation set: \t(19394, 2)\n" 246 | ], 247 | "name": "stdout" 248 | } 249 | ] 250 | }, 251 | { 252 | "metadata": { 253 | "id": "HphlDFduGbof", 254 | "colab_type": "text" 255 | }, 256 | "cell_type": "markdown", 257 | "source": [ 258 | "## Training ULMFiT for Sentiment Classification\n", 259 | "\n" 260 | ] 261 | }, 262 | { 263 | "metadata": { 264 | "id": "t6iLjGGLGbxG", 265 | "colab_type": "text" 266 | }, 267 | "cell_type": "markdown", 268 | "source": [ 269 | "#### Model Inputs\n", 270 | "\n", 271 | "The fast.ai libaray provides custom data classes that allow us to quickly and easily assemble our data for modelling. The base class is called a `TextDataBunch`. From the [fast.ai docs](https://docs.fast.ai/text.data.html#Quickly-assemble-your-data):\n", 272 | "\n", 273 | "You should get your data in one of the following formats to make the most of the fastai library and use one of the factory methods of one of the TextDataBunch classes:\n", 274 | "\n", 275 | "* raw text files in folders train, valid, test in an ImageNet style,\n", 276 | "* a csv where some column(s) gives the label(s) and the following one the associated text,\n", 277 | "* a dataframe structured the same way,\n", 278 | "* tokens and labels arrays,\n", 279 | "* ids, vocabulary (correspondence id to word) and labels.\n", 280 | "\n", 281 | "We will use two extension of this class:\n", 282 | "\n", 283 | "* `TextLMDataBunch` -- a `TextDataBunch` for training a language modeling, and\n", 284 | "* `TextClasDataBunch` -- a `TextDataBunch` for training a RNN\n" 285 | ] 286 | }, 287 | { 288 | "metadata": { 289 | "id": "B_yLJeWbwMX4", 290 | "colab_type": "code", 291 | "colab": {} 292 | }, 293 | "cell_type": "code", 294 | "source": [ 295 | "# Language model data\n", 296 | "data_lm = TextLMDataBunch.from_df(path='', train_df = df_trn, valid_df = df_val, \n", 297 | " text_cols='text')\n", 298 | "\n", 299 | "# Classifier model data\n", 300 | "data_clas = TextClasDataBunch.from_df(path='', train_df=df_trn, valid_df=df_val, \n", 301 | " vocab=data_lm.train_ds.vocab, bs=256, \n", 302 | " text_cols='text', label_cols='label')" 303 | ], 304 | "execution_count": 0, 305 | "outputs": [] 306 | }, 307 | { 308 | "metadata": { 309 | "id": "Z8NFj_mPJCyE", 310 | "colab_type": "text" 311 | }, 312 | "cell_type": "markdown", 313 | "source": [ 314 | "#### Fine Tune Pre-Trained Language Model\n", 315 | "\n", 316 | "The fast.ai library includes the **AWD_LSTM** language model that was pretrained on the Wikipedia data set. Here we use [`language_model_learner`](https://docs.fast.ai/text.learner.html#language_model_learner) to define our base language mode that we will fine tune with our sentiment classification task corpus.\n", 317 | "\n", 318 | "We will fine tune the model using a [one cycle policy](https://docs.fast.ai/callbacks.one_cycle.html#What-is-1cycle?).\n" 319 | ] 320 | }, 321 | { 322 | "metadata": { 323 | "id": "woTkLRu8xAXK", 324 | "colab_type": "code", 325 | "colab": {} 326 | }, 327 | "cell_type": "code", 328 | "source": [ 329 | "# instantiate the language model learner\n", 330 | "lm_learn = language_model_learner(data_lm, arch=AWD_LSTM, pretrained=True, drop_mult=0.5)" 331 | ], 332 | "execution_count": 0, 333 | "outputs": [] 334 | }, 335 | { 336 | "metadata": { 337 | "id": "nf-8GONsxAZl", 338 | "colab_type": "code", 339 | "outputId": "177d630a-46be-4873-ad8e-c5ae050925a7", 340 | "colab": { 341 | "base_uri": "https://localhost:8080/", 342 | "height": 34 343 | } 344 | }, 345 | "cell_type": "code", 346 | "source": [ 347 | "# find max learning rate for one cycle training\n", 348 | "lm_learn.lr_find(start_lr=1e-8, end_lr=1e2)" 349 | ], 350 | "execution_count": 0, 351 | "outputs": [ 352 | { 353 | "output_type": "display_data", 354 | "data": { 355 | "text/html": [ 356 | "" 357 | ], 358 | "text/plain": [ 359 | "" 360 | ] 361 | }, 362 | "metadata": { 363 | "tags": [] 364 | } 365 | }, 366 | { 367 | "output_type": "stream", 368 | "text": [ 369 | "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" 370 | ], 371 | "name": "stdout" 372 | } 373 | ] 374 | }, 375 | { 376 | "metadata": { 377 | "id": "98Rg16vUxAiM", 378 | "colab_type": "code", 379 | "outputId": "4d315b31-a70a-4cb8-e2a7-869a3556527d", 380 | "colab": { 381 | "base_uri": "https://localhost:8080/", 382 | "height": 283 383 | } 384 | }, 385 | "cell_type": "code", 386 | "source": [ 387 | "# plot learning rate performance to get max learning rate for one cycle fitting\n", 388 | "lm_learn.recorder.plot()" 389 | ], 390 | "execution_count": 0, 391 | "outputs": [ 392 | { 393 | "output_type": "display_data", 394 | "data": { 395 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xuc3HV97/HXZ2Z29r7JJru5kATC\nNSAQIwkXwXBAQJH6EBC10tKDouVRtSL10lPrOdXWYytY66labNOqUKu0RaAHPCJQSUiVayBcEkkg\nwdxvm2ySvc7sXD7nj/ltssTNZS8zv9/s7/18POaxs7/9zXw/O9nMe77f3+/3/Zq7IyIi8ZUIuwAR\nEQmXgkBEJOYUBCIiMacgEBGJOQWBiEjMKQhERGJOQSAiEnMKAhGRmFMQiIjEXCrsAo5FW1ubz507\nN+wyRESqynPPPbfb3duPtl9VBMHcuXNZsWJF2GWIiFQVM9t4LPtpaEhEJOYUBCIiMacgEBGJOQWB\niEjMKQhERGJOQSAiEnMKAhGRmFMQiIhE0I79Gb7+yFpe7+gpe1sKAhGRCNrU2ce3HlvHtn2Zsrel\nIBARiaDuTA6A5rryTwChIBARiaCuIAha6mvK3paCQEQkgrozeUA9AhGR2Orq19CQiEisdWfy1KYS\n1KaSZW9LQSAiEkFdmTzNdeU/PgAKAhGRSOrK5Gipr8ySMQoCEZEI6laPQEQk3rozOVoqcKAYFAQi\nIpHU1Z+jRT0CEZH4Kg0NqUcgIhJbpYPF6hGIiMRSrlAkkyvSXKsegYhILFVyeglQEIiIRM7g9BIa\nGhIRiamDPQIFgYhILB2YglpDQyIi8XRwURr1CEREYqlLB4tFROJNB4tFRGJu8GBxk64jEBGJp65M\njqbaFMmEVaQ9BYGISMR0Z/IVO2MIFAQiIpHTnclV7IwhKGMQmNn3zGyXma0asu39ZrbazIpmtqhc\nbYuIVLOu/nzFVieD8vYI7gSuPGTbKuC9wPIytisiUtW6sxOkR+Duy4HOQ7a94u5ry9WmiMhEUMm1\nCCDCxwjM7GYzW2FmKzo6OsIuR0SkYiq5OhlEOAjcfYm7L3L3Re3t7WGXIyJSEe6uHoGISJz15wrk\ni16xq4pBQSAiEimVXpQGynv66N3Ak8A8M9tiZh8xs2vNbAvwVuD/mdnD5WpfRKQaVXrmUYCyRY67\nX3+YH91frjZFRKrd/v5Sj0BXFouIxFQYPQIFgYhIhAyuRaAegYhITA32CHTWkIhITE2os4ZERGTk\nuvpzpBJGfU2yYm0qCEREImTwqmKzyixKAwoCEZFI6arwWgSgIBARiZTuTGXXIgAFgYhIpHRncjTX\nqkcgIhJblV6dDBQEIiKRUun1ikFBICISKZVeiwAUBCIikVEoOt3ZfEVXJwMFgYhIZPRkK39VMSgI\nREQio6s/mGdIPQIRkXganGdIZw2JiMRUGGsRgIJARCQyDq5FoCAQEYmlgz0CDQ2JiMTS4MFiBYGI\nSEwdXJRGQ0MiIrHUnc1TV5MgnarsW7OCQEQkIrr6cxU/UAwKAhGRyAhjniFQEIiIREYYq5OBgkBE\nJDK6Mnla6hUEIiKxVVqLQENDIiKx1dWfp0VBICISX90ZnTUkIhJb2XyBbL44sYaGzOx7ZrbLzFYN\n2TbFzB41s9eCr63lal9EpJrs7S1NLzGlsbbibZezR3AncOUh2/4E+Lm7nwr8PPheRCT29vRmAZjS\nOIGGhtx9OdB5yOargbuC+3cB15SrfRGRatLZOwBMvB7BcKa7+/bg/g5geoXbFxGJpINBkK5426Ed\nLHZ3B/xwPzezm81shZmt6OjoqGBlIiKVF6cg2GlmMwGCr7sOt6O7L3H3Re6+qL29vWIFioiEobN3\ngITB5BhcWfwAcGNw/0bg/1a4fRGRSNrTO0BrQ5pEwiredjlPH70beBKYZ2ZbzOwjwFeBK8zsNeDy\n4HsRkdjb2ztAawjDQgBlu3LB3a8/zI8uK1ebIiLVak/vQCjHB0BXFouIREJn7wBTFQQiIvEV5tCQ\ngkBEJGTForO3Tz0CEZHY2tefo+jhXEMACgIRkdCFeTEZKAhEREKnIBARibnOAzOPKghERGKpM1iL\nYGoIM4+CgkBEJHSDPYLWENYiAAWBiEjo9vQO0FSbojaVDKV9BYGISMg6Q5xeAhQEIiKh6wzxqmJQ\nEIiIhC7MeYZAQSAiEjoNDYmIxJi7KwhEROKsb6BANl9UEIiIxFXY00vAMQaBmZ1sZrXB/UvM7BYz\nm1ze0kREJr4DQdAQ8SAA7gUKZnYKsASYA/yobFWJiMTEgSBoin4QFN09D1wLfMvdPwfMLF9ZIiLx\nsCcIgmo4fTRnZtcDNwI/CbaFMymGiMgEsrdajhEAHwbeCnzF3X9tZicCPyhfWSIi8bCnd4CapNFU\nmwqthmNq2d1/BdwCYGatQLO731bOwkRE4qCzN8uUxjRmFloNx3rW0DIzazGzKcDzwD+a2d+UtzQR\nkYmvdDFZOOsQDDrWoaFJ7t4FvBf4Z3c/H7i8fGWJiMRDKQjCPeR6rEGQMrOZwAc4eLBYRETGqJp6\nBH8BPAysd/dnzewk4LXylSUiEg97Qp55FI79YPE9wD1Dvn8duK5cRYmIxEGuUKQ7k6c1xKuK4dgP\nFs82s/vNbFdwu9fMZpe7OBGRiWxvBK4qhmMfGvo+8ABwXHB7MNgmIiKjFIWriuHYg6Dd3b/v7vng\ndifQXsa6REQmvMEeQVUMDQF7zOwGM0sGtxuAPaNt1Mw+ZWarzGy1md062ucREalmB3oEVTI0dBOl\nU0d3ANuB9wEfGk2DZnYW8PvAecCbgXcHs5qKiMRKFNYigGMMAnff6O7vcfd2d5/m7tcw+rOGzgCe\ndve+YEbTxyldqCYiEiudvQOYweT66rigbDifHuXjVgGLzWyqmTUAV1Fa30BEJFY6eweYVF9DKhnu\nYpFjme5uVDMkufsrZnYb8AjQC7wAFH7jyc1uBm4GOP7448dQpohINIW9aP2gscSQj/qB7t9194Xu\nfjGwF3h1mH2WuPsid1/U3q4TlERk4tnTmw391FE4So/AzLoZ/g3fgPrRNmpm09x9l5kdT+n4wAWj\nfS4RkWq1p2eAE9sawy7jyEHg7s1lavdeM5sK5IBPuPu+MrUjIhJJ7s62ff287dS2sEsZ0zGCUXP3\nxWG0KyISFXv7cvQOFJjd2hB2KWM6RiAiIqO0ZW8fAHNaRz3KPm4UBCIiIdjc2Q+gHoGISFxtDnoE\ns6eoRyAiEktb9vYxqb6GlrpwryoGBYGISCg2d/YzJwK9AVAQiIiEYsvePuZE4PgAKAhERCrO3dmy\nt5/ZEThjCBQEIiIV19GTJZsvMmeKegQiIrF08NRR9QhERGLp4MVk6hGIiMTSlr2lHsEs9QhEROJp\nc2cfbU1pGtKhTPf2GxQEIiIVVjpjKBrDQqAgEBGpuM17+yJzoBgUBCIiFVUoltYhiMqpo6AgEBGp\nqJ1dGXIFV49ARCSuNndG69RRUBCIiFTU4Kmj6hGIiMTU4DoEUbmGABQEIiIVtWVvP9NbaqlNJcMu\n5QAFgYhIBW3ujM7004MUBCIiFbRlb7ROHQUFgYhIxeQKRbbvj846BIMUBCIiFbJ9X4aiR+vUUVAQ\niIhUzOD00+oRiIjE1OCpozpGICISU1v29pMwmDGpLuxS3kBBICJSIRv39DFzUj01yWi99UarGhGR\nCezVnd2cNr0p7DJ+g4JARKQCBvJF1u3q4fSZLWGX8hsUBCIiFfD67h7yRef0Gc1hl/IbQgkCM/sj\nM1ttZqvM7G4zi9aRExGRcbZ2RzcA8xQEYGazgFuARe5+FpAEPljpOkREKumV7d3UJI2T2nSMYFAK\nqDezFNAAbAupDhGRili7o4uT25tIp6I3Il/xitx9K/DXwCZgO7Df3R+pdB0iIpW0dkd3JIeFIJyh\noVbgauBE4Dig0cxuGGa/m81shZmt6OjoqHSZIiLjZn9/jm37M5w+I3pnDEE4Q0OXA7929w53zwH3\nARceupO7L3H3Re6+qL29veJFioiMl8EDxVE8YwjCCYJNwAVm1mBmBlwGvBJCHSIiFbF2RxcQzTOG\nIJxjBE8DPwaeB14OalhS6TpERCplzY5uWupSzIzYHEODUmE06u5fBL4YRtsiIpW2Zkc3p89ooTQI\nEj3RO49JRGQCcXdejfAZQ6AgEBEpq637+unO5jl9poJARCSWon7GECgIRETKak0QBKdNVxCIiMTS\nmh3dzG6tp7muJuxSDktBICJSRmt3dEV6WAhCOn20Utw9tNO1ikVnX3+Ozt4se3oG6MrkSRgkEkYq\nYaSTCSY11DCpvnSrr0m+odZsvkBftkBPNk/fQIGBfJGBQpF8oUg2X6Qnm6c7k6M7U/q5Ozge/N6l\n370YbEuYkUokSCVLbaeSieBrqY6pTWnammppa6plSmOa2lQisqe5iVSTbL7A+o5ernjT9LBLOaIJ\nHQR//chaHlq1gwWzJ7Pg+Mm8efZkWupr6OjOBrcM2XyRgjvupTfvRMJImJEwSCaMdCpBTfDGaWb0\nD5TeePsGCnRn8uzrG2Bv3wB7+3J0Z3L0Bm/evQN53EdXtxmjfuzQ5zDAzCgGv9+xShg0pFPUp5M0\nppM019XQVJuiqS5FXU2SmqRRk0hQkzLqUknq00nqakq3VKL02pkZZqVVmQZvDtQkE6RTCdLJUtDk\ni06h6BTdSSUS1NUkqU0lDraTSpBOJjCgP1d63ftzBdyd+nSKhpokDekk6SC8Bv/dmmpTB0I2FbH1\nYSU+1u/qpVB05kV0jqFBEzoITpvezNodPSx/bTf3rdw67s/fVJtickMNrQ1pJjfUMGtyPY21SRpr\nUzTVpmhtSDO1Kc2UxjST6kvjg/miUyw6mVyRrkyOfX059vfnyOQKpc/zXvpcXxe8wTXWpmhIJ6lN\nld4Y08kENakEzXUpmutqaK5LUV+TJGEWvPGXajv0E32h6OSLRfIFL92KRfJFJ5srsqc3y+6eAXb3\nZOnsHaB/YPANN09PtkBPJkdPNs/mzj6ywZt6vlgkV3CyuQJ9ucKYg6ucGtOlsEonE9QGQdOQTtKQ\nLr22dTWlIEkHoeNe+vfJ5AtkcgXSqSRNtSma61I0plO01Kdoqauhpb6GlroUjcHPmmpL99WjkkFr\nd5amljhDQ0PhuXrBLK5eMAt3Z/v+DC9u3kffQIFpLbW0N9fS3lRLfTp4EzUwDMcpFqHgpU+q+ULp\nDS9XKFJ0P/DmUV+TJJGonv/syYSRTCSpHeZf/PipDWN6bncnmy+SyRUoOhS9FHYOpAd7AKnSp/pc\nwRkolMLEDJJmJJNG0oxcoUgmVySbL5DJFckVBm+lHkN9TekNvTSMBpmgh9CbLZAvFkttF5180enN\n5tnfXwrarkyObL5ANlc8UGd/rkBvNs/uniz9uQK5YOgtmy+SMKOuJkF9TSmABwpFujN5erN5+nOF\nY3qtG9JJGtOlgBj8oNDakGZSQyk8SiFSc2BYrr25ltaGNMkq+puSo3ty/R6aa1Oc2NYYdilHNKGD\nYJCZcdzkeo6bXB92KROSmR0YGjqaVBLqOfp+UZUPQqErk6Orv/S1J5unJ1MaDuzO5OkfKNA7UAqO\nrv48e/sG2Linjxc276MrkyOTKw773KmEMbu1nuOnNnLClAbmTKlneksd01vqmNFSx3GT6yO5qIkM\nr1h0lq7t4OLT2iM/PBmLIBAZL6lkgtbGNK2N6VE/RzZfOr60vz/HnmBIbndPlh37M2zs7GPTnj5W\nbtpLdyb/hselUwnOPK6FBXMms2BO6ZjXCVMbNAwVUau3ddHRneXtp08Lu5SjUhCIVFhtKkltU5K2\nplpOPsxSG+5OVybPrq4MO7uy7OjK8OrObl7YtI+7n9nE93+5AYDmuhRnHtfC/NmTOW/uFM4/aUqk\nz1ePk8fW7MIMLpkX/fVUFAQiEWRmB856OvWQK1JzhSJrd3Szaut+Vm3bz8tbu7jziQ0sWf46yYSx\nYM5kLjqljYtPbWPBnMmRH5aYqB5bu4s3z57M1KbasEs5KgWBSJWpSSY4a9Ykzpo16cC2TK7A85v2\n8sS6Pfxi3W6+/dhrfPPnr9Fcm+LCU6by9tOnceWZM5nUoN5CJXR0Z3lx8z4+fcVpYZdyTBQEIhNA\nXU2SC09u48KT2/jsO+exvy/HL9btZvmrHSx/rYOHV+/kf/3Hai47YxrXvGUWl8xrpzZVvQfto27Z\n2l0AVXF8ABQEIhPSpIYafmv+TH5r/kzcnZe37uf+lVt58MVtPLRqB21NaW644AR+9/wTaG+O/tBF\ntVm6dhfTmms587hoX0g2SEEgMsGZGfNnT2b+7Ml84aoz+K91u/nBkxv5P//5GncsXc/VC47jI4tP\n5PSIX/1aLXKFIv/16m5+a/7MqjmjS0EgEiOpZIJL503j0nnTeL2jhzuf2MA9K7Zwz3NbuPi0dm5e\nfBIXnTK1at7AoujZDZ10Z/NcWiXDQqDZR0Vi66T2Jv7i6rN46vOX8cdXzuOV7V3c8N2nueqbv+CB\nF7dRKEZ43pAIW7pmF+lkgred0hZ2KcdMQSASc5Maavj4Jafwi/9xKbe/bz65QpFb7l7JFd94nPue\n30K+MPyV0DK8x9bs4vyTptA43HwuEaUgEBGgdKHbBxbN4ZFbL+aO3z2HdDLBp//9RS77m8f50dOb\nyBzDPEtxt3FPL+s7erl0XvUMC4GCQEQOkUgYV509k5/espglv7eQyfU1/On9L7P49qX8w+Pr6c7k\nwi4xsh5evQOAy8+I9voDh1IQiMiwEgnjHWfO4D8+cRE/+uj5zJvezF89tIa33baUO5ato28gf/Qn\niZmHVu3gzONaxjyjb6UpCETkiMyMC09p418+ej4P/OFFLDqhldt/tpaLb1/GXU9sIJvXkBHA9v39\nrNy0j6vOnhl2KSOmIBCRYzZ/9mS++6Fzufdjb+Xk9ka++MBqLvv64zz44jY8yqsTVcDPVpWGha48\na0bIlYycgkBERmzhCVP415sv4J9vOo/muho+efdKrr3jCZ7b2Bl2aaF5aNUO5k1v5uT2prBLGTEF\ngYiMiplx8Wnt/OSTb+P2981n275+rvvOk3z8h8+xcU9v2OVVVEd3lmc3dFZlbwB0ZbGIjFEyYXxg\n0RzePX8mS5a/zpLlr/Por3byexfM5ZbLTmFyw+gX8akWD6/egTu86+zqDAL1CERkXDSkU9x6+Wks\n++wlXHfObO584tdcfPtS/m7pOnqyE/sMo5+t2sFJbY3Mmx7tReoPR0EgIuNqWksdX71uPj/91GIW\nntDK1x5ey+LbHuM7y9bTOwEDYW/vAE++vocrz5pRtXM0VTwIzGyemb0w5NZlZrdWug4RKa/TZ7Tw\n/Q+fx/0fv5D5sydz28/WsPj2pdz1xAZyE2jaikd/tZNC0avytNFBFQ8Cd1/r7gvcfQGwEOgD7q90\nHSJSGW85vpW7bjqPez92IadNb+KLD6zmnd9YziOrd0yIU04fWrWd2a31VbP2wHDCHhq6DFjv7htD\nrkNEymzhCa3c/fsX8N0bF2EGN//gOX77H57iiXW7qzYQNu3p4xfrdvOuKh4WgvCD4IPA3SHXICIV\nYmZcdsZ0Hr71Yr58zVls7Ozld/7pad7390/y+KsdVRUI7s7n73+J2lSSm952YtjljEloQWBmaeA9\nwD2H+fnNZrbCzFZ0dHRUtjgRKatUMsHvXXACj3/uUr589Zls39fPjd97hmvueIKla3ZVRSD827Ob\n+eW6PXz+qtOZOak+7HLGxMJ6wc3sauAT7v6Oo+27aNEiX7FiRQWqEpEwDOSL3Pv8Fv5u6Tq27O3n\nzbMn8anLT+XSedMiOeSyY3+GK/7mcc6c1cKPPnoBiUT0agQws+fcfdHR9gtzaOh6NCwkIkA6leD6\n845n6Wcv4bbrzqazb4Cb7lzBu79VWi0tSovjuDv/8z9eJlcs8tX3zo9sCIxEKEFgZo3AFcB9YbQv\nItFUk0zw2+cez2OfuYTb3zef/lyBW+5eydu//jg/eHJDJKa+fvCl7fznK7v4zBXzmNvWGHY54yK0\noaGR0NCQSDwVi86jr+zk7x9fz8pN+2iqTfGeBcfxwXPncPasSRUfNvrJS9v43D0vcdqMZu772IUk\nI94bONahIc01JCKRlUgY7zxzBu9403Se27iXu5/ZzH3Pb+FHT2/ijJkt/Pai2Vzzlllln8+oUHS+\n/sha7li2noUntPKdG86JfAiMhHoEIlJVujI5HnhhG//67CZWbe0inUrwzjNncN05s3jryVOpTSXH\ntb39/Tn+6N9e4LE1u7j+vDl86T1njnsb5XKsPQIFgYhUrdXb9nPPii3cv3Ir+/tzNKaTXHRKG5ee\nPo3Fp7Yxu3X0S0a+urObHz61kfue30p/rsAX33MmN5x/fCTPYjocBYGIxEYmV+CX63azdO0ulq7p\nYOu+fgBmTqpj0dwpnDu3lVPam2hrrqWtqZbJ9TW/cbbP/v4cq7fu56Wt+3lszS6e+XUn6WSCq86e\nwUcXn8RZsyaF8auNiYJARGLJ3Xl1Zw9Pvb6HZzd0smLDXnZ0Zd6wTyph1KeT1NUkqU0lcOdAeACc\n1N7IBxbN4f0LZzO1qbbSv8K40cFiEYklM2PejGbmzWjmxgvn4u5s3dfP5s5+dvdkD9z6BgpkckWy\nuQIFd35n+vGcPWsSZ8+aRGvjxF9MZygFgYhMaGbG7NaGMR0vmOjCnnRORERCpiAQEYk5BYGISMwp\nCEREYk5BICIScwoCEZGYUxCIiMScgkBEJOaqYooJM9sPvDZk0yRg/yH3D/e1Ddg9guaGPvfRth+6\n7Uh1DW4jonUNbqtRXaHUdbQ6DlfXcDWqrolT10jfKw5t6wR3bz/qo9w98jdgyeG+H7x/hK8rxtLW\nkbaPpK7B+1Gta0h9qiuEuo5Wx+FqGK5G1TVx6gq+HnNtR6rrSLdqGRp68AjfP3iUr2Nt60jbR1LX\n4P2o1nWkNlRX+es6Wh2Hq2G4elSX6hqRqhgaGgszW+HHMPtepamukVFdI6O6RiaqdUFlaquWHsFY\nLAm7gMNQXSOjukZGdY1MVOuCCtQ24XsEIiJyZHHoEYiIyBFUVRCY2ffMbJeZrRrFYxea2ctmts7M\nvmnBwqNm9m9m9kJw22BmL0ShruBnnzSzNWa22sxuj0JdZvYlM9s65DW7Kgp1Dfn5Z8zMzawtCnWZ\n2ZfN7KXgtXrEzI6LSF1fC/62XjKz+81sckTqen/w9140sxGNi4+lnsM8341m9lpwu/FotUegrq+Y\n2WYz6xnxk47mVKOwbsDFwDnAqlE89hngAsCAh4B3DbPP14E/i0JdwKXAfwK1wffTIlLXl4DPRvHf\nEZgDPAxsBNqiUBfQMmSfW4C/j0hd7wBSwf3bgNsiUtcZwDxgGbCoEvUEbc09ZNsU4PXga2twv/Vo\nf4Mh13UBMBPoGem/R1X1CNx9OdA5dJuZnWxmPzOz58zsv8zs9EMfZ2YzKf2HfMpLr9g/A9ccso8B\nHwDujkhdHwO+6u7ZoI1dEalrzMpY1zeAPwZGdeCrHHW5e9eQXRtHU1uZ6nrE3fPBrk8BsyNS1yvu\nvnaktYylnsN4J/Cou3e6+17gUeDK0fzfqERdQTtPufv2Y3yeN6iqIDiMJcAn3X0h8FngjmH2mQVs\nGfL9lmDbUIuBne7+GuNjrHWdBiw2s6fN7HEzOzcidQH8YTCk8D0za41CXWZ2NbDV3V8cp3rGpa6g\ntq+Y2Wbgd4E/i0pdQ9xE6ZNt1OqqVD3DmQVsHvL9YI3jVft41zUmVb1msZk1ARcC9wwZpqsd5dNd\nzyh6A2WsK0Wp+3cBcC7w72Z2UvApJMy6vgN8mdIn2y9TGk67abQ1jUddZtYA/Cml4Y5xM15/X+7+\nBeALZvZ54A+BL0ahruC5vgDkgR+Opabxrms8HKkeM/sw8Klg2ynAT81sAPi1u18bt7qqOggo9Wj2\nufuCoRvNLAk8F3z7AKU3r6Fd39nA1iH7p4D3AgsjVNcW4L7gjf8ZMytSmnOkI8y63H3nkMf9I/CT\nMdQzXnWdDJwIvBj8x5oNPG9m57n7jhDrOtQPgZ8yxiAYr7rM7EPAu4HLxvIBY7zrGkfD1gPg7t8H\nvh/Utwz4kLtvGLLLVuCSQ2pcFmwfa+3lqGtsRnpQIewbMJchB12AJ4D3B/cNePNhHnfoAZ6rhvzs\nSuDxKNUF/AHwF8H90yh1By0Cdc0css8fAf8ahdfrkH02MIqDxWV6vU4dss8ngR9HpK4rgV8B7VH6\nux/y82WM8GDxaOvh8Adlf03pgGxrcH/Ksf4NhlHXkH1GfLB41H8AYdwoDd1sB3KUPjF/hNInwZ8B\nLwZ/2MOe9QMsAlYB64FvM+RNFbgT+IMo1QWkgX8JfvY88PaI1PUD4GXgJUqf7mZGoa5D9tnA6M4a\nKsfrdW+w/SVK88DMikhd6yh9uHghuI3mbKZy1HVt8FxZYCfwcLnrYZg33GD7TcHrtA748Ej+BkOq\n6/bg+YvB1y8d62unK4tFRGJuIpw1JCIiY6AgEBGJOQWBiEjMKQhERGJOQSAiEnMKAqlKo5phcWzt\n/ZOZvWmcnqtgpdlIV5nZg3aUmT/NbLKZfXw82hYZjk4flapkZj3u3jSOz5fyg5OwldXQ2s3sLuBV\nd//KEfafC/zE3c+qRH0SP+oRyIRhZu1mdq+ZPRvcLgq2n2dmT5rZSjN7wszmBds/ZGYPmNljwM/N\n7BIzW2ZmP7bSXP0/NDswT/4yC+bHN7OeYCK5F83sKTObHmw/Ofj+ZTP738fYa3mSgxPnNZnZz83s\n+eA5rg72+SpwctCL+Fqw7+eC3/ElM/vzcXwZJYYUBDKR/C3wDXc/F7gO+Kdg+xpgsbu/hdLsn385\n5DHnAO9z9/8WfP8W4FbgTcBJwEXDtNMIPOXubwaWA78/pP2/dfezeeMMlcMK5uC5jNIV2gAZ4Fp3\nP4fSehRfD4LoT4D17r7A3T9nZu8ATgXOAxYAC83s4qO1J3I41T7pnMhQlwNvGjKjY0sw0+Mk4C4z\nO5XSrKk1Qx7zqLsPnSv+GXffAmCl1ermAr84pJ0BDk629xxwRXD/rRycm/5HwF8fps764LlnAa9Q\nmlMeSnPO/GXwpl4Mfj59mMctwtQAAAABZUlEQVS/I7itDL5vohQMyw/TnsgRKQhkIkkAF7h7ZuhG\nM/s2sNTdrw3G25cN+XHvIc+RHXK/wPD/R3J+8ODa4fY5kn53XxBMn/0w8Angm5TWK2gHFrp7zsw2\nAHXDPN6Av3L3fxhhuyLD0tCQTCSPUJrhEwAzG5zmdxIHpwr+UBnbf4rSkBTAB4+2s7v3UVq+8jPB\nVOiTgF1BCFwKnBDs2g00D3now8BNQW8HM5tlZtPG6XeQGFIQSLVqMLMtQ26fpvSmuig4gPorSlN5\nQ2lWxr8ys5WUtxd8K/BpM3uJ0qIi+4/2AHdfSWlm0usprVewyMxeBv47pWMbuPse4JfB6aZfc/dH\nKA09PRns+2PeGBQiI6LTR0XGSTDU0+/ubmYfBK5396uP9jiRsOkYgcj4WQh8OzjTZx9jXMJTpFLU\nIxARiTkdIxARiTkFgYhIzCkIRERiTkEgIhJzCgIRkZhTEIiIxNz/BxYOh1Lv4od9AAAAAElFTkSu\nQmCC\n", 396 | "text/plain": [ 397 | "
" 398 | ] 399 | }, 400 | "metadata": { 401 | "tags": [] 402 | } 403 | } 404 | ] 405 | }, 406 | { 407 | "metadata": { 408 | "id": "WX2qaTTxxAlO", 409 | "colab_type": "code", 410 | "outputId": "35109d8f-c041-4703-c4fa-0d848c22fa2d", 411 | "colab": { 412 | "base_uri": "https://localhost:8080/", 413 | "height": 80 414 | } 415 | }, 416 | "cell_type": "code", 417 | "source": [ 418 | "# run one cycle training with max learning rate equal to ~1/10 (1 order of magnitude) of min learning rate in plot\n", 419 | "\n", 420 | "# Run one epoch with lower layers \n", 421 | "lm_learn.fit_one_cycle(cyc_len=1, max_lr=1e-1, moms=(0.8, 0.7))" 422 | ], 423 | "execution_count": 0, 424 | "outputs": [ 425 | { 426 | "output_type": "display_data", 427 | "data": { 428 | "text/html": [ 429 | "\n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | "
epochtrain_lossvalid_lossaccuracytime
04.2276023.9151310.35391903:02
" 449 | ], 450 | "text/plain": [ 451 | "" 452 | ] 453 | }, 454 | "metadata": { 455 | "tags": [] 456 | } 457 | } 458 | ] 459 | }, 460 | { 461 | "metadata": { 462 | "id": "00KrBsFp_uty", 463 | "colab_type": "code", 464 | "outputId": "ea6643f6-0fc0-4dbc-bf5f-f30e86c51eab", 465 | "colab": { 466 | "base_uri": "https://localhost:8080/", 467 | "height": 111 468 | } 469 | }, 470 | "cell_type": "code", 471 | "source": [ 472 | "# Run for many epochs with all layers unfrozen and reduce max learning rate by factor of 2.6 (according to paper)\n", 473 | "lm_learn.unfreeze()\n", 474 | "lm_learn.fit_one_cycle(cyc_len=2, max_lr=4e-3, moms=(0.8, 0.7))" 475 | ], 476 | "execution_count": 0, 477 | "outputs": [ 478 | { 479 | "output_type": "display_data", 480 | "data": { 481 | "text/html": [ 482 | "\n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | "
epochtrain_lossvalid_lossaccuracytime
03.6693643.5361960.39720003:41
13.4612983.4229950.41141503:41
" 509 | ], 510 | "text/plain": [ 511 | "" 512 | ] 513 | }, 514 | "metadata": { 515 | "tags": [] 516 | } 517 | } 518 | ] 519 | }, 520 | { 521 | "metadata": { 522 | "id": "hgIO8l6vxAno", 523 | "colab_type": "code", 524 | "colab": {} 525 | }, 526 | "cell_type": "code", 527 | "source": [ 528 | "# save the fine tuned encoder\n", 529 | "lm_learn.save_encoder('ft_sentiment_enc')" 530 | ], 531 | "execution_count": 0, 532 | "outputs": [] 533 | }, 534 | { 535 | "metadata": { 536 | "id": "r6xTBfZps-ze", 537 | "colab_type": "text" 538 | }, 539 | "cell_type": "markdown", 540 | "source": [ 541 | "#### Train Target Task Classification\n", 542 | "\n", 543 | "Now we use our fine tuned encoder (Generalized Language Model) and train our classification task on top of this. The fast.ai library provides the [text_classification_learner](https://docs.fast.ai/text.learner.html#text_classifier_learner) function that allows us to easily build this model and load our fine tuned encoder into it.\n", 544 | "\n" 545 | ] 546 | }, 547 | { 548 | "metadata": { 549 | "id": "vQ56Tdn6xAsT", 550 | "colab_type": "code", 551 | "colab": {} 552 | }, 553 | "cell_type": "code", 554 | "source": [ 555 | "# instantiate classifcation task learner\n", 556 | "cla_learn = text_classifier_learner(data_clas, arch=AWD_LSTM, drop_mult=0.5)\n", 557 | "\n", 558 | "# load finetuned encoder and freeze layers\n", 559 | "cla_learn.load_encoder('ft_sentiment_enc')\n", 560 | "cla_learn.freeze()" 561 | ], 562 | "execution_count": 0, 563 | "outputs": [] 564 | }, 565 | { 566 | "metadata": { 567 | "id": "eLoO2I551eo5", 568 | "colab_type": "code", 569 | "outputId": "d88560d8-0254-425c-85f1-a2d68bc4be51", 570 | "colab": { 571 | "base_uri": "https://localhost:8080/", 572 | "height": 34 573 | } 574 | }, 575 | "cell_type": "code", 576 | "source": [ 577 | "# find max learning rate for one cycle training\n", 578 | "cla_learn.lr_find(start_lr=1e-8, end_lr=1e2)" 579 | ], 580 | "execution_count": 0, 581 | "outputs": [ 582 | { 583 | "output_type": "display_data", 584 | "data": { 585 | "text/html": [ 586 | "" 587 | ], 588 | "text/plain": [ 589 | "" 590 | ] 591 | }, 592 | "metadata": { 593 | "tags": [] 594 | } 595 | }, 596 | { 597 | "output_type": "stream", 598 | "text": [ 599 | "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" 600 | ], 601 | "name": "stdout" 602 | } 603 | ] 604 | }, 605 | { 606 | "metadata": { 607 | "id": "a3bC64RqF702", 608 | "colab_type": "code", 609 | "outputId": "4dcefada-f920-4a8d-c0af-e62bb3eeab8f", 610 | "colab": { 611 | "base_uri": "https://localhost:8080/", 612 | "height": 283 613 | } 614 | }, 615 | "cell_type": "code", 616 | "source": [ 617 | "cla_learn.recorder.plot()" 618 | ], 619 | "execution_count": 0, 620 | "outputs": [ 621 | { 622 | "output_type": "display_data", 623 | "data": { 624 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VOXZ//HPlT0ECAHCvoRNFtkJ\noKjUtaK2om1VUCuoFWtdHmu12s1atb/H1lpbrRsqta48itZSxYW6omDZN5F9DShhXzJZJ/fvjzng\nGJJMgDmZmfB9v17zYs597jPnypDMNfdy7mPOOURERGqTFOsAREQk/ilZiIhIREoWIiISkZKFiIhE\npGQhIiIRKVmIiEhEShYiIhKRkoWIiETka7Iws1FmtsLMVpvZHdXsf9DMFnqPlWa2O2xfMGzfVD/j\nFBGR2plfV3CbWTKwEjgLKADmAGOdc8tqqH8jMMg5d5W3vd8517iu52vZsqXLy8s76rhFRI4l8+bN\n2+6cy41UL8XHGIYBq51zawHMbDIwGqg2WQBjgd8e6cny8vKYO3fukR4uInJMMrMNdannZzdUe2BT\n2HaBV3YIM+sMdAHeDyvOMLO5ZvaZmV3gX5giIhKJny2LwzEGmOKcC4aVdXbObTazrsD7ZrbEObcm\n/CAzmwBMAOjUqVP9RSsicozxs2WxGegYtt3BK6vOGOCl8ALn3Gbv37XAh8Cgqgc55yY65/Kdc/m5\nuRG73ERE5Aj5mSzmAD3MrIuZpRFKCIfMajKzXkAOMCusLMfM0r3nLYGTqHmsQ0REfOZbN5RzrsLM\nbgDeAZKBSc65z83sbmCuc+5A4hgDTHbfnJbVG3jCzCoJJbT7appFJSIi/vNt6mx9y8/Pd5oNJSJy\neMxsnnMuP1I9XcEtIiIRKVmIiCSwV+cV8OJ/N/p+HiULEZEE9tqCAqbM2xS54lFSshARSWCBsiBZ\n6f5fMqdkISKSwAKlQRqlJft+HiULEZEEFiivoFGaWhYiIlILtSxERCSiQJmShYiI1KKy0lFcHlQ3\nlIiI1Ky4PLRQt1oWIiJSo6KyCgAaaeqsiIjUpLjMa1mkqmUhIiI1KCoNJYusdCULERGpQcDrhsrU\nALeIiNQk4HVDZWmAW0REavJ1y0LJQkREavB1y0LdUCIiUoOiMl1nISIiERTrOgsREYnkwNTZTF1n\nISIiNSkuD5KRmkRykvl+LiULEZEEVVRaP/eyAJ+ThZmNMrMVZrbazO6oZv+DZrbQe6w0s91V9jc1\nswIz+5ufcYqIJKLielqeHMC3lGRmycAjwFlAATDHzKY655YdqOOc+2lY/RuBQVVe5h7gY79iFBFJ\nZEVlFfUybRb8bVkMA1Y759Y658qAycDoWuqPBV46sGFmQ4DWwLs+xigikrACZcF6uSAP/E0W7YFN\nYdsFXtkhzKwz0AV439tOAh4AbvUxPhGRhBYoC9bLIoIQPwPcY4Apzrmgt/0TYJpzrqC2g8xsgpnN\nNbO527Zt8z1IEZF4EigLkplaP91Qfp5lM9AxbLuDV1adMcD1YdsnAqeY2U+AxkCame13zn1jkNw5\nNxGYCJCfn++iFbiISCIIlFXUW8vCz2QxB+hhZl0IJYkxwKVVK5lZLyAHmHWgzDl3Wdj+8UB+1UQh\nInKsC9TjbCjfuqGccxXADcA7wBfAy865z83sbjM7P6zqGGCyc04tAxGRwxCox+ssfD2Lc24aMK1K\n2Z1Vtu+K8BrPAM9EOTQRkYTmnCNQ3gBaFiIi4p+S8kqco2FcwS0iIv4oOrDirFoWIiJSk+J6vJcF\nKFmIiCSkr1sW6oYSEZEaHLilaqNj7ApuERE5DAHvxkeN6uHGR6BkISKSkAJeN1RWPdxSFZQsREQS\n0oFuqIaw6qyIiPjkQLJoCPezEBERnxzohlLLQkREahTQdRYiIhJJUVkFaclJpCbXz8e4koWISAIq\nLgvW2zUWoGQhIpKQikqD9Ta4DUoWIiIJqbi8ot4Gt0HJQkQkIYVaFkoWIiJSi+KyoFoWIiJSu6Ky\nCo1ZiIhI7QJqWYiISCQBtSxERCSSQKlaFiIiUgvnHIHyIFkN5aI8MxtlZivMbLWZ3VHN/gfNbKH3\nWGlmu73yzmY23yv/3Mx+7GecIiKJpLSikmClq7dbqgL4diYzSwYeAc4CCoA5ZjbVObfsQB3n3E/D\n6t8IDPI2vwROdM6VmlljYKl37Ba/4hURSRTF9byIIPjbshgGrHbOrXXOlQGTgdG11B8LvATgnCtz\nzpV65ek+xykiklCKvOXJG0qyaA9sCtsu8MoOYWadgS7A+2FlHc1ssfcaf1CrQkQk5OuWxbE3G2oM\nMMU5FzxQ4Jzb5JzrD3QHxplZ66oHmdkEM5trZnO3bdtWj+GKiMROUQPrhtoMdAzb7uCVVWcMXhdU\nVV6LYilwSjX7Jjrn8p1z+bm5uUcZrohIYggc7IZqGC2LOUAPM+tiZmmEEsLUqpXMrBeQA8wKK+tg\nZpne8xzgZGCFj7GKiCSMQGn9tyx8S0vOuQozuwF4B0gGJjnnPjezu4G5zrkDiWMMMNk558IO7w08\nYGYOMOBPzrklfsUqIpJIAuWhZFGf11n42oZxzk0DplUpu7PK9l3VHDcd6O9nbCIiiSpQGuqGymwg\n3VAiIuKDgDfArftZiIhIjRraALeIiPggUBYkJclIS6m/j3AlCxGRBBMoC9brTChQshARSThFpRX1\n2gUFShYiIgknUB6kUT1OmwUlCxGRhBMorVA3lIiI1C40ZqFuKBERqYUGuEVEJKJAWQVZalmIiEht\nAmVBMtWyEBGR2gTKgvW61AcoWYiIJJxAWUW9LiIIShYiIgmlrKKS8qBTy0JERGp24P7bGrMQEZEa\nBcpDK85mpasbSkREalAUg1uqgpKFiEhCOdANpSu4RUSkRkUHb3ykloWIiNTg65aFkoWIiNSgKAa3\nVAUlCxGRhBJoiAPcZjbKzFaY2Wozu6Oa/Q+a2ULvsdLMdnvlA81slpl9bmaLzewSP+MUEUkUgbLY\nTJ317Wxmlgw8ApwFFABzzGyqc27ZgTrOuZ+G1b8RGORtBoArnHOrzKwdMM/M3nHO7fYrXhGRRFDU\nAMcshgGrnXNrnXNlwGRgdC31xwIvATjnVjrnVnnPtwCFQK6PsYqIJITisiBJBukp9TuK4OfZ2gOb\nwrYLvLJDmFlnoAvwfjX7hgFpwBofYhQRSShFZRU0SkvBzOr1vPEywD0GmOKcC4YXmllb4DngSudc\nZdWDzGyCmc01s7nbtm2rp1BFRGKnOAZ3yQN/k8VmoGPYdgevrDpj8LqgDjCzpsCbwK+cc59Vd5Bz\nbqJzLt85l5+bq14qEWn4ihpgspgD9DCzLmaWRighTK1aycx6ATnArLCyNOCfwLPOuSk+xigiklCK\nvW6o+uZbsnDOVQA3AO8AXwAvO+c+N7O7zez8sKpjgMnOORdWdjEwEhgfNrV2oF+xiogkiqLS2LQs\n6pSezKwbUOCcKzWzU4H+hL711zqV1Tk3DZhWpezOKtt3VXPc88DzdYlNRORYEigPkp2ZWu/nrWvL\n4lUgaGbdgYmExiJe9C0qERGpVqC0gkap8TtmUel1K10IPOycuw1o619YIiJSnUBZkEbp8Zssys1s\nLDAOeMMrq/92kIjIMS5QVhHXs6GuBE4Efu+cW2dmXQhd/yAiIvUoUBYkKwazoep0Rm89p5sAzCwH\naOKc+4OfgYmIyDcFKx2lFZVkxmvLwsw+NLOmZtYcmA88aWZ/9jc0EREJF6u75EHdu6GynXN7ge8R\nmjI7HDjTv7BERKSqPYFygLieOpvirdN0MV8PcIuISD3a7SWLnEZp9X7uuiaLuwldib3GOTfHzLoC\nq/wLS0REqtoZKAOgeVb9J4u6DnC/ArwStr0W+L5fQYmIyKF2e8miWby2LMysg5n908wKvcerZtbB\n7+BERORrO4vivGUB/J3Q8h4XeduXe2Vn+RGU1Mw5x97iCorLg5SUBykuD5JkxnGtG9f7zVBEpH7t\nKirDLDYD3HVNFrnOub+HbT9jZjf7EZDUbOnmPfzmX0tZsPHQ9RuHdM7hjnN6MTSveQwiE5H6sCtQ\nTnZmKslJ9f/FsK7JYoeZXc7XNygaC+zwJySpam9JOX9+dyXPzlpP86w0bju7JzmN0shITSIjNZnC\nvSU8+uEaLnp8Fmf2bs3to3rSo3WTWIctIlG2M1BG8xiMV0Ddk8VVwMPAg4ADZgLjfYqpQSrYFeDV\neZs5vVcr+nXIrtMxhftKeO+LQv48fSXb95dy+fDO3Hp2z2qboJcM7cSkT9fx+IdrOPsvH3Pj6T24\n+cwe6poSaUB2B8po1ig2y/LVdTbUBiD8hkV43VB/8SOohsI5x/yNu3j6k3W8vfQrKh1M/HgNf79y\nGMO6HNpdVBGs5P3lhXy6ejsz1+xgVeF+APq1z+bpcfn079CsxnNlpiVz/WnduXRYJ+55cxl/fW8V\nxeVBfnFOLyUMkQZiZ1E57ZtlxOTcR7Ma1S0oWRz04YpC/rVwC5XOEax0OAcbdwZYsnkPTTNSuGZk\nV87p25afvbyQcZNm8+QV+Zzco+XB41dt3cetUxazaNNuMlOTGdqlOd8f0oETu7agX/tskurYR5mT\nlcYDFw2gSXoKEz9eS1lFJb/9bh8lDJEGYFdRGX3bNY3JuY8mWejTx7NpZ4Drnp9PemoSzTJTSTLD\nDBpnpHL36OP5/uAOZKWH3urJE07kh0//l6v+MYfHLx/MyB65TJyxlr9MX0VWejJ/uWQg5/ZrS1rK\nkd/x1sy46/zjSUlO4ulP1lEerOSe0X3rnHBEJP4459gVKCMnBtNm4eiShYtcpeFzzvHLfy4hyeDN\nm06hfbPMWuvnNknnpWtO4IpJs7n2uXl0y23M8q/2cW6/Ntw9ui8tG6dHJS4z49fn9SY1OYnHP1pD\ncXmQ31/QLyarVYrI0SsuD1JaURmTpT4gQrIws31UnxQMqP1T8Rjx2vzNzFi1nd+df3zERHFATlYa\nz/9oOFc9M4d124t45NLBnNc/+jceNDNuH9WTzNRkHvzPShZu2s2DFw9kQMeaxz7qan9pBbe/upiU\nJOPWb/ekY/NGUYhYRGqy6+C6UHE4wO2c0/zLWmzfX8o9by5jcKdm/PCEzod1bHZmKi9feyIVlZWk\np/j3bd/M+J8ze5Cfl8Otryzie4/N5IbTunPD6d1JTT6yrq6te0u48u9zWLF1H6nJxltLvuLKk/L4\nyWndY3KxUDQEKx0fLC9kV6CMUX3b0CQjMX8Oabh2eVdvJ2I31DHvd/9eRqA0yB++3/+IxgOSk4zk\npPrpFjqpe0vevnkkd039nL++t4oPVhRyz+i+h93KWLl1H+MnzWZPcTlPj8unV5um/OndFUycsZaX\n527itrN7cenwTj79FNG3v7SCV+Zu4pmZ69mwIwDAnf/6nPP6t2XM0I4M6ZyjyQESF2K51Af4nCzM\nbBTwVyAZeMo5d1+V/Q8Cp3mbjYBWzrlm3r63gROAT5xz3/Ezzkg27gjwr4WbaZ2dQbfcxnTPbczc\nDTv596It/PTM4xLmArjszFQevGQgZ/VpzW9eX8roRz7lvH5tue3snuS1zIp4/Kw1O5jw3FwyUpP5\nv2tPpG/70PUif7poAONH5HHvm8v45T+X0DwrjVF92/j94xyVwn0lPPnxWibP3sS+0gqGdM7h52f3\non1OJv83ZxNTF25myrwCerRqzM1nHse5/dooaUhM7fIWEYxVN5Q55884tZklAysJrR9VAMwBxnq3\naK2u/o3AIOfcVd72GYQSyLV1SRb5+flu7ty50Qr/oE9Xb+f6F+cfXEf+gCSD7q0a88aNpxzVzKVY\n2VdSzpMz1vHUjND02kuHd+Ls49t4rZ3QY19JBcu27GXplj0s27KXdduL6N6qMc9cOZQOOYeOUZRV\nVHLho5/y1Z4S3v3pSFpEabA+mrbtK+WJj9bw/H83UB50nNevLVed3IWBVVpYRaUVvLn4S57+ZB0r\ntu5jUKdm/Orc3uRrORWJkWc+Xcdd/17GvF+fGdW/LTOb55zLj1jPx2RxInCXc+5sb/sXAM65/62h\n/kzgt8656WFlpwK3xiJZOOd4ZuZ67n3zC7rlZvHED/NJMlhduJ812/azcWeAy4Z3pnfb2Mx5jpbC\nfSU89N4qXpq9iWBl9b8LHXIy6dsum34dsrl8eGeya/lms/yrvZz/8Kec0bsVj142OOK38feXb+XR\nD9Zwy7ePY0S3lrXWPRprt+3npdkbee6zDaGkNqgDN57ePWKLKljpmDJvEw+8u5LCfaWcfXxrbh/V\ni665jX2LVaQ6D05fyUPvr2LVveeQcoTjjdWpa7LwsxuqPbApbLsAGF5dRTPrDHQB3j+cE5jZBGAC\nQKdO0esnL60I8pvXl/Ly3AK+3ac1f75kII296yQ6t8jijN6to3auWGvVJIN7L+jHdad2p2BngKB3\nUWGw0pGekkzvtk0Oa+38Xm2acvNZPfjj2yuYumgLowe2r7Fu4d4Sbnl5EXuKy7n0yf9y2fBO3HFO\nr0MGl/eXVrC4YDdLN+9h6ea9LN28h027AuQ0SqN10wxaN02nVdMMOuRk0qVFFp1bZNG5RSMK95Uy\nbcmXvLH4S774ci9JBhcMas+Np/egSx263SA0rnTJ0E58d0A7np6xjsc/WsN7X3zM5Sd05qYzesSs\n/1iOPbsCZTTNSI1qojgc8TLAPQaY4pwLHs5BzrmJwEQItSyOJoDKSseCTbt4e+lXTFvyFZt3F3PT\nGT24+Ywex8TFbO2bZdZ56m8kE07pyvRlW7nzX59zQtcWtG566PIEzjl+/upiisuC/PuGk3l9wWae\n/nQdHywv5J4L+pKZlszM1TuYuWY7iwr2HGz1tMvOoG/7bM7q05pdgTIK95WyeXcJ8zfuPjgAWNWQ\nzjn85jt9OLdfG9pmH9nP2CgthRvP6MGYYZ34y39Cizq+Oq+A60/vzvgReWSk6voV8dfOorKYfjnx\nM1lsBjqGbXfwyqozBrjex1hqtH1/KQ+9t4p3Pv+KrXtLSU02Tu7eknsv6MtpvVrFIqSEl5KcxAMX\nDeDch2Zwx6uLmTR+6CHdUS/O3siHK7Zx13f70Ld9Nn3bZ3Nu/7b8fMpirv5HqDsxOcno3yGbH3+r\nK/l5zenXPrvWixb3lZSzYUeADTsCrN9RREZqMqP6tolaEoTQRZW/v7Af40fkcd9by7nvreU8O3M9\nN5zegx8M6ZCQ41eSGHYHymO2iCD4O2aRQmiA+wxCSWIOcKlz7vMq9XoBbwNdXJVg6mPMYn9pBSfd\n9z7DuzTn3H5tOb13K5pqjn1UTPpkHXe/sYzz+rflF+f0Ojgovn57Eef8dQZDOufw7FXDvtFyKykP\nMnXhFlo2SWNoXvO4v95h5urt3P/uChZs3E37ZpncdEZ3vje4wxFfwyJSk3P/OoO22Rk8PX5oVF83\n5mMWzrkKM7sBeIfQ1NlJzrnPzexuYK5zbqpXdQwwuZpEMQPoBTQ2swLgaufcO9GOs3F6CnN+daa+\nEfpg/Ig89hSX88THa5i+bCs/OrkL147sxi0vLyQ12bj/okOvT8lITebioR1reMX4M6J7S17r1oKP\nVm7jwekruf3VJTzywRp+eW5vzj6+tabbStTsDpTRJ0aLCIKPLYv65tfUWTl6W3YX86d3VvDags1k\npCZRUl7JX8cMrHXwOxE55/hgRSF/fHsFy7/ax8jjcvnd+ccfMpi+dW8JZRWVWiJFDkuv37zFD0/o\nzK/O6xPV1415y0LkgHbNMvnzJQMZf1Ie97+zgs4tGnH+gHaxDivqzIzTe7VmZI9cnp21gQenr+Ts\nBz/mmpFd6NqyMf9dt4P/rtvJhh0B0lKSmDzhBAZ3yol12JIAisuClJRXxmypD1CykHrUv0Mznru6\n2tnTDUpKchJXndyF7wxoy33TlvPIB2uA0BX0w7o05/LhnXnusw1MeHYe/7rhpKgOwEvD9PXV20oW\nIg1OqyYZ/PmSgVz7rW5UOkfP1k0OjtGc2jOX7z06k6ufmcOU60YcvI5HpDoHpoXHMlloVFfEZz3b\nNKF326bfGMzv0boJf7tsMKsK93Pz5AU1Xj0vAhxcbiiW11koWYjEyLeOy+Wu7/bhP18Uct9bX8Q6\nHIljO2O8iCCoG0okpn54Yh6rC/fz5Ix1bNpZzM9H9dS6U3KIWN/LApQsRGLuN9/pQ4vG6Tzx0Rqm\nf7GVscM68j9nHEduk/hbtVdi48AAd7MY3lxM3VAiMZaSnMRNZ/Tgw9tO47LhnZg8exPfuv8Dnvho\njcYyBAi1LJpmpMRsEUFQshCJG7lN0rl7dF+m3/ItRnRryf++tZzLnvqMLbuLYx2axNiuQHnMVzhW\nshCJM11aZvHkFUP44w/6s7hgD6P+8jFvLN4S67AkhnYFyg7rVgF+ULIQiUNmxsX5HZl20yl0zW3M\nDS8u4ObJC9i0MxDr0CQGYr08OShZiMS1vJZZvPLjE7npjB5MW/IVp/7pQ259ZRFrt+2PdWhSj3YH\nymN6QR4oWYjEvdTkJG456zhm3H4a407M443FWzjzzx9x00sLKNxXEuvwpB7sLCqL6TUWoGQhkjBa\nN83gzu/2YcbPT+eakV15d9lXXPjITFZu3Rfr0MRHJeVBisuDMb3GApQsRBJObpN0fnFOb165dgRl\nwUq+/9hMZq7eHuuwxCfxsIggKFmIJKx+HbJ5/fqTaJudwRWTZjNlXkGsQxIf7Co6sC6UuqFE5Ai1\nb5bJlOtGcELXFtz6yiIe+3BNrEOSKFPLQkSiomlGKn+/cijnD2jHH95ezr8Wbo51SBJFO+NgXSjQ\n2lAiDUJqchJ/umgAW/eWcNsri2nXLJOhec1jHZZEwW61LEQkmtJSknjih0PokJPJhGfnsn57UaxD\nkijY6Y1ZNNPUWRGJlmaN0pg0figAVz0z5+C3UklcuwJlNMlIITWGiwiCkoVIg5PXMouJV+RTsKuY\na5+bR3FZ8LCOLw9W+hSZHIldgdgv9QE+JwszG2VmK8xstZndUc3+B81sofdYaWa7w/aNM7NV3mOc\nn3GKNDRD85pz/0X9mb1+J1c9M4ei0oqIx5RVVPLzKYsYdPd0lm7eUw9RSl3sLIr9IoLgY7Iws2Tg\nEeAcoA8w1sz6hNdxzv3UOTfQOTcQeBh4zTu2OfBbYDgwDPitmeX4FatIQzR6YHv+cslAZq/fybhJ\ns9lbUl5j3T2BcsZNms3Lcwsw4PoX57OnuOb6Un92B8ppHuPxCvC3ZTEMWO2cW+ucKwMmA6NrqT8W\neMl7fjYw3Tm30zm3C5gOjPIxVpEGafTA9jw8dhALN+3mh0/9lz2BQxPAxh0BvvfYp8zdsJMHLhrA\n368cyuZdxdz2yiKc082XYm1nUVnMp82Cv1Nn2wObwrYLCLUUDmFmnYEuwPu1HNu+muMmABMAOnXq\ndPQRizRA5/ZrS1pyEj95YT5jn/yMHwzpQGpKEmnJRkWl48/vrqSi0vHc1cM5oWsLAO44pxf3vvkF\nT81YxzUjux58rWCl4z9fbCWnURrDumhqbn3YHSiL+bRZiJ/rLMYAU5xzhzUS55ybCEwEyM/P11cg\nkRqc2ac1T47L5/oX5nP3G8u+sa9zi0ZMGj+UbrmND5ZdfXIX5m3YxX1vL2dgp2YM7NiM1xds5rGP\n1rB2WxHNs9KYecfpZKQm1/ePckwpKQ9SVBaMiwFuP5PFZqBj2HYHr6w6Y4Drqxx7apVjP4xibCLH\nnG8dl8u835xJcVmQ8qCjPFhJebCSNtkZpKd880PfzPjjD/qz/G+fct3z80lPSWLz7mJ6t23KDad1\n528frOaNxV/ygyEdYvTTHBt2B+LjGgvwd8xiDtDDzLqYWRqhhDC1aiUz6wXkALPCit8Bvm1mOd7A\n9re9MhE5CukpyTRrlEZuk3TaNcukc4usQxLFAU0yUnn0ssGUlAdp3TSdSePzmXbTyfzs28fRo1Vj\n/jFzvcY0fHZgXajmDbkbyjlXYWY3EPqQTwYmOec+N7O7gbnOuQOJYwww2YX91jnndprZPYQSDsDd\nzrmdfsUqItXr3bYp839zFqnJhpkdLL9iRB6/eX0p8zfuZkhnTVT0y644WRcKfB6zcM5NA6ZVKbuz\nyvZdNRw7CZjkW3AiUidpKYd2QHxvUHv++PZy/jFzvZKFj3Z53VDxMMCtK7hF5LBlpadw0ZCOTFvy\nJYV7dWtXv2zYGVrfq0VjJQsRSVBXnNiZoHO88N+NsQ6lQQpWOl6avZGheTm0bJwe63CULETkyOS1\nzOLU43J5cfZGyiq0nlS0TV/2FZt2FnP1yV1iHQqgZCEiR2HciDy27SvlraVfxjqUBuepGevo2DyT\ns/q0iXUogJKFiByFkT1y6dIyi2dmro91KA3Kwk27mbthF1eO6EJykkU+oB4oWYjIEUtKMsaPyGPB\nxt08NWNtrMNpMJ7+ZB1N0lO4eGjHyJXrSbws9yEiCeqy4Z2YvW4n9775BZlpyVw2vHOsQ0pom3cX\nM23Jl1x1Uh6N0+PnIzp+IhGRhJSSnMSDlwykuDzIr19fSmZqMt8brGVAjtSz3pXx40bkxTqUb1A3\nlIgctbSUJB69bDAndm3Bra8s4q0lGvA+EkWlFbw4eyPn9G1Lh5xGsQ7nG9SyEJGoyEhN5skr8rli\n0mxumryAEXM2kZqcRGqykZqcRMfmmZzXrx292zb5xtIh8rVX5m5iX0kFV58SH9NlwylZiEjUZKWn\n8Pcrh/Krfy5l484A5RWVB1e3fXPJlzzywRq65mbx3f7t+O6AdnRv1Tjyix5DXl+4hX7tsxncKf6W\nUFGyEJGoapqRysNjBx1SvrOojLeXfsXURZt56P1V/PW9VQzq1IxL8jvynQHt4mowNxacc6wu3M/3\nBh9yn7e4cGz/74hIvWmelcalwztx6fBObN1bwtSFW/i/uZu447Ul3P3GMs7r15abzuhBx+bx1Vdf\nX7buLWV/aUXctraULESk3rVumsE1I7vyo1O6sGDTbl6es4mpi7YwY9V2Jk84gbyWWbEOsd6tLtwP\nQPfc+EwWmg0lIjFjZgzulMN93+/Paz8ZQWlFkLFPfsaGHUWxDq3erS7cBxC3LQslCxGJC73aNOWF\nH51ASXmQsRM/Y+OOQKxDqleQzOwUAAAPNUlEQVSrt+2nSUYKuU1iv8JsdZQsRCRu9GnXlOd/NJxA\neaiFsWnnsZMwVhfup1tu47idVqxkISJx5fh22Tx/9XD2l1Zw8ROzWLZlb6xDqhdrthXFbRcUKFmI\nSBzq2z6bl645AYAfPD6T/yzbGuOI/LWnuJxt+0qVLEREDlefdk15/fqT6JbbmGuem8tTM9binIt1\nWL6I95lQoGQhInGsddMMXr72RM7u04Z73/yCX/5zKaUVwViHFXVrDiQLtSxERI5MZloyj142mOtO\n7cZLszdywSMzWbl1X6zDiqrV2/aTlpwU1xck+poszGyUma0ws9VmdkcNdS42s2Vm9rmZvRhW/gcz\nW+o9LvEzThGJb0lJxu2jevHkFfkU7i3hOw9/wqRP1lFZ2TC6pVYX7qdLy6y4uStedXy7gtvMkoFH\ngLOAAmCOmU11zi0Lq9MD+AVwknNul5m18srPAwYDA4F04EMze8s5d2xMixCRap3VpzUDO47k9lcX\nc/cby3h/eSFXn9KFttkZtGmaQXZmatxOPa3Nmm376dsuO9Zh1MrP5T6GAaudc2sBzGwyMBpYFlbn\nGuAR59wuAOdcoVfeB/jYOVcBVJjZYmAU8LKP8YpIAshtks7T4/J5cfZG7n3jCz5Zvf3gvozUJFo1\nySAnK43mjVJpnpVObpN0TuuZy9C85iTF4Tf3kvIgm3YGGD0wPhcQPMDPZNEe2BS2XQAMr1LnOAAz\n+xRIBu5yzr0NLAJ+a2YPAI2A0/hmksE7bgIwAaBTp07Rjl9E4pSZcdnwzpzbty1rtu3nq70lfLWn\nhK17SyjcV8quQDnb9peycut+CveV8PhHa2jTNIPz+rfluwPaMaBDdty0QNZtL6LSxffgNsR+IcEU\noAdwKtAB+NjM+jnn3jWzocBMYBswCzhkCoRzbiIwESA/P79hdF6KSJ3lZKWRn9W81jpFpRW8t7yQ\nfy/awnOzNvD0J+sYeVwuD48dRHZmaj1FWrNEmDYL/g5wbwY6hm138MrCFQBTnXPlzrl1wEpCyQPn\n3O+dcwOdc2cB5u0TETksWekpnD+gHU9ekc+cX5/Jr8/rzaw127nw0U9Ztz32CxauLtyPGXTNje+V\ndv1MFnOAHmbWxczSgDHA1Cp1XifUqsDMWhLqllprZslm1sIr7w/0B971MVYROQZkZ6byo1O68vzV\nw9lVVMYFj3zKzLAxj1hYs20/HXIyyUhNjmkckfiWLLzB6RuAd4AvgJedc5+b2d1mdr5X7R1gh5kt\nAz4AbnPO7QBSgRle+UTgcu/1RESO2vCuLZh6w8m0bprODyfN5vnPNsQsltWF++O+Cwp8HrNwzk0D\nplUpuzPsuQNu8R7hdUoIzYgSEfFFx+aNePW6EfzP5IX8+vWlbNldzG1n96zXge9gpWPt9iJO6dGy\n3s55pHQFt4gcs5pkpPLkFflcNrwTj364hp+9soiyisp6O3/BrgBlFZVxPxMKYj8bSkQkppKTjHsv\n6Evb7Az+9O5Ktu0r5bHLh9A43f+Px9UJsCbUAWpZiMgxz8y44fQe3P+D/sxcs4NLnpjF1r0lvp93\nzbZQsuiWAGMWShYiIp6L8jvy1Lh81m0v4rsPf8KCjbt8Pd/qwv20bJxGs0Zpvp4nGpQsRETCnNaz\nFa/9ZATpqUlcMvEzXp1X4Nu5DtxKNREoWYiIVNGrTVOmXn8y+Z1z+Nkri7j3jWVUBKM78O2cC02b\nTYDxClCyEBGpVk5WGv+4ahjjR+Tx1CfruGLSbLbtKz3q191bUs6r8wq48pk57C2poEeCJAvNhhIR\nqUFqchJ3nX88x7dryq9fX8p5D83g4bGDGN61xWG9Tnmwkve+KGTKvAI+XrmNsmAl7ZtlMmFkVy4c\n3MGn6KNLyUJEJIKL8jvSt302178wn7FPfsatZ/fkxyO7RVzyfOOOAJPnbOSVeQVs21dKqybpXH5C\nZ74zoC2DOjaLm5Vv60LJQkSkDnq3bcrUG0/mF68t4Y9vr2Du+l08NHZQtddjlJQHuf3Vxfxr4RaS\nDE7v1YqxwzrxreNySUlOzN5/JQsRkTpqnJ7CQ2MGMiwvh7v+vYwxE2cxafxQWjXJOFhnb0k5P/rH\nXOas38l1p3bjihM70zY7M4ZRR0dipjgRkRgxM354Yh5PjctnTWER339sJmu9i+u27StlzBOfMX/D\nLv5yyUBuH9WrQSQKULIQETkip/VsxeQJJxAoDfL9x2by5uIvuejxmazdvp+nxuXH/W1SD5eShYjI\nERrQsRmvXjeCppmpXP/ifHYWlfHCj4Zzas9WsQ4t6jRmISJyFPJaZvHqdSP42/urGTusEz3bNIl1\nSL5QshAROUotG6dz1/nHxzoMX6kbSkREIlKyEBGRiJQsREQkIiULERGJSMlCREQiUrIQEZGIlCxE\nRCQiJQsREYnInHOxjiEqzGwbsMHbzAb21PK8urKWwPbDPG3469R1X9XymrZrizvasda0P1JZIr23\ndY1b723De2/rEvux/N52ds7lRqztnGtwD2Bibc9rKJt7NOep676q5TVt1xZ3tGOtaX+kskR6b+sa\nt97bhvfe1iV2vbeRHw21G+rfEZ7XtP9ozlPXfVXLa9qOFPfhinRsdfsjlSXSe3s4cR8uvbe1P4/1\ne1uX2PXeRtBguqGOlpnNdc7lxzqOukikWCGx4k2kWCGx4k2kWCGx4q2PWBtqy+JITIx1AIchkWKF\nxIo3kWKFxIo3kWKFxIrX91jVshARkYjUshARkYgaXLIws0lmVmhmS4/g2CFmtsTMVpvZQ2ZmXvn/\nmdlC77HezBbGc7zevhvNbLmZfW5mf4zXWM3sLjPbHPb+nhuNWP2KN2z/z8zMmVnLeI3VzO4xs8Xe\n+/qumbWLRqw+xnu/9zu72Mz+aWbN4jjWi7y/rUozO+qxgqOJsYbXG2dmq7zHuLDyWn+va3W4063i\n/QGMBAYDS4/g2NnACYABbwHnVFPnAeDOeI4XOA34D5DubbeK41jvAm5NpN8FoCPwDqHrelrGa6xA\n07A6NwGPx/N7C3wbSPGe/wH4QxzH2hvoCXwI5McqRu/8eVXKmgNrvX9zvOc5kX6vIz0aXMvCOfcx\nsDO8zMy6mdnbZjbPzGaYWa+qx5lZW0J/XJ+50Lv6LHBBlToGXAy8FOfxXgfc55wr9c5RGMex+sbH\neB8Efg5EbcDPj1idc3vDqmYlQLzvOucqvKqfAR3iONYvnHMrohHf0cRYg7OB6c65nc65XcB0YNTR\n/h02uGRRg4nAjc65IcCtwKPV1GkPFIRtF3hl4U4BtjrnVvkS5deONt7jgFPM7L9m9pGZDY3jWAFu\n8LoeJplZjn+hAkcZr5mNBjY75xb5HCdE4b01s9+b2SbgMuBOH2OF6P2dAVxF6JuvX6IZq1/qEmN1\n2gObwrYPxH1UP0+Dvwe3mTUGRgCvhHXPpR/hy40liq2K6kQp3hRCTdATgKHAy2bW1fs2ETVRivUx\n4B5C33rvIdTNd1W0Ygx3tPGaWSPgl4S6S3wVrd9b59yvgF+Z2S+AG4DfRi3IMNH8OzOzXwEVwAvR\nie6Q14/mZ4IvaovRzK4E/scr6w5MM7MyYJ1z7kK/YmrwyYJQ62m3c25geKGZJQPzvM2phD60wpu9\nHYDNYfVTgO8BQ3yNNjrxFgCveclhtplVElo7Zlu8xeqc2xp23JPAG1GOMZrxdgO6AIu8P+AOwHwz\nG+ac+yrOYq3qBWAaPiULovd3Nh74DnBGtL/cRDtWn1UbI4Bz7u/A3wHM7ENgvHNufViVzcCpYdsd\nCI1tbOZofp6jHZiJxweQR9hAETATuMh7bsCAGo6rOvhzbti+UcBHiRAv8GPgbu/5cYSapBansbYN\nq/NTYHI8v7dV6qwnSgPcPr23PcLq3AhMief31vsbWwbkRjNOP38PiNIA95HGSM0D3OsIDW7neM+b\n1/X3usb4ov2fEusHoW6iL4FyQt+wryb0bfBtYJH3y1jtbCYgH1gKrAH+RtgHLPAM8ONEiBdIA573\n9s0HTo/jWJ8DlgCLCX2baxuNWP38XQirs57ozYby47191StfTGgdoPbx/N4Cqwl9sVnoPaIye8un\nWC/0XqsU2Aq8E4sYqSZZeOVXee/nauDKw/m9rumhK7hFRCSiY2U2lIiIHAUlCxERiUjJQkREIlKy\nEBGRiJQsREQkIiULadDMbH89n+8pM+sTpdcKWmjF2KVm9m+LsAqrmTUzs59E49wiVWnqrDRoZrbf\nOdc4iq+X4r5e7M5X4bGb2T+Alc6539dSPw94wznXtz7ik2OLWhZyzDGzXDN71czmeI+TvPJhZjbL\nzBaY2Uwz6+mVjzezqWb2PvCemZ1qZh+a2RQL3X/hBW9FYrzyfO/5fm8hv0Vm9pmZtfbKu3nbS8zs\n3jq2fmbx9WKGjc3sPTOb773GaK/OfUA3rzVyv1f3Nu9nXGxmv4vi2yjHGCULORb9FXjQOTcU+D7w\nlFe+HDjFOTeI0Aqt/y/smMHAD5xz3/K2BwE3A32ArsBJ1ZwnC/jMOTcA+Bi4Juz8f3XO9eObq4BW\ny1uz6AxCV7gDlAAXOucGE7p3yQNesroDWOOcG+icu83Mvg30AIYBA4EhZjYy0vlEqnMsLCQoUtWZ\nQJ+w1Tybeqt8ZgP/MLMehFbBTQ07ZrpzLvx+A7OdcwUAFrpzYh7wSZXzlPH1wojzgLO85yfy9X0E\nXgT+VEOcmd5rtwe+IHRfAgit6/P/vA/+Sm9/62qO/7b3WOBtNyaUPD6u4XwiNVKykGNREnCCc64k\nvNDM/gZ84Jy70Ov//zBsd1GV1ygNex6k+r+lcvf1oGBNdWpT7Jwb6C2N/g5wPfAQoXtT5AJDnHPl\nZrYeyKjmeAP+1zn3xGGeV+QQ6oaSY9G7hFZhBcDMDiwDnc3XSzaP9/H8nxHq/gIYE6mycy5A6Lao\nP/OWys8GCr1EcRrQ2au6D2gSdug7wFVeqwkza29mraL0M8gxRslCGrpGZlYQ9riF0Advvjfou4zQ\nku4AfwT+18wW4G+r+2bgFjNbTOjmNXsiHeCcW0Bo9dixhO5NkW9mS4ArCI214JzbAXzqTbW93zn3\nLqFurlle3Sl8M5mI1JmmzorUM69bqdg558xsDDDWOTc60nEisaQxC5H6NwT4mzeDaTc+3UZWJJrU\nshARkYg0ZiEiIhEpWYiISERKFiIiEpGShYiIRKRkISIiESlZiIhIRP8faKedvOAPLt4AAAAASUVO\nRK5CYII=\n", 625 | "text/plain": [ 626 | "
" 627 | ] 628 | }, 629 | "metadata": { 630 | "tags": [] 631 | } 632 | } 633 | ] 634 | }, 635 | { 636 | "metadata": { 637 | "id": "zLNRa2nP1zrR", 638 | "colab_type": "code", 639 | "outputId": "573f1fec-0df2-4e7a-8b36-79f23e7b1828", 640 | "colab": { 641 | "base_uri": "https://localhost:8080/", 642 | "height": 80 643 | } 644 | }, 645 | "cell_type": "code", 646 | "source": [ 647 | "# run one cycle training with max learning rate equal to 1/10 of min learning rate in plot\n", 648 | "\n", 649 | "# train final layer for 1 epoch\n", 650 | "cla_learn.fit_one_cycle(cyc_len=1, max_lr=1e-2, moms=(0.8, 0.7))" 651 | ], 652 | "execution_count": 0, 653 | "outputs": [ 654 | { 655 | "output_type": "display_data", 656 | "data": { 657 | "text/html": [ 658 | "\n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | "
epochtrain_lossvalid_lossaccuracytime
00.6222800.5989380.65886401:13
" 678 | ], 679 | "text/plain": [ 680 | "" 681 | ] 682 | }, 683 | "metadata": { 684 | "tags": [] 685 | } 686 | } 687 | ] 688 | }, 689 | { 690 | "metadata": { 691 | "id": "TyuC7O885orY", 692 | "colab_type": "code", 693 | "outputId": "cbddc552-41f7-45f7-ce5f-02fbbad3b67c", 694 | "colab": { 695 | "base_uri": "https://localhost:8080/", 696 | "height": 204 697 | } 698 | }, 699 | "cell_type": "code", 700 | "source": [ 701 | "# gradual unfreeze to train model\n", 702 | "cla_learn.freeze_to(-2)\n", 703 | "cla_learn.fit_one_cycle(5, max_lr=1e-3, moms=(0.8,0.7))" 704 | ], 705 | "execution_count": 0, 706 | "outputs": [ 707 | { 708 | "output_type": "display_data", 709 | "data": { 710 | "text/html": [ 711 | "\n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | "
epochtrain_lossvalid_lossaccuracytime
00.5962620.5785580.67948801:27
10.5696170.5412800.71058101:25
20.5438960.5148470.73141201:28
30.5260320.5043790.74218801:30
40.5230160.5029110.74275501:27
" 759 | ], 760 | "text/plain": [ 761 | "" 762 | ] 763 | }, 764 | "metadata": { 765 | "tags": [] 766 | } 767 | } 768 | ] 769 | }, 770 | { 771 | "metadata": { 772 | "id": "0mBqyYgMGsvk", 773 | "colab_type": "code", 774 | "colab": {} 775 | }, 776 | "cell_type": "code", 777 | "source": [ 778 | "" 779 | ], 780 | "execution_count": 0, 781 | "outputs": [] 782 | } 783 | ] 784 | } --------------------------------------------------------------------------------