├── PyTorchTweetTextClassification.ipynb
└── README.md
/PyTorchTweetTextClassification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "\n",
8 | "***\n",
9 | "\n",
10 | "
\n",
11 | "Let us look at how we can implement text classification using PyTorch\n",
12 | "
\n",
13 | "\n",
14 | "The dataset is from the Tweet Sentiment Extraction challenge from Kaggle(https://www.kaggle.com/c/tweet-sentiment-extraction/overview) \n",
15 | "
\n",
16 | "\n",
17 | "We would implement text classification using a simple embedding bag of words using PyTorch on tweet data to classify tweets as \"positive\",\"negative\" or \"neutral\"\n",
18 | "\n",
19 | "\n",
20 | "
\n",
21 | "\n",
22 | "***\n"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "\n",
32 | "import os\n",
33 | "import re\n",
34 | "import shutil\n",
35 | "import string\n",
36 | "\n",
37 | "\n",
38 | "from collections import Counter\n",
39 | "\n",
40 | "\n",
41 | "import pandas as pd\n",
42 | "import numpy as np\n",
43 | "\n",
44 | "import sklearn\n",
45 | "\n",
46 | "\n",
47 | "from sklearn.model_selection import train_test_split\n",
48 | "\n",
49 | "\n"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "****\n",
57 | "Let us define methods to pre-process the review data\n",
58 | "****"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 2,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "def remove_emoji(text):\n",
68 | " emoji_pattern = re.compile(\"[\"\n",
69 | " u\"\\U0001F600-\\U0001F64F\" # emoticons\n",
70 | " u\"\\U0001F300-\\U0001F5FF\" # symbols & pictographs\n",
71 | " u\"\\U0001F680-\\U0001F6FF\" # transport & map symbols\n",
72 | " u\"\\U0001F1E0-\\U0001F1FF\" # flags (iOS)\n",
73 | " u\"\\U00002702-\\U000027B0\"\n",
74 | " u\"\\U000024C2-\\U0001F251\"\n",
75 | " \"]+\", flags=re.UNICODE)\n",
76 | " return emoji_pattern.sub(r'', text)\n",
77 | "\n",
78 | "def remove_url(text): \n",
79 | " url_pattern = re.compile('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')\n",
80 | " return url_pattern.sub(r'', text)\n",
81 | " # converting return value from list to string\n",
82 | "\n",
83 | "\n",
84 | "\n",
85 | "def clean_text(text ): \n",
86 | " delete_dict = {sp_character: '' for sp_character in string.punctuation} \n",
87 | " delete_dict[' '] = ' ' \n",
88 | " table = str.maketrans(delete_dict)\n",
89 | " text1 = text.translate(table)\n",
90 | " #print('cleaned:'+text1)\n",
91 | " textArr= text1.split()\n",
92 | " text2 = ' '.join([w for w in textArr if ( not w.isdigit() and ( not w.isdigit() and len(w)>2))]) \n",
93 | " \n",
94 | " return text2.lower()\n",
95 | "\n"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 3,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "def get_sentiment(sentiment):\n",
105 | " if sentiment == 'positive':\n",
106 | " return 2\n",
107 | " elif sentiment == 'negative':\n",
108 | " return 1\n",
109 | " else:\n",
110 | " return 0\n",
111 | " "
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 4,
117 | "metadata": {},
118 | "outputs": [
119 | {
120 | "name": "stdout",
121 | "output_type": "stream",
122 | "text": [
123 | "-------Train data--------\n",
124 | "neutral 10704\n",
125 | "positive 8375\n",
126 | "negative 7673\n",
127 | "Name: sentiment, dtype: int64\n",
128 | "26752\n",
129 | "-------------------------\n",
130 | "-------Test data--------\n",
131 | "neutral 1376\n",
132 | "positive 1075\n",
133 | "negative 983\n",
134 | "Name: sentiment, dtype: int64\n",
135 | "3434\n",
136 | "-------------------------\n",
137 | "Train Max Sentence Length :33\n",
138 | "Test Max Sentence Length :32\n"
139 | ]
140 | }
141 | ],
142 | "source": [
143 | "train_data= pd.read_csv(\"C:\\\\TweetSenitment\\\\train.csv\")\n",
144 | "train_data.dropna(axis = 0, how ='any',inplace=True) \n",
145 | "train_data['Num_words_text'] = train_data['text'].apply(lambda x:len(str(x).split())) \n",
146 | "mask = train_data['Num_words_text'] >2\n",
147 | "train_data = train_data[mask]\n",
148 | "print('-------Train data--------')\n",
149 | "print(train_data['sentiment'].value_counts())\n",
150 | "print(len(train_data))\n",
151 | "print('-------------------------')\n",
152 | "max_train_sentence_length = train_data['Num_words_text'].max()\n",
153 | "\n",
154 | "\n",
155 | "train_data['text'] = train_data['text'].apply(remove_emoji)\n",
156 | "train_data['text'] = train_data['text'].apply(remove_url)\n",
157 | "train_data['text'] = train_data['text'].apply(clean_text)\n",
158 | "\n",
159 | "train_data['label'] = train_data['sentiment'].apply(get_sentiment)\n",
160 | "\n",
161 | "test_data= pd.read_csv(\"C:\\\\TweetSenitment\\\\test.csv\")\n",
162 | "test_data.dropna(axis = 0, how ='any',inplace=True) \n",
163 | "test_data['Num_words_text'] = test_data['text'].apply(lambda x:len(str(x).split())) \n",
164 | "\n",
165 | "max_test_sentence_length = test_data['Num_words_text'].max()\n",
166 | "\n",
167 | "mask = test_data['Num_words_text'] >2\n",
168 | "test_data = test_data[mask]\n",
169 | "\n",
170 | "print('-------Test data--------')\n",
171 | "print(test_data['sentiment'].value_counts())\n",
172 | "print(len(test_data))\n",
173 | "print('-------------------------')\n",
174 | "\n",
175 | "test_data['text'] = test_data['text'].apply(remove_emoji)\n",
176 | "test_data['text'] = test_data['text'].apply(remove_url)\n",
177 | "test_data['text'] = test_data['text'].apply(clean_text)\n",
178 | "\n",
179 | "test_data['label'] = test_data['sentiment'].apply(get_sentiment)\n",
180 | "\n",
181 | "print('Train Max Sentence Length :'+str(max_train_sentence_length))\n",
182 | "print('Test Max Sentence Length :'+str(max_test_sentence_length))\n",
183 | "\n",
184 | "\n"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 5,
190 | "metadata": {},
191 | "outputs": [
192 | {
193 | "data": {
194 | "text/html": [
195 | "
\n",
196 | "\n",
209 | "
\n",
210 | " \n",
211 | " \n",
212 | " | \n",
213 | " textID | \n",
214 | " text | \n",
215 | " selected_text | \n",
216 | " sentiment | \n",
217 | " Num_words_text | \n",
218 | " label | \n",
219 | "
\n",
220 | " \n",
221 | " \n",
222 | " \n",
223 | " 0 | \n",
224 | " cb774db0d1 | \n",
225 | " have responded were going | \n",
226 | " I`d have responded, if I were going | \n",
227 | " neutral | \n",
228 | " 7 | \n",
229 | " 0 | \n",
230 | "
\n",
231 | " \n",
232 | " 1 | \n",
233 | " 549e992a42 | \n",
234 | " sooo sad will miss you here san diego | \n",
235 | " Sooo SAD | \n",
236 | " negative | \n",
237 | " 10 | \n",
238 | " 1 | \n",
239 | "
\n",
240 | " \n",
241 | " 2 | \n",
242 | " 088c60f138 | \n",
243 | " boss bullying | \n",
244 | " bullying me | \n",
245 | " negative | \n",
246 | " 5 | \n",
247 | " 1 | \n",
248 | "
\n",
249 | " \n",
250 | " 3 | \n",
251 | " 9642c003ef | \n",
252 | " what interview leave alone | \n",
253 | " leave me alone | \n",
254 | " negative | \n",
255 | " 5 | \n",
256 | " 1 | \n",
257 | "
\n",
258 | " \n",
259 | " 4 | \n",
260 | " 358bd9e861 | \n",
261 | " sons why couldnt they put them the releases al... | \n",
262 | " Sons of ****, | \n",
263 | " negative | \n",
264 | " 14 | \n",
265 | " 1 | \n",
266 | "
\n",
267 | " \n",
268 | " 5 | \n",
269 | " 28b57f3990 | \n",
270 | " some shameless plugging for the best rangers f... | \n",
271 | " http://www.dothebouncy.com/smf - some shameles... | \n",
272 | " neutral | \n",
273 | " 12 | \n",
274 | " 0 | \n",
275 | "
\n",
276 | " \n",
277 | " 6 | \n",
278 | " 6e0c6d75b1 | \n",
279 | " 2am feedings for the baby are fun when all smi... | \n",
280 | " fun | \n",
281 | " positive | \n",
282 | " 14 | \n",
283 | " 2 | \n",
284 | "
\n",
285 | " \n",
286 | " 8 | \n",
287 | " e050245fbd | \n",
288 | " both you | \n",
289 | " Both of you | \n",
290 | " neutral | \n",
291 | " 3 | \n",
292 | " 0 | \n",
293 | "
\n",
294 | " \n",
295 | " 9 | \n",
296 | " fc2cbefa9d | \n",
297 | " journey wow just became cooler hehe that possible | \n",
298 | " Wow... u just became cooler. | \n",
299 | " positive | \n",
300 | " 10 | \n",
301 | " 2 | \n",
302 | "
\n",
303 | " \n",
304 | " 10 | \n",
305 | " 2339a9b08b | \n",
306 | " much love hopeful reckon the chances are minim... | \n",
307 | " as much as i love to be hopeful, i reckon the ... | \n",
308 | " neutral | \n",
309 | " 23 | \n",
310 | " 0 | \n",
311 | "
\n",
312 | " \n",
313 | "
\n",
314 | "
"
315 | ],
316 | "text/plain": [
317 | " textID text \\\n",
318 | "0 cb774db0d1 have responded were going \n",
319 | "1 549e992a42 sooo sad will miss you here san diego \n",
320 | "2 088c60f138 boss bullying \n",
321 | "3 9642c003ef what interview leave alone \n",
322 | "4 358bd9e861 sons why couldnt they put them the releases al... \n",
323 | "5 28b57f3990 some shameless plugging for the best rangers f... \n",
324 | "6 6e0c6d75b1 2am feedings for the baby are fun when all smi... \n",
325 | "8 e050245fbd both you \n",
326 | "9 fc2cbefa9d journey wow just became cooler hehe that possible \n",
327 | "10 2339a9b08b much love hopeful reckon the chances are minim... \n",
328 | "\n",
329 | " selected_text sentiment \\\n",
330 | "0 I`d have responded, if I were going neutral \n",
331 | "1 Sooo SAD negative \n",
332 | "2 bullying me negative \n",
333 | "3 leave me alone negative \n",
334 | "4 Sons of ****, negative \n",
335 | "5 http://www.dothebouncy.com/smf - some shameles... neutral \n",
336 | "6 fun positive \n",
337 | "8 Both of you neutral \n",
338 | "9 Wow... u just became cooler. positive \n",
339 | "10 as much as i love to be hopeful, i reckon the ... neutral \n",
340 | "\n",
341 | " Num_words_text label \n",
342 | "0 7 0 \n",
343 | "1 10 1 \n",
344 | "2 5 1 \n",
345 | "3 5 1 \n",
346 | "4 14 1 \n",
347 | "5 12 0 \n",
348 | "6 14 2 \n",
349 | "8 3 0 \n",
350 | "9 10 2 \n",
351 | "10 23 0 "
352 | ]
353 | },
354 | "execution_count": 5,
355 | "metadata": {},
356 | "output_type": "execute_result"
357 | }
358 | ],
359 | "source": [
360 | "train_data.head(10)"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": 6,
366 | "metadata": {},
367 | "outputs": [
368 | {
369 | "data": {
370 | "text/html": [
371 | "\n",
372 | "\n",
385 | "
\n",
386 | " \n",
387 | " \n",
388 | " | \n",
389 | " textID | \n",
390 | " text | \n",
391 | " sentiment | \n",
392 | " Num_words_text | \n",
393 | " label | \n",
394 | "
\n",
395 | " \n",
396 | " \n",
397 | " \n",
398 | " 0 | \n",
399 | " f87dea47db | \n",
400 | " last session the day | \n",
401 | " neutral | \n",
402 | " 6 | \n",
403 | " 0 | \n",
404 | "
\n",
405 | " \n",
406 | " 1 | \n",
407 | " 96d74cb729 | \n",
408 | " shanghai also really exciting precisely skyscr... | \n",
409 | " positive | \n",
410 | " 15 | \n",
411 | " 2 | \n",
412 | "
\n",
413 | " \n",
414 | " 2 | \n",
415 | " eee518ae67 | \n",
416 | " recession hit veronique branquinho she has qui... | \n",
417 | " negative | \n",
418 | " 13 | \n",
419 | " 1 | \n",
420 | "
\n",
421 | " \n",
422 | " 4 | \n",
423 | " 33987a8ee5 | \n",
424 | " like | \n",
425 | " positive | \n",
426 | " 5 | \n",
427 | " 2 | \n",
428 | "
\n",
429 | " \n",
430 | " 5 | \n",
431 | " 726e501993 | \n",
432 | " thats great weee visitors | \n",
433 | " positive | \n",
434 | " 4 | \n",
435 | " 2 | \n",
436 | "
\n",
437 | " \n",
438 | " 6 | \n",
439 | " 261932614e | \n",
440 | " think everyone hates here lol | \n",
441 | " negative | \n",
442 | " 8 | \n",
443 | " 1 | \n",
444 | "
\n",
445 | " \n",
446 | " 7 | \n",
447 | " afa11da83f | \n",
448 | " soooooo wish could but school and myspace comp... | \n",
449 | " negative | \n",
450 | " 13 | \n",
451 | " 1 | \n",
452 | "
\n",
453 | " \n",
454 | " 8 | \n",
455 | " e64208b4ef | \n",
456 | " and within short time the last clue all them | \n",
457 | " neutral | \n",
458 | " 12 | \n",
459 | " 0 | \n",
460 | "
\n",
461 | " \n",
462 | " 9 | \n",
463 | " 37bcad24ca | \n",
464 | " what did you get day alright havent done anyth... | \n",
465 | " neutral | \n",
466 | " 18 | \n",
467 | " 0 | \n",
468 | "
\n",
469 | " \n",
470 | " 10 | \n",
471 | " 24c92644a4 | \n",
472 | " bike was put holdshould have known that argh t... | \n",
473 | " negative | \n",
474 | " 12 | \n",
475 | " 1 | \n",
476 | "
\n",
477 | " \n",
478 | "
\n",
479 | "
"
480 | ],
481 | "text/plain": [
482 | " textID text sentiment \\\n",
483 | "0 f87dea47db last session the day neutral \n",
484 | "1 96d74cb729 shanghai also really exciting precisely skyscr... positive \n",
485 | "2 eee518ae67 recession hit veronique branquinho she has qui... negative \n",
486 | "4 33987a8ee5 like positive \n",
487 | "5 726e501993 thats great weee visitors positive \n",
488 | "6 261932614e think everyone hates here lol negative \n",
489 | "7 afa11da83f soooooo wish could but school and myspace comp... negative \n",
490 | "8 e64208b4ef and within short time the last clue all them neutral \n",
491 | "9 37bcad24ca what did you get day alright havent done anyth... neutral \n",
492 | "10 24c92644a4 bike was put holdshould have known that argh t... negative \n",
493 | "\n",
494 | " Num_words_text label \n",
495 | "0 6 0 \n",
496 | "1 15 2 \n",
497 | "2 13 1 \n",
498 | "4 5 2 \n",
499 | "5 4 2 \n",
500 | "6 8 1 \n",
501 | "7 13 1 \n",
502 | "8 12 0 \n",
503 | "9 18 0 \n",
504 | "10 12 1 "
505 | ]
506 | },
507 | "execution_count": 6,
508 | "metadata": {},
509 | "output_type": "execute_result"
510 | }
511 | ],
512 | "source": [
513 | "test_data.head(10)"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {},
519 | "source": [
520 | "***\n",
521 | "Let us create our train,valid and test datasets\n",
522 | "***"
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": 7,
528 | "metadata": {},
529 | "outputs": [
530 | {
531 | "name": "stdout",
532 | "output_type": "stream",
533 | "text": [
534 | "Train data len:21401\n",
535 | "Class distributionCounter({0: 8563, 2: 6700, 1: 6138})\n",
536 | "Valid data len:5351\n",
537 | "Class distributionCounter({0: 2141, 2: 1675, 1: 1535})\n",
538 | "Test data len:3434\n",
539 | "Class distributionCounter({0: 1376, 2: 1075, 1: 983})\n"
540 | ]
541 | }
542 | ],
543 | "source": [
544 | "X_train, X_valid, Y_train, Y_valid= train_test_split(train_data['text'].tolist(),\\\n",
545 | " train_data['label'].tolist(),\\\n",
546 | " test_size=0.2,\\\n",
547 | " stratify = train_data['label'].tolist(),\\\n",
548 | " random_state=0)\n",
549 | "\n",
550 | "\n",
551 | "print('Train data len:'+str(len(X_train)))\n",
552 | "print('Class distribution'+str(Counter(Y_train)))\n",
553 | "\n",
554 | "\n",
555 | "print('Valid data len:'+str(len(X_valid)))\n",
556 | "print('Class distribution'+ str(Counter(Y_valid)))\n",
557 | "\n",
558 | "print('Test data len:'+str(len(test_data['text'].tolist())))\n",
559 | "print('Class distribution'+ str(Counter(test_data['label'].tolist())))\n",
560 | "\n",
561 | "\n",
562 | "train_dat =list(zip(Y_train,X_train))\n",
563 | "valid_dat =list(zip(Y_valid,X_valid))\n",
564 | "test_dat=list(zip(test_data['label'].tolist(),test_data['text'].tolist()))\n"
565 | ]
566 | },
567 | {
568 | "cell_type": "code",
569 | "execution_count": 8,
570 | "metadata": {},
571 | "outputs": [],
572 | "source": [
573 | "import torch\n",
574 | "from torch.utils.data import DataLoader\n",
575 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
576 | ]
577 | },
578 | {
579 | "cell_type": "markdown",
580 | "metadata": {},
581 | "source": [
582 | "***\n",
583 | "Let us create our vocabulary on train data\n",
584 | "***"
585 | ]
586 | },
587 | {
588 | "cell_type": "code",
589 | "execution_count": 9,
590 | "metadata": {},
591 | "outputs": [],
592 | "source": [
593 | "from torchtext.data.utils import get_tokenizer\n",
594 | "from torchtext.vocab import build_vocab_from_iterator\n",
595 | "\n",
596 | "tokenizer = get_tokenizer('basic_english')\n",
597 | "train_iter = train_dat\n",
598 | "def yield_tokens(data_iter):\n",
599 | " for _, text in data_iter:\n",
600 | " yield tokenizer(text)\n",
601 | "\n",
602 | "vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[\"\"])\n",
603 | "vocab.set_default_index(vocab[\"\"])\n"
604 | ]
605 | },
606 | {
607 | "cell_type": "markdown",
608 | "metadata": {},
609 | "source": [
610 | "Let us define our text and label preprocessing pipleines"
611 | ]
612 | },
613 | {
614 | "cell_type": "code",
615 | "execution_count": 10,
616 | "metadata": {},
617 | "outputs": [],
618 | "source": [
619 | "text_pipeline = lambda x: vocab(tokenizer(x))\n",
620 | "label_pipeline = lambda x: int(x) "
621 | ]
622 | },
623 | {
624 | "cell_type": "code",
625 | "execution_count": 11,
626 | "metadata": {},
627 | "outputs": [
628 | {
629 | "data": {
630 | "text/plain": [
631 | "[62, 0, 1, 0, 12881]"
632 | ]
633 | },
634 | "execution_count": 11,
635 | "metadata": {},
636 | "output_type": "execute_result"
637 | }
638 | ],
639 | "source": [
640 | "text_pipeline('here is the an example')\n"
641 | ]
642 | },
643 | {
644 | "cell_type": "code",
645 | "execution_count": 12,
646 | "metadata": {},
647 | "outputs": [
648 | {
649 | "data": {
650 | "text/plain": [
651 | "1"
652 | ]
653 | },
654 | "execution_count": 12,
655 | "metadata": {},
656 | "output_type": "execute_result"
657 | }
658 | ],
659 | "source": [
660 | "label_pipeline('1')"
661 | ]
662 | },
663 | {
664 | "cell_type": "markdown",
665 | "metadata": {},
666 | "source": [
667 | "Let us define our batch collation function"
668 | ]
669 | },
670 | {
671 | "cell_type": "code",
672 | "execution_count": 13,
673 | "metadata": {},
674 | "outputs": [],
675 | "source": [
676 | "\n",
677 | "\n",
678 | "def collate_batch(batch):\n",
679 | " label_list, text_list, offsets = [], [], [0]\n",
680 | " for (_label, _text) in batch:\n",
681 | " label_list.append(label_pipeline(_label))\n",
682 | " processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)\n",
683 | " text_list.append(processed_text)\n",
684 | " offsets.append(processed_text.size(0))\n",
685 | " label_list = torch.tensor(label_list, dtype=torch.int64)\n",
686 | " offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)\n",
687 | " text_list = torch.cat(text_list)\n",
688 | " return label_list.to(device), text_list.to(device), offsets.to(device)\n",
689 | "\n",
690 | "#train_iter =train_dat\n",
691 | "#dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)"
692 | ]
693 | },
694 | {
695 | "cell_type": "markdown",
696 | "metadata": {},
697 | "source": [
698 | "Let us define our text classification model"
699 | ]
700 | },
701 | {
702 | "cell_type": "code",
703 | "execution_count": 14,
704 | "metadata": {},
705 | "outputs": [],
706 | "source": [
707 | "from torch import nn\n",
708 | "import torch.nn.functional as F\n",
709 | "\n",
710 | "class TextClassificationModel(nn.Module):\n",
711 | "\n",
712 | " def __init__(self, vocab_size, embed_dim, num_class):\n",
713 | " super(TextClassificationModel, self).__init__()\n",
714 | " self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)\n",
715 | " self.fc1 = nn.Linear(embed_dim,64)\n",
716 | " self.fc2 = nn.Linear(64,16)\n",
717 | " self.fc3 = nn.Linear(16, num_class)\n",
718 | " self.init_weights()\n",
719 | "\n",
720 | " def init_weights(self):\n",
721 | " initrange = 0.5\n",
722 | " self.embedding.weight.data.uniform_(-initrange, initrange)\n",
723 | " self.fc1.weight.data.uniform_(-initrange, initrange)\n",
724 | " self.fc1.bias.data.zero_()\n",
725 | " self.fc2.weight.data.uniform_(-initrange, initrange)\n",
726 | " self.fc2.bias.data.zero_()\n",
727 | " self.fc3.weight.data.uniform_(-initrange, initrange)\n",
728 | " self.fc3.bias.data.zero_()\n",
729 | "\n",
730 | " def forward(self, text, offsets):\n",
731 | " embedded = self.embedding(text, offsets)\n",
732 | " x = F.relu(self.fc1(embedded))\n",
733 | " x = F.relu(self.fc2(x))\n",
734 | " x = self.fc3(x)\n",
735 | " return x"
736 | ]
737 | },
738 | {
739 | "cell_type": "markdown",
740 | "metadata": {},
741 | "source": [
742 | "Let us create an object of our text classification class"
743 | ]
744 | },
745 | {
746 | "cell_type": "code",
747 | "execution_count": 15,
748 | "metadata": {},
749 | "outputs": [
750 | {
751 | "name": "stdout",
752 | "output_type": "stream",
753 | "text": [
754 | "3\n"
755 | ]
756 | }
757 | ],
758 | "source": [
759 | "train_iter1 = train_dat\n",
760 | "num_class = len(set([label for (label, text) in train_iter1]))\n",
761 | "print(num_class)\n",
762 | "vocab_size = len(vocab)\n",
763 | "emsize = 128\n",
764 | "model = TextClassificationModel(vocab_size, emsize, num_class).to(device)"
765 | ]
766 | },
767 | {
768 | "cell_type": "markdown",
769 | "metadata": {},
770 | "source": [
771 | "Let us define our train and evaluate methods"
772 | ]
773 | },
774 | {
775 | "cell_type": "code",
776 | "execution_count": 16,
777 | "metadata": {},
778 | "outputs": [],
779 | "source": [
780 | "import time\n",
781 | "\n",
782 | "def train(dataloader):\n",
783 | " model.train()\n",
784 | " total_acc, total_count = 0, 0\n",
785 | " log_interval = 500\n",
786 | " start_time = time.time()\n",
787 | "\n",
788 | " for idx, (label, text, offsets) in enumerate(dataloader):\n",
789 | " optimizer.zero_grad()\n",
790 | " predited_label = model(text, offsets)\n",
791 | " loss = criterion(predited_label, label)\n",
792 | " loss.backward()\n",
793 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n",
794 | " optimizer.step()\n",
795 | " total_acc += (predited_label.argmax(1) == label).sum().item()\n",
796 | " total_count += label.size(0)\n",
797 | " if idx % log_interval == 0 and idx > 0:\n",
798 | " elapsed = time.time() - start_time\n",
799 | " print('| epoch {:3d} | {:5d}/{:5d} batches '\n",
800 | " '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),\n",
801 | " total_acc/total_count))\n",
802 | " total_acc, total_count = 0, 0\n",
803 | " start_time = time.time()\n",
804 | "\n",
805 | "def evaluate(dataloader):\n",
806 | " model.eval()\n",
807 | " total_acc, total_count = 0, 0\n",
808 | "\n",
809 | " with torch.no_grad():\n",
810 | " for idx, (label, text, offsets) in enumerate(dataloader):\n",
811 | " predited_label = model(text, offsets)\n",
812 | " loss = criterion(predited_label, label)\n",
813 | " total_acc += (predited_label.argmax(1) == label).sum().item()\n",
814 | " total_count += label.size(0)\n",
815 | " return total_acc/total_count"
816 | ]
817 | },
818 | {
819 | "cell_type": "markdown",
820 | "metadata": {},
821 | "source": [
822 | "Let us create dataloaders for text,train and validation data iterators and then train our model"
823 | ]
824 | },
825 | {
826 | "cell_type": "code",
827 | "execution_count": 17,
828 | "metadata": {},
829 | "outputs": [
830 | {
831 | "name": "stdout",
832 | "output_type": "stream",
833 | "text": [
834 | "| epoch 1 | 500/ 1338 batches | accuracy 0.444\n",
835 | "| epoch 1 | 1000/ 1338 batches | accuracy 0.546\n",
836 | "-----------------------------------------------------------\n",
837 | "| end of epoch 1 | time: 4.41s | valid accuracy 0.627 \n",
838 | "-----------------------------------------------------------\n",
839 | "| epoch 2 | 500/ 1338 batches | accuracy 0.625\n",
840 | "| epoch 2 | 1000/ 1338 batches | accuracy 0.619\n",
841 | "-----------------------------------------------------------\n",
842 | "| end of epoch 2 | time: 3.44s | valid accuracy 0.562 \n",
843 | "-----------------------------------------------------------\n",
844 | "| epoch 3 | 500/ 1338 batches | accuracy 0.681\n",
845 | "| epoch 3 | 1000/ 1338 batches | accuracy 0.711\n",
846 | "-----------------------------------------------------------\n",
847 | "| end of epoch 3 | time: 3.26s | valid accuracy 0.677 \n",
848 | "-----------------------------------------------------------\n",
849 | "| epoch 4 | 500/ 1338 batches | accuracy 0.733\n",
850 | "| epoch 4 | 1000/ 1338 batches | accuracy 0.735\n",
851 | "-----------------------------------------------------------\n",
852 | "| end of epoch 4 | time: 3.30s | valid accuracy 0.671 \n",
853 | "-----------------------------------------------------------\n",
854 | "| epoch 5 | 500/ 1338 batches | accuracy 0.748\n",
855 | "| epoch 5 | 1000/ 1338 batches | accuracy 0.755\n",
856 | "-----------------------------------------------------------\n",
857 | "| end of epoch 5 | time: 3.24s | valid accuracy 0.675 \n",
858 | "-----------------------------------------------------------\n",
859 | "| epoch 6 | 500/ 1338 batches | accuracy 0.765\n",
860 | "| epoch 6 | 1000/ 1338 batches | accuracy 0.755\n",
861 | "-----------------------------------------------------------\n",
862 | "| end of epoch 6 | time: 3.24s | valid accuracy 0.676 \n",
863 | "-----------------------------------------------------------\n",
864 | "| epoch 7 | 500/ 1338 batches | accuracy 0.760\n",
865 | "| epoch 7 | 1000/ 1338 batches | accuracy 0.761\n",
866 | "-----------------------------------------------------------\n",
867 | "| end of epoch 7 | time: 3.37s | valid accuracy 0.676 \n",
868 | "-----------------------------------------------------------\n",
869 | "| epoch 8 | 500/ 1338 batches | accuracy 0.758\n",
870 | "| epoch 8 | 1000/ 1338 batches | accuracy 0.762\n",
871 | "-----------------------------------------------------------\n",
872 | "| end of epoch 8 | time: 3.70s | valid accuracy 0.676 \n",
873 | "-----------------------------------------------------------\n",
874 | "| epoch 9 | 500/ 1338 batches | accuracy 0.758\n",
875 | "| epoch 9 | 1000/ 1338 batches | accuracy 0.762\n",
876 | "-----------------------------------------------------------\n",
877 | "| end of epoch 9 | time: 3.49s | valid accuracy 0.676 \n",
878 | "-----------------------------------------------------------\n",
879 | "| epoch 10 | 500/ 1338 batches | accuracy 0.753\n",
880 | "| epoch 10 | 1000/ 1338 batches | accuracy 0.761\n",
881 | "-----------------------------------------------------------\n",
882 | "| end of epoch 10 | time: 3.32s | valid accuracy 0.676 \n",
883 | "-----------------------------------------------------------\n"
884 | ]
885 | }
886 | ],
887 | "source": [
888 | "from torch.utils.data.dataset import random_split\n",
889 | "from torchtext.data.functional import to_map_style_dataset\n",
890 | "# Hyperparameters\n",
891 | "EPOCHS = 10 # epoch\n",
892 | "LR =10 # learning rate\n",
893 | "BATCH_SIZE = 16 # batch size for training\n",
894 | "\n",
895 | "criterion = torch.nn.CrossEntropyLoss()\n",
896 | "optimizer = torch.optim.SGD(model.parameters(), lr=LR)\n",
897 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)\n",
898 | "total_accu = None\n",
899 | "\n",
900 | "train_iter2 = train_dat\n",
901 | "test_iter2 =test_dat \n",
902 | "valid_iter2= valid_dat\n",
903 | "\n",
904 | "\n",
905 | "\n",
906 | "\n",
907 | "train_dataloader = DataLoader(train_iter2, batch_size=BATCH_SIZE,\n",
908 | " shuffle=True, collate_fn=collate_batch)\n",
909 | "valid_dataloader = DataLoader(valid_iter2, batch_size=BATCH_SIZE,\n",
910 | " shuffle=True, collate_fn=collate_batch)\n",
911 | "test_dataloader = DataLoader(test_iter2, batch_size=BATCH_SIZE,\n",
912 | " shuffle=True, collate_fn=collate_batch)\n",
913 | "\n",
914 | "for epoch in range(1, EPOCHS + 1):\n",
915 | " epoch_start_time = time.time()\n",
916 | " train(train_dataloader)\n",
917 | " accu_val = evaluate(valid_dataloader)\n",
918 | " if total_accu is not None and total_accu > accu_val:\n",
919 | " scheduler.step()\n",
920 | " else:\n",
921 | " total_accu = accu_val\n",
922 | " print('-' * 59)\n",
923 | " print('| end of epoch {:3d} | time: {:5.2f}s | '\n",
924 | " 'valid accuracy {:8.3f} '.format(epoch,\n",
925 | " time.time() - epoch_start_time,\n",
926 | " accu_val))\n",
927 | " print('-' * 59)"
928 | ]
929 | },
930 | {
931 | "cell_type": "code",
932 | "execution_count": 18,
933 | "metadata": {},
934 | "outputs": [
935 | {
936 | "name": "stdout",
937 | "output_type": "stream",
938 | "text": [
939 | "Checking the results of test dataset.\n",
940 | "test accuracy 0.692\n"
941 | ]
942 | }
943 | ],
944 | "source": [
945 | "print('Checking the results of test dataset.')\n",
946 | "accu_test = evaluate(test_dataloader)\n",
947 | "print('test accuracy {:8.3f}'.format(accu_test))"
948 | ]
949 | },
950 | {
951 | "cell_type": "markdown",
952 | "metadata": {},
953 | "source": [
954 | "Test model on text"
955 | ]
956 | },
957 | {
958 | "cell_type": "code",
959 | "execution_count": 19,
960 | "metadata": {},
961 | "outputs": [
962 | {
963 | "name": "stdout",
964 | "output_type": "stream",
965 | "text": [
966 | "This is a Neutral tweet\n"
967 | ]
968 | }
969 | ],
970 | "source": [
971 | "sentiment_label = {2:\"Positive\",\n",
972 | " 1: \"Negative\",\n",
973 | " 0: \"Neutral\"\n",
974 | " }\n",
975 | "\n",
976 | "def predict(text, text_pipeline):\n",
977 | " with torch.no_grad():\n",
978 | " text = torch.tensor(text_pipeline(text))\n",
979 | " output = model(text, torch.tensor([0]))\n",
980 | " return output.argmax(1).item() \n",
981 | "ex_text_str = \"soooooo wish i could, but im in school and myspace is completely blocked\"\n",
982 | "model = model.to(\"cpu\")\n",
983 | "\n",
984 | "print(\"This is a %s tweet\" %sentiment_label[predict(ex_text_str, text_pipeline)])"
985 | ]
986 | },
987 | {
988 | "cell_type": "code",
989 | "execution_count": null,
990 | "metadata": {},
991 | "outputs": [],
992 | "source": []
993 | }
994 | ],
995 | "metadata": {
996 | "kernelspec": {
997 | "display_name": "Python 3",
998 | "language": "python",
999 | "name": "python3"
1000 | },
1001 | "language_info": {
1002 | "codemirror_mode": {
1003 | "name": "ipython",
1004 | "version": 3
1005 | },
1006 | "file_extension": ".py",
1007 | "mimetype": "text/x-python",
1008 | "name": "python",
1009 | "nbconvert_exporter": "python",
1010 | "pygments_lexer": "ipython3",
1011 | "version": "3.6.13"
1012 | }
1013 | },
1014 | "nbformat": 4,
1015 | "nbformat_minor": 4
1016 | }
1017 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorchTextClassificationCustomDataset
2 | In this repository i explain how you can implement a text classifier on custom dataset using PyTorch
3 | This jupyter notebook is inspired from: https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html
4 |
5 | The dataset is from the Tweet Sentiment Extraction challenge from Kaggle(https://www.kaggle.com/c/tweet-sentiment-extraction/overview)
6 |
7 | We would implement text classification using a simple embedding bag of words using PyTorch on tweet data to classify tweets as "positive","negative" or "neutral"
8 |
9 |
10 | Pre-requisites:
11 |
12 | PyTorch (https://pytorch.org/)
13 |
14 | TorchText (https://anaconda.org/pytorch/torchtext)
15 |
16 | Python 3.6 and above (https://www.anaconda.com/products/individual)
17 |
--------------------------------------------------------------------------------