├── README.md └── NLP for Text Classification (Jupyter Notebook).ipynb /README.md: -------------------------------------------------------------------------------- 1 | # nlptextclassification 2 | This is the source code for the 'Natural Language Processing for Text Classification with NLTK & Scikit-learn' video 3 | -------------------------------------------------------------------------------- /NLP for Text Classification (Jupyter Notebook).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Natural Language Processing for Text Classification with NLTK and Scikit-learn\n", 8 | "\n", 9 | "### Presented by Eduonix!\n", 10 | "\n", 11 | "In the project, Getting Started With Natural Language Processing in Python, we learned the basics of tokenizing, part-of-speech tagging, stemming, chunking, and named entity recognition; furthermore, we dove into machine learning and text classification using a simple support vector classifier and a dataset of positive and negative movie reviews. \n", 12 | "\n", 13 | "In this tutorial, we will expand on this foundation and explore different ways to improve our text classification results. We will cover and use:\n", 14 | "\n", 15 | "* Regular Expressions\n", 16 | "* Feature Engineering\n", 17 | "* Multiple scikit-learn Classifiers\n", 18 | "* Ensemble Methods\n", 19 | "\n", 20 | "### 1. Import Necessary Libraries\n", 21 | "\n", 22 | "To ensure the necessary libraries are installed correctly and up-to-date, print the version numbers for each library. This will also improve the reproducibility of our project." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "Python: 2.7.13 |Continuum Analytics, Inc.| (default, May 11 2017, 13:17:26) [MSC v.1500 64 bit (AMD64)]\n", 35 | "NLTK: 3.2.5\n", 36 | "Scikit-learn: 0.19.1\n", 37 | "Pandas: 0.21.0\n", 38 | "Numpy: 1.14.1\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "import sys\n", 44 | "import nltk\n", 45 | "import sklearn\n", 46 | "import pandas\n", 47 | "import numpy\n", 48 | "\n", 49 | "print('Python: {}'.format(sys.version))\n", 50 | "print('NLTK: {}'.format(nltk.__version__))\n", 51 | "print('Scikit-learn: {}'.format(sklearn.__version__))\n", 52 | "print('Pandas: {}'.format(pandas.__version__))\n", 53 | "print('Numpy: {}'.format(numpy.__version__))" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "### 2. Load the Dataset\n", 61 | "\n", 62 | "Now that we have ensured that our libraries are installed correctly, let's load the data set as a Pandas DataFrame. Furthermore, let's extract some useful information such as the column information and class distributions. \n", 63 | "\n", 64 | "The data set we will be using comes from the UCI Machine Learning Repository. It contains over 5000 SMS labeled messages that have been collected for mobile phone spam research. It can be downloaded from the following URL:\n", 65 | "\n", 66 | "https://archive.ics.uci.edu/ml/datasets/sms+spam+collection" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 2, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "import pandas as pd\n", 76 | "import numpy as np\n", 77 | "\n", 78 | "# load the dataset of SMS messages\n", 79 | "df = pd.read_table('SMSSPamCollection', header=None, encoding='utf-8')" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 3, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "\n", 92 | "RangeIndex: 5572 entries, 0 to 5571\n", 93 | "Data columns (total 2 columns):\n", 94 | "0 5572 non-null object\n", 95 | "1 5572 non-null object\n", 96 | "dtypes: object(2)\n", 97 | "memory usage: 87.1+ KB\n", 98 | "None\n", 99 | " 0 1\n", 100 | "0 ham Go until jurong point, crazy.. Available only ...\n", 101 | "1 ham Ok lar... Joking wif u oni...\n", 102 | "2 spam Free entry in 2 a wkly comp to win FA Cup fina...\n", 103 | "3 ham U dun say so early hor... U c already then say...\n", 104 | "4 ham Nah I don't think he goes to usf, he lives aro...\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "# print useful information about the dataset\n", 110 | "print(df.info())\n", 111 | "print(df.head())" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 4, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "ham 4825\n", 124 | "spam 747\n", 125 | "Name: 0, dtype: int64\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "# check class distribution\n", 131 | "classes = df[0]\n", 132 | "print(classes.value_counts())" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "### 2. Preprocess the Data\n", 140 | "\n", 141 | "Preprocessing the data is an essential step in natural language process. In the following cells, we will convert our class labels to binary values using the LabelEncoder from sklearn, replace email addresses, URLs, phone numbers, and other symbols by using regular expressions, remove stop words, and extract word stems. " 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stdout", 151 | "output_type": "stream", 152 | "text": [ 153 | "[0 0 1 0 0 1 0 0 1 1]\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "from sklearn.preprocessing import LabelEncoder\n", 159 | "\n", 160 | "# convert class labels to binary values, 0 = ham and 1 = spam\n", 161 | "encoder = LabelEncoder()\n", 162 | "Y = encoder.fit_transform(classes)\n", 163 | "\n", 164 | "print(Y[:10])" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 6, 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "0 Go until jurong point, crazy.. Available only ...\n", 177 | "1 Ok lar... Joking wif u oni...\n", 178 | "2 Free entry in 2 a wkly comp to win FA Cup fina...\n", 179 | "3 U dun say so early hor... U c already then say...\n", 180 | "4 Nah I don't think he goes to usf, he lives aro...\n", 181 | "5 FreeMsg Hey there darling it's been 3 week's n...\n", 182 | "6 Even my brother is not like to speak with me. ...\n", 183 | "7 As per your request 'Melle Melle (Oru Minnamin...\n", 184 | "8 WINNER!! As a valued network customer you have...\n", 185 | "9 Had your mobile 11 months or more? U R entitle...\n", 186 | "Name: 1, dtype: object\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "# store the SMS message data\n", 192 | "text_messages = df[1]\n", 193 | "print(text_messages[:10])" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "#### 2.1 Regular Expressions\n", 201 | "\n", 202 | "Some common regular expression metacharacters - copied from wikipedia\n", 203 | "\n", 204 | "**^** Matches the starting position within the string. In line-based tools, it matches the starting position of any line.\n", 205 | "\n", 206 | "**.** Matches any single character (many applications exclude newlines, and exactly which characters are considered newlines is flavor-, character-encoding-, and platform-specific, but it is safe to assume that the line feed character is included). Within POSIX bracket expressions, the dot character matches a literal dot. For example, a.c matches \"abc\", etc., but [a.c] matches only \"a\", \".\", or \"c\".\n", 207 | "\n", 208 | "**[ ]** A bracket expression. Matches a single character that is contained within the brackets. For example, [abc] matches \"a\", \"b\", or \"c\". [a-z] specifies a range which matches any lowercase letter from \"a\" to \"z\". These forms can be mixed: [abcx-z] matches \"a\", \"b\", \"c\", \"x\", \"y\", or \"z\", as does [a-cx-z].\n", 209 | "The - character is treated as a literal character if it is the last or the first (after the ^, if present) character within the brackets: [abc-], [-abc]. Note that backslash escapes are not allowed. The ] character can be included in a bracket expression if it is the first (after the ^) character: []abc].\n", 210 | "\n", 211 | "**[^ ]** Matches a single character that is not contained within the brackets. For example, [^abc] matches any character other than \"a\", \"b\", or \"c\". [^a-z] matches any single character that is not a lowercase letter from \"a\" to \"z\". Likewise, literal characters and ranges can be mixed.\n", 212 | "\n", 213 | "**$** Matches the ending position of the string or the position just before a string-ending newline. In line-based tools, it matches the ending position of any line.\n", 214 | "\n", 215 | "**( )** Defines a marked subexpression. The string matched within the parentheses can be recalled later (see the next entry, \\n). A marked subexpression is also called a block or capturing group. BRE mode requires \\( \\).\n", 216 | "\n", 217 | "**\\n** Matches what the nth marked subexpression matched, where n is a digit from 1 to 9. This construct is vaguely defined in the POSIX.2 standard. Some tools allow referencing more than nine capturing groups.\n", 218 | "\n", 219 | "**\\*** Matches the preceding element zero or more times. For example, ab*c matches \"ac\", \"abc\", \"abbbc\", etc. [xyz]* matches \"\", \"x\", \"y\", \"z\", \"zx\", \"zyx\", \"xyzzy\", and so on. (ab)* matches \"\", \"ab\", \"abab\", \"ababab\", and so on.\n", 220 | "\n", 221 | "**{m,n}** Matches the preceding element at least m and not more than n times. For example, a{3,5} matches only \"aaa\", \"aaaa\", and \"aaaaa\". This is not found in a few older instances of regexes. BRE mode requires \\{m,n\\}." 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 7, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "# use regular expressions to replace email addresses, URLs, phone numbers, other numbers\n", 231 | "\n", 232 | "# Replace email addresses with 'email'\n", 233 | "processed = text_messages.str.replace(r'^.+@[^\\.].*\\.[a-z]{2,}$',\n", 234 | " 'emailaddress')\n", 235 | "\n", 236 | "# Replace URLs with 'webaddress'\n", 237 | "processed = processed.str.replace(r'^http\\://[a-zA-Z0-9\\-\\.]+\\.[a-zA-Z]{2,3}(/\\S*)?$',\n", 238 | " 'webaddress')\n", 239 | "\n", 240 | "# Replace money symbols with 'moneysymb' (£ can by typed with ALT key + 156)\n", 241 | "processed = processed.str.replace(r'£|\\$', 'moneysymb')\n", 242 | " \n", 243 | "# Replace 10 digit phone numbers (formats include paranthesis, spaces, no spaces, dashes) with 'phonenumber'\n", 244 | "processed = processed.str.replace(r'^\\(?[\\d]{3}\\)?[\\s-]?[\\d]{3}[\\s-]?[\\d]{4}$',\n", 245 | " 'phonenumbr')\n", 246 | " \n", 247 | "# Replace numbers with 'numbr'\n", 248 | "processed = processed.str.replace(r'\\d+(\\.\\d+)?', 'numbr')" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 8, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "# Remove punctuation\n", 258 | "processed = processed.str.replace(r'[^\\w\\d\\s]', ' ')\n", 259 | "\n", 260 | "# Replace whitespace between terms with a single space\n", 261 | "processed = processed.str.replace(r'\\s+', ' ')\n", 262 | "\n", 263 | "# Remove leading and trailing whitespace\n", 264 | "processed = processed.str.replace(r'^\\s+|\\s+?$', '')" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 9, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "0 go until jurong point crazy available only in ...\n", 277 | "1 ok lar joking wif u oni\n", 278 | "2 free entry in numbr a wkly comp to win fa cup ...\n", 279 | "3 u dun say so early hor u c already then say\n", 280 | "4 nah i don t think he goes to usf he lives arou...\n", 281 | "5 freemsg hey there darling it s been numbr week...\n", 282 | "6 even my brother is not like to speak with me t...\n", 283 | "7 as per your request melle melle oru minnaminun...\n", 284 | "8 winner as a valued network customer you have b...\n", 285 | "9 had your mobile numbr months or more u r entit...\n", 286 | "10 i m gonna be home soon and i don t want to tal...\n", 287 | "11 six chances to win cash from numbr to numbr nu...\n", 288 | "12 urgent you have won a numbr week free membersh...\n", 289 | "13 i ve been searching for the right words to tha...\n", 290 | "14 i have a date on sunday with will\n", 291 | "15 xxxmobilemovieclub to use your credit click th...\n", 292 | "16 oh k i m watching here\n", 293 | "17 eh u remember how numbr spell his name yes i d...\n", 294 | "18 fine if that s the way u feel that s the way i...\n", 295 | "19 england v macedonia dont miss the goals team n...\n", 296 | "20 is that seriously how you spell his name\n", 297 | "21 i m going to try for numbr months ha ha only j...\n", 298 | "22 so pay first lar then when is da stock comin\n", 299 | "23 aft i finish my lunch then i go str down lor a...\n", 300 | "24 ffffffffff alright no way i can meet up with y...\n", 301 | "25 just forced myself to eat a slice i m really n...\n", 302 | "26 lol your always so convincing\n", 303 | "27 did you catch the bus are you frying an egg di...\n", 304 | "28 i m back amp we re packing the car now i ll le...\n", 305 | "29 ahhh work i vaguely remember that what does it...\n", 306 | " ... \n", 307 | "5542 armand says get your ass over to epsilon\n", 308 | "5543 u still havent got urself a jacket ah\n", 309 | "5544 i m taking derek amp taylor to walmart if i m ...\n", 310 | "5545 hi its in durban are you still on this number\n", 311 | "5546 ic there are a lotta childporn cars then\n", 312 | "5547 had your contract mobile numbr mnths latest mo...\n", 313 | "5548 no i was trying it all weekend v\n", 314 | "5549 you know wot people wear t shirts jumpers hat ...\n", 315 | "5550 cool what time you think you can get here\n", 316 | "5551 wen did you get so spiritual and deep that s g...\n", 317 | "5552 have a safe trip to nigeria wish you happiness...\n", 318 | "5553 hahaha use your brain dear\n", 319 | "5554 well keep in mind i ve only got enough gas for...\n", 320 | "5555 yeh indians was nice tho it did kane me off a ...\n", 321 | "5556 yes i have so that s why u texted pshew missin...\n", 322 | "5557 no i meant the calculation is the same that lt...\n", 323 | "5558 sorry i ll call later\n", 324 | "5559 if you aren t here in the next lt gt hours imm...\n", 325 | "5560 anything lor juz both of us lor\n", 326 | "5561 get me out of this dump heap my mom decided to...\n", 327 | "5562 ok lor sony ericsson salesman i ask shuhui the...\n", 328 | "5563 ard numbr like dat lor\n", 329 | "5564 why don t you wait til at least wednesday to s...\n", 330 | "5565 huh y lei\n", 331 | "5566 reminder from onumbr to get numbr pounds free ...\n", 332 | "5567 this is the numbrnd time we have tried numbr c...\n", 333 | "5568 will b going to esplanade fr home\n", 334 | "5569 pity was in mood for that so any other suggest...\n", 335 | "5570 the guy did some bitching but i acted like i d...\n", 336 | "5571 rofl its true to its name\n", 337 | "Name: 1, Length: 5572, dtype: object\n" 338 | ] 339 | } 340 | ], 341 | "source": [ 342 | "# change words to lower case - Hello, HELLO, hello are all the same word\n", 343 | "processed = processed.str.lower()\n", 344 | "print(processed)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 10, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "from nltk.corpus import stopwords\n", 354 | "\n", 355 | "# remove stop words from text messages\n", 356 | "\n", 357 | "stop_words = set(stopwords.words('english'))\n", 358 | "\n", 359 | "processed = processed.apply(lambda x: ' '.join(\n", 360 | " term for term in x.split() if term not in stop_words))" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 11, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "# Remove word stems using a Porter stemmer\n", 370 | "ps = nltk.PorterStemmer()\n", 371 | "\n", 372 | "processed = processed.apply(lambda x: ' '.join(\n", 373 | " ps.stem(term) for term in x.split()))" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": {}, 379 | "source": [ 380 | "### 3. Generating Features\n", 381 | "\n", 382 | "Feature engineering is the process of using domain knowledge of the data to create features for machine learning algorithms. In this project, the words in each text message will be our features. For this purpose, it will be necessary to tokenize each word. We will use the 1500 most common words as features." 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 12, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "from nltk.tokenize import word_tokenize\n", 392 | "\n", 393 | "# create bag-of-words\n", 394 | "all_words = []\n", 395 | "\n", 396 | "for message in processed:\n", 397 | " words = word_tokenize(message)\n", 398 | " for w in words:\n", 399 | " all_words.append(w)\n", 400 | " \n", 401 | "all_words = nltk.FreqDist(all_words)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 13, 407 | "metadata": {}, 408 | "outputs": [ 409 | { 410 | "name": "stdout", 411 | "output_type": "stream", 412 | "text": [ 413 | "Number of words: 6562\n", 414 | "Most common words: [(u'numbr', 2961), (u'u', 1207), (u'call', 679), (u'go', 456), (u'get', 451), (u'ur', 391), (u'gt', 318), (u'lt', 316), (u'come', 304), (u'ok', 293), (u'free', 284), (u'day', 276), (u'know', 275), (u'love', 266), (u'like', 261)]\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "# print the total number of words and the 15 most common words\n", 420 | "print('Number of words: {}'.format(len(all_words)))\n", 421 | "print('Most common words: {}'.format(all_words.most_common(15)))" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 14, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "# use the 1500 most common words as features\n", 431 | "word_features = list(all_words.keys())[:1500]" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 15, 437 | "metadata": {}, 438 | "outputs": [ 439 | { 440 | "name": "stdout", 441 | "output_type": "stream", 442 | "text": [ 443 | "avail\n", 444 | "buffet\n", 445 | "world\n", 446 | "great\n" 447 | ] 448 | } 449 | ], 450 | "source": [ 451 | "# The find_features function will determine which of the 1500 word features are contained in the review\n", 452 | "def find_features(message):\n", 453 | " words = word_tokenize(message)\n", 454 | " features = {}\n", 455 | " for word in word_features:\n", 456 | " features[word] = (word in words)\n", 457 | "\n", 458 | " return features\n", 459 | "\n", 460 | "# Lets see an example!\n", 461 | "features = find_features(processed[0])\n", 462 | "for key, value in features.items():\n", 463 | " if value == True:\n", 464 | " print key" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 16, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "# Now lets do it for all the messages\n", 474 | "messages = zip(processed, Y)\n", 475 | "\n", 476 | "# define a seed for reproducibility\n", 477 | "seed = 1\n", 478 | "np.random.seed = seed\n", 479 | "np.random.shuffle(messages)\n", 480 | "\n", 481 | "# call find_features function for each SMS message\n", 482 | "featuresets = [(find_features(text), label) for (text, label) in messages]" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": 17, 488 | "metadata": {}, 489 | "outputs": [], 490 | "source": [ 491 | "# we can split the featuresets into training and testing datasets using sklearn\n", 492 | "from sklearn import model_selection\n", 493 | "\n", 494 | "# split the data into training and testing datasets\n", 495 | "training, testing = model_selection.train_test_split(featuresets, test_size = 0.25, random_state=seed)" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 18, 501 | "metadata": {}, 502 | "outputs": [ 503 | { 504 | "name": "stdout", 505 | "output_type": "stream", 506 | "text": [ 507 | "4179\n", 508 | "1393\n" 509 | ] 510 | } 511 | ], 512 | "source": [ 513 | "print(len(training))\n", 514 | "print(len(testing))" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": {}, 520 | "source": [ 521 | "### 4. Scikit-Learn Classifiers with NLTK\n", 522 | "\n", 523 | "Now that we have our dataset, we can start building algorithms! Let's start with a simple linear support vector classifier, then expand to other algorithms. We'll need to import each algorithm we plan on using from sklearn. We also need to import some performance metrics, such as accuracy_score and classification_report." 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 19, 529 | "metadata": {}, 530 | "outputs": [ 531 | { 532 | "name": "stdout", 533 | "output_type": "stream", 534 | "text": [ 535 | "SVC Accuracy: 96.1234745154\n" 536 | ] 537 | } 538 | ], 539 | "source": [ 540 | "# We can use sklearn algorithms in NLTK\n", 541 | "from nltk.classify.scikitlearn import SklearnClassifier\n", 542 | "from sklearn.svm import SVC\n", 543 | "\n", 544 | "model = SklearnClassifier(SVC(kernel = 'linear'))\n", 545 | "\n", 546 | "# train the model on the training data\n", 547 | "model.train(training)\n", 548 | "\n", 549 | "# and test on the testing dataset!\n", 550 | "accuracy = nltk.classify.accuracy(model, testing)*100\n", 551 | "print(\"SVC Accuracy: {}\".format(accuracy))" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 20, 557 | "metadata": {}, 558 | "outputs": [ 559 | { 560 | "name": "stdout", 561 | "output_type": "stream", 562 | "text": [ 563 | "K Nearest Neighbors Accuracy: 94.0416367552\n", 564 | "Decision Tree Accuracy: 95.2620244078\n", 565 | "Random Forest Accuracy: 95.6927494616\n", 566 | "Logistic Regression Accuracy: 95.9798994975\n", 567 | "SGD Classifier Accuracy: 95.9798994975\n", 568 | "Naive Bayes Accuracy: 96.2670495334\n", 569 | "SVM Linear Accuracy: 96.1234745154\n" 570 | ] 571 | } 572 | ], 573 | "source": [ 574 | "from sklearn.neighbors import KNeighborsClassifier\n", 575 | "from sklearn.tree import DecisionTreeClassifier\n", 576 | "from sklearn.ensemble import RandomForestClassifier\n", 577 | "from sklearn.linear_model import LogisticRegression, SGDClassifier\n", 578 | "from sklearn.naive_bayes import MultinomialNB\n", 579 | "from sklearn.svm import SVC\n", 580 | "from sklearn.metrics import classification_report, accuracy_score, confusion_matrix\n", 581 | "\n", 582 | "# Define models to train\n", 583 | "names = [\"K Nearest Neighbors\", \"Decision Tree\", \"Random Forest\", \"Logistic Regression\", \"SGD Classifier\",\n", 584 | " \"Naive Bayes\", \"SVM Linear\"]\n", 585 | "\n", 586 | "classifiers = [\n", 587 | " KNeighborsClassifier(),\n", 588 | " DecisionTreeClassifier(),\n", 589 | " RandomForestClassifier(),\n", 590 | " LogisticRegression(),\n", 591 | " SGDClassifier(max_iter = 100),\n", 592 | " MultinomialNB(),\n", 593 | " SVC(kernel = 'linear')\n", 594 | "]\n", 595 | "\n", 596 | "models = zip(names, classifiers)\n", 597 | "\n", 598 | "for name, model in models:\n", 599 | " nltk_model = SklearnClassifier(model)\n", 600 | " nltk_model.train(training)\n", 601 | " accuracy = nltk.classify.accuracy(nltk_model, testing)*100\n", 602 | " print(\"{} Accuracy: {}\".format(name, accuracy))" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 24, 608 | "metadata": {}, 609 | "outputs": [ 610 | { 611 | "name": "stdout", 612 | "output_type": "stream", 613 | "text": [ 614 | "Voting Classifier: Accuracy: 96.1234745154\n" 615 | ] 616 | } 617 | ], 618 | "source": [ 619 | "# Ensemble methods - Voting classifier\n", 620 | "from sklearn.ensemble import VotingClassifier\n", 621 | "\n", 622 | "names = [\"K Nearest Neighbors\", \"Decision Tree\", \"Random Forest\", \"Logistic Regression\", \"SGD Classifier\",\n", 623 | " \"Naive Bayes\", \"SVM Linear\"]\n", 624 | "\n", 625 | "classifiers = [\n", 626 | " KNeighborsClassifier(),\n", 627 | " DecisionTreeClassifier(),\n", 628 | " RandomForestClassifier(),\n", 629 | " LogisticRegression(),\n", 630 | " SGDClassifier(max_iter = 100),\n", 631 | " MultinomialNB(),\n", 632 | " SVC(kernel = 'linear')\n", 633 | "]\n", 634 | "\n", 635 | "models = zip(names, classifiers)\n", 636 | "\n", 637 | "nltk_ensemble = SklearnClassifier(VotingClassifier(estimators = models, voting = 'hard', n_jobs = -1))\n", 638 | "nltk_ensemble.train(training)\n", 639 | "accuracy = nltk.classify.accuracy(nltk_model, testing)*100\n", 640 | "print(\"Voting Classifier: Accuracy: {}\".format(accuracy))" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": null, 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [ 649 | "# make class label prediction for testing set\n", 650 | "txt_features, labels = zip(*testing)\n", 651 | "\n", 652 | "prediction = nltk_ensemble.classify_many(txt_features)" 653 | ] 654 | }, 655 | { 656 | "cell_type": "code", 657 | "execution_count": 26, 658 | "metadata": {}, 659 | "outputs": [ 660 | { 661 | "name": "stdout", 662 | "output_type": "stream", 663 | "text": [ 664 | " precision recall f1-score support\n", 665 | "\n", 666 | " 0 0.96 0.99 0.98 1214\n", 667 | " 1 0.92 0.75 0.83 179\n", 668 | "\n", 669 | "avg / total 0.96 0.96 0.96 1393\n", 670 | "\n" 671 | ] 672 | }, 673 | { 674 | "data": { 675 | "text/html": [ 676 | "
\n", 677 | "\n", 690 | "\n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | "
predicted
hamspam
actualham120311
spam45134
\n", 718 | "
" 719 | ], 720 | "text/plain": [ 721 | " predicted \n", 722 | " ham spam\n", 723 | "actual ham 1203 11\n", 724 | " spam 45 134" 725 | ] 726 | }, 727 | "execution_count": 26, 728 | "metadata": {}, 729 | "output_type": "execute_result" 730 | } 731 | ], 732 | "source": [ 733 | "# print a confusion matrix and a classification report\n", 734 | "print(classification_report(labels, prediction))\n", 735 | "\n", 736 | "pd.DataFrame(\n", 737 | " confusion_matrix(labels, prediction),\n", 738 | " index = [['actual', 'actual'], ['ham', 'spam']],\n", 739 | " columns = [['predicted', 'predicted'], ['ham', 'spam']])" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": null, 745 | "metadata": {}, 746 | "outputs": [], 747 | "source": [] 748 | } 749 | ], 750 | "metadata": { 751 | "kernelspec": { 752 | "display_name": "Python [default]", 753 | "language": "python", 754 | "name": "python2" 755 | }, 756 | "language_info": { 757 | "codemirror_mode": { 758 | "name": "ipython", 759 | "version": 2 760 | }, 761 | "file_extension": ".py", 762 | "mimetype": "text/x-python", 763 | "name": "python", 764 | "nbconvert_exporter": "python", 765 | "pygments_lexer": "ipython2", 766 | "version": "2.7.13" 767 | } 768 | }, 769 | "nbformat": 4, 770 | "nbformat_minor": 2 771 | } 772 | --------------------------------------------------------------------------------