└── Arabizi_KT_CV82091_LB_8286G.ipynb
/Arabizi_KT_CV82091_LB_8286G.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Arabizi_KT_CV82091_LB_8286G.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "accelerator": "GPU"
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "code",
19 | "metadata": {
20 | "colab": {
21 | "base_uri": "https://localhost:8080/"
22 | },
23 | "id": "ojSAlH9W_gfv",
24 | "outputId": "6bed4e5e-d511-4ccb-e09e-179f561162cc"
25 | },
26 | "source": [
27 | "# Check GPU type\r\n",
28 | "!nvidia-smi"
29 | ],
30 | "execution_count": 1,
31 | "outputs": [
32 | {
33 | "output_type": "stream",
34 | "text": [
35 | "Tue Mar 2 07:08:47 2021 \n",
36 | "+-----------------------------------------------------------------------------+\n",
37 | "| NVIDIA-SMI 460.39 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
38 | "|-------------------------------+----------------------+----------------------+\n",
39 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
40 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
41 | "| | | MIG M. |\n",
42 | "|===============================+======================+======================|\n",
43 | "| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n",
44 | "| N/A 34C P0 24W / 300W | 0MiB / 16160MiB | 0% Default |\n",
45 | "| | | N/A |\n",
46 | "+-------------------------------+----------------------+----------------------+\n",
47 | " \n",
48 | "+-----------------------------------------------------------------------------+\n",
49 | "| Processes: |\n",
50 | "| GPU GI CI PID Type Process name GPU Memory |\n",
51 | "| ID ID Usage |\n",
52 | "|=============================================================================|\n",
53 | "| No running processes found |\n",
54 | "+-----------------------------------------------------------------------------+\n"
55 | ],
56 | "name": "stdout"
57 | }
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "metadata": {
63 | "colab": {
64 | "base_uri": "https://localhost:8080/"
65 | },
66 | "id": "5016OlJnp2kC",
67 | "outputId": "1310ee54-5823-4d7d-a882-3ff37756f37e"
68 | },
69 | "source": [
70 | "# Upgrade pip and install ktrain\r\n",
71 | "!pip -qq install -U pip\r\n",
72 | "!pip -qq install ktrain"
73 | ],
74 | "execution_count": 2,
75 | "outputs": [
76 | {
77 | "output_type": "stream",
78 | "text": [
79 | "\u001b[K |████████████████████████████████| 1.5MB 5.4MB/s \n",
80 | "\u001b[K |████████████████████████████████| 25.3 MB 94.2 MB/s \n",
81 | "\u001b[K |████████████████████████████████| 6.8 MB 61.9 MB/s \n",
82 | "\u001b[K |████████████████████████████████| 981 kB 56.6 MB/s \n",
83 | "\u001b[K |████████████████████████████████| 263 kB 58.2 MB/s \n",
84 | "\u001b[K |████████████████████████████████| 1.3 MB 58.2 MB/s \n",
85 | "\u001b[K |████████████████████████████████| 1.2 MB 60.2 MB/s \n",
86 | "\u001b[K |████████████████████████████████| 468 kB 27.6 MB/s \n",
87 | "\u001b[K |████████████████████████████████| 1.1 MB 60.4 MB/s \n",
88 | "\u001b[K |████████████████████████████████| 883 kB 60.1 MB/s \n",
89 | "\u001b[K |████████████████████████████████| 2.9 MB 65.3 MB/s \n",
90 | "\u001b[?25h Building wheel for ktrain (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
91 | " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
92 | " Building wheel for keras-bert (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
93 | " Building wheel for keras-embed-sim (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
94 | " Building wheel for keras-layer-normalization (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
95 | " Building wheel for keras-multi-head (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
96 | " Building wheel for keras-self-attention (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
97 | " Building wheel for keras-pos-embd (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
98 | " Building wheel for keras-position-wise-feed-forward (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
99 | " Building wheel for langdetect (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
100 | " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
101 | " Building wheel for syntok (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
102 | ],
103 | "name": "stdout"
104 | }
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "metadata": {
110 | "id": "qDHDhHzWrDmm",
111 | "colab": {
112 | "base_uri": "https://localhost:8080/"
113 | },
114 | "outputId": "1c15ae83-3a35-49b0-a6e7-bd9088176a8f"
115 | },
116 | "source": [
117 | "!gdown --id 1LZBMbdMAr8iwmNfN2JkiFw-uGPleBOsf\r\n",
118 | "!unzip -q '/content/Arabizi_data.zip'"
119 | ],
120 | "execution_count": 3,
121 | "outputs": [
122 | {
123 | "output_type": "stream",
124 | "text": [
125 | "Downloading...\n",
126 | "From: https://drive.google.com/uc?id=1LZBMbdMAr8iwmNfN2JkiFw-uGPleBOsf\n",
127 | "To: /content/Arabizi_data.zip\n",
128 | "\r0.00B [00:00, ?B/s]\r3.67MB [00:00, 106MB/s]\n"
129 | ],
130 | "name": "stdout"
131 | }
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "metadata": {
137 | "id": "TWw-1GHGqVI1",
138 | "colab": {
139 | "base_uri": "https://localhost:8080/"
140 | },
141 | "outputId": "eeb20ec7-d532-422d-b491-6ae9932a9d66"
142 | },
143 | "source": [
144 | "# Import libaries\r\n",
145 | "import numpy as np \r\n",
146 | "import pandas as pd\r\n",
147 | "from tqdm import tqdm\r\n",
148 | "import random\r\n",
149 | "import os\r\n",
150 | "import re\r\n",
151 | "import ktrain\r\n",
152 | "from ktrain import text\r\n",
153 | "import tensorflow as tf\r\n",
154 | "from sklearn.model_selection import StratifiedKFold\r\n",
155 | "import string\r\n",
156 | "import nltk\r\n",
157 | "from nltk.tokenize import word_tokenize\r\n",
158 | "from nltk.corpus import stopwords\r\n",
159 | "nltk.download('punkt')\r\n",
160 | "import warnings\r\n",
161 | "warnings.filterwarnings('ignore')"
162 | ],
163 | "execution_count": 95,
164 | "outputs": [
165 | {
166 | "output_type": "stream",
167 | "text": [
168 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
169 | "[nltk_data] Unzipping tokenizers/punkt.zip.\n"
170 | ],
171 | "name": "stdout"
172 | }
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "metadata": {
178 | "id": "gdIMbg9vnM9b"
179 | },
180 | "source": [
181 | "# Set seed\r\n",
182 | "SEED = 3031\r\n",
183 | "\r\n",
184 | "# def set_seeds(seed=SEED):\r\n",
185 | "# os.environ['PYTHONHASHSEED'] = str(seed)\r\n",
186 | "# random.seed(seed)\r\n",
187 | "# tf.random.set_seed(seed)\r\n",
188 | "# np.random.seed(seed)\r\n",
189 | "\r\n",
190 | "# def set_global_determinism(seed=SEED):\r\n",
191 | "# set_seeds(seed=seed)\r\n",
192 | "\r\n",
193 | "# os.environ['TF_DETERMINISTIC_OPS'] = '1'\r\n",
194 | "# os.environ['TF_CUDNN_DETERMINISTIC'] = '1'\r\n",
195 | " \r\n",
196 | "# tf.config.threading.set_inter_op_parallelism_threads(1)\r\n",
197 | "# tf.config.threading.set_intra_op_parallelism_threads(1)\r\n",
198 | "\r\n",
199 | "# set_global_determinism(seed=SEED)"
200 | ],
201 | "execution_count": 5,
202 | "outputs": []
203 | },
204 | {
205 | "cell_type": "code",
206 | "metadata": {
207 | "id": "7rmqE-HZR3FK"
208 | },
209 | "source": [
210 | "def clean_text(text):\r\n",
211 | " '''Make text lowercase, remove text in square brackets,remove links,remove punctuation\r\n",
212 | " and remove words containing numbers.'''\r\n",
213 | " text = text.lower()\r\n",
214 | " text = re.sub('\\[.*?\\]', '', text)\r\n",
215 | " text = re.sub('https?://\\S+|www\\.\\S+', '', text)\r\n",
216 | " text = re.sub('<.*?>+', '', text)\r\n",
217 | " text = re.sub('[%s]' % re.escape(string.punctuation), '', text)\r\n",
218 | " text = re.sub('\\n', '', text)\r\n",
219 | " text = re.sub('\\w*\\d\\w*', '', text)\r\n",
220 | " return text\r\n",
221 | "\r\n",
222 | "def text_preprocessing(text):\r\n",
223 | " \"\"\"\r\n",
224 | " Cleaning and parsing the text.\r\n",
225 | "\r\n",
226 | " \"\"\"\r\n",
227 | " tokenizer = nltk.tokenize.RegexpTokenizer(r'\\w+')\r\n",
228 | " nopunc = clean_text(text)\r\n",
229 | " tokenized_text = tokenizer.tokenize(nopunc)\r\n",
230 | " #remove_stopwords = [w for w in tokenized_text if w not in stopwords.words('english')]\r\n",
231 | " combined_text = ' '.join(tokenized_text)\r\n",
232 | " return combined_text\r\n",
233 | "\r\n",
234 | "def text_cleaner(text):\r\n",
235 | " text = re.sub('\\s+',' ', text)\r\n",
236 | " text = text.strip()\r\n",
237 | " text = re.sub(r'(.)\\1+', r'\\1\\1', text)\r\n",
238 | " return text"
239 | ],
240 | "execution_count": 77,
241 | "outputs": []
242 | },
243 | {
244 | "cell_type": "code",
245 | "metadata": {
246 | "id": "q0ECuC-bqVGi",
247 | "colab": {
248 | "base_uri": "https://localhost:8080/",
249 | "height": 195
250 | },
251 | "outputId": "0d391ed6-5450-4592-837e-bc5401578125"
252 | },
253 | "source": [
254 | "train = pd.read_csv('/content/Arabizi_data/Train.csv')\r\n",
255 | "test = pd.read_csv('/content/Arabizi_data/Test.csv')\r\n",
256 | "sample = pd.read_csv('/content/Arabizi_data/SampleSubmission.csv')\r\n",
257 | "train.head()"
258 | ],
259 | "execution_count": 122,
260 | "outputs": [
261 | {
262 | "output_type": "execute_result",
263 | "data": {
264 | "text/html": [
265 | "
\n",
266 | "\n",
279 | "
\n",
280 | " \n",
281 | " \n",
282 | " | \n",
283 | " ID | \n",
284 | " text | \n",
285 | " label | \n",
286 | "
\n",
287 | " \n",
288 | " \n",
289 | " \n",
290 | " | 0 | \n",
291 | " 13P0QT0 | \n",
292 | " 3sbaaaaaaaaaaaaaaaaaaaa lek ou le seim riahi o... | \n",
293 | " -1 | \n",
294 | "
\n",
295 | " \n",
296 | " | 1 | \n",
297 | " SKCLXCJ | \n",
298 | " cha3eb fey9elkoum menghir ta7ayoul ou kressi | \n",
299 | " -1 | \n",
300 | "
\n",
301 | " \n",
302 | " | 2 | \n",
303 | " V1TVXIJ | \n",
304 | " bereau degage nathef ya slim walahi ya7chiw fi... | \n",
305 | " -1 | \n",
306 | "
\n",
307 | " \n",
308 | " | 3 | \n",
309 | " U0TTYY8 | \n",
310 | " ak slouma | \n",
311 | " 1 | \n",
312 | "
\n",
313 | " \n",
314 | " | 4 | \n",
315 | " 68DX797 | \n",
316 | " entom titmanou lina a7na 3iid moubarik a7na ch... | \n",
317 | " -1 | \n",
318 | "
\n",
319 | " \n",
320 | "
\n",
321 | "
"
322 | ],
323 | "text/plain": [
324 | " ID text label\n",
325 | "0 13P0QT0 3sbaaaaaaaaaaaaaaaaaaaa lek ou le seim riahi o... -1\n",
326 | "1 SKCLXCJ cha3eb fey9elkoum menghir ta7ayoul ou kressi -1\n",
327 | "2 V1TVXIJ bereau degage nathef ya slim walahi ya7chiw fi... -1\n",
328 | "3 U0TTYY8 ak slouma 1\n",
329 | "4 68DX797 entom titmanou lina a7na 3iid moubarik a7na ch... -1"
330 | ]
331 | },
332 | "metadata": {
333 | "tags": []
334 | },
335 | "execution_count": 122
336 | }
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "metadata": {
342 | "colab": {
343 | "base_uri": "https://localhost:8080/"
344 | },
345 | "id": "Zmf9Co_h3Thg",
346 | "outputId": "caf8dd00-8415-42ed-b8a9-1c5a1d710501"
347 | },
348 | "source": [
349 | "train.label.value_counts()"
350 | ],
351 | "execution_count": 123,
352 | "outputs": [
353 | {
354 | "output_type": "execute_result",
355 | "data": {
356 | "text/plain": [
357 | " 1 38239\n",
358 | "-1 29295\n",
359 | " 0 2466\n",
360 | "Name: label, dtype: int64"
361 | ]
362 | },
363 | "metadata": {
364 | "tags": []
365 | },
366 | "execution_count": 123
367 | }
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "metadata": {
373 | "id": "X6GUtTyX3cwv"
374 | },
375 | "source": [
376 | "train.label = train.label.astype(str)"
377 | ],
378 | "execution_count": 124,
379 | "outputs": []
380 | },
381 | {
382 | "cell_type": "code",
383 | "metadata": {
384 | "colab": {
385 | "base_uri": "https://localhost:8080/",
386 | "height": 195
387 | },
388 | "id": "iJu-Skwq_JaM",
389 | "outputId": "8a4b1f2c-6b25-42a6-d177-e311e7591243"
390 | },
391 | "source": [
392 | "# Preview last five rows in test\r\n",
393 | "test.tail()"
394 | ],
395 | "execution_count": 125,
396 | "outputs": [
397 | {
398 | "output_type": "execute_result",
399 | "data": {
400 | "text/html": [
401 | "\n",
402 | "\n",
415 | "
\n",
416 | " \n",
417 | " \n",
418 | " | \n",
419 | " ID | \n",
420 | " text | \n",
421 | "
\n",
422 | " \n",
423 | " \n",
424 | " \n",
425 | " | 29995 | \n",
426 | " NHXTL3R | \n",
427 | " me ihebekch raw | \n",
428 | "
\n",
429 | " \n",
430 | " | 29996 | \n",
431 | " U1YWB2O | \n",
432 | " nchallah rabi m3ak w iwaf9ek mais just 7abit n... | \n",
433 | "
\n",
434 | " \n",
435 | " | 29997 | \n",
436 | " O3KYLM0 | \n",
437 | " slim rabi m3ak w e5edem w 5alli l7ossed lemnay... | \n",
438 | "
\n",
439 | " \n",
440 | " | 29998 | \n",
441 | " W4C38TY | \n",
442 | " bara 5alis rouhik yizi mitbal3it jam3iya hlaki... | \n",
443 | "
\n",
444 | " \n",
445 | " | 29999 | \n",
446 | " 4NNX5QE | \n",
447 | " rabi m3aaaak ya khawlaaa n7ebouuuuk rana barsh... | \n",
448 | "
\n",
449 | " \n",
450 | "
\n",
451 | "
"
452 | ],
453 | "text/plain": [
454 | " ID text\n",
455 | "29995 NHXTL3R me ihebekch raw\n",
456 | "29996 U1YWB2O nchallah rabi m3ak w iwaf9ek mais just 7abit n...\n",
457 | "29997 O3KYLM0 slim rabi m3ak w e5edem w 5alli l7ossed lemnay...\n",
458 | "29998 W4C38TY bara 5alis rouhik yizi mitbal3it jam3iya hlaki...\n",
459 | "29999 4NNX5QE rabi m3aaaak ya khawlaaa n7ebouuuuk rana barsh..."
460 | ]
461 | },
462 | "metadata": {
463 | "tags": []
464 | },
465 | "execution_count": 125
466 | }
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "metadata": {
472 | "colab": {
473 | "base_uri": "https://localhost:8080/"
474 | },
475 | "id": "gz_vbjQj3A94",
476 | "outputId": "455d7e45-ce09-42d8-9746-748bcfbaec6c"
477 | },
478 | "source": [
479 | "train.shape, test.shape, sample.shape"
480 | ],
481 | "execution_count": 126,
482 | "outputs": [
483 | {
484 | "output_type": "execute_result",
485 | "data": {
486 | "text/plain": [
487 | "((70000, 3), (30000, 2), (30000, 2))"
488 | ]
489 | },
490 | "metadata": {
491 | "tags": []
492 | },
493 | "execution_count": 126
494 | }
495 | ]
496 | },
497 | {
498 | "cell_type": "code",
499 | "metadata": {
500 | "colab": {
501 | "base_uri": "https://localhost:8080/",
502 | "height": 195
503 | },
504 | "id": "bm4MnMIBSHfr",
505 | "outputId": "248b33be-5dff-4c5a-f7c4-5606395c68e3"
506 | },
507 | "source": [
508 | "tqdm.pandas()\r\n",
509 | "train['clean_text'] = train.text.apply(lambda x: text_cleaner(x))\r\n",
510 | "train.head()"
511 | ],
512 | "execution_count": 127,
513 | "outputs": [
514 | {
515 | "output_type": "execute_result",
516 | "data": {
517 | "text/html": [
518 | "\n",
519 | "\n",
532 | "
\n",
533 | " \n",
534 | " \n",
535 | " | \n",
536 | " ID | \n",
537 | " text | \n",
538 | " label | \n",
539 | " clean_text | \n",
540 | "
\n",
541 | " \n",
542 | " \n",
543 | " \n",
544 | " | 0 | \n",
545 | " 13P0QT0 | \n",
546 | " 3sbaaaaaaaaaaaaaaaaaaaa lek ou le seim riahi o... | \n",
547 | " -1 | \n",
548 | " 3sbaa lek ou le seim riahi ou 3sbaa le ca | \n",
549 | "
\n",
550 | " \n",
551 | " | 1 | \n",
552 | " SKCLXCJ | \n",
553 | " cha3eb fey9elkoum menghir ta7ayoul ou kressi | \n",
554 | " -1 | \n",
555 | " cha3eb fey9elkoum menghir ta7ayoul ou kressi | \n",
556 | "
\n",
557 | " \n",
558 | " | 2 | \n",
559 | " V1TVXIJ | \n",
560 | " bereau degage nathef ya slim walahi ya7chiw fi... | \n",
561 | " -1 | \n",
562 | " bereau degage nathef ya slim walahi ya7chiw fi... | \n",
563 | "
\n",
564 | " \n",
565 | " | 3 | \n",
566 | " U0TTYY8 | \n",
567 | " ak slouma | \n",
568 | " 1 | \n",
569 | " ak slouma | \n",
570 | "
\n",
571 | " \n",
572 | " | 4 | \n",
573 | " 68DX797 | \n",
574 | " entom titmanou lina a7na 3iid moubarik a7na ch... | \n",
575 | " -1 | \n",
576 | " entom titmanou lina a7na 3iid moubarik a7na ch... | \n",
577 | "
\n",
578 | " \n",
579 | "
\n",
580 | "
"
581 | ],
582 | "text/plain": [
583 | " ID ... clean_text\n",
584 | "0 13P0QT0 ... 3sbaa lek ou le seim riahi ou 3sbaa le ca\n",
585 | "1 SKCLXCJ ... cha3eb fey9elkoum menghir ta7ayoul ou kressi\n",
586 | "2 V1TVXIJ ... bereau degage nathef ya slim walahi ya7chiw fi...\n",
587 | "3 U0TTYY8 ... ak slouma\n",
588 | "4 68DX797 ... entom titmanou lina a7na 3iid moubarik a7na ch...\n",
589 | "\n",
590 | "[5 rows x 4 columns]"
591 | ]
592 | },
593 | "metadata": {
594 | "tags": []
595 | },
596 | "execution_count": 127
597 | }
598 | ]
599 | },
600 | {
601 | "cell_type": "code",
602 | "metadata": {
603 | "colab": {
604 | "base_uri": "https://localhost:8080/",
605 | "height": 195
606 | },
607 | "id": "ohmweWGMS6eD",
608 | "outputId": "d9e7865a-60ea-4a51-893e-e6f360521b31"
609 | },
610 | "source": [
611 | "test['clean_text'] = test.text.apply(lambda x: text_cleaner(x))\r\n",
612 | "test.head()"
613 | ],
614 | "execution_count": 128,
615 | "outputs": [
616 | {
617 | "output_type": "execute_result",
618 | "data": {
619 | "text/html": [
620 | "\n",
621 | "\n",
634 | "
\n",
635 | " \n",
636 | " \n",
637 | " | \n",
638 | " ID | \n",
639 | " text | \n",
640 | " clean_text | \n",
641 | "
\n",
642 | " \n",
643 | " \n",
644 | " \n",
645 | " | 0 | \n",
646 | " 2DDHQW9 | \n",
647 | " barcha aaindou fiha hak w barcha teflim kadhalik | \n",
648 | " barcha aaindou fiha hak w barcha teflim kadhalik | \n",
649 | "
\n",
650 | " \n",
651 | " | 1 | \n",
652 | " 5HY6UEY | \n",
653 | " ye gernabou ye 9a7ba | \n",
654 | " ye gernabou ye 9a7ba | \n",
655 | "
\n",
656 | " \n",
657 | " | 2 | \n",
658 | " ATNVUJX | \n",
659 | " saber w barra rabbi m3ak 5ouya | \n",
660 | " saber w barra rabbi m3ak 5ouya | \n",
661 | "
\n",
662 | " \n",
663 | " | 3 | \n",
664 | " Q9XYVOQ | \n",
665 | " cha3ébbb ta7aaaaannnnnnnnnnn tfouuhh | \n",
666 | " cha3ébb ta7aann tfouuhh | \n",
667 | "
\n",
668 | " \n",
669 | " | 4 | \n",
670 | " TOAHLRH | \n",
671 | " rabi y5alihoulek w yfar7ek bih w inchallah itc... | \n",
672 | " rabi y5alihoulek w yfar7ek bih w inchallah itc... | \n",
673 | "
\n",
674 | " \n",
675 | "
\n",
676 | "
"
677 | ],
678 | "text/plain": [
679 | " ID ... clean_text\n",
680 | "0 2DDHQW9 ... barcha aaindou fiha hak w barcha teflim kadhalik\n",
681 | "1 5HY6UEY ... ye gernabou ye 9a7ba\n",
682 | "2 ATNVUJX ... saber w barra rabbi m3ak 5ouya\n",
683 | "3 Q9XYVOQ ... cha3ébb ta7aann tfouuhh\n",
684 | "4 TOAHLRH ... rabi y5alihoulek w yfar7ek bih w inchallah itc...\n",
685 | "\n",
686 | "[5 rows x 3 columns]"
687 | ]
688 | },
689 | "metadata": {
690 | "tags": []
691 | },
692 | "execution_count": 128
693 | }
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "metadata": {
699 | "id": "h5dscE0Vre-J",
700 | "colab": {
701 | "base_uri": "https://localhost:8080/"
702 | },
703 | "outputId": "9f896752-11e1-4e1a-efac-ab3d03d9d7d6"
704 | },
705 | "source": [
706 | "MODEL_NAME = 'bert-base-uncased'\r\n",
707 | "MAX_LEN = 64\r\n",
708 | "BATCH_SIZE = 64\r\n",
709 | "FOLDS = 5\r\n",
710 | "LR = 3e-5\r\n",
711 | "EPOCHS = 3\r\n",
712 | "\r\n",
713 | "# List of class names\r\n",
714 | "CLASS_NAMES = sorted(train.label.unique().tolist()) # ['afya', 'burudani', 'kimataifa', 'kitaifa', 'michezo', 'uchumi']\r\n",
715 | "\r\n",
716 | "# Instantiate transformer with the provided parameters\r\n",
717 | "t = text.Transformer(model_name=MODEL_NAME, maxlen=MAX_LEN, class_names=CLASS_NAMES, batch_size=BATCH_SIZE)\r\n",
718 | "CLASS_NAMES"
719 | ],
720 | "execution_count": 129,
721 | "outputs": [
722 | {
723 | "output_type": "execute_result",
724 | "data": {
725 | "text/plain": [
726 | "['-1', '0', '1']"
727 | ]
728 | },
729 | "metadata": {
730 | "tags": []
731 | },
732 | "execution_count": 129
733 | }
734 | ]
735 | },
736 | {
737 | "cell_type": "code",
738 | "metadata": {
739 | "id": "chBN-wZiy1QL",
740 | "colab": {
741 | "base_uri": "https://localhost:8080/"
742 | },
743 | "outputId": "ff63f6ef-2b95-47b3-c33d-391eb8bac3b8"
744 | },
745 | "source": [
746 | "%%time\r\n",
747 | "# Prepare test data\r\n",
748 | "test_data = np.asarray(test.clean_text)\r\n",
749 | "\r\n",
750 | "# Set number of folds to 3\r\n",
751 | "folds = StratifiedKFold(n_splits=FOLDS, random_state=SEED, shuffle=False)\r\n",
752 | "\r\n",
753 | "# List to store predictions and loss-score per fold\r\n",
754 | "oof_preds = []\r\n",
755 | "oof_loss_score = []\r\n",
756 | "\r\n",
757 | "for i, (train_index, test_index) in enumerate(folds.split(train.clean_text, train.label)):\r\n",
758 | " X_train, X_test = list(train.loc[train_index, 'clean_text']), list(train.loc[test_index, 'clean_text'])\r\n",
759 | " y_train, y_test = np.asarray(train.loc[train_index, 'label']), np.asarray(train.loc[test_index, 'label'])\r\n",
760 | "\r\n",
761 | " # Preprocess training and validation data\r\n",
762 | " train_set = t.preprocess_train(X_train, y_train, verbose = 0)\r\n",
763 | " val_set = t.preprocess_test(X_test, y_test, verbose = 0)\r\n",
764 | "\r\n",
765 | " # Instantiate model\r\n",
766 | " model = t.get_classifier()\r\n",
767 | " learner = ktrain.get_learner(model, train_data=train_set, val_data=val_set, batch_size=BATCH_SIZE)\r\n",
768 | "\r\n",
769 | " history = learner.fit(LR, n_cycles=EPOCHS, checkpoint_folder='/tmp')\r\n",
770 | " fold_accuracies = history.history['val_accuracy'] \r\n",
771 | " best_score, best_epoch = max(fold_accuracies), np.array(fold_accuracies).argmax() + 1\r\n",
772 | " oof_loss_score.append(best_score)\r\n",
773 | " print(f'\\033[1m\\033[92m Fold {i+1}: {best_score}\\33[0m\\n')\r\n",
774 | "\r\n",
775 | " #Load best weights\r\n",
776 | " model = t.get_classifier()\r\n",
777 | " model.load_weights('../tmp/weights-0' + str(best_epoch) + '.hdf5')\r\n",
778 | " learner = ktrain.get_learner(model, train_data=train_set, val_data=val_set, batch_size=BATCH_SIZE)\r\n",
779 | "\r\n",
780 | " # Make predictions\r\n",
781 | " preds = ktrain.get_predictor(learner.model, preproc=t).predict(test_data, return_proba=True)\r\n",
782 | "\r\n",
783 | " # Append preds to oof_preds list\r\n",
784 | " oof_preds.append(preds)\r\n",
785 | "\r\n",
786 | "# Check cv score and prepare submission file\r\n",
787 | "LOSS = np.round(np.mean(oof_loss_score), 5)\r\n",
788 | "print(f'\\n\\33[96m\\33[1m\\33[4m Mean Loss: {LOSS}\\33[0m')\r\n",
789 | "\r\n",
790 | "name = f'{MODEL_NAME}_ML{MAX_LEN}_BS{BATCH_SIZE}_FD{FOLDS}_LR{LR}_EP{EPOCHS}_LS{LOSS}'\r\n",
791 | "sub = pd.DataFrame(np.mean(oof_preds, axis=0), columns = t.get_classes())\r\n",
792 | "sub.to_csv(name + '.csv', index = False)\r\n",
793 | "ss = pd.DataFrame({'ID':test.ID, 'label': sub.idxmax(axis = 1)})\r\n",
794 | "ss.to_csv(f'KT_bert{LOSS}.csv', index = False)"
795 | ],
796 | "execution_count": 130,
797 | "outputs": [
798 | {
799 | "output_type": "stream",
800 | "text": [
801 | "Epoch 1/3\n",
802 | "875/875 [==============================] - 292s 305ms/step - loss: 0.6345 - accuracy: 0.7121 - val_loss: 0.4663 - val_accuracy: 0.8023\n",
803 | "Epoch 2/3\n",
804 | "875/875 [==============================] - 273s 301ms/step - loss: 0.4152 - accuracy: 0.8300 - val_loss: 0.4754 - val_accuracy: 0.8016\n",
805 | "Epoch 3/3\n",
806 | "875/875 [==============================] - 274s 301ms/step - loss: 0.3116 - accuracy: 0.8760 - val_loss: 0.4404 - val_accuracy: 0.8218\n",
807 | "\u001b[1m\u001b[92m Fold 1: 0.8217856884002686\u001b[0m\n",
808 | "\n",
809 | "Epoch 1/3\n",
810 | "875/875 [==============================] - 291s 305ms/step - loss: 0.6398 - accuracy: 0.7110 - val_loss: 0.4818 - val_accuracy: 0.8003\n",
811 | "Epoch 2/3\n",
812 | "875/875 [==============================] - 274s 302ms/step - loss: 0.4155 - accuracy: 0.8297 - val_loss: 0.4434 - val_accuracy: 0.8172\n",
813 | "Epoch 3/3\n",
814 | "875/875 [==============================] - 274s 302ms/step - loss: 0.3036 - accuracy: 0.8799 - val_loss: 0.4486 - val_accuracy: 0.8184\n",
815 | "\u001b[1m\u001b[92m Fold 2: 0.8184285759925842\u001b[0m\n",
816 | "\n",
817 | "Epoch 1/3\n",
818 | "875/875 [==============================] - 292s 305ms/step - loss: 0.6423 - accuracy: 0.7103 - val_loss: 0.4753 - val_accuracy: 0.8006\n",
819 | "Epoch 2/3\n",
820 | "875/875 [==============================] - 274s 302ms/step - loss: 0.4082 - accuracy: 0.8323 - val_loss: 0.4544 - val_accuracy: 0.8154\n",
821 | "Epoch 3/3\n",
822 | "875/875 [==============================] - 275s 303ms/step - loss: 0.3061 - accuracy: 0.8820 - val_loss: 0.5234 - val_accuracy: 0.8090\n",
823 | "\u001b[1m\u001b[92m Fold 3: 0.8153571486473083\u001b[0m\n",
824 | "\n",
825 | "Epoch 1/3\n",
826 | "875/875 [==============================] - 291s 305ms/step - loss: 0.6313 - accuracy: 0.7151 - val_loss: 0.4684 - val_accuracy: 0.7984\n",
827 | "Epoch 2/3\n",
828 | "875/875 [==============================] - 274s 302ms/step - loss: 0.4189 - accuracy: 0.8271 - val_loss: 0.4292 - val_accuracy: 0.8226\n",
829 | "Epoch 3/3\n",
830 | "875/875 [==============================] - 274s 302ms/step - loss: 0.3045 - accuracy: 0.8777 - val_loss: 0.4774 - val_accuracy: 0.8154\n",
831 | "\u001b[1m\u001b[92m Fold 4: 0.8226428627967834\u001b[0m\n",
832 | "\n",
833 | "Epoch 1/3\n",
834 | "875/875 [==============================] - 294s 306ms/step - loss: 0.6402 - accuracy: 0.7073 - val_loss: 0.4774 - val_accuracy: 0.8046\n",
835 | "Epoch 2/3\n",
836 | "875/875 [==============================] - 275s 303ms/step - loss: 0.4161 - accuracy: 0.8290 - val_loss: 0.4452 - val_accuracy: 0.8175\n",
837 | "Epoch 3/3\n",
838 | "875/875 [==============================] - 275s 303ms/step - loss: 0.3071 - accuracy: 0.8792 - val_loss: 0.4636 - val_accuracy: 0.8264\n",
839 | "\u001b[1m\u001b[92m Fold 5: 0.8263571262359619\u001b[0m\n",
840 | "\n",
841 | "\n",
842 | "\u001b[96m\u001b[1m\u001b[4m Mean Loss: 0.82091\u001b[0m\n",
843 | "CPU times: user 35min 7s, sys: 9min 30s, total: 44min 38s\n",
844 | "Wall time: 1h 17min 35s\n"
845 | ],
846 | "name": "stdout"
847 | }
848 | ]
849 | }
850 | ]
851 | }
--------------------------------------------------------------------------------