├── .gitignore ├── models ├── model.h5 ├── scaler.pki ├── onehot_encoder_geo.pki └── label_encoder_gender.pki ├── requirements.txt ├── logs └── fit │ └── 20250114-020203 │ ├── train │ ├── events.out.tfevents.1736849161.DAVEWORLD.21368.0.v2 │ └── events.out.tfevents.1736849278.DAVEWORLD.21368.2.v2 │ └── validation │ ├── events.out.tfevents.1736849171.DAVEWORLD.21368.1.v2 │ └── events.out.tfevents.1736849282.DAVEWORLD.21368.3.v2 ├── .devcontainer └── devcontainer.json ├── README.md ├── app.py └── training ├── model_prediction.ipynb ├── hyperparametertuning.ipynb └── model_training.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ -------------------------------------------------------------------------------- /models/model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EniolaAdemola/customer-churn-classification/HEAD/models/model.h5 -------------------------------------------------------------------------------- /models/scaler.pki: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EniolaAdemola/customer-churn-classification/HEAD/models/scaler.pki -------------------------------------------------------------------------------- /models/onehot_encoder_geo.pki: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EniolaAdemola/customer-churn-classification/HEAD/models/onehot_encoder_geo.pki -------------------------------------------------------------------------------- /models/label_encoder_gender.pki: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EniolaAdemola/customer-churn-classification/HEAD/models/label_encoder_gender.pki -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # tensorflow==2.15.0 2 | tensorflow==2.16.1 #Try this version if you have issues with the 2.15.0 version 3 | numpy 4 | pandas 5 | matplotlib 6 | scikit-learn 7 | streamlit 8 | scikeras 9 | tensorboard -------------------------------------------------------------------------------- /logs/fit/20250114-020203/train/events.out.tfevents.1736849161.DAVEWORLD.21368.0.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EniolaAdemola/customer-churn-classification/HEAD/logs/fit/20250114-020203/train/events.out.tfevents.1736849161.DAVEWORLD.21368.0.v2 -------------------------------------------------------------------------------- /logs/fit/20250114-020203/train/events.out.tfevents.1736849278.DAVEWORLD.21368.2.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EniolaAdemola/customer-churn-classification/HEAD/logs/fit/20250114-020203/train/events.out.tfevents.1736849278.DAVEWORLD.21368.2.v2 -------------------------------------------------------------------------------- /logs/fit/20250114-020203/validation/events.out.tfevents.1736849171.DAVEWORLD.21368.1.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EniolaAdemola/customer-churn-classification/HEAD/logs/fit/20250114-020203/validation/events.out.tfevents.1736849171.DAVEWORLD.21368.1.v2 -------------------------------------------------------------------------------- /logs/fit/20250114-020203/validation/events.out.tfevents.1736849282.DAVEWORLD.21368.3.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EniolaAdemola/customer-churn-classification/HEAD/logs/fit/20250114-020203/validation/events.out.tfevents.1736849282.DAVEWORLD.21368.3.v2 -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Python 3", 3 | // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile 4 | "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye", 5 | "customizations": { 6 | "codespaces": { 7 | "openFiles": [ 8 | "README.md", 9 | "app.py" 10 | ] 11 | }, 12 | "vscode": { 13 | "settings": {}, 14 | "extensions": [ 15 | "ms-python.python", 16 | "ms-python.vscode-pylance" 17 | ] 18 | } 19 | }, 20 | "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y 0.5: 69 | st.write(f"There is a {proba_percentage:.2f}% chance the customer is likely to churn.") 70 | else: 71 | st.write(f"There is a {proba_percentage:.2f}% chance the customer is not likely to churn.") 72 | 73 | 74 | -------------------------------------------------------------------------------- /training/model_prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "from tensorflow.keras.models import load_model\n", 11 | "import pickle\n", 12 | "import numpy as np\n", 13 | "import pandas as pd" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 3, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# Load the trained model, Scaler, LabelEncoder, and one hot encoder\n", 23 | "model = load_model('../models/model.h5')\n", 24 | "\n", 25 | "# Load the encoder and scaler\n", 26 | "with open('../models/onehot_encoder_geo.pki', 'rb') as file:\n", 27 | " label_encoder_geo = pickle.load(file)\n", 28 | "\n", 29 | "with open('../models/label_encoder_gender.pki', 'rb') as file:\n", 30 | " label_encoder_gender = pickle.load(file)\n", 31 | "\n", 32 | "with open('../models/scaler.pki', 'rb') as file:\n", 33 | " scaler = pickle.load(file)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 4, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "input_data = {\n", 43 | " 'CreditScore': 600,\n", 44 | " 'Geography': 'France',\n", 45 | " 'Gender': 'Male',\n", 46 | " 'Age': 40,\n", 47 | " 'Tenure': 3,\n", 48 | " 'Balance': 60000,\n", 49 | " 'NumOfProducts': 2,\n", 50 | " 'HasCrCard': 1,\n", 51 | " 'IsActiveMember': 1,\n", 52 | " 'EstimatedSalary': 50000\n", 53 | "}" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "c:\\Users\\HP\\Documents\\appliso-genai-class\\class-project\\churn-classification\\venv\\Lib\\site-packages\\sklearn\\utils\\validation.py:2739: UserWarning: X does not have valid feature names, but OneHotEncoder was fitted with feature names\n", 66 | " warnings.warn(\n" 67 | ] 68 | }, 69 | { 70 | "data": { 71 | "text/html": [ 72 | "
\n", 73 | "\n", 86 | "\n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | "
Geography_FranceGeography_GermanyGeography_Spain
01.00.00.0
\n", 104 | "
" 105 | ], 106 | "text/plain": [ 107 | " Geography_France Geography_Germany Geography_Spain\n", 108 | "0 1.0 0.0 0.0" 109 | ] 110 | }, 111 | "execution_count": 13, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "geo_encoded = label_encoder_geo.transform([[input_data['Geography']]]).toarray()\n", 118 | "# geo_encoded_df(pd.DataFrame(encoded, columns=onehot_encoder_geo.categories_[0]))\n", 119 | "geo_encoded_df = pd.DataFrame(geo_encoded, columns=label_encoder_geo.get_feature_names_out([\"Geography\"])) #you can also use categories_ to get the column name\n", 120 | "geo_encoded_df" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 14, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "{'CreditScore': 600,\n", 132 | " 'Geography': 'France',\n", 133 | " 'Gender': 'Male',\n", 134 | " 'Age': 40,\n", 135 | " 'Tenure': 3,\n", 136 | " 'Balance': 60000,\n", 137 | " 'NumOfProducts': 2,\n", 138 | " 'HasCrCard': 1,\n", 139 | " 'IsActiveMember': 1,\n", 140 | " 'EstimatedSalary': 50000}" 141 | ] 142 | }, 143 | "execution_count": 14, 144 | "metadata": {}, 145 | "output_type": "execute_result" 146 | } 147 | ], 148 | "source": [ 149 | "input_data" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/html": [ 160 | "
\n", 161 | "\n", 174 | "\n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | "
CreditScoreGeographyGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalary
0600FranceMale4036000021150000
\n", 206 | "
" 207 | ], 208 | "text/plain": [ 209 | " CreditScore Geography Gender ... HasCrCard IsActiveMember EstimatedSalary\n", 210 | "0 600 France Male ... 1 1 50000\n", 211 | "\n", 212 | "[1 rows x 10 columns]" 213 | ] 214 | }, 215 | "execution_count": 23, 216 | "metadata": {}, 217 | "output_type": "execute_result" 218 | } 219 | ], 220 | "source": [ 221 | "# turn the inout to a dataframe\n", 222 | "input_df = pd.DataFrame([input_data])\n", 223 | "input_df" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 24, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/html": [ 234 | "
\n", 235 | "\n", 248 | "\n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | "
CreditScoreGeographyGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalary
0600France14036000021150000
\n", 280 | "
" 281 | ], 282 | "text/plain": [ 283 | " CreditScore Geography Gender ... HasCrCard IsActiveMember EstimatedSalary\n", 284 | "0 600 France 1 ... 1 1 50000\n", 285 | "\n", 286 | "[1 rows x 10 columns]" 287 | ] 288 | }, 289 | "execution_count": 24, 290 | "metadata": {}, 291 | "output_type": "execute_result" 292 | } 293 | ], 294 | "source": [ 295 | "# Encoding the gender\n", 296 | "input_df[\"Gender\"] = label_encoder_gender.transform(input_df[\"Gender\"])\n", 297 | "input_df" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "text/html": [ 308 | "
\n", 309 | "\n", 322 | "\n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | "
CreditScoreGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryGeography_FranceGeography_GermanyGeography_Spain
0600140360000211500001.00.00.0
\n", 358 | "
" 359 | ], 360 | "text/plain": [ 361 | " CreditScore Gender ... Geography_Germany Geography_Spain\n", 362 | "0 600 1 ... 0.0 0.0\n", 363 | "\n", 364 | "[1 rows x 12 columns]" 365 | ] 366 | }, 367 | "execution_count": 25, 368 | "metadata": {}, 369 | "output_type": "execute_result" 370 | } 371 | ], 372 | "source": [ 373 | "# Drop the Geography column and add the encoded geography columns\n", 374 | "input_df = pd.concat([input_df.drop(\"Geography\", axis=1), geo_encoded_df], axis=1)\n", 375 | "input_df" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 26, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "data": { 385 | "text/html": [ 386 | "
\n", 387 | "\n", 400 | "\n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | "
CreditScoreGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryGeography_FranceGeography_GermanyGeography_Spain
0600140360000211500001.00.00.0
\n", 436 | "
" 437 | ], 438 | "text/plain": [ 439 | " CreditScore Gender ... Geography_Germany Geography_Spain\n", 440 | "0 600 1 ... 0.0 0.0\n", 441 | "\n", 442 | "[1 rows x 12 columns]" 443 | ] 444 | }, 445 | "execution_count": 26, 446 | "metadata": {}, 447 | "output_type": "execute_result" 448 | } 449 | ], 450 | "source": [ 451 | "input_df" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 27, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "data": { 461 | "text/plain": [ 462 | "array([[-0.53598516, 0.91324755, 0.10479359, -0.69539349, -0.25781119,\n", 463 | " 0.80843615, 0.64920267, 0.97481699, -0.87683221, 1.00150113,\n", 464 | " -0.57946723, -0.57638802]])" 465 | ] 466 | }, 467 | "execution_count": 27, 468 | "metadata": {}, 469 | "output_type": "execute_result" 470 | } 471 | ], 472 | "source": [ 473 | "# Scale the input\n", 474 | "input_scaled = scaler.transform(input_df)\n", 475 | "input_scaled" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 28, 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | "1/1 [==============================] - 0s 494ms/step\n" 488 | ] 489 | } 490 | ], 491 | "source": [ 492 | "# Make a prediction\n", 493 | "prediction = model.predict(input_scaled)" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 30, 499 | "metadata": {}, 500 | "outputs": [ 501 | { 502 | "data": { 503 | "text/plain": [ 504 | "array([[0.05843489]], dtype=float32)" 505 | ] 506 | }, 507 | "execution_count": 30, 508 | "metadata": {}, 509 | "output_type": "execute_result" 510 | } 511 | ], 512 | "source": [ 513 | "prediction" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": 32, 519 | "metadata": {}, 520 | "outputs": [ 521 | { 522 | "data": { 523 | "text/plain": [ 524 | "0.05843489" 525 | ] 526 | }, 527 | "execution_count": 32, 528 | "metadata": {}, 529 | "output_type": "execute_result" 530 | } 531 | ], 532 | "source": [ 533 | "prediction_proba = prediction[0][0]\n", 534 | "prediction_proba" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 34, 540 | "metadata": {}, 541 | "outputs": [ 542 | { 543 | "name": "stdout", 544 | "output_type": "stream", 545 | "text": [ 546 | "The customer is likely to stay with the bank\n" 547 | ] 548 | } 549 | ], 550 | "source": [ 551 | "# Print the prediction\n", 552 | "if prediction_proba > 0.5:\n", 553 | " print(\"The customer is likely to leave the bank\")\n", 554 | "else:\n", 555 | " print(\"The customer is likely to stay with the bank\")" 556 | ] 557 | } 558 | ], 559 | "metadata": { 560 | "kernelspec": { 561 | "display_name": "Python 3", 562 | "language": "python", 563 | "name": "python3" 564 | }, 565 | "language_info": { 566 | "codemirror_mode": { 567 | "name": "ipython", 568 | "version": 3 569 | }, 570 | "file_extension": ".py", 571 | "mimetype": "text/x-python", 572 | "name": "python", 573 | "nbconvert_exporter": "python", 574 | "pygments_lexer": "ipython3", 575 | "version": "3.11.0" 576 | } 577 | }, 578 | "nbformat": 4, 579 | "nbformat_minor": 2 580 | } 581 | -------------------------------------------------------------------------------- /training/hyperparametertuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Finding the Best Layout for a Neural Network\n", 8 | "\n", 9 | "Choosing the right number of hidden layers and neurons can be tricky, but these tips can help:\n", 10 | "\n", 11 | "- **Start Simple**: Begin with a basic design and add complexity only if needed.\n", 12 | "- **Experiment**: Use methods like grid search or random search to try different architectures.\n", 13 | "- **Test & Validate**: Apply cross-validation to check the performance of your designs.\n", 14 | "- **Follow Basic Rules**:\n", 15 | " - The size of the hidden layer should be between the input and output layer sizes.\n", 16 | " - Starting with 1-2 hidden layers is often a good approach.\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 3, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import pandas as pd\n", 26 | "from sklearn.model_selection import train_test_split, GridSearchCV\n", 27 | "from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder\n", 28 | "from sklearn.pipeline import Pipeline\n", 29 | "from scikeras.wrappers import KerasClassifier\n", 30 | "import tensorflow as tf\n", 31 | "from tensorflow.keras.models import Sequential\n", 32 | "from tensorflow.keras.layers import Dense\n", 33 | "from tensorflow.keras.callbacks import EarlyStopping\n", 34 | "import pickle" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "data=pd.read_csv('../data/Churn_Modelling.csv')\n", 44 | "data = data.drop(['RowNumber', 'CustomerId', 'Surname'], axis=1)\n", 45 | "\n", 46 | "label_encoder_gender = LabelEncoder()\n", 47 | "data['Gender'] = label_encoder_gender.fit_transform(data['Gender'])\n", 48 | "\n", 49 | "onehot_encoder_geo = OneHotEncoder(handle_unknown='ignore')\n", 50 | "geo_encoded = onehot_encoder_geo.fit_transform(data[['Geography']]).toarray()\n", 51 | "geo_encoded_df = pd.DataFrame(geo_encoded, columns=onehot_encoder_geo.get_feature_names_out(['Geography']))\n", 52 | "\n", 53 | "data = pd.concat([data.drop('Geography', axis=1), geo_encoded_df], axis=1)\n", 54 | "\n", 55 | "X = data.drop('Exited', axis=1)\n", 56 | "y = data['Exited']\n", 57 | "\n", 58 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 59 | "\n", 60 | "scaler = StandardScaler()\n", 61 | "X_train = scaler.fit_transform(X_train)\n", 62 | "X_test = scaler.transform(X_test)\n", 63 | "\n", 64 | "# Save encoders and scaler for later use\n", 65 | "with open('../models/label_encoder_gender.pkl', 'wb') as file:\n", 66 | " pickle.dump(label_encoder_gender, file)\n", 67 | "\n", 68 | "with open('../models/onehot_encoder_geo.pkl', 'wb') as file:\n", 69 | " pickle.dump(onehot_encoder_geo, file)\n", 70 | "\n", 71 | "with open('../models/scaler.pkl', 'wb') as file:\n", 72 | " pickle.dump(scaler, file)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 19, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "## Define a function to create the model and try different parameters(KerasClassifier)\n", 82 | "\n", 83 | "def create_model(neurons=32,layers=1):\n", 84 | " model=Sequential()\n", 85 | " model.add(Dense(neurons,activation='relu',input_shape=(X_train.shape[1],)))\n", 86 | "\n", 87 | " for _ in range(layers-1):\n", 88 | " model.add(Dense(neurons,activation='relu'))\n", 89 | "\n", 90 | " model.add(Dense(1,activation='sigmoid'))\n", 91 | " model.compile(optimizer='adam',loss=\"binary_crossentropy\",metrics=['accuracy'])\n", 92 | "\n", 93 | " return model\n", 94 | "\n" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 20, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "## Create a Keras classifier\n", 104 | "model=KerasClassifier(layers=1,neurons=32,build_fn=create_model,verbose=1)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 21, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "\n", 114 | "# Define the grid search parameters\n", 115 | "param_grid = {\n", 116 | " 'neurons': [16, 32, 64, 128],\n", 117 | " 'layers': [1, 2],\n", 118 | " 'epochs': [50, 100]\n", 119 | "}" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 22, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n", 132 | "WARNING:tensorflow:From e:\\UDemy Final\\ANN Classification\\venv\\Lib\\site-packages\\keras\\src\\backend.py:873: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", 133 | "\n" 134 | ] 135 | }, 136 | { 137 | "name": "stderr", 138 | "output_type": "stream", 139 | "text": [ 140 | "e:\\UDemy Final\\ANN Classification\\venv\\Lib\\site-packages\\scikeras\\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.\n", 141 | " X, y = self._initialize(X, y)\n" 142 | ] 143 | }, 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "WARNING:tensorflow:From e:\\UDemy Final\\ANN Classification\\venv\\Lib\\site-packages\\keras\\src\\optimizers\\__init__.py:309: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", 149 | "\n", 150 | "Epoch 1/100\n", 151 | "WARNING:tensorflow:From e:\\UDemy Final\\ANN Classification\\venv\\Lib\\site-packages\\keras\\src\\utils\\tf_utils.py:492: The name tf.ragged.RaggedTensorValue is deprecated. Please use tf.compat.v1.ragged.RaggedTensorValue instead.\n", 152 | "\n", 153 | "WARNING:tensorflow:From e:\\UDemy Final\\ANN Classification\\venv\\Lib\\site-packages\\keras\\src\\engine\\base_layer_utils.py:384: The name tf.executing_eagerly_outside_functions is deprecated. Please use tf.compat.v1.executing_eagerly_outside_functions instead.\n", 154 | "\n", 155 | "250/250 [==============================] - 1s 1ms/step - loss: 0.6275 - accuracy: 0.6605\n", 156 | "Epoch 2/100\n", 157 | "250/250 [==============================] - 0s 958us/step - loss: 0.4631 - accuracy: 0.8009\n", 158 | "Epoch 3/100\n", 159 | "250/250 [==============================] - 0s 938us/step - loss: 0.4391 - accuracy: 0.8058\n", 160 | "Epoch 4/100\n", 161 | "250/250 [==============================] - 0s 972us/step - loss: 0.4250 - accuracy: 0.8140\n", 162 | "Epoch 5/100\n", 163 | "250/250 [==============================] - 0s 990us/step - loss: 0.4109 - accuracy: 0.8246\n", 164 | "Epoch 6/100\n", 165 | "250/250 [==============================] - 0s 961us/step - loss: 0.3973 - accuracy: 0.8310\n", 166 | "Epoch 7/100\n", 167 | "250/250 [==============================] - 0s 966us/step - loss: 0.3853 - accuracy: 0.8404\n", 168 | "Epoch 8/100\n", 169 | "250/250 [==============================] - 0s 978us/step - loss: 0.3755 - accuracy: 0.8457\n", 170 | "Epoch 9/100\n", 171 | "250/250 [==============================] - 0s 974us/step - loss: 0.3683 - accuracy: 0.8505\n", 172 | "Epoch 10/100\n", 173 | "250/250 [==============================] - 0s 967us/step - loss: 0.3622 - accuracy: 0.8519\n", 174 | "Epoch 11/100\n", 175 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3571 - accuracy: 0.8565\n", 176 | "Epoch 12/100\n", 177 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3532 - accuracy: 0.8560\n", 178 | "Epoch 13/100\n", 179 | "250/250 [==============================] - 0s 973us/step - loss: 0.3500 - accuracy: 0.8593\n", 180 | "Epoch 14/100\n", 181 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3474 - accuracy: 0.8600\n", 182 | "Epoch 15/100\n", 183 | "250/250 [==============================] - 0s 973us/step - loss: 0.3456 - accuracy: 0.8616\n", 184 | "Epoch 16/100\n", 185 | "250/250 [==============================] - 0s 946us/step - loss: 0.3440 - accuracy: 0.8597\n", 186 | "Epoch 17/100\n", 187 | "250/250 [==============================] - 0s 952us/step - loss: 0.3425 - accuracy: 0.8620\n", 188 | "Epoch 18/100\n", 189 | "250/250 [==============================] - 0s 955us/step - loss: 0.3414 - accuracy: 0.8626\n", 190 | "Epoch 19/100\n", 191 | "250/250 [==============================] - 0s 946us/step - loss: 0.3405 - accuracy: 0.8599\n", 192 | "Epoch 20/100\n", 193 | "250/250 [==============================] - 0s 931us/step - loss: 0.3397 - accuracy: 0.8616\n", 194 | "Epoch 21/100\n", 195 | "250/250 [==============================] - 0s 975us/step - loss: 0.3391 - accuracy: 0.8629\n", 196 | "Epoch 22/100\n", 197 | "250/250 [==============================] - 0s 990us/step - loss: 0.3382 - accuracy: 0.8626\n", 198 | "Epoch 23/100\n", 199 | "250/250 [==============================] - 0s 972us/step - loss: 0.3378 - accuracy: 0.8615\n", 200 | "Epoch 24/100\n", 201 | "250/250 [==============================] - 0s 975us/step - loss: 0.3376 - accuracy: 0.8619\n", 202 | "Epoch 25/100\n", 203 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3368 - accuracy: 0.8641\n", 204 | "Epoch 26/100\n", 205 | "250/250 [==============================] - 0s 967us/step - loss: 0.3367 - accuracy: 0.8612\n", 206 | "Epoch 27/100\n", 207 | "250/250 [==============================] - 0s 980us/step - loss: 0.3362 - accuracy: 0.8631\n", 208 | "Epoch 28/100\n", 209 | "250/250 [==============================] - 0s 980us/step - loss: 0.3360 - accuracy: 0.8627\n", 210 | "Epoch 29/100\n", 211 | "250/250 [==============================] - 0s 987us/step - loss: 0.3356 - accuracy: 0.8625\n", 212 | "Epoch 30/100\n", 213 | "250/250 [==============================] - 0s 981us/step - loss: 0.3352 - accuracy: 0.8635\n", 214 | "Epoch 31/100\n", 215 | "250/250 [==============================] - 0s 975us/step - loss: 0.3353 - accuracy: 0.8620\n", 216 | "Epoch 32/100\n", 217 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3347 - accuracy: 0.8635\n", 218 | "Epoch 33/100\n", 219 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3346 - accuracy: 0.8634\n", 220 | "Epoch 34/100\n", 221 | "250/250 [==============================] - 0s 964us/step - loss: 0.3343 - accuracy: 0.8625\n", 222 | "Epoch 35/100\n", 223 | "250/250 [==============================] - 0s 927us/step - loss: 0.3340 - accuracy: 0.8651\n", 224 | "Epoch 36/100\n", 225 | "250/250 [==============================] - 0s 952us/step - loss: 0.3338 - accuracy: 0.8629\n", 226 | "Epoch 37/100\n", 227 | "250/250 [==============================] - 0s 942us/step - loss: 0.3336 - accuracy: 0.8636\n", 228 | "Epoch 38/100\n", 229 | "250/250 [==============================] - 0s 954us/step - loss: 0.3333 - accuracy: 0.8620\n", 230 | "Epoch 39/100\n", 231 | "250/250 [==============================] - 0s 963us/step - loss: 0.3333 - accuracy: 0.8625\n", 232 | "Epoch 40/100\n", 233 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3331 - accuracy: 0.8640\n", 234 | "Epoch 41/100\n", 235 | "250/250 [==============================] - 0s 919us/step - loss: 0.3327 - accuracy: 0.8633\n", 236 | "Epoch 42/100\n", 237 | "250/250 [==============================] - 0s 939us/step - loss: 0.3325 - accuracy: 0.8641\n", 238 | "Epoch 43/100\n", 239 | "250/250 [==============================] - 0s 901us/step - loss: 0.3325 - accuracy: 0.8634\n", 240 | "Epoch 44/100\n", 241 | "250/250 [==============================] - 0s 916us/step - loss: 0.3324 - accuracy: 0.8637\n", 242 | "Epoch 45/100\n", 243 | "250/250 [==============================] - 0s 964us/step - loss: 0.3321 - accuracy: 0.8641\n", 244 | "Epoch 46/100\n", 245 | "250/250 [==============================] - 0s 940us/step - loss: 0.3323 - accuracy: 0.8640\n", 246 | "Epoch 47/100\n", 247 | "250/250 [==============================] - 0s 919us/step - loss: 0.3319 - accuracy: 0.8639\n", 248 | "Epoch 48/100\n", 249 | "250/250 [==============================] - 0s 931us/step - loss: 0.3318 - accuracy: 0.8644\n", 250 | "Epoch 49/100\n", 251 | "250/250 [==============================] - 0s 948us/step - loss: 0.3317 - accuracy: 0.8643\n", 252 | "Epoch 50/100\n", 253 | "250/250 [==============================] - 0s 922us/step - loss: 0.3313 - accuracy: 0.8644\n", 254 | "Epoch 51/100\n", 255 | "250/250 [==============================] - 0s 897us/step - loss: 0.3313 - accuracy: 0.8648\n", 256 | "Epoch 52/100\n", 257 | "250/250 [==============================] - 0s 900us/step - loss: 0.3313 - accuracy: 0.8643\n", 258 | "Epoch 53/100\n", 259 | "250/250 [==============================] - 0s 915us/step - loss: 0.3311 - accuracy: 0.8630\n", 260 | "Epoch 54/100\n", 261 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3309 - accuracy: 0.8625\n", 262 | "Epoch 55/100\n", 263 | "250/250 [==============================] - 0s 936us/step - loss: 0.3315 - accuracy: 0.8627\n", 264 | "Epoch 56/100\n", 265 | "250/250 [==============================] - 0s 934us/step - loss: 0.3309 - accuracy: 0.8639\n", 266 | "Epoch 57/100\n", 267 | "250/250 [==============================] - 0s 927us/step - loss: 0.3305 - accuracy: 0.8649\n", 268 | "Epoch 58/100\n", 269 | "250/250 [==============================] - 0s 945us/step - loss: 0.3303 - accuracy: 0.8659\n", 270 | "Epoch 59/100\n", 271 | "250/250 [==============================] - 0s 940us/step - loss: 0.3307 - accuracy: 0.8644\n", 272 | "Epoch 60/100\n", 273 | "250/250 [==============================] - 0s 960us/step - loss: 0.3305 - accuracy: 0.8635\n", 274 | "Epoch 61/100\n", 275 | "250/250 [==============================] - 0s 959us/step - loss: 0.3306 - accuracy: 0.8627\n", 276 | "Epoch 62/100\n", 277 | "250/250 [==============================] - 0s 963us/step - loss: 0.3303 - accuracy: 0.8636\n", 278 | "Epoch 63/100\n", 279 | "250/250 [==============================] - 0s 967us/step - loss: 0.3297 - accuracy: 0.8648\n", 280 | "Epoch 64/100\n", 281 | "250/250 [==============================] - 0s 940us/step - loss: 0.3299 - accuracy: 0.8629\n", 282 | "Epoch 65/100\n", 283 | "250/250 [==============================] - 0s 935us/step - loss: 0.3303 - accuracy: 0.8636\n", 284 | "Epoch 66/100\n", 285 | "250/250 [==============================] - 0s 995us/step - loss: 0.3299 - accuracy: 0.8641\n", 286 | "Epoch 67/100\n", 287 | "250/250 [==============================] - 0s 956us/step - loss: 0.3298 - accuracy: 0.8649\n", 288 | "Epoch 68/100\n", 289 | "250/250 [==============================] - 0s 922us/step - loss: 0.3298 - accuracy: 0.8634\n", 290 | "Epoch 69/100\n", 291 | "250/250 [==============================] - 0s 910us/step - loss: 0.3295 - accuracy: 0.8633\n", 292 | "Epoch 70/100\n", 293 | "250/250 [==============================] - 0s 952us/step - loss: 0.3295 - accuracy: 0.8643\n", 294 | "Epoch 71/100\n", 295 | "250/250 [==============================] - 0s 944us/step - loss: 0.3296 - accuracy: 0.8629\n", 296 | "Epoch 72/100\n", 297 | "250/250 [==============================] - 0s 929us/step - loss: 0.3292 - accuracy: 0.8640\n", 298 | "Epoch 73/100\n", 299 | "250/250 [==============================] - 0s 955us/step - loss: 0.3291 - accuracy: 0.8627\n", 300 | "Epoch 74/100\n", 301 | "250/250 [==============================] - 0s 936us/step - loss: 0.3294 - accuracy: 0.8641\n", 302 | "Epoch 75/100\n", 303 | "250/250 [==============================] - 0s 949us/step - loss: 0.3292 - accuracy: 0.8648\n", 304 | "Epoch 76/100\n", 305 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3293 - accuracy: 0.8651\n", 306 | "Epoch 77/100\n", 307 | "250/250 [==============================] - 0s 990us/step - loss: 0.3290 - accuracy: 0.8641\n", 308 | "Epoch 78/100\n", 309 | "250/250 [==============================] - 0s 937us/step - loss: 0.3293 - accuracy: 0.8634\n", 310 | "Epoch 79/100\n", 311 | "250/250 [==============================] - 0s 939us/step - loss: 0.3286 - accuracy: 0.8645\n", 312 | "Epoch 80/100\n", 313 | "250/250 [==============================] - 0s 916us/step - loss: 0.3290 - accuracy: 0.8648\n", 314 | "Epoch 81/100\n", 315 | "250/250 [==============================] - 0s 934us/step - loss: 0.3290 - accuracy: 0.8641\n", 316 | "Epoch 82/100\n", 317 | "250/250 [==============================] - 0s 936us/step - loss: 0.3288 - accuracy: 0.8651\n", 318 | "Epoch 83/100\n", 319 | "250/250 [==============================] - 0s 951us/step - loss: 0.3289 - accuracy: 0.8641\n", 320 | "Epoch 84/100\n", 321 | "250/250 [==============================] - 0s 936us/step - loss: 0.3286 - accuracy: 0.8631\n", 322 | "Epoch 85/100\n", 323 | "250/250 [==============================] - 0s 984us/step - loss: 0.3289 - accuracy: 0.8644\n", 324 | "Epoch 86/100\n", 325 | "250/250 [==============================] - 0s 935us/step - loss: 0.3287 - accuracy: 0.8636\n", 326 | "Epoch 87/100\n", 327 | "250/250 [==============================] - 0s 956us/step - loss: 0.3286 - accuracy: 0.8651\n", 328 | "Epoch 88/100\n", 329 | "250/250 [==============================] - 0s 951us/step - loss: 0.3284 - accuracy: 0.8635\n", 330 | "Epoch 89/100\n", 331 | "250/250 [==============================] - 0s 958us/step - loss: 0.3286 - accuracy: 0.8645\n", 332 | "Epoch 90/100\n", 333 | "250/250 [==============================] - 0s 956us/step - loss: 0.3288 - accuracy: 0.8648\n", 334 | "Epoch 91/100\n", 335 | "250/250 [==============================] - 0s 932us/step - loss: 0.3282 - accuracy: 0.8637\n", 336 | "Epoch 92/100\n", 337 | "250/250 [==============================] - 0s 918us/step - loss: 0.3282 - accuracy: 0.8635\n", 338 | "Epoch 93/100\n", 339 | "250/250 [==============================] - 0s 948us/step - loss: 0.3285 - accuracy: 0.8633\n", 340 | "Epoch 94/100\n", 341 | "250/250 [==============================] - 0s 952us/step - loss: 0.3285 - accuracy: 0.8627\n", 342 | "Epoch 95/100\n", 343 | "250/250 [==============================] - 0s 956us/step - loss: 0.3280 - accuracy: 0.8649\n", 344 | "Epoch 96/100\n", 345 | "250/250 [==============================] - 0s 928us/step - loss: 0.3283 - accuracy: 0.8650\n", 346 | "Epoch 97/100\n", 347 | "250/250 [==============================] - 0s 947us/step - loss: 0.3282 - accuracy: 0.8646\n", 348 | "Epoch 98/100\n", 349 | "250/250 [==============================] - 0s 1ms/step - loss: 0.3278 - accuracy: 0.8639\n", 350 | "Epoch 99/100\n", 351 | "250/250 [==============================] - 0s 929us/step - loss: 0.3278 - accuracy: 0.8630\n", 352 | "Epoch 100/100\n", 353 | "250/250 [==============================] - 0s 950us/step - loss: 0.3276 - accuracy: 0.8636\n", 354 | "Best: 0.858375 using {'epochs': 100, 'layers': 1, 'neurons': 16}\n" 355 | ] 356 | } 357 | ], 358 | "source": [ 359 | "# Perform grid search\n", 360 | "grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1, cv=3,verbose=1)\n", 361 | "grid_result = grid.fit(X_train, y_train)\n", 362 | "\n", 363 | "# Print the best parameters\n", 364 | "print(\"Best: %f using %s\" % (grid_result.best_score_, grid_result.best_params_))" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "Python 3", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.11.0" 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 2 396 | } 397 | -------------------------------------------------------------------------------- /training/model_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "from sklearn.model_selection import train_test_split\n", 11 | "from sklearn.preprocessing import StandardScaler, LabelEncoder" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/html": [ 22 | "
\n", 23 | "\n", 36 | "\n", 37 | " \n", 38 | " \n", 39 | " \n", 40 | " \n", 41 | " \n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | "
RowNumberCustomerIdSurnameCreditScoreGeographyGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExited
0115634602Hargrave619FranceFemale4220.00111101348.881
1215647311Hill608SpainFemale41183807.86101112542.580
2315619304Onio502FranceFemale428159660.80310113931.571
3415701354Boni699FranceFemale3910.0020093826.630
4515737888Mitchell850SpainFemale432125510.8211179084.100
\n", 144 | "
" 145 | ], 146 | "text/plain": [ 147 | " RowNumber CustomerId Surname CreditScore Geography Gender Age \\\n", 148 | "0 1 15634602 Hargrave 619 France Female 42 \n", 149 | "1 2 15647311 Hill 608 Spain Female 41 \n", 150 | "2 3 15619304 Onio 502 France Female 42 \n", 151 | "3 4 15701354 Boni 699 France Female 39 \n", 152 | "4 5 15737888 Mitchell 850 Spain Female 43 \n", 153 | "\n", 154 | " Tenure Balance NumOfProducts HasCrCard IsActiveMember \\\n", 155 | "0 2 0.00 1 1 1 \n", 156 | "1 1 83807.86 1 0 1 \n", 157 | "2 8 159660.80 3 1 0 \n", 158 | "3 1 0.00 2 0 0 \n", 159 | "4 2 125510.82 1 1 1 \n", 160 | "\n", 161 | " EstimatedSalary Exited \n", 162 | "0 101348.88 1 \n", 163 | "1 112542.58 0 \n", 164 | "2 113931.57 1 \n", 165 | "3 93826.63 0 \n", 166 | "4 79084.10 0 " 167 | ] 168 | }, 169 | "execution_count": 3, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "# Load the dataset\n", 176 | "data = pd.read_csv(\"../data/Churn_Modelling.csv\")\n", 177 | "data.head()" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 4, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/html": [ 188 | "
\n", 189 | "\n", 202 | "\n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | "
CreditScoreGeographyGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExited
0619FranceFemale4220.00111101348.881
1608SpainFemale41183807.86101112542.580
2502FranceFemale428159660.80310113931.571
3699FranceFemale3910.0020093826.630
4850SpainFemale432125510.8211179084.100
....................................
9995771FranceMale3950.0021096270.640
9996516FranceMale351057369.61111101699.770
9997709FranceFemale3670.0010142085.581
9998772GermanyMale42375075.3121092888.521
9999792FranceFemale284130142.7911038190.780
\n", 376 | "

