├── PyTorchTweetTextClassification.ipynb └── README.md /PyTorchTweetTextClassification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\n", 8 | "***\n", 9 | "\n", 10 | "
\n", 11 | "Let us look at how we can implement text classification using PyTorch\n", 12 | "
\n", 13 | "\n", 14 | "The dataset is from the Tweet Sentiment Extraction challenge from Kaggle(https://www.kaggle.com/c/tweet-sentiment-extraction/overview) \n", 15 | "
\n", 16 | "\n", 17 | "We would implement text classification using a simple embedding bag of words using PyTorch on tweet data to classify tweets as \"positive\",\"negative\" or \"neutral\"\n", 18 | "\n", 19 | "\n", 20 | "
\n", 21 | "\n", 22 | "***\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "\n", 32 | "import os\n", 33 | "import re\n", 34 | "import shutil\n", 35 | "import string\n", 36 | "\n", 37 | "\n", 38 | "from collections import Counter\n", 39 | "\n", 40 | "\n", 41 | "import pandas as pd\n", 42 | "import numpy as np\n", 43 | "\n", 44 | "import sklearn\n", 45 | "\n", 46 | "\n", 47 | "from sklearn.model_selection import train_test_split\n", 48 | "\n", 49 | "\n" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "****\n", 57 | "Let us define methods to pre-process the review data\n", 58 | "****" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "def remove_emoji(text):\n", 68 | " emoji_pattern = re.compile(\"[\"\n", 69 | " u\"\\U0001F600-\\U0001F64F\" # emoticons\n", 70 | " u\"\\U0001F300-\\U0001F5FF\" # symbols & pictographs\n", 71 | " u\"\\U0001F680-\\U0001F6FF\" # transport & map symbols\n", 72 | " u\"\\U0001F1E0-\\U0001F1FF\" # flags (iOS)\n", 73 | " u\"\\U00002702-\\U000027B0\"\n", 74 | " u\"\\U000024C2-\\U0001F251\"\n", 75 | " \"]+\", flags=re.UNICODE)\n", 76 | " return emoji_pattern.sub(r'', text)\n", 77 | "\n", 78 | "def remove_url(text): \n", 79 | " url_pattern = re.compile('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')\n", 80 | " return url_pattern.sub(r'', text)\n", 81 | " # converting return value from list to string\n", 82 | "\n", 83 | "\n", 84 | "\n", 85 | "def clean_text(text ): \n", 86 | " delete_dict = {sp_character: '' for sp_character in string.punctuation} \n", 87 | " delete_dict[' '] = ' ' \n", 88 | " table = str.maketrans(delete_dict)\n", 89 | " text1 = text.translate(table)\n", 90 | " #print('cleaned:'+text1)\n", 91 | " textArr= text1.split()\n", 92 | " text2 = ' '.join([w for w in textArr if ( not w.isdigit() and ( not w.isdigit() and len(w)>2))]) \n", 93 | " \n", 94 | " return text2.lower()\n", 95 | "\n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "def get_sentiment(sentiment):\n", 105 | " if sentiment == 'positive':\n", 106 | " return 2\n", 107 | " elif sentiment == 'negative':\n", 108 | " return 1\n", 109 | " else:\n", 110 | " return 0\n", 111 | " " 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 4, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "-------Train data--------\n", 124 | "neutral 10704\n", 125 | "positive 8375\n", 126 | "negative 7673\n", 127 | "Name: sentiment, dtype: int64\n", 128 | "26752\n", 129 | "-------------------------\n", 130 | "-------Test data--------\n", 131 | "neutral 1376\n", 132 | "positive 1075\n", 133 | "negative 983\n", 134 | "Name: sentiment, dtype: int64\n", 135 | "3434\n", 136 | "-------------------------\n", 137 | "Train Max Sentence Length :33\n", 138 | "Test Max Sentence Length :32\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "train_data= pd.read_csv(\"C:\\\\TweetSenitment\\\\train.csv\")\n", 144 | "train_data.dropna(axis = 0, how ='any',inplace=True) \n", 145 | "train_data['Num_words_text'] = train_data['text'].apply(lambda x:len(str(x).split())) \n", 146 | "mask = train_data['Num_words_text'] >2\n", 147 | "train_data = train_data[mask]\n", 148 | "print('-------Train data--------')\n", 149 | "print(train_data['sentiment'].value_counts())\n", 150 | "print(len(train_data))\n", 151 | "print('-------------------------')\n", 152 | "max_train_sentence_length = train_data['Num_words_text'].max()\n", 153 | "\n", 154 | "\n", 155 | "train_data['text'] = train_data['text'].apply(remove_emoji)\n", 156 | "train_data['text'] = train_data['text'].apply(remove_url)\n", 157 | "train_data['text'] = train_data['text'].apply(clean_text)\n", 158 | "\n", 159 | "train_data['label'] = train_data['sentiment'].apply(get_sentiment)\n", 160 | "\n", 161 | "test_data= pd.read_csv(\"C:\\\\TweetSenitment\\\\test.csv\")\n", 162 | "test_data.dropna(axis = 0, how ='any',inplace=True) \n", 163 | "test_data['Num_words_text'] = test_data['text'].apply(lambda x:len(str(x).split())) \n", 164 | "\n", 165 | "max_test_sentence_length = test_data['Num_words_text'].max()\n", 166 | "\n", 167 | "mask = test_data['Num_words_text'] >2\n", 168 | "test_data = test_data[mask]\n", 169 | "\n", 170 | "print('-------Test data--------')\n", 171 | "print(test_data['sentiment'].value_counts())\n", 172 | "print(len(test_data))\n", 173 | "print('-------------------------')\n", 174 | "\n", 175 | "test_data['text'] = test_data['text'].apply(remove_emoji)\n", 176 | "test_data['text'] = test_data['text'].apply(remove_url)\n", 177 | "test_data['text'] = test_data['text'].apply(clean_text)\n", 178 | "\n", 179 | "test_data['label'] = test_data['sentiment'].apply(get_sentiment)\n", 180 | "\n", 181 | "print('Train Max Sentence Length :'+str(max_train_sentence_length))\n", 182 | "print('Test Max Sentence Length :'+str(max_test_sentence_length))\n", 183 | "\n", 184 | "\n" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 5, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "data": { 194 | "text/html": [ 195 | "
\n", 196 | "\n", 209 | "\n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | "
textIDtextselected_textsentimentNum_words_textlabel
0cb774db0d1have responded were goingI`d have responded, if I were goingneutral70
1549e992a42sooo sad will miss you here san diegoSooo SADnegative101
2088c60f138boss bullyingbullying menegative51
39642c003efwhat interview leave aloneleave me alonenegative51
4358bd9e861sons why couldnt they put them the releases al...Sons of ****,negative141
528b57f3990some shameless plugging for the best rangers f...http://www.dothebouncy.com/smf - some shameles...neutral120
66e0c6d75b12am feedings for the baby are fun when all smi...funpositive142
8e050245fbdboth youBoth of youneutral30
9fc2cbefa9djourney wow just became cooler hehe that possibleWow... u just became cooler.positive102
102339a9b08bmuch love hopeful reckon the chances are minim...as much as i love to be hopeful, i reckon the ...neutral230
\n", 314 | "
" 315 | ], 316 | "text/plain": [ 317 | " textID text \\\n", 318 | "0 cb774db0d1 have responded were going \n", 319 | "1 549e992a42 sooo sad will miss you here san diego \n", 320 | "2 088c60f138 boss bullying \n", 321 | "3 9642c003ef what interview leave alone \n", 322 | "4 358bd9e861 sons why couldnt they put them the releases al... \n", 323 | "5 28b57f3990 some shameless plugging for the best rangers f... \n", 324 | "6 6e0c6d75b1 2am feedings for the baby are fun when all smi... \n", 325 | "8 e050245fbd both you \n", 326 | "9 fc2cbefa9d journey wow just became cooler hehe that possible \n", 327 | "10 2339a9b08b much love hopeful reckon the chances are minim... \n", 328 | "\n", 329 | " selected_text sentiment \\\n", 330 | "0 I`d have responded, if I were going neutral \n", 331 | "1 Sooo SAD negative \n", 332 | "2 bullying me negative \n", 333 | "3 leave me alone negative \n", 334 | "4 Sons of ****, negative \n", 335 | "5 http://www.dothebouncy.com/smf - some shameles... neutral \n", 336 | "6 fun positive \n", 337 | "8 Both of you neutral \n", 338 | "9 Wow... u just became cooler. positive \n", 339 | "10 as much as i love to be hopeful, i reckon the ... neutral \n", 340 | "\n", 341 | " Num_words_text label \n", 342 | "0 7 0 \n", 343 | "1 10 1 \n", 344 | "2 5 1 \n", 345 | "3 5 1 \n", 346 | "4 14 1 \n", 347 | "5 12 0 \n", 348 | "6 14 2 \n", 349 | "8 3 0 \n", 350 | "9 10 2 \n", 351 | "10 23 0 " 352 | ] 353 | }, 354 | "execution_count": 5, 355 | "metadata": {}, 356 | "output_type": "execute_result" 357 | } 358 | ], 359 | "source": [ 360 | "train_data.head(10)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 6, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "text/html": [ 371 | "
\n", 372 | "\n", 385 | "\n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | "
textIDtextsentimentNum_words_textlabel
0f87dea47dblast session the dayneutral60
196d74cb729shanghai also really exciting precisely skyscr...positive152
2eee518ae67recession hit veronique branquinho she has qui...negative131
433987a8ee5likepositive52
5726e501993thats great weee visitorspositive42
6261932614ethink everyone hates here lolnegative81
7afa11da83fsoooooo wish could but school and myspace comp...negative131
8e64208b4efand within short time the last clue all themneutral120
937bcad24cawhat did you get day alright havent done anyth...neutral180
1024c92644a4bike was put holdshould have known that argh t...negative121
\n", 479 | "
" 480 | ], 481 | "text/plain": [ 482 | " textID text sentiment \\\n", 483 | "0 f87dea47db last session the day neutral \n", 484 | "1 96d74cb729 shanghai also really exciting precisely skyscr... positive \n", 485 | "2 eee518ae67 recession hit veronique branquinho she has qui... negative \n", 486 | "4 33987a8ee5 like positive \n", 487 | "5 726e501993 thats great weee visitors positive \n", 488 | "6 261932614e think everyone hates here lol negative \n", 489 | "7 afa11da83f soooooo wish could but school and myspace comp... negative \n", 490 | "8 e64208b4ef and within short time the last clue all them neutral \n", 491 | "9 37bcad24ca what did you get day alright havent done anyth... neutral \n", 492 | "10 24c92644a4 bike was put holdshould have known that argh t... negative \n", 493 | "\n", 494 | " Num_words_text label \n", 495 | "0 6 0 \n", 496 | "1 15 2 \n", 497 | "2 13 1 \n", 498 | "4 5 2 \n", 499 | "5 4 2 \n", 500 | "6 8 1 \n", 501 | "7 13 1 \n", 502 | "8 12 0 \n", 503 | "9 18 0 \n", 504 | "10 12 1 " 505 | ] 506 | }, 507 | "execution_count": 6, 508 | "metadata": {}, 509 | "output_type": "execute_result" 510 | } 511 | ], 512 | "source": [ 513 | "test_data.head(10)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "***\n", 521 | "Let us create our train,valid and test datasets\n", 522 | "***" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 7, 528 | "metadata": {}, 529 | "outputs": [ 530 | { 531 | "name": "stdout", 532 | "output_type": "stream", 533 | "text": [ 534 | "Train data len:21401\n", 535 | "Class distributionCounter({0: 8563, 2: 6700, 1: 6138})\n", 536 | "Valid data len:5351\n", 537 | "Class distributionCounter({0: 2141, 2: 1675, 1: 1535})\n", 538 | "Test data len:3434\n", 539 | "Class distributionCounter({0: 1376, 2: 1075, 1: 983})\n" 540 | ] 541 | } 542 | ], 543 | "source": [ 544 | "X_train, X_valid, Y_train, Y_valid= train_test_split(train_data['text'].tolist(),\\\n", 545 | " train_data['label'].tolist(),\\\n", 546 | " test_size=0.2,\\\n", 547 | " stratify = train_data['label'].tolist(),\\\n", 548 | " random_state=0)\n", 549 | "\n", 550 | "\n", 551 | "print('Train data len:'+str(len(X_train)))\n", 552 | "print('Class distribution'+str(Counter(Y_train)))\n", 553 | "\n", 554 | "\n", 555 | "print('Valid data len:'+str(len(X_valid)))\n", 556 | "print('Class distribution'+ str(Counter(Y_valid)))\n", 557 | "\n", 558 | "print('Test data len:'+str(len(test_data['text'].tolist())))\n", 559 | "print('Class distribution'+ str(Counter(test_data['label'].tolist())))\n", 560 | "\n", 561 | "\n", 562 | "train_dat =list(zip(Y_train,X_train))\n", 563 | "valid_dat =list(zip(Y_valid,X_valid))\n", 564 | "test_dat=list(zip(test_data['label'].tolist(),test_data['text'].tolist()))\n" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": 8, 570 | "metadata": {}, 571 | "outputs": [], 572 | "source": [ 573 | "import torch\n", 574 | "from torch.utils.data import DataLoader\n", 575 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 576 | ] 577 | }, 578 | { 579 | "cell_type": "markdown", 580 | "metadata": {}, 581 | "source": [ 582 | "***\n", 583 | "Let us create our vocabulary on train data\n", 584 | "***" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 9, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "from torchtext.data.utils import get_tokenizer\n", 594 | "from torchtext.vocab import build_vocab_from_iterator\n", 595 | "\n", 596 | "tokenizer = get_tokenizer('basic_english')\n", 597 | "train_iter = train_dat\n", 598 | "def yield_tokens(data_iter):\n", 599 | " for _, text in data_iter:\n", 600 | " yield tokenizer(text)\n", 601 | "\n", 602 | "vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[\"\"])\n", 603 | "vocab.set_default_index(vocab[\"\"])\n" 604 | ] 605 | }, 606 | { 607 | "cell_type": "markdown", 608 | "metadata": {}, 609 | "source": [ 610 | "Let us define our text and label preprocessing pipleines" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 10, 616 | "metadata": {}, 617 | "outputs": [], 618 | "source": [ 619 | "text_pipeline = lambda x: vocab(tokenizer(x))\n", 620 | "label_pipeline = lambda x: int(x) " 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": 11, 626 | "metadata": {}, 627 | "outputs": [ 628 | { 629 | "data": { 630 | "text/plain": [ 631 | "[62, 0, 1, 0, 12881]" 632 | ] 633 | }, 634 | "execution_count": 11, 635 | "metadata": {}, 636 | "output_type": "execute_result" 637 | } 638 | ], 639 | "source": [ 640 | "text_pipeline('here is the an example')\n" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": 12, 646 | "metadata": {}, 647 | "outputs": [ 648 | { 649 | "data": { 650 | "text/plain": [ 651 | "1" 652 | ] 653 | }, 654 | "execution_count": 12, 655 | "metadata": {}, 656 | "output_type": "execute_result" 657 | } 658 | ], 659 | "source": [ 660 | "label_pipeline('1')" 661 | ] 662 | }, 663 | { 664 | "cell_type": "markdown", 665 | "metadata": {}, 666 | "source": [ 667 | "Let us define our batch collation function" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 13, 673 | "metadata": {}, 674 | "outputs": [], 675 | "source": [ 676 | "\n", 677 | "\n", 678 | "def collate_batch(batch):\n", 679 | " label_list, text_list, offsets = [], [], [0]\n", 680 | " for (_label, _text) in batch:\n", 681 | " label_list.append(label_pipeline(_label))\n", 682 | " processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)\n", 683 | " text_list.append(processed_text)\n", 684 | " offsets.append(processed_text.size(0))\n", 685 | " label_list = torch.tensor(label_list, dtype=torch.int64)\n", 686 | " offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)\n", 687 | " text_list = torch.cat(text_list)\n", 688 | " return label_list.to(device), text_list.to(device), offsets.to(device)\n", 689 | "\n", 690 | "#train_iter =train_dat\n", 691 | "#dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)" 692 | ] 693 | }, 694 | { 695 | "cell_type": "markdown", 696 | "metadata": {}, 697 | "source": [ 698 | "Let us define our text classification model" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 14, 704 | "metadata": {}, 705 | "outputs": [], 706 | "source": [ 707 | "from torch import nn\n", 708 | "import torch.nn.functional as F\n", 709 | "\n", 710 | "class TextClassificationModel(nn.Module):\n", 711 | "\n", 712 | " def __init__(self, vocab_size, embed_dim, num_class):\n", 713 | " super(TextClassificationModel, self).__init__()\n", 714 | " self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)\n", 715 | " self.fc1 = nn.Linear(embed_dim,64)\n", 716 | " self.fc2 = nn.Linear(64,16)\n", 717 | " self.fc3 = nn.Linear(16, num_class)\n", 718 | " self.init_weights()\n", 719 | "\n", 720 | " def init_weights(self):\n", 721 | " initrange = 0.5\n", 722 | " self.embedding.weight.data.uniform_(-initrange, initrange)\n", 723 | " self.fc1.weight.data.uniform_(-initrange, initrange)\n", 724 | " self.fc1.bias.data.zero_()\n", 725 | " self.fc2.weight.data.uniform_(-initrange, initrange)\n", 726 | " self.fc2.bias.data.zero_()\n", 727 | " self.fc3.weight.data.uniform_(-initrange, initrange)\n", 728 | " self.fc3.bias.data.zero_()\n", 729 | "\n", 730 | " def forward(self, text, offsets):\n", 731 | " embedded = self.embedding(text, offsets)\n", 732 | " x = F.relu(self.fc1(embedded))\n", 733 | " x = F.relu(self.fc2(x))\n", 734 | " x = self.fc3(x)\n", 735 | " return x" 736 | ] 737 | }, 738 | { 739 | "cell_type": "markdown", 740 | "metadata": {}, 741 | "source": [ 742 | "Let us create an object of our text classification class" 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": 15, 748 | "metadata": {}, 749 | "outputs": [ 750 | { 751 | "name": "stdout", 752 | "output_type": "stream", 753 | "text": [ 754 | "3\n" 755 | ] 756 | } 757 | ], 758 | "source": [ 759 | "train_iter1 = train_dat\n", 760 | "num_class = len(set([label for (label, text) in train_iter1]))\n", 761 | "print(num_class)\n", 762 | "vocab_size = len(vocab)\n", 763 | "emsize = 128\n", 764 | "model = TextClassificationModel(vocab_size, emsize, num_class).to(device)" 765 | ] 766 | }, 767 | { 768 | "cell_type": "markdown", 769 | "metadata": {}, 770 | "source": [ 771 | "Let us define our train and evaluate methods" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": 16, 777 | "metadata": {}, 778 | "outputs": [], 779 | "source": [ 780 | "import time\n", 781 | "\n", 782 | "def train(dataloader):\n", 783 | " model.train()\n", 784 | " total_acc, total_count = 0, 0\n", 785 | " log_interval = 500\n", 786 | " start_time = time.time()\n", 787 | "\n", 788 | " for idx, (label, text, offsets) in enumerate(dataloader):\n", 789 | " optimizer.zero_grad()\n", 790 | " predited_label = model(text, offsets)\n", 791 | " loss = criterion(predited_label, label)\n", 792 | " loss.backward()\n", 793 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n", 794 | " optimizer.step()\n", 795 | " total_acc += (predited_label.argmax(1) == label).sum().item()\n", 796 | " total_count += label.size(0)\n", 797 | " if idx % log_interval == 0 and idx > 0:\n", 798 | " elapsed = time.time() - start_time\n", 799 | " print('| epoch {:3d} | {:5d}/{:5d} batches '\n", 800 | " '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),\n", 801 | " total_acc/total_count))\n", 802 | " total_acc, total_count = 0, 0\n", 803 | " start_time = time.time()\n", 804 | "\n", 805 | "def evaluate(dataloader):\n", 806 | " model.eval()\n", 807 | " total_acc, total_count = 0, 0\n", 808 | "\n", 809 | " with torch.no_grad():\n", 810 | " for idx, (label, text, offsets) in enumerate(dataloader):\n", 811 | " predited_label = model(text, offsets)\n", 812 | " loss = criterion(predited_label, label)\n", 813 | " total_acc += (predited_label.argmax(1) == label).sum().item()\n", 814 | " total_count += label.size(0)\n", 815 | " return total_acc/total_count" 816 | ] 817 | }, 818 | { 819 | "cell_type": "markdown", 820 | "metadata": {}, 821 | "source": [ 822 | "Let us create dataloaders for text,train and validation data iterators and then train our model" 823 | ] 824 | }, 825 | { 826 | "cell_type": "code", 827 | "execution_count": 17, 828 | "metadata": {}, 829 | "outputs": [ 830 | { 831 | "name": "stdout", 832 | "output_type": "stream", 833 | "text": [ 834 | "| epoch 1 | 500/ 1338 batches | accuracy 0.444\n", 835 | "| epoch 1 | 1000/ 1338 batches | accuracy 0.546\n", 836 | "-----------------------------------------------------------\n", 837 | "| end of epoch 1 | time: 4.41s | valid accuracy 0.627 \n", 838 | "-----------------------------------------------------------\n", 839 | "| epoch 2 | 500/ 1338 batches | accuracy 0.625\n", 840 | "| epoch 2 | 1000/ 1338 batches | accuracy 0.619\n", 841 | "-----------------------------------------------------------\n", 842 | "| end of epoch 2 | time: 3.44s | valid accuracy 0.562 \n", 843 | "-----------------------------------------------------------\n", 844 | "| epoch 3 | 500/ 1338 batches | accuracy 0.681\n", 845 | "| epoch 3 | 1000/ 1338 batches | accuracy 0.711\n", 846 | "-----------------------------------------------------------\n", 847 | "| end of epoch 3 | time: 3.26s | valid accuracy 0.677 \n", 848 | "-----------------------------------------------------------\n", 849 | "| epoch 4 | 500/ 1338 batches | accuracy 0.733\n", 850 | "| epoch 4 | 1000/ 1338 batches | accuracy 0.735\n", 851 | "-----------------------------------------------------------\n", 852 | "| end of epoch 4 | time: 3.30s | valid accuracy 0.671 \n", 853 | "-----------------------------------------------------------\n", 854 | "| epoch 5 | 500/ 1338 batches | accuracy 0.748\n", 855 | "| epoch 5 | 1000/ 1338 batches | accuracy 0.755\n", 856 | "-----------------------------------------------------------\n", 857 | "| end of epoch 5 | time: 3.24s | valid accuracy 0.675 \n", 858 | "-----------------------------------------------------------\n", 859 | "| epoch 6 | 500/ 1338 batches | accuracy 0.765\n", 860 | "| epoch 6 | 1000/ 1338 batches | accuracy 0.755\n", 861 | "-----------------------------------------------------------\n", 862 | "| end of epoch 6 | time: 3.24s | valid accuracy 0.676 \n", 863 | "-----------------------------------------------------------\n", 864 | "| epoch 7 | 500/ 1338 batches | accuracy 0.760\n", 865 | "| epoch 7 | 1000/ 1338 batches | accuracy 0.761\n", 866 | "-----------------------------------------------------------\n", 867 | "| end of epoch 7 | time: 3.37s | valid accuracy 0.676 \n", 868 | "-----------------------------------------------------------\n", 869 | "| epoch 8 | 500/ 1338 batches | accuracy 0.758\n", 870 | "| epoch 8 | 1000/ 1338 batches | accuracy 0.762\n", 871 | "-----------------------------------------------------------\n", 872 | "| end of epoch 8 | time: 3.70s | valid accuracy 0.676 \n", 873 | "-----------------------------------------------------------\n", 874 | "| epoch 9 | 500/ 1338 batches | accuracy 0.758\n", 875 | "| epoch 9 | 1000/ 1338 batches | accuracy 0.762\n", 876 | "-----------------------------------------------------------\n", 877 | "| end of epoch 9 | time: 3.49s | valid accuracy 0.676 \n", 878 | "-----------------------------------------------------------\n", 879 | "| epoch 10 | 500/ 1338 batches | accuracy 0.753\n", 880 | "| epoch 10 | 1000/ 1338 batches | accuracy 0.761\n", 881 | "-----------------------------------------------------------\n", 882 | "| end of epoch 10 | time: 3.32s | valid accuracy 0.676 \n", 883 | "-----------------------------------------------------------\n" 884 | ] 885 | } 886 | ], 887 | "source": [ 888 | "from torch.utils.data.dataset import random_split\n", 889 | "from torchtext.data.functional import to_map_style_dataset\n", 890 | "# Hyperparameters\n", 891 | "EPOCHS = 10 # epoch\n", 892 | "LR =10 # learning rate\n", 893 | "BATCH_SIZE = 16 # batch size for training\n", 894 | "\n", 895 | "criterion = torch.nn.CrossEntropyLoss()\n", 896 | "optimizer = torch.optim.SGD(model.parameters(), lr=LR)\n", 897 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)\n", 898 | "total_accu = None\n", 899 | "\n", 900 | "train_iter2 = train_dat\n", 901 | "test_iter2 =test_dat \n", 902 | "valid_iter2= valid_dat\n", 903 | "\n", 904 | "\n", 905 | "\n", 906 | "\n", 907 | "train_dataloader = DataLoader(train_iter2, batch_size=BATCH_SIZE,\n", 908 | " shuffle=True, collate_fn=collate_batch)\n", 909 | "valid_dataloader = DataLoader(valid_iter2, batch_size=BATCH_SIZE,\n", 910 | " shuffle=True, collate_fn=collate_batch)\n", 911 | "test_dataloader = DataLoader(test_iter2, batch_size=BATCH_SIZE,\n", 912 | " shuffle=True, collate_fn=collate_batch)\n", 913 | "\n", 914 | "for epoch in range(1, EPOCHS + 1):\n", 915 | " epoch_start_time = time.time()\n", 916 | " train(train_dataloader)\n", 917 | " accu_val = evaluate(valid_dataloader)\n", 918 | " if total_accu is not None and total_accu > accu_val:\n", 919 | " scheduler.step()\n", 920 | " else:\n", 921 | " total_accu = accu_val\n", 922 | " print('-' * 59)\n", 923 | " print('| end of epoch {:3d} | time: {:5.2f}s | '\n", 924 | " 'valid accuracy {:8.3f} '.format(epoch,\n", 925 | " time.time() - epoch_start_time,\n", 926 | " accu_val))\n", 927 | " print('-' * 59)" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "execution_count": 18, 933 | "metadata": {}, 934 | "outputs": [ 935 | { 936 | "name": "stdout", 937 | "output_type": "stream", 938 | "text": [ 939 | "Checking the results of test dataset.\n", 940 | "test accuracy 0.692\n" 941 | ] 942 | } 943 | ], 944 | "source": [ 945 | "print('Checking the results of test dataset.')\n", 946 | "accu_test = evaluate(test_dataloader)\n", 947 | "print('test accuracy {:8.3f}'.format(accu_test))" 948 | ] 949 | }, 950 | { 951 | "cell_type": "markdown", 952 | "metadata": {}, 953 | "source": [ 954 | "Test model on text" 955 | ] 956 | }, 957 | { 958 | "cell_type": "code", 959 | "execution_count": 19, 960 | "metadata": {}, 961 | "outputs": [ 962 | { 963 | "name": "stdout", 964 | "output_type": "stream", 965 | "text": [ 966 | "This is a Neutral tweet\n" 967 | ] 968 | } 969 | ], 970 | "source": [ 971 | "sentiment_label = {2:\"Positive\",\n", 972 | " 1: \"Negative\",\n", 973 | " 0: \"Neutral\"\n", 974 | " }\n", 975 | "\n", 976 | "def predict(text, text_pipeline):\n", 977 | " with torch.no_grad():\n", 978 | " text = torch.tensor(text_pipeline(text))\n", 979 | " output = model(text, torch.tensor([0]))\n", 980 | " return output.argmax(1).item() \n", 981 | "ex_text_str = \"soooooo wish i could, but im in school and myspace is completely blocked\"\n", 982 | "model = model.to(\"cpu\")\n", 983 | "\n", 984 | "print(\"This is a %s tweet\" %sentiment_label[predict(ex_text_str, text_pipeline)])" 985 | ] 986 | }, 987 | { 988 | "cell_type": "code", 989 | "execution_count": null, 990 | "metadata": {}, 991 | "outputs": [], 992 | "source": [] 993 | } 994 | ], 995 | "metadata": { 996 | "kernelspec": { 997 | "display_name": "Python 3", 998 | "language": "python", 999 | "name": "python3" 1000 | }, 1001 | "language_info": { 1002 | "codemirror_mode": { 1003 | "name": "ipython", 1004 | "version": 3 1005 | }, 1006 | "file_extension": ".py", 1007 | "mimetype": "text/x-python", 1008 | "name": "python", 1009 | "nbconvert_exporter": "python", 1010 | "pygments_lexer": "ipython3", 1011 | "version": "3.6.13" 1012 | } 1013 | }, 1014 | "nbformat": 4, 1015 | "nbformat_minor": 4 1016 | } 1017 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorchTextClassificationCustomDataset 2 | In this repository i explain how you can implement a text classifier on custom dataset using PyTorch 3 | This jupyter notebook is inspired from: https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html 4 | 5 | The dataset is from the Tweet Sentiment Extraction challenge from Kaggle(https://www.kaggle.com/c/tweet-sentiment-extraction/overview) 6 | 7 | We would implement text classification using a simple embedding bag of words using PyTorch on tweet data to classify tweets as "positive","negative" or "neutral" 8 | 9 | 10 | Pre-requisites: 11 | 12 | PyTorch (https://pytorch.org/) 13 | 14 | TorchText (https://anaconda.org/pytorch/torchtext) 15 | 16 | Python 3.6 and above (https://www.anaconda.com/products/individual) 17 | --------------------------------------------------------------------------------