├── .gitignore ├── README.md ├── churn_predictor.ipynb ├── churn_predictor.py ├── dataset └── .gitignore └── download_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Churn Predictor Sample Project 2 | 3 | > This is a Turi sample for churn prediction. Explore the [gallery](https://turi.com/learn/gallery/) to see other examples. 4 | 5 | This sample code shows how to build, evaluate, and deploy a 6 | model to predict customer churn. You could use this model to find 7 | customers who are likely to stop using your product or service. 8 | 9 | 10 | ## Get started 11 | 12 | 1. Before you begin, make sure you have [installed GraphLab Create 1.9](https://turi.com/download/), 13 | a Python package for machine learning. 14 | 15 | 2. [Download and extract the example code](https://github.com/turi-code/sample-churn-predictor/archive/master.zip) 16 | to a directory on your machine, or clone it with the following command: 17 | 18 | ```bash 19 | git clone http://github.com/turi-code/sample-churn-predictor 20 | cd sample-churn-predictor 21 | ``` 22 | 23 | 3. While in the `sample-churn-predictor` directory, run the following script 24 | to download the sample project data: 25 | 26 | ```bash 27 | python download_data.py 28 | ``` 29 | 30 | 4. Making sure you are working in a Python environment with GraphLab Create installed, 31 | run the `churn_predictor.py` script to build and explore the model on your machine: 32 | 33 | ```bash 34 | python -i churn_predictor.py 35 | ``` 36 | 37 | The `-i` flag causes Python to drop into an interactive interpreter 38 | after the script executes. 39 | 40 | Once the model has been created, a browser window should open 41 | to let you explore and interact with your model. 42 | 43 | Alternatively, you can also run the provided IPython Notebook: 44 | 45 | ```bash 46 | ipython notebook churn_predictor.ipynb 47 | ``` 48 | 49 | 50 | ## Troubleshooting 51 | 52 | If you are having trouble, please [create a Github Issue](https://github.com/turi-code/sample-churn-predictor/issues/new) 53 | or start a discussion on the [user forum](http://forum.turi.com/). 54 | -------------------------------------------------------------------------------- /churn_predictor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stderr", 12 | "output_type": "stream", 13 | "text": [ 14 | "2016-04-19 14:02:12,212 [INFO] graphlab.deploy._session, 584: Using session dir: /Users/zach/graphlab-dev2/local_scripts/../artifacts\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import graphlab as gl\n", 20 | "import datetime\n", 21 | "from dateutil import parser as datetime_parser" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Load Data" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "collapsed": false 36 | }, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "This commercial license of GraphLab Create is assigned to engr@turi.com.\n" 43 | ] 44 | }, 45 | { 46 | "name": "stderr", 47 | "output_type": "stream", 48 | "text": [ 49 | "2016-04-19 14:02:15,541 [INFO] graphlab.cython.cy_server, 176: GraphLab Create v1.9 started. Logging: /tmp/graphlab_server_1461099733.log\n" 50 | ] 51 | }, 52 | { 53 | "data": { 54 | "text/html": [ 55 | "
Finished parsing file /Users/zach/sample-churn-predictor/dataset/online_retail.csv
" 56 | ], 57 | "text/plain": [ 58 | "Finished parsing file /Users/zach/sample-churn-predictor/dataset/online_retail.csv" 59 | ] 60 | }, 61 | "metadata": {}, 62 | "output_type": "display_data" 63 | }, 64 | { 65 | "data": { 66 | "text/html": [ 67 | "
Parsing completed. Parsed 100 lines in 8.38356 secs.
" 68 | ], 69 | "text/plain": [ 70 | "Parsing completed. Parsed 100 lines in 8.38356 secs." 71 | ] 72 | }, 73 | "metadata": {}, 74 | "output_type": "display_data" 75 | }, 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "------------------------------------------------------\n", 81 | "Inferred types from first line of file as \n", 82 | "column_type_hints=[int,str,str,int,str,float,int,str]\n", 83 | "If parsing fails due to incorrect types, you can correct\n", 84 | "the inferred type list above and pass it to read_csv in\n", 85 | "the column_type_hints argument\n", 86 | "------------------------------------------------------\n" 87 | ] 88 | }, 89 | { 90 | "data": { 91 | "text/html": [ 92 | "
Finished parsing file /Users/zach/sample-churn-predictor/dataset/online_retail.csv
" 93 | ], 94 | "text/plain": [ 95 | "Finished parsing file /Users/zach/sample-churn-predictor/dataset/online_retail.csv" 96 | ] 97 | }, 98 | "metadata": {}, 99 | "output_type": "display_data" 100 | }, 101 | { 102 | "data": { 103 | "text/html": [ 104 | "
Parsing completed. Parsed 541909 lines in 4.73211 secs.
" 105 | ], 106 | "text/plain": [ 107 | "Parsing completed. Parsed 541909 lines in 4.73211 secs." 108 | ] 109 | }, 110 | "metadata": {}, 111 | "output_type": "display_data" 112 | } 113 | ], 114 | "source": [ 115 | "# Table of product purchases\n", 116 | "purchases = gl.SFrame.read_csv('dataset/online_retail.csv')" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "# Prepare Data\n", 124 | "\n", 125 | "Convert the datetime strings to Python datetimes and create a GraphLab Create TimeSeries\n", 126 | "from the `InvoiceDate` column." 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 3, 132 | "metadata": { 133 | "collapsed": true 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "# Convert InvoiceDate strings (e.g. \"12/1/10 8:26\") to datetimes\n", 138 | "purchases['InvoiceDate'] = purchases['InvoiceDate'].apply(datetime_parser.parse)\n", 139 | "\n", 140 | "# Create a TimeSeries\n", 141 | "timeseries = gl.TimeSeries(purchases, 'InvoiceDate')" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "# Train Churn Predictor Model\n", 149 | "\n", 150 | "A churn forecast requires a time boundary and a churn period.\n", 151 | "Activity before the boundary is used to train the model.\n", 152 | "After the boundary, activity (or lack of activity)\n", 153 | "during the churn period is used to define whether the\n", 154 | "user churned." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 4, 160 | "metadata": { 161 | "collapsed": false 162 | }, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "PROGRESS: Grouping observation_data by user.\n", 169 | "PROGRESS: Resampling grouped observation_data by time-period 1 day, 0:00:00.\n", 170 | "PROGRESS: Generating features for time-boundary.\n", 171 | "PROGRESS: --------------------------------------------------\n", 172 | "PROGRESS: Features for 2011-07-31 17:00:00.\n", 173 | "PROGRESS: Training a classifier model.\n" 174 | ] 175 | }, 176 | { 177 | "data": { 178 | "text/html": [ 179 | "
Boosted trees classifier:
" 180 | ], 181 | "text/plain": [ 182 | "Boosted trees classifier:" 183 | ] 184 | }, 185 | "metadata": {}, 186 | "output_type": "display_data" 187 | }, 188 | { 189 | "data": { 190 | "text/html": [ 191 | "
--------------------------------------------------------
" 192 | ], 193 | "text/plain": [ 194 | "--------------------------------------------------------" 195 | ] 196 | }, 197 | "metadata": {}, 198 | "output_type": "display_data" 199 | }, 200 | { 201 | "data": { 202 | "text/html": [ 203 | "
Number of examples          : 2555
" 204 | ], 205 | "text/plain": [ 206 | "Number of examples : 2555" 207 | ] 208 | }, 209 | "metadata": {}, 210 | "output_type": "display_data" 211 | }, 212 | { 213 | "data": { 214 | "text/html": [ 215 | "
Number of classes           : 2
" 216 | ], 217 | "text/plain": [ 218 | "Number of classes : 2" 219 | ] 220 | }, 221 | "metadata": {}, 222 | "output_type": "display_data" 223 | }, 224 | { 225 | "data": { 226 | "text/html": [ 227 | "
Number of feature columns   : 10
" 228 | ], 229 | "text/plain": [ 230 | "Number of feature columns : 10" 231 | ] 232 | }, 233 | "metadata": {}, 234 | "output_type": "display_data" 235 | }, 236 | { 237 | "data": { 238 | "text/html": [ 239 | "
Number of unpacked features : 145
" 240 | ], 241 | "text/plain": [ 242 | "Number of unpacked features : 145" 243 | ] 244 | }, 245 | "metadata": {}, 246 | "output_type": "display_data" 247 | }, 248 | { 249 | "data": { 250 | "text/html": [ 251 | "
+-----------+--------------+-------------------+-------------------+
" 252 | ], 253 | "text/plain": [ 254 | "+-----------+--------------+-------------------+-------------------+" 255 | ] 256 | }, 257 | "metadata": {}, 258 | "output_type": "display_data" 259 | }, 260 | { 261 | "data": { 262 | "text/html": [ 263 | "
| Iteration | Elapsed Time | Training-accuracy | Training-log_loss |
" 264 | ], 265 | "text/plain": [ 266 | "| Iteration | Elapsed Time | Training-accuracy | Training-log_loss |" 267 | ] 268 | }, 269 | "metadata": {}, 270 | "output_type": "display_data" 271 | }, 272 | { 273 | "data": { 274 | "text/html": [ 275 | "
+-----------+--------------+-------------------+-------------------+
" 276 | ], 277 | "text/plain": [ 278 | "+-----------+--------------+-------------------+-------------------+" 279 | ] 280 | }, 281 | "metadata": {}, 282 | "output_type": "display_data" 283 | }, 284 | { 285 | "data": { 286 | "text/html": [ 287 | "
| 1         | 0.027765     | 0.841879          | 0.568897          |
" 288 | ], 289 | "text/plain": [ 290 | "| 1 | 0.027765 | 0.841879 | 0.568897 |" 291 | ] 292 | }, 293 | "metadata": {}, 294 | "output_type": "display_data" 295 | }, 296 | { 297 | "data": { 298 | "text/html": [ 299 | "
| 2         | 0.052122     | 0.852055          | 0.498716          |
" 300 | ], 301 | "text/plain": [ 302 | "| 2 | 0.052122 | 0.852055 | 0.498716 |" 303 | ] 304 | }, 305 | "metadata": {}, 306 | "output_type": "display_data" 307 | }, 308 | { 309 | "data": { 310 | "text/html": [ 311 | "
| 3         | 0.074933     | 0.856360          | 0.454076          |
" 312 | ], 313 | "text/plain": [ 314 | "| 3 | 0.074933 | 0.856360 | 0.454076 |" 315 | ] 316 | }, 317 | "metadata": {}, 318 | "output_type": "display_data" 319 | }, 320 | { 321 | "data": { 322 | "text/html": [ 323 | "
| 4         | 0.113640     | 0.858317          | 0.424838          |
" 324 | ], 325 | "text/plain": [ 326 | "| 4 | 0.113640 | 0.858317 | 0.424838 |" 327 | ] 328 | }, 329 | "metadata": {}, 330 | "output_type": "display_data" 331 | }, 332 | { 333 | "data": { 334 | "text/html": [ 335 | "
| 5         | 0.136425     | 0.858708          | 0.404621          |
" 336 | ], 337 | "text/plain": [ 338 | "| 5 | 0.136425 | 0.858708 | 0.404621 |" 339 | ] 340 | }, 341 | "metadata": {}, 342 | "output_type": "display_data" 343 | }, 344 | { 345 | "data": { 346 | "text/html": [ 347 | "
| 6         | 0.163361     | 0.861448          | 0.391001          |
" 348 | ], 349 | "text/plain": [ 350 | "| 6 | 0.163361 | 0.861448 | 0.391001 |" 351 | ] 352 | }, 353 | "metadata": {}, 354 | "output_type": "display_data" 355 | }, 356 | { 357 | "data": { 358 | "text/html": [ 359 | "
+-----------+--------------+-------------------+-------------------+
" 360 | ], 361 | "text/plain": [ 362 | "+-----------+--------------+-------------------+-------------------+" 363 | ] 364 | }, 365 | "metadata": {}, 366 | "output_type": "display_data" 367 | }, 368 | { 369 | "name": "stdout", 370 | "output_type": "stream", 371 | "text": [ 372 | "PROGRESS: --------------------------------------------------\n", 373 | "PROGRESS: Model training complete: Next steps\n", 374 | "PROGRESS: --------------------------------------------------\n", 375 | "PROGRESS: (1) Evaluate the model at various timestamps in the past:\n", 376 | "PROGRESS: metrics = model.evaluate(data, time_in_past)\n", 377 | "PROGRESS: (2) Make a churn forecast for a timestamp in the future:\n", 378 | "PROGRESS: predictions = model.predict(data, time_in_future)\n" 379 | ] 380 | } 381 | ], 382 | "source": [ 383 | "# Split the data into train and validation\n", 384 | "train, valid = gl.churn_predictor.random_split(timeseries, user_id='CustomerID', fraction=0.8, seed = 1)\n", 385 | "\n", 386 | "# Train the model using data before August\n", 387 | "churn_boundary_oct = datetime.datetime(year = 2011, month = 8, day = 1)\n", 388 | "# Define churn as \"inactive for 30 days after August 1st 2011\"\n", 389 | "churn_period = datetime.timedelta(days = 30)\n", 390 | "\n", 391 | "model = gl.churn_predictor.create(train, user_id='CustomerID',\n", 392 | " features = ['Quantity'],\n", 393 | " churn_period = churn_period,\n", 394 | " time_boundaries = [churn_boundary_oct])" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "# Explore and Evaluate the Model" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 5, 407 | "metadata": { 408 | "collapsed": false, 409 | "scrolled": false 410 | }, 411 | "outputs": [ 412 | { 413 | "name": "stderr", 414 | "output_type": "stream", 415 | "text": [ 416 | "2016-04-19 14:04:16,633 [WARNING] graphlab.toolkits.churn_predictor._churn_predictor, 1881: This feature is currently in beta. Please use with caution and not in mission-critical applications. For feedback or suggestions on this feature, please e-mail product-feedback@turi.com.\n" 417 | ] 418 | }, 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "PROGRESS: Making a churn forecast for the time window:\n", 424 | "PROGRESS: --------------------------------------------------\n", 425 | "PROGRESS: Start : 2011-08-01 00:00:00\n", 426 | "PROGRESS: End : 2011-08-31 00:00:00\n", 427 | "PROGRESS: --------------------------------------------------\n", 428 | "PROGRESS: Grouping dataset by user.\n", 429 | "PROGRESS: Resampling grouped observation_data by time-period 1 day, 0:00:00.\n", 430 | "PROGRESS: Generating features for boundary 2011-08-01 00:00:00.\n", 431 | "PROGRESS: Not enough data to make predictions for 1170 user(s). \n", 432 | "PROGRESS: Making a churn forecast for the time window:\n", 433 | "PROGRESS: --------------------------------------------------\n", 434 | "PROGRESS: Start : 2011-08-01 00:00:00\n", 435 | "PROGRESS: End : 2011-08-31 00:00:00\n", 436 | "PROGRESS: --------------------------------------------------\n", 437 | "PROGRESS: Grouping dataset by user.\n", 438 | "PROGRESS: Resampling grouped observation_data by time-period 1 day, 0:00:00.\n", 439 | "PROGRESS: Generating features for boundary 2011-08-01 00:00:00.\n", 440 | "PROGRESS: Not enough data to make predictions for 254 user(s). \n", 441 | "PROGRESS: Making a churn forecast for the time window:\n", 442 | "PROGRESS: --------------------------------------------------\n", 443 | "PROGRESS: Start : 2011-08-01 00:00:00\n", 444 | "PROGRESS: End : 2011-08-31 00:00:00\n", 445 | "PROGRESS: --------------------------------------------------\n", 446 | "PROGRESS: Grouping dataset by user.\n", 447 | "PROGRESS: Resampling grouped observation_data by time-period 1 day, 0:00:00.\n", 448 | "PROGRESS: Generating features for boundary 2011-08-01 00:00:00.\n", 449 | "PROGRESS: Not enough data to make predictions for 254 user(s). \n" 450 | ] 451 | } 452 | ], 453 | "source": [ 454 | "# Interactively explore churn predictions\n", 455 | "view = model.views.overview(exploration_set=timeseries,\n", 456 | " validation_set=valid,\n", 457 | " exploration_time=churn_boundary_oct,\n", 458 | " validation_time=churn_boundary_oct)\n", 459 | "view.show()" 460 | ] 461 | } 462 | ], 463 | "metadata": { 464 | "kernelspec": { 465 | "display_name": "Python 2", 466 | "language": "python", 467 | "name": "python2" 468 | }, 469 | "language_info": { 470 | "codemirror_mode": { 471 | "name": "ipython", 472 | "version": 2 473 | }, 474 | "file_extension": ".py", 475 | "mimetype": "text/x-python", 476 | "name": "python", 477 | "nbconvert_exporter": "python", 478 | "pygments_lexer": "ipython2", 479 | "version": "2.7.11" 480 | } 481 | }, 482 | "nbformat": 4, 483 | "nbformat_minor": 0 484 | } 485 | -------------------------------------------------------------------------------- /churn_predictor.py: -------------------------------------------------------------------------------- 1 | import graphlab as gl 2 | import datetime 3 | from dateutil import parser as datetime_parser 4 | 5 | 6 | ### Load Data ### 7 | 8 | # Table of product purchases 9 | purchases = gl.SFrame.read_csv('dataset/online_retail.csv') 10 | 11 | 12 | ### Prepare Data ### 13 | 14 | # Convert InvoiceDate strings (e.g. "12/1/10 8:26") to datetimes 15 | purchases['InvoiceDate'] = purchases['InvoiceDate'].apply(datetime_parser.parse) 16 | 17 | # Create a TimeSeries 18 | timeseries = gl.TimeSeries(purchases, 'InvoiceDate') 19 | 20 | 21 | ### Train the churn predictor model ### 22 | 23 | # Split the data into train and validation 24 | train, valid = gl.churn_predictor.random_split(timeseries, user_id='CustomerID', fraction=0.8, seed = 1) 25 | 26 | # A churn forecast requires a time boundary and a churn period. 27 | # Activity before the boundary is used to train the model. 28 | # After the boundary, activity (or lack of activity) 29 | # during the churn period is used to define whether the 30 | # user churned. 31 | 32 | # Train the model using data before August 33 | churn_boundary_oct = datetime.datetime(year = 2011, month = 8, day = 1) 34 | # Define churn as "inactive for 30 days after August 1st 2011" 35 | churn_period = datetime.timedelta(days = 30) 36 | 37 | model = gl.churn_predictor.create(train, user_id='CustomerID', 38 | features = ['Quantity'], 39 | churn_period = churn_period, 40 | time_boundaries = [churn_boundary_oct]) 41 | 42 | 43 | ### Explore the Model ### 44 | 45 | # Interactively explore churn predictions 46 | view = model.views.overview(exploration_set = timeseries, 47 | validation_set = valid, 48 | exploration_time = churn_boundary_oct, 49 | validation_time = churn_boundary_oct) 50 | view.show() 51 | -------------------------------------------------------------------------------- /dataset/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | 3 | !.gitignore 4 | 5 | -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # 4 | # Downloads datasets for sample projects. 5 | # 6 | 7 | from os import path 8 | import urllib 9 | import zipfile 10 | 11 | def download(url, target_file): 12 | """Download a file from the url to the given target filename.""" 13 | 14 | def report(blocks, block_size, total_size): 15 | percent = (100 * blocks * block_size) / total_size 16 | last_percent = (100 * (blocks - 1) * block_size) / total_size 17 | 18 | if percent > last_percent: 19 | print " {0:.0f}% complete".format(percent) 20 | 21 | print "Downloading %s to %s..." % (url, target_file) 22 | urllib.urlretrieve(url, target_file, report) 23 | print "Download complete." 24 | 25 | def unzip(file, target_dir): 26 | """Unzip a zip archive to the given target directory.""" 27 | 28 | print "Extracting %s to %s..." % (file, target_dir) 29 | with zipfile.ZipFile(file, 'r') as zfile: 30 | zfile.extractall(target_dir) 31 | print "Files extracted." 32 | 33 | if __name__ == '__main__': 34 | 35 | download("https://static.turi.com/datasets/churn-prediction/online_retail.csv", path.join('./dataset/online_retail.csv')) 36 | --------------------------------------------------------------------------------