├── .DS_Store ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── images ├── .DS_Store ├── AutoEncoder.png ├── Autoencoder_structure.png ├── function.png ├── image-01.png ├── image-02.png ├── image-03.png ├── image-04.png ├── image-05.png ├── image-06.png ├── image-07.png ├── image-08.png ├── image-09.png ├── image-10.png ├── image-11.png ├── image-12.png ├── image-13.png └── image-14.png └── notebooks ├── AutoEncoders.ipynb ├── NeuralNetOverSampling.ipynb ├── sagemaker_fraud_detection_xgb.ipynb └── train_nn.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/.DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Sagemaker Fraud Detection Workshop 2 | 3 | ### Lab description 4 | 5 | This lab demonstrates three different ML algorithms used for identifying fraudelent transactions on the same dataset: 6 | - SageMaker XGBoost 7 | - AutoEncoders 8 | - Neural Networks 9 | 10 | ### Steps for launching the workshop environment using EVENT ENGINE 11 | Note: these steps were tested on Chrome browser using Mac OS 12 | #### open a browser and navigate to https://dashboard.eventengine.run/login 13 | #### Enter a 12-character "hash" provided to you by workshop organizer. 14 | #### Click on "Accpet Terms & Login" 15 | ![Navigate to Sagemaker Service](/images/image-01.png) 16 | 17 | #### Click on "AWS Console" 18 | ![Navigate to Sagemaker Service](/images/image-02.png) 19 | 20 | #### Please, log off from any other AWS accounts you are currently logged into 21 | 22 | #### Click on "Open AWS Console" 23 | ![Navigate to Sagemaker Service](images/image-03.png) 24 | 25 | #### You should see a screen like this. 26 | #### We now need select the correct Identity Role for the workshop 27 | #### Type "IAM" into the search bar and click on IAM 28 | (Identity and Access Management). 29 | ![Navigate to Sagemaker Service](/images/image-04.png) 30 | 31 | #### Click on "Roles" 32 | ![Navigate to Sagemaker Service](/images/image-05.png) 33 | 34 | #### Scroll down past "Create Role" and Click on "TeamRole" 35 | ![Navigate to Sagemaker Service](/images/image-06.png) 36 | 37 | #### Copy "Role ARN" by selecting the copy icon on the right 38 | #### You may want to temporariliy paste this role ARN into a notepad 39 | #### Once you copied TeamRole ARN, click on "Services" in the upper left corner 40 | ![Navigate to Sagemaker Service](/images/image-07.png) 41 | 42 | #### Enter "SageMaker" in the search bar and click on it 43 | ![Navigate to Sagemaker Service](/images/image-08.png) 44 | 45 | #### You should see a screen like this. 46 | #### Click on the orange button "Create Notebook Instance" 47 | ![Navigate to Sagemaker Service](/images/image-09.png) 48 | 49 | #### On the next webpage, 50 | #### - Give your notebook a name (no underscores, please) 51 | #### - Under Notebook instance type, select "ml.c5.2xlarge" 52 | #### - Under "Permission and encryption" select "Enter a custom IAM role ARN"; 53 | #### - Paste your TeamRole ARN in the cell below labled "Custom IAM role ARN" 54 | #### Note: your TeamRole ARN will have different AWS account number than what you see here 55 | #### - Scroll down to the bottom of the page and click on "Create Notebook instance" 56 | ![Navigate to Sagemaker Service](/images/image-10.png) 57 | 58 | #### You should see your notebook being created. In a couple of minutes, its status will change 59 | #### from "Pending" to "In Service", at which point, please click on "Open Jupyter" 60 | ![Navigate to Sagemaker Service](/images/image-11.png) 61 | 62 | #### In Jupyter Notebook console, please, click on 'New' -> 'Terminal' on the right-hand side 63 | ![Navigate to Sagemaker Service](/images/image-12.png) 64 | 65 | #### A new Chrome browser tab will open displaying a command prompt terminal 66 | #### In the terminal tap, please, issue these two commands: 67 | #### $ cd SageMaker 68 | #### $ git clone https://github.com/aws-samples/amazon-sagemaker-fraud-detection 69 | #### You should see output similar to this: 70 | ![Navigate to Sagemaker Service](/images/image-13.png) 71 | 72 | #### You may now close the browser tab with command prompt terminal, 73 | #### return to Jupyter console and navigate the created folder structure to 74 | #### amazon-sagemaker-fraud-detection -> notebooks 75 | #### launch and run each one of the three Jupyter notebooks 76 | ![Navigate to Sagemaker Service](/images/image-14.png) 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | #### Open SageMaker Console by clicking on "Services" and searching for Sagemaker 89 | ![Navigate to Sagemaker Service](/images/image-08.png) 90 | 91 | ## License 92 | 93 | This library is licensed under the MIT-0 License. See the LICENSE file. 94 | 95 | -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/.DS_Store -------------------------------------------------------------------------------- /images/AutoEncoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/AutoEncoder.png -------------------------------------------------------------------------------- /images/Autoencoder_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/Autoencoder_structure.png -------------------------------------------------------------------------------- /images/function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/function.png -------------------------------------------------------------------------------- /images/image-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-01.png -------------------------------------------------------------------------------- /images/image-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-02.png -------------------------------------------------------------------------------- /images/image-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-03.png -------------------------------------------------------------------------------- /images/image-04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-04.png -------------------------------------------------------------------------------- /images/image-05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-05.png -------------------------------------------------------------------------------- /images/image-06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-06.png -------------------------------------------------------------------------------- /images/image-07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-07.png -------------------------------------------------------------------------------- /images/image-08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-08.png -------------------------------------------------------------------------------- /images/image-09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-09.png -------------------------------------------------------------------------------- /images/image-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-10.png -------------------------------------------------------------------------------- /images/image-11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-11.png -------------------------------------------------------------------------------- /images/image-12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-12.png -------------------------------------------------------------------------------- /images/image-13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-13.png -------------------------------------------------------------------------------- /images/image-14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-fraud-detection/66c88ab1f2b63686d052fe2febb9324b7847607d/images/image-14.png -------------------------------------------------------------------------------- /notebooks/NeuralNetOverSampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Fraud Detection Using Neural Network - A Supervised Deep Learning Method" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Introduction\n", 15 | "In this lab, we are going to use the Neural Network to perform fraud detection. We are going to use the same dataset i.e. credit card data set. \n", 16 | "\n", 17 | "From previous labs we know that our dataset is highly imbalanced. The class column corresponds to whether or not a transaction is fradulent. We see that the majority of data is non-fraudulant with only $492$ ($.173\\%$) of the data corresponding to fraudulant examples.\n", 18 | "\n", 19 | "For unbalanced data sets like ours where the positive (fraudulent) examples occur much less frequently than the negative (legitimate) examples, we may try “over-sampling” the minority dataset by generating synthetic data (read about SMOTE in Data Mining for Imbalanced Datasets: An Overview (https://link.springer.com/chapter/10.1007%2F0-387-25465-X_40) or undersampling the majority class by using ensemble methods (see http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.68.6858&rep=rep1&type=pdfor).\n", 20 | "\n", 21 | "Let's start by installing one of the liabraries for SMOTE technique." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "!pip install -U imbalanced-learn" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# creating directory structure\n", 40 | "!mkdir ../data\n", 41 | "!mkdir ../model\n", 42 | "!mkdir ../logs" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | "Using TensorFlow backend.\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "# first neural network with keras tutorial\n", 60 | "from numpy import loadtxt\n", 61 | "from keras.models import Sequential\n", 62 | "from keras.layers import Dense\n", 63 | "import pandas as pd\n", 64 | "import numpy as np\n", 65 | "from imblearn.over_sampling import SMOTE, ADASYN\n", 66 | "from sagemaker.tensorflow import TensorFlow\n", 67 | "from collections import Counter\n", 68 | "import matplotlib.pyplot as plt\n", 69 | "import seaborn as sns" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "## Downloading data" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "!curl https://s3-us-west-2.amazonaws.com/sagemaker-e2e-solutions/fraud-detection/creditcardfraud.zip -o ../data/creditcardfraud.zip" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Archive: ../data/creditcardfraud.zip\n", 98 | " inflating: ../data/creditcard.csv \n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "!unzip -o ../data/creditcardfraud.zip -d ../data/" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "## Load and Visualize" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 6, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "data = pd.read_csv('../data/creditcard.csv', delimiter=',')" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 7, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/html": [ 130 | "
\n", 131 | "\n", 144 | "\n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | "
TimeV1V2V3V4V5V6V7V8V9...V21V22V23V24V25V26V27V28AmountClass
00.0-1.359807-0.0727812.5363471.378155-0.3383210.4623880.2395990.0986980.363787...-0.0183070.277838-0.1104740.0669280.128539-0.1891150.133558-0.021053149.620
10.01.1918570.2661510.1664800.4481540.060018-0.082361-0.0788030.085102-0.255425...-0.225775-0.6386720.101288-0.3398460.1671700.125895-0.0089830.0147242.690
21.0-1.358354-1.3401631.7732090.379780-0.5031981.8004990.7914610.247676-1.514654...0.2479980.7716790.909412-0.689281-0.327642-0.139097-0.055353-0.059752378.660
31.0-0.966272-0.1852261.792993-0.863291-0.0103091.2472030.2376090.377436-1.387024...-0.1083000.005274-0.190321-1.1755750.647376-0.2219290.0627230.061458123.500
42.0-1.1582330.8777371.5487180.403034-0.4071930.0959210.592941-0.2705330.817739...-0.0094310.798278-0.1374580.141267-0.2060100.5022920.2194220.21515369.990
52.0-0.4259660.9605231.141109-0.1682520.420987-0.0297280.4762010.260314-0.568671...-0.208254-0.559825-0.026398-0.371427-0.2327940.1059150.2538440.0810803.670
64.01.2296580.1410040.0453711.2026130.1918810.272708-0.0051590.0812130.464960...-0.167716-0.270710-0.154104-0.7800550.750137-0.2572370.0345070.0051684.990
77.0-0.6442691.4179641.074380-0.4921990.9489340.4281181.120631-3.8078640.615375...1.943465-1.0154550.057504-0.649709-0.415267-0.051634-1.206921-1.08533940.800
87.0-0.8942860.286157-0.113192-0.2715262.6695993.7218180.3701450.851084-0.392048...-0.073425-0.268092-0.2042331.0115920.373205-0.3841570.0117470.14240493.200
99.0-0.3382621.1195931.044367-0.2221870.499361-0.2467610.6515830.069539-0.736727...-0.246914-0.633753-0.120794-0.385050-0.0697330.0941990.2462190.0830763.680
\n", 414 | "

10 rows × 31 columns

