├── images └── vsn_img.png ├── README.md ├── icr_competition.ipynb └── spam_prediction.ipynb /images/vsn_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducnh279/Variable-Selection-Network-with-PyTorch/HEAD/images/vsn_img.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variable Selection Network with PyTorch 2 | 3 | ## Introduction 4 | Welcome to this repository, where you'll find a PyTorch implementation of Variable-Selection Network (VSN). VSN is a deep neural network architecture designed to learn and select relevant input variables, enhancing both model interpretability and performance. I chose PyTorch for VSN implementation due to the absence of an existing version online. The closest reference is a Keras example titled ["Classification with Gated Residual and Variable Selection Networks"](https://keras.io/examples/structured_data/classification_with_grn_and_vsn/). 5 | 6 | ## Inspiration 7 | The VSN architecture is inspired by the paper ["Gated Residual and Variable Selection Networks for Tabular Data"](https://arxiv.org/pdf/1912.09363) and the winning solution of the Kaggle competition ["ICR - Identifying Age-Related Conditions"](https://www.kaggle.com/competitions/icr-identify-age-related-conditions), which utilized VSNs effectively. 8 | 9 | ## Dataset 10 | Experiments utilized the datasets from the Kaggle competition "ICR - Identifying Age-Related Conditions" & "Prediction of spam with Bayesian model". The VSN implementation performed well in this competition. 11 | 12 | | Dataset | Access Link | 13 | |----------------------------------------------|-------------| 14 | | ICR - Identifying Age-Related Conditions | [ICR](https://www.kaggle.com/competitions/icr-identify-age-related-conditions/data) | 15 | | Prediction of Spam with Bayesian Model | [Spam](https://www.kaggle.com/competitions/lets-surpass-the-hosts-bayesian-model/data) | 16 | --- 17 | -------------------------------------------------------------------------------- /icr_competition.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c34b377e", 6 | "metadata": { 7 | "papermill": { 8 | "duration": 0.007664, 9 | "end_time": "2024-01-01T19:29:31.284892", 10 | "exception": false, 11 | "start_time": "2024-01-01T19:29:31.277228", 12 | "status": "completed" 13 | }, 14 | "tags": [] 15 | }, 16 | "source": [ 17 | "# Imports" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "id": "2e552fd1", 24 | "metadata": { 25 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 26 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 27 | "execution": { 28 | "iopub.execute_input": "2024-01-01T19:29:31.301300Z", 29 | "iopub.status.busy": "2024-01-01T19:29:31.300886Z", 30 | "iopub.status.idle": "2024-01-01T19:29:44.882010Z", 31 | "shell.execute_reply": "2024-01-01T19:29:44.880775Z" 32 | }, 33 | "papermill": { 34 | "duration": 13.592588, 35 | "end_time": "2024-01-01T19:29:44.884824", 36 | "exception": false, 37 | "start_time": "2024-01-01T19:29:31.292236", 38 | "status": "completed" 39 | }, 40 | "tags": [] 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stderr", 45 | "output_type": "stream", 46 | "text": [ 47 | "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n", 48 | " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "import sys\n", 54 | "sys.path.append('/kaggle/input/iterativestratification')\n", 55 | "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n", 56 | "import pandas as pd\n", 57 | "import numpy as np\n", 58 | "from tqdm.auto import tqdm, trange\n", 59 | "import lightgbm as lgb\n", 60 | "from sklearn.compose import make_column_transformer\n", 61 | "from sklearn.metrics import accuracy_score\n", 62 | "from sklearn.model_selection import StratifiedKFold\n", 63 | "import warnings\n", 64 | "import matplotlib.pyplot as plt\n", 65 | "warnings.filterwarnings('ignore')\n", 66 | "\n", 67 | "import pandas as pd\n", 68 | "import numpy as np\n", 69 | "\n", 70 | "import time\n", 71 | "from tqdm import tqdm\n", 72 | "from pathlib import Path\n", 73 | "import multiprocessing as mp\n", 74 | "\n", 75 | "from sklearn.model_selection import StratifiedKFold\n", 76 | "from sklearn.model_selection import cross_val_score\n", 77 | "from sklearn.preprocessing import StandardScaler\n", 78 | "from sklearn.metrics import roc_auc_score\n", 79 | "\n", 80 | "from transformers import get_cosine_schedule_with_warmup\n", 81 | "\n", 82 | "import torch \n", 83 | "from torch import nn, optim\n", 84 | "import torch.nn.functional as F\n", 85 | "from torch.cuda.amp import GradScaler, autocast\n", 86 | "from torch.utils.data import Dataset, DataLoader" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 2, 92 | "id": "6b9f5527", 93 | "metadata": { 94 | "execution": { 95 | "iopub.execute_input": "2024-01-01T19:29:44.901672Z", 96 | "iopub.status.busy": "2024-01-01T19:29:44.900768Z", 97 | "iopub.status.idle": "2024-01-01T19:29:44.910838Z", 98 | "shell.execute_reply": "2024-01-01T19:29:44.909891Z" 99 | }, 100 | "papermill": { 101 | "duration": 0.020838, 102 | "end_time": "2024-01-01T19:29:44.913085", 103 | "exception": false, 104 | "start_time": "2024-01-01T19:29:44.892247", 105 | "status": "completed" 106 | }, 107 | "tags": [] 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "INPUT_PATH = Path('/kaggle/input/icr-identify-age-related-conditions')\n", 112 | "OUTPUT_PATH = Path('/kaggle/working')\n", 113 | "ORIGINAL_FEATURES = ['AB', 'AF', 'AH', 'AM', 'AR', 'AX', 'AY', 'AZ', 'BC', 'BD ', 'BN', 'BP',\n", 114 | " 'BQ', 'BR', 'BZ', 'CB', 'CC', 'CD ', 'CF', 'CH', 'CL', 'CR', 'CS', 'CU',\n", 115 | " 'CW ', 'DA', 'DE', 'DF', 'DH', 'DI', 'DL', 'DN', 'DU', 'DV', 'DY', 'EB',\n", 116 | " 'EE', 'EG', 'EH', 'EJ', 'EL', 'EP', 'EU', 'FC', 'FD ', 'FE', 'FI', 'FL',\n", 117 | " 'FR', 'FS', 'GB', 'GE', 'GF', 'GH', 'GI', 'GL']\n", 118 | "\n", 119 | "TARGET = 'Class'\n", 120 | "\n", 121 | "N_CORES = mp.cpu_count()\n", 122 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 123 | "NUM_FEATURES = len(ORIGINAL_FEATURES)\n", 124 | "BATCH_SIZE = 32\n", 125 | "N_EPOCHS = 20\n", 126 | "N_WARMUPS = 0\n", 127 | "LEARNING_RATE = 0.001\n", 128 | "WEIGHT_DECAY = 0.005\n", 129 | "SEED = 252" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "beced216", 135 | "metadata": { 136 | "papermill": { 137 | "duration": 0.006972, 138 | "end_time": "2024-01-01T19:29:44.927325", 139 | "exception": false, 140 | "start_time": "2024-01-01T19:29:44.920353", 141 | "status": "completed" 142 | }, 143 | "tags": [] 144 | }, 145 | "source": [ 146 | "# Read data" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 3, 152 | "id": "b54fe002", 153 | "metadata": { 154 | "execution": { 155 | "iopub.execute_input": "2024-01-01T19:29:44.943105Z", 156 | "iopub.status.busy": "2024-01-01T19:29:44.942730Z", 157 | "iopub.status.idle": "2024-01-01T19:29:45.005138Z", 158 | "shell.execute_reply": "2024-01-01T19:29:45.003966Z" 159 | }, 160 | "papermill": { 161 | "duration": 0.073392, 162 | "end_time": "2024-01-01T19:29:45.007812", 163 | "exception": false, 164 | "start_time": "2024-01-01T19:29:44.934420", 165 | "status": "completed" 166 | }, 167 | "tags": [] 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "train = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/train.csv').drop(['Id'], axis=1)\n", 172 | "test = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/test.csv')\n", 173 | "greeks = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/greeks.csv')" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "d4fc5577", 179 | "metadata": { 180 | "papermill": { 181 | "duration": 0.007123, 182 | "end_time": "2024-01-01T19:29:45.022621", 183 | "exception": false, 184 | "start_time": "2024-01-01T19:29:45.015498", 185 | "status": "completed" 186 | }, 187 | "tags": [] 188 | }, 189 | "source": [ 190 | "# Evaluation Metric" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 4, 196 | "id": "66de6b00", 197 | "metadata": { 198 | "execution": { 199 | "iopub.execute_input": "2024-01-01T19:29:45.039742Z", 200 | "iopub.status.busy": "2024-01-01T19:29:45.038897Z", 201 | "iopub.status.idle": "2024-01-01T19:29:45.046448Z", 202 | "shell.execute_reply": "2024-01-01T19:29:45.045586Z" 203 | }, 204 | "papermill": { 205 | "duration": 0.018844, 206 | "end_time": "2024-01-01T19:29:45.048815", 207 | "exception": false, 208 | "start_time": "2024-01-01T19:29:45.029971", 209 | "status": "completed" 210 | }, 211 | "tags": [] 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "import torch\n", 216 | "\n", 217 | "def score(y_true, y_pred):\n", 218 | "\n", 219 | " # Calculate the number of observations for each class\n", 220 | " N_0 = torch.sum(1 - y_true)\n", 221 | " N_1 = torch.sum(y_true)\n", 222 | " \n", 223 | " # Calculate the predicted probabilities for each class\n", 224 | " p_1 = torch.clamp(y_pred, 1e-15, 1 - 1e-15)\n", 225 | " p_0 = 1 - p_1\n", 226 | " \n", 227 | " # Calculate the average log loss for each class\n", 228 | " log_loss_0 = -torch.sum((1 - y_true) * torch.log(p_0)) / N_0\n", 229 | " log_loss_1 = -torch.sum(y_true * torch.log(p_1)) / N_1\n", 230 | " \n", 231 | " # Return the (not further weighted) average of the averages\n", 232 | " a = (log_loss_0 + log_loss_1) / 2\n", 233 | " \n", 234 | " return a\n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 5, 240 | "id": "65d76aa1", 241 | "metadata": { 242 | "execution": { 243 | "iopub.execute_input": "2024-01-01T19:29:45.065932Z", 244 | "iopub.status.busy": "2024-01-01T19:29:45.065218Z", 245 | "iopub.status.idle": "2024-01-01T19:29:45.072757Z", 246 | "shell.execute_reply": "2024-01-01T19:29:45.071985Z" 247 | }, 248 | "papermill": { 249 | "duration": 0.018734, 250 | "end_time": "2024-01-01T19:29:45.075039", 251 | "exception": false, 252 | "start_time": "2024-01-01T19:29:45.056305", 253 | "status": "completed" 254 | }, 255 | "tags": [] 256 | }, 257 | "outputs": [], 258 | "source": [ 259 | "def balanced_log_loss(y_true, y_pred):\n", 260 | " y_true = np.array(y_true)\n", 261 | " y_pred = np.array(y_pred)\n", 262 | " # y_true: correct labels 0, 1\n", 263 | " # y_pred: predicted probabilities of class=1\n", 264 | " # Implements the Evaluation equation with w_0 = w_1 = 1.\n", 265 | " # Calculate the number of observations for each class\n", 266 | " N_0 = np.sum(1 - y_true)\n", 267 | " N_1 = np.sum(y_true)\n", 268 | " # Calculate the predicted probabilities for each class\n", 269 | " p_1 = np.clip(y_pred, 1e-15, 1 - 1e-15)\n", 270 | " p_0 = 1 - p_1\n", 271 | " # Calculate the average log loss for each class\n", 272 | " log_loss_0 = -np.sum((1 - y_true) * np.log(p_0)) / N_0\n", 273 | " log_loss_1 = -np.sum(y_true * np.log(p_1)) / N_1\n", 274 | " # return the (not further weighted) average of the averages\n", 275 | " a = (log_loss_0 + log_loss_1)/2\n", 276 | " return 'Loss', a, False" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "id": "c0e361bb", 282 | "metadata": { 283 | "papermill": { 284 | "duration": 0.007147, 285 | "end_time": "2024-01-01T19:29:45.089796", 286 | "exception": false, 287 | "start_time": "2024-01-01T19:29:45.082649", 288 | "status": "completed" 289 | }, 290 | "tags": [] 291 | }, 292 | "source": [ 293 | "# 5-seed ensemble" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 6, 299 | "id": "b95ac841", 300 | "metadata": { 301 | "execution": { 302 | "iopub.execute_input": "2024-01-01T19:29:45.106903Z", 303 | "iopub.status.busy": "2024-01-01T19:29:45.106118Z", 304 | "iopub.status.idle": "2024-01-01T19:29:45.120541Z", 305 | "shell.execute_reply": "2024-01-01T19:29:45.119708Z" 306 | }, 307 | "papermill": { 308 | "duration": 0.025846, 309 | "end_time": "2024-01-01T19:29:45.123013", 310 | "exception": false, 311 | "start_time": "2024-01-01T19:29:45.097167", 312 | "status": "completed" 313 | }, 314 | "tags": [] 315 | }, 316 | "outputs": [], 317 | "source": [ 318 | "features = ['AB', 'AF', 'AH', 'AM', 'AR', 'AX', 'AY', 'AZ', 'BC', 'BD ', 'BN', 'BP',\n", 319 | " 'BQ', 'BR', 'BZ', 'CB', 'CC', 'CD ', 'CF', 'CH', 'CL', 'CR', 'CS', 'CU',\n", 320 | " 'CW ', 'DA', 'DE', 'DF', 'DH', 'DI', 'DL', 'DN', 'DU', 'DV', 'DY', 'EB',\n", 321 | " 'EE', 'EG', 'EH', 'EJ', 'EL', 'EP', 'EU', 'FC', 'FD ', 'FE', 'FI', 'FL',\n", 322 | " 'FR', 'FS', 'GB', 'GE', 'GF', 'GH', 'GI', 'GL']\n", 323 | "\n", 324 | "train['EJ'] = train['EJ'].str.strip()\n", 325 | "test['EJ'] = test['EJ'].str.strip()\n", 326 | "\n", 327 | "train['EJ'] = train['EJ'].map({'A': 1, 'B': 0})\n", 328 | "test['EJ'] = test['EJ'].map({'A': 1, 'B': 0})" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 7, 334 | "id": "24f2d921", 335 | "metadata": { 336 | "execution": { 337 | "iopub.execute_input": "2024-01-01T19:29:45.140285Z", 338 | "iopub.status.busy": "2024-01-01T19:29:45.139534Z", 339 | "iopub.status.idle": "2024-01-01T19:29:45.156933Z", 340 | "shell.execute_reply": "2024-01-01T19:29:45.156113Z" 341 | }, 342 | "papermill": { 343 | "duration": 0.028755, 344 | "end_time": "2024-01-01T19:29:45.159384", 345 | "exception": false, 346 | "start_time": "2024-01-01T19:29:45.130629", 347 | "status": "completed" 348 | }, 349 | "tags": [] 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)\n", 354 | "X = train[features]\n", 355 | "y = train['Class']\n", 356 | "\n", 357 | "for fold, (_, val_idx) in enumerate(skf.split(X, y)):\n", 358 | " train.loc[train.index.isin(val_idx), 'fold'] = fold + 1\n", 359 | "train['fold'] = train['fold'].astype(int)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 8, 365 | "id": "cfeee14a", 366 | "metadata": { 367 | "execution": { 368 | "iopub.execute_input": "2024-01-01T19:29:45.176830Z", 369 | "iopub.status.busy": "2024-01-01T19:29:45.176092Z", 370 | "iopub.status.idle": "2024-01-01T19:29:45.184622Z", 371 | "shell.execute_reply": "2024-01-01T19:29:45.183872Z" 372 | }, 373 | "papermill": { 374 | "duration": 0.019749, 375 | "end_time": "2024-01-01T19:29:45.186918", 376 | "exception": false, 377 | "start_time": "2024-01-01T19:29:45.167169", 378 | "status": "completed" 379 | }, 380 | "tags": [] 381 | }, 382 | "outputs": [], 383 | "source": [ 384 | "def prepare_folds(train, features, target):\n", 385 | " X = train[features]\n", 386 | " y = train[target]\n", 387 | " skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)\n", 388 | " scaler = StandardScaler()\n", 389 | " \n", 390 | " fold = 0\n", 391 | " for train_indices, val_indices in skf.split(X, y):\n", 392 | " fold += 1\n", 393 | " print(f'Preparing fold {fold} ...')\n", 394 | " df_train = train.loc[train.index.isin(train_indices)].reset_index(drop=True)\n", 395 | " df_val = train.loc[train.index.isin(val_indices)].reset_index(drop=True)\n", 396 | " \n", 397 | " df_train[features] = scaler.fit_transform(df_train[features])\n", 398 | " df_val[features] = scaler.transform(df_val[features])\n", 399 | " \n", 400 | " df_train.to_csv(f'df_train_fold_{fold}.csv', index=False)\n", 401 | " df_val.to_csv(f'df_val_fold_{fold}.csv', index=False)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 9, 407 | "id": "9d4d2037", 408 | "metadata": { 409 | "execution": { 410 | "iopub.execute_input": "2024-01-01T19:29:45.204241Z", 411 | "iopub.status.busy": "2024-01-01T19:29:45.203491Z", 412 | "iopub.status.idle": "2024-01-01T19:29:45.755298Z", 413 | "shell.execute_reply": "2024-01-01T19:29:45.753919Z" 414 | }, 415 | "papermill": { 416 | "duration": 0.563143, 417 | "end_time": "2024-01-01T19:29:45.757695", 418 | "exception": false, 419 | "start_time": "2024-01-01T19:29:45.194552", 420 | "status": "completed" 421 | }, 422 | "tags": [] 423 | }, 424 | "outputs": [ 425 | { 426 | "name": "stdout", 427 | "output_type": "stream", 428 | "text": [ 429 | "Preparing fold 1 ...\n", 430 | "Preparing fold 2 ...\n", 431 | "Preparing fold 3 ...\n", 432 | "Preparing fold 4 ...\n", 433 | "Preparing fold 5 ...\n" 434 | ] 435 | } 436 | ], 437 | "source": [ 438 | "prepare_folds(train, ORIGINAL_FEATURES, TARGET)" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 10, 444 | "id": "71aebcb7", 445 | "metadata": { 446 | "execution": { 447 | "iopub.execute_input": "2024-01-01T19:29:45.776193Z", 448 | "iopub.status.busy": "2024-01-01T19:29:45.775356Z", 449 | "iopub.status.idle": "2024-01-01T19:29:45.991449Z", 450 | "shell.execute_reply": "2024-01-01T19:29:45.990230Z" 451 | }, 452 | "papermill": { 453 | "duration": 0.228705, 454 | "end_time": "2024-01-01T19:29:45.994501", 455 | "exception": false, 456 | "start_time": "2024-01-01T19:29:45.765796", 457 | "status": "completed" 458 | }, 459 | "tags": [] 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "from sklearn.impute import SimpleImputer\n", 464 | "imp = SimpleImputer(strategy='median')" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 11, 470 | "id": "49576712", 471 | "metadata": { 472 | "execution": { 473 | "iopub.execute_input": "2024-01-01T19:29:46.012575Z", 474 | "iopub.status.busy": "2024-01-01T19:29:46.011865Z", 475 | "iopub.status.idle": "2024-01-01T19:29:46.021457Z", 476 | "shell.execute_reply": "2024-01-01T19:29:46.020754Z" 477 | }, 478 | "papermill": { 479 | "duration": 0.020943, 480 | "end_time": "2024-01-01T19:29:46.023529", 481 | "exception": false, 482 | "start_time": "2024-01-01T19:29:46.002586", 483 | "status": "completed" 484 | }, 485 | "tags": [] 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "class SpamDataset(Dataset):\n", 490 | " def __init__(self, features, targets):\n", 491 | " self.features = torch.tensor(features, dtype=torch.float)\n", 492 | " self.targets = torch.tensor(targets, dtype=torch.long)\n", 493 | "\n", 494 | " def __getitem__(self, index):\n", 495 | " X = self.features[index]\n", 496 | " y = self.targets[index]\n", 497 | " return X, y\n", 498 | "\n", 499 | " def __len__(self):\n", 500 | " return self.targets.shape[0]\n", 501 | "\n", 502 | "\n", 503 | "def get_dataloader(features, targets, \n", 504 | " feature_names, \n", 505 | " target_name,\n", 506 | " batch_size,\n", 507 | " mode):\n", 508 | " if mode == 'train':\n", 509 | " shuffle = True\n", 510 | " drop_last = True\n", 511 | " else:\n", 512 | " shuffle = False\n", 513 | " drop_last = False\n", 514 | " \n", 515 | " torch.manual_seed(SEED)\n", 516 | " train_dataset = SpamDataset(\n", 517 | " features=features, \n", 518 | " targets=targets\n", 519 | " )\n", 520 | " \n", 521 | " data_loader = DataLoader(\n", 522 | " dataset=train_dataset,\n", 523 | " batch_size=batch_size,\n", 524 | " shuffle=shuffle,\n", 525 | " drop_last=drop_last,\n", 526 | " num_workers=N_CORES\n", 527 | " )\n", 528 | " \n", 529 | " return data_loader" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 12, 535 | "id": "7b204261", 536 | "metadata": { 537 | "execution": { 538 | "iopub.execute_input": "2024-01-01T19:29:46.041904Z", 539 | "iopub.status.busy": "2024-01-01T19:29:46.041063Z", 540 | "iopub.status.idle": "2024-01-01T19:29:46.066403Z", 541 | "shell.execute_reply": "2024-01-01T19:29:46.065445Z" 542 | }, 543 | "papermill": { 544 | "duration": 0.037087, 545 | "end_time": "2024-01-01T19:29:46.068710", 546 | "exception": false, 547 | "start_time": "2024-01-01T19:29:46.031623", 548 | "status": "completed" 549 | }, 550 | "tags": [] 551 | }, 552 | "outputs": [], 553 | "source": [ 554 | "class GatedLinearUnit(nn.Module):\n", 555 | " def __init__(self, input_size):\n", 556 | " super(GatedLinearUnit, self).__init__()\n", 557 | " self.linear = nn.Linear(input_size, input_size)\n", 558 | " self.gate = nn.Sequential(\n", 559 | " nn.Linear(input_size, input_size),\n", 560 | " nn.Sigmoid()\n", 561 | " )\n", 562 | " \n", 563 | " def forward(self, x):\n", 564 | " return self.linear(x) * self.gate(x)\n", 565 | " \n", 566 | " \n", 567 | "class GatedResidualNetwork(nn.Module):\n", 568 | " def __init__(self, input_size, hidden_size, dropout):\n", 569 | " super(GatedResidualNetwork, self).__init__()\n", 570 | " self.input_size = input_size\n", 571 | " self.hidden_size = hidden_size\n", 572 | " \n", 573 | " self.grn = nn.Sequential(\n", 574 | " nn.Linear(input_size, hidden_size),\n", 575 | " nn.ELU(),\n", 576 | " nn.Linear(hidden_size, hidden_size),\n", 577 | " nn.Dropout(dropout),\n", 578 | " GatedLinearUnit(hidden_size),\n", 579 | " )\n", 580 | " \n", 581 | " self.layer_norm = nn.LayerNorm(hidden_size)\n", 582 | " self.feature_projection = nn.Linear(input_size, hidden_size)\n", 583 | " \n", 584 | " def forward(self, inputs):\n", 585 | " x = self.grn(inputs)\n", 586 | " if inputs.shape[-1] != self.hidden_size:\n", 587 | " inputs = self.feature_projection(inputs)\n", 588 | " x = self.layer_norm(x + inputs)\n", 589 | " return x\n", 590 | " \n", 591 | "class VariableSelectionNetwork(nn.Module):\n", 592 | " def __init__(self, num_features, dense_units, hidden_size, dropout):\n", 593 | " super(VariableSelectionNetwork, self).__init__()\n", 594 | " self.num_features = num_features\n", 595 | " self.hidden_size = hidden_size\n", 596 | " self.grns = nn.ModuleList()\n", 597 | " for _ in range(num_features):\n", 598 | " self.grns.append(GatedResidualNetwork(dense_units, hidden_size, dropout))\n", 599 | " \n", 600 | " \n", 601 | " self.grn_concat = GatedResidualNetwork(num_features*dense_units, hidden_size, dropout)\n", 602 | " self.softmax = nn.Sequential(\n", 603 | " nn.Linear(hidden_size, num_features),\n", 604 | " nn.Softmax(dim=-1)\n", 605 | " )\n", 606 | " \n", 607 | " def forward(self, inputs):\n", 608 | " v = torch.cat(inputs, dim=1)\n", 609 | " v = self.grn_concat(v)\n", 610 | " v = self.softmax(v)\n", 611 | " v = torch.unsqueeze(v, dim=-1)\n", 612 | " \n", 613 | " x = []\n", 614 | " for idx, input_ in enumerate(inputs):\n", 615 | " x.append(self.grns[idx](input_))\n", 616 | " x = torch.stack(x, dim=1)\n", 617 | " \n", 618 | " out = (v.transpose(2, 1) @ x).squeeze(dim=1)\n", 619 | " return out\n", 620 | " \n", 621 | "class VariableSelectionFlow(nn.Module):\n", 622 | " def __init__(self, num_features, hidden_size, dense_units, dropout):\n", 623 | " super(VariableSelectionFlow, self).__init__()\n", 624 | " self.variable_selection = VariableSelectionNetwork(num_features, dense_units, hidden_size, dropout)\n", 625 | " self.split = lambda x: torch.split(x, 1, dim=-1)\n", 626 | " self.dense_list = nn.ModuleList(\n", 627 | " [\n", 628 | " nn.Linear(1, dense_units) \n", 629 | " for _ in range(num_features)\n", 630 | " ]\n", 631 | " )\n", 632 | " \n", 633 | " \n", 634 | " def forward(self, inputs):\n", 635 | " split_inputs = self.split(inputs)\n", 636 | " x = []\n", 637 | " for split_input, linear in zip(split_inputs, self.dense_list):\n", 638 | " x.append(linear(split_input))\n", 639 | " \n", 640 | " return self.variable_selection(x)\n", 641 | "class Net(nn.Module):\n", 642 | " def __init__(self, num_features, dense_units, hidden_sizes, dropouts):\n", 643 | " super(Net, self).__init__()\n", 644 | " self.num_features = num_features\n", 645 | " self.dense_units = dense_units\n", 646 | " self.hidden_size_1 = hidden_sizes[0]\n", 647 | " self.hidden_size_2 = hidden_sizes[1]\n", 648 | " self.hidden_size_3 = hidden_sizes[2]\n", 649 | "\n", 650 | " self.dropout_1 = dropouts[0]\n", 651 | " self.dropout_2 = dropouts[1]\n", 652 | " self.dropout_3 = dropouts[2]\n", 653 | " \n", 654 | " self.variable_slection_flows = nn.Sequential(\n", 655 | " VariableSelectionFlow(num_features, self.hidden_size_1, dense_units, self.dropout_1),\n", 656 | " VariableSelectionFlow(self.hidden_size_1, self.hidden_size_2, dense_units, self.dropout_2),\n", 657 | " VariableSelectionFlow(self.hidden_size_2, self.hidden_size_3, dense_units, self.dropout_3),\n", 658 | " nn.Linear(self.hidden_size_3, 2)\n", 659 | " )\n", 660 | " \n", 661 | " def forward(self, x):\n", 662 | " logits = self.variable_slection_flows(x)\n", 663 | " return logits" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": 13, 669 | "id": "6058d4e7", 670 | "metadata": { 671 | "execution": { 672 | "iopub.execute_input": "2024-01-01T19:29:46.087005Z", 673 | "iopub.status.busy": "2024-01-01T19:29:46.086630Z", 674 | "iopub.status.idle": "2024-01-01T19:29:46.100234Z", 675 | "shell.execute_reply": "2024-01-01T19:29:46.099168Z" 676 | }, 677 | "papermill": { 678 | "duration": 0.025228, 679 | "end_time": "2024-01-01T19:29:46.102328", 680 | "exception": false, 681 | "start_time": "2024-01-01T19:29:46.077100", 682 | "status": "completed" 683 | }, 684 | "tags": [] 685 | }, 686 | "outputs": [], 687 | "source": [ 688 | "def fit(model, optimizer, scheduler, epochs, train_dataloader, val_dataloader):\n", 689 | "\n", 690 | " start_time = time.time()\n", 691 | " scaler = GradScaler()\n", 692 | "\n", 693 | " for epoch in range(epochs):\n", 694 | "\n", 695 | " model.train()\n", 696 | " \n", 697 | " for batch_idx, (features, targets) in enumerate(train_dataloader):\n", 698 | " features = features.to(DEVICE)\n", 699 | " targets = targets.to(DEVICE)\n", 700 | " with autocast():\n", 701 | " logits = model(features)\n", 702 | " probs = F.softmax(logits, dim=-1)[:, 1]\n", 703 | " loss = score(targets, probs)\n", 704 | "\n", 705 | " scaler.scale(loss).backward()\n", 706 | " scaler.step(optimizer)\n", 707 | " scaler.update()\n", 708 | " optimizer.zero_grad()\n", 709 | " scheduler.step()\n", 710 | "\n", 711 | " if not batch_idx % 10:\n", 712 | " print(\n", 713 | " f'Epoch: {epoch + 1}/{epochs}'\n", 714 | " f' | Batch: {batch_idx}/{len(train_dataloader)}'\n", 715 | " f' | Loss: {loss.detach().cpu().item():.4f}')\n", 716 | "\n", 717 | " if val_dataloader is not None:\n", 718 | " y_scores = torch.tensor([])\n", 719 | " y_true = torch.tensor([])\n", 720 | "\n", 721 | " with torch.inference_mode():\n", 722 | "\n", 723 | " model.eval()\n", 724 | "\n", 725 | " for batch_idx, (features, targets) in enumerate(val_dataloader):\n", 726 | " features = features.to(DEVICE)\n", 727 | " with autocast():\n", 728 | " logits = model(features).detach().cpu().type(torch.float)\n", 729 | " probs = F.softmax(logits, dim=-1)[:, 1]\n", 730 | " y_scores = torch.cat([y_scores, probs])\n", 731 | " y_true = torch.cat([y_true, targets])\n", 732 | "\n", 733 | " val_score = balanced_log_loss(y_true, y_scores)\n", 734 | " print('Validation score (AUC):', val_score)\n", 735 | "\n", 736 | " elapsed = (time.time() - start_time) / 60\n", 737 | " print(f'Total training time: {elapsed:.3f} min')\n", 738 | "\n", 739 | " model.eval()\n", 740 | "\n", 741 | " if val_dataloader is not None:\n", 742 | " return model, val_score\n", 743 | " else:\n", 744 | " return model" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 14, 750 | "id": "e7117426", 751 | "metadata": { 752 | "execution": { 753 | "iopub.execute_input": "2024-01-01T19:29:46.121104Z", 754 | "iopub.status.busy": "2024-01-01T19:29:46.120221Z", 755 | "iopub.status.idle": "2024-01-01T19:29:46.124942Z", 756 | "shell.execute_reply": "2024-01-01T19:29:46.124215Z" 757 | }, 758 | "papermill": { 759 | "duration": 0.016167, 760 | "end_time": "2024-01-01T19:29:46.126843", 761 | "exception": false, 762 | "start_time": "2024-01-01T19:29:46.110676", 763 | "status": "completed" 764 | }, 765 | "tags": [] 766 | }, 767 | "outputs": [], 768 | "source": [ 769 | "N_EPOCHS = 15\n", 770 | "N_WARMUPS = 100\n", 771 | "WEIGHT_DECAY = 0.005\n", 772 | "LEARNING_RATE = 0.01" 773 | ] 774 | }, 775 | { 776 | "cell_type": "code", 777 | "execution_count": 15, 778 | "id": "2cf1b342", 779 | "metadata": { 780 | "execution": { 781 | "iopub.execute_input": "2024-01-01T19:29:46.145332Z", 782 | "iopub.status.busy": "2024-01-01T19:29:46.144673Z", 783 | "iopub.status.idle": "2024-01-01T19:29:46.151028Z", 784 | "shell.execute_reply": "2024-01-01T19:29:46.150011Z" 785 | }, 786 | "papermill": { 787 | "duration": 0.018249, 788 | "end_time": "2024-01-01T19:29:46.153417", 789 | "exception": false, 790 | "start_time": "2024-01-01T19:29:46.135168", 791 | "status": "completed" 792 | }, 793 | "tags": [] 794 | }, 795 | "outputs": [], 796 | "source": [ 797 | "# val_scores = []\n", 798 | "# for fold in range(1, 6):\n", 799 | "# df_train = pd.read_csv(OUTPUT_PATH / f'df_train_fold_{fold}.csv')\n", 800 | "# df_val = pd.read_csv(OUTPUT_PATH / f'df_val_fold_{fold}.csv')\n", 801 | "# imp = SimpleImputer(strategy='median')\n", 802 | "# X_train = df_train[ORIGINAL_FEATURES]\n", 803 | "# y_train = df_train[TARGET]\n", 804 | "# X_val = df_val[ORIGINAL_FEATURES]\n", 805 | "# y_val = df_val[TARGET]\n", 806 | " \n", 807 | "# X_train, X_val = imp.fit_transform(X_train), imp.transform(X_val)\n", 808 | " \n", 809 | "# train_dataloader = get_dataloader(\n", 810 | "# X_train, y_train,\n", 811 | "# feature_names=ORIGINAL_FEATURES, \n", 812 | "# target_name=TARGET,\n", 813 | "# batch_size=BATCH_SIZE,\n", 814 | "# mode='train'\n", 815 | "# )\n", 816 | " \n", 817 | "# val_dataloader = get_dataloader(\n", 818 | "# X_val, y_val, \n", 819 | "# feature_names=ORIGINAL_FEATURES, \n", 820 | "# target_name=TARGET,\n", 821 | "# batch_size=BATCH_SIZE,\n", 822 | "# mode='val'\n", 823 | "# )\n", 824 | " \n", 825 | "# torch.manual_seed(SEED)\n", 826 | "# model = Net(\n", 827 | "# num_features=NUM_FEATURES, \n", 828 | "# dense_units=8,\n", 829 | "# hidden_sizes=[32, 32, 32],\n", 830 | "# dropouts=[0.75, 0.5, 0.1]\n", 831 | "# )\n", 832 | " \n", 833 | "# model.to(DEVICE)\n", 834 | "# model.train()\n", 835 | " \n", 836 | "# optimizer = optim.AdamW(\n", 837 | "# model.parameters(), \n", 838 | "# lr=LEARNING_RATE,\n", 839 | "# weight_decay=WEIGHT_DECAY\n", 840 | "# )\n", 841 | " \n", 842 | "# scheduler = get_cosine_schedule_with_warmup(\n", 843 | "# optimizer=optimizer, \n", 844 | "# num_warmup_steps=N_WARMUPS,\n", 845 | "# num_training_steps=len(train_dataloader)*N_EPOCHS\n", 846 | "# )\n", 847 | " \n", 848 | "# model, val_score = fit(model=model,\n", 849 | "# optimizer=optimizer,\n", 850 | "# scheduler=scheduler,\n", 851 | "# epochs=N_EPOCHS,\n", 852 | "# train_dataloader=train_dataloader,\n", 853 | "# val_dataloader=val_dataloader)\n", 854 | " \n", 855 | "# val_scores.append(val_score)" 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": 16, 861 | "id": "92dc4434", 862 | "metadata": { 863 | "execution": { 864 | "iopub.execute_input": "2024-01-01T19:29:46.171440Z", 865 | "iopub.status.busy": "2024-01-01T19:29:46.171085Z", 866 | "iopub.status.idle": "2024-01-01T19:29:46.176610Z", 867 | "shell.execute_reply": "2024-01-01T19:29:46.175869Z" 868 | }, 869 | "papermill": { 870 | "duration": 0.016751, 871 | "end_time": "2024-01-01T19:29:46.178598", 872 | "exception": false, 873 | "start_time": "2024-01-01T19:29:46.161847", 874 | "status": "completed" 875 | }, 876 | "tags": [] 877 | }, 878 | "outputs": [], 879 | "source": [ 880 | "test['Class'] = 0" 881 | ] 882 | }, 883 | { 884 | "cell_type": "code", 885 | "execution_count": 17, 886 | "id": "6f67d8cc", 887 | "metadata": { 888 | "execution": { 889 | "iopub.execute_input": "2024-01-01T19:29:46.197203Z", 890 | "iopub.status.busy": "2024-01-01T19:29:46.196547Z", 891 | "iopub.status.idle": "2024-01-01T19:30:59.588191Z", 892 | "shell.execute_reply": "2024-01-01T19:30:59.586463Z" 893 | }, 894 | "papermill": { 895 | "duration": 73.403646, 896 | "end_time": "2024-01-01T19:30:59.590753", 897 | "exception": false, 898 | "start_time": "2024-01-01T19:29:46.187107", 899 | "status": "completed" 900 | }, 901 | "tags": [] 902 | }, 903 | "outputs": [ 904 | { 905 | "name": "stdout", 906 | "output_type": "stream", 907 | "text": [ 908 | "Epoch: 1/15 | Batch: 0/19 | Loss: 0.7048\n", 909 | "Epoch: 1/15 | Batch: 10/19 | Loss: 0.6967\n", 910 | "Epoch: 2/15 | Batch: 0/19 | Loss: 0.6925\n", 911 | "Epoch: 2/15 | Batch: 10/19 | Loss: 0.7254\n", 912 | "Epoch: 3/15 | Batch: 0/19 | Loss: 0.5173\n", 913 | "Epoch: 3/15 | Batch: 10/19 | Loss: 0.7123\n", 914 | "Epoch: 4/15 | Batch: 0/19 | Loss: 0.4894\n", 915 | "Epoch: 4/15 | Batch: 10/19 | Loss: 0.2931\n", 916 | "Epoch: 5/15 | Batch: 0/19 | Loss: 0.4056\n", 917 | "Epoch: 5/15 | Batch: 10/19 | Loss: 0.4076\n", 918 | "Epoch: 6/15 | Batch: 0/19 | Loss: 0.4935\n", 919 | "Epoch: 6/15 | Batch: 10/19 | Loss: 0.3445\n", 920 | "Epoch: 7/15 | Batch: 0/19 | Loss: 0.5423\n", 921 | "Epoch: 7/15 | Batch: 10/19 | Loss: 0.5617\n", 922 | "Epoch: 8/15 | Batch: 0/19 | Loss: 0.3421\n", 923 | "Epoch: 8/15 | Batch: 10/19 | Loss: 0.7682\n", 924 | "Epoch: 9/15 | Batch: 0/19 | Loss: 0.5203\n", 925 | "Epoch: 9/15 | Batch: 10/19 | Loss: 0.3110\n", 926 | "Epoch: 10/15 | Batch: 0/19 | Loss: 0.4536\n", 927 | "Epoch: 10/15 | Batch: 10/19 | Loss: 0.5030\n", 928 | "Epoch: 11/15 | Batch: 0/19 | Loss: 0.5255\n", 929 | "Epoch: 11/15 | Batch: 10/19 | Loss: 0.2635\n", 930 | "Epoch: 12/15 | Batch: 0/19 | Loss: 0.3376\n", 931 | "Epoch: 12/15 | Batch: 10/19 | Loss: 0.4615\n", 932 | "Epoch: 13/15 | Batch: 0/19 | Loss: 0.5024\n", 933 | "Epoch: 13/15 | Batch: 10/19 | Loss: 0.2413\n", 934 | "Epoch: 14/15 | Batch: 0/19 | Loss: 0.4117\n", 935 | "Epoch: 14/15 | Batch: 10/19 | Loss: 0.2432\n", 936 | "Epoch: 15/15 | Batch: 0/19 | Loss: 0.3910\n", 937 | "Epoch: 15/15 | Batch: 10/19 | Loss: 0.3647\n", 938 | "Total training time: 1.213 min\n" 939 | ] 940 | } 941 | ], 942 | "source": [ 943 | "X_train = train[ORIGINAL_FEATURES]\n", 944 | "y_train = train[TARGET]\n", 945 | "X_test = test[ORIGINAL_FEATURES]\n", 946 | "y_test = test[TARGET]\n", 947 | "\n", 948 | "X_train, X_test = imp.fit_transform(X_train), imp.transform(X_test)\n", 949 | "\n", 950 | "train_dataloader = get_dataloader(\n", 951 | " X_train, y_train,\n", 952 | " feature_names=ORIGINAL_FEATURES, \n", 953 | " target_name=TARGET,\n", 954 | " batch_size=BATCH_SIZE,\n", 955 | " mode='train'\n", 956 | ")\n", 957 | "\n", 958 | "test_dataloader = get_dataloader(\n", 959 | " X_test, y_test, \n", 960 | " feature_names=ORIGINAL_FEATURES, \n", 961 | " target_name=TARGET,\n", 962 | " batch_size=BATCH_SIZE,\n", 963 | " mode='val'\n", 964 | ")\n", 965 | "\n", 966 | "torch.manual_seed(SEED)\n", 967 | "model = Net(\n", 968 | " num_features=NUM_FEATURES, \n", 969 | " dense_units=8,\n", 970 | " hidden_sizes=[32, 32, 32],\n", 971 | " dropouts=[0.75, 0.5, 0.25]\n", 972 | ")\n", 973 | "\n", 974 | "model.to(DEVICE)\n", 975 | "model.train()\n", 976 | "\n", 977 | "optimizer = optim.AdamW(\n", 978 | " model.parameters(), \n", 979 | " lr=LEARNING_RATE,\n", 980 | " weight_decay=WEIGHT_DECAY\n", 981 | ")\n", 982 | "\n", 983 | "scheduler = get_cosine_schedule_with_warmup(\n", 984 | " optimizer=optimizer, \n", 985 | " num_warmup_steps=N_WARMUPS,\n", 986 | " num_training_steps=len(train_dataloader)*N_EPOCHS\n", 987 | ")\n", 988 | "\n", 989 | "model = fit(model=model,\n", 990 | " optimizer=optimizer,\n", 991 | " scheduler=scheduler,\n", 992 | " epochs=N_EPOCHS,\n", 993 | " train_dataloader=train_dataloader,\n", 994 | " val_dataloader=None)" 995 | ] 996 | }, 997 | { 998 | "cell_type": "code", 999 | "execution_count": 18, 1000 | "id": "e157b23d", 1001 | "metadata": { 1002 | "execution": { 1003 | "iopub.execute_input": "2024-01-01T19:30:59.616895Z", 1004 | "iopub.status.busy": "2024-01-01T19:30:59.616447Z", 1005 | "iopub.status.idle": "2024-01-01T19:30:59.838258Z", 1006 | "shell.execute_reply": "2024-01-01T19:30:59.837044Z" 1007 | }, 1008 | "papermill": { 1009 | "duration": 0.238962, 1010 | "end_time": "2024-01-01T19:30:59.840939", 1011 | "exception": false, 1012 | "start_time": "2024-01-01T19:30:59.601977", 1013 | "status": "completed" 1014 | }, 1015 | "tags": [] 1016 | }, 1017 | "outputs": [], 1018 | "source": [ 1019 | "y_scores = torch.tensor([])\n", 1020 | "\n", 1021 | "with torch.inference_mode():\n", 1022 | "\n", 1023 | " model.eval()\n", 1024 | "\n", 1025 | " for batch_idx, (features, targets) in enumerate(test_dataloader):\n", 1026 | " features = features.to(DEVICE)\n", 1027 | " with autocast():\n", 1028 | " logits = model(features).detach().cpu().type(torch.float)\n", 1029 | " probs = F.softmax(logits, dim=-1)\n", 1030 | " y_scores = torch.cat([y_scores, probs])" 1031 | ] 1032 | }, 1033 | { 1034 | "cell_type": "markdown", 1035 | "id": "6013174b", 1036 | "metadata": { 1037 | "papermill": { 1038 | "duration": 0.011217, 1039 | "end_time": "2024-01-01T19:30:59.864280", 1040 | "exception": false, 1041 | "start_time": "2024-01-01T19:30:59.853063", 1042 | "status": "completed" 1043 | }, 1044 | "tags": [] 1045 | }, 1046 | "source": [ 1047 | "# Submission" 1048 | ] 1049 | }, 1050 | { 1051 | "cell_type": "code", 1052 | "execution_count": 19, 1053 | "id": "9ad834b5", 1054 | "metadata": { 1055 | "execution": { 1056 | "iopub.execute_input": "2024-01-01T19:30:59.889223Z", 1057 | "iopub.status.busy": "2024-01-01T19:30:59.888758Z", 1058 | "iopub.status.idle": "2024-01-01T19:30:59.932660Z", 1059 | "shell.execute_reply": "2024-01-01T19:30:59.931613Z" 1060 | }, 1061 | "papermill": { 1062 | "duration": 0.059337, 1063 | "end_time": "2024-01-01T19:30:59.935071", 1064 | "exception": false, 1065 | "start_time": "2024-01-01T19:30:59.875734", 1066 | "status": "completed" 1067 | }, 1068 | "tags": [] 1069 | }, 1070 | "outputs": [ 1071 | { 1072 | "name": "stdout", 1073 | "output_type": "stream", 1074 | "text": [ 1075 | "Submission file saved!\n" 1076 | ] 1077 | }, 1078 | { 1079 | "data": { 1080 | "text/html": [ 1081 | "
\n", 1082 | "\n", 1095 | "\n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | "
Idclass_0class_1
000eed32682bb0.9238020.076198
1010ebe33f6680.9238020.076198
202fa521e18380.9238020.076198
3040e15f562a20.9238020.076198
4046e85c7cc7f0.9238020.076198
\n", 1137 | "
" 1138 | ], 1139 | "text/plain": [ 1140 | " Id class_0 class_1\n", 1141 | "0 00eed32682bb 0.923802 0.076198\n", 1142 | "1 010ebe33f668 0.923802 0.076198\n", 1143 | "2 02fa521e1838 0.923802 0.076198\n", 1144 | "3 040e15f562a2 0.923802 0.076198\n", 1145 | "4 046e85c7cc7f 0.923802 0.076198" 1146 | ] 1147 | }, 1148 | "execution_count": 19, 1149 | "metadata": {}, 1150 | "output_type": "execute_result" 1151 | } 1152 | ], 1153 | "source": [ 1154 | "sub = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv')\n", 1155 | "sub[['class_0', 'class_1']] = y_scores\n", 1156 | "sub.to_csv('submission.csv', index=False)\n", 1157 | "print('Submission file saved!')\n", 1158 | "sub" 1159 | ] 1160 | }, 1161 | { 1162 | "cell_type": "code", 1163 | "execution_count": null, 1164 | "id": "9b514427", 1165 | "metadata": { 1166 | "papermill": { 1167 | "duration": 0.011463, 1168 | "end_time": "2024-01-01T19:30:59.958503", 1169 | "exception": false, 1170 | "start_time": "2024-01-01T19:30:59.947040", 1171 | "status": "completed" 1172 | }, 1173 | "tags": [] 1174 | }, 1175 | "outputs": [], 1176 | "source": [] 1177 | } 1178 | ], 1179 | "metadata": { 1180 | "kaggle": { 1181 | "accelerator": "none", 1182 | "dataSources": [ 1183 | { 1184 | "databundleVersionId": 5687476, 1185 | "sourceId": 52784, 1186 | "sourceType": "competition" 1187 | }, 1188 | { 1189 | "databundleVersionId": 5768690, 1190 | "datasetId": 3273406, 1191 | "sourceId": 5693071, 1192 | "sourceType": "datasetVersion" 1193 | }, 1194 | { 1195 | "databundleVersionId": 7382017, 1196 | "datasetId": 930977, 1197 | "sourceId": 7292109, 1198 | "sourceType": "datasetVersion" 1199 | } 1200 | ], 1201 | "dockerImageVersionId": 30474, 1202 | "isGpuEnabled": false, 1203 | "isInternetEnabled": false, 1204 | "language": "python", 1205 | "sourceType": "notebook" 1206 | }, 1207 | "kernelspec": { 1208 | "display_name": "Python 3", 1209 | "language": "python", 1210 | "name": "python3" 1211 | }, 1212 | "language_info": { 1213 | "codemirror_mode": { 1214 | "name": "ipython", 1215 | "version": 3 1216 | }, 1217 | "file_extension": ".py", 1218 | "mimetype": "text/x-python", 1219 | "name": "python", 1220 | "nbconvert_exporter": "python", 1221 | "pygments_lexer": "ipython3", 1222 | "version": "3.10.10" 1223 | }, 1224 | "papermill": { 1225 | "default_parameters": {}, 1226 | "duration": 104.093568, 1227 | "end_time": "2024-01-01T19:31:02.855589", 1228 | "environment_variables": {}, 1229 | "exception": null, 1230 | "input_path": "__notebook__.ipynb", 1231 | "output_path": "__notebook__.ipynb", 1232 | "parameters": {}, 1233 | "start_time": "2024-01-01T19:29:18.762021", 1234 | "version": "2.4.0" 1235 | } 1236 | }, 1237 | "nbformat": 4, 1238 | "nbformat_minor": 5 1239 | } 1240 | -------------------------------------------------------------------------------- /spam_prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "91349eca", 7 | "metadata": { 8 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 9 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 10 | "execution": { 11 | "iopub.execute_input": "2024-01-05T17:14:43.763383Z", 12 | "iopub.status.busy": "2024-01-05T17:14:43.762477Z", 13 | "iopub.status.idle": "2024-01-05T17:14:50.370293Z", 14 | "shell.execute_reply": "2024-01-05T17:14:50.369488Z" 15 | }, 16 | "papermill": { 17 | "duration": 6.61799, 18 | "end_time": "2024-01-05T17:14:50.372582", 19 | "exception": false, 20 | "start_time": "2024-01-05T17:14:43.754592", 21 | "status": "completed" 22 | }, 23 | "tags": [] 24 | }, 25 | "outputs": [ 26 | { 27 | "name": "stderr", 28 | "output_type": "stream", 29 | "text": [ 30 | "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3\n", 31 | " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "import numpy as np\n", 37 | "import pandas as pd\n", 38 | "\n", 39 | "import time\n", 40 | "from tqdm import tqdm\n", 41 | "from pathlib import Path\n", 42 | "import multiprocessing as mp\n", 43 | "\n", 44 | "from sklearn.metrics import roc_auc_score\n", 45 | "from sklearn.preprocessing import StandardScaler\n", 46 | "from sklearn.model_selection import cross_val_score, StratifiedKFold\n", 47 | "\n", 48 | "from transformers import get_cosine_schedule_with_warmup\n", 49 | "\n", 50 | "import torch \n", 51 | "from torch import nn, optim\n", 52 | "import torch.nn.functional as F\n", 53 | "from torch.cuda.amp import GradScaler, autocast\n", 54 | "from torch.utils.data import Dataset, DataLoader" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "b0ee803f", 60 | "metadata": { 61 | "papermill": { 62 | "duration": 0.005714, 63 | "end_time": "2024-01-05T17:14:50.384668", 64 | "exception": false, 65 | "start_time": "2024-01-05T17:14:50.378954", 66 | "status": "completed" 67 | }, 68 | "tags": [] 69 | }, 70 | "source": [ 71 | "# General Settings" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 2, 77 | "id": "919d0ef5", 78 | "metadata": { 79 | "execution": { 80 | "iopub.execute_input": "2024-01-05T17:14:50.399076Z", 81 | "iopub.status.busy": "2024-01-05T17:14:50.397834Z", 82 | "iopub.status.idle": "2024-01-05T17:14:50.456406Z", 83 | "shell.execute_reply": "2024-01-05T17:14:50.455513Z" 84 | }, 85 | "papermill": { 86 | "duration": 0.067852, 87 | "end_time": "2024-01-05T17:14:50.458590", 88 | "exception": false, 89 | "start_time": "2024-01-05T17:14:50.390738", 90 | "status": "completed" 91 | }, 92 | "tags": [] 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "INPUT_PATH = Path('/kaggle/input/lets-surpass-the-hosts-bayesian-model')\n", 97 | "OUTPUT_PATH = Path('/kaggle/working')\n", 98 | "\n", 99 | "ORIGINAL_FEATURES = ['A', 'B', 'E', 'F', 'G']\n", 100 | "\n", 101 | "TARGET = 'Target'\n", 102 | "\n", 103 | "N_CORES = mp.cpu_count()\n", 104 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 105 | "NUM_FEATURES = len(ORIGINAL_FEATURES)\n", 106 | "BATCH_SIZE = 32\n", 107 | "N_EPOCHS = 5\n", 108 | "N_WARMUPS = 80\n", 109 | "LEARNING_RATE = 0.004\n", 110 | "WEIGHT_DECAY = 0.01\n", 111 | "SEED = 252" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "ccab7057", 117 | "metadata": { 118 | "papermill": { 119 | "duration": 0.005618, 120 | "end_time": "2024-01-05T17:14:50.470230", 121 | "exception": false, 122 | "start_time": "2024-01-05T17:14:50.464612", 123 | "status": "completed" 124 | }, 125 | "tags": [] 126 | }, 127 | "source": [ 128 | "# Load Data" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 3, 134 | "id": "edba1611", 135 | "metadata": { 136 | "execution": { 137 | "iopub.execute_input": "2024-01-05T17:14:50.484027Z", 138 | "iopub.status.busy": "2024-01-05T17:14:50.483173Z", 139 | "iopub.status.idle": "2024-01-05T17:14:50.504976Z", 140 | "shell.execute_reply": "2024-01-05T17:14:50.504151Z" 141 | }, 142 | "papermill": { 143 | "duration": 0.031072, 144 | "end_time": "2024-01-05T17:14:50.507226", 145 | "exception": false, 146 | "start_time": "2024-01-05T17:14:50.476154", 147 | "status": "completed" 148 | }, 149 | "tags": [] 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "train = pd.read_csv(INPUT_PATH / 'train_df.csv')\n", 154 | "train['Target'] = train['Target'].astype(int)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "id": "68a6c905", 160 | "metadata": { 161 | "papermill": { 162 | "duration": 0.005868, 163 | "end_time": "2024-01-05T17:14:50.519041", 164 | "exception": false, 165 | "start_time": "2024-01-05T17:14:50.513173", 166 | "status": "completed" 167 | }, 168 | "tags": [] 169 | }, 170 | "source": [ 171 | "# Split Folds" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 4, 177 | "id": "f8827c0e", 178 | "metadata": { 179 | "execution": { 180 | "iopub.execute_input": "2024-01-05T17:14:50.532855Z", 181 | "iopub.status.busy": "2024-01-05T17:14:50.532108Z", 182 | "iopub.status.idle": "2024-01-05T17:14:50.540434Z", 183 | "shell.execute_reply": "2024-01-05T17:14:50.539490Z" 184 | }, 185 | "papermill": { 186 | "duration": 0.017352, 187 | "end_time": "2024-01-05T17:14:50.542410", 188 | "exception": false, 189 | "start_time": "2024-01-05T17:14:50.525058", 190 | "status": "completed" 191 | }, 192 | "tags": [] 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "def prepare_folds(train, features, target):\n", 197 | " X = train[features]\n", 198 | " y = train[target]\n", 199 | " skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=1)\n", 200 | " scaler = StandardScaler()\n", 201 | " \n", 202 | " fold = 0\n", 203 | " for train_indices, val_indices in skf.split(X, y):\n", 204 | " fold += 1\n", 205 | " print(f'Preparing fold {fold} ...')\n", 206 | " df_train = train.loc[train.index.isin(train_indices)].reset_index(drop=True)\n", 207 | " df_val = train.loc[train.index.isin(val_indices)].reset_index(drop=True)\n", 208 | " \n", 209 | "# df_train[features] = scaler.fit_transform(df_train[features])\n", 210 | "# df_val[features] = scaler.transform(df_val[features])\n", 211 | " \n", 212 | " test = pd.read_csv(INPUT_PATH / 'test_df.csv')\n", 213 | " \n", 214 | "# test[features] = scaler.transform(test[features])\n", 215 | " test['Target'] = 0\n", 216 | " df_train.to_csv(f'df_train_fold_{fold}.csv', index=False)\n", 217 | " df_val.to_csv(f'df_val_fold_{fold}.csv', index=False)\n", 218 | " test.to_csv(f'test_fold_{fold}.csv', index=False)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 5, 224 | "id": "f4472a58", 225 | "metadata": { 226 | "execution": { 227 | "iopub.execute_input": "2024-01-05T17:14:50.556472Z", 228 | "iopub.status.busy": "2024-01-05T17:14:50.555601Z", 229 | "iopub.status.idle": "2024-01-05T17:14:50.725218Z", 230 | "shell.execute_reply": "2024-01-05T17:14:50.724036Z" 231 | }, 232 | "papermill": { 233 | "duration": 0.179071, 234 | "end_time": "2024-01-05T17:14:50.727467", 235 | "exception": false, 236 | "start_time": "2024-01-05T17:14:50.548396", 237 | "status": "completed" 238 | }, 239 | "tags": [] 240 | }, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "Preparing fold 1 ...\n", 247 | "Preparing fold 2 ...\n", 248 | "Preparing fold 3 ...\n", 249 | "Preparing fold 4 ...\n", 250 | "Preparing fold 5 ...\n", 251 | "Preparing fold 6 ...\n", 252 | "Preparing fold 7 ...\n", 253 | "Preparing fold 8 ...\n", 254 | "Preparing fold 9 ...\n", 255 | "Preparing fold 10 ...\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "prepare_folds(train, ORIGINAL_FEATURES, TARGET)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "id": "43edb98a", 266 | "metadata": { 267 | "papermill": { 268 | "duration": 0.0058, 269 | "end_time": "2024-01-05T17:14:50.739435", 270 | "exception": false, 271 | "start_time": "2024-01-05T17:14:50.733635", 272 | "status": "completed" 273 | }, 274 | "tags": [] 275 | }, 276 | "source": [ 277 | "# Dataset and DataLoader" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 6, 283 | "id": "a490c6b5", 284 | "metadata": { 285 | "execution": { 286 | "iopub.execute_input": "2024-01-05T17:14:50.753183Z", 287 | "iopub.status.busy": "2024-01-05T17:14:50.752806Z", 288 | "iopub.status.idle": "2024-01-05T17:14:50.761722Z", 289 | "shell.execute_reply": "2024-01-05T17:14:50.760854Z" 290 | }, 291 | "papermill": { 292 | "duration": 0.018273, 293 | "end_time": "2024-01-05T17:14:50.763716", 294 | "exception": false, 295 | "start_time": "2024-01-05T17:14:50.745443", 296 | "status": "completed" 297 | }, 298 | "tags": [] 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "class SpamDataset(Dataset):\n", 303 | " def __init__(self, features, targets):\n", 304 | " self.features = torch.tensor(features, dtype=torch.float)\n", 305 | " self.targets = torch.tensor(targets, dtype=torch.long)\n", 306 | "\n", 307 | " def __getitem__(self, index):\n", 308 | " X = self.features[index]\n", 309 | " y = self.targets[index]\n", 310 | " return X, y\n", 311 | "\n", 312 | " def __len__(self):\n", 313 | " return self.targets.shape[0]\n", 314 | "\n", 315 | "\n", 316 | "def get_dataloader(df, \n", 317 | " feature_names, \n", 318 | " target_name,\n", 319 | " batch_size,\n", 320 | " mode):\n", 321 | " if mode == 'train':\n", 322 | " shuffle = True\n", 323 | " drop_last = True\n", 324 | " else:\n", 325 | " shuffle = False\n", 326 | " drop_last = False\n", 327 | " \n", 328 | " torch.manual_seed(SEED)\n", 329 | " train_dataset = SpamDataset(\n", 330 | " features=df[feature_names].to_numpy(), \n", 331 | " targets=df[target_name].to_numpy()\n", 332 | " )\n", 333 | " \n", 334 | " data_loader = DataLoader(\n", 335 | " dataset=train_dataset,\n", 336 | " batch_size=batch_size,\n", 337 | " shuffle=shuffle,\n", 338 | " drop_last=drop_last,\n", 339 | " num_workers=N_CORES\n", 340 | " )\n", 341 | " \n", 342 | " return data_loader" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "id": "fb6b3248", 348 | "metadata": { 349 | "papermill": { 350 | "duration": 0.005775, 351 | "end_time": "2024-01-05T17:14:50.775670", 352 | "exception": false, 353 | "start_time": "2024-01-05T17:14:50.769895", 354 | "status": "completed" 355 | }, 356 | "tags": [] 357 | }, 358 | "source": [ 359 | "# Model" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 7, 365 | "id": "668d7f56", 366 | "metadata": { 367 | "execution": { 368 | "iopub.execute_input": "2024-01-05T17:14:50.789282Z", 369 | "iopub.status.busy": "2024-01-05T17:14:50.788904Z", 370 | "iopub.status.idle": "2024-01-05T17:14:50.807039Z", 371 | "shell.execute_reply": "2024-01-05T17:14:50.806070Z" 372 | }, 373 | "papermill": { 374 | "duration": 0.02739, 375 | "end_time": "2024-01-05T17:14:50.809009", 376 | "exception": false, 377 | "start_time": "2024-01-05T17:14:50.781619", 378 | "status": "completed" 379 | }, 380 | "tags": [] 381 | }, 382 | "outputs": [], 383 | "source": [ 384 | "class GatedLinearUnit(nn.Module):\n", 385 | " def __init__(self, input_size):\n", 386 | " super(GatedLinearUnit, self).__init__()\n", 387 | " self.linear = nn.Linear(input_size, input_size)\n", 388 | " self.gate = nn.Sequential(\n", 389 | " nn.Linear(input_size, input_size),\n", 390 | " nn.Sigmoid()\n", 391 | " )\n", 392 | " \n", 393 | " def forward(self, x):\n", 394 | " return self.linear(x) * self.gate(x)\n", 395 | " \n", 396 | " \n", 397 | "class GatedResidualNetwork(nn.Module):\n", 398 | " def __init__(self, input_size, hidden_size, dropout):\n", 399 | " super(GatedResidualNetwork, self).__init__()\n", 400 | " self.input_size = input_size\n", 401 | " self.hidden_size = hidden_size\n", 402 | " \n", 403 | " self.grn = nn.Sequential(\n", 404 | " nn.Linear(input_size, hidden_size),\n", 405 | " nn.ELU(),\n", 406 | " nn.Linear(hidden_size, hidden_size),\n", 407 | " nn.Dropout(dropout),\n", 408 | " GatedLinearUnit(hidden_size),\n", 409 | " )\n", 410 | " \n", 411 | " self.layer_norm = nn.LayerNorm(hidden_size)\n", 412 | " self.feature_projection = nn.Linear(input_size, hidden_size)\n", 413 | " \n", 414 | " def forward(self, inputs):\n", 415 | " x = self.grn(inputs)\n", 416 | " if inputs.shape[-1] != self.hidden_size:\n", 417 | " inputs = self.feature_projection(inputs)\n", 418 | " x = self.layer_norm(x + inputs)\n", 419 | " return x\n", 420 | " \n", 421 | "class VariableSelectionNetwork(nn.Module):\n", 422 | " def __init__(self, num_features, dense_units, hidden_size, dropout):\n", 423 | " super(VariableSelectionNetwork, self).__init__()\n", 424 | " self.num_features = num_features\n", 425 | " self.hidden_size = hidden_size\n", 426 | " self.grns = nn.ModuleList()\n", 427 | " for _ in range(num_features):\n", 428 | " self.grns.append(GatedResidualNetwork(dense_units, hidden_size, dropout))\n", 429 | " \n", 430 | " \n", 431 | " self.grn_concat = GatedResidualNetwork(num_features*dense_units, hidden_size, dropout)\n", 432 | " self.softmax = nn.Sequential(\n", 433 | " nn.Linear(hidden_size, num_features),\n", 434 | " nn.Softmax(dim=-1)\n", 435 | " )\n", 436 | " \n", 437 | " def forward(self, inputs):\n", 438 | " v = torch.cat(inputs, dim=1)\n", 439 | " v = self.grn_concat(v)\n", 440 | " v = self.softmax(v)\n", 441 | " v = torch.unsqueeze(v, dim=-1)\n", 442 | " \n", 443 | " x = []\n", 444 | " for idx, input_ in enumerate(inputs):\n", 445 | " x.append(self.grns[idx](input_))\n", 446 | " x = torch.stack(x, dim=1)\n", 447 | " \n", 448 | " out = (v.transpose(2, 1) @ x).squeeze(dim=1)\n", 449 | " return out\n", 450 | " \n", 451 | "class VariableSelectionFlow(nn.Module):\n", 452 | " def __init__(self, num_features, hidden_size, dense_units, dropout):\n", 453 | " super(VariableSelectionFlow, self).__init__()\n", 454 | " self.variable_selection = VariableSelectionNetwork(num_features, dense_units, hidden_size, dropout)\n", 455 | " self.split = lambda x: torch.split(x, 1, dim=-1)\n", 456 | " self.dense_list = nn.ModuleList(\n", 457 | " [\n", 458 | " nn.Linear(1, dense_units) \n", 459 | " for _ in range(num_features)\n", 460 | " ]\n", 461 | " )\n", 462 | " \n", 463 | " \n", 464 | " def forward(self, inputs):\n", 465 | " split_inputs = self.split(inputs)\n", 466 | " x = []\n", 467 | " for split_input, linear in zip(split_inputs, self.dense_list):\n", 468 | " x.append(linear(split_input))\n", 469 | " \n", 470 | " return self.variable_selection(x)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 8, 476 | "id": "596185e3", 477 | "metadata": { 478 | "execution": { 479 | "iopub.execute_input": "2024-01-05T17:14:50.823298Z", 480 | "iopub.status.busy": "2024-01-05T17:14:50.822450Z", 481 | "iopub.status.idle": "2024-01-05T17:14:50.829876Z", 482 | "shell.execute_reply": "2024-01-05T17:14:50.828956Z" 483 | }, 484 | "papermill": { 485 | "duration": 0.016799, 486 | "end_time": "2024-01-05T17:14:50.831889", 487 | "exception": false, 488 | "start_time": "2024-01-05T17:14:50.815090", 489 | "status": "completed" 490 | }, 491 | "tags": [] 492 | }, 493 | "outputs": [], 494 | "source": [ 495 | "class Net(nn.Module):\n", 496 | " def __init__(self, num_features, dense_units, hidden_size, dropout):\n", 497 | " super(Net, self).__init__()\n", 498 | " self.num_features = num_features\n", 499 | " self.dense_units = dense_units\n", 500 | " self.hidden_size = hidden_size\n", 501 | "\n", 502 | " self.dropout = dropout\n", 503 | " \n", 504 | " self.variable_slection_flows = nn.Sequential(\n", 505 | " VariableSelectionFlow(num_features, self.hidden_size, dense_units, self.dropout),\n", 506 | " nn.Linear(self.hidden_size, 2)\n", 507 | " )\n", 508 | " \n", 509 | " self.apply(self._init_weights)\n", 510 | "\n", 511 | " def _init_weights(self, module):\n", 512 | " if isinstance(module, nn.Linear):\n", 513 | " torch.nn.init.xavier_uniform_(module.weight)\n", 514 | " \n", 515 | " def forward(self, x):\n", 516 | " logits = self.variable_slection_flows(x)\n", 517 | " return logits" 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "id": "5f7bd44f", 523 | "metadata": { 524 | "papermill": { 525 | "duration": 0.006125, 526 | "end_time": "2024-01-05T17:14:50.844108", 527 | "exception": false, 528 | "start_time": "2024-01-05T17:14:50.837983", 529 | "status": "completed" 530 | }, 531 | "tags": [] 532 | }, 533 | "source": [ 534 | "# Training" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 9, 540 | "id": "8c508cca", 541 | "metadata": { 542 | "execution": { 543 | "iopub.execute_input": "2024-01-05T17:14:50.857821Z", 544 | "iopub.status.busy": "2024-01-05T17:14:50.857461Z", 545 | "iopub.status.idle": "2024-01-05T17:14:50.869526Z", 546 | "shell.execute_reply": "2024-01-05T17:14:50.868544Z" 547 | }, 548 | "papermill": { 549 | "duration": 0.021545, 550 | "end_time": "2024-01-05T17:14:50.871731", 551 | "exception": false, 552 | "start_time": "2024-01-05T17:14:50.850186", 553 | "status": "completed" 554 | }, 555 | "tags": [] 556 | }, 557 | "outputs": [], 558 | "source": [ 559 | "def fit(model, optimizer, scheduler, epochs, train_dataloader, val_dataloader):\n", 560 | "\n", 561 | " start_time = time.time()\n", 562 | " scaler = GradScaler()\n", 563 | "\n", 564 | " for epoch in range(epochs):\n", 565 | "\n", 566 | " model.train()\n", 567 | " \n", 568 | " for batch_idx, (features, targets) in enumerate(train_dataloader):\n", 569 | " features = features.to(DEVICE)\n", 570 | " targets = targets.to(DEVICE)\n", 571 | " with autocast():\n", 572 | " logits = model(features) \n", 573 | " loss = F.cross_entropy(logits, targets)\n", 574 | "\n", 575 | " scaler.scale(loss).backward()\n", 576 | " scaler.step(optimizer)\n", 577 | " scaler.update()\n", 578 | " optimizer.zero_grad()\n", 579 | "\n", 580 | "# if not batch_idx % 10:\n", 581 | "# print(\n", 582 | "# f'Epoch: {epoch + 1}/{epochs}'\n", 583 | "# f' | Batch: {batch_idx}/{len(train_dataloader)}'\n", 584 | "# f' | Loss: {loss.detach().cpu().item():.4f}')\n", 585 | "\n", 586 | " if val_dataloader is not None:\n", 587 | " y_scores = torch.tensor([])\n", 588 | " y_true = torch.tensor([])\n", 589 | "\n", 590 | " with torch.inference_mode():\n", 591 | "\n", 592 | " model.eval()\n", 593 | "\n", 594 | " for batch_idx, (features, targets) in enumerate(val_dataloader):\n", 595 | " features = features.to(DEVICE)\n", 596 | " with autocast():\n", 597 | " logits = model(features).detach().cpu().type(torch.float)\n", 598 | " probs = F.softmax(logits, dim=-1)[:, 1]\n", 599 | " y_scores = torch.cat([y_scores, probs])\n", 600 | " y_true = torch.cat([y_true, targets])\n", 601 | "\n", 602 | " val_score = roc_auc_score(y_true, y_scores)\n", 603 | " print('Validation score (AUC):', val_score.item())\n", 604 | "\n", 605 | " elapsed = (time.time() - start_time) / 60\n", 606 | " print(f'Total training time: {elapsed:.3f} min')\n", 607 | "\n", 608 | " model.eval()\n", 609 | "\n", 610 | " if val_dataloader is not None:\n", 611 | " return model, val_score\n", 612 | " else:\n", 613 | " return model" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": 10, 619 | "id": "d3781f98", 620 | "metadata": { 621 | "execution": { 622 | "iopub.execute_input": "2024-01-05T17:14:50.885901Z", 623 | "iopub.status.busy": "2024-01-05T17:14:50.885209Z", 624 | "iopub.status.idle": "2024-01-05T17:14:50.890099Z", 625 | "shell.execute_reply": "2024-01-05T17:14:50.889155Z" 626 | }, 627 | "papermill": { 628 | "duration": 0.01412, 629 | "end_time": "2024-01-05T17:14:50.892095", 630 | "exception": false, 631 | "start_time": "2024-01-05T17:14:50.877975", 632 | "status": "completed" 633 | }, 634 | "tags": [] 635 | }, 636 | "outputs": [], 637 | "source": [ 638 | "N_EPOCHS = 8\n", 639 | "LEARNING_RATE = 0.001\n", 640 | "WEIGHT_DECAY = 0.0\n", 641 | "SEEDS = [789, 279, 318, 2001, 1976, 1966, 1994, 252]" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 11, 647 | "id": "4bec664f", 648 | "metadata": { 649 | "execution": { 650 | "iopub.execute_input": "2024-01-05T17:14:50.906197Z", 651 | "iopub.status.busy": "2024-01-05T17:14:50.905304Z", 652 | "iopub.status.idle": "2024-01-05T17:19:00.587022Z", 653 | "shell.execute_reply": "2024-01-05T17:19:00.585910Z" 654 | }, 655 | "papermill": { 656 | "duration": 249.691262, 657 | "end_time": "2024-01-05T17:19:00.589369", 658 | "exception": false, 659 | "start_time": "2024-01-05T17:14:50.898107", 660 | "status": "completed" 661 | }, 662 | "tags": [] 663 | }, 664 | "outputs": [ 665 | { 666 | "name": "stdout", 667 | "output_type": "stream", 668 | "text": [ 669 | "Seed 789\n", 670 | "Starting fold 1 ...\n", 671 | "Validation score (AUC): 0.8086734693877551\n", 672 | "Validation score (AUC): 0.8418367346938775\n", 673 | "Validation score (AUC): 0.8596938775510203\n", 674 | "Validation score (AUC): 0.8571428571428572\n", 675 | "Validation score (AUC): 0.8367346938775511\n", 676 | "Validation score (AUC): 0.8443877551020408\n", 677 | "Validation score (AUC): 0.854591836734694\n", 678 | "Validation score (AUC): 0.8545918367346939\n", 679 | "Total training time: 0.054 min\n", 680 | "Starting fold 2 ...\n", 681 | "Validation score (AUC): 0.7704081632653061\n", 682 | "Validation score (AUC): 0.8061224489795918\n", 683 | "Validation score (AUC): 0.7882653061224489\n", 684 | "Validation score (AUC): 0.7755102040816326\n", 685 | "Validation score (AUC): 0.7551020408163265\n", 686 | "Validation score (AUC): 0.7576530612244898\n", 687 | "Validation score (AUC): 0.7372448979591837\n", 688 | "Validation score (AUC): 0.7423469387755102\n", 689 | "Total training time: 0.048 min\n", 690 | "Starting fold 3 ...\n", 691 | "Validation score (AUC): 0.7627551020408162\n", 692 | "Validation score (AUC): 0.8163265306122449\n", 693 | "Validation score (AUC): 0.8239795918367346\n", 694 | "Validation score (AUC): 0.826530612244898\n", 695 | "Validation score (AUC): 0.8392857142857142\n", 696 | "Validation score (AUC): 0.8443877551020408\n", 697 | "Validation score (AUC): 0.8443877551020409\n", 698 | "Validation score (AUC): 0.8392857142857143\n", 699 | "Total training time: 0.051 min\n", 700 | "Starting fold 4 ...\n", 701 | "Validation score (AUC): 0.8035714285714286\n", 702 | "Validation score (AUC): 0.8596938775510203\n", 703 | "Validation score (AUC): 0.8775510204081632\n", 704 | "Validation score (AUC): 0.8673469387755102\n", 705 | "Validation score (AUC): 0.8724489795918368\n", 706 | "Validation score (AUC): 0.875\n", 707 | "Validation score (AUC): 0.8801020408163265\n", 708 | "Validation score (AUC): 0.8801020408163265\n", 709 | "Total training time: 0.045 min\n", 710 | "Starting fold 5 ...\n", 711 | "Validation score (AUC): 0.6607142857142858\n", 712 | "Validation score (AUC): 0.75\n", 713 | "Validation score (AUC): 0.7908163265306123\n", 714 | "Validation score (AUC): 0.8163265306122449\n", 715 | "Validation score (AUC): 0.826530612244898\n", 716 | "Validation score (AUC): 0.8316326530612245\n", 717 | "Validation score (AUC): 0.8443877551020408\n", 718 | "Validation score (AUC): 0.8469387755102041\n", 719 | "Total training time: 0.053 min\n", 720 | "Starting fold 6 ...\n", 721 | "Validation score (AUC): 0.6505102040816327\n", 722 | "Validation score (AUC): 0.6760204081632653\n", 723 | "Validation score (AUC): 0.7066326530612245\n", 724 | "Validation score (AUC): 0.729591836734694\n", 725 | "Validation score (AUC): 0.7627551020408163\n", 726 | "Validation score (AUC): 0.7704081632653061\n", 727 | "Validation score (AUC): 0.7729591836734694\n", 728 | "Validation score (AUC): 0.7831632653061226\n", 729 | "Total training time: 0.050 min\n", 730 | "Starting fold 7 ...\n", 731 | "Validation score (AUC): 0.6962962962962962\n", 732 | "Validation score (AUC): 0.7382716049382716\n", 733 | "Validation score (AUC): 0.7679012345679012\n", 734 | "Validation score (AUC): 0.8074074074074075\n", 735 | "Validation score (AUC): 0.8296296296296296\n", 736 | "Validation score (AUC): 0.8320987654320987\n", 737 | "Validation score (AUC): 0.8370370370370369\n", 738 | "Validation score (AUC): 0.8419753086419752\n", 739 | "Total training time: 0.046 min\n", 740 | "Starting fold 8 ...\n", 741 | "Validation score (AUC): 0.7135802469135802\n", 742 | "Validation score (AUC): 0.7703703703703704\n", 743 | "Validation score (AUC): 0.8197530864197531\n", 744 | "Validation score (AUC): 0.8592592592592592\n", 745 | "Validation score (AUC): 0.8814814814814814\n", 746 | "Validation score (AUC): 0.9012345679012346\n", 747 | "Validation score (AUC): 0.9037037037037037\n", 748 | "Validation score (AUC): 0.9061728395061728\n", 749 | "Total training time: 0.050 min\n", 750 | "Starting fold 9 ...\n", 751 | "Validation score (AUC): 0.7851851851851851\n", 752 | "Validation score (AUC): 0.7876543209876543\n", 753 | "Validation score (AUC): 0.8074074074074074\n", 754 | "Validation score (AUC): 0.8222222222222222\n", 755 | "Validation score (AUC): 0.8098765432098765\n", 756 | "Validation score (AUC): 0.802469135802469\n", 757 | "Validation score (AUC): 0.8049382716049382\n", 758 | "Validation score (AUC): 0.8197530864197531\n", 759 | "Total training time: 0.049 min\n", 760 | "Starting fold 10 ...\n", 761 | "Validation score (AUC): 0.7925925925925925\n", 762 | "Validation score (AUC): 0.8345679012345679\n", 763 | "Validation score (AUC): 0.8320987654320987\n", 764 | "Validation score (AUC): 0.8395061728395061\n", 765 | "Validation score (AUC): 0.8222222222222222\n", 766 | "Validation score (AUC): 0.8222222222222222\n", 767 | "Validation score (AUC): 0.8469135802469135\n", 768 | "Validation score (AUC): 0.8469135802469135\n", 769 | "Total training time: 0.048 min\n", 770 | "Seed 279\n", 771 | "Starting fold 1 ...\n", 772 | "Validation score (AUC): 0.37244897959183676\n", 773 | "Validation score (AUC): 0.42857142857142855\n", 774 | "Validation score (AUC): 0.49744897959183676\n", 775 | "Validation score (AUC): 0.5663265306122449\n", 776 | "Validation score (AUC): 0.625\n", 777 | "Validation score (AUC): 0.75\n", 778 | "Validation score (AUC): 0.8239795918367347\n", 779 | "Validation score (AUC): 0.8418367346938775\n", 780 | "Total training time: 0.052 min\n", 781 | "Starting fold 2 ...\n", 782 | "Validation score (AUC): 0.3903061224489796\n", 783 | "Validation score (AUC): 0.45918367346938777\n", 784 | "Validation score (AUC): 0.5790816326530612\n", 785 | "Validation score (AUC): 0.6479591836734694\n", 786 | "Validation score (AUC): 0.7372448979591837\n", 787 | "Validation score (AUC): 0.7716836734693878\n", 788 | "Validation score (AUC): 0.7959183673469388\n", 789 | "Validation score (AUC): 0.778061224489796\n", 790 | "Total training time: 0.050 min\n", 791 | "Starting fold 3 ...\n", 792 | "Validation score (AUC): 0.5102040816326531\n", 793 | "Validation score (AUC): 0.6224489795918368\n", 794 | "Validation score (AUC): 0.7372448979591837\n", 795 | "Validation score (AUC): 0.7857142857142857\n", 796 | "Validation score (AUC): 0.8290816326530613\n", 797 | "Validation score (AUC): 0.8341836734693877\n", 798 | "Validation score (AUC): 0.8392857142857142\n", 799 | "Validation score (AUC): 0.8520408163265307\n", 800 | "Total training time: 0.050 min\n", 801 | "Starting fold 4 ...\n", 802 | "Validation score (AUC): 0.5\n", 803 | "Validation score (AUC): 0.6071428571428572\n", 804 | "Validation score (AUC): 0.7346938775510204\n", 805 | "Validation score (AUC): 0.8010204081632653\n", 806 | "Validation score (AUC): 0.8571428571428571\n", 807 | "Validation score (AUC): 0.8673469387755102\n", 808 | "Validation score (AUC): 0.8367346938775511\n", 809 | "Validation score (AUC): 0.8265306122448979\n", 810 | "Total training time: 0.048 min\n", 811 | "Starting fold 5 ...\n", 812 | "Validation score (AUC): 0.5191326530612245\n", 813 | "Validation score (AUC): 0.6326530612244898\n", 814 | "Validation score (AUC): 0.7295918367346939\n", 815 | "Validation score (AUC): 0.7602040816326532\n", 816 | "Validation score (AUC): 0.7423469387755103\n", 817 | "Validation score (AUC): 0.7474489795918368\n", 818 | "Validation score (AUC): 0.7474489795918368\n", 819 | "Validation score (AUC): 0.788265306122449\n", 820 | "Total training time: 0.047 min\n", 821 | "Starting fold 6 ...\n", 822 | "Validation score (AUC): 0.4948979591836734\n", 823 | "Validation score (AUC): 0.6045918367346939\n", 824 | "Validation score (AUC): 0.6505102040816326\n", 825 | "Validation score (AUC): 0.6658163265306123\n", 826 | "Validation score (AUC): 0.7117346938775511\n", 827 | "Validation score (AUC): 0.6785714285714286\n", 828 | "Validation score (AUC): 0.6913265306122449\n", 829 | "Validation score (AUC): 0.7142857142857143\n", 830 | "Total training time: 0.050 min\n", 831 | "Starting fold 7 ...\n", 832 | "Validation score (AUC): 0.48148148148148145\n", 833 | "Validation score (AUC): 0.6148148148148148\n", 834 | "Validation score (AUC): 0.7555555555555555\n", 835 | "Validation score (AUC): 0.8469135802469135\n", 836 | "Validation score (AUC): 0.851851851851852\n", 837 | "Validation score (AUC): 0.817283950617284\n", 838 | "Validation score (AUC): 0.8\n", 839 | "Validation score (AUC): 0.7950617283950616\n", 840 | "Total training time: 0.046 min\n", 841 | "Starting fold 8 ...\n", 842 | "Validation score (AUC): 0.41481481481481475\n", 843 | "Validation score (AUC): 0.5975308641975308\n", 844 | "Validation score (AUC): 0.6518518518518518\n", 845 | "Validation score (AUC): 0.6790123456790123\n", 846 | "Validation score (AUC): 0.6938271604938272\n", 847 | "Validation score (AUC): 0.7530864197530864\n", 848 | "Validation score (AUC): 0.8\n", 849 | "Validation score (AUC): 0.8469135802469137\n", 850 | "Total training time: 0.047 min\n", 851 | "Starting fold 9 ...\n", 852 | "Validation score (AUC): 0.2888888888888889\n", 853 | "Validation score (AUC): 0.362962962962963\n", 854 | "Validation score (AUC): 0.46172839506172836\n", 855 | "Validation score (AUC): 0.5506172839506173\n", 856 | "Validation score (AUC): 0.6444444444444444\n", 857 | "Validation score (AUC): 0.6839506172839506\n", 858 | "Validation score (AUC): 0.7358024691358025\n", 859 | "Validation score (AUC): 0.7530864197530864\n", 860 | "Total training time: 0.053 min\n", 861 | "Starting fold 10 ...\n", 862 | "Validation score (AUC): 0.4419753086419753\n", 863 | "Validation score (AUC): 0.46419753086419757\n", 864 | "Validation score (AUC): 0.5037037037037037\n", 865 | "Validation score (AUC): 0.528395061728395\n", 866 | "Validation score (AUC): 0.5876543209876544\n", 867 | "Validation score (AUC): 0.7234567901234568\n", 868 | "Validation score (AUC): 0.7901234567901234\n", 869 | "Validation score (AUC): 0.8123456790123456\n", 870 | "Total training time: 0.047 min\n", 871 | "Seed 318\n", 872 | "Starting fold 1 ...\n", 873 | "Validation score (AUC): 0.691326530612245\n", 874 | "Validation score (AUC): 0.7372448979591837\n", 875 | "Validation score (AUC): 0.7346938775510203\n", 876 | "Validation score (AUC): 0.7346938775510204\n", 877 | "Validation score (AUC): 0.7346938775510204\n", 878 | "Validation score (AUC): 0.7346938775510204\n", 879 | "Validation score (AUC): 0.7372448979591837\n", 880 | "Validation score (AUC): 0.7372448979591837\n", 881 | "Total training time: 0.049 min\n", 882 | "Starting fold 2 ...\n", 883 | "Validation score (AUC): 0.7397959183673469\n", 884 | "Validation score (AUC): 0.7040816326530612\n", 885 | "Validation score (AUC): 0.7066326530612245\n", 886 | "Validation score (AUC): 0.7015306122448979\n", 887 | "Validation score (AUC): 0.6989795918367346\n", 888 | "Validation score (AUC): 0.6989795918367346\n", 889 | "Validation score (AUC): 0.6989795918367346\n", 890 | "Validation score (AUC): 0.701530612244898\n", 891 | "Total training time: 0.052 min\n", 892 | "Starting fold 3 ...\n", 893 | "Validation score (AUC): 0.8647959183673468\n", 894 | "Validation score (AUC): 0.8520408163265306\n", 895 | "Validation score (AUC): 0.8469387755102041\n", 896 | "Validation score (AUC): 0.8469387755102041\n", 897 | "Validation score (AUC): 0.8469387755102041\n", 898 | "Validation score (AUC): 0.8469387755102041\n", 899 | "Validation score (AUC): 0.8469387755102041\n", 900 | "Validation score (AUC): 0.8520408163265306\n", 901 | "Total training time: 0.048 min\n", 902 | "Starting fold 4 ...\n", 903 | "Validation score (AUC): 0.7908163265306123\n", 904 | "Validation score (AUC): 0.8188775510204082\n", 905 | "Validation score (AUC): 0.7933673469387755\n", 906 | "Validation score (AUC): 0.7908163265306123\n", 907 | "Validation score (AUC): 0.7908163265306123\n", 908 | "Validation score (AUC): 0.7908163265306123\n", 909 | "Validation score (AUC): 0.7933673469387755\n", 910 | "Validation score (AUC): 0.7908163265306122\n", 911 | "Total training time: 0.051 min\n", 912 | "Starting fold 5 ...\n", 913 | "Validation score (AUC): 0.8418367346938775\n", 914 | "Validation score (AUC): 0.8494897959183673\n", 915 | "Validation score (AUC): 0.8520408163265305\n", 916 | "Validation score (AUC): 0.8558673469387754\n", 917 | "Validation score (AUC): 0.8571428571428572\n", 918 | "Validation score (AUC): 0.8545918367346939\n", 919 | "Validation score (AUC): 0.8520408163265305\n", 920 | "Validation score (AUC): 0.8520408163265306\n", 921 | "Total training time: 0.046 min\n", 922 | "Starting fold 6 ...\n", 923 | "Validation score (AUC): 0.701530612244898\n", 924 | "Validation score (AUC): 0.7066326530612245\n", 925 | "Validation score (AUC): 0.701530612244898\n", 926 | "Validation score (AUC): 0.701530612244898\n", 927 | "Validation score (AUC): 0.701530612244898\n", 928 | "Validation score (AUC): 0.7066326530612245\n", 929 | "Validation score (AUC): 0.7040816326530612\n", 930 | "Validation score (AUC): 0.7066326530612245\n", 931 | "Total training time: 0.047 min\n", 932 | "Starting fold 7 ...\n", 933 | "Validation score (AUC): 0.8222222222222222\n", 934 | "Validation score (AUC): 0.7999999999999999\n", 935 | "Validation score (AUC): 0.7753086419753085\n", 936 | "Validation score (AUC): 0.7679012345679012\n", 937 | "Validation score (AUC): 0.7703703703703703\n", 938 | "Validation score (AUC): 0.7703703703703703\n", 939 | "Validation score (AUC): 0.7753086419753085\n", 940 | "Validation score (AUC): 0.7777777777777778\n", 941 | "Total training time: 0.052 min\n", 942 | "Starting fold 8 ...\n", 943 | "Validation score (AUC): 0.7975308641975309\n", 944 | "Validation score (AUC): 0.8567901234567902\n", 945 | "Validation score (AUC): 0.8592592592592593\n", 946 | "Validation score (AUC): 0.8592592592592593\n", 947 | "Validation score (AUC): 0.8592592592592593\n", 948 | "Validation score (AUC): 0.8592592592592593\n", 949 | "Validation score (AUC): 0.8592592592592593\n", 950 | "Validation score (AUC): 0.8592592592592593\n", 951 | "Total training time: 0.046 min\n", 952 | "Starting fold 9 ...\n", 953 | "Validation score (AUC): 0.7407407407407407\n", 954 | "Validation score (AUC): 0.7530864197530864\n", 955 | "Validation score (AUC): 0.7358024691358025\n", 956 | "Validation score (AUC): 0.7382716049382716\n", 957 | "Validation score (AUC): 0.7382716049382716\n", 958 | "Validation score (AUC): 0.7382716049382716\n", 959 | "Validation score (AUC): 0.7382716049382716\n", 960 | "Validation score (AUC): 0.7407407407407407\n", 961 | "Total training time: 0.048 min\n", 962 | "Starting fold 10 ...\n", 963 | "Validation score (AUC): 0.6913580246913581\n", 964 | "Validation score (AUC): 0.8123456790123457\n", 965 | "Validation score (AUC): 0.8148148148148147\n", 966 | "Validation score (AUC): 0.8148148148148148\n", 967 | "Validation score (AUC): 0.817283950617284\n", 968 | "Validation score (AUC): 0.817283950617284\n", 969 | "Validation score (AUC): 0.817283950617284\n", 970 | "Validation score (AUC): 0.8197530864197531\n", 971 | "Total training time: 0.047 min\n", 972 | "Seed 2001\n", 973 | "Starting fold 1 ...\n", 974 | "Validation score (AUC): 0.7551020408163266\n", 975 | "Validation score (AUC): 0.7525510204081634\n", 976 | "Validation score (AUC): 0.7602040816326531\n", 977 | "Validation score (AUC): 0.798469387755102\n", 978 | "Validation score (AUC): 0.8214285714285715\n", 979 | "Validation score (AUC): 0.8367346938775511\n", 980 | "Validation score (AUC): 0.826530612244898\n", 981 | "Validation score (AUC): 0.8278061224489796\n", 982 | "Total training time: 0.046 min\n", 983 | "Starting fold 2 ...\n", 984 | "Validation score (AUC): 0.7091836734693877\n", 985 | "Validation score (AUC): 0.7193877551020409\n", 986 | "Validation score (AUC): 0.7117346938775511\n", 987 | "Validation score (AUC): 0.7117346938775511\n", 988 | "Validation score (AUC): 0.7244897959183673\n", 989 | "Validation score (AUC): 0.7219387755102041\n", 990 | "Validation score (AUC): 0.7244897959183675\n", 991 | "Validation score (AUC): 0.7219387755102041\n", 992 | "Total training time: 0.055 min\n", 993 | "Starting fold 3 ...\n", 994 | "Validation score (AUC): 0.8341836734693878\n", 995 | "Validation score (AUC): 0.8596938775510203\n", 996 | "Validation score (AUC): 0.8418367346938775\n", 997 | "Validation score (AUC): 0.8290816326530612\n", 998 | "Validation score (AUC): 0.8112244897959184\n", 999 | "Validation score (AUC): 0.8086734693877551\n", 1000 | "Validation score (AUC): 0.8112244897959184\n", 1001 | "Validation score (AUC): 0.8188775510204082\n", 1002 | "Total training time: 0.048 min\n", 1003 | "Starting fold 4 ...\n", 1004 | "Validation score (AUC): 0.8418367346938775\n", 1005 | "Validation score (AUC): 0.8341836734693877\n", 1006 | "Validation score (AUC): 0.8188775510204082\n", 1007 | "Validation score (AUC): 0.826530612244898\n", 1008 | "Validation score (AUC): 0.8392857142857143\n", 1009 | "Validation score (AUC): 0.8469387755102042\n", 1010 | "Validation score (AUC): 0.8520408163265306\n", 1011 | "Validation score (AUC): 0.8571428571428572\n", 1012 | "Total training time: 0.048 min\n", 1013 | "Starting fold 5 ...\n", 1014 | "Validation score (AUC): 0.7780612244897959\n", 1015 | "Validation score (AUC): 0.7959183673469388\n", 1016 | "Validation score (AUC): 0.8520408163265306\n", 1017 | "Validation score (AUC): 0.8673469387755102\n", 1018 | "Validation score (AUC): 0.8622448979591837\n", 1019 | "Validation score (AUC): 0.8596938775510204\n", 1020 | "Validation score (AUC): 0.8545918367346939\n", 1021 | "Validation score (AUC): 0.8571428571428571\n", 1022 | "Total training time: 0.051 min\n", 1023 | "Starting fold 6 ...\n", 1024 | "Validation score (AUC): 0.7716836734693877\n", 1025 | "Validation score (AUC): 0.7806122448979592\n", 1026 | "Validation score (AUC): 0.7602040816326531\n", 1027 | "Validation score (AUC): 0.7448979591836735\n", 1028 | "Validation score (AUC): 0.7602040816326531\n", 1029 | "Validation score (AUC): 0.7602040816326532\n", 1030 | "Validation score (AUC): 0.7602040816326531\n", 1031 | "Validation score (AUC): 0.7678571428571429\n", 1032 | "Total training time: 0.046 min\n", 1033 | "Starting fold 7 ...\n", 1034 | "Validation score (AUC): 0.8222222222222222\n", 1035 | "Validation score (AUC): 0.8074074074074074\n", 1036 | "Validation score (AUC): 0.7851851851851851\n", 1037 | "Validation score (AUC): 0.7925925925925926\n", 1038 | "Validation score (AUC): 0.8024691358024691\n", 1039 | "Validation score (AUC): 0.817283950617284\n", 1040 | "Validation score (AUC): 0.8246913580246913\n", 1041 | "Validation score (AUC): 0.8222222222222222\n", 1042 | "Total training time: 0.049 min\n", 1043 | "Starting fold 8 ...\n", 1044 | "Validation score (AUC): 0.8493827160493828\n", 1045 | "Validation score (AUC): 0.9111111111111111\n", 1046 | "Validation score (AUC): 0.9111111111111111\n", 1047 | "Validation score (AUC): 0.9111111111111111\n", 1048 | "Validation score (AUC): 0.9407407407407407\n", 1049 | "Validation score (AUC): 0.9432098765432099\n", 1050 | "Validation score (AUC): 0.9481481481481482\n", 1051 | "Validation score (AUC): 0.9407407407407407\n", 1052 | "Total training time: 0.052 min\n", 1053 | "Starting fold 9 ...\n", 1054 | "Validation score (AUC): 0.7530864197530864\n", 1055 | "Validation score (AUC): 0.7679012345679013\n", 1056 | "Validation score (AUC): 0.7950617283950617\n", 1057 | "Validation score (AUC): 0.7802469135802469\n", 1058 | "Validation score (AUC): 0.8\n", 1059 | "Validation score (AUC): 0.8098765432098766\n", 1060 | "Validation score (AUC): 0.7975308641975308\n", 1061 | "Validation score (AUC): 0.8024691358024691\n", 1062 | "Total training time: 0.047 min\n", 1063 | "Starting fold 10 ...\n", 1064 | "Validation score (AUC): 0.837037037037037\n", 1065 | "Validation score (AUC): 0.8469135802469135\n", 1066 | "Validation score (AUC): 0.8493827160493828\n", 1067 | "Validation score (AUC): 0.8567901234567901\n", 1068 | "Validation score (AUC): 0.8567901234567901\n", 1069 | "Validation score (AUC): 0.8567901234567901\n", 1070 | "Validation score (AUC): 0.8641975308641976\n", 1071 | "Validation score (AUC): 0.8691358024691358\n", 1072 | "Total training time: 0.051 min\n", 1073 | "Seed 1976\n", 1074 | "Starting fold 1 ...\n", 1075 | "Validation score (AUC): 0.19642857142857142\n", 1076 | "Validation score (AUC): 0.3979591836734694\n", 1077 | "Validation score (AUC): 0.7066326530612245\n", 1078 | "Validation score (AUC): 0.7653061224489796\n", 1079 | "Validation score (AUC): 0.8163265306122449\n", 1080 | "Validation score (AUC): 0.8035714285714286\n", 1081 | "Validation score (AUC): 0.8112244897959183\n", 1082 | "Validation score (AUC): 0.8061224489795918\n", 1083 | "Total training time: 0.046 min\n", 1084 | "Starting fold 2 ...\n", 1085 | "Validation score (AUC): 0.29974489795918363\n", 1086 | "Validation score (AUC): 0.451530612244898\n", 1087 | "Validation score (AUC): 0.6683673469387755\n", 1088 | "Validation score (AUC): 0.7193877551020408\n", 1089 | "Validation score (AUC): 0.7321428571428571\n", 1090 | "Validation score (AUC): 0.7397959183673469\n", 1091 | "Validation score (AUC): 0.7372448979591837\n", 1092 | "Validation score (AUC): 0.7372448979591837\n", 1093 | "Total training time: 0.047 min\n", 1094 | "Starting fold 3 ...\n", 1095 | "Validation score (AUC): 0.2755102040816326\n", 1096 | "Validation score (AUC): 0.4872448979591837\n", 1097 | "Validation score (AUC): 0.7219387755102041\n", 1098 | "Validation score (AUC): 0.7627551020408163\n", 1099 | "Validation score (AUC): 0.7933673469387755\n", 1100 | "Validation score (AUC): 0.8214285714285714\n", 1101 | "Validation score (AUC): 0.8392857142857142\n", 1102 | "Validation score (AUC): 0.846938775510204\n", 1103 | "Total training time: 0.053 min\n", 1104 | "Starting fold 4 ...\n", 1105 | "Validation score (AUC): 0.32142857142857145\n", 1106 | "Validation score (AUC): 0.6045918367346939\n", 1107 | "Validation score (AUC): 0.8660714285714286\n", 1108 | "Validation score (AUC): 0.8622448979591838\n", 1109 | "Validation score (AUC): 0.8545918367346939\n", 1110 | "Validation score (AUC): 0.8418367346938774\n", 1111 | "Validation score (AUC): 0.826530612244898\n", 1112 | "Validation score (AUC): 0.8290816326530612\n", 1113 | "Total training time: 0.046 min\n", 1114 | "Starting fold 5 ...\n", 1115 | "Validation score (AUC): 0.2576530612244898\n", 1116 | "Validation score (AUC): 0.4362244897959183\n", 1117 | "Validation score (AUC): 0.8214285714285714\n", 1118 | "Validation score (AUC): 0.8418367346938775\n", 1119 | "Validation score (AUC): 0.864795918367347\n", 1120 | "Validation score (AUC): 0.8596938775510204\n", 1121 | "Validation score (AUC): 0.8571428571428571\n", 1122 | "Validation score (AUC): 0.8596938775510203\n", 1123 | "Total training time: 0.049 min\n", 1124 | "Starting fold 6 ...\n", 1125 | "Validation score (AUC): 0.3622448979591837\n", 1126 | "Validation score (AUC): 0.5816326530612245\n", 1127 | "Validation score (AUC): 0.7576530612244898\n", 1128 | "Validation score (AUC): 0.7602040816326531\n", 1129 | "Validation score (AUC): 0.7678571428571429\n", 1130 | "Validation score (AUC): 0.7627551020408162\n", 1131 | "Validation score (AUC): 0.7538265306122449\n", 1132 | "Validation score (AUC): 0.760204081632653\n", 1133 | "Total training time: 0.048 min\n", 1134 | "Starting fold 7 ...\n", 1135 | "Validation score (AUC): 0.24074074074074076\n", 1136 | "Validation score (AUC): 0.5432098765432098\n", 1137 | "Validation score (AUC): 0.8444444444444444\n", 1138 | "Validation score (AUC): 0.8740740740740741\n", 1139 | "Validation score (AUC): 0.8617283950617284\n", 1140 | "Validation score (AUC): 0.8395061728395062\n", 1141 | "Validation score (AUC): 0.8320987654320987\n", 1142 | "Validation score (AUC): 0.8320987654320987\n", 1143 | "Total training time: 0.047 min\n", 1144 | "Starting fold 8 ...\n", 1145 | "Validation score (AUC): 0.25185185185185177\n", 1146 | "Validation score (AUC): 0.3851851851851852\n", 1147 | "Validation score (AUC): 0.8074074074074072\n", 1148 | "Validation score (AUC): 0.8790123456790123\n", 1149 | "Validation score (AUC): 0.9160493827160494\n", 1150 | "Validation score (AUC): 0.928395061728395\n", 1151 | "Validation score (AUC): 0.9308641975308641\n", 1152 | "Validation score (AUC): 0.9234567901234567\n", 1153 | "Total training time: 0.052 min\n", 1154 | "Starting fold 9 ...\n", 1155 | "Validation score (AUC): 0.2716049382716049\n", 1156 | "Validation score (AUC): 0.4493827160493827\n", 1157 | "Validation score (AUC): 0.6691358024691358\n", 1158 | "Validation score (AUC): 0.7407407407407407\n", 1159 | "Validation score (AUC): 0.7753086419753087\n", 1160 | "Validation score (AUC): 0.8\n", 1161 | "Validation score (AUC): 0.7728395061728395\n", 1162 | "Validation score (AUC): 0.7777777777777778\n", 1163 | "Total training time: 0.046 min\n", 1164 | "Starting fold 10 ...\n", 1165 | "Validation score (AUC): 0.31851851851851853\n", 1166 | "Validation score (AUC): 0.6691358024691358\n", 1167 | "Validation score (AUC): 0.8049382716049382\n", 1168 | "Validation score (AUC): 0.8395061728395062\n", 1169 | "Validation score (AUC): 0.8444444444444444\n", 1170 | "Validation score (AUC): 0.854320987654321\n", 1171 | "Validation score (AUC): 0.837037037037037\n", 1172 | "Validation score (AUC): 0.8469135802469135\n", 1173 | "Total training time: 0.048 min\n", 1174 | "Seed 1966\n", 1175 | "Starting fold 1 ...\n", 1176 | "Validation score (AUC): 0.4591836734693877\n", 1177 | "Validation score (AUC): 0.5612244897959183\n", 1178 | "Validation score (AUC): 0.7678571428571428\n", 1179 | "Validation score (AUC): 0.8367346938775511\n", 1180 | "Validation score (AUC): 0.8392857142857143\n", 1181 | "Validation score (AUC): 0.8443877551020408\n", 1182 | "Validation score (AUC): 0.8469387755102041\n", 1183 | "Validation score (AUC): 0.8418367346938777\n", 1184 | "Total training time: 0.050 min\n", 1185 | "Starting fold 2 ...\n", 1186 | "Validation score (AUC): 0.5586734693877551\n", 1187 | "Validation score (AUC): 0.6275510204081632\n", 1188 | "Validation score (AUC): 0.6989795918367346\n", 1189 | "Validation score (AUC): 0.7423469387755102\n", 1190 | "Validation score (AUC): 0.7551020408163265\n", 1191 | "Validation score (AUC): 0.7704081632653061\n", 1192 | "Validation score (AUC): 0.7704081632653061\n", 1193 | "Validation score (AUC): 0.7551020408163265\n", 1194 | "Total training time: 0.046 min\n", 1195 | "Starting fold 3 ...\n", 1196 | "Validation score (AUC): 0.7168367346938777\n", 1197 | "Validation score (AUC): 0.7908163265306123\n", 1198 | "Validation score (AUC): 0.8214285714285714\n", 1199 | "Validation score (AUC): 0.8112244897959184\n", 1200 | "Validation score (AUC): 0.7959183673469388\n", 1201 | "Validation score (AUC): 0.8061224489795918\n", 1202 | "Validation score (AUC): 0.8112244897959183\n", 1203 | "Validation score (AUC): 0.8188775510204082\n", 1204 | "Total training time: 0.055 min\n", 1205 | "Starting fold 4 ...\n", 1206 | "Validation score (AUC): 0.6556122448979592\n", 1207 | "Validation score (AUC): 0.7806122448979592\n", 1208 | "Validation score (AUC): 0.8724489795918366\n", 1209 | "Validation score (AUC): 0.8801020408163266\n", 1210 | "Validation score (AUC): 0.8673469387755102\n", 1211 | "Validation score (AUC): 0.8596938775510204\n", 1212 | "Validation score (AUC): 0.8673469387755102\n", 1213 | "Validation score (AUC): 0.8698979591836735\n", 1214 | "Total training time: 0.048 min\n", 1215 | "Starting fold 5 ...\n", 1216 | "Validation score (AUC): 0.75\n", 1217 | "Validation score (AUC): 0.778061224489796\n", 1218 | "Validation score (AUC): 0.7908163265306122\n", 1219 | "Validation score (AUC): 0.788265306122449\n", 1220 | "Validation score (AUC): 0.7678571428571429\n", 1221 | "Validation score (AUC): 0.7857142857142857\n", 1222 | "Validation score (AUC): 0.7908163265306123\n", 1223 | "Validation score (AUC): 0.8010204081632653\n", 1224 | "Total training time: 0.048 min\n", 1225 | "Starting fold 6 ...\n", 1226 | "Validation score (AUC): 0.5841836734693878\n", 1227 | "Validation score (AUC): 0.6913265306122449\n", 1228 | "Validation score (AUC): 0.7346938775510206\n", 1229 | "Validation score (AUC): 0.7448979591836735\n", 1230 | "Validation score (AUC): 0.7270408163265306\n", 1231 | "Validation score (AUC): 0.7372448979591836\n", 1232 | "Validation score (AUC): 0.7627551020408164\n", 1233 | "Validation score (AUC): 0.7653061224489797\n", 1234 | "Total training time: 0.051 min\n", 1235 | "Starting fold 7 ...\n", 1236 | "Validation score (AUC): 0.6987654320987654\n", 1237 | "Validation score (AUC): 0.7851851851851852\n", 1238 | "Validation score (AUC): 0.8592592592592592\n", 1239 | "Validation score (AUC): 0.8148148148148149\n", 1240 | "Validation score (AUC): 0.7950617283950617\n", 1241 | "Validation score (AUC): 0.8024691358024691\n", 1242 | "Validation score (AUC): 0.8074074074074075\n", 1243 | "Validation score (AUC): 0.8197530864197531\n", 1244 | "Total training time: 0.046 min\n", 1245 | "Starting fold 8 ...\n", 1246 | "Validation score (AUC): 0.6222222222222222\n", 1247 | "Validation score (AUC): 0.6617283950617284\n", 1248 | "Validation score (AUC): 0.7259259259259259\n", 1249 | "Validation score (AUC): 0.8469135802469135\n", 1250 | "Validation score (AUC): 0.9037037037037037\n", 1251 | "Validation score (AUC): 0.8987654320987654\n", 1252 | "Validation score (AUC): 0.9012345679012346\n", 1253 | "Validation score (AUC): 0.9061728395061728\n", 1254 | "Total training time: 0.050 min\n", 1255 | "Starting fold 9 ...\n", 1256 | "Validation score (AUC): 0.5135802469135802\n", 1257 | "Validation score (AUC): 0.6098765432098765\n", 1258 | "Validation score (AUC): 0.7185185185185186\n", 1259 | "Validation score (AUC): 0.7950617283950616\n", 1260 | "Validation score (AUC): 0.8222222222222222\n", 1261 | "Validation score (AUC): 0.8320987654320988\n", 1262 | "Validation score (AUC): 0.8345679012345679\n", 1263 | "Validation score (AUC): 0.8296296296296296\n", 1264 | "Total training time: 0.049 min\n", 1265 | "Starting fold 10 ...\n", 1266 | "Validation score (AUC): 0.5135802469135802\n", 1267 | "Validation score (AUC): 0.6148148148148148\n", 1268 | "Validation score (AUC): 0.711111111111111\n", 1269 | "Validation score (AUC): 0.8222222222222222\n", 1270 | "Validation score (AUC): 0.8246913580246913\n", 1271 | "Validation score (AUC): 0.8320987654320987\n", 1272 | "Validation score (AUC): 0.8345679012345679\n", 1273 | "Validation score (AUC): 0.837037037037037\n", 1274 | "Total training time: 0.046 min\n", 1275 | "Seed 1994\n", 1276 | "Starting fold 1 ...\n", 1277 | "Validation score (AUC): 0.6951530612244897\n", 1278 | "Validation score (AUC): 0.7295918367346939\n", 1279 | "Validation score (AUC): 0.7704081632653061\n", 1280 | "Validation score (AUC): 0.8061224489795918\n", 1281 | "Validation score (AUC): 0.8163265306122449\n", 1282 | "Validation score (AUC): 0.8239795918367347\n", 1283 | "Validation score (AUC): 0.854591836734694\n", 1284 | "Validation score (AUC): 0.8571428571428571\n", 1285 | "Total training time: 0.052 min\n", 1286 | "Starting fold 2 ...\n", 1287 | "Validation score (AUC): 0.6709183673469388\n", 1288 | "Validation score (AUC): 0.7040816326530612\n", 1289 | "Validation score (AUC): 0.7193877551020408\n", 1290 | "Validation score (AUC): 0.7321428571428571\n", 1291 | "Validation score (AUC): 0.7219387755102041\n", 1292 | "Validation score (AUC): 0.7040816326530612\n", 1293 | "Validation score (AUC): 0.7244897959183674\n", 1294 | "Validation score (AUC): 0.7321428571428572\n", 1295 | "Total training time: 0.046 min\n", 1296 | "Starting fold 3 ...\n", 1297 | "Validation score (AUC): 0.5255102040816327\n", 1298 | "Validation score (AUC): 0.5688775510204082\n", 1299 | "Validation score (AUC): 0.6377551020408163\n", 1300 | "Validation score (AUC): 0.6938775510204082\n", 1301 | "Validation score (AUC): 0.7142857142857143\n", 1302 | "Validation score (AUC): 0.7423469387755102\n", 1303 | "Validation score (AUC): 0.7653061224489797\n", 1304 | "Validation score (AUC): 0.7755102040816326\n", 1305 | "Total training time: 0.052 min\n", 1306 | "Starting fold 4 ...\n", 1307 | "Validation score (AUC): 0.6454081632653061\n", 1308 | "Validation score (AUC): 0.6709183673469388\n", 1309 | "Validation score (AUC): 0.7295918367346939\n", 1310 | "Validation score (AUC): 0.8086734693877551\n", 1311 | "Validation score (AUC): 0.8520408163265306\n", 1312 | "Validation score (AUC): 0.8520408163265307\n", 1313 | "Validation score (AUC): 0.8622448979591837\n", 1314 | "Validation score (AUC): 0.8673469387755104\n", 1315 | "Total training time: 0.052 min\n", 1316 | "Starting fold 5 ...\n", 1317 | "Validation score (AUC): 0.43367346938775514\n", 1318 | "Validation score (AUC): 0.4897959183673469\n", 1319 | "Validation score (AUC): 0.5586734693877551\n", 1320 | "Validation score (AUC): 0.639030612244898\n", 1321 | "Validation score (AUC): 0.6862244897959184\n", 1322 | "Validation score (AUC): 0.7270408163265305\n", 1323 | "Validation score (AUC): 0.7448979591836735\n", 1324 | "Validation score (AUC): 0.7704081632653061\n", 1325 | "Total training time: 0.047 min\n", 1326 | "Starting fold 6 ...\n", 1327 | "Validation score (AUC): 0.5816326530612246\n", 1328 | "Validation score (AUC): 0.6147959183673469\n", 1329 | "Validation score (AUC): 0.6479591836734694\n", 1330 | "Validation score (AUC): 0.6785714285714286\n", 1331 | "Validation score (AUC): 0.6989795918367347\n", 1332 | "Validation score (AUC): 0.7117346938775511\n", 1333 | "Validation score (AUC): 0.7270408163265306\n", 1334 | "Validation score (AUC): 0.7295918367346939\n", 1335 | "Total training time: 0.052 min\n", 1336 | "Starting fold 7 ...\n", 1337 | "Validation score (AUC): 0.5753086419753086\n", 1338 | "Validation score (AUC): 0.6296296296296297\n", 1339 | "Validation score (AUC): 0.674074074074074\n", 1340 | "Validation score (AUC): 0.7061728395061728\n", 1341 | "Validation score (AUC): 0.7259259259259259\n", 1342 | "Validation score (AUC): 0.7555555555555555\n", 1343 | "Validation score (AUC): 0.7580246913580246\n", 1344 | "Validation score (AUC): 0.782716049382716\n", 1345 | "Total training time: 0.048 min\n", 1346 | "Starting fold 8 ...\n", 1347 | "Validation score (AUC): 0.6641975308641975\n", 1348 | "Validation score (AUC): 0.688888888888889\n", 1349 | "Validation score (AUC): 0.7308641975308642\n", 1350 | "Validation score (AUC): 0.7876543209876542\n", 1351 | "Validation score (AUC): 0.8197530864197531\n", 1352 | "Validation score (AUC): 0.8320987654320988\n", 1353 | "Validation score (AUC): 0.8493827160493826\n", 1354 | "Validation score (AUC): 0.854320987654321\n", 1355 | "Total training time: 0.049 min\n", 1356 | "Starting fold 9 ...\n", 1357 | "Validation score (AUC): 0.7901234567901234\n", 1358 | "Validation score (AUC): 0.8123456790123458\n", 1359 | "Validation score (AUC): 0.8444444444444443\n", 1360 | "Validation score (AUC): 0.8666666666666666\n", 1361 | "Validation score (AUC): 0.8666666666666667\n", 1362 | "Validation score (AUC): 0.8716049382716049\n", 1363 | "Validation score (AUC): 0.8641975308641975\n", 1364 | "Validation score (AUC): 0.8814814814814814\n", 1365 | "Total training time: 0.052 min\n", 1366 | "Starting fold 10 ...\n", 1367 | "Validation score (AUC): 0.7407407407407407\n", 1368 | "Validation score (AUC): 0.7876543209876543\n", 1369 | "Validation score (AUC): 0.8320987654320988\n", 1370 | "Validation score (AUC): 0.8419753086419752\n", 1371 | "Validation score (AUC): 0.854320987654321\n", 1372 | "Validation score (AUC): 0.8567901234567901\n", 1373 | "Validation score (AUC): 0.8617283950617284\n", 1374 | "Validation score (AUC): 0.8567901234567902\n", 1375 | "Total training time: 0.046 min\n", 1376 | "Seed 252\n", 1377 | "Starting fold 1 ...\n", 1378 | "Validation score (AUC): 0.75\n", 1379 | "Validation score (AUC): 0.7448979591836735\n", 1380 | "Validation score (AUC): 0.7780612244897959\n", 1381 | "Validation score (AUC): 0.8341836734693878\n", 1382 | "Validation score (AUC): 0.8418367346938775\n", 1383 | "Validation score (AUC): 0.8571428571428572\n", 1384 | "Validation score (AUC): 0.8520408163265306\n", 1385 | "Validation score (AUC): 0.8622448979591838\n", 1386 | "Total training time: 0.049 min\n", 1387 | "Starting fold 2 ...\n", 1388 | "Validation score (AUC): 0.7142857142857142\n", 1389 | "Validation score (AUC): 0.6989795918367347\n", 1390 | "Validation score (AUC): 0.6964285714285714\n", 1391 | "Validation score (AUC): 0.7193877551020409\n", 1392 | "Validation score (AUC): 0.7448979591836734\n", 1393 | "Validation score (AUC): 0.7448979591836735\n", 1394 | "Validation score (AUC): 0.7423469387755102\n", 1395 | "Validation score (AUC): 0.7397959183673469\n", 1396 | "Total training time: 0.050 min\n", 1397 | "Starting fold 3 ...\n", 1398 | "Validation score (AUC): 0.7091836734693878\n", 1399 | "Validation score (AUC): 0.7295918367346939\n", 1400 | "Validation score (AUC): 0.7372448979591838\n", 1401 | "Validation score (AUC): 0.7755102040816326\n", 1402 | "Validation score (AUC): 0.8035714285714285\n", 1403 | "Validation score (AUC): 0.798469387755102\n", 1404 | "Validation score (AUC): 0.8163265306122449\n", 1405 | "Validation score (AUC): 0.8137755102040816\n", 1406 | "Total training time: 0.051 min\n", 1407 | "Starting fold 4 ...\n", 1408 | "Validation score (AUC): 0.8673469387755102\n", 1409 | "Validation score (AUC): 0.8622448979591837\n", 1410 | "Validation score (AUC): 0.8494897959183674\n", 1411 | "Validation score (AUC): 0.8494897959183674\n", 1412 | "Validation score (AUC): 0.8571428571428572\n", 1413 | "Validation score (AUC): 0.8673469387755102\n", 1414 | "Validation score (AUC): 0.8647959183673469\n", 1415 | "Validation score (AUC): 0.8673469387755103\n", 1416 | "Total training time: 0.052 min\n", 1417 | "Starting fold 5 ...\n", 1418 | "Validation score (AUC): 0.673469387755102\n", 1419 | "Validation score (AUC): 0.7270408163265306\n", 1420 | "Validation score (AUC): 0.7525510204081632\n", 1421 | "Validation score (AUC): 0.7716836734693878\n", 1422 | "Validation score (AUC): 0.7959183673469388\n", 1423 | "Validation score (AUC): 0.8035714285714286\n", 1424 | "Validation score (AUC): 0.8239795918367347\n", 1425 | "Validation score (AUC): 0.8137755102040816\n", 1426 | "Total training time: 0.047 min\n", 1427 | "Starting fold 6 ...\n", 1428 | "Validation score (AUC): 0.6938775510204082\n", 1429 | "Validation score (AUC): 0.6989795918367346\n", 1430 | "Validation score (AUC): 0.7117346938775511\n", 1431 | "Validation score (AUC): 0.7193877551020409\n", 1432 | "Validation score (AUC): 0.7295918367346939\n", 1433 | "Validation score (AUC): 0.7448979591836734\n", 1434 | "Validation score (AUC): 0.75\n", 1435 | "Validation score (AUC): 0.7474489795918368\n", 1436 | "Total training time: 0.049 min\n", 1437 | "Starting fold 7 ...\n", 1438 | "Validation score (AUC): 0.7382716049382716\n", 1439 | "Validation score (AUC): 0.7654320987654321\n", 1440 | "Validation score (AUC): 0.7703703703703703\n", 1441 | "Validation score (AUC): 0.8148148148148148\n", 1442 | "Validation score (AUC): 0.8197530864197531\n", 1443 | "Validation score (AUC): 0.8345679012345678\n", 1444 | "Validation score (AUC): 0.8296296296296296\n", 1445 | "Validation score (AUC): 0.8345679012345678\n", 1446 | "Total training time: 0.051 min\n", 1447 | "Starting fold 8 ...\n", 1448 | "Validation score (AUC): 0.8049382716049382\n", 1449 | "Validation score (AUC): 0.8123456790123457\n", 1450 | "Validation score (AUC): 0.8395061728395061\n", 1451 | "Validation score (AUC): 0.8666666666666667\n", 1452 | "Validation score (AUC): 0.8814814814814815\n", 1453 | "Validation score (AUC): 0.8987654320987655\n", 1454 | "Validation score (AUC): 0.9160493827160494\n", 1455 | "Validation score (AUC): 0.9209876543209876\n", 1456 | "Total training time: 0.047 min\n", 1457 | "Starting fold 9 ...\n", 1458 | "Validation score (AUC): 0.834567901234568\n", 1459 | "Validation score (AUC): 0.837037037037037\n", 1460 | "Validation score (AUC): 0.8592592592592593\n", 1461 | "Validation score (AUC): 0.8567901234567902\n", 1462 | "Validation score (AUC): 0.8790123456790123\n", 1463 | "Validation score (AUC): 0.8691358024691358\n", 1464 | "Validation score (AUC): 0.8617283950617284\n", 1465 | "Validation score (AUC): 0.8518518518518519\n", 1466 | "Total training time: 0.050 min\n", 1467 | "Starting fold 10 ...\n", 1468 | "Validation score (AUC): 0.8271604938271605\n", 1469 | "Validation score (AUC): 0.8395061728395062\n", 1470 | "Validation score (AUC): 0.8395061728395062\n", 1471 | "Validation score (AUC): 0.8592592592592592\n", 1472 | "Validation score (AUC): 0.8518518518518519\n", 1473 | "Validation score (AUC): 0.8691358024691358\n", 1474 | "Validation score (AUC): 0.8567901234567901\n", 1475 | "Validation score (AUC): 0.8567901234567901\n", 1476 | "Total training time: 0.049 min\n" 1477 | ] 1478 | } 1479 | ], 1480 | "source": [ 1481 | "SEED_PREDS = 0\n", 1482 | "SEED_SCORES = []\n", 1483 | "\n", 1484 | "for seed in SEEDS:\n", 1485 | " print('Seed', seed)\n", 1486 | "\n", 1487 | " val_scores = []\n", 1488 | " Y_SCORES = 0\n", 1489 | " \n", 1490 | " for fold in range(1, 11):\n", 1491 | " print(f'Starting fold {fold} ...')\n", 1492 | " df_train = pd.read_csv(OUTPUT_PATH / f'df_train_fold_{fold}.csv')\n", 1493 | " df_val = pd.read_csv(OUTPUT_PATH / f'df_val_fold_{fold}.csv')\n", 1494 | " test = pd.read_csv(OUTPUT_PATH / f'test_fold_{fold}.csv')\n", 1495 | "\n", 1496 | " train_dataloader = get_dataloader(\n", 1497 | " df_train, \n", 1498 | " feature_names=ORIGINAL_FEATURES, \n", 1499 | " target_name=TARGET,\n", 1500 | " batch_size=BATCH_SIZE,\n", 1501 | " mode='train'\n", 1502 | " )\n", 1503 | "\n", 1504 | " val_dataloader = get_dataloader(\n", 1505 | " df_val, \n", 1506 | " feature_names=ORIGINAL_FEATURES, \n", 1507 | " target_name=TARGET,\n", 1508 | " batch_size=BATCH_SIZE,\n", 1509 | " mode='val'\n", 1510 | " )\n", 1511 | "\n", 1512 | " test_dataloader = get_dataloader(\n", 1513 | " test, \n", 1514 | " feature_names=ORIGINAL_FEATURES, \n", 1515 | " target_name=TARGET,\n", 1516 | " batch_size=BATCH_SIZE,\n", 1517 | " mode='test'\n", 1518 | " )\n", 1519 | "\n", 1520 | " torch.manual_seed(seed)\n", 1521 | " model = Net(\n", 1522 | " num_features=NUM_FEATURES, \n", 1523 | " dense_units=8,\n", 1524 | " hidden_size=16,\n", 1525 | " dropout=0.05\n", 1526 | " )\n", 1527 | "\n", 1528 | " model.to(DEVICE)\n", 1529 | "\n", 1530 | " model.train()\n", 1531 | "\n", 1532 | " optimizer = optim.AdamW(\n", 1533 | " model.parameters(), \n", 1534 | " lr=LEARNING_RATE,\n", 1535 | " weight_decay=WEIGHT_DECAY\n", 1536 | " )\n", 1537 | "\n", 1538 | " model, val_score = fit(model=model,\n", 1539 | " optimizer=optimizer,\n", 1540 | " scheduler=None,\n", 1541 | " epochs=N_EPOCHS,\n", 1542 | " train_dataloader=train_dataloader,\n", 1543 | " val_dataloader=val_dataloader)\n", 1544 | "\n", 1545 | " val_scores.append(val_score)\n", 1546 | "\n", 1547 | " y_scores = torch.tensor([])\n", 1548 | " with torch.inference_mode():\n", 1549 | " model.eval()\n", 1550 | " for batch_idx, (features, targets) in enumerate(test_dataloader):\n", 1551 | " features = features.to(DEVICE)\n", 1552 | " with autocast():\n", 1553 | " logits = model(features).detach().cpu().type(torch.float)\n", 1554 | " probs = F.softmax(logits, dim=-1)[:, 1]\n", 1555 | " y_scores = torch.cat([y_scores, probs])\n", 1556 | "\n", 1557 | " Y_SCORES += y_scores / 10\n", 1558 | " \n", 1559 | " SEED_PREDS += Y_SCORES / len(SEEDS)\n", 1560 | " SEED_SCORES.append(np.mean(val_scores))" 1561 | ] 1562 | }, 1563 | { 1564 | "cell_type": "code", 1565 | "execution_count": 12, 1566 | "id": "a8c664c5", 1567 | "metadata": { 1568 | "execution": { 1569 | "iopub.execute_input": "2024-01-05T17:19:00.712075Z", 1570 | "iopub.status.busy": "2024-01-05T17:19:00.711736Z", 1571 | "iopub.status.idle": "2024-01-05T17:19:00.720775Z", 1572 | "shell.execute_reply": "2024-01-05T17:19:00.719820Z" 1573 | }, 1574 | "papermill": { 1575 | "duration": 0.071775, 1576 | "end_time": "2024-01-05T17:19:00.722771", 1577 | "exception": false, 1578 | "start_time": "2024-01-05T17:19:00.650996", 1579 | "status": "completed" 1580 | }, 1581 | "tags": [] 1582 | }, 1583 | "outputs": [ 1584 | { 1585 | "name": "stdout", 1586 | "output_type": "stream", 1587 | "text": [ 1588 | "0.8171630527210885\n", 1589 | "0.016486137137366824\n" 1590 | ] 1591 | }, 1592 | { 1593 | "data": { 1594 | "text/plain": [ 1595 | "[0.8361243386243388,\n", 1596 | " 0.8008427815570673,\n", 1597 | " 0.7837836986646509,\n", 1598 | " 0.8285333207357016,\n", 1599 | " 0.8219532627865961,\n", 1600 | " 0.8244633408919124,\n", 1601 | " 0.8107451499118167,\n", 1602 | " 0.830858528596624]" 1603 | ] 1604 | }, 1605 | "execution_count": 12, 1606 | "metadata": {}, 1607 | "output_type": "execute_result" 1608 | } 1609 | ], 1610 | "source": [ 1611 | "print(np.mean(SEED_SCORES))\n", 1612 | "print(np.std(SEED_SCORES))\n", 1613 | "\n", 1614 | "SEED_SCORES" 1615 | ] 1616 | }, 1617 | { 1618 | "cell_type": "code", 1619 | "execution_count": 13, 1620 | "id": "050de0df", 1621 | "metadata": { 1622 | "execution": { 1623 | "iopub.execute_input": "2024-01-05T17:19:00.843891Z", 1624 | "iopub.status.busy": "2024-01-05T17:19:00.843061Z", 1625 | "iopub.status.idle": "2024-01-05T17:19:00.866928Z", 1626 | "shell.execute_reply": "2024-01-05T17:19:00.866068Z" 1627 | }, 1628 | "papermill": { 1629 | "duration": 0.086833, 1630 | "end_time": "2024-01-05T17:19:00.869048", 1631 | "exception": false, 1632 | "start_time": "2024-01-05T17:19:00.782215", 1633 | "status": "completed" 1634 | }, 1635 | "tags": [] 1636 | }, 1637 | "outputs": [ 1638 | { 1639 | "data": { 1640 | "text/html": [ 1641 | "
\n", 1642 | "\n", 1655 | "\n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1661 | " \n", 1662 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | " \n", 1716 | " \n", 1717 | " \n", 1718 | " \n", 1719 | " \n", 1720 | "
idTarget
04200.553706
14210.758907
24220.904092
34230.927545
44240.916706
.........
2756950.913177
2766960.619613
2776970.914858
2786980.600448
2796990.834674
\n", 1721 | "

