├── Insight_Circulatory_Analysis.ipynb
├── LICENSE
├── README.md
├── SimpleNN_Baseline.ipynb
├── UserGuideMimicII.pdf
├── XGBoost_Baseline_Models.ipynb
└── weights.best.hdf5
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Randy (Jimmy) Giedt
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cardiovascular Death Prediction
2 | Files related to an insight data science project focused on predicting death in atherosclerosis patients from the MIMICIII database.
3 |
4 | #### Summary
5 | Cardiovascular mortality is one of the leading causes of death in the developed world. In this project, I utilize the MIMIC III dataset to explore the application and effectiveness of traditional and state of the art machine learning and deep learning techniques to patients diagnosed with coronary atherosclerosis. While this project focuses on a specific disease, the methods and techniques illustrated here have broad application in analysis of medical and temporal datasets. Source Code for this project can be found on Github.
6 |
7 | #### Problem Background
8 | An estimated 16.3 million Americans have coronary artery disease, which is ~7% of all U.S. citizens over the age of 20. Coronary artery disease occurs when atherosclerotic plaques, which are made up of cholesterol and other fatty substances, significantly occlude the arteries which provide blood to the heart muscle. The consequences of this disease can range from angina (chest pain) to a rupture of the plaque leading to arterial clot formation and heart tissue death (Myocardial Ischemia). For more information see:
9 |
10 | #### Data Set
11 | For this project I utilized the freely available MIMIC III Database: https://mimic.physionet.org/ . The dataset is made up ~ 47,000 Unique patients with over 650,000 diagnoses. Each patient and diagnosis is composed of a rich list of attributes including patient demographic information, lists of medications, a history of diagnoses and other potentially predictive medical characteristics in patients.
12 |
13 | #### An Overview of the MIMIC III database and the information included.
14 | To begin analysis of this data, we want to find all of the patients in the database with a diagnosis of coronary atherosclerosis; this corresponds to an ICD9 code of 414.01. Furthermore, we want to find which of these patients survived through the study versus those which were deceased. Additionally, we want to remove those patients who were more than 89 years old at anytime in the study as, 1. Physicians will more than likely treat these patients differently due to their age, and 2. Exact ages for these patients aren’t recorded in the Mimic III database. Results from this analysis including the size of our final dataset can be seen below.
15 |
16 | #### Analysis of the MIMIC III Database to understand the patient population with coronary atherosclerosis diagnoses.
17 | With these datasets we can being to conduct preliminary data analysis on the demographic characteristics found in our population dividing the deceased from those patients who survived. For a more complete data analysis, please see the linked github Jupyter notebooks.
18 |
19 | #### A few examples of demographic information illustrating differences in death rates.
20 | Looking at just these three examples, we can see that race seems to have some affect on outcome, as the average patient who dies is more likely to not self classify as white. Similarly, these patients who die are far more likely to be admitted to the ICU via the ER and are more likely to be on Medicare insurance. Of course there are confounders for this data (for example, Medicare patients are also more likely to be older), but with a more extensive analysis, a seemingly reasonable hypothesis would be that applying a machine learning approach on demographic/ admission/ and other data would allow us to predict patients most likely to pass away.
21 |
22 | #### Machine Learning
23 | To employ a machine learning approach, I created a feature set involving all demographic data available (race, religion, language, insurance type, marriage status, age, sex among others) as well as a naive interpretation of medical history including # of appointments they have been to. Combined, after one-hot encoding our feature set is ~ 120 features. Training a random forest classifier with this data and analyzing the effects of feature number, it appears (without additional hyperparamter optimization) that our optimal cross validation score will be ~ 0.69. Further optimization via gridsearch yields a maximum score of 0.72.
24 |
25 | #### Deep Learning
26 | While 0.72 accuracy may be helpful for some applications, in the case of medical and patient care, it unfortunately does not provide predications accurate enough for physicians to use in the clinic. Indeed, this value would yield a large proportion of false positives in the data set.
27 |
28 | #### Simplified version of a typical RNN operating on medical visits.
29 | To get a better predication of this data, an ideal solution would take into account patients past appointments i.e. what has happened in a patients history, in addition to demographic information. An ideal approach for this type of data is a recurrent neural net (RNN). Briefly, RNNs are a type of neural net setup for designed for taking in temporal data. As seen above, the structure of an RNN is designed such that we can conduct operations on individual time points and utilize information from each time point to predict data at the next, an ideal fit for working with medical diagnoses with patients. Unlike the above diagram though, we only care about a final output, not necessarily outputs at each diagnosis.
30 |
31 | Setting up an RNN for our data (see Github for implementation details), I was able to incorporate the diagnoses of individual patients over time, yielding a cross-validated AUC of 0.82, a substantial improvement from the 0.72 cross validation score. While its great that our score increased, frustratingly, RNN (and neural nets in general) while more sophisticated than typical machine learning methods, are nearly uninterpretable black boxes with their hidden layers and other under the hood characteristics.
32 |
33 | #### Attention RNN Model.
34 | As a solution for this, the final analysis of this project I implemented was based on a great paper out of Georgia Tech focusing on an interpretable attention RNN for medical data (RETAIN). The idea of an attention RNN is that we will be able to gain information on the hidden layers that we missed in our RNN interpretation above, while still gaining from the sophistication and in this case, time data incorporation, found in an RNN. Altering this model for our purposes and implementing it yielded a cross validated AUC of 0.81, nearly as good as found in our RNN. In addition, this model can also tell us what medical codes were significant when analyzing our data. In checking the most significant medical codes from our data, I found that, unsurprisingly, heart attacks were one of the biggest contributors to death.
35 |
36 | #### Conclusions
37 | Here I implemented several different methods of machine learning and deep learning in an attempt to predict mortality in patients with coronary atherosclerosis. While utilizing patient demographic information and limited information about their interaction with physicians, we were able to come to a reasonable predication about their mortality risk. By utilizing an RNN, we could improve on this prediction, with the penalty of making the model uninterpretable. By utilizing an attention RNN I was able to come to a similar predictive value while better understanding the variables that lead to risk for the patient.
38 |
39 | The ideal user of this system would be care providers, hospitals and insurance companies. Using machine learning and/or deep learning techniques, it should be possible, as shown here, to better predict patients who would benefit from aggressive physician intervention in order to save lives. Furthermore, there would be substantial cost savings with such a system as heart attacks or other cardiovascular diseases are nearly always acute events, requiring expensive emergency transportation and care for patients.
40 |
--------------------------------------------------------------------------------
/SimpleNN_Baseline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### Simple Neural network based on demographic data for cardiovascular death prediction"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 32,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import tensorflow as tf\n",
17 | "import keras \n",
18 | "import pandas as pd"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "#### Load data and clean data from demographics file"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 60,
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "data": {
35 | "text/html": [
36 | "
\n",
37 | "\n",
50 | "
\n",
51 | " \n",
52 | " \n",
53 | " | \n",
54 | " Unnamed: 0 | \n",
55 | " SUBJECT_ID | \n",
56 | " GENDER | \n",
57 | " DOB | \n",
58 | " DOD | \n",
59 | " DOA | \n",
60 | " ADMIT_AGE | \n",
61 | " ETHNICITY | \n",
62 | " MARITAL_STATUS | \n",
63 | " LANGUAGE | \n",
64 | " ... | \n",
65 | " INSURANCE | \n",
66 | " ADMISSION_LOCATION | \n",
67 | " #ADMISSIONS | \n",
68 | " OUTSIDE_DEATH_FLAG | \n",
69 | " DEATH_FLAG | \n",
70 | " OLD_FLAG | \n",
71 | " HEART_ATTACK_FLAG | \n",
72 | " ATHERO_DIAGNOSIS_FLAG | \n",
73 | " HEART_DEATH_FLAG | \n",
74 | " CAUSE | \n",
75 | "
\n",
76 | " \n",
77 | " \n",
78 | " \n",
79 | " 0 | \n",
80 | " 0 | \n",
81 | " 31 | \n",
82 | " M | \n",
83 | " 2036-05-17 | \n",
84 | " 2108-08-30 | \n",
85 | " 2108-08-22 23:27:00 | \n",
86 | " 72.312329 | \n",
87 | " WHITE | \n",
88 | " MARRIED | \n",
89 | " UKNOWN | \n",
90 | " ... | \n",
91 | " Medicare | \n",
92 | " TRANSFER FROM HOSP/EXTRAM | \n",
93 | " 1 | \n",
94 | " 0 | \n",
95 | " 1 | \n",
96 | " 0 | \n",
97 | " 0 | \n",
98 | " 0 | \n",
99 | " 0 | \n",
100 | " STATUS EPILEPTICUS | \n",
101 | "
\n",
102 | " \n",
103 | " 1 | \n",
104 | " 1 | \n",
105 | " 56 | \n",
106 | " F | \n",
107 | " 1804-01-02 | \n",
108 | " 2104-01-08 | \n",
109 | " 2104-01-02 02:01:00 | \n",
110 | " NaN | \n",
111 | " WHITE | \n",
112 | " UKNOWN | \n",
113 | " UKNOWN | \n",
114 | " ... | \n",
115 | " Medicare | \n",
116 | " EMERGENCY ROOM ADMIT | \n",
117 | " 1 | \n",
118 | " 0 | \n",
119 | " 1 | \n",
120 | " 1 | \n",
121 | " 0 | \n",
122 | " 0 | \n",
123 | " 0 | \n",
124 | " HEAD BLEED | \n",
125 | "
\n",
126 | " \n",
127 | " 2 | \n",
128 | " 2 | \n",
129 | " 61 | \n",
130 | " M | \n",
131 | " 2063-10-21 | \n",
132 | " 2119-02-03 | \n",
133 | " 2119-01-04 18:12:00 | \n",
134 | " 55.241096 | \n",
135 | " WHITE | \n",
136 | " MARRIED | \n",
137 | " UKNOWN | \n",
138 | " ... | \n",
139 | " Private | \n",
140 | " CLINIC REFERRAL/PREMATURE | \n",
141 | " 2 | \n",
142 | " 0 | \n",
143 | " 1 | \n",
144 | " 0 | \n",
145 | " 0 | \n",
146 | " 0 | \n",
147 | " 0 | \n",
148 | " NON-HODGKINS LYMPHOMA;FEBRILE;NEUTROPENIA | \n",
149 | "
\n",
150 | " \n",
151 | " 3 | \n",
152 | " 3 | \n",
153 | " 67 | \n",
154 | " M | \n",
155 | " 2084-06-05 | \n",
156 | " 2157-12-02 | \n",
157 | " 2157-12-02 00:45:00 | \n",
158 | " 73.539726 | \n",
159 | " WHITE | \n",
160 | " SINGLE | \n",
161 | " UKNOWN | \n",
162 | " ... | \n",
163 | " Medicare | \n",
164 | " EMERGENCY ROOM ADMIT | \n",
165 | " 2 | \n",
166 | " 0 | \n",
167 | " 1 | \n",
168 | " 0 | \n",
169 | " 0 | \n",
170 | " 0 | \n",
171 | " 0 | \n",
172 | " SUBARACHNOID HEMORRHAGE | \n",
173 | "
\n",
174 | " \n",
175 | " 4 | \n",
176 | " 4 | \n",
177 | " 84 | \n",
178 | " F | \n",
179 | " 2151-10-21 | \n",
180 | " 2196-04-17 | \n",
181 | " 2196-04-14 04:02:00 | \n",
182 | " 44.512329 | \n",
183 | " WHITE | \n",
184 | " MARRIED | \n",
185 | " UKNOWN | \n",
186 | " ... | \n",
187 | " Private | \n",
188 | " EMERGENCY ROOM ADMIT | \n",
189 | " 2 | \n",
190 | " 0 | \n",
191 | " 1 | \n",
192 | " 0 | \n",
193 | " 0 | \n",
194 | " 0 | \n",
195 | " 0 | \n",
196 | " GLIOBLASTOMA,NAUSEA | \n",
197 | "
\n",
198 | " \n",
199 | "
\n",
200 | "
5 rows × 21 columns
\n",
201 | "
"
202 | ],
203 | "text/plain": [
204 | " Unnamed: 0 SUBJECT_ID GENDER DOB DOD DOA \\\n",
205 | "0 0 31 M 2036-05-17 2108-08-30 2108-08-22 23:27:00 \n",
206 | "1 1 56 F 1804-01-02 2104-01-08 2104-01-02 02:01:00 \n",
207 | "2 2 61 M 2063-10-21 2119-02-03 2119-01-04 18:12:00 \n",
208 | "3 3 67 M 2084-06-05 2157-12-02 2157-12-02 00:45:00 \n",
209 | "4 4 84 F 2151-10-21 2196-04-17 2196-04-14 04:02:00 \n",
210 | "\n",
211 | " ADMIT_AGE ETHNICITY MARITAL_STATUS LANGUAGE \\\n",
212 | "0 72.312329 WHITE MARRIED UKNOWN \n",
213 | "1 NaN WHITE UKNOWN UKNOWN \n",
214 | "2 55.241096 WHITE MARRIED UKNOWN \n",
215 | "3 73.539726 WHITE SINGLE UKNOWN \n",
216 | "4 44.512329 WHITE MARRIED UKNOWN \n",
217 | "\n",
218 | " ... INSURANCE \\\n",
219 | "0 ... Medicare \n",
220 | "1 ... Medicare \n",
221 | "2 ... Private \n",
222 | "3 ... Medicare \n",
223 | "4 ... Private \n",
224 | "\n",
225 | " ADMISSION_LOCATION #ADMISSIONS OUTSIDE_DEATH_FLAG DEATH_FLAG \\\n",
226 | "0 TRANSFER FROM HOSP/EXTRAM 1 0 1 \n",
227 | "1 EMERGENCY ROOM ADMIT 1 0 1 \n",
228 | "2 CLINIC REFERRAL/PREMATURE 2 0 1 \n",
229 | "3 EMERGENCY ROOM ADMIT 2 0 1 \n",
230 | "4 EMERGENCY ROOM ADMIT 2 0 1 \n",
231 | "\n",
232 | " OLD_FLAG HEART_ATTACK_FLAG ATHERO_DIAGNOSIS_FLAG HEART_DEATH_FLAG \\\n",
233 | "0 0 0 0 0 \n",
234 | "1 1 0 0 0 \n",
235 | "2 0 0 0 0 \n",
236 | "3 0 0 0 0 \n",
237 | "4 0 0 0 0 \n",
238 | "\n",
239 | " CAUSE \n",
240 | "0 STATUS EPILEPTICUS \n",
241 | "1 HEAD BLEED \n",
242 | "2 NON-HODGKINS LYMPHOMA;FEBRILE;NEUTROPENIA \n",
243 | "3 SUBARACHNOID HEMORRHAGE \n",
244 | "4 GLIOBLASTOMA,NAUSEA \n",
245 | "\n",
246 | "[5 rows x 21 columns]"
247 | ]
248 | },
249 | "execution_count": 60,
250 | "metadata": {},
251 | "output_type": "execute_result"
252 | }
253 | ],
254 | "source": [
255 | "demographics = pd.read_csv('Demographics.csv')\n",
256 | "demographics.head()"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 61,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "# First, define atherosclerosis diagnoses from non-atherosclerosis diagnoses\n",
266 | "athero_pre = demographics[demographics['OLD_FLAG']==0]\n",
267 | "athero_pos = athero_pre[athero_pre['ATHERO_DIAGNOSIS_FLAG']== 1]\n",
268 | "athero_neg = athero_pre[athero_pre['ATHERO_DIAGNOSIS_FLAG']==0]\n",
269 | "\n",
270 | "# Clean data sets\n",
271 | "del athero_neg['CAUSE']\n",
272 | "del athero_pos['CAUSE']\n",
273 | "\n",
274 | "del athero_neg['ATHERO_DIAGNOSIS_FLAG']\n",
275 | "del athero_pos['ATHERO_DIAGNOSIS_FLAG']\n",
276 | "\n",
277 | "del athero_neg['OLD_FLAG']\n",
278 | "del athero_pos['OLD_FLAG']\n",
279 | "\n",
280 | "del athero_neg['OUTSIDE_DEATH_FLAG']\n",
281 | "del athero_pos['OUTSIDE_DEATH_FLAG']\n",
282 | "\n",
283 | "del athero_neg['SUBJECT_ID']\n",
284 | "del athero_pos['SUBJECT_ID']\n",
285 | "\n",
286 | "del athero_neg['DOB']\n",
287 | "del athero_pos['DOB']\n",
288 | "\n",
289 | "del athero_neg['DOD']\n",
290 | "del athero_pos['DOD']"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": 62,
296 | "metadata": {},
297 | "outputs": [
298 | {
299 | "data": {
300 | "text/html": [
301 | "\n",
302 | "\n",
315 | "
\n",
316 | " \n",
317 | " \n",
318 | " | \n",
319 | " GENDER | \n",
320 | " ADMIT_AGE | \n",
321 | " ETHNICITY | \n",
322 | " MARITAL_STATUS | \n",
323 | " LANGUAGE | \n",
324 | " RELIGION | \n",
325 | " INSURANCE | \n",
326 | " ADMISSION_LOCATION | \n",
327 | " #ADMISSIONS | \n",
328 | " DEATH_FLAG | \n",
329 | " HEART_DEATH_FLAG | \n",
330 | "
\n",
331 | " \n",
332 | " \n",
333 | " \n",
334 | " 9 | \n",
335 | " M | \n",
336 | " 69.641096 | \n",
337 | " WHITE | \n",
338 | " MARRIED | \n",
339 | " UKNOWN | \n",
340 | " CATHOLIC | \n",
341 | " Private | \n",
342 | " TRANSFER FROM HOSP/EXTRAM | \n",
343 | " 4 | \n",
344 | " 1 | \n",
345 | " 0 | \n",
346 | "
\n",
347 | " \n",
348 | " 12 | \n",
349 | " F | \n",
350 | " 69.005479 | \n",
351 | " WHITE | \n",
352 | " MARRIED | \n",
353 | " ENGL | \n",
354 | " PROTESTANT QUAKER | \n",
355 | " Medicare | \n",
356 | " EMERGENCY ROOM ADMIT | \n",
357 | " 2 | \n",
358 | " 1 | \n",
359 | " 0 | \n",
360 | "
\n",
361 | " \n",
362 | " 17 | \n",
363 | " M | \n",
364 | " 87.882192 | \n",
365 | " WHITE | \n",
366 | " MARRIED | \n",
367 | " UKNOWN | \n",
368 | " JEWISH | \n",
369 | " Medicare | \n",
370 | " EMERGENCY ROOM ADMIT | \n",
371 | " 2 | \n",
372 | " 1 | \n",
373 | " 0 | \n",
374 | "
\n",
375 | " \n",
376 | " 19 | \n",
377 | " F | \n",
378 | " 76.871233 | \n",
379 | " WHITE | \n",
380 | " MARRIED | \n",
381 | " PORT | \n",
382 | " CATHOLIC | \n",
383 | " Medicare | \n",
384 | " TRANSFER FROM HOSP/EXTRAM | \n",
385 | " 4 | \n",
386 | " 1 | \n",
387 | " 1 | \n",
388 | "
\n",
389 | " \n",
390 | " 22 | \n",
391 | " F | \n",
392 | " 85.726027 | \n",
393 | " BLACK/AFRICAN AMERICAN | \n",
394 | " WIDOWED | \n",
395 | " UKNOWN | \n",
396 | " CATHOLIC | \n",
397 | " Medicare | \n",
398 | " EMERGENCY ROOM ADMIT | \n",
399 | " 2 | \n",
400 | " 1 | \n",
401 | " 0 | \n",
402 | "
\n",
403 | " \n",
404 | "
\n",
405 | "
"
406 | ],
407 | "text/plain": [
408 | " GENDER ADMIT_AGE ETHNICITY MARITAL_STATUS LANGUAGE \\\n",
409 | "9 M 69.641096 WHITE MARRIED UKNOWN \n",
410 | "12 F 69.005479 WHITE MARRIED ENGL \n",
411 | "17 M 87.882192 WHITE MARRIED UKNOWN \n",
412 | "19 F 76.871233 WHITE MARRIED PORT \n",
413 | "22 F 85.726027 BLACK/AFRICAN AMERICAN WIDOWED UKNOWN \n",
414 | "\n",
415 | " RELIGION INSURANCE ADMISSION_LOCATION #ADMISSIONS \\\n",
416 | "9 CATHOLIC Private TRANSFER FROM HOSP/EXTRAM 4 \n",
417 | "12 PROTESTANT QUAKER Medicare EMERGENCY ROOM ADMIT 2 \n",
418 | "17 JEWISH Medicare EMERGENCY ROOM ADMIT 2 \n",
419 | "19 CATHOLIC Medicare TRANSFER FROM HOSP/EXTRAM 4 \n",
420 | "22 CATHOLIC Medicare EMERGENCY ROOM ADMIT 2 \n",
421 | "\n",
422 | " DEATH_FLAG HEART_DEATH_FLAG \n",
423 | "9 1 0 \n",
424 | "12 1 0 \n",
425 | "17 1 0 \n",
426 | "19 1 1 \n",
427 | "22 1 0 "
428 | ]
429 | },
430 | "execution_count": 62,
431 | "metadata": {},
432 | "output_type": "execute_result"
433 | }
434 | ],
435 | "source": [
436 | "athero_pos['DOA']\n",
437 | "del athero_pos['DOA']\n",
438 | "del athero_neg['DOA']\n",
439 | "\n",
440 | "athero_neg['HEART_ATTACK_FLAG']\n",
441 | "del athero_neg['HEART_ATTACK_FLAG']\n",
442 | "del athero_pos['HEART_ATTACK_FLAG']\n",
443 | "\n",
444 | "del athero_pos['Unnamed: 0']\n",
445 | "athero_pos.head()"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": 63,
451 | "metadata": {},
452 | "outputs": [],
453 | "source": [
454 | "# Create Outcome data sets\n",
455 | "athero_heartdeath = pd.Series(athero_pos['HEART_DEATH_FLAG'])\n",
456 | "athero_death = pd.Series(athero_pos['DEATH_FLAG'])"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 64,
462 | "metadata": {},
463 | "outputs": [],
464 | "source": [
465 | "del athero_pos['HEART_DEATH_FLAG']\n",
466 | "del athero_pos['DEATH_FLAG']"
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "execution_count": 65,
472 | "metadata": {},
473 | "outputs": [],
474 | "source": [
475 | "# Get dummies\n",
476 | "athero_pos = pd.get_dummies(athero_pos, columns=['GENDER','ETHNICITY','MARITAL_STATUS', 'LANGUAGE', 'RELIGION', 'INSURANCE', 'ADMISSION_LOCATION'])"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": 66,
482 | "metadata": {},
483 | "outputs": [
484 | {
485 | "name": "stdout",
486 | "output_type": "stream",
487 | "text": [
488 | "0 10001\n",
489 | "1 202\n",
490 | "Name: HEART_DEATH_FLAG, dtype: int64\n",
491 | "0 6576\n",
492 | "1 3627\n",
493 | "Name: DEATH_FLAG, dtype: int64\n"
494 | ]
495 | }
496 | ],
497 | "source": [
498 | "# Check outcome numbers\n",
499 | "print(athero_heartdeath.value_counts())\n",
500 | "print(athero_death.value_counts())"
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": 67,
506 | "metadata": {},
507 | "outputs": [],
508 | "source": [
509 | "# Normalize data\n",
510 | "from sklearn import preprocessing\n",
511 | "athero_pos = preprocessing.scale(athero_pos)"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "execution_count": 68,
517 | "metadata": {},
518 | "outputs": [],
519 | "source": [
520 | "# Test/ train \n",
521 | "from sklearn.model_selection import train_test_split\n",
522 | "X_train, X_test, y_train, y_test = train_test_split(athero_pos, athero_death, test_size=0.20, random_state=42)"
523 | ]
524 | },
525 | {
526 | "cell_type": "markdown",
527 | "metadata": {},
528 | "source": [
529 | "#### Create simple neural network as a baseline model"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 69,
535 | "metadata": {},
536 | "outputs": [],
537 | "source": [
538 | "from keras.models import Sequential\n",
539 | "from keras.layers import Dense\n",
540 | "import numpy\n",
541 | "# fix random seed for reproducibility\n",
542 | "numpy.random.seed(0)"
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": 70,
548 | "metadata": {},
549 | "outputs": [],
550 | "source": [
551 | "# Create model\n",
552 | "model = Sequential()\n",
553 | "model.add(Dense(80, input_dim=121 , activation = 'relu'))\n",
554 | "model.add(Dense(60, activation = 'relu'))\n",
555 | "model.add(Dense(40, activation = 'relu'))\n",
556 | "model.add(Dense(20, activation = 'relu'))\n",
557 | "model.add(Dense(1, activation = 'sigmoid'))"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": 71,
563 | "metadata": {},
564 | "outputs": [],
565 | "source": [
566 | "# Compile model\n",
567 | "model.compile(loss = 'binary_crossentropy', optimizer='adam', metrics=['accuracy'])"
568 | ]
569 | },
570 | {
571 | "cell_type": "code",
572 | "execution_count": 72,
573 | "metadata": {},
574 | "outputs": [],
575 | "source": [
576 | "# checkpoint\n",
577 | "from keras.callbacks import ModelCheckpoint\n",
578 | "filepath=\"weights.best.hdf5\"\n",
579 | "checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')\n",
580 | "callbacks_list = [checkpoint]"
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "execution_count": 73,
586 | "metadata": {},
587 | "outputs": [
588 | {
589 | "name": "stderr",
590 | "output_type": "stream",
591 | "text": [
592 | "/Users/jimmy/anaconda3/envs/Tensorflow/lib/python3.6/site-packages/keras/models.py:942: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
593 | " warnings.warn('The `nb_epoch` argument in `fit` '\n"
594 | ]
595 | },
596 | {
597 | "name": "stdout",
598 | "output_type": "stream",
599 | "text": [
600 | "Train on 8162 samples, validate on 2041 samples\n",
601 | "Epoch 1/150\n",
602 | "8162/8162 [==============================] - 1s 125us/step - loss: 0.5769 - acc: 0.7031 - val_loss: 0.5374 - val_acc: 0.7364\n",
603 | "\n",
604 | "Epoch 00001: val_acc improved from -inf to 0.73640, saving model to weights.best.hdf5\n",
605 | "Epoch 2/150\n",
606 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.5295 - acc: 0.7362 - val_loss: 0.5353 - val_acc: 0.7398\n",
607 | "\n",
608 | "Epoch 00002: val_acc improved from 0.73640 to 0.73983, saving model to weights.best.hdf5\n",
609 | "Epoch 3/150\n",
610 | "8162/8162 [==============================] - 1s 81us/step - loss: 0.5155 - acc: 0.7488 - val_loss: 0.5490 - val_acc: 0.7310\n",
611 | "\n",
612 | "Epoch 00003: val_acc did not improve\n",
613 | "Epoch 4/150\n",
614 | "8162/8162 [==============================] - 1s 81us/step - loss: 0.5099 - acc: 0.7517 - val_loss: 0.5363 - val_acc: 0.7447\n",
615 | "\n",
616 | "Epoch 00004: val_acc improved from 0.73983 to 0.74473, saving model to weights.best.hdf5\n",
617 | "Epoch 5/150\n",
618 | "8162/8162 [==============================] - 1s 83us/step - loss: 0.5017 - acc: 0.7546 - val_loss: 0.5461 - val_acc: 0.7403\n",
619 | "\n",
620 | "Epoch 00005: val_acc did not improve\n",
621 | "Epoch 6/150\n",
622 | "8162/8162 [==============================] - 1s 82us/step - loss: 0.4941 - acc: 0.7599 - val_loss: 0.5490 - val_acc: 0.7403\n",
623 | "\n",
624 | "Epoch 00006: val_acc did not improve\n",
625 | "Epoch 7/150\n",
626 | "8162/8162 [==============================] - 1s 90us/step - loss: 0.4886 - acc: 0.7659 - val_loss: 0.5498 - val_acc: 0.7379\n",
627 | "\n",
628 | "Epoch 00007: val_acc did not improve\n",
629 | "Epoch 8/150\n",
630 | "8162/8162 [==============================] - 1s 91us/step - loss: 0.4821 - acc: 0.7677 - val_loss: 0.5577 - val_acc: 0.7354\n",
631 | "\n",
632 | "Epoch 00008: val_acc did not improve\n",
633 | "Epoch 9/150\n",
634 | "8162/8162 [==============================] - 1s 93us/step - loss: 0.4778 - acc: 0.7704 - val_loss: 0.5612 - val_acc: 0.7261\n",
635 | "\n",
636 | "Epoch 00009: val_acc did not improve\n",
637 | "Epoch 10/150\n",
638 | "8162/8162 [==============================] - 1s 92us/step - loss: 0.4735 - acc: 0.7764 - val_loss: 0.5594 - val_acc: 0.7379\n",
639 | "\n",
640 | "Epoch 00010: val_acc did not improve\n",
641 | "Epoch 11/150\n",
642 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.4648 - acc: 0.7780 - val_loss: 0.5794 - val_acc: 0.7266\n",
643 | "\n",
644 | "Epoch 00011: val_acc did not improve\n",
645 | "Epoch 12/150\n",
646 | "8162/8162 [==============================] - 1s 89us/step - loss: 0.4610 - acc: 0.7789 - val_loss: 0.5671 - val_acc: 0.7403\n",
647 | "\n",
648 | "Epoch 00012: val_acc did not improve\n",
649 | "Epoch 13/150\n",
650 | "8162/8162 [==============================] - 1s 88us/step - loss: 0.4552 - acc: 0.7839 - val_loss: 0.5864 - val_acc: 0.7315\n",
651 | "\n",
652 | "Epoch 00013: val_acc did not improve\n",
653 | "Epoch 14/150\n",
654 | "8162/8162 [==============================] - 1s 90us/step - loss: 0.4485 - acc: 0.7862 - val_loss: 0.5913 - val_acc: 0.7281\n",
655 | "\n",
656 | "Epoch 00014: val_acc did not improve\n",
657 | "Epoch 15/150\n",
658 | "8162/8162 [==============================] - 1s 89us/step - loss: 0.4403 - acc: 0.7887 - val_loss: 0.6029 - val_acc: 0.7300\n",
659 | "\n",
660 | "Epoch 00015: val_acc did not improve\n",
661 | "Epoch 16/150\n",
662 | "8162/8162 [==============================] - 1s 92us/step - loss: 0.4339 - acc: 0.7909 - val_loss: 0.6046 - val_acc: 0.7266\n",
663 | "\n",
664 | "Epoch 00016: val_acc did not improve\n",
665 | "Epoch 17/150\n",
666 | "8162/8162 [==============================] - 1s 91us/step - loss: 0.4264 - acc: 0.7998 - val_loss: 0.6200 - val_acc: 0.7237\n",
667 | "\n",
668 | "Epoch 00017: val_acc did not improve\n",
669 | "Epoch 18/150\n",
670 | "8162/8162 [==============================] - 1s 91us/step - loss: 0.4202 - acc: 0.8020 - val_loss: 0.6216 - val_acc: 0.7188\n",
671 | "\n",
672 | "Epoch 00018: val_acc did not improve\n",
673 | "Epoch 19/150\n",
674 | "8162/8162 [==============================] - 1s 89us/step - loss: 0.4140 - acc: 0.8027 - val_loss: 0.6633 - val_acc: 0.7291\n",
675 | "\n",
676 | "Epoch 00019: val_acc did not improve\n",
677 | "Epoch 20/150\n",
678 | "8162/8162 [==============================] - 1s 88us/step - loss: 0.4060 - acc: 0.8063 - val_loss: 0.6282 - val_acc: 0.7276\n",
679 | "\n",
680 | "Epoch 00020: val_acc did not improve\n",
681 | "Epoch 21/150\n",
682 | "8162/8162 [==============================] - 1s 89us/step - loss: 0.3985 - acc: 0.8099 - val_loss: 0.6368 - val_acc: 0.7124\n",
683 | "\n",
684 | "Epoch 00021: val_acc did not improve\n",
685 | "Epoch 22/150\n",
686 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.3955 - acc: 0.8118 - val_loss: 0.7164 - val_acc: 0.7286\n",
687 | "\n",
688 | "Epoch 00022: val_acc did not improve\n",
689 | "Epoch 23/150\n",
690 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.3896 - acc: 0.8185 - val_loss: 0.6900 - val_acc: 0.7207\n",
691 | "\n",
692 | "Epoch 00023: val_acc did not improve\n",
693 | "Epoch 24/150\n",
694 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.3815 - acc: 0.8219 - val_loss: 0.7278 - val_acc: 0.7286\n",
695 | "\n",
696 | "Epoch 00024: val_acc did not improve\n",
697 | "Epoch 25/150\n",
698 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.3825 - acc: 0.8166 - val_loss: 0.6930 - val_acc: 0.7227\n",
699 | "\n",
700 | "Epoch 00025: val_acc did not improve\n",
701 | "Epoch 26/150\n",
702 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.3766 - acc: 0.8244 - val_loss: 0.7103 - val_acc: 0.7075\n",
703 | "\n",
704 | "Epoch 00026: val_acc did not improve\n",
705 | "Epoch 27/150\n",
706 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.3673 - acc: 0.8276 - val_loss: 0.7401 - val_acc: 0.7119\n",
707 | "\n",
708 | "Epoch 00027: val_acc did not improve\n",
709 | "Epoch 28/150\n",
710 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.3612 - acc: 0.8290 - val_loss: 0.7527 - val_acc: 0.6997\n",
711 | "\n",
712 | "Epoch 00028: val_acc did not improve\n",
713 | "Epoch 29/150\n",
714 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.3567 - acc: 0.8296 - val_loss: 0.7599 - val_acc: 0.7109\n",
715 | "\n",
716 | "Epoch 00029: val_acc did not improve\n",
717 | "Epoch 30/150\n",
718 | "8162/8162 [==============================] - 1s 82us/step - loss: 0.3562 - acc: 0.8303 - val_loss: 0.7882 - val_acc: 0.7148\n",
719 | "\n",
720 | "Epoch 00030: val_acc did not improve\n",
721 | "Epoch 31/150\n",
722 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.3502 - acc: 0.8356 - val_loss: 0.8316 - val_acc: 0.6913\n",
723 | "\n",
724 | "Epoch 00031: val_acc did not improve\n",
725 | "Epoch 32/150\n",
726 | "8162/8162 [==============================] - 1s 83us/step - loss: 0.3468 - acc: 0.8393 - val_loss: 0.8371 - val_acc: 0.6997\n",
727 | "\n",
728 | "Epoch 00032: val_acc did not improve\n",
729 | "Epoch 33/150\n",
730 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.3448 - acc: 0.8351 - val_loss: 0.8517 - val_acc: 0.6801\n",
731 | "\n",
732 | "Epoch 00033: val_acc did not improve\n",
733 | "Epoch 34/150\n",
734 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.3495 - acc: 0.8308 - val_loss: 0.8518 - val_acc: 0.7085\n",
735 | "\n",
736 | "Epoch 00034: val_acc did not improve\n",
737 | "Epoch 35/150\n",
738 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.3366 - acc: 0.8444 - val_loss: 0.9066 - val_acc: 0.6894\n",
739 | "\n",
740 | "Epoch 00035: val_acc did not improve\n",
741 | "Epoch 36/150\n",
742 | "8162/8162 [==============================] - 1s 88us/step - loss: 0.3370 - acc: 0.8389 - val_loss: 0.8553 - val_acc: 0.6982\n",
743 | "\n",
744 | "Epoch 00036: val_acc did not improve\n",
745 | "Epoch 37/150\n",
746 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.3385 - acc: 0.8433 - val_loss: 0.9141 - val_acc: 0.7046\n",
747 | "\n",
748 | "Epoch 00037: val_acc did not improve\n",
749 | "Epoch 38/150\n",
750 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.3302 - acc: 0.8450 - val_loss: 0.9367 - val_acc: 0.7060\n",
751 | "\n",
752 | "Epoch 00038: val_acc did not improve\n",
753 | "Epoch 39/150\n",
754 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.3303 - acc: 0.8443 - val_loss: 0.9484 - val_acc: 0.6850\n",
755 | "\n",
756 | "Epoch 00039: val_acc did not improve\n",
757 | "Epoch 40/150\n",
758 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.3271 - acc: 0.8472 - val_loss: 0.9598 - val_acc: 0.6943\n",
759 | "\n",
760 | "Epoch 00040: val_acc did not improve\n",
761 | "Epoch 41/150\n",
762 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.3361 - acc: 0.8420 - val_loss: 0.9120 - val_acc: 0.7104\n",
763 | "\n",
764 | "Epoch 00041: val_acc did not improve\n",
765 | "Epoch 42/150\n",
766 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.3290 - acc: 0.8520 - val_loss: 0.9555 - val_acc: 0.7153\n",
767 | "\n",
768 | "Epoch 00042: val_acc did not improve\n",
769 | "Epoch 43/150\n",
770 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.3165 - acc: 0.8522 - val_loss: 0.9639 - val_acc: 0.6967\n",
771 | "\n",
772 | "Epoch 00043: val_acc did not improve\n",
773 | "Epoch 44/150\n",
774 | "8162/8162 [==============================] - 1s 89us/step - loss: 0.3140 - acc: 0.8491 - val_loss: 0.9528 - val_acc: 0.7041\n",
775 | "\n",
776 | "Epoch 00044: val_acc did not improve\n",
777 | "Epoch 45/150\n",
778 | "8162/8162 [==============================] - 1s 88us/step - loss: 0.3117 - acc: 0.8537 - val_loss: 1.0267 - val_acc: 0.6948\n",
779 | "\n",
780 | "Epoch 00045: val_acc did not improve\n",
781 | "Epoch 46/150\n",
782 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.3081 - acc: 0.8581 - val_loss: 0.9870 - val_acc: 0.6879\n"
783 | ]
784 | },
785 | {
786 | "name": "stdout",
787 | "output_type": "stream",
788 | "text": [
789 | "\n",
790 | "Epoch 00046: val_acc did not improve\n",
791 | "Epoch 47/150\n",
792 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.3055 - acc: 0.8540 - val_loss: 1.1067 - val_acc: 0.7036\n",
793 | "\n",
794 | "Epoch 00047: val_acc did not improve\n",
795 | "Epoch 48/150\n",
796 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.3215 - acc: 0.8478 - val_loss: 1.0187 - val_acc: 0.7036\n",
797 | "\n",
798 | "Epoch 00048: val_acc did not improve\n",
799 | "Epoch 49/150\n",
800 | "8162/8162 [==============================] - 1s 80us/step - loss: 0.3151 - acc: 0.8502 - val_loss: 1.0717 - val_acc: 0.7026\n",
801 | "\n",
802 | "Epoch 00049: val_acc did not improve\n",
803 | "Epoch 50/150\n",
804 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.3164 - acc: 0.8542 - val_loss: 1.0011 - val_acc: 0.6923\n",
805 | "\n",
806 | "Epoch 00050: val_acc did not improve\n",
807 | "Epoch 51/150\n",
808 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.3031 - acc: 0.8580 - val_loss: 1.0531 - val_acc: 0.6948\n",
809 | "\n",
810 | "Epoch 00051: val_acc did not improve\n",
811 | "Epoch 52/150\n",
812 | "8162/8162 [==============================] - 1s 82us/step - loss: 0.3011 - acc: 0.8581 - val_loss: 1.0498 - val_acc: 0.7060\n",
813 | "\n",
814 | "Epoch 00052: val_acc did not improve\n",
815 | "Epoch 53/150\n",
816 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2919 - acc: 0.8628 - val_loss: 1.1354 - val_acc: 0.6908\n",
817 | "\n",
818 | "Epoch 00053: val_acc did not improve\n",
819 | "Epoch 54/150\n",
820 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.3005 - acc: 0.8602 - val_loss: 1.0698 - val_acc: 0.6948\n",
821 | "\n",
822 | "Epoch 00054: val_acc did not improve\n",
823 | "Epoch 55/150\n",
824 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2966 - acc: 0.8568 - val_loss: 1.1420 - val_acc: 0.7011\n",
825 | "\n",
826 | "Epoch 00055: val_acc did not improve\n",
827 | "Epoch 56/150\n",
828 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2967 - acc: 0.8590 - val_loss: 1.1139 - val_acc: 0.6972\n",
829 | "\n",
830 | "Epoch 00056: val_acc did not improve\n",
831 | "Epoch 57/150\n",
832 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2978 - acc: 0.8606 - val_loss: 1.1177 - val_acc: 0.6801\n",
833 | "\n",
834 | "Epoch 00057: val_acc did not improve\n",
835 | "Epoch 58/150\n",
836 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2958 - acc: 0.8609 - val_loss: 1.1598 - val_acc: 0.6977\n",
837 | "\n",
838 | "Epoch 00058: val_acc did not improve\n",
839 | "Epoch 59/150\n",
840 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2879 - acc: 0.8631 - val_loss: 1.1615 - val_acc: 0.6928\n",
841 | "\n",
842 | "Epoch 00059: val_acc did not improve\n",
843 | "Epoch 60/150\n",
844 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2976 - acc: 0.8567 - val_loss: 1.1657 - val_acc: 0.6859\n",
845 | "\n",
846 | "Epoch 00060: val_acc did not improve\n",
847 | "Epoch 61/150\n",
848 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2874 - acc: 0.8631 - val_loss: 1.1946 - val_acc: 0.6972\n",
849 | "\n",
850 | "Epoch 00061: val_acc did not improve\n",
851 | "Epoch 62/150\n",
852 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2920 - acc: 0.8595 - val_loss: 1.1680 - val_acc: 0.6894\n",
853 | "\n",
854 | "Epoch 00062: val_acc did not improve\n",
855 | "Epoch 63/150\n",
856 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2843 - acc: 0.8619 - val_loss: 1.1463 - val_acc: 0.6864\n",
857 | "\n",
858 | "Epoch 00063: val_acc did not improve\n",
859 | "Epoch 64/150\n",
860 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2823 - acc: 0.8618 - val_loss: 1.1966 - val_acc: 0.6967\n",
861 | "\n",
862 | "Epoch 00064: val_acc did not improve\n",
863 | "Epoch 65/150\n",
864 | "8162/8162 [==============================] - 1s 81us/step - loss: 0.2764 - acc: 0.8690 - val_loss: 1.1927 - val_acc: 0.6933\n",
865 | "\n",
866 | "Epoch 00065: val_acc did not improve\n",
867 | "Epoch 66/150\n",
868 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2833 - acc: 0.8666 - val_loss: 1.1904 - val_acc: 0.6899\n",
869 | "\n",
870 | "Epoch 00066: val_acc did not improve\n",
871 | "Epoch 67/150\n",
872 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2816 - acc: 0.8665 - val_loss: 1.2546 - val_acc: 0.7144\n",
873 | "\n",
874 | "Epoch 00067: val_acc did not improve\n",
875 | "Epoch 68/150\n",
876 | "8162/8162 [==============================] - 1s 79us/step - loss: 0.2859 - acc: 0.8650 - val_loss: 1.1555 - val_acc: 0.6982\n",
877 | "\n",
878 | "Epoch 00068: val_acc did not improve\n",
879 | "Epoch 69/150\n",
880 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2815 - acc: 0.8665 - val_loss: 1.2112 - val_acc: 0.7006\n",
881 | "\n",
882 | "Epoch 00069: val_acc did not improve\n",
883 | "Epoch 70/150\n",
884 | "8162/8162 [==============================] - 1s 82us/step - loss: 0.2841 - acc: 0.8678 - val_loss: 1.2186 - val_acc: 0.6894\n",
885 | "\n",
886 | "Epoch 00070: val_acc did not improve\n",
887 | "Epoch 71/150\n",
888 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2812 - acc: 0.8688 - val_loss: 1.2108 - val_acc: 0.6918\n",
889 | "\n",
890 | "Epoch 00071: val_acc did not improve\n",
891 | "Epoch 72/150\n",
892 | "8162/8162 [==============================] - 1s 82us/step - loss: 0.2740 - acc: 0.8694 - val_loss: 1.3209 - val_acc: 0.6889\n",
893 | "\n",
894 | "Epoch 00072: val_acc did not improve\n",
895 | "Epoch 73/150\n",
896 | "8162/8162 [==============================] - 1s 82us/step - loss: 0.2769 - acc: 0.8689 - val_loss: 1.2432 - val_acc: 0.6869\n",
897 | "\n",
898 | "Epoch 00073: val_acc did not improve\n",
899 | "Epoch 74/150\n",
900 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2721 - acc: 0.8673 - val_loss: 1.3216 - val_acc: 0.7006\n",
901 | "\n",
902 | "Epoch 00074: val_acc did not improve\n",
903 | "Epoch 75/150\n",
904 | "8162/8162 [==============================] - 1s 80us/step - loss: 0.2747 - acc: 0.8714 - val_loss: 1.2756 - val_acc: 0.6923\n",
905 | "\n",
906 | "Epoch 00075: val_acc did not improve\n",
907 | "Epoch 76/150\n",
908 | "8162/8162 [==============================] - 1s 88us/step - loss: 0.2849 - acc: 0.8645 - val_loss: 1.2655 - val_acc: 0.7021\n",
909 | "\n",
910 | "Epoch 00076: val_acc did not improve\n",
911 | "Epoch 77/150\n",
912 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2797 - acc: 0.8679 - val_loss: 1.2343 - val_acc: 0.6903\n",
913 | "\n",
914 | "Epoch 00077: val_acc did not improve\n",
915 | "Epoch 78/150\n",
916 | "8162/8162 [==============================] - 1s 83us/step - loss: 0.2753 - acc: 0.8691 - val_loss: 1.2848 - val_acc: 0.6928\n",
917 | "\n",
918 | "Epoch 00078: val_acc did not improve\n",
919 | "Epoch 79/150\n",
920 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2751 - acc: 0.8725 - val_loss: 1.3156 - val_acc: 0.7085\n",
921 | "\n",
922 | "Epoch 00079: val_acc did not improve\n",
923 | "Epoch 80/150\n",
924 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2701 - acc: 0.8725 - val_loss: 1.3206 - val_acc: 0.6913\n",
925 | "\n",
926 | "Epoch 00080: val_acc did not improve\n",
927 | "Epoch 81/150\n",
928 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2714 - acc: 0.8712 - val_loss: 1.3172 - val_acc: 0.6933\n",
929 | "\n",
930 | "Epoch 00081: val_acc did not improve\n",
931 | "Epoch 82/150\n",
932 | "8162/8162 [==============================] - 1s 90us/step - loss: 0.2739 - acc: 0.8721 - val_loss: 1.4071 - val_acc: 0.6938\n",
933 | "\n",
934 | "Epoch 00082: val_acc did not improve\n",
935 | "Epoch 83/150\n",
936 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2768 - acc: 0.8684 - val_loss: 1.2627 - val_acc: 0.6943\n",
937 | "\n",
938 | "Epoch 00083: val_acc did not improve\n",
939 | "Epoch 84/150\n",
940 | "8162/8162 [==============================] - 1s 88us/step - loss: 0.2923 - acc: 0.8662 - val_loss: 1.2891 - val_acc: 0.6997\n",
941 | "\n",
942 | "Epoch 00084: val_acc did not improve\n",
943 | "Epoch 85/150\n",
944 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2763 - acc: 0.8737 - val_loss: 1.3326 - val_acc: 0.7080\n",
945 | "\n",
946 | "Epoch 00085: val_acc did not improve\n",
947 | "Epoch 86/150\n",
948 | "8162/8162 [==============================] - 1s 90us/step - loss: 0.2626 - acc: 0.8769 - val_loss: 1.3409 - val_acc: 0.6943\n",
949 | "\n",
950 | "Epoch 00086: val_acc did not improve\n",
951 | "Epoch 87/150\n",
952 | "8162/8162 [==============================] - 1s 89us/step - loss: 0.2650 - acc: 0.8727 - val_loss: 1.3358 - val_acc: 0.7026\n",
953 | "\n",
954 | "Epoch 00087: val_acc did not improve\n",
955 | "Epoch 88/150\n",
956 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2637 - acc: 0.8767 - val_loss: 1.3283 - val_acc: 0.6938\n",
957 | "\n",
958 | "Epoch 00088: val_acc did not improve\n",
959 | "Epoch 89/150\n",
960 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2664 - acc: 0.8743 - val_loss: 1.2922 - val_acc: 0.6908\n",
961 | "\n",
962 | "Epoch 00089: val_acc did not improve\n",
963 | "Epoch 90/150\n",
964 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2603 - acc: 0.8750 - val_loss: 1.3863 - val_acc: 0.6957\n",
965 | "\n",
966 | "Epoch 00090: val_acc did not improve\n",
967 | "Epoch 91/150\n",
968 | "8162/8162 [==============================] - 1s 88us/step - loss: 0.2701 - acc: 0.8780 - val_loss: 1.3177 - val_acc: 0.7006\n",
969 | "\n",
970 | "Epoch 00091: val_acc did not improve\n",
971 | "Epoch 92/150\n",
972 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2652 - acc: 0.8798 - val_loss: 1.3475 - val_acc: 0.6854\n",
973 | "\n",
974 | "Epoch 00092: val_acc did not improve\n",
975 | "Epoch 93/150\n",
976 | "8162/8162 [==============================] - 1s 91us/step - loss: 0.2563 - acc: 0.8775 - val_loss: 1.3728 - val_acc: 0.6766\n"
977 | ]
978 | },
979 | {
980 | "name": "stdout",
981 | "output_type": "stream",
982 | "text": [
983 | "\n",
984 | "Epoch 00093: val_acc did not improve\n",
985 | "Epoch 94/150\n",
986 | "8162/8162 [==============================] - 1s 92us/step - loss: 0.2635 - acc: 0.8734 - val_loss: 1.3640 - val_acc: 0.6840\n",
987 | "\n",
988 | "Epoch 00094: val_acc did not improve\n",
989 | "Epoch 95/150\n",
990 | "8162/8162 [==============================] - 1s 92us/step - loss: 0.2556 - acc: 0.8787 - val_loss: 1.4000 - val_acc: 0.6903\n",
991 | "\n",
992 | "Epoch 00095: val_acc did not improve\n",
993 | "Epoch 96/150\n",
994 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2630 - acc: 0.8790 - val_loss: 1.3762 - val_acc: 0.6992\n",
995 | "\n",
996 | "Epoch 00096: val_acc did not improve\n",
997 | "Epoch 97/150\n",
998 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2558 - acc: 0.8763 - val_loss: 1.4405 - val_acc: 0.7016\n",
999 | "\n",
1000 | "Epoch 00097: val_acc did not improve\n",
1001 | "Epoch 98/150\n",
1002 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2553 - acc: 0.8805 - val_loss: 1.3749 - val_acc: 0.6948\n",
1003 | "\n",
1004 | "Epoch 00098: val_acc did not improve\n",
1005 | "Epoch 99/150\n",
1006 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2701 - acc: 0.8743 - val_loss: 1.4121 - val_acc: 0.6982\n",
1007 | "\n",
1008 | "Epoch 00099: val_acc did not improve\n",
1009 | "Epoch 100/150\n",
1010 | "8162/8162 [==============================] - 1s 90us/step - loss: 0.2668 - acc: 0.8760 - val_loss: 1.4003 - val_acc: 0.6957\n",
1011 | "\n",
1012 | "Epoch 00100: val_acc did not improve\n",
1013 | "Epoch 101/150\n",
1014 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2564 - acc: 0.8785 - val_loss: 1.4701 - val_acc: 0.6967\n",
1015 | "\n",
1016 | "Epoch 00101: val_acc did not improve\n",
1017 | "Epoch 102/150\n",
1018 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2606 - acc: 0.8801 - val_loss: 1.4411 - val_acc: 0.6952\n",
1019 | "\n",
1020 | "Epoch 00102: val_acc did not improve\n",
1021 | "Epoch 103/150\n",
1022 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2564 - acc: 0.8737 - val_loss: 1.3929 - val_acc: 0.6962\n",
1023 | "\n",
1024 | "Epoch 00103: val_acc did not improve\n",
1025 | "Epoch 104/150\n",
1026 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2539 - acc: 0.8783 - val_loss: 1.5250 - val_acc: 0.6943\n",
1027 | "\n",
1028 | "Epoch 00104: val_acc did not improve\n",
1029 | "Epoch 105/150\n",
1030 | "8162/8162 [==============================] - 1s 90us/step - loss: 0.2542 - acc: 0.8791 - val_loss: 1.4034 - val_acc: 0.7011\n",
1031 | "\n",
1032 | "Epoch 00105: val_acc did not improve\n",
1033 | "Epoch 106/150\n",
1034 | "8162/8162 [==============================] - 1s 89us/step - loss: 0.2532 - acc: 0.8796 - val_loss: 1.4905 - val_acc: 0.6933\n",
1035 | "\n",
1036 | "Epoch 00106: val_acc did not improve\n",
1037 | "Epoch 107/150\n",
1038 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2579 - acc: 0.8776 - val_loss: 1.4361 - val_acc: 0.6967\n",
1039 | "\n",
1040 | "Epoch 00107: val_acc did not improve\n",
1041 | "Epoch 108/150\n",
1042 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2586 - acc: 0.8785 - val_loss: 1.4582 - val_acc: 0.6889\n",
1043 | "\n",
1044 | "Epoch 00108: val_acc did not improve\n",
1045 | "Epoch 109/150\n",
1046 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2582 - acc: 0.8775 - val_loss: 1.4093 - val_acc: 0.6825\n",
1047 | "\n",
1048 | "Epoch 00109: val_acc did not improve\n",
1049 | "Epoch 110/150\n",
1050 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2541 - acc: 0.8819 - val_loss: 1.4957 - val_acc: 0.6957\n",
1051 | "\n",
1052 | "Epoch 00110: val_acc did not improve\n",
1053 | "Epoch 111/150\n",
1054 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2534 - acc: 0.8814 - val_loss: 1.4838 - val_acc: 0.6899\n",
1055 | "\n",
1056 | "Epoch 00111: val_acc did not improve\n",
1057 | "Epoch 112/150\n",
1058 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2544 - acc: 0.8793 - val_loss: 1.4498 - val_acc: 0.6923\n",
1059 | "\n",
1060 | "Epoch 00112: val_acc did not improve\n",
1061 | "Epoch 113/150\n",
1062 | "8162/8162 [==============================] - 1s 98us/step - loss: 0.2460 - acc: 0.8831 - val_loss: 1.5025 - val_acc: 0.6815\n",
1063 | "\n",
1064 | "Epoch 00113: val_acc did not improve\n",
1065 | "Epoch 114/150\n",
1066 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2433 - acc: 0.8829 - val_loss: 1.3985 - val_acc: 0.6805\n",
1067 | "\n",
1068 | "Epoch 00114: val_acc did not improve\n",
1069 | "Epoch 115/150\n",
1070 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2399 - acc: 0.8869 - val_loss: 1.5132 - val_acc: 0.6850\n",
1071 | "\n",
1072 | "Epoch 00115: val_acc did not improve\n",
1073 | "Epoch 116/150\n",
1074 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2390 - acc: 0.8859 - val_loss: 1.5096 - val_acc: 0.6903\n",
1075 | "\n",
1076 | "Epoch 00116: val_acc did not improve\n",
1077 | "Epoch 117/150\n",
1078 | "8162/8162 [==============================] - 1s 80us/step - loss: 0.2387 - acc: 0.8824 - val_loss: 1.5236 - val_acc: 0.6894\n",
1079 | "\n",
1080 | "Epoch 00117: val_acc did not improve\n",
1081 | "Epoch 118/150\n",
1082 | "8162/8162 [==============================] - 1s 83us/step - loss: 0.2418 - acc: 0.8815 - val_loss: 1.5059 - val_acc: 0.6854\n",
1083 | "\n",
1084 | "Epoch 00118: val_acc did not improve\n",
1085 | "Epoch 119/150\n",
1086 | "8162/8162 [==============================] - 1s 77us/step - loss: 0.2457 - acc: 0.8825 - val_loss: 1.5490 - val_acc: 0.6835\n",
1087 | "\n",
1088 | "Epoch 00119: val_acc did not improve\n",
1089 | "Epoch 120/150\n",
1090 | "8162/8162 [==============================] - 1s 80us/step - loss: 0.2483 - acc: 0.8804 - val_loss: 1.5204 - val_acc: 0.6928\n",
1091 | "\n",
1092 | "Epoch 00120: val_acc did not improve\n",
1093 | "Epoch 121/150\n",
1094 | "8162/8162 [==============================] - 1s 79us/step - loss: 0.2588 - acc: 0.8805 - val_loss: 1.5191 - val_acc: 0.6874\n",
1095 | "\n",
1096 | "Epoch 00121: val_acc did not improve\n",
1097 | "Epoch 122/150\n",
1098 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2460 - acc: 0.8814 - val_loss: 1.4850 - val_acc: 0.6810\n",
1099 | "\n",
1100 | "Epoch 00122: val_acc did not improve\n",
1101 | "Epoch 123/150\n",
1102 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2451 - acc: 0.8813 - val_loss: 1.5511 - val_acc: 0.6918\n",
1103 | "\n",
1104 | "Epoch 00123: val_acc did not improve\n",
1105 | "Epoch 124/150\n",
1106 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2384 - acc: 0.8864 - val_loss: 1.5619 - val_acc: 0.6864\n",
1107 | "\n",
1108 | "Epoch 00124: val_acc did not improve\n",
1109 | "Epoch 125/150\n",
1110 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2405 - acc: 0.8851 - val_loss: 1.5573 - val_acc: 0.6933\n",
1111 | "\n",
1112 | "Epoch 00125: val_acc did not improve\n",
1113 | "Epoch 126/150\n",
1114 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2343 - acc: 0.8851 - val_loss: 1.6384 - val_acc: 0.6972\n",
1115 | "\n",
1116 | "Epoch 00126: val_acc did not improve\n",
1117 | "Epoch 127/150\n",
1118 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2415 - acc: 0.8853 - val_loss: 1.5703 - val_acc: 0.6948\n",
1119 | "\n",
1120 | "Epoch 00127: val_acc did not improve\n",
1121 | "Epoch 128/150\n",
1122 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2368 - acc: 0.8836 - val_loss: 1.5816 - val_acc: 0.6869\n",
1123 | "\n",
1124 | "Epoch 00128: val_acc did not improve\n",
1125 | "Epoch 129/150\n",
1126 | "8162/8162 [==============================] - 1s 82us/step - loss: 0.2383 - acc: 0.8859 - val_loss: 1.5480 - val_acc: 0.6899\n",
1127 | "\n",
1128 | "Epoch 00129: val_acc did not improve\n",
1129 | "Epoch 130/150\n",
1130 | "8162/8162 [==============================] - 1s 80us/step - loss: 0.2436 - acc: 0.8836 - val_loss: 1.5502 - val_acc: 0.6943\n",
1131 | "\n",
1132 | "Epoch 00130: val_acc did not improve\n",
1133 | "Epoch 131/150\n",
1134 | "8162/8162 [==============================] - 1s 83us/step - loss: 0.2392 - acc: 0.8877 - val_loss: 1.5478 - val_acc: 0.6957\n",
1135 | "\n",
1136 | "Epoch 00131: val_acc did not improve\n",
1137 | "Epoch 132/150\n",
1138 | "8162/8162 [==============================] - 1s 80us/step - loss: 0.2411 - acc: 0.8850 - val_loss: 1.4816 - val_acc: 0.6874\n",
1139 | "\n",
1140 | "Epoch 00132: val_acc did not improve\n",
1141 | "Epoch 133/150\n",
1142 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2377 - acc: 0.8878 - val_loss: 1.5740 - val_acc: 0.6825\n",
1143 | "\n",
1144 | "Epoch 00133: val_acc did not improve\n",
1145 | "Epoch 134/150\n",
1146 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2435 - acc: 0.8852 - val_loss: 1.7086 - val_acc: 0.6884\n",
1147 | "\n",
1148 | "Epoch 00134: val_acc did not improve\n",
1149 | "Epoch 135/150\n",
1150 | "8162/8162 [==============================] - 1s 89us/step - loss: 0.2335 - acc: 0.8892 - val_loss: 1.5845 - val_acc: 0.6884\n",
1151 | "\n",
1152 | "Epoch 00135: val_acc did not improve\n",
1153 | "Epoch 136/150\n",
1154 | "8162/8162 [==============================] - 1s 83us/step - loss: 0.2461 - acc: 0.8823 - val_loss: 1.5247 - val_acc: 0.7011\n",
1155 | "\n",
1156 | "Epoch 00136: val_acc did not improve\n",
1157 | "Epoch 137/150\n",
1158 | "8162/8162 [==============================] - 1s 84us/step - loss: 0.2321 - acc: 0.8885 - val_loss: 1.5387 - val_acc: 0.6859\n",
1159 | "\n",
1160 | "Epoch 00137: val_acc did not improve\n",
1161 | "Epoch 138/150\n",
1162 | "8162/8162 [==============================] - 1s 79us/step - loss: 0.2338 - acc: 0.8892 - val_loss: 1.5845 - val_acc: 0.6791\n",
1163 | "\n",
1164 | "Epoch 00138: val_acc did not improve\n",
1165 | "Epoch 139/150\n",
1166 | "8162/8162 [==============================] - 1s 82us/step - loss: 0.2365 - acc: 0.8859 - val_loss: 1.5788 - val_acc: 0.6903\n",
1167 | "\n",
1168 | "Epoch 00139: val_acc did not improve\n",
1169 | "Epoch 140/150\n"
1170 | ]
1171 | },
1172 | {
1173 | "name": "stdout",
1174 | "output_type": "stream",
1175 | "text": [
1176 | "8162/8162 [==============================] - 1s 80us/step - loss: 0.2325 - acc: 0.8852 - val_loss: 1.6379 - val_acc: 0.6952\n",
1177 | "\n",
1178 | "Epoch 00140: val_acc did not improve\n",
1179 | "Epoch 141/150\n",
1180 | "8162/8162 [==============================] - 1s 80us/step - loss: 0.2449 - acc: 0.8868 - val_loss: 1.6677 - val_acc: 0.6874\n",
1181 | "\n",
1182 | "Epoch 00141: val_acc did not improve\n",
1183 | "Epoch 142/150\n",
1184 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2446 - acc: 0.8823 - val_loss: 1.5542 - val_acc: 0.6830\n",
1185 | "\n",
1186 | "Epoch 00142: val_acc did not improve\n",
1187 | "Epoch 143/150\n",
1188 | "8162/8162 [==============================] - 1s 91us/step - loss: 0.2318 - acc: 0.8878 - val_loss: 1.5618 - val_acc: 0.6928\n",
1189 | "\n",
1190 | "Epoch 00143: val_acc did not improve\n",
1191 | "Epoch 144/150\n",
1192 | "8162/8162 [==============================] - 1s 92us/step - loss: 0.2256 - acc: 0.8917 - val_loss: 1.6059 - val_acc: 0.6850\n",
1193 | "\n",
1194 | "Epoch 00144: val_acc did not improve\n",
1195 | "Epoch 145/150\n",
1196 | "8162/8162 [==============================] - 1s 91us/step - loss: 0.2291 - acc: 0.8899 - val_loss: 1.6069 - val_acc: 0.6933\n",
1197 | "\n",
1198 | "Epoch 00145: val_acc did not improve\n",
1199 | "Epoch 146/150\n",
1200 | "8162/8162 [==============================] - 1s 86us/step - loss: 0.2277 - acc: 0.8877 - val_loss: 1.5481 - val_acc: 0.6840\n",
1201 | "\n",
1202 | "Epoch 00146: val_acc did not improve\n",
1203 | "Epoch 147/150\n",
1204 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2282 - acc: 0.8873 - val_loss: 1.6756 - val_acc: 0.6903\n",
1205 | "\n",
1206 | "Epoch 00147: val_acc did not improve\n",
1207 | "Epoch 148/150\n",
1208 | "8162/8162 [==============================] - 1s 85us/step - loss: 0.2336 - acc: 0.8917 - val_loss: 1.6334 - val_acc: 0.6992\n",
1209 | "\n",
1210 | "Epoch 00148: val_acc did not improve\n",
1211 | "Epoch 149/150\n",
1212 | "8162/8162 [==============================] - 1s 87us/step - loss: 0.2335 - acc: 0.8891 - val_loss: 1.7145 - val_acc: 0.6874\n",
1213 | "\n",
1214 | "Epoch 00149: val_acc did not improve\n",
1215 | "Epoch 150/150\n",
1216 | "8162/8162 [==============================] - 1s 90us/step - loss: 0.2500 - acc: 0.8827 - val_loss: 1.6869 - val_acc: 0.6908\n",
1217 | "\n",
1218 | "Epoch 00150: val_acc did not improve\n"
1219 | ]
1220 | },
1221 | {
1222 | "data": {
1223 | "text/plain": [
1224 | ""
1225 | ]
1226 | },
1227 | "execution_count": 73,
1228 | "metadata": {},
1229 | "output_type": "execute_result"
1230 | }
1231 | ],
1232 | "source": [
1233 | "model.fit(X_train, y_train, batch_size=20, nb_epoch=150, verbose=1, callbacks=callbacks_list, validation_data=(X_test, y_test), shuffle=True)"
1234 | ]
1235 | },
1236 | {
1237 | "cell_type": "code",
1238 | "execution_count": 74,
1239 | "metadata": {},
1240 | "outputs": [
1241 | {
1242 | "name": "stdout",
1243 | "output_type": "stream",
1244 | "text": [
1245 | "acc: 74.47%\n"
1246 | ]
1247 | }
1248 | ],
1249 | "source": [
1250 | "# Load model \n",
1251 | "model.load_weights(\"weights.best.hdf5\")\n",
1252 | "\n",
1253 | "# estimate accuracy on test data set using loaded weights\n",
1254 | "scores = model.evaluate(X_test, y_test, verbose=0)\n",
1255 | "print(\"%s: %.2f%%\" % (model.metrics_names[1], scores[1]*100))"
1256 | ]
1257 | },
1258 | {
1259 | "cell_type": "markdown",
1260 | "metadata": {},
1261 | "source": [
1262 | "#### So with a simple neural network focused on demographic data only we have an accuracy of 74%, a 5% improvement over XGBoost"
1263 | ]
1264 | },
1265 | {
1266 | "cell_type": "code",
1267 | "execution_count": null,
1268 | "metadata": {},
1269 | "outputs": [],
1270 | "source": []
1271 | }
1272 | ],
1273 | "metadata": {
1274 | "kernelspec": {
1275 | "display_name": "Python 3",
1276 | "language": "python",
1277 | "name": "python3"
1278 | },
1279 | "language_info": {
1280 | "codemirror_mode": {
1281 | "name": "ipython",
1282 | "version": 3
1283 | },
1284 | "file_extension": ".py",
1285 | "mimetype": "text/x-python",
1286 | "name": "python",
1287 | "nbconvert_exporter": "python",
1288 | "pygments_lexer": "ipython3",
1289 | "version": "3.6.4"
1290 | }
1291 | },
1292 | "nbformat": 4,
1293 | "nbformat_minor": 2
1294 | }
1295 |
--------------------------------------------------------------------------------
/UserGuideMimicII.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rjgiedt/Cardiovascular_Death_Prediction/90f74112b30af528d1351bfd8daa0cc1bf6ae46a/UserGuideMimicII.pdf
--------------------------------------------------------------------------------
/XGBoost_Baseline_Models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Implementing a baseline machine learning model"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 40,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pandas as pd"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 41,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "demographics = pd.read_csv('Demographics.csv')"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 42,
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "data": {
35 | "text/html": [
36 | "\n",
37 | "\n",
50 | "
\n",
51 | " \n",
52 | " \n",
53 | " | \n",
54 | " GENDER | \n",
55 | " ADMIT_AGE | \n",
56 | " ETHNICITY | \n",
57 | " MARITAL_STATUS | \n",
58 | " LANGUAGE | \n",
59 | " RELIGION | \n",
60 | " INSURANCE | \n",
61 | " ADMISSION_LOCATION | \n",
62 | " #ADMISSIONS | \n",
63 | " OUTSIDE_DEATH_FLAG | \n",
64 | " DEATH_FLAG | \n",
65 | " OLD_FLAG | \n",
66 | " HEART_ATTACK_FLAG | \n",
67 | " ATHERO_DIAGNOSIS_FLAG | \n",
68 | " HEART_DEATH_FLAG | \n",
69 | "
\n",
70 | " \n",
71 | " \n",
72 | " \n",
73 | " 0 | \n",
74 | " M | \n",
75 | " 72.312329 | \n",
76 | " WHITE | \n",
77 | " MARRIED | \n",
78 | " UKNOWN | \n",
79 | " CATHOLIC | \n",
80 | " Medicare | \n",
81 | " TRANSFER FROM HOSP/EXTRAM | \n",
82 | " 1 | \n",
83 | " 0 | \n",
84 | " 1 | \n",
85 | " 0 | \n",
86 | " 0 | \n",
87 | " 0 | \n",
88 | " 0 | \n",
89 | "
\n",
90 | " \n",
91 | " 1 | \n",
92 | " F | \n",
93 | " NaN | \n",
94 | " WHITE | \n",
95 | " UKNOWN | \n",
96 | " UKNOWN | \n",
97 | " NOT SPECIFIED | \n",
98 | " Medicare | \n",
99 | " EMERGENCY ROOM ADMIT | \n",
100 | " 1 | \n",
101 | " 0 | \n",
102 | " 1 | \n",
103 | " 1 | \n",
104 | " 0 | \n",
105 | " 0 | \n",
106 | " 0 | \n",
107 | "
\n",
108 | " \n",
109 | " 2 | \n",
110 | " M | \n",
111 | " 55.241096 | \n",
112 | " WHITE | \n",
113 | " MARRIED | \n",
114 | " UKNOWN | \n",
115 | " CATHOLIC | \n",
116 | " Private | \n",
117 | " CLINIC REFERRAL/PREMATURE | \n",
118 | " 2 | \n",
119 | " 0 | \n",
120 | " 1 | \n",
121 | " 0 | \n",
122 | " 0 | \n",
123 | " 0 | \n",
124 | " 0 | \n",
125 | "
\n",
126 | " \n",
127 | " 3 | \n",
128 | " M | \n",
129 | " 73.539726 | \n",
130 | " WHITE | \n",
131 | " SINGLE | \n",
132 | " UKNOWN | \n",
133 | " JEWISH | \n",
134 | " Medicare | \n",
135 | " EMERGENCY ROOM ADMIT | \n",
136 | " 2 | \n",
137 | " 0 | \n",
138 | " 1 | \n",
139 | " 0 | \n",
140 | " 0 | \n",
141 | " 0 | \n",
142 | " 0 | \n",
143 | "
\n",
144 | " \n",
145 | " 4 | \n",
146 | " F | \n",
147 | " 44.512329 | \n",
148 | " WHITE | \n",
149 | " MARRIED | \n",
150 | " UKNOWN | \n",
151 | " OTHER | \n",
152 | " Private | \n",
153 | " EMERGENCY ROOM ADMIT | \n",
154 | " 2 | \n",
155 | " 0 | \n",
156 | " 1 | \n",
157 | " 0 | \n",
158 | " 0 | \n",
159 | " 0 | \n",
160 | " 0 | \n",
161 | "
\n",
162 | " \n",
163 | "
\n",
164 | "
"
165 | ],
166 | "text/plain": [
167 | " GENDER ADMIT_AGE ETHNICITY MARITAL_STATUS LANGUAGE RELIGION \\\n",
168 | "0 M 72.312329 WHITE MARRIED UKNOWN CATHOLIC \n",
169 | "1 F NaN WHITE UKNOWN UKNOWN NOT SPECIFIED \n",
170 | "2 M 55.241096 WHITE MARRIED UKNOWN CATHOLIC \n",
171 | "3 M 73.539726 WHITE SINGLE UKNOWN JEWISH \n",
172 | "4 F 44.512329 WHITE MARRIED UKNOWN OTHER \n",
173 | "\n",
174 | " INSURANCE ADMISSION_LOCATION #ADMISSIONS OUTSIDE_DEATH_FLAG \\\n",
175 | "0 Medicare TRANSFER FROM HOSP/EXTRAM 1 0 \n",
176 | "1 Medicare EMERGENCY ROOM ADMIT 1 0 \n",
177 | "2 Private CLINIC REFERRAL/PREMATURE 2 0 \n",
178 | "3 Medicare EMERGENCY ROOM ADMIT 2 0 \n",
179 | "4 Private EMERGENCY ROOM ADMIT 2 0 \n",
180 | "\n",
181 | " DEATH_FLAG OLD_FLAG HEART_ATTACK_FLAG ATHERO_DIAGNOSIS_FLAG \\\n",
182 | "0 1 0 0 0 \n",
183 | "1 1 1 0 0 \n",
184 | "2 1 0 0 0 \n",
185 | "3 1 0 0 0 \n",
186 | "4 1 0 0 0 \n",
187 | "\n",
188 | " HEART_DEATH_FLAG \n",
189 | "0 0 \n",
190 | "1 0 \n",
191 | "2 0 \n",
192 | "3 0 \n",
193 | "4 0 "
194 | ]
195 | },
196 | "execution_count": 42,
197 | "metadata": {},
198 | "output_type": "execute_result"
199 | }
200 | ],
201 | "source": [
202 | "ml = demographics.loc[:,'GENDER': 'HEART_DEATH_FLAG']\n",
203 | "del ml['DOB']\n",
204 | "del ml['DOD']\n",
205 | "del ml['DOA']\n",
206 | "ml.head()"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": 43,
212 | "metadata": {},
213 | "outputs": [],
214 | "source": [
215 | "# Get numerical data\n",
216 | "ml_data = pd.get_dummies(ml, columns=['GENDER','ETHNICITY','MARITAL_STATUS', 'LANGUAGE', 'RELIGION', 'INSURANCE', 'ADMISSION_LOCATION'])\n",
217 | "ml_data = ml_data[ml_data['OLD_FLAG']==0]"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 44,
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "# Reduce population to only those with ages\n",
227 | "ml_data = ml_data[ml_data['OLD_FLAG']==0]"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 45,
233 | "metadata": {},
234 | "outputs": [],
235 | "source": [
236 | "# Produce output data sets to create models\n",
237 | "heart_attacks = ml_data['HEART_ATTACK_FLAG']\n",
238 | "athero_diagnosis = ml_data['ATHERO_DIAGNOSIS_FLAG']\n",
239 | "deaths = ml_data['DEATH_FLAG']\n",
240 | "heart_deaths = ml_data['HEART_DEATH_FLAG']"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 46,
246 | "metadata": {},
247 | "outputs": [
248 | {
249 | "data": {
250 | "text/html": [
251 | "\n",
252 | "\n",
265 | "
\n",
266 | " \n",
267 | " \n",
268 | " | \n",
269 | " ADMIT_AGE | \n",
270 | " #ADMISSIONS | \n",
271 | " GENDER_F | \n",
272 | " GENDER_M | \n",
273 | " ETHNICITY_AMERICAN INDIAN/ALASKA NATIVE | \n",
274 | " ETHNICITY_AMERICAN INDIAN/ALASKA NATIVE FEDERALLY RECOGNIZED TRIBE | \n",
275 | " ETHNICITY_ASIAN | \n",
276 | " ETHNICITY_ASIAN - ASIAN INDIAN | \n",
277 | " ETHNICITY_ASIAN - CAMBODIAN | \n",
278 | " ETHNICITY_ASIAN - CHINESE | \n",
279 | " ... | \n",
280 | " INSURANCE_Self Pay | \n",
281 | " ADMISSION_LOCATION_** INFO NOT AVAILABLE ** | \n",
282 | " ADMISSION_LOCATION_CLINIC REFERRAL/PREMATURE | \n",
283 | " ADMISSION_LOCATION_EMERGENCY ROOM ADMIT | \n",
284 | " ADMISSION_LOCATION_HMO REFERRAL/SICK | \n",
285 | " ADMISSION_LOCATION_PHYS REFERRAL/NORMAL DELI | \n",
286 | " ADMISSION_LOCATION_TRANSFER FROM HOSP/EXTRAM | \n",
287 | " ADMISSION_LOCATION_TRANSFER FROM OTHER HEALT | \n",
288 | " ADMISSION_LOCATION_TRANSFER FROM SKILLED NUR | \n",
289 | " ADMISSION_LOCATION_TRSF WITHIN THIS FACILITY | \n",
290 | "
\n",
291 | " \n",
292 | " \n",
293 | " \n",
294 | " 0 | \n",
295 | " 72.312329 | \n",
296 | " 1 | \n",
297 | " 0 | \n",
298 | " 1 | \n",
299 | " 0 | \n",
300 | " 0 | \n",
301 | " 0 | \n",
302 | " 0 | \n",
303 | " 0 | \n",
304 | " 0 | \n",
305 | " ... | \n",
306 | " 0 | \n",
307 | " 0 | \n",
308 | " 0 | \n",
309 | " 0 | \n",
310 | " 0 | \n",
311 | " 0 | \n",
312 | " 1 | \n",
313 | " 0 | \n",
314 | " 0 | \n",
315 | " 0 | \n",
316 | "
\n",
317 | " \n",
318 | " 2 | \n",
319 | " 55.241096 | \n",
320 | " 2 | \n",
321 | " 0 | \n",
322 | " 1 | \n",
323 | " 0 | \n",
324 | " 0 | \n",
325 | " 0 | \n",
326 | " 0 | \n",
327 | " 0 | \n",
328 | " 0 | \n",
329 | " ... | \n",
330 | " 0 | \n",
331 | " 0 | \n",
332 | " 1 | \n",
333 | " 0 | \n",
334 | " 0 | \n",
335 | " 0 | \n",
336 | " 0 | \n",
337 | " 0 | \n",
338 | " 0 | \n",
339 | " 0 | \n",
340 | "
\n",
341 | " \n",
342 | " 3 | \n",
343 | " 73.539726 | \n",
344 | " 2 | \n",
345 | " 0 | \n",
346 | " 1 | \n",
347 | " 0 | \n",
348 | " 0 | \n",
349 | " 0 | \n",
350 | " 0 | \n",
351 | " 0 | \n",
352 | " 0 | \n",
353 | " ... | \n",
354 | " 0 | \n",
355 | " 0 | \n",
356 | " 0 | \n",
357 | " 1 | \n",
358 | " 0 | \n",
359 | " 0 | \n",
360 | " 0 | \n",
361 | " 0 | \n",
362 | " 0 | \n",
363 | " 0 | \n",
364 | "
\n",
365 | " \n",
366 | " 4 | \n",
367 | " 44.512329 | \n",
368 | " 2 | \n",
369 | " 1 | \n",
370 | " 0 | \n",
371 | " 0 | \n",
372 | " 0 | \n",
373 | " 0 | \n",
374 | " 0 | \n",
375 | " 0 | \n",
376 | " 0 | \n",
377 | " ... | \n",
378 | " 0 | \n",
379 | " 0 | \n",
380 | " 0 | \n",
381 | " 1 | \n",
382 | " 0 | \n",
383 | " 0 | \n",
384 | " 0 | \n",
385 | " 0 | \n",
386 | " 0 | \n",
387 | " 0 | \n",
388 | "
\n",
389 | " \n",
390 | " 5 | \n",
391 | " 81.627397 | \n",
392 | " 1 | \n",
393 | " 1 | \n",
394 | " 0 | \n",
395 | " 0 | \n",
396 | " 0 | \n",
397 | " 0 | \n",
398 | " 0 | \n",
399 | " 0 | \n",
400 | " 0 | \n",
401 | " ... | \n",
402 | " 0 | \n",
403 | " 0 | \n",
404 | " 0 | \n",
405 | " 1 | \n",
406 | " 0 | \n",
407 | " 0 | \n",
408 | " 0 | \n",
409 | " 0 | \n",
410 | " 0 | \n",
411 | " 0 | \n",
412 | "
\n",
413 | " \n",
414 | "
\n",
415 | "
5 rows × 164 columns
\n",
416 | "
"
417 | ],
418 | "text/plain": [
419 | " ADMIT_AGE #ADMISSIONS GENDER_F GENDER_M \\\n",
420 | "0 72.312329 1 0 1 \n",
421 | "2 55.241096 2 0 1 \n",
422 | "3 73.539726 2 0 1 \n",
423 | "4 44.512329 2 1 0 \n",
424 | "5 81.627397 1 1 0 \n",
425 | "\n",
426 | " ETHNICITY_AMERICAN INDIAN/ALASKA NATIVE \\\n",
427 | "0 0 \n",
428 | "2 0 \n",
429 | "3 0 \n",
430 | "4 0 \n",
431 | "5 0 \n",
432 | "\n",
433 | " ETHNICITY_AMERICAN INDIAN/ALASKA NATIVE FEDERALLY RECOGNIZED TRIBE \\\n",
434 | "0 0 \n",
435 | "2 0 \n",
436 | "3 0 \n",
437 | "4 0 \n",
438 | "5 0 \n",
439 | "\n",
440 | " ETHNICITY_ASIAN ETHNICITY_ASIAN - ASIAN INDIAN \\\n",
441 | "0 0 0 \n",
442 | "2 0 0 \n",
443 | "3 0 0 \n",
444 | "4 0 0 \n",
445 | "5 0 0 \n",
446 | "\n",
447 | " ETHNICITY_ASIAN - CAMBODIAN ETHNICITY_ASIAN - CHINESE \\\n",
448 | "0 0 0 \n",
449 | "2 0 0 \n",
450 | "3 0 0 \n",
451 | "4 0 0 \n",
452 | "5 0 0 \n",
453 | "\n",
454 | " ... INSURANCE_Self Pay \\\n",
455 | "0 ... 0 \n",
456 | "2 ... 0 \n",
457 | "3 ... 0 \n",
458 | "4 ... 0 \n",
459 | "5 ... 0 \n",
460 | "\n",
461 | " ADMISSION_LOCATION_** INFO NOT AVAILABLE ** \\\n",
462 | "0 0 \n",
463 | "2 0 \n",
464 | "3 0 \n",
465 | "4 0 \n",
466 | "5 0 \n",
467 | "\n",
468 | " ADMISSION_LOCATION_CLINIC REFERRAL/PREMATURE \\\n",
469 | "0 0 \n",
470 | "2 1 \n",
471 | "3 0 \n",
472 | "4 0 \n",
473 | "5 0 \n",
474 | "\n",
475 | " ADMISSION_LOCATION_EMERGENCY ROOM ADMIT \\\n",
476 | "0 0 \n",
477 | "2 0 \n",
478 | "3 1 \n",
479 | "4 1 \n",
480 | "5 1 \n",
481 | "\n",
482 | " ADMISSION_LOCATION_HMO REFERRAL/SICK \\\n",
483 | "0 0 \n",
484 | "2 0 \n",
485 | "3 0 \n",
486 | "4 0 \n",
487 | "5 0 \n",
488 | "\n",
489 | " ADMISSION_LOCATION_PHYS REFERRAL/NORMAL DELI \\\n",
490 | "0 0 \n",
491 | "2 0 \n",
492 | "3 0 \n",
493 | "4 0 \n",
494 | "5 0 \n",
495 | "\n",
496 | " ADMISSION_LOCATION_TRANSFER FROM HOSP/EXTRAM \\\n",
497 | "0 1 \n",
498 | "2 0 \n",
499 | "3 0 \n",
500 | "4 0 \n",
501 | "5 0 \n",
502 | "\n",
503 | " ADMISSION_LOCATION_TRANSFER FROM OTHER HEALT \\\n",
504 | "0 0 \n",
505 | "2 0 \n",
506 | "3 0 \n",
507 | "4 0 \n",
508 | "5 0 \n",
509 | "\n",
510 | " ADMISSION_LOCATION_TRANSFER FROM SKILLED NUR \\\n",
511 | "0 0 \n",
512 | "2 0 \n",
513 | "3 0 \n",
514 | "4 0 \n",
515 | "5 0 \n",
516 | "\n",
517 | " ADMISSION_LOCATION_TRSF WITHIN THIS FACILITY \n",
518 | "0 0 \n",
519 | "2 0 \n",
520 | "3 0 \n",
521 | "4 0 \n",
522 | "5 0 \n",
523 | "\n",
524 | "[5 rows x 164 columns]"
525 | ]
526 | },
527 | "execution_count": 46,
528 | "metadata": {},
529 | "output_type": "execute_result"
530 | }
531 | ],
532 | "source": [
533 | "# Predict just deaths on non-diagnostic data\n",
534 | "del ml_data['HEART_ATTACK_FLAG']\n",
535 | "del ml_data['ATHERO_DIAGNOSIS_FLAG']\n",
536 | "del ml_data['DEATH_FLAG']\n",
537 | "del ml_data['HEART_DEATH_FLAG']\n",
538 | "del ml_data['OLD_FLAG']\n",
539 | "del ml_data['OUTSIDE_DEATH_FLAG']\n",
540 | "\n",
541 | "ml_data.head()"
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "execution_count": 47,
547 | "metadata": {},
548 | "outputs": [],
549 | "source": [
550 | "# Create randomly undersampled data set\n",
551 | "from imblearn.under_sampling import RandomUnderSampler\n",
552 | "rus = RandomUnderSampler(return_indices=True)\n",
553 | "X_resampled, y_resampled, idx_resampled = rus.fit_sample(ml_data, deaths)"
554 | ]
555 | },
556 | {
557 | "cell_type": "code",
558 | "execution_count": 48,
559 | "metadata": {},
560 | "outputs": [
561 | {
562 | "name": "stdout",
563 | "output_type": "stream",
564 | "text": [
565 | "1 14320\n",
566 | "0 14320\n",
567 | "dtype: int64\n",
568 | "0 30250\n",
569 | "1 14320\n",
570 | "Name: DEATH_FLAG, dtype: int64\n"
571 | ]
572 | }
573 | ],
574 | "source": [
575 | "# Check sampling numbers\n",
576 | "y_resampled = pd.Series(y_resampled)\n",
577 | "print(y_resampled.value_counts())\n",
578 | "print(deaths.value_counts())"
579 | ]
580 | },
581 | {
582 | "cell_type": "code",
583 | "execution_count": 49,
584 | "metadata": {},
585 | "outputs": [
586 | {
587 | "data": {
588 | "text/plain": [
589 | "0.7068435754189945"
590 | ]
591 | },
592 | "execution_count": 49,
593 | "metadata": {},
594 | "output_type": "execute_result"
595 | }
596 | ],
597 | "source": [
598 | "# Create and test an XGBoost with 5 fold cross validation for predicting death on this model with no hyperparameter \n",
599 | "# optimization\n",
600 | "from sklearn.ensemble import GradientBoostingClassifier\n",
601 | "from sklearn.model_selection import cross_val_score\n",
602 | "\n",
603 | "base_model_XG = GradientBoostingClassifier()\n",
604 | "scores = cross_val_score(base_model_XG, X_resampled, y_resampled, cv=5)\n",
605 | "scores.mean()"
606 | ]
607 | },
608 | {
609 | "cell_type": "code",
610 | "execution_count": 34,
611 | "metadata": {},
612 | "outputs": [],
613 | "source": [
614 | "# Same process for predicting heart attacks\n",
615 | "rus = RandomUnderSampler(return_indices=True)\n",
616 | "X_resampled, y_resampled, idx_resampled = rus.fit_sample(ml_data, heart_attacks)"
617 | ]
618 | },
619 | {
620 | "cell_type": "code",
621 | "execution_count": 35,
622 | "metadata": {},
623 | "outputs": [
624 | {
625 | "name": "stdout",
626 | "output_type": "stream",
627 | "text": [
628 | "1 121\n",
629 | "0 121\n",
630 | "dtype: int64\n",
631 | "0 30250\n",
632 | "1 14320\n",
633 | "Name: DEATH_FLAG, dtype: int64\n"
634 | ]
635 | }
636 | ],
637 | "source": [
638 | "# Check sampling numbers\n",
639 | "y_resampled = pd.Series(y_resampled)\n",
640 | "print(y_resampled.value_counts())\n",
641 | "print(deaths.value_counts())"
642 | ]
643 | },
644 | {
645 | "cell_type": "markdown",
646 | "metadata": {},
647 | "source": [
648 | "So probably not large enough #s for heart attacks alone, instead of undersampling will use SMOTE to oversample"
649 | ]
650 | },
651 | {
652 | "cell_type": "code",
653 | "execution_count": 36,
654 | "metadata": {},
655 | "outputs": [],
656 | "source": [
657 | "from imblearn.over_sampling import SMOTE \n",
658 | "sm = SMOTE(random_state=42)\n",
659 | "X_resampled, y_resampled = sm.fit_sample(ml_data, heart_attacks)"
660 | ]
661 | },
662 | {
663 | "cell_type": "code",
664 | "execution_count": 37,
665 | "metadata": {},
666 | "outputs": [
667 | {
668 | "name": "stdout",
669 | "output_type": "stream",
670 | "text": [
671 | "1 44449\n",
672 | "0 44449\n",
673 | "dtype: int64\n"
674 | ]
675 | }
676 | ],
677 | "source": [
678 | "y_resampled = pd.Series(y_resampled)\n",
679 | "print(y_resampled.value_counts())"
680 | ]
681 | },
682 | {
683 | "cell_type": "code",
684 | "execution_count": 77,
685 | "metadata": {},
686 | "outputs": [
687 | {
688 | "data": {
689 | "text/plain": [
690 | "0.7068435754189945"
691 | ]
692 | },
693 | "execution_count": 77,
694 | "metadata": {},
695 | "output_type": "execute_result"
696 | }
697 | ],
698 | "source": [
699 | "base_model_XG = GradientBoostingClassifier()\n",
700 | "scores = cross_val_score(base_model_XG, X_resampled, y_resampled, cv=5)\n",
701 | "scores.mean()"
702 | ]
703 | },
704 | {
705 | "cell_type": "markdown",
706 | "metadata": {},
707 | "source": [
708 | "High accuracy with SMOTE, likely overfitting..."
709 | ]
710 | },
711 | {
712 | "cell_type": "markdown",
713 | "metadata": {},
714 | "source": [
715 | "#### Next lets look at a model for atherosclerosis death predictions"
716 | ]
717 | },
718 | {
719 | "cell_type": "code",
720 | "execution_count": 82,
721 | "metadata": {},
722 | "outputs": [],
723 | "source": [
724 | "# First, define atherosclerosis diagnoses from non-atherosclerosis diagnoses\n",
725 | "athero_pre = demographics[demographics['OLD_FLAG']==0]\n",
726 | "athero_pos = athero_pre[athero_pre['ATHERO_DIAGNOSIS_FLAG']== 1]\n",
727 | "athero_neg = athero_pre[athero_pre['ATHERO_DIAGNOSIS_FLAG']==0]\n",
728 | "\n",
729 | "# Clean data sets\n",
730 | "del athero_neg['CAUSE']\n",
731 | "del athero_pos['CAUSE']\n",
732 | "\n",
733 | "del athero_neg['ATHERO_DIAGNOSIS_FLAG']\n",
734 | "del athero_pos['ATHERO_DIAGNOSIS_FLAG']\n",
735 | "\n",
736 | "del athero_neg['OLD_FLAG']\n",
737 | "del athero_pos['OLD_FLAG']\n",
738 | "\n",
739 | "del athero_neg['OUTSIDE_DEATH_FLAG']\n",
740 | "del athero_pos['OUTSIDE_DEATH_FLAG']\n",
741 | "\n",
742 | "del athero_neg['SUBJECT_ID']\n",
743 | "del athero_pos['SUBJECT_ID']\n",
744 | "\n",
745 | "del athero_neg['DOB']\n",
746 | "del athero_pos['DOB']\n",
747 | "\n",
748 | "del athero_neg['DOD']\n",
749 | "del athero_pos['DOD']"
750 | ]
751 | },
752 | {
753 | "cell_type": "code",
754 | "execution_count": 83,
755 | "metadata": {},
756 | "outputs": [],
757 | "source": [
758 | "athero_pos['DOA']\n",
759 | "del athero_pos['DOA']\n",
760 | "del athero_neg['DOA']\n",
761 | "\n",
762 | "athero_neg['HEART_ATTACK_FLAG']\n",
763 | "del athero_neg['HEART_ATTACK_FLAG']\n",
764 | "del athero_pos['HEART_ATTACK_FLAG']\n",
765 | "\n",
766 | "del athero_pos['Unnamed: 0']"
767 | ]
768 | },
769 | {
770 | "cell_type": "code",
771 | "execution_count": 84,
772 | "metadata": {},
773 | "outputs": [
774 | {
775 | "data": {
776 | "text/plain": [
777 | "10203"
778 | ]
779 | },
780 | "execution_count": 84,
781 | "metadata": {},
782 | "output_type": "execute_result"
783 | }
784 | ],
785 | "source": [
786 | "len(athero_pos)"
787 | ]
788 | },
789 | {
790 | "cell_type": "code",
791 | "execution_count": 85,
792 | "metadata": {},
793 | "outputs": [
794 | {
795 | "data": {
796 | "text/html": [
797 | "\n",
798 | "\n",
811 | "
\n",
812 | " \n",
813 | " \n",
814 | " | \n",
815 | " GENDER | \n",
816 | " ADMIT_AGE | \n",
817 | " ETHNICITY | \n",
818 | " MARITAL_STATUS | \n",
819 | " LANGUAGE | \n",
820 | " RELIGION | \n",
821 | " INSURANCE | \n",
822 | " ADMISSION_LOCATION | \n",
823 | " #ADMISSIONS | \n",
824 | " DEATH_FLAG | \n",
825 | " HEART_DEATH_FLAG | \n",
826 | "
\n",
827 | " \n",
828 | " \n",
829 | " \n",
830 | " 9 | \n",
831 | " M | \n",
832 | " 69.641096 | \n",
833 | " WHITE | \n",
834 | " MARRIED | \n",
835 | " UKNOWN | \n",
836 | " CATHOLIC | \n",
837 | " Private | \n",
838 | " TRANSFER FROM HOSP/EXTRAM | \n",
839 | " 4 | \n",
840 | " 1 | \n",
841 | " 0 | \n",
842 | "
\n",
843 | " \n",
844 | " 12 | \n",
845 | " F | \n",
846 | " 69.005479 | \n",
847 | " WHITE | \n",
848 | " MARRIED | \n",
849 | " ENGL | \n",
850 | " PROTESTANT QUAKER | \n",
851 | " Medicare | \n",
852 | " EMERGENCY ROOM ADMIT | \n",
853 | " 2 | \n",
854 | " 1 | \n",
855 | " 0 | \n",
856 | "
\n",
857 | " \n",
858 | " 17 | \n",
859 | " M | \n",
860 | " 87.882192 | \n",
861 | " WHITE | \n",
862 | " MARRIED | \n",
863 | " UKNOWN | \n",
864 | " JEWISH | \n",
865 | " Medicare | \n",
866 | " EMERGENCY ROOM ADMIT | \n",
867 | " 2 | \n",
868 | " 1 | \n",
869 | " 0 | \n",
870 | "
\n",
871 | " \n",
872 | " 19 | \n",
873 | " F | \n",
874 | " 76.871233 | \n",
875 | " WHITE | \n",
876 | " MARRIED | \n",
877 | " PORT | \n",
878 | " CATHOLIC | \n",
879 | " Medicare | \n",
880 | " TRANSFER FROM HOSP/EXTRAM | \n",
881 | " 4 | \n",
882 | " 1 | \n",
883 | " 1 | \n",
884 | "
\n",
885 | " \n",
886 | " 22 | \n",
887 | " F | \n",
888 | " 85.726027 | \n",
889 | " BLACK/AFRICAN AMERICAN | \n",
890 | " WIDOWED | \n",
891 | " UKNOWN | \n",
892 | " CATHOLIC | \n",
893 | " Medicare | \n",
894 | " EMERGENCY ROOM ADMIT | \n",
895 | " 2 | \n",
896 | " 1 | \n",
897 | " 0 | \n",
898 | "
\n",
899 | " \n",
900 | "
\n",
901 | "
"
902 | ],
903 | "text/plain": [
904 | " GENDER ADMIT_AGE ETHNICITY MARITAL_STATUS LANGUAGE \\\n",
905 | "9 M 69.641096 WHITE MARRIED UKNOWN \n",
906 | "12 F 69.005479 WHITE MARRIED ENGL \n",
907 | "17 M 87.882192 WHITE MARRIED UKNOWN \n",
908 | "19 F 76.871233 WHITE MARRIED PORT \n",
909 | "22 F 85.726027 BLACK/AFRICAN AMERICAN WIDOWED UKNOWN \n",
910 | "\n",
911 | " RELIGION INSURANCE ADMISSION_LOCATION #ADMISSIONS \\\n",
912 | "9 CATHOLIC Private TRANSFER FROM HOSP/EXTRAM 4 \n",
913 | "12 PROTESTANT QUAKER Medicare EMERGENCY ROOM ADMIT 2 \n",
914 | "17 JEWISH Medicare EMERGENCY ROOM ADMIT 2 \n",
915 | "19 CATHOLIC Medicare TRANSFER FROM HOSP/EXTRAM 4 \n",
916 | "22 CATHOLIC Medicare EMERGENCY ROOM ADMIT 2 \n",
917 | "\n",
918 | " DEATH_FLAG HEART_DEATH_FLAG \n",
919 | "9 1 0 \n",
920 | "12 1 0 \n",
921 | "17 1 0 \n",
922 | "19 1 1 \n",
923 | "22 1 0 "
924 | ]
925 | },
926 | "execution_count": 85,
927 | "metadata": {},
928 | "output_type": "execute_result"
929 | }
930 | ],
931 | "source": [
932 | "athero_pos.head()"
933 | ]
934 | },
935 | {
936 | "cell_type": "code",
937 | "execution_count": 65,
938 | "metadata": {},
939 | "outputs": [],
940 | "source": [
941 | "# Create Outcome data sets\n",
942 | "athero_heartdeath = pd.Series(athero_pos['HEART_DEATH_FLAG'])\n",
943 | "athero_death = pd.Series(athero_pos['DEATH_FLAG'])"
944 | ]
945 | },
946 | {
947 | "cell_type": "code",
948 | "execution_count": 86,
949 | "metadata": {},
950 | "outputs": [],
951 | "source": [
952 | "del athero_pos['HEART_DEATH_FLAG']\n",
953 | "del athero_pos['DEATH_FLAG']"
954 | ]
955 | },
956 | {
957 | "cell_type": "code",
958 | "execution_count": 87,
959 | "metadata": {},
960 | "outputs": [],
961 | "source": [
962 | "# Get dummies\n",
963 | "athero_pos = pd.get_dummies(athero_pos, columns=['GENDER','ETHNICITY','MARITAL_STATUS', 'LANGUAGE', 'RELIGION', 'INSURANCE', 'ADMISSION_LOCATION'])"
964 | ]
965 | },
966 | {
967 | "cell_type": "code",
968 | "execution_count": 88,
969 | "metadata": {},
970 | "outputs": [
971 | {
972 | "name": "stdout",
973 | "output_type": "stream",
974 | "text": [
975 | "0 10001\n",
976 | "1 202\n",
977 | "Name: HEART_DEATH_FLAG, dtype: int64\n",
978 | "0 6576\n",
979 | "1 3627\n",
980 | "Name: DEATH_FLAG, dtype: int64\n"
981 | ]
982 | }
983 | ],
984 | "source": [
985 | "# Check outcome numbers\n",
986 | "print(athero_heartdeath.value_counts())\n",
987 | "print(athero_death.value_counts())"
988 | ]
989 | },
990 | {
991 | "cell_type": "code",
992 | "execution_count": 89,
993 | "metadata": {},
994 | "outputs": [
995 | {
996 | "data": {
997 | "text/plain": [
998 | "0.6961760077878063"
999 | ]
1000 | },
1001 | "execution_count": 89,
1002 | "metadata": {},
1003 | "output_type": "execute_result"
1004 | }
1005 | ],
1006 | "source": [
1007 | "base_model_XG = GradientBoostingClassifier()\n",
1008 | "scores = cross_val_score(base_model_XG, athero_pos, athero_death, cv=5)\n",
1009 | "scores.mean()"
1010 | ]
1011 | },
1012 | {
1013 | "cell_type": "markdown",
1014 | "metadata": {},
1015 | "source": [
1016 | "#### So for athero diagnosed patients, we are able to predict with 0.6961 5 fold CV accuracy"
1017 | ]
1018 | },
1019 | {
1020 | "cell_type": "code",
1021 | "execution_count": 90,
1022 | "metadata": {},
1023 | "outputs": [
1024 | {
1025 | "data": {
1026 | "text/html": [
1027 | "\n",
1028 | "\n",
1041 | "
\n",
1042 | " \n",
1043 | " \n",
1044 | " | \n",
1045 | " ADMIT_AGE | \n",
1046 | " #ADMISSIONS | \n",
1047 | " GENDER_F | \n",
1048 | " GENDER_M | \n",
1049 | " ETHNICITY_AMERICAN INDIAN/ALASKA NATIVE | \n",
1050 | " ETHNICITY_ASIAN | \n",
1051 | " ETHNICITY_ASIAN - ASIAN INDIAN | \n",
1052 | " ETHNICITY_ASIAN - CAMBODIAN | \n",
1053 | " ETHNICITY_ASIAN - CHINESE | \n",
1054 | " ETHNICITY_ASIAN - FILIPINO | \n",
1055 | " ... | \n",
1056 | " INSURANCE_Medicaid | \n",
1057 | " INSURANCE_Medicare | \n",
1058 | " INSURANCE_Private | \n",
1059 | " INSURANCE_Self Pay | \n",
1060 | " ADMISSION_LOCATION_CLINIC REFERRAL/PREMATURE | \n",
1061 | " ADMISSION_LOCATION_EMERGENCY ROOM ADMIT | \n",
1062 | " ADMISSION_LOCATION_PHYS REFERRAL/NORMAL DELI | \n",
1063 | " ADMISSION_LOCATION_TRANSFER FROM HOSP/EXTRAM | \n",
1064 | " ADMISSION_LOCATION_TRANSFER FROM OTHER HEALT | \n",
1065 | " ADMISSION_LOCATION_TRANSFER FROM SKILLED NUR | \n",
1066 | "
\n",
1067 | " \n",
1068 | " \n",
1069 | " \n",
1070 | " 9 | \n",
1071 | " 69.641096 | \n",
1072 | " 4 | \n",
1073 | " 0 | \n",
1074 | " 1 | \n",
1075 | " 0 | \n",
1076 | " 0 | \n",
1077 | " 0 | \n",
1078 | " 0 | \n",
1079 | " 0 | \n",
1080 | " 0 | \n",
1081 | " ... | \n",
1082 | " 0 | \n",
1083 | " 0 | \n",
1084 | " 1 | \n",
1085 | " 0 | \n",
1086 | " 0 | \n",
1087 | " 0 | \n",
1088 | " 0 | \n",
1089 | " 1 | \n",
1090 | " 0 | \n",
1091 | " 0 | \n",
1092 | "
\n",
1093 | " \n",
1094 | " 12 | \n",
1095 | " 69.005479 | \n",
1096 | " 2 | \n",
1097 | " 1 | \n",
1098 | " 0 | \n",
1099 | " 0 | \n",
1100 | " 0 | \n",
1101 | " 0 | \n",
1102 | " 0 | \n",
1103 | " 0 | \n",
1104 | " 0 | \n",
1105 | " ... | \n",
1106 | " 0 | \n",
1107 | " 1 | \n",
1108 | " 0 | \n",
1109 | " 0 | \n",
1110 | " 0 | \n",
1111 | " 1 | \n",
1112 | " 0 | \n",
1113 | " 0 | \n",
1114 | " 0 | \n",
1115 | " 0 | \n",
1116 | "
\n",
1117 | " \n",
1118 | " 17 | \n",
1119 | " 87.882192 | \n",
1120 | " 2 | \n",
1121 | " 0 | \n",
1122 | " 1 | \n",
1123 | " 0 | \n",
1124 | " 0 | \n",
1125 | " 0 | \n",
1126 | " 0 | \n",
1127 | " 0 | \n",
1128 | " 0 | \n",
1129 | " ... | \n",
1130 | " 0 | \n",
1131 | " 1 | \n",
1132 | " 0 | \n",
1133 | " 0 | \n",
1134 | " 0 | \n",
1135 | " 1 | \n",
1136 | " 0 | \n",
1137 | " 0 | \n",
1138 | " 0 | \n",
1139 | " 0 | \n",
1140 | "
\n",
1141 | " \n",
1142 | " 19 | \n",
1143 | " 76.871233 | \n",
1144 | " 4 | \n",
1145 | " 1 | \n",
1146 | " 0 | \n",
1147 | " 0 | \n",
1148 | " 0 | \n",
1149 | " 0 | \n",
1150 | " 0 | \n",
1151 | " 0 | \n",
1152 | " 0 | \n",
1153 | " ... | \n",
1154 | " 0 | \n",
1155 | " 1 | \n",
1156 | " 0 | \n",
1157 | " 0 | \n",
1158 | " 0 | \n",
1159 | " 0 | \n",
1160 | " 0 | \n",
1161 | " 1 | \n",
1162 | " 0 | \n",
1163 | " 0 | \n",
1164 | "
\n",
1165 | " \n",
1166 | " 22 | \n",
1167 | " 85.726027 | \n",
1168 | " 2 | \n",
1169 | " 1 | \n",
1170 | " 0 | \n",
1171 | " 0 | \n",
1172 | " 0 | \n",
1173 | " 0 | \n",
1174 | " 0 | \n",
1175 | " 0 | \n",
1176 | " 0 | \n",
1177 | " ... | \n",
1178 | " 0 | \n",
1179 | " 1 | \n",
1180 | " 0 | \n",
1181 | " 0 | \n",
1182 | " 0 | \n",
1183 | " 1 | \n",
1184 | " 0 | \n",
1185 | " 0 | \n",
1186 | " 0 | \n",
1187 | " 0 | \n",
1188 | "
\n",
1189 | " \n",
1190 | "
\n",
1191 | "
5 rows × 121 columns
\n",
1192 | "
"
1193 | ],
1194 | "text/plain": [
1195 | " ADMIT_AGE #ADMISSIONS GENDER_F GENDER_M \\\n",
1196 | "9 69.641096 4 0 1 \n",
1197 | "12 69.005479 2 1 0 \n",
1198 | "17 87.882192 2 0 1 \n",
1199 | "19 76.871233 4 1 0 \n",
1200 | "22 85.726027 2 1 0 \n",
1201 | "\n",
1202 | " ETHNICITY_AMERICAN INDIAN/ALASKA NATIVE ETHNICITY_ASIAN \\\n",
1203 | "9 0 0 \n",
1204 | "12 0 0 \n",
1205 | "17 0 0 \n",
1206 | "19 0 0 \n",
1207 | "22 0 0 \n",
1208 | "\n",
1209 | " ETHNICITY_ASIAN - ASIAN INDIAN ETHNICITY_ASIAN - CAMBODIAN \\\n",
1210 | "9 0 0 \n",
1211 | "12 0 0 \n",
1212 | "17 0 0 \n",
1213 | "19 0 0 \n",
1214 | "22 0 0 \n",
1215 | "\n",
1216 | " ETHNICITY_ASIAN - CHINESE ETHNICITY_ASIAN - FILIPINO \\\n",
1217 | "9 0 0 \n",
1218 | "12 0 0 \n",
1219 | "17 0 0 \n",
1220 | "19 0 0 \n",
1221 | "22 0 0 \n",
1222 | "\n",
1223 | " ... INSURANCE_Medicaid \\\n",
1224 | "9 ... 0 \n",
1225 | "12 ... 0 \n",
1226 | "17 ... 0 \n",
1227 | "19 ... 0 \n",
1228 | "22 ... 0 \n",
1229 | "\n",
1230 | " INSURANCE_Medicare INSURANCE_Private INSURANCE_Self Pay \\\n",
1231 | "9 0 1 0 \n",
1232 | "12 1 0 0 \n",
1233 | "17 1 0 0 \n",
1234 | "19 1 0 0 \n",
1235 | "22 1 0 0 \n",
1236 | "\n",
1237 | " ADMISSION_LOCATION_CLINIC REFERRAL/PREMATURE \\\n",
1238 | "9 0 \n",
1239 | "12 0 \n",
1240 | "17 0 \n",
1241 | "19 0 \n",
1242 | "22 0 \n",
1243 | "\n",
1244 | " ADMISSION_LOCATION_EMERGENCY ROOM ADMIT \\\n",
1245 | "9 0 \n",
1246 | "12 1 \n",
1247 | "17 1 \n",
1248 | "19 0 \n",
1249 | "22 1 \n",
1250 | "\n",
1251 | " ADMISSION_LOCATION_PHYS REFERRAL/NORMAL DELI \\\n",
1252 | "9 0 \n",
1253 | "12 0 \n",
1254 | "17 0 \n",
1255 | "19 0 \n",
1256 | "22 0 \n",
1257 | "\n",
1258 | " ADMISSION_LOCATION_TRANSFER FROM HOSP/EXTRAM \\\n",
1259 | "9 1 \n",
1260 | "12 0 \n",
1261 | "17 0 \n",
1262 | "19 1 \n",
1263 | "22 0 \n",
1264 | "\n",
1265 | " ADMISSION_LOCATION_TRANSFER FROM OTHER HEALT \\\n",
1266 | "9 0 \n",
1267 | "12 0 \n",
1268 | "17 0 \n",
1269 | "19 0 \n",
1270 | "22 0 \n",
1271 | "\n",
1272 | " ADMISSION_LOCATION_TRANSFER FROM SKILLED NUR \n",
1273 | "9 0 \n",
1274 | "12 0 \n",
1275 | "17 0 \n",
1276 | "19 0 \n",
1277 | "22 0 \n",
1278 | "\n",
1279 | "[5 rows x 121 columns]"
1280 | ]
1281 | },
1282 | "execution_count": 90,
1283 | "metadata": {},
1284 | "output_type": "execute_result"
1285 | }
1286 | ],
1287 | "source": [
1288 | "athero_pos.head()"
1289 | ]
1290 | },
1291 | {
1292 | "cell_type": "code",
1293 | "execution_count": null,
1294 | "metadata": {},
1295 | "outputs": [],
1296 | "source": []
1297 | }
1298 | ],
1299 | "metadata": {
1300 | "kernelspec": {
1301 | "display_name": "Python 2",
1302 | "language": "python",
1303 | "name": "python2"
1304 | },
1305 | "language_info": {
1306 | "codemirror_mode": {
1307 | "name": "ipython",
1308 | "version": 2
1309 | },
1310 | "file_extension": ".py",
1311 | "mimetype": "text/x-python",
1312 | "name": "python",
1313 | "nbconvert_exporter": "python",
1314 | "pygments_lexer": "ipython2",
1315 | "version": "2.7.14"
1316 | }
1317 | },
1318 | "nbformat": 4,
1319 | "nbformat_minor": 2
1320 | }
1321 |
--------------------------------------------------------------------------------
/weights.best.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rjgiedt/Cardiovascular_Death_Prediction/90f74112b30af528d1351bfd8daa0cc1bf6ae46a/weights.best.hdf5
--------------------------------------------------------------------------------