10000 rows × 11 columns

\n", 377 | "
" 378 | ], 379 | "text/plain": [ 380 | " CreditScore Geography Gender Age Tenure Balance NumOfProducts \\\n", 381 | "0 619 France Female 42 2 0.00 1 \n", 382 | "1 608 Spain Female 41 1 83807.86 1 \n", 383 | "2 502 France Female 42 8 159660.80 3 \n", 384 | "3 699 France Female 39 1 0.00 2 \n", 385 | "4 850 Spain Female 43 2 125510.82 1 \n", 386 | "... ... ... ... ... ... ... ... \n", 387 | "9995 771 France Male 39 5 0.00 2 \n", 388 | "9996 516 France Male 35 10 57369.61 1 \n", 389 | "9997 709 France Female 36 7 0.00 1 \n", 390 | "9998 772 Germany Male 42 3 75075.31 2 \n", 391 | "9999 792 France Female 28 4 130142.79 1 \n", 392 | "\n", 393 | " HasCrCard IsActiveMember EstimatedSalary Exited \n", 394 | "0 1 1 101348.88 1 \n", 395 | "1 0 1 112542.58 0 \n", 396 | "2 1 0 113931.57 1 \n", 397 | "3 0 0 93826.63 0 \n", 398 | "4 1 1 79084.10 0 \n", 399 | "... ... ... ... ... \n", 400 | "9995 1 0 96270.64 0 \n", 401 | "9996 1 1 101699.77 0 \n", 402 | "9997 0 1 42085.58 1 \n", 403 | "9998 1 0 92888.52 1 \n", 404 | "9999 1 0 38190.78 0 \n", 405 | "\n", 406 | "[10000 rows x 11 columns]" 407 | ] 408 | }, 409 | "execution_count": 4, 410 | "metadata": {}, 411 | "output_type": "execute_result" 412 | } 413 | ], 414 | "source": [ 415 | "## Preprocessing the data\n", 416 | "# Drop the columns which are not required\n", 417 | "data = data.drop([\"RowNumber\", \"CustomerId\", \"Surname\"], axis=1)\n", 418 | "data" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 5, 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "\n", 431 | "RangeIndex: 10000 entries, 0 to 9999\n", 432 | "Data columns (total 11 columns):\n", 433 | " # Column Non-Null Count Dtype \n", 434 | "--- ------ -------------- ----- \n", 435 | " 0 CreditScore 10000 non-null int64 \n", 436 | " 1 Geography 10000 non-null object \n", 437 | " 2 Gender 10000 non-null object \n", 438 | " 3 Age 10000 non-null int64 \n", 439 | " 4 Tenure 10000 non-null int64 \n", 440 | " 5 Balance 10000 non-null float64\n", 441 | " 6 NumOfProducts 10000 non-null int64 \n", 442 | " 7 HasCrCard 10000 non-null int64 \n", 443 | " 8 IsActiveMember 10000 non-null int64 \n", 444 | " 9 EstimatedSalary 10000 non-null float64\n", 445 | " 10 Exited 10000 non-null int64 \n", 446 | "dtypes: float64(2), int64(7), object(2)\n", 447 | "memory usage: 859.5+ KB\n" 448 | ] 449 | } 450 | ], 451 | "source": [ 452 | "data.info()" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 6, 458 | "metadata": {}, 459 | "outputs": [ 460 | { 461 | "data": { 462 | "text/plain": [ 463 | "array(['Female', 'Male'], dtype=object)" 464 | ] 465 | }, 466 | "execution_count": 6, 467 | "metadata": {}, 468 | "output_type": "execute_result" 469 | } 470 | ], 471 | "source": [ 472 | "# data[\"Gender\"].value_counts()\n", 473 | "data[\"Gender\"].unique()" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 7, 479 | "metadata": {}, 480 | "outputs": [ 481 | { 482 | "data": { 483 | "text/html": [ 484 | "
\n", 485 | "\n", 498 | "\n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | "
CreditScoreGeographyGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExited
0619France04220.00111101348.881
1608Spain041183807.86101112542.580
2502France0428159660.80310113931.571
3699France03910.0020093826.630
4850Spain0432125510.8211179084.100
....................................
9995771France13950.0021096270.640
9996516France1351057369.61111101699.770
9997709France03670.0010142085.581
9998772Germany142375075.3121092888.521
9999792France0284130142.7911038190.780
\n", 672 | "