280 rows × 2 columns

\n", 1722 | "
" 1723 | ], 1724 | "text/plain": [ 1725 | " id Target\n", 1726 | "0 420 0.553706\n", 1727 | "1 421 0.758907\n", 1728 | "2 422 0.904092\n", 1729 | "3 423 0.927545\n", 1730 | "4 424 0.916706\n", 1731 | ".. ... ...\n", 1732 | "275 695 0.913177\n", 1733 | "276 696 0.619613\n", 1734 | "277 697 0.914858\n", 1735 | "278 698 0.600448\n", 1736 | "279 699 0.834674\n", 1737 | "\n", 1738 | "[280 rows x 2 columns]" 1739 | ] 1740 | }, 1741 | "execution_count": 13, 1742 | "metadata": {}, 1743 | "output_type": "execute_result" 1744 | } 1745 | ], 1746 | "source": [ 1747 | "sub = pd.read_csv(INPUT_PATH / 'submission.csv')\n", 1748 | "sub[TARGET] = SEED_PREDS\n", 1749 | "sub.to_csv('submission.csv', index=False)\n", 1750 | "sub" 1751 | ] 1752 | }, 1753 | { 1754 | "cell_type": "code", 1755 | "execution_count": null, 1756 | "id": "ff634bac", 1757 | "metadata": { 1758 | "papermill": { 1759 | "duration": 0.060313, 1760 | "end_time": "2024-01-05T17:19:00.990556", 1761 | "exception": false, 1762 | "start_time": "2024-01-05T17:19:00.930243", 1763 | "status": "completed" 1764 | }, 1765 | "tags": [] 1766 | }, 1767 | "outputs": [], 1768 | "source": [] 1769 | } 1770 | ], 1771 | "metadata": { 1772 | "kaggle": { 1773 | "accelerator": "gpu", 1774 | "dataSources": [ 1775 | { 1776 | "databundleVersionId": 7397785, 1777 | "sourceId": 66702, 1778 | "sourceType": "competition" 1779 | } 1780 | ], 1781 | "dockerImageVersionId": 30626, 1782 | "isGpuEnabled": true, 1783 | "isInternetEnabled": false, 1784 | "language": "python", 1785 | "sourceType": "notebook" 1786 | }, 1787 | "kernelspec": { 1788 | "display_name": "Python 3", 1789 | "language": "python", 1790 | "name": "python3" 1791 | }, 1792 | "language_info": { 1793 | "codemirror_mode": { 1794 | "name": "ipython", 1795 | "version": 3 1796 | }, 1797 | "file_extension": ".py", 1798 | "mimetype": "text/x-python", 1799 | "name": "python", 1800 | "nbconvert_exporter": "python", 1801 | "pygments_lexer": "ipython3", 1802 | "version": "3.10.12" 1803 | }, 1804 | "papermill": { 1805 | "default_parameters": {}, 1806 | "duration": 262.280111, 1807 | "end_time": "2024-01-05T17:19:02.573435", 1808 | "environment_variables": {}, 1809 | "exception": null, 1810 | "input_path": "__notebook__.ipynb", 1811 | "output_path": "__notebook__.ipynb", 1812 | "parameters": {}, 1813 | "start_time": "2024-01-05T17:14:40.293324", 1814 | "version": "2.4.0" 1815 | } 1816 | }, 1817 | "nbformat": 4, 1818 | "nbformat_minor": 5 1819 | } 1820 | --------------------------------------------------------------------------------