├── .idea
├── .gitignore
├── depression-detection.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── .ipynb_checkpoints
└── 3_depression_detector-checkpoint.ipynb
├── 0_depression_data_scrapper.ipynb
├── 1_non-depressive-data-gathering.ipynb
├── 2_combine_datasets.ipynb
├── 3_depression_detector.ipynb
├── Depression detection
├── Project Motivation.odt
├── README.md
├── data
├── 2018-EI-reg-En-anger-test-gold.txt
├── 2018-EI-reg-En-fear-test-gold.txt
├── 2018-EI-reg-En-joy-test-gold.txt
├── 2018-EI-reg-En-sadness-test-gold.txt
├── general_tweets.csv
├── tweets_combined.csv
├── tweets_final_1_clean.csv
├── tweets_final_2_clean.csv
├── tweets_final_3_clean.csv
├── tweets_final_4_clean.csv
├── tweets_final_5_clean.csv
└── tweets_final_6_clean.csv
└── embedding
└── glove.twitter.27B.100d.md
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/depression-detection.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/2_combine_datasets.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Combining Datasets\n",
8 | "\n",
9 | "In this script we will combine the data gathered from the depreessive scrapping script with the data from the mostly non-depressive scripts,\n",
10 | "analyse their target distribution, standardise their formats and then construct a final dataset containing all of them.\n",
11 | "\n",
12 | "The data from the depressive script hsa been manually reviewed and annotated by several different team members, and then reviewwd one more\n",
13 | "time by aonther member, to ensure consistency."
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 1,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import numpy as np\n",
23 | "import pandas as pd\n",
24 | "import matplotlib.pyplot as plt\n",
25 | "import seaborn as sns"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "**1. Read in non-depressive dataset**"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "df0 = pd.read_csv(\"./data/general_tweets.csv\")"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "df1 = pd.read_csv(\"./data/tweets_final_1_clean.csv\", engine='python')\n",
51 | "df2 = pd.read_csv(\"./data/tweets_final_2_clean.csv\", engine='python')\n",
52 | "df3 = pd.read_csv(\"./data/tweets_final_3_clean.csv\", engine='python')"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 4,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "df4 = pd.read_csv(\"./data/tweets_final_4_clean.csv\", engine='python')\n",
62 | "df5 = pd.read_csv(\"./data/tweets_final_5_clean.csv\", engine='python')\n",
63 | "df6 = pd.read_csv(\"./data/tweets_final_6_clean.csv\", engine='python')"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 5,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "pd.set_option('display.max_colwidth', -1)"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "metadata": {},
78 | "source": [
79 | "**2. Briefly analyse their content**"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 43,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "data": {
89 | "text/html": [
90 | "
\n",
91 | "\n",
104 | "
\n",
105 | " \n",
106 | " \n",
107 | " \n",
108 | " tweet \n",
109 | " target \n",
110 | " \n",
111 | " \n",
112 | " \n",
113 | " \n",
114 | " 0 \n",
115 | " Be happy. Be confident. Be kind.\\n\\n #KissablesLoveSMShopmag\\nAllOutDenimFor KISSMARC \n",
116 | " 0 \n",
117 | " \n",
118 | " \n",
119 | " 1 \n",
120 | " @queenjlouise @mamaw_gereck awe thanks 😊 \n",
121 | " 0 \n",
122 | " \n",
123 | " \n",
124 | "
\n",
125 | "
"
126 | ],
127 | "text/plain": [
128 | " tweet \\\n",
129 | "0 Be happy. Be confident. Be kind.\\n\\n #KissablesLoveSMShopmag\\nAllOutDenimFor KISSMARC \n",
130 | "1 @queenjlouise @mamaw_gereck awe thanks 😊 \n",
131 | "\n",
132 | " target \n",
133 | "0 0 \n",
134 | "1 0 "
135 | ]
136 | },
137 | "execution_count": 43,
138 | "metadata": {},
139 | "output_type": "execute_result"
140 | }
141 | ],
142 | "source": [
143 | "df0.head(2)"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 44,
149 | "metadata": {},
150 | "outputs": [
151 | {
152 | "data": {
153 | "text/html": [
154 | "\n",
155 | "\n",
168 | "
\n",
169 | " \n",
170 | " \n",
171 | " \n",
172 | " tweet \n",
173 | " target \n",
174 | " \n",
175 | " \n",
176 | " \n",
177 | " \n",
178 | " 0 \n",
179 | " Be happy. Be confident. Be kind.\\n\\n #KissablesLoveSMShopmag\\nAllOutDenimFor KISSMARC \n",
180 | " 0 \n",
181 | " \n",
182 | " \n",
183 | " 1 \n",
184 | " @queenjlouise @mamaw_gereck awe thanks 😊 \n",
185 | " 0 \n",
186 | " \n",
187 | " \n",
188 | "
\n",
189 | "
"
190 | ],
191 | "text/plain": [
192 | " tweet \\\n",
193 | "0 Be happy. Be confident. Be kind.\\n\\n #KissablesLoveSMShopmag\\nAllOutDenimFor KISSMARC \n",
194 | "1 @queenjlouise @mamaw_gereck awe thanks 😊 \n",
195 | "\n",
196 | " target \n",
197 | "0 0 \n",
198 | "1 0 "
199 | ]
200 | },
201 | "execution_count": 44,
202 | "metadata": {},
203 | "output_type": "execute_result"
204 | }
205 | ],
206 | "source": [
207 | "df0.head(2)"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 45,
213 | "metadata": {},
214 | "outputs": [
215 | {
216 | "data": {
217 | "text/html": [
218 | "\n",
219 | "\n",
232 | "
\n",
233 | " \n",
234 | " \n",
235 | " \n",
236 | " tweet \n",
237 | " target \n",
238 | " \n",
239 | " \n",
240 | " \n",
241 | " \n",
242 | " 0 \n",
243 | " mood can be caused by infectious diseases, nutritional deficiencies, neurological conditions, and physiological problems. \n",
244 | " 0 \n",
245 | " \n",
246 | " \n",
247 | " 1 \n",
248 | " With all of this unnessary family drama, I feel like moving far away and starting over again. From one thing to another I just feel . Hope I get through this \n",
249 | " 1 \n",
250 | " \n",
251 | " \n",
252 | "
\n",
253 | "
"
254 | ],
255 | "text/plain": [
256 | " tweet \\\n",
257 | "0 mood can be caused by infectious diseases, nutritional deficiencies, neurological conditions, and physiological problems. \n",
258 | "1 With all of this unnessary family drama, I feel like moving far away and starting over again. From one thing to another I just feel . Hope I get through this \n",
259 | "\n",
260 | " target \n",
261 | "0 0 \n",
262 | "1 1 "
263 | ]
264 | },
265 | "execution_count": 45,
266 | "metadata": {},
267 | "output_type": "execute_result"
268 | }
269 | ],
270 | "source": [
271 | "df1.head(2)"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 46,
277 | "metadata": {},
278 | "outputs": [
279 | {
280 | "data": {
281 | "text/html": [
282 | "\n",
283 | "\n",
296 | "
\n",
297 | " \n",
298 | " \n",
299 | " \n",
300 | " tweet \n",
301 | " target \n",
302 | " \n",
303 | " \n",
304 | " \n",
305 | " \n",
306 | " 0 \n",
307 | " Looking for some beautiful images. Having a really crappy week, and need some support. . Post some great beauty. Thank you. \n",
308 | " 1 \n",
309 | " \n",
310 | " \n",
311 | " 1 \n",
312 | " When teens feel down, there are ways they can cope with these feelings to avoid serious , such as make new friends or social connections and participate in sports or school activities. Call Focus & Balance at 210-858-9980 to learn more and see how we can help. pic.twitter.com/ydrYcTBD2r \n",
313 | " 0 \n",
314 | " \n",
315 | " \n",
316 | "
\n",
317 | "
"
318 | ],
319 | "text/plain": [
320 | " tweet \\\n",
321 | "0 Looking for some beautiful images. Having a really crappy week, and need some support. . Post some great beauty. Thank you. \n",
322 | "1 When teens feel down, there are ways they can cope with these feelings to avoid serious , such as make new friends or social connections and participate in sports or school activities. Call Focus & Balance at 210-858-9980 to learn more and see how we can help. pic.twitter.com/ydrYcTBD2r \n",
323 | "\n",
324 | " target \n",
325 | "0 1 \n",
326 | "1 0 "
327 | ]
328 | },
329 | "execution_count": 46,
330 | "metadata": {},
331 | "output_type": "execute_result"
332 | }
333 | ],
334 | "source": [
335 | "df2.head(2)"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 47,
341 | "metadata": {},
342 | "outputs": [
343 | {
344 | "data": {
345 | "text/html": [
346 | "\n",
347 | "\n",
360 | "
\n",
361 | " \n",
362 | " \n",
363 | " \n",
364 | " tweet \n",
365 | " target \n",
366 | " \n",
367 | " \n",
368 | " \n",
369 | " \n",
370 | " 0 \n",
371 | " Why is tomorrow Monday??? & LOL \n",
372 | " 0 \n",
373 | " \n",
374 | " \n",
375 | " 1 \n",
376 | " I feel like I'm annoying like Always… \n",
377 | " 1 \n",
378 | " \n",
379 | " \n",
380 | "
\n",
381 | "
"
382 | ],
383 | "text/plain": [
384 | " tweet target\n",
385 | "0 Why is tomorrow Monday??? & LOL 0 \n",
386 | "1 I feel like I'm annoying like Always… 1 "
387 | ]
388 | },
389 | "execution_count": 47,
390 | "metadata": {},
391 | "output_type": "execute_result"
392 | }
393 | ],
394 | "source": [
395 | "df3.head(2)"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 48,
401 | "metadata": {},
402 | "outputs": [
403 | {
404 | "data": {
405 | "text/html": [
406 | "\n",
407 | "\n",
420 | "
\n",
421 | " \n",
422 | " \n",
423 | " \n",
424 | " tweet \n",
425 | " target \n",
426 | " \n",
427 | " \n",
428 | " \n",
429 | " \n",
430 | " 0 \n",
431 | " 2016 was the best year so far \n",
432 | " 0 \n",
433 | " \n",
434 | " \n",
435 | " 1 \n",
436 | " So over new years think I'm just going to sleep into the new year. \n",
437 | " 0 \n",
438 | " \n",
439 | " \n",
440 | "
\n",
441 | "
"
442 | ],
443 | "text/plain": [
444 | " tweet \\\n",
445 | "0 2016 was the best year so far \n",
446 | "1 So over new years think I'm just going to sleep into the new year. \n",
447 | "\n",
448 | " target \n",
449 | "0 0 \n",
450 | "1 0 "
451 | ]
452 | },
453 | "execution_count": 48,
454 | "metadata": {},
455 | "output_type": "execute_result"
456 | }
457 | ],
458 | "source": [
459 | "df4.head(2)"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": 49,
465 | "metadata": {},
466 | "outputs": [
467 | {
468 | "data": {
469 | "text/html": [
470 | "\n",
471 | "\n",
484 | "
\n",
485 | " \n",
486 | " \n",
487 | " \n",
488 | " tweet \n",
489 | " target \n",
490 | " \n",
491 | " \n",
492 | " \n",
493 | " \n",
494 | " 0 \n",
495 | " On topic of famous people , apparently stuff needs to be put in place for mental illness. What about non famous people \n",
496 | " 0 \n",
497 | " \n",
498 | " \n",
499 | " 1 \n",
500 | " Tripped hard today! This is promising! \n",
501 | " 1 \n",
502 | " \n",
503 | " \n",
504 | "
\n",
505 | "
"
506 | ],
507 | "text/plain": [
508 | " tweet \\\n",
509 | "0 On topic of famous people , apparently stuff needs to be put in place for mental illness. What about non famous people \n",
510 | "1 Tripped hard today! This is promising! \n",
511 | "\n",
512 | " target \n",
513 | "0 0 \n",
514 | "1 1 "
515 | ]
516 | },
517 | "execution_count": 49,
518 | "metadata": {},
519 | "output_type": "execute_result"
520 | }
521 | ],
522 | "source": [
523 | "df5.head(2)"
524 | ]
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": 50,
529 | "metadata": {},
530 | "outputs": [
531 | {
532 | "data": {
533 | "text/html": [
534 | "\n",
535 | "\n",
548 | "
\n",
549 | " \n",
550 | " \n",
551 | " \n",
552 | " tweet \n",
553 | " target \n",
554 | " \n",
555 | " \n",
556 | " \n",
557 | " \n",
558 | " 0 \n",
559 | " Why is tomorrow Monday??? & LOL \n",
560 | " 0 \n",
561 | " \n",
562 | " \n",
563 | " 1 \n",
564 | " I feel like I'm annoying like Always… \n",
565 | " 1 \n",
566 | " \n",
567 | " \n",
568 | "
\n",
569 | "
"
570 | ],
571 | "text/plain": [
572 | " tweet target\n",
573 | "0 Why is tomorrow Monday??? & LOL 0 \n",
574 | "1 I feel like I'm annoying like Always… 1 "
575 | ]
576 | },
577 | "execution_count": 50,
578 | "metadata": {},
579 | "output_type": "execute_result"
580 | }
581 | ],
582 | "source": [
583 | "df3.head(2)"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": 17,
589 | "metadata": {},
590 | "outputs": [
591 | {
592 | "data": {
593 | "text/plain": [
594 | "[Text(0.5, 1.0, 'target distribution - df6')]"
595 | ]
596 | },
597 | "execution_count": 17,
598 | "metadata": {},
599 | "output_type": "execute_result"
600 | },
601 | {
602 | "data": {
603 | "image/png": "\n",
604 | "text/plain": [
605 | ""
606 | ]
607 | },
608 | "metadata": {
609 | "needs_background": "light"
610 | },
611 | "output_type": "display_data"
612 | }
613 | ],
614 | "source": [
615 | "fig = plt.figure(figsize=(15,8))\n",
616 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
617 | "\n",
618 | "ax1 = fig.add_subplot(2,3,1)\n",
619 | "sns.countplot(x='target', data=df1)\n",
620 | "ax1.set(title=\"target distribution - df1\")\n",
621 | "\n",
622 | "ax2 = fig.add_subplot(2,3,2)\n",
623 | "sns.countplot(x='target', data=df2)\n",
624 | "ax2.set(title=\"target distribution - df2\")\n",
625 | "\n",
626 | "ax3 = fig.add_subplot(2,3,3)\n",
627 | "sns.countplot(x='target', data=df3)\n",
628 | "ax3.set(title=\"target distribution - df3\")\n",
629 | "\n",
630 | "ax4 = fig.add_subplot(2,3,4)\n",
631 | "sns.countplot(x='target', data=df4)\n",
632 | "ax4.set(title=\"target distribution - df4\")\n",
633 | "\n",
634 | "ax5 = fig.add_subplot(2,3,5)\n",
635 | "sns.countplot(x='target', data=df5)\n",
636 | "ax5.set(title=\"target distribution - df5\")\n",
637 | "\n",
638 | "ax6 = fig.add_subplot(2,3,6)\n",
639 | "sns.countplot(x='target', data=df6)\n",
640 | "ax6.set(title=\"target distribution - df6\")"
641 | ]
642 | },
643 | {
644 | "cell_type": "code",
645 | "execution_count": 18,
646 | "metadata": {},
647 | "outputs": [
648 | {
649 | "data": {
650 | "image/png": "\n",
651 | "text/plain": [
652 | ""
653 | ]
654 | },
655 | "metadata": {
656 | "needs_background": "light"
657 | },
658 | "output_type": "display_data"
659 | }
660 | ],
661 | "source": [
662 | "fig = plt.figure(figsize=(5,3))\n",
663 | "plt.title('distribution of non-depressive dataset')\n",
664 | "ax = sns.barplot(x=df0.target.unique(),y=df0.target.value_counts());\n",
665 | "ax.set(xlabel='Labels');"
666 | ]
667 | },
668 | {
669 | "cell_type": "markdown",
670 | "metadata": {},
671 | "source": [
672 | "**3. Standardise dataset format**"
673 | ]
674 | },
675 | {
676 | "cell_type": "code",
677 | "execution_count": 20,
678 | "metadata": {},
679 | "outputs": [],
680 | "source": [
681 | "df0 = df0[['Tweet','target']].copy()\n",
682 | "\n",
683 | "df1 = df1[['tweet_processed','target']].copy()\n",
684 | "df2 = df2[['tweet_processed','target']].copy()\n",
685 | "df3 = df3[['tweet_processed','target']].copy()\n",
686 | "df4 = df4[['tweet_processed','target']].copy()\n",
687 | "df5 = df5[['tweet_processed','target']].copy()\n",
688 | "df6 = df6[['tweet_processed','target']].copy()"
689 | ]
690 | },
691 | {
692 | "cell_type": "code",
693 | "execution_count": 22,
694 | "metadata": {},
695 | "outputs": [],
696 | "source": [
697 | "df0 = df0.rename(columns = {\"Tweet\": \"tweet\"}) \n",
698 | "\n",
699 | "df1 = df1.rename(columns = {\"tweet_processed\": \"tweet\"}) \n",
700 | "df2 = df2.rename(columns = {\"tweet_processed\": \"tweet\"}) \n",
701 | "df3 = df3.rename(columns = {\"tweet_processed\": \"tweet\"}) \n",
702 | "df4 = df4.rename(columns = {\"tweet_processed\": \"tweet\"}) \n",
703 | "df5 = df5.rename(columns = {\"tweet_processed\": \"tweet\"}) \n",
704 | "df6 = df6.rename(columns = {\"tweet_processed\": \"tweet\"}) "
705 | ]
706 | },
707 | {
708 | "cell_type": "markdown",
709 | "metadata": {},
710 | "source": [
711 | "let´s skep df1 and df2 for the time being, as the labels are a little inconsistent"
712 | ]
713 | },
714 | {
715 | "cell_type": "code",
716 | "execution_count": 29,
717 | "metadata": {},
718 | "outputs": [],
719 | "source": [
720 | "df_all = pd.concat([df0, df1, df2, df3, df4, df5, df6])"
721 | ]
722 | },
723 | {
724 | "cell_type": "code",
725 | "execution_count": 56,
726 | "metadata": {},
727 | "outputs": [
728 | {
729 | "data": {
730 | "text/html": [
731 | "\n",
732 | "\n",
745 | "
\n",
746 | " \n",
747 | " \n",
748 | " \n",
749 | " tweet \n",
750 | " target \n",
751 | " \n",
752 | " \n",
753 | " \n",
754 | " \n",
755 | " 0 \n",
756 | " Today in Selfcare: beauty & laughs Kung Fu Panda 3 #Wellness #joy #laughter #selfcare #therapist #philadelphia \n",
757 | " 0 \n",
758 | " \n",
759 | " \n",
760 | " 1 \n",
761 | " I get to spend New Year's home again alone and lonely. ???• \n",
762 | " 1 \n",
763 | " \n",
764 | " \n",
765 | " 2 \n",
766 | " Depressed and lonely /: Stuck in a deep, never ending hole :( \n",
767 | " 1 \n",
768 | " \n",
769 | " \n",
770 | " 3 \n",
771 | " If this is your response to someone saying they're dealing with , you're a terrible person. \n",
772 | " 0 \n",
773 | " \n",
774 | " \n",
775 | " 4 \n",
776 | " Apparently you get a free pass just by mentioning Where was I on the free badge day??!! \n",
777 | " 0 \n",
778 | " \n",
779 | " \n",
780 | " 5 \n",
781 | " When you will never again give birth to violent men.. pic.twitter.com/pkdPhhlUuZ \n",
782 | " 0 \n",
783 | " \n",
784 | " \n",
785 | " 6 \n",
786 | " Learning to pretend to have a good time had become a natural skill. I hope one day it is genuine \n",
787 | " 1 \n",
788 | " \n",
789 | " \n",
790 | " 7 \n",
791 | " Aw man im outta pizza rolls \n",
792 | " 0 \n",
793 | " \n",
794 | " \n",
795 | " 8 \n",
796 | " When you go out and try to be a part of life & end up feeling like you are less a part of it then when you started. pic.twitter.com/J625NXrWDb \n",
797 | " 0 \n",
798 | " \n",
799 | " \n",
800 | " 9 \n",
801 | " So far he stop texting me…after I said something…so hopefully he doesn't show up at my house… \n",
802 | " 1 \n",
803 | " \n",
804 | " \n",
805 | "
\n",
806 | "
"
807 | ],
808 | "text/plain": [
809 | " tweet \\\n",
810 | "0 Today in Selfcare: beauty & laughs Kung Fu Panda 3 #Wellness #joy #laughter #selfcare #therapist #philadelphia \n",
811 | "1 I get to spend New Year's home again alone and lonely. ???• \n",
812 | "2 Depressed and lonely /: Stuck in a deep, never ending hole :( \n",
813 | "3 If this is your response to someone saying they're dealing with , you're a terrible person. \n",
814 | "4 Apparently you get a free pass just by mentioning Where was I on the free badge day??!! \n",
815 | "5 When you will never again give birth to violent men.. pic.twitter.com/pkdPhhlUuZ \n",
816 | "6 Learning to pretend to have a good time had become a natural skill. I hope one day it is genuine \n",
817 | "7 Aw man im outta pizza rolls \n",
818 | "8 When you go out and try to be a part of life & end up feeling like you are less a part of it then when you started. pic.twitter.com/J625NXrWDb \n",
819 | "9 So far he stop texting me…after I said something…so hopefully he doesn't show up at my house… \n",
820 | "\n",
821 | " target \n",
822 | "0 0 \n",
823 | "1 1 \n",
824 | "2 1 \n",
825 | "3 0 \n",
826 | "4 0 \n",
827 | "5 0 \n",
828 | "6 1 \n",
829 | "7 0 \n",
830 | "8 0 \n",
831 | "9 1 "
832 | ]
833 | },
834 | "execution_count": 56,
835 | "metadata": {},
836 | "output_type": "execute_result"
837 | }
838 | ],
839 | "source": [
840 | "df_all = df_all.sample(frac=1).reset_index(drop=True)\n",
841 | "df_all.head(10)"
842 | ]
843 | },
844 | {
845 | "cell_type": "code",
846 | "execution_count": 31,
847 | "metadata": {},
848 | "outputs": [
849 | {
850 | "data": {
851 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVcAAADgCAYAAAC3iSVhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFapJREFUeJzt3Xu4VXWdx/H3B0Qt0MI4GlfxUabU8YbkJdNQi5SmsJksTYXUCZ2wmpmyzGceQc1saiqz8joSWCpRU8kUhXhJs7yAI6J4CVKQAygoKnhJhb7zx/odXR73PmcfOL+zzz58Xs+zn7P3b/3WWt+99j6fs9Zvrb2PIgIzM+tcvepdgJlZT+RwNTPLwOFqZpaBw9XMLAOHq5lZBg5XM7MMHK5WlaRPS7q9nT5/lLRfV9XURh3t1tqdSFokafRmLmMbSQ9L2rED84yW1Lw567XaOFxtk0n6CLA+Iu6tdy2NJiL2jIjfb+YyXgamAl/plKJa6ao/WI32h7FWDtctlKStOmExpwM/7oTldFudtJ1yuhaYIGmbehdib+RwrQNJIyXdK2m9pJ9J+qmkr5Wm/4OkBZKelfQnSXuXpi2V9CVJCyU9l+bdtgPzfkXSQuAFSVtJOkvSX1ItD0r6WI3PYWvgCODWUtsUSTMlXZ2Wt0jSqNL03SX9PtW2SNJHS9OmSfqhpN+kee+StGsb63+HpFmS1km6G9i11fR3S5oraa2kRyR9otW6LkvT10u6VdLOpekhaZKkxcDiGpY3Nm279ZJWSPpSah8g6dfp+a6V9AdJvUqvxQckDZL0kqQdSsvbT9JTkvqkx6dIekjSM5LmlGuNiGbgGeCgKtvpLen5PiPpQeA9raZXfP0l7Q5cBhws6XlJz6b2D6f37jpJyyVNKS1rW0k/kfR0es7zJO2Upr1N0lWSVqVt9DVJvautp0eICN+68AZsDSwDvgD0Af4ReAX4Wpo+ElgNHAj0BiYAS4Ft0vSlwN3AIGAH4CHg9A7MuwAYCrwltR2bltUL+CTwAjAwTfs0cHuV57En8EKrtinAX4Gxaf0XAnemaX2AJcDZaRscAawH3pWmTwPWAgcAWwHXADPa2I4zgJlAX+DvgRUttaa25cDJaVkjgaeAPUvrWg8cBmwDfK/8PIEA5qbt+5YalrcKODTd7w+MTPcvpAiOPul2KKDSa/GBdP9m4DOl9X8LuCzdPyZtt93Tuv8D+FOrbTEL+HyV7fQN4A/puQwFHgCaS9M79PoDo4G9Uv+9gSeBY9K004D/Bd6aXv/9ge3TtF8Bl6dtuSPFe/i09t5njXyrewFb2i39Qq9o+SVLbbfzerheCpzfap5HgPen+0uBE0vTvln6Raxl3lPaqW8BMC7dr/qmBw4BnmjVNgW4sfR4D+CldP9Q4AmgV2n6dcCUdH8a8N+laWOBh6usuzfwKvDuUtvXeT1cPwn8odU8lwOTS+uaUZrWD9gIDE2PAziiNL295T2egmX7Vn3OA64HdqvwHJbyerj+M3Bzui+KID8sPf4tcGppvl7Ai8DOpbZrgHOqbKtHgaNKjydSCtdNff1L/S8CvpvunwL8Cdi7VZ+dgJdJf9BT2/HALbWupxFvHhboeoOAFZHeVcny0v2dgS+mw6pn02HS0DRfiydK91+kCIda5y2vC0njS8MIz1LsBQ6o4Xk8A2xXob11bduqGLccBCyPiL+Vpi8DBrf3vCSdnQ4Zn5d0GdBEsRdXfi7LSvd3Bg5stR1OAN5Z6vPavBHxPMVec7Xt1N7y/onij8GyNMRwcGr/FsVe5w2SHpV0FpX9nOKweBDFH9+g2NtsWff3SutdSxHA5e22HVDtcHoQ1bdTh19/SQdKukXSGknPUYy7t/T/MTAHmCFppaRvpqGNnSn23FeV1nM5xR5sj9XdB+t7olXAYEkqBexQ4C/p/nLggoi4YBOWXcu8r4V6Gru7EjgSuCMiNkpaQPHL257FxSI0OCJW1NB/JTBUUq9SwA4D/tzejBHxdYo905a6ewMbKLbbw6VltVgO3BoRH2xjsUNLy+tHcdi8srzaWpcXEfOAcSlIzqAYrhgaEeuBL1L8wdsTuEXSvIi4qdX8z0q6AfgExeH/daX3Rstrek0bz2V34NtVpq1Kz3VRevzadqrh9a/0lXnXAj8Ajo6Iv0q6iBSuEfEqcC5wrqThwGyKI6fZFHuuAyJiQ4Vl9siv5vOea9e7g+IQ9AwVJ5TGUYwztrgSOD3tIUhS33QSodJeYmsdnbcvxRt7DYCkkyn2XNqVfpFuBN5fS3/gLorxvC9L6qPiGs+PUIyddkhEbAR+AUyR9FZJe1CML7f4NfB3kk5K6+oj6T3p5EmLsZLep+LE3PnAXRHxhr36WpYnaWtJJ0h6W9om6yhe35aTi7tJUql9Y5V1XAuMp9gLvrbUfhnw1RTOLSeGjm2ZKGkwxR+GO6ssd2aav7+kIcDnStPae/2fBIakbdRiO2BtCtYDgE+Vajlc0l7pj986iqGbjRGxCrgB+Lak7SX1krSrpPe3sZ6G53DtYhHxCsVJrFMpDuVOpPjlfTlNnw98hmLv4BmKw8pP17jsDs0bEQ9S7PHcQfEG3wv4YweezuXASTXW9grwUeBoipNBlwDjI+LhNmes7gyKYYMnKMZQf1Ra13pgDHAcxd7oE8B/Upy8anEtMJniMHt/isP8arW3t7yTgKWS1lEcJp+Y2kdQ/AF6nmIbXxLVr22dlfo/GRH3ldb9y7SuGWn5D1BswxafAqZHcc1rJedSDAU8RhFwr106V8PrfzPFHu8Tkp5KbZ8FzpO0HjiHIrxbvJNiiGMdxYnWW4GfpGnjKU5kPkjx3vw5MLCN9TQ8RfTIPfKGIukuipNSP2q3czej4uLvz0UDfZBA0jSKkzr/Ue9aNoeKa1vvozj5tbre9dgbecy1DtLh0CMUe3AnUFzS8ru6FrWJIuJ99a5hS5X2Vt9d7zqsModrfbyL4nCqH8WJrI+ncSkz6yE8LGBmloFPaJmZZeBwNTPLoEeOuQ4YMCCGDx9e7zLMrIe55557noqIplr69shwHT58OPPnz693GWbWw0ha1n6vgocFzMwycLiamWXgcDUzy8DhamaWgcPVzCyDHnm1wKba/8yr612CbYZ7vjW+3iWYvcZ7rmZmGThczcwycLiamWXgcDUzy8DhamaWgcPVzCwDh6uZWQYOVzOzDByuZmYZOFzNzDJwuJqZZeBwNTPLwOFqZpaBw9XMLAOHq5lZBg5XM7MMHK5mZhk4XM3MMnC4mpllkC1cJQ2VdIukhyQtkvSF1L6DpLmSFqef/VO7JF0saYmkhZJGlpY1IfVfLGlCrprNzDpLzj3XDcAXI2J34CBgkqQ9gLOAmyJiBHBTegxwNDAi3SYCl0IRxsBk4EDgAGBySyCbmXVX2cI1IlZFxP+l++uBh4DBwDhgeuo2HTgm3R8HXB2FO4G3SxoIfAiYGxFrI+IZYC5wVK66zcw6Q5eMuUoaDuwH3AXsFBGroAhgYMfUbTCwvDRbc2qr1m5m1m1lD1dJ/YD/Af41Ita11bVCW7TR3no9EyXNlzR/zZo1m1asmVknyRqukvpQBOs1EfGL1PxkOtwn/Vyd2puBoaXZhwAr22h/g4i4IiJGRcSopqamzn0iZmYdlPNqAQFXAQ9FxHdKk2YBLWf8JwDXl9rHp6sGDgKeS8MGc4AxkvqnE1ljUpuZWbe1VcZlHwKcBNwvaUFqOxv4BjBT0qnA48CxadpsYCywBHgROBkgItZKOh+Yl/qdFxFrM9ZtZrbZsoVrRNxO5fFSgCMr9A9gUpVlTQWmdl51ZmZ5+RNaZmYZOFzNzDJwuJqZZeBwNTPLwOFqZpaBw9XMLAOHq5lZBg5XM7MMHK5mZhk4XM3MMnC4mpll4HA1M8vA4WpmloHD1cwsA4ermVkGDlczswwcrmZmGThczcwycLiamWXgcDUzy8DhamaWgcPVzCwDh6uZWQYOVzOzDLKFq6SpklZLeqDUNkXSCkkL0m1sadpXJS2R9IikD5Xaj0ptSySdlateM7POlHPPdRpwVIX270bEvuk2G0DSHsBxwJ5pnksk9ZbUG/ghcDSwB3B86mtm1q1tlWvBEXGbpOE1dh8HzIiIl4HHJC0BDkjTlkTEowCSZqS+D3ZyuWZmnaoeY65nSFqYhg36p7bBwPJSn+bUVq3dzKxb6+pwvRTYFdgXWAV8O7WrQt9oo/1NJE2UNF/S/DVr1nRGrWZmm6xLwzUinoyIjRHxN+BKXj/0bwaGlroOAVa20V5p2VdExKiIGNXU1NT5xZuZdUCXhqukgaWHHwNariSYBRwnaRtJuwAjgLuBecAISbtI2pripNesrqzZzGxTZDuhJek6YDQwQFIzMBkYLWlfikP7pcBpABGxSNJMihNVG4BJEbExLecMYA7QG5gaEYty1Wxm1lnaDVdJh0TEH9tray0ijq/QfFUb/S8ALqjQPhuY3V6dZmbdSS3DAt+vsc3MzJKqe66SDgbeCzRJ+vfSpO0pDtHNzKyKtoYFtgb6pT7bldrXAR/PWZSZWaOrGq4RcStwq6RpEbFMUt+IeKELazMza1i1jLkOkvQg8BCApH0kXZK3LDOzxlZLuF4EfAh4GiAi7gMOy1mUmVmjq+lDBBGxvFXTxgy1mJn1GLV8iGC5pPcCkT4l9XnSEIGZmVVWy57r6cAkim+jaqb40pVJOYsyM2t07e65RsRTwAldUIuZWY9Ry8dfL67Q/BwwPyKu7/ySzMwaXy3DAttSDAUsTre9gR2AUyVdlLE2M7OGVcsJrd2AIyJiA4CkS4EbgA8C92eszcysYdWy5zoY6Ft63BcYlL4S8OUsVZmZNbha9ly/CSyQ9HuKf7tyGPB1SX2BGzPWZmbWsNoMV0miGAKYTfEvWQScHREt/2rlzLzlmZk1pjbDNSJC0q8iYn/AVwaYmdWoljHXOyW9J3slZmY9SC1jrocDp0laBrxAMTQQEbF31srMzBpYLeF6dPYqzMx6mFo+/roMQNKOFB8oMDOzdrQ75irpo5IWA48Bt1L8S+zfZq7LzKyh1XJC63zgIODPEbELcCTQ5r/VNjPb0tUSrq9GxNNAL0m9IuIWiu8aMDOzKmoJ12cl9QNuA66R9D3g1fZmkjRV0mpJD5TadpA0V9Li9LN/apekiyUtkbRQ0sjSPBNS/8WSJnT8KZqZdb1awvU+4EXg34DfAX8BHq5hvmnAUa3azgJuiogRwE3pMRRXJIxIt4nApVCEMTAZOJDiE2KTWwLZzKw7qyVcD4+Iv0XEhoiYHhEXA+1+qCAibgPWtmoeB0xP96cDx5Tar47CncDbJQ2k+MeIcyNibUQ8A8zlzYFtZtbtVL0US9K/AJ8FdpW0sDRpOzb9hNZOEbEKICJWpcu7oPjmrfI/QWxObdXazcy6tbauc72W4pKrC3n98B1gfUS03iPdXKrQFm20v3kB0kSKIQWGDRvWeZWZmW2CqsMCEfFcRCyNiOMjYlnptjnB+mQ63Cf9XJ3am4GhpX5DgJVttFeq94qIGBURo5qamjajRDOzzVfLmGtnmgW0nPGfwOvftDULGJ+uGjgIeC4NH8wBxkjqn05kjUltZmbdWi3fLbBJJF0HjAYGSGqmOOv/DWCmpFOBx4FjU/fZwFhgCcWVCScDRMRaSecD81K/8zIMSZiZdbps4RoRx1eZdGSFvgFMqrKcqcDUTizNzCy7rh4WMDPbIjhczcwycLiamWXgcDUzy8DhamaWQbarBcx6usfP26veJdhmGHbO/VmX7z1XM7MMHK5mZhk4XM3MMnC4mpll4HA1M8vA4WpmloHD1cwsA4ermVkGDlczswwcrmZmGThczcwycLiamWXgcDUzy8DhamaWgcPVzCwDh6uZWQYOVzOzDByuZmYZ1CVcJS2VdL+kBZLmp7YdJM2VtDj97J/aJeliSUskLZQ0sh41m5l1RD33XA+PiH0jYlR6fBZwU0SMAG5KjwGOBkak20Tg0i6v1Mysg7rTsMA4YHq6Px04ptR+dRTuBN4uaWA9CjQzq1W9wjWAGyTdI2liatspIlYBpJ87pvbBwPLSvM2pzcys26rXv9Y+JCJWStoRmCvp4Tb6qkJbvKlTEdITAYYNG9Y5VZqZbaK67LlGxMr0czXwS+AA4MmWw/30c3Xq3gwMLc0+BFhZYZlXRMSoiBjV1NSUs3wzs3Z1ebhK6itpu5b7wBjgAWAWMCF1mwBcn+7PAsanqwYOAp5rGT4wM+uu6jEssBPwS0kt6782In4naR4wU9KpwOPAsan/bGAssAR4ETi560s2M+uYLg/XiHgU2KdC+9PAkRXaA5jUBaWZmXWa7nQplplZj+FwNTPLwOFqZpaBw9XMLAOHq5lZBg5XM7MMHK5mZhk4XM3MMnC4mpll4HA1M8vA4WpmloHD1cwsA4ermVkGDlczswwcrmZmGThczcwycLiamWXgcDUzy8DhamaWgcPVzCwDh6uZWQYOVzOzDByuZmYZOFzNzDJomHCVdJSkRyQtkXRWvesxM2tLQ4SrpN7AD4GjgT2A4yXtUd+qzMyqa4hwBQ4AlkTEoxHxCjADGFfnmszMqmqUcB0MLC89bk5tZmbd0lb1LqBGqtAWb+ggTQQmpofPS3oke1WNZwDwVL2LyEX/NaHeJfQ0Pfr9wuRKsdKunWvt2Cjh2gwMLT0eAqwsd4iIK4ArurKoRiNpfkSMqncd1hj8ftk8jTIsMA8YIWkXSVsDxwGz6lyTmVlVDbHnGhEbJJ0BzAF6A1MjYlGdyzIzq6ohwhUgImYDs+tdR4PzsIl1hN8vm0ER0X4vMzPrkEYZczUzaygO1y2EPz5stZI0VdJqSQ/Uu5ZG5nDdAvjjw9ZB04Cj6l1Eo3O4bhn88WGrWUTcBqytdx2NzuG6ZfDHh826mMN1y9Dux4fNrHM5XLcM7X582Mw6l8N1y+CPD5t1MYfrFiAiNgAtHx9+CJjpjw9bNZKuA+4A3iWpWdKp9a6pEfkTWmZmGXjP1cwsA4ermVkGDlczswwcrmZmGThczcwycLhajyDp+Q70nSLpS7mWbwYOVzOzLByu1mNJ+oikuyTdK+lGSTuVJu8j6WZJiyV9pjTPmZLmSVoo6dwKyxwo6TZJCyQ9IOnQLnky1nAcrtaT3Q4cFBH7UXzN4pdL0/YGPgwcDJwjaZCkMcAIiq9o3BfYX9JhrZb5KWBOROwL7AMsyPwcrEE1zD8oNNsEQ4CfShoIbA08Vpp2fUS8BLwk6RaKQH0fMAa4N/XpRxG2t5XmmwdMldQH+FVEOFytIu+5Wk/2feAHEbEXcBqwbWla6899B8VXM14YEfum224RcdUbOhVfJH0YsAL4saTx+cq3RuZwtZ7sbRQhCDCh1bRxkraV9A5gNMUe6RzgFEn9ACQNlrRjeSZJOwOrI+JK4CpgZMb6rYF5WMB6irdKai49/g4wBfiZpBXAncAupel3A78BhgHnR8RKYKWk3YE7JAE8D5wIrC7NNxo4U9Krabr3XK0ifyuWmVkGHhYwM8vA4WpmloHD1cwsA4ermVkGDlczswwcrmZmGThczcwycLiamWXw/5y3Msfd/d/oAAAAAElFTkSuQmCC\n",
852 | "text/plain": [
853 | ""
854 | ]
855 | },
856 | "metadata": {
857 | "needs_background": "light"
858 | },
859 | "output_type": "display_data"
860 | }
861 | ],
862 | "source": [
863 | "fig = plt.figure(figsize=(5,3))\n",
864 | "plt.title('general (non-depressive) dataset')\n",
865 | "ax = sns.barplot(x=df_all.target.unique(),y=df_all.target.value_counts());\n",
866 | "ax.set(xlabel='Labels');"
867 | ]
868 | },
869 | {
870 | "cell_type": "code",
871 | "execution_count": 57,
872 | "metadata": {},
873 | "outputs": [
874 | {
875 | "data": {
876 | "text/plain": [
877 | "0 2357\n",
878 | "1 843 \n",
879 | "Name: target, dtype: int64"
880 | ]
881 | },
882 | "execution_count": 57,
883 | "metadata": {},
884 | "output_type": "execute_result"
885 | }
886 | ],
887 | "source": [
888 | "df_all.target.value_counts()"
889 | ]
890 | },
891 | {
892 | "cell_type": "code",
893 | "execution_count": 58,
894 | "metadata": {},
895 | "outputs": [],
896 | "source": [
897 | "df_all.to_csv(\"./data/tweets_combined.csv\")"
898 | ]
899 | },
900 | {
901 | "cell_type": "markdown",
902 | "metadata": {},
903 | "source": [
904 | "## Conclusion:\n",
905 | "Even though the target is not balanced as there are more non-depressive data, we believe it covers a good range of depressive and non-depressive tweets and would be a good basis for developing a depression detection classifier"
906 | ]
907 | },
908 | {
909 | "cell_type": "code",
910 | "execution_count": null,
911 | "metadata": {},
912 | "outputs": [],
913 | "source": []
914 | }
915 | ],
916 | "metadata": {
917 | "kernelspec": {
918 | "display_name": "Python 3",
919 | "language": "python",
920 | "name": "python3"
921 | },
922 | "language_info": {
923 | "codemirror_mode": {
924 | "name": "ipython",
925 | "version": 3
926 | },
927 | "file_extension": ".py",
928 | "mimetype": "text/x-python",
929 | "name": "python",
930 | "nbconvert_exporter": "python",
931 | "pygments_lexer": "ipython3",
932 | "version": "3.6.8"
933 | }
934 | },
935 | "nbformat": 4,
936 | "nbformat_minor": 2
937 | }
938 |
--------------------------------------------------------------------------------
/3_depression_detector.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "Nm1Ba7FR5nKO"
8 | },
9 | "source": [
10 | "### **Project Showcase - Depression Detection using Twitter Data**"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "colab_type": "text",
17 | "id": "cFIRyYt97knb"
18 | },
19 | "source": [
20 | "This script contains pre-processing for the twitter data and model creation using PyTorch. \n",
21 | "\n",
22 | "We used TorchText, a PyTorch libray that made pre-pre-processing both simple and efficient, and applied custom techniques to work with our unique twitter data.\n",
23 | "\n",
24 | "A lot of the preprocessing and model section was inspired by code from this article, as we are new to NLP using PyTorch - https://medium.com/@sonicboom8/sentiment-analysis-torchtext-55fb57b1fab8"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 0,
30 | "metadata": {
31 | "colab": {},
32 | "colab_type": "code",
33 | "id": "KlrX5NJiBE8G"
34 | },
35 | "outputs": [],
36 | "source": [
37 | "import numpy as np\n",
38 | "import matplotlib.pyplot as plt\n",
39 | "import pandas as pd"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 0,
45 | "metadata": {
46 | "colab": {},
47 | "colab_type": "code",
48 | "id": "MzadFKTlP0wU"
49 | },
50 | "outputs": [],
51 | "source": [
52 | "import matplotlib.pyplot as plt\n",
53 | "import seaborn as sns\n",
54 | "import spacy\n",
55 | "from tqdm import tqdm, tqdm_notebook, tnrange\n",
56 | "tqdm.pandas(desc='Progress')"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 0,
62 | "metadata": {
63 | "colab": {},
64 | "colab_type": "code",
65 | "id": "nCm2NnqrP4fL"
66 | },
67 | "outputs": [],
68 | "source": [
69 | "import torchtext\n",
70 | "from torchtext.data import Field, BucketIterator, TabularDataset"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 0,
76 | "metadata": {
77 | "colab": {},
78 | "colab_type": "code",
79 | "id": "MSy7O88mQMhU"
80 | },
81 | "outputs": [],
82 | "source": [
83 | "from sklearn.model_selection import train_test_split\n",
84 | "from sklearn.metrics import accuracy_score"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 0,
90 | "metadata": {
91 | "colab": {},
92 | "colab_type": "code",
93 | "id": "6DkH-8NbQUiU"
94 | },
95 | "outputs": [],
96 | "source": [
97 | "import os, sys\n",
98 | "import re\n",
99 | "import string\n",
100 | "import itertools"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 0,
106 | "metadata": {
107 | "colab": {},
108 | "colab_type": "code",
109 | "id": "p1fswumgQZGs"
110 | },
111 | "outputs": [],
112 | "source": [
113 | "import torch\n",
114 | "import torch.nn as nn\n",
115 | "import torch.optim as optim\n",
116 | "from torch.autograd import Variable\n",
117 | "import torch.nn.functional as F\n",
118 | "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 0,
124 | "metadata": {
125 | "colab": {},
126 | "colab_type": "code",
127 | "id": "jfAbtdwOQa6s"
128 | },
129 | "outputs": [],
130 | "source": [
131 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 8,
137 | "metadata": {
138 | "colab": {
139 | "base_uri": "https://localhost:8080/",
140 | "height": 126
141 | },
142 | "colab_type": "code",
143 | "id": "ht17_lvKQpun",
144 | "outputId": "1587fb32-637b-442a-bbf4-965d8a633913"
145 | },
146 | "outputs": [
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "Python version: 3.6.8 (default, Jan 14 2019, 11:02:34) \n",
152 | "[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]]\n",
153 | "Pandas version: 0.24.2\n",
154 | "Pytorch version: 1.1.0\n",
155 | "Torch Text version: 0.3.1\n",
156 | "Spacy version: 2.1.8\n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "print('Python version:',sys.version)\n",
162 | "print('Pandas version:',pd.__version__)\n",
163 | "print('Pytorch version:', torch.__version__)\n",
164 | "print('Torch Text version:', torchtext.__version__)\n",
165 | "print('Spacy version:', spacy.__version__)"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "metadata": {
171 | "colab_type": "text",
172 | "id": "Y9Cy0aZfvPE0"
173 | },
174 | "source": [
175 | "## **1. Load Data**"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 9,
181 | "metadata": {
182 | "colab": {
183 | "base_uri": "https://localhost:8080/",
184 | "height": 35
185 | },
186 | "colab_type": "code",
187 | "id": "cl-sPuUTAPDF",
188 | "outputId": "5bf14b17-94e6-40cb-dc2a-505e592246bf"
189 | },
190 | "outputs": [
191 | {
192 | "name": "stdout",
193 | "output_type": "stream",
194 | "text": [
195 | "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n"
196 | ]
197 | }
198 | ],
199 | "source": [
200 | "from google.colab import drive\n",
201 | "drive.mount('/content/gdrive')"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 0,
207 | "metadata": {
208 | "colab": {},
209 | "colab_type": "code",
210 | "id": "2lvtAgdT7gPN"
211 | },
212 | "outputs": [],
213 | "source": [
214 | "df = pd.read_csv(\"/content/gdrive/My Drive/Colab Notebooks/data/tweets_combined.csv\")"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 0,
220 | "metadata": {
221 | "colab": {},
222 | "colab_type": "code",
223 | "id": "hPVE4THcCGId"
224 | },
225 | "outputs": [],
226 | "source": [
227 | "pd.set_option('display.max_colwidth', -1)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 12,
233 | "metadata": {
234 | "colab": {
235 | "base_uri": "https://localhost:8080/",
236 | "height": 198
237 | },
238 | "colab_type": "code",
239 | "id": "yXa79WdbBDPh",
240 | "outputId": "ee177313-ada5-482b-8755-d7a565788127"
241 | },
242 | "outputs": [
243 | {
244 | "data": {
245 | "text/html": [
246 | "\n",
247 | "\n",
260 | "
\n",
261 | " \n",
262 | " \n",
263 | " \n",
264 | " Unnamed: 0 \n",
265 | " tweet \n",
266 | " target \n",
267 | " \n",
268 | " \n",
269 | " \n",
270 | " \n",
271 | " 0 \n",
272 | " 0 \n",
273 | " Wanted to watch the TNF game but I forgot my parents cancelled the NFL Network package ???? \n",
274 | " 0 \n",
275 | " \n",
276 | " \n",
277 | " 1 \n",
278 | " 1 \n",
279 | " For all the years I took the lynx box sets for granted and this year I didn't get one ???¢???‚ \n",
280 | " 0 \n",
281 | " \n",
282 | " \n",
283 | " 2 \n",
284 | " 2 \n",
285 | " Hon. Miss Dashwood, whose manners very pretty face, she offended you have done before. \n",
286 | " 0 \n",
287 | " \n",
288 | " \n",
289 | " 3 \n",
290 | " 3 \n",
291 | " ebullient [ih BUL yunt] adj.boiling; bubbling with excitement; exuberant. A boiling liquid can be called ebullient. Excited or enthusiastic. \n",
292 | " 0 \n",
293 | " \n",
294 | " \n",
295 | " 4 \n",
296 | " 4 \n",
297 | " All the proud parents on fb about their kids school report and am shitting myself for Graces arriving 😂😂😂 #troublemaker \n",
298 | " 0 \n",
299 | " \n",
300 | " \n",
301 | "
\n",
302 | "
"
303 | ],
304 | "text/plain": [
305 | " Unnamed: 0 ... target\n",
306 | "0 0 ... 0 \n",
307 | "1 1 ... 0 \n",
308 | "2 2 ... 0 \n",
309 | "3 3 ... 0 \n",
310 | "4 4 ... 0 \n",
311 | "\n",
312 | "[5 rows x 3 columns]"
313 | ]
314 | },
315 | "execution_count": 12,
316 | "metadata": {
317 | "tags": []
318 | },
319 | "output_type": "execute_result"
320 | }
321 | ],
322 | "source": [
323 | "df.head()"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 13,
329 | "metadata": {
330 | "colab": {
331 | "base_uri": "https://localhost:8080/",
332 | "height": 72
333 | },
334 | "colab_type": "code",
335 | "id": "wMpkcMGEBqO5",
336 | "outputId": "03c6a2a4-11bf-4037-c8bc-528ce5489924"
337 | },
338 | "outputs": [
339 | {
340 | "data": {
341 | "text/plain": [
342 | "0 1936\n",
343 | "1 585 \n",
344 | "Name: target, dtype: int64"
345 | ]
346 | },
347 | "execution_count": 13,
348 | "metadata": {
349 | "tags": []
350 | },
351 | "output_type": "execute_result"
352 | }
353 | ],
354 | "source": [
355 | "df.target.value_counts()"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "execution_count": 14,
361 | "metadata": {
362 | "colab": {
363 | "base_uri": "https://localhost:8080/",
364 | "height": 228
365 | },
366 | "colab_type": "code",
367 | "id": "3x_LwHt-Pqm9",
368 | "outputId": "ec4799bc-84d9-4e98-dfe6-0407e4a603dc"
369 | },
370 | "outputs": [
371 | {
372 | "data": {
373 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVcAAADTCAYAAAA1Z1BiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADaJJREFUeJzt3X+s3XV9x/HnS5ARgUlZS1PbYolr\ntrEIFe4Af8zUudRCslWTxcDc6BhZ3QaL2zIm8w9xEKfZD+fAH0kdDZAojmVzbZZutTKEbAHXW8UC\nImsDdLQUWqxDGcYpvvfH+d54KPeW03o/99xz+3wkJ+d73t/P93veN7l55Xs+3+/3nFQVkqTp9bJh\nNyBJc5HhKkkNGK6S1IDhKkkNGK6S1IDhKkkNGK6S1IDhKkkNNAvXJEuT3Jnka0keTPKern5akq1J\ndnbP87p6ktyQZFeSHUnO7dvX2m78ziRrW/UsSdMlre7QSrIIWFRVX05yCrAdeDvwG8DBqvpwkmuA\neVX13iQXA78HXAxcAPxNVV2Q5DRgHBgDqtvPeVX1zanee/78+bVs2bImf5ekY9f27dufrqoFg4w9\nvlUTVbUP2NctfzvJQ8BiYA2wsht2C/BF4L1d/dbqpf29SU7tAnolsLWqDgIk2QqsBm6b6r2XLVvG\n+Ph4g79K0rEsye5Bx87InGuSZcDrgC8BC7vgBXgSWNgtLwYe79tsT1ebqn7oe6xLMp5k/MCBA9Pa\nvyQdqebhmuRk4B+A36+qb/Wv645Sp2VeoqrWV9VYVY0tWDDQUbskNdM0XJO8nF6wfrqq/rErP9V9\n3J+Yl93f1fcCS/s2X9LVpqpL0qzV8mqBADcBD1XVR/pWbQImzvivBTb21S/rrhq4EHimmz7YAqxK\nMq+7smBVV5OkWavZCS3gjcCvA/cnua+rvQ/4MHB7kiuA3cA7u3Wb6V0psAt4DrgcoKoOJrke2NaN\nu27i5JYkzVbNLsUaprGxsTqaqwXOu/rWBt1opmz/i8uG3YLmuCTbq2pskLHeoSVJDRiuktSA4SpJ\nDRiuktSA4SpJDRiuktSA4SpJDRiuktSA4SpJDRiuktSA4SpJDRiuktSA4SpJDRiuktSA4SpJDRiu\nktSA4SpJDRiuktSA4SpJDRiuktSA4SpJDRiuktSA4SpJDRiuktSA4SpJDRiuktSA4SpJDRiuktSA\n4SpJDRiuktSA4SpJDTQL1yQbkuxP8kBf7QNJ9ia5r3tc3LfuT5LsSvJwkrf11Vd3tV1JrmnVryRN\np5ZHrjcDqyep/3VVregemwGSnAVcAvxst80nkhyX5Djg48BFwFnApd1YSZrVjm+146q6O8myAYev\nAT5bVd8FHk2yCzi/W7erqh4BSPLZbuzXprldSZpWw5hzvSrJjm7aYF5XWww83jdmT1ebqv4iSdYl\nGU8yfuDAgRZ9S9LAZjpcPwm8BlgB7AP+arp2XFXrq2qsqsYWLFgwXbuVpKPSbFpgMlX11MRykk8B\n/9y93Ass7Ru6pKtxmLokzVozeuSaZFHfy3cAE1cSbAIuSfJjSc4ElgP/CWwDlic5M8kJ9E56bZrJ\nniXpaDQ7ck1yG7ASmJ9kD3AtsDLJCqCAx4B3A1TVg0lup3ei6vvAlVX1fLefq4AtwHHAhqp6sFXP\nkjRdWl4tcOkk5ZsOM/6DwAcnqW8GNk9ja5LUnHdoSVIDhqskNWC4SlIDhqskNWC4SlIDhqskNWC4\nSlIDhqskNWC4SlIDhqskNWC4SlIDhqskNWC4SlIDhqskNWC4SlIDhqskNWC4SlIDhqskNWC4SlID\nhqskNfCS4ZrkjYPUJEk/NMiR640D1iRJnSl/WjvJ64E3AAuS/GHfqh8HjmvdmCSNsinDFTgBOLkb\nc0pf/VvAr7RsSpJG3ZThWlV3AXclubmqdid5RVU9N4O9SdLIGmTO9VVJvgZ8HSDJOUk+0bYtSRpt\ng4TrR4G3Ad8AqKqvAm9u2ZQkjbqBrnOtqscPKT3foBdJmjMOd0JrwuNJ3gBUkpcD7wEeatuWJI22\nQY5cfxu4ElgM7AVWdK8lSVN4ySPXqnoaeNcM9CJJc8ZLhmuSGyYpPwOMV9XG6W9JkkbfINMCJ9Kb\nCtjZPc4GlgBXJPnoVBsl2ZBkf5IH+mqnJdmaZGf3PK+rJ8kNSXYl2ZHk3L5t1nbjdyZZe5R/pyTN\nqEHC9WzgLVV1Y1XdCPwi8NPAO4BVh9nuZmD1IbVrgDuqajlwR/ca4CJgefdYB3wSemEMXAtcAJwP\nXDsRyJI0mw0SrvPo3QY74STgtKp6HvjuVBtV1d3AwUPKa4BbuuVbgLf31W+tnnuBU5Msond97daq\nOlhV3wS28uLAlqRZZ5BLsf4cuC/JF4HQu4Hgz5KcBHzhCN9vYVXt65afBBZ2y4uB/mtp93S1qeov\nkmQdvaNezjjjjCNsS5Km12HDNUmAzwOb6X0sB3hfVT3RLV99tG9cVZWkjnb7Sfa3HlgPMDY2Nm37\nlaSjcdhpgaoqYHNV7auqjd3jicNt8xKe6j7u0z3v7+p7gaV945Z0tanqkjSrDTLn+uUkPzdN77cJ\nmDjjvxbY2Fe/rLtq4ELgmW76YAuwKsm87kTWqq4mSbPaIHOuFwDvSrIb+F96865VVWcfbqMktwEr\ngflJ9tA76/9h4PYkVwC7gXd2wzcDFwO7gOeAy+m9ycEk1wPbunHXVdWhJ8kkadYZJFzfdjQ7rqpL\np1j11knGFlPcUltVG4ANR9ODJA3LILe/7gZIcjq9GwokSS9hkF9//eUkO4FHgbuAx4B/adyXJI20\nQU5oXQ9cCPxXVZ1J72P9vU27kqQRN0i4fq+qvgG8LMnLqupOYKxxX5I00gY5ofU/SU4G7gY+nWQ/\n8GzbtiRptA0Srl+ld3nUH9D7XtdX8sLvGpAkHWKQcH1LVf0A+AHdl64k2dG0K0kacVOGa5LfAX4X\neM0hYXoK8B+tG5OkUXa4I9fP0Lvk6kP88HtXAb7tXVKSdHhThmtVPUPv51ymutNKkjSFQS7FkiQd\nIcNVkhowXCWpAcNVkhowXCWpAcNVkhowXCWpAcNVkhowXCWpAcNVkhowXCWpAcNVkhowXCWpAcNV\nkhowXCWpAcNVkhowXCWpgUF+oFDSJP77utcOuwX9CM54//1N9++RqyQ1YLhKUgOGqyQ1MJRwTfJY\nkvuT3JdkvKudlmRrkp3d87yuniQ3JNmVZEeSc4fRsyQdiWEeub6lqlZU1Vj3+hrgjqpaDtzRvQa4\nCFjePdYBn5zxTiXpCM2maYE1wC3d8i3A2/vqt1bPvcCpSRYNo0FJGtSwwrWAzyfZnmRdV1tYVfu6\n5SeBhd3yYuDxvm33dLUXSLIuyXiS8QMHDrTqW5IGMqzrXN9UVXuTnA5sTfL1/pVVVUnqSHZYVeuB\n9QBjY2NHtK0kTbehHLlW1d7ueT/wOeB84KmJj/vd8/5u+F5gad/mS7qaJM1aMx6uSU5KcsrEMrAK\neADYBKzthq0FNnbLm4DLuqsGLgSe6Zs+kKRZaRjTAguBzyWZeP/PVNW/JtkG3J7kCmA38M5u/Gbg\nYmAX8Bxw+cy3LElHZsbDtaoeAc6ZpP4N4K2T1Au4cgZak6RpM5suxZKkOcNwlaQGDFdJasBwlaQG\nDFdJasBwlaQGDFdJasBwlaQGDFdJasBwlaQGDFdJasBwlaQGDFdJasBwlaQGDFdJasBwlaQGDFdJ\nasBwlaQGDFdJasBwlaQGDFdJasBwlaQGDFdJasBwlaQGDFdJasBwlaQGDFdJasBwlaQGDFdJasBw\nlaQGDFdJamBkwjXJ6iQPJ9mV5Jph9yNJhzMS4ZrkOODjwEXAWcClSc4ableSNLWRCFfgfGBXVT1S\nVf8HfBZYM+SeJGlKxw+7gQEtBh7ve70HuKB/QJJ1wLru5bNJHp6h3kbJfODpYTfRSv5y7bBbmGvm\n9P8L1+Zotnr1oANHJVxfUlWtB9YPu4/ZLMl4VY0Nuw+NBv9ffjSjMi2wF1ja93pJV5OkWWlUwnUb\nsDzJmUlOAC4BNg25J0ma0khMC1TV95NcBWwBjgM2VNWDQ25rFDltoiPh/8uPIFU17B4kac4ZlWkB\nSRophqskNWC4HiO8fViDSrIhyf4kDwy7l1FmuB4DvH1YR+hmYPWwmxh1huuxwduHNbCquhs4OOw+\nRp3hemyY7PbhxUPqRTomGK6S1IDhemzw9mFphhmuxwZvH5ZmmOF6DKiq7wMTtw8/BNzu7cOaSpLb\ngHuAn0qyJ8kVw+5pFHn7qyQ14JGrJDVguEpSA4arJDVguEpSA4arJDVguGpOSPLsEYz9QJI/arV/\nCQxXSWrCcNWcleSXknwpyVeSfCHJwr7V5yS5J8nOJL/Vt83VSbYl2ZHkTyfZ56Ikdye5L8kDSX5+\nRv4YjRzDVXPZvwMXVtXr6H3N4h/3rTsb+AXg9cD7k7wqySpgOb2vaFwBnJfkzYfs81eBLVW1AjgH\nuK/x36ARNRK//iodpSXA3yVZBJwAPNq3bmNVfQf4TpI76QXqm4BVwFe6MSfTC9u7+7bbBmxI8nLg\nn6rKcNWkPHLVXHYj8LGqei3wbuDEvnWH3vddQIAPVdWK7vGTVXXTCwb1vkj6zfS+VezmJJe1a1+j\nzHDVXPZKfvjVimsPWbcmyYlJfgJYSe+IdAvwm0lOBkiyOMnp/RsleTXwVFV9Cvhb4NyG/WuEOS2g\nueIVSfb0vf4I8AHg75N8E/g34My+9TuAO4H5wPVV9QTwRJKfAe5JAvAs8GvA/r7tVgJXJ/let94j\nV03Kb8WSpAacFpCkBgxXSWrAcJWkBgxXSWrAcJWkBgxXSWrAcJWkBv4frDxfrP70iVoAAAAASUVO\nRK5CYII=\n",
374 | "text/plain": [
375 | ""
376 | ]
377 | },
378 | "metadata": {
379 | "tags": []
380 | },
381 | "output_type": "display_data"
382 | }
383 | ],
384 | "source": [
385 | "fig = plt.figure(figsize=(5,3))\n",
386 | "ax = sns.barplot(x=df.target.unique(),y=df.target.value_counts());\n",
387 | "ax.set(xlabel='Labels');"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 15,
393 | "metadata": {
394 | "colab": {
395 | "base_uri": "https://localhost:8080/",
396 | "height": 417
397 | },
398 | "colab_type": "code",
399 | "id": "aQfsoDz286Tz",
400 | "outputId": "db8002df-548e-48a6-e91a-c79e2ea88ea8"
401 | },
402 | "outputs": [
403 | {
404 | "data": {
405 | "text/plain": [
406 | "(0 Wanted to watch the TNF game but I forgot my parents cancelled the NFL Network package ???? \n",
407 | " 1 For all the years I took the lynx box sets for granted and this year I didn't get one ???¢???‚ \n",
408 | " 2 Hon. Miss Dashwood, whose manners very pretty face, she offended you have done before. \n",
409 | " 3 ebullient [ih BUL yunt] adj.boiling; bubbling with excitement; exuberant. A boiling liquid can be called ebullient. Excited or enthusiastic.\n",
410 | " 4 All the proud parents on fb about their kids school report and am shitting myself for Graces arriving 😂😂😂 #troublemaker \n",
411 | " 5 \"I a temporary, contract job but I am hoping for permanent\" - that's heard all too often amongst GenY. \n",
412 | " 6 He aint flinch tho, I'll give him that lol \n",
413 | " 7 Scratched 2 inches ice off the car 2 go sledging.After being done we had enough snow and stayed at home pic.twitter.com/RaRHNljLMt \n",
414 | " 8 Your perseveres. Good luck in Mercury Awards, darling @ANOHNItweets pic.twitter.com/is67imvghq \n",
415 | " 9 @Jesus_Luvs_Us Hello Jillian & JC - a nice wisg from Betty for blessings to us all - I wish you the same - have a wonderful Tuesday! 😀 \n",
416 | " Name: tweet, dtype: object,\n",
417 | " 2511 I need to work on being a stronger person emotionally \n",
418 | " 2512 Yay... a big shout out for my friend @GuptaJuhi90 . Welcoming her back on twitter. She a terrific writer. Stay tuned for tweets. \n",
419 | " 2513 @SaidTheSky @DasEnergiFest @mitchell_thayne Bro, you should be #smiling you're bring all of us smiles. Share the experience #together\n",
420 | " 2514 Over the past 2 days I've drank 27 bottles of beer, 2 pints of cider and half a bottle of prosecco. \n",
421 | " 2515 @thedramble Tried it once. Poured into the sink. Didn't bother to make tasting notes :-) Looks like cheap sake, tasted like it too...\n",
422 | " 2516 honestly just wanna watch OTH & have no Netflix to do that, so \n",
423 | " 2517 Mayweather's trash talking is on par with Nate Diaz #terrible would love to see McGregor KO him #MayweatherVsMcGregor \n",
424 | " 2518 Shakespeare Dictionary: The word 'knotty-pated' means 'block-headed, dull-witted' \n",
425 | " 2519 @help_dms___ So good lunch 😊 \n",
426 | " 2520 So excited for tonight, I actually feel a bit sick 😅 \n",
427 | " Name: tweet, dtype: object)"
428 | ]
429 | },
430 | "execution_count": 15,
431 | "metadata": {
432 | "tags": []
433 | },
434 | "output_type": "execute_result"
435 | }
436 | ],
437 | "source": [
438 | "df.tweet.head(10), df.tweet.tail(10)"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": 16,
444 | "metadata": {
445 | "colab": {
446 | "base_uri": "https://localhost:8080/",
447 | "height": 126
448 | },
449 | "colab_type": "code",
450 | "id": "Xopy9Qr2i6eD",
451 | "outputId": "2dab77be-e45b-4032-e49c-ec1c7dbf757f"
452 | },
453 | "outputs": [
454 | {
455 | "data": {
456 | "text/plain": [
457 | "0 Wanted to watch the TNF game but I forgot my parents cancelled the NFL Network package ???? \n",
458 | "1 For all the years I took the lynx box sets for granted and this year I didn't get one ???¢???‚ \n",
459 | "2 Hon. Miss Dashwood, whose manners very pretty face, she offended you have done before. \n",
460 | "3 ebullient [ih BUL yunt] adj.boiling; bubbling with excitement; exuberant. A boiling liquid can be called ebullient. Excited or enthusiastic.\n",
461 | "4 All the proud parents on fb about their kids school report and am shitting myself for Graces arriving 😂😂😂 #troublemaker \n",
462 | "Name: tweet, dtype: object"
463 | ]
464 | },
465 | "execution_count": 16,
466 | "metadata": {
467 | "tags": []
468 | },
469 | "output_type": "execute_result"
470 | }
471 | ],
472 | "source": [
473 | "# check non-depressive tweets\n",
474 | "df[df[\"target\"]==0].tweet.head()"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": 17,
480 | "metadata": {
481 | "colab": {
482 | "base_uri": "https://localhost:8080/",
483 | "height": 126
484 | },
485 | "colab_type": "code",
486 | "id": "C-OqFJlvBs2B",
487 | "outputId": "67291707-6791-4bcb-93d5-bb44b224c9d7"
488 | },
489 | "outputs": [
490 | {
491 | "data": {
492 | "text/plain": [
493 | "16 It's weird that I get 300x more depressed over Christmas break every year. \n",
494 | "20 It's so sad when you talk so highly of someone then they end up disappointing u and making u look like a pendeja \n",
495 | "25 A whopping 9% of adults in America are depressed. Another 3.4% are deeply depressed. Let's Talk! \n",
496 | "27 Not only have I posted a blog post daily for 2 months I've also filmed and uploaded a daily vlog whilst being severely \n",
497 | "33 HealthTap: I haven't eaten anything all day because I feel and low What should I do? \n",
498 | "Name: tweet, dtype: object"
499 | ]
500 | },
501 | "execution_count": 17,
502 | "metadata": {
503 | "tags": []
504 | },
505 | "output_type": "execute_result"
506 | }
507 | ],
508 | "source": [
509 | "# check depressive tweets\n",
510 | "df[df[\"target\"]==1].tweet.head()"
511 | ]
512 | },
513 | {
514 | "cell_type": "markdown",
515 | "metadata": {
516 | "colab_type": "text",
517 | "id": "Ydm2HSMOwBSQ"
518 | },
519 | "source": [
520 | "## **2. Define How to Preprocess Data**"
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": 18,
526 | "metadata": {
527 | "colab": {
528 | "base_uri": "https://localhost:8080/",
529 | "height": 35
530 | },
531 | "colab_type": "code",
532 | "id": "Ts4ioqHeRCC8",
533 | "outputId": "1963ad0a-0b6b-484f-f356-afb6d2b9edf3"
534 | },
535 | "outputs": [
536 | {
537 | "name": "stderr",
538 | "output_type": "stream",
539 | "text": [
540 | "Progress: 100%|██████████| 2521/2521 [00:00<00:00, 177030.26it/s]\n"
541 | ]
542 | }
543 | ],
544 | "source": [
545 | "# torchtext have trouble handling \\n. Replace \\n character with space\n",
546 | "df['tweet'] = df.tweet.progress_apply(lambda x: re.sub('\\n', ' ', x))"
547 | ]
548 | },
549 | {
550 | "cell_type": "code",
551 | "execution_count": 0,
552 | "metadata": {
553 | "colab": {},
554 | "colab_type": "code",
555 | "id": "yo57ODfhRWO0"
556 | },
557 | "outputs": [],
558 | "source": [
559 | "contraction_dict = {\"ain't\": \"is not\", \"aren't\": \"are not\",\"can't\": \"cannot\", \"'cause\": \"because\", \"could've\": \"could have\", \"couldn't\": \"could not\", \"didn't\": \"did not\", \"doesn't\": \"does not\", \"don't\": \"do not\", \"hadn't\": \"had not\", \"hasn't\": \"has not\", \"haven't\": \"have not\", \"he'd\": \"he would\",\"he'll\": \"he will\", \"he's\": \"he is\", \"how'd\": \"how did\", \"how'd'y\": \"how do you\", \"how'll\": \"how will\", \"how's\": \"how is\", \"I'd\": \"I would\", \"I'd've\": \"I would have\", \"I'll\": \"I will\", \"I'll've\": \"I will have\",\"I'm\": \"I am\", \"I've\": \"I have\", \"i'd\": \"i would\", \"i'd've\": \"i would have\", \"i'll\": \"i will\", \"i'll've\": \"i will have\",\"i'm\": \"i am\", \"i've\": \"i have\", \"isn't\": \"is not\", \"it'd\": \"it would\", \"it'd've\": \"it would have\", \"it'll\": \"it will\", \"it'll've\": \"it will have\",\"it's\": \"it is\", \"let's\": \"let us\", \"ma'am\": \"madam\", \"mayn't\": \"may not\", \"might've\": \"might have\",\"mightn't\": \"might not\",\"mightn't've\": \"might not have\", \"must've\": \"must have\", \"mustn't\": \"must not\", \"mustn't've\": \"must not have\", \"needn't\": \"need not\", \"needn't've\": \"need not have\",\"o'clock\": \"of the clock\", \"oughtn't\": \"ought not\", \"oughtn't've\": \"ought not have\", \"shan't\": \"shall not\", \"sha'n't\": \"shall not\", \"shan't've\": \"shall not have\", \"she'd\": \"she would\", \"she'd've\": \"she would have\", \"she'll\": \"she will\", \"she'll've\": \"she will have\", \"she's\": \"she is\", \"should've\": \"should have\", \"shouldn't\": \"should not\", \"shouldn't've\": \"should not have\", \"so've\": \"so have\",\"so's\": \"so as\", \"this's\": \"this is\",\"that'd\": \"that would\", \"that'd've\": \"that would have\", \"that's\": \"that is\", \"there'd\": \"there would\", \"there'd've\": \"there would have\", \"there's\": \"there is\", \"here's\": \"here is\",\"they'd\": \"they would\", \"they'd've\": \"they would have\", \"they'll\": \"they will\", \"they'll've\": \"they will have\", \"they're\": \"they are\", \"they've\": \"they have\", \"to've\": \"to have\", \"wasn't\": \"was not\", \"we'd\": \"we would\", \"we'd've\": \"we would have\", \"we'll\": \"we will\", \"we'll've\": \"we will have\", \"we're\": \"we are\", \"we've\": \"we have\", \"weren't\": \"were not\", \"what'll\": \"what will\", \"what'll've\": \"what will have\", \"what're\": \"what are\", \"what's\": \"what is\", \"what've\": \"what have\", \"when's\": \"when is\", \"when've\": \"when have\", \"where'd\": \"where did\", \"where's\": \"where is\", \"where've\": \"where have\", \"who'll\": \"who will\", \"who'll've\": \"who will have\", \"who's\": \"who is\", \"who've\": \"who have\", \"why's\": \"why is\", \"why've\": \"why have\", \"will've\": \"will have\", \"won't\": \"will not\", \"won't've\": \"will not have\", \"would've\": \"would have\", \"wouldn't\": \"would not\", \"wouldn't've\": \"would not have\", \"y'all\": \"you all\", \"y'all'd\": \"you all would\",\"y'all'd've\": \"you all would have\",\"y'all're\": \"you all are\",\"y'all've\": \"you all have\",\"you'd\": \"you would\", \"you'd've\": \"you would have\", \"you'll\": \"you will\", \"you'll've\": \"you will have\", \"you're\": \"you are\", \"you've\": \"you have\"}"
560 | ]
561 | },
562 | {
563 | "cell_type": "code",
564 | "execution_count": 0,
565 | "metadata": {
566 | "colab": {},
567 | "colab_type": "code",
568 | "id": "dy7D7zhvpsvh"
569 | },
570 | "outputs": [],
571 | "source": [
572 | "def _get_contractions(contraction_dict):\n",
573 | " contraction_re = re.compile('(%s)' % '|'.join(contraction_dict.keys()))\n",
574 | " return contraction_dict, contraction_re\n",
575 | "\n",
576 | "contractions, contractions_re = _get_contractions(contraction_dict)\n",
577 | "\n",
578 | "def replace_contractions(text):\n",
579 | " def replace(match):\n",
580 | " return contractions[match.group(0)]\n",
581 | " return contractions_re.sub(replace, text)"
582 | ]
583 | },
584 | {
585 | "cell_type": "code",
586 | "execution_count": 0,
587 | "metadata": {
588 | "colab": {},
589 | "colab_type": "code",
590 | "id": "hc_P0uFHkYJh"
591 | },
592 | "outputs": [],
593 | "source": [
594 | "def tweet_clean(text):\n",
595 | " text = re.sub(r'https?:/\\/\\S+', ' ', text) # remove urls\n",
596 | " text = re.sub(r'<([^>]*)>', ' ', text) # remove emojis\n",
597 | " text = re.sub(r'@\\w+', ' ', text) # remove at mentions\n",
598 | " text = re.sub(r'#', '', text) # remove hashtag symbol\n",
599 | " text = re.sub(r'[0-9]+', ' ', text) # remove numbers\n",
600 | " text = replace_contractions(text)\n",
601 | " pattern = re.compile(r\"[ \\n\\t]+\")\n",
602 | " text = pattern.sub(\" \", text) \n",
603 | " text = \"\".join(\"\".join(s)[:2] for _, s in itertools.groupby(text)) \n",
604 | " text = re.sub(r'[^A-Za-z0-9,?.!]+', ' ', text) # remove all symbols and punctuation except for . , ! and ?\n",
605 | " return text.strip()"
606 | ]
607 | },
608 | {
609 | "cell_type": "code",
610 | "execution_count": 0,
611 | "metadata": {
612 | "colab": {},
613 | "colab_type": "code",
614 | "id": "iCE-ztcy0RFq"
615 | },
616 | "outputs": [],
617 | "source": [
618 | "nlp = spacy.load('en',disable=['parser', 'tagger', 'ner'])\n",
619 | "def tokenizer(s): return [w.text.lower() for w in nlp(tweet_clean(s))]"
620 | ]
621 | },
622 | {
623 | "cell_type": "markdown",
624 | "metadata": {
625 | "colab_type": "text",
626 | "id": "l5c4RdL90Wgu"
627 | },
628 | "source": [
629 | "**Define fields**"
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "execution_count": 0,
635 | "metadata": {
636 | "colab": {},
637 | "colab_type": "code",
638 | "id": "gFlQw-Pt05eH"
639 | },
640 | "outputs": [],
641 | "source": [
642 | "TEXT = Field(sequential=True, tokenize=tokenizer, include_lengths=True, use_vocab=True)\n",
643 | "TARGET = Field(sequential=False, use_vocab=False, pad_token=None, unk_token=None, is_target =False)"
644 | ]
645 | },
646 | {
647 | "cell_type": "code",
648 | "execution_count": 0,
649 | "metadata": {
650 | "colab": {},
651 | "colab_type": "code",
652 | "id": "LXS_2wDd0qQ9"
653 | },
654 | "outputs": [],
655 | "source": [
656 | "data_fields = [\n",
657 | " (None, None),\n",
658 | " (\"tweet\", TEXT), \n",
659 | " (\"target\", TARGET)\n",
660 | "]"
661 | ]
662 | },
663 | {
664 | "cell_type": "markdown",
665 | "metadata": {
666 | "colab_type": "text",
667 | "id": "SJHSyAjrxsEc"
668 | },
669 | "source": [
670 | "## **3. Create Train, Valid and Test datasets**"
671 | ]
672 | },
673 | {
674 | "cell_type": "code",
675 | "execution_count": 0,
676 | "metadata": {
677 | "colab": {},
678 | "colab_type": "code",
679 | "id": "MnN_75qpgMHN"
680 | },
681 | "outputs": [],
682 | "source": [
683 | "def split_train_test(df, test_size=0.2):\n",
684 | " train, val = train_test_split(df, test_size=test_size,random_state=42)\n",
685 | " return train.reset_index(drop=True), val.reset_index(drop=True)"
686 | ]
687 | },
688 | {
689 | "cell_type": "code",
690 | "execution_count": 0,
691 | "metadata": {
692 | "colab": {},
693 | "colab_type": "code",
694 | "id": "HEFA4z2CE_Sq"
695 | },
696 | "outputs": [],
697 | "source": [
698 | "# create train and validation set \n",
699 | "train_val, test = split_train_test(df, test_size=0.2)\n",
700 | "train, val = split_train_test(train_val, test_size=0.2)"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "execution_count": 0,
706 | "metadata": {
707 | "colab": {},
708 | "colab_type": "code",
709 | "id": "19vn3wuLFeae"
710 | },
711 | "outputs": [],
712 | "source": [
713 | "train.to_csv(\"train.csv\", index=False)\n",
714 | "val.to_csv(\"val.csv\", index=False)\n",
715 | "test.to_csv(\"test.csv\", index=False)"
716 | ]
717 | },
718 | {
719 | "cell_type": "code",
720 | "execution_count": 28,
721 | "metadata": {
722 | "colab": {
723 | "base_uri": "https://localhost:8080/",
724 | "height": 35
725 | },
726 | "colab_type": "code",
727 | "id": "VFi9p4ibRvX0",
728 | "outputId": "6775feee-0dec-49c8-ac99-58ee4cd4daf3"
729 | },
730 | "outputs": [
731 | {
732 | "data": {
733 | "text/plain": [
734 | "((1612, 3), (404, 3), (505, 3))"
735 | ]
736 | },
737 | "execution_count": 28,
738 | "metadata": {
739 | "tags": []
740 | },
741 | "output_type": "execute_result"
742 | }
743 | ],
744 | "source": [
745 | "train.shape, val.shape, test.shape"
746 | ]
747 | },
748 | {
749 | "cell_type": "code",
750 | "execution_count": 29,
751 | "metadata": {
752 | "colab": {
753 | "base_uri": "https://localhost:8080/",
754 | "height": 314
755 | },
756 | "colab_type": "code",
757 | "id": "HC87TsYER-0Q",
758 | "outputId": "429f32a6-a9f8-4b3c-fab2-4e8cca19b0d1"
759 | },
760 | "outputs": [
761 | {
762 | "data": {
763 | "text/plain": [
764 | "[Text(0, 0.5, 'counts'), Text(0.5, 0, 'Labels'), Text(0.5, 1.0, 'test')]"
765 | ]
766 | },
767 | "execution_count": 29,
768 | "metadata": {
769 | "tags": []
770 | },
771 | "output_type": "execute_result"
772 | },
773 | {
774 | "data": {
775 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAm4AAAEWCAYAAADfMRsiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3X2UZXV95/v3R0BQUXnqMC2Naa72\nNUGjLVNDyJjJEFB5GJPWuWogiXSQNa03mNHRUcF7Z8AoE3JHZSQTyWoEAUdFRDMyhsQQwGF5hwcb\n7CAP8dICSncauuVJiJGk8Xv/OL+CQ1HVnOquU6f2qfdrrbNq799+qG/Vqu/a3/rtvX+/VBWSJEla\n+J416gAkSZI0GAs3SZKkjrBwkyRJ6ggLN0mSpI6wcJMkSeoICzdJkqSOsHATAEn+JMl/GHUc0rAl\nOTzJxr71W5McPsi+O/C9zCtJc8rCbUwkuTvJa3f0+Kp6Z1V9ZC5jkrqgql5eVd/Y2fMk+Z0k35xy\nbvNKY2tnrzvtHE/LG22fhdsikGTXUccgSZJ2noXbGEjyWeDFwP9I8miSDySpJCcl+QFwVdvvS0nu\nTfJwkmuSvLzvHBck+WhbPjzJxiTvS7IlyeYkJ47kh5NmkOSDSS6d0vbJJGcnOTHJ7UkeSXJnknds\n5zxP9BokeU7LhQeT3Ab8syn7npLke+28tyV5U2v/eeBPgF9qOfhQa38ir9r6v0myIckDSS5L8qK+\nbZXknUnuSPJQkj9Okjn4VUlzbobrzmFJ/lf7+/3r/kcQWs/anS137kryWzPljbbPwm0MVNXbgB8A\nv1ZVewKXtE3/Evh54Ki2/ufACuBngJuAz23ntP8EeCFwAHAS8MdJ9p776KUddjFwbJLnAyTZBXgr\n8HlgC/AG4AXAicBZSQ4Z4JynAS9pn6OA1VO2fw/4F/Ry48PAf0uytKpuB94JXFtVe1bVXlNPnOQI\n4A9ajEuB77efod8b6BWLr2z7HYW0AE1z3fkc8GfAR4F9gH8PfDnJkiTPA84Gjqmq5wP/HFg/SN7o\n6SzcxtvpVfV3VfX3AFV1flU9UlWPAacDr0rywhmO/Ufg96vqH6vqcuBR4GXzErU0gKr6Pr1/QN7U\nmo4AflxV11XVn1XV96rnfwJ/Sa/geiZvBc6oqgeq6h56F5v+7/mlqvrbqvppVX0RuAM4dMCQfws4\nv6puajl4Kr2ehuV9+5xZVQ9V1Q+Aq4GVA55bGrXfBi6vqstbflwBrAOObdt/CrwiyXOqanNV3Tqy\nSDvOwm283TO5kGSXJGe22zw/Au5um/ab4dj7q2pb3/qPgT2HE6a0wz4PHN+Wf7Otk+SYJNe1W5IP\n0bt4zPS33u9F9OUNvV6xJyQ5Icn6divoIeAVA5538txPnK+qHgXup9erPenevmVzTl3ys8BbJnOj\n5ccvA0ur6u+A36DXu7Y5yZ8l+blRBttlFm7jo56h7TeBVcBr6d3mWd7afYZGXfYl4PAky+j1vH0+\nye7Al4GPAfu32y+XM9jf+mbgwL71F08uJPlZ4FzgXcC+7by39J13uhzs97f0Lm6T53sesC+waYC4\npIWo/2/+HuCzVbVX3+d5VXUmQFV9vapeR+8xgb+hl0tTz6EBWLiNj/uA/207258PPEbvP/znAv9p\nPoKShqmqtgLfAD4D3NWemXk2sDuwFdiW5Bjg9QOe8hLg1CR7t2Lw9/q2PY/eRWYrQHth5xV92+8D\nliV59gzn/gJwYpKVrbj8T8D1VXX3gLFJC03/dee/Ab+W5Kh2h2eP9qLbsiT7J1nV/ll5jN6jNz/t\nO8f28kZTWLiNjz8A/u/WPf3mabZfRO82zSbgNuC6eYxNGqbP0+tJ/jxAVT0C/Ft6RdiD9HqbLxvw\nXB+mlyd30Xsu7rOTG6rqNuDjwLX0Lja/APy/fcdeBdwK3Jvkh1NPXFV/BfwHer2Bm+m9AHHcgHFJ\nC1H/dec36N3V+RC9f27uAd5Pr854FvBeer3OD9B7ce7/bOfYbt7o6VJlL6UkSVIX2OMmSZLUERZu\nkiRJHWHhJkmS1BEWbpIkSR0xlpOP77fffrV8+fJRh6ExcOONN/6wqpaMOo75YN5orpg30uzMJmfG\nsnBbvnw569atG3UYGgNJvv/Me40H80ZzxbyRZmc2OeOtUknS2GqDwX47ydfa+kFJrk+yIckXJwd+\nTbJ7W9/Qti8fZdzSTCzcJEnj7N3A7X3rfwicVVUvpTdA80mt/STgwdZ+VttPWnAs3CRJY6lNW/av\ngE+39QBHAJe2XS4E3tiWV7V12vYj2/7SgmLhJkkaV/8F+ABPzou5L/BQVW1r6xuBA9ryAfSmaaJt\nf7jt/zRJ1iRZl2Td1q1bhxW7NC0LN0nS2EnyBmBLVd041+euqrVVNVFVE0uWLIqXZ7WAjOVbpZKk\nRe81wK8nORbYA3gB8ElgryS7tl61ZcCmtv8m4EBgY5JdgRcC989/2NL22eMmSRo7VXVqVS2rquXA\nccBVVfVbwNXAm9tuq4GvtuXL2jpt+1VVVfMYsjQQCzdJ0mLyQeC9STbQe4btvNZ+HrBva38vcMqI\n4pO2y1ulkqSxVlXfAL7Rlu8EDp1mn58Ab5nXwKQdsCgLt3/6/otGHcK8ufE/nzDqEDQGzBlp9swb\nDYO3SqUOSLJHkhuS/HWSW5N8uLU7CrwkLSIWblI3PAYcUVWvAlYCRyc5DEeBl6RFZWiFW5Lzk2xJ\ncktf239O8jdJbk7yp0n26tt2ausd+G6So/raj25tG5L4sKgWpep5tK3u1j6Fo8BL0qIyzB63C4Cj\np7RdAbyiql4J/H/AqQBJDqb3uvbL2zGfahMD7wL8MXAMcDBwfNtXWnRaTqwHttDLpe+xk6PAOwK8\nJHXL0Aq3qroGeGBK21/2XWSuozf4IfR6By6uqseq6i5gA723fg4FNlTVnVX1D8DFbV9p0amqx6tq\nJb28ORT4uTk4pyPAS1KHjPIZt7cDf96Wn+gdaCZ7DmZqlxatqnqI3iCiv0QbBb5tmm4UeBwFXpLG\nx0gKtyT/F7AN+NwcntNbPhpbSZZMPhOa5DnA64DbcRR4SVpU5n0ctyS/A7wBOLLvQvJE70DT33Mw\nU/tTVNVaYC3AxMSEFyiNm6XAhe25z2cBl1TV15LcBlyc5KPAt3nqKPCfbaPAP0DvGVJJUsfNa+GW\n5GjgA8C/rKof9226DPh8kk8ALwJWADcAAVYkOYhewXYc8JvzGbO0EFTVzcCrp2l3FHhJWkSGVrgl\n+QJwOLBfko3AafTeIt0duKKNTHBdVb2zqm5NcglwG71bqCdX1ePtPO8Cvg7sApxfVbcOK2ZJkqSF\nbGiFW1UdP03zedO0Te5/BnDGNO2XA5fPYWiSJEmd5MwJkiRJHWHhJkmS1BEWbpIkSR1h4SZJktQR\nFm6SJEkdYeEmSRo7SfZIckOSv05ya5IPt/YLktyVZH37rGztSXJ2kg1Jbk5yyGh/Aml68z5zgiRJ\n8+Ax4IiqejTJbsA3k0zOj/3+qrp0yv7H0Bv8fQXwi8A57au0oNjjJkkaO9XzaFvdrX22Nx3iKuCi\ndtx1wF5Jlg47Tmm2LNwkSWMpyS5J1gNbgCuq6vq26Yx2O/SsJLu3tgOAe/oO39japjvvmiTrkqzb\nunXr0OKXpmPhJkkaS1X1eFWtBJYBhyZ5Bb2pF38O+GfAPsAHd+C8a6tqoqomlixZMqcxS8/Ewk2S\nNNaq6iHgauDoqtrcboc+BnwGOLTttgk4sO+wZa1NWlAs3CRJYyfJkiR7teXnAK8D/mbyubUkAd4I\n3NIOuQw4ob1dehjwcFVtHkHo0nb5VqkkaRwtBS5Msgu9TopLquprSa5KsgQIsB54Z9v/cuBYYAPw\nY+DEEcQsPSMLN0nS2Kmqm4FXT9N+xAz7F3DysOOSdpa3SiVJkjrCwk2SJKkjLNwkSZI6wsJNkiSp\nIyzcJEmSOsLCTZIkqSMs3CRJkjrCwk2SJKkjLNykDkhyYJKrk9yW5NYk727tpyfZlGR9+xzbd8yp\nSTYk+W6So0YXvSRprjhzgtQN24D3VdVNSZ4P3JjkirbtrKr6WP/OSQ4GjgNeDrwI+Ksk/3tVPT6v\nUUuS5pQ9blIHVNXmqrqpLT8C3A4csJ1DVgEXV9VjVXUXvfkXDx1+pJKkYRpa4Zbk/CRbktzS17ZP\nkiuS3NG+7t3ak+Tsdlvn5iSH9B2zuu1/R5LVw4pX6ooky+nNwXh9a3pXy5vzJ3OKXlF3T99hG9l+\noSdJ6oBh9rhdABw9pe0U4MqqWgFc2dYBjgFWtM8a4BzoFXrAacAv0ustOK3vwiQtOkn2BL4MvKeq\nfkQvV14CrAQ2Ax+f5fnWJFmXZN3WrVvnPF5J0twaWuFWVdcAD0xpXgVc2JYvBN7Y135R9VwH7JVk\nKXAUcEVVPVBVDwJX8PRiUFoUkuxGr2j7XFV9BaCq7quqx6vqp8C5PHk7dBNwYN/hy1rbU1TV2qqa\nqKqJJUuWDPcHkCTttPl+xm3/qtrclu8F9m/LM93WGfh2jz0HGmdJApwH3F5Vn+hrX9q325uAyUcT\nLgOOS7J7koPo9WbfMF/xSpKGY2RvlVZVJak5PN9aYC3AxMTEnJ1XWiBeA7wN+E6S9a3tQ8DxSVYC\nBdwNvAOgqm5NcglwG703Uk/2jVJJ6r75LtzuS7K0qja3noItrX2m2zqbgMOntH9jHuKUFpSq+iaQ\naTZdvp1jzgDOGFpQkqR5N9+3Si8DJt8MXQ18ta/9hPZ26WHAw+2W6teB1yfZu72U8PrWJkmStOgM\ncziQLwDXAi9LsjHJScCZwOuS3AG8tq1Dr9fgTnpjTZ0L/C5AVT0AfAT4Vvv8fmuTJGm7kuyR5IYk\nf91mHPlwaz8oyfVtCKovJnl2a9+9rW9o25ePMn5pOkO7VVpVx8+w6chp9i3g5BnOcz5w/hyGJkla\nHB4DjqiqR9tb2d9M8ufAe+nNOHJxkj8BTqI3tM5JwINV9dIkxwF/CPzGqIKXpuPMCZKksdSGmHq0\nre7WPgUcAVza2qcOTTU5ZNWlwJHtjW5pwbBwkySNrSS7tDext9AbC/R7wENVta3t0j/M1BNDULXt\nDwP7TnNOh5/SyFi4SZLGVhugeiW9UQkOBX5uDs7pwNUaGQs3SdLYq6qHgKuBX6I3O8/kM979s4o8\nMTRV2/5C4P55DlXaLgs3SdJYSrIkyV5t+TnA64Db6RVwb267TR2aanLIqjcDV7WX56QFY2QzJ0iS\nNGRLgQuT7EKvo+KSqvpaktuAi5N8FPg2venkaF8/m2QDvbm2jxtF0NL2WLhJksZSVd0MvHqa9jvp\nPe82tf0nwFvmITRph3mrVJIkqSMs3CRJkjrCwk2SJKkjLNwkSZI6wsJNkiSpIyzcJEmSOsLCTZIk\nqSMs3CRJkjrCwk2SJKkjLNwkSZI6wsJNkiSpIyzcJEmSOsLCTZIkqSMs3CRJkjrCwk2SJKkjLNyk\nDkhyYJKrk9yW5NYk727t+yS5Iskd7everT1Jzk6yIcnNSQ4Z7U8gSZoLFm5SN2wD3ldVBwOHAScn\nORg4BbiyqlYAV7Z1gGOAFe2zBjhn/kOWJM21kRRuSf5d6zW4JckXkuyR5KAk17cegi8meXbbd/e2\nvqFtXz6KmKVRqqrNVXVTW34EuB04AFgFXNh2uxB4Y1teBVxUPdcBeyVZOs9hS5Lm2LwXbkkOAP4t\nMFFVrwB2AY4D/hA4q6peCjwInNQOOQl4sLWf1faTFq32z8urgeuB/atqc9t0L7B/Wz4AuKfvsI2t\nbeq51iRZl2Td1q1bhxazNN+283jB6Uk2JVnfPsf2HXNq6yT4bpKjRhe9NLNR3SrdFXhOkl2B5wKb\ngSOAS9v2qT0Hkz0KlwJHJsk8xiotGEn2BL4MvKeqftS/raoKqNmcr6rWVtVEVU0sWbJkDiOVRm6m\nxwug10mwsn0uB2jbjgNeDhwNfCrJLqMIXNqeeS/cqmoT8DHgB/QKtoeBG4GHqmpb262/d+CJnoO2\n/WFg36nntedA4y7JbvSKts9V1Vda832Tt0Db1y2tfRNwYN/hy1qbtChs5/GCmawCLq6qx6rqLmAD\ncOjwI5VmZxS3SvemlyAHAS8Cnkfvv5udYs+BxlnrZT4PuL2qPtG36TJgdVteDXy1r/2E9nbpYcDD\nfbdUpUVlyuMFAO9qb1ufP/kmNgM+XtDOZ0eBRmYUt0pfC9xVVVur6h+BrwCvoffw9K5tn/7egSd6\nDtr2FwL3z2/I0si9BngbcMSUZ3POBF6X5A56uXVm2/9y4E56vQbnAr87gpilkZvm8YJzgJcAK+nd\n9fn4bM9pR4FGaddn3mXO/QA4LMlzgb8HjgTWAVcDbwYu5uk9B6uBa9v2q9qzPNKiUVXfBGZ6tvPI\nafYv4OShBiUtcNM9XlBV9/VtPxf4Wlv18QJ1wiiecbue3ksGNwHfaTGsBT4IvDfJBnrPsJ3XDjkP\n2Le1v5cnx6mSJGlaMz1eMGVYnDcBt7Tly4Dj2hBUB9EbA/GG+YpXGtQoetyoqtOA06Y038k0D4JW\n1U+At8xHXJKksTH5eMF3kqxvbR8Cjk+ykt4b2HcD7wCoqluTXALcRu+N1JOr6vF5j1p6BiMp3CRJ\nGqbtPF5w+XaOOQM4Y2hBSXPAKa8kSZI6wsJNkiSpIyzcJEmSOsLCTZIkqSMs3CRJkjpioMItybuT\nvKBNn3NekpuSvH7YwUnjyHySZseckZ40aI/b29tUIa8H9qY3Ns6Z2z9E0gzMJ2l2zBmpGbRwmxwL\n51jgs1V1KzNPvyNp+8wnaXbMGakZtHC7Mclf0kuaryd5PvDT4YUljTXzSZodc0ZqBp054SRgJXBn\nVf04yb7AicMLSxpr5pM0O+aM1Aza43ZFVd1UVQ8BVNX9wFnDC0saa+aTNDvmjNRst8ctyR7Ac4H9\nkuzNk88UvAA4YMixSWPFfJJmx5yRnu6ZbpW+A3gP8CLgRp5Mmh8B/3WIcUnjyHySZseckabYbuFW\nVZ8EPpnk96rqj+YpJmksmU/S7Jgz0tMN9HJCVf1Rkn8OLO8/pqouGlJc0tgyn6TZMWekJw1UuCX5\nLPASYD3weGsuwKSRZsl8kmbHnJGeNOhwIBPAwVVVwwxGWiTMJ2l2zBmpGXQ4kFuAfzLMQKRFxHyS\nZmfWOZPkwCRXJ7ktya1J3t3a90lyRZI72te9W3uSnJ1kQ5KbkxwyhJ9D2mmD9rjtB9yW5AbgscnG\nqvr1oUQljTfzSZqdHcmZbcD7quqmNtPCjUmuAH4HuLKqzkxyCnAK8EHgGGBF+/wicE77Ki0ogxZu\npw8zCGmROX3UAUgdc/psD6iqzcDmtvxIktvpjf22Cji87XYh8A16hdsq4KJ2O/a6JHslWdrOIy0Y\ng75V+j+HHYi0WJhP0uzsbM4kWQ68Grge2L+vGLsX2L8tHwDc03fYxtb2tMItyRpgDcCLX/zinQlN\nmrWBnnFL8kiSH7XPT5I8nuRHww5OGkfmkzQ7O5MzSfYEvgy8p6qeckzrXZv1Cw9VtbaqJqpqYsmS\nJbM9XNopAxVuVfX8qnpBVb0AeA7wfwCfGmpk0pjakXxKcn6SLUlu6Ws7PcmmJOvb59i+bae2h6y/\nm+Soof0w0jzY0WtQkt3oFW2fq6qvtOb7kixt25cCW1r7JuDAvsOXtTZpQRn0rdInVM9/B7wYSDtp\nFvl0AXD0NO1nVdXK9rkcIMnBwHHAy9sxn0qyyxyGLY3MoDmTJMB5wO1V9Ym+TZcBq9vyauCrfe0n\ntLdLDwMe9vk2LUSDDsD7r/tWn0VvTJ2f7Og3TbIX8GngFfS6qd8OfBf4Ir2Rse8G3lpVD7bk+yRw\nLPBj4Heq6qYd/d7SqO1IPlXVNe05nUGsAi6uqseAu5JsAA4Frp19tNLo7eA16DXA24DvJFnf2j4E\nnAlckuQk4PvAW9u2y+ldZzbQu9acODfRS3Nr0LdKf61veRu9wmrVTnzfTwJ/UVVvTvJs4Ln0EspX\ntLUYzGU+vSvJCcA6ekMfPEjvgerr+vaZfMj6aXzIWh0x65ypqm/y5KT0Ux05zf4FnLyD8UnzZtC3\nSufsP48kLwR+hd5YOlTVPwD/kMRXtLUozGE+nQN8hF6v9UeAj9PrvZ5NLGuBtQATExOOSq8FaS6v\nQVLXDfpW6bIkf9oejt6S5MtJlu3g9zwI2Ap8Jsm3k3w6yfOY/SvaU2Nck2RdknVbt27dwdCk4Zur\nfKqq+6rq8ar6KXAuvduh4EPWGjNzfA2SOm3QlxM+Q+/BzRe1z/9obTtiV+AQ4JyqejXwd/Ruiz5h\nR17R9vVsdcic5NPkm3HNm+hNC0Q793FJdk9yEL3HDG7YqYil0ZrLa5DUaYMWbkuq6jNVta19LgB2\ntDraCGysquvb+qX0Cjlf0dZiMet8SvIFei8XvCzJxvZg9f+T5DtJbgZ+Ffh3AFV1K3AJcBvwF8DJ\nVfX4EH8eadjm8hokddqghdv9SX47yS7t89vA/TvyDavqXuCeJC9rTUfSu8D4irYWi1nnU1UdX1VL\nq2q3qlpWVedV1duq6heq6pVV9ev9eVFVZ1TVS6rqZVX150P/iaThmrNrkNR1g75V+nbgj4Cz6N3C\n/F+0lwt20O8Bn2tvlN5J77XrZ+Er2loc5jqfpHFnzkjNoIXb7wOr21ADJNkH+BizfINtUlWtpzcO\nz1S+oq3FYE7zSVoEzBmpGfRW6SsnEwagqh6gN2GvpNkzn6TZMWekZtDC7VlJ9p5caf/tDNpbJ+mp\nzCdpdswZqRn0D//jwLVJvtTW3wKcMZyQpLFnPkmzY85IzaAzJ1yUZB1wRGv611V12/DCksaX+STN\njjkjPWngruaWJCaKNAfMJ2l2zBmpZ9Bn3CRJkjRiFm6SJEkdYeEmSZLUERZukiRJHWHhJkmS1BEW\nbpIkSR1h4SZJGktJzk+yJcktfW2nJ9mUZH37HNu37dQkG5J8N8lRo4la2j4LN0nSuLoAOHqa9rOq\namX7XA6Q5GDgOODl7ZhPJdll3iKVBmThJkkaS1V1DfDAgLuvAi6uqseq6i5gA3Do0IKTdpCFmyRp\nsXlXkpvbrdTJyesPAO7p22dja3uaJGuSrEuybuvWrcOOVXoKCzdJ0mJyDvASYCWwmd4E9rNSVWur\naqKqJpYsWTLX8UnbZeEmSVo0quq+qnq8qn4KnMuTt0M3AQf27bqstUkLioWbJGnRSLK0b/VNwOQb\np5cBxyXZPclBwArghvmOT3omu446AEmShiHJF4DDgf2SbAROAw5PshIo4G7gHQBVdWuSS4DbgG3A\nyVX1+CjilrbHwk2SNJaq6vhpms/bzv5nAGcMLyJp53mrVJIkqSMs3CRJkjrCwk2SJKkjLNykDphh\nzsV9klyR5I72de/WniRntzkXb05yyOgilyTNpZEVbkl2SfLtJF9r6wclub5dbL6Y5Nmtffe2vqFt\nXz6qmKURuoCnz7l4CnBlVa0ArmzrAMfQG8pgBbCG3oCjkqQxMMoet3cDt/et/yG9iX9fCjwInNTa\nTwIebO1ntf2kRWWGORdXARe25QuBN/a1X1Q91wF7TRm7SpLUUSMp3JIsA/4V8Om2HuAI4NK2y9SL\n0OTF6VLgyLa/tNjtX1Wb2/K9wP5t2TkXJWlMjWoct/8CfAB4flvfF3ioqra19f4LzRMXoaraluTh\ntv8P+0+YZA2920K8+MUvHmrw0kJTVZWkduC4tcBagImJiVkfr6f6we//wqhDmDcv/o/fGXUI0qI0\n7z1uSd4AbKmqG+fyvE76q0XovslboO3rltbunIuSNKZGcav0NcCvJ7kbuJjeLdJP0nsOZ7IHsP9C\n88RFqG1/IXD/fAYsLVCXAavb8mrgq33tJ7S3Sw8DHu67pSpJ6rB5L9yq6tSqWlZVy4HjgKuq6reA\nq4E3t92mXoQmL05vbvt7S0eLSptz8VrgZUk2JjkJOBN4XZI7gNe2dYDLgTuBDcC5wO+OIGRJ0hAs\npLlKPwhcnOSjwLd5cj6584DPJtlA762640YUnzQyM8y5CHDkNPsWcPJwI5IkjcJIC7eq+gbwjbZ8\nJ3DoNPv8BHjLvAYmH7KWJA2d15rZc+YESZKkjrBwkyRJ6ggLN0mSpI6wcJMkSeoICzdJ0lhKcn6S\nLUlu6WvbJ8kVSe5oX/du7UlydpINSW5OcsjoIpdmZuEmSRpXFwBHT2k7BbiyqlYAV7Z1gGOAFe2z\nBjhnnmKUZsXCTZI0lqrqGnrjf/ZbBVzYli8E3tjXflH1XEdvNp+l8xOpNDgLN0nSYrJ/3xRw9wL7\nt+UDgHv69tvY2qQFxcJNkrQotVlGZj2FYpI1SdYlWbd169YhRCbNzMJNkrSY3Dd5C7R93dLaNwEH\n9u23rLU9TVWtraqJqppYsmTJUIOVprJwkyQtJpcBq9vyauCrfe0ntLdLDwMe7rulKi0YC2mSeUmS\n5kySLwCHA/sl2QicBpwJXJLkJOD7wFvb7pcDxwIbgB8DJ857wNIALNwkSWOpqo6fYdOR0+xbwMnD\njUjaed4qlSRJ6ggLN0mSpI6wcJMkSeoICzdJkqSOsHCTJEnqCAs3SZKkjrBwkyRJ6ggLN0mSpI6w\ncJMkSeoICzdJkqSOsHCTJEnqiHkv3JIcmOTqJLcluTXJu1v7PkmuSHJH+7p3a0+Ss5NsSHJzkkPm\nO2ZpIUtyd5LvJFmfZF1rmzafJEndNooet23A+6rqYOAw4OQkBwOnAFdW1QrgyrYOcAywon3WAOfM\nf8jSgverVbWyqiba+kz5JEnqsHkv3Kpqc1Xd1JYfAW4HDgBWARe23S4E3tiWVwEXVc91wF5Jls5z\n2FLXzJRPkqQOG+kzbkmWA68Grgf2r6rNbdO9wP5t+QDgnr7DNra2qedak2RdknVbt24dWszSAlTA\nXya5Mcma1jZTPj2FeSNJ3TKywi3JnsCXgfdU1Y/6t1VV0bsYDayq1lbVRFVNLFmyZA4jlRa8X66q\nQ+g9VnBykl/p37i9fDJvJKlbRlK4JdmNXtH2uar6Smu+b/IWaPu6pbVvAg7sO3xZa5MEVNWm9nUL\n8KfAocycT5KkDhvFW6UBzgNur6pP9G26DFjdllcDX+1rP6G9XXoY8HDfLSBpUUvyvCTPn1wGXg/c\nwsz5JEnqsF1H8D1fA7wN+E7xxDOMAAAFeUlEQVSS9a3tQ8CZwCVJTgK+D7y1bbscOBbYAPwYOHF+\nw5UWtP2BP+39P8SuwOer6i+SfIvp80kSvWF0gEeAx4FtVTWRZB/gi8By4G7grVX14KhilKYz74Vb\nVX0TyAybj5xm/wJOHmpQUkdV1Z3Aq6Zpv59p8knSU/xqVf2wb31yGJ0zk5zS1j84mtCk6TlzgiRJ\nPQ6jowXPwk2StBg5jI46aRTPuEmSNGq/XFWbkvwMcEWSv+nfWFWVZMZhdIC1ABMTE7MaukraWfa4\nSZIWHYfRUVdZuEmSFhWH0VGXeatUkrTYOIyOOsvCTZK0qDiMjrrMW6WSJEkdYeEmSZLUERZukiRJ\nHWHhJkmS1BEWbpIkSR1h4SZJktQRFm6SJEkdYeEmSZLUERZukiRJHWHhJkmS1BEWbpIkSR1h4SZJ\nktQRFm6SJEkdYeEmSZLUERZukiRJHWHhJkmS1BEWbpIkSR1h4SZJktQRnSnckhyd5LtJNiQ5ZdTx\nSAudOSPNnnmjha4ThVuSXYA/Bo4BDgaOT3LwaKOSFi5zRpo980Zd0InCDTgU2FBVd1bVPwAXA6tG\nHJO0kJkz0uyZN1rwdh11AAM6ALinb30j8Iv9OyRZA6xpq48m+e48xTao/YAfzvc3zcdWz/e33Fnz\n/3s6Ldvb+rPzFcYce8acAfNmOubMgMybhZo3XmsG09lrTVcKt2dUVWuBtaOOYyZJ1lXVxKjjWOj8\nPc0v86b7/B3Nv4WcN/49DKbLv6eu3CrdBBzYt76stUmanjkjzZ55owWvK4Xbt4AVSQ5K8mzgOOCy\nEcckLWTmjDR75o0WvE7cKq2qbUneBXwd2AU4v6puHXFYs7Ugu9UXIH9Pc2BMcgb8exiEv6M5MiZ5\n49/DYDr7e0pVjToGSZIkDaArt0olSZIWPQs3SZKkjrBwmwdOofLMkpyfZEuSW0Ydi0bPnHlm5oym\nMm+e2TjkjYXbkDmFysAuAI4edRAaPXNmYBdgzqgxbwZ2AR3PGwu34XMKlQFU1TXAA6OOQwuCOTMA\nc0ZTmDcDGIe8sXAbvummUDlgRLFIXWDOSLNn3iwSFm6SJEkdYeE2fE6hIs2OOSPNnnmzSFi4DZ9T\nqEizY85Is2feLBIWbkNWVduAySlUbgcu6eAUKkOX5AvAtcDLkmxMctKoY9JomDODMWfUz7wZzDjk\njVNeSZIkdYQ9bpIkSR1h4SZJktQRFm6SJEkdYeEmSZLUERZukiRJHWHh1nFJHp3Fvqcn+ffDOr/U\nBeaMNHvmzcJh4SZJktQRFm5jKMmvJbk+ybeT/FWS/fs2vyrJtUnuSPJv+o55f5JvJbk5yYenOefS\nJNckWZ/kliT/Yl5+GGkemDPS7Jk3o2HhNp6+CRxWVa8GLgY+0LftlcARwC8B/zHJi5K8HlgBHAqs\nBP5pkl+Zcs7fBL5eVSuBVwHrh/wzSPPJnJFmz7wZgV1HHYCGYhnwxSRLgWcDd/Vt+2pV/T3w90mu\nppdAvwy8Hvh222dPesl1Td9x3wLOT7Ib8N+rymTSODFnpNkzb0bAHrfx9EfAf62qXwDeAezRt23q\nHGcFBPiDqlrZPi+tqvOeslPVNcCvAJuAC5KcMLzwpXlnzkizZ96MgIXbeHohvT96gNVTtq1KskeS\nfYHD6f1383Xg7Un2BEhyQJKf6T8oyc8C91XVucCngUOGGL8038wZafbMmxHwVmn3PTfJxr71TwCn\nA19K8iBwFXBQ3/abgauB/YCPVNXfAn+b5OeBa5MAPAr8NrCl77jDgfcn+ce23f+C1FXmjDR75s0C\nkaqpvZmSJElaiLxVKkmS1BEWbpIkSR1h4SZJktQRFm6SJEkdYeEmSZLUERZukiRJHWHhJkmS1BH/\nP71ova2JZQVRAAAAAElFTkSuQmCC\n",
776 | "text/plain": [
777 | ""
778 | ]
779 | },
780 | "metadata": {
781 | "tags": []
782 | },
783 | "output_type": "display_data"
784 | }
785 | ],
786 | "source": [
787 | "fig = plt.figure(figsize=(10,4))\n",
788 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
789 | "\n",
790 | "ax = fig.add_subplot(1,3,1)\n",
791 | "ax = sns.barplot(x=train.target.unique(),y=train.target.value_counts())\n",
792 | "ax.set(xlabel='Labels', ylabel=\"counts\", title=\"train\")\n",
793 | "\n",
794 | "ax1 = fig.add_subplot(1,3,2)\n",
795 | "ax1 = sns.barplot(x=val.target.unique(),y=val.target.value_counts())\n",
796 | "ax1.set(xlabel='Labels', ylabel=\"counts\", title=\"validation\")\n",
797 | "\n",
798 | "ax2 = fig.add_subplot(1,3,3)\n",
799 | "ax2 = sns.barplot(x=test.target.unique(),y=test.target.value_counts())\n",
800 | "ax2.set(xlabel='Labels', ylabel=\"counts\", title=\"test\")"
801 | ]
802 | },
803 | {
804 | "cell_type": "code",
805 | "execution_count": 30,
806 | "metadata": {
807 | "colab": {
808 | "base_uri": "https://localhost:8080/",
809 | "height": 54
810 | },
811 | "colab_type": "code",
812 | "id": "uswTGPRH2v6J",
813 | "outputId": "c0867882-56cf-4cc6-f566-7b06a9e0fe66"
814 | },
815 | "outputs": [
816 | {
817 | "name": "stdout",
818 | "output_type": "stream",
819 | "text": [
820 | "CPU times: user 1.75 s, sys: 52 ms, total: 1.8 s\n",
821 | "Wall time: 1.76 s\n"
822 | ]
823 | }
824 | ],
825 | "source": [
826 | "%%time\n",
827 | "train_data, val_data, test_data = TabularDataset.splits(path='./', format='csv', train='train.csv', validation='val.csv', test='test.csv', fields=data_fields, skip_header=True)"
828 | ]
829 | },
830 | {
831 | "cell_type": "code",
832 | "execution_count": 31,
833 | "metadata": {
834 | "colab": {
835 | "base_uri": "https://localhost:8080/",
836 | "height": 35
837 | },
838 | "colab_type": "code",
839 | "id": "eSgxCIC23keA",
840 | "outputId": "e9018a7f-76c2-4d40-b148-5d05c4e3865f"
841 | },
842 | "outputs": [
843 | {
844 | "data": {
845 | "text/plain": [
846 | "(1612, 404, 505)"
847 | ]
848 | },
849 | "execution_count": 31,
850 | "metadata": {
851 | "tags": []
852 | },
853 | "output_type": "execute_result"
854 | }
855 | ],
856 | "source": [
857 | "len(train_data), len(val_data), len(test_data)"
858 | ]
859 | },
860 | {
861 | "cell_type": "markdown",
862 | "metadata": {
863 | "colab_type": "text",
864 | "id": "z54cb2_B4b79"
865 | },
866 | "source": [
867 | "## **4. Load pretrained embeddings and build vocab**"
868 | ]
869 | },
870 | {
871 | "cell_type": "code",
872 | "execution_count": 32,
873 | "metadata": {
874 | "colab": {
875 | "base_uri": "https://localhost:8080/",
876 | "height": 54
877 | },
878 | "colab_type": "code",
879 | "id": "n-tsm29k_WTq",
880 | "outputId": "58cded06-7fee-4c04-cb03-92531d227e4e"
881 | },
882 | "outputs": [
883 | {
884 | "name": "stdout",
885 | "output_type": "stream",
886 | "text": [
887 | "glove.6B.100d.txt glove.twitter.27B.100d.txt\t uncased_L-12_H-768_A-12.zip\n",
888 | "glove.6B.200d.txt glove.twitter.27B.100d.txt.pt\n"
889 | ]
890 | }
891 | ],
892 | "source": [
893 | "!ls '/content/gdrive/My Drive/embedding'"
894 | ]
895 | },
896 | {
897 | "cell_type": "code",
898 | "execution_count": 33,
899 | "metadata": {
900 | "colab": {
901 | "base_uri": "https://localhost:8080/",
902 | "height": 54
903 | },
904 | "colab_type": "code",
905 | "id": "9GfaJCDJEBK7",
906 | "outputId": "9a37baf2-ae28-448b-a672-ac30d9241f6c"
907 | },
908 | "outputs": [
909 | {
910 | "name": "stdout",
911 | "output_type": "stream",
912 | "text": [
913 | "CPU times: user 532 ms, sys: 644 ms, total: 1.18 s\n",
914 | "Wall time: 1.45 s\n"
915 | ]
916 | }
917 | ],
918 | "source": [
919 | "%%time\n",
920 | "vec = torchtext.vocab.Vectors('glove.twitter.27B.100d.txt', '/content/gdrive/My Drive/embedding')"
921 | ]
922 | },
923 | {
924 | "cell_type": "code",
925 | "execution_count": 34,
926 | "metadata": {
927 | "colab": {
928 | "base_uri": "https://localhost:8080/",
929 | "height": 35
930 | },
931 | "colab_type": "code",
932 | "id": "w8HuaHBEod81",
933 | "outputId": "c3fbff7a-33e2-4562-af45-3711f353e421"
934 | },
935 | "outputs": [
936 | {
937 | "data": {
938 | "text/plain": [
939 | "1612"
940 | ]
941 | },
942 | "execution_count": 34,
943 | "metadata": {
944 | "tags": []
945 | },
946 | "output_type": "execute_result"
947 | }
948 | ],
949 | "source": [
950 | "len(train_data)"
951 | ]
952 | },
953 | {
954 | "cell_type": "code",
955 | "execution_count": 35,
956 | "metadata": {
957 | "colab": {
958 | "base_uri": "https://localhost:8080/",
959 | "height": 54
960 | },
961 | "colab_type": "code",
962 | "id": "jGAbh_g2QxAt",
963 | "outputId": "b14ed9d5-7bb8-41a3-b803-e3aa2084116d"
964 | },
965 | "outputs": [
966 | {
967 | "name": "stdout",
968 | "output_type": "stream",
969 | "text": [
970 | "CPU times: user 98.7 ms, sys: 2.02 ms, total: 101 ms\n",
971 | "Wall time: 106 ms\n"
972 | ]
973 | }
974 | ],
975 | "source": [
976 | "%%time\n",
977 | "MAX_VOCAB_SIZE = 100_000\n",
978 | "\n",
979 | "TEXT.build_vocab(train_data, \n",
980 | " max_size = MAX_VOCAB_SIZE,\n",
981 | " vectors=vec)\n",
982 | "\n",
983 | "TARGET.build_vocab(train_data)"
984 | ]
985 | },
986 | {
987 | "cell_type": "code",
988 | "execution_count": 36,
989 | "metadata": {
990 | "colab": {
991 | "base_uri": "https://localhost:8080/",
992 | "height": 35
993 | },
994 | "colab_type": "code",
995 | "id": "IaAJsJviSrr2",
996 | "outputId": "2bd84a30-0f59-4364-848e-83243bfd0110"
997 | },
998 | "outputs": [
999 | {
1000 | "data": {
1001 | "text/plain": [
1002 | "torch.Size([4577, 100])"
1003 | ]
1004 | },
1005 | "execution_count": 36,
1006 | "metadata": {
1007 | "tags": []
1008 | },
1009 | "output_type": "execute_result"
1010 | }
1011 | ],
1012 | "source": [
1013 | "TEXT.vocab.vectors.shape"
1014 | ]
1015 | },
1016 | {
1017 | "cell_type": "code",
1018 | "execution_count": 37,
1019 | "metadata": {
1020 | "colab": {
1021 | "base_uri": "https://localhost:8080/",
1022 | "height": 35
1023 | },
1024 | "colab_type": "code",
1025 | "id": "g3vwetTcG9mx",
1026 | "outputId": "8b888140-294f-4b80-d79e-899224ac777a"
1027 | },
1028 | "outputs": [
1029 | {
1030 | "data": {
1031 | "text/plain": [
1032 | ""
1033 | ]
1034 | },
1035 | "execution_count": 37,
1036 | "metadata": {
1037 | "tags": []
1038 | },
1039 | "output_type": "execute_result"
1040 | }
1041 | ],
1042 | "source": [
1043 | "train_data"
1044 | ]
1045 | },
1046 | {
1047 | "cell_type": "markdown",
1048 | "metadata": {
1049 | "colab_type": "text",
1050 | "id": "PrF-Wlbi5czS"
1051 | },
1052 | "source": [
1053 | "## **5. Load data in batches**"
1054 | ]
1055 | },
1056 | {
1057 | "cell_type": "markdown",
1058 | "metadata": {
1059 | "colab_type": "text",
1060 | "id": "uwBCojaA-bk7"
1061 | },
1062 | "source": [
1063 | "We will use the BucketIterator to access the Dataloader. It sorts data according to length of text, and groups similar length text in a batch, thus reducing the amount of padding required. It pads the batch according to the max length in that particular batch"
1064 | ]
1065 | },
1066 | {
1067 | "cell_type": "code",
1068 | "execution_count": 0,
1069 | "metadata": {
1070 | "colab": {},
1071 | "colab_type": "code",
1072 | "id": "25yseUiWKwfC"
1073 | },
1074 | "outputs": [],
1075 | "source": [
1076 | "\n",
1077 | "train_loader, val_loader, test_loader = BucketIterator.splits(datasets=(train_data, val_data, test_data), \n",
1078 | " batch_sizes=(3,3,3), \n",
1079 | " sort_key=lambda x: len(x.tweet), \n",
1080 | " device=None, \n",
1081 | " sort_within_batch=True, \n",
1082 | " repeat=False)"
1083 | ]
1084 | },
1085 | {
1086 | "cell_type": "code",
1087 | "execution_count": 39,
1088 | "metadata": {
1089 | "colab": {
1090 | "base_uri": "https://localhost:8080/",
1091 | "height": 35
1092 | },
1093 | "colab_type": "code",
1094 | "id": "B-xV8SAC7dZ7",
1095 | "outputId": "6b23d716-da32-4398-be1a-be5d3a6d1715"
1096 | },
1097 | "outputs": [
1098 | {
1099 | "data": {
1100 | "text/plain": [
1101 | "(538, 135, 169)"
1102 | ]
1103 | },
1104 | "execution_count": 39,
1105 | "metadata": {
1106 | "tags": []
1107 | },
1108 | "output_type": "execute_result"
1109 | }
1110 | ],
1111 | "source": [
1112 | "len(train_loader), len(val_loader), len(test_loader)"
1113 | ]
1114 | },
1115 | {
1116 | "cell_type": "code",
1117 | "execution_count": 40,
1118 | "metadata": {
1119 | "colab": {
1120 | "base_uri": "https://localhost:8080/",
1121 | "height": 35
1122 | },
1123 | "colab_type": "code",
1124 | "id": "d_p56mLY71w7",
1125 | "outputId": "66aa47a0-9693-401c-cf34-e97b695b373e"
1126 | },
1127 | "outputs": [
1128 | {
1129 | "data": {
1130 | "text/plain": [
1131 | "torchtext.data.batch.Batch"
1132 | ]
1133 | },
1134 | "execution_count": 40,
1135 | "metadata": {
1136 | "tags": []
1137 | },
1138 | "output_type": "execute_result"
1139 | }
1140 | ],
1141 | "source": [
1142 | "batch = next(iter(train_loader))\n",
1143 | "type(batch)"
1144 | ]
1145 | },
1146 | {
1147 | "cell_type": "code",
1148 | "execution_count": 41,
1149 | "metadata": {
1150 | "colab": {
1151 | "base_uri": "https://localhost:8080/",
1152 | "height": 35
1153 | },
1154 | "colab_type": "code",
1155 | "id": "dMp72yoo7-70",
1156 | "outputId": "3dea4de7-b939-42ce-b948-30e3b4424a03"
1157 | },
1158 | "outputs": [
1159 | {
1160 | "data": {
1161 | "text/plain": [
1162 | "tensor([0, 0, 0])"
1163 | ]
1164 | },
1165 | "execution_count": 41,
1166 | "metadata": {
1167 | "tags": []
1168 | },
1169 | "output_type": "execute_result"
1170 | }
1171 | ],
1172 | "source": [
1173 | "batch.target"
1174 | ]
1175 | },
1176 | {
1177 | "cell_type": "code",
1178 | "execution_count": 42,
1179 | "metadata": {
1180 | "colab": {
1181 | "base_uri": "https://localhost:8080/",
1182 | "height": 126
1183 | },
1184 | "colab_type": "code",
1185 | "id": "0W5vLVNy81Lj",
1186 | "outputId": "2bdede4f-5d0b-4db3-8d8b-e00ba13f5753"
1187 | },
1188 | "outputs": [
1189 | {
1190 | "data": {
1191 | "text/plain": [
1192 | "(tensor([[ 28, 1512, 113],\n",
1193 | " [ 143, 118, 682],\n",
1194 | " [ 392, 670, 838],\n",
1195 | " [1771, 5, 13],\n",
1196 | " [ 14, 1414, 57],\n",
1197 | " [1137, 1258, 160]]), tensor([6, 6, 6]))"
1198 | ]
1199 | },
1200 | "execution_count": 42,
1201 | "metadata": {
1202 | "tags": []
1203 | },
1204 | "output_type": "execute_result"
1205 | }
1206 | ],
1207 | "source": [
1208 | "batch.tweet"
1209 | ]
1210 | },
1211 | {
1212 | "cell_type": "code",
1213 | "execution_count": 43,
1214 | "metadata": {
1215 | "colab": {
1216 | "base_uri": "https://localhost:8080/",
1217 | "height": 35
1218 | },
1219 | "colab_type": "code",
1220 | "id": "d3KJ6U2R9Lbz",
1221 | "outputId": "929d6c62-62d1-43af-80cc-8aeb17bcc554"
1222 | },
1223 | "outputs": [
1224 | {
1225 | "data": {
1226 | "text/plain": [
1227 | "''"
1228 | ]
1229 | },
1230 | "execution_count": 43,
1231 | "metadata": {
1232 | "tags": []
1233 | },
1234 | "output_type": "execute_result"
1235 | }
1236 | ],
1237 | "source": [
1238 | "TEXT.vocab.itos[1]"
1239 | ]
1240 | },
1241 | {
1242 | "cell_type": "code",
1243 | "execution_count": 0,
1244 | "metadata": {
1245 | "colab": {},
1246 | "colab_type": "code",
1247 | "id": "NfABZ4WR9nZT"
1248 | },
1249 | "outputs": [],
1250 | "source": [
1251 | "def idxtosent(batch, idx):\n",
1252 | " return ' '.join([TEXT.vocab.itos[i] for i in batch.tweet[0][:,idx].cpu().data.numpy()])"
1253 | ]
1254 | },
1255 | {
1256 | "cell_type": "code",
1257 | "execution_count": 45,
1258 | "metadata": {
1259 | "colab": {
1260 | "base_uri": "https://localhost:8080/",
1261 | "height": 35
1262 | },
1263 | "colab_type": "code",
1264 | "id": "fj5ZIbUC9ydr",
1265 | "outputId": "a03b209e-4a03-4421-9d36-9330120a0b7c"
1266 | },
1267 | "outputs": [
1268 | {
1269 | "data": {
1270 | "text/plain": [
1271 | "'do nt worry abt it bb'"
1272 | ]
1273 | },
1274 | "execution_count": 45,
1275 | "metadata": {
1276 | "tags": []
1277 | },
1278 | "output_type": "execute_result"
1279 | }
1280 | ],
1281 | "source": [
1282 | "idxtosent(batch,0)"
1283 | ]
1284 | },
1285 | {
1286 | "cell_type": "code",
1287 | "execution_count": 46,
1288 | "metadata": {
1289 | "colab": {
1290 | "base_uri": "https://localhost:8080/",
1291 | "height": 235
1292 | },
1293 | "colab_type": "code",
1294 | "id": "kzTDu3aR90zC",
1295 | "outputId": "732bd3a0-29be-436e-d4dc-ef6aa654f56b"
1296 | },
1297 | "outputs": [
1298 | {
1299 | "data": {
1300 | "text/plain": [
1301 | "{'batch_size': 3,\n",
1302 | " 'dataset': ,\n",
1303 | " 'fields': dict_keys([None, 'tweet', 'target']),\n",
1304 | " 'input_fields': ['tweet', 'target'],\n",
1305 | " 'target': tensor([0, 0, 0]),\n",
1306 | " 'target_fields': [],\n",
1307 | " 'tweet': (tensor([[ 28, 1512, 113],\n",
1308 | " [ 143, 118, 682],\n",
1309 | " [ 392, 670, 838],\n",
1310 | " [1771, 5, 13],\n",
1311 | " [ 14, 1414, 57],\n",
1312 | " [1137, 1258, 160]]), tensor([6, 6, 6]))}"
1313 | ]
1314 | },
1315 | "execution_count": 46,
1316 | "metadata": {
1317 | "tags": []
1318 | },
1319 | "output_type": "execute_result"
1320 | }
1321 | ],
1322 | "source": [
1323 | "batch.__dict__"
1324 | ]
1325 | },
1326 | {
1327 | "cell_type": "code",
1328 | "execution_count": 0,
1329 | "metadata": {
1330 | "colab": {},
1331 | "colab_type": "code",
1332 | "id": "ClH37CUB-GXT"
1333 | },
1334 | "outputs": [],
1335 | "source": [
1336 | "class BatchGenerator:\n",
1337 | " def __init__(self, dl, x_field, y_field):\n",
1338 | " self.dl, self.x_field, self.y_field = dl, x_field, y_field\n",
1339 | " \n",
1340 | " def __len__(self):\n",
1341 | " return len(self.dl)\n",
1342 | " \n",
1343 | " def __iter__(self):\n",
1344 | " for batch in self.dl:\n",
1345 | " X = getattr(batch, self.x_field)\n",
1346 | " y = getattr(batch, self.y_field)\n",
1347 | " yield (X,y)"
1348 | ]
1349 | },
1350 | {
1351 | "cell_type": "code",
1352 | "execution_count": 48,
1353 | "metadata": {
1354 | "colab": {
1355 | "base_uri": "https://localhost:8080/",
1356 | "height": 417
1357 | },
1358 | "colab_type": "code",
1359 | "id": "bYWXjqv--UEF",
1360 | "outputId": "ce4c846d-d688-4d8b-c2f3-5a135c62cde9"
1361 | },
1362 | "outputs": [
1363 | {
1364 | "data": {
1365 | "text/plain": [
1366 | "((tensor([[ 227, 3, 665],\n",
1367 | " [2795, 64, 1387],\n",
1368 | " [4003, 3, 1181],\n",
1369 | " [ 94, 1038, 1760],\n",
1370 | " [3846, 42, 1095],\n",
1371 | " [ 4, 7, 1179],\n",
1372 | " [3867, 554, 30],\n",
1373 | " [ 430, 3337, 230],\n",
1374 | " [ 4, 2, 1316],\n",
1375 | " [1140, 37, 2],\n",
1376 | " [ 6, 3, 8],\n",
1377 | " [1752, 35, 401],\n",
1378 | " [ 16, 28, 1433],\n",
1379 | " [ 267, 15, 29],\n",
1380 | " [2268, 64, 20],\n",
1381 | " [3689, 3, 404],\n",
1382 | " [ 169, 23, 665],\n",
1383 | " [ 385, 3854, 2],\n",
1384 | " [ 9, 4, 168],\n",
1385 | " [1368, 222, 49],\n",
1386 | " [ 727, 964, 1305],\n",
1387 | " [ 720, 2, 2]]), tensor([22, 22, 22])), tensor([0, 1, 0]))"
1388 | ]
1389 | },
1390 | "execution_count": 48,
1391 | "metadata": {
1392 | "tags": []
1393 | },
1394 | "output_type": "execute_result"
1395 | }
1396 | ],
1397 | "source": [
1398 | "train_batch_it = BatchGenerator(train_loader, 'tweet', 'target')\n",
1399 | "next(iter(train_batch_it))"
1400 | ]
1401 | },
1402 | {
1403 | "cell_type": "markdown",
1404 | "metadata": {
1405 | "colab_type": "text",
1406 | "id": "Tj_aJPx__mv5"
1407 | },
1408 | "source": [
1409 | "## **6. Models and Training**\n",
1410 | "\n",
1411 | "For the model, we decided to follow the example in https://medium.com/@sonicboom8/sentiment-analysis-torchtext-55fb57b1fab8, but make small modifications such as adding some dropout, to prevent overfitting\n",
1412 | "\n",
1413 | "The model is uses a pre-trained embedding layer from glove, a bidirectional GRU and also a concat pooling method where we perform average pool and max pool and then concatenate the results.\n",
1414 | "\n",
1415 | "The final result was ok, with around 80% test accuracy. It was clear that the model was overfitting but we had run out of time to make further adjustments. This has been a very educational experience as it was our first time implementing NLP using PyTorch. We plan to experiment and improve the model using a varitey of methods in the future."
1416 | ]
1417 | },
1418 | {
1419 | "cell_type": "code",
1420 | "execution_count": 0,
1421 | "metadata": {
1422 | "colab": {},
1423 | "colab_type": "code",
1424 | "id": "eMfwXhUUGBcD"
1425 | },
1426 | "outputs": [],
1427 | "source": [
1428 | "vocab_size = len(TEXT.vocab)\n",
1429 | "embedding_dim = 100\n",
1430 | "n_hidden = 64\n",
1431 | "n_out = 2"
1432 | ]
1433 | },
1434 | {
1435 | "cell_type": "code",
1436 | "execution_count": 0,
1437 | "metadata": {
1438 | "colab": {},
1439 | "colab_type": "code",
1440 | "id": "-puOJHXXkcL0"
1441 | },
1442 | "outputs": [],
1443 | "source": [
1444 | "# Define a PyTorch module named ConcatPoolingGRUAdaptive\n",
1445 | "class ConcatPoolingGRUAdaptive(nn.Module):\n",
1446 | " \n",
1447 | " # Constructor with the following arguments\n",
1448 | " def __init__(self, vocab_size, embedding_dim, n_hidden, n_out, pretrained_vec, dropout, bidirectional=True):\n",
1449 | " \n",
1450 | " # Call the constructor of the nn.Module class\n",
1451 | " super().__init__()\n",
1452 | " \n",
1453 | " # Initialize instance variables\n",
1454 | " self.vocab_size = vocab_size\n",
1455 | " self.embedding_dim = embedding_dim\n",
1456 | " self.n_hidden = n_hidden\n",
1457 | " self.n_out = n_out\n",
1458 | " self.bidirectional = bidirectional\n",
1459 | " \n",
1460 | " # Create an embedding layer with a size of vocab_size x embedding_dim\n",
1461 | " self.emb = nn.Embedding(self.vocab_size, self.embedding_dim)\n",
1462 | " \n",
1463 | " # Load pre-trained word embeddings into the embedding layer\n",
1464 | " self.emb.weight.data.copy_(pretrained_vec)\n",
1465 | " \n",
1466 | " # Freeze the embedding layer during training\n",
1467 | " self.emb.weight.requires_grad = False\n",
1468 | " \n",
1469 | " # Create a GRU layer with input size of embedding_dim and hidden size of n_hidden\n",
1470 | " self.gru = nn.GRU(self.embedding_dim, self.n_hidden, bidirectional=bidirectional)\n",
1471 | " \n",
1472 | " # Create a fully-connected linear layer to map GRU output to class scores\n",
1473 | " if bidirectional:\n",
1474 | " self.fc = nn.Linear(self.n_hidden*2*2, self.n_out)\n",
1475 | " else:\n",
1476 | " self.fc = nn.Linear(self.n_hidden*2, self.n_out)\n",
1477 | " \n",
1478 | " # Create a dropout layer with dropout probability of dropout\n",
1479 | " self.dropout = nn.Dropout(dropout)\n",
1480 | " \n",
1481 | " # Define the forward method for the module\n",
1482 | " def forward(self, seq, lengths):\n",
1483 | " bs = seq.size(1)\n",
1484 | " \n",
1485 | " # Initialize the hidden state of the GRU\n",
1486 | " self.h = self.init_hidden(bs)\n",
1487 | " \n",
1488 | " # Transpose input sequence to batch-first format\n",
1489 | " seq = seq.transpose(0,1)\n",
1490 | " \n",
1491 | " # Pass input sequence through the embedding layer\n",
1492 | " embs = self.emb(seq)\n",
1493 | " \n",
1494 | " # Transpose embeddings back to sequence-first format\n",
1495 | " embs = embs.transpose(0,1)\n",
1496 | " \n",
1497 | " # Pack the sequence of embeddings using lengths to avoid computing on padded elements\n",
1498 | " embs = pack_padded_sequence(embs, lengths)\n",
1499 | " \n",
1500 | " # Pass packed embeddings through the GRU\n",
1501 | " gru_out, self.h = self.gru(embs, self.h)\n",
1502 | " \n",
1503 | " # Unpack the sequence and remove padding\n",
1504 | " gru_out, lengths = pad_packed_sequence(gru_out) \n",
1505 | " \n",
1506 | " # Apply adaptive max and average pooling to the GRU output along the time dimension\n",
1507 | " avg_pool = F.adaptive_avg_pool1d(gru_out.permute(1,2,0),1).view(bs,-1)\n",
1508 | " max_pool = F.adaptive_max_pool1d(gru_out.permute(1,2,0),1).view(bs,-1) \n",
1509 | " \n",
1510 | " # Concatenate the average and max pooled outputs and apply dropout\n",
1511 | " cat = self.dropout(torch.cat([avg_pool,max_pool],dim=1))\n",
1512 | " \n",
1513 | " # Map the concatenated output to class scores\n",
1514 | " outp = self.fc(cat)\n",
1515 | " \n",
1516 | " # Apply log softmax to class scores and return\n",
1517 | " return F.log_softmax(outp)\n",
1518 | " \n",
1519 | " # Helper method to initialize the hidden state of the GRU\n",
1520 | " def init_hidden(self, batch_size): \n",
1521 | " if self.bidirectional:\n",
1522 | " return torch.zeros((2,batch_size,self.n_hidden)).to(device)\n",
1523 | " else:\n",
1524 | " return torch.zeros((1,batch_size,self.n_hidden)).cuda().to(device)"
1525 | ]
1526 | },
1527 | {
1528 | "cell_type": "code",
1529 | "execution_count": 0,
1530 | "metadata": {
1531 | "colab": {},
1532 | "colab_type": "code",
1533 | "id": "WlycNBIGGtO5"
1534 | },
1535 | "outputs": [],
1536 | "source": [
1537 | "def train(model, iterator, optimizer, criterion, num_batch):\n",
1538 | " y_true_train = list()\n",
1539 | " y_pred_train = list()\n",
1540 | " total_loss_train = 0 \n",
1541 | " \n",
1542 | " #t = tqdm_notebook(iterator, leave=False, total=num_batch)\n",
1543 | " \n",
1544 | " for (X,lengths),y in iterator:\n",
1545 | "\n",
1546 | " #t.set_description(f'Epoch {epoch}')\n",
1547 | " lengths = lengths.cpu().numpy()\n",
1548 | "\n",
1549 | " opt.zero_grad()\n",
1550 | " pred = model(X, lengths)\n",
1551 | " loss = criterion(pred, y)\n",
1552 | " loss.backward()\n",
1553 | " opt.step()\n",
1554 | "\n",
1555 | " #t.set_postfix(loss=loss.item())\n",
1556 | " pred_idx = torch.max(pred, dim=1)[1]\n",
1557 | "\n",
1558 | " y_true_train += list(y.cpu().data.numpy())\n",
1559 | " y_pred_train += list(pred_idx.cpu().data.numpy())\n",
1560 | " total_loss_train += loss.item()\n",
1561 | " \n",
1562 | " train_acc = accuracy_score(y_true_train, y_pred_train)\n",
1563 | " train_loss = total_loss_train/num_batch\n",
1564 | " return train_loss, train_acc"
1565 | ]
1566 | },
1567 | {
1568 | "cell_type": "code",
1569 | "execution_count": 0,
1570 | "metadata": {
1571 | "colab": {},
1572 | "colab_type": "code",
1573 | "id": "j74LaFc6HlRU"
1574 | },
1575 | "outputs": [],
1576 | "source": [
1577 | "def evaluate(model, iterator, criterion, num_batch):\n",
1578 | " y_true_val = list()\n",
1579 | " y_pred_val = list()\n",
1580 | " total_loss_val = 0\n",
1581 | " for (X,lengths),y in iterator: #tqdm_notebook(iterator, leave=False): \n",
1582 | " \n",
1583 | " pred = model(X, lengths.cpu().numpy())\n",
1584 | " loss = criterion(pred, y)\n",
1585 | " pred_idx = torch.max(pred, 1)[1]\n",
1586 | " y_true_val += list(y.cpu().data.numpy())\n",
1587 | " y_pred_val += list(pred_idx.cpu().data.numpy())\n",
1588 | " total_loss_val += loss.item()\n",
1589 | " valacc = accuracy_score(y_true_val, y_pred_val)\n",
1590 | " valloss = total_loss_val/num_batch\n",
1591 | " return valloss, valacc\n",
1592 | " "
1593 | ]
1594 | },
1595 | {
1596 | "cell_type": "code",
1597 | "execution_count": 0,
1598 | "metadata": {
1599 | "colab": {},
1600 | "colab_type": "code",
1601 | "id": "oq402kPY_086"
1602 | },
1603 | "outputs": [],
1604 | "source": [
1605 | "train_loader, val_loader, test_loader = BucketIterator.splits(datasets=(train_data, val_data, test_data), \n",
1606 | " batch_sizes=(32,32,32), \n",
1607 | " sort_key=lambda x: len(x.tweet), \n",
1608 | " device=device, \n",
1609 | " sort_within_batch=True, \n",
1610 | " repeat=False)"
1611 | ]
1612 | },
1613 | {
1614 | "cell_type": "code",
1615 | "execution_count": 0,
1616 | "metadata": {
1617 | "colab": {},
1618 | "colab_type": "code",
1619 | "id": "qOoo-vYpFLOi"
1620 | },
1621 | "outputs": [],
1622 | "source": [
1623 | "train_batch_it = BatchGenerator(train_loader, 'tweet', 'target')\n",
1624 | "val_batch_it = BatchGenerator(val_loader, 'tweet', 'target')\n",
1625 | "test_batch_it = BatchGenerator(test_loader, 'tweet', 'target')"
1626 | ]
1627 | },
1628 | {
1629 | "cell_type": "code",
1630 | "execution_count": 0,
1631 | "metadata": {
1632 | "colab": {},
1633 | "colab_type": "code",
1634 | "id": "7c8x62WYvD4o"
1635 | },
1636 | "outputs": [],
1637 | "source": [
1638 | "m = ConcatPoolingGRUAdaptive(vocab_size, embedding_dim, n_hidden, n_out, train_data.fields['tweet'].vocab.vectors, 0.5).to(device)\n",
1639 | "opt = optim.Adam(filter(lambda p: p.requires_grad, m.parameters()), 1e-3)\n"
1640 | ]
1641 | },
1642 | {
1643 | "cell_type": "code",
1644 | "execution_count": 0,
1645 | "metadata": {
1646 | "colab": {},
1647 | "colab_type": "code",
1648 | "id": "APwJlE7QPSO3"
1649 | },
1650 | "outputs": [],
1651 | "source": [
1652 | "loss_fn=F.nll_loss\n",
1653 | "epochs=10"
1654 | ]
1655 | },
1656 | {
1657 | "cell_type": "code",
1658 | "execution_count": 0,
1659 | "metadata": {
1660 | "colab": {},
1661 | "colab_type": "code",
1662 | "id": "EGtekmstQVWo"
1663 | },
1664 | "outputs": [],
1665 | "source": [
1666 | "import time\n",
1667 | "\n",
1668 | "def epoch_time(start_time, end_time):\n",
1669 | " elapsed_time = end_time - start_time\n",
1670 | " elapsed_mins = int(elapsed_time / 60)\n",
1671 | " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
1672 | " return elapsed_mins, elapsed_secs"
1673 | ]
1674 | },
1675 | {
1676 | "cell_type": "code",
1677 | "execution_count": 58,
1678 | "metadata": {
1679 | "colab": {
1680 | "base_uri": "https://localhost:8080/",
1681 | "height": 217
1682 | },
1683 | "colab_type": "code",
1684 | "id": "5KK44r2qHDZV",
1685 | "outputId": "04cffe04-f4d3-4c87-b296-fa2967d91e3d"
1686 | },
1687 | "outputs": [
1688 | {
1689 | "name": "stderr",
1690 | "output_type": "stream",
1691 | "text": [
1692 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:36: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
1693 | ]
1694 | },
1695 | {
1696 | "name": "stdout",
1697 | "output_type": "stream",
1698 | "text": [
1699 | "Epoch 0: train_loss: 0.5438 train_acc: 0.7599 | val_loss: 0.5334 val_acc: 0.7574\n",
1700 | "Epoch 1: train_loss: 0.4733 train_acc: 0.7829 | val_loss: 0.4746 val_acc: 0.7797\n",
1701 | "Epoch 2: train_loss: 0.4218 train_acc: 0.8015 | val_loss: 0.4234 val_acc: 0.8020\n",
1702 | "Epoch 3: train_loss: 0.3920 train_acc: 0.8139 | val_loss: 0.4133 val_acc: 0.8045\n",
1703 | "Epoch 4: train_loss: 0.3640 train_acc: 0.8319 | val_loss: 0.4217 val_acc: 0.7921\n",
1704 | "Epoch 5: train_loss: 0.3424 train_acc: 0.8406 | val_loss: 0.3862 val_acc: 0.8342\n",
1705 | "Epoch 6: train_loss: 0.3226 train_acc: 0.8474 | val_loss: 0.4097 val_acc: 0.8045\n",
1706 | "Epoch 7: train_loss: 0.3181 train_acc: 0.8524 | val_loss: 0.4272 val_acc: 0.7970\n",
1707 | "Epoch 8: train_loss: 0.2910 train_acc: 0.8716 | val_loss: 0.3975 val_acc: 0.8094\n",
1708 | "Epoch 9: train_loss: 0.2587 train_acc: 0.8877 | val_loss: 0.4128 val_acc: 0.8069\n"
1709 | ]
1710 | }
1711 | ],
1712 | "source": [
1713 | "best_valid_loss = float('inf')\n",
1714 | "\n",
1715 | "epochs=10\n",
1716 | "\n",
1717 | "for epoch in range(epochs): \n",
1718 | "\n",
1719 | " start_time = time.time()\n",
1720 | " \n",
1721 | " train_loss, train_acc = train(m, iter(train_batch_it), opt, loss_fn, len(train_batch_it))\n",
1722 | " valid_loss, valid_acc = evaluate(m, iter(val_batch_it), loss_fn, len(val_batch_it))\n",
1723 | "\n",
1724 | " end_time = time.time()\n",
1725 | "\n",
1726 | " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
1727 | " \n",
1728 | " if valid_loss < best_valid_loss:\n",
1729 | " best_valid_loss = valid_loss\n",
1730 | " torch.save(m.state_dict(), 'tut4-model.pt')\n",
1731 | " \n",
1732 | " \n",
1733 | " print(f'Epoch {epoch}: train_loss: {train_loss:.4f} train_acc: {train_acc:.4f} | val_loss: {valid_loss:.4f} val_acc: {valid_acc:.4f}')\n"
1734 | ]
1735 | },
1736 | {
1737 | "cell_type": "code",
1738 | "execution_count": 59,
1739 | "metadata": {
1740 | "colab": {
1741 | "base_uri": "https://localhost:8080/",
1742 | "height": 54
1743 | },
1744 | "colab_type": "code",
1745 | "id": "m0FwUXj1KfaN",
1746 | "outputId": "916acc69-3a29-41c7-ec67-a6f894a40b85"
1747 | },
1748 | "outputs": [
1749 | {
1750 | "name": "stdout",
1751 | "output_type": "stream",
1752 | "text": [
1753 | "Test Loss: 0.409 | Test Acc: 80.59%\n"
1754 | ]
1755 | },
1756 | {
1757 | "name": "stderr",
1758 | "output_type": "stream",
1759 | "text": [
1760 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:36: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
1761 | ]
1762 | }
1763 | ],
1764 | "source": [
1765 | "test_loss, test_acc = evaluate(m, iter(test_batch_it), loss_fn, len(test_batch_it))\n",
1766 | "\n",
1767 | "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
1768 | ]
1769 | },
1770 | {
1771 | "cell_type": "code",
1772 | "execution_count": 0,
1773 | "metadata": {
1774 | "colab": {},
1775 | "colab_type": "code",
1776 | "id": "qDsz_cIGSuKO"
1777 | },
1778 | "outputs": [],
1779 | "source": []
1780 | }
1781 | ],
1782 | "metadata": {
1783 | "accelerator": "GPU",
1784 | "colab": {
1785 | "name": "depression_detector.ipynb",
1786 | "provenance": [],
1787 | "version": "0.3.2"
1788 | },
1789 | "hide_input": false,
1790 | "kernelspec": {
1791 | "display_name": "Python 3",
1792 | "language": "python",
1793 | "name": "python3"
1794 | },
1795 | "language_info": {
1796 | "codemirror_mode": {
1797 | "name": "ipython",
1798 | "version": 3
1799 | },
1800 | "file_extension": ".py",
1801 | "mimetype": "text/x-python",
1802 | "name": "python",
1803 | "nbconvert_exporter": "python",
1804 | "pygments_lexer": "ipython3",
1805 | "version": "3.9.1"
1806 | }
1807 | },
1808 | "nbformat": 4,
1809 | "nbformat_minor": 1
1810 | }
1811 |
--------------------------------------------------------------------------------
/Depression detection:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Project Motivation.odt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swcwang/depression-detection/c857d3c350c2c78323b441cb6f3ac2a806c88aef/Project Motivation.odt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Depression Detection Using Twitter Data
2 | Depression Detection using Twitter data
3 |
4 | 1. [Project Motivation](#motivation)
5 | 2. [Proposed Model](#proposedModel)
6 | 2. [Dataset Construction](#dataset)
7 | 3. [Project Phases](#phases)
8 | 4. [Future Plan](#futureplan)
9 | 5. [Contributors](#contributors)
10 | 6. [Recources](#recources)
11 |
12 |
13 |
14 | ## Project Motivation
15 |
16 | As the society becomes more and more technically advanced, we are tuning more into the online world. But as a result, some people can feel more isolated and this can in turn increase the occurrence of anxiety and depression.
17 |
18 | Ironically, some people turn to social media to outlet their thoughts and emotions, and platforms like twitter, offers users a level of anonymity, which can entice the user to be more uninhibited in their expressions.
19 |
20 | This offers an opportunity for emotional detection of users based on their tweets. From the medical perspective, it provides a good opportunity to identify potential depression in the users and from there offer suitable support and assist them to a path of recovery, by introducing them to self-care chatbots such as woebot, that uses Cognitive Behavioral Therapy to help its users change their negative thought patterns, as well as providing friendship in their time of need.
21 |
22 | We also aim to provide a dataset that is specifically designed for depression identification based on tweets. Our research so far shows such data is not readily available, and proves to be a major stumbling block in the development of this project.
23 |
24 | ## Proposed Model
25 |
26 | Our depression detector can be incorporated into existing products such as GBoard on Android, the Google Keyboard, which uses federated learning to improve the user experience based on their search history . As GBoard already collects user history, it is conceivable it can be extended to incorporate our model and identify instances of depression, especially based on user´s textual input over a period of time.
27 |
28 | If signs of depression has been detected, then it would be desirable to suggest the user to use a self-care chatbot. (we need to think of a way to do this without infringing on the user's privacy, for example, we do not want to send this diagnosis back to the server in its raw form) Perhaps the suggestion of self-care bot can be an automatic feature that is integrated into GBoard upon depression detection, so that the raw data does not need to go back to the centralised server, and does not require revelation of the user's identity. Here we may implement federated learning along with local differential privacy, to ensure the user~s privacy is protected.
29 |
30 | We also aim to provide a dataset that is specifically designed for depression identification based on tweets. Our research so far shows such data is not readily available, and proves to be a major stumbling block in the development of this project.
31 |
32 |
33 |
34 |
35 |
36 |
37 | ## Dataset Construction
38 |
39 | Dataset construction turned out to be the biggest piece of work in our project.
40 |
41 | Initially we used data from an existing github repository [Detecting-Depression-in-Tweets](https://github.com/viritaromero/Detecting-Depression-in-Tweets), which has the same purpose of detecting depression in tweets. However, our initial model was able to achieve over 99% accuracy during validation, because the data was too simplistic and most of the labelled depressive entries contain the word **depression**.
42 |
43 | So we decided to create our own twitter dataset for depression by using the third party tool, [TWINT](https://github.com/twintproject/twint).
44 |
45 |
46 | ### Collecting data
47 |
48 | We constructed a script (data_scraper.ipynb) to apply the following steps to identify potentially depressive tweets:
49 | 1. Used twint to collect tweets based on the following hashtags
50 | - #depressed
51 | - #depression
52 | - #loneliness
53 | - #hopelessness1.
54 |
55 | 2. Remove duplicated entries based on tweet id.
56 |
57 | 3. Remove entries that contain these positive or medical or educational sounding hashtags
58 | - #mentalhealth
59 | - #health
60 | - #happiness
61 | - #mentalillness
62 | - #happy
63 | - #joy
64 | - #wellbeing1.
65 |
66 | 4. Remove entries with any of the following characteristics, as they are more likely to be promotional messages
67 | - Containing more than three hashtags
68 | - Containing @mentions
69 | - Containing URLs
70 |
71 | 5. Remove entries with less than 25 characters, or 5 words.
72 |
73 | 6. Lastly, remove all hashtags from the tweets. This is because the hashtags themselves are an obvious indicator of depressive text, and we would like to train our model to focus on the content of the tweet rather than the existence of depressive hashtags.
74 |
75 |
76 | The results are saved into csv files and allocated to our team members for review.
77 |
78 |
79 | ### Reviewing dataset
80 |
81 | We manually reviewed csv files generated by the previous script, which contain filtered tweets that originally contained depressive hashtags. The csv files have a target column set to 1 by default, and we manually set the non-depressive entries to have target of 0, and also removed non-English tweets from the file.
82 |
83 | The resulting csv files contain roughly 50-50 split of depressive and non-depressive tweets. This is a good resource to train our model on, as all of the tweets originally had depressive hashtags, so by distinguishing the tweets based on its content rather than tags, we are training the model to be more sensitive to the content and more precise in its predictions.
84 |
85 |
86 | ### More CLearning
87 |
88 | We also collected additional tweets from other sources, to represent non-depressive texts. These contain other emotions, such as joy, love, surprise, as well as some emotionally neutral tweets.
89 |
90 | ### Finalizing Dataset
91 | Combine the datasets from Part 3 and Part 4 to create a final dataset
92 |
93 |
94 |
95 | Even though the dataset creation took a lot of time and effort, we believe that it was really important to get this right, as this is the basis of creating an accurate depression detector and potentially differentiate it from just a sentiment detector.
96 |
97 |
98 | 
99 | 
100 |
101 |
102 |
103 |
104 | ## Future Plan
105 |
106 | - Use an external software such as liwc (http://liwc.wpengine.com/) to review the linguistic and emotional content of the tweets, and verify that the labels are correct.
107 |
108 |
109 |
110 | ## Contributors
111 |
112 | Contributor | Slack Handle
113 | ------------ | -------------
114 | Susan Wang | @SusanW
115 | Labiba Kanij Rupty | @Labiba
116 | Mahfuza Humayra Mohona | @Mohona
117 | Aarthi Alagammai | @Aarthi Alagammai
118 | Munira Omar | @Munira Omar
119 | Marwa Qabeel | @Marwa
120 |
121 |
122 |
123 |
124 |
125 | ## Resources
126 | - Anne Bonner's Medium article [You Are What You Tweet](https://towardsdatascience.com/you-are-what-you-tweet-7e23fb84f4ed).
127 | - [Sentiment Analysis — TorchText](https://medium.com/@sonicboom8/sentiment-analysis-torchtext-55fb57b1fab8)
128 | - Pranjal Chaubey repo [Sixty AI](https://github.com/pranjalchaubey/Sixty-AI)
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/data/tweets_final_1_clean.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swcwang/depression-detection/c857d3c350c2c78323b441cb6f3ac2a806c88aef/data/tweets_final_1_clean.csv
--------------------------------------------------------------------------------
/data/tweets_final_2_clean.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swcwang/depression-detection/c857d3c350c2c78323b441cb6f3ac2a806c88aef/data/tweets_final_2_clean.csv
--------------------------------------------------------------------------------
/data/tweets_final_3_clean.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swcwang/depression-detection/c857d3c350c2c78323b441cb6f3ac2a806c88aef/data/tweets_final_3_clean.csv
--------------------------------------------------------------------------------
/data/tweets_final_4_clean.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swcwang/depression-detection/c857d3c350c2c78323b441cb6f3ac2a806c88aef/data/tweets_final_4_clean.csv
--------------------------------------------------------------------------------
/data/tweets_final_5_clean.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swcwang/depression-detection/c857d3c350c2c78323b441cb6f3ac2a806c88aef/data/tweets_final_5_clean.csv
--------------------------------------------------------------------------------
/data/tweets_final_6_clean.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swcwang/depression-detection/c857d3c350c2c78323b441cb6f3ac2a806c88aef/data/tweets_final_6_clean.csv
--------------------------------------------------------------------------------
/embedding/glove.twitter.27B.100d.md:
--------------------------------------------------------------------------------
1 | **Use this link to download the file: [glove.twitter.27B.100d.txt](https://drive.google.com/file/d/10vMvdp_7A07lQ1D1cRt0ZhVBycOnRUeM/view?usp=share_link)**.
2 |
--------------------------------------------------------------------------------