10000 rows × 11 columns

\n", 673 | "
" 674 | ], 675 | "text/plain": [ 676 | " CreditScore Geography Gender Age Tenure Balance NumOfProducts \\\n", 677 | "0 619 France 0 42 2 0.00 1 \n", 678 | "1 608 Spain 0 41 1 83807.86 1 \n", 679 | "2 502 France 0 42 8 159660.80 3 \n", 680 | "3 699 France 0 39 1 0.00 2 \n", 681 | "4 850 Spain 0 43 2 125510.82 1 \n", 682 | "... ... ... ... ... ... ... ... \n", 683 | "9995 771 France 1 39 5 0.00 2 \n", 684 | "9996 516 France 1 35 10 57369.61 1 \n", 685 | "9997 709 France 0 36 7 0.00 1 \n", 686 | "9998 772 Germany 1 42 3 75075.31 2 \n", 687 | "9999 792 France 0 28 4 130142.79 1 \n", 688 | "\n", 689 | " HasCrCard IsActiveMember EstimatedSalary Exited \n", 690 | "0 1 1 101348.88 1 \n", 691 | "1 0 1 112542.58 0 \n", 692 | "2 1 0 113931.57 1 \n", 693 | "3 0 0 93826.63 0 \n", 694 | "4 1 1 79084.10 0 \n", 695 | "... ... ... ... ... \n", 696 | "9995 1 0 96270.64 0 \n", 697 | "9996 1 1 101699.77 0 \n", 698 | "9997 0 1 42085.58 1 \n", 699 | "9998 1 0 92888.52 1 \n", 700 | "9999 1 0 38190.78 0 \n", 701 | "\n", 702 | "[10000 rows x 11 columns]" 703 | ] 704 | }, 705 | "execution_count": 7, 706 | "metadata": {}, 707 | "output_type": "execute_result" 708 | } 709 | ], 710 | "source": [ 711 | "# Encode the categorical data\n", 712 | "label_encoder_gender = LabelEncoder()\n", 713 | "data[\"Gender\"] = label_encoder_gender.fit_transform(data[\"Gender\"])\n", 714 | "data" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": 8, 720 | "metadata": {}, 721 | "outputs": [ 722 | { 723 | "data": { 724 | "text/plain": [ 725 | "array([0, 1])" 726 | ] 727 | }, 728 | "execution_count": 8, 729 | "metadata": {}, 730 | "output_type": "execute_result" 731 | } 732 | ], 733 | "source": [ 734 | "data[\"Gender\"].unique()" 735 | ] 736 | }, 737 | { 738 | "cell_type": "code", 739 | "execution_count": 9, 740 | "metadata": {}, 741 | "outputs": [ 742 | { 743 | "data": { 744 | "text/plain": [ 745 | "array([[1., 0., 0.],\n", 746 | " [0., 0., 1.],\n", 747 | " [1., 0., 0.],\n", 748 | " ...,\n", 749 | " [1., 0., 0.],\n", 750 | " [0., 1., 0.],\n", 751 | " [1., 0., 0.]])" 752 | ] 753 | }, 754 | "execution_count": 9, 755 | "metadata": {}, 756 | "output_type": "execute_result" 757 | } 758 | ], 759 | "source": [ 760 | "# one hot encoding for \"Geography\" column\n", 761 | "from sklearn.preprocessing import OneHotEncoder\n", 762 | "onehot_encoder_geo = OneHotEncoder()\n", 763 | "# geo_encoder = onehot_encoder_geo.fit_transform(data[[\"Geography\"]]) # returns a sparse matrix\n", 764 | "geo_encoder = onehot_encoder_geo.fit_transform(data[[\"Geography\"]]).toarray() # returns a numpy array\n", 765 | "geo_encoder" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": 10, 771 | "metadata": {}, 772 | "outputs": [ 773 | { 774 | "data": { 775 | "text/plain": [ 776 | "array(['Geography_France', 'Geography_Germany', 'Geography_Spain'],\n", 777 | " dtype=object)" 778 | ] 779 | }, 780 | "execution_count": 10, 781 | "metadata": {}, 782 | "output_type": "execute_result" 783 | } 784 | ], 785 | "source": [ 786 | "onehot_encoder_geo.get_feature_names_out([\"Geography\"])" 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "execution_count": 11, 792 | "metadata": {}, 793 | "outputs": [], 794 | "source": [ 795 | "geo_encoded_df = pd.DataFrame(geo_encoder, columns=onehot_encoder_geo.get_feature_names_out([\"Geography\"]))" 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "execution_count": 12, 801 | "metadata": {}, 802 | "outputs": [ 803 | { 804 | "data": { 805 | "text/html": [ 806 | "
\n", 807 | "\n", 820 | "\n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | "
CreditScoreGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExitedGeography_FranceGeography_GermanyGeography_Spain
061904220.00111101348.8811.00.00.0
1608041183807.86101112542.5800.00.01.0
25020428159660.80310113931.5711.00.00.0
369903910.0020093826.6301.00.00.0
48500432125510.8211179084.1000.00.01.0
..........................................
999577113950.0021096270.6401.00.00.0
99965161351057369.61111101699.7701.00.00.0
999770903670.0010142085.5811.00.00.0
9998772142375075.3121092888.5210.01.00.0
99997920284130142.7911038190.7801.00.00.0
\n", 1018 | "

