├── data └── imgs │ ├── Slide2.JPG │ └── dfhead.png ├── LICENSE ├── README.md └── Sentiment_Analysis_torchtext.ipynb /data/imgs/Slide2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpanwar08/sentiment-analysis-torchtext/HEAD/data/imgs/Slide2.JPG -------------------------------------------------------------------------------- /data/imgs/dfhead.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpanwar08/sentiment-analysis-torchtext/HEAD/data/imgs/dfhead.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Himanshu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sentiment Analysis in Torch Text 2 | Sentiment analysis is a classification task where each sample is assigned a positive or negative label. 3 | 4 | This repo contains the code for the this [blog](https://medium.com/@sonicboom8/sentiment-analysis-torchtext-55fb57b1fab8). 5 | 6 | ## Typical components of classification task in NLP 7 | 1. Preprocessing and tokenization 8 | 2. Generating vocabulary of unique tokens and converting words to indices 9 | 3. Loading pretrained vectors e.g. Glove, Word2vec, Fasttext 10 | 4. Padding text with zeros in case of variable lengths 11 | 5. Dataloading and batching 12 | 6. Model creation and training 13 | 14 | ## Why use torchtext 15 | Torchtext provide set of classes that are useful in NLP tasks. These classes takes care of first 5 points above with very minimal code. 16 | 17 | ## Prerequisites 18 | * Python 3.6 19 | * [Pytorch 0.4](http://pytorch.org/) 20 | * [TorchText 0.2.3](https://github.com/pytorch/text) 21 | * Understanding of GRU/LSTM [1] 22 | 23 | ## What is covered in the [notebook](Sentiment%20analysis%20pytorch.ipynb) 24 | 25 | 1. Train validation split 26 | 2. Define how to process data 27 | 3. Create torchtext dataset 28 | 4. Load pretrained word vectors and building vocabulary 29 | 5. Loading the data in batches 30 | 6. Simple GRU model 31 | 6. GRU model with concat pooling 32 | 7. Training 33 | 34 | 35 | ## Data Overview 36 | 37 | ![Top 5 rows of dataset](data/imgs/dfhead.png "Top 5 rows of dataset") 38 | 39 | ## Concat Pooling model architecture [2] 40 | ![GRU model with concat pooling](data/imgs/Slide2.JPG "GRU model with concat pooling") 41 | 42 | ## References 43 | [1] https://colah.github.io/posts/2015-08-Understanding-LSTMs/ 44 | [2] https://arxiv.org/abs/1801.06146 45 | -------------------------------------------------------------------------------- /Sentiment_Analysis_torchtext.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Sentiment Analysis in torchtext" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "import os, sys\n", 25 | "import re\n", 26 | "import string\n", 27 | "import pathlib\n", 28 | "import random\n", 29 | "from collections import Counter, OrderedDict\n", 30 | "import numpy as np\n", 31 | "import pandas as pd\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "import seaborn as sns\n", 34 | "import spacy\n", 35 | "from tqdm import tqdm, tqdm_notebook, tnrange\n", 36 | "tqdm.pandas(desc='Progress')\n", 37 | "\n", 38 | "import torch\n", 39 | "import torch.nn as nn\n", 40 | "import torch.optim as optim\n", 41 | "from torch.autograd import Variable\n", 42 | "import torch.nn.functional as F\n", 43 | "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n", 44 | "\n", 45 | "import torchtext\n", 46 | "from torchtext import data\n", 47 | "from torchtext import vocab\n", 48 | "\n", 49 | "from sklearn.model_selection import StratifiedShuffleSplit, train_test_split\n", 50 | "from sklearn.metrics import accuracy_score\n", 51 | "\n", 52 | "from IPython.core.interactiveshell import InteractiveShell\n", 53 | "InteractiveShell.ast_node_interactivity='all'\n", 54 | "\n", 55 | "import warnings\n", 56 | "warnings.filterwarnings('ignore')\n", 57 | "\n", 58 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 48, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "Python version: 3.6.4 | packaged by conda-forge | (default, Dec 23 2017, 16:31:06) \n", 71 | "[GCC 4.8.2 20140120 (Red Hat 4.8.2-15)]\n", 72 | "Pandas version: 0.22.0\n", 73 | "Pytorch version: 0.4.0\n", 74 | "Torch Text version: 0.2.3\n", 75 | "Spacy version: 2.0.8\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "print('Python version:',sys.version)\n", 81 | "print('Pandas version:',pd.__version__)\n", 82 | "print('Pytorch version:', torch.__version__)\n", 83 | "print('Torch Text version:', torchtext.__version__)\n", 84 | "print('Spacy version:', spacy.__version__)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "### Load data" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 2, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "data_root = pathlib.Path('./data')" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 3, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stderr", 110 | "output_type": "stream", 111 | "text": [ 112 | "b'Skipping line 8836: expected 4 fields, saw 5\\n'\n", 113 | "b'Skipping line 535882: expected 4 fields, saw 7\\n'\n" 114 | ] 115 | }, 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "(1578612, 4)" 120 | ] 121 | }, 122 | "execution_count": 3, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | }, 126 | { 127 | "data": { 128 | "text/html": [ 129 | "
\n", 130 | "\n", 143 | "\n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | "
ItemIDSentimentSentimentSourceSentimentText
010Sentiment140is so sad for my APL frie...
120Sentiment140I missed the New Moon trail...
231Sentiment140omg its already 7:30 :O
340Sentiment140.. Omgaga. Im sooo im gunna CRy. I'...
450Sentiment140i think mi bf is cheating on me!!! ...
\n", 191 | "
" 192 | ], 193 | "text/plain": [ 194 | " ItemID Sentiment SentimentSource \\\n", 195 | "0 1 0 Sentiment140 \n", 196 | "1 2 0 Sentiment140 \n", 197 | "2 3 1 Sentiment140 \n", 198 | "3 4 0 Sentiment140 \n", 199 | "4 5 0 Sentiment140 \n", 200 | "\n", 201 | " SentimentText \n", 202 | "0 is so sad for my APL frie... \n", 203 | "1 I missed the New Moon trail... \n", 204 | "2 omg its already 7:30 :O \n", 205 | "3 .. Omgaga. Im sooo im gunna CRy. I'... \n", 206 | "4 i think mi bf is cheating on me!!! ... " 207 | ] 208 | }, 209 | "execution_count": 3, 210 | "metadata": {}, 211 | "output_type": "execute_result" 212 | } 213 | ], 214 | "source": [ 215 | "df = pd.read_csv(data_root/'Sentiment Analysis Dataset.csv', error_bad_lines=False)\n", 216 | "df.shape\n", 217 | "df.head()" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 4, 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "data": { 227 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgoAAAFACAYAAADd6lTCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGwhJREFUeJzt3X+w3fVd5/HnyyAttVIoDSwmVNBmXGm1FLIQ7W63FjeE+iPMTpmluyvZymycLnX9sVul7o7Rtox1dewu3cpuViKho6X0h0vUYJql1a4OUEJBKMWaK1W4BktoKFKppanv/eN8sh5uz+fec1JO7s3N8zFz5ny/7+/n8/187h+Z88r3Z6oKSZKkUb5usScgSZKWLoOCJEnqMihIkqQug4IkSeoyKEiSpC6DgiRJ6jIoSJKkLoOCJEnqMihIkqSu4xZ7AkvFi170ojrzzDMXexqSJB0Rd91112NVtXKhdgaF5swzz2TPnj2LPQ1Jko6IJH8xTjtPPUiSpC6DgiRJ6jIoSJKkLoOCJEnqmmpQSPITSe5P8skk703y3CRnJbkjyd4k70tyfGv7nLY+07afObSft7T6p5NcNFTf0GozSa4aqo8cQ5IkTWZqQSHJKuDfA2ur6mXACuAy4BeBd1bVGuBx4IrW5Qrg8ap6CfDO1o4kZ7d+LwU2AL+aZEWSFcC7gYuBs4HXt7bMM4YkSZrAtE89HAeckOQ44HnAI8BrgA+07duBS9ryxrZO235hkrT6jVX1par6DDADnN8+M1X1YFU9DdwIbGx9emNIkqQJTC0oVNVfAr8MPMQgIDwB3AV8vqoOtmazwKq2vAp4uPU92NqfMlyf06dXP2WeMSRJ0gSmeerhZAZHA84Cvgn4BganCeaqQ106256t+qg5bk6yJ8me/fv3j2oiSdIxbZqnHr4X+ExV7a+qLwMfAr4bOKmdigBYDexry7PAGQBt+wuAA8P1OX169cfmGeMZqmprVa2tqrUrVy74FEtJko450wwKDwHrkjyvXTdwIfAp4KPA61qbTcDNbXlHW6dt/0hVVatf1u6KOAtYA3wcuBNY0+5wOJ7BBY87Wp/eGJIkaQJTe9dDVd2R5APAJ4CDwN3AVuB3gRuTvL3VrmtdrgPek2SGwZGEy9p+7k9yE4OQcRC4sqq+ApDkTcAuBndUbKuq+9u+frozxhF33ptvWKyhpWfVXb90+WJPQdIimOpLoapqC7BlTvlBBncszG37t8Clnf1cDVw9or4T2DmiPnIMSceOh976HYs9BelZ8eKfvW9Rx/fJjJIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkrqmFhSSfFuSe4Y+f53kx5O8MMnuJHvb98mtfZJck2Qmyb1Jzh3a16bWfm+STUP185Lc1/pckyStPnIMSZI0makFhar6dFWdU1XnAOcBTwG/BVwF3FpVa4Bb2zrAxcCa9tkMXAuDH31gC3ABcD6wZeiH/9rW9lC/Da3eG0OSJE3gSJ16uBD4s6r6C2AjsL3VtwOXtOWNwA01cDtwUpLTgYuA3VV1oKoeB3YDG9q2E6vqtqoq4IY5+xo1hiRJmsCRCgqXAe9ty6dV1SMA7fvUVl8FPDzUZ7bV5qvPjqjPN8YzJNmcZE+SPfv37z/MP02SpOVr6kEhyfHADwLvX6jpiFodRn1sVbW1qtZW1dqVK1dO0lWSpGPCkTiicDHwiar6bFv/bDttQPt+tNVngTOG+q0G9i1QXz2iPt8YkiRpAkciKLyevz/tALADOHTnwibg5qH65e3uh3XAE+20wS5gfZKT20WM64FdbduTSda1ux0un7OvUWNIkqQJHDfNnSd5HvDPgB8ZKr8DuCnJFcBDwKWtvhN4LTDD4A6JNwBU1YEkbwPubO3eWlUH2vIbgeuBE4Bb2me+MSRJ0gSmGhSq6inglDm1zzG4C2Ju2wKu7OxnG7BtRH0P8LIR9ZFjSJKkyfhkRkmS1GVQkCRJXQYFSZLUZVCQJEldBgVJktRlUJAkSV0GBUmS1GVQkCRJXQYFSZLUZVCQJEldBgVJktRlUJAkSV0GBUmS1GVQkCRJXQYFSZLUZVCQJEldBgVJktRlUJAkSV0GBUmS1GVQkCRJXQYFSZLUZVCQJEldBgVJktRlUJAkSV1TDQpJTkrygSR/kuSBJN+V5IVJdifZ275Pbm2T5JokM0nuTXLu0H42tfZ7k2waqp+X5L7W55okafWRY0iSpMlM+4jCfwN+r6r+IfBy4AHgKuDWqloD3NrWAS4G1rTPZuBaGPzoA1uAC4DzgS1DP/zXtraH+m1o9d4YkiRpAlMLCklOBF4FXAdQVU9X1eeBjcD21mw7cElb3gjcUAO3AyclOR24CNhdVQeq6nFgN7ChbTuxqm6rqgJumLOvUWNIkqQJTPOIwrcA+4FfT3J3kl9L8g3AaVX1CED7PrW1XwU8PNR/ttXmq8+OqDPPGM+QZHOSPUn27N+///D/UkmSlqlpBoXjgHOBa6vqFcDfMP8pgIyo1WHUx1ZVW6tqbVWtXbly5SRdJUk6JkwzKMwCs1V1R1v/AIPg8Nl22oD2/ehQ+zOG+q8G9i1QXz2izjxjSJKkCUwtKFTVXwEPJ/m2VroQ+BSwAzh058Im4Oa2vAO4vN39sA54op022AWsT3Jyu4hxPbCrbXsyybp2t8Plc/Y1agxJkjSB46a8/x8FfiPJ8cCDwBsYhJObklwBPARc2truBF4LzABPtbZU1YEkbwPubO3eWlUH2vIbgeuBE4Bb2gfgHZ0xJEnSBKYaFKrqHmDtiE0XjmhbwJWd/WwDto2o7wFeNqL+uVFjSJKkyfhkRkmS1GVQkCRJXQYFSZLUZVCQJEldBgVJktRlUJAkSV0GBUmS1GVQkCRJXQYFSZLUZVCQJEldBgVJktRlUJAkSV0GBUmS1GVQkCRJXQYFSZLUZVCQJEldBgVJktRlUJAkSV0GBUmS1GVQkCRJXQYFSZLUZVCQJEldBgVJktRlUJAkSV1TDQpJ/jzJfUnuSbKn1V6YZHeSve375FZPkmuSzCS5N8m5Q/vZ1NrvTbJpqH5e2/9M65v5xpAkSZM5EkcUvqeqzqmqtW39KuDWqloD3NrWAS4G1rTPZuBaGPzoA1uAC4DzgS1DP/zXtraH+m1YYAxJkjSBxTj1sBHY3pa3A5cM1W+ogduBk5KcDlwE7K6qA1X1OLAb2NC2nVhVt1VVATfM2deoMSRJ0gSmHRQK+HCSu5JsbrXTquoRgPZ9aquvAh4e6jvbavPVZ0fU5xvjGZJsTrInyZ79+/cf5p8oSdLyddyU9//KqtqX5FRgd5I/madtRtTqMOpjq6qtwFaAtWvXTtRXkqRjwVSPKFTVvvb9KPBbDK4x+Gw7bUD7frQ1nwXOGOq+Gti3QH31iDrzjCFJkiYwtaCQ5BuSfOOhZWA98ElgB3DozoVNwM1teQdwebv7YR3wRDttsAtYn+TkdhHjemBX2/ZkknXtbofL5+xr1BiSJGkCY516SPLKqvqjhWpznAb8Vrtj8TjgN6vq95LcCdyU5ArgIeDS1n4n8FpgBngKeANAVR1I8jbgztburVV1oC2/EbgeOAG4pX0A3tEZQ5IkTWDcaxTeBZw7Ru3/q6oHgZePqH8OuHBEvYArO/vaBmwbUd8DvGzcMSRJ0mTmDQpJvgv4bmBlkp8c2nQisGKaE5MkSYtvoSMKxwPPb+2+caj+18DrpjUpSZK0NMwbFKrqD4A/SHJ9Vf3FEZqTJElaIsa9RuE5SbYCZw73qarXTGNSkiRpaRg3KLwf+B/ArwFfmd50JEnSUjJuUDhYVddOdSaSJGnJGfeBS7+d5N8lOb29wvmF7a2OkiRpGRv3iMKhpxy+eahWwLc8u9ORJElLyVhBoarOmvZEJEnS0jPWqYckz0vyn9udDyRZk+T7pzs1SZK02Ma9RuHXgacZPKURBm9ufPtUZiRJkpaMcYPCt1bVfwG+DFBVXwQytVlJkqQlYdyg8HSSExhcwEiSbwW+NLVZSZKkJWHcux62AL8HnJHkN4BXAv9mWpOSJElLw7h3PexO8glgHYNTDj9WVY9NdWaSJGnRjXvqAWAVg1dLHw+8Ksk/n86UJEnSUjHWEYUk24DvBO4H/q6VC/jQlOYlSZKWgHGvUVhXVWdPdSaSJGnJGffUw21JDAqSJB1jxj2isJ1BWPgrBrdFBqiq+s6pzUySJC26cYPCNuCHgPv4+2sUJEnSMjduUHioqnZMdSaSJGnJGTco/EmS3wR+m6EnMlaVdz1IkrSMjXsx4wkMAsJ64AfaZ6y3RyZZkeTuJL/T1s9KckeSvUnel+T4Vn9OW59p288c2sdbWv3TSS4aqm9otZkkVw3VR44hSZImM1ZQqKo3jPj88Jhj/BjwwND6LwLvrKo1wOPAFa1+BfB4Vb0EeGdrR7vb4jLgpcAG4Fdb+FgBvBu4GDgbeP3QnRm9MSRJ0gTmDQpJfqp9vyvJNXM/C+08yWrg+4Bfa+sBXgN8oDXZDlzSlje2ddr2C1v7jcCNVfWlqvoMMAOc3z4zVfVgVT0N3AhsXGAMSZI0gYWuUTh0JGDPYe7/vwI/BXxjWz8F+HxVHWzrswweDU37fhigqg4meaK1XwXcPrTP4T4Pz6lfsMAYkiRpAvMGhar67bb4VFW9f3hbkkvn65vk+4FHq+quJK8+VB41zALbevVRR0Pmaz9qjpuBzQAvfvGLRzWRJOmYNu7FjG8ZszbslcAPJvlzBqcFXsPgCMNJSQ4FlNXAvrY8C5wB0La/ADgwXJ/Tp1d/bJ4xnqGqtlbV2qpau3LlygX+HEmSjj0LXaNwcZJ3AavmXJ9wPXBwvr5V9ZaqWl1VZzK4GPEjVfWvgI8Cr2vNNgE3t+UdbZ22/SNVVa1+Wbsr4ixgDfBx4E5gTbvD4fg2xo7WpzeGJEmawELXKOxjcH3CDwJ3DdWfBH7iMMf8aeDGJG8H7gaua/XrgPckmWFwJOEygKq6P8lNwKcYhJMrq+orAEneBOxi8PrrbVV1/wJjSJKkCSx0jcIfA3+c5Der6suHO0hV/T7w+235QQZ3LMxt87fAyOsequpq4OoR9Z3AzhH1kWNIkqTJjPtkxvOT/Bzwza3PoZdCfcu0JiZJkhbfuEHhOganGu4CvjK96UiSpKVk3KDwRFXdMtWZSJKkJWfcoPDRJL8EfIhnvhTqE1OZlSRJWhLGDQoXtO+1Q7Vi8GwESZK0TI0VFKrqe6Y9EUmStPSM9WTGJKcluS7JLW397CS+kVGSpGVu3Ec4X8/gwUbf1Nb/FPjxaUxIkiQtHeMGhRdV1U3A38Hg7Y54m6QkScveuEHhb5KcQnsLY5J1wBNTm5UkSVoSxr3r4ScZvJzpW5P8EbCSv3/pkiRJWqYWenvkP0ryD9rzEv4p8DMMnqPwYQaveZYkScvYQqce/ifwdFv+buA/Ae8GHge2TnFekiRpCVjo1MOKqjrQlv8FsLWqPgh8MMk9052aJElabAsdUViR5FCYuBD4yNC2ca9vkCRJR6mFfuzfC/xBkseALwL/FyDJS/CuB0mSlr15g0JVXZ3kVuB04MNVVW3T1wE/Ou3JSZKkxbXg6YOqun1E7U+nMx1JkrSUjPvAJUmSdAwyKEiSpC6DgiRJ6jIoSJKkLoOCJEnqMihIkqSuqQWFJM9N8vEkf5zk/iQ/3+pnJbkjyd4k70tyfKs/p63PtO1nDu3rLa3+6SQXDdU3tNpMkquG6iPHkCRJk5nmEYUvAa+pqpcD5wAbkqwDfhF4Z1WtYfByqSta+yuAx6vqJcA7WzuSnA1cBrwU2AD8apIVSVYweEHVxcDZwOtbW+YZQ5IkTWBqQaEGvtBWv759CngN8IFW3w5c0pY3tnXa9guTpNVvrKovVdVngBng/PaZqaoHq+pp4EZgY+vTG0OSJE1gqtcotP/53wM8CuwG/gz4fFUdbE1mgVVteRXwMEDb/gRwynB9Tp9e/ZR5xpAkSROYalCoqq9U1TnAagZHAL59VLP2nc62Z6v+VZJsTrInyZ79+/ePaiJJ0jHtiNz1UFWfB34fWAecNPTq6tXAvrY8C5wB0La/ADgwXJ/Tp1d/bJ4x5s5ra1Wtraq1K1eu/Fr+REmSlqVp3vWwMslJbfkE4HuBB4CPAq9rzTYBN7flHW2dtv0j7W2VO4DL2l0RZwFrgI8DdwJr2h0OxzO44HFH69MbQ5IkTWDBt0d+DU4Htre7E74OuKmqfifJp4Abk7wduBu4rrW/DnhPkhkGRxIuA6iq+5PcBHwKOAhcWVVfAUjyJmAXsALYVlX3t339dGcMSZI0gakFhaq6F3jFiPqDDK5XmFv/W+DSzr6uBq4eUd8J7Bx3DEmSNBmfzChJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKlrakEhyRlJPprkgST3J/mxVn9hkt1J9rbvk1s9Sa5JMpPk3iTnDu1rU2u/N8mmofp5Se5rfa5JkvnGkCRJk5nmEYWDwH+oqm8H1gFXJjkbuAq4tarWALe2dYCLgTXtsxm4FgY/+sAW4ALgfGDL0A//ta3toX4bWr03hiRJmsDUgkJVPVJVn2jLTwIPAKuAjcD21mw7cElb3gjcUAO3AyclOR24CNhdVQeq6nFgN7ChbTuxqm6rqgJumLOvUWNIkqQJHJFrFJKcCbwCuAM4raoegUGYAE5tzVYBDw91m221+eqzI+rMM8bceW1OsifJnv379x/unydJ0rI19aCQ5PnAB4Efr6q/nq/piFodRn1sVbW1qtZW1dqVK1dO0lWSpGPCVINCkq9nEBJ+o6o+1MqfbacNaN+PtvoscMZQ99XAvgXqq0fU5xtDkiRNYJp3PQS4Dnigqn5laNMO4NCdC5uAm4fql7e7H9YBT7TTBruA9UlObhcxrgd2tW1PJlnXxrp8zr5GjSFJkiZw3BT3/Urgh4D7ktzTaj8DvAO4KckVwEPApW3bTuC1wAzwFPAGgKo6kORtwJ2t3Vur6kBbfiNwPXACcEv7MM8YkiRpAlMLClX1h4y+jgDgwhHtC7iys69twLYR9T3Ay0bUPzdqDEmSNBmfzChJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKnLoCBJkroMCpIkqcugIEmSugwKkiSpy6AgSZK6DAqSJKlrakEhybYkjyb55FDthUl2J9nbvk9u9SS5JslMknuTnDvUZ1NrvzfJpqH6eUnua32uSZL5xpAkSZOb5hGF64ENc2pXAbdW1Rrg1rYOcDGwpn02A9fC4Ecf2AJcAJwPbBn64b+2tT3Ub8MCY0iSpAlNLShU1ceAA3PKG4HtbXk7cMlQ/YYauB04KcnpwEXA7qo6UFWPA7uBDW3biVV1W1UVcMOcfY0aQ5IkTehIX6NwWlU9AtC+T231VcDDQ+1mW22++uyI+nxjfJUkm5PsSbJn//79h/1HSZK0XC2VixkzolaHUZ9IVW2tqrVVtXblypWTdpckadk70kHhs+20Ae370VafBc4Yarca2LdAffWI+nxjSJKkCR3poLADOHTnwibg5qH65e3uh3XAE+20wS5gfZKT20WM64FdbduTSda1ux0un7OvUWNIkqQJHTetHSd5L/Bq4EVJZhncvfAO4KYkVwAPAZe25juB1wIzwFPAGwCq6kCStwF3tnZvrapDF0i+kcGdFScAt7QP84whSZImNLWgUFWv72y6cETbAq7s7GcbsG1EfQ/wshH1z40aQ5IkTW6pXMwoSZKWIIOCJEnqMihIkqQug4IkSeoyKEiSpC6DgiRJ6jIoSJKkLoOCJEnqMihIkqQug4IkSeoyKEiSpC6DgiRJ6jIoSJKkLoOCJEnqMihIkqQug4IkSeoyKEiSpC6DgiRJ6jIoSJKkLoOCJEnqMihIkqQug4IkSeoyKEiSpK5lGxSSbEjy6SQzSa5a7PlIknQ0WpZBIckK4N3AxcDZwOuTnL24s5Ik6eizLIMCcD4wU1UPVtXTwI3AxkWekyRJR53lGhRWAQ8Prc+2miRJmsBxiz2BKcmIWn1Vo2QzsLmtfiHJp6c6K03Li4DHFnsSy11+edNiT0FLk//+pm3LqJ+0Z8U3j9NouQaFWeCMofXVwL65japqK7D1SE1K05FkT1WtXex5SMci//0tf8v11MOdwJokZyU5HrgM2LHIc5Ik6aizLI8oVNXBJG8CdgErgG1Vdf8iT0uSpKPOsgwKAFW1E9i52PPQEeHpI2nx+O9vmUvVV13jJ0mSBCzfaxQkSdKzwKAgSZK6DAo6qvlOD2lxJNmW5NEkn1zsuWi6DAo6avlOD2lRXQ9sWOxJaPoMCjqa+U4PaZFU1ceAA4s9D02fQUFHM9/pIUlTZlDQ0Wysd3pIkg6fQUFHs7He6SFJOnwGBR3NfKeHJE2ZQUFHrao6CBx6p8cDwE2+00M6MpK8F7gN+LYks0muWOw5aTp8hLMkSeryiIIkSeoyKEiSpC6DgiRJ6jIoSJKkLoOCJEnqMihI+pol+cIEbX8uyX+c1v4lPbsMCpIkqcugIGkqkvxAkjuS3J3k/yQ5bWjzy5N8JMneJP92qM+bk9yZ5N4kPz9in6cn+ViSe5J8Msk/OSJ/jHQMMyhImpY/BNZV1SsYvAL8p4a2fSfwfcB3AT+b5JuSrAfWMHh9+DnAeUleNWef/xLYVVXnAC8H7pny3yAd845b7AlIWrZWA+9LcjpwPPCZoW03V9UXgS8m+SiDcPCPgfXA3a3N8xkEh48N9bsT2Jbk64H/XVUGBWnKPKIgaVreBfz3qvoO4EeA5w5tm/vs+GLw2vBfqKpz2uclVXXdMxpVfQx4FfCXwHuSXD696UsCg4Kk6XkBgx90gE1ztm1M8twkpwCvZnCkYBfww0meD5BkVZJThzsl+Wbg0ar6X8B1wLlTnL8kPPUg6dnxvCSzQ+u/Avwc8P4kfwncDpw1tP3jwO8CLwbeVlX7gH1Jvh24LQnAF4B/DTw61O/VwJuTfLlt94iCNGW+PVKSJHV56kGSJHUZFCRJUpdBQZIkdRkUJElSl0FBkiR1GRQkSVKXQUGSJHX9P7r5PllIE9HFAAAAAElFTkSuQmCC\n", 228 | "text/plain": [ 229 | "
" 230 | ] 231 | }, 232 | "metadata": {}, 233 | "output_type": "display_data" 234 | } 235 | ], 236 | "source": [ 237 | "fig = plt.figure(figsize=(8,5))\n", 238 | "ax = sns.barplot(x=df.Sentiment.unique(),y=df.Sentiment.value_counts());\n", 239 | "ax.set(xlabel='Labels');" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "## Train validation split" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "##### torchtext have trouble handling \\n. Replace \\n character with space." 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 5, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stderr", 263 | "output_type": "stream", 264 | "text": [ 265 | "Progress: 100%|██████████| 1578612/1578612 [00:02<00:00, 655831.85it/s]\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "df['SentimentText'] = df.SentimentText.progress_apply(lambda x: re.sub('\\n', ' ', x))" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 6, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "def split_train_test(df, test_size=0.2):\n", 280 | " train, val = train_test_split(df, test_size=test_size,random_state=42)\n", 281 | " return train.reset_index(drop=True), val.reset_index(drop=True)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 7, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "traindf, valdf = split_train_test(df, test_size=0.2)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 8, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "data": { 300 | "text/plain": [ 301 | "(1262889, 4)" 302 | ] 303 | }, 304 | "execution_count": 8, 305 | "metadata": {}, 306 | "output_type": "execute_result" 307 | }, 308 | { 309 | "data": { 310 | "text/html": [ 311 | "
\n", 312 | "\n", 325 | "\n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | "
ItemIDSentimentSentimentSourceSentimentText
03639191Sentiment140@p3cia hihi.. already looked
110026891Sentiment140@lizzylou62 Good luck with the exams!
212575430Sentiment140The krispy kreme in CT is so closed
34958961Sentiment140@TomJ93 because of what @_nanu_ said
44454700Sentiment140@TellYaFriday I have nothing else to do...i'm...
\n", 373 | "
" 374 | ], 375 | "text/plain": [ 376 | " ItemID Sentiment SentimentSource \\\n", 377 | "0 363919 1 Sentiment140 \n", 378 | "1 1002689 1 Sentiment140 \n", 379 | "2 1257543 0 Sentiment140 \n", 380 | "3 495896 1 Sentiment140 \n", 381 | "4 445470 0 Sentiment140 \n", 382 | "\n", 383 | " SentimentText \n", 384 | "0 @p3cia hihi.. already looked \n", 385 | "1 @lizzylou62 Good luck with the exams! \n", 386 | "2 The krispy kreme in CT is so closed \n", 387 | "3 @TomJ93 because of what @_nanu_ said \n", 388 | "4 @TellYaFriday I have nothing else to do...i'm... " 389 | ] 390 | }, 391 | "execution_count": 8, 392 | "metadata": {}, 393 | "output_type": "execute_result" 394 | }, 395 | { 396 | "data": { 397 | "text/plain": [ 398 | "1 632124\n", 399 | "0 630765\n", 400 | "Name: Sentiment, dtype: int64" 401 | ] 402 | }, 403 | "execution_count": 8, 404 | "metadata": {}, 405 | "output_type": "execute_result" 406 | } 407 | ], 408 | "source": [ 409 | "traindf.shape\n", 410 | "traindf.head()\n", 411 | "traindf.Sentiment.value_counts()" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 9, 417 | "metadata": {}, 418 | "outputs": [ 419 | { 420 | "data": { 421 | "text/plain": [ 422 | "(315723, 4)" 423 | ] 424 | }, 425 | "execution_count": 9, 426 | "metadata": {}, 427 | "output_type": "execute_result" 428 | }, 429 | { 430 | "data": { 431 | "text/html": [ 432 | "
\n", 433 | "\n", 446 | "\n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | "
ItemIDSentimentSentimentSourceSentimentText
014327171Sentiment140http://www.popsugar.com/2999655 keep voting fo...
18154801Sentiment140I follow @actionchick because she always has ...
211437011Sentiment140Slow This Dance Now
310440450Sentiment140no win on the ipod for tonight
49798540Sentiment140@LegendaryWriter tell me about it
\n", 494 | "
" 495 | ], 496 | "text/plain": [ 497 | " ItemID Sentiment SentimentSource \\\n", 498 | "0 1432717 1 Sentiment140 \n", 499 | "1 815480 1 Sentiment140 \n", 500 | "2 1143701 1 Sentiment140 \n", 501 | "3 1044045 0 Sentiment140 \n", 502 | "4 979854 0 Sentiment140 \n", 503 | "\n", 504 | " SentimentText \n", 505 | "0 http://www.popsugar.com/2999655 keep voting fo... \n", 506 | "1 I follow @actionchick because she always has ... \n", 507 | "2 Slow This Dance Now \n", 508 | "3 no win on the ipod for tonight \n", 509 | "4 @LegendaryWriter tell me about it " 510 | ] 511 | }, 512 | "execution_count": 9, 513 | "metadata": {}, 514 | "output_type": "execute_result" 515 | }, 516 | { 517 | "data": { 518 | "text/plain": [ 519 | "1 158053\n", 520 | "0 157670\n", 521 | "Name: Sentiment, dtype: int64" 522 | ] 523 | }, 524 | "execution_count": 9, 525 | "metadata": {}, 526 | "output_type": "execute_result" 527 | } 528 | ], 529 | "source": [ 530 | "valdf.shape\n", 531 | "valdf.head()\n", 532 | "valdf.Sentiment.value_counts()" 533 | ] 534 | }, 535 | { 536 | "cell_type": "markdown", 537 | "metadata": {}, 538 | "source": [ 539 | "##### Save the train and validation df" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 10, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "traindf.to_csv(data_root/'traindf.csv', index=False)\n", 549 | "valdf.to_csv(data_root/'valdf.csv', index=False)" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "metadata": {}, 555 | "source": [ 556 | "## 1. Define how to process data" 557 | ] 558 | }, 559 | { 560 | "cell_type": "markdown", 561 | "metadata": {}, 562 | "source": [ 563 | "##### Preprocessing" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": 11, 569 | "metadata": {}, 570 | "outputs": [], 571 | "source": [ 572 | "nlp = spacy.load('en',disable=['parser', 'tagger', 'ner'])\n", 573 | "def tokenizer(s): return [w.text.lower() for w in nlp(tweet_clean(s))]" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 12, 579 | "metadata": {}, 580 | "outputs": [], 581 | "source": [ 582 | "def tweet_clean(text):\n", 583 | " text = re.sub(r'[^A-Za-z0-9]+', ' ', text) # remove non alphanumeric character\n", 584 | " text = re.sub(r'https?:/\\/\\S+', ' ', text) # remove links\n", 585 | " return text.strip()" 586 | ] 587 | }, 588 | { 589 | "cell_type": "markdown", 590 | "metadata": {}, 591 | "source": [ 592 | "##### Define fields" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 13, 598 | "metadata": {}, 599 | "outputs": [], 600 | "source": [ 601 | "txt_field = data.Field(sequential=True, tokenize=tokenizer, include_lengths=True, use_vocab=True)\n", 602 | "label_field = data.Field(sequential=False, use_vocab=False, pad_token=None, unk_token=None)\n", 603 | "\n", 604 | "train_val_fields = [\n", 605 | " ('ItemID', None),\n", 606 | " ('Sentiment', label_field),\n", 607 | " ('SentimentSource', None),\n", 608 | " ('SentimentText', txt_field)\n", 609 | "]" 610 | ] 611 | }, 612 | { 613 | "cell_type": "markdown", 614 | "metadata": {}, 615 | "source": [ 616 | "## 2. Create torchtext dataset" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 14, 622 | "metadata": {}, 623 | "outputs": [ 624 | { 625 | "name": "stdout", 626 | "output_type": "stream", 627 | "text": [ 628 | "CPU times: user 2min 55s, sys: 787 ms, total: 2min 56s\n", 629 | "Wall time: 2min 56s\n" 630 | ] 631 | } 632 | ], 633 | "source": [ 634 | "%%time\n", 635 | "trainds, valds = data.TabularDataset.splits(path='./data', format='csv', train='traindf.csv', validation='valdf.csv', fields=train_val_fields, skip_header=True)" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": 15, 641 | "metadata": {}, 642 | "outputs": [ 643 | { 644 | "data": { 645 | "text/plain": [ 646 | "torchtext.data.dataset.TabularDataset" 647 | ] 648 | }, 649 | "execution_count": 15, 650 | "metadata": {}, 651 | "output_type": "execute_result" 652 | } 653 | ], 654 | "source": [ 655 | "type(trainds)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 16, 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "data": { 665 | "text/plain": [ 666 | "(1262889, 315723)" 667 | ] 668 | }, 669 | "execution_count": 16, 670 | "metadata": {}, 671 | "output_type": "execute_result" 672 | } 673 | ], 674 | "source": [ 675 | "len(trainds), len(valds)" 676 | ] 677 | }, 678 | { 679 | "cell_type": "code", 680 | "execution_count": 17, 681 | "metadata": {}, 682 | "outputs": [ 683 | { 684 | "data": { 685 | "text/plain": [ 686 | "torchtext.data.example.Example" 687 | ] 688 | }, 689 | "execution_count": 17, 690 | "metadata": {}, 691 | "output_type": "execute_result" 692 | }, 693 | { 694 | "data": { 695 | "text/plain": [ 696 | "dict_items([('ItemID', None), ('Sentiment', ), ('SentimentSource', None), ('SentimentText', )])" 697 | ] 698 | }, 699 | "execution_count": 17, 700 | "metadata": {}, 701 | "output_type": "execute_result" 702 | }, 703 | { 704 | "data": { 705 | "text/plain": [ 706 | "'1'" 707 | ] 708 | }, 709 | "execution_count": 17, 710 | "metadata": {}, 711 | "output_type": "execute_result" 712 | }, 713 | { 714 | "data": { 715 | "text/plain": [ 716 | "['p3cia', 'hihi', 'already', 'looked']" 717 | ] 718 | }, 719 | "execution_count": 17, 720 | "metadata": {}, 721 | "output_type": "execute_result" 722 | } 723 | ], 724 | "source": [ 725 | "ex = trainds[0]\n", 726 | "type(ex)\n", 727 | "trainds.fields.items()\n", 728 | "ex.Sentiment\n", 729 | "ex.SentimentText" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": 18, 735 | "metadata": {}, 736 | "outputs": [ 737 | { 738 | "data": { 739 | "text/plain": [ 740 | "torchtext.data.example.Example" 741 | ] 742 | }, 743 | "execution_count": 18, 744 | "metadata": {}, 745 | "output_type": "execute_result" 746 | }, 747 | { 748 | "data": { 749 | "text/plain": [ 750 | "'1'" 751 | ] 752 | }, 753 | "execution_count": 18, 754 | "metadata": {}, 755 | "output_type": "execute_result" 756 | }, 757 | { 758 | "data": { 759 | "text/plain": [ 760 | "['http',\n", 761 | " 'www',\n", 762 | " 'popsugar',\n", 763 | " 'com',\n", 764 | " '2999655',\n", 765 | " 'keep',\n", 766 | " 'voting',\n", 767 | " 'for',\n", 768 | " 'robert',\n", 769 | " 'pattinson',\n", 770 | " 'in',\n", 771 | " 'the',\n", 772 | " 'popsugar100',\n", 773 | " 'as',\n", 774 | " 'well']" 775 | ] 776 | }, 777 | "execution_count": 18, 778 | "metadata": {}, 779 | "output_type": "execute_result" 780 | } 781 | ], 782 | "source": [ 783 | "ex = valds[0]\n", 784 | "type(ex)\n", 785 | "ex.Sentiment\n", 786 | "ex.SentimentText" 787 | ] 788 | }, 789 | { 790 | "cell_type": "markdown", 791 | "metadata": {}, 792 | "source": [ 793 | "## 3. Load pretrained word vectors and building vocabulary" 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": 19, 799 | "metadata": {}, 800 | "outputs": [ 801 | { 802 | "name": "stdout", 803 | "output_type": "stream", 804 | "text": [ 805 | "CPU times: user 311 ms, sys: 464 ms, total: 775 ms\n", 806 | "Wall time: 1.31 s\n" 807 | ] 808 | } 809 | ], 810 | "source": [ 811 | "%%time\n", 812 | "vec = vocab.Vectors('glove.twitter.27B.100d.txt', './data/glove_embedding/')" 813 | ] 814 | }, 815 | { 816 | "cell_type": "code", 817 | "execution_count": 20, 818 | "metadata": {}, 819 | "outputs": [ 820 | { 821 | "name": "stdout", 822 | "output_type": "stream", 823 | "text": [ 824 | "CPU times: user 9.28 s, sys: 36 ms, total: 9.32 s\n", 825 | "Wall time: 9.32 s\n" 826 | ] 827 | } 828 | ], 829 | "source": [ 830 | "%%time\n", 831 | "txt_field.build_vocab(trainds, valds, max_size=100000, vectors=vec)\n", 832 | "label_field.build_vocab(trainds)" 833 | ] 834 | }, 835 | { 836 | "cell_type": "code", 837 | "execution_count": 21, 838 | "metadata": {}, 839 | "outputs": [ 840 | { 841 | "data": { 842 | "text/plain": [ 843 | "torch.Size([100002, 100])" 844 | ] 845 | }, 846 | "execution_count": 21, 847 | "metadata": {}, 848 | "output_type": "execute_result" 849 | } 850 | ], 851 | "source": [ 852 | "txt_field.vocab.vectors.shape" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "execution_count": 22, 858 | "metadata": {}, 859 | "outputs": [ 860 | { 861 | "data": { 862 | "text/plain": [ 863 | "tensor([ 0.0952, 0.3702, 0.5429, 0.1962, 0.0482, 0.3203, -0.5964,\n", 864 | " 0.0159, -0.1299, -0.6303, 0.0819, 0.2416, -6.0990, -0.6856,\n", 865 | " 0.5035, -0.0341, 0.1170, -0.0077, -0.0865, 0.4362, -0.4398,\n", 866 | " 0.2612, -0.0403, -0.1919, 0.0832, -0.5825, -0.0319, 0.1263,\n", 867 | " 0.4012, 0.0689, -0.1052, -0.2080, -0.4255, 0.4780, 0.3465,\n", 868 | " 0.2406, 0.0502, -0.0726, -0.0024, -0.5034, -1.0601, -0.3159,\n", 869 | " -0.0325, -0.0763, 0.7904, 0.0864, -0.1963, 0.0576, 0.8413,\n", 870 | " -0.4202, -0.0011, -0.0856, 0.0619, 0.2142, -0.1036, -0.0369,\n", 871 | " -0.2600, -0.3566, 0.0543, 0.0309, 0.1409, -0.0920, -0.4184,\n", 872 | " -0.3113, -0.1494, -0.0002, -0.3345, -0.1485, -0.1194, -0.2717,\n", 873 | " 0.3132, -0.1100, -0.4752, 0.1406, 0.3964, -0.0494, -0.4260,\n", 874 | " -0.2358, 0.0615, -0.0353, 2.4161, 0.2898, 0.3888, 0.3678,\n", 875 | " 0.2069, 0.1399, -0.4246, 0.4459, 0.2623, -0.4483, 0.0037,\n", 876 | " -0.2252, 0.1476, -0.3642, -0.1849, 0.2228, 0.4763, -0.5108,\n", 877 | " 0.4688, 0.3488])" 878 | ] 879 | }, 880 | "execution_count": 22, 881 | "metadata": {}, 882 | "output_type": "execute_result" 883 | } 884 | ], 885 | "source": [ 886 | "txt_field.vocab.vectors[txt_field.vocab.stoi['the']]" 887 | ] 888 | }, 889 | { 890 | "cell_type": "markdown", 891 | "metadata": {}, 892 | "source": [ 893 | "## 4. Loading the data in batches" 894 | ] 895 | }, 896 | { 897 | "cell_type": "code", 898 | "execution_count": 24, 899 | "metadata": {}, 900 | "outputs": [], 901 | "source": [ 902 | "traindl, valdl = data.BucketIterator.splits(datasets=(trainds, valds), \n", 903 | " batch_sizes=(3,3), \n", 904 | " sort_key=lambda x: len(x.SentimentText), \n", 905 | " device=None, \n", 906 | " sort_within_batch=True, \n", 907 | " repeat=False)" 908 | ] 909 | }, 910 | { 911 | "cell_type": "code", 912 | "execution_count": 25, 913 | "metadata": {}, 914 | "outputs": [ 915 | { 916 | "data": { 917 | "text/plain": [ 918 | "(420963, 105241)" 919 | ] 920 | }, 921 | "execution_count": 25, 922 | "metadata": {}, 923 | "output_type": "execute_result" 924 | } 925 | ], 926 | "source": [ 927 | "len(traindl), len(valdl)" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "execution_count": 26, 933 | "metadata": {}, 934 | "outputs": [ 935 | { 936 | "data": { 937 | "text/plain": [ 938 | "torchtext.data.batch.Batch" 939 | ] 940 | }, 941 | "execution_count": 26, 942 | "metadata": {}, 943 | "output_type": "execute_result" 944 | } 945 | ], 946 | "source": [ 947 | "batch = next(iter(traindl))\n", 948 | "type(batch)" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": 27, 954 | "metadata": {}, 955 | "outputs": [ 956 | { 957 | "data": { 958 | "text/plain": [ 959 | "tensor([ 1, 0, 0], device='cuda:0')" 960 | ] 961 | }, 962 | "execution_count": 27, 963 | "metadata": {}, 964 | "output_type": "execute_result" 965 | } 966 | ], 967 | "source": [ 968 | "batch.Sentiment" 969 | ] 970 | }, 971 | { 972 | "cell_type": "markdown", 973 | "metadata": {}, 974 | "source": [ 975 | "##### returns word indices and lengths" 976 | ] 977 | }, 978 | { 979 | "cell_type": "code", 980 | "execution_count": 28, 981 | "metadata": {}, 982 | "outputs": [ 983 | { 984 | "data": { 985 | "text/plain": [ 986 | "(tensor([[ 3590, 0, 88],\n", 987 | " [ 88, 183, 386],\n", 988 | " [ 274, 100, 2],\n", 989 | " [ 2, 22, 14],\n", 990 | " [ 49, 7, 17],\n", 991 | " [ 49, 13, 5732],\n", 992 | " [ 49, 378, 17],\n", 993 | " [ 21087, 89, 427],\n", 994 | " [ 2, 9, 846],\n", 995 | " [ 67, 95, 816],\n", 996 | " [ 103, 4329, 3986],\n", 997 | " [ 17, 2, 22],\n", 998 | " [ 299, 14, 2],\n", 999 | " [ 7, 434, 66],\n", 1000 | " [ 18, 1286, 15],\n", 1001 | " [ 48, 2, 64],\n", 1002 | " [ 646, 29, 134],\n", 1003 | " [ 183, 15, 3264],\n", 1004 | " [ 23, 132, 35],\n", 1005 | " [ 379, 3, 2],\n", 1006 | " [ 3, 1001, 40],\n", 1007 | " [ 279, 18, 134],\n", 1008 | " [ 265, 959, 5332]], device='cuda:0'),\n", 1009 | " tensor([ 23, 23, 23], device='cuda:0'))" 1010 | ] 1011 | }, 1012 | "execution_count": 28, 1013 | "metadata": {}, 1014 | "output_type": "execute_result" 1015 | } 1016 | ], 1017 | "source": [ 1018 | "batch.SentimentText" 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "code", 1023 | "execution_count": 29, 1024 | "metadata": {}, 1025 | "outputs": [ 1026 | { 1027 | "data": { 1028 | "text/plain": [ 1029 | "{'ItemID': None,\n", 1030 | " 'Sentiment': ,\n", 1031 | " 'SentimentSource': None,\n", 1032 | " 'SentimentText': }" 1033 | ] 1034 | }, 1035 | "execution_count": 29, 1036 | "metadata": {}, 1037 | "output_type": "execute_result" 1038 | } 1039 | ], 1040 | "source": [ 1041 | "batch.dataset.fields" 1042 | ] 1043 | }, 1044 | { 1045 | "cell_type": "code", 1046 | "execution_count": 30, 1047 | "metadata": {}, 1048 | "outputs": [ 1049 | { 1050 | "data": { 1051 | "text/plain": [ 1052 | "''" 1053 | ] 1054 | }, 1055 | "execution_count": 30, 1056 | "metadata": {}, 1057 | "output_type": "execute_result" 1058 | } 1059 | ], 1060 | "source": [ 1061 | "txt_field.vocab.itos[1]" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "markdown", 1066 | "metadata": {}, 1067 | "source": [ 1068 | "##### convert index to string" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "markdown", 1073 | "metadata": {}, 1074 | "source": [ 1075 | "Function to convert batch to text" 1076 | ] 1077 | }, 1078 | { 1079 | "cell_type": "code", 1080 | "execution_count": 31, 1081 | "metadata": {}, 1082 | "outputs": [], 1083 | "source": [ 1084 | "def idxtosent(batch, idx):\n", 1085 | " return ' '.join([txt_field.vocab.itos[i] for i in batch.SentimentText[0][:,idx].cpu().data.numpy()])" 1086 | ] 1087 | }, 1088 | { 1089 | "cell_type": "code", 1090 | "execution_count": 32, 1091 | "metadata": {}, 1092 | "outputs": [ 1093 | { 1094 | "data": { 1095 | "text/plain": [ 1096 | "'therealjordin oh also i love love love papercut i really hope that makes it on your album ok just wanted to tell ya'" 1097 | ] 1098 | }, 1099 | "execution_count": 32, 1100 | "metadata": {}, 1101 | "output_type": "execute_result" 1102 | } 1103 | ], 1104 | "source": [ 1105 | "idxtosent(batch,0)" 1106 | ] 1107 | }, 1108 | { 1109 | "cell_type": "code", 1110 | "execution_count": 33, 1111 | "metadata": {}, 1112 | "outputs": [ 1113 | { 1114 | "data": { 1115 | "text/plain": [ 1116 | "' ok then but it s funny when you re hyper i m sooo gutted i can t come to yours on wednesday'" 1117 | ] 1118 | }, 1119 | "execution_count": 33, 1120 | "metadata": {}, 1121 | "output_type": "execute_result" 1122 | } 1123 | ], 1124 | "source": [ 1125 | "idxtosent(batch,1)" 1126 | ] 1127 | }, 1128 | { 1129 | "cell_type": "code", 1130 | "execution_count": 34, 1131 | "metadata": {}, 1132 | "outputs": [ 1133 | { 1134 | "data": { 1135 | "text/plain": [ 1136 | "'oh shit i m that douche that tweets during band sets but i don t know them nor do i like them srry'" 1137 | ] 1138 | }, 1139 | "execution_count": 34, 1140 | "metadata": {}, 1141 | "output_type": "execute_result" 1142 | } 1143 | ], 1144 | "source": [ 1145 | "idxtosent(batch,2)" 1146 | ] 1147 | }, 1148 | { 1149 | "cell_type": "code", 1150 | "execution_count": 35, 1151 | "metadata": {}, 1152 | "outputs": [ 1153 | { 1154 | "data": { 1155 | "text/plain": [ 1156 | "{'Sentiment': tensor([ 1, 0, 0], device='cuda:0'),\n", 1157 | " 'SentimentText': (tensor([[ 3590, 0, 88],\n", 1158 | " [ 88, 183, 386],\n", 1159 | " [ 274, 100, 2],\n", 1160 | " [ 2, 22, 14],\n", 1161 | " [ 49, 7, 17],\n", 1162 | " [ 49, 13, 5732],\n", 1163 | " [ 49, 378, 17],\n", 1164 | " [ 21087, 89, 427],\n", 1165 | " [ 2, 9, 846],\n", 1166 | " [ 67, 95, 816],\n", 1167 | " [ 103, 4329, 3986],\n", 1168 | " [ 17, 2, 22],\n", 1169 | " [ 299, 14, 2],\n", 1170 | " [ 7, 434, 66],\n", 1171 | " [ 18, 1286, 15],\n", 1172 | " [ 48, 2, 64],\n", 1173 | " [ 646, 29, 134],\n", 1174 | " [ 183, 15, 3264],\n", 1175 | " [ 23, 132, 35],\n", 1176 | " [ 379, 3, 2],\n", 1177 | " [ 3, 1001, 40],\n", 1178 | " [ 279, 18, 134],\n", 1179 | " [ 265, 959, 5332]], device='cuda:0'),\n", 1180 | " tensor([ 23, 23, 23], device='cuda:0')),\n", 1181 | " 'batch_size': 3,\n", 1182 | " 'dataset': ,\n", 1183 | " 'fields': dict_keys(['ItemID', 'Sentiment', 'SentimentSource', 'SentimentText']),\n", 1184 | " 'train': True}" 1185 | ] 1186 | }, 1187 | "execution_count": 35, 1188 | "metadata": {}, 1189 | "output_type": "execute_result" 1190 | } 1191 | ], 1192 | "source": [ 1193 | "batch.__dict__" 1194 | ] 1195 | }, 1196 | { 1197 | "cell_type": "code", 1198 | "execution_count": 36, 1199 | "metadata": {}, 1200 | "outputs": [ 1201 | { 1202 | "data": { 1203 | "text/plain": [ 1204 | "{'Sentiment': tensor([ 0, 1, 0], device='cuda:0'),\n", 1205 | " 'SentimentText': (tensor([[ 0, 67373, 82141]], device='cuda:0'),\n", 1206 | " tensor([ 1, 1, 1], device='cuda:0')),\n", 1207 | " 'batch_size': 3,\n", 1208 | " 'dataset': ,\n", 1209 | " 'fields': dict_keys(['ItemID', 'Sentiment', 'SentimentSource', 'SentimentText']),\n", 1210 | " 'train': False}" 1211 | ] 1212 | }, 1213 | "execution_count": 36, 1214 | "metadata": {}, 1215 | "output_type": "execute_result" 1216 | } 1217 | ], 1218 | "source": [ 1219 | "val_batch = next(iter(valdl))\n", 1220 | "val_batch.__dict__" 1221 | ] 1222 | }, 1223 | { 1224 | "cell_type": "markdown", 1225 | "metadata": {}, 1226 | "source": [ 1227 | "##### Note that BucketIterator returns a Batch object instead of text index and labels directly and Batch object is not iterable unlike pytorch Dataloader. A single Batch object contains the data of one batch and the text and labels can be accessed via column names. \n", 1228 | "##### This is one of the small hiccups in torchtext. But this can be easily overcome in two ways. Either write some extra code in the training loop for getting the data out of Batch object or write a iterable wrapper around Batch Object that returns the desired data. I will take the second approach as this is much cleaner." 1229 | ] 1230 | }, 1231 | { 1232 | "cell_type": "code", 1233 | "execution_count": 37, 1234 | "metadata": {}, 1235 | "outputs": [], 1236 | "source": [ 1237 | "class BatchGenerator:\n", 1238 | " def __init__(self, dl, x_field, y_field):\n", 1239 | " self.dl, self.x_field, self.y_field = dl, x_field, y_field\n", 1240 | " \n", 1241 | " def __len__(self):\n", 1242 | " return len(self.dl)\n", 1243 | " \n", 1244 | " def __iter__(self):\n", 1245 | " for batch in self.dl:\n", 1246 | " X = getattr(batch, self.x_field)\n", 1247 | " y = getattr(batch, self.y_field)\n", 1248 | " yield (X,y)" 1249 | ] 1250 | }, 1251 | { 1252 | "cell_type": "code", 1253 | "execution_count": 38, 1254 | "metadata": {}, 1255 | "outputs": [ 1256 | { 1257 | "data": { 1258 | "text/plain": [ 1259 | "((tensor([[ 2948, 3499, 89132],\n", 1260 | " [ 2473, 8096, 53994],\n", 1261 | " [ 111, 5777, 13980],\n", 1262 | " [ 2, 3163, 384],\n", 1263 | " [ 707, 545, 5],\n", 1264 | " [ 0, 14366, 675],\n", 1265 | " [ 71, 44, 52],\n", 1266 | " [ 8187, 59, 3],\n", 1267 | " [ 4529, 22192, 35],\n", 1268 | " [ 39, 63, 138],\n", 1269 | " [ 141, 0, 1]], device='cuda:0'),\n", 1270 | " tensor([ 11, 11, 10], device='cuda:0')),\n", 1271 | " tensor([ 1, 1, 0], device='cuda:0'))" 1272 | ] 1273 | }, 1274 | "execution_count": 38, 1275 | "metadata": {}, 1276 | "output_type": "execute_result" 1277 | } 1278 | ], 1279 | "source": [ 1280 | "train_batch_it = BatchGenerator(traindl, 'SentimentText', 'Sentiment')\n", 1281 | "next(iter(train_batch_it))" 1282 | ] 1283 | }, 1284 | { 1285 | "cell_type": "code", 1286 | "execution_count": null, 1287 | "metadata": {}, 1288 | "outputs": [], 1289 | "source": [] 1290 | }, 1291 | { 1292 | "cell_type": "markdown", 1293 | "metadata": {}, 1294 | "source": [ 1295 | "## 5. Finally Model and training" 1296 | ] 1297 | }, 1298 | { 1299 | "cell_type": "code", 1300 | "execution_count": 39, 1301 | "metadata": {}, 1302 | "outputs": [], 1303 | "source": [ 1304 | "vocab_size = len(txt_field.vocab)\n", 1305 | "embedding_dim = 100\n", 1306 | "n_hidden = 64\n", 1307 | "n_out = 2" 1308 | ] 1309 | }, 1310 | { 1311 | "cell_type": "markdown", 1312 | "metadata": {}, 1313 | "source": [ 1314 | "#### Simple GRU model" 1315 | ] 1316 | }, 1317 | { 1318 | "cell_type": "code", 1319 | "execution_count": 40, 1320 | "metadata": {}, 1321 | "outputs": [], 1322 | "source": [ 1323 | "class SimpleGRU(nn.Module):\n", 1324 | " def __init__(self, vocab_size, embedding_dim, n_hidden, n_out, pretrained_vec, bidirectional=True):\n", 1325 | " super().__init__()\n", 1326 | " self.vocab_size,self.embedding_dim,self.n_hidden,self.n_out,self.bidirectional = vocab_size, embedding_dim, n_hidden, n_out, bidirectional\n", 1327 | " self.emb = nn.Embedding(self.vocab_size, self.embedding_dim)\n", 1328 | " self.emb.weight.data.copy_(pretrained_vec)\n", 1329 | " self.emb.weight.requires_grad = False\n", 1330 | " self.gru = nn.GRU(self.embedding_dim, self.n_hidden, bidirectional=bidirectional)\n", 1331 | " self.out = nn.Linear(self.n_hidden, self.n_out)\n", 1332 | " \n", 1333 | " def forward(self, seq, lengths):\n", 1334 | " bs = seq.size(1) # batch size\n", 1335 | " seq = seq.transpose(0,1)\n", 1336 | " self.h = self.init_hidden(bs) # initialize hidden state of GRU\n", 1337 | " embs = self.emb(seq)\n", 1338 | " embs = embs.transpose(0,1)\n", 1339 | " embs = pack_padded_sequence(embs, lengths) # unpad\n", 1340 | " gru_out, self.h = self.gru(embs, self.h) # gru returns hidden state of all timesteps as well as hidden state at last timestep\n", 1341 | " gru_out, lengths = pad_packed_sequence(gru_out) # pad the sequence to the max length in the batch\n", 1342 | " # since it is as classification problem, we will grab the last hidden state\n", 1343 | " outp = self.out(self.h[-1]) # self.h[-1] contains hidden state of last timestep\n", 1344 | "# return F.log_softmax(outp, dim=-1)\n", 1345 | " return F.log_softmax(outp)\n", 1346 | " \n", 1347 | " def init_hidden(self, batch_size): \n", 1348 | " if self.bidirectional:\n", 1349 | " return torch.zeros((2,batch_size,self.n_hidden)).to(device)\n", 1350 | " else:\n", 1351 | " return torch.zeros((1,batch_size,self.n_hidden)).to(device)" 1352 | ] 1353 | }, 1354 | { 1355 | "cell_type": "markdown", 1356 | "metadata": {}, 1357 | "source": [ 1358 | "#### Concat Pooling model" 1359 | ] 1360 | }, 1361 | { 1362 | "cell_type": "code", 1363 | "execution_count": 41, 1364 | "metadata": {}, 1365 | "outputs": [], 1366 | "source": [ 1367 | "class ConcatPoolingGRUAdaptive(nn.Module):\n", 1368 | " def __init__(self, vocab_size, embedding_dim, n_hidden, n_out, pretrained_vec, bidirectional=True):\n", 1369 | " super().__init__()\n", 1370 | " self.vocab_size = vocab_size\n", 1371 | " self.embedding_dim = embedding_dim\n", 1372 | " self.n_hidden = n_hidden\n", 1373 | " self.n_out = n_out\n", 1374 | " self.bidirectional = bidirectional\n", 1375 | " \n", 1376 | " self.emb = nn.Embedding(self.vocab_size, self.embedding_dim)\n", 1377 | " self.emb.weight.data.copy_(pretrained_vec)\n", 1378 | " self.emb.weight.requires_grad = False\n", 1379 | " self.gru = nn.GRU(self.embedding_dim, self.n_hidden, bidirectional=bidirectional)\n", 1380 | " if bidirectional:\n", 1381 | " self.out = nn.Linear(self.n_hidden*2*2, self.n_out)\n", 1382 | " else:\n", 1383 | " self.out = nn.Linear(self.n_hidden*2, self.n_out)\n", 1384 | " \n", 1385 | " def forward(self, seq, lengths):\n", 1386 | " bs = seq.size(1)\n", 1387 | " self.h = self.init_hidden(bs)\n", 1388 | " seq = seq.transpose(0,1)\n", 1389 | " embs = self.emb(seq)\n", 1390 | " embs = embs.transpose(0,1)\n", 1391 | " embs = pack_padded_sequence(embs, lengths)\n", 1392 | " gru_out, self.h = self.gru(embs, self.h)\n", 1393 | " gru_out, lengths = pad_packed_sequence(gru_out) \n", 1394 | " \n", 1395 | " avg_pool = F.adaptive_avg_pool1d(gru_out.permute(1,2,0),1).view(bs,-1)\n", 1396 | " max_pool = F.adaptive_max_pool1d(gru_out.permute(1,2,0),1).view(bs,-1) \n", 1397 | " outp = self.out(torch.cat([avg_pool,max_pool],dim=1))\n", 1398 | " return F.log_softmax(outp)\n", 1399 | " \n", 1400 | " def init_hidden(self, batch_size): \n", 1401 | " if self.bidirectional:\n", 1402 | " return torch.zeros((2,batch_size,self.n_hidden)).to(device)\n", 1403 | " else:\n", 1404 | " return torch.zeros((1,batch_size,self.n_hidden)).cuda().to(device)" 1405 | ] 1406 | }, 1407 | { 1408 | "cell_type": "markdown", 1409 | "metadata": {}, 1410 | "source": [ 1411 | "#### Training function" 1412 | ] 1413 | }, 1414 | { 1415 | "cell_type": "code", 1416 | "execution_count": 42, 1417 | "metadata": {}, 1418 | "outputs": [], 1419 | "source": [ 1420 | "def fit(model, train_dl, val_dl, loss_fn, opt, epochs=3):\n", 1421 | " num_batch = len(train_dl)\n", 1422 | " for epoch in tnrange(epochs): \n", 1423 | " y_true_train = list()\n", 1424 | " y_pred_train = list()\n", 1425 | " total_loss_train = 0 \n", 1426 | " \n", 1427 | " t = tqdm_notebook(iter(train_dl), leave=False, total=num_batch)\n", 1428 | " for (X,lengths),y in t:\n", 1429 | " t.set_description(f'Epoch {epoch}')\n", 1430 | " lengths = lengths.cpu().numpy()\n", 1431 | " \n", 1432 | " opt.zero_grad()\n", 1433 | " pred = model(X, lengths)\n", 1434 | " loss = loss_fn(pred, y)\n", 1435 | " loss.backward()\n", 1436 | " opt.step()\n", 1437 | " \n", 1438 | " t.set_postfix(loss=loss.item())\n", 1439 | " pred_idx = torch.max(pred, dim=1)[1]\n", 1440 | " \n", 1441 | " y_true_train += list(y.cpu().data.numpy())\n", 1442 | " y_pred_train += list(pred_idx.cpu().data.numpy())\n", 1443 | " total_loss_train += loss.item()\n", 1444 | " \n", 1445 | " train_acc = accuracy_score(y_true_train, y_pred_train)\n", 1446 | " train_loss = total_loss_train/len(train_dl)\n", 1447 | " \n", 1448 | " if val_dl:\n", 1449 | " y_true_val = list()\n", 1450 | " y_pred_val = list()\n", 1451 | " total_loss_val = 0\n", 1452 | " for (X,lengths),y in tqdm_notebook(val_dl, leave=False):\n", 1453 | " pred = model(X, lengths.cpu().numpy())\n", 1454 | " loss = loss_fn(pred, y)\n", 1455 | " pred_idx = torch.max(pred, 1)[1]\n", 1456 | " y_true_val += list(y.cpu().data.numpy())\n", 1457 | " y_pred_val += list(pred_idx.cpu().data.numpy())\n", 1458 | " total_loss_val += loss.item()\n", 1459 | " valacc = accuracy_score(y_true_val, y_pred_val)\n", 1460 | " valloss = total_loss_val/len(valdl)\n", 1461 | " print(f'Epoch {epoch}: train_loss: {train_loss:.4f} train_acc: {train_acc:.4f} | val_loss: {valloss:.4f} val_acc: {valacc:.4f}')\n", 1462 | " else:\n", 1463 | " print(f'Epoch {epoch}: train_loss: {train_loss:.4f} train_acc: {train_acc:.4f}')" 1464 | ] 1465 | }, 1466 | { 1467 | "cell_type": "markdown", 1468 | "metadata": {}, 1469 | "source": [ 1470 | "##### Dataloader" 1471 | ] 1472 | }, 1473 | { 1474 | "cell_type": "code", 1475 | "execution_count": 43, 1476 | "metadata": {}, 1477 | "outputs": [], 1478 | "source": [ 1479 | "traindl, valdl = data.BucketIterator.splits(datasets=(trainds, valds), batch_sizes=(512,1024), sort_key=lambda x: len(x.SentimentText), device=0, sort_within_batch=True, repeat=False)\n", 1480 | "train_batch_it = BatchGenerator(traindl, 'SentimentText', 'Sentiment')\n", 1481 | "val_batch_it = BatchGenerator(valdl, 'SentimentText', 'Sentiment')" 1482 | ] 1483 | }, 1484 | { 1485 | "cell_type": "markdown", 1486 | "metadata": {}, 1487 | "source": [ 1488 | "##### Train simple GRU model" 1489 | ] 1490 | }, 1491 | { 1492 | "cell_type": "code", 1493 | "execution_count": 45, 1494 | "metadata": {}, 1495 | "outputs": [ 1496 | { 1497 | "data": { 1498 | "application/vnd.jupyter.widget-view+json": { 1499 | "model_id": "1ea67e63ee244414846d756b54668357", 1500 | "version_major": 2, 1501 | "version_minor": 0 1502 | }, 1503 | "text/html": [ 1504 | "

Failed to display Jupyter Widget of type HBox.

\n", 1505 | "

\n", 1506 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1507 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1508 | " likely means that the widgets JavaScript library is either not installed or\n", 1509 | " not enabled. See the Jupyter\n", 1510 | " Widgets Documentation for setup instructions.\n", 1511 | "

\n", 1512 | "

\n", 1513 | " If you're reading this message in another frontend (for example, a static\n", 1514 | " rendering on GitHub or NBViewer),\n", 1515 | " it may mean that your frontend doesn't currently support widgets.\n", 1516 | "

\n" 1517 | ], 1518 | "text/plain": [ 1519 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))" 1520 | ] 1521 | }, 1522 | "metadata": {}, 1523 | "output_type": "display_data" 1524 | }, 1525 | { 1526 | "data": { 1527 | "application/vnd.jupyter.widget-view+json": { 1528 | "model_id": "cb08b1c6c1604ee5b6e1f3737748b129", 1529 | "version_major": 2, 1530 | "version_minor": 0 1531 | }, 1532 | "text/html": [ 1533 | "

Failed to display Jupyter Widget of type HBox.

\n", 1534 | "

\n", 1535 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1536 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1537 | " likely means that the widgets JavaScript library is either not installed or\n", 1538 | " not enabled. See the Jupyter\n", 1539 | " Widgets Documentation for setup instructions.\n", 1540 | "

\n", 1541 | "

\n", 1542 | " If you're reading this message in another frontend (for example, a static\n", 1543 | " rendering on GitHub or NBViewer),\n", 1544 | " it may mean that your frontend doesn't currently support widgets.\n", 1545 | "

\n" 1546 | ], 1547 | "text/plain": [ 1548 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 1549 | ] 1550 | }, 1551 | "metadata": {}, 1552 | "output_type": "display_data" 1553 | }, 1554 | { 1555 | "data": { 1556 | "application/vnd.jupyter.widget-view+json": { 1557 | "model_id": "5e2c250c21d94c14b6390f9510269423", 1558 | "version_major": 2, 1559 | "version_minor": 0 1560 | }, 1561 | "text/html": [ 1562 | "

Failed to display Jupyter Widget of type HBox.

\n", 1563 | "

\n", 1564 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1565 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1566 | " likely means that the widgets JavaScript library is either not installed or\n", 1567 | " not enabled. See the Jupyter\n", 1568 | " Widgets Documentation for setup instructions.\n", 1569 | "

\n", 1570 | "

\n", 1571 | " If you're reading this message in another frontend (for example, a static\n", 1572 | " rendering on GitHub or NBViewer),\n", 1573 | " it may mean that your frontend doesn't currently support widgets.\n", 1574 | "

\n" 1575 | ], 1576 | "text/plain": [ 1577 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 1578 | ] 1579 | }, 1580 | "metadata": {}, 1581 | "output_type": "display_data" 1582 | }, 1583 | { 1584 | "name": "stdout", 1585 | "output_type": "stream", 1586 | "text": [ 1587 | "Epoch 0: train_loss: 0.4463 train_acc: 0.7892 | val_loss: 0.4154 val_acc: 0.8077\n" 1588 | ] 1589 | }, 1590 | { 1591 | "data": { 1592 | "application/vnd.jupyter.widget-view+json": { 1593 | "model_id": "6da1ef4b2d8f485cbd563fe7c736cce8", 1594 | "version_major": 2, 1595 | "version_minor": 0 1596 | }, 1597 | "text/html": [ 1598 | "

Failed to display Jupyter Widget of type HBox.

\n", 1599 | "

\n", 1600 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1601 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1602 | " likely means that the widgets JavaScript library is either not installed or\n", 1603 | " not enabled. See the Jupyter\n", 1604 | " Widgets Documentation for setup instructions.\n", 1605 | "

\n", 1606 | "

\n", 1607 | " If you're reading this message in another frontend (for example, a static\n", 1608 | " rendering on GitHub or NBViewer),\n", 1609 | " it may mean that your frontend doesn't currently support widgets.\n", 1610 | "

\n" 1611 | ], 1612 | "text/plain": [ 1613 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 1614 | ] 1615 | }, 1616 | "metadata": {}, 1617 | "output_type": "display_data" 1618 | }, 1619 | { 1620 | "data": { 1621 | "application/vnd.jupyter.widget-view+json": { 1622 | "model_id": "1b6fe4fffa884a2b83b18c1bbf580d8b", 1623 | "version_major": 2, 1624 | "version_minor": 0 1625 | }, 1626 | "text/html": [ 1627 | "

Failed to display Jupyter Widget of type HBox.

\n", 1628 | "

\n", 1629 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1630 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1631 | " likely means that the widgets JavaScript library is either not installed or\n", 1632 | " not enabled. See the Jupyter\n", 1633 | " Widgets Documentation for setup instructions.\n", 1634 | "

\n", 1635 | "

\n", 1636 | " If you're reading this message in another frontend (for example, a static\n", 1637 | " rendering on GitHub or NBViewer),\n", 1638 | " it may mean that your frontend doesn't currently support widgets.\n", 1639 | "

\n" 1640 | ], 1641 | "text/plain": [ 1642 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 1643 | ] 1644 | }, 1645 | "metadata": {}, 1646 | "output_type": "display_data" 1647 | }, 1648 | { 1649 | "name": "stdout", 1650 | "output_type": "stream", 1651 | "text": [ 1652 | "Epoch 1: train_loss: 0.4071 train_acc: 0.8130 | val_loss: 0.4001 val_acc: 0.8178\n" 1653 | ] 1654 | }, 1655 | { 1656 | "data": { 1657 | "application/vnd.jupyter.widget-view+json": { 1658 | "model_id": "71cef71e4cec4327865021d7be8f22c7", 1659 | "version_major": 2, 1660 | "version_minor": 0 1661 | }, 1662 | "text/html": [ 1663 | "

Failed to display Jupyter Widget of type HBox.

\n", 1664 | "

\n", 1665 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1666 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1667 | " likely means that the widgets JavaScript library is either not installed or\n", 1668 | " not enabled. See the Jupyter\n", 1669 | " Widgets Documentation for setup instructions.\n", 1670 | "

\n", 1671 | "

\n", 1672 | " If you're reading this message in another frontend (for example, a static\n", 1673 | " rendering on GitHub or NBViewer),\n", 1674 | " it may mean that your frontend doesn't currently support widgets.\n", 1675 | "

\n" 1676 | ], 1677 | "text/plain": [ 1678 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 1679 | ] 1680 | }, 1681 | "metadata": {}, 1682 | "output_type": "display_data" 1683 | }, 1684 | { 1685 | "data": { 1686 | "application/vnd.jupyter.widget-view+json": { 1687 | "model_id": "499de9c307ae429dbb79c7ec89a3cd17", 1688 | "version_major": 2, 1689 | "version_minor": 0 1690 | }, 1691 | "text/html": [ 1692 | "

Failed to display Jupyter Widget of type HBox.

\n", 1693 | "

\n", 1694 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1695 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1696 | " likely means that the widgets JavaScript library is either not installed or\n", 1697 | " not enabled. See the Jupyter\n", 1698 | " Widgets Documentation for setup instructions.\n", 1699 | "

\n", 1700 | "

\n", 1701 | " If you're reading this message in another frontend (for example, a static\n", 1702 | " rendering on GitHub or NBViewer),\n", 1703 | " it may mean that your frontend doesn't currently support widgets.\n", 1704 | "

\n" 1705 | ], 1706 | "text/plain": [ 1707 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 1708 | ] 1709 | }, 1710 | "metadata": {}, 1711 | "output_type": "display_data" 1712 | }, 1713 | { 1714 | "name": "stdout", 1715 | "output_type": "stream", 1716 | "text": [ 1717 | "Epoch 2: train_loss: 0.3952 train_acc: 0.8199 | val_loss: 0.4060 val_acc: 0.8146\n" 1718 | ] 1719 | }, 1720 | { 1721 | "data": { 1722 | "application/vnd.jupyter.widget-view+json": { 1723 | "model_id": "d2ba9f170596427fbdf673c2da483037", 1724 | "version_major": 2, 1725 | "version_minor": 0 1726 | }, 1727 | "text/html": [ 1728 | "

Failed to display Jupyter Widget of type HBox.

\n", 1729 | "

\n", 1730 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1731 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1732 | " likely means that the widgets JavaScript library is either not installed or\n", 1733 | " not enabled. See the Jupyter\n", 1734 | " Widgets Documentation for setup instructions.\n", 1735 | "

\n", 1736 | "

\n", 1737 | " If you're reading this message in another frontend (for example, a static\n", 1738 | " rendering on GitHub or NBViewer),\n", 1739 | " it may mean that your frontend doesn't currently support widgets.\n", 1740 | "

\n" 1741 | ], 1742 | "text/plain": [ 1743 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 1744 | ] 1745 | }, 1746 | "metadata": {}, 1747 | "output_type": "display_data" 1748 | }, 1749 | { 1750 | "data": { 1751 | "application/vnd.jupyter.widget-view+json": { 1752 | "model_id": "c50af3074c3442b0818944adb72eae87", 1753 | "version_major": 2, 1754 | "version_minor": 0 1755 | }, 1756 | "text/html": [ 1757 | "

Failed to display Jupyter Widget of type HBox.

\n", 1758 | "

\n", 1759 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1760 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1761 | " likely means that the widgets JavaScript library is either not installed or\n", 1762 | " not enabled. See the Jupyter\n", 1763 | " Widgets Documentation for setup instructions.\n", 1764 | "

\n", 1765 | "

\n", 1766 | " If you're reading this message in another frontend (for example, a static\n", 1767 | " rendering on GitHub or NBViewer),\n", 1768 | " it may mean that your frontend doesn't currently support widgets.\n", 1769 | "

\n" 1770 | ], 1771 | "text/plain": [ 1772 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 1773 | ] 1774 | }, 1775 | "metadata": {}, 1776 | "output_type": "display_data" 1777 | }, 1778 | { 1779 | "name": "stdout", 1780 | "output_type": "stream", 1781 | "text": [ 1782 | "Epoch 3: train_loss: 0.3877 train_acc: 0.8240 | val_loss: 0.3876 val_acc: 0.8248\n" 1783 | ] 1784 | }, 1785 | { 1786 | "data": { 1787 | "application/vnd.jupyter.widget-view+json": { 1788 | "model_id": "1e2643812f914dcfaa447c5df332dace", 1789 | "version_major": 2, 1790 | "version_minor": 0 1791 | }, 1792 | "text/html": [ 1793 | "

Failed to display Jupyter Widget of type HBox.

\n", 1794 | "

\n", 1795 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1796 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1797 | " likely means that the widgets JavaScript library is either not installed or\n", 1798 | " not enabled. See the Jupyter\n", 1799 | " Widgets Documentation for setup instructions.\n", 1800 | "

\n", 1801 | "

\n", 1802 | " If you're reading this message in another frontend (for example, a static\n", 1803 | " rendering on GitHub or NBViewer),\n", 1804 | " it may mean that your frontend doesn't currently support widgets.\n", 1805 | "

\n" 1806 | ], 1807 | "text/plain": [ 1808 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 1809 | ] 1810 | }, 1811 | "metadata": {}, 1812 | "output_type": "display_data" 1813 | }, 1814 | { 1815 | "data": { 1816 | "application/vnd.jupyter.widget-view+json": { 1817 | "model_id": "bf93640eae874cbea9a154690a4fa9ed", 1818 | "version_major": 2, 1819 | "version_minor": 0 1820 | }, 1821 | "text/html": [ 1822 | "

Failed to display Jupyter Widget of type HBox.

\n", 1823 | "

\n", 1824 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1825 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1826 | " likely means that the widgets JavaScript library is either not installed or\n", 1827 | " not enabled. See the Jupyter\n", 1828 | " Widgets Documentation for setup instructions.\n", 1829 | "

\n", 1830 | "

\n", 1831 | " If you're reading this message in another frontend (for example, a static\n", 1832 | " rendering on GitHub or NBViewer),\n", 1833 | " it may mean that your frontend doesn't currently support widgets.\n", 1834 | "

\n" 1835 | ], 1836 | "text/plain": [ 1837 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 1838 | ] 1839 | }, 1840 | "metadata": {}, 1841 | "output_type": "display_data" 1842 | }, 1843 | { 1844 | "name": "stdout", 1845 | "output_type": "stream", 1846 | "text": [ 1847 | "Epoch 4: train_loss: 0.3822 train_acc: 0.8270 | val_loss: 0.3861 val_acc: 0.8256\n", 1848 | "\n" 1849 | ] 1850 | } 1851 | ], 1852 | "source": [ 1853 | "m = SimpleGRU(vocab_size, embedding_dim, n_hidden, n_out, trainds.fields['SentimentText'].vocab.vectors).to(device)\n", 1854 | "opt = optim.Adam(filter(lambda p: p.requires_grad, m.parameters()), 1e-3)\n", 1855 | "\n", 1856 | "fit(model=m, train_dl=train_batch_it, val_dl=val_batch_it, loss_fn=F.nll_loss, opt=opt, epochs=5)" 1857 | ] 1858 | }, 1859 | { 1860 | "cell_type": "markdown", 1861 | "metadata": {}, 1862 | "source": [ 1863 | "##### Train Concat Pooling model" 1864 | ] 1865 | }, 1866 | { 1867 | "cell_type": "code", 1868 | "execution_count": 46, 1869 | "metadata": {}, 1870 | "outputs": [ 1871 | { 1872 | "data": { 1873 | "application/vnd.jupyter.widget-view+json": { 1874 | "model_id": "66ccb790ac2e4a6c8032a3a0df71bace", 1875 | "version_major": 2, 1876 | "version_minor": 0 1877 | }, 1878 | "text/html": [ 1879 | "

Failed to display Jupyter Widget of type HBox.

\n", 1880 | "

\n", 1881 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1882 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1883 | " likely means that the widgets JavaScript library is either not installed or\n", 1884 | " not enabled. See the Jupyter\n", 1885 | " Widgets Documentation for setup instructions.\n", 1886 | "

\n", 1887 | "

\n", 1888 | " If you're reading this message in another frontend (for example, a static\n", 1889 | " rendering on GitHub or NBViewer),\n", 1890 | " it may mean that your frontend doesn't currently support widgets.\n", 1891 | "

\n" 1892 | ], 1893 | "text/plain": [ 1894 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))" 1895 | ] 1896 | }, 1897 | "metadata": {}, 1898 | "output_type": "display_data" 1899 | }, 1900 | { 1901 | "data": { 1902 | "application/vnd.jupyter.widget-view+json": { 1903 | "model_id": "5ef3cfad0728452f88009d677b482d87", 1904 | "version_major": 2, 1905 | "version_minor": 0 1906 | }, 1907 | "text/html": [ 1908 | "

Failed to display Jupyter Widget of type HBox.

\n", 1909 | "

\n", 1910 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1911 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1912 | " likely means that the widgets JavaScript library is either not installed or\n", 1913 | " not enabled. See the Jupyter\n", 1914 | " Widgets Documentation for setup instructions.\n", 1915 | "

\n", 1916 | "

\n", 1917 | " If you're reading this message in another frontend (for example, a static\n", 1918 | " rendering on GitHub or NBViewer),\n", 1919 | " it may mean that your frontend doesn't currently support widgets.\n", 1920 | "

\n" 1921 | ], 1922 | "text/plain": [ 1923 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 1924 | ] 1925 | }, 1926 | "metadata": {}, 1927 | "output_type": "display_data" 1928 | }, 1929 | { 1930 | "data": { 1931 | "application/vnd.jupyter.widget-view+json": { 1932 | "model_id": "9243767da3034e80b730c782435e6421", 1933 | "version_major": 2, 1934 | "version_minor": 0 1935 | }, 1936 | "text/html": [ 1937 | "

Failed to display Jupyter Widget of type HBox.

\n", 1938 | "

\n", 1939 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1940 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1941 | " likely means that the widgets JavaScript library is either not installed or\n", 1942 | " not enabled. See the Jupyter\n", 1943 | " Widgets Documentation for setup instructions.\n", 1944 | "

\n", 1945 | "

\n", 1946 | " If you're reading this message in another frontend (for example, a static\n", 1947 | " rendering on GitHub or NBViewer),\n", 1948 | " it may mean that your frontend doesn't currently support widgets.\n", 1949 | "

\n" 1950 | ], 1951 | "text/plain": [ 1952 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 1953 | ] 1954 | }, 1955 | "metadata": {}, 1956 | "output_type": "display_data" 1957 | }, 1958 | { 1959 | "name": "stdout", 1960 | "output_type": "stream", 1961 | "text": [ 1962 | "Epoch 0: train_loss: 0.4349 train_acc: 0.7959 | val_loss: 0.4036 val_acc: 0.8153\n" 1963 | ] 1964 | }, 1965 | { 1966 | "data": { 1967 | "application/vnd.jupyter.widget-view+json": { 1968 | "model_id": "7efe0017eda740f589749a31671674e5", 1969 | "version_major": 2, 1970 | "version_minor": 0 1971 | }, 1972 | "text/html": [ 1973 | "

Failed to display Jupyter Widget of type HBox.

\n", 1974 | "

\n", 1975 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 1976 | " that the widgets JavaScript is still loading. If this message persists, it\n", 1977 | " likely means that the widgets JavaScript library is either not installed or\n", 1978 | " not enabled. See the Jupyter\n", 1979 | " Widgets Documentation for setup instructions.\n", 1980 | "

\n", 1981 | "

\n", 1982 | " If you're reading this message in another frontend (for example, a static\n", 1983 | " rendering on GitHub or NBViewer),\n", 1984 | " it may mean that your frontend doesn't currently support widgets.\n", 1985 | "

\n" 1986 | ], 1987 | "text/plain": [ 1988 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 1989 | ] 1990 | }, 1991 | "metadata": {}, 1992 | "output_type": "display_data" 1993 | }, 1994 | { 1995 | "data": { 1996 | "application/vnd.jupyter.widget-view+json": { 1997 | "model_id": "3b9ef438202f4a61b4216c90866a8210", 1998 | "version_major": 2, 1999 | "version_minor": 0 2000 | }, 2001 | "text/html": [ 2002 | "

Failed to display Jupyter Widget of type HBox.

\n", 2003 | "

\n", 2004 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 2005 | " that the widgets JavaScript is still loading. If this message persists, it\n", 2006 | " likely means that the widgets JavaScript library is either not installed or\n", 2007 | " not enabled. See the Jupyter\n", 2008 | " Widgets Documentation for setup instructions.\n", 2009 | "

\n", 2010 | "

\n", 2011 | " If you're reading this message in another frontend (for example, a static\n", 2012 | " rendering on GitHub or NBViewer),\n", 2013 | " it may mean that your frontend doesn't currently support widgets.\n", 2014 | "

\n" 2015 | ], 2016 | "text/plain": [ 2017 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 2018 | ] 2019 | }, 2020 | "metadata": {}, 2021 | "output_type": "display_data" 2022 | }, 2023 | { 2024 | "name": "stdout", 2025 | "output_type": "stream", 2026 | "text": [ 2027 | "Epoch 1: train_loss: 0.3975 train_acc: 0.8189 | val_loss: 0.3913 val_acc: 0.8227\n" 2028 | ] 2029 | }, 2030 | { 2031 | "data": { 2032 | "application/vnd.jupyter.widget-view+json": { 2033 | "model_id": "b5fef79a65304fac8c381c4708cdde51", 2034 | "version_major": 2, 2035 | "version_minor": 0 2036 | }, 2037 | "text/html": [ 2038 | "

Failed to display Jupyter Widget of type HBox.

\n", 2039 | "

\n", 2040 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 2041 | " that the widgets JavaScript is still loading. If this message persists, it\n", 2042 | " likely means that the widgets JavaScript library is either not installed or\n", 2043 | " not enabled. See the Jupyter\n", 2044 | " Widgets Documentation for setup instructions.\n", 2045 | "

\n", 2046 | "

\n", 2047 | " If you're reading this message in another frontend (for example, a static\n", 2048 | " rendering on GitHub or NBViewer),\n", 2049 | " it may mean that your frontend doesn't currently support widgets.\n", 2050 | "

\n" 2051 | ], 2052 | "text/plain": [ 2053 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 2054 | ] 2055 | }, 2056 | "metadata": {}, 2057 | "output_type": "display_data" 2058 | }, 2059 | { 2060 | "data": { 2061 | "application/vnd.jupyter.widget-view+json": { 2062 | "model_id": "23985883852c4784aadd4db473436f78", 2063 | "version_major": 2, 2064 | "version_minor": 0 2065 | }, 2066 | "text/html": [ 2067 | "

Failed to display Jupyter Widget of type HBox.

\n", 2068 | "

\n", 2069 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 2070 | " that the widgets JavaScript is still loading. If this message persists, it\n", 2071 | " likely means that the widgets JavaScript library is either not installed or\n", 2072 | " not enabled. See the Jupyter\n", 2073 | " Widgets Documentation for setup instructions.\n", 2074 | "

\n", 2075 | "

\n", 2076 | " If you're reading this message in another frontend (for example, a static\n", 2077 | " rendering on GitHub or NBViewer),\n", 2078 | " it may mean that your frontend doesn't currently support widgets.\n", 2079 | "

\n" 2080 | ], 2081 | "text/plain": [ 2082 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 2083 | ] 2084 | }, 2085 | "metadata": {}, 2086 | "output_type": "display_data" 2087 | }, 2088 | { 2089 | "name": "stdout", 2090 | "output_type": "stream", 2091 | "text": [ 2092 | "Epoch 2: train_loss: 0.3853 train_acc: 0.8257 | val_loss: 0.3877 val_acc: 0.8250\n" 2093 | ] 2094 | }, 2095 | { 2096 | "data": { 2097 | "application/vnd.jupyter.widget-view+json": { 2098 | "model_id": "1220e354f91944f0b676d51ab018c564", 2099 | "version_major": 2, 2100 | "version_minor": 0 2101 | }, 2102 | "text/html": [ 2103 | "

Failed to display Jupyter Widget of type HBox.

\n", 2104 | "

\n", 2105 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 2106 | " that the widgets JavaScript is still loading. If this message persists, it\n", 2107 | " likely means that the widgets JavaScript library is either not installed or\n", 2108 | " not enabled. See the Jupyter\n", 2109 | " Widgets Documentation for setup instructions.\n", 2110 | "

\n", 2111 | "

\n", 2112 | " If you're reading this message in another frontend (for example, a static\n", 2113 | " rendering on GitHub or NBViewer),\n", 2114 | " it may mean that your frontend doesn't currently support widgets.\n", 2115 | "

\n" 2116 | ], 2117 | "text/plain": [ 2118 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 2119 | ] 2120 | }, 2121 | "metadata": {}, 2122 | "output_type": "display_data" 2123 | }, 2124 | { 2125 | "data": { 2126 | "application/vnd.jupyter.widget-view+json": { 2127 | "model_id": "ffd489350f1b47bea4aa191dd29a66cd", 2128 | "version_major": 2, 2129 | "version_minor": 0 2130 | }, 2131 | "text/html": [ 2132 | "

Failed to display Jupyter Widget of type HBox.

\n", 2133 | "

\n", 2134 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 2135 | " that the widgets JavaScript is still loading. If this message persists, it\n", 2136 | " likely means that the widgets JavaScript library is either not installed or\n", 2137 | " not enabled. See the Jupyter\n", 2138 | " Widgets Documentation for setup instructions.\n", 2139 | "

\n", 2140 | "

\n", 2141 | " If you're reading this message in another frontend (for example, a static\n", 2142 | " rendering on GitHub or NBViewer),\n", 2143 | " it may mean that your frontend doesn't currently support widgets.\n", 2144 | "

\n" 2145 | ], 2146 | "text/plain": [ 2147 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 2148 | ] 2149 | }, 2150 | "metadata": {}, 2151 | "output_type": "display_data" 2152 | }, 2153 | { 2154 | "name": "stdout", 2155 | "output_type": "stream", 2156 | "text": [ 2157 | "Epoch 3: train_loss: 0.3777 train_acc: 0.8300 | val_loss: 0.3822 val_acc: 0.8283\n" 2158 | ] 2159 | }, 2160 | { 2161 | "data": { 2162 | "application/vnd.jupyter.widget-view+json": { 2163 | "model_id": "c06d5486bfab4167a57cb91b4a42c20c", 2164 | "version_major": 2, 2165 | "version_minor": 0 2166 | }, 2167 | "text/html": [ 2168 | "

Failed to display Jupyter Widget of type HBox.

\n", 2169 | "

\n", 2170 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 2171 | " that the widgets JavaScript is still loading. If this message persists, it\n", 2172 | " likely means that the widgets JavaScript library is either not installed or\n", 2173 | " not enabled. See the Jupyter\n", 2174 | " Widgets Documentation for setup instructions.\n", 2175 | "

\n", 2176 | "

\n", 2177 | " If you're reading this message in another frontend (for example, a static\n", 2178 | " rendering on GitHub or NBViewer),\n", 2179 | " it may mean that your frontend doesn't currently support widgets.\n", 2180 | "

\n" 2181 | ], 2182 | "text/plain": [ 2183 | "HBox(children=(IntProgress(value=0, max=2467), HTML(value='')))" 2184 | ] 2185 | }, 2186 | "metadata": {}, 2187 | "output_type": "display_data" 2188 | }, 2189 | { 2190 | "data": { 2191 | "application/vnd.jupyter.widget-view+json": { 2192 | "model_id": "314e0deb0730405c926399a01991ec09", 2193 | "version_major": 2, 2194 | "version_minor": 0 2195 | }, 2196 | "text/html": [ 2197 | "

Failed to display Jupyter Widget of type HBox.

\n", 2198 | "

\n", 2199 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 2200 | " that the widgets JavaScript is still loading. If this message persists, it\n", 2201 | " likely means that the widgets JavaScript library is either not installed or\n", 2202 | " not enabled. See the Jupyter\n", 2203 | " Widgets Documentation for setup instructions.\n", 2204 | "

\n", 2205 | "

\n", 2206 | " If you're reading this message in another frontend (for example, a static\n", 2207 | " rendering on GitHub or NBViewer),\n", 2208 | " it may mean that your frontend doesn't currently support widgets.\n", 2209 | "

\n" 2210 | ], 2211 | "text/plain": [ 2212 | "HBox(children=(IntProgress(value=0, max=309), HTML(value='')))" 2213 | ] 2214 | }, 2215 | "metadata": {}, 2216 | "output_type": "display_data" 2217 | }, 2218 | { 2219 | "name": "stdout", 2220 | "output_type": "stream", 2221 | "text": [ 2222 | "Epoch 4: train_loss: 0.3715 train_acc: 0.8332 | val_loss: 0.3804 val_acc: 0.8289\n", 2223 | "\n" 2224 | ] 2225 | } 2226 | ], 2227 | "source": [ 2228 | "m = ConcatPoolingGRUAdaptive(vocab_size, embedding_dim, n_hidden, n_out, trainds.fields['SentimentText'].vocab.vectors).to(device)\n", 2229 | "opt = optim.Adam(filter(lambda p: p.requires_grad, m.parameters()), 1e-3)\n", 2230 | "\n", 2231 | "fit(model=m, train_dl=train_batch_it, val_dl=val_batch_it, loss_fn=F.nll_loss, opt=opt, epochs=5)" 2232 | ] 2233 | } 2234 | ], 2235 | "metadata": { 2236 | "kernelspec": { 2237 | "display_name": "Python 3", 2238 | "language": "python", 2239 | "name": "python3" 2240 | }, 2241 | "language_info": { 2242 | "codemirror_mode": { 2243 | "name": "ipython", 2244 | "version": 3 2245 | }, 2246 | "file_extension": ".py", 2247 | "mimetype": "text/x-python", 2248 | "name": "python", 2249 | "nbconvert_exporter": "python", 2250 | "pygments_lexer": "ipython3", 2251 | "version": "3.6.4" 2252 | } 2253 | }, 2254 | "nbformat": 4, 2255 | "nbformat_minor": 2 2256 | } 2257 | --------------------------------------------------------------------------------