├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── data
└── text.csv
├── model
└── params.yaml
├── notebook
└── test_churn_prediction.ipynb
├── requirements.txt
└── scripts
├── create_dataset.py
├── interpret.py
├── preprocess.py
└── train.py
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | this software and associated documentation files (the "Software"), to deal in
5 | the Software without restriction, including without limitation the rights to
6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software is furnished to do so.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15 |
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Churn Prediction with Text and Interpretability
2 |
3 | Customer churn, the loss of current customers, is a problem faced by a wide range of companies. When trying to retain customers, it is in a company’s best interest to focus their efforts on customers who are more likely to leave, but companies need a way to detect customers who are likely to leave before they have decided to leave. Users prone to churn often leave clues to their disposition in user behavior and customer support chat logs which can be detected and understood using Natural Language Processing (NLP) tools.
4 |
5 | Here, we demonstrate how to build a churn prediction model that leverages both text and structured data (numerical and categorical) which we call a bi-modal model architecture. We use Amazon SageMaker to prepare, build, and train the model. Detecting customers who are likely to churn is only part of the battle, finding the root cause is an essential part of actually solving the issue. Since we are not only interested in the likelihood of a customer churning but also in the driving factors, we complement the prediction model with an analysis into feature importance for both text and non-text inputs.
6 |
7 | The categorical and numerical data is from Kaggle: Customer Churn Prediction 2020 and was combined with a synthetic text dataset we created using GPT-2.
8 |
9 | ## Blog Post
10 |
11 | [Medium / Towards Data Science blog post](https://towardsdatascience.com/customer-churn-prediction-with-text-and-interpretability-bd3d57af34b1)
12 |
13 | ## Installation
14 |
15 | ```
16 | git clone https://github.com/aws-samples/churn-prediction-with-text-and-interpretability.git
17 | conda create -n py39 python=3.9
18 | conda activate py39
19 | cd churn-prediction-with-text-and-interpretability
20 | pip install -r requirements.txt
21 | ```
22 |
23 | ## Download categorical/numerical data and combine with synthetic text data
24 |
25 | 1. Download categorical/numerical data - [Customer Churn Prediction 2020](https://www.kaggle.com/c/customer-churn-prediction-2020/data)
26 | May require Kaggle account.
27 | Download train.csv and store in data folder.
28 |
29 | 2. Run script to combine categorical data with synthetic text data (../scripts)
30 | ```
31 | python create_dataset.py
32 | ```
33 |
34 | ## Run in Notebook
35 |
36 | An example notebook to run the entire pipeline and print/visualize the results in included in ../notebook.
37 |
38 | ## Run in Terminal
39 |
40 | The python scripts to prepare the data, train and evaluate the model, as well as interpret the model, are stored in ../scripts.
41 | The parameters used for training and interpreting the model are stored in ../model/params.yaml.
42 |
43 |
44 | 1. Prepare the data:
45 | ```
46 | python preprocess.py
47 | ```
48 | 2. Train and evaluate the model:
49 | ```
50 | python train.py
51 | ```
52 | 3. Interpret the trained model (text):
53 | ```
54 | python interpret.py --churn 1 --speaker Customer
55 | ```
56 |
57 | ## Credits
58 |
59 | * Packages:
60 | * [Spacy](https://spacy.io/usage/linguistic-features/)
61 | * [PyTorch](https://pytorch.org/)
62 | * [XGBoost](https://xgboost.readthedocs.io/en/latest/)
63 | * [Hugging Face Sentence Transformers](https://huggingface.co/sentence-transformers)
64 |
65 | * Datasets:
66 | * [Customer Churn Prediction 2020](https://www.kaggle.com/c/customer-churn-prediction-2020/data) (with synthetic text dataset)
67 |
68 | * Models:
69 | * GPT2, Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever
70 | * BERT, Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova
71 | * Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks, Reimers, Nils and Gurevych, Iryna
72 |
73 | ## Security
74 |
75 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
76 |
77 | ## License
78 |
79 | This library is licensed under the MIT-0 License. See the LICENSE file.
80 |
81 |
--------------------------------------------------------------------------------
/model/params.yaml:
--------------------------------------------------------------------------------
1 | data_dir: ../data
2 | model_dir: ../model
3 | batch_size: 10
4 | batch_size_test: 1000
5 | epochs: 10
6 | pos_weight: 10
7 | lr: 0.001
8 | momentum: 0.9
9 | topn_relevant_keywords: 1000
10 | frac_relevant_keywords: 0.25
11 | w_marg_contr: 0.3
12 | w_count: 0.2
13 | w_sim : 0.5
--------------------------------------------------------------------------------
/notebook/test_churn_prediction.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "df276921",
6 | "metadata": {},
7 | "source": [
8 | "# Churn Prediction with Text and Interpretability"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "1c1f7791",
14 | "metadata": {},
15 | "source": [
16 | "This notebook runs the entire churn prediction pipeline from data preparation to model evaluation and interpretation.\n",
17 | "\n",
18 | "Alternatively, everything can be run from the terminal as well (see README.md).\n",
19 | "\n",
20 | "Prerequisite: Dataset has been created (see README.md)."
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "id": "032bb6ce",
26 | "metadata": {},
27 | "source": [
28 | "### Setup"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 1,
34 | "id": "a32ade35",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "import os\n",
39 | "import pandas as pd\n",
40 | "from matplotlib import pyplot as plt\n",
41 | "\n",
42 | "os.chdir(\"../scripts\")\n",
43 | "\n",
44 | "import preprocess\n",
45 | "import train\n",
46 | "import interpret"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 3,
52 | "id": "9d500e76",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "%load_ext autoreload\n",
57 | "%autoreload 2"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "id": "4b8b1336",
63 | "metadata": {},
64 | "source": [
65 | "### Load and Prepare the Data"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 4,
71 | "id": "94c30e10",
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "data": {
76 | "text/html": [
77 | "
\n",
78 | "\n",
91 | "
\n",
92 | " \n",
93 | " \n",
94 | " \n",
95 | " churn \n",
96 | " chat_log \n",
97 | " state \n",
98 | " account_length \n",
99 | " area_code \n",
100 | " international_plan \n",
101 | " voice_mail_plan \n",
102 | " number_vmail_messages \n",
103 | " total_day_minutes \n",
104 | " total_day_calls \n",
105 | " ... \n",
106 | " total_eve_minutes \n",
107 | " total_eve_calls \n",
108 | " total_eve_charge \n",
109 | " total_night_minutes \n",
110 | " total_night_calls \n",
111 | " total_night_charge \n",
112 | " total_intl_minutes \n",
113 | " total_intl_calls \n",
114 | " total_intl_charge \n",
115 | " number_customer_service_calls \n",
116 | " \n",
117 | " \n",
118 | " \n",
119 | " \n",
120 | " 0 \n",
121 | " no \n",
122 | " Customer: Well, the only thing that I'm consid... \n",
123 | " CT \n",
124 | " 134 \n",
125 | " area_code_408 \n",
126 | " no \n",
127 | " no \n",
128 | " 0 \n",
129 | " 177.2 \n",
130 | " 91 \n",
131 | " ... \n",
132 | " 228.7 \n",
133 | " 105 \n",
134 | " 19.44 \n",
135 | " 194.3 \n",
136 | " 113 \n",
137 | " 8.74 \n",
138 | " 8.9 \n",
139 | " 3 \n",
140 | " 2.40 \n",
141 | " 2 \n",
142 | " \n",
143 | " \n",
144 | " 1 \n",
145 | " yes \n",
146 | " Customer: Well, I just want to be able to canc... \n",
147 | " WV \n",
148 | " 78 \n",
149 | " area_code_408 \n",
150 | " no \n",
151 | " no \n",
152 | " 0 \n",
153 | " 226.3 \n",
154 | " 88 \n",
155 | " ... \n",
156 | " 306.2 \n",
157 | " 81 \n",
158 | " 26.03 \n",
159 | " 200.9 \n",
160 | " 120 \n",
161 | " 9.04 \n",
162 | " 7.8 \n",
163 | " 11 \n",
164 | " 2.11 \n",
165 | " 1 \n",
166 | " \n",
167 | " \n",
168 | " 2 \n",
169 | " no \n",
170 | " Customer: I would like data.\\nTelCom Agent: Ok... \n",
171 | " IN \n",
172 | " 88 \n",
173 | " area_code_415 \n",
174 | " no \n",
175 | " no \n",
176 | " 0 \n",
177 | " 183.5 \n",
178 | " 93 \n",
179 | " ... \n",
180 | " 170.5 \n",
181 | " 80 \n",
182 | " 14.49 \n",
183 | " 193.8 \n",
184 | " 88 \n",
185 | " 8.72 \n",
186 | " 8.3 \n",
187 | " 5 \n",
188 | " 2.24 \n",
189 | " 3 \n",
190 | " \n",
191 | " \n",
192 | "
\n",
193 | "
3 rows × 21 columns
\n",
194 | "
"
195 | ],
196 | "text/plain": [
197 | " churn chat_log state \\\n",
198 | "0 no Customer: Well, the only thing that I'm consid... CT \n",
199 | "1 yes Customer: Well, I just want to be able to canc... WV \n",
200 | "2 no Customer: I would like data.\\nTelCom Agent: Ok... IN \n",
201 | "\n",
202 | " account_length area_code international_plan voice_mail_plan \\\n",
203 | "0 134 area_code_408 no no \n",
204 | "1 78 area_code_408 no no \n",
205 | "2 88 area_code_415 no no \n",
206 | "\n",
207 | " number_vmail_messages total_day_minutes total_day_calls ... \\\n",
208 | "0 0 177.2 91 ... \n",
209 | "1 0 226.3 88 ... \n",
210 | "2 0 183.5 93 ... \n",
211 | "\n",
212 | " total_eve_minutes total_eve_calls total_eve_charge total_night_minutes \\\n",
213 | "0 228.7 105 19.44 194.3 \n",
214 | "1 306.2 81 26.03 200.9 \n",
215 | "2 170.5 80 14.49 193.8 \n",
216 | "\n",
217 | " total_night_calls total_night_charge total_intl_minutes \\\n",
218 | "0 113 8.74 8.9 \n",
219 | "1 120 9.04 7.8 \n",
220 | "2 88 8.72 8.3 \n",
221 | "\n",
222 | " total_intl_calls total_intl_charge number_customer_service_calls \n",
223 | "0 3 2.40 2 \n",
224 | "1 11 2.11 1 \n",
225 | "2 5 2.24 3 \n",
226 | "\n",
227 | "[3 rows x 21 columns]"
228 | ]
229 | },
230 | "execution_count": 4,
231 | "metadata": {},
232 | "output_type": "execute_result"
233 | }
234 | ],
235 | "source": [
236 | "df = pd.read_csv('../data/churn_dataset.csv')\n",
237 | "df.head(3)"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "id": "ede78f6b",
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "X_train, X_test, y_train, y_test = preprocess.prep_data(df, use_existing=False, test_size=0.33)"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": 27,
253 | "id": "5d0f6d4f",
254 | "metadata": {},
255 | "outputs": [
256 | {
257 | "data": {
258 | "text/plain": [
259 | "((2233, 841), (1100, 841), (2233, 1), (1100, 1))"
260 | ]
261 | },
262 | "execution_count": 27,
263 | "metadata": {},
264 | "output_type": "execute_result"
265 | }
266 | ],
267 | "source": [
268 | "X_train.shape, X_test.shape, y_train.shape, y_test.shape"
269 | ]
270 | },
271 | {
272 | "cell_type": "markdown",
273 | "id": "c477f751",
274 | "metadata": {},
275 | "source": [
276 | "### Train and Evaluate the Model"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 28,
282 | "id": "0d68bf05",
283 | "metadata": {},
284 | "outputs": [
285 | {
286 | "name": "stdout",
287 | "output_type": "stream",
288 | "text": [
289 | "/home/ec2-user/SageMaker/churn_test/scripts\n"
290 | ]
291 | }
292 | ],
293 | "source": [
294 | "!pwd"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": 29,
300 | "id": "eb2e2878",
301 | "metadata": {
302 | "collapsed": true,
303 | "jupyter": {
304 | "outputs_hidden": true
305 | }
306 | },
307 | "outputs": [
308 | {
309 | "name": "stdout",
310 | "output_type": "stream",
311 | "text": [
312 | "starting epoch: 1\n",
313 | "Train Epoch: 1, train-auc-score: 0.9561\n",
314 | "test_auc_score: 0.9422\n",
315 | "starting epoch: 2\n",
316 | "Train Epoch: 2, train-auc-score: 0.9571\n",
317 | "test_auc_score: 0.9453\n",
318 | "starting epoch: 3\n",
319 | "Train Epoch: 3, train-auc-score: 0.9602\n",
320 | "test_auc_score: 0.9480\n",
321 | "starting epoch: 4\n",
322 | "Train Epoch: 4, train-auc-score: 0.9594\n",
323 | "test_auc_score: 0.9467\n",
324 | "starting epoch: 5\n",
325 | "Train Epoch: 5, train-auc-score: 0.9628\n",
326 | "test_auc_score: 0.9529\n",
327 | "starting epoch: 6\n",
328 | "Train Epoch: 6, train-auc-score: 0.9711\n",
329 | "test_auc_score: 0.9555\n",
330 | "starting epoch: 7\n",
331 | "Train Epoch: 7, train-auc-score: 0.9756\n",
332 | "test_auc_score: 0.9586\n",
333 | "starting epoch: 8\n",
334 | "Train Epoch: 8, train-auc-score: 0.9804\n",
335 | "test_auc_score: 0.9598\n",
336 | "starting epoch: 9\n",
337 | "Train Epoch: 9, train-auc-score: 0.9810\n",
338 | "test_auc_score: 0.9618\n",
339 | "starting epoch: 10\n",
340 | "Train Epoch: 10, train-auc-score: 0.9644\n",
341 | "test_auc_score: 0.9552\n",
342 | "saving scores\n",
343 | "saving model\n"
344 | ]
345 | }
346 | ],
347 | "source": [
348 | "# train the model\n",
349 | "train.train(\n",
350 | " X=X_train,\n",
351 | " y=y_train,\n",
352 | " X_test=X_test,\n",
353 | " y_test=y_test\n",
354 | ")"
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": 30,
360 | "id": "294bd939",
361 | "metadata": {},
362 | "outputs": [
363 | {
364 | "data": {
365 | "image/png": "\n",
366 | "text/plain": [
367 | ""
368 | ]
369 | },
370 | "metadata": {
371 | "needs_background": "light"
372 | },
373 | "output_type": "display_data"
374 | }
375 | ],
376 | "source": [
377 | "# plot training stats\n",
378 | "train.plot_train_stats()"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 31,
384 | "id": "75e22bb5",
385 | "metadata": {},
386 | "outputs": [
387 | {
388 | "data": {
389 | "image/png": "\n",
390 | "text/plain": [
391 | ""
392 | ]
393 | },
394 | "metadata": {
395 | "needs_background": "light"
396 | },
397 | "output_type": "display_data"
398 | }
399 | ],
400 | "source": [
401 | "# plot pr curve\n",
402 | "train.plot_pr_curve(X_test, y_test)"
403 | ]
404 | },
405 | {
406 | "cell_type": "markdown",
407 | "id": "fa34b9d3",
408 | "metadata": {},
409 | "source": [
410 | "### Interpret the Model"
411 | ]
412 | },
413 | {
414 | "cell_type": "markdown",
415 | "id": "2ca793cf",
416 | "metadata": {},
417 | "source": [
418 | "#### Categorical and Numerical Features"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": 32,
424 | "id": "5ad792da",
425 | "metadata": {},
426 | "outputs": [
427 | {
428 | "name": "stderr",
429 | "output_type": "stream",
430 | "text": [
431 | "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/xgboost/sklearn.py:1146: UserWarning: The use of label encoder in XGBClassifier is deprecated and will be removed in a future release. To remove this warning, do the following: 1) Pass option use_label_encoder=False when constructing XGBClassifier object; and 2) Encode your labels (y) as integers starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].\n",
432 | " warnings.warn(label_encoder_deprecation_msg, UserWarning)\n",
433 | "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/sklearn/utils/validation.py:63: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
434 | " return f(*args, **kwargs)\n"
435 | ]
436 | },
437 | {
438 | "name": "stdout",
439 | "output_type": "stream",
440 | "text": [
441 | "[16:15:16] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n"
442 | ]
443 | },
444 | {
445 | "data": {
446 | "image/png": "\n",
447 | "text/plain": [
448 | ""
449 | ]
450 | },
451 | "metadata": {
452 | "needs_background": "light"
453 | },
454 | "output_type": "display_data"
455 | }
456 | ],
457 | "source": [
458 | "preds_xgb = interpret.train_xgb()"
459 | ]
460 | },
461 | {
462 | "cell_type": "markdown",
463 | "id": "9e80253e",
464 | "metadata": {},
465 | "source": [
466 | "#### Textual Features (focus on customer chats that result in churn)"
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "execution_count": 9,
472 | "id": "7f18dcf5",
473 | "metadata": {},
474 | "outputs": [],
475 | "source": [
476 | "chats, df_sub = interpret.get_chats(df=df, churn=1, speaker='Customer')"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": 10,
482 | "id": "0b65267d",
483 | "metadata": {},
484 | "outputs": [
485 | {
486 | "data": {
487 | "text/plain": [
488 | "[\"Well, I just want to be able to cancel the contract because I don't think that I want to stay. My local provider has been terrible and I really would like to switch. Sure, I can.\",\n",
489 | " \"Well, it's the old TelCom billing system for the last 5 years. I don't trust anymore and I think you should change to the newer billing system. I would like to give you a call back number. Okay, I can see why you need the new billing system, but I don't know if I can do that. I would like to know your cancellation policy.\",\n",
490 | " \"Well, I've been getting phone calls from a very good friend who's a TelCom agent and I have told him the same thing and the problem has not been resolved. He has offered me a $20/mo deal but that's not good enough for me because I'm getting $20 out of his pocket. Sure. $60 to cancel for nine months with a $20/mo bonus.\"]"
491 | ]
492 | },
493 | "execution_count": 10,
494 | "metadata": {},
495 | "output_type": "execute_result"
496 | }
497 | ],
498 | "source": [
499 | "chats[:3]"
500 | ]
501 | },
502 | {
503 | "cell_type": "markdown",
504 | "id": "8c73e7dd",
505 | "metadata": {},
506 | "source": [
507 | "##### Candidate keywords (POS tagging, lower casing, lemmatization)"
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "execution_count": 11,
513 | "id": "3db8f636",
514 | "metadata": {},
515 | "outputs": [
516 | {
517 | "name": "stdout",
518 | "output_type": "stream",
519 | "text": [
520 | "['want', 'able', 'cancel', 'contract', 'think', 'want', 'stay', 'local', 'provider', 'terrible']\n"
521 | ]
522 | }
523 | ],
524 | "source": [
525 | "# find candidate keywords\n",
526 | "keywords, tokens = interpret.get_keywords(chats)\n",
527 | "print(keywords[:10])\n",
528 | "\n",
529 | "# map keywords to original tokens\n",
530 | "keywords_dict = interpret.map_to_orig_tok(keywords, tokens)"
531 | ]
532 | },
533 | {
534 | "cell_type": "markdown",
535 | "id": "172eed77",
536 | "metadata": {},
537 | "source": [
538 | "##### Relevant keywords (semantic similarity)"
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": 12,
544 | "id": "07be6aa0",
545 | "metadata": {},
546 | "outputs": [],
547 | "source": [
548 | "relevant_keywords, simMat = interpret.get_relevant_keywords(\n",
549 | " text = chats, \n",
550 | " keywords_dict = keywords_dict\n",
551 | ")"
552 | ]
553 | },
554 | {
555 | "cell_type": "code",
556 | "execution_count": 13,
557 | "id": "267a00b0",
558 | "metadata": {},
559 | "outputs": [
560 | {
561 | "name": "stdout",
562 | "output_type": "stream",
563 | "text": [
564 | "['voicemail', 'bored', 'spam', 'backlog', 'sick', 'frustrated', 'incompetence', 'overcharge', 'angry', 'disappointed', 'scam', 'lag', 'termination', 'resign', 'nightmare', 'incompetent', 'frustration', 'lagging', 'yesterday', 'friday', 'monday', 'discontinue', 'cheat', 'dead', 'annoyed']\n"
565 | ]
566 | }
567 | ],
568 | "source": [
569 | "print(relevant_keywords[:25])"
570 | ]
571 | },
572 | {
573 | "cell_type": "markdown",
574 | "id": "3ba0da26",
575 | "metadata": {},
576 | "source": [
577 | "##### Impactful kewords (marginal contribution)"
578 | ]
579 | },
580 | {
581 | "cell_type": "code",
582 | "execution_count": 14,
583 | "id": "8c877937",
584 | "metadata": {},
585 | "outputs": [
586 | {
587 | "name": "stderr",
588 | "output_type": "stream",
589 | "text": [
590 | "100%|██████████| 250/250 [11:25<00:00, 2.74s/it]\n"
591 | ]
592 | }
593 | ],
594 | "source": [
595 | "# get marginal contribution to prediction for each keyword\n",
596 | "marg_contr_df = interpret.perform_ablation(\n",
597 | " df = df_sub,\n",
598 | " keywords = relevant_keywords,\n",
599 | " keywords_dict = keywords_dict\n",
600 | ")"
601 | ]
602 | },
603 | {
604 | "cell_type": "markdown",
605 | "id": "d43f8aa6",
606 | "metadata": {},
607 | "source": [
608 | "##### Create joint metric (semantic similarity + marginal contribution + count)"
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": null,
614 | "id": "c02d300a",
615 | "metadata": {},
616 | "outputs": [],
617 | "source": [
618 | "# load from local disc if available\n",
619 | "#results_df = pd.read_csv('model/ablation_results.csv')\n",
620 | "#results_df = results_df.rename(columns={'Unnamed: 0' : 'keyword'})"
621 | ]
622 | },
623 | {
624 | "cell_type": "code",
625 | "execution_count": 15,
626 | "id": "2e112e00",
627 | "metadata": {},
628 | "outputs": [],
629 | "source": [
630 | "results_df = interpret.get_important_keywords(\n",
631 | " simMat_df=simMat,\n",
632 | " marg_contr_df=marg_contr_df\n",
633 | ")"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": 16,
639 | "id": "efd90445",
640 | "metadata": {},
641 | "outputs": [
642 | {
643 | "data": {
644 | "text/html": [
645 | "\n",
646 | "\n",
659 | "
\n",
660 | " \n",
661 | " \n",
662 | " \n",
663 | " keyword \n",
664 | " sim \n",
665 | " chg \n",
666 | " count \n",
667 | " joint \n",
668 | " \n",
669 | " \n",
670 | " \n",
671 | " \n",
672 | " 0 \n",
673 | " voicemail \n",
674 | " 90.076553 \n",
675 | " 0.006695 \n",
676 | " 5 \n",
677 | " 0.628774 \n",
678 | " \n",
679 | " \n",
680 | " 1 \n",
681 | " cancel \n",
682 | " 61.187359 \n",
683 | " -0.081321 \n",
684 | " 168 \n",
685 | " 0.545633 \n",
686 | " \n",
687 | " \n",
688 | " 2 \n",
689 | " sick \n",
690 | " 74.789459 \n",
691 | " -0.127242 \n",
692 | " 1 \n",
693 | " 0.538919 \n",
694 | " \n",
695 | " \n",
696 | " 3 \n",
697 | " turnover \n",
698 | " 60.896118 \n",
699 | " -0.286630 \n",
700 | " 1 \n",
701 | " 0.533321 \n",
702 | " \n",
703 | " \n",
704 | " 4 \n",
705 | " disappointed \n",
706 | " 70.248131 \n",
707 | " -0.091740 \n",
708 | " 5 \n",
709 | " 0.522520 \n",
710 | " \n",
711 | " \n",
712 | " 5 \n",
713 | " spam \n",
714 | " 77.940460 \n",
715 | " -0.000022 \n",
716 | " 3 \n",
717 | " 0.506429 \n",
718 | " \n",
719 | " \n",
720 | " 6 \n",
721 | " bored \n",
722 | " 78.131271 \n",
723 | " -0.038932 \n",
724 | " 1 \n",
725 | " 0.502213 \n",
726 | " \n",
727 | " \n",
728 | " 7 \n",
729 | " unhappy \n",
730 | " 65.601990 \n",
731 | " -0.024782 \n",
732 | " 37 \n",
733 | " 0.493910 \n",
734 | " \n",
735 | " \n",
736 | " 8 \n",
737 | " frustrated \n",
738 | " 73.496033 \n",
739 | " -0.006023 \n",
740 | " 5 \n",
741 | " 0.486930 \n",
742 | " \n",
743 | " \n",
744 | " 9 \n",
745 | " mistake \n",
746 | " 66.247879 \n",
747 | " -0.093309 \n",
748 | " 3 \n",
749 | " 0.470609 \n",
750 | " \n",
751 | " \n",
752 | " 10 \n",
753 | " late \n",
754 | " 65.971649 \n",
755 | " -0.003873 \n",
756 | " 18 \n",
757 | " 0.458023 \n",
758 | " \n",
759 | " \n",
760 | " 11 \n",
761 | " error \n",
762 | " 63.123695 \n",
763 | " -0.065609 \n",
764 | " 7 \n",
765 | " 0.448412 \n",
766 | " \n",
767 | " \n",
768 | " 12 \n",
769 | " faulty \n",
770 | " 55.486092 \n",
771 | " -0.239817 \n",
772 | " 1 \n",
773 | " 0.448232 \n",
774 | " \n",
775 | " \n",
776 | " 13 \n",
777 | " angry \n",
778 | " 70.865860 \n",
779 | " -0.002967 \n",
780 | " 3 \n",
781 | " 0.444018 \n",
782 | " \n",
783 | " \n",
784 | " 14 \n",
785 | " backlog \n",
786 | " 74.808548 \n",
787 | " 0.000020 \n",
788 | " 1 \n",
789 | " 0.442185 \n",
790 | " \n",
791 | " \n",
792 | " 15 \n",
793 | " customer \n",
794 | " 52.072552 \n",
795 | " -0.006985 \n",
796 | " 480 \n",
797 | " 0.439737 \n",
798 | " \n",
799 | " \n",
800 | " 16 \n",
801 | " lag \n",
802 | " 69.912178 \n",
803 | " -0.004650 \n",
804 | " 3 \n",
805 | " 0.436584 \n",
806 | " \n",
807 | " \n",
808 | " 17 \n",
809 | " overpay \n",
810 | " 65.573105 \n",
811 | " -0.093846 \n",
812 | " 1 \n",
813 | " 0.429261 \n",
814 | " \n",
815 | " \n",
816 | " 18 \n",
817 | " disconnect \n",
818 | " 64.002014 \n",
819 | " -0.010485 \n",
820 | " 11 \n",
821 | " 0.429104 \n",
822 | " \n",
823 | " \n",
824 | " 19 \n",
825 | " incompetence \n",
826 | " 73.107452 \n",
827 | " -0.000754 \n",
828 | " 1 \n",
829 | " 0.427229 \n",
830 | " \n",
831 | " \n",
832 | "
\n",
833 | "
"
834 | ],
835 | "text/plain": [
836 | " keyword sim chg count joint\n",
837 | "0 voicemail 90.076553 0.006695 5 0.628774\n",
838 | "1 cancel 61.187359 -0.081321 168 0.545633\n",
839 | "2 sick 74.789459 -0.127242 1 0.538919\n",
840 | "3 turnover 60.896118 -0.286630 1 0.533321\n",
841 | "4 disappointed 70.248131 -0.091740 5 0.522520\n",
842 | "5 spam 77.940460 -0.000022 3 0.506429\n",
843 | "6 bored 78.131271 -0.038932 1 0.502213\n",
844 | "7 unhappy 65.601990 -0.024782 37 0.493910\n",
845 | "8 frustrated 73.496033 -0.006023 5 0.486930\n",
846 | "9 mistake 66.247879 -0.093309 3 0.470609\n",
847 | "10 late 65.971649 -0.003873 18 0.458023\n",
848 | "11 error 63.123695 -0.065609 7 0.448412\n",
849 | "12 faulty 55.486092 -0.239817 1 0.448232\n",
850 | "13 angry 70.865860 -0.002967 3 0.444018\n",
851 | "14 backlog 74.808548 0.000020 1 0.442185\n",
852 | "15 customer 52.072552 -0.006985 480 0.439737\n",
853 | "16 lag 69.912178 -0.004650 3 0.436584\n",
854 | "17 overpay 65.573105 -0.093846 1 0.429261\n",
855 | "18 disconnect 64.002014 -0.010485 11 0.429104\n",
856 | "19 incompetence 73.107452 -0.000754 1 0.427229"
857 | ]
858 | },
859 | "execution_count": 16,
860 | "metadata": {},
861 | "output_type": "execute_result"
862 | }
863 | ],
864 | "source": [
865 | "results_df.head(20)"
866 | ]
867 | },
868 | {
869 | "cell_type": "markdown",
870 | "id": "7036dc83",
871 | "metadata": {},
872 | "source": [
873 | "##### Context of keywords"
874 | ]
875 | },
876 | {
877 | "cell_type": "code",
878 | "execution_count": 17,
879 | "id": "5c0f0d2b",
880 | "metadata": {},
881 | "outputs": [],
882 | "source": [
883 | "keyword_of_interest = 'spam'"
884 | ]
885 | },
886 | {
887 | "cell_type": "code",
888 | "execution_count": 18,
889 | "id": "041befbf",
890 | "metadata": {},
891 | "outputs": [
892 | {
893 | "name": "stdout",
894 | "output_type": "stream",
895 | "text": [
896 | "Basically, I'm getting a lot of spam calls every day from a guy named Michael who's calling from a really weird number.\n",
897 | "TelCom started to flood me with emails and phone calls, spamming me with thousands of phony invoices.\n",
898 | "I just got some spam messages last night, and today it's been getting a lot of texts that I \"don't have my SIM card\" and \"I need my SIM card.\n"
899 | ]
900 | }
901 | ],
902 | "source": [
903 | "interpret.obtain_context(\n",
904 | " chats_list = chats,\n",
905 | " keyword = keyword_of_interest\n",
906 | ")"
907 | ]
908 | },
909 | {
910 | "cell_type": "code",
911 | "execution_count": null,
912 | "id": "a294f494",
913 | "metadata": {},
914 | "outputs": [],
915 | "source": []
916 | }
917 | ],
918 | "metadata": {
919 | "kernelspec": {
920 | "display_name": "conda_pytorch_p36",
921 | "language": "python",
922 | "name": "conda_pytorch_p36"
923 | },
924 | "language_info": {
925 | "codemirror_mode": {
926 | "name": "ipython",
927 | "version": 3
928 | },
929 | "file_extension": ".py",
930 | "mimetype": "text/x-python",
931 | "name": "python",
932 | "nbconvert_exporter": "python",
933 | "pygments_lexer": "ipython3",
934 | "version": "3.6.13"
935 | }
936 | },
937 | "nbformat": 4,
938 | "nbformat_minor": 5
939 | }
940 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas
2 | matplotlib
3 | sentence_transformers==2.0.0
4 | xgboost==1.4.2
5 | spacy>=3.0.0,<4.0.0
6 | https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm
--------------------------------------------------------------------------------
/scripts/create_dataset.py:
--------------------------------------------------------------------------------
1 | """ Module for creating the dataset by combining categorical/numerical and text csv. files.
2 |
3 | First, download categorical/numerical data into data folder as described in README.md, then:
4 |
5 | Run in CLI example:
6 | 'python create_dataset.py'
7 |
8 | """
9 |
10 | import yaml
11 | import pandas as pd
12 | from pathlib import Path
13 |
14 | with open("../model/params.yaml", "r") as params_file:
15 | params = yaml.safe_load(params_file)
16 |
17 |
18 | def create_joint_dataset(
19 | df_categorical,
20 | df_text
21 | ):
22 | df_text_no = df_text[df_text.churn == 'no'].reset_index(drop=True)
23 | df_text_yes = df_text[df_text.churn == 'yes'].reset_index(drop=True)
24 |
25 | df_cat_no = df_categorical[df_categorical.churn == 'no'].reset_index(drop=True)[:len(df_text_no)]
26 | df_cat_yes = df_categorical[df_categorical.churn == 'yes'].reset_index(drop=True)[:len(df_text_yes)]
27 |
28 | df_no = pd.concat([df_text_no, df_cat_no.iloc[:,:-1]], axis=1)
29 | df_yes = pd.concat([df_text_yes, df_cat_yes.iloc[:,:-1]], axis=1)
30 |
31 | df = pd.concat([df_no, df_yes], axis=0)
32 |
33 | # shuffle data
34 | df = df.sample(frac=1).reset_index(drop=True)
35 |
36 | return df
37 |
38 |
39 | if __name__ == "__main__":
40 | data_dir = params['data_dir']
41 |
42 | # load data
43 | df_categorical = pd.read_csv(Path(data_dir, "train.csv"))
44 | df_text = pd.read_csv(Path(data_dir, "text.csv"))
45 |
46 | # create joint dataset
47 | df = create_joint_dataset(df_categorical, df_text)
48 |
49 | # save data
50 | df.to_csv(Path(data_dir, "churn_dataset.csv"), index=False)
51 |
--------------------------------------------------------------------------------
/scripts/interpret.py:
--------------------------------------------------------------------------------
1 | """ Module for interpreting the trained churn prediction model with text.
2 |
3 | The parameters for model interpretation are stored in 'params.yaml'.
4 |
5 | Run in CLI example:
6 | 'python interpret.py --churn 1 --speaker Customer'
7 |
8 | """
9 |
10 |
11 | import json
12 | import yaml
13 | import spacy
14 | import torch
15 | import argparse
16 | import numpy as np
17 | import pandas as pd
18 | from tqdm import tqdm
19 | from pathlib import Path
20 | from sklearn import preprocessing
21 | from xgboost import XGBClassifier
22 | from collections import defaultdict
23 | from matplotlib import pyplot as plt
24 | import preprocess
25 | from preprocess import BertEncoder
26 | import train
27 |
28 | nlp = spacy.load('en_core_web_sm')
29 |
30 | with open("../model/params.yaml", "r") as params_file:
31 | params = yaml.safe_load(params_file)
32 |
33 | model_dir = params['model_dir']
34 |
35 |
36 | def get_chats(
37 | df,
38 | churn=None,
39 | speaker=None
40 | ):
41 | """
42 | Args:
43 | df: dataframe with all data
44 | churn (int): 1 for churn, 0 for no churn, None (default) for all chats
45 | customer (str): 'Customer' for only customer chats,
46 | 'Agent' for only agent chats,
47 | None (default) for all chats
48 |
49 | Returns:
50 | list of chats (strings)
51 | """
52 | # convert labels to binary numeric
53 | df = preprocess.convert_label(df)
54 | # keep only churn/no churn chat logs
55 | if churn is not None:
56 | df = df[df.churn == churn]
57 | # drop short chat logs
58 | df = df[df.chat_log.apply(lambda x: len(str(x))>=5)]
59 |
60 | # select chats
61 | chat_logs = list(df['chat_log'])
62 | chat_logs = [i if isinstance(i, str) else "nan" for i in chat_logs]
63 |
64 | if speaker is not None:
65 | chats = []
66 | for chat in chat_logs:
67 | sents = chat.split('\n')
68 | cchat = []
69 | for sent in sents:
70 | if str(sent).split(':')[0] == speaker:
71 | cchat.append(sent[10:])
72 | chats.append(' '.join(cchat))
73 | else:
74 | chats = chat_logs
75 |
76 | return chats, df
77 |
78 |
79 | def get_keywords(texts):
80 | """Returns candidate keywords based on POS as well as original form of keyword.
81 | """
82 | candidate_pos = ['ADJ', 'VERB', 'NOUN', 'PROPN']
83 | keywords = []
84 | tokens = [] # for referencing keywords in original text later on
85 | for text in texts:
86 | text_keywords = []
87 | text_tokens = []
88 | doc = nlp(text)
89 | for token in doc:
90 | if token.pos_ in candidate_pos and token.is_stop is False:
91 | text_tokens.append(str(token))
92 | text_keywords.append(token.lemma_.lower())
93 |
94 | keywords.extend(text_keywords)
95 | tokens.extend(text_tokens)
96 |
97 | return keywords, tokens
98 |
99 |
100 | def map_to_orig_tok(keywords, tokens):
101 | """Create dictionary mapping keywords to original tokens.
102 | """
103 | keywords_dict = defaultdict(list)
104 | for kw, t in zip(keywords, tokens):
105 | keywords_dict[kw].append(t)
106 | for kw, l in keywords_dict.items():
107 | keywords_dict[kw] = list(set(l))
108 |
109 | keywords_dict = dict(keywords_dict)
110 |
111 | return keywords_dict
112 |
113 |
114 | def get_relevant_keywords(
115 | text,
116 | keywords_dict
117 | ):
118 | """Returns relevant keywords based on semantic similarity to class embedding.
119 | """
120 | # obtain class embedding
121 | textual_transformer = BertEncoder()
122 | textual_features = textual_transformer.transform(text)
123 | class_embedding = np.mean(np.array(textual_features), axis=0)
124 |
125 | # obtain relevant keywords
126 | unique_keywords = list(keywords_dict.keys())
127 | topn_relevant_keywords = params['topn_relevant_keywords']
128 | relevant_keywords, simMat = relevant_keywords_helper(unique_keywords,
129 | class_embedding,
130 | topn_relevant_keywords)
131 | return relevant_keywords, simMat
132 |
133 |
134 | def relevant_keywords_helper(
135 | keywords,
136 | class_embedding,
137 | topn_relevant_keywords
138 | ):
139 | """Helper function for obtaining embedding similarity.
140 | """
141 | textual_transformer = BertEncoder()
142 | keyword_embeddings = textual_transformer.transform(keywords)
143 |
144 | simMatrix = np.dot(keyword_embeddings, class_embedding.T)
145 | d = {"keyword" : keywords, "sim" : list(simMatrix)}
146 | df_simMatrix = pd.DataFrame(d).sort_values(by="sim", ascending=False).reset_index(drop=True)
147 | relevant_keywords = list(df_simMatrix['keyword'])[:topn_relevant_keywords]
148 |
149 | return relevant_keywords, df_simMatrix
150 |
151 |
152 | def prep_ablation(df):
153 | """Prepare data for making predictions."""
154 |
155 | # convert df to list of dicts
156 | data = df.to_json(orient="records")
157 | data = json.loads(data)
158 |
159 | # load model assets from training job
160 | model_assets = train.get_train_assets()
161 | #print('extracting features')
162 | numerical_features, categorical_features, textual_features = preprocess.extract_features(
163 | data,
164 | model_assets['numerical_feature_names'],
165 | model_assets['categorical_feature_names'],
166 | model_assets['textual_feature_names']
167 | )
168 |
169 | # extract labels
170 | _, _, _, label_name = preprocess.get_feature_names(df)
171 | labels = preprocess.extract_labels(
172 | data,
173 | label_name
174 | )
175 |
176 | # preprocess the data
177 | #print('transforming numerical_features')
178 | numerical_features = model_assets['numerical_transformer'].transform(numerical_features)
179 | #print('transforming categorical_features')
180 | categorical_features = model_assets['categorical_transformer'].transform(categorical_features)
181 | #print('transforming textual_features')
182 | textual_features = model_assets['textual_transformer'].transform(textual_features)
183 |
184 | #print('concatenating features')
185 | categorical_features = categorical_features.toarray()
186 | textual_features = np.array(textual_features)
187 | textual_features = textual_features.reshape(textual_features.shape[0], -1)
188 | features = np.concatenate([
189 | numerical_features,
190 | categorical_features,
191 | textual_features
192 | ], axis=1)
193 |
194 | return features, labels
195 |
196 |
197 | def perform_ablation(
198 | df,
199 | keywords,
200 | keywords_dict
201 | ):
202 | """Predict w/ and w/o keywords ablated.
203 | """
204 | # select subset of relevant keywords
205 | frac_relevant_keywords = params['frac_relevant_keywords']
206 | topn = int(params['topn_relevant_keywords'] * frac_relevant_keywords)
207 | keywords_select = keywords[:topn]
208 | keywords_select_dict = {}
209 | for kw in keywords_select:
210 | keywords_select_dict[kw] = keywords_dict[kw]
211 |
212 | # loop through keywords and perform ablation analysis
213 | results_dict = {}
214 | for keyword, keywords_list in tqdm(keywords_select_dict.items()):
215 | # get portion of df where keywords occur
216 | df_incl = pd.DataFrame()
217 | df_excl = pd.DataFrame()
218 | for i, row in df.iterrows():
219 | for kw in keywords_list:
220 | if kw in row['chat_log']:
221 | # df including keyword in chat
222 | df_incl = df_incl.append(row)
223 | # df excluding keyword in chat
224 | chat_wkw = row['chat_log'].replace(kw, ' ')
225 | row['chat_log'] = chat_wkw
226 | df_excl = df_excl.append(row)
227 |
228 | # prep data incl/excl keyword in text (loads trained preprocessors)
229 | features_incl, labels_incl = prep_ablation(
230 | df_incl
231 | )
232 |
233 | features_excl, labels_excl = prep_ablation(
234 | df_excl
235 | )
236 |
237 | # predict using previously trained model
238 | pred_incl = train.predict(features_incl, labels_incl)
239 | pred_excl = train.predict(features_excl, labels_excl)
240 |
241 | # save results
242 | results_dict[keyword] = {'incl' : np.average(pred_incl),
243 | 'excl' : np.average(pred_excl),
244 | 'chg' : np.average(pred_excl) - np.average(pred_incl),
245 | 'count' : len(pred_incl)}
246 |
247 | # store results
248 | results_df = pd.DataFrame.from_dict(results_dict, orient='index')
249 | results_df = results_df.sort_values(by=["chg", "count"], ascending=(True, False))
250 | results_df.to_csv(Path(model_dir, "ablation_results.csv"))
251 |
252 | return results_df
253 |
254 |
255 | def get_important_keywords(
256 | simMat_df,
257 | marg_contr_df
258 | ):
259 | # merge dataframes
260 | marg_contr_df.insert(0, 'keyword', marg_contr_df.index)
261 | marg_contr_df = marg_contr_df.drop(['incl', 'excl'], axis=1)
262 | results_df = marg_contr_df.merge(simMat_df, how='left', on='keyword')
263 |
264 | # rescale metrics
265 | temp_df = results_df[['chg','count','sim']].copy()
266 | temp_df['chg'] = temp_df['chg'] * (-1)
267 | temp_df['count'] = np.log(temp_df['count']) # taking log transformation on keyword counts due to extreme outliers
268 | min_max_scaler = preprocessing.MinMaxScaler(feature_range=(0, 1))
269 | temp_df = min_max_scaler.fit_transform(temp_df)
270 | temp_df = pd.DataFrame(temp_df, columns=['chg','count','sim'])
271 |
272 | # calculate weighted average
273 | w_marg_contr = params['w_marg_contr']
274 | w_count = params['w_count']
275 | w_sim = params['w_sim']
276 | temp_df['joint'] = temp_df['chg'] * w_marg_contr + temp_df['count'] * w_count + temp_df['sim'] * w_sim
277 |
278 | results_df['joint'] = temp_df['joint']
279 | results_df = results_df.sort_values(by="joint", ascending=False).reset_index(drop=True)
280 | results_df = results_df[['keyword','sim','chg','count','joint']]
281 |
282 | # store results
283 | results_df.to_csv(Path(model_dir, "important_keywords.csv"))
284 |
285 | return results_df
286 |
287 |
288 | def obtain_context(
289 | chats_list,
290 | keyword,
291 | limit=3
292 | ):
293 | """Prints out limited number of chats where keyword occurs.
294 | """
295 | nlp = spacy.load('en_core_web_sm')
296 | counter = 0
297 | for chat in chats_list:
298 | doc = nlp(chat)
299 | for sent in doc.sents:
300 | if keyword in sent.text and counter < limit:
301 | print(sent.text)
302 | print('\n')
303 | counter+=1
304 | if counter >= limit:
305 | break
306 |
307 | return None
308 |
309 |
310 | def train_xgb():
311 | """Train XGBoost model to explain categorical and numerical feature importance.
312 | """
313 | # load one-hot feature names
314 | filepath = Path(model_dir, "one_hot_feature_names.json")
315 | numerical_feature_names, categorical_feature_names, _ = preprocess.load_feature_names(filepath)
316 | one_hot_feature_names = numerical_feature_names + categorical_feature_names
317 |
318 | # load train data and exclude text embeddings
319 | train = pd.read_csv(Path(model_dir, "train.csv"))
320 | train_notext = train.iloc[:, :len(one_hot_feature_names)]
321 | labels = pd.read_csv(Path(model_dir, "labels.csv"))
322 | test = pd.read_csv(Path(model_dir, "test.csv"))
323 | test_notext = test.iloc[:, :len(one_hot_feature_names)]
324 | labels_test = pd.read_csv(Path(model_dir, "labels_test.csv"))
325 |
326 | # train XGBoost model
327 | xgb = XGBClassifier()
328 | xgb.fit(train_notext, labels)
329 |
330 | y_pred = xgb.predict_proba(test_notext)
331 | y_pred = [p[1] for p in y_pred]
332 |
333 | # plot important features
334 | topn = 10
335 | sorted_idx = xgb.feature_importances_.argsort()
336 | plt.barh(np.array(one_hot_feature_names)[sorted_idx[-topn:]],
337 | xgb.feature_importances_[sorted_idx[-topn:]])
338 | plt.title("Xgboost Feature Importance")
339 | plt.show()
340 |
341 | return y_pred
342 |
343 |
344 | if __name__ == "__main__":
345 | parser = argparse.ArgumentParser()
346 | parser.add_argument("--churn", type=int, default=1)
347 | parser.add_argument("--speaker", type=str, default='Customer')
348 | args = parser.parse_args()
349 |
350 | data_dir = params['data_dir']
351 | df = pd.read_csv(Path(data_dir, "churn_dataset.csv"))
352 |
353 | churn = args.churn
354 | speaker = args.speaker
355 | chats, df_sub = get_chats(df, churn, speaker)
356 | keywords, tokens = get_keywords(chats)
357 | keywords_dict = map_to_orig_tok(keywords, tokens)
358 |
359 | relevant_keywords, simMat_df = get_relevant_keywords(chats, keywords_dict)
360 | marg_contr_df = perform_ablation(df_sub, relevant_keywords, keywords_dict)
361 |
362 | get_important_keywords(simMat_df, marg_contr_df)
363 |
--------------------------------------------------------------------------------
/scripts/preprocess.py:
--------------------------------------------------------------------------------
1 | """ Module for preparing the data for the churn prediction model with text.
2 |
3 | Run in CLI example:
4 | 'python preprocess.py --test-size 0.33'
5 |
6 | """
7 |
8 |
9 | import os
10 | import sys
11 | import json
12 | import yaml
13 | import joblib
14 | import logging
15 | import argparse
16 | import numpy as np
17 | import pandas as pd
18 | from pathlib import Path
19 | from matplotlib import pyplot as plt
20 | from sklearn.model_selection import train_test_split
21 | from sklearn.impute import SimpleImputer
22 | from sklearn.preprocessing import OneHotEncoder
23 | from sklearn.base import BaseEstimator, TransformerMixin
24 | from sentence_transformers import SentenceTransformer
25 |
26 |
27 | with open("../model/params.yaml", "r") as params_file:
28 | params = yaml.safe_load(params_file)
29 |
30 | model_dir = params['model_dir']
31 |
32 |
33 | class BertEncoder(BaseEstimator, TransformerMixin):
34 | def __init__(self, model_name='bert-base-nli-mean-tokens'):
35 | self.model = SentenceTransformer(model_name)
36 | self.model.parallel_tokenization = False
37 |
38 | def fit(self, X, y=None):
39 | return self
40 |
41 | def transform(self, X):
42 | output = []
43 | for sample in X:
44 | encodings = self.model.encode(sample)
45 | output.append(encodings)
46 | return output
47 |
48 |
49 | def extract_labels(
50 | data,
51 | label_name
52 | ):
53 | labels = []
54 | for sample in data:
55 | value = sample[label_name]
56 | labels.append(value)
57 | labels = np.array(labels).astype('int')
58 |
59 | return labels.reshape(labels.shape[0],-1)
60 |
61 |
62 | def convert_label(
63 | df
64 | ):
65 | df.churn = df.churn.replace("no", 0)
66 | df.churn = df.churn.replace("yes", 1)
67 | return df
68 |
69 | def extract_numerical_features(
70 | sample,
71 | numerical_feature_names
72 | ):
73 | output = []
74 | for feature_name in numerical_feature_names:
75 | if feature_name in sample.keys():
76 | value = sample[feature_name]
77 | if value is None:
78 | value = np.nan
79 | else:
80 | value = np.nan
81 | output.append(value)
82 | return output
83 |
84 |
85 | def extract_categorical_features(
86 | sample,
87 | categorical_feature_names
88 | ):
89 | output = []
90 | for feature_name in categorical_feature_names:
91 | if feature_name in sample.keys():
92 | value = sample[feature_name]
93 | if value is None:
94 | value = ""
95 | else:
96 | value = ""
97 | output.append(value)
98 | return output
99 |
100 |
101 | def extract_textual_features(
102 | sample,
103 | textual_feature_names
104 | ):
105 | output = []
106 | for feature_name in textual_feature_names:
107 | if feature_name in sample.keys():
108 | value = sample[feature_name]
109 | if value is None:
110 | value = ""
111 | else:
112 | value = ""
113 | output.append(value)
114 | return output
115 |
116 |
117 | def split_data(
118 | df,
119 | label_name,
120 | test_size
121 | ):
122 | """Splits data and creates json format.
123 | """
124 | X = df.drop(columns=[label_name], axis=1)
125 | y = df[label_name]
126 | X_train, X_test, y_train, y_test = train_test_split(
127 | X, y, test_size=test_size, random_state=123, stratify=y)
128 |
129 | train = pd.DataFrame(X_train, columns = X.columns)
130 | train[label_name] = y_train
131 | test = pd.DataFrame(X_test, columns = X.columns)
132 | test[label_name] = y_test
133 |
134 | # create list of dicts
135 | train = train.to_json(orient="records")
136 | train = json.loads(train)
137 | test = test.to_json(orient="records")
138 | test = json.loads(test)
139 |
140 | return train, test
141 |
142 |
143 | def extract_features(
144 | data,
145 | numerical_feature_names,
146 | categorical_feature_names,
147 | textual_feature_names
148 | ):
149 | """extract features by given feature names.
150 | """
151 | numerical_features = []
152 | categorical_features = []
153 | textual_features = []
154 | for sample in data:
155 | num_feat = extract_numerical_features(sample, numerical_feature_names)
156 | numerical_features.append(num_feat)
157 | cat_feat = extract_categorical_features(sample, categorical_feature_names)
158 | categorical_features.append(cat_feat)
159 | text_feat = extract_textual_features(sample, textual_feature_names)
160 | textual_features.append(text_feat)
161 |
162 | textual_features = [i if isinstance(i[0], str) else ["nan"] for i in textual_features]
163 | textual_features = [i[0] for i in textual_features]
164 |
165 | return numerical_features, categorical_features, textual_features
166 |
167 |
168 | def save_feature_names(
169 | numerical_feature_names,
170 | categorical_feature_names,
171 | textual_feature_names,
172 | filepath
173 | ):
174 | feature_names = {
175 | 'numerical': numerical_feature_names,
176 | 'categorical': categorical_feature_names,
177 | 'textual': textual_feature_names
178 | }
179 | with open(filepath, 'w') as f:
180 | json.dump(feature_names, f)
181 |
182 |
183 | def load_feature_names(filepath):
184 | with open(filepath, 'r') as f:
185 | feature_names = json.load(f)
186 | numerical_feature_names = feature_names['numerical']
187 | categorical_feature_names = feature_names['categorical']
188 | textual_feature_names = feature_names['textual']
189 | return numerical_feature_names, categorical_feature_names, textual_feature_names
190 |
191 |
192 | def get_feature_names(
193 | df
194 | ):
195 | num_columns = df.select_dtypes(include=np.number).columns.tolist()
196 | numerical_feature_names = [i for i in num_columns if i not in ['churn']]
197 |
198 | cat_columns = df.select_dtypes(include='object').columns.tolist()
199 | categorical_feature_names = [i for i in cat_columns if i not in ['chat_log']]
200 |
201 | textual_feature_names = ['chat_log']
202 | label_name = 'churn'
203 |
204 | return numerical_feature_names, categorical_feature_names, textual_feature_names, label_name
205 |
206 |
207 | def prep_data(
208 | df, use_existing=False, test_size=0.33
209 | ):
210 | """
211 | Args:
212 | df: Pandas dataframe with raw data
213 | use_existing: Set to True if you want to use locally stored,
214 | already prepared train/test data. Set to False if you want
215 | to rerun the data preparation pipeline.
216 | Returns:
217 | Train and test data as well as train labels and test labels.
218 | """
219 | # if prepared data exists, don't prepare again if use_existing set to True
220 | train_file = Path(model_dir, 'train.csv')
221 | labels_file = Path(model_dir, 'labels.csv')
222 | test_file = Path(model_dir, 'test.csv')
223 | labels_test_file = Path(model_dir, 'labels_test.csv')
224 | feature_names_file = Path(model_dir, "feature_names.json")
225 | oh_feature_names_file = Path(model_dir, "one_hot_feature_names.json")
226 | all_file_paths = [train_file, labels_file, test_file, labels_test_file,
227 | feature_names_file, oh_feature_names_file]
228 |
229 | if use_existing == True and all(file.exists() for file in all_file_paths):
230 | features = np.array(pd.read_csv('../model/train.csv'))
231 | labels = np.array(pd.read_csv('../model/labels.csv'))
232 | features_test = np.array(pd.read_csv('../model/test.csv'))
233 | labels_test = np.array(pd.read_csv('../model/labels_test.csv'))
234 | print("Using already prepared data.")
235 |
236 | else:
237 | print("Running data preparation pipeline...")
238 | # convert label to binary numeric
239 | df = convert_label(df)
240 |
241 | # extract feature names
242 | numerical_feature_names, categorical_feature_names, textual_feature_names, label_name = get_feature_names(
243 | df
244 | )
245 | # train/test split and convert to json format (list of dicts)
246 | train, test = split_data(
247 | df,
248 | label_name,
249 | test_size
250 | )
251 | # extract features & label
252 | print('extracting features')
253 | numerical_features, categorical_features, textual_features = extract_features(
254 | train,
255 | numerical_feature_names,
256 | categorical_feature_names,
257 | textual_feature_names
258 | )
259 | labels = extract_labels(
260 | train,
261 | label_name
262 | )
263 | # extract features & label (for test data)
264 | numerical_features_test, categorical_features_test, textual_features_test = extract_features(
265 | test,
266 | numerical_feature_names,
267 | categorical_feature_names,
268 | textual_feature_names
269 | )
270 | labels_test = extract_labels(
271 | test,
272 | label_name
273 | )
274 | # define preprocessors
275 | print('defining preprocessors')
276 | numerical_transformer = SimpleImputer(missing_values=np.nan, strategy='mean', add_indicator=True)
277 | categorical_transformer = OneHotEncoder(handle_unknown="ignore")
278 | textual_transformer = BertEncoder()
279 |
280 | # fit preprocessors
281 | print('fitting numerical_transformer')
282 | numerical_transformer.fit(numerical_features)
283 | print('saving numerical_transformer')
284 | joblib.dump(numerical_transformer, Path(model_dir, "numerical_transformer.joblib"))
285 | print('fitting categorical_transformer')
286 | categorical_transformer.fit(categorical_features)
287 | print('saving categorical_transformer')
288 | joblib.dump(categorical_transformer, Path(model_dir, "categorical_transformer.joblib"))
289 |
290 | # transform features
291 | print('transforming numerical_features')
292 | numerical_features = numerical_transformer.transform(numerical_features)
293 | print('transforming categorical_features')
294 | categorical_features = categorical_transformer.transform(categorical_features)
295 | print('transforming textual_features')
296 | textual_features = textual_transformer.transform(textual_features)
297 |
298 | # transform features (for test data)
299 | print('transforming numerical_features_test')
300 | numerical_features_test = numerical_transformer.transform(numerical_features_test)
301 | print('transforming categorical_features_test')
302 | categorical_features_test = categorical_transformer.transform(categorical_features_test)
303 | print('transforming textual_features_test')
304 | textual_features_test = textual_transformer.transform(textual_features_test)
305 |
306 | # concat features
307 | print('concatenating features')
308 | categorical_features = categorical_features.toarray()
309 | textual_features = np.array(textual_features)
310 | textual_features = textual_features.reshape(textual_features.shape[0], -1)
311 | features = np.concatenate([
312 | numerical_features,
313 | categorical_features,
314 | textual_features
315 | ], axis=1)
316 |
317 | # concat features (test data)
318 | print('concatenating features of test data')
319 | categorical_features_test = categorical_features_test.toarray()
320 | textual_features_test = np.array(textual_features_test)
321 | textual_features_test = textual_features_test.reshape(textual_features_test.shape[0], -1)
322 | features_test = np.concatenate([
323 | numerical_features_test,
324 | categorical_features_test,
325 | textual_features_test
326 | ], axis=1)
327 |
328 | # save to disk
329 | pd.DataFrame(features).to_csv(Path(model_dir, "train.csv"), index=False)
330 | pd.DataFrame(labels).to_csv(Path(model_dir, "labels.csv"), index=False)
331 | pd.DataFrame(features_test).to_csv(Path(model_dir, "test.csv"), index=False)
332 | pd.DataFrame(labels_test).to_csv(Path(model_dir, "labels_test.csv"), index=False)
333 |
334 | save_feature_names(
335 | numerical_feature_names,
336 | categorical_feature_names,
337 | textual_feature_names,
338 | Path(model_dir, "feature_names.json")
339 | )
340 | # one-hot encoded feature names (for feat_imp)
341 | save_feature_names(
342 | numerical_feature_names,
343 | categorical_transformer.get_feature_names().tolist(),
344 | textual_feature_names,
345 | Path(model_dir, "one_hot_feature_names.json")
346 | )
347 |
348 | return features, features_test, labels, labels_test
349 |
350 |
351 | if __name__ == "__main__":
352 | parser = argparse.ArgumentParser()
353 | parser.add_argument("--use-existing", action="store_true")
354 | parser.add_argument("--test-size", type=float, default=0.33)
355 | args = parser.parse_args()
356 |
357 | data_dir = params['data_dir']
358 | df = pd.read_csv(Path(data_dir, "churn_dataset.csv"))
359 |
360 | if args.use_existing:
361 | use_existing = True
362 | else:
363 | use_existing = False
364 | test_size = args.test_size
365 |
366 | prep_data(df, use_existing, test_size)
367 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | """ Module for training the churn prediction model with text.
2 |
3 | Training parameters are stored in 'params.yaml'.
4 |
5 | Run in CLI example:
6 | 'python train.py'
7 |
8 | """
9 |
10 |
11 | import os
12 | import sys
13 | import json
14 | import yaml
15 | import joblib
16 | import logging
17 | import argparse
18 | import numpy as np
19 | import pandas as pd
20 | from pathlib import Path
21 | from matplotlib import pyplot as plt
22 | from sklearn.model_selection import train_test_split
23 | from sklearn.impute import SimpleImputer
24 | from sklearn.preprocessing import OneHotEncoder
25 | from sklearn.metrics import roc_auc_score, roc_curve, auc
26 | from sklearn.metrics import precision_recall_curve
27 | import torch
28 | import torch.nn as nn
29 | import torch.nn.functional as F
30 | import torch.optim as optim
31 | from torch.utils.data import DataLoader, TensorDataset
32 | from torch import Tensor
33 | import preprocess
34 | from preprocess import BertEncoder
35 |
36 |
37 | logger = logging.getLogger(__name__)
38 | logger.setLevel(logging.DEBUG)
39 | logger.addHandler(logging.StreamHandler(sys.stdout))
40 |
41 |
42 | with open("../model/params.yaml", "r") as params_file:
43 | params = yaml.safe_load(params_file)
44 |
45 | model_dir = params['model_dir']
46 |
47 |
48 | class Net(nn.Module):
49 | def __init__(self, x1_size, x2_size):
50 | super(Net, self).__init__()
51 | self.batch_norm = nn.BatchNorm1d(x1_size)
52 | self.fc1 = nn.Linear(x1_size, 10)
53 | self.fc2 = nn.Linear(10 + x2_size, 10)
54 | self.fc3 = nn.Linear(10, 1)
55 |
56 | def forward(self, x1, x2):
57 | x1 = self.batch_norm(x1)
58 | x1 = F.relu(self.fc1(x1))
59 | x1 = F.dropout(x1, p=0.2, training=self.training)
60 | x12 = torch.cat((x1.view(x1.size(0), -1),
61 | x2.view(x2.size(0), -1)), dim=1)
62 | x12 = F.dropout(x12, p=0.1, training=self.training)
63 | x12 = self.fc2(x12)
64 | out = self.fc3(x12)
65 | return out
66 |
67 |
68 | def train(
69 | X,
70 | y,
71 | X_test,
72 | y_test
73 | ):
74 | # get parameters
75 | batch_size = params['batch_size']
76 | batch_size_test = params['batch_size_test']
77 | epochs = params['epochs']
78 | pos_weight = params['pos_weight']
79 | lr = params['lr']
80 | momentum = params['momentum']
81 |
82 | # prepare training job
83 | X = np.array(X)
84 | y = np.array(y)
85 | X_test = np.array(X_test)
86 | y_test = np.array(y_test)
87 | training_data = TensorDataset( Tensor(X), Tensor(y) )
88 | train_loader = DataLoader(training_data, batch_size=batch_size,
89 | shuffle=True,
90 | num_workers=4)
91 | test_data = TensorDataset( Tensor(X_test), Tensor(y_test) )
92 | test_loader = DataLoader(test_data, batch_size=batch_size_test,
93 | shuffle=True,
94 | num_workers=4)
95 |
96 | # get size of num/cat & text data
97 | numerical_feature_names, categorical_feature_names, _ = preprocess.load_feature_names(Path(model_dir, "one_hot_feature_names.json"))
98 | number_cat_num_features = len(numerical_feature_names) + len(categorical_feature_names)
99 | x1_size = X[:, :number_cat_num_features].shape[1]
100 | x2_size = X[:, number_cat_num_features:].shape[1]
101 |
102 | model = Net(x1_size, x2_size)
103 | criterion = nn.BCEWithLogitsLoss(pos_weight=torch.as_tensor(pos_weight, dtype=torch.float))
104 | #criterion = nn.BCELoss(pos_weight=torch.as_tensor(pos_weight, dtype=torch.float))
105 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
106 |
107 | # train NN model
108 | scores_df = pd.DataFrame()
109 | train_scores = []
110 | test_scores = []
111 | for epoch in range(1, epochs + 1): # loop over the dataset multiple times
112 | model.train()
113 | print("starting epoch: ", epoch)
114 |
115 | for batch_idx, (data, target) in enumerate(train_loader, 1):
116 | # zero the parameter gradients
117 | optimizer.zero_grad()
118 |
119 | # split data inputs into cat/num features and text features
120 | x1 = data[:, :number_cat_num_features]
121 | x2 = data[:, number_cat_num_features:]
122 |
123 | # forward + backward + optimize
124 | output = model(x1, x2)
125 | loss = criterion(output, target) #target.type_as(output)) #labels_batch.long())
126 | loss.backward()
127 | optimizer.step()
128 |
129 | # training performance after each epoch
130 | with torch.no_grad():
131 | output = model(Tensor(X[:, :number_cat_num_features]), Tensor(X[:, number_cat_num_features:])).reshape(-1)
132 | preds = torch.sigmoid(output)
133 | train_score = roc_auc_score(Tensor(y), preds)
134 | train_scores.append(train_score)
135 | logger.info('Train Epoch: {}, train-auc-score: {:.4f}'.format(epoch, train_score))
136 |
137 | # test performance after each epoch
138 | test_score = test(model, test_loader)
139 | test_scores.append(test_score)
140 |
141 | # save scores
142 | print('saving scores')
143 | scores_df['train_scores'] = train_scores
144 | scores_df['test_scores'] = test_scores
145 | scores_df.to_csv(Path(model_dir, 'training_scores.csv'), index=False)
146 |
147 | # save model
148 | print('saving model')
149 | torch.save(model.state_dict(), Path(model_dir, 'model.pth'))
150 |
151 | return None
152 |
153 |
154 | def test(
155 | model,
156 | test_loader
157 | ):
158 | model.eval()
159 | correct = 0
160 | preds_all = []
161 | targets_all = []
162 | with torch.no_grad():
163 | for data, targets in test_loader:
164 |
165 | # split data inputs
166 | numerical_feature_names, categorical_feature_names, _ = preprocess.load_feature_names(Path(model_dir, "one_hot_feature_names.json"))
167 | number_cat_num_features = len(numerical_feature_names) + len(categorical_feature_names)
168 | x1 = data[:, :number_cat_num_features]
169 | x2 = data[:, number_cat_num_features:]
170 |
171 | output = model(x1, x2).reshape(-1)
172 | preds = torch.sigmoid(output)
173 | preds_all.extend(preds)
174 | targets_all.extend(targets)
175 | test_score = roc_auc_score(Tensor(targets_all), Tensor(preds_all))
176 |
177 | logger.info('test_auc_score: {:.4f}'.format(test_score))
178 |
179 | return test_score
180 |
181 |
182 | def get_train_assets(
183 | ):
184 | #print('loading feature_names')
185 | numerical_feature_names, categorical_feature_names, textual_feature_names = preprocess.load_feature_names(Path(model_dir, "feature_names.json"))
186 | #print('loading numerical_transformer')
187 | numerical_transformer = joblib.load(Path(model_dir, "numerical_transformer.joblib"))
188 | #print('loading categorical_transformer')
189 | categorical_transformer = joblib.load(Path(model_dir, "categorical_transformer.joblib"))
190 | #print('loading textual_transformer')
191 | textual_transformer = BertEncoder()
192 |
193 | model_assets = {
194 | 'numerical_feature_names': numerical_feature_names,
195 | 'numerical_transformer': numerical_transformer,
196 | 'categorical_feature_names': categorical_feature_names,
197 | 'categorical_transformer': categorical_transformer,
198 | 'textual_feature_names': textual_feature_names,
199 | 'textual_transformer': textual_transformer
200 | }
201 | return model_assets
202 |
203 |
204 | def predict(
205 | features,
206 | labels
207 | ):
208 | # get size of num/cat & text data to specify neural net
209 | numerical_feature_names, categorical_feature_names, _ = preprocess.load_feature_names(Path(model_dir, "one_hot_feature_names.json"))
210 | number_cat_num_features = len(numerical_feature_names) + len(categorical_feature_names)
211 | x1_size = features[:, :number_cat_num_features].shape[1]
212 | x2_size = features[:, number_cat_num_features:].shape[1]
213 |
214 | model = Net(x1_size, x2_size)
215 | with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
216 | model.load_state_dict(torch.load(f))
217 |
218 | model.eval()
219 | with torch.no_grad():
220 | output = model(Tensor(features[:, :number_cat_num_features]), Tensor(features[:, number_cat_num_features:])).reshape(-1)
221 | preds = torch.sigmoid(output)
222 |
223 | return preds
224 |
225 |
226 | def plot_train_stats(
227 | ):
228 | scores_df = pd.read_csv(Path(model_dir, 'scores.csv'))
229 | scores_df.plot(xlabel='Epochs', ylabel='AUC Score', title='AUC Score')
230 | return None
231 |
232 |
233 | def plot_pr_curve(
234 | features,
235 | labels
236 | ):
237 | preds = predict(features, labels)
238 | precisions, recalls, thresholds = precision_recall_curve(labels, preds)
239 | roc_auc = roc_auc_score(labels, preds)
240 |
241 | fig, ax = plt.subplots(figsize=(6,5))
242 | ax.plot(recalls, precisions, label='Model w/ text: AUC = %0.2f' % roc_auc)
243 | ax.plot([1, 0], [0, 1],'--')
244 | ax.set_xlabel('Recall')
245 | ax.set_ylabel('Precision')
246 | ax.legend(loc='lower left')
247 | plt.title('Precision-Recall Curve')
248 | plt.show()
249 | return None
250 |
251 |
252 | def plot_roc_curve(
253 | features,
254 | labels
255 | ):
256 | preds = predict(features, labels)
257 | fpr, tpr, threshold = roc_curve(labels, preds)
258 | roc_auc = auc(fpr, tpr)
259 |
260 | plt.title('ROC Curve')
261 | plt.plot(fpr, tpr, 'b', label='AUC = %0.2f' % roc_auc)
262 | plt.legend(loc = 'lower right')
263 | plt.plot([0, 1], [0, 1],'r--')
264 | plt.xlim([0, 1])
265 | plt.ylim([0, 1])
266 | plt.ylabel('True Positive Rate')
267 | plt.xlabel('False Positive Rate')
268 | plt.show()
269 | return None
270 |
271 |
272 | if __name__ == "__main__":
273 |
274 | model_dir = params['model_dir']
275 |
276 | X_train = pd.read_csv(Path(model_dir, "train.csv"))
277 | y_train = pd.read_csv(Path(model_dir, "labels.csv"))
278 | X_test = pd.read_csv(Path(model_dir, "test.csv"))
279 | y_test = pd.read_csv(Path(model_dir, "labels_test.csv"))
280 |
281 | train(X=X_train, y=y_train, X_test=X_test, y_test=y_test)
282 |
--------------------------------------------------------------------------------