├── Report.pdf ├── data ├── sample_data.csv ├── sample_label.csv └── sample_submission.csv ├── model_1-49-0.9795.hdf5 ├── model_2-13-0.9792.hdf5 ├── model_3-30-0.9789.hdf5 ├── model_4-18-0.9791.hdf5 ├── model_99198_45.h5 ├── paper ├── cnn_classifier_using_bytes.pdf ├── malconv.pdf └── mlp_and_rnn.pdf ├── readme.md ├── requirements.txt ├── result.csv ├── test.py └── train.py /Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/Report.pdf -------------------------------------------------------------------------------- /data/sample_label.csv: -------------------------------------------------------------------------------- 1 | sample_id,category 2 | 0,1 3 | 1,1 4 | 2,1 5 | 3,1 6 | 4,1 7 | 5,1 8 | 6,0 9 | 7,0 10 | 8,1 11 | 9,1 12 | 10,0 13 | 11,1 14 | 12,0 15 | 13,0 16 | 14,0 17 | 15,1 18 | 16,0 19 | 17,1 20 | 18,1 21 | 19,0 22 | 20,0 23 | 21,1 24 | 22,1 25 | 23,1 26 | 24,1 27 | 25,0 28 | 26,1 29 | 27,1 30 | 28,1 31 | 29,1 32 | 30,1 33 | 31,1 34 | 32,0 35 | 33,0 36 | 34,0 37 | 35,1 38 | 36,1 39 | 37,1 40 | 38,0 41 | 39,1 42 | 40,1 43 | 41,1 44 | 42,1 45 | 43,1 46 | 44,1 47 | 45,1 48 | 46,1 49 | 47,0 50 | 48,0 51 | 49,1 52 | 50,0 53 | 51,1 54 | 52,0 55 | 53,1 56 | 54,1 57 | 55,1 58 | 56,1 59 | 57,1 60 | 58,0 61 | 59,1 62 | 60,0 63 | 61,0 64 | 62,1 65 | 63,1 66 | 64,1 67 | 65,1 68 | 66,0 69 | 67,1 70 | 68,1 71 | 69,0 72 | 70,1 73 | 71,1 74 | 72,1 75 | 73,0 76 | 74,1 77 | 75,1 78 | 76,0 79 | 77,0 80 | 78,0 81 | 79,1 82 | 80,1 83 | 81,0 84 | 82,0 85 | 83,1 86 | 84,0 87 | 85,0 88 | 86,0 89 | 87,0 90 | 88,1 91 | 89,1 92 | 90,1 93 | 91,1 94 | 92,1 95 | 93,1 96 | 94,1 97 | 95,0 98 | 96,1 99 | 97,1 100 | 98,0 101 | 99,1 102 | 100,1 103 | 101,0 104 | 102,0 105 | 103,1 106 | 104,1 107 | 105,1 108 | 106,1 109 | 107,1 110 | 108,0 111 | 109,1 112 | 110,1 113 | 111,1 114 | 112,1 115 | 113,1 116 | 114,0 117 | 115,0 118 | 116,0 119 | 117,1 120 | 118,1 121 | 119,0 122 | 120,1 123 | 121,1 124 | 122,1 125 | 123,1 126 | 124,1 127 | 125,1 128 | 126,1 129 | 127,1 130 | 128,1 131 | 129,0 132 | 130,0 133 | 131,1 134 | 132,1 135 | 133,1 136 | 134,0 137 | 135,1 138 | 136,0 139 | 137,1 140 | 138,0 141 | 139,1 142 | 140,1 143 | 141,1 144 | 142,0 145 | 143,1 146 | 144,1 147 | 145,0 148 | 146,1 149 | 147,1 150 | 148,0 151 | 149,1 152 | 150,1 153 | 151,1 154 | 152,1 155 | 153,0 156 | 154,1 157 | 155,1 158 | 156,0 159 | 157,1 160 | 158,1 161 | 159,1 162 | 160,1 163 | 161,0 164 | 162,0 165 | 163,1 166 | 164,1 167 | 165,1 168 | 166,1 169 | 167,1 170 | 168,1 171 | 169,1 172 | 170,0 173 | 171,1 174 | 172,1 175 | 173,0 176 | 174,0 177 | 175,1 178 | 176,1 179 | 177,1 180 | 178,1 181 | 179,0 182 | 180,1 183 | 181,1 184 | 182,0 185 | 183,1 186 | 184,0 187 | 185,1 188 | 186,0 189 | 187,1 190 | 188,0 191 | 189,1 192 | 190,0 193 | 191,1 194 | 192,1 195 | 193,0 196 | 194,0 197 | 195,0 198 | 196,1 199 | 197,1 200 | 198,1 201 | 199,1 202 | 200,0 203 | 201,1 204 | 202,1 205 | 203,1 206 | 204,1 207 | 205,1 208 | 206,1 209 | 207,1 210 | 208,1 211 | 209,0 212 | 210,1 213 | 211,1 214 | 212,0 215 | 213,1 216 | 214,0 217 | 215,0 218 | 216,0 219 | 217,1 220 | 218,0 221 | 219,0 222 | 220,1 223 | 221,1 224 | 222,0 225 | 223,0 226 | 224,1 227 | 225,0 228 | 226,0 229 | 227,1 230 | 228,1 231 | 229,0 232 | 230,0 233 | 231,1 234 | 232,1 235 | 233,0 236 | 234,1 237 | 235,1 238 | 236,1 239 | 237,1 240 | 238,1 241 | 239,0 242 | 240,1 243 | 241,1 244 | 242,1 245 | 243,0 246 | 244,0 247 | 245,1 248 | 246,0 249 | 247,0 250 | 248,0 251 | 249,1 252 | 250,1 253 | 251,1 254 | 252,1 255 | 253,1 256 | 254,1 257 | 255,1 258 | 256,1 259 | 257,1 260 | 258,1 261 | 259,1 262 | 260,0 263 | 261,1 264 | 262,0 265 | 263,1 266 | 264,1 267 | 265,1 268 | 266,1 269 | 267,1 270 | 268,1 271 | 269,0 272 | 270,0 273 | 271,1 274 | 272,1 275 | 273,0 276 | 274,1 277 | 275,1 278 | 276,0 279 | 277,1 280 | 278,1 281 | 279,1 282 | 280,1 283 | 281,1 284 | 282,1 285 | 283,1 286 | 284,0 287 | 285,1 288 | 286,1 289 | 287,0 290 | 288,0 291 | 289,1 292 | 290,1 293 | 291,0 294 | 292,1 295 | 293,0 296 | 294,0 297 | 295,0 298 | 296,0 299 | 297,1 300 | 298,0 301 | 299,1 302 | 300,1 303 | 301,0 304 | 302,0 305 | 303,0 306 | 304,0 307 | 305,0 308 | 306,1 309 | 307,0 310 | 308,1 311 | 309,1 312 | 310,0 313 | 311,1 314 | 312,1 315 | 313,1 316 | 314,1 317 | 315,1 318 | 316,0 319 | 317,0 320 | 318,1 321 | 319,0 322 | 320,1 323 | 321,1 324 | 322,1 325 | 323,1 326 | 324,0 327 | 325,1 328 | 326,0 329 | 327,1 330 | 328,1 331 | 329,1 332 | 330,1 333 | 331,0 334 | 332,1 335 | 333,1 336 | 334,1 337 | 335,1 338 | 336,1 339 | 337,1 340 | 338,1 341 | 339,0 342 | 340,0 343 | 341,0 344 | 342,1 345 | 343,0 346 | 344,1 347 | 345,0 348 | 346,0 349 | 347,0 350 | 348,1 351 | 349,0 352 | 350,1 353 | 351,0 354 | 352,1 355 | 353,0 356 | 354,0 357 | 355,0 358 | 356,1 359 | 357,1 360 | 358,0 361 | 359,1 362 | 360,0 363 | 361,1 364 | 362,0 365 | 363,0 366 | 364,1 367 | 365,0 368 | 366,1 369 | 367,1 370 | 368,1 371 | 369,1 372 | 370,1 373 | 371,0 374 | 372,1 375 | 373,0 376 | 374,0 377 | 375,1 378 | 376,0 379 | 377,1 380 | 378,0 381 | 379,1 382 | 380,1 383 | 381,0 384 | 382,0 385 | 383,1 386 | 384,1 387 | 385,0 388 | 386,1 389 | 387,0 390 | 388,1 391 | 389,0 392 | 390,1 393 | 391,1 394 | 392,1 395 | 393,1 396 | 394,0 397 | 395,0 398 | 396,1 399 | 397,1 400 | 398,1 401 | 399,1 402 | 400,1 403 | 401,1 404 | 402,1 405 | 403,0 406 | 404,1 407 | 405,0 408 | 406,1 409 | 407,0 410 | 408,1 411 | 409,0 412 | 410,1 413 | 411,0 414 | 412,0 415 | 413,1 416 | 414,1 417 | 415,1 418 | 416,0 419 | 417,0 420 | 418,1 421 | 419,1 422 | 420,1 423 | 421,0 424 | 422,1 425 | 423,1 426 | 424,0 427 | 425,0 428 | 426,1 429 | 427,1 430 | 428,1 431 | 429,1 432 | 430,1 433 | 431,1 434 | 432,1 435 | 433,0 436 | 434,1 437 | 435,1 438 | 436,1 439 | 437,0 440 | 438,1 441 | 439,1 442 | 440,0 443 | 441,1 444 | 442,0 445 | 443,1 446 | 444,0 447 | 445,1 448 | 446,1 449 | 447,1 450 | 448,1 451 | 449,1 452 | 450,1 453 | 451,0 454 | 452,1 455 | 453,0 456 | 454,1 457 | 455,0 458 | 456,1 459 | 457,1 460 | 458,1 461 | 459,1 462 | 460,0 463 | 461,1 464 | 462,1 465 | 463,0 466 | 464,1 467 | 465,0 468 | 466,0 469 | 467,1 470 | 468,1 471 | 469,1 472 | 470,0 473 | 471,1 474 | 472,1 475 | 473,0 476 | 474,0 477 | 475,1 478 | 476,1 479 | 477,1 480 | 478,1 481 | 479,0 482 | 480,0 483 | 481,0 484 | 482,0 485 | 483,1 486 | 484,0 487 | 485,0 488 | 486,0 489 | 487,1 490 | 488,0 491 | 489,1 492 | 490,0 493 | 491,0 494 | 492,1 495 | 493,1 496 | 494,1 497 | 495,1 498 | 496,1 499 | 497,0 500 | 498,1 501 | 499,1 502 | -------------------------------------------------------------------------------- /model_1-49-0.9795.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/model_1-49-0.9795.hdf5 -------------------------------------------------------------------------------- /model_2-13-0.9792.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/model_2-13-0.9792.hdf5 -------------------------------------------------------------------------------- /model_3-30-0.9789.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/model_3-30-0.9789.hdf5 -------------------------------------------------------------------------------- /model_4-18-0.9791.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/model_4-18-0.9791.hdf5 -------------------------------------------------------------------------------- /model_99198_45.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/model_99198_45.h5 -------------------------------------------------------------------------------- /paper/cnn_classifier_using_bytes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/paper/cnn_classifier_using_bytes.pdf -------------------------------------------------------------------------------- /paper/malconv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/paper/malconv.pdf -------------------------------------------------------------------------------- /paper/mlp_and_rnn.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SayHiRay/malware-detection/d1afc49f9f6f8feecc8ff3299d72ebd27250af29/paper/mlp_and_rnn.pdf -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | This repository is the Master course project for CS5242 at NUS. The project is a in-class Kaggle competition, and the detail of the competition can be found [here on Kaggle](https://www.kaggle.com/c/cs5242-malware-detection). 2 | 3 | The training and test procedure are as the following: 4 | 5 | 1. Run `train.py`, three Keras models are trained on different train/validation split. Each model is trained for 50 epochs. The **accuracy** and **AUC score** are reported after each epoch. Based on these reported metrics, we decided the **3 best models** to use, each trained on a different train/validation split. 6 | 2. Run `test.py`. Three models selected from the last step are loaded. We then make predictions on the test set using the 3 models, and obtain 3 copies of predictions. Then we obtain the final result by taking **mean** of the 3 predictions. 7 | 8 | Note that for our best submission of this project, the model files of the 3 models are also included with our code. They will be loaded in `test.py` by default, so that our result can be easily replicated. Due to storage limit, data files are not included in the repo, but can be found on Kaggle page of the competition. 9 | 10 | Though its simplicity, our final result ranks 7th out of the 68 teams. There is still space for improvement, such as better hyperparameter tuning, using other architectures including ResNet, and better ensemble techniques. For more information on this project, please refer to **Report.pdf**. 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras==2.1.5 2 | tensorflow>=1.4.0 3 | pandas>=0.19 4 | numpy>=1.14 5 | scikit-learn>=0.19 6 | h5py>=2.7.0 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from keras.models import Sequential 4 | from keras.layers import Dense, Activation, Flatten, Conv1D, MaxPooling1D, GlobalAveragePooling1D, Dropout, Embedding 5 | 6 | 7 | def get_empty_model(params): 8 | print(params) 9 | conv_dropout_1 = params['conv_dropout_1'] 10 | conv_dropout_2 = params['conv_dropout_2'] 11 | conv_dropout_3 = params['conv_dropout_3'] 12 | conv_dropout_4 = params['conv_dropout_4'] 13 | conv_dropout_5 = params['conv_dropout_5'] 14 | dense_dropout = params['dense_dropout'] 15 | dense_dim = params['dense_dim'] 16 | optimizer = params['optimizer'] 17 | batch_size = params['batch_size'] 18 | epochs = params['epochs'] 19 | 20 | input_length = 4096 21 | 22 | model = Sequential() 23 | model.add(Embedding(256, 16, input_length=input_length)) 24 | model.add(Dropout(conv_dropout_1)) 25 | model.add(Conv1D(48, 32, strides=4, padding='same', dilation_rate=1, activation='relu', use_bias=True, 26 | kernel_initializer='glorot_uniform', bias_initializer='zeros')) 27 | model.add(Dropout(conv_dropout_2)) 28 | model.add(Conv1D(96, 32, strides=4, padding='same', dilation_rate=1, activation='relu', use_bias=True, 29 | kernel_initializer='glorot_uniform', bias_initializer='zeros')) 30 | model.add(Dropout(conv_dropout_3)) 31 | model.add(MaxPooling1D(pool_size=4, strides=None, padding='valid')) 32 | model.add(Conv1D(128, 16, strides=8, padding='same', dilation_rate=1, activation='relu', use_bias=True, 33 | kernel_initializer='glorot_uniform', bias_initializer='zeros')) 34 | model.add(Dropout(conv_dropout_4)) 35 | model.add(Conv1D(192, 16, strides=8, padding='same', dilation_rate=1, activation='relu', use_bias=True, 36 | kernel_initializer='glorot_uniform', bias_initializer='zeros')) 37 | model.add(Dropout(conv_dropout_5)) 38 | 39 | model.add(Flatten()) 40 | 41 | model.add(Dense(dense_dim, activation='selu')) 42 | model.add(Dropout(dense_dropout)) 43 | 44 | model.add(Dense(1, activation='sigmoid')) 45 | model.compile(optimizer=optimizer, 46 | loss='binary_crossentropy', 47 | metrics=['accuracy']) 48 | 49 | return model 50 | 51 | 52 | if __name__ == "__main__": 53 | params = { 54 | 'batch_size': 128, 55 | 'conv_dropout_1': 0.2, 56 | 'conv_dropout_2': 0.2, 57 | 'conv_dropout_3': 0.2, 58 | 'conv_dropout_4': 0.2, 59 | 'conv_dropout_5': 0.2, 60 | 'dense_dim': 64, 61 | 'dense_dropout': 0.5, 62 | 'epochs': 40, 63 | 'optimizer': 'adam' 64 | } 65 | 66 | test = pd.read_csv(r'test.csv', header=None, names=list(range(4096)), usecols=list(range(4096)), 67 | dtype=np.float16) 68 | test.fillna(0, inplace=True) 69 | test = test.astype(np.int16) 70 | 71 | # Load trained models from hdf5 files, and ensemble results by 72 | # Taking mean of the predictions of 3 models. 73 | final_model_1 = get_empty_model(params) 74 | final_model_1.load_weights('model_1-49-0.9795.hdf5') 75 | X_test = test.values 76 | pred = final_model_1.predict(X_test) 77 | result_array_1 = pred.flatten() 78 | 79 | final_model_2 = get_empty_model(params) 80 | final_model_2.load_weights('model_2-13-0.9792.hdf5') 81 | X_test = test.values 82 | pred = final_model_2.predict(X_test) 83 | result_array_2 = pred.flatten() 84 | 85 | final_model_3 = get_empty_model(params) 86 | final_model_3.load_weights('model_3-30-0.9789.hdf5') 87 | X_test = test.values 88 | pred = final_model_3.predict(X_test) 89 | result_array_3 = pred.flatten() 90 | 91 | result_array = (result_array_1 + result_array_2 + result_array_3) / 3 92 | 93 | sample_id = range(pred.shape[0]) 94 | df_pred = pd.DataFrame( 95 | {'sample_id': sample_id, 96 | 'malware': result_array 97 | }) 98 | df_pred.to_csv('result.csv', columns=['sample_id', 'malware'], index=False) 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from keras.models import Sequential 4 | from keras.layers import Dense, Activation, Flatten, Conv1D, MaxPooling1D, GlobalAveragePooling1D, Dropout, Embedding, AlphaDropout 5 | from keras.callbacks import Callback, ModelCheckpoint 6 | from sklearn.model_selection import StratifiedShuffleSplit 7 | from sklearn.metrics import roc_auc_score 8 | 9 | 10 | def read_and_preprocess_training_data(file_name_train, file_name_label): 11 | """ 12 | Read in training features and labels, and preprocess them. 13 | """ 14 | train = pd.read_csv(file_name_train, header=None, usecols=list(range(4096)), dtype=np.float16) 15 | X_raw = train.values 16 | input_dim_X = (~np.isnan(X_raw)).sum( 17 | 1) # Used later to sort the training samples, so that all samples in one batch have similar dimensions 18 | 19 | train.fillna(0, inplace=True) 20 | train = train.astype(np.int16) 21 | labels = pd.read_csv(file_name_label, dtype={'sample_id': np.int32, 'category': np.int8}) 22 | train_labels = pd.concat([train, labels], axis=1) 23 | train_labels = train_labels.assign(num_dim=pd.Series(input_dim_X)) 24 | 25 | train_labels.sort_values(by='num_dim', ascending=False, 26 | inplace=True) # Now we have all training data sorted by the number of their dimensions 27 | 28 | train = train_labels.drop(['sample_id', 'category', 'num_dim'], axis=1) 29 | X = train.values 30 | y = train_labels['category'].values 31 | 32 | return X, y 33 | 34 | 35 | class roc_callback(Callback): 36 | """ 37 | A Keras callback function for calculating the auc_score on validation set. 38 | """ 39 | 40 | def __init__(self, training_data, validation_data): 41 | self.x_val = validation_data[0] 42 | self.y_val = validation_data[1] 43 | 44 | def on_train_begin(self, logs={}): 45 | return 46 | 47 | def on_train_end(self, logs={}): 48 | return 49 | 50 | def on_epoch_begin(self, epoch, logs={}): 51 | return 52 | 53 | def on_epoch_end(self, epoch, logs={}): 54 | y_pred_val = self.model.predict(self.x_val) 55 | roc_val = roc_auc_score(self.y_val, y_pred_val) 56 | print('\rroc-auc_val: %s \n' % (str(round(roc_val, 5))), end=100 * ' ' + '\n') 57 | return 58 | 59 | def on_batch_begin(self, batch, logs={}): 60 | return 61 | 62 | def on_batch_end(self, batch, logs={}): 63 | return 64 | 65 | 66 | def f_for_validating_model(x_train, y_train, x_valid, y_valid, params): 67 | """ 68 | Construct and train a Keras model according to the input params, 69 | and return the model after training is finished. 70 | """ 71 | print(params) 72 | conv_dropout_1 = params['conv_dropout_1'] 73 | conv_dropout_2 = params['conv_dropout_2'] 74 | conv_dropout_3 = params['conv_dropout_3'] 75 | conv_dropout_4 = params['conv_dropout_4'] 76 | conv_dropout_5 = params['conv_dropout_5'] 77 | dense_dropout = params['dense_dropout'] 78 | dense_dim = params['dense_dim'] 79 | optimizer = params['optimizer'] 80 | batch_size = params['batch_size'] 81 | epochs = params['epochs'] 82 | callbacks = params['callbacks'] 83 | 84 | input_length = 4096 85 | 86 | model = Sequential() 87 | model.add(Embedding(256, 16, input_length=input_length)) 88 | model.add(Dropout(conv_dropout_1)) 89 | model.add(Conv1D(48, 32, strides=4, padding='same', dilation_rate=1, activation='relu', use_bias=True, 90 | kernel_initializer='glorot_uniform', bias_initializer='zeros')) 91 | model.add(Dropout(conv_dropout_2)) 92 | model.add(Conv1D(96, 32, strides=4, padding='same', dilation_rate=1, activation='relu', use_bias=True, 93 | kernel_initializer='glorot_uniform', bias_initializer='zeros')) 94 | model.add(Dropout(conv_dropout_3)) 95 | model.add(MaxPooling1D(pool_size=4, strides=None, padding='valid')) 96 | model.add(Conv1D(128, 16, strides=8, padding='same', dilation_rate=1, activation='relu', use_bias=True, 97 | kernel_initializer='glorot_uniform', bias_initializer='zeros')) 98 | model.add(Dropout(conv_dropout_4)) 99 | model.add(Conv1D(192, 16, strides=8, padding='same', dilation_rate=1, activation='relu', use_bias=True, 100 | kernel_initializer='glorot_uniform', bias_initializer='zeros')) 101 | model.add(Dropout(conv_dropout_5)) 102 | 103 | model.add(Flatten()) 104 | 105 | model.add(Dense(dense_dim, activation='selu')) 106 | model.add(Dropout(dense_dropout)) 107 | 108 | model.add(Dense(1, activation='sigmoid')) 109 | model.compile(optimizer=optimizer, 110 | loss='binary_crossentropy', 111 | metrics=['accuracy']) 112 | 113 | model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_valid, y_valid), 114 | callbacks=callbacks) 115 | # score, acc = model.evaluate(x_valid, y_valid, verbose=0) 116 | # print('Test accuracy after training finished: ', acc) 117 | 118 | return model 119 | 120 | 121 | def train_a_model(X, y, model_number=0, random_state=0): 122 | """ 123 | Randomly divide train and validation set according to the input random_state, 124 | and then use the divided dataset to train a model. Prediction accuracy and 125 | AUC score is reported, and the model is saved after each epoch during the 126 | training process. The reported information will then be used to select which 127 | model we want to use for our prediction on the test dataset. 128 | """ 129 | sss = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=random_state) 130 | for train_index, valid_index in sss.split(X, y): 131 | x_train, x_valid = X[train_index], X[valid_index] 132 | y_train, y_valid = y[train_index], y[valid_index] 133 | 134 | checkpoint_filepath = "model_{}".format(model_number) + "-{epoch:02d}-{val_acc:.4f}.hdf5" 135 | checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_acc', verbose=1, save_weights_only=True, mode='max') 136 | 137 | callbacks_list = [checkpoint, roc_callback(training_data=(x_train, y_train), validation_data=(x_valid, y_valid))] 138 | 139 | params_for_validating_model = { 140 | 'batch_size': 128, 141 | 'conv_dropout_1': 0.2, 142 | 'conv_dropout_2': 0.2, 143 | 'conv_dropout_3': 0.2, 144 | 'conv_dropout_4': 0.2, 145 | 'conv_dropout_5': 0.2, 146 | 'dense_dim': 64, 147 | 'dense_dropout': 0.5, 148 | 'epochs': 50, 149 | 'optimizer': 'adam', 150 | 'callbacks': callbacks_list 151 | } 152 | 153 | validated_model = f_for_validating_model(x_train, y_train, x_valid, y_valid, params_for_validating_model) 154 | 155 | 156 | if __name__ == "__main__": 157 | file_name_train = r'train.csv' 158 | file_name_label = r'train_label.csv' 159 | 160 | X, y = read_and_preprocess_training_data(file_name_train, file_name_label) 161 | 162 | train_a_model(X, y, model_number=0, random_state=0) 163 | --------------------------------------------------------------------------------