├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data └── text.csv ├── model └── params.yaml ├── notebook └── test_churn_prediction.ipynb ├── requirements.txt └── scripts ├── create_dataset.py ├── interpret.py ├── preprocess.py └── train.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Churn Prediction with Text and Interpretability 2 | 3 | Customer churn, the loss of current customers, is a problem faced by a wide range of companies. When trying to retain customers, it is in a company’s best interest to focus their efforts on customers who are more likely to leave, but companies need a way to detect customers who are likely to leave before they have decided to leave. Users prone to churn often leave clues to their disposition in user behavior and customer support chat logs which can be detected and understood using Natural Language Processing (NLP) tools. 4 | 5 | Here, we demonstrate how to build a churn prediction model that leverages both text and structured data (numerical and categorical) which we call a bi-modal model architecture. We use Amazon SageMaker to prepare, build, and train the model. Detecting customers who are likely to churn is only part of the battle, finding the root cause is an essential part of actually solving the issue. Since we are not only interested in the likelihood of a customer churning but also in the driving factors, we complement the prediction model with an analysis into feature importance for both text and non-text inputs. 6 | 7 | The categorical and numerical data is from Kaggle: Customer Churn Prediction 2020 and was combined with a synthetic text dataset we created using GPT-2. 8 | 9 | ## Blog Post 10 | 11 | [Medium / Towards Data Science blog post](https://towardsdatascience.com/customer-churn-prediction-with-text-and-interpretability-bd3d57af34b1) 12 | 13 | ## Installation 14 | 15 | ``` 16 | git clone https://github.com/aws-samples/churn-prediction-with-text-and-interpretability.git 17 | conda create -n py39 python=3.9 18 | conda activate py39 19 | cd churn-prediction-with-text-and-interpretability 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ## Download categorical/numerical data and combine with synthetic text data 24 | 25 | 1. Download categorical/numerical data - [Customer Churn Prediction 2020](https://www.kaggle.com/c/customer-churn-prediction-2020/data) 26 | May require Kaggle account. 27 | Download train.csv and store in data folder. 28 | 29 | 2. Run script to combine categorical data with synthetic text data (../scripts) 30 | ``` 31 | python create_dataset.py 32 | ``` 33 | 34 | ## Run in Notebook 35 | 36 | An example notebook to run the entire pipeline and print/visualize the results in included in ../notebook. 37 | 38 | ## Run in Terminal 39 | 40 | The python scripts to prepare the data, train and evaluate the model, as well as interpret the model, are stored in ../scripts. 41 | The parameters used for training and interpreting the model are stored in ../model/params.yaml. 42 | 43 | 44 | 1. Prepare the data: 45 | ``` 46 | python preprocess.py 47 | ``` 48 | 2. Train and evaluate the model: 49 | ``` 50 | python train.py 51 | ``` 52 | 3. Interpret the trained model (text): 53 | ``` 54 | python interpret.py --churn 1 --speaker Customer 55 | ``` 56 | 57 | ## Credits 58 | 59 | * Packages: 60 | * [Spacy](https://spacy.io/usage/linguistic-features/) 61 | * [PyTorch](https://pytorch.org/) 62 | * [XGBoost](https://xgboost.readthedocs.io/en/latest/) 63 | * [Hugging Face Sentence Transformers](https://huggingface.co/sentence-transformers) 64 | 65 | * Datasets: 66 | * [Customer Churn Prediction 2020](https://www.kaggle.com/c/customer-churn-prediction-2020/data) (with synthetic text dataset) 67 | 68 | * Models: 69 | * GPT2, Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever 70 | * BERT, Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova 71 | * Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks, Reimers, Nils and Gurevych, Iryna 72 | 73 | ## Security 74 | 75 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 76 | 77 | ## License 78 | 79 | This library is licensed under the MIT-0 License. See the LICENSE file. 80 | 81 | -------------------------------------------------------------------------------- /model/params.yaml: -------------------------------------------------------------------------------- 1 | data_dir: ../data 2 | model_dir: ../model 3 | batch_size: 10 4 | batch_size_test: 1000 5 | epochs: 10 6 | pos_weight: 10 7 | lr: 0.001 8 | momentum: 0.9 9 | topn_relevant_keywords: 1000 10 | frac_relevant_keywords: 0.25 11 | w_marg_contr: 0.3 12 | w_count: 0.2 13 | w_sim : 0.5 -------------------------------------------------------------------------------- /notebook/test_churn_prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "df276921", 6 | "metadata": {}, 7 | "source": [ 8 | "# Churn Prediction with Text and Interpretability" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "1c1f7791", 14 | "metadata": {}, 15 | "source": [ 16 | "This notebook runs the entire churn prediction pipeline from data preparation to model evaluation and interpretation.\n", 17 | "\n", 18 | "Alternatively, everything can be run from the terminal as well (see README.md).\n", 19 | "\n", 20 | "Prerequisite: Dataset has been created (see README.md)." 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "032bb6ce", 26 | "metadata": {}, 27 | "source": [ 28 | "### Setup" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "id": "a32ade35", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import os\n", 39 | "import pandas as pd\n", 40 | "from matplotlib import pyplot as plt\n", 41 | "\n", 42 | "os.chdir(\"../scripts\")\n", 43 | "\n", 44 | "import preprocess\n", 45 | "import train\n", 46 | "import interpret" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "id": "9d500e76", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "%load_ext autoreload\n", 57 | "%autoreload 2" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "4b8b1336", 63 | "metadata": {}, 64 | "source": [ 65 | "### Load and Prepare the Data" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "id": "94c30e10", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/html": [ 77 | "
\n", 78 | "\n", 91 | "\n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | "
churnchat_logstateaccount_lengtharea_codeinternational_planvoice_mail_plannumber_vmail_messagestotal_day_minutestotal_day_calls...total_eve_minutestotal_eve_callstotal_eve_chargetotal_night_minutestotal_night_callstotal_night_chargetotal_intl_minutestotal_intl_callstotal_intl_chargenumber_customer_service_calls
0noCustomer: Well, the only thing that I'm consid...CT134area_code_408nono0177.291...228.710519.44194.31138.748.932.402
1yesCustomer: Well, I just want to be able to canc...WV78area_code_408nono0226.388...306.28126.03200.91209.047.8112.111
2noCustomer: I would like data.\\nTelCom Agent: Ok...IN88area_code_415nono0183.593...170.58014.49193.8888.728.352.243
\n", 193 | "

3 rows × 21 columns

