├── .ipynb_checkpoints ├── data_exploration-checkpoint.ipynb └── model_evaluation-checkpoint.ipynb ├── README.md ├── assets ├── class_distribution.png ├── loss.svg ├── lr.svg └── nlp_report.pdf ├── cache ├── cached_bert_dev_multi_label_512_nlp_valid.csv └── cached_bert_train_multi_label_512_nlp_train.csv ├── data_exploration.ipynb ├── data_generator.py ├── find_threshold.py ├── inference.py ├── labels.csv ├── nlp_test.csv ├── nlp_train.csv ├── nlp_valid.csv ├── requirements.txt └── train_bert.py /.ipynb_checkpoints/data_exploration-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import matplotlib\n", 11 | "from matplotlib import pyplot as plt\n", 12 | "%matplotlib inline" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "text/html": [ 23 | "
\n", 24 | "\n", 37 | "\n", 38 | " \n", 39 | " \n", 40 | " \n", 41 | " \n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | "
idtextangeranticipationdisgustfearjoyloveoptimismpessimismsadnesssurprisetrustneutral
00He was answering a question about the criticis...101000010000
11I'm going to start today's discussion thread w...111100010000
22By announcing the 395 self-quarantined, it pai...111100010000
33Likewise, sorry if I offended you. I’m not act...101100010000
44People infected by experience high fever, coug...000000000001
\n", 145 | "
" 146 | ], 147 | "text/plain": [ 148 | " id text anger anticipation \\\n", 149 | "0 0 He was answering a question about the criticis... 1 0 \n", 150 | "1 1 I'm going to start today's discussion thread w... 1 1 \n", 151 | "2 2 By announcing the 395 self-quarantined, it pai... 1 1 \n", 152 | "3 3 Likewise, sorry if I offended you. I’m not act... 1 0 \n", 153 | "4 4 People infected by experience high fever, coug... 0 0 \n", 154 | "\n", 155 | " disgust fear joy love optimism pessimism sadness surprise trust \\\n", 156 | "0 1 0 0 0 0 1 0 0 0 \n", 157 | "1 1 1 0 0 0 1 0 0 0 \n", 158 | "2 1 1 0 0 0 1 0 0 0 \n", 159 | "3 1 1 0 0 0 1 0 0 0 \n", 160 | "4 0 0 0 0 0 0 0 0 0 \n", 161 | "\n", 162 | " neutral \n", 163 | "0 0 \n", 164 | "1 0 \n", 165 | "2 0 \n", 166 | "3 0 \n", 167 | "4 1 " 168 | ] 169 | }, 170 | "execution_count": 2, 171 | "metadata": {}, 172 | "output_type": "execute_result" 173 | } 174 | ], 175 | "source": [ 176 | "# read dataset\n", 177 | "# we will be using trained dataset to understand how people are reacting\n", 178 | "data = pd.read_csv(\"nlp_train.csv\")\n", 179 | "data.head()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 3, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "data": { 189 | "text/html": [ 190 | "
\n", 191 | "\n", 204 | "\n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | "
idangeranticipationdisgustfearjoyloveoptimismpessimismsadnesssurprisetrustneutral
count1493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.000000
mean746.0000000.3643670.5030140.4541190.4541190.1239120.0924310.3281980.4326860.2772940.1085060.1681180.113195
std431.1362890.4814130.5001580.4980570.4980570.3295910.2897310.4697150.4956140.4478130.3111230.3740960.316937
min0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
25%373.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
50%746.0000000.0000001.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
75%1119.0000001.0000001.0000001.0000001.0000000.0000000.0000001.0000001.0000001.0000000.0000000.0000000.000000
max1492.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
\n", 354 | "
" 355 | ], 356 | "text/plain": [ 357 | " id anger anticipation disgust fear \\\n", 358 | "count 1493.000000 1493.000000 1493.000000 1493.000000 1493.000000 \n", 359 | "mean 746.000000 0.364367 0.503014 0.454119 0.454119 \n", 360 | "std 431.136289 0.481413 0.500158 0.498057 0.498057 \n", 361 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 362 | "25% 373.000000 0.000000 0.000000 0.000000 0.000000 \n", 363 | "50% 746.000000 0.000000 1.000000 0.000000 0.000000 \n", 364 | "75% 1119.000000 1.000000 1.000000 1.000000 1.000000 \n", 365 | "max 1492.000000 1.000000 1.000000 1.000000 1.000000 \n", 366 | "\n", 367 | " joy love optimism pessimism sadness \\\n", 368 | "count 1493.000000 1493.000000 1493.000000 1493.000000 1493.000000 \n", 369 | "mean 0.123912 0.092431 0.328198 0.432686 0.277294 \n", 370 | "std 0.329591 0.289731 0.469715 0.495614 0.447813 \n", 371 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 372 | "25% 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 373 | "50% 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 374 | "75% 0.000000 0.000000 1.000000 1.000000 1.000000 \n", 375 | "max 1.000000 1.000000 1.000000 1.000000 1.000000 \n", 376 | "\n", 377 | " surprise trust neutral \n", 378 | "count 1493.000000 1493.000000 1493.000000 \n", 379 | "mean 0.108506 0.168118 0.113195 \n", 380 | "std 0.311123 0.374096 0.316937 \n", 381 | "min 0.000000 0.000000 0.000000 \n", 382 | "25% 0.000000 0.000000 0.000000 \n", 383 | "50% 0.000000 0.000000 0.000000 \n", 384 | "75% 0.000000 0.000000 0.000000 \n", 385 | "max 1.000000 1.000000 1.000000 " 386 | ] 387 | }, 388 | "execution_count": 3, 389 | "metadata": {}, 390 | "output_type": "execute_result" 391 | } 392 | ], 393 | "source": [ 394 | "#basic stats\n", 395 | "data.describe()" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 11, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "data": { 405 | "text/plain": [ 406 | "0 949\n", 407 | "1 544\n", 408 | "Name: anger, dtype: int64" 409 | ] 410 | }, 411 | "execution_count": 11, 412 | "metadata": {}, 413 | "output_type": "execute_result" 414 | } 415 | ], 416 | "source": [ 417 | "data['anger'].value_counts()" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 15, 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 16, 430 | "metadata": {}, 431 | "outputs": [ 432 | { 433 | "data": { 434 | "text/plain": [ 435 | "0 1003\n", 436 | "1 490\n", 437 | "Name: optimism, dtype: int64" 438 | ] 439 | }, 440 | "execution_count": 16, 441 | "metadata": {}, 442 | "output_type": "execute_result" 443 | } 444 | ], 445 | "source": [ 446 | "freqs = {\"anger\":data['anger'].value_counts()[1]\n", 447 | "anticipation = data['anticipation'].value_counts()[1]\n", 448 | "data['disgust'].value_counts()[1]\n", 449 | "data['fear'].value_counts()[1]\n", 450 | "data['joy'].value_counts()[1]\n", 451 | "data['love'].value_counts()[1]\n", 452 | "data['optimism'].value_counts()[1]\n", 453 | "data['pessimism'].value_counts()[1]\n", 454 | "data['sadness'].value_counts()[1]\n", 455 | "data['surprise'].value_counts()[1]\n", 456 | "data['trust'].value_counts()[1]\n", 457 | "data['neutral'].value_counts()[1]" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": {}, 464 | "outputs": [], 465 | "source": [] 466 | } 467 | ], 468 | "metadata": { 469 | "kernelspec": { 470 | "display_name": "Python 3", 471 | "language": "python", 472 | "name": "python3" 473 | }, 474 | "language_info": { 475 | "codemirror_mode": { 476 | "name": "ipython", 477 | "version": 3 478 | }, 479 | "file_extension": ".py", 480 | "mimetype": "text/x-python", 481 | "name": "python", 482 | "nbconvert_exporter": "python", 483 | "pygments_lexer": "ipython3", 484 | "version": "3.7.4" 485 | } 486 | }, 487 | "nbformat": 4, 488 | "nbformat_minor": 2 489 | } 490 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/model_evaluation-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Multi Emotion Detection from COVID-19 Text using BERT ## 2 | 3 | ### Requirements ### 4 | 5 | The code was tested with Python3.8 and PyTorch 1.5.0. Requirements can be installed with following command. 6 | ``` 7 | pip install -r requirements.txt 8 | ``` 9 | You may have to change package name (append +cpu) for non-cuda version. 10 | 11 | ### Data Preparation ### 12 | The model expects data to be in the csv file. So, you first need to convert json files into csv files. If you have data in any other format, you may need to modify the code. 13 | 14 | To generate the csv files run the following command. 15 | 16 | ``` 17 | python .\data_generator.py --file=D:\UTD\Assignment\NLP\project\nlp_test.json --csvfile=D:\UTD\Assignment\NLP\project\nlp_test.csv 18 | ``` 19 | 20 | Here ```--csvfile``` represents where to store the converted file. 21 | 22 | 23 | ### Training ### 24 | Once you have the files in the required format, you can start training. You may want to change the parameters. I tried with multiple parameters and the file contains ones that gave the best result. 25 | ``` 26 | python train_bert.py --epochs=15 27 | ``` 28 | 29 | You can find all available options by running following command. 30 | 31 | ``` 32 | python train_bert.py --help 33 | ``` 34 | Following graphs shows loss (1) and learning rate (2) over time. 35 | 36 | 37 | 38 | 39 | 40 | 41 | ### Inference ### 42 | Once you have the trained model, you can run the inference on test csv files. Note that as of now, this script requires annotated data to compute the metrics. But it can easily be modified to generate output only. 43 | 44 | Pretrained model can be found [here](https://utdallas.box.com/s/sqqb0n9qe7txb6j3725aiz76gwlmszuw) 45 | 46 | ``` 47 | python inference.py --test_csv=D:\\UTD\\Assignment\\NLP\\project\\nlp_valid.csv --model_dir=D:\\UTD\\Assignment\\NLP\\project\\model_output\\3_finetune_e20 48 | ``` 49 | 50 | If ```--evaluation``` is set to true, it will output various metrics. 51 | 52 | **Threshold** 53 | One important factor here is to find the optimal threshold for the confidence score. I tested various threshold and found 0.0017 to give the best results for the above specified model. If you train your own model, you may want to run ```find_threshold.py``` to find the best threshold. 54 | 55 | ### Model Evaluation ### 56 | 57 | If you want to evaluate your model on test or train set, you can do so by running following command. Note that the file must be in the csv format. 58 | 59 | ``` 60 | python inference.py --test_csv=D:\\UTD\\Assignment\\NLP\\project\\nlp_test.csv --evaluation=True 61 | ``` 62 | 63 | I ran the evaluation on train set and found following information. 64 | 65 | 66 | Emotion | Precision | Recall | f1-score 67 | ---------------|---------------|------------|--------------- 68 | Anger | 0.97 | 0.95 | 0.96 69 | Anticipation | 0.98 | 1.00 | 0.99 70 | Disgust | 0.97 | 0.96 | 0.97 71 | Fear | 1.00 | 1.00 | 1.00 72 | Joy | 0.93 | 0.93 | 0.93 73 | Love | 1.00 | 0.73 | 0.85 74 | Optimism | 0.91 | 1.00 | 0.95 75 | Pessimism | 0.98 | 0.94 | 0.96 76 | Sadness | 0.99 | 0.92 | 0.95 77 | Suprise | 0.98 | 0.92 | 0.95 78 | Trust | 0.95 | 0.88 | 0.91 79 | Neutral | 0.92 | 1.00 | 0.96 80 | Average | 0.98 | 0.97 | 0.97 81 | 82 | Following table shows information about test set. 83 | 84 | Emotion | Precision | Recall | f1-score 85 | ---------------|---------------|------------|--------------- 86 | Anger | 0.53 | 0.67 | 0.59 87 | Anticipation | 0.67 | 0.68 | 0.68 88 | Disgust | 0.62 | 0.78 | 0.69 89 | Fear | 0.69 | 0.72 | 0.71 90 | Joy | 0.50 | 0.27 | 0.35 91 | Love | 0.32 | 0.37 | 0.34 92 | Optimism | 0.37 | 0.59 | 0.45 93 | Pessimism | 0.44 | 0.73 | 0.55 94 | Sadness | 0.42 | 0.53 | 0.47 95 | Suprise | 0.55 | 0.38 | 0.45 96 | Trust | 0.12 | 0.12 | 0.12 97 | Neutral | 0.37 | 0.37 | 0.37 98 | Average | 0.60 | 0.62 | 0.55 99 | 100 | If we set threshold to 0.02 then the average accuracy is 0.66. 101 | 102 | ### Possible Improvements 103 | 104 | ![Class Imbalance](/assets/class_distribution.png) 105 | 106 | The biggest caveat here was the class imbalance. It has been established that the class imabalance can negatively affect our model. So, it is always a good idea to balance our data before training the model. Due to limited time, I didn't do any of that stuff. But ideally, we want to oversample from minority classes or pass weights to the loss function. I implemented both approaches for image classification [here](github.com/savan77/Transfer-Learning). 107 | 108 | As we can see above, negative emotions such as an anger or pessimism have bigger representation in the data compare to happy emotions. This makes sense, but in order to train a good model we should even it out. 109 | 110 | Moreover, we can play with the network. Here, the standard off-the-shelf text classification network was used. Maybe adding more fully connected layer on top of the BERT may help. -------------------------------------------------------------------------------- /assets/class_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/savan77/EmotionDetectionBERT/39f859d48250d84e0cef7d1fb9163c37afe6dbfa/assets/class_distribution.png -------------------------------------------------------------------------------- /assets/loss.svg: -------------------------------------------------------------------------------- 1 | -0.100.10.20.30.40.50.60.7-50005001k1.5k2k2.5k3k3.5k -------------------------------------------------------------------------------- /assets/lr.svg: -------------------------------------------------------------------------------- 1 | -1e-401e-42e-43e-44e-45e-46e-47e-48e-4-50005001k1.5k2k2.5k3k3.5k -------------------------------------------------------------------------------- /assets/nlp_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/savan77/EmotionDetectionBERT/39f859d48250d84e0cef7d1fb9163c37afe6dbfa/assets/nlp_report.pdf -------------------------------------------------------------------------------- /cache/cached_bert_dev_multi_label_512_nlp_valid.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/savan77/EmotionDetectionBERT/39f859d48250d84e0cef7d1fb9163c37afe6dbfa/cache/cached_bert_dev_multi_label_512_nlp_valid.csv -------------------------------------------------------------------------------- /cache/cached_bert_train_multi_label_512_nlp_train.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/savan77/EmotionDetectionBERT/39f859d48250d84e0cef7d1fb9163c37afe6dbfa/cache/cached_bert_train_multi_label_512_nlp_train.csv -------------------------------------------------------------------------------- /data_exploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import matplotlib\n", 11 | "from matplotlib import pyplot as plt\n", 12 | "%matplotlib inline" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "text/html": [ 23 | "
\n", 24 | "\n", 37 | "\n", 38 | " \n", 39 | " \n", 40 | " \n", 41 | " \n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | "
idtextangeranticipationdisgustfearjoyloveoptimismpessimismsadnesssurprisetrustneutral
00He was answering a question about the criticis...101000010000
11I'm going to start today's discussion thread w...111100010000
22By announcing the 395 self-quarantined, it pai...111100010000
33Likewise, sorry if I offended you. I’m not act...101100010000
44People infected by experience high fever, coug...000000000001
\n", 145 | "
" 146 | ], 147 | "text/plain": [ 148 | " id text anger anticipation \\\n", 149 | "0 0 He was answering a question about the criticis... 1 0 \n", 150 | "1 1 I'm going to start today's discussion thread w... 1 1 \n", 151 | "2 2 By announcing the 395 self-quarantined, it pai... 1 1 \n", 152 | "3 3 Likewise, sorry if I offended you. I’m not act... 1 0 \n", 153 | "4 4 People infected by experience high fever, coug... 0 0 \n", 154 | "\n", 155 | " disgust fear joy love optimism pessimism sadness surprise trust \\\n", 156 | "0 1 0 0 0 0 1 0 0 0 \n", 157 | "1 1 1 0 0 0 1 0 0 0 \n", 158 | "2 1 1 0 0 0 1 0 0 0 \n", 159 | "3 1 1 0 0 0 1 0 0 0 \n", 160 | "4 0 0 0 0 0 0 0 0 0 \n", 161 | "\n", 162 | " neutral \n", 163 | "0 0 \n", 164 | "1 0 \n", 165 | "2 0 \n", 166 | "3 0 \n", 167 | "4 1 " 168 | ] 169 | }, 170 | "execution_count": 2, 171 | "metadata": {}, 172 | "output_type": "execute_result" 173 | } 174 | ], 175 | "source": [ 176 | "# read dataset\n", 177 | "# we will be using trained dataset to understand how people are reacting\n", 178 | "data = pd.read_csv(\"nlp_train.csv\")\n", 179 | "data.head()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 3, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "data": { 189 | "text/html": [ 190 | "
\n", 191 | "\n", 204 | "\n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | "
idangeranticipationdisgustfearjoyloveoptimismpessimismsadnesssurprisetrustneutral
count1493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.0000001493.000000
mean746.0000000.3643670.5030140.4541190.4541190.1239120.0924310.3281980.4326860.2772940.1085060.1681180.113195
std431.1362890.4814130.5001580.4980570.4980570.3295910.2897310.4697150.4956140.4478130.3111230.3740960.316937
min0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
25%373.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
50%746.0000000.0000001.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
75%1119.0000001.0000001.0000001.0000001.0000000.0000000.0000001.0000001.0000001.0000000.0000000.0000000.000000
max1492.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
\n", 354 | "
" 355 | ], 356 | "text/plain": [ 357 | " id anger anticipation disgust fear \\\n", 358 | "count 1493.000000 1493.000000 1493.000000 1493.000000 1493.000000 \n", 359 | "mean 746.000000 0.364367 0.503014 0.454119 0.454119 \n", 360 | "std 431.136289 0.481413 0.500158 0.498057 0.498057 \n", 361 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 362 | "25% 373.000000 0.000000 0.000000 0.000000 0.000000 \n", 363 | "50% 746.000000 0.000000 1.000000 0.000000 0.000000 \n", 364 | "75% 1119.000000 1.000000 1.000000 1.000000 1.000000 \n", 365 | "max 1492.000000 1.000000 1.000000 1.000000 1.000000 \n", 366 | "\n", 367 | " joy love optimism pessimism sadness \\\n", 368 | "count 1493.000000 1493.000000 1493.000000 1493.000000 1493.000000 \n", 369 | "mean 0.123912 0.092431 0.328198 0.432686 0.277294 \n", 370 | "std 0.329591 0.289731 0.469715 0.495614 0.447813 \n", 371 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 372 | "25% 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 373 | "50% 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 374 | "75% 0.000000 0.000000 1.000000 1.000000 1.000000 \n", 375 | "max 1.000000 1.000000 1.000000 1.000000 1.000000 \n", 376 | "\n", 377 | " surprise trust neutral \n", 378 | "count 1493.000000 1493.000000 1493.000000 \n", 379 | "mean 0.108506 0.168118 0.113195 \n", 380 | "std 0.311123 0.374096 0.316937 \n", 381 | "min 0.000000 0.000000 0.000000 \n", 382 | "25% 0.000000 0.000000 0.000000 \n", 383 | "50% 0.000000 0.000000 0.000000 \n", 384 | "75% 0.000000 0.000000 0.000000 \n", 385 | "max 1.000000 1.000000 1.000000 " 386 | ] 387 | }, 388 | "execution_count": 3, 389 | "metadata": {}, 390 | "output_type": "execute_result" 391 | } 392 | ], 393 | "source": [ 394 | "#basic stats\n", 395 | "data.describe()" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 11, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "data": { 405 | "text/plain": [ 406 | "0 949\n", 407 | "1 544\n", 408 | "Name: anger, dtype: int64" 409 | ] 410 | }, 411 | "execution_count": 11, 412 | "metadata": {}, 413 | "output_type": "execute_result" 414 | } 415 | ], 416 | "source": [ 417 | "data['anger'].value_counts()" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 15, 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 17, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "freqs = {\"anger\":data['anger'].value_counts()[1],\n", 434 | "\"anticipation\": data['anticipation'].value_counts()[1],\n", 435 | "\"disgust\":data['disgust'].value_counts()[1],\n", 436 | "\"fear\":data['fear'].value_counts()[1],\n", 437 | "\"joy\":data['joy'].value_counts()[1],\n", 438 | "\"love\":data['love'].value_counts()[1],\n", 439 | "\"optimism\":data['optimism'].value_counts()[1],\n", 440 | "\"pessimism\":data['pessimism'].value_counts()[1],\n", 441 | "\"sadness\":data['sadness'].value_counts()[1],\n", 442 | "\"surprise\":data['surprise'].value_counts()[1],\n", 443 | "\"trust\":data['trust'].value_counts()[1],\n", 444 | "\"neutral\":data['neutral'].value_counts()[1]}" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 18, 450 | "metadata": {}, 451 | "outputs": [ 452 | { 453 | "data": { 454 | "text/plain": [ 455 | "{'anger': 544,\n", 456 | " 'anticipation': 751,\n", 457 | " 'disgust': 678,\n", 458 | " 'fear': 678,\n", 459 | " 'joy': 185,\n", 460 | " 'love': 138,\n", 461 | " 'optimism': 490,\n", 462 | " 'pessimism': 646,\n", 463 | " 'sadness': 414,\n", 464 | " 'surprise': 162,\n", 465 | " 'trust': 251,\n", 466 | " 'neutral': 169}" 467 | ] 468 | }, 469 | "execution_count": 18, 470 | "metadata": {}, 471 | "output_type": "execute_result" 472 | } 473 | ], 474 | "source": [ 475 | "freqs" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 85, 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "data": { 485 | "image/png": "\n", 486 | "text/plain": [ 487 | "
" 488 | ] 489 | }, 490 | "metadata": { 491 | "needs_background": "light" 492 | }, 493 | "output_type": "display_data" 494 | } 495 | ], 496 | "source": [ 497 | "#plot class distribution\n", 498 | "plt.rcParams['figure.figsize'] = [13, 5]\n", 499 | "plt.bar(freqs.keys(), freqs.values(), width=0.3,color='skyblue')\n", 500 | "plt.text(10,700,\"Target Class Distribution\", fontsize=15, ha='center', va='center')\n", 501 | "t = plt.title(\"Emotions from COVID-19 Related Text\", fontsize=18)\n", 502 | "# t.set_color(\"m\")\n", 503 | "x = plt.xlabel(\"Emotion\", fontsize=13)\n", 504 | "x.set_color('g')\n", 505 | "y = plt.ylabel(\"Frequency\", fontsize=13)\n", 506 | "y.set_color('g')\n", 507 | "plt.savefig('class_distribution.png')\n", 508 | "# [i.set_color(\"c\") for i in plt.gca().get_xticklabels()]\n", 509 | "plt.show()" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 86, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "# there is a clear class imbalance" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": null, 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [] 527 | } 528 | ], 529 | "metadata": { 530 | "kernelspec": { 531 | "display_name": "Python 3", 532 | "language": "python", 533 | "name": "python3" 534 | }, 535 | "language_info": { 536 | "codemirror_mode": { 537 | "name": "ipython", 538 | "version": 3 539 | }, 540 | "file_extension": ".py", 541 | "mimetype": "text/x-python", 542 | "name": "python", 543 | "nbconvert_exporter": "python", 544 | "pygments_lexer": "ipython3", 545 | "version": "3.7.4" 546 | } 547 | }, 548 | "nbformat": 4, 549 | "nbformat_minor": 2 550 | } 551 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | ### 2 | # Generate data (in csv format) to train the BERT model 3 | 4 | import csv 5 | import json 6 | import argparse 7 | import os 8 | 9 | # store emotions as a one hot encodings 10 | def create_model(llabels): 11 | llist = [0]* 12 12 | for em, val in enumerate(llabels.values()): 13 | llist[em] = 1 if val else 0 14 | return llist 15 | 16 | # generate and write to csv file 17 | def generate_csv(file, csvfile): 18 | data= open(file,"r") 19 | out = open(csvfile, "w", encoding="utf-8", newline="") 20 | writer = csv.writer(out) 21 | writer.writerow(["id", "text", "anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"]) 22 | data = json.load(data) 23 | idd = 0 24 | for i,v in data.items(): 25 | bin_vector = create_model(v['emotion']) 26 | writer.writerow([idd,v['body']]+bin_vector) 27 | idd += 1 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--file", default="nlp_train.json", type=str) 33 | parser.add_argument("--csvfile", default="nlp_train.csv", type=str) 34 | args = parser.parse_args() 35 | generate_csv(args.file, args.csvfile) -------------------------------------------------------------------------------- /find_threshold.py: -------------------------------------------------------------------------------- 1 | ###### 2 | # Script to find best possible threshold for the confidence. 3 | # 4 | 5 | from fast_bert.prediction import BertClassificationPredictor 6 | import argparse 7 | import csv 8 | import pandas as pd 9 | 10 | def threshold(model, csvs): 11 | labels = ["anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"] 12 | 13 | predictor = BertClassificationPredictor( 14 | model_path=args.model_dir, 15 | label_path="D:\\UTD\\Assignment\\NLP\\project\\", # location for labels.csv file 16 | multi_label=False, 17 | model_type='bert', 18 | do_lower_case=False) 19 | thresholds = [0.0005,0.00077,0.00079,0.00083,0.00087,0.0009,0.00093,0.00095,0.00099,0.001,0.0012,0.0015,0.00155,0.0016,0.00166,0.0017,0.0019,0.002,0.0021,0.0023,0.0025,0.0028,0.003,0.0035,0.0032,0.0037,0.004,0.0045,0.0047,0.0041,0.005,0.0053,0.0055,0.0062,0.009, 0.007, 0.01, 0.011,0.013,0.014,0.012, 0.015, 0.02, 0.25, 0.03,0.035,0.039] 20 | # targets = [] 21 | inputs = {} 22 | data = pd.read_csv(csvs) 23 | # print(data.head()) 24 | for idx, row in data.iterrows(): 25 | temp = [] 26 | for label in labels: 27 | if row[label] == 1: 28 | temp.append(label) 29 | inputs[row['text']] = temp 30 | 31 | multiple_predictions = predictor.predict_batch(list(inputs.keys())) 32 | threshold_accs = {} 33 | 34 | for th in thresholds: 35 | correct = 0 36 | # print(list(inputs.values())[0]) 37 | outputs = [] 38 | for out in multiple_predictions: 39 | temp = [] 40 | for emotion in out: 41 | if emotion[1] >= th: # greater than threshold 42 | temp.append(emotion[0]) 43 | outputs.append(temp) 44 | # print(outputs[0]) 45 | for i in range(len(inputs)): 46 | if (set(outputs[i]) == set(list(inputs.values())[i])): 47 | correct += 1 48 | print("Threshold: ", th, "Correct: ", correct) 49 | threshold_accs[str(th)] = correct/len(inputs) 50 | print(threshold_accs) 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--model_dir",default="D:\\UTD\\Assignment\\NLP\\project\\model_output\\3_finetune_e20", help="path to output dir") 55 | parser.add_argument("--test_csv", default="D:\\UTD\\Assignment\\NLP\\project\\nlp_test.csv") 56 | args = parser.parse_args() 57 | threshold(args.model_dir, args.test_csv) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Script to generate inference ofr a given csv file 2 | 3 | from fast_bert.prediction import BertClassificationPredictor 4 | import argparse 5 | import csv 6 | import pandas as pd 7 | import os 8 | from sklearn.metrics import classification_report 9 | from sklearn.preprocessing import MultiLabelBinarizer 10 | from pprint import pprint 11 | 12 | # run inference on the csv file provided using the trained model 13 | def run(model,csvs, threshold, evaluation): 14 | labels = ["anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"] 15 | 16 | predictor = BertClassificationPredictor( 17 | model_path=args.model_dir, 18 | label_path="D:\\UTD\\Assignment\\NLP\\project\\", # location for labels.csv file 19 | multi_label=False, 20 | model_type='bert', 21 | do_lower_case=False) 22 | 23 | inputs = {} 24 | ids = [] 25 | data = pd.read_csv(csvs) 26 | # print(data.head()) 27 | for idx, row in data.iterrows(): 28 | temp = [] 29 | for label in labels: 30 | if row[label] == 1: 31 | temp.append(label) 32 | inputs[row['text']] = temp 33 | ids.append(row['id']) 34 | 35 | multiple_predictions = predictor.predict_batch(list(inputs.keys())) 36 | outputs = [] 37 | out_file = open(os.path.join(os.path.dirname(csvs),"model_output.csv"), "w", encoding="utf-8", newline="") 38 | csv_writer = csv.writer(out_file) 39 | csv_writer.writerow(["id","text", "emotions", "target"]) 40 | 41 | for i, out in enumerate(multiple_predictions): 42 | temp = [] 43 | for emotion in out: 44 | if emotion[1] > threshold: # greater than threshold 45 | temp.append(emotion[0]) 46 | csv_writer.writerow([ids[i],list(inputs.keys())[i],temp,list(inputs.values())[i] ]) 47 | outputs.append(temp) 48 | 49 | print("****************\n") 50 | print("Predictions saved in a file: ", os.path.join(os.path.dirname(csvs),"model_output.csv")) 51 | if evaluation: 52 | print("\n\n Running Model Evaluation\n") 53 | y_true = list(inputs.values()) 54 | y_pred = outputs 55 | y_true_encoded = MultiLabelBinarizer().fit_transform(y_true) 56 | y_pred_encoded = MultiLabelBinarizer().fit_transform(y_pred) 57 | pprint(classification_report(y_true_encoded, y_pred_encoded)) 58 | pprint(classification_report(y_true_encoded, y_pred_encoded, target_names=labels)) 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument("--model_dir",default="D:\\UTD\\Assignment\\NLP\\project\\model_output\\3_finetune_e20", help="path to output dir") 63 | parser.add_argument("--test_csv", default="D:\\UTD\\Assignment\\NLP\\project\\nlp_test.csv") 64 | parser.add_argument("--threshold", default=0.0017, type=float) 65 | parser.add_argument("--writeto_file", default=True) 66 | parser.add_argument("--evaluation", default=True) 67 | args = parser.parse_args() 68 | run(args.model_dir, args.test_csv, args.threshold, args.evaluation) -------------------------------------------------------------------------------- /labels.csv: -------------------------------------------------------------------------------- 1 | anger 2 | anticipation 3 | disgust 4 | fear 5 | joy 6 | love 7 | optimism 8 | pessimism 9 | sadness 10 | surprise 11 | trust 12 | neutral -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch===1.5.0 2 | torchvision===0.6.0 3 | fast-bert -------------------------------------------------------------------------------- /train_bert.py: -------------------------------------------------------------------------------- 1 | # Training script for bert 2 | 3 | from fast_bert.data_cls import BertDataBunch 4 | from fast_bert.learner_cls import BertLearner 5 | from fast_bert.metrics import accuracy 6 | import logging 7 | import torch 8 | import os 9 | import argparse 10 | 11 | OUTPUT_DIR = "model_output/" 12 | 13 | def train(args): 14 | if args.is_onepanel: 15 | args.out_dir = os.path.join("/onepanel/output/",args.out_dir) 16 | if not os.path.exists(args.out_dir): 17 | os.mkdir(args.out_dir) 18 | 19 | logger = logging.getLogger() 20 | labels = ["anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"] 21 | databunch = BertDataBunch(".", ".", 22 | tokenizer=args.pretrained_model, 23 | train_file='nlp_train.csv', 24 | label_file='labels.csv', 25 | val_file="nlp_valid.csv", 26 | text_col='text', 27 | label_col=labels, 28 | batch_size_per_gpu=args.batch_size, 29 | max_seq_length=512, 30 | multi_gpu=False, 31 | multi_label=True, 32 | model_type='bert') 33 | 34 | device_cuda = torch.device("cuda") 35 | metrics = [{'name': 'accuracy', 'function': accuracy}] 36 | 37 | learner = BertLearner.from_pretrained_model( 38 | databunch, 39 | pretrained_path=args.pretrained_model, 40 | metrics=metrics, 41 | device=device_cuda, 42 | logger=logger, 43 | output_dir=args.out_dir, 44 | finetuned_wgts_path=None, 45 | warmup_steps=200, 46 | multi_gpu=False, 47 | is_fp16=False, 48 | multi_label=True, 49 | logging_steps=10) 50 | 51 | learner.fit(epochs=args.epochs, 52 | lr=2e-3, 53 | schedule_type="warmup_cosine_hard_restarts", 54 | optimizer_type="lamb") 55 | # validate=True) 56 | learner.save_model() 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--pretrained_model", default="bert-base-uncased", help="path to a pretrained model") 62 | parser.add_argument("--out_dir",default="model_output/", help="path to output dir") 63 | parser.add_argument("--is_onepanel", default=False, type=bool, help="train on onepanel cloud") 64 | parser.add_argument("--epochs", default=15, type=int) 65 | parser.add_argument("--batch_size", default=10, type=int) 66 | args = parser.parse_args() 67 | train(args) 68 | --------------------------------------------------------------------------------