├── .ipynb_checkpoints
├── data_exploration-checkpoint.ipynb
└── model_evaluation-checkpoint.ipynb
├── README.md
├── assets
├── class_distribution.png
├── loss.svg
├── lr.svg
└── nlp_report.pdf
├── cache
├── cached_bert_dev_multi_label_512_nlp_valid.csv
└── cached_bert_train_multi_label_512_nlp_train.csv
├── data_exploration.ipynb
├── data_generator.py
├── find_threshold.py
├── inference.py
├── labels.csv
├── nlp_test.csv
├── nlp_train.csv
├── nlp_valid.csv
├── requirements.txt
└── train_bert.py
/.ipynb_checkpoints/data_exploration-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import matplotlib\n",
11 | "from matplotlib import pyplot as plt\n",
12 | "%matplotlib inline"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "data": {
22 | "text/html": [
23 | "
\n",
24 | "\n",
37 | "
\n",
38 | " \n",
39 | " \n",
40 | " | \n",
41 | " id | \n",
42 | " text | \n",
43 | " anger | \n",
44 | " anticipation | \n",
45 | " disgust | \n",
46 | " fear | \n",
47 | " joy | \n",
48 | " love | \n",
49 | " optimism | \n",
50 | " pessimism | \n",
51 | " sadness | \n",
52 | " surprise | \n",
53 | " trust | \n",
54 | " neutral | \n",
55 | "
\n",
56 | " \n",
57 | " \n",
58 | " \n",
59 | " 0 | \n",
60 | " 0 | \n",
61 | " He was answering a question about the criticis... | \n",
62 | " 1 | \n",
63 | " 0 | \n",
64 | " 1 | \n",
65 | " 0 | \n",
66 | " 0 | \n",
67 | " 0 | \n",
68 | " 0 | \n",
69 | " 1 | \n",
70 | " 0 | \n",
71 | " 0 | \n",
72 | " 0 | \n",
73 | " 0 | \n",
74 | "
\n",
75 | " \n",
76 | " 1 | \n",
77 | " 1 | \n",
78 | " I'm going to start today's discussion thread w... | \n",
79 | " 1 | \n",
80 | " 1 | \n",
81 | " 1 | \n",
82 | " 1 | \n",
83 | " 0 | \n",
84 | " 0 | \n",
85 | " 0 | \n",
86 | " 1 | \n",
87 | " 0 | \n",
88 | " 0 | \n",
89 | " 0 | \n",
90 | " 0 | \n",
91 | "
\n",
92 | " \n",
93 | " 2 | \n",
94 | " 2 | \n",
95 | " By announcing the 395 self-quarantined, it pai... | \n",
96 | " 1 | \n",
97 | " 1 | \n",
98 | " 1 | \n",
99 | " 1 | \n",
100 | " 0 | \n",
101 | " 0 | \n",
102 | " 0 | \n",
103 | " 1 | \n",
104 | " 0 | \n",
105 | " 0 | \n",
106 | " 0 | \n",
107 | " 0 | \n",
108 | "
\n",
109 | " \n",
110 | " 3 | \n",
111 | " 3 | \n",
112 | " Likewise, sorry if I offended you. I’m not act... | \n",
113 | " 1 | \n",
114 | " 0 | \n",
115 | " 1 | \n",
116 | " 1 | \n",
117 | " 0 | \n",
118 | " 0 | \n",
119 | " 0 | \n",
120 | " 1 | \n",
121 | " 0 | \n",
122 | " 0 | \n",
123 | " 0 | \n",
124 | " 0 | \n",
125 | "
\n",
126 | " \n",
127 | " 4 | \n",
128 | " 4 | \n",
129 | " People infected by experience high fever, coug... | \n",
130 | " 0 | \n",
131 | " 0 | \n",
132 | " 0 | \n",
133 | " 0 | \n",
134 | " 0 | \n",
135 | " 0 | \n",
136 | " 0 | \n",
137 | " 0 | \n",
138 | " 0 | \n",
139 | " 0 | \n",
140 | " 0 | \n",
141 | " 1 | \n",
142 | "
\n",
143 | " \n",
144 | "
\n",
145 | "
"
146 | ],
147 | "text/plain": [
148 | " id text anger anticipation \\\n",
149 | "0 0 He was answering a question about the criticis... 1 0 \n",
150 | "1 1 I'm going to start today's discussion thread w... 1 1 \n",
151 | "2 2 By announcing the 395 self-quarantined, it pai... 1 1 \n",
152 | "3 3 Likewise, sorry if I offended you. I’m not act... 1 0 \n",
153 | "4 4 People infected by experience high fever, coug... 0 0 \n",
154 | "\n",
155 | " disgust fear joy love optimism pessimism sadness surprise trust \\\n",
156 | "0 1 0 0 0 0 1 0 0 0 \n",
157 | "1 1 1 0 0 0 1 0 0 0 \n",
158 | "2 1 1 0 0 0 1 0 0 0 \n",
159 | "3 1 1 0 0 0 1 0 0 0 \n",
160 | "4 0 0 0 0 0 0 0 0 0 \n",
161 | "\n",
162 | " neutral \n",
163 | "0 0 \n",
164 | "1 0 \n",
165 | "2 0 \n",
166 | "3 0 \n",
167 | "4 1 "
168 | ]
169 | },
170 | "execution_count": 2,
171 | "metadata": {},
172 | "output_type": "execute_result"
173 | }
174 | ],
175 | "source": [
176 | "# read dataset\n",
177 | "# we will be using trained dataset to understand how people are reacting\n",
178 | "data = pd.read_csv(\"nlp_train.csv\")\n",
179 | "data.head()"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": 3,
185 | "metadata": {},
186 | "outputs": [
187 | {
188 | "data": {
189 | "text/html": [
190 | "\n",
191 | "\n",
204 | "
\n",
205 | " \n",
206 | " \n",
207 | " | \n",
208 | " id | \n",
209 | " anger | \n",
210 | " anticipation | \n",
211 | " disgust | \n",
212 | " fear | \n",
213 | " joy | \n",
214 | " love | \n",
215 | " optimism | \n",
216 | " pessimism | \n",
217 | " sadness | \n",
218 | " surprise | \n",
219 | " trust | \n",
220 | " neutral | \n",
221 | "
\n",
222 | " \n",
223 | " \n",
224 | " \n",
225 | " count | \n",
226 | " 1493.000000 | \n",
227 | " 1493.000000 | \n",
228 | " 1493.000000 | \n",
229 | " 1493.000000 | \n",
230 | " 1493.000000 | \n",
231 | " 1493.000000 | \n",
232 | " 1493.000000 | \n",
233 | " 1493.000000 | \n",
234 | " 1493.000000 | \n",
235 | " 1493.000000 | \n",
236 | " 1493.000000 | \n",
237 | " 1493.000000 | \n",
238 | " 1493.000000 | \n",
239 | "
\n",
240 | " \n",
241 | " mean | \n",
242 | " 746.000000 | \n",
243 | " 0.364367 | \n",
244 | " 0.503014 | \n",
245 | " 0.454119 | \n",
246 | " 0.454119 | \n",
247 | " 0.123912 | \n",
248 | " 0.092431 | \n",
249 | " 0.328198 | \n",
250 | " 0.432686 | \n",
251 | " 0.277294 | \n",
252 | " 0.108506 | \n",
253 | " 0.168118 | \n",
254 | " 0.113195 | \n",
255 | "
\n",
256 | " \n",
257 | " std | \n",
258 | " 431.136289 | \n",
259 | " 0.481413 | \n",
260 | " 0.500158 | \n",
261 | " 0.498057 | \n",
262 | " 0.498057 | \n",
263 | " 0.329591 | \n",
264 | " 0.289731 | \n",
265 | " 0.469715 | \n",
266 | " 0.495614 | \n",
267 | " 0.447813 | \n",
268 | " 0.311123 | \n",
269 | " 0.374096 | \n",
270 | " 0.316937 | \n",
271 | "
\n",
272 | " \n",
273 | " min | \n",
274 | " 0.000000 | \n",
275 | " 0.000000 | \n",
276 | " 0.000000 | \n",
277 | " 0.000000 | \n",
278 | " 0.000000 | \n",
279 | " 0.000000 | \n",
280 | " 0.000000 | \n",
281 | " 0.000000 | \n",
282 | " 0.000000 | \n",
283 | " 0.000000 | \n",
284 | " 0.000000 | \n",
285 | " 0.000000 | \n",
286 | " 0.000000 | \n",
287 | "
\n",
288 | " \n",
289 | " 25% | \n",
290 | " 373.000000 | \n",
291 | " 0.000000 | \n",
292 | " 0.000000 | \n",
293 | " 0.000000 | \n",
294 | " 0.000000 | \n",
295 | " 0.000000 | \n",
296 | " 0.000000 | \n",
297 | " 0.000000 | \n",
298 | " 0.000000 | \n",
299 | " 0.000000 | \n",
300 | " 0.000000 | \n",
301 | " 0.000000 | \n",
302 | " 0.000000 | \n",
303 | "
\n",
304 | " \n",
305 | " 50% | \n",
306 | " 746.000000 | \n",
307 | " 0.000000 | \n",
308 | " 1.000000 | \n",
309 | " 0.000000 | \n",
310 | " 0.000000 | \n",
311 | " 0.000000 | \n",
312 | " 0.000000 | \n",
313 | " 0.000000 | \n",
314 | " 0.000000 | \n",
315 | " 0.000000 | \n",
316 | " 0.000000 | \n",
317 | " 0.000000 | \n",
318 | " 0.000000 | \n",
319 | "
\n",
320 | " \n",
321 | " 75% | \n",
322 | " 1119.000000 | \n",
323 | " 1.000000 | \n",
324 | " 1.000000 | \n",
325 | " 1.000000 | \n",
326 | " 1.000000 | \n",
327 | " 0.000000 | \n",
328 | " 0.000000 | \n",
329 | " 1.000000 | \n",
330 | " 1.000000 | \n",
331 | " 1.000000 | \n",
332 | " 0.000000 | \n",
333 | " 0.000000 | \n",
334 | " 0.000000 | \n",
335 | "
\n",
336 | " \n",
337 | " max | \n",
338 | " 1492.000000 | \n",
339 | " 1.000000 | \n",
340 | " 1.000000 | \n",
341 | " 1.000000 | \n",
342 | " 1.000000 | \n",
343 | " 1.000000 | \n",
344 | " 1.000000 | \n",
345 | " 1.000000 | \n",
346 | " 1.000000 | \n",
347 | " 1.000000 | \n",
348 | " 1.000000 | \n",
349 | " 1.000000 | \n",
350 | " 1.000000 | \n",
351 | "
\n",
352 | " \n",
353 | "
\n",
354 | "
"
355 | ],
356 | "text/plain": [
357 | " id anger anticipation disgust fear \\\n",
358 | "count 1493.000000 1493.000000 1493.000000 1493.000000 1493.000000 \n",
359 | "mean 746.000000 0.364367 0.503014 0.454119 0.454119 \n",
360 | "std 431.136289 0.481413 0.500158 0.498057 0.498057 \n",
361 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
362 | "25% 373.000000 0.000000 0.000000 0.000000 0.000000 \n",
363 | "50% 746.000000 0.000000 1.000000 0.000000 0.000000 \n",
364 | "75% 1119.000000 1.000000 1.000000 1.000000 1.000000 \n",
365 | "max 1492.000000 1.000000 1.000000 1.000000 1.000000 \n",
366 | "\n",
367 | " joy love optimism pessimism sadness \\\n",
368 | "count 1493.000000 1493.000000 1493.000000 1493.000000 1493.000000 \n",
369 | "mean 0.123912 0.092431 0.328198 0.432686 0.277294 \n",
370 | "std 0.329591 0.289731 0.469715 0.495614 0.447813 \n",
371 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
372 | "25% 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
373 | "50% 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
374 | "75% 0.000000 0.000000 1.000000 1.000000 1.000000 \n",
375 | "max 1.000000 1.000000 1.000000 1.000000 1.000000 \n",
376 | "\n",
377 | " surprise trust neutral \n",
378 | "count 1493.000000 1493.000000 1493.000000 \n",
379 | "mean 0.108506 0.168118 0.113195 \n",
380 | "std 0.311123 0.374096 0.316937 \n",
381 | "min 0.000000 0.000000 0.000000 \n",
382 | "25% 0.000000 0.000000 0.000000 \n",
383 | "50% 0.000000 0.000000 0.000000 \n",
384 | "75% 0.000000 0.000000 0.000000 \n",
385 | "max 1.000000 1.000000 1.000000 "
386 | ]
387 | },
388 | "execution_count": 3,
389 | "metadata": {},
390 | "output_type": "execute_result"
391 | }
392 | ],
393 | "source": [
394 | "#basic stats\n",
395 | "data.describe()"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 11,
401 | "metadata": {},
402 | "outputs": [
403 | {
404 | "data": {
405 | "text/plain": [
406 | "0 949\n",
407 | "1 544\n",
408 | "Name: anger, dtype: int64"
409 | ]
410 | },
411 | "execution_count": 11,
412 | "metadata": {},
413 | "output_type": "execute_result"
414 | }
415 | ],
416 | "source": [
417 | "data['anger'].value_counts()"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 15,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": []
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": 16,
430 | "metadata": {},
431 | "outputs": [
432 | {
433 | "data": {
434 | "text/plain": [
435 | "0 1003\n",
436 | "1 490\n",
437 | "Name: optimism, dtype: int64"
438 | ]
439 | },
440 | "execution_count": 16,
441 | "metadata": {},
442 | "output_type": "execute_result"
443 | }
444 | ],
445 | "source": [
446 | "freqs = {\"anger\":data['anger'].value_counts()[1]\n",
447 | "anticipation = data['anticipation'].value_counts()[1]\n",
448 | "data['disgust'].value_counts()[1]\n",
449 | "data['fear'].value_counts()[1]\n",
450 | "data['joy'].value_counts()[1]\n",
451 | "data['love'].value_counts()[1]\n",
452 | "data['optimism'].value_counts()[1]\n",
453 | "data['pessimism'].value_counts()[1]\n",
454 | "data['sadness'].value_counts()[1]\n",
455 | "data['surprise'].value_counts()[1]\n",
456 | "data['trust'].value_counts()[1]\n",
457 | "data['neutral'].value_counts()[1]"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": null,
463 | "metadata": {},
464 | "outputs": [],
465 | "source": []
466 | }
467 | ],
468 | "metadata": {
469 | "kernelspec": {
470 | "display_name": "Python 3",
471 | "language": "python",
472 | "name": "python3"
473 | },
474 | "language_info": {
475 | "codemirror_mode": {
476 | "name": "ipython",
477 | "version": 3
478 | },
479 | "file_extension": ".py",
480 | "mimetype": "text/x-python",
481 | "name": "python",
482 | "nbconvert_exporter": "python",
483 | "pygments_lexer": "ipython3",
484 | "version": "3.7.4"
485 | }
486 | },
487 | "nbformat": 4,
488 | "nbformat_minor": 2
489 | }
490 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/model_evaluation-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 2
6 | }
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Multi Emotion Detection from COVID-19 Text using BERT ##
2 |
3 | ### Requirements ###
4 |
5 | The code was tested with Python3.8 and PyTorch 1.5.0. Requirements can be installed with following command.
6 | ```
7 | pip install -r requirements.txt
8 | ```
9 | You may have to change package name (append +cpu) for non-cuda version.
10 |
11 | ### Data Preparation ###
12 | The model expects data to be in the csv file. So, you first need to convert json files into csv files. If you have data in any other format, you may need to modify the code.
13 |
14 | To generate the csv files run the following command.
15 |
16 | ```
17 | python .\data_generator.py --file=D:\UTD\Assignment\NLP\project\nlp_test.json --csvfile=D:\UTD\Assignment\NLP\project\nlp_test.csv
18 | ```
19 |
20 | Here ```--csvfile``` represents where to store the converted file.
21 |
22 |
23 | ### Training ###
24 | Once you have the files in the required format, you can start training. You may want to change the parameters. I tried with multiple parameters and the file contains ones that gave the best result.
25 | ```
26 | python train_bert.py --epochs=15
27 | ```
28 |
29 | You can find all available options by running following command.
30 |
31 | ```
32 | python train_bert.py --help
33 | ```
34 | Following graphs shows loss (1) and learning rate (2) over time.
35 |
36 |
37 |
38 |
39 |
40 |
41 | ### Inference ###
42 | Once you have the trained model, you can run the inference on test csv files. Note that as of now, this script requires annotated data to compute the metrics. But it can easily be modified to generate output only.
43 |
44 | Pretrained model can be found [here](https://utdallas.box.com/s/sqqb0n9qe7txb6j3725aiz76gwlmszuw)
45 |
46 | ```
47 | python inference.py --test_csv=D:\\UTD\\Assignment\\NLP\\project\\nlp_valid.csv --model_dir=D:\\UTD\\Assignment\\NLP\\project\\model_output\\3_finetune_e20
48 | ```
49 |
50 | If ```--evaluation``` is set to true, it will output various metrics.
51 |
52 | **Threshold**
53 | One important factor here is to find the optimal threshold for the confidence score. I tested various threshold and found 0.0017 to give the best results for the above specified model. If you train your own model, you may want to run ```find_threshold.py``` to find the best threshold.
54 |
55 | ### Model Evaluation ###
56 |
57 | If you want to evaluate your model on test or train set, you can do so by running following command. Note that the file must be in the csv format.
58 |
59 | ```
60 | python inference.py --test_csv=D:\\UTD\\Assignment\\NLP\\project\\nlp_test.csv --evaluation=True
61 | ```
62 |
63 | I ran the evaluation on train set and found following information.
64 |
65 |
66 | Emotion | Precision | Recall | f1-score
67 | ---------------|---------------|------------|---------------
68 | Anger | 0.97 | 0.95 | 0.96
69 | Anticipation | 0.98 | 1.00 | 0.99
70 | Disgust | 0.97 | 0.96 | 0.97
71 | Fear | 1.00 | 1.00 | 1.00
72 | Joy | 0.93 | 0.93 | 0.93
73 | Love | 1.00 | 0.73 | 0.85
74 | Optimism | 0.91 | 1.00 | 0.95
75 | Pessimism | 0.98 | 0.94 | 0.96
76 | Sadness | 0.99 | 0.92 | 0.95
77 | Suprise | 0.98 | 0.92 | 0.95
78 | Trust | 0.95 | 0.88 | 0.91
79 | Neutral | 0.92 | 1.00 | 0.96
80 | Average | 0.98 | 0.97 | 0.97
81 |
82 | Following table shows information about test set.
83 |
84 | Emotion | Precision | Recall | f1-score
85 | ---------------|---------------|------------|---------------
86 | Anger | 0.53 | 0.67 | 0.59
87 | Anticipation | 0.67 | 0.68 | 0.68
88 | Disgust | 0.62 | 0.78 | 0.69
89 | Fear | 0.69 | 0.72 | 0.71
90 | Joy | 0.50 | 0.27 | 0.35
91 | Love | 0.32 | 0.37 | 0.34
92 | Optimism | 0.37 | 0.59 | 0.45
93 | Pessimism | 0.44 | 0.73 | 0.55
94 | Sadness | 0.42 | 0.53 | 0.47
95 | Suprise | 0.55 | 0.38 | 0.45
96 | Trust | 0.12 | 0.12 | 0.12
97 | Neutral | 0.37 | 0.37 | 0.37
98 | Average | 0.60 | 0.62 | 0.55
99 |
100 | If we set threshold to 0.02 then the average accuracy is 0.66.
101 |
102 | ### Possible Improvements
103 |
104 | 
105 |
106 | The biggest caveat here was the class imbalance. It has been established that the class imabalance can negatively affect our model. So, it is always a good idea to balance our data before training the model. Due to limited time, I didn't do any of that stuff. But ideally, we want to oversample from minority classes or pass weights to the loss function. I implemented both approaches for image classification [here](github.com/savan77/Transfer-Learning).
107 |
108 | As we can see above, negative emotions such as an anger or pessimism have bigger representation in the data compare to happy emotions. This makes sense, but in order to train a good model we should even it out.
109 |
110 | Moreover, we can play with the network. Here, the standard off-the-shelf text classification network was used. Maybe adding more fully connected layer on top of the BERT may help.
--------------------------------------------------------------------------------
/assets/class_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/savan77/EmotionDetectionBERT/39f859d48250d84e0cef7d1fb9163c37afe6dbfa/assets/class_distribution.png
--------------------------------------------------------------------------------
/assets/loss.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/lr.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/nlp_report.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/savan77/EmotionDetectionBERT/39f859d48250d84e0cef7d1fb9163c37afe6dbfa/assets/nlp_report.pdf
--------------------------------------------------------------------------------
/cache/cached_bert_dev_multi_label_512_nlp_valid.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/savan77/EmotionDetectionBERT/39f859d48250d84e0cef7d1fb9163c37afe6dbfa/cache/cached_bert_dev_multi_label_512_nlp_valid.csv
--------------------------------------------------------------------------------
/cache/cached_bert_train_multi_label_512_nlp_train.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/savan77/EmotionDetectionBERT/39f859d48250d84e0cef7d1fb9163c37afe6dbfa/cache/cached_bert_train_multi_label_512_nlp_train.csv
--------------------------------------------------------------------------------
/data_exploration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import matplotlib\n",
11 | "from matplotlib import pyplot as plt\n",
12 | "%matplotlib inline"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "data": {
22 | "text/html": [
23 | "\n",
24 | "\n",
37 | "
\n",
38 | " \n",
39 | " \n",
40 | " | \n",
41 | " id | \n",
42 | " text | \n",
43 | " anger | \n",
44 | " anticipation | \n",
45 | " disgust | \n",
46 | " fear | \n",
47 | " joy | \n",
48 | " love | \n",
49 | " optimism | \n",
50 | " pessimism | \n",
51 | " sadness | \n",
52 | " surprise | \n",
53 | " trust | \n",
54 | " neutral | \n",
55 | "
\n",
56 | " \n",
57 | " \n",
58 | " \n",
59 | " 0 | \n",
60 | " 0 | \n",
61 | " He was answering a question about the criticis... | \n",
62 | " 1 | \n",
63 | " 0 | \n",
64 | " 1 | \n",
65 | " 0 | \n",
66 | " 0 | \n",
67 | " 0 | \n",
68 | " 0 | \n",
69 | " 1 | \n",
70 | " 0 | \n",
71 | " 0 | \n",
72 | " 0 | \n",
73 | " 0 | \n",
74 | "
\n",
75 | " \n",
76 | " 1 | \n",
77 | " 1 | \n",
78 | " I'm going to start today's discussion thread w... | \n",
79 | " 1 | \n",
80 | " 1 | \n",
81 | " 1 | \n",
82 | " 1 | \n",
83 | " 0 | \n",
84 | " 0 | \n",
85 | " 0 | \n",
86 | " 1 | \n",
87 | " 0 | \n",
88 | " 0 | \n",
89 | " 0 | \n",
90 | " 0 | \n",
91 | "
\n",
92 | " \n",
93 | " 2 | \n",
94 | " 2 | \n",
95 | " By announcing the 395 self-quarantined, it pai... | \n",
96 | " 1 | \n",
97 | " 1 | \n",
98 | " 1 | \n",
99 | " 1 | \n",
100 | " 0 | \n",
101 | " 0 | \n",
102 | " 0 | \n",
103 | " 1 | \n",
104 | " 0 | \n",
105 | " 0 | \n",
106 | " 0 | \n",
107 | " 0 | \n",
108 | "
\n",
109 | " \n",
110 | " 3 | \n",
111 | " 3 | \n",
112 | " Likewise, sorry if I offended you. I’m not act... | \n",
113 | " 1 | \n",
114 | " 0 | \n",
115 | " 1 | \n",
116 | " 1 | \n",
117 | " 0 | \n",
118 | " 0 | \n",
119 | " 0 | \n",
120 | " 1 | \n",
121 | " 0 | \n",
122 | " 0 | \n",
123 | " 0 | \n",
124 | " 0 | \n",
125 | "
\n",
126 | " \n",
127 | " 4 | \n",
128 | " 4 | \n",
129 | " People infected by experience high fever, coug... | \n",
130 | " 0 | \n",
131 | " 0 | \n",
132 | " 0 | \n",
133 | " 0 | \n",
134 | " 0 | \n",
135 | " 0 | \n",
136 | " 0 | \n",
137 | " 0 | \n",
138 | " 0 | \n",
139 | " 0 | \n",
140 | " 0 | \n",
141 | " 1 | \n",
142 | "
\n",
143 | " \n",
144 | "
\n",
145 | "
"
146 | ],
147 | "text/plain": [
148 | " id text anger anticipation \\\n",
149 | "0 0 He was answering a question about the criticis... 1 0 \n",
150 | "1 1 I'm going to start today's discussion thread w... 1 1 \n",
151 | "2 2 By announcing the 395 self-quarantined, it pai... 1 1 \n",
152 | "3 3 Likewise, sorry if I offended you. I’m not act... 1 0 \n",
153 | "4 4 People infected by experience high fever, coug... 0 0 \n",
154 | "\n",
155 | " disgust fear joy love optimism pessimism sadness surprise trust \\\n",
156 | "0 1 0 0 0 0 1 0 0 0 \n",
157 | "1 1 1 0 0 0 1 0 0 0 \n",
158 | "2 1 1 0 0 0 1 0 0 0 \n",
159 | "3 1 1 0 0 0 1 0 0 0 \n",
160 | "4 0 0 0 0 0 0 0 0 0 \n",
161 | "\n",
162 | " neutral \n",
163 | "0 0 \n",
164 | "1 0 \n",
165 | "2 0 \n",
166 | "3 0 \n",
167 | "4 1 "
168 | ]
169 | },
170 | "execution_count": 2,
171 | "metadata": {},
172 | "output_type": "execute_result"
173 | }
174 | ],
175 | "source": [
176 | "# read dataset\n",
177 | "# we will be using trained dataset to understand how people are reacting\n",
178 | "data = pd.read_csv(\"nlp_train.csv\")\n",
179 | "data.head()"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": 3,
185 | "metadata": {},
186 | "outputs": [
187 | {
188 | "data": {
189 | "text/html": [
190 | "\n",
191 | "\n",
204 | "
\n",
205 | " \n",
206 | " \n",
207 | " | \n",
208 | " id | \n",
209 | " anger | \n",
210 | " anticipation | \n",
211 | " disgust | \n",
212 | " fear | \n",
213 | " joy | \n",
214 | " love | \n",
215 | " optimism | \n",
216 | " pessimism | \n",
217 | " sadness | \n",
218 | " surprise | \n",
219 | " trust | \n",
220 | " neutral | \n",
221 | "
\n",
222 | " \n",
223 | " \n",
224 | " \n",
225 | " count | \n",
226 | " 1493.000000 | \n",
227 | " 1493.000000 | \n",
228 | " 1493.000000 | \n",
229 | " 1493.000000 | \n",
230 | " 1493.000000 | \n",
231 | " 1493.000000 | \n",
232 | " 1493.000000 | \n",
233 | " 1493.000000 | \n",
234 | " 1493.000000 | \n",
235 | " 1493.000000 | \n",
236 | " 1493.000000 | \n",
237 | " 1493.000000 | \n",
238 | " 1493.000000 | \n",
239 | "
\n",
240 | " \n",
241 | " mean | \n",
242 | " 746.000000 | \n",
243 | " 0.364367 | \n",
244 | " 0.503014 | \n",
245 | " 0.454119 | \n",
246 | " 0.454119 | \n",
247 | " 0.123912 | \n",
248 | " 0.092431 | \n",
249 | " 0.328198 | \n",
250 | " 0.432686 | \n",
251 | " 0.277294 | \n",
252 | " 0.108506 | \n",
253 | " 0.168118 | \n",
254 | " 0.113195 | \n",
255 | "
\n",
256 | " \n",
257 | " std | \n",
258 | " 431.136289 | \n",
259 | " 0.481413 | \n",
260 | " 0.500158 | \n",
261 | " 0.498057 | \n",
262 | " 0.498057 | \n",
263 | " 0.329591 | \n",
264 | " 0.289731 | \n",
265 | " 0.469715 | \n",
266 | " 0.495614 | \n",
267 | " 0.447813 | \n",
268 | " 0.311123 | \n",
269 | " 0.374096 | \n",
270 | " 0.316937 | \n",
271 | "
\n",
272 | " \n",
273 | " min | \n",
274 | " 0.000000 | \n",
275 | " 0.000000 | \n",
276 | " 0.000000 | \n",
277 | " 0.000000 | \n",
278 | " 0.000000 | \n",
279 | " 0.000000 | \n",
280 | " 0.000000 | \n",
281 | " 0.000000 | \n",
282 | " 0.000000 | \n",
283 | " 0.000000 | \n",
284 | " 0.000000 | \n",
285 | " 0.000000 | \n",
286 | " 0.000000 | \n",
287 | "
\n",
288 | " \n",
289 | " 25% | \n",
290 | " 373.000000 | \n",
291 | " 0.000000 | \n",
292 | " 0.000000 | \n",
293 | " 0.000000 | \n",
294 | " 0.000000 | \n",
295 | " 0.000000 | \n",
296 | " 0.000000 | \n",
297 | " 0.000000 | \n",
298 | " 0.000000 | \n",
299 | " 0.000000 | \n",
300 | " 0.000000 | \n",
301 | " 0.000000 | \n",
302 | " 0.000000 | \n",
303 | "
\n",
304 | " \n",
305 | " 50% | \n",
306 | " 746.000000 | \n",
307 | " 0.000000 | \n",
308 | " 1.000000 | \n",
309 | " 0.000000 | \n",
310 | " 0.000000 | \n",
311 | " 0.000000 | \n",
312 | " 0.000000 | \n",
313 | " 0.000000 | \n",
314 | " 0.000000 | \n",
315 | " 0.000000 | \n",
316 | " 0.000000 | \n",
317 | " 0.000000 | \n",
318 | " 0.000000 | \n",
319 | "
\n",
320 | " \n",
321 | " 75% | \n",
322 | " 1119.000000 | \n",
323 | " 1.000000 | \n",
324 | " 1.000000 | \n",
325 | " 1.000000 | \n",
326 | " 1.000000 | \n",
327 | " 0.000000 | \n",
328 | " 0.000000 | \n",
329 | " 1.000000 | \n",
330 | " 1.000000 | \n",
331 | " 1.000000 | \n",
332 | " 0.000000 | \n",
333 | " 0.000000 | \n",
334 | " 0.000000 | \n",
335 | "
\n",
336 | " \n",
337 | " max | \n",
338 | " 1492.000000 | \n",
339 | " 1.000000 | \n",
340 | " 1.000000 | \n",
341 | " 1.000000 | \n",
342 | " 1.000000 | \n",
343 | " 1.000000 | \n",
344 | " 1.000000 | \n",
345 | " 1.000000 | \n",
346 | " 1.000000 | \n",
347 | " 1.000000 | \n",
348 | " 1.000000 | \n",
349 | " 1.000000 | \n",
350 | " 1.000000 | \n",
351 | "
\n",
352 | " \n",
353 | "
\n",
354 | "
"
355 | ],
356 | "text/plain": [
357 | " id anger anticipation disgust fear \\\n",
358 | "count 1493.000000 1493.000000 1493.000000 1493.000000 1493.000000 \n",
359 | "mean 746.000000 0.364367 0.503014 0.454119 0.454119 \n",
360 | "std 431.136289 0.481413 0.500158 0.498057 0.498057 \n",
361 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
362 | "25% 373.000000 0.000000 0.000000 0.000000 0.000000 \n",
363 | "50% 746.000000 0.000000 1.000000 0.000000 0.000000 \n",
364 | "75% 1119.000000 1.000000 1.000000 1.000000 1.000000 \n",
365 | "max 1492.000000 1.000000 1.000000 1.000000 1.000000 \n",
366 | "\n",
367 | " joy love optimism pessimism sadness \\\n",
368 | "count 1493.000000 1493.000000 1493.000000 1493.000000 1493.000000 \n",
369 | "mean 0.123912 0.092431 0.328198 0.432686 0.277294 \n",
370 | "std 0.329591 0.289731 0.469715 0.495614 0.447813 \n",
371 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
372 | "25% 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
373 | "50% 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
374 | "75% 0.000000 0.000000 1.000000 1.000000 1.000000 \n",
375 | "max 1.000000 1.000000 1.000000 1.000000 1.000000 \n",
376 | "\n",
377 | " surprise trust neutral \n",
378 | "count 1493.000000 1493.000000 1493.000000 \n",
379 | "mean 0.108506 0.168118 0.113195 \n",
380 | "std 0.311123 0.374096 0.316937 \n",
381 | "min 0.000000 0.000000 0.000000 \n",
382 | "25% 0.000000 0.000000 0.000000 \n",
383 | "50% 0.000000 0.000000 0.000000 \n",
384 | "75% 0.000000 0.000000 0.000000 \n",
385 | "max 1.000000 1.000000 1.000000 "
386 | ]
387 | },
388 | "execution_count": 3,
389 | "metadata": {},
390 | "output_type": "execute_result"
391 | }
392 | ],
393 | "source": [
394 | "#basic stats\n",
395 | "data.describe()"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 11,
401 | "metadata": {},
402 | "outputs": [
403 | {
404 | "data": {
405 | "text/plain": [
406 | "0 949\n",
407 | "1 544\n",
408 | "Name: anger, dtype: int64"
409 | ]
410 | },
411 | "execution_count": 11,
412 | "metadata": {},
413 | "output_type": "execute_result"
414 | }
415 | ],
416 | "source": [
417 | "data['anger'].value_counts()"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 15,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": []
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": 17,
430 | "metadata": {},
431 | "outputs": [],
432 | "source": [
433 | "freqs = {\"anger\":data['anger'].value_counts()[1],\n",
434 | "\"anticipation\": data['anticipation'].value_counts()[1],\n",
435 | "\"disgust\":data['disgust'].value_counts()[1],\n",
436 | "\"fear\":data['fear'].value_counts()[1],\n",
437 | "\"joy\":data['joy'].value_counts()[1],\n",
438 | "\"love\":data['love'].value_counts()[1],\n",
439 | "\"optimism\":data['optimism'].value_counts()[1],\n",
440 | "\"pessimism\":data['pessimism'].value_counts()[1],\n",
441 | "\"sadness\":data['sadness'].value_counts()[1],\n",
442 | "\"surprise\":data['surprise'].value_counts()[1],\n",
443 | "\"trust\":data['trust'].value_counts()[1],\n",
444 | "\"neutral\":data['neutral'].value_counts()[1]}"
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "execution_count": 18,
450 | "metadata": {},
451 | "outputs": [
452 | {
453 | "data": {
454 | "text/plain": [
455 | "{'anger': 544,\n",
456 | " 'anticipation': 751,\n",
457 | " 'disgust': 678,\n",
458 | " 'fear': 678,\n",
459 | " 'joy': 185,\n",
460 | " 'love': 138,\n",
461 | " 'optimism': 490,\n",
462 | " 'pessimism': 646,\n",
463 | " 'sadness': 414,\n",
464 | " 'surprise': 162,\n",
465 | " 'trust': 251,\n",
466 | " 'neutral': 169}"
467 | ]
468 | },
469 | "execution_count": 18,
470 | "metadata": {},
471 | "output_type": "execute_result"
472 | }
473 | ],
474 | "source": [
475 | "freqs"
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": 85,
481 | "metadata": {},
482 | "outputs": [
483 | {
484 | "data": {
485 | "image/png": "\n",
486 | "text/plain": [
487 | ""
488 | ]
489 | },
490 | "metadata": {
491 | "needs_background": "light"
492 | },
493 | "output_type": "display_data"
494 | }
495 | ],
496 | "source": [
497 | "#plot class distribution\n",
498 | "plt.rcParams['figure.figsize'] = [13, 5]\n",
499 | "plt.bar(freqs.keys(), freqs.values(), width=0.3,color='skyblue')\n",
500 | "plt.text(10,700,\"Target Class Distribution\", fontsize=15, ha='center', va='center')\n",
501 | "t = plt.title(\"Emotions from COVID-19 Related Text\", fontsize=18)\n",
502 | "# t.set_color(\"m\")\n",
503 | "x = plt.xlabel(\"Emotion\", fontsize=13)\n",
504 | "x.set_color('g')\n",
505 | "y = plt.ylabel(\"Frequency\", fontsize=13)\n",
506 | "y.set_color('g')\n",
507 | "plt.savefig('class_distribution.png')\n",
508 | "# [i.set_color(\"c\") for i in plt.gca().get_xticklabels()]\n",
509 | "plt.show()"
510 | ]
511 | },
512 | {
513 | "cell_type": "code",
514 | "execution_count": 86,
515 | "metadata": {},
516 | "outputs": [],
517 | "source": [
518 | "# there is a clear class imbalance"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": null,
524 | "metadata": {},
525 | "outputs": [],
526 | "source": []
527 | }
528 | ],
529 | "metadata": {
530 | "kernelspec": {
531 | "display_name": "Python 3",
532 | "language": "python",
533 | "name": "python3"
534 | },
535 | "language_info": {
536 | "codemirror_mode": {
537 | "name": "ipython",
538 | "version": 3
539 | },
540 | "file_extension": ".py",
541 | "mimetype": "text/x-python",
542 | "name": "python",
543 | "nbconvert_exporter": "python",
544 | "pygments_lexer": "ipython3",
545 | "version": "3.7.4"
546 | }
547 | },
548 | "nbformat": 4,
549 | "nbformat_minor": 2
550 | }
551 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | ###
2 | # Generate data (in csv format) to train the BERT model
3 |
4 | import csv
5 | import json
6 | import argparse
7 | import os
8 |
9 | # store emotions as a one hot encodings
10 | def create_model(llabels):
11 | llist = [0]* 12
12 | for em, val in enumerate(llabels.values()):
13 | llist[em] = 1 if val else 0
14 | return llist
15 |
16 | # generate and write to csv file
17 | def generate_csv(file, csvfile):
18 | data= open(file,"r")
19 | out = open(csvfile, "w", encoding="utf-8", newline="")
20 | writer = csv.writer(out)
21 | writer.writerow(["id", "text", "anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"])
22 | data = json.load(data)
23 | idd = 0
24 | for i,v in data.items():
25 | bin_vector = create_model(v['emotion'])
26 | writer.writerow([idd,v['body']]+bin_vector)
27 | idd += 1
28 |
29 |
30 | if __name__ == "__main__":
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument("--file", default="nlp_train.json", type=str)
33 | parser.add_argument("--csvfile", default="nlp_train.csv", type=str)
34 | args = parser.parse_args()
35 | generate_csv(args.file, args.csvfile)
--------------------------------------------------------------------------------
/find_threshold.py:
--------------------------------------------------------------------------------
1 | ######
2 | # Script to find best possible threshold for the confidence.
3 | #
4 |
5 | from fast_bert.prediction import BertClassificationPredictor
6 | import argparse
7 | import csv
8 | import pandas as pd
9 |
10 | def threshold(model, csvs):
11 | labels = ["anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"]
12 |
13 | predictor = BertClassificationPredictor(
14 | model_path=args.model_dir,
15 | label_path="D:\\UTD\\Assignment\\NLP\\project\\", # location for labels.csv file
16 | multi_label=False,
17 | model_type='bert',
18 | do_lower_case=False)
19 | thresholds = [0.0005,0.00077,0.00079,0.00083,0.00087,0.0009,0.00093,0.00095,0.00099,0.001,0.0012,0.0015,0.00155,0.0016,0.00166,0.0017,0.0019,0.002,0.0021,0.0023,0.0025,0.0028,0.003,0.0035,0.0032,0.0037,0.004,0.0045,0.0047,0.0041,0.005,0.0053,0.0055,0.0062,0.009, 0.007, 0.01, 0.011,0.013,0.014,0.012, 0.015, 0.02, 0.25, 0.03,0.035,0.039]
20 | # targets = []
21 | inputs = {}
22 | data = pd.read_csv(csvs)
23 | # print(data.head())
24 | for idx, row in data.iterrows():
25 | temp = []
26 | for label in labels:
27 | if row[label] == 1:
28 | temp.append(label)
29 | inputs[row['text']] = temp
30 |
31 | multiple_predictions = predictor.predict_batch(list(inputs.keys()))
32 | threshold_accs = {}
33 |
34 | for th in thresholds:
35 | correct = 0
36 | # print(list(inputs.values())[0])
37 | outputs = []
38 | for out in multiple_predictions:
39 | temp = []
40 | for emotion in out:
41 | if emotion[1] >= th: # greater than threshold
42 | temp.append(emotion[0])
43 | outputs.append(temp)
44 | # print(outputs[0])
45 | for i in range(len(inputs)):
46 | if (set(outputs[i]) == set(list(inputs.values())[i])):
47 | correct += 1
48 | print("Threshold: ", th, "Correct: ", correct)
49 | threshold_accs[str(th)] = correct/len(inputs)
50 | print(threshold_accs)
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--model_dir",default="D:\\UTD\\Assignment\\NLP\\project\\model_output\\3_finetune_e20", help="path to output dir")
55 | parser.add_argument("--test_csv", default="D:\\UTD\\Assignment\\NLP\\project\\nlp_test.csv")
56 | args = parser.parse_args()
57 | threshold(args.model_dir, args.test_csv)
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # Script to generate inference ofr a given csv file
2 |
3 | from fast_bert.prediction import BertClassificationPredictor
4 | import argparse
5 | import csv
6 | import pandas as pd
7 | import os
8 | from sklearn.metrics import classification_report
9 | from sklearn.preprocessing import MultiLabelBinarizer
10 | from pprint import pprint
11 |
12 | # run inference on the csv file provided using the trained model
13 | def run(model,csvs, threshold, evaluation):
14 | labels = ["anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"]
15 |
16 | predictor = BertClassificationPredictor(
17 | model_path=args.model_dir,
18 | label_path="D:\\UTD\\Assignment\\NLP\\project\\", # location for labels.csv file
19 | multi_label=False,
20 | model_type='bert',
21 | do_lower_case=False)
22 |
23 | inputs = {}
24 | ids = []
25 | data = pd.read_csv(csvs)
26 | # print(data.head())
27 | for idx, row in data.iterrows():
28 | temp = []
29 | for label in labels:
30 | if row[label] == 1:
31 | temp.append(label)
32 | inputs[row['text']] = temp
33 | ids.append(row['id'])
34 |
35 | multiple_predictions = predictor.predict_batch(list(inputs.keys()))
36 | outputs = []
37 | out_file = open(os.path.join(os.path.dirname(csvs),"model_output.csv"), "w", encoding="utf-8", newline="")
38 | csv_writer = csv.writer(out_file)
39 | csv_writer.writerow(["id","text", "emotions", "target"])
40 |
41 | for i, out in enumerate(multiple_predictions):
42 | temp = []
43 | for emotion in out:
44 | if emotion[1] > threshold: # greater than threshold
45 | temp.append(emotion[0])
46 | csv_writer.writerow([ids[i],list(inputs.keys())[i],temp,list(inputs.values())[i] ])
47 | outputs.append(temp)
48 |
49 | print("****************\n")
50 | print("Predictions saved in a file: ", os.path.join(os.path.dirname(csvs),"model_output.csv"))
51 | if evaluation:
52 | print("\n\n Running Model Evaluation\n")
53 | y_true = list(inputs.values())
54 | y_pred = outputs
55 | y_true_encoded = MultiLabelBinarizer().fit_transform(y_true)
56 | y_pred_encoded = MultiLabelBinarizer().fit_transform(y_pred)
57 | pprint(classification_report(y_true_encoded, y_pred_encoded))
58 | pprint(classification_report(y_true_encoded, y_pred_encoded, target_names=labels))
59 |
60 | if __name__ == "__main__":
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument("--model_dir",default="D:\\UTD\\Assignment\\NLP\\project\\model_output\\3_finetune_e20", help="path to output dir")
63 | parser.add_argument("--test_csv", default="D:\\UTD\\Assignment\\NLP\\project\\nlp_test.csv")
64 | parser.add_argument("--threshold", default=0.0017, type=float)
65 | parser.add_argument("--writeto_file", default=True)
66 | parser.add_argument("--evaluation", default=True)
67 | args = parser.parse_args()
68 | run(args.model_dir, args.test_csv, args.threshold, args.evaluation)
--------------------------------------------------------------------------------
/labels.csv:
--------------------------------------------------------------------------------
1 | anger
2 | anticipation
3 | disgust
4 | fear
5 | joy
6 | love
7 | optimism
8 | pessimism
9 | sadness
10 | surprise
11 | trust
12 | neutral
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch===1.5.0
2 | torchvision===0.6.0
3 | fast-bert
--------------------------------------------------------------------------------
/train_bert.py:
--------------------------------------------------------------------------------
1 | # Training script for bert
2 |
3 | from fast_bert.data_cls import BertDataBunch
4 | from fast_bert.learner_cls import BertLearner
5 | from fast_bert.metrics import accuracy
6 | import logging
7 | import torch
8 | import os
9 | import argparse
10 |
11 | OUTPUT_DIR = "model_output/"
12 |
13 | def train(args):
14 | if args.is_onepanel:
15 | args.out_dir = os.path.join("/onepanel/output/",args.out_dir)
16 | if not os.path.exists(args.out_dir):
17 | os.mkdir(args.out_dir)
18 |
19 | logger = logging.getLogger()
20 | labels = ["anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"]
21 | databunch = BertDataBunch(".", ".",
22 | tokenizer=args.pretrained_model,
23 | train_file='nlp_train.csv',
24 | label_file='labels.csv',
25 | val_file="nlp_valid.csv",
26 | text_col='text',
27 | label_col=labels,
28 | batch_size_per_gpu=args.batch_size,
29 | max_seq_length=512,
30 | multi_gpu=False,
31 | multi_label=True,
32 | model_type='bert')
33 |
34 | device_cuda = torch.device("cuda")
35 | metrics = [{'name': 'accuracy', 'function': accuracy}]
36 |
37 | learner = BertLearner.from_pretrained_model(
38 | databunch,
39 | pretrained_path=args.pretrained_model,
40 | metrics=metrics,
41 | device=device_cuda,
42 | logger=logger,
43 | output_dir=args.out_dir,
44 | finetuned_wgts_path=None,
45 | warmup_steps=200,
46 | multi_gpu=False,
47 | is_fp16=False,
48 | multi_label=True,
49 | logging_steps=10)
50 |
51 | learner.fit(epochs=args.epochs,
52 | lr=2e-3,
53 | schedule_type="warmup_cosine_hard_restarts",
54 | optimizer_type="lamb")
55 | # validate=True)
56 | learner.save_model()
57 |
58 |
59 | if __name__ == "__main__":
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("--pretrained_model", default="bert-base-uncased", help="path to a pretrained model")
62 | parser.add_argument("--out_dir",default="model_output/", help="path to output dir")
63 | parser.add_argument("--is_onepanel", default=False, type=bool, help="train on onepanel cloud")
64 | parser.add_argument("--epochs", default=15, type=int)
65 | parser.add_argument("--batch_size", default=10, type=int)
66 | args = parser.parse_args()
67 | train(args)
68 |
--------------------------------------------------------------------------------