\n", 415 | "
" 416 | ], 417 | "text/plain": [ 418 | " Time V1 V2 V3 V4 V5 V6 V7 \\\n", 419 | "0 0.0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 0.239599 \n", 420 | "1 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 \n", 421 | "2 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 \n", 422 | "3 1.0 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 0.237609 \n", 423 | "4 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 \n", 424 | "5 2.0 -0.425966 0.960523 1.141109 -0.168252 0.420987 -0.029728 0.476201 \n", 425 | "6 4.0 1.229658 0.141004 0.045371 1.202613 0.191881 0.272708 -0.005159 \n", 426 | "7 7.0 -0.644269 1.417964 1.074380 -0.492199 0.948934 0.428118 1.120631 \n", 427 | "8 7.0 -0.894286 0.286157 -0.113192 -0.271526 2.669599 3.721818 0.370145 \n", 428 | "9 9.0 -0.338262 1.119593 1.044367 -0.222187 0.499361 -0.246761 0.651583 \n", 429 | "\n", 430 | " V8 V9 ... V21 V22 V23 V24 V25 \\\n", 431 | "0 0.098698 0.363787 ... -0.018307 0.277838 -0.110474 0.066928 0.128539 \n", 432 | "1 0.085102 -0.255425 ... -0.225775 -0.638672 0.101288 -0.339846 0.167170 \n", 433 | "2 0.247676 -1.514654 ... 0.247998 0.771679 0.909412 -0.689281 -0.327642 \n", 434 | "3 0.377436 -1.387024 ... -0.108300 0.005274 -0.190321 -1.175575 0.647376 \n", 435 | "4 -0.270533 0.817739 ... -0.009431 0.798278 -0.137458 0.141267 -0.206010 \n", 436 | "5 0.260314 -0.568671 ... -0.208254 -0.559825 -0.026398 -0.371427 -0.232794 \n", 437 | "6 0.081213 0.464960 ... -0.167716 -0.270710 -0.154104 -0.780055 0.750137 \n", 438 | "7 -3.807864 0.615375 ... 1.943465 -1.015455 0.057504 -0.649709 -0.415267 \n", 439 | "8 0.851084 -0.392048 ... -0.073425 -0.268092 -0.204233 1.011592 0.373205 \n", 440 | "9 0.069539 -0.736727 ... -0.246914 -0.633753 -0.120794 -0.385050 -0.069733 \n", 441 | "\n", 442 | " V26 V27 V28 Amount Class \n", 443 | "0 -0.189115 0.133558 -0.021053 149.62 0 \n", 444 | "1 0.125895 -0.008983 0.014724 2.69 0 \n", 445 | "2 -0.139097 -0.055353 -0.059752 378.66 0 \n", 446 | "3 -0.221929 0.062723 0.061458 123.50 0 \n", 447 | "4 0.502292 0.219422 0.215153 69.99 0 \n", 448 | "5 0.105915 0.253844 0.081080 3.67 0 \n", 449 | "6 -0.257237 0.034507 0.005168 4.99 0 \n", 450 | "7 -0.051634 -1.206921 -1.085339 40.80 0 \n", 451 | "8 -0.384157 0.011747 0.142404 93.20 0 \n", 452 | "9 0.094199 0.246219 0.083076 3.68 0 \n", 453 | "\n", 454 | "[10 rows x 31 columns]" 455 | ] 456 | }, 457 | "execution_count": 7, 458 | "metadata": {}, 459 | "output_type": "execute_result" 460 | } 461 | ], 462 | "source": [ 463 | "data.head(10)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 8, 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "data": { 473 | "text/plain": [ 474 | "Text(0, 0.5, 'Frequency')" 475 | ] 476 | }, 477 | "execution_count": 8, 478 | "metadata": {}, 479 | "output_type": "execute_result" 480 | }, 481 | { 482 | "data": { 483 | "image/png": "\n", 484 | "text/plain": [ 485 | "
" 486 | ] 487 | }, 488 | "metadata": {}, 489 | "output_type": "display_data" 490 | } 491 | ], 492 | "source": [ 493 | "labels = ['normal','fraud']\n", 494 | "classes = pd.value_counts(data['Class'], sort = True)\n", 495 | "classes.plot(kind = 'bar', rot=0)\n", 496 | "plt.title(\"Transaction class distribution\")\n", 497 | "plt.xticks(range(2), labels)\n", 498 | "plt.xlabel(\"Class\")\n", 499 | "plt.ylabel(\"Frequency\")" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "metadata": {}, 505 | "source": [ 506 | "As we learned in previous labs, that features like 'Time', and 'Amount' are not that relevant. So let's drop these fields." 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 9, 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "data = data.drop(['Time','Amount'],axis=1)" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "## Split dateset into Train, Validation, and Test" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 10, 528 | "metadata": {}, 529 | "outputs": [], 530 | "source": [ 531 | "import boto3\n", 532 | "import os\n", 533 | "import sagemaker\n", 534 | "from sklearn.model_selection import train_test_split\n", 535 | "\n", 536 | "session = sagemaker.Session()\n", 537 | "\n", 538 | "bucket = session.default_bucket()\n", 539 | "sagemaker_iam_role = sagemaker.get_execution_role()\n", 540 | "\n", 541 | "prefix = 'sagemaker/NeuralNetwork-fraud'\n", 542 | "\n", 543 | "RANDOM_SEED = 314 #used to help randomly select the data points\n", 544 | "TEST_PCT = 0.2 # 20% of the data\n", 545 | "\n", 546 | "train_data, test_data = train_test_split(data, test_size=TEST_PCT,random_state=RANDOM_SEED)\n", 547 | "validation_data, test_data = train_test_split(test_data, test_size=0.5,random_state=RANDOM_SEED)" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": {}, 553 | "source": [ 554 | "## Balacing our Training dataset" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 11, 560 | "metadata": {}, 561 | "outputs": [ 562 | { 563 | "name": "stdout", 564 | "output_type": "stream", 565 | "text": [ 566 | "(227845, 28)\n", 567 | "(227845, 1)\n" 568 | ] 569 | }, 570 | { 571 | "name": "stderr", 572 | "output_type": "stream", 573 | "text": [ 574 | "/home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/sklearn/utils/validation.py:724: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", 575 | " y = column_or_1d(y, warn=True)\n" 576 | ] 577 | }, 578 | { 579 | "name": "stdout", 580 | "output_type": "stream", 581 | "text": [ 582 | "(454936, 28)\n", 583 | "(454936, 1)\n" 584 | ] 585 | }, 586 | { 587 | "data": { 588 | "text/plain": [ 589 | "(454936, 29)" 590 | ] 591 | }, 592 | "execution_count": 11, 593 | "metadata": {}, 594 | "output_type": "execute_result" 595 | } 596 | ], 597 | "source": [ 598 | "X = train_data.iloc[:, 0:28]\n", 599 | "y = train_data.iloc[:,28:29]\n", 600 | "print(X.shape)\n", 601 | "print(y.shape)\n", 602 | "\n", 603 | "X_resampled, y_resampled = SMOTE().fit_resample(X, y)\n", 604 | "y_resampled = y_resampled.reshape((y_resampled.shape[0],1))\n", 605 | "\n", 606 | "X = X_resampled\n", 607 | "y = y_resampled\n", 608 | "\n", 609 | "print(X.shape)\n", 610 | "print(y.shape)\n", 611 | "train_data = np.concatenate((X, y),axis=1)\n", 612 | "train_data = pd.DataFrame(train_data)\n", 613 | "train_data.shape" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": 12, 619 | "metadata": {}, 620 | "outputs": [ 621 | { 622 | "data": { 623 | "text/html": [ 624 | "
\n", 625 | "\n", 638 | "\n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | "
0123456789...19202122232425262728
00.2605380.644354-0.649921-0.0534630.883230-0.7579730.787879-0.1420560.291431-0.926765...-0.1265910.1246000.572313-0.245022-0.698144-0.3753060.7152180.0813760.1082980.0
1-1.3822660.1121620.5417461.459137-1.3201200.8815051.2748780.692477-0.523981-0.419198...0.6838670.2461740.1604480.774357-0.064849-0.023380-0.2424200.1588460.1535500.0
21.998524-0.228993-0.2768570.572004-0.3057670.025961-0.706267-0.0688672.603357-0.577862...-0.1847250.0432190.6573040.1504400.592132-0.064754-0.2776140.020307-0.0362670.0
30.736055-0.4231050.2802551.521043-0.739431-0.9827840.463617-0.284605-0.061795-0.048532...0.3465340.2405330.182956-0.2855090.7274830.538996-0.326837-0.0278060.0680510.0
42.063104-0.064287-0.9551700.639229-0.229669-1.213497-0.025961-0.3170862.089984-0.246140...-0.379353-0.459128-1.0266460.373861-0.062467-0.3817830.162610-0.108893-0.0699740.0
51.938746-1.348070-0.571826-0.898448-1.164698-0.275187-0.9977560.074171-0.0383860.851980...0.1260710.041904-0.1518100.264565-0.440844-0.564571-0.486713-0.012370-0.0387350.0
6-0.7668211.1688860.9516800.7173841.039082-0.3466631.765473-0.569001-0.5787340.621377...0.109518-0.0406730.372062-0.7058670.0107091.003730-0.266130-0.277632-0.3681660.0
71.2956940.4004210.0983590.558956-0.036970-0.7465470.129645-0.215352-0.053382-0.305466...-0.012401-0.313920-0.8622300.053621-0.1709140.3154380.125078-0.0183830.0270280.0
81.2464540.038172-0.059950-0.096151-0.195872-0.8124500.183660-0.112738-0.1998240.046081...-0.061652-0.411634-1.3484990.148395-0.0002300.0844800.625746-0.113259-0.0082120.0
92.0167740.007711-3.2980510.1250373.0765083.0375940.1991250.576568-0.3066990.383016...-0.1667500.1315680.3283180.0075290.7018160.502323-0.465696-0.016936-0.0663470.0
\n", 908 | "

10 rows × 29 columns

\n", 909 | "
" 910 | ], 911 | "text/plain": [ 912 | " 0 1 2 3 4 5 6 \\\n", 913 | "0 0.260538 0.644354 -0.649921 -0.053463 0.883230 -0.757973 0.787879 \n", 914 | "1 -1.382266 0.112162 0.541746 1.459137 -1.320120 0.881505 1.274878 \n", 915 | "2 1.998524 -0.228993 -0.276857 0.572004 -0.305767 0.025961 -0.706267 \n", 916 | "3 0.736055 -0.423105 0.280255 1.521043 -0.739431 -0.982784 0.463617 \n", 917 | "4 2.063104 -0.064287 -0.955170 0.639229 -0.229669 -1.213497 -0.025961 \n", 918 | "5 1.938746 -1.348070 -0.571826 -0.898448 -1.164698 -0.275187 -0.997756 \n", 919 | "6 -0.766821 1.168886 0.951680 0.717384 1.039082 -0.346663 1.765473 \n", 920 | "7 1.295694 0.400421 0.098359 0.558956 -0.036970 -0.746547 0.129645 \n", 921 | "8 1.246454 0.038172 -0.059950 -0.096151 -0.195872 -0.812450 0.183660 \n", 922 | "9 2.016774 0.007711 -3.298051 0.125037 3.076508 3.037594 0.199125 \n", 923 | "\n", 924 | " 7 8 9 ... 19 20 21 22 \\\n", 925 | "0 -0.142056 0.291431 -0.926765 ... -0.126591 0.124600 0.572313 -0.245022 \n", 926 | "1 0.692477 -0.523981 -0.419198 ... 0.683867 0.246174 0.160448 0.774357 \n", 927 | "2 -0.068867 2.603357 -0.577862 ... -0.184725 0.043219 0.657304 0.150440 \n", 928 | "3 -0.284605 -0.061795 -0.048532 ... 0.346534 0.240533 0.182956 -0.285509 \n", 929 | "4 -0.317086 2.089984 -0.246140 ... -0.379353 -0.459128 -1.026646 0.373861 \n", 930 | "5 0.074171 -0.038386 0.851980 ... 0.126071 0.041904 -0.151810 0.264565 \n", 931 | "6 -0.569001 -0.578734 0.621377 ... 0.109518 -0.040673 0.372062 -0.705867 \n", 932 | "7 -0.215352 -0.053382 -0.305466 ... -0.012401 -0.313920 -0.862230 0.053621 \n", 933 | "8 -0.112738 -0.199824 0.046081 ... -0.061652 -0.411634 -1.348499 0.148395 \n", 934 | "9 0.576568 -0.306699 0.383016 ... -0.166750 0.131568 0.328318 0.007529 \n", 935 | "\n", 936 | " 23 24 25 26 27 28 \n", 937 | "0 -0.698144 -0.375306 0.715218 0.081376 0.108298 0.0 \n", 938 | "1 -0.064849 -0.023380 -0.242420 0.158846 0.153550 0.0 \n", 939 | "2 0.592132 -0.064754 -0.277614 0.020307 -0.036267 0.0 \n", 940 | "3 0.727483 0.538996 -0.326837 -0.027806 0.068051 0.0 \n", 941 | "4 -0.062467 -0.381783 0.162610 -0.108893 -0.069974 0.0 \n", 942 | "5 -0.440844 -0.564571 -0.486713 -0.012370 -0.038735 0.0 \n", 943 | "6 0.010709 1.003730 -0.266130 -0.277632 -0.368166 0.0 \n", 944 | "7 -0.170914 0.315438 0.125078 -0.018383 0.027028 0.0 \n", 945 | "8 -0.000230 0.084480 0.625746 -0.113259 -0.008212 0.0 \n", 946 | "9 0.701816 0.502323 -0.465696 -0.016936 -0.066347 0.0 \n", 947 | "\n", 948 | "[10 rows x 29 columns]" 949 | ] 950 | }, 951 | "execution_count": 12, 952 | "metadata": {}, 953 | "output_type": "execute_result" 954 | } 955 | ], 956 | "source": [ 957 | "train_data.head(10)" 958 | ] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "execution_count": 13, 963 | "metadata": {}, 964 | "outputs": [ 965 | { 966 | "data": { 967 | "text/plain": [ 968 | "Text(0, 0.5, 'Frequency')" 969 | ] 970 | }, 971 | "execution_count": 13, 972 | "metadata": {}, 973 | "output_type": "execute_result" 974 | }, 975 | { 976 | "data": { 977 | "image/png": "\n", 978 | "text/plain": [ 979 | "
" 980 | ] 981 | }, 982 | "metadata": {}, 983 | "output_type": "display_data" 984 | } 985 | ], 986 | "source": [ 987 | "labels = ['normal','fraud']\n", 988 | "classes = pd.value_counts(train_data.iloc[:,28], sort = True)\n", 989 | "classes.plot(kind = 'bar', rot=0)\n", 990 | "plt.title(\"Transaction class distribution\")\n", 991 | "plt.xticks(range(2), labels)\n", 992 | "plt.xlabel(\"Class\")\n", 993 | "plt.ylabel(\"Frequency\")" 994 | ] 995 | }, 996 | { 997 | "cell_type": "markdown", 998 | "metadata": {}, 999 | "source": [ 1000 | "After balancing the dataset between fraud and non-fraud, we have equal number of records for both the categories in our training set." 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "markdown", 1005 | "metadata": {}, 1006 | "source": [ 1007 | "## Uploading our dataset to S3" 1008 | ] 1009 | }, 1010 | { 1011 | "cell_type": "code", 1012 | "execution_count": 14, 1013 | "metadata": {}, 1014 | "outputs": [ 1015 | { 1016 | "name": "stdout", 1017 | "output_type": "stream", 1018 | "text": [ 1019 | "Training artifacts will be uploaded to: s3://sagemaker-us-east-1-343208833149/sagemaker/NeuralNetwork-fraud/output\n" 1020 | ] 1021 | } 1022 | ], 1023 | "source": [ 1024 | "train_data.to_csv('train.csv', header=False, index=False)\n", 1025 | "validation_data.to_csv('validation.csv', header=False, index=False)\n", 1026 | "test_data.to_csv('test.csv', header=False, index=False)\n", 1027 | "\n", 1028 | "\n", 1029 | "boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train/train.csv')) \\\n", 1030 | " .upload_file('train.csv')\n", 1031 | "boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'validation/validation.csv')) \\\n", 1032 | " .upload_file('validation.csv')\n", 1033 | "boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'test/test.csv')) \\\n", 1034 | " .upload_file('test.csv')\n", 1035 | "\n", 1036 | "s3_train_data = 's3://{}/{}/train/train.csv'.format(bucket, prefix)\n", 1037 | "s3_validation_data = 's3://{}/{}/validation/validation.csv'.format(bucket, prefix)\n", 1038 | "\n", 1039 | "\n", 1040 | "output_location = 's3://{}/{}/output'.format(bucket, prefix)\n", 1041 | "\n", 1042 | "print('Training artifacts will be uploaded to: {}'.format(output_location))" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "markdown", 1047 | "metadata": {}, 1048 | "source": [ 1049 | "## Set Up and Launch Training" 1050 | ] 1051 | }, 1052 | { 1053 | "cell_type": "code", 1054 | "execution_count": 15, 1055 | "metadata": {}, 1056 | "outputs": [], 1057 | "source": [ 1058 | "role = sagemaker.get_execution_role()" 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "code", 1063 | "execution_count": 16, 1064 | "metadata": {}, 1065 | "outputs": [], 1066 | "source": [ 1067 | "epochs = 50\n", 1068 | "batchsize = 1000\n", 1069 | "key = \"data\"\n", 1070 | "key_output = \"output\" # Path from the bucket's root to the dataset\n", 1071 | "train_instance_type='ml.m4.xlarge' # The type of EC2 instance which will be used for training\n", 1072 | "deploy_instance_type='ml.m4.xlarge' # The type of EC2 instance which will be used for deployment\n", 1073 | "hyperparameters={\n", 1074 | " \"learning_rate\": 1e-4,\n", 1075 | " \"decay\": 1e-6,\n", 1076 | " \"epochs\": epochs, \n", 1077 | " \"batch_size\": batchsize\n", 1078 | "}" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "code", 1083 | "execution_count": 17, 1084 | "metadata": {}, 1085 | "outputs": [], 1086 | "source": [ 1087 | "inputs = {'training': s3_train_data, 'validation': s3_validation_data}" 1088 | ] 1089 | }, 1090 | { 1091 | "cell_type": "code", 1092 | "execution_count": 18, 1093 | "metadata": {}, 1094 | "outputs": [], 1095 | "source": [ 1096 | "model_dir = '/opt/ml/model'" 1097 | ] 1098 | }, 1099 | { 1100 | "cell_type": "code", 1101 | "execution_count": null, 1102 | "metadata": { 1103 | "scrolled": false 1104 | }, 1105 | "outputs": [], 1106 | "source": [ 1107 | "my_estimator = TensorFlow(entry_point='train_nn.py',\n", 1108 | " role=role,\n", 1109 | " model_dir=model_dir,\n", 1110 | " framework_version='1.13', \n", 1111 | " train_instance_count=1,\n", 1112 | " train_instance_type=train_instance_type,\n", 1113 | " #train_instance_type='local',\n", 1114 | " py_version='py3',\n", 1115 | " script_mode=True,\n", 1116 | " base_job_name='Neural-Net-Fraud-Detection',\n", 1117 | " hyperparameters=hyperparameters\n", 1118 | " )\n", 1119 | "my_estimator.fit(inputs=inputs,logs=True)" 1120 | ] 1121 | }, 1122 | { 1123 | "cell_type": "markdown", 1124 | "metadata": {}, 1125 | "source": [ 1126 | "## Deploy our Trained Model on SageMaker Instances" 1127 | ] 1128 | }, 1129 | { 1130 | "cell_type": "markdown", 1131 | "metadata": {}, 1132 | "source": [ 1133 | "Note: Deployment process may take from 5-10mins !" 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "code", 1138 | "execution_count": 20, 1139 | "metadata": { 1140 | "scrolled": true 1141 | }, 1142 | "outputs": [ 1143 | { 1144 | "name": "stdout", 1145 | "output_type": "stream", 1146 | "text": [ 1147 | "---------------------------------------------------------------------------------------------------!" 1148 | ] 1149 | } 1150 | ], 1151 | "source": [ 1152 | "my_estimator.name = 'deployed-neural-net-prediction'\n", 1153 | "my_predictor = my_estimator.deploy(initial_instance_count = 1, instance_type = deploy_instance_type)" 1154 | ] 1155 | }, 1156 | { 1157 | "cell_type": "markdown", 1158 | "metadata": {}, 1159 | "source": [ 1160 | "## Prediction on Test Dataset" 1161 | ] 1162 | }, 1163 | { 1164 | "cell_type": "code", 1165 | "execution_count": 21, 1166 | "metadata": {}, 1167 | "outputs": [], 1168 | "source": [ 1169 | "from sagemaker.predictor import csv_serializer \n", 1170 | "\n", 1171 | "my_predictor.content_type = 'text/csv'\n", 1172 | "my_predictor.serializer = csv_serializer\n", 1173 | "my_predictor.deserializer = None" 1174 | ] 1175 | }, 1176 | { 1177 | "cell_type": "code", 1178 | "execution_count": 22, 1179 | "metadata": {}, 1180 | "outputs": [], 1181 | "source": [ 1182 | "import json" 1183 | ] 1184 | }, 1185 | { 1186 | "cell_type": "code", 1187 | "execution_count": 23, 1188 | "metadata": {}, 1189 | "outputs": [ 1190 | { 1191 | "name": "stdout", 1192 | "output_type": "stream", 1193 | "text": [ 1194 | "No of rows to predict = 28481\n" 1195 | ] 1196 | } 1197 | ], 1198 | "source": [ 1199 | "print('No of rows to predict =',len(test_data))" 1200 | ] 1201 | }, 1202 | { 1203 | "cell_type": "code", 1204 | "execution_count": 24, 1205 | "metadata": {}, 1206 | "outputs": [], 1207 | "source": [ 1208 | "y_true = test_data.iloc[:,28:29]" 1209 | ] 1210 | }, 1211 | { 1212 | "cell_type": "code", 1213 | "execution_count": 25, 1214 | "metadata": {}, 1215 | "outputs": [ 1216 | { 1217 | "name": "stderr", 1218 | "output_type": "stream", 1219 | "text": [ 1220 | "/home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ipykernel/__main__.py:12: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n" 1221 | ] 1222 | } 1223 | ], 1224 | "source": [ 1225 | "def predict(data, rows=500):\n", 1226 | " split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))\n", 1227 | " predictions = []\n", 1228 | " for array in split_array:\n", 1229 | " x = my_predictor.predict(array).decode('utf-8')\n", 1230 | " x = json.loads(x)\n", 1231 | " x = np.array(x[\"predictions\"])\n", 1232 | " y = len(x)\n", 1233 | " predictions = np.append(predictions,x)\n", 1234 | " return predictions\n", 1235 | "\n", 1236 | "result_out = predict(test_data.as_matrix()[:,0:28])" 1237 | ] 1238 | }, 1239 | { 1240 | "cell_type": "code", 1241 | "execution_count": 26, 1242 | "metadata": {}, 1243 | "outputs": [], 1244 | "source": [ 1245 | "y_pred = (result_out>0.5)" 1246 | ] 1247 | }, 1248 | { 1249 | "cell_type": "code", 1250 | "execution_count": 27, 1251 | "metadata": {}, 1252 | "outputs": [], 1253 | "source": [ 1254 | "y_true = test_data.iloc[:,28:29]" 1255 | ] 1256 | }, 1257 | { 1258 | "cell_type": "code", 1259 | "execution_count": 28, 1260 | "metadata": {}, 1261 | "outputs": [ 1262 | { 1263 | "name": "stdout", 1264 | "output_type": "stream", 1265 | "text": [ 1266 | "No. of rows predicted = 28481\n" 1267 | ] 1268 | } 1269 | ], 1270 | "source": [ 1271 | "print('No. of rows predicted =',len(y_pred))" 1272 | ] 1273 | }, 1274 | { 1275 | "cell_type": "markdown", 1276 | "metadata": {}, 1277 | "source": [ 1278 | "## Analyzing our Results" 1279 | ] 1280 | }, 1281 | { 1282 | "cell_type": "code", 1283 | "execution_count": 29, 1284 | "metadata": {}, 1285 | "outputs": [ 1286 | { 1287 | "name": "stdout", 1288 | "output_type": "stream", 1289 | "text": [ 1290 | "Number of frauds: 46\n", 1291 | "Number of non-frauds: 28435\n", 1292 | "Percentage of fradulent data: 0.16151118289385907\n" 1293 | ] 1294 | } 1295 | ], 1296 | "source": [ 1297 | "test_nonfrauds, test_frauds = test_data.groupby('Class').size()\n", 1298 | "print('Number of frauds: ', test_frauds)\n", 1299 | "print('Number of non-frauds: ', test_nonfrauds)\n", 1300 | "print('Percentage of fradulent data:', 100.*test_frauds/(test_frauds + test_nonfrauds))" 1301 | ] 1302 | }, 1303 | { 1304 | "cell_type": "code", 1305 | "execution_count": 30, 1306 | "metadata": {}, 1307 | "outputs": [ 1308 | { 1309 | "data": { 1310 | "image/png": "\n", 1311 | "text/plain": [ 1312 | "
" 1313 | ] 1314 | }, 1315 | "metadata": {}, 1316 | "output_type": "display_data" 1317 | } 1318 | ], 1319 | "source": [ 1320 | "from sklearn.metrics import confusion_matrix, precision_recall_curve\n", 1321 | "import seaborn as sns\n", 1322 | "import matplotlib.pyplot as plt\n", 1323 | "LABELS = [\"Normal\",\"Fraud\"]\n", 1324 | "conf_matrix = confusion_matrix(y_true=y_true, y_pred=y_pred)\n", 1325 | "tn, fp, fn, tp = conf_matrix.ravel() \n", 1326 | "plt.figure(figsize=(12, 12))\n", 1327 | "sns.heatmap(conf_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt=\"d\");\n", 1328 | "plt.title(\"Confusion matrix\")\n", 1329 | "plt.ylabel('True class')\n", 1330 | "plt.xlabel('Predicted class')\n", 1331 | "plt.show()" 1332 | ] 1333 | }, 1334 | { 1335 | "cell_type": "code", 1336 | "execution_count": 33, 1337 | "metadata": {}, 1338 | "outputs": [ 1339 | { 1340 | "name": "stdout", 1341 | "output_type": "stream", 1342 | "text": [ 1343 | " precision recall f1-score support\n", 1344 | "\n", 1345 | " 0 1.00 1.00 1.00 28435\n", 1346 | " 1 0.61 0.76 0.68 46\n", 1347 | "\n", 1348 | " accuracy 1.00 28481\n", 1349 | " macro avg 0.81 0.88 0.84 28481\n", 1350 | "weighted avg 1.00 1.00 1.00 28481\n", 1351 | "\n" 1352 | ] 1353 | } 1354 | ], 1355 | "source": [ 1356 | "from sklearn.metrics import classification_report\n", 1357 | "print(classification_report(y_true, y_pred))" 1358 | ] 1359 | }, 1360 | { 1361 | "cell_type": "code", 1362 | "execution_count": 36, 1363 | "metadata": {}, 1364 | "outputs": [], 1365 | "source": [ 1366 | "from sklearn.metrics import precision_score,accuracy_score, recall_score\n", 1367 | "accuracy = accuracy_score(y_true=y_true, y_pred=y_pred)\n", 1368 | "precision = precision_score(y_true=y_true, y_pred=y_pred)\n", 1369 | "recall = recall_score(y_true=y_true, y_pred=y_pred)" 1370 | ] 1371 | }, 1372 | { 1373 | "cell_type": "code", 1374 | "execution_count": 37, 1375 | "metadata": {}, 1376 | "outputs": [ 1377 | { 1378 | "name": "stdout", 1379 | "output_type": "stream", 1380 | "text": [ 1381 | "\n", 1382 | "Accuracy Score: 1.0\n", 1383 | "\n", 1384 | "Precision Score: 0.61\n", 1385 | "\n", 1386 | "Recall Score: 0.76\n" 1387 | ] 1388 | } 1389 | ], 1390 | "source": [ 1391 | "print (\"\")\n", 1392 | "print (\"Accuracy Score: \", round(accuracy, 2))\n", 1393 | "print (\"\")\n", 1394 | "print (\"Precision Score: \", round(precision, 2))\n", 1395 | "print (\"\")\n", 1396 | "print (\"Recall Score: \", round(recall, 2))" 1397 | ] 1398 | }, 1399 | { 1400 | "cell_type": "markdown", 1401 | "metadata": {}, 1402 | "source": [ 1403 | "## Data Acknowledgements\n", 1404 | "The dataset used to demonstrated the fraud detection solution has been collected and analysed during a research collaboration of Worldline and the Machine Learning Group (http://mlg.ulb.ac.be) of ULB (Université Libre de Bruxelles) on big data mining and fraud detection. More details on current and past projects on related topics are available on https://www.researchgate.net/project/Fraud-detection-5 and the page of the DefeatFraud project We cite the following works:\n", 1405 | "\n", 1406 | "- Andrea Dal Pozzolo, Olivier Caelen, Reid A. Johnson and Gianluca Bontempi. Calibrating Probability with Undersampling for Unbalanced Classification. In Symposium on Computational Intelligence and Data Mining (CIDM), IEEE, 2015\n", 1407 | "- Dal Pozzolo, Andrea; Caelen, Olivier; Le Borgne, Yann-Ael; Waterschoot, Serge; Bontempi, Gianluca. Learned lessons in credit card fraud detection from a practitioner perspective, Expert systems with applications,41,10,4915-4928,2014, Pergamon\n", 1408 | "- Dal Pozzolo, Andrea; Boracchi, Giacomo; Caelen, Olivier; Alippi, Cesare; Bontempi, Gianluca. Credit card fraud detection: a realistic modeling and a novel learning strategy, IEEE transactions on neural networks and learning systems,29,8,3784-3797,2018,IEEE\n", 1409 | "- Dal Pozzolo, Andrea Adaptive Machine learning for credit card fraud detection ULB MLG PhD thesis (supervised by G. Bontempi)\n", 1410 | "- Carcillo, Fabrizio; Dal Pozzolo, Andrea; Le Borgne, Yann-Aël; Caelen, Olivier; Mazzer, Yannis; Bontempi, Gianluca. Scarff: a scalable framework for streaming credit card fraud detection with Spark, Information fusion,41, 182-194,2018,Elsevier\n", 1411 | "- Carcillo, Fabrizio; Le Borgne, Yann-Aël; Caelen, Olivier; Bontempi, Gianluca. Streaming active learning strategies for real-life credit card fraud detection: assessment and visualization, International Journal of Data Science and Analytics, 5,4,285-300,2018,Springer International Publishing" 1412 | ] 1413 | } 1414 | ], 1415 | "metadata": { 1416 | "kernelspec": { 1417 | "display_name": "conda_tensorflow_p36", 1418 | "language": "python", 1419 | "name": "conda_tensorflow_p36" 1420 | }, 1421 | "language_info": { 1422 | "codemirror_mode": { 1423 | "name": "ipython", 1424 | "version": 3 1425 | }, 1426 | "file_extension": ".py", 1427 | "mimetype": "text/x-python", 1428 | "name": "python", 1429 | "nbconvert_exporter": "python", 1430 | "pygments_lexer": "ipython3", 1431 | "version": "3.6.5" 1432 | } 1433 | }, 1434 | "nbformat": 4, 1435 | "nbformat_minor": 2 1436 | } 1437 | -------------------------------------------------------------------------------- /notebooks/sagemaker_fraud_detection_xgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Credit card fraud detector" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Investigate and process the data" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "Let's start by downloading and reading in the credit card fraud data set." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "%%bash\n", 31 | "wget https://s3-us-west-2.amazonaws.com/sagemaker-e2e-solutions/fraud-detection/creditcardfraud.zip\n", 32 | "unzip creditcardfraud.zip" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 26, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import numpy as np \n", 42 | "import pandas as pd\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "\n", 45 | "data = pd.read_csv('creditcard.csv', delimiter=',')" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "Let's take a peek at our data (we only show a subset of the columns in the table):" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 27, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "Index(['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10',\n", 65 | " 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20',\n", 66 | " 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount',\n", 67 | " 'Class'],\n", 68 | " dtype='object')\n" 69 | ] 70 | }, 71 | { 72 | "data": { 73 | "text/html": [ 74 | "
\n", 75 | "\n", 88 | "\n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | "
TimeV1V2V3V4V5V6V7V8V9...V21V22V23V24V25V26V27V28AmountClass
00.0-1.359807-0.0727812.5363471.378155-0.3383210.4623880.2395990.0986980.363787...-0.0183070.277838-0.1104740.0669280.128539-0.1891150.133558-0.021053149.620
10.01.1918570.2661510.1664800.4481540.060018-0.082361-0.0788030.085102-0.255425...-0.225775-0.6386720.101288-0.3398460.1671700.125895-0.0089830.0147242.690
21.0-1.358354-1.3401631.7732090.379780-0.5031981.8004990.7914610.247676-1.514654...0.2479980.7716790.909412-0.689281-0.327642-0.139097-0.055353-0.059752378.660
31.0-0.966272-0.1852261.792993-0.863291-0.0103091.2472030.2376090.377436-1.387024...-0.1083000.005274-0.190321-1.1755750.647376-0.2219290.0627230.061458123.500
42.0-1.1582330.8777371.5487180.403034-0.4071930.0959210.592941-0.2705330.817739...-0.0094310.798278-0.1374580.141267-0.2060100.5022920.2194220.21515369.990
52.0-0.4259660.9605231.141109-0.1682520.420987-0.0297280.4762010.260314-0.568671...-0.208254-0.559825-0.026398-0.371427-0.2327940.1059150.2538440.0810803.670
64.01.2296580.1410040.0453711.2026130.1918810.272708-0.0051590.0812130.464960...-0.167716-0.270710-0.154104-0.7800550.750137-0.2572370.0345070.0051684.990
77.0-0.6442691.4179641.074380-0.4921990.9489340.4281181.120631-3.8078640.615375...1.943465-1.0154550.057504-0.649709-0.415267-0.051634-1.206921-1.08533940.800
87.0-0.8942860.286157-0.113192-0.2715262.6695993.7218180.3701450.851084-0.392048...-0.073425-0.268092-0.2042331.0115920.373205-0.3841570.0117470.14240493.200
99.0-0.3382621.1195931.044367-0.2221870.499361-0.2467610.6515830.069539-0.736727...-0.246914-0.633753-0.120794-0.385050-0.0697330.0941990.2462190.0830763.680
\n", 358 | "

10 rows × 31 columns

\n", 359 | "
" 360 | ], 361 | "text/plain": [ 362 | " Time V1 V2 V3 V4 V5 V6 V7 \\\n", 363 | "0 0.0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 0.239599 \n", 364 | "1 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 \n", 365 | "2 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 \n", 366 | "3 1.0 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 0.237609 \n", 367 | "4 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 \n", 368 | "5 2.0 -0.425966 0.960523 1.141109 -0.168252 0.420987 -0.029728 0.476201 \n", 369 | "6 4.0 1.229658 0.141004 0.045371 1.202613 0.191881 0.272708 -0.005159 \n", 370 | "7 7.0 -0.644269 1.417964 1.074380 -0.492199 0.948934 0.428118 1.120631 \n", 371 | "8 7.0 -0.894286 0.286157 -0.113192 -0.271526 2.669599 3.721818 0.370145 \n", 372 | "9 9.0 -0.338262 1.119593 1.044367 -0.222187 0.499361 -0.246761 0.651583 \n", 373 | "\n", 374 | " V8 V9 ... V21 V22 V23 V24 V25 \\\n", 375 | "0 0.098698 0.363787 ... -0.018307 0.277838 -0.110474 0.066928 0.128539 \n", 376 | "1 0.085102 -0.255425 ... -0.225775 -0.638672 0.101288 -0.339846 0.167170 \n", 377 | "2 0.247676 -1.514654 ... 0.247998 0.771679 0.909412 -0.689281 -0.327642 \n", 378 | "3 0.377436 -1.387024 ... -0.108300 0.005274 -0.190321 -1.175575 0.647376 \n", 379 | "4 -0.270533 0.817739 ... -0.009431 0.798278 -0.137458 0.141267 -0.206010 \n", 380 | "5 0.260314 -0.568671 ... -0.208254 -0.559825 -0.026398 -0.371427 -0.232794 \n", 381 | "6 0.081213 0.464960 ... -0.167716 -0.270710 -0.154104 -0.780055 0.750137 \n", 382 | "7 -3.807864 0.615375 ... 1.943465 -1.015455 0.057504 -0.649709 -0.415267 \n", 383 | "8 0.851084 -0.392048 ... -0.073425 -0.268092 -0.204233 1.011592 0.373205 \n", 384 | "9 0.069539 -0.736727 ... -0.246914 -0.633753 -0.120794 -0.385050 -0.069733 \n", 385 | "\n", 386 | " V26 V27 V28 Amount Class \n", 387 | "0 -0.189115 0.133558 -0.021053 149.62 0 \n", 388 | "1 0.125895 -0.008983 0.014724 2.69 0 \n", 389 | "2 -0.139097 -0.055353 -0.059752 378.66 0 \n", 390 | "3 -0.221929 0.062723 0.061458 123.50 0 \n", 391 | "4 0.502292 0.219422 0.215153 69.99 0 \n", 392 | "5 0.105915 0.253844 0.081080 3.67 0 \n", 393 | "6 -0.257237 0.034507 0.005168 4.99 0 \n", 394 | "7 -0.051634 -1.206921 -1.085339 40.80 0 \n", 395 | "8 -0.384157 0.011747 0.142404 93.20 0 \n", 396 | "9 0.094199 0.246219 0.083076 3.68 0 \n", 397 | "\n", 398 | "[10 rows x 31 columns]" 399 | ] 400 | }, 401 | "execution_count": 27, 402 | "metadata": {}, 403 | "output_type": "execute_result" 404 | } 405 | ], 406 | "source": [ 407 | "print(data.columns)\n", 408 | "data[['Time', 'V1', 'V2', 'V27', 'V28', 'Amount', 'Class']].describe()\n", 409 | "data.head(10)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "markdown", 414 | "metadata": {}, 415 | "source": [ 416 | "The class column corresponds to whether or not a transaction is fradulent. We see that the majority of data is non-fraudulant with only $492$ ($.173\\%$) of the data corresponding to fraudulant examples." 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 28, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "name": "stdout", 426 | "output_type": "stream", 427 | "text": [ 428 | "Number of frauds: 492\n", 429 | "Number of non-frauds: 284315\n", 430 | "Percentage of fradulent data: 0.1727485630620034\n" 431 | ] 432 | } 433 | ], 434 | "source": [ 435 | "nonfrauds, frauds = data.groupby('Class').size()\n", 436 | "print('Number of frauds: ', frauds)\n", 437 | "print('Number of non-frauds: ', nonfrauds)\n", 438 | "print('Percentage of fradulent data:', 100.*frauds/(frauds + nonfrauds))" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "This dataset has 28 columns, $V_i$ for $i=1..28$ of anonymized features along with columns for time, amount, and class. We already know that the columns $V_i$ have been normalized to have $0$ mean and unit standard deviation as the result of a PCA. You can read more about PCA here:. \n", 446 | "\n", 447 | "Tip: For our dataset this amount of preprocessing will give us reasonable accuracy, but it's important to note that there are more preprocessing steps one can use to improve accuracy . For unbalanced data sets like ours where the positive (fraudulent) examples occur much less frequently than the negative (legitimate) examples, we may try “over-sampling” the minority dataset by generating synthetic data (read about SMOTE in Data Mining for Imbalanced Datasets: An Overview (https://link.springer.com/chapter/10.1007%2F0-387-25465-X_40) or undersampling the majority class by using ensemble methods (see http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.68.6858&rep=rep1&type=pdfor)." 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 29, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "feature_columns = data.columns[:-1]\n", 457 | "label_column = data.columns[-1]\n", 458 | "\n", 459 | "features = data[feature_columns].values.astype('float32')\n", 460 | "labels = (data[label_column].values).astype('float32')" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": {}, 466 | "source": [ 467 | "Let's do some analysis and discuss different ways we can preprocess our data. Let's discuss the way in which this data was preprocessed." 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": {}, 473 | "source": [ 474 | "## SageMaker XGB" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "metadata": {}, 480 | "source": [ 481 | "### Prepare Data and Upload to S3" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": {}, 487 | "source": [ 488 | "The Amazon common libraries provide utilities to convert NumPy n-dimensional arrays into a the Record-IO format which SageMaker uses for a concise representation of features and labels. The Record-IO format is implemented via protocol buffer so the serialization is very efficient." 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 30, 494 | "metadata": {}, 495 | "outputs": [ 496 | { 497 | "data": { 498 | "text/html": [ 499 | "
\n", 500 | "\n", 513 | "\n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | "
ClassTimeV1V2V3V4V5V6V7V8...V20V21V22V23V24V25V26V27V28Amount
000.0-1.359807-0.0727812.5363471.378155-0.3383210.4623880.2395990.098698...0.251412-0.0183070.277838-0.1104740.0669280.128539-0.1891150.133558-0.021053149.62
100.01.1918570.2661510.1664800.4481540.060018-0.082361-0.0788030.085102...-0.069083-0.225775-0.6386720.101288-0.3398460.1671700.125895-0.0089830.0147242.69
201.0-1.358354-1.3401631.7732090.379780-0.5031981.8004990.7914610.247676...0.5249800.2479980.7716790.909412-0.689281-0.327642-0.139097-0.055353-0.059752378.66
301.0-0.966272-0.1852261.792993-0.863291-0.0103091.2472030.2376090.377436...-0.208038-0.1083000.005274-0.190321-1.1755750.647376-0.2219290.0627230.061458123.50
402.0-1.1582330.8777371.5487180.403034-0.4071930.0959210.592941-0.270533...0.408542-0.0094310.798278-0.1374580.141267-0.2060100.5022920.2194220.21515369.99
\n", 663 | "

5 rows × 31 columns

\n", 664 | "
" 665 | ], 666 | "text/plain": [ 667 | " Class Time V1 V2 V3 V4 V5 V6 \\\n", 668 | "0 0 0.0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 \n", 669 | "1 0 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 \n", 670 | "2 0 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 \n", 671 | "3 0 1.0 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 \n", 672 | "4 0 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 \n", 673 | "\n", 674 | " V7 V8 ... V20 V21 V22 V23 V24 \\\n", 675 | "0 0.239599 0.098698 ... 0.251412 -0.018307 0.277838 -0.110474 0.066928 \n", 676 | "1 -0.078803 0.085102 ... -0.069083 -0.225775 -0.638672 0.101288 -0.339846 \n", 677 | "2 0.791461 0.247676 ... 0.524980 0.247998 0.771679 0.909412 -0.689281 \n", 678 | "3 0.237609 0.377436 ... -0.208038 -0.108300 0.005274 -0.190321 -1.175575 \n", 679 | "4 0.592941 -0.270533 ... 0.408542 -0.009431 0.798278 -0.137458 0.141267 \n", 680 | "\n", 681 | " V25 V26 V27 V28 Amount \n", 682 | "0 0.128539 -0.189115 0.133558 -0.021053 149.62 \n", 683 | "1 0.167170 0.125895 -0.008983 0.014724 2.69 \n", 684 | "2 -0.327642 -0.139097 -0.055353 -0.059752 378.66 \n", 685 | "3 0.647376 -0.221929 0.062723 0.061458 123.50 \n", 686 | "4 -0.206010 0.502292 0.219422 0.215153 69.99 \n", 687 | "\n", 688 | "[5 rows x 31 columns]" 689 | ] 690 | }, 691 | "execution_count": 30, 692 | "metadata": {}, 693 | "output_type": "execute_result" 694 | } 695 | ], 696 | "source": [ 697 | "model_data = data\n", 698 | "model_data.head()\n", 699 | "model_data = pd.concat([model_data['Class'], model_data.drop(['Class'], axis=1)], axis=1)\n", 700 | "model_data.head()\n" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": {}, 706 | "source": [ 707 | "### Now we upload the data to S3 using boto3." 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 32, 713 | "metadata": {}, 714 | "outputs": [ 715 | { 716 | "name": "stdout", 717 | "output_type": "stream", 718 | "text": [ 719 | "Uploaded training data location: s3://sagemaker-us-east-1-282128611277/sagemaker/DEMO-xgboost-fraud/train/train.csv\n", 720 | "Uploaded training data location: s3://sagemaker-us-east-1-282128611277/sagemaker/DEMO-xgboost-fraud/validation/validation.csv\n", 721 | "Training artifacts will be uploaded to: s3://sagemaker-us-east-1-282128611277/sagemaker/DEMO-xgboost-fraud/output\n" 722 | ] 723 | } 724 | ], 725 | "source": [ 726 | "import boto3\n", 727 | "import os\n", 728 | "import sagemaker\n", 729 | "\n", 730 | "session = sagemaker.Session()\n", 731 | "\n", 732 | "bucket = session.default_bucket()\n", 733 | "sagemaker_iam_role = sagemaker.get_execution_role()\n", 734 | "\n", 735 | "prefix = 'sagemaker/DEMO-xgboost-fraud'\n", 736 | "\n", 737 | "train_data, validation_data, test_data = np.split(model_data.sample(frac=1, random_state=1729), \n", 738 | " [int(0.7 * len(model_data)), int(0.9 * len(model_data))])\n", 739 | "train_data.to_csv('train.csv', header=False, index=False)\n", 740 | "validation_data.to_csv('validation.csv', header=False, index=False)\n", 741 | "\n", 742 | "\n", 743 | "boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train/train.csv')) \\\n", 744 | " .upload_file('train.csv')\n", 745 | "boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'validation/validation.csv')) \\\n", 746 | " .upload_file('validation.csv')\n", 747 | "s3_train_data = 's3://{}/{}/train/train.csv'.format(bucket, prefix)\n", 748 | "s3_validation_data = 's3://{}/{}/validation/validation.csv'.format(bucket, prefix)\n", 749 | "print('Uploaded training data location: {}'.format(s3_train_data))\n", 750 | "print('Uploaded training data location: {}'.format(s3_validation_data))\n", 751 | "\n", 752 | "output_location = 's3://{}/{}/output'.format(bucket, prefix)\n", 753 | "print('Training artifacts will be uploaded to: {}'.format(output_location))" 754 | ] 755 | }, 756 | { 757 | "cell_type": "markdown", 758 | "metadata": {}, 759 | "source": [ 760 | "---\n", 761 | "## Train\n", 762 | "\n", 763 | "Moving onto training, first we'll need to specify the locations of the XGBoost algorithm containers.\n", 764 | "To specify the Linear Learner algorithm, we use a utility function to obtain it's URI. A complete list of build-in algorithms is found here: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html" 765 | ] 766 | }, 767 | { 768 | "cell_type": "code", 769 | "execution_count": 33, 770 | "metadata": {}, 771 | "outputs": [ 772 | { 773 | "name": "stderr", 774 | "output_type": "stream", 775 | "text": [ 776 | "WARNING:root:There is a more up to date SageMaker XGBoost image.To use the newer image, please set 'repo_version'='0.90-1. For example:\n", 777 | "\tget_image_uri(region, 'xgboost', '0.90-1').\n" 778 | ] 779 | } 780 | ], 781 | "source": [ 782 | "from sagemaker.amazon.amazon_estimator import get_image_uri\n", 783 | "container = get_image_uri(boto3.Session().region_name, 'xgboost')" 784 | ] 785 | }, 786 | { 787 | "cell_type": "markdown", 788 | "metadata": {}, 789 | "source": [ 790 | "Then, because we're training with the CSV file format, we'll create s3_inputs that our training function can use as a pointer to the files in S3." 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": 34, 796 | "metadata": {}, 797 | "outputs": [], 798 | "source": [ 799 | "s3_input_train = sagemaker.s3_input(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n", 800 | "s3_input_validation = sagemaker.s3_input(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')" 801 | ] 802 | }, 803 | { 804 | "cell_type": "markdown", 805 | "metadata": {}, 806 | "source": [ 807 | "Now, we can specify a few parameters like what type of training instances we'd like to use and how many, as well as our XGBoost hyperparameters. A few key hyperparameters are:\n", 808 | "- `max_depth` controls how deep each tree within the algorithm can be built. Deeper trees can lead to better fit, but are more computationally expensive and can lead to overfitting. There is typically some trade-off in model performance that needs to be explored between a large number of shallow trees and a smaller number of deeper trees.\n", 809 | "- `subsample` controls sampling of the training data. This technique can help reduce overfitting, but setting it too low can also starve the model of data.\n", 810 | "- `num_round` controls the number of boosting rounds. This is essentially the subsequent models that are trained using the residuals of previous iterations. Again, more rounds should produce a better fit on the training data, but can be computationally expensive or lead to overfitting.\n", 811 | "- `eta` controls how aggressive each round of boosting is. Larger values lead to more conservative boosting.\n", 812 | "- `gamma` controls how aggressively trees are grown. Larger values lead to more conservative models.\n", 813 | "\n", 814 | "More detail on XGBoost's hyperparmeters can be found on their GitHub [page](https://github.com/dmlc/xgboost/blob/master/doc/parameter.md)." 815 | ] 816 | }, 817 | { 818 | "cell_type": "markdown", 819 | "metadata": {}, 820 | "source": [ 821 | "SageMaker abstracts training with Estimators. We can pass container, and all parameters to the estimator, as well as the hyperparameters for the linear learner and fit the estimator to the data in S3.\n", 822 | "Note: For IP protection reasons, SageMaker built-in algorithms, such as XGBoost, can't be run locally, i.e. on the same instance where this Jupyter Notebook code is running. " 823 | ] 824 | }, 825 | { 826 | "cell_type": "code", 827 | "execution_count": 35, 828 | "metadata": {}, 829 | "outputs": [], 830 | "source": [ 831 | "xgb = sagemaker.estimator.Estimator(container,\n", 832 | " role=sagemaker_iam_role, \n", 833 | " train_instance_count=1, \n", 834 | " train_instance_type='ml.m4.xlarge',\n", 835 | " output_path=output_location,\n", 836 | " sagemaker_session=session)\n", 837 | "xgb.set_hyperparameters(max_depth=5,\n", 838 | " eta=0.2,\n", 839 | " gamma=4,\n", 840 | " min_child_weight=6,\n", 841 | " subsample=0.8,\n", 842 | " silent=0,\n", 843 | " objective='binary:logistic',\n", 844 | " num_round=100)" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": null, 850 | "metadata": {}, 851 | "outputs": [], 852 | "source": [ 853 | "xgb.fit({'train': s3_input_train, 'validation': s3_input_validation}) " 854 | ] 855 | }, 856 | { 857 | "cell_type": "markdown", 858 | "metadata": {}, 859 | "source": [ 860 | "### Host XGBoost Model" 861 | ] 862 | }, 863 | { 864 | "cell_type": "markdown", 865 | "metadata": {}, 866 | "source": [ 867 | "Now we deploy the estimator to and endpoint." 868 | ] 869 | }, 870 | { 871 | "cell_type": "code", 872 | "execution_count": 38, 873 | "metadata": {}, 874 | "outputs": [ 875 | { 876 | "name": "stderr", 877 | "output_type": "stream", 878 | "text": [ 879 | "WARNING:sagemaker:Using already existing model: xgboost-2019-12-03-21-58-34-726\n" 880 | ] 881 | }, 882 | { 883 | "name": "stdout", 884 | "output_type": "stream", 885 | "text": [ 886 | "---------------------------------------------------------------------------------------------------------------!" 887 | ] 888 | } 889 | ], 890 | "source": [ 891 | "xgb.name = 'deployed-xgboost-fraud-prediction'\n", 892 | "xgb_predictor = xgb.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge',\n", 893 | " endpoint_name='deployed-xgboost-fraud-prediction')" 894 | ] 895 | }, 896 | { 897 | "cell_type": "markdown", 898 | "metadata": {}, 899 | "source": [ 900 | "### Evaluate\n", 901 | "\n", 902 | "Now that we have a hosted endpoint running, we can make real-time predictions from our model very easily, \n", 903 | "simply by making an http POST request. But first, we'll need to setup serializers and deserializers for passing our `test_data` NumPy arrays to the model behind the endpoint." 904 | ] 905 | }, 906 | { 907 | "cell_type": "code", 908 | "execution_count": 39, 909 | "metadata": {}, 910 | "outputs": [], 911 | "source": [ 912 | "from sagemaker.predictor import csv_serializer \n", 913 | "\n", 914 | "xgb_predictor.content_type = 'text/csv'\n", 915 | "xgb_predictor.serializer = csv_serializer\n", 916 | "xgb_predictor.deserializer = None" 917 | ] 918 | }, 919 | { 920 | "cell_type": "markdown", 921 | "metadata": {}, 922 | "source": [ 923 | "Now, we'll use a simple function to:\n", 924 | "1. Loop over our test dataset\n", 925 | "1. Split it into mini-batches of rows \n", 926 | "1. Convert those mini-batchs to CSV string payloads\n", 927 | "1. Retrieve mini-batch predictions by invoking the XGBoost endpoint\n", 928 | "1. Collect predictions and convert from the CSV output our model provides into a NumPy array" 929 | ] 930 | }, 931 | { 932 | "cell_type": "code", 933 | "execution_count": 40, 934 | "metadata": {}, 935 | "outputs": [ 936 | { 937 | "name": "stderr", 938 | "output_type": "stream", 939 | "text": [ 940 | "/home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/ipykernel/__main__.py:9: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n" 941 | ] 942 | } 943 | ], 944 | "source": [ 945 | "def predict(data, rows=500):\n", 946 | " split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))\n", 947 | " predictions = ''\n", 948 | " for array in split_array:\n", 949 | " predictions = ','.join([predictions, xgb_predictor.predict(array).decode('utf-8')])\n", 950 | "\n", 951 | " return np.fromstring(predictions[1:], sep=',')\n", 952 | "\n", 953 | "predictions = predict(test_data.as_matrix()[:, 1:])" 954 | ] 955 | }, 956 | { 957 | "cell_type": "markdown", 958 | "metadata": {}, 959 | "source": [ 960 | "There are many ways to compare the performance of a machine learning model, but let's start by simply by comparing actual to predicted values. In this case, we're simply predicting whether the customer churned (1) or not (0), which produces a simple confusion matrix." 961 | ] 962 | }, 963 | { 964 | "cell_type": "code", 965 | "execution_count": 41, 966 | "metadata": {}, 967 | "outputs": [ 968 | { 969 | "name": "stdout", 970 | "output_type": "stream", 971 | "text": [ 972 | "Number of frauds: 54\n", 973 | "Number of non-frauds: 28427\n", 974 | "Percentage of fradulent data: 0.1896000842667041\n" 975 | ] 976 | } 977 | ], 978 | "source": [ 979 | "test_nonfrauds, test_frauds = test_data.groupby('Class').size()\n", 980 | "print('Number of frauds: ', test_frauds)\n", 981 | "print('Number of non-frauds: ', test_nonfrauds)\n", 982 | "print('Percentage of fradulent data:', 100.*test_frauds/(test_frauds + test_nonfrauds))" 983 | ] 984 | }, 985 | { 986 | "cell_type": "code", 987 | "execution_count": 42, 988 | "metadata": {}, 989 | "outputs": [ 990 | { 991 | "data": { 992 | "text/html": [ 993 | "
\n", 994 | "\n", 1007 | "\n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | "
predictions0.01.0
actual
0284261
11440
\n", 1033 | "
" 1034 | ], 1035 | "text/plain": [ 1036 | "predictions 0.0 1.0\n", 1037 | "actual \n", 1038 | "0 28426 1\n", 1039 | "1 14 40" 1040 | ] 1041 | }, 1042 | "execution_count": 42, 1043 | "metadata": {}, 1044 | "output_type": "execute_result" 1045 | } 1046 | ], 1047 | "source": [ 1048 | "pd.crosstab(index=test_data.iloc[:, 0], columns=np.round(predictions), rownames=['actual'], colnames=['predictions'])" 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "code", 1053 | "execution_count": 43, 1054 | "metadata": {}, 1055 | "outputs": [ 1056 | { 1057 | "name": "stdout", 1058 | "output_type": "stream", 1059 | "text": [ 1060 | "precision: 0.98\n", 1061 | "recall: 0.74\n" 1062 | ] 1063 | } 1064 | ], 1065 | "source": [ 1066 | "#precision: tp / (tp + fp)\n", 1067 | "#recall: tp / (tp + fn)\n", 1068 | "from sklearn.metrics import precision_recall_fscore_support\n", 1069 | "results = precision_recall_fscore_support(test_data.iloc[:, 0],\n", 1070 | " np.round(predictions))\n", 1071 | "print('precision: ', round(results[0][1], 2))\n", 1072 | "print('recall: ', round(results[1][1], 2))" 1073 | ] 1074 | }, 1075 | { 1076 | "cell_type": "markdown", 1077 | "metadata": {}, 1078 | "source": [ 1079 | "Note, due to randomized elements of the algorithm, you results may differ slightly.\n", 1080 | "\n", 1081 | "Of the 54 fraudsters, we've correctly predicted 40 of them (true positives). And, we incorrectly predicted 1 case of fraud (false positive). There are also 14 cases of fraud that the model classified as benign transaction (false negatives) - which can get really expensive.\n", 1082 | "\n", 1083 | "An important point here is that because of the np.round() function above we are using a simple threshold (or cutoff) of 0.5. Our predictions from xgboost come out as continuous values between 0 and 1 and we force them into the binary classes that we began with. So, we should consider adjusting this cutoff. That will almost certainly increase the number of false positives, but it can also be expected to increase the number of true positives and reduce the number of false negatives.\n", 1084 | "\n", 1085 | "To get a rough intuition here, let's look at the continuous values of our predictions." 1086 | ] 1087 | }, 1088 | { 1089 | "cell_type": "code", 1090 | "execution_count": 44, 1091 | "metadata": {}, 1092 | "outputs": [ 1093 | { 1094 | "data": { 1095 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD8CAYAAACcjGjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEOpJREFUeJzt3H+s3XV9x/HnSyrOTR3VVkKgW5nWZJVliA12cdlQFigssZgZAolSCbFGYdHNLKL7AwOSSBY1IUFcDQ1lUYH5YzSxrmsYC3FZkTth/BzjDlHaVegogguZDnzvj/OpHvq5l3u49/ae3vb5SE7u97y/n+/3+/60hdf9/jgnVYUkScNeNu4GJEmHHsNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJnSXjbmC2li1bVitXrhx3G5K0aCxbtozt27dvr6p1M41dtOGwcuVKJiYmxt2GJC0qSZaNMs7LSpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkzqL9hPRcrLz0W2M57qOf+eOxHFeSXirPHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktSZMRySrEhyW5IHktyf5COt/qkku5Pc3V5nD23ziSSTSR5KcuZQfV2rTSa5dKh+YpI7Wv2mJEfP90QlSaMb5czhOeBjVbUaWAtcnGR1W/f5qjq5vbYBtHXnAW8G1gFfSHJUkqOAa4CzgNXA+UP7uart643AU8BF8zQ/SdIszBgOVbWnqr7Xln8CPAgc/yKbrAdurKqfVtX3gUng1PaarKpHqupnwI3A+iQB3gl8rW2/BThnthOSJM3dS7rnkGQl8Bbgjla6JMk9STYnWdpqxwOPDW22q9Wmq78O+HFVPXdAXZI0JiOHQ5JXAV8HPlpVzwDXAm8ATgb2AJ89KB2+sIeNSSaSTOzdu/dgH06SjlgjhUOSlzMIhi9X1TcAqurxqnq+qn4OfInBZSOA3cCKoc1PaLXp6k8CxyRZckC9U1WbqmpNVa1Zvnz5KK1LkmZhlKeVAlwHPFhVnxuqHzc07N3AfW15K3BeklckORFYBXwXuBNY1Z5MOprBTeutVVXAbcB72vYbgFvmNi1J0lwsmXkIbwfeB9yb5O5W+ySDp41OBgp4FPggQFXdn+Rm4AEGTzpdXFXPAyS5BNgOHAVsrqr72/4+DtyY5NPAXQzCSJI0JjOGQ1V9B8gUq7a9yDZXAldOUd821XZV9Qi/vCwlSRozPyEtSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSerMGA5JViS5LckDSe5P8pFWf22SHUkebj+XtnqSXJ1kMsk9SU4Z2teGNv7hJBuG6m9Ncm/b5uokORiTlSSNZpQzh+eAj1XVamAtcHGS1cClwK1VtQq4tb0HOAtY1V4bgWthECbAZcDbgFOBy/YHShvzgaHt1s19apKk2ZoxHKpqT1V9ry3/BHgQOB5YD2xpw7YA57Tl9cANNbATOCbJccCZwI6q2ldVTwE7gHVt3WuqamdVFXDD0L4kSWPwku45JFkJvAW4Azi2qva0VT8Cjm3LxwOPDW22q9VerL5rirokaUxGDockrwK+Dny0qp4ZXtd+46957m2qHjYmmUgysXfv3oN9OEk6Yo0UDkleziAYvlxV32jlx9slIdrPJ1p9N7BiaPMTWu3F6idMUe9U1aaqWlNVa5YvXz5K65KkWRjlaaUA1wEPVtXnhlZtBfY/cbQBuGWofkF7amkt8HS7/LQdOCPJ0nYj+gxge1v3TJK17VgXDO1LkjQGS0YY83bgfcC9Se5utU8CnwFuTnIR8APg3LZuG3A2MAk8C1wIUFX7klwB3NnGXV5V+9ryh4HrgVcC324vSdKYzBgOVfUdYLrPHZw+xfgCLp5mX5uBzVPUJ4CTZupFkrQw/IS0JKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOjOGQ5LNSZ5Ict9Q7VNJdie5u73OHlr3iSSTSR5KcuZQfV2rTSa5dKh+YpI7Wv2mJEfP5wQlSS/dKGcO1wPrpqh/vqpObq9tAElWA+cBb27bfCHJUUmOAq4BzgJWA+e3sQBXtX29EXgKuGguE5Ikzd2M4VBVtwP7RtzfeuDGqvppVX0fmAROba/Jqnqkqn4G3AisTxLgncDX2vZbgHNe4hwkSfNsLvccLklyT7vstLTVjgceGxqzq9Wmq78O+HFVPXdAXZI0RrMNh2uBNwAnA3uAz85bRy8iycYkE0km9u7duxCHlKQj0qzCoaoer6rnq+rnwJcYXDYC2A2sGBp6QqtNV38SOCbJkgPq0x13U1Wtqao1y5cvn03rkqQRzCockhw39PbdwP4nmbYC5yV5RZITgVXAd4E7gVXtyaSjGdy03lpVBdwGvKdtvwG4ZTY9SZLmz5KZBiT5KnAasCzJLuAy4LQkJwMFPAp8EKCq7k9yM/AA8BxwcVU93/ZzCbAdOArYXFX3t0N8HLgxyaeBu4Dr5m12kqRZmTEcqur8KcrT/g+8qq4Erpyivg3YNkX9EX55WUqSdAjwE9KSpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpM6M4ZBkc5Inktw3VHttkh1JHm4/l7Z6klydZDLJPUlOGdpmQxv/cJINQ/W3Jrm3bXN1ksz3JCVJL80oZw7XA+sOqF0K3FpVq4Bb23uAs4BV7bURuBYGYQJcBrwNOBW4bH+gtDEfGNruwGNJkhbYjOFQVbcD+w4orwe2tOUtwDlD9RtqYCdwTJLjgDOBHVW1r6qeAnYA69q611TVzqoq4IahfUmSxmS29xyOrao9bflHwLFt+XjgsaFxu1rtxeq7pqhLksZozjek22/8NQ+9zCjJxiQTSSb27t27EIeUpCPSbMPh8XZJiPbziVbfDawYGndCq71Y/YQp6lOqqk1Vtaaq1ixfvnyWrUuSZjLbcNgK7H/iaANwy1D9gvbU0lrg6Xb5aTtwRpKl7Ub0GcD2tu6ZJGvbU0oXDO1LkjQmS2YakOSrwGnAsiS7GDx19Bng5iQXAT8Azm3DtwFnA5PAs8CFAFW1L8kVwJ1t3OVVtf8m94cZPBH1SuDb7SVJGqMZw6Gqzp9m1elTjC3g4mn2sxnYPEV9Ajhppj4kSQvHT0hLkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpM6dwSPJoknuT3J1kotVem2RHkofbz6WtniRXJ5lMck+SU4b2s6GNfzjJhrlNSZI0V/Nx5vCOqjq5qta095cCt1bVKuDW9h7gLGBVe20EroVBmACXAW8DTgUu2x8okqTxOBiXldYDW9ryFuCcofoNNbATOCbJccCZwI6q2ldVTwE7gHUHoS9J0ojmGg4F/EOSf02ysdWOrao9bflHwLFt+XjgsaFtd7XadPVOko1JJpJM7N27d46tS5Kms2SO2/9+Ve1O8npgR5J/H15ZVZWk5niM4f1tAjYBrFmzZt72K0l6oTmdOVTV7vbzCeCbDO4ZPN4uF9F+PtGG7wZWDG1+QqtNV5ckjcmswyHJryV59f5l4AzgPmArsP+Jow3ALW15K3BBe2ppLfB0u/y0HTgjydJ2I/qMVpMkjclcLisdC3wzyf79fKWq/j7JncDNSS4CfgCc28ZvA84GJoFngQsBqmpfkiuAO9u4y6tq3xz6kiTN0azDoaoeAX53ivqTwOlT1Au4eJp9bQY2z7YXSdL88hPSkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqTOIRMOSdYleSjJZJJLx92PJB3JDolwSHIUcA1wFrAaOD/J6vF2JUlHrkMiHIBTgcmqeqSqfgbcCKwfc0+SdMQ6VMLheOCxofe7Wk2SNAZLxt3AS5FkI7Cxvf2fJA/NclfLgP+en65Gl6sW+ogvMJY5j5lzPvwdafOFuc155O0OlXDYDawYen9Cq71AVW0CNs31YEkmqmrNXPezmDjnI8ORNucjbb6wcHM+VC4r3QmsSnJikqOB84CtY+5Jko5Yh8SZQ1U9l+QSYDtwFLC5qu4fc1uSdMQ6JMIBoKq2AdsW6HBzvjS1CDnnI8ORNucjbb6wQHNOVS3EcSRJi8ihcs9BknQIOazDYaav5EjyiiQ3tfV3JFm58F3OnxHm++dJHkhyT5Jbk/zmOPqcT6N+7UqSP0lSSRb9ky2jzDnJue3v+v4kX1noHufbCP+2fyPJbUnuav++zx5Hn/MpyeYkTyS5b5r1SXJ1+zO5J8kp89pAVR2WLwY3tv8T+C3gaODfgNUHjPkw8MW2fB5w07j7PsjzfQfwq235Q4t5vqPOuY17NXA7sBNYM+6+F+DveRVwF7C0vX/9uPtegDlvAj7UllcDj46773mY9x8ApwD3TbP+bODbQIC1wB3zefzD+cxhlK/kWA9sactfA05PkgXscT7NON+quq2qnm1vdzL4PMliNurXrlwBXAX870I2d5CMMucPANdU1VMAVfXEAvc430aZcwGvacu/DvzXAvZ3UFTV7cC+FxmyHrihBnYCxyQ5br6OfziHwyhfyfGLMVX1HPA08LoF6W7+vdSvILmIwW8di9mMc26n2iuq6lsL2dhBNMrf85uANyX55yQ7k6xbsO4OjlHm/CngvUl2MXjq8U8XprWxOqhfO3TIPMqqhZPkvcAa4A/H3cvBlORlwOeA94+5lYW2hMGlpdMYnB3enuR3qurHY+3q4DofuL6qPpvk94C/SXJSVf183I0tVofzmcMoX8nxizFJljA4HX1yQbqbfyN9BUmSPwL+EnhXVf10gXo7WGaa86uBk4B/SvIog+uyWxf5TelR/p53AVur6v+q6vvAfzAIi8VqlDlfBNwMUFX/AvwKg+8gOpyN9N/8bB3O4TDKV3JsBTa05fcA/1jtTs8iNON8k7wF+GsGwbDYr0PDDHOuqqerallVrayqlQzus7yrqibG0+68GOXf9d8xOGsgyTIGl5keWcgm59koc/4hcDpAkt9mEA57F7TLhbcVuKA9tbQWeLqq9szXzg/by0o1zVdyJLkcmKiqrcB1DE4/Jxnc+DlvfB3PzYjz/SvgVcDftvvuP6yqd42t6Tkacc6HlRHnvB04I8kDwPPAX1TVYj0jHnXOHwO+lOTPGNycfv8i/kUPgCRfZRDyy9q9lMuAlwNU1RcZ3Fs5G5gEngUunNfjL/I/P0nSQXA4X1aSJM2S4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6vw/asn5cMei2roAAAAASUVORK5CYII=\n", 1096 | "text/plain": [ 1097 | "
" 1098 | ] 1099 | }, 1100 | "metadata": {}, 1101 | "output_type": "display_data" 1102 | } 1103 | ], 1104 | "source": [ 1105 | "plt.hist(predictions)\n", 1106 | "plt.show()" 1107 | ] 1108 | }, 1109 | { 1110 | "cell_type": "markdown", 1111 | "metadata": {}, 1112 | "source": [ 1113 | "By varying the cutoff threshold, we can trade false positives for false negatives. " 1114 | ] 1115 | }, 1116 | { 1117 | "cell_type": "code", 1118 | "execution_count": 45, 1119 | "metadata": {}, 1120 | "outputs": [ 1121 | { 1122 | "data": { 1123 | "text/html": [ 1124 | "
\n", 1125 | "\n", 1138 | "\n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | "
col_001
Class
02841116
1747
\n", 1164 | "
" 1165 | ], 1166 | "text/plain": [ 1167 | "col_0 0 1\n", 1168 | "Class \n", 1169 | "0 28411 16\n", 1170 | "1 7 47" 1171 | ] 1172 | }, 1173 | "execution_count": 45, 1174 | "metadata": {}, 1175 | "output_type": "execute_result" 1176 | } 1177 | ], 1178 | "source": [ 1179 | "pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > 0.04, 1, 0))" 1180 | ] 1181 | }, 1182 | { 1183 | "cell_type": "code", 1184 | "execution_count": 46, 1185 | "metadata": {}, 1186 | "outputs": [ 1187 | { 1188 | "name": "stdout", 1189 | "output_type": "stream", 1190 | "text": [ 1191 | "precision: 0.75\n", 1192 | "recall: 0.87\n" 1193 | ] 1194 | } 1195 | ], 1196 | "source": [ 1197 | "results = precision_recall_fscore_support(test_data.iloc[:, 0],\n", 1198 | " np.where(predictions > 0.04, 1, 0))\n", 1199 | "print('precision: ', round(results[0][1], 2))\n", 1200 | "print('recall: ', round(results[1][1], 2))" 1201 | ] 1202 | }, 1203 | { 1204 | "cell_type": "code", 1205 | "execution_count": 47, 1206 | "metadata": {}, 1207 | "outputs": [ 1208 | { 1209 | "data": { 1210 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD8CAYAAACcjGjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEOpJREFUeJzt3H+s3XV9x/HnSyrOTR3VVkKgW5nWZJVliA12cdlQFigssZgZAolSCbFGYdHNLKL7AwOSSBY1IUFcDQ1lUYH5YzSxrmsYC3FZkTth/BzjDlHaVegogguZDnzvj/OpHvq5l3u49/ae3vb5SE7u97y/n+/3+/60hdf9/jgnVYUkScNeNu4GJEmHHsNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJnSXjbmC2li1bVitXrhx3G5K0aCxbtozt27dvr6p1M41dtOGwcuVKJiYmxt2GJC0qSZaNMs7LSpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkzqL9hPRcrLz0W2M57qOf+eOxHFeSXirPHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktSZMRySrEhyW5IHktyf5COt/qkku5Pc3V5nD23ziSSTSR5KcuZQfV2rTSa5dKh+YpI7Wv2mJEfP90QlSaMb5czhOeBjVbUaWAtcnGR1W/f5qjq5vbYBtHXnAW8G1gFfSHJUkqOAa4CzgNXA+UP7uart643AU8BF8zQ/SdIszBgOVbWnqr7Xln8CPAgc/yKbrAdurKqfVtX3gUng1PaarKpHqupnwI3A+iQB3gl8rW2/BThnthOSJM3dS7rnkGQl8Bbgjla6JMk9STYnWdpqxwOPDW22q9Wmq78O+HFVPXdAXZI0JiOHQ5JXAV8HPlpVzwDXAm8ATgb2AJ89KB2+sIeNSSaSTOzdu/dgH06SjlgjhUOSlzMIhi9X1TcAqurxqnq+qn4OfInBZSOA3cCKoc1PaLXp6k8CxyRZckC9U1WbqmpNVa1Zvnz5KK1LkmZhlKeVAlwHPFhVnxuqHzc07N3AfW15K3BeklckORFYBXwXuBNY1Z5MOprBTeutVVXAbcB72vYbgFvmNi1J0lwsmXkIbwfeB9yb5O5W+ySDp41OBgp4FPggQFXdn+Rm4AEGTzpdXFXPAyS5BNgOHAVsrqr72/4+DtyY5NPAXQzCSJI0JjOGQ1V9B8gUq7a9yDZXAldOUd821XZV9Qi/vCwlSRozPyEtSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSerMGA5JViS5LckDSe5P8pFWf22SHUkebj+XtnqSXJ1kMsk9SU4Z2teGNv7hJBuG6m9Ncm/b5uokORiTlSSNZpQzh+eAj1XVamAtcHGS1cClwK1VtQq4tb0HOAtY1V4bgWthECbAZcDbgFOBy/YHShvzgaHt1s19apKk2ZoxHKpqT1V9ry3/BHgQOB5YD2xpw7YA57Tl9cANNbATOCbJccCZwI6q2ldVTwE7gHVt3WuqamdVFXDD0L4kSWPwku45JFkJvAW4Azi2qva0VT8Cjm3LxwOPDW22q9VerL5rirokaUxGDockrwK+Dny0qp4ZXtd+46957m2qHjYmmUgysXfv3oN9OEk6Yo0UDkleziAYvlxV32jlx9slIdrPJ1p9N7BiaPMTWu3F6idMUe9U1aaqWlNVa5YvXz5K65KkWRjlaaUA1wEPVtXnhlZtBfY/cbQBuGWofkF7amkt8HS7/LQdOCPJ0nYj+gxge1v3TJK17VgXDO1LkjQGS0YY83bgfcC9Se5utU8CnwFuTnIR8APg3LZuG3A2MAk8C1wIUFX7klwB3NnGXV5V+9ryh4HrgVcC324vSdKYzBgOVfUdYLrPHZw+xfgCLp5mX5uBzVPUJ4CTZupFkrQw/IS0JKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOjOGQ5LNSZ5Ict9Q7VNJdie5u73OHlr3iSSTSR5KcuZQfV2rTSa5dKh+YpI7Wv2mJEfP5wQlSS/dKGcO1wPrpqh/vqpObq9tAElWA+cBb27bfCHJUUmOAq4BzgJWA+e3sQBXtX29EXgKuGguE5Ikzd2M4VBVtwP7RtzfeuDGqvppVX0fmAROba/Jqnqkqn4G3AisTxLgncDX2vZbgHNe4hwkSfNsLvccLklyT7vstLTVjgceGxqzq9Wmq78O+HFVPXdAXZI0RrMNh2uBNwAnA3uAz85bRy8iycYkE0km9u7duxCHlKQj0qzCoaoer6rnq+rnwJcYXDYC2A2sGBp6QqtNV38SOCbJkgPq0x13U1Wtqao1y5cvn03rkqQRzCockhw39PbdwP4nmbYC5yV5RZITgVXAd4E7gVXtyaSjGdy03lpVBdwGvKdtvwG4ZTY9SZLmz5KZBiT5KnAasCzJLuAy4LQkJwMFPAp8EKCq7k9yM/AA8BxwcVU93/ZzCbAdOArYXFX3t0N8HLgxyaeBu4Dr5m12kqRZmTEcqur8KcrT/g+8qq4Erpyivg3YNkX9EX55WUqSdAjwE9KSpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpI7hIEnqGA6SpM6M4ZBkc5Inktw3VHttkh1JHm4/l7Z6klydZDLJPUlOGdpmQxv/cJINQ/W3Jrm3bXN1ksz3JCVJL80oZw7XA+sOqF0K3FpVq4Bb23uAs4BV7bURuBYGYQJcBrwNOBW4bH+gtDEfGNruwGNJkhbYjOFQVbcD+w4orwe2tOUtwDlD9RtqYCdwTJLjgDOBHVW1r6qeAnYA69q611TVzqoq4IahfUmSxmS29xyOrao9bflHwLFt+XjgsaFxu1rtxeq7pqhLksZozjek22/8NQ+9zCjJxiQTSSb27t27EIeUpCPSbMPh8XZJiPbziVbfDawYGndCq71Y/YQp6lOqqk1Vtaaq1ixfvnyWrUuSZjLbcNgK7H/iaANwy1D9gvbU0lrg6Xb5aTtwRpKl7Ub0GcD2tu6ZJGvbU0oXDO1LkjQmS2YakOSrwGnAsiS7GDx19Bng5iQXAT8Azm3DtwFnA5PAs8CFAFW1L8kVwJ1t3OVVtf8m94cZPBH1SuDb7SVJGqMZw6Gqzp9m1elTjC3g4mn2sxnYPEV9Ajhppj4kSQvHT0hLkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpM6dwSPJoknuT3J1kotVem2RHkofbz6WtniRXJ5lMck+SU4b2s6GNfzjJhrlNSZI0V/Nx5vCOqjq5qta095cCt1bVKuDW9h7gLGBVe20EroVBmACXAW8DTgUu2x8okqTxOBiXldYDW9ryFuCcofoNNbATOCbJccCZwI6q2ldVTwE7gHUHoS9J0ojmGg4F/EOSf02ysdWOrao9bflHwLFt+XjgsaFtd7XadPVOko1JJpJM7N27d46tS5Kms2SO2/9+Ve1O8npgR5J/H15ZVZWk5niM4f1tAjYBrFmzZt72K0l6oTmdOVTV7vbzCeCbDO4ZPN4uF9F+PtGG7wZWDG1+QqtNV5ckjcmswyHJryV59f5l4AzgPmArsP+Jow3ALW15K3BBe2ppLfB0u/y0HTgjydJ2I/qMVpMkjclcLisdC3wzyf79fKWq/j7JncDNSS4CfgCc28ZvA84GJoFngQsBqmpfkiuAO9u4y6tq3xz6kiTN0azDoaoeAX53ivqTwOlT1Au4eJp9bQY2z7YXSdL88hPSkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqTOIRMOSdYleSjJZJJLx92PJB3JDolwSHIUcA1wFrAaOD/J6vF2JUlHrkMiHIBTgcmqeqSqfgbcCKwfc0+SdMQ6VMLheOCxofe7Wk2SNAZLxt3AS5FkI7Cxvf2fJA/NclfLgP+en65Gl6sW+ogvMJY5j5lzPvwdafOFuc155O0OlXDYDawYen9Cq71AVW0CNs31YEkmqmrNXPezmDjnI8ORNucjbb6wcHM+VC4r3QmsSnJikqOB84CtY+5Jko5Yh8SZQ1U9l+QSYDtwFLC5qu4fc1uSdMQ6JMIBoKq2AdsW6HBzvjS1CDnnI8ORNucjbb6wQHNOVS3EcSRJi8ihcs9BknQIOazDYaav5EjyiiQ3tfV3JFm58F3OnxHm++dJHkhyT5Jbk/zmOPqcT6N+7UqSP0lSSRb9ky2jzDnJue3v+v4kX1noHufbCP+2fyPJbUnuav++zx5Hn/MpyeYkTyS5b5r1SXJ1+zO5J8kp89pAVR2WLwY3tv8T+C3gaODfgNUHjPkw8MW2fB5w07j7PsjzfQfwq235Q4t5vqPOuY17NXA7sBNYM+6+F+DveRVwF7C0vX/9uPtegDlvAj7UllcDj46773mY9x8ApwD3TbP+bODbQIC1wB3zefzD+cxhlK/kWA9sactfA05PkgXscT7NON+quq2qnm1vdzL4PMliNurXrlwBXAX870I2d5CMMucPANdU1VMAVfXEAvc430aZcwGvacu/DvzXAvZ3UFTV7cC+FxmyHrihBnYCxyQ5br6OfziHwyhfyfGLMVX1HPA08LoF6W7+vdSvILmIwW8di9mMc26n2iuq6lsL2dhBNMrf85uANyX55yQ7k6xbsO4OjlHm/CngvUl2MXjq8U8XprWxOqhfO3TIPMqqhZPkvcAa4A/H3cvBlORlwOeA94+5lYW2hMGlpdMYnB3enuR3qurHY+3q4DofuL6qPpvk94C/SXJSVf183I0tVofzmcMoX8nxizFJljA4HX1yQbqbfyN9BUmSPwL+EnhXVf10gXo7WGaa86uBk4B/SvIog+uyWxf5TelR/p53AVur6v+q6vvAfzAIi8VqlDlfBNwMUFX/AvwKg+8gOpyN9N/8bB3O4TDKV3JsBTa05fcA/1jtTs8iNON8k7wF+GsGwbDYr0PDDHOuqqerallVrayqlQzus7yrqibG0+68GOXf9d8xOGsgyTIGl5keWcgm59koc/4hcDpAkt9mEA57F7TLhbcVuKA9tbQWeLqq9szXzg/by0o1zVdyJLkcmKiqrcB1DE4/Jxnc+DlvfB3PzYjz/SvgVcDftvvuP6yqd42t6Tkacc6HlRHnvB04I8kDwPPAX1TVYj0jHnXOHwO+lOTPGNycfv8i/kUPgCRfZRDyy9q9lMuAlwNU1RcZ3Fs5G5gEngUunNfjL/I/P0nSQXA4X1aSJM2S4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6vw/asn5cMei2roAAAAASUVORK5CYII=\n", 1211 | "text/plain": [ 1212 | "
" 1213 | ] 1214 | }, 1215 | "metadata": {}, 1216 | "output_type": "display_data" 1217 | } 1218 | ], 1219 | "source": [ 1220 | "plt.hist(predictions)\n", 1221 | "plt.show()" 1222 | ] 1223 | }, 1224 | { 1225 | "cell_type": "markdown", 1226 | "metadata": {}, 1227 | "source": [ 1228 | "### Relative cost of errors\n", 1229 | "\n", 1230 | "Any practical binary classification problem is likely to produce a similarly sensitive cutoff. \n", 1231 | "If we put an ML model into production, there are costs associated with the model erroneously assigning false positives and false negatives. Because the choice of the cutoff affects all four of these statistics, we need to consider the relative costs to the business for each of these four outcomes for each prediction.\n", 1232 | "\n", 1233 | "#### Assigning costs\n", 1234 | "\n", 1235 | "What are the costs for our problem fraud detection? The costs, of course, depend on the specific actions that the business takes. Let's make some assumptions here.\n", 1236 | "\n", 1237 | "First, assign the cost of \\$0.00 to both the true negatives (correctly recognized benign transactions) and true positives (correctly recognized fraudulent transactions). Our model essentially correctly identified both situations. One can assign a benefit (i.e. negative cost) to correctly detected fraud, but we are not going to do this here.\n", 1238 | "\n", 1239 | "False negatives are the most problematic, because they represent a fraudulent transactions that slipped through our model. Based on some Internet research (see sources below), we assign a cost of \\$450.00 for each one. This is the cost of false negatives.\n", 1240 | "\n", 1241 | "Finally, False positives are the genuine transactions that our model would block as fraud. This would result in an annoyed customer that might possibly close the credit card account and move to another bank. We assume that it costs a \\$500.00 sign-on bonus to obtain a cr. card customer and that \\5 percent of annoyed customers would defect. \n", 1242 | "\n", 1243 | "Source:\n", 1244 | "\n", 1245 | "https://www.creditcards.com/credit-card-news/credit-card-security-id-theft-fraud-statistics-1276.php\n", 1246 | "https://wallethub.com/edu/cc/credit-debit-card-fraud-statistics/25725/\n" 1247 | ] 1248 | }, 1249 | { 1250 | "cell_type": "markdown", 1251 | "metadata": {}, 1252 | "source": [ 1253 | "#### Finding the optimal cutoff\n", 1254 | "\n", 1255 | "It’s clear that false negatives are substantially more costly than false positives. We should be minimizing a cost function that looks like this:\n", 1256 | "\n", 1257 | "```txt\n", 1258 | "$450 * FN(C) + $0 * TN(C) + 0.05*$500 * FP(C) + $0 * TP(C)\n", 1259 | "```\n", 1260 | "\n", 1261 | "FN(C) means that the false negative percentage is a function of the cutoff, C, and similar for TN, FP, and TP. We need to find the cutoff, C, where the result of the expression is smallest.\n", 1262 | "\n", 1263 | "A straightforward way to do this, is to simply run a simulation over a large number of possible cutoffs. We test 100 possible values in the for loop below." 1264 | ] 1265 | }, 1266 | { 1267 | "cell_type": "code", 1268 | "execution_count": 48, 1269 | "metadata": {}, 1270 | "outputs": [ 1271 | { 1272 | "data": { 1273 | "image/png": "\n", 1274 | "text/plain": [ 1275 | "
" 1276 | ] 1277 | }, 1278 | "metadata": {}, 1279 | "output_type": "display_data" 1280 | }, 1281 | { 1282 | "name": "stdout", 1283 | "output_type": "stream", 1284 | "text": [ 1285 | "Cost is minimized near a cutoff of: 0.04\n" 1286 | ] 1287 | } 1288 | ], 1289 | "source": [ 1290 | "TN_cost = 0\n", 1291 | "TP_cost = 0\n", 1292 | "FP_cost = 0.05*500 #$cost of losing an annoyed customer (assuming 5% defection and $500 sign-on bonus)\n", 1293 | "FN_cost = 450 # $cost of of letting a fradulent transaction slip through\n", 1294 | "\n", 1295 | "cutoffs = np.arange(0.01, 1, 0.01)\n", 1296 | "costs = []\n", 1297 | "for c in cutoffs:\n", 1298 | " costs.append(np.sum(np.sum(np.array([[TN_cost, FP_cost], [FN_cost, TP_cost]]) * \n", 1299 | " pd.crosstab(index=test_data.iloc[:, 0], \n", 1300 | " columns=np.where(predictions > c, 1, 0)))))\n", 1301 | "\n", 1302 | "costs = np.array(costs)\n", 1303 | "plt.plot(cutoffs, costs)\n", 1304 | "plt.show()\n", 1305 | "print('Cost is minimized near a cutoff of:', cutoffs[np.argmin(costs)])" 1306 | ] 1307 | }, 1308 | { 1309 | "cell_type": "markdown", 1310 | "metadata": {}, 1311 | "source": [ 1312 | "## Clean up\n", 1313 | "\n", 1314 | "We will leave the prediction endpoint running at the end of this notebook so we can handle incoming event streams. However, don't forget to delete the prediction endpoint when you're done. You can do that at the Amazon SageMaker console in the Endpoints page. Or you can run `xgb_predictor.delete_endpoint()`" 1315 | ] 1316 | }, 1317 | { 1318 | "cell_type": "code", 1319 | "execution_count": 49, 1320 | "metadata": {}, 1321 | "outputs": [], 1322 | "source": [ 1323 | "#xgb_predictor.delete_endpoint()" 1324 | ] 1325 | }, 1326 | { 1327 | "cell_type": "markdown", 1328 | "metadata": {}, 1329 | "source": [ 1330 | "\n", 1331 | "## Data Acknowledgements\n", 1332 | "\n", 1333 | "The dataset used to demonstrated the fraud detection solution has been collected and analysed during a research collaboration of Worldline and the Machine Learning Group (http://mlg.ulb.ac.be) of ULB (Université Libre de Bruxelles) on big data mining and fraud detection. More details on current and past projects on related topics are available on https://www.researchgate.net/project/Fraud-detection-5 and the page of the [DefeatFraud](https://mlg.ulb.ac.be/wordpress/portfolio_page/defeatfraud-assessment-and-validation-of-deep-feature-engineering-and-learning-solutions-for-fraud-detection/) project\n", 1334 | "We cite the following works:\n", 1335 | "* Andrea Dal Pozzolo, Olivier Caelen, Reid A. Johnson and Gianluca Bontempi. Calibrating Probability with Undersampling for Unbalanced Classification. In Symposium on Computational Intelligence and Data Mining (CIDM), IEEE, 2015\n", 1336 | "* Dal Pozzolo, Andrea; Caelen, Olivier; Le Borgne, Yann-Ael; Waterschoot, Serge; Bontempi, Gianluca. Learned lessons in credit card fraud detection from a practitioner perspective, Expert systems with applications,41,10,4915-4928,2014, Pergamon\n", 1337 | "* Dal Pozzolo, Andrea; Boracchi, Giacomo; Caelen, Olivier; Alippi, Cesare; Bontempi, Gianluca. Credit card fraud detection: a realistic modeling and a novel learning strategy, IEEE transactions on neural networks and learning systems,29,8,3784-3797,2018,IEEE\n", 1338 | "* Dal Pozzolo, Andrea Adaptive Machine learning for credit card fraud detection ULB MLG PhD thesis (supervised by G. Bontempi)\n", 1339 | "* Carcillo, Fabrizio; Dal Pozzolo, Andrea; Le Borgne, Yann-Aël; Caelen, Olivier; Mazzer, Yannis; Bontempi, Gianluca. Scarff: a scalable framework for streaming credit card fraud detection with Spark, Information fusion,41, 182-194,2018,Elsevier\n", 1340 | "* Carcillo, Fabrizio; Le Borgne, Yann-Aël; Caelen, Olivier; Bontempi, Gianluca. Streaming active learning strategies for real-life credit card fraud detection: assessment and visualization, International Journal of Data Science and Analytics, 5,4,285-300,2018,Springer International Publishing" 1341 | ] 1342 | } 1343 | ], 1344 | "metadata": { 1345 | "kernelspec": { 1346 | "display_name": "conda_mxnet_p36", 1347 | "language": "python", 1348 | "name": "conda_mxnet_p36" 1349 | }, 1350 | "language_info": { 1351 | "codemirror_mode": { 1352 | "name": "ipython", 1353 | "version": 3 1354 | }, 1355 | "file_extension": ".py", 1356 | "mimetype": "text/x-python", 1357 | "name": "python", 1358 | "nbconvert_exporter": "python", 1359 | "pygments_lexer": "ipython3", 1360 | "version": "3.6.5" 1361 | } 1362 | }, 1363 | "nbformat": 4, 1364 | "nbformat_minor": 2 1365 | } 1366 | -------------------------------------------------------------------------------- /notebooks/train_nn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import tensorflow as tf 5 | from tensorflow.contrib.eager.python import tfe 6 | from keras.models import Sequential 7 | from keras.layers import Dense, Dropout, Activation 8 | from keras.layers import Embedding 9 | from keras.layers import Conv1D, GlobalMaxPooling1D 10 | import pandas as pd 11 | from sklearn.metrics import classification_report 12 | from sklearn.metrics import precision_score,accuracy_score 13 | 14 | tf.logging.set_verbosity(tf.logging.ERROR) 15 | 16 | max_features = 20000 17 | maxlen = 400 18 | embedding_dims = 300 19 | filters = 250 20 | kernel_size = 3 21 | hidden_dims = 250 22 | 23 | 24 | def parse_args(): 25 | 26 | parser = argparse.ArgumentParser() 27 | 28 | # hyperparameters sent by the client are passed as command-line arguments to the script 29 | parser.add_argument('--epochs', type=int, default=1) 30 | parser.add_argument('--batch_size', type=int, default=64) 31 | 32 | # data directories 33 | parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAINING')) 34 | parser.add_argument('--test', type=str, default=os.environ.get('SM_CHANNEL_VALIDATION')) 35 | 36 | # model directory: we will use the default set by SageMaker, /opt/ml/model 37 | parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR')) 38 | 39 | return parser.parse_known_args() 40 | 41 | 42 | def get_train_data(train_dir): 43 | 44 | #print('train_dir = ',train_dir) 45 | 46 | train = pd.read_csv(os.path.join(train_dir,'train.csv'), delimiter=',') 47 | #print(train.head(10)) 48 | 49 | x_train = train.iloc[:, 0:28] 50 | y_train = train.iloc[:,28:29] 51 | 52 | #print('x train', x_train.shape,'y train', y_train.shape) 53 | 54 | 55 | return x_train, y_train 56 | 57 | 58 | def get_test_data(test_dir): 59 | 60 | test = pd.read_csv(os.path.join(test_dir,'validation.csv'), delimiter=',') 61 | #print(test.head(10)) 62 | 63 | x_test = test.iloc[:,0:28] 64 | y_test = test.iloc[:,28:29] 65 | #print('x test', x_test.shape,'y test', y_test.shape) 66 | 67 | return x_test, y_test 68 | 69 | 70 | def get_model(): 71 | 72 | model = tf.keras.models.Sequential([ 73 | tf.keras.layers.Dense(28, input_dim=28, activation='relu'), 74 | tf.keras.layers.Dense(8, activation='relu'), 75 | tf.keras.layers.Dense(1, activation='sigmoid') 76 | ]) 77 | 78 | print(model.summary) 79 | 80 | model.compile(optimizer='adam', 81 | loss='binary_crossentropy', 82 | metrics=['accuracy']) 83 | 84 | 85 | return model 86 | 87 | 88 | if __name__ == "__main__": 89 | 90 | args, _ = parse_args() 91 | 92 | print(args) 93 | 94 | x_train, y_train = get_train_data(args.train) 95 | x_test, y_test = get_test_data(args.test) 96 | 97 | model = get_model() 98 | 99 | model.summary() 100 | 101 | model.fit( 102 | x_train, 103 | y_train, 104 | epochs=args.epochs, 105 | batch_size=args.batch_size, 106 | validation_data=(x_test, y_test) 107 | ) 108 | 109 | # create a TensorFlow SavedModel for deployment to a SageMaker endpoint with TensorFlow Serving 110 | 111 | tf.contrib.saved_model.save_keras_model(model, args.model_dir) --------------------------------------------------------------------------------