10000 rows × 13 columns

\n", 1019 | "
" 1020 | ], 1021 | "text/plain": [ 1022 | " CreditScore Gender Age Tenure Balance NumOfProducts HasCrCard \\\n", 1023 | "0 619 0 42 2 0.00 1 1 \n", 1024 | "1 608 0 41 1 83807.86 1 0 \n", 1025 | "2 502 0 42 8 159660.80 3 1 \n", 1026 | "3 699 0 39 1 0.00 2 0 \n", 1027 | "4 850 0 43 2 125510.82 1 1 \n", 1028 | "... ... ... ... ... ... ... ... \n", 1029 | "9995 771 1 39 5 0.00 2 1 \n", 1030 | "9996 516 1 35 10 57369.61 1 1 \n", 1031 | "9997 709 0 36 7 0.00 1 0 \n", 1032 | "9998 772 1 42 3 75075.31 2 1 \n", 1033 | "9999 792 0 28 4 130142.79 1 1 \n", 1034 | "\n", 1035 | " IsActiveMember EstimatedSalary Exited Geography_France \\\n", 1036 | "0 1 101348.88 1 1.0 \n", 1037 | "1 1 112542.58 0 0.0 \n", 1038 | "2 0 113931.57 1 1.0 \n", 1039 | "3 0 93826.63 0 1.0 \n", 1040 | "4 1 79084.10 0 0.0 \n", 1041 | "... ... ... ... ... \n", 1042 | "9995 0 96270.64 0 1.0 \n", 1043 | "9996 1 101699.77 0 1.0 \n", 1044 | "9997 1 42085.58 1 1.0 \n", 1045 | "9998 0 92888.52 1 0.0 \n", 1046 | "9999 0 38190.78 0 1.0 \n", 1047 | "\n", 1048 | " Geography_Germany Geography_Spain \n", 1049 | "0 0.0 0.0 \n", 1050 | "1 0.0 1.0 \n", 1051 | "2 0.0 0.0 \n", 1052 | "3 0.0 0.0 \n", 1053 | "4 0.0 1.0 \n", 1054 | "... ... ... \n", 1055 | "9995 0.0 0.0 \n", 1056 | "9996 0.0 0.0 \n", 1057 | "9997 0.0 0.0 \n", 1058 | "9998 1.0 0.0 \n", 1059 | "9999 0.0 0.0 \n", 1060 | "\n", 1061 | "[10000 rows x 13 columns]" 1062 | ] 1063 | }, 1064 | "execution_count": 12, 1065 | "metadata": {}, 1066 | "output_type": "execute_result" 1067 | } 1068 | ], 1069 | "source": [ 1070 | "# Combine one hot encoded data with the original data\n", 1071 | "data = pd.concat([data.drop(\"Geography\", axis=1), geo_encoded_df], axis=1)\n", 1072 | "data" 1073 | ] 1074 | }, 1075 | { 1076 | "cell_type": "code", 1077 | "execution_count": null, 1078 | "metadata": {}, 1079 | "outputs": [ 1080 | { 1081 | "data": { 1082 | "text/plain": [ 1083 | "NumOfProducts\n", 1084 | "1 5084\n", 1085 | "2 4590\n", 1086 | "3 266\n", 1087 | "4 60\n", 1088 | "Name: count, dtype: int64" 1089 | ] 1090 | }, 1091 | "execution_count": 51, 1092 | "metadata": {}, 1093 | "output_type": "execute_result" 1094 | } 1095 | ], 1096 | "source": [ 1097 | "data[\"NumOfProducts\"].value_counts()" 1098 | ] 1099 | }, 1100 | { 1101 | "cell_type": "code", 1102 | "execution_count": 52, 1103 | "metadata": {}, 1104 | "outputs": [ 1105 | { 1106 | "data": { 1107 | "text/plain": [ 1108 | "Tenure\n", 1109 | "2 1048\n", 1110 | "1 1035\n", 1111 | "7 1028\n", 1112 | "8 1025\n", 1113 | "5 1012\n", 1114 | "3 1009\n", 1115 | "4 989\n", 1116 | "9 984\n", 1117 | "6 967\n", 1118 | "10 490\n", 1119 | "0 413\n", 1120 | "Name: count, dtype: int64" 1121 | ] 1122 | }, 1123 | "execution_count": 52, 1124 | "metadata": {}, 1125 | "output_type": "execute_result" 1126 | } 1127 | ], 1128 | "source": [ 1129 | "data[\"Tenure\"].value_counts()" 1130 | ] 1131 | }, 1132 | { 1133 | "cell_type": "code", 1134 | "execution_count": 49, 1135 | "metadata": {}, 1136 | "outputs": [ 1137 | { 1138 | "data": { 1139 | "text/html": [ 1140 | "
\n", 1141 | "\n", 1154 | "\n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | "
CreditScoreGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExitedGeography_FranceGeography_GermanyGeography_Spain
061904220.0111101348.8811.00.00.0
\n", 1192 | "
" 1193 | ], 1194 | "text/plain": [ 1195 | " CreditScore Gender Age Tenure Balance NumOfProducts HasCrCard \\\n", 1196 | "0 619 0 42 2 0.0 1 1 \n", 1197 | "\n", 1198 | " IsActiveMember EstimatedSalary Exited Geography_France \\\n", 1199 | "0 1 101348.88 1 1.0 \n", 1200 | "\n", 1201 | " Geography_Germany Geography_Spain \n", 1202 | "0 0.0 0.0 " 1203 | ] 1204 | }, 1205 | "execution_count": 49, 1206 | "metadata": {}, 1207 | "output_type": "execute_result" 1208 | } 1209 | ], 1210 | "source": [ 1211 | "data[0:1]\n", 1212 | "# data.iloc[0]" 1213 | ] 1214 | }, 1215 | { 1216 | "cell_type": "code", 1217 | "execution_count": 13, 1218 | "metadata": {}, 1219 | "outputs": [], 1220 | "source": [ 1221 | "# Save the encoders and scaler\n", 1222 | "import pickle\n", 1223 | "with open(\"../models/label_encoder_gender.pki\", \"wb\") as file:\n", 1224 | " pickle.dump(label_encoder_gender, file)\n", 1225 | "\n", 1226 | "with open(\"../models/onehot_encoder_geo.pki\", \"wb\") as file:\n", 1227 | " pickle.dump(onehot_encoder_geo, file)" 1228 | ] 1229 | }, 1230 | { 1231 | "cell_type": "code", 1232 | "execution_count": 14, 1233 | "metadata": {}, 1234 | "outputs": [ 1235 | { 1236 | "data": { 1237 | "text/html": [ 1238 | "
\n", 1239 | "\n", 1252 | "\n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | " \n", 1275 | " \n", 1276 | " \n", 1277 | " \n", 1278 | " \n", 1279 | " \n", 1280 | " \n", 1281 | " \n", 1282 | " \n", 1283 | " \n", 1284 | " \n", 1285 | " \n", 1286 | " \n", 1287 | " \n", 1288 | " \n", 1289 | " \n", 1290 | " \n", 1291 | " \n", 1292 | " \n", 1293 | " \n", 1294 | " \n", 1295 | " \n", 1296 | " \n", 1297 | " \n", 1298 | " \n", 1299 | " \n", 1300 | " \n", 1301 | " \n", 1302 | " \n", 1303 | " \n", 1304 | " \n", 1305 | " \n", 1306 | " \n", 1307 | " \n", 1308 | " \n", 1309 | " \n", 1310 | " \n", 1311 | " \n", 1312 | " \n", 1313 | " \n", 1314 | " \n", 1315 | " \n", 1316 | " \n", 1317 | " \n", 1318 | " \n", 1319 | " \n", 1320 | " \n", 1321 | " \n", 1322 | " \n", 1323 | " \n", 1324 | " \n", 1325 | " \n", 1326 | " \n", 1327 | " \n", 1328 | " \n", 1329 | " \n", 1330 | " \n", 1331 | " \n", 1332 | " \n", 1333 | " \n", 1334 | " \n", 1335 | " \n", 1336 | " \n", 1337 | " \n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | "
CreditScoreGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExitedGeography_FranceGeography_GermanyGeography_Spain
061904220.00111101348.8811.00.00.0
1608041183807.86101112542.5800.00.01.0
25020428159660.80310113931.5711.00.00.0
369903910.0020093826.6301.00.00.0
48500432125510.8211179084.1000.00.01.0
\n", 1354 | "
" 1355 | ], 1356 | "text/plain": [ 1357 | " CreditScore Gender Age Tenure Balance NumOfProducts HasCrCard \\\n", 1358 | "0 619 0 42 2 0.00 1 1 \n", 1359 | "1 608 0 41 1 83807.86 1 0 \n", 1360 | "2 502 0 42 8 159660.80 3 1 \n", 1361 | "3 699 0 39 1 0.00 2 0 \n", 1362 | "4 850 0 43 2 125510.82 1 1 \n", 1363 | "\n", 1364 | " IsActiveMember EstimatedSalary Exited Geography_France \\\n", 1365 | "0 1 101348.88 1 1.0 \n", 1366 | "1 1 112542.58 0 0.0 \n", 1367 | "2 0 113931.57 1 1.0 \n", 1368 | "3 0 93826.63 0 1.0 \n", 1369 | "4 1 79084.10 0 0.0 \n", 1370 | "\n", 1371 | " Geography_Germany Geography_Spain \n", 1372 | "0 0.0 0.0 \n", 1373 | "1 0.0 1.0 \n", 1374 | "2 0.0 0.0 \n", 1375 | "3 0.0 0.0 \n", 1376 | "4 0.0 1.0 " 1377 | ] 1378 | }, 1379 | "execution_count": 14, 1380 | "metadata": {}, 1381 | "output_type": "execute_result" 1382 | } 1383 | ], 1384 | "source": [ 1385 | "data.head()" 1386 | ] 1387 | }, 1388 | { 1389 | "cell_type": "code", 1390 | "execution_count": 15, 1391 | "metadata": {}, 1392 | "outputs": [], 1393 | "source": [ 1394 | "# Divide the dataset into independent and dependent features\n", 1395 | "X = data.drop(\"Exited\", axis=1)\n", 1396 | "y=data[\"Exited\"]\n", 1397 | "\n", 1398 | "# split the data into training and testing data\n", 1399 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 1400 | "\n", 1401 | "# Scale the feature\n", 1402 | "scaler = StandardScaler()\n", 1403 | "X_train = scaler.fit_transform(X_train)\n", 1404 | "X_test = scaler.transform(X_test)" 1405 | ] 1406 | }, 1407 | { 1408 | "cell_type": "code", 1409 | "execution_count": 16, 1410 | "metadata": {}, 1411 | "outputs": [ 1412 | { 1413 | "data": { 1414 | "text/plain": [ 1415 | "array([[ 0.35649971, 0.91324755, -0.6557859 , ..., 1.00150113,\n", 1416 | " -0.57946723, -0.57638802],\n", 1417 | " [-0.20389777, 0.91324755, 0.29493847, ..., -0.99850112,\n", 1418 | " 1.72572313, -0.57638802],\n", 1419 | " [-0.96147213, 0.91324755, -1.41636539, ..., -0.99850112,\n", 1420 | " -0.57946723, 1.73494238],\n", 1421 | " ...,\n", 1422 | " [ 0.86500853, -1.09499335, -0.08535128, ..., 1.00150113,\n", 1423 | " -0.57946723, -0.57638802],\n", 1424 | " [ 0.15932282, 0.91324755, 0.3900109 , ..., 1.00150113,\n", 1425 | " -0.57946723, -0.57638802],\n", 1426 | " [ 0.47065475, 0.91324755, 1.15059039, ..., -0.99850112,\n", 1427 | " 1.72572313, -0.57638802]])" 1428 | ] 1429 | }, 1430 | "execution_count": 16, 1431 | "metadata": {}, 1432 | "output_type": "execute_result" 1433 | } 1434 | ], 1435 | "source": [ 1436 | "X_train" 1437 | ] 1438 | }, 1439 | { 1440 | "cell_type": "code", 1441 | "execution_count": 17, 1442 | "metadata": {}, 1443 | "outputs": [ 1444 | { 1445 | "name": "stdout", 1446 | "output_type": "stream", 1447 | "text": [ 1448 | "(8000, 12)\n" 1449 | ] 1450 | } 1451 | ], 1452 | "source": [ 1453 | "print(X_train.shape)" 1454 | ] 1455 | }, 1456 | { 1457 | "cell_type": "code", 1458 | "execution_count": 18, 1459 | "metadata": {}, 1460 | "outputs": [], 1461 | "source": [ 1462 | "with open(\"../models/scaler.pki\", \"wb\") as file:\n", 1463 | " pickle.dump(scaler, file)" 1464 | ] 1465 | }, 1466 | { 1467 | "cell_type": "code", 1468 | "execution_count": 19, 1469 | "metadata": {}, 1470 | "outputs": [ 1471 | { 1472 | "data": { 1473 | "text/html": [ 1474 | "
\n", 1475 | "\n", 1488 | "\n", 1489 | " \n", 1490 | " \n", 1491 | " \n", 1492 | " \n", 1493 | " \n", 1494 | " \n", 1495 | " \n", 1496 | " \n", 1497 | " \n", 1498 | " \n", 1499 | " \n", 1500 | " \n", 1501 | " \n", 1502 | " \n", 1503 | " \n", 1504 | " \n", 1505 | " \n", 1506 | " \n", 1507 | " \n", 1508 | " \n", 1509 | " \n", 1510 | " \n", 1511 | " \n", 1512 | " \n", 1513 | " \n", 1514 | " \n", 1515 | " \n", 1516 | " \n", 1517 | " \n", 1518 | " \n", 1519 | " \n", 1520 | " \n", 1521 | " \n", 1522 | " \n", 1523 | " \n", 1524 | " \n", 1525 | " \n", 1526 | " \n", 1527 | " \n", 1528 | " \n", 1529 | " \n", 1530 | " \n", 1531 | " \n", 1532 | " \n", 1533 | " \n", 1534 | " \n", 1535 | " \n", 1536 | " \n", 1537 | " \n", 1538 | " \n", 1539 | " \n", 1540 | " \n", 1541 | " \n", 1542 | " \n", 1543 | " \n", 1544 | " \n", 1545 | " \n", 1546 | " \n", 1547 | " \n", 1548 | " \n", 1549 | " \n", 1550 | " \n", 1551 | " \n", 1552 | " \n", 1553 | " \n", 1554 | " \n", 1555 | " \n", 1556 | " \n", 1557 | " \n", 1558 | " \n", 1559 | " \n", 1560 | " \n", 1561 | " \n", 1562 | " \n", 1563 | " \n", 1564 | " \n", 1565 | " \n", 1566 | " \n", 1567 | " \n", 1568 | " \n", 1569 | " \n", 1570 | " \n", 1571 | " \n", 1572 | " \n", 1573 | " \n", 1574 | " \n", 1575 | " \n", 1576 | " \n", 1577 | " \n", 1578 | " \n", 1579 | " \n", 1580 | " \n", 1581 | " \n", 1582 | " \n", 1583 | " \n", 1584 | " \n", 1585 | " \n", 1586 | " \n", 1587 | " \n", 1588 | " \n", 1589 | " \n", 1590 | " \n", 1591 | " \n", 1592 | " \n", 1593 | " \n", 1594 | " \n", 1595 | " \n", 1596 | " \n", 1597 | " \n", 1598 | " \n", 1599 | " \n", 1600 | " \n", 1601 | " \n", 1602 | " \n", 1603 | " \n", 1604 | " \n", 1605 | " \n", 1606 | " \n", 1607 | " \n", 1608 | " \n", 1609 | " \n", 1610 | " \n", 1611 | " \n", 1612 | " \n", 1613 | " \n", 1614 | " \n", 1615 | " \n", 1616 | " \n", 1617 | " \n", 1618 | " \n", 1619 | " \n", 1620 | " \n", 1621 | " \n", 1622 | " \n", 1623 | " \n", 1624 | " \n", 1625 | " \n", 1626 | " \n", 1627 | " \n", 1628 | " \n", 1629 | " \n", 1630 | " \n", 1631 | " \n", 1632 | " \n", 1633 | " \n", 1634 | " \n", 1635 | " \n", 1636 | " \n", 1637 | " \n", 1638 | " \n", 1639 | " \n", 1640 | " \n", 1641 | " \n", 1642 | " \n", 1643 | " \n", 1644 | " \n", 1645 | " \n", 1646 | " \n", 1647 | " \n", 1648 | " \n", 1649 | " \n", 1650 | " \n", 1651 | " \n", 1652 | " \n", 1653 | " \n", 1654 | " \n", 1655 | " \n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1661 | " \n", 1662 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | "
CreditScoreGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExitedGeography_FranceGeography_GermanyGeography_Spain
061904220.00111101348.8811.00.00.0
1608041183807.86101112542.5800.00.01.0
25020428159660.80310113931.5711.00.00.0
369903910.0020093826.6301.00.00.0
48500432125510.8211179084.1000.00.01.0
..........................................
999577113950.0021096270.6401.00.00.0
99965161351057369.61111101699.7701.00.00.0
999770903670.0010142085.5811.00.00.0
9998772142375075.3121092888.5210.01.00.0
99997920284130142.7911038190.7801.00.00.0
\n", 1686 | "

