├── images
└── vsn_img.png
├── README.md
├── icr_competition.ipynb
└── spam_prediction.ipynb
/images/vsn_img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ducnh279/Variable-Selection-Network-with-PyTorch/HEAD/images/vsn_img.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variable Selection Network with PyTorch
2 |
3 | ## Introduction
4 | Welcome to this repository, where you'll find a PyTorch implementation of Variable-Selection Network (VSN). VSN is a deep neural network architecture designed to learn and select relevant input variables, enhancing both model interpretability and performance. I chose PyTorch for VSN implementation due to the absence of an existing version online. The closest reference is a Keras example titled ["Classification with Gated Residual and Variable Selection Networks"](https://keras.io/examples/structured_data/classification_with_grn_and_vsn/).
5 |
6 | ## Inspiration
7 | The VSN architecture is inspired by the paper ["Gated Residual and Variable Selection Networks for Tabular Data"](https://arxiv.org/pdf/1912.09363) and the winning solution of the Kaggle competition ["ICR - Identifying Age-Related Conditions"](https://www.kaggle.com/competitions/icr-identify-age-related-conditions), which utilized VSNs effectively.
8 |
9 | ## Dataset
10 | Experiments utilized the datasets from the Kaggle competition "ICR - Identifying Age-Related Conditions" & "Prediction of spam with Bayesian model". The VSN implementation performed well in this competition.
11 |
12 | | Dataset | Access Link |
13 | |----------------------------------------------|-------------|
14 | | ICR - Identifying Age-Related Conditions | [ICR](https://www.kaggle.com/competitions/icr-identify-age-related-conditions/data) |
15 | | Prediction of Spam with Bayesian Model | [Spam](https://www.kaggle.com/competitions/lets-surpass-the-hosts-bayesian-model/data) |
16 | ---
17 |
--------------------------------------------------------------------------------
/icr_competition.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "c34b377e",
6 | "metadata": {
7 | "papermill": {
8 | "duration": 0.007664,
9 | "end_time": "2024-01-01T19:29:31.284892",
10 | "exception": false,
11 | "start_time": "2024-01-01T19:29:31.277228",
12 | "status": "completed"
13 | },
14 | "tags": []
15 | },
16 | "source": [
17 | "# Imports"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 1,
23 | "id": "2e552fd1",
24 | "metadata": {
25 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
26 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
27 | "execution": {
28 | "iopub.execute_input": "2024-01-01T19:29:31.301300Z",
29 | "iopub.status.busy": "2024-01-01T19:29:31.300886Z",
30 | "iopub.status.idle": "2024-01-01T19:29:44.882010Z",
31 | "shell.execute_reply": "2024-01-01T19:29:44.880775Z"
32 | },
33 | "papermill": {
34 | "duration": 13.592588,
35 | "end_time": "2024-01-01T19:29:44.884824",
36 | "exception": false,
37 | "start_time": "2024-01-01T19:29:31.292236",
38 | "status": "completed"
39 | },
40 | "tags": []
41 | },
42 | "outputs": [
43 | {
44 | "name": "stderr",
45 | "output_type": "stream",
46 | "text": [
47 | "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
48 | " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
49 | ]
50 | }
51 | ],
52 | "source": [
53 | "import sys\n",
54 | "sys.path.append('/kaggle/input/iterativestratification')\n",
55 | "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
56 | "import pandas as pd\n",
57 | "import numpy as np\n",
58 | "from tqdm.auto import tqdm, trange\n",
59 | "import lightgbm as lgb\n",
60 | "from sklearn.compose import make_column_transformer\n",
61 | "from sklearn.metrics import accuracy_score\n",
62 | "from sklearn.model_selection import StratifiedKFold\n",
63 | "import warnings\n",
64 | "import matplotlib.pyplot as plt\n",
65 | "warnings.filterwarnings('ignore')\n",
66 | "\n",
67 | "import pandas as pd\n",
68 | "import numpy as np\n",
69 | "\n",
70 | "import time\n",
71 | "from tqdm import tqdm\n",
72 | "from pathlib import Path\n",
73 | "import multiprocessing as mp\n",
74 | "\n",
75 | "from sklearn.model_selection import StratifiedKFold\n",
76 | "from sklearn.model_selection import cross_val_score\n",
77 | "from sklearn.preprocessing import StandardScaler\n",
78 | "from sklearn.metrics import roc_auc_score\n",
79 | "\n",
80 | "from transformers import get_cosine_schedule_with_warmup\n",
81 | "\n",
82 | "import torch \n",
83 | "from torch import nn, optim\n",
84 | "import torch.nn.functional as F\n",
85 | "from torch.cuda.amp import GradScaler, autocast\n",
86 | "from torch.utils.data import Dataset, DataLoader"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 2,
92 | "id": "6b9f5527",
93 | "metadata": {
94 | "execution": {
95 | "iopub.execute_input": "2024-01-01T19:29:44.901672Z",
96 | "iopub.status.busy": "2024-01-01T19:29:44.900768Z",
97 | "iopub.status.idle": "2024-01-01T19:29:44.910838Z",
98 | "shell.execute_reply": "2024-01-01T19:29:44.909891Z"
99 | },
100 | "papermill": {
101 | "duration": 0.020838,
102 | "end_time": "2024-01-01T19:29:44.913085",
103 | "exception": false,
104 | "start_time": "2024-01-01T19:29:44.892247",
105 | "status": "completed"
106 | },
107 | "tags": []
108 | },
109 | "outputs": [],
110 | "source": [
111 | "INPUT_PATH = Path('/kaggle/input/icr-identify-age-related-conditions')\n",
112 | "OUTPUT_PATH = Path('/kaggle/working')\n",
113 | "ORIGINAL_FEATURES = ['AB', 'AF', 'AH', 'AM', 'AR', 'AX', 'AY', 'AZ', 'BC', 'BD ', 'BN', 'BP',\n",
114 | " 'BQ', 'BR', 'BZ', 'CB', 'CC', 'CD ', 'CF', 'CH', 'CL', 'CR', 'CS', 'CU',\n",
115 | " 'CW ', 'DA', 'DE', 'DF', 'DH', 'DI', 'DL', 'DN', 'DU', 'DV', 'DY', 'EB',\n",
116 | " 'EE', 'EG', 'EH', 'EJ', 'EL', 'EP', 'EU', 'FC', 'FD ', 'FE', 'FI', 'FL',\n",
117 | " 'FR', 'FS', 'GB', 'GE', 'GF', 'GH', 'GI', 'GL']\n",
118 | "\n",
119 | "TARGET = 'Class'\n",
120 | "\n",
121 | "N_CORES = mp.cpu_count()\n",
122 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
123 | "NUM_FEATURES = len(ORIGINAL_FEATURES)\n",
124 | "BATCH_SIZE = 32\n",
125 | "N_EPOCHS = 20\n",
126 | "N_WARMUPS = 0\n",
127 | "LEARNING_RATE = 0.001\n",
128 | "WEIGHT_DECAY = 0.005\n",
129 | "SEED = 252"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "id": "beced216",
135 | "metadata": {
136 | "papermill": {
137 | "duration": 0.006972,
138 | "end_time": "2024-01-01T19:29:44.927325",
139 | "exception": false,
140 | "start_time": "2024-01-01T19:29:44.920353",
141 | "status": "completed"
142 | },
143 | "tags": []
144 | },
145 | "source": [
146 | "# Read data"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 3,
152 | "id": "b54fe002",
153 | "metadata": {
154 | "execution": {
155 | "iopub.execute_input": "2024-01-01T19:29:44.943105Z",
156 | "iopub.status.busy": "2024-01-01T19:29:44.942730Z",
157 | "iopub.status.idle": "2024-01-01T19:29:45.005138Z",
158 | "shell.execute_reply": "2024-01-01T19:29:45.003966Z"
159 | },
160 | "papermill": {
161 | "duration": 0.073392,
162 | "end_time": "2024-01-01T19:29:45.007812",
163 | "exception": false,
164 | "start_time": "2024-01-01T19:29:44.934420",
165 | "status": "completed"
166 | },
167 | "tags": []
168 | },
169 | "outputs": [],
170 | "source": [
171 | "train = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/train.csv').drop(['Id'], axis=1)\n",
172 | "test = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/test.csv')\n",
173 | "greeks = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/greeks.csv')"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "id": "d4fc5577",
179 | "metadata": {
180 | "papermill": {
181 | "duration": 0.007123,
182 | "end_time": "2024-01-01T19:29:45.022621",
183 | "exception": false,
184 | "start_time": "2024-01-01T19:29:45.015498",
185 | "status": "completed"
186 | },
187 | "tags": []
188 | },
189 | "source": [
190 | "# Evaluation Metric"
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": 4,
196 | "id": "66de6b00",
197 | "metadata": {
198 | "execution": {
199 | "iopub.execute_input": "2024-01-01T19:29:45.039742Z",
200 | "iopub.status.busy": "2024-01-01T19:29:45.038897Z",
201 | "iopub.status.idle": "2024-01-01T19:29:45.046448Z",
202 | "shell.execute_reply": "2024-01-01T19:29:45.045586Z"
203 | },
204 | "papermill": {
205 | "duration": 0.018844,
206 | "end_time": "2024-01-01T19:29:45.048815",
207 | "exception": false,
208 | "start_time": "2024-01-01T19:29:45.029971",
209 | "status": "completed"
210 | },
211 | "tags": []
212 | },
213 | "outputs": [],
214 | "source": [
215 | "import torch\n",
216 | "\n",
217 | "def score(y_true, y_pred):\n",
218 | "\n",
219 | " # Calculate the number of observations for each class\n",
220 | " N_0 = torch.sum(1 - y_true)\n",
221 | " N_1 = torch.sum(y_true)\n",
222 | " \n",
223 | " # Calculate the predicted probabilities for each class\n",
224 | " p_1 = torch.clamp(y_pred, 1e-15, 1 - 1e-15)\n",
225 | " p_0 = 1 - p_1\n",
226 | " \n",
227 | " # Calculate the average log loss for each class\n",
228 | " log_loss_0 = -torch.sum((1 - y_true) * torch.log(p_0)) / N_0\n",
229 | " log_loss_1 = -torch.sum(y_true * torch.log(p_1)) / N_1\n",
230 | " \n",
231 | " # Return the (not further weighted) average of the averages\n",
232 | " a = (log_loss_0 + log_loss_1) / 2\n",
233 | " \n",
234 | " return a\n"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 5,
240 | "id": "65d76aa1",
241 | "metadata": {
242 | "execution": {
243 | "iopub.execute_input": "2024-01-01T19:29:45.065932Z",
244 | "iopub.status.busy": "2024-01-01T19:29:45.065218Z",
245 | "iopub.status.idle": "2024-01-01T19:29:45.072757Z",
246 | "shell.execute_reply": "2024-01-01T19:29:45.071985Z"
247 | },
248 | "papermill": {
249 | "duration": 0.018734,
250 | "end_time": "2024-01-01T19:29:45.075039",
251 | "exception": false,
252 | "start_time": "2024-01-01T19:29:45.056305",
253 | "status": "completed"
254 | },
255 | "tags": []
256 | },
257 | "outputs": [],
258 | "source": [
259 | "def balanced_log_loss(y_true, y_pred):\n",
260 | " y_true = np.array(y_true)\n",
261 | " y_pred = np.array(y_pred)\n",
262 | " # y_true: correct labels 0, 1\n",
263 | " # y_pred: predicted probabilities of class=1\n",
264 | " # Implements the Evaluation equation with w_0 = w_1 = 1.\n",
265 | " # Calculate the number of observations for each class\n",
266 | " N_0 = np.sum(1 - y_true)\n",
267 | " N_1 = np.sum(y_true)\n",
268 | " # Calculate the predicted probabilities for each class\n",
269 | " p_1 = np.clip(y_pred, 1e-15, 1 - 1e-15)\n",
270 | " p_0 = 1 - p_1\n",
271 | " # Calculate the average log loss for each class\n",
272 | " log_loss_0 = -np.sum((1 - y_true) * np.log(p_0)) / N_0\n",
273 | " log_loss_1 = -np.sum(y_true * np.log(p_1)) / N_1\n",
274 | " # return the (not further weighted) average of the averages\n",
275 | " a = (log_loss_0 + log_loss_1)/2\n",
276 | " return 'Loss', a, False"
277 | ]
278 | },
279 | {
280 | "cell_type": "markdown",
281 | "id": "c0e361bb",
282 | "metadata": {
283 | "papermill": {
284 | "duration": 0.007147,
285 | "end_time": "2024-01-01T19:29:45.089796",
286 | "exception": false,
287 | "start_time": "2024-01-01T19:29:45.082649",
288 | "status": "completed"
289 | },
290 | "tags": []
291 | },
292 | "source": [
293 | "# 5-seed ensemble"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 6,
299 | "id": "b95ac841",
300 | "metadata": {
301 | "execution": {
302 | "iopub.execute_input": "2024-01-01T19:29:45.106903Z",
303 | "iopub.status.busy": "2024-01-01T19:29:45.106118Z",
304 | "iopub.status.idle": "2024-01-01T19:29:45.120541Z",
305 | "shell.execute_reply": "2024-01-01T19:29:45.119708Z"
306 | },
307 | "papermill": {
308 | "duration": 0.025846,
309 | "end_time": "2024-01-01T19:29:45.123013",
310 | "exception": false,
311 | "start_time": "2024-01-01T19:29:45.097167",
312 | "status": "completed"
313 | },
314 | "tags": []
315 | },
316 | "outputs": [],
317 | "source": [
318 | "features = ['AB', 'AF', 'AH', 'AM', 'AR', 'AX', 'AY', 'AZ', 'BC', 'BD ', 'BN', 'BP',\n",
319 | " 'BQ', 'BR', 'BZ', 'CB', 'CC', 'CD ', 'CF', 'CH', 'CL', 'CR', 'CS', 'CU',\n",
320 | " 'CW ', 'DA', 'DE', 'DF', 'DH', 'DI', 'DL', 'DN', 'DU', 'DV', 'DY', 'EB',\n",
321 | " 'EE', 'EG', 'EH', 'EJ', 'EL', 'EP', 'EU', 'FC', 'FD ', 'FE', 'FI', 'FL',\n",
322 | " 'FR', 'FS', 'GB', 'GE', 'GF', 'GH', 'GI', 'GL']\n",
323 | "\n",
324 | "train['EJ'] = train['EJ'].str.strip()\n",
325 | "test['EJ'] = test['EJ'].str.strip()\n",
326 | "\n",
327 | "train['EJ'] = train['EJ'].map({'A': 1, 'B': 0})\n",
328 | "test['EJ'] = test['EJ'].map({'A': 1, 'B': 0})"
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": 7,
334 | "id": "24f2d921",
335 | "metadata": {
336 | "execution": {
337 | "iopub.execute_input": "2024-01-01T19:29:45.140285Z",
338 | "iopub.status.busy": "2024-01-01T19:29:45.139534Z",
339 | "iopub.status.idle": "2024-01-01T19:29:45.156933Z",
340 | "shell.execute_reply": "2024-01-01T19:29:45.156113Z"
341 | },
342 | "papermill": {
343 | "duration": 0.028755,
344 | "end_time": "2024-01-01T19:29:45.159384",
345 | "exception": false,
346 | "start_time": "2024-01-01T19:29:45.130629",
347 | "status": "completed"
348 | },
349 | "tags": []
350 | },
351 | "outputs": [],
352 | "source": [
353 | "skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)\n",
354 | "X = train[features]\n",
355 | "y = train['Class']\n",
356 | "\n",
357 | "for fold, (_, val_idx) in enumerate(skf.split(X, y)):\n",
358 | " train.loc[train.index.isin(val_idx), 'fold'] = fold + 1\n",
359 | "train['fold'] = train['fold'].astype(int)"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": 8,
365 | "id": "cfeee14a",
366 | "metadata": {
367 | "execution": {
368 | "iopub.execute_input": "2024-01-01T19:29:45.176830Z",
369 | "iopub.status.busy": "2024-01-01T19:29:45.176092Z",
370 | "iopub.status.idle": "2024-01-01T19:29:45.184622Z",
371 | "shell.execute_reply": "2024-01-01T19:29:45.183872Z"
372 | },
373 | "papermill": {
374 | "duration": 0.019749,
375 | "end_time": "2024-01-01T19:29:45.186918",
376 | "exception": false,
377 | "start_time": "2024-01-01T19:29:45.167169",
378 | "status": "completed"
379 | },
380 | "tags": []
381 | },
382 | "outputs": [],
383 | "source": [
384 | "def prepare_folds(train, features, target):\n",
385 | " X = train[features]\n",
386 | " y = train[target]\n",
387 | " skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)\n",
388 | " scaler = StandardScaler()\n",
389 | " \n",
390 | " fold = 0\n",
391 | " for train_indices, val_indices in skf.split(X, y):\n",
392 | " fold += 1\n",
393 | " print(f'Preparing fold {fold} ...')\n",
394 | " df_train = train.loc[train.index.isin(train_indices)].reset_index(drop=True)\n",
395 | " df_val = train.loc[train.index.isin(val_indices)].reset_index(drop=True)\n",
396 | " \n",
397 | " df_train[features] = scaler.fit_transform(df_train[features])\n",
398 | " df_val[features] = scaler.transform(df_val[features])\n",
399 | " \n",
400 | " df_train.to_csv(f'df_train_fold_{fold}.csv', index=False)\n",
401 | " df_val.to_csv(f'df_val_fold_{fold}.csv', index=False)"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": 9,
407 | "id": "9d4d2037",
408 | "metadata": {
409 | "execution": {
410 | "iopub.execute_input": "2024-01-01T19:29:45.204241Z",
411 | "iopub.status.busy": "2024-01-01T19:29:45.203491Z",
412 | "iopub.status.idle": "2024-01-01T19:29:45.755298Z",
413 | "shell.execute_reply": "2024-01-01T19:29:45.753919Z"
414 | },
415 | "papermill": {
416 | "duration": 0.563143,
417 | "end_time": "2024-01-01T19:29:45.757695",
418 | "exception": false,
419 | "start_time": "2024-01-01T19:29:45.194552",
420 | "status": "completed"
421 | },
422 | "tags": []
423 | },
424 | "outputs": [
425 | {
426 | "name": "stdout",
427 | "output_type": "stream",
428 | "text": [
429 | "Preparing fold 1 ...\n",
430 | "Preparing fold 2 ...\n",
431 | "Preparing fold 3 ...\n",
432 | "Preparing fold 4 ...\n",
433 | "Preparing fold 5 ...\n"
434 | ]
435 | }
436 | ],
437 | "source": [
438 | "prepare_folds(train, ORIGINAL_FEATURES, TARGET)"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": 10,
444 | "id": "71aebcb7",
445 | "metadata": {
446 | "execution": {
447 | "iopub.execute_input": "2024-01-01T19:29:45.776193Z",
448 | "iopub.status.busy": "2024-01-01T19:29:45.775356Z",
449 | "iopub.status.idle": "2024-01-01T19:29:45.991449Z",
450 | "shell.execute_reply": "2024-01-01T19:29:45.990230Z"
451 | },
452 | "papermill": {
453 | "duration": 0.228705,
454 | "end_time": "2024-01-01T19:29:45.994501",
455 | "exception": false,
456 | "start_time": "2024-01-01T19:29:45.765796",
457 | "status": "completed"
458 | },
459 | "tags": []
460 | },
461 | "outputs": [],
462 | "source": [
463 | "from sklearn.impute import SimpleImputer\n",
464 | "imp = SimpleImputer(strategy='median')"
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": 11,
470 | "id": "49576712",
471 | "metadata": {
472 | "execution": {
473 | "iopub.execute_input": "2024-01-01T19:29:46.012575Z",
474 | "iopub.status.busy": "2024-01-01T19:29:46.011865Z",
475 | "iopub.status.idle": "2024-01-01T19:29:46.021457Z",
476 | "shell.execute_reply": "2024-01-01T19:29:46.020754Z"
477 | },
478 | "papermill": {
479 | "duration": 0.020943,
480 | "end_time": "2024-01-01T19:29:46.023529",
481 | "exception": false,
482 | "start_time": "2024-01-01T19:29:46.002586",
483 | "status": "completed"
484 | },
485 | "tags": []
486 | },
487 | "outputs": [],
488 | "source": [
489 | "class SpamDataset(Dataset):\n",
490 | " def __init__(self, features, targets):\n",
491 | " self.features = torch.tensor(features, dtype=torch.float)\n",
492 | " self.targets = torch.tensor(targets, dtype=torch.long)\n",
493 | "\n",
494 | " def __getitem__(self, index):\n",
495 | " X = self.features[index]\n",
496 | " y = self.targets[index]\n",
497 | " return X, y\n",
498 | "\n",
499 | " def __len__(self):\n",
500 | " return self.targets.shape[0]\n",
501 | "\n",
502 | "\n",
503 | "def get_dataloader(features, targets, \n",
504 | " feature_names, \n",
505 | " target_name,\n",
506 | " batch_size,\n",
507 | " mode):\n",
508 | " if mode == 'train':\n",
509 | " shuffle = True\n",
510 | " drop_last = True\n",
511 | " else:\n",
512 | " shuffle = False\n",
513 | " drop_last = False\n",
514 | " \n",
515 | " torch.manual_seed(SEED)\n",
516 | " train_dataset = SpamDataset(\n",
517 | " features=features, \n",
518 | " targets=targets\n",
519 | " )\n",
520 | " \n",
521 | " data_loader = DataLoader(\n",
522 | " dataset=train_dataset,\n",
523 | " batch_size=batch_size,\n",
524 | " shuffle=shuffle,\n",
525 | " drop_last=drop_last,\n",
526 | " num_workers=N_CORES\n",
527 | " )\n",
528 | " \n",
529 | " return data_loader"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 12,
535 | "id": "7b204261",
536 | "metadata": {
537 | "execution": {
538 | "iopub.execute_input": "2024-01-01T19:29:46.041904Z",
539 | "iopub.status.busy": "2024-01-01T19:29:46.041063Z",
540 | "iopub.status.idle": "2024-01-01T19:29:46.066403Z",
541 | "shell.execute_reply": "2024-01-01T19:29:46.065445Z"
542 | },
543 | "papermill": {
544 | "duration": 0.037087,
545 | "end_time": "2024-01-01T19:29:46.068710",
546 | "exception": false,
547 | "start_time": "2024-01-01T19:29:46.031623",
548 | "status": "completed"
549 | },
550 | "tags": []
551 | },
552 | "outputs": [],
553 | "source": [
554 | "class GatedLinearUnit(nn.Module):\n",
555 | " def __init__(self, input_size):\n",
556 | " super(GatedLinearUnit, self).__init__()\n",
557 | " self.linear = nn.Linear(input_size, input_size)\n",
558 | " self.gate = nn.Sequential(\n",
559 | " nn.Linear(input_size, input_size),\n",
560 | " nn.Sigmoid()\n",
561 | " )\n",
562 | " \n",
563 | " def forward(self, x):\n",
564 | " return self.linear(x) * self.gate(x)\n",
565 | " \n",
566 | " \n",
567 | "class GatedResidualNetwork(nn.Module):\n",
568 | " def __init__(self, input_size, hidden_size, dropout):\n",
569 | " super(GatedResidualNetwork, self).__init__()\n",
570 | " self.input_size = input_size\n",
571 | " self.hidden_size = hidden_size\n",
572 | " \n",
573 | " self.grn = nn.Sequential(\n",
574 | " nn.Linear(input_size, hidden_size),\n",
575 | " nn.ELU(),\n",
576 | " nn.Linear(hidden_size, hidden_size),\n",
577 | " nn.Dropout(dropout),\n",
578 | " GatedLinearUnit(hidden_size),\n",
579 | " )\n",
580 | " \n",
581 | " self.layer_norm = nn.LayerNorm(hidden_size)\n",
582 | " self.feature_projection = nn.Linear(input_size, hidden_size)\n",
583 | " \n",
584 | " def forward(self, inputs):\n",
585 | " x = self.grn(inputs)\n",
586 | " if inputs.shape[-1] != self.hidden_size:\n",
587 | " inputs = self.feature_projection(inputs)\n",
588 | " x = self.layer_norm(x + inputs)\n",
589 | " return x\n",
590 | " \n",
591 | "class VariableSelectionNetwork(nn.Module):\n",
592 | " def __init__(self, num_features, dense_units, hidden_size, dropout):\n",
593 | " super(VariableSelectionNetwork, self).__init__()\n",
594 | " self.num_features = num_features\n",
595 | " self.hidden_size = hidden_size\n",
596 | " self.grns = nn.ModuleList()\n",
597 | " for _ in range(num_features):\n",
598 | " self.grns.append(GatedResidualNetwork(dense_units, hidden_size, dropout))\n",
599 | " \n",
600 | " \n",
601 | " self.grn_concat = GatedResidualNetwork(num_features*dense_units, hidden_size, dropout)\n",
602 | " self.softmax = nn.Sequential(\n",
603 | " nn.Linear(hidden_size, num_features),\n",
604 | " nn.Softmax(dim=-1)\n",
605 | " )\n",
606 | " \n",
607 | " def forward(self, inputs):\n",
608 | " v = torch.cat(inputs, dim=1)\n",
609 | " v = self.grn_concat(v)\n",
610 | " v = self.softmax(v)\n",
611 | " v = torch.unsqueeze(v, dim=-1)\n",
612 | " \n",
613 | " x = []\n",
614 | " for idx, input_ in enumerate(inputs):\n",
615 | " x.append(self.grns[idx](input_))\n",
616 | " x = torch.stack(x, dim=1)\n",
617 | " \n",
618 | " out = (v.transpose(2, 1) @ x).squeeze(dim=1)\n",
619 | " return out\n",
620 | " \n",
621 | "class VariableSelectionFlow(nn.Module):\n",
622 | " def __init__(self, num_features, hidden_size, dense_units, dropout):\n",
623 | " super(VariableSelectionFlow, self).__init__()\n",
624 | " self.variable_selection = VariableSelectionNetwork(num_features, dense_units, hidden_size, dropout)\n",
625 | " self.split = lambda x: torch.split(x, 1, dim=-1)\n",
626 | " self.dense_list = nn.ModuleList(\n",
627 | " [\n",
628 | " nn.Linear(1, dense_units) \n",
629 | " for _ in range(num_features)\n",
630 | " ]\n",
631 | " )\n",
632 | " \n",
633 | " \n",
634 | " def forward(self, inputs):\n",
635 | " split_inputs = self.split(inputs)\n",
636 | " x = []\n",
637 | " for split_input, linear in zip(split_inputs, self.dense_list):\n",
638 | " x.append(linear(split_input))\n",
639 | " \n",
640 | " return self.variable_selection(x)\n",
641 | "class Net(nn.Module):\n",
642 | " def __init__(self, num_features, dense_units, hidden_sizes, dropouts):\n",
643 | " super(Net, self).__init__()\n",
644 | " self.num_features = num_features\n",
645 | " self.dense_units = dense_units\n",
646 | " self.hidden_size_1 = hidden_sizes[0]\n",
647 | " self.hidden_size_2 = hidden_sizes[1]\n",
648 | " self.hidden_size_3 = hidden_sizes[2]\n",
649 | "\n",
650 | " self.dropout_1 = dropouts[0]\n",
651 | " self.dropout_2 = dropouts[1]\n",
652 | " self.dropout_3 = dropouts[2]\n",
653 | " \n",
654 | " self.variable_slection_flows = nn.Sequential(\n",
655 | " VariableSelectionFlow(num_features, self.hidden_size_1, dense_units, self.dropout_1),\n",
656 | " VariableSelectionFlow(self.hidden_size_1, self.hidden_size_2, dense_units, self.dropout_2),\n",
657 | " VariableSelectionFlow(self.hidden_size_2, self.hidden_size_3, dense_units, self.dropout_3),\n",
658 | " nn.Linear(self.hidden_size_3, 2)\n",
659 | " )\n",
660 | " \n",
661 | " def forward(self, x):\n",
662 | " logits = self.variable_slection_flows(x)\n",
663 | " return logits"
664 | ]
665 | },
666 | {
667 | "cell_type": "code",
668 | "execution_count": 13,
669 | "id": "6058d4e7",
670 | "metadata": {
671 | "execution": {
672 | "iopub.execute_input": "2024-01-01T19:29:46.087005Z",
673 | "iopub.status.busy": "2024-01-01T19:29:46.086630Z",
674 | "iopub.status.idle": "2024-01-01T19:29:46.100234Z",
675 | "shell.execute_reply": "2024-01-01T19:29:46.099168Z"
676 | },
677 | "papermill": {
678 | "duration": 0.025228,
679 | "end_time": "2024-01-01T19:29:46.102328",
680 | "exception": false,
681 | "start_time": "2024-01-01T19:29:46.077100",
682 | "status": "completed"
683 | },
684 | "tags": []
685 | },
686 | "outputs": [],
687 | "source": [
688 | "def fit(model, optimizer, scheduler, epochs, train_dataloader, val_dataloader):\n",
689 | "\n",
690 | " start_time = time.time()\n",
691 | " scaler = GradScaler()\n",
692 | "\n",
693 | " for epoch in range(epochs):\n",
694 | "\n",
695 | " model.train()\n",
696 | " \n",
697 | " for batch_idx, (features, targets) in enumerate(train_dataloader):\n",
698 | " features = features.to(DEVICE)\n",
699 | " targets = targets.to(DEVICE)\n",
700 | " with autocast():\n",
701 | " logits = model(features)\n",
702 | " probs = F.softmax(logits, dim=-1)[:, 1]\n",
703 | " loss = score(targets, probs)\n",
704 | "\n",
705 | " scaler.scale(loss).backward()\n",
706 | " scaler.step(optimizer)\n",
707 | " scaler.update()\n",
708 | " optimizer.zero_grad()\n",
709 | " scheduler.step()\n",
710 | "\n",
711 | " if not batch_idx % 10:\n",
712 | " print(\n",
713 | " f'Epoch: {epoch + 1}/{epochs}'\n",
714 | " f' | Batch: {batch_idx}/{len(train_dataloader)}'\n",
715 | " f' | Loss: {loss.detach().cpu().item():.4f}')\n",
716 | "\n",
717 | " if val_dataloader is not None:\n",
718 | " y_scores = torch.tensor([])\n",
719 | " y_true = torch.tensor([])\n",
720 | "\n",
721 | " with torch.inference_mode():\n",
722 | "\n",
723 | " model.eval()\n",
724 | "\n",
725 | " for batch_idx, (features, targets) in enumerate(val_dataloader):\n",
726 | " features = features.to(DEVICE)\n",
727 | " with autocast():\n",
728 | " logits = model(features).detach().cpu().type(torch.float)\n",
729 | " probs = F.softmax(logits, dim=-1)[:, 1]\n",
730 | " y_scores = torch.cat([y_scores, probs])\n",
731 | " y_true = torch.cat([y_true, targets])\n",
732 | "\n",
733 | " val_score = balanced_log_loss(y_true, y_scores)\n",
734 | " print('Validation score (AUC):', val_score)\n",
735 | "\n",
736 | " elapsed = (time.time() - start_time) / 60\n",
737 | " print(f'Total training time: {elapsed:.3f} min')\n",
738 | "\n",
739 | " model.eval()\n",
740 | "\n",
741 | " if val_dataloader is not None:\n",
742 | " return model, val_score\n",
743 | " else:\n",
744 | " return model"
745 | ]
746 | },
747 | {
748 | "cell_type": "code",
749 | "execution_count": 14,
750 | "id": "e7117426",
751 | "metadata": {
752 | "execution": {
753 | "iopub.execute_input": "2024-01-01T19:29:46.121104Z",
754 | "iopub.status.busy": "2024-01-01T19:29:46.120221Z",
755 | "iopub.status.idle": "2024-01-01T19:29:46.124942Z",
756 | "shell.execute_reply": "2024-01-01T19:29:46.124215Z"
757 | },
758 | "papermill": {
759 | "duration": 0.016167,
760 | "end_time": "2024-01-01T19:29:46.126843",
761 | "exception": false,
762 | "start_time": "2024-01-01T19:29:46.110676",
763 | "status": "completed"
764 | },
765 | "tags": []
766 | },
767 | "outputs": [],
768 | "source": [
769 | "N_EPOCHS = 15\n",
770 | "N_WARMUPS = 100\n",
771 | "WEIGHT_DECAY = 0.005\n",
772 | "LEARNING_RATE = 0.01"
773 | ]
774 | },
775 | {
776 | "cell_type": "code",
777 | "execution_count": 15,
778 | "id": "2cf1b342",
779 | "metadata": {
780 | "execution": {
781 | "iopub.execute_input": "2024-01-01T19:29:46.145332Z",
782 | "iopub.status.busy": "2024-01-01T19:29:46.144673Z",
783 | "iopub.status.idle": "2024-01-01T19:29:46.151028Z",
784 | "shell.execute_reply": "2024-01-01T19:29:46.150011Z"
785 | },
786 | "papermill": {
787 | "duration": 0.018249,
788 | "end_time": "2024-01-01T19:29:46.153417",
789 | "exception": false,
790 | "start_time": "2024-01-01T19:29:46.135168",
791 | "status": "completed"
792 | },
793 | "tags": []
794 | },
795 | "outputs": [],
796 | "source": [
797 | "# val_scores = []\n",
798 | "# for fold in range(1, 6):\n",
799 | "# df_train = pd.read_csv(OUTPUT_PATH / f'df_train_fold_{fold}.csv')\n",
800 | "# df_val = pd.read_csv(OUTPUT_PATH / f'df_val_fold_{fold}.csv')\n",
801 | "# imp = SimpleImputer(strategy='median')\n",
802 | "# X_train = df_train[ORIGINAL_FEATURES]\n",
803 | "# y_train = df_train[TARGET]\n",
804 | "# X_val = df_val[ORIGINAL_FEATURES]\n",
805 | "# y_val = df_val[TARGET]\n",
806 | " \n",
807 | "# X_train, X_val = imp.fit_transform(X_train), imp.transform(X_val)\n",
808 | " \n",
809 | "# train_dataloader = get_dataloader(\n",
810 | "# X_train, y_train,\n",
811 | "# feature_names=ORIGINAL_FEATURES, \n",
812 | "# target_name=TARGET,\n",
813 | "# batch_size=BATCH_SIZE,\n",
814 | "# mode='train'\n",
815 | "# )\n",
816 | " \n",
817 | "# val_dataloader = get_dataloader(\n",
818 | "# X_val, y_val, \n",
819 | "# feature_names=ORIGINAL_FEATURES, \n",
820 | "# target_name=TARGET,\n",
821 | "# batch_size=BATCH_SIZE,\n",
822 | "# mode='val'\n",
823 | "# )\n",
824 | " \n",
825 | "# torch.manual_seed(SEED)\n",
826 | "# model = Net(\n",
827 | "# num_features=NUM_FEATURES, \n",
828 | "# dense_units=8,\n",
829 | "# hidden_sizes=[32, 32, 32],\n",
830 | "# dropouts=[0.75, 0.5, 0.1]\n",
831 | "# )\n",
832 | " \n",
833 | "# model.to(DEVICE)\n",
834 | "# model.train()\n",
835 | " \n",
836 | "# optimizer = optim.AdamW(\n",
837 | "# model.parameters(), \n",
838 | "# lr=LEARNING_RATE,\n",
839 | "# weight_decay=WEIGHT_DECAY\n",
840 | "# )\n",
841 | " \n",
842 | "# scheduler = get_cosine_schedule_with_warmup(\n",
843 | "# optimizer=optimizer, \n",
844 | "# num_warmup_steps=N_WARMUPS,\n",
845 | "# num_training_steps=len(train_dataloader)*N_EPOCHS\n",
846 | "# )\n",
847 | " \n",
848 | "# model, val_score = fit(model=model,\n",
849 | "# optimizer=optimizer,\n",
850 | "# scheduler=scheduler,\n",
851 | "# epochs=N_EPOCHS,\n",
852 | "# train_dataloader=train_dataloader,\n",
853 | "# val_dataloader=val_dataloader)\n",
854 | " \n",
855 | "# val_scores.append(val_score)"
856 | ]
857 | },
858 | {
859 | "cell_type": "code",
860 | "execution_count": 16,
861 | "id": "92dc4434",
862 | "metadata": {
863 | "execution": {
864 | "iopub.execute_input": "2024-01-01T19:29:46.171440Z",
865 | "iopub.status.busy": "2024-01-01T19:29:46.171085Z",
866 | "iopub.status.idle": "2024-01-01T19:29:46.176610Z",
867 | "shell.execute_reply": "2024-01-01T19:29:46.175869Z"
868 | },
869 | "papermill": {
870 | "duration": 0.016751,
871 | "end_time": "2024-01-01T19:29:46.178598",
872 | "exception": false,
873 | "start_time": "2024-01-01T19:29:46.161847",
874 | "status": "completed"
875 | },
876 | "tags": []
877 | },
878 | "outputs": [],
879 | "source": [
880 | "test['Class'] = 0"
881 | ]
882 | },
883 | {
884 | "cell_type": "code",
885 | "execution_count": 17,
886 | "id": "6f67d8cc",
887 | "metadata": {
888 | "execution": {
889 | "iopub.execute_input": "2024-01-01T19:29:46.197203Z",
890 | "iopub.status.busy": "2024-01-01T19:29:46.196547Z",
891 | "iopub.status.idle": "2024-01-01T19:30:59.588191Z",
892 | "shell.execute_reply": "2024-01-01T19:30:59.586463Z"
893 | },
894 | "papermill": {
895 | "duration": 73.403646,
896 | "end_time": "2024-01-01T19:30:59.590753",
897 | "exception": false,
898 | "start_time": "2024-01-01T19:29:46.187107",
899 | "status": "completed"
900 | },
901 | "tags": []
902 | },
903 | "outputs": [
904 | {
905 | "name": "stdout",
906 | "output_type": "stream",
907 | "text": [
908 | "Epoch: 1/15 | Batch: 0/19 | Loss: 0.7048\n",
909 | "Epoch: 1/15 | Batch: 10/19 | Loss: 0.6967\n",
910 | "Epoch: 2/15 | Batch: 0/19 | Loss: 0.6925\n",
911 | "Epoch: 2/15 | Batch: 10/19 | Loss: 0.7254\n",
912 | "Epoch: 3/15 | Batch: 0/19 | Loss: 0.5173\n",
913 | "Epoch: 3/15 | Batch: 10/19 | Loss: 0.7123\n",
914 | "Epoch: 4/15 | Batch: 0/19 | Loss: 0.4894\n",
915 | "Epoch: 4/15 | Batch: 10/19 | Loss: 0.2931\n",
916 | "Epoch: 5/15 | Batch: 0/19 | Loss: 0.4056\n",
917 | "Epoch: 5/15 | Batch: 10/19 | Loss: 0.4076\n",
918 | "Epoch: 6/15 | Batch: 0/19 | Loss: 0.4935\n",
919 | "Epoch: 6/15 | Batch: 10/19 | Loss: 0.3445\n",
920 | "Epoch: 7/15 | Batch: 0/19 | Loss: 0.5423\n",
921 | "Epoch: 7/15 | Batch: 10/19 | Loss: 0.5617\n",
922 | "Epoch: 8/15 | Batch: 0/19 | Loss: 0.3421\n",
923 | "Epoch: 8/15 | Batch: 10/19 | Loss: 0.7682\n",
924 | "Epoch: 9/15 | Batch: 0/19 | Loss: 0.5203\n",
925 | "Epoch: 9/15 | Batch: 10/19 | Loss: 0.3110\n",
926 | "Epoch: 10/15 | Batch: 0/19 | Loss: 0.4536\n",
927 | "Epoch: 10/15 | Batch: 10/19 | Loss: 0.5030\n",
928 | "Epoch: 11/15 | Batch: 0/19 | Loss: 0.5255\n",
929 | "Epoch: 11/15 | Batch: 10/19 | Loss: 0.2635\n",
930 | "Epoch: 12/15 | Batch: 0/19 | Loss: 0.3376\n",
931 | "Epoch: 12/15 | Batch: 10/19 | Loss: 0.4615\n",
932 | "Epoch: 13/15 | Batch: 0/19 | Loss: 0.5024\n",
933 | "Epoch: 13/15 | Batch: 10/19 | Loss: 0.2413\n",
934 | "Epoch: 14/15 | Batch: 0/19 | Loss: 0.4117\n",
935 | "Epoch: 14/15 | Batch: 10/19 | Loss: 0.2432\n",
936 | "Epoch: 15/15 | Batch: 0/19 | Loss: 0.3910\n",
937 | "Epoch: 15/15 | Batch: 10/19 | Loss: 0.3647\n",
938 | "Total training time: 1.213 min\n"
939 | ]
940 | }
941 | ],
942 | "source": [
943 | "X_train = train[ORIGINAL_FEATURES]\n",
944 | "y_train = train[TARGET]\n",
945 | "X_test = test[ORIGINAL_FEATURES]\n",
946 | "y_test = test[TARGET]\n",
947 | "\n",
948 | "X_train, X_test = imp.fit_transform(X_train), imp.transform(X_test)\n",
949 | "\n",
950 | "train_dataloader = get_dataloader(\n",
951 | " X_train, y_train,\n",
952 | " feature_names=ORIGINAL_FEATURES, \n",
953 | " target_name=TARGET,\n",
954 | " batch_size=BATCH_SIZE,\n",
955 | " mode='train'\n",
956 | ")\n",
957 | "\n",
958 | "test_dataloader = get_dataloader(\n",
959 | " X_test, y_test, \n",
960 | " feature_names=ORIGINAL_FEATURES, \n",
961 | " target_name=TARGET,\n",
962 | " batch_size=BATCH_SIZE,\n",
963 | " mode='val'\n",
964 | ")\n",
965 | "\n",
966 | "torch.manual_seed(SEED)\n",
967 | "model = Net(\n",
968 | " num_features=NUM_FEATURES, \n",
969 | " dense_units=8,\n",
970 | " hidden_sizes=[32, 32, 32],\n",
971 | " dropouts=[0.75, 0.5, 0.25]\n",
972 | ")\n",
973 | "\n",
974 | "model.to(DEVICE)\n",
975 | "model.train()\n",
976 | "\n",
977 | "optimizer = optim.AdamW(\n",
978 | " model.parameters(), \n",
979 | " lr=LEARNING_RATE,\n",
980 | " weight_decay=WEIGHT_DECAY\n",
981 | ")\n",
982 | "\n",
983 | "scheduler = get_cosine_schedule_with_warmup(\n",
984 | " optimizer=optimizer, \n",
985 | " num_warmup_steps=N_WARMUPS,\n",
986 | " num_training_steps=len(train_dataloader)*N_EPOCHS\n",
987 | ")\n",
988 | "\n",
989 | "model = fit(model=model,\n",
990 | " optimizer=optimizer,\n",
991 | " scheduler=scheduler,\n",
992 | " epochs=N_EPOCHS,\n",
993 | " train_dataloader=train_dataloader,\n",
994 | " val_dataloader=None)"
995 | ]
996 | },
997 | {
998 | "cell_type": "code",
999 | "execution_count": 18,
1000 | "id": "e157b23d",
1001 | "metadata": {
1002 | "execution": {
1003 | "iopub.execute_input": "2024-01-01T19:30:59.616895Z",
1004 | "iopub.status.busy": "2024-01-01T19:30:59.616447Z",
1005 | "iopub.status.idle": "2024-01-01T19:30:59.838258Z",
1006 | "shell.execute_reply": "2024-01-01T19:30:59.837044Z"
1007 | },
1008 | "papermill": {
1009 | "duration": 0.238962,
1010 | "end_time": "2024-01-01T19:30:59.840939",
1011 | "exception": false,
1012 | "start_time": "2024-01-01T19:30:59.601977",
1013 | "status": "completed"
1014 | },
1015 | "tags": []
1016 | },
1017 | "outputs": [],
1018 | "source": [
1019 | "y_scores = torch.tensor([])\n",
1020 | "\n",
1021 | "with torch.inference_mode():\n",
1022 | "\n",
1023 | " model.eval()\n",
1024 | "\n",
1025 | " for batch_idx, (features, targets) in enumerate(test_dataloader):\n",
1026 | " features = features.to(DEVICE)\n",
1027 | " with autocast():\n",
1028 | " logits = model(features).detach().cpu().type(torch.float)\n",
1029 | " probs = F.softmax(logits, dim=-1)\n",
1030 | " y_scores = torch.cat([y_scores, probs])"
1031 | ]
1032 | },
1033 | {
1034 | "cell_type": "markdown",
1035 | "id": "6013174b",
1036 | "metadata": {
1037 | "papermill": {
1038 | "duration": 0.011217,
1039 | "end_time": "2024-01-01T19:30:59.864280",
1040 | "exception": false,
1041 | "start_time": "2024-01-01T19:30:59.853063",
1042 | "status": "completed"
1043 | },
1044 | "tags": []
1045 | },
1046 | "source": [
1047 | "# Submission"
1048 | ]
1049 | },
1050 | {
1051 | "cell_type": "code",
1052 | "execution_count": 19,
1053 | "id": "9ad834b5",
1054 | "metadata": {
1055 | "execution": {
1056 | "iopub.execute_input": "2024-01-01T19:30:59.889223Z",
1057 | "iopub.status.busy": "2024-01-01T19:30:59.888758Z",
1058 | "iopub.status.idle": "2024-01-01T19:30:59.932660Z",
1059 | "shell.execute_reply": "2024-01-01T19:30:59.931613Z"
1060 | },
1061 | "papermill": {
1062 | "duration": 0.059337,
1063 | "end_time": "2024-01-01T19:30:59.935071",
1064 | "exception": false,
1065 | "start_time": "2024-01-01T19:30:59.875734",
1066 | "status": "completed"
1067 | },
1068 | "tags": []
1069 | },
1070 | "outputs": [
1071 | {
1072 | "name": "stdout",
1073 | "output_type": "stream",
1074 | "text": [
1075 | "Submission file saved!\n"
1076 | ]
1077 | },
1078 | {
1079 | "data": {
1080 | "text/html": [
1081 | "
\n",
1082 | "\n",
1095 | "
\n",
1096 | " \n",
1097 | " \n",
1098 | " | \n",
1099 | " Id | \n",
1100 | " class_0 | \n",
1101 | " class_1 | \n",
1102 | "
\n",
1103 | " \n",
1104 | " \n",
1105 | " \n",
1106 | " | 0 | \n",
1107 | " 00eed32682bb | \n",
1108 | " 0.923802 | \n",
1109 | " 0.076198 | \n",
1110 | "
\n",
1111 | " \n",
1112 | " | 1 | \n",
1113 | " 010ebe33f668 | \n",
1114 | " 0.923802 | \n",
1115 | " 0.076198 | \n",
1116 | "
\n",
1117 | " \n",
1118 | " | 2 | \n",
1119 | " 02fa521e1838 | \n",
1120 | " 0.923802 | \n",
1121 | " 0.076198 | \n",
1122 | "
\n",
1123 | " \n",
1124 | " | 3 | \n",
1125 | " 040e15f562a2 | \n",
1126 | " 0.923802 | \n",
1127 | " 0.076198 | \n",
1128 | "
\n",
1129 | " \n",
1130 | " | 4 | \n",
1131 | " 046e85c7cc7f | \n",
1132 | " 0.923802 | \n",
1133 | " 0.076198 | \n",
1134 | "
\n",
1135 | " \n",
1136 | "
\n",
1137 | "
"
1138 | ],
1139 | "text/plain": [
1140 | " Id class_0 class_1\n",
1141 | "0 00eed32682bb 0.923802 0.076198\n",
1142 | "1 010ebe33f668 0.923802 0.076198\n",
1143 | "2 02fa521e1838 0.923802 0.076198\n",
1144 | "3 040e15f562a2 0.923802 0.076198\n",
1145 | "4 046e85c7cc7f 0.923802 0.076198"
1146 | ]
1147 | },
1148 | "execution_count": 19,
1149 | "metadata": {},
1150 | "output_type": "execute_result"
1151 | }
1152 | ],
1153 | "source": [
1154 | "sub = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv')\n",
1155 | "sub[['class_0', 'class_1']] = y_scores\n",
1156 | "sub.to_csv('submission.csv', index=False)\n",
1157 | "print('Submission file saved!')\n",
1158 | "sub"
1159 | ]
1160 | },
1161 | {
1162 | "cell_type": "code",
1163 | "execution_count": null,
1164 | "id": "9b514427",
1165 | "metadata": {
1166 | "papermill": {
1167 | "duration": 0.011463,
1168 | "end_time": "2024-01-01T19:30:59.958503",
1169 | "exception": false,
1170 | "start_time": "2024-01-01T19:30:59.947040",
1171 | "status": "completed"
1172 | },
1173 | "tags": []
1174 | },
1175 | "outputs": [],
1176 | "source": []
1177 | }
1178 | ],
1179 | "metadata": {
1180 | "kaggle": {
1181 | "accelerator": "none",
1182 | "dataSources": [
1183 | {
1184 | "databundleVersionId": 5687476,
1185 | "sourceId": 52784,
1186 | "sourceType": "competition"
1187 | },
1188 | {
1189 | "databundleVersionId": 5768690,
1190 | "datasetId": 3273406,
1191 | "sourceId": 5693071,
1192 | "sourceType": "datasetVersion"
1193 | },
1194 | {
1195 | "databundleVersionId": 7382017,
1196 | "datasetId": 930977,
1197 | "sourceId": 7292109,
1198 | "sourceType": "datasetVersion"
1199 | }
1200 | ],
1201 | "dockerImageVersionId": 30474,
1202 | "isGpuEnabled": false,
1203 | "isInternetEnabled": false,
1204 | "language": "python",
1205 | "sourceType": "notebook"
1206 | },
1207 | "kernelspec": {
1208 | "display_name": "Python 3",
1209 | "language": "python",
1210 | "name": "python3"
1211 | },
1212 | "language_info": {
1213 | "codemirror_mode": {
1214 | "name": "ipython",
1215 | "version": 3
1216 | },
1217 | "file_extension": ".py",
1218 | "mimetype": "text/x-python",
1219 | "name": "python",
1220 | "nbconvert_exporter": "python",
1221 | "pygments_lexer": "ipython3",
1222 | "version": "3.10.10"
1223 | },
1224 | "papermill": {
1225 | "default_parameters": {},
1226 | "duration": 104.093568,
1227 | "end_time": "2024-01-01T19:31:02.855589",
1228 | "environment_variables": {},
1229 | "exception": null,
1230 | "input_path": "__notebook__.ipynb",
1231 | "output_path": "__notebook__.ipynb",
1232 | "parameters": {},
1233 | "start_time": "2024-01-01T19:29:18.762021",
1234 | "version": "2.4.0"
1235 | }
1236 | },
1237 | "nbformat": 4,
1238 | "nbformat_minor": 5
1239 | }
1240 |
--------------------------------------------------------------------------------
/spam_prediction.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "91349eca",
7 | "metadata": {
8 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
9 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
10 | "execution": {
11 | "iopub.execute_input": "2024-01-05T17:14:43.763383Z",
12 | "iopub.status.busy": "2024-01-05T17:14:43.762477Z",
13 | "iopub.status.idle": "2024-01-05T17:14:50.370293Z",
14 | "shell.execute_reply": "2024-01-05T17:14:50.369488Z"
15 | },
16 | "papermill": {
17 | "duration": 6.61799,
18 | "end_time": "2024-01-05T17:14:50.372582",
19 | "exception": false,
20 | "start_time": "2024-01-05T17:14:43.754592",
21 | "status": "completed"
22 | },
23 | "tags": []
24 | },
25 | "outputs": [
26 | {
27 | "name": "stderr",
28 | "output_type": "stream",
29 | "text": [
30 | "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3\n",
31 | " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
32 | ]
33 | }
34 | ],
35 | "source": [
36 | "import numpy as np\n",
37 | "import pandas as pd\n",
38 | "\n",
39 | "import time\n",
40 | "from tqdm import tqdm\n",
41 | "from pathlib import Path\n",
42 | "import multiprocessing as mp\n",
43 | "\n",
44 | "from sklearn.metrics import roc_auc_score\n",
45 | "from sklearn.preprocessing import StandardScaler\n",
46 | "from sklearn.model_selection import cross_val_score, StratifiedKFold\n",
47 | "\n",
48 | "from transformers import get_cosine_schedule_with_warmup\n",
49 | "\n",
50 | "import torch \n",
51 | "from torch import nn, optim\n",
52 | "import torch.nn.functional as F\n",
53 | "from torch.cuda.amp import GradScaler, autocast\n",
54 | "from torch.utils.data import Dataset, DataLoader"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "id": "b0ee803f",
60 | "metadata": {
61 | "papermill": {
62 | "duration": 0.005714,
63 | "end_time": "2024-01-05T17:14:50.384668",
64 | "exception": false,
65 | "start_time": "2024-01-05T17:14:50.378954",
66 | "status": "completed"
67 | },
68 | "tags": []
69 | },
70 | "source": [
71 | "# General Settings"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 2,
77 | "id": "919d0ef5",
78 | "metadata": {
79 | "execution": {
80 | "iopub.execute_input": "2024-01-05T17:14:50.399076Z",
81 | "iopub.status.busy": "2024-01-05T17:14:50.397834Z",
82 | "iopub.status.idle": "2024-01-05T17:14:50.456406Z",
83 | "shell.execute_reply": "2024-01-05T17:14:50.455513Z"
84 | },
85 | "papermill": {
86 | "duration": 0.067852,
87 | "end_time": "2024-01-05T17:14:50.458590",
88 | "exception": false,
89 | "start_time": "2024-01-05T17:14:50.390738",
90 | "status": "completed"
91 | },
92 | "tags": []
93 | },
94 | "outputs": [],
95 | "source": [
96 | "INPUT_PATH = Path('/kaggle/input/lets-surpass-the-hosts-bayesian-model')\n",
97 | "OUTPUT_PATH = Path('/kaggle/working')\n",
98 | "\n",
99 | "ORIGINAL_FEATURES = ['A', 'B', 'E', 'F', 'G']\n",
100 | "\n",
101 | "TARGET = 'Target'\n",
102 | "\n",
103 | "N_CORES = mp.cpu_count()\n",
104 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
105 | "NUM_FEATURES = len(ORIGINAL_FEATURES)\n",
106 | "BATCH_SIZE = 32\n",
107 | "N_EPOCHS = 5\n",
108 | "N_WARMUPS = 80\n",
109 | "LEARNING_RATE = 0.004\n",
110 | "WEIGHT_DECAY = 0.01\n",
111 | "SEED = 252"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "id": "ccab7057",
117 | "metadata": {
118 | "papermill": {
119 | "duration": 0.005618,
120 | "end_time": "2024-01-05T17:14:50.470230",
121 | "exception": false,
122 | "start_time": "2024-01-05T17:14:50.464612",
123 | "status": "completed"
124 | },
125 | "tags": []
126 | },
127 | "source": [
128 | "# Load Data"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 3,
134 | "id": "edba1611",
135 | "metadata": {
136 | "execution": {
137 | "iopub.execute_input": "2024-01-05T17:14:50.484027Z",
138 | "iopub.status.busy": "2024-01-05T17:14:50.483173Z",
139 | "iopub.status.idle": "2024-01-05T17:14:50.504976Z",
140 | "shell.execute_reply": "2024-01-05T17:14:50.504151Z"
141 | },
142 | "papermill": {
143 | "duration": 0.031072,
144 | "end_time": "2024-01-05T17:14:50.507226",
145 | "exception": false,
146 | "start_time": "2024-01-05T17:14:50.476154",
147 | "status": "completed"
148 | },
149 | "tags": []
150 | },
151 | "outputs": [],
152 | "source": [
153 | "train = pd.read_csv(INPUT_PATH / 'train_df.csv')\n",
154 | "train['Target'] = train['Target'].astype(int)"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "id": "68a6c905",
160 | "metadata": {
161 | "papermill": {
162 | "duration": 0.005868,
163 | "end_time": "2024-01-05T17:14:50.519041",
164 | "exception": false,
165 | "start_time": "2024-01-05T17:14:50.513173",
166 | "status": "completed"
167 | },
168 | "tags": []
169 | },
170 | "source": [
171 | "# Split Folds"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 4,
177 | "id": "f8827c0e",
178 | "metadata": {
179 | "execution": {
180 | "iopub.execute_input": "2024-01-05T17:14:50.532855Z",
181 | "iopub.status.busy": "2024-01-05T17:14:50.532108Z",
182 | "iopub.status.idle": "2024-01-05T17:14:50.540434Z",
183 | "shell.execute_reply": "2024-01-05T17:14:50.539490Z"
184 | },
185 | "papermill": {
186 | "duration": 0.017352,
187 | "end_time": "2024-01-05T17:14:50.542410",
188 | "exception": false,
189 | "start_time": "2024-01-05T17:14:50.525058",
190 | "status": "completed"
191 | },
192 | "tags": []
193 | },
194 | "outputs": [],
195 | "source": [
196 | "def prepare_folds(train, features, target):\n",
197 | " X = train[features]\n",
198 | " y = train[target]\n",
199 | " skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=1)\n",
200 | " scaler = StandardScaler()\n",
201 | " \n",
202 | " fold = 0\n",
203 | " for train_indices, val_indices in skf.split(X, y):\n",
204 | " fold += 1\n",
205 | " print(f'Preparing fold {fold} ...')\n",
206 | " df_train = train.loc[train.index.isin(train_indices)].reset_index(drop=True)\n",
207 | " df_val = train.loc[train.index.isin(val_indices)].reset_index(drop=True)\n",
208 | " \n",
209 | "# df_train[features] = scaler.fit_transform(df_train[features])\n",
210 | "# df_val[features] = scaler.transform(df_val[features])\n",
211 | " \n",
212 | " test = pd.read_csv(INPUT_PATH / 'test_df.csv')\n",
213 | " \n",
214 | "# test[features] = scaler.transform(test[features])\n",
215 | " test['Target'] = 0\n",
216 | " df_train.to_csv(f'df_train_fold_{fold}.csv', index=False)\n",
217 | " df_val.to_csv(f'df_val_fold_{fold}.csv', index=False)\n",
218 | " test.to_csv(f'test_fold_{fold}.csv', index=False)"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": 5,
224 | "id": "f4472a58",
225 | "metadata": {
226 | "execution": {
227 | "iopub.execute_input": "2024-01-05T17:14:50.556472Z",
228 | "iopub.status.busy": "2024-01-05T17:14:50.555601Z",
229 | "iopub.status.idle": "2024-01-05T17:14:50.725218Z",
230 | "shell.execute_reply": "2024-01-05T17:14:50.724036Z"
231 | },
232 | "papermill": {
233 | "duration": 0.179071,
234 | "end_time": "2024-01-05T17:14:50.727467",
235 | "exception": false,
236 | "start_time": "2024-01-05T17:14:50.548396",
237 | "status": "completed"
238 | },
239 | "tags": []
240 | },
241 | "outputs": [
242 | {
243 | "name": "stdout",
244 | "output_type": "stream",
245 | "text": [
246 | "Preparing fold 1 ...\n",
247 | "Preparing fold 2 ...\n",
248 | "Preparing fold 3 ...\n",
249 | "Preparing fold 4 ...\n",
250 | "Preparing fold 5 ...\n",
251 | "Preparing fold 6 ...\n",
252 | "Preparing fold 7 ...\n",
253 | "Preparing fold 8 ...\n",
254 | "Preparing fold 9 ...\n",
255 | "Preparing fold 10 ...\n"
256 | ]
257 | }
258 | ],
259 | "source": [
260 | "prepare_folds(train, ORIGINAL_FEATURES, TARGET)"
261 | ]
262 | },
263 | {
264 | "cell_type": "markdown",
265 | "id": "43edb98a",
266 | "metadata": {
267 | "papermill": {
268 | "duration": 0.0058,
269 | "end_time": "2024-01-05T17:14:50.739435",
270 | "exception": false,
271 | "start_time": "2024-01-05T17:14:50.733635",
272 | "status": "completed"
273 | },
274 | "tags": []
275 | },
276 | "source": [
277 | "# Dataset and DataLoader"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": 6,
283 | "id": "a490c6b5",
284 | "metadata": {
285 | "execution": {
286 | "iopub.execute_input": "2024-01-05T17:14:50.753183Z",
287 | "iopub.status.busy": "2024-01-05T17:14:50.752806Z",
288 | "iopub.status.idle": "2024-01-05T17:14:50.761722Z",
289 | "shell.execute_reply": "2024-01-05T17:14:50.760854Z"
290 | },
291 | "papermill": {
292 | "duration": 0.018273,
293 | "end_time": "2024-01-05T17:14:50.763716",
294 | "exception": false,
295 | "start_time": "2024-01-05T17:14:50.745443",
296 | "status": "completed"
297 | },
298 | "tags": []
299 | },
300 | "outputs": [],
301 | "source": [
302 | "class SpamDataset(Dataset):\n",
303 | " def __init__(self, features, targets):\n",
304 | " self.features = torch.tensor(features, dtype=torch.float)\n",
305 | " self.targets = torch.tensor(targets, dtype=torch.long)\n",
306 | "\n",
307 | " def __getitem__(self, index):\n",
308 | " X = self.features[index]\n",
309 | " y = self.targets[index]\n",
310 | " return X, y\n",
311 | "\n",
312 | " def __len__(self):\n",
313 | " return self.targets.shape[0]\n",
314 | "\n",
315 | "\n",
316 | "def get_dataloader(df, \n",
317 | " feature_names, \n",
318 | " target_name,\n",
319 | " batch_size,\n",
320 | " mode):\n",
321 | " if mode == 'train':\n",
322 | " shuffle = True\n",
323 | " drop_last = True\n",
324 | " else:\n",
325 | " shuffle = False\n",
326 | " drop_last = False\n",
327 | " \n",
328 | " torch.manual_seed(SEED)\n",
329 | " train_dataset = SpamDataset(\n",
330 | " features=df[feature_names].to_numpy(), \n",
331 | " targets=df[target_name].to_numpy()\n",
332 | " )\n",
333 | " \n",
334 | " data_loader = DataLoader(\n",
335 | " dataset=train_dataset,\n",
336 | " batch_size=batch_size,\n",
337 | " shuffle=shuffle,\n",
338 | " drop_last=drop_last,\n",
339 | " num_workers=N_CORES\n",
340 | " )\n",
341 | " \n",
342 | " return data_loader"
343 | ]
344 | },
345 | {
346 | "cell_type": "markdown",
347 | "id": "fb6b3248",
348 | "metadata": {
349 | "papermill": {
350 | "duration": 0.005775,
351 | "end_time": "2024-01-05T17:14:50.775670",
352 | "exception": false,
353 | "start_time": "2024-01-05T17:14:50.769895",
354 | "status": "completed"
355 | },
356 | "tags": []
357 | },
358 | "source": [
359 | "# Model"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": 7,
365 | "id": "668d7f56",
366 | "metadata": {
367 | "execution": {
368 | "iopub.execute_input": "2024-01-05T17:14:50.789282Z",
369 | "iopub.status.busy": "2024-01-05T17:14:50.788904Z",
370 | "iopub.status.idle": "2024-01-05T17:14:50.807039Z",
371 | "shell.execute_reply": "2024-01-05T17:14:50.806070Z"
372 | },
373 | "papermill": {
374 | "duration": 0.02739,
375 | "end_time": "2024-01-05T17:14:50.809009",
376 | "exception": false,
377 | "start_time": "2024-01-05T17:14:50.781619",
378 | "status": "completed"
379 | },
380 | "tags": []
381 | },
382 | "outputs": [],
383 | "source": [
384 | "class GatedLinearUnit(nn.Module):\n",
385 | " def __init__(self, input_size):\n",
386 | " super(GatedLinearUnit, self).__init__()\n",
387 | " self.linear = nn.Linear(input_size, input_size)\n",
388 | " self.gate = nn.Sequential(\n",
389 | " nn.Linear(input_size, input_size),\n",
390 | " nn.Sigmoid()\n",
391 | " )\n",
392 | " \n",
393 | " def forward(self, x):\n",
394 | " return self.linear(x) * self.gate(x)\n",
395 | " \n",
396 | " \n",
397 | "class GatedResidualNetwork(nn.Module):\n",
398 | " def __init__(self, input_size, hidden_size, dropout):\n",
399 | " super(GatedResidualNetwork, self).__init__()\n",
400 | " self.input_size = input_size\n",
401 | " self.hidden_size = hidden_size\n",
402 | " \n",
403 | " self.grn = nn.Sequential(\n",
404 | " nn.Linear(input_size, hidden_size),\n",
405 | " nn.ELU(),\n",
406 | " nn.Linear(hidden_size, hidden_size),\n",
407 | " nn.Dropout(dropout),\n",
408 | " GatedLinearUnit(hidden_size),\n",
409 | " )\n",
410 | " \n",
411 | " self.layer_norm = nn.LayerNorm(hidden_size)\n",
412 | " self.feature_projection = nn.Linear(input_size, hidden_size)\n",
413 | " \n",
414 | " def forward(self, inputs):\n",
415 | " x = self.grn(inputs)\n",
416 | " if inputs.shape[-1] != self.hidden_size:\n",
417 | " inputs = self.feature_projection(inputs)\n",
418 | " x = self.layer_norm(x + inputs)\n",
419 | " return x\n",
420 | " \n",
421 | "class VariableSelectionNetwork(nn.Module):\n",
422 | " def __init__(self, num_features, dense_units, hidden_size, dropout):\n",
423 | " super(VariableSelectionNetwork, self).__init__()\n",
424 | " self.num_features = num_features\n",
425 | " self.hidden_size = hidden_size\n",
426 | " self.grns = nn.ModuleList()\n",
427 | " for _ in range(num_features):\n",
428 | " self.grns.append(GatedResidualNetwork(dense_units, hidden_size, dropout))\n",
429 | " \n",
430 | " \n",
431 | " self.grn_concat = GatedResidualNetwork(num_features*dense_units, hidden_size, dropout)\n",
432 | " self.softmax = nn.Sequential(\n",
433 | " nn.Linear(hidden_size, num_features),\n",
434 | " nn.Softmax(dim=-1)\n",
435 | " )\n",
436 | " \n",
437 | " def forward(self, inputs):\n",
438 | " v = torch.cat(inputs, dim=1)\n",
439 | " v = self.grn_concat(v)\n",
440 | " v = self.softmax(v)\n",
441 | " v = torch.unsqueeze(v, dim=-1)\n",
442 | " \n",
443 | " x = []\n",
444 | " for idx, input_ in enumerate(inputs):\n",
445 | " x.append(self.grns[idx](input_))\n",
446 | " x = torch.stack(x, dim=1)\n",
447 | " \n",
448 | " out = (v.transpose(2, 1) @ x).squeeze(dim=1)\n",
449 | " return out\n",
450 | " \n",
451 | "class VariableSelectionFlow(nn.Module):\n",
452 | " def __init__(self, num_features, hidden_size, dense_units, dropout):\n",
453 | " super(VariableSelectionFlow, self).__init__()\n",
454 | " self.variable_selection = VariableSelectionNetwork(num_features, dense_units, hidden_size, dropout)\n",
455 | " self.split = lambda x: torch.split(x, 1, dim=-1)\n",
456 | " self.dense_list = nn.ModuleList(\n",
457 | " [\n",
458 | " nn.Linear(1, dense_units) \n",
459 | " for _ in range(num_features)\n",
460 | " ]\n",
461 | " )\n",
462 | " \n",
463 | " \n",
464 | " def forward(self, inputs):\n",
465 | " split_inputs = self.split(inputs)\n",
466 | " x = []\n",
467 | " for split_input, linear in zip(split_inputs, self.dense_list):\n",
468 | " x.append(linear(split_input))\n",
469 | " \n",
470 | " return self.variable_selection(x)"
471 | ]
472 | },
473 | {
474 | "cell_type": "code",
475 | "execution_count": 8,
476 | "id": "596185e3",
477 | "metadata": {
478 | "execution": {
479 | "iopub.execute_input": "2024-01-05T17:14:50.823298Z",
480 | "iopub.status.busy": "2024-01-05T17:14:50.822450Z",
481 | "iopub.status.idle": "2024-01-05T17:14:50.829876Z",
482 | "shell.execute_reply": "2024-01-05T17:14:50.828956Z"
483 | },
484 | "papermill": {
485 | "duration": 0.016799,
486 | "end_time": "2024-01-05T17:14:50.831889",
487 | "exception": false,
488 | "start_time": "2024-01-05T17:14:50.815090",
489 | "status": "completed"
490 | },
491 | "tags": []
492 | },
493 | "outputs": [],
494 | "source": [
495 | "class Net(nn.Module):\n",
496 | " def __init__(self, num_features, dense_units, hidden_size, dropout):\n",
497 | " super(Net, self).__init__()\n",
498 | " self.num_features = num_features\n",
499 | " self.dense_units = dense_units\n",
500 | " self.hidden_size = hidden_size\n",
501 | "\n",
502 | " self.dropout = dropout\n",
503 | " \n",
504 | " self.variable_slection_flows = nn.Sequential(\n",
505 | " VariableSelectionFlow(num_features, self.hidden_size, dense_units, self.dropout),\n",
506 | " nn.Linear(self.hidden_size, 2)\n",
507 | " )\n",
508 | " \n",
509 | " self.apply(self._init_weights)\n",
510 | "\n",
511 | " def _init_weights(self, module):\n",
512 | " if isinstance(module, nn.Linear):\n",
513 | " torch.nn.init.xavier_uniform_(module.weight)\n",
514 | " \n",
515 | " def forward(self, x):\n",
516 | " logits = self.variable_slection_flows(x)\n",
517 | " return logits"
518 | ]
519 | },
520 | {
521 | "cell_type": "markdown",
522 | "id": "5f7bd44f",
523 | "metadata": {
524 | "papermill": {
525 | "duration": 0.006125,
526 | "end_time": "2024-01-05T17:14:50.844108",
527 | "exception": false,
528 | "start_time": "2024-01-05T17:14:50.837983",
529 | "status": "completed"
530 | },
531 | "tags": []
532 | },
533 | "source": [
534 | "# Training"
535 | ]
536 | },
537 | {
538 | "cell_type": "code",
539 | "execution_count": 9,
540 | "id": "8c508cca",
541 | "metadata": {
542 | "execution": {
543 | "iopub.execute_input": "2024-01-05T17:14:50.857821Z",
544 | "iopub.status.busy": "2024-01-05T17:14:50.857461Z",
545 | "iopub.status.idle": "2024-01-05T17:14:50.869526Z",
546 | "shell.execute_reply": "2024-01-05T17:14:50.868544Z"
547 | },
548 | "papermill": {
549 | "duration": 0.021545,
550 | "end_time": "2024-01-05T17:14:50.871731",
551 | "exception": false,
552 | "start_time": "2024-01-05T17:14:50.850186",
553 | "status": "completed"
554 | },
555 | "tags": []
556 | },
557 | "outputs": [],
558 | "source": [
559 | "def fit(model, optimizer, scheduler, epochs, train_dataloader, val_dataloader):\n",
560 | "\n",
561 | " start_time = time.time()\n",
562 | " scaler = GradScaler()\n",
563 | "\n",
564 | " for epoch in range(epochs):\n",
565 | "\n",
566 | " model.train()\n",
567 | " \n",
568 | " for batch_idx, (features, targets) in enumerate(train_dataloader):\n",
569 | " features = features.to(DEVICE)\n",
570 | " targets = targets.to(DEVICE)\n",
571 | " with autocast():\n",
572 | " logits = model(features) \n",
573 | " loss = F.cross_entropy(logits, targets)\n",
574 | "\n",
575 | " scaler.scale(loss).backward()\n",
576 | " scaler.step(optimizer)\n",
577 | " scaler.update()\n",
578 | " optimizer.zero_grad()\n",
579 | "\n",
580 | "# if not batch_idx % 10:\n",
581 | "# print(\n",
582 | "# f'Epoch: {epoch + 1}/{epochs}'\n",
583 | "# f' | Batch: {batch_idx}/{len(train_dataloader)}'\n",
584 | "# f' | Loss: {loss.detach().cpu().item():.4f}')\n",
585 | "\n",
586 | " if val_dataloader is not None:\n",
587 | " y_scores = torch.tensor([])\n",
588 | " y_true = torch.tensor([])\n",
589 | "\n",
590 | " with torch.inference_mode():\n",
591 | "\n",
592 | " model.eval()\n",
593 | "\n",
594 | " for batch_idx, (features, targets) in enumerate(val_dataloader):\n",
595 | " features = features.to(DEVICE)\n",
596 | " with autocast():\n",
597 | " logits = model(features).detach().cpu().type(torch.float)\n",
598 | " probs = F.softmax(logits, dim=-1)[:, 1]\n",
599 | " y_scores = torch.cat([y_scores, probs])\n",
600 | " y_true = torch.cat([y_true, targets])\n",
601 | "\n",
602 | " val_score = roc_auc_score(y_true, y_scores)\n",
603 | " print('Validation score (AUC):', val_score.item())\n",
604 | "\n",
605 | " elapsed = (time.time() - start_time) / 60\n",
606 | " print(f'Total training time: {elapsed:.3f} min')\n",
607 | "\n",
608 | " model.eval()\n",
609 | "\n",
610 | " if val_dataloader is not None:\n",
611 | " return model, val_score\n",
612 | " else:\n",
613 | " return model"
614 | ]
615 | },
616 | {
617 | "cell_type": "code",
618 | "execution_count": 10,
619 | "id": "d3781f98",
620 | "metadata": {
621 | "execution": {
622 | "iopub.execute_input": "2024-01-05T17:14:50.885901Z",
623 | "iopub.status.busy": "2024-01-05T17:14:50.885209Z",
624 | "iopub.status.idle": "2024-01-05T17:14:50.890099Z",
625 | "shell.execute_reply": "2024-01-05T17:14:50.889155Z"
626 | },
627 | "papermill": {
628 | "duration": 0.01412,
629 | "end_time": "2024-01-05T17:14:50.892095",
630 | "exception": false,
631 | "start_time": "2024-01-05T17:14:50.877975",
632 | "status": "completed"
633 | },
634 | "tags": []
635 | },
636 | "outputs": [],
637 | "source": [
638 | "N_EPOCHS = 8\n",
639 | "LEARNING_RATE = 0.001\n",
640 | "WEIGHT_DECAY = 0.0\n",
641 | "SEEDS = [789, 279, 318, 2001, 1976, 1966, 1994, 252]"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": 11,
647 | "id": "4bec664f",
648 | "metadata": {
649 | "execution": {
650 | "iopub.execute_input": "2024-01-05T17:14:50.906197Z",
651 | "iopub.status.busy": "2024-01-05T17:14:50.905304Z",
652 | "iopub.status.idle": "2024-01-05T17:19:00.587022Z",
653 | "shell.execute_reply": "2024-01-05T17:19:00.585910Z"
654 | },
655 | "papermill": {
656 | "duration": 249.691262,
657 | "end_time": "2024-01-05T17:19:00.589369",
658 | "exception": false,
659 | "start_time": "2024-01-05T17:14:50.898107",
660 | "status": "completed"
661 | },
662 | "tags": []
663 | },
664 | "outputs": [
665 | {
666 | "name": "stdout",
667 | "output_type": "stream",
668 | "text": [
669 | "Seed 789\n",
670 | "Starting fold 1 ...\n",
671 | "Validation score (AUC): 0.8086734693877551\n",
672 | "Validation score (AUC): 0.8418367346938775\n",
673 | "Validation score (AUC): 0.8596938775510203\n",
674 | "Validation score (AUC): 0.8571428571428572\n",
675 | "Validation score (AUC): 0.8367346938775511\n",
676 | "Validation score (AUC): 0.8443877551020408\n",
677 | "Validation score (AUC): 0.854591836734694\n",
678 | "Validation score (AUC): 0.8545918367346939\n",
679 | "Total training time: 0.054 min\n",
680 | "Starting fold 2 ...\n",
681 | "Validation score (AUC): 0.7704081632653061\n",
682 | "Validation score (AUC): 0.8061224489795918\n",
683 | "Validation score (AUC): 0.7882653061224489\n",
684 | "Validation score (AUC): 0.7755102040816326\n",
685 | "Validation score (AUC): 0.7551020408163265\n",
686 | "Validation score (AUC): 0.7576530612244898\n",
687 | "Validation score (AUC): 0.7372448979591837\n",
688 | "Validation score (AUC): 0.7423469387755102\n",
689 | "Total training time: 0.048 min\n",
690 | "Starting fold 3 ...\n",
691 | "Validation score (AUC): 0.7627551020408162\n",
692 | "Validation score (AUC): 0.8163265306122449\n",
693 | "Validation score (AUC): 0.8239795918367346\n",
694 | "Validation score (AUC): 0.826530612244898\n",
695 | "Validation score (AUC): 0.8392857142857142\n",
696 | "Validation score (AUC): 0.8443877551020408\n",
697 | "Validation score (AUC): 0.8443877551020409\n",
698 | "Validation score (AUC): 0.8392857142857143\n",
699 | "Total training time: 0.051 min\n",
700 | "Starting fold 4 ...\n",
701 | "Validation score (AUC): 0.8035714285714286\n",
702 | "Validation score (AUC): 0.8596938775510203\n",
703 | "Validation score (AUC): 0.8775510204081632\n",
704 | "Validation score (AUC): 0.8673469387755102\n",
705 | "Validation score (AUC): 0.8724489795918368\n",
706 | "Validation score (AUC): 0.875\n",
707 | "Validation score (AUC): 0.8801020408163265\n",
708 | "Validation score (AUC): 0.8801020408163265\n",
709 | "Total training time: 0.045 min\n",
710 | "Starting fold 5 ...\n",
711 | "Validation score (AUC): 0.6607142857142858\n",
712 | "Validation score (AUC): 0.75\n",
713 | "Validation score (AUC): 0.7908163265306123\n",
714 | "Validation score (AUC): 0.8163265306122449\n",
715 | "Validation score (AUC): 0.826530612244898\n",
716 | "Validation score (AUC): 0.8316326530612245\n",
717 | "Validation score (AUC): 0.8443877551020408\n",
718 | "Validation score (AUC): 0.8469387755102041\n",
719 | "Total training time: 0.053 min\n",
720 | "Starting fold 6 ...\n",
721 | "Validation score (AUC): 0.6505102040816327\n",
722 | "Validation score (AUC): 0.6760204081632653\n",
723 | "Validation score (AUC): 0.7066326530612245\n",
724 | "Validation score (AUC): 0.729591836734694\n",
725 | "Validation score (AUC): 0.7627551020408163\n",
726 | "Validation score (AUC): 0.7704081632653061\n",
727 | "Validation score (AUC): 0.7729591836734694\n",
728 | "Validation score (AUC): 0.7831632653061226\n",
729 | "Total training time: 0.050 min\n",
730 | "Starting fold 7 ...\n",
731 | "Validation score (AUC): 0.6962962962962962\n",
732 | "Validation score (AUC): 0.7382716049382716\n",
733 | "Validation score (AUC): 0.7679012345679012\n",
734 | "Validation score (AUC): 0.8074074074074075\n",
735 | "Validation score (AUC): 0.8296296296296296\n",
736 | "Validation score (AUC): 0.8320987654320987\n",
737 | "Validation score (AUC): 0.8370370370370369\n",
738 | "Validation score (AUC): 0.8419753086419752\n",
739 | "Total training time: 0.046 min\n",
740 | "Starting fold 8 ...\n",
741 | "Validation score (AUC): 0.7135802469135802\n",
742 | "Validation score (AUC): 0.7703703703703704\n",
743 | "Validation score (AUC): 0.8197530864197531\n",
744 | "Validation score (AUC): 0.8592592592592592\n",
745 | "Validation score (AUC): 0.8814814814814814\n",
746 | "Validation score (AUC): 0.9012345679012346\n",
747 | "Validation score (AUC): 0.9037037037037037\n",
748 | "Validation score (AUC): 0.9061728395061728\n",
749 | "Total training time: 0.050 min\n",
750 | "Starting fold 9 ...\n",
751 | "Validation score (AUC): 0.7851851851851851\n",
752 | "Validation score (AUC): 0.7876543209876543\n",
753 | "Validation score (AUC): 0.8074074074074074\n",
754 | "Validation score (AUC): 0.8222222222222222\n",
755 | "Validation score (AUC): 0.8098765432098765\n",
756 | "Validation score (AUC): 0.802469135802469\n",
757 | "Validation score (AUC): 0.8049382716049382\n",
758 | "Validation score (AUC): 0.8197530864197531\n",
759 | "Total training time: 0.049 min\n",
760 | "Starting fold 10 ...\n",
761 | "Validation score (AUC): 0.7925925925925925\n",
762 | "Validation score (AUC): 0.8345679012345679\n",
763 | "Validation score (AUC): 0.8320987654320987\n",
764 | "Validation score (AUC): 0.8395061728395061\n",
765 | "Validation score (AUC): 0.8222222222222222\n",
766 | "Validation score (AUC): 0.8222222222222222\n",
767 | "Validation score (AUC): 0.8469135802469135\n",
768 | "Validation score (AUC): 0.8469135802469135\n",
769 | "Total training time: 0.048 min\n",
770 | "Seed 279\n",
771 | "Starting fold 1 ...\n",
772 | "Validation score (AUC): 0.37244897959183676\n",
773 | "Validation score (AUC): 0.42857142857142855\n",
774 | "Validation score (AUC): 0.49744897959183676\n",
775 | "Validation score (AUC): 0.5663265306122449\n",
776 | "Validation score (AUC): 0.625\n",
777 | "Validation score (AUC): 0.75\n",
778 | "Validation score (AUC): 0.8239795918367347\n",
779 | "Validation score (AUC): 0.8418367346938775\n",
780 | "Total training time: 0.052 min\n",
781 | "Starting fold 2 ...\n",
782 | "Validation score (AUC): 0.3903061224489796\n",
783 | "Validation score (AUC): 0.45918367346938777\n",
784 | "Validation score (AUC): 0.5790816326530612\n",
785 | "Validation score (AUC): 0.6479591836734694\n",
786 | "Validation score (AUC): 0.7372448979591837\n",
787 | "Validation score (AUC): 0.7716836734693878\n",
788 | "Validation score (AUC): 0.7959183673469388\n",
789 | "Validation score (AUC): 0.778061224489796\n",
790 | "Total training time: 0.050 min\n",
791 | "Starting fold 3 ...\n",
792 | "Validation score (AUC): 0.5102040816326531\n",
793 | "Validation score (AUC): 0.6224489795918368\n",
794 | "Validation score (AUC): 0.7372448979591837\n",
795 | "Validation score (AUC): 0.7857142857142857\n",
796 | "Validation score (AUC): 0.8290816326530613\n",
797 | "Validation score (AUC): 0.8341836734693877\n",
798 | "Validation score (AUC): 0.8392857142857142\n",
799 | "Validation score (AUC): 0.8520408163265307\n",
800 | "Total training time: 0.050 min\n",
801 | "Starting fold 4 ...\n",
802 | "Validation score (AUC): 0.5\n",
803 | "Validation score (AUC): 0.6071428571428572\n",
804 | "Validation score (AUC): 0.7346938775510204\n",
805 | "Validation score (AUC): 0.8010204081632653\n",
806 | "Validation score (AUC): 0.8571428571428571\n",
807 | "Validation score (AUC): 0.8673469387755102\n",
808 | "Validation score (AUC): 0.8367346938775511\n",
809 | "Validation score (AUC): 0.8265306122448979\n",
810 | "Total training time: 0.048 min\n",
811 | "Starting fold 5 ...\n",
812 | "Validation score (AUC): 0.5191326530612245\n",
813 | "Validation score (AUC): 0.6326530612244898\n",
814 | "Validation score (AUC): 0.7295918367346939\n",
815 | "Validation score (AUC): 0.7602040816326532\n",
816 | "Validation score (AUC): 0.7423469387755103\n",
817 | "Validation score (AUC): 0.7474489795918368\n",
818 | "Validation score (AUC): 0.7474489795918368\n",
819 | "Validation score (AUC): 0.788265306122449\n",
820 | "Total training time: 0.047 min\n",
821 | "Starting fold 6 ...\n",
822 | "Validation score (AUC): 0.4948979591836734\n",
823 | "Validation score (AUC): 0.6045918367346939\n",
824 | "Validation score (AUC): 0.6505102040816326\n",
825 | "Validation score (AUC): 0.6658163265306123\n",
826 | "Validation score (AUC): 0.7117346938775511\n",
827 | "Validation score (AUC): 0.6785714285714286\n",
828 | "Validation score (AUC): 0.6913265306122449\n",
829 | "Validation score (AUC): 0.7142857142857143\n",
830 | "Total training time: 0.050 min\n",
831 | "Starting fold 7 ...\n",
832 | "Validation score (AUC): 0.48148148148148145\n",
833 | "Validation score (AUC): 0.6148148148148148\n",
834 | "Validation score (AUC): 0.7555555555555555\n",
835 | "Validation score (AUC): 0.8469135802469135\n",
836 | "Validation score (AUC): 0.851851851851852\n",
837 | "Validation score (AUC): 0.817283950617284\n",
838 | "Validation score (AUC): 0.8\n",
839 | "Validation score (AUC): 0.7950617283950616\n",
840 | "Total training time: 0.046 min\n",
841 | "Starting fold 8 ...\n",
842 | "Validation score (AUC): 0.41481481481481475\n",
843 | "Validation score (AUC): 0.5975308641975308\n",
844 | "Validation score (AUC): 0.6518518518518518\n",
845 | "Validation score (AUC): 0.6790123456790123\n",
846 | "Validation score (AUC): 0.6938271604938272\n",
847 | "Validation score (AUC): 0.7530864197530864\n",
848 | "Validation score (AUC): 0.8\n",
849 | "Validation score (AUC): 0.8469135802469137\n",
850 | "Total training time: 0.047 min\n",
851 | "Starting fold 9 ...\n",
852 | "Validation score (AUC): 0.2888888888888889\n",
853 | "Validation score (AUC): 0.362962962962963\n",
854 | "Validation score (AUC): 0.46172839506172836\n",
855 | "Validation score (AUC): 0.5506172839506173\n",
856 | "Validation score (AUC): 0.6444444444444444\n",
857 | "Validation score (AUC): 0.6839506172839506\n",
858 | "Validation score (AUC): 0.7358024691358025\n",
859 | "Validation score (AUC): 0.7530864197530864\n",
860 | "Total training time: 0.053 min\n",
861 | "Starting fold 10 ...\n",
862 | "Validation score (AUC): 0.4419753086419753\n",
863 | "Validation score (AUC): 0.46419753086419757\n",
864 | "Validation score (AUC): 0.5037037037037037\n",
865 | "Validation score (AUC): 0.528395061728395\n",
866 | "Validation score (AUC): 0.5876543209876544\n",
867 | "Validation score (AUC): 0.7234567901234568\n",
868 | "Validation score (AUC): 0.7901234567901234\n",
869 | "Validation score (AUC): 0.8123456790123456\n",
870 | "Total training time: 0.047 min\n",
871 | "Seed 318\n",
872 | "Starting fold 1 ...\n",
873 | "Validation score (AUC): 0.691326530612245\n",
874 | "Validation score (AUC): 0.7372448979591837\n",
875 | "Validation score (AUC): 0.7346938775510203\n",
876 | "Validation score (AUC): 0.7346938775510204\n",
877 | "Validation score (AUC): 0.7346938775510204\n",
878 | "Validation score (AUC): 0.7346938775510204\n",
879 | "Validation score (AUC): 0.7372448979591837\n",
880 | "Validation score (AUC): 0.7372448979591837\n",
881 | "Total training time: 0.049 min\n",
882 | "Starting fold 2 ...\n",
883 | "Validation score (AUC): 0.7397959183673469\n",
884 | "Validation score (AUC): 0.7040816326530612\n",
885 | "Validation score (AUC): 0.7066326530612245\n",
886 | "Validation score (AUC): 0.7015306122448979\n",
887 | "Validation score (AUC): 0.6989795918367346\n",
888 | "Validation score (AUC): 0.6989795918367346\n",
889 | "Validation score (AUC): 0.6989795918367346\n",
890 | "Validation score (AUC): 0.701530612244898\n",
891 | "Total training time: 0.052 min\n",
892 | "Starting fold 3 ...\n",
893 | "Validation score (AUC): 0.8647959183673468\n",
894 | "Validation score (AUC): 0.8520408163265306\n",
895 | "Validation score (AUC): 0.8469387755102041\n",
896 | "Validation score (AUC): 0.8469387755102041\n",
897 | "Validation score (AUC): 0.8469387755102041\n",
898 | "Validation score (AUC): 0.8469387755102041\n",
899 | "Validation score (AUC): 0.8469387755102041\n",
900 | "Validation score (AUC): 0.8520408163265306\n",
901 | "Total training time: 0.048 min\n",
902 | "Starting fold 4 ...\n",
903 | "Validation score (AUC): 0.7908163265306123\n",
904 | "Validation score (AUC): 0.8188775510204082\n",
905 | "Validation score (AUC): 0.7933673469387755\n",
906 | "Validation score (AUC): 0.7908163265306123\n",
907 | "Validation score (AUC): 0.7908163265306123\n",
908 | "Validation score (AUC): 0.7908163265306123\n",
909 | "Validation score (AUC): 0.7933673469387755\n",
910 | "Validation score (AUC): 0.7908163265306122\n",
911 | "Total training time: 0.051 min\n",
912 | "Starting fold 5 ...\n",
913 | "Validation score (AUC): 0.8418367346938775\n",
914 | "Validation score (AUC): 0.8494897959183673\n",
915 | "Validation score (AUC): 0.8520408163265305\n",
916 | "Validation score (AUC): 0.8558673469387754\n",
917 | "Validation score (AUC): 0.8571428571428572\n",
918 | "Validation score (AUC): 0.8545918367346939\n",
919 | "Validation score (AUC): 0.8520408163265305\n",
920 | "Validation score (AUC): 0.8520408163265306\n",
921 | "Total training time: 0.046 min\n",
922 | "Starting fold 6 ...\n",
923 | "Validation score (AUC): 0.701530612244898\n",
924 | "Validation score (AUC): 0.7066326530612245\n",
925 | "Validation score (AUC): 0.701530612244898\n",
926 | "Validation score (AUC): 0.701530612244898\n",
927 | "Validation score (AUC): 0.701530612244898\n",
928 | "Validation score (AUC): 0.7066326530612245\n",
929 | "Validation score (AUC): 0.7040816326530612\n",
930 | "Validation score (AUC): 0.7066326530612245\n",
931 | "Total training time: 0.047 min\n",
932 | "Starting fold 7 ...\n",
933 | "Validation score (AUC): 0.8222222222222222\n",
934 | "Validation score (AUC): 0.7999999999999999\n",
935 | "Validation score (AUC): 0.7753086419753085\n",
936 | "Validation score (AUC): 0.7679012345679012\n",
937 | "Validation score (AUC): 0.7703703703703703\n",
938 | "Validation score (AUC): 0.7703703703703703\n",
939 | "Validation score (AUC): 0.7753086419753085\n",
940 | "Validation score (AUC): 0.7777777777777778\n",
941 | "Total training time: 0.052 min\n",
942 | "Starting fold 8 ...\n",
943 | "Validation score (AUC): 0.7975308641975309\n",
944 | "Validation score (AUC): 0.8567901234567902\n",
945 | "Validation score (AUC): 0.8592592592592593\n",
946 | "Validation score (AUC): 0.8592592592592593\n",
947 | "Validation score (AUC): 0.8592592592592593\n",
948 | "Validation score (AUC): 0.8592592592592593\n",
949 | "Validation score (AUC): 0.8592592592592593\n",
950 | "Validation score (AUC): 0.8592592592592593\n",
951 | "Total training time: 0.046 min\n",
952 | "Starting fold 9 ...\n",
953 | "Validation score (AUC): 0.7407407407407407\n",
954 | "Validation score (AUC): 0.7530864197530864\n",
955 | "Validation score (AUC): 0.7358024691358025\n",
956 | "Validation score (AUC): 0.7382716049382716\n",
957 | "Validation score (AUC): 0.7382716049382716\n",
958 | "Validation score (AUC): 0.7382716049382716\n",
959 | "Validation score (AUC): 0.7382716049382716\n",
960 | "Validation score (AUC): 0.7407407407407407\n",
961 | "Total training time: 0.048 min\n",
962 | "Starting fold 10 ...\n",
963 | "Validation score (AUC): 0.6913580246913581\n",
964 | "Validation score (AUC): 0.8123456790123457\n",
965 | "Validation score (AUC): 0.8148148148148147\n",
966 | "Validation score (AUC): 0.8148148148148148\n",
967 | "Validation score (AUC): 0.817283950617284\n",
968 | "Validation score (AUC): 0.817283950617284\n",
969 | "Validation score (AUC): 0.817283950617284\n",
970 | "Validation score (AUC): 0.8197530864197531\n",
971 | "Total training time: 0.047 min\n",
972 | "Seed 2001\n",
973 | "Starting fold 1 ...\n",
974 | "Validation score (AUC): 0.7551020408163266\n",
975 | "Validation score (AUC): 0.7525510204081634\n",
976 | "Validation score (AUC): 0.7602040816326531\n",
977 | "Validation score (AUC): 0.798469387755102\n",
978 | "Validation score (AUC): 0.8214285714285715\n",
979 | "Validation score (AUC): 0.8367346938775511\n",
980 | "Validation score (AUC): 0.826530612244898\n",
981 | "Validation score (AUC): 0.8278061224489796\n",
982 | "Total training time: 0.046 min\n",
983 | "Starting fold 2 ...\n",
984 | "Validation score (AUC): 0.7091836734693877\n",
985 | "Validation score (AUC): 0.7193877551020409\n",
986 | "Validation score (AUC): 0.7117346938775511\n",
987 | "Validation score (AUC): 0.7117346938775511\n",
988 | "Validation score (AUC): 0.7244897959183673\n",
989 | "Validation score (AUC): 0.7219387755102041\n",
990 | "Validation score (AUC): 0.7244897959183675\n",
991 | "Validation score (AUC): 0.7219387755102041\n",
992 | "Total training time: 0.055 min\n",
993 | "Starting fold 3 ...\n",
994 | "Validation score (AUC): 0.8341836734693878\n",
995 | "Validation score (AUC): 0.8596938775510203\n",
996 | "Validation score (AUC): 0.8418367346938775\n",
997 | "Validation score (AUC): 0.8290816326530612\n",
998 | "Validation score (AUC): 0.8112244897959184\n",
999 | "Validation score (AUC): 0.8086734693877551\n",
1000 | "Validation score (AUC): 0.8112244897959184\n",
1001 | "Validation score (AUC): 0.8188775510204082\n",
1002 | "Total training time: 0.048 min\n",
1003 | "Starting fold 4 ...\n",
1004 | "Validation score (AUC): 0.8418367346938775\n",
1005 | "Validation score (AUC): 0.8341836734693877\n",
1006 | "Validation score (AUC): 0.8188775510204082\n",
1007 | "Validation score (AUC): 0.826530612244898\n",
1008 | "Validation score (AUC): 0.8392857142857143\n",
1009 | "Validation score (AUC): 0.8469387755102042\n",
1010 | "Validation score (AUC): 0.8520408163265306\n",
1011 | "Validation score (AUC): 0.8571428571428572\n",
1012 | "Total training time: 0.048 min\n",
1013 | "Starting fold 5 ...\n",
1014 | "Validation score (AUC): 0.7780612244897959\n",
1015 | "Validation score (AUC): 0.7959183673469388\n",
1016 | "Validation score (AUC): 0.8520408163265306\n",
1017 | "Validation score (AUC): 0.8673469387755102\n",
1018 | "Validation score (AUC): 0.8622448979591837\n",
1019 | "Validation score (AUC): 0.8596938775510204\n",
1020 | "Validation score (AUC): 0.8545918367346939\n",
1021 | "Validation score (AUC): 0.8571428571428571\n",
1022 | "Total training time: 0.051 min\n",
1023 | "Starting fold 6 ...\n",
1024 | "Validation score (AUC): 0.7716836734693877\n",
1025 | "Validation score (AUC): 0.7806122448979592\n",
1026 | "Validation score (AUC): 0.7602040816326531\n",
1027 | "Validation score (AUC): 0.7448979591836735\n",
1028 | "Validation score (AUC): 0.7602040816326531\n",
1029 | "Validation score (AUC): 0.7602040816326532\n",
1030 | "Validation score (AUC): 0.7602040816326531\n",
1031 | "Validation score (AUC): 0.7678571428571429\n",
1032 | "Total training time: 0.046 min\n",
1033 | "Starting fold 7 ...\n",
1034 | "Validation score (AUC): 0.8222222222222222\n",
1035 | "Validation score (AUC): 0.8074074074074074\n",
1036 | "Validation score (AUC): 0.7851851851851851\n",
1037 | "Validation score (AUC): 0.7925925925925926\n",
1038 | "Validation score (AUC): 0.8024691358024691\n",
1039 | "Validation score (AUC): 0.817283950617284\n",
1040 | "Validation score (AUC): 0.8246913580246913\n",
1041 | "Validation score (AUC): 0.8222222222222222\n",
1042 | "Total training time: 0.049 min\n",
1043 | "Starting fold 8 ...\n",
1044 | "Validation score (AUC): 0.8493827160493828\n",
1045 | "Validation score (AUC): 0.9111111111111111\n",
1046 | "Validation score (AUC): 0.9111111111111111\n",
1047 | "Validation score (AUC): 0.9111111111111111\n",
1048 | "Validation score (AUC): 0.9407407407407407\n",
1049 | "Validation score (AUC): 0.9432098765432099\n",
1050 | "Validation score (AUC): 0.9481481481481482\n",
1051 | "Validation score (AUC): 0.9407407407407407\n",
1052 | "Total training time: 0.052 min\n",
1053 | "Starting fold 9 ...\n",
1054 | "Validation score (AUC): 0.7530864197530864\n",
1055 | "Validation score (AUC): 0.7679012345679013\n",
1056 | "Validation score (AUC): 0.7950617283950617\n",
1057 | "Validation score (AUC): 0.7802469135802469\n",
1058 | "Validation score (AUC): 0.8\n",
1059 | "Validation score (AUC): 0.8098765432098766\n",
1060 | "Validation score (AUC): 0.7975308641975308\n",
1061 | "Validation score (AUC): 0.8024691358024691\n",
1062 | "Total training time: 0.047 min\n",
1063 | "Starting fold 10 ...\n",
1064 | "Validation score (AUC): 0.837037037037037\n",
1065 | "Validation score (AUC): 0.8469135802469135\n",
1066 | "Validation score (AUC): 0.8493827160493828\n",
1067 | "Validation score (AUC): 0.8567901234567901\n",
1068 | "Validation score (AUC): 0.8567901234567901\n",
1069 | "Validation score (AUC): 0.8567901234567901\n",
1070 | "Validation score (AUC): 0.8641975308641976\n",
1071 | "Validation score (AUC): 0.8691358024691358\n",
1072 | "Total training time: 0.051 min\n",
1073 | "Seed 1976\n",
1074 | "Starting fold 1 ...\n",
1075 | "Validation score (AUC): 0.19642857142857142\n",
1076 | "Validation score (AUC): 0.3979591836734694\n",
1077 | "Validation score (AUC): 0.7066326530612245\n",
1078 | "Validation score (AUC): 0.7653061224489796\n",
1079 | "Validation score (AUC): 0.8163265306122449\n",
1080 | "Validation score (AUC): 0.8035714285714286\n",
1081 | "Validation score (AUC): 0.8112244897959183\n",
1082 | "Validation score (AUC): 0.8061224489795918\n",
1083 | "Total training time: 0.046 min\n",
1084 | "Starting fold 2 ...\n",
1085 | "Validation score (AUC): 0.29974489795918363\n",
1086 | "Validation score (AUC): 0.451530612244898\n",
1087 | "Validation score (AUC): 0.6683673469387755\n",
1088 | "Validation score (AUC): 0.7193877551020408\n",
1089 | "Validation score (AUC): 0.7321428571428571\n",
1090 | "Validation score (AUC): 0.7397959183673469\n",
1091 | "Validation score (AUC): 0.7372448979591837\n",
1092 | "Validation score (AUC): 0.7372448979591837\n",
1093 | "Total training time: 0.047 min\n",
1094 | "Starting fold 3 ...\n",
1095 | "Validation score (AUC): 0.2755102040816326\n",
1096 | "Validation score (AUC): 0.4872448979591837\n",
1097 | "Validation score (AUC): 0.7219387755102041\n",
1098 | "Validation score (AUC): 0.7627551020408163\n",
1099 | "Validation score (AUC): 0.7933673469387755\n",
1100 | "Validation score (AUC): 0.8214285714285714\n",
1101 | "Validation score (AUC): 0.8392857142857142\n",
1102 | "Validation score (AUC): 0.846938775510204\n",
1103 | "Total training time: 0.053 min\n",
1104 | "Starting fold 4 ...\n",
1105 | "Validation score (AUC): 0.32142857142857145\n",
1106 | "Validation score (AUC): 0.6045918367346939\n",
1107 | "Validation score (AUC): 0.8660714285714286\n",
1108 | "Validation score (AUC): 0.8622448979591838\n",
1109 | "Validation score (AUC): 0.8545918367346939\n",
1110 | "Validation score (AUC): 0.8418367346938774\n",
1111 | "Validation score (AUC): 0.826530612244898\n",
1112 | "Validation score (AUC): 0.8290816326530612\n",
1113 | "Total training time: 0.046 min\n",
1114 | "Starting fold 5 ...\n",
1115 | "Validation score (AUC): 0.2576530612244898\n",
1116 | "Validation score (AUC): 0.4362244897959183\n",
1117 | "Validation score (AUC): 0.8214285714285714\n",
1118 | "Validation score (AUC): 0.8418367346938775\n",
1119 | "Validation score (AUC): 0.864795918367347\n",
1120 | "Validation score (AUC): 0.8596938775510204\n",
1121 | "Validation score (AUC): 0.8571428571428571\n",
1122 | "Validation score (AUC): 0.8596938775510203\n",
1123 | "Total training time: 0.049 min\n",
1124 | "Starting fold 6 ...\n",
1125 | "Validation score (AUC): 0.3622448979591837\n",
1126 | "Validation score (AUC): 0.5816326530612245\n",
1127 | "Validation score (AUC): 0.7576530612244898\n",
1128 | "Validation score (AUC): 0.7602040816326531\n",
1129 | "Validation score (AUC): 0.7678571428571429\n",
1130 | "Validation score (AUC): 0.7627551020408162\n",
1131 | "Validation score (AUC): 0.7538265306122449\n",
1132 | "Validation score (AUC): 0.760204081632653\n",
1133 | "Total training time: 0.048 min\n",
1134 | "Starting fold 7 ...\n",
1135 | "Validation score (AUC): 0.24074074074074076\n",
1136 | "Validation score (AUC): 0.5432098765432098\n",
1137 | "Validation score (AUC): 0.8444444444444444\n",
1138 | "Validation score (AUC): 0.8740740740740741\n",
1139 | "Validation score (AUC): 0.8617283950617284\n",
1140 | "Validation score (AUC): 0.8395061728395062\n",
1141 | "Validation score (AUC): 0.8320987654320987\n",
1142 | "Validation score (AUC): 0.8320987654320987\n",
1143 | "Total training time: 0.047 min\n",
1144 | "Starting fold 8 ...\n",
1145 | "Validation score (AUC): 0.25185185185185177\n",
1146 | "Validation score (AUC): 0.3851851851851852\n",
1147 | "Validation score (AUC): 0.8074074074074072\n",
1148 | "Validation score (AUC): 0.8790123456790123\n",
1149 | "Validation score (AUC): 0.9160493827160494\n",
1150 | "Validation score (AUC): 0.928395061728395\n",
1151 | "Validation score (AUC): 0.9308641975308641\n",
1152 | "Validation score (AUC): 0.9234567901234567\n",
1153 | "Total training time: 0.052 min\n",
1154 | "Starting fold 9 ...\n",
1155 | "Validation score (AUC): 0.2716049382716049\n",
1156 | "Validation score (AUC): 0.4493827160493827\n",
1157 | "Validation score (AUC): 0.6691358024691358\n",
1158 | "Validation score (AUC): 0.7407407407407407\n",
1159 | "Validation score (AUC): 0.7753086419753087\n",
1160 | "Validation score (AUC): 0.8\n",
1161 | "Validation score (AUC): 0.7728395061728395\n",
1162 | "Validation score (AUC): 0.7777777777777778\n",
1163 | "Total training time: 0.046 min\n",
1164 | "Starting fold 10 ...\n",
1165 | "Validation score (AUC): 0.31851851851851853\n",
1166 | "Validation score (AUC): 0.6691358024691358\n",
1167 | "Validation score (AUC): 0.8049382716049382\n",
1168 | "Validation score (AUC): 0.8395061728395062\n",
1169 | "Validation score (AUC): 0.8444444444444444\n",
1170 | "Validation score (AUC): 0.854320987654321\n",
1171 | "Validation score (AUC): 0.837037037037037\n",
1172 | "Validation score (AUC): 0.8469135802469135\n",
1173 | "Total training time: 0.048 min\n",
1174 | "Seed 1966\n",
1175 | "Starting fold 1 ...\n",
1176 | "Validation score (AUC): 0.4591836734693877\n",
1177 | "Validation score (AUC): 0.5612244897959183\n",
1178 | "Validation score (AUC): 0.7678571428571428\n",
1179 | "Validation score (AUC): 0.8367346938775511\n",
1180 | "Validation score (AUC): 0.8392857142857143\n",
1181 | "Validation score (AUC): 0.8443877551020408\n",
1182 | "Validation score (AUC): 0.8469387755102041\n",
1183 | "Validation score (AUC): 0.8418367346938777\n",
1184 | "Total training time: 0.050 min\n",
1185 | "Starting fold 2 ...\n",
1186 | "Validation score (AUC): 0.5586734693877551\n",
1187 | "Validation score (AUC): 0.6275510204081632\n",
1188 | "Validation score (AUC): 0.6989795918367346\n",
1189 | "Validation score (AUC): 0.7423469387755102\n",
1190 | "Validation score (AUC): 0.7551020408163265\n",
1191 | "Validation score (AUC): 0.7704081632653061\n",
1192 | "Validation score (AUC): 0.7704081632653061\n",
1193 | "Validation score (AUC): 0.7551020408163265\n",
1194 | "Total training time: 0.046 min\n",
1195 | "Starting fold 3 ...\n",
1196 | "Validation score (AUC): 0.7168367346938777\n",
1197 | "Validation score (AUC): 0.7908163265306123\n",
1198 | "Validation score (AUC): 0.8214285714285714\n",
1199 | "Validation score (AUC): 0.8112244897959184\n",
1200 | "Validation score (AUC): 0.7959183673469388\n",
1201 | "Validation score (AUC): 0.8061224489795918\n",
1202 | "Validation score (AUC): 0.8112244897959183\n",
1203 | "Validation score (AUC): 0.8188775510204082\n",
1204 | "Total training time: 0.055 min\n",
1205 | "Starting fold 4 ...\n",
1206 | "Validation score (AUC): 0.6556122448979592\n",
1207 | "Validation score (AUC): 0.7806122448979592\n",
1208 | "Validation score (AUC): 0.8724489795918366\n",
1209 | "Validation score (AUC): 0.8801020408163266\n",
1210 | "Validation score (AUC): 0.8673469387755102\n",
1211 | "Validation score (AUC): 0.8596938775510204\n",
1212 | "Validation score (AUC): 0.8673469387755102\n",
1213 | "Validation score (AUC): 0.8698979591836735\n",
1214 | "Total training time: 0.048 min\n",
1215 | "Starting fold 5 ...\n",
1216 | "Validation score (AUC): 0.75\n",
1217 | "Validation score (AUC): 0.778061224489796\n",
1218 | "Validation score (AUC): 0.7908163265306122\n",
1219 | "Validation score (AUC): 0.788265306122449\n",
1220 | "Validation score (AUC): 0.7678571428571429\n",
1221 | "Validation score (AUC): 0.7857142857142857\n",
1222 | "Validation score (AUC): 0.7908163265306123\n",
1223 | "Validation score (AUC): 0.8010204081632653\n",
1224 | "Total training time: 0.048 min\n",
1225 | "Starting fold 6 ...\n",
1226 | "Validation score (AUC): 0.5841836734693878\n",
1227 | "Validation score (AUC): 0.6913265306122449\n",
1228 | "Validation score (AUC): 0.7346938775510206\n",
1229 | "Validation score (AUC): 0.7448979591836735\n",
1230 | "Validation score (AUC): 0.7270408163265306\n",
1231 | "Validation score (AUC): 0.7372448979591836\n",
1232 | "Validation score (AUC): 0.7627551020408164\n",
1233 | "Validation score (AUC): 0.7653061224489797\n",
1234 | "Total training time: 0.051 min\n",
1235 | "Starting fold 7 ...\n",
1236 | "Validation score (AUC): 0.6987654320987654\n",
1237 | "Validation score (AUC): 0.7851851851851852\n",
1238 | "Validation score (AUC): 0.8592592592592592\n",
1239 | "Validation score (AUC): 0.8148148148148149\n",
1240 | "Validation score (AUC): 0.7950617283950617\n",
1241 | "Validation score (AUC): 0.8024691358024691\n",
1242 | "Validation score (AUC): 0.8074074074074075\n",
1243 | "Validation score (AUC): 0.8197530864197531\n",
1244 | "Total training time: 0.046 min\n",
1245 | "Starting fold 8 ...\n",
1246 | "Validation score (AUC): 0.6222222222222222\n",
1247 | "Validation score (AUC): 0.6617283950617284\n",
1248 | "Validation score (AUC): 0.7259259259259259\n",
1249 | "Validation score (AUC): 0.8469135802469135\n",
1250 | "Validation score (AUC): 0.9037037037037037\n",
1251 | "Validation score (AUC): 0.8987654320987654\n",
1252 | "Validation score (AUC): 0.9012345679012346\n",
1253 | "Validation score (AUC): 0.9061728395061728\n",
1254 | "Total training time: 0.050 min\n",
1255 | "Starting fold 9 ...\n",
1256 | "Validation score (AUC): 0.5135802469135802\n",
1257 | "Validation score (AUC): 0.6098765432098765\n",
1258 | "Validation score (AUC): 0.7185185185185186\n",
1259 | "Validation score (AUC): 0.7950617283950616\n",
1260 | "Validation score (AUC): 0.8222222222222222\n",
1261 | "Validation score (AUC): 0.8320987654320988\n",
1262 | "Validation score (AUC): 0.8345679012345679\n",
1263 | "Validation score (AUC): 0.8296296296296296\n",
1264 | "Total training time: 0.049 min\n",
1265 | "Starting fold 10 ...\n",
1266 | "Validation score (AUC): 0.5135802469135802\n",
1267 | "Validation score (AUC): 0.6148148148148148\n",
1268 | "Validation score (AUC): 0.711111111111111\n",
1269 | "Validation score (AUC): 0.8222222222222222\n",
1270 | "Validation score (AUC): 0.8246913580246913\n",
1271 | "Validation score (AUC): 0.8320987654320987\n",
1272 | "Validation score (AUC): 0.8345679012345679\n",
1273 | "Validation score (AUC): 0.837037037037037\n",
1274 | "Total training time: 0.046 min\n",
1275 | "Seed 1994\n",
1276 | "Starting fold 1 ...\n",
1277 | "Validation score (AUC): 0.6951530612244897\n",
1278 | "Validation score (AUC): 0.7295918367346939\n",
1279 | "Validation score (AUC): 0.7704081632653061\n",
1280 | "Validation score (AUC): 0.8061224489795918\n",
1281 | "Validation score (AUC): 0.8163265306122449\n",
1282 | "Validation score (AUC): 0.8239795918367347\n",
1283 | "Validation score (AUC): 0.854591836734694\n",
1284 | "Validation score (AUC): 0.8571428571428571\n",
1285 | "Total training time: 0.052 min\n",
1286 | "Starting fold 2 ...\n",
1287 | "Validation score (AUC): 0.6709183673469388\n",
1288 | "Validation score (AUC): 0.7040816326530612\n",
1289 | "Validation score (AUC): 0.7193877551020408\n",
1290 | "Validation score (AUC): 0.7321428571428571\n",
1291 | "Validation score (AUC): 0.7219387755102041\n",
1292 | "Validation score (AUC): 0.7040816326530612\n",
1293 | "Validation score (AUC): 0.7244897959183674\n",
1294 | "Validation score (AUC): 0.7321428571428572\n",
1295 | "Total training time: 0.046 min\n",
1296 | "Starting fold 3 ...\n",
1297 | "Validation score (AUC): 0.5255102040816327\n",
1298 | "Validation score (AUC): 0.5688775510204082\n",
1299 | "Validation score (AUC): 0.6377551020408163\n",
1300 | "Validation score (AUC): 0.6938775510204082\n",
1301 | "Validation score (AUC): 0.7142857142857143\n",
1302 | "Validation score (AUC): 0.7423469387755102\n",
1303 | "Validation score (AUC): 0.7653061224489797\n",
1304 | "Validation score (AUC): 0.7755102040816326\n",
1305 | "Total training time: 0.052 min\n",
1306 | "Starting fold 4 ...\n",
1307 | "Validation score (AUC): 0.6454081632653061\n",
1308 | "Validation score (AUC): 0.6709183673469388\n",
1309 | "Validation score (AUC): 0.7295918367346939\n",
1310 | "Validation score (AUC): 0.8086734693877551\n",
1311 | "Validation score (AUC): 0.8520408163265306\n",
1312 | "Validation score (AUC): 0.8520408163265307\n",
1313 | "Validation score (AUC): 0.8622448979591837\n",
1314 | "Validation score (AUC): 0.8673469387755104\n",
1315 | "Total training time: 0.052 min\n",
1316 | "Starting fold 5 ...\n",
1317 | "Validation score (AUC): 0.43367346938775514\n",
1318 | "Validation score (AUC): 0.4897959183673469\n",
1319 | "Validation score (AUC): 0.5586734693877551\n",
1320 | "Validation score (AUC): 0.639030612244898\n",
1321 | "Validation score (AUC): 0.6862244897959184\n",
1322 | "Validation score (AUC): 0.7270408163265305\n",
1323 | "Validation score (AUC): 0.7448979591836735\n",
1324 | "Validation score (AUC): 0.7704081632653061\n",
1325 | "Total training time: 0.047 min\n",
1326 | "Starting fold 6 ...\n",
1327 | "Validation score (AUC): 0.5816326530612246\n",
1328 | "Validation score (AUC): 0.6147959183673469\n",
1329 | "Validation score (AUC): 0.6479591836734694\n",
1330 | "Validation score (AUC): 0.6785714285714286\n",
1331 | "Validation score (AUC): 0.6989795918367347\n",
1332 | "Validation score (AUC): 0.7117346938775511\n",
1333 | "Validation score (AUC): 0.7270408163265306\n",
1334 | "Validation score (AUC): 0.7295918367346939\n",
1335 | "Total training time: 0.052 min\n",
1336 | "Starting fold 7 ...\n",
1337 | "Validation score (AUC): 0.5753086419753086\n",
1338 | "Validation score (AUC): 0.6296296296296297\n",
1339 | "Validation score (AUC): 0.674074074074074\n",
1340 | "Validation score (AUC): 0.7061728395061728\n",
1341 | "Validation score (AUC): 0.7259259259259259\n",
1342 | "Validation score (AUC): 0.7555555555555555\n",
1343 | "Validation score (AUC): 0.7580246913580246\n",
1344 | "Validation score (AUC): 0.782716049382716\n",
1345 | "Total training time: 0.048 min\n",
1346 | "Starting fold 8 ...\n",
1347 | "Validation score (AUC): 0.6641975308641975\n",
1348 | "Validation score (AUC): 0.688888888888889\n",
1349 | "Validation score (AUC): 0.7308641975308642\n",
1350 | "Validation score (AUC): 0.7876543209876542\n",
1351 | "Validation score (AUC): 0.8197530864197531\n",
1352 | "Validation score (AUC): 0.8320987654320988\n",
1353 | "Validation score (AUC): 0.8493827160493826\n",
1354 | "Validation score (AUC): 0.854320987654321\n",
1355 | "Total training time: 0.049 min\n",
1356 | "Starting fold 9 ...\n",
1357 | "Validation score (AUC): 0.7901234567901234\n",
1358 | "Validation score (AUC): 0.8123456790123458\n",
1359 | "Validation score (AUC): 0.8444444444444443\n",
1360 | "Validation score (AUC): 0.8666666666666666\n",
1361 | "Validation score (AUC): 0.8666666666666667\n",
1362 | "Validation score (AUC): 0.8716049382716049\n",
1363 | "Validation score (AUC): 0.8641975308641975\n",
1364 | "Validation score (AUC): 0.8814814814814814\n",
1365 | "Total training time: 0.052 min\n",
1366 | "Starting fold 10 ...\n",
1367 | "Validation score (AUC): 0.7407407407407407\n",
1368 | "Validation score (AUC): 0.7876543209876543\n",
1369 | "Validation score (AUC): 0.8320987654320988\n",
1370 | "Validation score (AUC): 0.8419753086419752\n",
1371 | "Validation score (AUC): 0.854320987654321\n",
1372 | "Validation score (AUC): 0.8567901234567901\n",
1373 | "Validation score (AUC): 0.8617283950617284\n",
1374 | "Validation score (AUC): 0.8567901234567902\n",
1375 | "Total training time: 0.046 min\n",
1376 | "Seed 252\n",
1377 | "Starting fold 1 ...\n",
1378 | "Validation score (AUC): 0.75\n",
1379 | "Validation score (AUC): 0.7448979591836735\n",
1380 | "Validation score (AUC): 0.7780612244897959\n",
1381 | "Validation score (AUC): 0.8341836734693878\n",
1382 | "Validation score (AUC): 0.8418367346938775\n",
1383 | "Validation score (AUC): 0.8571428571428572\n",
1384 | "Validation score (AUC): 0.8520408163265306\n",
1385 | "Validation score (AUC): 0.8622448979591838\n",
1386 | "Total training time: 0.049 min\n",
1387 | "Starting fold 2 ...\n",
1388 | "Validation score (AUC): 0.7142857142857142\n",
1389 | "Validation score (AUC): 0.6989795918367347\n",
1390 | "Validation score (AUC): 0.6964285714285714\n",
1391 | "Validation score (AUC): 0.7193877551020409\n",
1392 | "Validation score (AUC): 0.7448979591836734\n",
1393 | "Validation score (AUC): 0.7448979591836735\n",
1394 | "Validation score (AUC): 0.7423469387755102\n",
1395 | "Validation score (AUC): 0.7397959183673469\n",
1396 | "Total training time: 0.050 min\n",
1397 | "Starting fold 3 ...\n",
1398 | "Validation score (AUC): 0.7091836734693878\n",
1399 | "Validation score (AUC): 0.7295918367346939\n",
1400 | "Validation score (AUC): 0.7372448979591838\n",
1401 | "Validation score (AUC): 0.7755102040816326\n",
1402 | "Validation score (AUC): 0.8035714285714285\n",
1403 | "Validation score (AUC): 0.798469387755102\n",
1404 | "Validation score (AUC): 0.8163265306122449\n",
1405 | "Validation score (AUC): 0.8137755102040816\n",
1406 | "Total training time: 0.051 min\n",
1407 | "Starting fold 4 ...\n",
1408 | "Validation score (AUC): 0.8673469387755102\n",
1409 | "Validation score (AUC): 0.8622448979591837\n",
1410 | "Validation score (AUC): 0.8494897959183674\n",
1411 | "Validation score (AUC): 0.8494897959183674\n",
1412 | "Validation score (AUC): 0.8571428571428572\n",
1413 | "Validation score (AUC): 0.8673469387755102\n",
1414 | "Validation score (AUC): 0.8647959183673469\n",
1415 | "Validation score (AUC): 0.8673469387755103\n",
1416 | "Total training time: 0.052 min\n",
1417 | "Starting fold 5 ...\n",
1418 | "Validation score (AUC): 0.673469387755102\n",
1419 | "Validation score (AUC): 0.7270408163265306\n",
1420 | "Validation score (AUC): 0.7525510204081632\n",
1421 | "Validation score (AUC): 0.7716836734693878\n",
1422 | "Validation score (AUC): 0.7959183673469388\n",
1423 | "Validation score (AUC): 0.8035714285714286\n",
1424 | "Validation score (AUC): 0.8239795918367347\n",
1425 | "Validation score (AUC): 0.8137755102040816\n",
1426 | "Total training time: 0.047 min\n",
1427 | "Starting fold 6 ...\n",
1428 | "Validation score (AUC): 0.6938775510204082\n",
1429 | "Validation score (AUC): 0.6989795918367346\n",
1430 | "Validation score (AUC): 0.7117346938775511\n",
1431 | "Validation score (AUC): 0.7193877551020409\n",
1432 | "Validation score (AUC): 0.7295918367346939\n",
1433 | "Validation score (AUC): 0.7448979591836734\n",
1434 | "Validation score (AUC): 0.75\n",
1435 | "Validation score (AUC): 0.7474489795918368\n",
1436 | "Total training time: 0.049 min\n",
1437 | "Starting fold 7 ...\n",
1438 | "Validation score (AUC): 0.7382716049382716\n",
1439 | "Validation score (AUC): 0.7654320987654321\n",
1440 | "Validation score (AUC): 0.7703703703703703\n",
1441 | "Validation score (AUC): 0.8148148148148148\n",
1442 | "Validation score (AUC): 0.8197530864197531\n",
1443 | "Validation score (AUC): 0.8345679012345678\n",
1444 | "Validation score (AUC): 0.8296296296296296\n",
1445 | "Validation score (AUC): 0.8345679012345678\n",
1446 | "Total training time: 0.051 min\n",
1447 | "Starting fold 8 ...\n",
1448 | "Validation score (AUC): 0.8049382716049382\n",
1449 | "Validation score (AUC): 0.8123456790123457\n",
1450 | "Validation score (AUC): 0.8395061728395061\n",
1451 | "Validation score (AUC): 0.8666666666666667\n",
1452 | "Validation score (AUC): 0.8814814814814815\n",
1453 | "Validation score (AUC): 0.8987654320987655\n",
1454 | "Validation score (AUC): 0.9160493827160494\n",
1455 | "Validation score (AUC): 0.9209876543209876\n",
1456 | "Total training time: 0.047 min\n",
1457 | "Starting fold 9 ...\n",
1458 | "Validation score (AUC): 0.834567901234568\n",
1459 | "Validation score (AUC): 0.837037037037037\n",
1460 | "Validation score (AUC): 0.8592592592592593\n",
1461 | "Validation score (AUC): 0.8567901234567902\n",
1462 | "Validation score (AUC): 0.8790123456790123\n",
1463 | "Validation score (AUC): 0.8691358024691358\n",
1464 | "Validation score (AUC): 0.8617283950617284\n",
1465 | "Validation score (AUC): 0.8518518518518519\n",
1466 | "Total training time: 0.050 min\n",
1467 | "Starting fold 10 ...\n",
1468 | "Validation score (AUC): 0.8271604938271605\n",
1469 | "Validation score (AUC): 0.8395061728395062\n",
1470 | "Validation score (AUC): 0.8395061728395062\n",
1471 | "Validation score (AUC): 0.8592592592592592\n",
1472 | "Validation score (AUC): 0.8518518518518519\n",
1473 | "Validation score (AUC): 0.8691358024691358\n",
1474 | "Validation score (AUC): 0.8567901234567901\n",
1475 | "Validation score (AUC): 0.8567901234567901\n",
1476 | "Total training time: 0.049 min\n"
1477 | ]
1478 | }
1479 | ],
1480 | "source": [
1481 | "SEED_PREDS = 0\n",
1482 | "SEED_SCORES = []\n",
1483 | "\n",
1484 | "for seed in SEEDS:\n",
1485 | " print('Seed', seed)\n",
1486 | "\n",
1487 | " val_scores = []\n",
1488 | " Y_SCORES = 0\n",
1489 | " \n",
1490 | " for fold in range(1, 11):\n",
1491 | " print(f'Starting fold {fold} ...')\n",
1492 | " df_train = pd.read_csv(OUTPUT_PATH / f'df_train_fold_{fold}.csv')\n",
1493 | " df_val = pd.read_csv(OUTPUT_PATH / f'df_val_fold_{fold}.csv')\n",
1494 | " test = pd.read_csv(OUTPUT_PATH / f'test_fold_{fold}.csv')\n",
1495 | "\n",
1496 | " train_dataloader = get_dataloader(\n",
1497 | " df_train, \n",
1498 | " feature_names=ORIGINAL_FEATURES, \n",
1499 | " target_name=TARGET,\n",
1500 | " batch_size=BATCH_SIZE,\n",
1501 | " mode='train'\n",
1502 | " )\n",
1503 | "\n",
1504 | " val_dataloader = get_dataloader(\n",
1505 | " df_val, \n",
1506 | " feature_names=ORIGINAL_FEATURES, \n",
1507 | " target_name=TARGET,\n",
1508 | " batch_size=BATCH_SIZE,\n",
1509 | " mode='val'\n",
1510 | " )\n",
1511 | "\n",
1512 | " test_dataloader = get_dataloader(\n",
1513 | " test, \n",
1514 | " feature_names=ORIGINAL_FEATURES, \n",
1515 | " target_name=TARGET,\n",
1516 | " batch_size=BATCH_SIZE,\n",
1517 | " mode='test'\n",
1518 | " )\n",
1519 | "\n",
1520 | " torch.manual_seed(seed)\n",
1521 | " model = Net(\n",
1522 | " num_features=NUM_FEATURES, \n",
1523 | " dense_units=8,\n",
1524 | " hidden_size=16,\n",
1525 | " dropout=0.05\n",
1526 | " )\n",
1527 | "\n",
1528 | " model.to(DEVICE)\n",
1529 | "\n",
1530 | " model.train()\n",
1531 | "\n",
1532 | " optimizer = optim.AdamW(\n",
1533 | " model.parameters(), \n",
1534 | " lr=LEARNING_RATE,\n",
1535 | " weight_decay=WEIGHT_DECAY\n",
1536 | " )\n",
1537 | "\n",
1538 | " model, val_score = fit(model=model,\n",
1539 | " optimizer=optimizer,\n",
1540 | " scheduler=None,\n",
1541 | " epochs=N_EPOCHS,\n",
1542 | " train_dataloader=train_dataloader,\n",
1543 | " val_dataloader=val_dataloader)\n",
1544 | "\n",
1545 | " val_scores.append(val_score)\n",
1546 | "\n",
1547 | " y_scores = torch.tensor([])\n",
1548 | " with torch.inference_mode():\n",
1549 | " model.eval()\n",
1550 | " for batch_idx, (features, targets) in enumerate(test_dataloader):\n",
1551 | " features = features.to(DEVICE)\n",
1552 | " with autocast():\n",
1553 | " logits = model(features).detach().cpu().type(torch.float)\n",
1554 | " probs = F.softmax(logits, dim=-1)[:, 1]\n",
1555 | " y_scores = torch.cat([y_scores, probs])\n",
1556 | "\n",
1557 | " Y_SCORES += y_scores / 10\n",
1558 | " \n",
1559 | " SEED_PREDS += Y_SCORES / len(SEEDS)\n",
1560 | " SEED_SCORES.append(np.mean(val_scores))"
1561 | ]
1562 | },
1563 | {
1564 | "cell_type": "code",
1565 | "execution_count": 12,
1566 | "id": "a8c664c5",
1567 | "metadata": {
1568 | "execution": {
1569 | "iopub.execute_input": "2024-01-05T17:19:00.712075Z",
1570 | "iopub.status.busy": "2024-01-05T17:19:00.711736Z",
1571 | "iopub.status.idle": "2024-01-05T17:19:00.720775Z",
1572 | "shell.execute_reply": "2024-01-05T17:19:00.719820Z"
1573 | },
1574 | "papermill": {
1575 | "duration": 0.071775,
1576 | "end_time": "2024-01-05T17:19:00.722771",
1577 | "exception": false,
1578 | "start_time": "2024-01-05T17:19:00.650996",
1579 | "status": "completed"
1580 | },
1581 | "tags": []
1582 | },
1583 | "outputs": [
1584 | {
1585 | "name": "stdout",
1586 | "output_type": "stream",
1587 | "text": [
1588 | "0.8171630527210885\n",
1589 | "0.016486137137366824\n"
1590 | ]
1591 | },
1592 | {
1593 | "data": {
1594 | "text/plain": [
1595 | "[0.8361243386243388,\n",
1596 | " 0.8008427815570673,\n",
1597 | " 0.7837836986646509,\n",
1598 | " 0.8285333207357016,\n",
1599 | " 0.8219532627865961,\n",
1600 | " 0.8244633408919124,\n",
1601 | " 0.8107451499118167,\n",
1602 | " 0.830858528596624]"
1603 | ]
1604 | },
1605 | "execution_count": 12,
1606 | "metadata": {},
1607 | "output_type": "execute_result"
1608 | }
1609 | ],
1610 | "source": [
1611 | "print(np.mean(SEED_SCORES))\n",
1612 | "print(np.std(SEED_SCORES))\n",
1613 | "\n",
1614 | "SEED_SCORES"
1615 | ]
1616 | },
1617 | {
1618 | "cell_type": "code",
1619 | "execution_count": 13,
1620 | "id": "050de0df",
1621 | "metadata": {
1622 | "execution": {
1623 | "iopub.execute_input": "2024-01-05T17:19:00.843891Z",
1624 | "iopub.status.busy": "2024-01-05T17:19:00.843061Z",
1625 | "iopub.status.idle": "2024-01-05T17:19:00.866928Z",
1626 | "shell.execute_reply": "2024-01-05T17:19:00.866068Z"
1627 | },
1628 | "papermill": {
1629 | "duration": 0.086833,
1630 | "end_time": "2024-01-05T17:19:00.869048",
1631 | "exception": false,
1632 | "start_time": "2024-01-05T17:19:00.782215",
1633 | "status": "completed"
1634 | },
1635 | "tags": []
1636 | },
1637 | "outputs": [
1638 | {
1639 | "data": {
1640 | "text/html": [
1641 | "\n",
1642 | "\n",
1655 | "
\n",
1656 | " \n",
1657 | " \n",
1658 | " | \n",
1659 | " id | \n",
1660 | " Target | \n",
1661 | "
\n",
1662 | " \n",
1663 | " \n",
1664 | " \n",
1665 | " | 0 | \n",
1666 | " 420 | \n",
1667 | " 0.553706 | \n",
1668 | "
\n",
1669 | " \n",
1670 | " | 1 | \n",
1671 | " 421 | \n",
1672 | " 0.758907 | \n",
1673 | "
\n",
1674 | " \n",
1675 | " | 2 | \n",
1676 | " 422 | \n",
1677 | " 0.904092 | \n",
1678 | "
\n",
1679 | " \n",
1680 | " | 3 | \n",
1681 | " 423 | \n",
1682 | " 0.927545 | \n",
1683 | "
\n",
1684 | " \n",
1685 | " | 4 | \n",
1686 | " 424 | \n",
1687 | " 0.916706 | \n",
1688 | "
\n",
1689 | " \n",
1690 | " | ... | \n",
1691 | " ... | \n",
1692 | " ... | \n",
1693 | "
\n",
1694 | " \n",
1695 | " | 275 | \n",
1696 | " 695 | \n",
1697 | " 0.913177 | \n",
1698 | "
\n",
1699 | " \n",
1700 | " | 276 | \n",
1701 | " 696 | \n",
1702 | " 0.619613 | \n",
1703 | "
\n",
1704 | " \n",
1705 | " | 277 | \n",
1706 | " 697 | \n",
1707 | " 0.914858 | \n",
1708 | "
\n",
1709 | " \n",
1710 | " | 278 | \n",
1711 | " 698 | \n",
1712 | " 0.600448 | \n",
1713 | "
\n",
1714 | " \n",
1715 | " | 279 | \n",
1716 | " 699 | \n",
1717 | " 0.834674 | \n",
1718 | "
\n",
1719 | " \n",
1720 | "
\n",
1721 | "
280 rows × 2 columns
\n",
1722 | "
"
1723 | ],
1724 | "text/plain": [
1725 | " id Target\n",
1726 | "0 420 0.553706\n",
1727 | "1 421 0.758907\n",
1728 | "2 422 0.904092\n",
1729 | "3 423 0.927545\n",
1730 | "4 424 0.916706\n",
1731 | ".. ... ...\n",
1732 | "275 695 0.913177\n",
1733 | "276 696 0.619613\n",
1734 | "277 697 0.914858\n",
1735 | "278 698 0.600448\n",
1736 | "279 699 0.834674\n",
1737 | "\n",
1738 | "[280 rows x 2 columns]"
1739 | ]
1740 | },
1741 | "execution_count": 13,
1742 | "metadata": {},
1743 | "output_type": "execute_result"
1744 | }
1745 | ],
1746 | "source": [
1747 | "sub = pd.read_csv(INPUT_PATH / 'submission.csv')\n",
1748 | "sub[TARGET] = SEED_PREDS\n",
1749 | "sub.to_csv('submission.csv', index=False)\n",
1750 | "sub"
1751 | ]
1752 | },
1753 | {
1754 | "cell_type": "code",
1755 | "execution_count": null,
1756 | "id": "ff634bac",
1757 | "metadata": {
1758 | "papermill": {
1759 | "duration": 0.060313,
1760 | "end_time": "2024-01-05T17:19:00.990556",
1761 | "exception": false,
1762 | "start_time": "2024-01-05T17:19:00.930243",
1763 | "status": "completed"
1764 | },
1765 | "tags": []
1766 | },
1767 | "outputs": [],
1768 | "source": []
1769 | }
1770 | ],
1771 | "metadata": {
1772 | "kaggle": {
1773 | "accelerator": "gpu",
1774 | "dataSources": [
1775 | {
1776 | "databundleVersionId": 7397785,
1777 | "sourceId": 66702,
1778 | "sourceType": "competition"
1779 | }
1780 | ],
1781 | "dockerImageVersionId": 30626,
1782 | "isGpuEnabled": true,
1783 | "isInternetEnabled": false,
1784 | "language": "python",
1785 | "sourceType": "notebook"
1786 | },
1787 | "kernelspec": {
1788 | "display_name": "Python 3",
1789 | "language": "python",
1790 | "name": "python3"
1791 | },
1792 | "language_info": {
1793 | "codemirror_mode": {
1794 | "name": "ipython",
1795 | "version": 3
1796 | },
1797 | "file_extension": ".py",
1798 | "mimetype": "text/x-python",
1799 | "name": "python",
1800 | "nbconvert_exporter": "python",
1801 | "pygments_lexer": "ipython3",
1802 | "version": "3.10.12"
1803 | },
1804 | "papermill": {
1805 | "default_parameters": {},
1806 | "duration": 262.280111,
1807 | "end_time": "2024-01-05T17:19:02.573435",
1808 | "environment_variables": {},
1809 | "exception": null,
1810 | "input_path": "__notebook__.ipynb",
1811 | "output_path": "__notebook__.ipynb",
1812 | "parameters": {},
1813 | "start_time": "2024-01-05T17:14:40.293324",
1814 | "version": "2.4.0"
1815 | }
1816 | },
1817 | "nbformat": 4,
1818 | "nbformat_minor": 5
1819 | }
1820 |
--------------------------------------------------------------------------------