\n", 194 | "
" 195 | ], 196 | "text/plain": [ 197 | " churn chat_log state \\\n", 198 | "0 no Customer: Well, the only thing that I'm consid... CT \n", 199 | "1 yes Customer: Well, I just want to be able to canc... WV \n", 200 | "2 no Customer: I would like data.\\nTelCom Agent: Ok... IN \n", 201 | "\n", 202 | " account_length area_code international_plan voice_mail_plan \\\n", 203 | "0 134 area_code_408 no no \n", 204 | "1 78 area_code_408 no no \n", 205 | "2 88 area_code_415 no no \n", 206 | "\n", 207 | " number_vmail_messages total_day_minutes total_day_calls ... \\\n", 208 | "0 0 177.2 91 ... \n", 209 | "1 0 226.3 88 ... \n", 210 | "2 0 183.5 93 ... \n", 211 | "\n", 212 | " total_eve_minutes total_eve_calls total_eve_charge total_night_minutes \\\n", 213 | "0 228.7 105 19.44 194.3 \n", 214 | "1 306.2 81 26.03 200.9 \n", 215 | "2 170.5 80 14.49 193.8 \n", 216 | "\n", 217 | " total_night_calls total_night_charge total_intl_minutes \\\n", 218 | "0 113 8.74 8.9 \n", 219 | "1 120 9.04 7.8 \n", 220 | "2 88 8.72 8.3 \n", 221 | "\n", 222 | " total_intl_calls total_intl_charge number_customer_service_calls \n", 223 | "0 3 2.40 2 \n", 224 | "1 11 2.11 1 \n", 225 | "2 5 2.24 3 \n", 226 | "\n", 227 | "[3 rows x 21 columns]" 228 | ] 229 | }, 230 | "execution_count": 4, 231 | "metadata": {}, 232 | "output_type": "execute_result" 233 | } 234 | ], 235 | "source": [ 236 | "df = pd.read_csv('../data/churn_dataset.csv')\n", 237 | "df.head(3)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "ede78f6b", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "X_train, X_test, y_train, y_test = preprocess.prep_data(df, use_existing=False, test_size=0.33)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 27, 253 | "id": "5d0f6d4f", 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "data": { 258 | "text/plain": [ 259 | "((2233, 841), (1100, 841), (2233, 1), (1100, 1))" 260 | ] 261 | }, 262 | "execution_count": 27, 263 | "metadata": {}, 264 | "output_type": "execute_result" 265 | } 266 | ], 267 | "source": [ 268 | "X_train.shape, X_test.shape, y_train.shape, y_test.shape" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "id": "c477f751", 274 | "metadata": {}, 275 | "source": [ 276 | "### Train and Evaluate the Model" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 28, 282 | "id": "0d68bf05", 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "/home/ec2-user/SageMaker/churn_test/scripts\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "!pwd" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 29, 300 | "id": "eb2e2878", 301 | "metadata": { 302 | "collapsed": true, 303 | "jupyter": { 304 | "outputs_hidden": true 305 | } 306 | }, 307 | "outputs": [ 308 | { 309 | "name": "stdout", 310 | "output_type": "stream", 311 | "text": [ 312 | "starting epoch: 1\n", 313 | "Train Epoch: 1, train-auc-score: 0.9561\n", 314 | "test_auc_score: 0.9422\n", 315 | "starting epoch: 2\n", 316 | "Train Epoch: 2, train-auc-score: 0.9571\n", 317 | "test_auc_score: 0.9453\n", 318 | "starting epoch: 3\n", 319 | "Train Epoch: 3, train-auc-score: 0.9602\n", 320 | "test_auc_score: 0.9480\n", 321 | "starting epoch: 4\n", 322 | "Train Epoch: 4, train-auc-score: 0.9594\n", 323 | "test_auc_score: 0.9467\n", 324 | "starting epoch: 5\n", 325 | "Train Epoch: 5, train-auc-score: 0.9628\n", 326 | "test_auc_score: 0.9529\n", 327 | "starting epoch: 6\n", 328 | "Train Epoch: 6, train-auc-score: 0.9711\n", 329 | "test_auc_score: 0.9555\n", 330 | "starting epoch: 7\n", 331 | "Train Epoch: 7, train-auc-score: 0.9756\n", 332 | "test_auc_score: 0.9586\n", 333 | "starting epoch: 8\n", 334 | "Train Epoch: 8, train-auc-score: 0.9804\n", 335 | "test_auc_score: 0.9598\n", 336 | "starting epoch: 9\n", 337 | "Train Epoch: 9, train-auc-score: 0.9810\n", 338 | "test_auc_score: 0.9618\n", 339 | "starting epoch: 10\n", 340 | "Train Epoch: 10, train-auc-score: 0.9644\n", 341 | "test_auc_score: 0.9552\n", 342 | "saving scores\n", 343 | "saving model\n" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "# train the model\n", 349 | "train.train(\n", 350 | " X=X_train,\n", 351 | " y=y_train,\n", 352 | " X_test=X_test,\n", 353 | " y_test=y_test\n", 354 | ")" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 30, 360 | "id": "294bd939", 361 | "metadata": {}, 362 | "outputs": [ 363 | { 364 | "data": { 365 | "image/png": "\n", 366 | "text/plain": [ 367 | "
" 368 | ] 369 | }, 370 | "metadata": { 371 | "needs_background": "light" 372 | }, 373 | "output_type": "display_data" 374 | } 375 | ], 376 | "source": [ 377 | "# plot training stats\n", 378 | "train.plot_train_stats()" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 31, 384 | "id": "75e22bb5", 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "data": { 389 | "image/png": "\n", 390 | "text/plain": [ 391 | "
" 392 | ] 393 | }, 394 | "metadata": { 395 | "needs_background": "light" 396 | }, 397 | "output_type": "display_data" 398 | } 399 | ], 400 | "source": [ 401 | "# plot pr curve\n", 402 | "train.plot_pr_curve(X_test, y_test)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "id": "fa34b9d3", 408 | "metadata": {}, 409 | "source": [ 410 | "### Interpret the Model" 411 | ] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "id": "2ca793cf", 416 | "metadata": {}, 417 | "source": [ 418 | "#### Categorical and Numerical Features" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 32, 424 | "id": "5ad792da", 425 | "metadata": {}, 426 | "outputs": [ 427 | { 428 | "name": "stderr", 429 | "output_type": "stream", 430 | "text": [ 431 | "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/xgboost/sklearn.py:1146: UserWarning: The use of label encoder in XGBClassifier is deprecated and will be removed in a future release. To remove this warning, do the following: 1) Pass option use_label_encoder=False when constructing XGBClassifier object; and 2) Encode your labels (y) as integers starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].\n", 432 | " warnings.warn(label_encoder_deprecation_msg, UserWarning)\n", 433 | "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/sklearn/utils/validation.py:63: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", 434 | " return f(*args, **kwargs)\n" 435 | ] 436 | }, 437 | { 438 | "name": "stdout", 439 | "output_type": "stream", 440 | "text": [ 441 | "[16:15:16] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n" 442 | ] 443 | }, 444 | { 445 | "data": { 446 | "image/png": "\n", 447 | "text/plain": [ 448 | "
" 449 | ] 450 | }, 451 | "metadata": { 452 | "needs_background": "light" 453 | }, 454 | "output_type": "display_data" 455 | } 456 | ], 457 | "source": [ 458 | "preds_xgb = interpret.train_xgb()" 459 | ] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "id": "9e80253e", 464 | "metadata": {}, 465 | "source": [ 466 | "#### Textual Features (focus on customer chats that result in churn)" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 9, 472 | "id": "7f18dcf5", 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "chats, df_sub = interpret.get_chats(df=df, churn=1, speaker='Customer')" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 10, 482 | "id": "0b65267d", 483 | "metadata": {}, 484 | "outputs": [ 485 | { 486 | "data": { 487 | "text/plain": [ 488 | "[\"Well, I just want to be able to cancel the contract because I don't think that I want to stay. My local provider has been terrible and I really would like to switch. Sure, I can.\",\n", 489 | " \"Well, it's the old TelCom billing system for the last 5 years. I don't trust anymore and I think you should change to the newer billing system. I would like to give you a call back number. Okay, I can see why you need the new billing system, but I don't know if I can do that. I would like to know your cancellation policy.\",\n", 490 | " \"Well, I've been getting phone calls from a very good friend who's a TelCom agent and I have told him the same thing and the problem has not been resolved. He has offered me a $20/mo deal but that's not good enough for me because I'm getting $20 out of his pocket. Sure. $60 to cancel for nine months with a $20/mo bonus.\"]" 491 | ] 492 | }, 493 | "execution_count": 10, 494 | "metadata": {}, 495 | "output_type": "execute_result" 496 | } 497 | ], 498 | "source": [ 499 | "chats[:3]" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "id": "8c73e7dd", 505 | "metadata": {}, 506 | "source": [ 507 | "##### Candidate keywords (POS tagging, lower casing, lemmatization)" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": 11, 513 | "id": "3db8f636", 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "['want', 'able', 'cancel', 'contract', 'think', 'want', 'stay', 'local', 'provider', 'terrible']\n" 521 | ] 522 | } 523 | ], 524 | "source": [ 525 | "# find candidate keywords\n", 526 | "keywords, tokens = interpret.get_keywords(chats)\n", 527 | "print(keywords[:10])\n", 528 | "\n", 529 | "# map keywords to original tokens\n", 530 | "keywords_dict = interpret.map_to_orig_tok(keywords, tokens)" 531 | ] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "id": "172eed77", 536 | "metadata": {}, 537 | "source": [ 538 | "##### Relevant keywords (semantic similarity)" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 12, 544 | "id": "07be6aa0", 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "relevant_keywords, simMat = interpret.get_relevant_keywords(\n", 549 | " text = chats, \n", 550 | " keywords_dict = keywords_dict\n", 551 | ")" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 13, 557 | "id": "267a00b0", 558 | "metadata": {}, 559 | "outputs": [ 560 | { 561 | "name": "stdout", 562 | "output_type": "stream", 563 | "text": [ 564 | "['voicemail', 'bored', 'spam', 'backlog', 'sick', 'frustrated', 'incompetence', 'overcharge', 'angry', 'disappointed', 'scam', 'lag', 'termination', 'resign', 'nightmare', 'incompetent', 'frustration', 'lagging', 'yesterday', 'friday', 'monday', 'discontinue', 'cheat', 'dead', 'annoyed']\n" 565 | ] 566 | } 567 | ], 568 | "source": [ 569 | "print(relevant_keywords[:25])" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "id": "3ba0da26", 575 | "metadata": {}, 576 | "source": [ 577 | "##### Impactful kewords (marginal contribution)" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": 14, 583 | "id": "8c877937", 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "name": "stderr", 588 | "output_type": "stream", 589 | "text": [ 590 | "100%|██████████| 250/250 [11:25<00:00, 2.74s/it]\n" 591 | ] 592 | } 593 | ], 594 | "source": [ 595 | "# get marginal contribution to prediction for each keyword\n", 596 | "marg_contr_df = interpret.perform_ablation(\n", 597 | " df = df_sub,\n", 598 | " keywords = relevant_keywords,\n", 599 | " keywords_dict = keywords_dict\n", 600 | ")" 601 | ] 602 | }, 603 | { 604 | "cell_type": "markdown", 605 | "id": "d43f8aa6", 606 | "metadata": {}, 607 | "source": [ 608 | "##### Create joint metric (semantic similarity + marginal contribution + count)" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "id": "c02d300a", 615 | "metadata": {}, 616 | "outputs": [], 617 | "source": [ 618 | "# load from local disc if available\n", 619 | "#results_df = pd.read_csv('model/ablation_results.csv')\n", 620 | "#results_df = results_df.rename(columns={'Unnamed: 0' : 'keyword'})" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": 15, 626 | "id": "2e112e00", 627 | "metadata": {}, 628 | "outputs": [], 629 | "source": [ 630 | "results_df = interpret.get_important_keywords(\n", 631 | " simMat_df=simMat,\n", 632 | " marg_contr_df=marg_contr_df\n", 633 | ")" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 16, 639 | "id": "efd90445", 640 | "metadata": {}, 641 | "outputs": [ 642 | { 643 | "data": { 644 | "text/html": [ 645 | "
\n", 646 | "\n", 659 | "\n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | "
keywordsimchgcountjoint
0voicemail90.0765530.00669550.628774
1cancel61.187359-0.0813211680.545633
2sick74.789459-0.12724210.538919
3turnover60.896118-0.28663010.533321
4disappointed70.248131-0.09174050.522520
5spam77.940460-0.00002230.506429
6bored78.131271-0.03893210.502213
7unhappy65.601990-0.024782370.493910
8frustrated73.496033-0.00602350.486930
9mistake66.247879-0.09330930.470609
10late65.971649-0.003873180.458023
11error63.123695-0.06560970.448412
12faulty55.486092-0.23981710.448232
13angry70.865860-0.00296730.444018
14backlog74.8085480.00002010.442185
15customer52.072552-0.0069854800.439737
16lag69.912178-0.00465030.436584
17overpay65.573105-0.09384610.429261
18disconnect64.002014-0.010485110.429104
19incompetence73.107452-0.00075410.427229
\n", 833 | "
" 834 | ], 835 | "text/plain": [ 836 | " keyword sim chg count joint\n", 837 | "0 voicemail 90.076553 0.006695 5 0.628774\n", 838 | "1 cancel 61.187359 -0.081321 168 0.545633\n", 839 | "2 sick 74.789459 -0.127242 1 0.538919\n", 840 | "3 turnover 60.896118 -0.286630 1 0.533321\n", 841 | "4 disappointed 70.248131 -0.091740 5 0.522520\n", 842 | "5 spam 77.940460 -0.000022 3 0.506429\n", 843 | "6 bored 78.131271 -0.038932 1 0.502213\n", 844 | "7 unhappy 65.601990 -0.024782 37 0.493910\n", 845 | "8 frustrated 73.496033 -0.006023 5 0.486930\n", 846 | "9 mistake 66.247879 -0.093309 3 0.470609\n", 847 | "10 late 65.971649 -0.003873 18 0.458023\n", 848 | "11 error 63.123695 -0.065609 7 0.448412\n", 849 | "12 faulty 55.486092 -0.239817 1 0.448232\n", 850 | "13 angry 70.865860 -0.002967 3 0.444018\n", 851 | "14 backlog 74.808548 0.000020 1 0.442185\n", 852 | "15 customer 52.072552 -0.006985 480 0.439737\n", 853 | "16 lag 69.912178 -0.004650 3 0.436584\n", 854 | "17 overpay 65.573105 -0.093846 1 0.429261\n", 855 | "18 disconnect 64.002014 -0.010485 11 0.429104\n", 856 | "19 incompetence 73.107452 -0.000754 1 0.427229" 857 | ] 858 | }, 859 | "execution_count": 16, 860 | "metadata": {}, 861 | "output_type": "execute_result" 862 | } 863 | ], 864 | "source": [ 865 | "results_df.head(20)" 866 | ] 867 | }, 868 | { 869 | "cell_type": "markdown", 870 | "id": "7036dc83", 871 | "metadata": {}, 872 | "source": [ 873 | "##### Context of keywords" 874 | ] 875 | }, 876 | { 877 | "cell_type": "code", 878 | "execution_count": 17, 879 | "id": "5c0f0d2b", 880 | "metadata": {}, 881 | "outputs": [], 882 | "source": [ 883 | "keyword_of_interest = 'spam'" 884 | ] 885 | }, 886 | { 887 | "cell_type": "code", 888 | "execution_count": 18, 889 | "id": "041befbf", 890 | "metadata": {}, 891 | "outputs": [ 892 | { 893 | "name": "stdout", 894 | "output_type": "stream", 895 | "text": [ 896 | "Basically, I'm getting a lot of spam calls every day from a guy named Michael who's calling from a really weird number.\n", 897 | "TelCom started to flood me with emails and phone calls, spamming me with thousands of phony invoices.\n", 898 | "I just got some spam messages last night, and today it's been getting a lot of texts that I \"don't have my SIM card\" and \"I need my SIM card.\n" 899 | ] 900 | } 901 | ], 902 | "source": [ 903 | "interpret.obtain_context(\n", 904 | " chats_list = chats,\n", 905 | " keyword = keyword_of_interest\n", 906 | ")" 907 | ] 908 | }, 909 | { 910 | "cell_type": "code", 911 | "execution_count": null, 912 | "id": "a294f494", 913 | "metadata": {}, 914 | "outputs": [], 915 | "source": [] 916 | } 917 | ], 918 | "metadata": { 919 | "kernelspec": { 920 | "display_name": "conda_pytorch_p36", 921 | "language": "python", 922 | "name": "conda_pytorch_p36" 923 | }, 924 | "language_info": { 925 | "codemirror_mode": { 926 | "name": "ipython", 927 | "version": 3 928 | }, 929 | "file_extension": ".py", 930 | "mimetype": "text/x-python", 931 | "name": "python", 932 | "nbconvert_exporter": "python", 933 | "pygments_lexer": "ipython3", 934 | "version": "3.6.13" 935 | } 936 | }, 937 | "nbformat": 4, 938 | "nbformat_minor": 5 939 | } 940 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | matplotlib 3 | sentence_transformers==2.0.0 4 | xgboost==1.4.2 5 | spacy>=3.0.0,<4.0.0 6 | https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm -------------------------------------------------------------------------------- /scripts/create_dataset.py: -------------------------------------------------------------------------------- 1 | """ Module for creating the dataset by combining categorical/numerical and text csv. files. 2 | 3 | First, download categorical/numerical data into data folder as described in README.md, then: 4 | 5 | Run in CLI example: 6 | 'python create_dataset.py' 7 | 8 | """ 9 | 10 | import yaml 11 | import pandas as pd 12 | from pathlib import Path 13 | 14 | with open("../model/params.yaml", "r") as params_file: 15 | params = yaml.safe_load(params_file) 16 | 17 | 18 | def create_joint_dataset( 19 | df_categorical, 20 | df_text 21 | ): 22 | df_text_no = df_text[df_text.churn == 'no'].reset_index(drop=True) 23 | df_text_yes = df_text[df_text.churn == 'yes'].reset_index(drop=True) 24 | 25 | df_cat_no = df_categorical[df_categorical.churn == 'no'].reset_index(drop=True)[:len(df_text_no)] 26 | df_cat_yes = df_categorical[df_categorical.churn == 'yes'].reset_index(drop=True)[:len(df_text_yes)] 27 | 28 | df_no = pd.concat([df_text_no, df_cat_no.iloc[:,:-1]], axis=1) 29 | df_yes = pd.concat([df_text_yes, df_cat_yes.iloc[:,:-1]], axis=1) 30 | 31 | df = pd.concat([df_no, df_yes], axis=0) 32 | 33 | # shuffle data 34 | df = df.sample(frac=1).reset_index(drop=True) 35 | 36 | return df 37 | 38 | 39 | if __name__ == "__main__": 40 | data_dir = params['data_dir'] 41 | 42 | # load data 43 | df_categorical = pd.read_csv(Path(data_dir, "train.csv")) 44 | df_text = pd.read_csv(Path(data_dir, "text.csv")) 45 | 46 | # create joint dataset 47 | df = create_joint_dataset(df_categorical, df_text) 48 | 49 | # save data 50 | df.to_csv(Path(data_dir, "churn_dataset.csv"), index=False) 51 | -------------------------------------------------------------------------------- /scripts/interpret.py: -------------------------------------------------------------------------------- 1 | """ Module for interpreting the trained churn prediction model with text. 2 | 3 | The parameters for model interpretation are stored in 'params.yaml'. 4 | 5 | Run in CLI example: 6 | 'python interpret.py --churn 1 --speaker Customer' 7 | 8 | """ 9 | 10 | 11 | import json 12 | import yaml 13 | import spacy 14 | import torch 15 | import argparse 16 | import numpy as np 17 | import pandas as pd 18 | from tqdm import tqdm 19 | from pathlib import Path 20 | from sklearn import preprocessing 21 | from xgboost import XGBClassifier 22 | from collections import defaultdict 23 | from matplotlib import pyplot as plt 24 | import preprocess 25 | from preprocess import BertEncoder 26 | import train 27 | 28 | nlp = spacy.load('en_core_web_sm') 29 | 30 | with open("../model/params.yaml", "r") as params_file: 31 | params = yaml.safe_load(params_file) 32 | 33 | model_dir = params['model_dir'] 34 | 35 | 36 | def get_chats( 37 | df, 38 | churn=None, 39 | speaker=None 40 | ): 41 | """ 42 | Args: 43 | df: dataframe with all data 44 | churn (int): 1 for churn, 0 for no churn, None (default) for all chats 45 | customer (str): 'Customer' for only customer chats, 46 | 'Agent' for only agent chats, 47 | None (default) for all chats 48 | 49 | Returns: 50 | list of chats (strings) 51 | """ 52 | # convert labels to binary numeric 53 | df = preprocess.convert_label(df) 54 | # keep only churn/no churn chat logs 55 | if churn is not None: 56 | df = df[df.churn == churn] 57 | # drop short chat logs 58 | df = df[df.chat_log.apply(lambda x: len(str(x))>=5)] 59 | 60 | # select chats 61 | chat_logs = list(df['chat_log']) 62 | chat_logs = [i if isinstance(i, str) else "nan" for i in chat_logs] 63 | 64 | if speaker is not None: 65 | chats = [] 66 | for chat in chat_logs: 67 | sents = chat.split('\n') 68 | cchat = [] 69 | for sent in sents: 70 | if str(sent).split(':')[0] == speaker: 71 | cchat.append(sent[10:]) 72 | chats.append(' '.join(cchat)) 73 | else: 74 | chats = chat_logs 75 | 76 | return chats, df 77 | 78 | 79 | def get_keywords(texts): 80 | """Returns candidate keywords based on POS as well as original form of keyword. 81 | """ 82 | candidate_pos = ['ADJ', 'VERB', 'NOUN', 'PROPN'] 83 | keywords = [] 84 | tokens = [] # for referencing keywords in original text later on 85 | for text in texts: 86 | text_keywords = [] 87 | text_tokens = [] 88 | doc = nlp(text) 89 | for token in doc: 90 | if token.pos_ in candidate_pos and token.is_stop is False: 91 | text_tokens.append(str(token)) 92 | text_keywords.append(token.lemma_.lower()) 93 | 94 | keywords.extend(text_keywords) 95 | tokens.extend(text_tokens) 96 | 97 | return keywords, tokens 98 | 99 | 100 | def map_to_orig_tok(keywords, tokens): 101 | """Create dictionary mapping keywords to original tokens. 102 | """ 103 | keywords_dict = defaultdict(list) 104 | for kw, t in zip(keywords, tokens): 105 | keywords_dict[kw].append(t) 106 | for kw, l in keywords_dict.items(): 107 | keywords_dict[kw] = list(set(l)) 108 | 109 | keywords_dict = dict(keywords_dict) 110 | 111 | return keywords_dict 112 | 113 | 114 | def get_relevant_keywords( 115 | text, 116 | keywords_dict 117 | ): 118 | """Returns relevant keywords based on semantic similarity to class embedding. 119 | """ 120 | # obtain class embedding 121 | textual_transformer = BertEncoder() 122 | textual_features = textual_transformer.transform(text) 123 | class_embedding = np.mean(np.array(textual_features), axis=0) 124 | 125 | # obtain relevant keywords 126 | unique_keywords = list(keywords_dict.keys()) 127 | topn_relevant_keywords = params['topn_relevant_keywords'] 128 | relevant_keywords, simMat = relevant_keywords_helper(unique_keywords, 129 | class_embedding, 130 | topn_relevant_keywords) 131 | return relevant_keywords, simMat 132 | 133 | 134 | def relevant_keywords_helper( 135 | keywords, 136 | class_embedding, 137 | topn_relevant_keywords 138 | ): 139 | """Helper function for obtaining embedding similarity. 140 | """ 141 | textual_transformer = BertEncoder() 142 | keyword_embeddings = textual_transformer.transform(keywords) 143 | 144 | simMatrix = np.dot(keyword_embeddings, class_embedding.T) 145 | d = {"keyword" : keywords, "sim" : list(simMatrix)} 146 | df_simMatrix = pd.DataFrame(d).sort_values(by="sim", ascending=False).reset_index(drop=True) 147 | relevant_keywords = list(df_simMatrix['keyword'])[:topn_relevant_keywords] 148 | 149 | return relevant_keywords, df_simMatrix 150 | 151 | 152 | def prep_ablation(df): 153 | """Prepare data for making predictions.""" 154 | 155 | # convert df to list of dicts 156 | data = df.to_json(orient="records") 157 | data = json.loads(data) 158 | 159 | # load model assets from training job 160 | model_assets = train.get_train_assets() 161 | #print('extracting features') 162 | numerical_features, categorical_features, textual_features = preprocess.extract_features( 163 | data, 164 | model_assets['numerical_feature_names'], 165 | model_assets['categorical_feature_names'], 166 | model_assets['textual_feature_names'] 167 | ) 168 | 169 | # extract labels 170 | _, _, _, label_name = preprocess.get_feature_names(df) 171 | labels = preprocess.extract_labels( 172 | data, 173 | label_name 174 | ) 175 | 176 | # preprocess the data 177 | #print('transforming numerical_features') 178 | numerical_features = model_assets['numerical_transformer'].transform(numerical_features) 179 | #print('transforming categorical_features') 180 | categorical_features = model_assets['categorical_transformer'].transform(categorical_features) 181 | #print('transforming textual_features') 182 | textual_features = model_assets['textual_transformer'].transform(textual_features) 183 | 184 | #print('concatenating features') 185 | categorical_features = categorical_features.toarray() 186 | textual_features = np.array(textual_features) 187 | textual_features = textual_features.reshape(textual_features.shape[0], -1) 188 | features = np.concatenate([ 189 | numerical_features, 190 | categorical_features, 191 | textual_features 192 | ], axis=1) 193 | 194 | return features, labels 195 | 196 | 197 | def perform_ablation( 198 | df, 199 | keywords, 200 | keywords_dict 201 | ): 202 | """Predict w/ and w/o keywords ablated. 203 | """ 204 | # select subset of relevant keywords 205 | frac_relevant_keywords = params['frac_relevant_keywords'] 206 | topn = int(params['topn_relevant_keywords'] * frac_relevant_keywords) 207 | keywords_select = keywords[:topn] 208 | keywords_select_dict = {} 209 | for kw in keywords_select: 210 | keywords_select_dict[kw] = keywords_dict[kw] 211 | 212 | # loop through keywords and perform ablation analysis 213 | results_dict = {} 214 | for keyword, keywords_list in tqdm(keywords_select_dict.items()): 215 | # get portion of df where keywords occur 216 | df_incl = pd.DataFrame() 217 | df_excl = pd.DataFrame() 218 | for i, row in df.iterrows(): 219 | for kw in keywords_list: 220 | if kw in row['chat_log']: 221 | # df including keyword in chat 222 | df_incl = df_incl.append(row) 223 | # df excluding keyword in chat 224 | chat_wkw = row['chat_log'].replace(kw, ' ') 225 | row['chat_log'] = chat_wkw 226 | df_excl = df_excl.append(row) 227 | 228 | # prep data incl/excl keyword in text (loads trained preprocessors) 229 | features_incl, labels_incl = prep_ablation( 230 | df_incl 231 | ) 232 | 233 | features_excl, labels_excl = prep_ablation( 234 | df_excl 235 | ) 236 | 237 | # predict using previously trained model 238 | pred_incl = train.predict(features_incl, labels_incl) 239 | pred_excl = train.predict(features_excl, labels_excl) 240 | 241 | # save results 242 | results_dict[keyword] = {'incl' : np.average(pred_incl), 243 | 'excl' : np.average(pred_excl), 244 | 'chg' : np.average(pred_excl) - np.average(pred_incl), 245 | 'count' : len(pred_incl)} 246 | 247 | # store results 248 | results_df = pd.DataFrame.from_dict(results_dict, orient='index') 249 | results_df = results_df.sort_values(by=["chg", "count"], ascending=(True, False)) 250 | results_df.to_csv(Path(model_dir, "ablation_results.csv")) 251 | 252 | return results_df 253 | 254 | 255 | def get_important_keywords( 256 | simMat_df, 257 | marg_contr_df 258 | ): 259 | # merge dataframes 260 | marg_contr_df.insert(0, 'keyword', marg_contr_df.index) 261 | marg_contr_df = marg_contr_df.drop(['incl', 'excl'], axis=1) 262 | results_df = marg_contr_df.merge(simMat_df, how='left', on='keyword') 263 | 264 | # rescale metrics 265 | temp_df = results_df[['chg','count','sim']].copy() 266 | temp_df['chg'] = temp_df['chg'] * (-1) 267 | temp_df['count'] = np.log(temp_df['count']) # taking log transformation on keyword counts due to extreme outliers 268 | min_max_scaler = preprocessing.MinMaxScaler(feature_range=(0, 1)) 269 | temp_df = min_max_scaler.fit_transform(temp_df) 270 | temp_df = pd.DataFrame(temp_df, columns=['chg','count','sim']) 271 | 272 | # calculate weighted average 273 | w_marg_contr = params['w_marg_contr'] 274 | w_count = params['w_count'] 275 | w_sim = params['w_sim'] 276 | temp_df['joint'] = temp_df['chg'] * w_marg_contr + temp_df['count'] * w_count + temp_df['sim'] * w_sim 277 | 278 | results_df['joint'] = temp_df['joint'] 279 | results_df = results_df.sort_values(by="joint", ascending=False).reset_index(drop=True) 280 | results_df = results_df[['keyword','sim','chg','count','joint']] 281 | 282 | # store results 283 | results_df.to_csv(Path(model_dir, "important_keywords.csv")) 284 | 285 | return results_df 286 | 287 | 288 | def obtain_context( 289 | chats_list, 290 | keyword, 291 | limit=3 292 | ): 293 | """Prints out limited number of chats where keyword occurs. 294 | """ 295 | nlp = spacy.load('en_core_web_sm') 296 | counter = 0 297 | for chat in chats_list: 298 | doc = nlp(chat) 299 | for sent in doc.sents: 300 | if keyword in sent.text and counter < limit: 301 | print(sent.text) 302 | print('\n') 303 | counter+=1 304 | if counter >= limit: 305 | break 306 | 307 | return None 308 | 309 | 310 | def train_xgb(): 311 | """Train XGBoost model to explain categorical and numerical feature importance. 312 | """ 313 | # load one-hot feature names 314 | filepath = Path(model_dir, "one_hot_feature_names.json") 315 | numerical_feature_names, categorical_feature_names, _ = preprocess.load_feature_names(filepath) 316 | one_hot_feature_names = numerical_feature_names + categorical_feature_names 317 | 318 | # load train data and exclude text embeddings 319 | train = pd.read_csv(Path(model_dir, "train.csv")) 320 | train_notext = train.iloc[:, :len(one_hot_feature_names)] 321 | labels = pd.read_csv(Path(model_dir, "labels.csv")) 322 | test = pd.read_csv(Path(model_dir, "test.csv")) 323 | test_notext = test.iloc[:, :len(one_hot_feature_names)] 324 | labels_test = pd.read_csv(Path(model_dir, "labels_test.csv")) 325 | 326 | # train XGBoost model 327 | xgb = XGBClassifier() 328 | xgb.fit(train_notext, labels) 329 | 330 | y_pred = xgb.predict_proba(test_notext) 331 | y_pred = [p[1] for p in y_pred] 332 | 333 | # plot important features 334 | topn = 10 335 | sorted_idx = xgb.feature_importances_.argsort() 336 | plt.barh(np.array(one_hot_feature_names)[sorted_idx[-topn:]], 337 | xgb.feature_importances_[sorted_idx[-topn:]]) 338 | plt.title("Xgboost Feature Importance") 339 | plt.show() 340 | 341 | return y_pred 342 | 343 | 344 | if __name__ == "__main__": 345 | parser = argparse.ArgumentParser() 346 | parser.add_argument("--churn", type=int, default=1) 347 | parser.add_argument("--speaker", type=str, default='Customer') 348 | args = parser.parse_args() 349 | 350 | data_dir = params['data_dir'] 351 | df = pd.read_csv(Path(data_dir, "churn_dataset.csv")) 352 | 353 | churn = args.churn 354 | speaker = args.speaker 355 | chats, df_sub = get_chats(df, churn, speaker) 356 | keywords, tokens = get_keywords(chats) 357 | keywords_dict = map_to_orig_tok(keywords, tokens) 358 | 359 | relevant_keywords, simMat_df = get_relevant_keywords(chats, keywords_dict) 360 | marg_contr_df = perform_ablation(df_sub, relevant_keywords, keywords_dict) 361 | 362 | get_important_keywords(simMat_df, marg_contr_df) 363 | -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | """ Module for preparing the data for the churn prediction model with text. 2 | 3 | Run in CLI example: 4 | 'python preprocess.py --test-size 0.33' 5 | 6 | """ 7 | 8 | 9 | import os 10 | import sys 11 | import json 12 | import yaml 13 | import joblib 14 | import logging 15 | import argparse 16 | import numpy as np 17 | import pandas as pd 18 | from pathlib import Path 19 | from matplotlib import pyplot as plt 20 | from sklearn.model_selection import train_test_split 21 | from sklearn.impute import SimpleImputer 22 | from sklearn.preprocessing import OneHotEncoder 23 | from sklearn.base import BaseEstimator, TransformerMixin 24 | from sentence_transformers import SentenceTransformer 25 | 26 | 27 | with open("../model/params.yaml", "r") as params_file: 28 | params = yaml.safe_load(params_file) 29 | 30 | model_dir = params['model_dir'] 31 | 32 | 33 | class BertEncoder(BaseEstimator, TransformerMixin): 34 | def __init__(self, model_name='bert-base-nli-mean-tokens'): 35 | self.model = SentenceTransformer(model_name) 36 | self.model.parallel_tokenization = False 37 | 38 | def fit(self, X, y=None): 39 | return self 40 | 41 | def transform(self, X): 42 | output = [] 43 | for sample in X: 44 | encodings = self.model.encode(sample) 45 | output.append(encodings) 46 | return output 47 | 48 | 49 | def extract_labels( 50 | data, 51 | label_name 52 | ): 53 | labels = [] 54 | for sample in data: 55 | value = sample[label_name] 56 | labels.append(value) 57 | labels = np.array(labels).astype('int') 58 | 59 | return labels.reshape(labels.shape[0],-1) 60 | 61 | 62 | def convert_label( 63 | df 64 | ): 65 | df.churn = df.churn.replace("no", 0) 66 | df.churn = df.churn.replace("yes", 1) 67 | return df 68 | 69 | def extract_numerical_features( 70 | sample, 71 | numerical_feature_names 72 | ): 73 | output = [] 74 | for feature_name in numerical_feature_names: 75 | if feature_name in sample.keys(): 76 | value = sample[feature_name] 77 | if value is None: 78 | value = np.nan 79 | else: 80 | value = np.nan 81 | output.append(value) 82 | return output 83 | 84 | 85 | def extract_categorical_features( 86 | sample, 87 | categorical_feature_names 88 | ): 89 | output = [] 90 | for feature_name in categorical_feature_names: 91 | if feature_name in sample.keys(): 92 | value = sample[feature_name] 93 | if value is None: 94 | value = "" 95 | else: 96 | value = "" 97 | output.append(value) 98 | return output 99 | 100 | 101 | def extract_textual_features( 102 | sample, 103 | textual_feature_names 104 | ): 105 | output = [] 106 | for feature_name in textual_feature_names: 107 | if feature_name in sample.keys(): 108 | value = sample[feature_name] 109 | if value is None: 110 | value = "" 111 | else: 112 | value = "" 113 | output.append(value) 114 | return output 115 | 116 | 117 | def split_data( 118 | df, 119 | label_name, 120 | test_size 121 | ): 122 | """Splits data and creates json format. 123 | """ 124 | X = df.drop(columns=[label_name], axis=1) 125 | y = df[label_name] 126 | X_train, X_test, y_train, y_test = train_test_split( 127 | X, y, test_size=test_size, random_state=123, stratify=y) 128 | 129 | train = pd.DataFrame(X_train, columns = X.columns) 130 | train[label_name] = y_train 131 | test = pd.DataFrame(X_test, columns = X.columns) 132 | test[label_name] = y_test 133 | 134 | # create list of dicts 135 | train = train.to_json(orient="records") 136 | train = json.loads(train) 137 | test = test.to_json(orient="records") 138 | test = json.loads(test) 139 | 140 | return train, test 141 | 142 | 143 | def extract_features( 144 | data, 145 | numerical_feature_names, 146 | categorical_feature_names, 147 | textual_feature_names 148 | ): 149 | """extract features by given feature names. 150 | """ 151 | numerical_features = [] 152 | categorical_features = [] 153 | textual_features = [] 154 | for sample in data: 155 | num_feat = extract_numerical_features(sample, numerical_feature_names) 156 | numerical_features.append(num_feat) 157 | cat_feat = extract_categorical_features(sample, categorical_feature_names) 158 | categorical_features.append(cat_feat) 159 | text_feat = extract_textual_features(sample, textual_feature_names) 160 | textual_features.append(text_feat) 161 | 162 | textual_features = [i if isinstance(i[0], str) else ["nan"] for i in textual_features] 163 | textual_features = [i[0] for i in textual_features] 164 | 165 | return numerical_features, categorical_features, textual_features 166 | 167 | 168 | def save_feature_names( 169 | numerical_feature_names, 170 | categorical_feature_names, 171 | textual_feature_names, 172 | filepath 173 | ): 174 | feature_names = { 175 | 'numerical': numerical_feature_names, 176 | 'categorical': categorical_feature_names, 177 | 'textual': textual_feature_names 178 | } 179 | with open(filepath, 'w') as f: 180 | json.dump(feature_names, f) 181 | 182 | 183 | def load_feature_names(filepath): 184 | with open(filepath, 'r') as f: 185 | feature_names = json.load(f) 186 | numerical_feature_names = feature_names['numerical'] 187 | categorical_feature_names = feature_names['categorical'] 188 | textual_feature_names = feature_names['textual'] 189 | return numerical_feature_names, categorical_feature_names, textual_feature_names 190 | 191 | 192 | def get_feature_names( 193 | df 194 | ): 195 | num_columns = df.select_dtypes(include=np.number).columns.tolist() 196 | numerical_feature_names = [i for i in num_columns if i not in ['churn']] 197 | 198 | cat_columns = df.select_dtypes(include='object').columns.tolist() 199 | categorical_feature_names = [i for i in cat_columns if i not in ['chat_log']] 200 | 201 | textual_feature_names = ['chat_log'] 202 | label_name = 'churn' 203 | 204 | return numerical_feature_names, categorical_feature_names, textual_feature_names, label_name 205 | 206 | 207 | def prep_data( 208 | df, use_existing=False, test_size=0.33 209 | ): 210 | """ 211 | Args: 212 | df: Pandas dataframe with raw data 213 | use_existing: Set to True if you want to use locally stored, 214 | already prepared train/test data. Set to False if you want 215 | to rerun the data preparation pipeline. 216 | Returns: 217 | Train and test data as well as train labels and test labels. 218 | """ 219 | # if prepared data exists, don't prepare again if use_existing set to True 220 | train_file = Path(model_dir, 'train.csv') 221 | labels_file = Path(model_dir, 'labels.csv') 222 | test_file = Path(model_dir, 'test.csv') 223 | labels_test_file = Path(model_dir, 'labels_test.csv') 224 | feature_names_file = Path(model_dir, "feature_names.json") 225 | oh_feature_names_file = Path(model_dir, "one_hot_feature_names.json") 226 | all_file_paths = [train_file, labels_file, test_file, labels_test_file, 227 | feature_names_file, oh_feature_names_file] 228 | 229 | if use_existing == True and all(file.exists() for file in all_file_paths): 230 | features = np.array(pd.read_csv('../model/train.csv')) 231 | labels = np.array(pd.read_csv('../model/labels.csv')) 232 | features_test = np.array(pd.read_csv('../model/test.csv')) 233 | labels_test = np.array(pd.read_csv('../model/labels_test.csv')) 234 | print("Using already prepared data.") 235 | 236 | else: 237 | print("Running data preparation pipeline...") 238 | # convert label to binary numeric 239 | df = convert_label(df) 240 | 241 | # extract feature names 242 | numerical_feature_names, categorical_feature_names, textual_feature_names, label_name = get_feature_names( 243 | df 244 | ) 245 | # train/test split and convert to json format (list of dicts) 246 | train, test = split_data( 247 | df, 248 | label_name, 249 | test_size 250 | ) 251 | # extract features & label 252 | print('extracting features') 253 | numerical_features, categorical_features, textual_features = extract_features( 254 | train, 255 | numerical_feature_names, 256 | categorical_feature_names, 257 | textual_feature_names 258 | ) 259 | labels = extract_labels( 260 | train, 261 | label_name 262 | ) 263 | # extract features & label (for test data) 264 | numerical_features_test, categorical_features_test, textual_features_test = extract_features( 265 | test, 266 | numerical_feature_names, 267 | categorical_feature_names, 268 | textual_feature_names 269 | ) 270 | labels_test = extract_labels( 271 | test, 272 | label_name 273 | ) 274 | # define preprocessors 275 | print('defining preprocessors') 276 | numerical_transformer = SimpleImputer(missing_values=np.nan, strategy='mean', add_indicator=True) 277 | categorical_transformer = OneHotEncoder(handle_unknown="ignore") 278 | textual_transformer = BertEncoder() 279 | 280 | # fit preprocessors 281 | print('fitting numerical_transformer') 282 | numerical_transformer.fit(numerical_features) 283 | print('saving numerical_transformer') 284 | joblib.dump(numerical_transformer, Path(model_dir, "numerical_transformer.joblib")) 285 | print('fitting categorical_transformer') 286 | categorical_transformer.fit(categorical_features) 287 | print('saving categorical_transformer') 288 | joblib.dump(categorical_transformer, Path(model_dir, "categorical_transformer.joblib")) 289 | 290 | # transform features 291 | print('transforming numerical_features') 292 | numerical_features = numerical_transformer.transform(numerical_features) 293 | print('transforming categorical_features') 294 | categorical_features = categorical_transformer.transform(categorical_features) 295 | print('transforming textual_features') 296 | textual_features = textual_transformer.transform(textual_features) 297 | 298 | # transform features (for test data) 299 | print('transforming numerical_features_test') 300 | numerical_features_test = numerical_transformer.transform(numerical_features_test) 301 | print('transforming categorical_features_test') 302 | categorical_features_test = categorical_transformer.transform(categorical_features_test) 303 | print('transforming textual_features_test') 304 | textual_features_test = textual_transformer.transform(textual_features_test) 305 | 306 | # concat features 307 | print('concatenating features') 308 | categorical_features = categorical_features.toarray() 309 | textual_features = np.array(textual_features) 310 | textual_features = textual_features.reshape(textual_features.shape[0], -1) 311 | features = np.concatenate([ 312 | numerical_features, 313 | categorical_features, 314 | textual_features 315 | ], axis=1) 316 | 317 | # concat features (test data) 318 | print('concatenating features of test data') 319 | categorical_features_test = categorical_features_test.toarray() 320 | textual_features_test = np.array(textual_features_test) 321 | textual_features_test = textual_features_test.reshape(textual_features_test.shape[0], -1) 322 | features_test = np.concatenate([ 323 | numerical_features_test, 324 | categorical_features_test, 325 | textual_features_test 326 | ], axis=1) 327 | 328 | # save to disk 329 | pd.DataFrame(features).to_csv(Path(model_dir, "train.csv"), index=False) 330 | pd.DataFrame(labels).to_csv(Path(model_dir, "labels.csv"), index=False) 331 | pd.DataFrame(features_test).to_csv(Path(model_dir, "test.csv"), index=False) 332 | pd.DataFrame(labels_test).to_csv(Path(model_dir, "labels_test.csv"), index=False) 333 | 334 | save_feature_names( 335 | numerical_feature_names, 336 | categorical_feature_names, 337 | textual_feature_names, 338 | Path(model_dir, "feature_names.json") 339 | ) 340 | # one-hot encoded feature names (for feat_imp) 341 | save_feature_names( 342 | numerical_feature_names, 343 | categorical_transformer.get_feature_names().tolist(), 344 | textual_feature_names, 345 | Path(model_dir, "one_hot_feature_names.json") 346 | ) 347 | 348 | return features, features_test, labels, labels_test 349 | 350 | 351 | if __name__ == "__main__": 352 | parser = argparse.ArgumentParser() 353 | parser.add_argument("--use-existing", action="store_true") 354 | parser.add_argument("--test-size", type=float, default=0.33) 355 | args = parser.parse_args() 356 | 357 | data_dir = params['data_dir'] 358 | df = pd.read_csv(Path(data_dir, "churn_dataset.csv")) 359 | 360 | if args.use_existing: 361 | use_existing = True 362 | else: 363 | use_existing = False 364 | test_size = args.test_size 365 | 366 | prep_data(df, use_existing, test_size) 367 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ Module for training the churn prediction model with text. 2 | 3 | Training parameters are stored in 'params.yaml'. 4 | 5 | Run in CLI example: 6 | 'python train.py' 7 | 8 | """ 9 | 10 | 11 | import os 12 | import sys 13 | import json 14 | import yaml 15 | import joblib 16 | import logging 17 | import argparse 18 | import numpy as np 19 | import pandas as pd 20 | from pathlib import Path 21 | from matplotlib import pyplot as plt 22 | from sklearn.model_selection import train_test_split 23 | from sklearn.impute import SimpleImputer 24 | from sklearn.preprocessing import OneHotEncoder 25 | from sklearn.metrics import roc_auc_score, roc_curve, auc 26 | from sklearn.metrics import precision_recall_curve 27 | import torch 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | import torch.optim as optim 31 | from torch.utils.data import DataLoader, TensorDataset 32 | from torch import Tensor 33 | import preprocess 34 | from preprocess import BertEncoder 35 | 36 | 37 | logger = logging.getLogger(__name__) 38 | logger.setLevel(logging.DEBUG) 39 | logger.addHandler(logging.StreamHandler(sys.stdout)) 40 | 41 | 42 | with open("../model/params.yaml", "r") as params_file: 43 | params = yaml.safe_load(params_file) 44 | 45 | model_dir = params['model_dir'] 46 | 47 | 48 | class Net(nn.Module): 49 | def __init__(self, x1_size, x2_size): 50 | super(Net, self).__init__() 51 | self.batch_norm = nn.BatchNorm1d(x1_size) 52 | self.fc1 = nn.Linear(x1_size, 10) 53 | self.fc2 = nn.Linear(10 + x2_size, 10) 54 | self.fc3 = nn.Linear(10, 1) 55 | 56 | def forward(self, x1, x2): 57 | x1 = self.batch_norm(x1) 58 | x1 = F.relu(self.fc1(x1)) 59 | x1 = F.dropout(x1, p=0.2, training=self.training) 60 | x12 = torch.cat((x1.view(x1.size(0), -1), 61 | x2.view(x2.size(0), -1)), dim=1) 62 | x12 = F.dropout(x12, p=0.1, training=self.training) 63 | x12 = self.fc2(x12) 64 | out = self.fc3(x12) 65 | return out 66 | 67 | 68 | def train( 69 | X, 70 | y, 71 | X_test, 72 | y_test 73 | ): 74 | # get parameters 75 | batch_size = params['batch_size'] 76 | batch_size_test = params['batch_size_test'] 77 | epochs = params['epochs'] 78 | pos_weight = params['pos_weight'] 79 | lr = params['lr'] 80 | momentum = params['momentum'] 81 | 82 | # prepare training job 83 | X = np.array(X) 84 | y = np.array(y) 85 | X_test = np.array(X_test) 86 | y_test = np.array(y_test) 87 | training_data = TensorDataset( Tensor(X), Tensor(y) ) 88 | train_loader = DataLoader(training_data, batch_size=batch_size, 89 | shuffle=True, 90 | num_workers=4) 91 | test_data = TensorDataset( Tensor(X_test), Tensor(y_test) ) 92 | test_loader = DataLoader(test_data, batch_size=batch_size_test, 93 | shuffle=True, 94 | num_workers=4) 95 | 96 | # get size of num/cat & text data 97 | numerical_feature_names, categorical_feature_names, _ = preprocess.load_feature_names(Path(model_dir, "one_hot_feature_names.json")) 98 | number_cat_num_features = len(numerical_feature_names) + len(categorical_feature_names) 99 | x1_size = X[:, :number_cat_num_features].shape[1] 100 | x2_size = X[:, number_cat_num_features:].shape[1] 101 | 102 | model = Net(x1_size, x2_size) 103 | criterion = nn.BCEWithLogitsLoss(pos_weight=torch.as_tensor(pos_weight, dtype=torch.float)) 104 | #criterion = nn.BCELoss(pos_weight=torch.as_tensor(pos_weight, dtype=torch.float)) 105 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) 106 | 107 | # train NN model 108 | scores_df = pd.DataFrame() 109 | train_scores = [] 110 | test_scores = [] 111 | for epoch in range(1, epochs + 1): # loop over the dataset multiple times 112 | model.train() 113 | print("starting epoch: ", epoch) 114 | 115 | for batch_idx, (data, target) in enumerate(train_loader, 1): 116 | # zero the parameter gradients 117 | optimizer.zero_grad() 118 | 119 | # split data inputs into cat/num features and text features 120 | x1 = data[:, :number_cat_num_features] 121 | x2 = data[:, number_cat_num_features:] 122 | 123 | # forward + backward + optimize 124 | output = model(x1, x2) 125 | loss = criterion(output, target) #target.type_as(output)) #labels_batch.long()) 126 | loss.backward() 127 | optimizer.step() 128 | 129 | # training performance after each epoch 130 | with torch.no_grad(): 131 | output = model(Tensor(X[:, :number_cat_num_features]), Tensor(X[:, number_cat_num_features:])).reshape(-1) 132 | preds = torch.sigmoid(output) 133 | train_score = roc_auc_score(Tensor(y), preds) 134 | train_scores.append(train_score) 135 | logger.info('Train Epoch: {}, train-auc-score: {:.4f}'.format(epoch, train_score)) 136 | 137 | # test performance after each epoch 138 | test_score = test(model, test_loader) 139 | test_scores.append(test_score) 140 | 141 | # save scores 142 | print('saving scores') 143 | scores_df['train_scores'] = train_scores 144 | scores_df['test_scores'] = test_scores 145 | scores_df.to_csv(Path(model_dir, 'training_scores.csv'), index=False) 146 | 147 | # save model 148 | print('saving model') 149 | torch.save(model.state_dict(), Path(model_dir, 'model.pth')) 150 | 151 | return None 152 | 153 | 154 | def test( 155 | model, 156 | test_loader 157 | ): 158 | model.eval() 159 | correct = 0 160 | preds_all = [] 161 | targets_all = [] 162 | with torch.no_grad(): 163 | for data, targets in test_loader: 164 | 165 | # split data inputs 166 | numerical_feature_names, categorical_feature_names, _ = preprocess.load_feature_names(Path(model_dir, "one_hot_feature_names.json")) 167 | number_cat_num_features = len(numerical_feature_names) + len(categorical_feature_names) 168 | x1 = data[:, :number_cat_num_features] 169 | x2 = data[:, number_cat_num_features:] 170 | 171 | output = model(x1, x2).reshape(-1) 172 | preds = torch.sigmoid(output) 173 | preds_all.extend(preds) 174 | targets_all.extend(targets) 175 | test_score = roc_auc_score(Tensor(targets_all), Tensor(preds_all)) 176 | 177 | logger.info('test_auc_score: {:.4f}'.format(test_score)) 178 | 179 | return test_score 180 | 181 | 182 | def get_train_assets( 183 | ): 184 | #print('loading feature_names') 185 | numerical_feature_names, categorical_feature_names, textual_feature_names = preprocess.load_feature_names(Path(model_dir, "feature_names.json")) 186 | #print('loading numerical_transformer') 187 | numerical_transformer = joblib.load(Path(model_dir, "numerical_transformer.joblib")) 188 | #print('loading categorical_transformer') 189 | categorical_transformer = joblib.load(Path(model_dir, "categorical_transformer.joblib")) 190 | #print('loading textual_transformer') 191 | textual_transformer = BertEncoder() 192 | 193 | model_assets = { 194 | 'numerical_feature_names': numerical_feature_names, 195 | 'numerical_transformer': numerical_transformer, 196 | 'categorical_feature_names': categorical_feature_names, 197 | 'categorical_transformer': categorical_transformer, 198 | 'textual_feature_names': textual_feature_names, 199 | 'textual_transformer': textual_transformer 200 | } 201 | return model_assets 202 | 203 | 204 | def predict( 205 | features, 206 | labels 207 | ): 208 | # get size of num/cat & text data to specify neural net 209 | numerical_feature_names, categorical_feature_names, _ = preprocess.load_feature_names(Path(model_dir, "one_hot_feature_names.json")) 210 | number_cat_num_features = len(numerical_feature_names) + len(categorical_feature_names) 211 | x1_size = features[:, :number_cat_num_features].shape[1] 212 | x2_size = features[:, number_cat_num_features:].shape[1] 213 | 214 | model = Net(x1_size, x2_size) 215 | with open(os.path.join(model_dir, 'model.pth'), 'rb') as f: 216 | model.load_state_dict(torch.load(f)) 217 | 218 | model.eval() 219 | with torch.no_grad(): 220 | output = model(Tensor(features[:, :number_cat_num_features]), Tensor(features[:, number_cat_num_features:])).reshape(-1) 221 | preds = torch.sigmoid(output) 222 | 223 | return preds 224 | 225 | 226 | def plot_train_stats( 227 | ): 228 | scores_df = pd.read_csv(Path(model_dir, 'scores.csv')) 229 | scores_df.plot(xlabel='Epochs', ylabel='AUC Score', title='AUC Score') 230 | return None 231 | 232 | 233 | def plot_pr_curve( 234 | features, 235 | labels 236 | ): 237 | preds = predict(features, labels) 238 | precisions, recalls, thresholds = precision_recall_curve(labels, preds) 239 | roc_auc = roc_auc_score(labels, preds) 240 | 241 | fig, ax = plt.subplots(figsize=(6,5)) 242 | ax.plot(recalls, precisions, label='Model w/ text: AUC = %0.2f' % roc_auc) 243 | ax.plot([1, 0], [0, 1],'--') 244 | ax.set_xlabel('Recall') 245 | ax.set_ylabel('Precision') 246 | ax.legend(loc='lower left') 247 | plt.title('Precision-Recall Curve') 248 | plt.show() 249 | return None 250 | 251 | 252 | def plot_roc_curve( 253 | features, 254 | labels 255 | ): 256 | preds = predict(features, labels) 257 | fpr, tpr, threshold = roc_curve(labels, preds) 258 | roc_auc = auc(fpr, tpr) 259 | 260 | plt.title('ROC Curve') 261 | plt.plot(fpr, tpr, 'b', label='AUC = %0.2f' % roc_auc) 262 | plt.legend(loc = 'lower right') 263 | plt.plot([0, 1], [0, 1],'r--') 264 | plt.xlim([0, 1]) 265 | plt.ylim([0, 1]) 266 | plt.ylabel('True Positive Rate') 267 | plt.xlabel('False Positive Rate') 268 | plt.show() 269 | return None 270 | 271 | 272 | if __name__ == "__main__": 273 | 274 | model_dir = params['model_dir'] 275 | 276 | X_train = pd.read_csv(Path(model_dir, "train.csv")) 277 | y_train = pd.read_csv(Path(model_dir, "labels.csv")) 278 | X_test = pd.read_csv(Path(model_dir, "test.csv")) 279 | y_test = pd.read_csv(Path(model_dir, "labels_test.csv")) 280 | 281 | train(X=X_train, y=y_train, X_test=X_test, y_test=y_test) 282 | --------------------------------------------------------------------------------