10000 rows × 13 columns

\n", 1687 | "
" 1688 | ], 1689 | "text/plain": [ 1690 | " CreditScore Gender Age Tenure Balance NumOfProducts HasCrCard \\\n", 1691 | "0 619 0 42 2 0.00 1 1 \n", 1692 | "1 608 0 41 1 83807.86 1 0 \n", 1693 | "2 502 0 42 8 159660.80 3 1 \n", 1694 | "3 699 0 39 1 0.00 2 0 \n", 1695 | "4 850 0 43 2 125510.82 1 1 \n", 1696 | "... ... ... ... ... ... ... ... \n", 1697 | "9995 771 1 39 5 0.00 2 1 \n", 1698 | "9996 516 1 35 10 57369.61 1 1 \n", 1699 | "9997 709 0 36 7 0.00 1 0 \n", 1700 | "9998 772 1 42 3 75075.31 2 1 \n", 1701 | "9999 792 0 28 4 130142.79 1 1 \n", 1702 | "\n", 1703 | " IsActiveMember EstimatedSalary Exited Geography_France \\\n", 1704 | "0 1 101348.88 1 1.0 \n", 1705 | "1 1 112542.58 0 0.0 \n", 1706 | "2 0 113931.57 1 1.0 \n", 1707 | "3 0 93826.63 0 1.0 \n", 1708 | "4 1 79084.10 0 0.0 \n", 1709 | "... ... ... ... ... \n", 1710 | "9995 0 96270.64 0 1.0 \n", 1711 | "9996 1 101699.77 0 1.0 \n", 1712 | "9997 1 42085.58 1 1.0 \n", 1713 | "9998 0 92888.52 1 0.0 \n", 1714 | "9999 0 38190.78 0 1.0 \n", 1715 | "\n", 1716 | " Geography_Germany Geography_Spain \n", 1717 | "0 0.0 0.0 \n", 1718 | "1 0.0 1.0 \n", 1719 | "2 0.0 0.0 \n", 1720 | "3 0.0 0.0 \n", 1721 | "4 0.0 1.0 \n", 1722 | "... ... ... \n", 1723 | "9995 0.0 0.0 \n", 1724 | "9996 0.0 0.0 \n", 1725 | "9997 0.0 0.0 \n", 1726 | "9998 1.0 0.0 \n", 1727 | "9999 0.0 0.0 \n", 1728 | "\n", 1729 | "[10000 rows x 13 columns]" 1730 | ] 1731 | }, 1732 | "execution_count": 19, 1733 | "metadata": {}, 1734 | "output_type": "execute_result" 1735 | } 1736 | ], 1737 | "source": [ 1738 | "data" 1739 | ] 1740 | }, 1741 | { 1742 | "cell_type": "code", 1743 | "execution_count": 20, 1744 | "metadata": {}, 1745 | "outputs": [], 1746 | "source": [ 1747 | "import tensorflow\n", 1748 | "from tensorflow.keras.models import Sequential\n", 1749 | "from tensorflow.keras.layers import Dense\n", 1750 | "from tensorflow.keras.callbacks import EarlyStopping, TensorBoard\n", 1751 | "import datetime" 1752 | ] 1753 | }, 1754 | { 1755 | "cell_type": "code", 1756 | "execution_count": 22, 1757 | "metadata": {}, 1758 | "outputs": [ 1759 | { 1760 | "data": { 1761 | "text/plain": [ 1762 | "12" 1763 | ] 1764 | }, 1765 | "execution_count": 22, 1766 | "metadata": {}, 1767 | "output_type": "execute_result" 1768 | } 1769 | ], 1770 | "source": [ 1771 | "X_train.shape[1]" 1772 | ] 1773 | }, 1774 | { 1775 | "cell_type": "code", 1776 | "execution_count": 23, 1777 | "metadata": {}, 1778 | "outputs": [ 1779 | { 1780 | "name": "stdout", 1781 | "output_type": "stream", 1782 | "text": [ 1783 | "WARNING:tensorflow:From c:\\Users\\HP\\Documents\\appliso-genai-class\\class-project\\churn-classification\\venv\\Lib\\site-packages\\keras\\src\\backend.py:873: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", 1784 | "\n" 1785 | ] 1786 | } 1787 | ], 1788 | "source": [ 1789 | "# HL -> Hidden Layer\n", 1790 | "model = Sequential([\n", 1791 | " Dense(64, activation=\"relu\", input_shape=(X_train.shape[1],)), #HL1 connected with the input layer\n", 1792 | " Dense(32, activation=\"relu\"), #HL2\n", 1793 | " Dense(1, activation=\"sigmoid\") #Output layer\n", 1794 | "])" 1795 | ] 1796 | }, 1797 | { 1798 | "cell_type": "code", 1799 | "execution_count": 24, 1800 | "metadata": {}, 1801 | "outputs": [ 1802 | { 1803 | "name": "stdout", 1804 | "output_type": "stream", 1805 | "text": [ 1806 | "Model: \"sequential\"\n", 1807 | "_________________________________________________________________\n", 1808 | " Layer (type) Output Shape Param # \n", 1809 | "=================================================================\n", 1810 | " dense (Dense) (None, 64) 832 \n", 1811 | " \n", 1812 | " dense_1 (Dense) (None, 32) 2080 \n", 1813 | " \n", 1814 | " dense_2 (Dense) (None, 1) 33 \n", 1815 | " \n", 1816 | "=================================================================\n", 1817 | "Total params: 2945 (11.50 KB)\n", 1818 | "Trainable params: 2945 (11.50 KB)\n", 1819 | "Non-trainable params: 0 (0.00 Byte)\n", 1820 | "_________________________________________________________________\n" 1821 | ] 1822 | } 1823 | ], 1824 | "source": [ 1825 | "model.summary()" 1826 | ] 1827 | }, 1828 | { 1829 | "cell_type": "code", 1830 | "execution_count": 25, 1831 | "metadata": {}, 1832 | "outputs": [], 1833 | "source": [ 1834 | "opt = tensorflow.keras.optimizers.Adam(learning_rate=0.01)" 1835 | ] 1836 | }, 1837 | { 1838 | "cell_type": "code", 1839 | "execution_count": 26, 1840 | "metadata": {}, 1841 | "outputs": [], 1842 | "source": [ 1843 | "# Compile the model\n", 1844 | "model.compile(optimizer=opt, loss=\"binary_crossentropy\", metrics=[\"accuracy\"])" 1845 | ] 1846 | }, 1847 | { 1848 | "cell_type": "code", 1849 | "execution_count": 28, 1850 | "metadata": {}, 1851 | "outputs": [ 1852 | { 1853 | "data": { 1854 | "text/plain": [ 1855 | "'20250114-020126'" 1856 | ] 1857 | }, 1858 | "execution_count": 28, 1859 | "metadata": {}, 1860 | "output_type": "execute_result" 1861 | } 1862 | ], 1863 | "source": [ 1864 | "datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")" 1865 | ] 1866 | }, 1867 | { 1868 | "cell_type": "code", 1869 | "execution_count": 29, 1870 | "metadata": {}, 1871 | "outputs": [], 1872 | "source": [ 1873 | "log_dir = \"../logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")" 1874 | ] 1875 | }, 1876 | { 1877 | "cell_type": "code", 1878 | "execution_count": 30, 1879 | "metadata": {}, 1880 | "outputs": [], 1881 | "source": [ 1882 | "tensorflow_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)" 1883 | ] 1884 | }, 1885 | { 1886 | "cell_type": "code", 1887 | "execution_count": null, 1888 | "metadata": {}, 1889 | "outputs": [], 1890 | "source": [ 1891 | "early_stopping_callback = EarlyStopping(monitor=\"val_loss\", patience=10, restore_best_weights=True)" 1892 | ] 1893 | }, 1894 | { 1895 | "cell_type": "code", 1896 | "execution_count": 34, 1897 | "metadata": {}, 1898 | "outputs": [ 1899 | { 1900 | "name": "stdout", 1901 | "output_type": "stream", 1902 | "text": [ 1903 | "Epoch 1/100\n", 1904 | "640/640 [==============================] - 4s 7ms/step - loss: 0.3505 - accuracy: 0.8606 - val_loss: 0.3562 - val_accuracy: 0.8569\n", 1905 | "Epoch 2/100\n", 1906 | "640/640 [==============================] - 4s 7ms/step - loss: 0.3436 - accuracy: 0.8620 - val_loss: 0.3523 - val_accuracy: 0.8500\n", 1907 | "Epoch 3/100\n", 1908 | "640/640 [==============================] - 5s 7ms/step - loss: 0.3407 - accuracy: 0.8623 - val_loss: 0.3449 - val_accuracy: 0.8537\n", 1909 | "Epoch 4/100\n", 1910 | "640/640 [==============================] - 5s 8ms/step - loss: 0.3380 - accuracy: 0.8630 - val_loss: 0.3552 - val_accuracy: 0.8537\n", 1911 | "Epoch 5/100\n", 1912 | "640/640 [==============================] - 7s 10ms/step - loss: 0.3365 - accuracy: 0.8648 - val_loss: 0.3616 - val_accuracy: 0.8519\n", 1913 | "Epoch 6/100\n", 1914 | "640/640 [==============================] - 6s 10ms/step - loss: 0.3343 - accuracy: 0.8652 - val_loss: 0.3675 - val_accuracy: 0.8487\n", 1915 | "Epoch 7/100\n", 1916 | "640/640 [==============================] - 7s 10ms/step - loss: 0.3329 - accuracy: 0.8647 - val_loss: 0.3581 - val_accuracy: 0.8544\n", 1917 | "Epoch 8/100\n", 1918 | "640/640 [==============================] - 6s 10ms/step - loss: 0.3300 - accuracy: 0.8664 - val_loss: 0.3573 - val_accuracy: 0.8512\n" 1919 | ] 1920 | } 1921 | ], 1922 | "source": [ 1923 | "history = model.fit(X_train, y_train, validation_split=0.2, epochs=100, batch_size=10, callbacks=[tensorflow_callback, early_stopping_callback])" 1924 | ] 1925 | }, 1926 | { 1927 | "cell_type": "code", 1928 | "execution_count": 37, 1929 | "metadata": {}, 1930 | "outputs": [ 1931 | { 1932 | "name": "stderr", 1933 | "output_type": "stream", 1934 | "text": [ 1935 | "c:\\Users\\HP\\Documents\\appliso-genai-class\\class-project\\churn-classification\\venv\\Lib\\site-packages\\keras\\src\\engine\\training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.\n", 1936 | " saving_api.save_model(\n" 1937 | ] 1938 | } 1939 | ], 1940 | "source": [ 1941 | "model.save(\"../models/model.h5\")" 1942 | ] 1943 | }, 1944 | { 1945 | "cell_type": "code", 1946 | "execution_count": 38, 1947 | "metadata": {}, 1948 | "outputs": [], 1949 | "source": [ 1950 | "## Load Tensorboard Extension\n", 1951 | "%load_ext tensorboard" 1952 | ] 1953 | }, 1954 | { 1955 | "cell_type": "code", 1956 | "execution_count": 40, 1957 | "metadata": {}, 1958 | "outputs": [ 1959 | { 1960 | "data": { 1961 | "text/plain": [ 1962 | "Reusing TensorBoard on port 6006 (pid 20132), started 0:00:53 ago. (Use '!kill 20132' to kill it.)" 1963 | ] 1964 | }, 1965 | "metadata": {}, 1966 | "output_type": "display_data" 1967 | }, 1968 | { 1969 | "data": { 1970 | "text/html": [ 1971 | "\n", 1972 | " \n", 1974 | " \n", 1985 | " " 1986 | ], 1987 | "text/plain": [ 1988 | "" 1989 | ] 1990 | }, 1991 | "metadata": {}, 1992 | "output_type": "display_data" 1993 | } 1994 | ], 1995 | "source": [ 1996 | "%tensorboard --logdir logs/fit" 1997 | ] 1998 | } 1999 | ], 2000 | "metadata": { 2001 | "kernelspec": { 2002 | "display_name": "Python 3", 2003 | "language": "python", 2004 | "name": "python3" 2005 | }, 2006 | "language_info": { 2007 | "codemirror_mode": { 2008 | "name": "ipython", 2009 | "version": 3 2010 | }, 2011 | "file_extension": ".py", 2012 | "mimetype": "text/x-python", 2013 | "name": "python", 2014 | "nbconvert_exporter": "python", 2015 | "pygments_lexer": "ipython3", 2016 | "version": "3.11.0" 2017 | } 2018 | }, 2019 | "nbformat": 4, 2020 | "nbformat_minor": 2 2021 | } 2022 | --------------------------------------------------------------------------------