└── Arabizi_KT_CV82091_LB_8286G.ipynb /Arabizi_KT_CV82091_LB_8286G.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Arabizi_KT_CV82091_LB_8286G.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "code", 19 | "metadata": { 20 | "colab": { 21 | "base_uri": "https://localhost:8080/" 22 | }, 23 | "id": "ojSAlH9W_gfv", 24 | "outputId": "6bed4e5e-d511-4ccb-e09e-179f561162cc" 25 | }, 26 | "source": [ 27 | "# Check GPU type\r\n", 28 | "!nvidia-smi" 29 | ], 30 | "execution_count": 1, 31 | "outputs": [ 32 | { 33 | "output_type": "stream", 34 | "text": [ 35 | "Tue Mar 2 07:08:47 2021 \n", 36 | "+-----------------------------------------------------------------------------+\n", 37 | "| NVIDIA-SMI 460.39 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", 38 | "|-------------------------------+----------------------+----------------------+\n", 39 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 40 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 41 | "| | | MIG M. |\n", 42 | "|===============================+======================+======================|\n", 43 | "| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n", 44 | "| N/A 34C P0 24W / 300W | 0MiB / 16160MiB | 0% Default |\n", 45 | "| | | N/A |\n", 46 | "+-------------------------------+----------------------+----------------------+\n", 47 | " \n", 48 | "+-----------------------------------------------------------------------------+\n", 49 | "| Processes: |\n", 50 | "| GPU GI CI PID Type Process name GPU Memory |\n", 51 | "| ID ID Usage |\n", 52 | "|=============================================================================|\n", 53 | "| No running processes found |\n", 54 | "+-----------------------------------------------------------------------------+\n" 55 | ], 56 | "name": "stdout" 57 | } 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "metadata": { 63 | "colab": { 64 | "base_uri": "https://localhost:8080/" 65 | }, 66 | "id": "5016OlJnp2kC", 67 | "outputId": "1310ee54-5823-4d7d-a882-3ff37756f37e" 68 | }, 69 | "source": [ 70 | "# Upgrade pip and install ktrain\r\n", 71 | "!pip -qq install -U pip\r\n", 72 | "!pip -qq install ktrain" 73 | ], 74 | "execution_count": 2, 75 | "outputs": [ 76 | { 77 | "output_type": "stream", 78 | "text": [ 79 | "\u001b[K |████████████████████████████████| 1.5MB 5.4MB/s \n", 80 | "\u001b[K |████████████████████████████████| 25.3 MB 94.2 MB/s \n", 81 | "\u001b[K |████████████████████████████████| 6.8 MB 61.9 MB/s \n", 82 | "\u001b[K |████████████████████████████████| 981 kB 56.6 MB/s \n", 83 | "\u001b[K |████████████████████████████████| 263 kB 58.2 MB/s \n", 84 | "\u001b[K |████████████████████████████████| 1.3 MB 58.2 MB/s \n", 85 | "\u001b[K |████████████████████████████████| 1.2 MB 60.2 MB/s \n", 86 | "\u001b[K |████████████████████████████████| 468 kB 27.6 MB/s \n", 87 | "\u001b[K |████████████████████████████████| 1.1 MB 60.4 MB/s \n", 88 | "\u001b[K |████████████████████████████████| 883 kB 60.1 MB/s \n", 89 | "\u001b[K |████████████████████████████████| 2.9 MB 65.3 MB/s \n", 90 | "\u001b[?25h Building wheel for ktrain (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 91 | " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 92 | " Building wheel for keras-bert (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 93 | " Building wheel for keras-embed-sim (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 94 | " Building wheel for keras-layer-normalization (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 95 | " Building wheel for keras-multi-head (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 96 | " Building wheel for keras-self-attention (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 97 | " Building wheel for keras-pos-embd (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 98 | " Building wheel for keras-position-wise-feed-forward (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 99 | " Building wheel for langdetect (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 100 | " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 101 | " Building wheel for syntok (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 102 | ], 103 | "name": "stdout" 104 | } 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "metadata": { 110 | "id": "qDHDhHzWrDmm", 111 | "colab": { 112 | "base_uri": "https://localhost:8080/" 113 | }, 114 | "outputId": "1c15ae83-3a35-49b0-a6e7-bd9088176a8f" 115 | }, 116 | "source": [ 117 | "!gdown --id 1LZBMbdMAr8iwmNfN2JkiFw-uGPleBOsf\r\n", 118 | "!unzip -q '/content/Arabizi_data.zip'" 119 | ], 120 | "execution_count": 3, 121 | "outputs": [ 122 | { 123 | "output_type": "stream", 124 | "text": [ 125 | "Downloading...\n", 126 | "From: https://drive.google.com/uc?id=1LZBMbdMAr8iwmNfN2JkiFw-uGPleBOsf\n", 127 | "To: /content/Arabizi_data.zip\n", 128 | "\r0.00B [00:00, ?B/s]\r3.67MB [00:00, 106MB/s]\n" 129 | ], 130 | "name": "stdout" 131 | } 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "metadata": { 137 | "id": "TWw-1GHGqVI1", 138 | "colab": { 139 | "base_uri": "https://localhost:8080/" 140 | }, 141 | "outputId": "eeb20ec7-d532-422d-b491-6ae9932a9d66" 142 | }, 143 | "source": [ 144 | "# Import libaries\r\n", 145 | "import numpy as np \r\n", 146 | "import pandas as pd\r\n", 147 | "from tqdm import tqdm\r\n", 148 | "import random\r\n", 149 | "import os\r\n", 150 | "import re\r\n", 151 | "import ktrain\r\n", 152 | "from ktrain import text\r\n", 153 | "import tensorflow as tf\r\n", 154 | "from sklearn.model_selection import StratifiedKFold\r\n", 155 | "import string\r\n", 156 | "import nltk\r\n", 157 | "from nltk.tokenize import word_tokenize\r\n", 158 | "from nltk.corpus import stopwords\r\n", 159 | "nltk.download('punkt')\r\n", 160 | "import warnings\r\n", 161 | "warnings.filterwarnings('ignore')" 162 | ], 163 | "execution_count": 95, 164 | "outputs": [ 165 | { 166 | "output_type": "stream", 167 | "text": [ 168 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n", 169 | "[nltk_data] Unzipping tokenizers/punkt.zip.\n" 170 | ], 171 | "name": "stdout" 172 | } 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "metadata": { 178 | "id": "gdIMbg9vnM9b" 179 | }, 180 | "source": [ 181 | "# Set seed\r\n", 182 | "SEED = 3031\r\n", 183 | "\r\n", 184 | "# def set_seeds(seed=SEED):\r\n", 185 | "# os.environ['PYTHONHASHSEED'] = str(seed)\r\n", 186 | "# random.seed(seed)\r\n", 187 | "# tf.random.set_seed(seed)\r\n", 188 | "# np.random.seed(seed)\r\n", 189 | "\r\n", 190 | "# def set_global_determinism(seed=SEED):\r\n", 191 | "# set_seeds(seed=seed)\r\n", 192 | "\r\n", 193 | "# os.environ['TF_DETERMINISTIC_OPS'] = '1'\r\n", 194 | "# os.environ['TF_CUDNN_DETERMINISTIC'] = '1'\r\n", 195 | " \r\n", 196 | "# tf.config.threading.set_inter_op_parallelism_threads(1)\r\n", 197 | "# tf.config.threading.set_intra_op_parallelism_threads(1)\r\n", 198 | "\r\n", 199 | "# set_global_determinism(seed=SEED)" 200 | ], 201 | "execution_count": 5, 202 | "outputs": [] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "metadata": { 207 | "id": "7rmqE-HZR3FK" 208 | }, 209 | "source": [ 210 | "def clean_text(text):\r\n", 211 | " '''Make text lowercase, remove text in square brackets,remove links,remove punctuation\r\n", 212 | " and remove words containing numbers.'''\r\n", 213 | " text = text.lower()\r\n", 214 | " text = re.sub('\\[.*?\\]', '', text)\r\n", 215 | " text = re.sub('https?://\\S+|www\\.\\S+', '', text)\r\n", 216 | " text = re.sub('<.*?>+', '', text)\r\n", 217 | " text = re.sub('[%s]' % re.escape(string.punctuation), '', text)\r\n", 218 | " text = re.sub('\\n', '', text)\r\n", 219 | " text = re.sub('\\w*\\d\\w*', '', text)\r\n", 220 | " return text\r\n", 221 | "\r\n", 222 | "def text_preprocessing(text):\r\n", 223 | " \"\"\"\r\n", 224 | " Cleaning and parsing the text.\r\n", 225 | "\r\n", 226 | " \"\"\"\r\n", 227 | " tokenizer = nltk.tokenize.RegexpTokenizer(r'\\w+')\r\n", 228 | " nopunc = clean_text(text)\r\n", 229 | " tokenized_text = tokenizer.tokenize(nopunc)\r\n", 230 | " #remove_stopwords = [w for w in tokenized_text if w not in stopwords.words('english')]\r\n", 231 | " combined_text = ' '.join(tokenized_text)\r\n", 232 | " return combined_text\r\n", 233 | "\r\n", 234 | "def text_cleaner(text):\r\n", 235 | " text = re.sub('\\s+',' ', text)\r\n", 236 | " text = text.strip()\r\n", 237 | " text = re.sub(r'(.)\\1+', r'\\1\\1', text)\r\n", 238 | " return text" 239 | ], 240 | "execution_count": 77, 241 | "outputs": [] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "metadata": { 246 | "id": "q0ECuC-bqVGi", 247 | "colab": { 248 | "base_uri": "https://localhost:8080/", 249 | "height": 195 250 | }, 251 | "outputId": "0d391ed6-5450-4592-837e-bc5401578125" 252 | }, 253 | "source": [ 254 | "train = pd.read_csv('/content/Arabizi_data/Train.csv')\r\n", 255 | "test = pd.read_csv('/content/Arabizi_data/Test.csv')\r\n", 256 | "sample = pd.read_csv('/content/Arabizi_data/SampleSubmission.csv')\r\n", 257 | "train.head()" 258 | ], 259 | "execution_count": 122, 260 | "outputs": [ 261 | { 262 | "output_type": "execute_result", 263 | "data": { 264 | "text/html": [ 265 | "
\n", 266 | "\n", 279 | "\n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | "
IDtextlabel
013P0QT03sbaaaaaaaaaaaaaaaaaaaa lek ou le seim riahi o...-1
1SKCLXCJcha3eb fey9elkoum menghir ta7ayoul ou kressi-1
2V1TVXIJbereau degage nathef ya slim walahi ya7chiw fi...-1
3U0TTYY8ak slouma1
468DX797entom titmanou lina a7na 3iid moubarik a7na ch...-1
\n", 321 | "
" 322 | ], 323 | "text/plain": [ 324 | " ID text label\n", 325 | "0 13P0QT0 3sbaaaaaaaaaaaaaaaaaaaa lek ou le seim riahi o... -1\n", 326 | "1 SKCLXCJ cha3eb fey9elkoum menghir ta7ayoul ou kressi -1\n", 327 | "2 V1TVXIJ bereau degage nathef ya slim walahi ya7chiw fi... -1\n", 328 | "3 U0TTYY8 ak slouma 1\n", 329 | "4 68DX797 entom titmanou lina a7na 3iid moubarik a7na ch... -1" 330 | ] 331 | }, 332 | "metadata": { 333 | "tags": [] 334 | }, 335 | "execution_count": 122 336 | } 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "metadata": { 342 | "colab": { 343 | "base_uri": "https://localhost:8080/" 344 | }, 345 | "id": "Zmf9Co_h3Thg", 346 | "outputId": "caf8dd00-8415-42ed-b8a9-1c5a1d710501" 347 | }, 348 | "source": [ 349 | "train.label.value_counts()" 350 | ], 351 | "execution_count": 123, 352 | "outputs": [ 353 | { 354 | "output_type": "execute_result", 355 | "data": { 356 | "text/plain": [ 357 | " 1 38239\n", 358 | "-1 29295\n", 359 | " 0 2466\n", 360 | "Name: label, dtype: int64" 361 | ] 362 | }, 363 | "metadata": { 364 | "tags": [] 365 | }, 366 | "execution_count": 123 367 | } 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "metadata": { 373 | "id": "X6GUtTyX3cwv" 374 | }, 375 | "source": [ 376 | "train.label = train.label.astype(str)" 377 | ], 378 | "execution_count": 124, 379 | "outputs": [] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "metadata": { 384 | "colab": { 385 | "base_uri": "https://localhost:8080/", 386 | "height": 195 387 | }, 388 | "id": "iJu-Skwq_JaM", 389 | "outputId": "8a4b1f2c-6b25-42a6-d177-e311e7591243" 390 | }, 391 | "source": [ 392 | "# Preview last five rows in test\r\n", 393 | "test.tail()" 394 | ], 395 | "execution_count": 125, 396 | "outputs": [ 397 | { 398 | "output_type": "execute_result", 399 | "data": { 400 | "text/html": [ 401 | "
\n", 402 | "\n", 415 | "\n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | "
IDtext
29995NHXTL3Rme ihebekch raw
29996U1YWB2Onchallah rabi m3ak w iwaf9ek mais just 7abit n...
29997O3KYLM0slim rabi m3ak w e5edem w 5alli l7ossed lemnay...
29998W4C38TYbara 5alis rouhik yizi mitbal3it jam3iya hlaki...
299994NNX5QErabi m3aaaak ya khawlaaa n7ebouuuuk rana barsh...
\n", 451 | "
" 452 | ], 453 | "text/plain": [ 454 | " ID text\n", 455 | "29995 NHXTL3R me ihebekch raw\n", 456 | "29996 U1YWB2O nchallah rabi m3ak w iwaf9ek mais just 7abit n...\n", 457 | "29997 O3KYLM0 slim rabi m3ak w e5edem w 5alli l7ossed lemnay...\n", 458 | "29998 W4C38TY bara 5alis rouhik yizi mitbal3it jam3iya hlaki...\n", 459 | "29999 4NNX5QE rabi m3aaaak ya khawlaaa n7ebouuuuk rana barsh..." 460 | ] 461 | }, 462 | "metadata": { 463 | "tags": [] 464 | }, 465 | "execution_count": 125 466 | } 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "metadata": { 472 | "colab": { 473 | "base_uri": "https://localhost:8080/" 474 | }, 475 | "id": "gz_vbjQj3A94", 476 | "outputId": "455d7e45-ce09-42d8-9746-748bcfbaec6c" 477 | }, 478 | "source": [ 479 | "train.shape, test.shape, sample.shape" 480 | ], 481 | "execution_count": 126, 482 | "outputs": [ 483 | { 484 | "output_type": "execute_result", 485 | "data": { 486 | "text/plain": [ 487 | "((70000, 3), (30000, 2), (30000, 2))" 488 | ] 489 | }, 490 | "metadata": { 491 | "tags": [] 492 | }, 493 | "execution_count": 126 494 | } 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "metadata": { 500 | "colab": { 501 | "base_uri": "https://localhost:8080/", 502 | "height": 195 503 | }, 504 | "id": "bm4MnMIBSHfr", 505 | "outputId": "248b33be-5dff-4c5a-f7c4-5606395c68e3" 506 | }, 507 | "source": [ 508 | "tqdm.pandas()\r\n", 509 | "train['clean_text'] = train.text.apply(lambda x: text_cleaner(x))\r\n", 510 | "train.head()" 511 | ], 512 | "execution_count": 127, 513 | "outputs": [ 514 | { 515 | "output_type": "execute_result", 516 | "data": { 517 | "text/html": [ 518 | "
\n", 519 | "\n", 532 | "\n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | "
IDtextlabelclean_text
013P0QT03sbaaaaaaaaaaaaaaaaaaaa lek ou le seim riahi o...-13sbaa lek ou le seim riahi ou 3sbaa le ca
1SKCLXCJcha3eb fey9elkoum menghir ta7ayoul ou kressi-1cha3eb fey9elkoum menghir ta7ayoul ou kressi
2V1TVXIJbereau degage nathef ya slim walahi ya7chiw fi...-1bereau degage nathef ya slim walahi ya7chiw fi...
3U0TTYY8ak slouma1ak slouma
468DX797entom titmanou lina a7na 3iid moubarik a7na ch...-1entom titmanou lina a7na 3iid moubarik a7na ch...
\n", 580 | "
" 581 | ], 582 | "text/plain": [ 583 | " ID ... clean_text\n", 584 | "0 13P0QT0 ... 3sbaa lek ou le seim riahi ou 3sbaa le ca\n", 585 | "1 SKCLXCJ ... cha3eb fey9elkoum menghir ta7ayoul ou kressi\n", 586 | "2 V1TVXIJ ... bereau degage nathef ya slim walahi ya7chiw fi...\n", 587 | "3 U0TTYY8 ... ak slouma\n", 588 | "4 68DX797 ... entom titmanou lina a7na 3iid moubarik a7na ch...\n", 589 | "\n", 590 | "[5 rows x 4 columns]" 591 | ] 592 | }, 593 | "metadata": { 594 | "tags": [] 595 | }, 596 | "execution_count": 127 597 | } 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "metadata": { 603 | "colab": { 604 | "base_uri": "https://localhost:8080/", 605 | "height": 195 606 | }, 607 | "id": "ohmweWGMS6eD", 608 | "outputId": "d9e7865a-60ea-4a51-893e-e6f360521b31" 609 | }, 610 | "source": [ 611 | "test['clean_text'] = test.text.apply(lambda x: text_cleaner(x))\r\n", 612 | "test.head()" 613 | ], 614 | "execution_count": 128, 615 | "outputs": [ 616 | { 617 | "output_type": "execute_result", 618 | "data": { 619 | "text/html": [ 620 | "
\n", 621 | "\n", 634 | "\n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | "
IDtextclean_text
02DDHQW9barcha aaindou fiha hak w barcha teflim kadhalikbarcha aaindou fiha hak w barcha teflim kadhalik
15HY6UEYye gernabou ye 9a7baye gernabou ye 9a7ba
2ATNVUJXsaber w barra rabbi m3ak 5ouyasaber w barra rabbi m3ak 5ouya
3Q9XYVOQcha3ébbb ta7aaaaannnnnnnnnnn tfouuhhcha3ébb ta7aann tfouuhh
4TOAHLRHrabi y5alihoulek w yfar7ek bih w inchallah itc...rabi y5alihoulek w yfar7ek bih w inchallah itc...
\n", 676 | "
" 677 | ], 678 | "text/plain": [ 679 | " ID ... clean_text\n", 680 | "0 2DDHQW9 ... barcha aaindou fiha hak w barcha teflim kadhalik\n", 681 | "1 5HY6UEY ... ye gernabou ye 9a7ba\n", 682 | "2 ATNVUJX ... saber w barra rabbi m3ak 5ouya\n", 683 | "3 Q9XYVOQ ... cha3ébb ta7aann tfouuhh\n", 684 | "4 TOAHLRH ... rabi y5alihoulek w yfar7ek bih w inchallah itc...\n", 685 | "\n", 686 | "[5 rows x 3 columns]" 687 | ] 688 | }, 689 | "metadata": { 690 | "tags": [] 691 | }, 692 | "execution_count": 128 693 | } 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "metadata": { 699 | "id": "h5dscE0Vre-J", 700 | "colab": { 701 | "base_uri": "https://localhost:8080/" 702 | }, 703 | "outputId": "9f896752-11e1-4e1a-efac-ab3d03d9d7d6" 704 | }, 705 | "source": [ 706 | "MODEL_NAME = 'bert-base-uncased'\r\n", 707 | "MAX_LEN = 64\r\n", 708 | "BATCH_SIZE = 64\r\n", 709 | "FOLDS = 5\r\n", 710 | "LR = 3e-5\r\n", 711 | "EPOCHS = 3\r\n", 712 | "\r\n", 713 | "# List of class names\r\n", 714 | "CLASS_NAMES = sorted(train.label.unique().tolist()) # ['afya', 'burudani', 'kimataifa', 'kitaifa', 'michezo', 'uchumi']\r\n", 715 | "\r\n", 716 | "# Instantiate transformer with the provided parameters\r\n", 717 | "t = text.Transformer(model_name=MODEL_NAME, maxlen=MAX_LEN, class_names=CLASS_NAMES, batch_size=BATCH_SIZE)\r\n", 718 | "CLASS_NAMES" 719 | ], 720 | "execution_count": 129, 721 | "outputs": [ 722 | { 723 | "output_type": "execute_result", 724 | "data": { 725 | "text/plain": [ 726 | "['-1', '0', '1']" 727 | ] 728 | }, 729 | "metadata": { 730 | "tags": [] 731 | }, 732 | "execution_count": 129 733 | } 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "metadata": { 739 | "id": "chBN-wZiy1QL", 740 | "colab": { 741 | "base_uri": "https://localhost:8080/" 742 | }, 743 | "outputId": "ff63f6ef-2b95-47b3-c33d-391eb8bac3b8" 744 | }, 745 | "source": [ 746 | "%%time\r\n", 747 | "# Prepare test data\r\n", 748 | "test_data = np.asarray(test.clean_text)\r\n", 749 | "\r\n", 750 | "# Set number of folds to 3\r\n", 751 | "folds = StratifiedKFold(n_splits=FOLDS, random_state=SEED, shuffle=False)\r\n", 752 | "\r\n", 753 | "# List to store predictions and loss-score per fold\r\n", 754 | "oof_preds = []\r\n", 755 | "oof_loss_score = []\r\n", 756 | "\r\n", 757 | "for i, (train_index, test_index) in enumerate(folds.split(train.clean_text, train.label)):\r\n", 758 | " X_train, X_test = list(train.loc[train_index, 'clean_text']), list(train.loc[test_index, 'clean_text'])\r\n", 759 | " y_train, y_test = np.asarray(train.loc[train_index, 'label']), np.asarray(train.loc[test_index, 'label'])\r\n", 760 | "\r\n", 761 | " # Preprocess training and validation data\r\n", 762 | " train_set = t.preprocess_train(X_train, y_train, verbose = 0)\r\n", 763 | " val_set = t.preprocess_test(X_test, y_test, verbose = 0)\r\n", 764 | "\r\n", 765 | " # Instantiate model\r\n", 766 | " model = t.get_classifier()\r\n", 767 | " learner = ktrain.get_learner(model, train_data=train_set, val_data=val_set, batch_size=BATCH_SIZE)\r\n", 768 | "\r\n", 769 | " history = learner.fit(LR, n_cycles=EPOCHS, checkpoint_folder='/tmp')\r\n", 770 | " fold_accuracies = history.history['val_accuracy'] \r\n", 771 | " best_score, best_epoch = max(fold_accuracies), np.array(fold_accuracies).argmax() + 1\r\n", 772 | " oof_loss_score.append(best_score)\r\n", 773 | " print(f'\\033[1m\\033[92m Fold {i+1}: {best_score}\\33[0m\\n')\r\n", 774 | "\r\n", 775 | " #Load best weights\r\n", 776 | " model = t.get_classifier()\r\n", 777 | " model.load_weights('../tmp/weights-0' + str(best_epoch) + '.hdf5')\r\n", 778 | " learner = ktrain.get_learner(model, train_data=train_set, val_data=val_set, batch_size=BATCH_SIZE)\r\n", 779 | "\r\n", 780 | " # Make predictions\r\n", 781 | " preds = ktrain.get_predictor(learner.model, preproc=t).predict(test_data, return_proba=True)\r\n", 782 | "\r\n", 783 | " # Append preds to oof_preds list\r\n", 784 | " oof_preds.append(preds)\r\n", 785 | "\r\n", 786 | "# Check cv score and prepare submission file\r\n", 787 | "LOSS = np.round(np.mean(oof_loss_score), 5)\r\n", 788 | "print(f'\\n\\33[96m\\33[1m\\33[4m Mean Loss: {LOSS}\\33[0m')\r\n", 789 | "\r\n", 790 | "name = f'{MODEL_NAME}_ML{MAX_LEN}_BS{BATCH_SIZE}_FD{FOLDS}_LR{LR}_EP{EPOCHS}_LS{LOSS}'\r\n", 791 | "sub = pd.DataFrame(np.mean(oof_preds, axis=0), columns = t.get_classes())\r\n", 792 | "sub.to_csv(name + '.csv', index = False)\r\n", 793 | "ss = pd.DataFrame({'ID':test.ID, 'label': sub.idxmax(axis = 1)})\r\n", 794 | "ss.to_csv(f'KT_bert{LOSS}.csv', index = False)" 795 | ], 796 | "execution_count": 130, 797 | "outputs": [ 798 | { 799 | "output_type": "stream", 800 | "text": [ 801 | "Epoch 1/3\n", 802 | "875/875 [==============================] - 292s 305ms/step - loss: 0.6345 - accuracy: 0.7121 - val_loss: 0.4663 - val_accuracy: 0.8023\n", 803 | "Epoch 2/3\n", 804 | "875/875 [==============================] - 273s 301ms/step - loss: 0.4152 - accuracy: 0.8300 - val_loss: 0.4754 - val_accuracy: 0.8016\n", 805 | "Epoch 3/3\n", 806 | "875/875 [==============================] - 274s 301ms/step - loss: 0.3116 - accuracy: 0.8760 - val_loss: 0.4404 - val_accuracy: 0.8218\n", 807 | "\u001b[1m\u001b[92m Fold 1: 0.8217856884002686\u001b[0m\n", 808 | "\n", 809 | "Epoch 1/3\n", 810 | "875/875 [==============================] - 291s 305ms/step - loss: 0.6398 - accuracy: 0.7110 - val_loss: 0.4818 - val_accuracy: 0.8003\n", 811 | "Epoch 2/3\n", 812 | "875/875 [==============================] - 274s 302ms/step - loss: 0.4155 - accuracy: 0.8297 - val_loss: 0.4434 - val_accuracy: 0.8172\n", 813 | "Epoch 3/3\n", 814 | "875/875 [==============================] - 274s 302ms/step - loss: 0.3036 - accuracy: 0.8799 - val_loss: 0.4486 - val_accuracy: 0.8184\n", 815 | "\u001b[1m\u001b[92m Fold 2: 0.8184285759925842\u001b[0m\n", 816 | "\n", 817 | "Epoch 1/3\n", 818 | "875/875 [==============================] - 292s 305ms/step - loss: 0.6423 - accuracy: 0.7103 - val_loss: 0.4753 - val_accuracy: 0.8006\n", 819 | "Epoch 2/3\n", 820 | "875/875 [==============================] - 274s 302ms/step - loss: 0.4082 - accuracy: 0.8323 - val_loss: 0.4544 - val_accuracy: 0.8154\n", 821 | "Epoch 3/3\n", 822 | "875/875 [==============================] - 275s 303ms/step - loss: 0.3061 - accuracy: 0.8820 - val_loss: 0.5234 - val_accuracy: 0.8090\n", 823 | "\u001b[1m\u001b[92m Fold 3: 0.8153571486473083\u001b[0m\n", 824 | "\n", 825 | "Epoch 1/3\n", 826 | "875/875 [==============================] - 291s 305ms/step - loss: 0.6313 - accuracy: 0.7151 - val_loss: 0.4684 - val_accuracy: 0.7984\n", 827 | "Epoch 2/3\n", 828 | "875/875 [==============================] - 274s 302ms/step - loss: 0.4189 - accuracy: 0.8271 - val_loss: 0.4292 - val_accuracy: 0.8226\n", 829 | "Epoch 3/3\n", 830 | "875/875 [==============================] - 274s 302ms/step - loss: 0.3045 - accuracy: 0.8777 - val_loss: 0.4774 - val_accuracy: 0.8154\n", 831 | "\u001b[1m\u001b[92m Fold 4: 0.8226428627967834\u001b[0m\n", 832 | "\n", 833 | "Epoch 1/3\n", 834 | "875/875 [==============================] - 294s 306ms/step - loss: 0.6402 - accuracy: 0.7073 - val_loss: 0.4774 - val_accuracy: 0.8046\n", 835 | "Epoch 2/3\n", 836 | "875/875 [==============================] - 275s 303ms/step - loss: 0.4161 - accuracy: 0.8290 - val_loss: 0.4452 - val_accuracy: 0.8175\n", 837 | "Epoch 3/3\n", 838 | "875/875 [==============================] - 275s 303ms/step - loss: 0.3071 - accuracy: 0.8792 - val_loss: 0.4636 - val_accuracy: 0.8264\n", 839 | "\u001b[1m\u001b[92m Fold 5: 0.8263571262359619\u001b[0m\n", 840 | "\n", 841 | "\n", 842 | "\u001b[96m\u001b[1m\u001b[4m Mean Loss: 0.82091\u001b[0m\n", 843 | "CPU times: user 35min 7s, sys: 9min 30s, total: 44min 38s\n", 844 | "Wall time: 1h 17min 35s\n" 845 | ], 846 | "name": "stdout" 847 | } 848 | ] 849 | } 850 | ] 851 | } --------------------------------------------------------------------------------