├── README.md
├── .gitignore
├── mess-the-data.ipynb
├── sklearn-pandas-catboost.ipynb
└── sklearn-pandas.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # catboost-with-pipelines
2 | Using catboost model with sklearn pipelines.
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/mess-the-data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import numpy as np"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {},
16 | "source": [
17 | "## Read data"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {},
24 | "outputs": [
25 | {
26 | "data": {
27 | "text/html": [
28 | "
\n",
29 | "\n",
42 | "
\n",
43 | " \n",
44 | " \n",
45 | " | \n",
46 | " age | \n",
47 | " gender | \n",
48 | " height | \n",
49 | " weight | \n",
50 | " ap_hi | \n",
51 | " ap_lo | \n",
52 | " cholesterol | \n",
53 | " gluc | \n",
54 | " smoke | \n",
55 | " alco | \n",
56 | " active | \n",
57 | " cardio | \n",
58 | "
\n",
59 | " \n",
60 | " \n",
61 | " \n",
62 | " | count | \n",
63 | " 70000.000000 | \n",
64 | " 70000.000000 | \n",
65 | " 70000.000000 | \n",
66 | " 70000.000000 | \n",
67 | " 70000.000000 | \n",
68 | " 70000.000000 | \n",
69 | " 70000.000000 | \n",
70 | " 70000.000000 | \n",
71 | " 70000.000000 | \n",
72 | " 70000.000000 | \n",
73 | " 70000.000000 | \n",
74 | " 70000.000000 | \n",
75 | "
\n",
76 | " \n",
77 | " | mean | \n",
78 | " 19468.865814 | \n",
79 | " 1.349571 | \n",
80 | " 164.359229 | \n",
81 | " 74.205690 | \n",
82 | " 128.817286 | \n",
83 | " 96.630414 | \n",
84 | " 1.366871 | \n",
85 | " 1.226457 | \n",
86 | " 0.088129 | \n",
87 | " 0.053771 | \n",
88 | " 0.803729 | \n",
89 | " 0.499700 | \n",
90 | "
\n",
91 | " \n",
92 | " | std | \n",
93 | " 2467.251667 | \n",
94 | " 0.476838 | \n",
95 | " 8.210126 | \n",
96 | " 14.395757 | \n",
97 | " 154.011419 | \n",
98 | " 188.472530 | \n",
99 | " 0.680250 | \n",
100 | " 0.572270 | \n",
101 | " 0.283484 | \n",
102 | " 0.225568 | \n",
103 | " 0.397179 | \n",
104 | " 0.500003 | \n",
105 | "
\n",
106 | " \n",
107 | " | min | \n",
108 | " 10798.000000 | \n",
109 | " 1.000000 | \n",
110 | " 55.000000 | \n",
111 | " 10.000000 | \n",
112 | " -150.000000 | \n",
113 | " -70.000000 | \n",
114 | " 1.000000 | \n",
115 | " 1.000000 | \n",
116 | " 0.000000 | \n",
117 | " 0.000000 | \n",
118 | " 0.000000 | \n",
119 | " 0.000000 | \n",
120 | "
\n",
121 | " \n",
122 | " | 25% | \n",
123 | " 17664.000000 | \n",
124 | " 1.000000 | \n",
125 | " 159.000000 | \n",
126 | " 65.000000 | \n",
127 | " 120.000000 | \n",
128 | " 80.000000 | \n",
129 | " 1.000000 | \n",
130 | " 1.000000 | \n",
131 | " 0.000000 | \n",
132 | " 0.000000 | \n",
133 | " 1.000000 | \n",
134 | " 0.000000 | \n",
135 | "
\n",
136 | " \n",
137 | " | 50% | \n",
138 | " 19703.000000 | \n",
139 | " 1.000000 | \n",
140 | " 165.000000 | \n",
141 | " 72.000000 | \n",
142 | " 120.000000 | \n",
143 | " 80.000000 | \n",
144 | " 1.000000 | \n",
145 | " 1.000000 | \n",
146 | " 0.000000 | \n",
147 | " 0.000000 | \n",
148 | " 1.000000 | \n",
149 | " 0.000000 | \n",
150 | "
\n",
151 | " \n",
152 | " | 75% | \n",
153 | " 21327.000000 | \n",
154 | " 2.000000 | \n",
155 | " 170.000000 | \n",
156 | " 82.000000 | \n",
157 | " 140.000000 | \n",
158 | " 90.000000 | \n",
159 | " 2.000000 | \n",
160 | " 1.000000 | \n",
161 | " 0.000000 | \n",
162 | " 0.000000 | \n",
163 | " 1.000000 | \n",
164 | " 1.000000 | \n",
165 | "
\n",
166 | " \n",
167 | " | max | \n",
168 | " 23713.000000 | \n",
169 | " 2.000000 | \n",
170 | " 250.000000 | \n",
171 | " 200.000000 | \n",
172 | " 16020.000000 | \n",
173 | " 11000.000000 | \n",
174 | " 3.000000 | \n",
175 | " 3.000000 | \n",
176 | " 1.000000 | \n",
177 | " 1.000000 | \n",
178 | " 1.000000 | \n",
179 | " 1.000000 | \n",
180 | "
\n",
181 | " \n",
182 | "
\n",
183 | "
"
184 | ],
185 | "text/plain": [
186 | " age gender height weight ap_hi \\\n",
187 | "count 70000.000000 70000.000000 70000.000000 70000.000000 70000.000000 \n",
188 | "mean 19468.865814 1.349571 164.359229 74.205690 128.817286 \n",
189 | "std 2467.251667 0.476838 8.210126 14.395757 154.011419 \n",
190 | "min 10798.000000 1.000000 55.000000 10.000000 -150.000000 \n",
191 | "25% 17664.000000 1.000000 159.000000 65.000000 120.000000 \n",
192 | "50% 19703.000000 1.000000 165.000000 72.000000 120.000000 \n",
193 | "75% 21327.000000 2.000000 170.000000 82.000000 140.000000 \n",
194 | "max 23713.000000 2.000000 250.000000 200.000000 16020.000000 \n",
195 | "\n",
196 | " ap_lo cholesterol gluc smoke alco \\\n",
197 | "count 70000.000000 70000.000000 70000.000000 70000.000000 70000.000000 \n",
198 | "mean 96.630414 1.366871 1.226457 0.088129 0.053771 \n",
199 | "std 188.472530 0.680250 0.572270 0.283484 0.225568 \n",
200 | "min -70.000000 1.000000 1.000000 0.000000 0.000000 \n",
201 | "25% 80.000000 1.000000 1.000000 0.000000 0.000000 \n",
202 | "50% 80.000000 1.000000 1.000000 0.000000 0.000000 \n",
203 | "75% 90.000000 2.000000 1.000000 0.000000 0.000000 \n",
204 | "max 11000.000000 3.000000 3.000000 1.000000 1.000000 \n",
205 | "\n",
206 | " active cardio \n",
207 | "count 70000.000000 70000.000000 \n",
208 | "mean 0.803729 0.499700 \n",
209 | "std 0.397179 0.500003 \n",
210 | "min 0.000000 0.000000 \n",
211 | "25% 1.000000 0.000000 \n",
212 | "50% 1.000000 0.000000 \n",
213 | "75% 1.000000 1.000000 \n",
214 | "max 1.000000 1.000000 "
215 | ]
216 | },
217 | "metadata": {},
218 | "output_type": "display_data"
219 | },
220 | {
221 | "data": {
222 | "text/html": [
223 | "\n",
224 | "\n",
237 | "
\n",
238 | " \n",
239 | " \n",
240 | " | \n",
241 | " age | \n",
242 | " gender | \n",
243 | " height | \n",
244 | " weight | \n",
245 | " ap_hi | \n",
246 | " ap_lo | \n",
247 | " cholesterol | \n",
248 | " gluc | \n",
249 | " smoke | \n",
250 | " alco | \n",
251 | " active | \n",
252 | " cardio | \n",
253 | "
\n",
254 | " \n",
255 | " | id | \n",
256 | " | \n",
257 | " | \n",
258 | " | \n",
259 | " | \n",
260 | " | \n",
261 | " | \n",
262 | " | \n",
263 | " | \n",
264 | " | \n",
265 | " | \n",
266 | " | \n",
267 | " | \n",
268 | "
\n",
269 | " \n",
270 | " \n",
271 | " \n",
272 | " | 0 | \n",
273 | " 18393 | \n",
274 | " 2 | \n",
275 | " 168 | \n",
276 | " 62.0 | \n",
277 | " 110 | \n",
278 | " 80 | \n",
279 | " 1 | \n",
280 | " 1 | \n",
281 | " 0 | \n",
282 | " 0 | \n",
283 | " 1 | \n",
284 | " 0 | \n",
285 | "
\n",
286 | " \n",
287 | " | 1 | \n",
288 | " 20228 | \n",
289 | " 1 | \n",
290 | " 156 | \n",
291 | " 85.0 | \n",
292 | " 140 | \n",
293 | " 90 | \n",
294 | " 3 | \n",
295 | " 1 | \n",
296 | " 0 | \n",
297 | " 0 | \n",
298 | " 1 | \n",
299 | " 1 | \n",
300 | "
\n",
301 | " \n",
302 | " | 2 | \n",
303 | " 18857 | \n",
304 | " 1 | \n",
305 | " 165 | \n",
306 | " 64.0 | \n",
307 | " 130 | \n",
308 | " 70 | \n",
309 | " 3 | \n",
310 | " 1 | \n",
311 | " 0 | \n",
312 | " 0 | \n",
313 | " 0 | \n",
314 | " 1 | \n",
315 | "
\n",
316 | " \n",
317 | " | 3 | \n",
318 | " 17623 | \n",
319 | " 2 | \n",
320 | " 169 | \n",
321 | " 82.0 | \n",
322 | " 150 | \n",
323 | " 100 | \n",
324 | " 1 | \n",
325 | " 1 | \n",
326 | " 0 | \n",
327 | " 0 | \n",
328 | " 1 | \n",
329 | " 1 | \n",
330 | "
\n",
331 | " \n",
332 | " | 4 | \n",
333 | " 17474 | \n",
334 | " 1 | \n",
335 | " 156 | \n",
336 | " 56.0 | \n",
337 | " 100 | \n",
338 | " 60 | \n",
339 | " 1 | \n",
340 | " 1 | \n",
341 | " 0 | \n",
342 | " 0 | \n",
343 | " 0 | \n",
344 | " 0 | \n",
345 | "
\n",
346 | " \n",
347 | "
\n",
348 | "
"
349 | ],
350 | "text/plain": [
351 | " age gender height weight ap_hi ap_lo cholesterol gluc smoke \\\n",
352 | "id \n",
353 | "0 18393 2 168 62.0 110 80 1 1 0 \n",
354 | "1 20228 1 156 85.0 140 90 3 1 0 \n",
355 | "2 18857 1 165 64.0 130 70 3 1 0 \n",
356 | "3 17623 2 169 82.0 150 100 1 1 0 \n",
357 | "4 17474 1 156 56.0 100 60 1 1 0 \n",
358 | "\n",
359 | " alco active cardio \n",
360 | "id \n",
361 | "0 0 1 0 \n",
362 | "1 0 1 1 \n",
363 | "2 0 0 1 \n",
364 | "3 0 1 1 \n",
365 | "4 0 0 0 "
366 | ]
367 | },
368 | "metadata": {},
369 | "output_type": "display_data"
370 | }
371 | ],
372 | "source": [
373 | "np.random.seed(seed=42)\n",
374 | "df_data = pd.read_csv(\"./cardiovascular-disease-dataset/original/cardio_train.csv\", sep=';', index_col=\"id\")\n",
375 | "display(df_data.describe())\n",
376 | "display(df_data.head())"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "metadata": {},
382 | "source": [
383 | "## Fill 5% random values with NaNs"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": 3,
389 | "metadata": {},
390 | "outputs": [],
391 | "source": [
392 | "to_na_indices = np.random.randint(low=0, high=df_data.shape[0], size=int(0.05 * df_data.shape[0]))\n",
393 | "df_data.iloc[to_na_indices, df_data.columns.get_loc(\"height\")] = np.nan\n",
394 | "\n",
395 | "to_na_indices = np.random.randint(low=0, high=df_data.shape[0], size=int(0.05 * df_data.shape[0]))\n",
396 | "df_data.iloc[to_na_indices, df_data.columns.get_loc(\"weight\")] = np.nan\n",
397 | "\n",
398 | "to_na_indices = np.random.randint(low=0, high=df_data.shape[0], size=int(0.05 * df_data.shape[0]))\n",
399 | "df_data.iloc[to_na_indices, df_data.columns.get_loc(\"cholesterol\")] = np.nan"
400 | ]
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "metadata": {},
405 | "source": [
406 | "## Decode categorical values with text"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 4,
412 | "metadata": {},
413 | "outputs": [
414 | {
415 | "data": {
416 | "text/html": [
417 | "\n",
418 | "\n",
431 | "
\n",
432 | " \n",
433 | " \n",
434 | " | \n",
435 | " age | \n",
436 | " gender | \n",
437 | " height | \n",
438 | " weight | \n",
439 | " ap_hi | \n",
440 | " ap_lo | \n",
441 | " cholesterol | \n",
442 | " gluc | \n",
443 | " smoke | \n",
444 | " alco | \n",
445 | " active | \n",
446 | " cardio | \n",
447 | "
\n",
448 | " \n",
449 | " \n",
450 | " \n",
451 | " | count | \n",
452 | " 70000.000000 | \n",
453 | " 70000 | \n",
454 | " 66591.000000 | \n",
455 | " 66578.000000 | \n",
456 | " 70000.000000 | \n",
457 | " 70000.000000 | \n",
458 | " 66589 | \n",
459 | " 70000 | \n",
460 | " 70000.000000 | \n",
461 | " 70000.000000 | \n",
462 | " 70000.000000 | \n",
463 | " 70000.000000 | \n",
464 | "
\n",
465 | " \n",
466 | " | unique | \n",
467 | " NaN | \n",
468 | " 2 | \n",
469 | " NaN | \n",
470 | " NaN | \n",
471 | " NaN | \n",
472 | " NaN | \n",
473 | " 3 | \n",
474 | " 3 | \n",
475 | " NaN | \n",
476 | " NaN | \n",
477 | " NaN | \n",
478 | " NaN | \n",
479 | "
\n",
480 | " \n",
481 | " | top | \n",
482 | " NaN | \n",
483 | " women | \n",
484 | " NaN | \n",
485 | " NaN | \n",
486 | " NaN | \n",
487 | " NaN | \n",
488 | " normal | \n",
489 | " normal | \n",
490 | " NaN | \n",
491 | " NaN | \n",
492 | " NaN | \n",
493 | " NaN | \n",
494 | "
\n",
495 | " \n",
496 | " | freq | \n",
497 | " NaN | \n",
498 | " 45530 | \n",
499 | " NaN | \n",
500 | " NaN | \n",
501 | " NaN | \n",
502 | " NaN | \n",
503 | " 49789 | \n",
504 | " 59479 | \n",
505 | " NaN | \n",
506 | " NaN | \n",
507 | " NaN | \n",
508 | " NaN | \n",
509 | "
\n",
510 | " \n",
511 | " | mean | \n",
512 | " 19468.865814 | \n",
513 | " NaN | \n",
514 | " 164.361205 | \n",
515 | " 74.210467 | \n",
516 | " 128.817286 | \n",
517 | " 96.630414 | \n",
518 | " NaN | \n",
519 | " NaN | \n",
520 | " 0.088129 | \n",
521 | " 0.053771 | \n",
522 | " 0.803729 | \n",
523 | " 0.499700 | \n",
524 | "
\n",
525 | " \n",
526 | " | std | \n",
527 | " 2467.251667 | \n",
528 | " NaN | \n",
529 | " 8.226411 | \n",
530 | " 14.397678 | \n",
531 | " 154.011419 | \n",
532 | " 188.472530 | \n",
533 | " NaN | \n",
534 | " NaN | \n",
535 | " 0.283484 | \n",
536 | " 0.225568 | \n",
537 | " 0.397179 | \n",
538 | " 0.500003 | \n",
539 | "
\n",
540 | " \n",
541 | " | min | \n",
542 | " 10798.000000 | \n",
543 | " NaN | \n",
544 | " 55.000000 | \n",
545 | " 10.000000 | \n",
546 | " -150.000000 | \n",
547 | " -70.000000 | \n",
548 | " NaN | \n",
549 | " NaN | \n",
550 | " 0.000000 | \n",
551 | " 0.000000 | \n",
552 | " 0.000000 | \n",
553 | " 0.000000 | \n",
554 | "
\n",
555 | " \n",
556 | " | 25% | \n",
557 | " 17664.000000 | \n",
558 | " NaN | \n",
559 | " 159.000000 | \n",
560 | " 65.000000 | \n",
561 | " 120.000000 | \n",
562 | " 80.000000 | \n",
563 | " NaN | \n",
564 | " NaN | \n",
565 | " 0.000000 | \n",
566 | " 0.000000 | \n",
567 | " 1.000000 | \n",
568 | " 0.000000 | \n",
569 | "
\n",
570 | " \n",
571 | " | 50% | \n",
572 | " 19703.000000 | \n",
573 | " NaN | \n",
574 | " 165.000000 | \n",
575 | " 72.000000 | \n",
576 | " 120.000000 | \n",
577 | " 80.000000 | \n",
578 | " NaN | \n",
579 | " NaN | \n",
580 | " 0.000000 | \n",
581 | " 0.000000 | \n",
582 | " 1.000000 | \n",
583 | " 0.000000 | \n",
584 | "
\n",
585 | " \n",
586 | " | 75% | \n",
587 | " 21327.000000 | \n",
588 | " NaN | \n",
589 | " 170.000000 | \n",
590 | " 82.000000 | \n",
591 | " 140.000000 | \n",
592 | " 90.000000 | \n",
593 | " NaN | \n",
594 | " NaN | \n",
595 | " 0.000000 | \n",
596 | " 0.000000 | \n",
597 | " 1.000000 | \n",
598 | " 1.000000 | \n",
599 | "
\n",
600 | " \n",
601 | " | max | \n",
602 | " 23713.000000 | \n",
603 | " NaN | \n",
604 | " 250.000000 | \n",
605 | " 200.000000 | \n",
606 | " 16020.000000 | \n",
607 | " 11000.000000 | \n",
608 | " NaN | \n",
609 | " NaN | \n",
610 | " 1.000000 | \n",
611 | " 1.000000 | \n",
612 | " 1.000000 | \n",
613 | " 1.000000 | \n",
614 | "
\n",
615 | " \n",
616 | "
\n",
617 | "
"
618 | ],
619 | "text/plain": [
620 | " age gender height weight ap_hi \\\n",
621 | "count 70000.000000 70000 66591.000000 66578.000000 70000.000000 \n",
622 | "unique NaN 2 NaN NaN NaN \n",
623 | "top NaN women NaN NaN NaN \n",
624 | "freq NaN 45530 NaN NaN NaN \n",
625 | "mean 19468.865814 NaN 164.361205 74.210467 128.817286 \n",
626 | "std 2467.251667 NaN 8.226411 14.397678 154.011419 \n",
627 | "min 10798.000000 NaN 55.000000 10.000000 -150.000000 \n",
628 | "25% 17664.000000 NaN 159.000000 65.000000 120.000000 \n",
629 | "50% 19703.000000 NaN 165.000000 72.000000 120.000000 \n",
630 | "75% 21327.000000 NaN 170.000000 82.000000 140.000000 \n",
631 | "max 23713.000000 NaN 250.000000 200.000000 16020.000000 \n",
632 | "\n",
633 | " ap_lo cholesterol gluc smoke alco \\\n",
634 | "count 70000.000000 66589 70000 70000.000000 70000.000000 \n",
635 | "unique NaN 3 3 NaN NaN \n",
636 | "top NaN normal normal NaN NaN \n",
637 | "freq NaN 49789 59479 NaN NaN \n",
638 | "mean 96.630414 NaN NaN 0.088129 0.053771 \n",
639 | "std 188.472530 NaN NaN 0.283484 0.225568 \n",
640 | "min -70.000000 NaN NaN 0.000000 0.000000 \n",
641 | "25% 80.000000 NaN NaN 0.000000 0.000000 \n",
642 | "50% 80.000000 NaN NaN 0.000000 0.000000 \n",
643 | "75% 90.000000 NaN NaN 0.000000 0.000000 \n",
644 | "max 11000.000000 NaN NaN 1.000000 1.000000 \n",
645 | "\n",
646 | " active cardio \n",
647 | "count 70000.000000 70000.000000 \n",
648 | "unique NaN NaN \n",
649 | "top NaN NaN \n",
650 | "freq NaN NaN \n",
651 | "mean 0.803729 0.499700 \n",
652 | "std 0.397179 0.500003 \n",
653 | "min 0.000000 0.000000 \n",
654 | "25% 1.000000 0.000000 \n",
655 | "50% 1.000000 0.000000 \n",
656 | "75% 1.000000 1.000000 \n",
657 | "max 1.000000 1.000000 "
658 | ]
659 | },
660 | "metadata": {},
661 | "output_type": "display_data"
662 | },
663 | {
664 | "data": {
665 | "text/html": [
666 | "\n",
667 | "\n",
680 | "
\n",
681 | " \n",
682 | " \n",
683 | " | \n",
684 | " age | \n",
685 | " gender | \n",
686 | " height | \n",
687 | " weight | \n",
688 | " ap_hi | \n",
689 | " ap_lo | \n",
690 | " cholesterol | \n",
691 | " gluc | \n",
692 | " smoke | \n",
693 | " alco | \n",
694 | " active | \n",
695 | " cardio | \n",
696 | "
\n",
697 | " \n",
698 | " | id | \n",
699 | " | \n",
700 | " | \n",
701 | " | \n",
702 | " | \n",
703 | " | \n",
704 | " | \n",
705 | " | \n",
706 | " | \n",
707 | " | \n",
708 | " | \n",
709 | " | \n",
710 | " | \n",
711 | "
\n",
712 | " \n",
713 | " \n",
714 | " \n",
715 | " | 0 | \n",
716 | " 18393 | \n",
717 | " men | \n",
718 | " 168.0 | \n",
719 | " 62.0 | \n",
720 | " 110 | \n",
721 | " 80 | \n",
722 | " normal | \n",
723 | " normal | \n",
724 | " 0 | \n",
725 | " 0 | \n",
726 | " 1 | \n",
727 | " 0 | \n",
728 | "
\n",
729 | " \n",
730 | " | 1 | \n",
731 | " 20228 | \n",
732 | " women | \n",
733 | " 156.0 | \n",
734 | " 85.0 | \n",
735 | " 140 | \n",
736 | " 90 | \n",
737 | " well_above_normal | \n",
738 | " normal | \n",
739 | " 0 | \n",
740 | " 0 | \n",
741 | " 1 | \n",
742 | " 1 | \n",
743 | "
\n",
744 | " \n",
745 | " | 2 | \n",
746 | " 18857 | \n",
747 | " women | \n",
748 | " 165.0 | \n",
749 | " 64.0 | \n",
750 | " 130 | \n",
751 | " 70 | \n",
752 | " NaN | \n",
753 | " normal | \n",
754 | " 0 | \n",
755 | " 0 | \n",
756 | " 0 | \n",
757 | " 1 | \n",
758 | "
\n",
759 | " \n",
760 | " | 3 | \n",
761 | " 17623 | \n",
762 | " men | \n",
763 | " 169.0 | \n",
764 | " 82.0 | \n",
765 | " 150 | \n",
766 | " 100 | \n",
767 | " normal | \n",
768 | " normal | \n",
769 | " 0 | \n",
770 | " 0 | \n",
771 | " 1 | \n",
772 | " 1 | \n",
773 | "
\n",
774 | " \n",
775 | " | 4 | \n",
776 | " 17474 | \n",
777 | " women | \n",
778 | " 156.0 | \n",
779 | " 56.0 | \n",
780 | " 100 | \n",
781 | " 60 | \n",
782 | " normal | \n",
783 | " normal | \n",
784 | " 0 | \n",
785 | " 0 | \n",
786 | " 0 | \n",
787 | " 0 | \n",
788 | "
\n",
789 | " \n",
790 | "
\n",
791 | "
"
792 | ],
793 | "text/plain": [
794 | " age gender height weight ap_hi ap_lo cholesterol gluc \\\n",
795 | "id \n",
796 | "0 18393 men 168.0 62.0 110 80 normal normal \n",
797 | "1 20228 women 156.0 85.0 140 90 well_above_normal normal \n",
798 | "2 18857 women 165.0 64.0 130 70 NaN normal \n",
799 | "3 17623 men 169.0 82.0 150 100 normal normal \n",
800 | "4 17474 women 156.0 56.0 100 60 normal normal \n",
801 | "\n",
802 | " smoke alco active cardio \n",
803 | "id \n",
804 | "0 0 0 1 0 \n",
805 | "1 0 0 1 1 \n",
806 | "2 0 0 0 1 \n",
807 | "3 0 0 1 1 \n",
808 | "4 0 0 0 0 "
809 | ]
810 | },
811 | "metadata": {},
812 | "output_type": "display_data"
813 | }
814 | ],
815 | "source": [
816 | "df_data[\"gender\"] = df_data[\"gender\"].replace({\n",
817 | " 1: \"women\",\n",
818 | " 2: \"men\"\n",
819 | "})\n",
820 | "\n",
821 | "df_data[\"cholesterol\"] = df_data[\"cholesterol\"].replace({\n",
822 | " 1: \"normal\",\n",
823 | " 2: \"above_normal\",\n",
824 | " 3: \"well_above_normal\"\n",
825 | "})\n",
826 | "\n",
827 | "df_data[\"gluc\"] = df_data[\"gluc\"].replace({\n",
828 | " 1: \"normal\",\n",
829 | " 2: \"above_normal\",\n",
830 | " 3: \"well_above_normal\"\n",
831 | "})\n",
832 | "\n",
833 | "display(df_data.describe(include=\"all\"))\n",
834 | "display(df_data.head())"
835 | ]
836 | },
837 | {
838 | "cell_type": "markdown",
839 | "metadata": {},
840 | "source": [
841 | "# Write messy csv to disk"
842 | ]
843 | },
844 | {
845 | "cell_type": "code",
846 | "execution_count": 5,
847 | "metadata": {},
848 | "outputs": [],
849 | "source": [
850 | "df_data.to_csv(\"./cardiovascular-disease-dataset/messy/cardio_train.csv\", sep=';', index_label=\"id\")"
851 | ]
852 | }
853 | ],
854 | "metadata": {
855 | "kernelspec": {
856 | "display_name": "Python 3",
857 | "language": "python",
858 | "name": "python3"
859 | },
860 | "language_info": {
861 | "codemirror_mode": {
862 | "name": "ipython",
863 | "version": 3
864 | },
865 | "file_extension": ".py",
866 | "mimetype": "text/x-python",
867 | "name": "python",
868 | "nbconvert_exporter": "python",
869 | "pygments_lexer": "ipython3",
870 | "version": "3.7.1"
871 | }
872 | },
873 | "nbformat": 4,
874 | "nbformat_minor": 2
875 | }
876 |
--------------------------------------------------------------------------------
/sklearn-pandas-catboost.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import numpy as np\n",
11 | "from sklearn_pandas import DataFrameMapper, gen_features\n",
12 | "from sklearn.model_selection import train_test_split\n",
13 | "from sklearn.impute import SimpleImputer\n",
14 | "from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler\n",
15 | "from sklearn.base import BaseEstimator, TransformerMixin\n",
16 | "from sklearn.pipeline import Pipeline\n",
17 | "from sklearn.compose import ColumnTransformer\n",
18 | "from sklearn.feature_selection import SelectFromModel\n",
19 | "from catboost import CatBoostClassifier\n",
20 | "from sklearn.model_selection import GridSearchCV"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | "## Read data\n",
28 | "We are not reading the original data from Kaggel. \n",
29 | "The data we are reading is based on the data from Kaggle with small changes:\n",
30 | "* Random replacement of values with NaNs\n",
31 | "* Transformations of categorical (int) features to text categories"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 2,
37 | "metadata": {},
38 | "outputs": [
39 | {
40 | "data": {
41 | "text/html": [
42 | "\n",
43 | "\n",
56 | "
\n",
57 | " \n",
58 | " \n",
59 | " | \n",
60 | " age | \n",
61 | " gender | \n",
62 | " height | \n",
63 | " weight | \n",
64 | " ap_hi | \n",
65 | " ap_lo | \n",
66 | " cholesterol | \n",
67 | " gluc | \n",
68 | " smoke | \n",
69 | " alco | \n",
70 | " active | \n",
71 | " cardio | \n",
72 | "
\n",
73 | " \n",
74 | " \n",
75 | " \n",
76 | " | count | \n",
77 | " 70000.000000 | \n",
78 | " 70000 | \n",
79 | " 66591.000000 | \n",
80 | " 66578.000000 | \n",
81 | " 70000.000000 | \n",
82 | " 70000.000000 | \n",
83 | " 66589 | \n",
84 | " 70000 | \n",
85 | " 70000.000000 | \n",
86 | " 70000.000000 | \n",
87 | " 70000.000000 | \n",
88 | " 70000.000000 | \n",
89 | "
\n",
90 | " \n",
91 | " | unique | \n",
92 | " NaN | \n",
93 | " 2 | \n",
94 | " NaN | \n",
95 | " NaN | \n",
96 | " NaN | \n",
97 | " NaN | \n",
98 | " 3 | \n",
99 | " 3 | \n",
100 | " NaN | \n",
101 | " NaN | \n",
102 | " NaN | \n",
103 | " NaN | \n",
104 | "
\n",
105 | " \n",
106 | " | top | \n",
107 | " NaN | \n",
108 | " women | \n",
109 | " NaN | \n",
110 | " NaN | \n",
111 | " NaN | \n",
112 | " NaN | \n",
113 | " normal | \n",
114 | " normal | \n",
115 | " NaN | \n",
116 | " NaN | \n",
117 | " NaN | \n",
118 | " NaN | \n",
119 | "
\n",
120 | " \n",
121 | " | freq | \n",
122 | " NaN | \n",
123 | " 45530 | \n",
124 | " NaN | \n",
125 | " NaN | \n",
126 | " NaN | \n",
127 | " NaN | \n",
128 | " 49789 | \n",
129 | " 59479 | \n",
130 | " NaN | \n",
131 | " NaN | \n",
132 | " NaN | \n",
133 | " NaN | \n",
134 | "
\n",
135 | " \n",
136 | " | mean | \n",
137 | " 19468.865814 | \n",
138 | " NaN | \n",
139 | " 164.361205 | \n",
140 | " 74.210467 | \n",
141 | " 128.817286 | \n",
142 | " 96.630414 | \n",
143 | " NaN | \n",
144 | " NaN | \n",
145 | " 0.088129 | \n",
146 | " 0.053771 | \n",
147 | " 0.803729 | \n",
148 | " 0.499700 | \n",
149 | "
\n",
150 | " \n",
151 | " | std | \n",
152 | " 2467.251667 | \n",
153 | " NaN | \n",
154 | " 8.226411 | \n",
155 | " 14.397678 | \n",
156 | " 154.011419 | \n",
157 | " 188.472530 | \n",
158 | " NaN | \n",
159 | " NaN | \n",
160 | " 0.283484 | \n",
161 | " 0.225568 | \n",
162 | " 0.397179 | \n",
163 | " 0.500003 | \n",
164 | "
\n",
165 | " \n",
166 | " | min | \n",
167 | " 10798.000000 | \n",
168 | " NaN | \n",
169 | " 55.000000 | \n",
170 | " 10.000000 | \n",
171 | " -150.000000 | \n",
172 | " -70.000000 | \n",
173 | " NaN | \n",
174 | " NaN | \n",
175 | " 0.000000 | \n",
176 | " 0.000000 | \n",
177 | " 0.000000 | \n",
178 | " 0.000000 | \n",
179 | "
\n",
180 | " \n",
181 | " | 25% | \n",
182 | " 17664.000000 | \n",
183 | " NaN | \n",
184 | " 159.000000 | \n",
185 | " 65.000000 | \n",
186 | " 120.000000 | \n",
187 | " 80.000000 | \n",
188 | " NaN | \n",
189 | " NaN | \n",
190 | " 0.000000 | \n",
191 | " 0.000000 | \n",
192 | " 1.000000 | \n",
193 | " 0.000000 | \n",
194 | "
\n",
195 | " \n",
196 | " | 50% | \n",
197 | " 19703.000000 | \n",
198 | " NaN | \n",
199 | " 165.000000 | \n",
200 | " 72.000000 | \n",
201 | " 120.000000 | \n",
202 | " 80.000000 | \n",
203 | " NaN | \n",
204 | " NaN | \n",
205 | " 0.000000 | \n",
206 | " 0.000000 | \n",
207 | " 1.000000 | \n",
208 | " 0.000000 | \n",
209 | "
\n",
210 | " \n",
211 | " | 75% | \n",
212 | " 21327.000000 | \n",
213 | " NaN | \n",
214 | " 170.000000 | \n",
215 | " 82.000000 | \n",
216 | " 140.000000 | \n",
217 | " 90.000000 | \n",
218 | " NaN | \n",
219 | " NaN | \n",
220 | " 0.000000 | \n",
221 | " 0.000000 | \n",
222 | " 1.000000 | \n",
223 | " 1.000000 | \n",
224 | "
\n",
225 | " \n",
226 | " | max | \n",
227 | " 23713.000000 | \n",
228 | " NaN | \n",
229 | " 250.000000 | \n",
230 | " 200.000000 | \n",
231 | " 16020.000000 | \n",
232 | " 11000.000000 | \n",
233 | " NaN | \n",
234 | " NaN | \n",
235 | " 1.000000 | \n",
236 | " 1.000000 | \n",
237 | " 1.000000 | \n",
238 | " 1.000000 | \n",
239 | "
\n",
240 | " \n",
241 | "
\n",
242 | "
"
243 | ],
244 | "text/plain": [
245 | " age gender height weight ap_hi \\\n",
246 | "count 70000.000000 70000 66591.000000 66578.000000 70000.000000 \n",
247 | "unique NaN 2 NaN NaN NaN \n",
248 | "top NaN women NaN NaN NaN \n",
249 | "freq NaN 45530 NaN NaN NaN \n",
250 | "mean 19468.865814 NaN 164.361205 74.210467 128.817286 \n",
251 | "std 2467.251667 NaN 8.226411 14.397678 154.011419 \n",
252 | "min 10798.000000 NaN 55.000000 10.000000 -150.000000 \n",
253 | "25% 17664.000000 NaN 159.000000 65.000000 120.000000 \n",
254 | "50% 19703.000000 NaN 165.000000 72.000000 120.000000 \n",
255 | "75% 21327.000000 NaN 170.000000 82.000000 140.000000 \n",
256 | "max 23713.000000 NaN 250.000000 200.000000 16020.000000 \n",
257 | "\n",
258 | " ap_lo cholesterol gluc smoke alco \\\n",
259 | "count 70000.000000 66589 70000 70000.000000 70000.000000 \n",
260 | "unique NaN 3 3 NaN NaN \n",
261 | "top NaN normal normal NaN NaN \n",
262 | "freq NaN 49789 59479 NaN NaN \n",
263 | "mean 96.630414 NaN NaN 0.088129 0.053771 \n",
264 | "std 188.472530 NaN NaN 0.283484 0.225568 \n",
265 | "min -70.000000 NaN NaN 0.000000 0.000000 \n",
266 | "25% 80.000000 NaN NaN 0.000000 0.000000 \n",
267 | "50% 80.000000 NaN NaN 0.000000 0.000000 \n",
268 | "75% 90.000000 NaN NaN 0.000000 0.000000 \n",
269 | "max 11000.000000 NaN NaN 1.000000 1.000000 \n",
270 | "\n",
271 | " active cardio \n",
272 | "count 70000.000000 70000.000000 \n",
273 | "unique NaN NaN \n",
274 | "top NaN NaN \n",
275 | "freq NaN NaN \n",
276 | "mean 0.803729 0.499700 \n",
277 | "std 0.397179 0.500003 \n",
278 | "min 0.000000 0.000000 \n",
279 | "25% 1.000000 0.000000 \n",
280 | "50% 1.000000 0.000000 \n",
281 | "75% 1.000000 1.000000 \n",
282 | "max 1.000000 1.000000 "
283 | ]
284 | },
285 | "metadata": {},
286 | "output_type": "display_data"
287 | },
288 | {
289 | "data": {
290 | "text/html": [
291 | "\n",
292 | "\n",
305 | "
\n",
306 | " \n",
307 | " \n",
308 | " | \n",
309 | " age | \n",
310 | " gender | \n",
311 | " height | \n",
312 | " weight | \n",
313 | " ap_hi | \n",
314 | " ap_lo | \n",
315 | " cholesterol | \n",
316 | " gluc | \n",
317 | " smoke | \n",
318 | " alco | \n",
319 | " active | \n",
320 | " cardio | \n",
321 | "
\n",
322 | " \n",
323 | " | id | \n",
324 | " | \n",
325 | " | \n",
326 | " | \n",
327 | " | \n",
328 | " | \n",
329 | " | \n",
330 | " | \n",
331 | " | \n",
332 | " | \n",
333 | " | \n",
334 | " | \n",
335 | " | \n",
336 | "
\n",
337 | " \n",
338 | " \n",
339 | " \n",
340 | " | 0 | \n",
341 | " 18393 | \n",
342 | " men | \n",
343 | " 168.0 | \n",
344 | " 62.0 | \n",
345 | " 110 | \n",
346 | " 80 | \n",
347 | " normal | \n",
348 | " normal | \n",
349 | " 0 | \n",
350 | " 0 | \n",
351 | " 1 | \n",
352 | " 0 | \n",
353 | "
\n",
354 | " \n",
355 | " | 1 | \n",
356 | " 20228 | \n",
357 | " women | \n",
358 | " 156.0 | \n",
359 | " 85.0 | \n",
360 | " 140 | \n",
361 | " 90 | \n",
362 | " well_above_normal | \n",
363 | " normal | \n",
364 | " 0 | \n",
365 | " 0 | \n",
366 | " 1 | \n",
367 | " 1 | \n",
368 | "
\n",
369 | " \n",
370 | " | 2 | \n",
371 | " 18857 | \n",
372 | " women | \n",
373 | " 165.0 | \n",
374 | " 64.0 | \n",
375 | " 130 | \n",
376 | " 70 | \n",
377 | " NaN | \n",
378 | " normal | \n",
379 | " 0 | \n",
380 | " 0 | \n",
381 | " 0 | \n",
382 | " 1 | \n",
383 | "
\n",
384 | " \n",
385 | " | 3 | \n",
386 | " 17623 | \n",
387 | " men | \n",
388 | " 169.0 | \n",
389 | " 82.0 | \n",
390 | " 150 | \n",
391 | " 100 | \n",
392 | " normal | \n",
393 | " normal | \n",
394 | " 0 | \n",
395 | " 0 | \n",
396 | " 1 | \n",
397 | " 1 | \n",
398 | "
\n",
399 | " \n",
400 | " | 4 | \n",
401 | " 17474 | \n",
402 | " women | \n",
403 | " 156.0 | \n",
404 | " 56.0 | \n",
405 | " 100 | \n",
406 | " 60 | \n",
407 | " normal | \n",
408 | " normal | \n",
409 | " 0 | \n",
410 | " 0 | \n",
411 | " 0 | \n",
412 | " 0 | \n",
413 | "
\n",
414 | " \n",
415 | "
\n",
416 | "
"
417 | ],
418 | "text/plain": [
419 | " age gender height weight ap_hi ap_lo cholesterol gluc \\\n",
420 | "id \n",
421 | "0 18393 men 168.0 62.0 110 80 normal normal \n",
422 | "1 20228 women 156.0 85.0 140 90 well_above_normal normal \n",
423 | "2 18857 women 165.0 64.0 130 70 NaN normal \n",
424 | "3 17623 men 169.0 82.0 150 100 normal normal \n",
425 | "4 17474 women 156.0 56.0 100 60 normal normal \n",
426 | "\n",
427 | " smoke alco active cardio \n",
428 | "id \n",
429 | "0 0 0 1 0 \n",
430 | "1 0 0 1 1 \n",
431 | "2 0 0 0 1 \n",
432 | "3 0 0 1 1 \n",
433 | "4 0 0 0 0 "
434 | ]
435 | },
436 | "metadata": {},
437 | "output_type": "display_data"
438 | }
439 | ],
440 | "source": [
441 | "np.random.seed(seed=42)\n",
442 | "df_data = pd.read_csv(\"./cardiovascular-disease-dataset/messy/cardio_train.csv\", sep=';', index_col=\"id\")\n",
443 | "display(df_data.describe(include=\"all\"))\n",
444 | "display(df_data.head())"
445 | ]
446 | },
447 | {
448 | "cell_type": "markdown",
449 | "metadata": {},
450 | "source": [
451 | "## Initializations "
452 | ]
453 | },
454 | {
455 | "cell_type": "markdown",
456 | "metadata": {},
457 | "source": [
458 | "Declerations of all column types and target column"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": 3,
464 | "metadata": {},
465 | "outputs": [],
466 | "source": [
467 | "category_features = [[\"cholesterol\"], [\"gluc\"]]\n",
468 | "binary_features = [[\"gender\"], [\"smoke\"], [\"alco\"], [\"active\"]]\n",
469 | "numeric_features = [[\"age\"], [\"height\"], [\"weight\"], [\"ap_hi\"], [\"ap_lo\"]]\n",
470 | "target = \"cardio\"\n",
471 | "\n",
472 | "categorical_suffix = \"_#CAT#\""
473 | ]
474 | },
475 | {
476 | "cell_type": "markdown",
477 | "metadata": {},
478 | "source": [
479 | "Split the data into features and labels"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": 4,
485 | "metadata": {},
486 | "outputs": [],
487 | "source": [
488 | "X = df_data.copy()\n",
489 | "y = X.pop(target)"
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "metadata": {},
495 | "source": [
496 | "## Split data to train test datasets"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": null,
502 | "metadata": {},
503 | "outputs": [],
504 | "source": [
505 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)"
506 | ]
507 | },
508 | {
509 | "cell_type": "markdown",
510 | "metadata": {},
511 | "source": [
512 | "## Custom Transformers"
513 | ]
514 | },
515 | {
516 | "cell_type": "markdown",
517 | "metadata": {},
518 | "source": [
519 | "We create custom transformers for feature engineering"
520 | ]
521 | },
522 | {
523 | "cell_type": "markdown",
524 | "metadata": {},
525 | "source": [
526 | "### Blood Pressure Transformer"
527 | ]
528 | },
529 | {
530 | "cell_type": "markdown",
531 | "metadata": {},
532 | "source": [
533 | "Custom transformer responsible for the creation of a new blood pressure categorical feature based on systolic (the number at the top) and diastolic (the number at the bottom) values. \n",
534 | "The transformer will create a new categorical feature with values according to the American Heart Association ranges of blood pressure:\n",
535 | "* normal \n",
536 | "* elevated \n",
537 | "* high_pressure_stage_1 \n",
538 | "* high_pressure_stage_2 \n",
539 | "* hypertensive_crisis \n",
540 | " \n",
541 | " \n",
542 | " \n",
543 | "Photo from [American Heart Association](https://www.heart.org/-/media/data-import/downloadables/pe-abh-what-is-high-blood-pressure-ucm_300310.pdf?la=en&hash=CAC0F1D377BDB7BC3870993918226869524AAC3D)"
544 | ]
545 | },
546 | {
547 | "cell_type": "code",
548 | "execution_count": null,
549 | "metadata": {},
550 | "outputs": [],
551 | "source": [
552 | "class BloodPressureTransformer(BaseEstimator, TransformerMixin):\n",
553 | " \n",
554 | " def __init__(self):\n",
555 | " \n",
556 | " # Systolic and diastolic blood pressure ranges based on the American Heart Association\n",
557 | " self.systolic_ranges = [-np.inf, 119, 129, 139, 180, np.inf]\n",
558 | " self.diastolic_ranges = [-np.inf, 79, 89, 120, np.inf]\n",
559 | " \n",
560 | " # Blood pressure categories\n",
561 | " self.blood_pressure_category = [\"normal\", \"elevated\", \"high_pressure_stage_1\", \"high_pressure_stage_2\", \"hypertensive_crisis\"]\n",
562 | " \n",
563 | " def fit(self, X, y=None):\n",
564 | " return self\n",
565 | " \n",
566 | " def transform(self, X):\n",
567 | " # Copy the data so we will not change the original instance\n",
568 | " df_blood_pressure = X.copy()\n",
569 | " \n",
570 | " # Break down ranges of systolic values to categories\n",
571 | " df_blood_pressure[\"systolic\"] = pd.cut(df_blood_pressure[\"ap_hi\"], self.systolic_ranges, labels=[\"<120\", \"120-129\", \"130-139\", \"140-180\", \">180\"])\n",
572 | " \n",
573 | " # Break down ranges of diastolic values to categories\n",
574 | " df_blood_pressure[\"diastolic\"] = pd.cut(df_blood_pressure[\"ap_lo\"], self.diastolic_ranges, labels=[\"<79\", \"80-89\", \"90-120\", \">120\"])\n",
575 | " \n",
576 | " # Combine ranges from systolic and diastolic features to determine the category of the blood pressure feature\n",
577 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \"<120\") &\n",
578 | " (df_blood_pressure[\"diastolic\"] == \"<79\"), \"blood_pressure\"] = self.blood_pressure_category[0]\n",
579 | " \n",
580 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \"120-129\") &\n",
581 | " (df_blood_pressure[\"diastolic\"] == \"<79\"), \"blood_pressure\"] = self.blood_pressure_category[1]\n",
582 | " \n",
583 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \"130-139\") |\n",
584 | " (df_blood_pressure[\"diastolic\"] == \"80-89\"), \"blood_pressure\"] = self.blood_pressure_category[2]\n",
585 | " \n",
586 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \"140-180\") |\n",
587 | " (df_blood_pressure[\"diastolic\"] == \"90-120\"), \"blood_pressure\"] = self.blood_pressure_category[3]\n",
588 | " \n",
589 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \">180\") |\n",
590 | " (df_blood_pressure[\"diastolic\"] == \">120\"), \"blood_pressure\"] = self.blood_pressure_category[4]\n",
591 | " \n",
592 | " # Return blood pressure feature as a dataframe with one column\n",
593 | " return df_blood_pressure[[\"blood_pressure\"]]"
594 | ]
595 | },
596 | {
597 | "cell_type": "markdown",
598 | "metadata": {},
599 | "source": [
600 | "### Unhealty Lifestyle Transformer"
601 | ]
602 | },
603 | {
604 | "cell_type": "markdown",
605 | "metadata": {},
606 | "source": [
607 | "Custom transformer responsible for the creation of a new \"unhealty lifestyle\" feature. \n",
608 | "This is a boolean feature representing the use of cigarettes, alcohol, and physical inactivity. "
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": null,
614 | "metadata": {},
615 | "outputs": [],
616 | "source": [
617 | "class UnhealtyLifestyleTransformer(BaseEstimator, TransformerMixin):\n",
618 | " \n",
619 | " def fit(self, X, y=None):\n",
620 | " return self\n",
621 | " \n",
622 | " def transform(self, X):\n",
623 | " # Copy the data so we will not change the original instance\n",
624 | " df_unhealty_lifestyle = X.copy()\n",
625 | " \n",
626 | " # If you smoke or use alcohol or don't do physical activity, you maintain an unhealty lifestyle!\n",
627 | " df_unhealty_lifestyle[\"unhealty_lifestyle\"] = df_unhealty_lifestyle[\"smoke\"] | df_unhealty_lifestyle[\"alco\"] | (1 - df_unhealty_lifestyle[\"active\"])\n",
628 | " \n",
629 | " # Return unhealty lifestyle feature as a dataframe with one column\n",
630 | " return df_unhealty_lifestyle[[\"unhealty_lifestyle\"]]"
631 | ]
632 | },
633 | {
634 | "cell_type": "markdown",
635 | "metadata": {},
636 | "source": [
637 | "## Definition of DataFrameMapper transformers"
638 | ]
639 | },
640 | {
641 | "cell_type": "markdown",
642 | "metadata": {},
643 | "source": [
644 | "Now we will define the pipeline of transformations and the raw features we need to complete the creation and processing of the new features and the original features. \n",
645 | "We will pass this to the DataFrameMapper class of the sklearn-pandas package."
646 | ]
647 | },
648 | {
649 | "cell_type": "code",
650 | "execution_count": null,
651 | "metadata": {},
652 | "outputs": [
653 | {
654 | "data": {
655 | "text/plain": [
656 | "(['ap_hi', 'ap_lo'],\n",
657 | " [BloodPressureTransformer(),\n",
658 | " SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
659 | " strategy='most_frequent', verbose=0)],\n",
660 | " {'alias': 'blood_pressure_#CAT#'})"
661 | ]
662 | },
663 | "execution_count": 8,
664 | "metadata": {},
665 | "output_type": "execute_result"
666 | }
667 | ],
668 | "source": [
669 | "# Input features \"ap_hi\", \"ap_lo\".\n",
670 | "# Steps:\n",
671 | "# BloodPressureTransformer - create blood pressure feature based on \"ap_hi\", \"ap_lo\".\n",
672 | "# SimpleImputer - fill nans with the most frequent value.\n",
673 | "# OneHotEncoder - encode categorical values as a one-hot numeric array.\n",
674 | "gen_blood_pressure = (\n",
675 | " [\"ap_hi\", \"ap_lo\"],\n",
676 | " [\n",
677 | " BloodPressureTransformer(),\n",
678 | " SimpleImputer(strategy=\"most_frequent\")\n",
679 | " ],\n",
680 | " {\"alias\": f\"blood_pressure{categorical_suffix}\"}\n",
681 | ")\n",
682 | "\n",
683 | "gen_blood_pressure"
684 | ]
685 | },
686 | {
687 | "cell_type": "code",
688 | "execution_count": null,
689 | "metadata": {},
690 | "outputs": [
691 | {
692 | "data": {
693 | "text/plain": [
694 | "(['smoke', 'alco', 'active'],\n",
695 | " [UnhealtyLifestyleTransformer(),\n",
696 | " SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
697 | " strategy='most_frequent', verbose=0)],\n",
698 | " {'alias': 'unhealty_lifestyle_#CAT#'})"
699 | ]
700 | },
701 | "execution_count": 9,
702 | "metadata": {},
703 | "output_type": "execute_result"
704 | }
705 | ],
706 | "source": [
707 | "# Input features [\"smoke\", \"alco\", \"active\"].\n",
708 | "# Steps:\n",
709 | "# UnhealtyLifestyleTransformer - create unhealty lifestyle feature based on \"smoke\", \"alco\", \"active\".\n",
710 | "# SimpleImputer - fill nans with the most frequent value.\n",
711 | "gen_unhealty_lifestyle = (\n",
712 | " [\"smoke\", \"alco\", \"active\"],\n",
713 | " [\n",
714 | " UnhealtyLifestyleTransformer(),\n",
715 | " SimpleImputer(strategy=\"most_frequent\")\n",
716 | " ],\n",
717 | " {\"alias\": f\"unhealty_lifestyle{categorical_suffix}\"}\n",
718 | ")\n",
719 | "\n",
720 | "gen_unhealty_lifestyle"
721 | ]
722 | },
723 | {
724 | "cell_type": "markdown",
725 | "metadata": {},
726 | "source": [
727 | "### Apply the same transformers for multiple columns with gen_features"
728 | ]
729 | },
730 | {
731 | "cell_type": "code",
732 | "execution_count": null,
733 | "metadata": {},
734 | "outputs": [
735 | {
736 | "data": {
737 | "text/plain": [
738 | "[(['cholesterol'],\n",
739 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
740 | " strategy='most_frequent', verbose=0)],\n",
741 | " {'alias': 'cholesterol_#CAT#'}),\n",
742 | " (['gluc'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
743 | " strategy='most_frequent', verbose=0)], {'alias': 'gluc_#CAT#'})]"
744 | ]
745 | },
746 | "execution_count": 10,
747 | "metadata": {},
748 | "output_type": "execute_result"
749 | }
750 | ],
751 | "source": [
752 | "# Input features [[\"cholesterol\"], [\"gluc\"]] (The columns are now list of lists because we want to send 2-dimentional DataFrame to each of the transformers).\n",
753 | "# Steps:\n",
754 | "# SimpleImputer - fill nans with the most frequent value.\n",
755 | "# OneHotEncoder - encode categorical values as a one-hot numeric array.\n",
756 | "gen_category = gen_features(\n",
757 | " columns=category_features,\n",
758 | " classes=[\n",
759 | " {\n",
760 | " \"class\": SimpleImputer,\n",
761 | " \"strategy\": \"most_frequent\"\n",
762 | " }\n",
763 | " ]\n",
764 | ")\n",
765 | "\n",
766 | "gen_category = [(col_name, transformer, {\"alias\": col_name[0] + categorical_suffix}) for col_name, transformer in gen_category]\n",
767 | "gen_category"
768 | ]
769 | },
770 | {
771 | "cell_type": "code",
772 | "execution_count": null,
773 | "metadata": {},
774 | "outputs": [
775 | {
776 | "data": {
777 | "text/plain": [
778 | "[(['gender'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
779 | " strategy='most_frequent', verbose=0)], {'alias': 'gender_#CAT#'}),\n",
780 | " (['smoke'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
781 | " strategy='most_frequent', verbose=0)], {'alias': 'smoke_#CAT#'}),\n",
782 | " (['alco'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
783 | " strategy='most_frequent', verbose=0)], {'alias': 'alco_#CAT#'}),\n",
784 | " (['active'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
785 | " strategy='most_frequent', verbose=0)], {'alias': 'active_#CAT#'})]"
786 | ]
787 | },
788 | "execution_count": 11,
789 | "metadata": {},
790 | "output_type": "execute_result"
791 | }
792 | ],
793 | "source": [
794 | "# Input features [[\"gender\"], [\"smoke\"], [\"alco\"], [\"active\"]] (The columns are now list of lists because we want to send 2-dimentional DataFrame to each of the transformers).\n",
795 | "# Steps:\n",
796 | "# SimpleImputer - fill nans with the most frequent value.\n",
797 | "gen_binary = gen_features(\n",
798 | " columns=binary_features,\n",
799 | " classes=[\n",
800 | " {\n",
801 | " \"class\": SimpleImputer,\n",
802 | " \"strategy\": \"most_frequent\"\n",
803 | " }\n",
804 | " ]\n",
805 | ")\n",
806 | "\n",
807 | "gen_binary = [(col_name, transformer, {\"alias\": col_name[0] + categorical_suffix}) for col_name, transformer in gen_binary]\n",
808 | "gen_binary"
809 | ]
810 | },
811 | {
812 | "cell_type": "code",
813 | "execution_count": null,
814 | "metadata": {},
815 | "outputs": [
816 | {
817 | "data": {
818 | "text/plain": [
819 | "[(['age'],\n",
820 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
821 | " verbose=0),\n",
822 | " StandardScaler(copy=True, with_mean=True, with_std=True)]),\n",
823 | " (['height'],\n",
824 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
825 | " verbose=0),\n",
826 | " StandardScaler(copy=True, with_mean=True, with_std=True)]),\n",
827 | " (['weight'],\n",
828 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
829 | " verbose=0),\n",
830 | " StandardScaler(copy=True, with_mean=True, with_std=True)]),\n",
831 | " (['ap_hi'],\n",
832 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
833 | " verbose=0),\n",
834 | " StandardScaler(copy=True, with_mean=True, with_std=True)]),\n",
835 | " (['ap_lo'],\n",
836 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
837 | " verbose=0),\n",
838 | " StandardScaler(copy=True, with_mean=True, with_std=True)])]"
839 | ]
840 | },
841 | "execution_count": 12,
842 | "metadata": {},
843 | "output_type": "execute_result"
844 | }
845 | ],
846 | "source": [
847 | "# Input features [[\"age\"], [\"height\"], [\"weight\"], [\"ap_hi\"], [\"ap_lo\"]] (The columns are now list of lists because we want to send 2-dimentional DataFrame to each of the transformers).\n",
848 | "# Steps:\n",
849 | "# SimpleImputer - fill nans with the mean value.\n",
850 | "# StandardScaler - standardize features by removing the mean and scaling to unit variance.\n",
851 | "gen_numeric = gen_features(\n",
852 | " columns=numeric_features,\n",
853 | " classes=[\n",
854 | " {\n",
855 | " \"class\": SimpleImputer,\n",
856 | " \"strategy\": \"mean\"\n",
857 | " },\n",
858 | " {\n",
859 | " \"class\": StandardScaler\n",
860 | " }\n",
861 | " ]\n",
862 | ")\n",
863 | "\n",
864 | "gen_numeric"
865 | ]
866 | },
867 | {
868 | "cell_type": "markdown",
869 | "metadata": {},
870 | "source": [
871 | "### DataFrameMapper Construction"
872 | ]
873 | },
874 | {
875 | "cell_type": "markdown",
876 | "metadata": {},
877 | "source": [
878 | "Now we will define the course of action of the DataFrameMapper and indicate that the input and output will be Pandas Dataframe."
879 | ]
880 | },
881 | {
882 | "cell_type": "code",
883 | "execution_count": null,
884 | "metadata": {},
885 | "outputs": [],
886 | "source": [
887 | "preprocess_mapper = DataFrameMapper(\n",
888 | " [\n",
889 | " gen_blood_pressure,\n",
890 | " gen_unhealty_lifestyle,\n",
891 | " *gen_category,\n",
892 | " *gen_binary,\n",
893 | " *gen_numeric,\n",
894 | " ],\n",
895 | " input_df=True,\n",
896 | " df_out=True\n",
897 | ")"
898 | ]
899 | },
900 | {
901 | "cell_type": "code",
902 | "execution_count": null,
903 | "metadata": {},
904 | "outputs": [
905 | {
906 | "data": {
907 | "text/html": [
908 | "\n",
909 | "\n",
922 | "
\n",
923 | " \n",
924 | " \n",
925 | " | \n",
926 | " blood_pressure_#CAT# | \n",
927 | " unhealty_lifestyle_#CAT# | \n",
928 | " cholesterol_#CAT# | \n",
929 | " gluc_#CAT# | \n",
930 | " gender_#CAT# | \n",
931 | " smoke_#CAT# | \n",
932 | " alco_#CAT# | \n",
933 | " active_#CAT# | \n",
934 | " age | \n",
935 | " height | \n",
936 | " weight | \n",
937 | " ap_hi | \n",
938 | " ap_lo | \n",
939 | "
\n",
940 | " \n",
941 | " | id | \n",
942 | " | \n",
943 | " | \n",
944 | " | \n",
945 | " | \n",
946 | " | \n",
947 | " | \n",
948 | " | \n",
949 | " | \n",
950 | " | \n",
951 | " | \n",
952 | " | \n",
953 | " | \n",
954 | " | \n",
955 | "
\n",
956 | " \n",
957 | " \n",
958 | " \n",
959 | " | 98125 | \n",
960 | " high_pressure_stage_2 | \n",
961 | " 0 | \n",
962 | " well_above_normal | \n",
963 | " normal | \n",
964 | " women | \n",
965 | " 0 | \n",
966 | " 0 | \n",
967 | " 1 | \n",
968 | " 0.388175 | \n",
969 | " -0.549199 | \n",
970 | " 1.024035e-15 | \n",
971 | " -0.056261 | \n",
972 | " -0.035475 | \n",
973 | "
\n",
974 | " \n",
975 | " | 28510 | \n",
976 | " high_pressure_stage_1 | \n",
977 | " 1 | \n",
978 | " well_above_normal | \n",
979 | " well_above_normal | \n",
980 | " men | \n",
981 | " 0 | \n",
982 | " 0 | \n",
983 | " 0 | \n",
984 | " 1.308468 | \n",
985 | " 0.333478 | \n",
986 | " -6.621198e-01 | \n",
987 | " -0.056261 | \n",
988 | " -0.085423 | \n",
989 | "
\n",
990 | " \n",
991 | " | 15795 | \n",
992 | " high_pressure_stage_2 | \n",
993 | " 0 | \n",
994 | " normal | \n",
995 | " normal | \n",
996 | " women | \n",
997 | " 0 | \n",
998 | " 0 | \n",
999 | " 1 | \n",
1000 | " 1.346527 | \n",
1001 | " -0.549199 | \n",
1002 | " -5.900598e-01 | \n",
1003 | " -0.056261 | \n",
1004 | " -0.035475 | \n",
1005 | "
\n",
1006 | " \n",
1007 | " | 39560 | \n",
1008 | " high_pressure_stage_2 | \n",
1009 | " 0 | \n",
1010 | " well_above_normal | \n",
1011 | " normal | \n",
1012 | " women | \n",
1013 | " 0 | \n",
1014 | " 0 | \n",
1015 | " 1 | \n",
1016 | " 1.291463 | \n",
1017 | " -0.170909 | \n",
1018 | " -1.382720e+00 | \n",
1019 | " -0.024632 | \n",
1020 | " -0.035475 | \n",
1021 | "
\n",
1022 | " \n",
1023 | " | 32677 | \n",
1024 | " high_pressure_stage_2 | \n",
1025 | " 0 | \n",
1026 | " well_above_normal | \n",
1027 | " normal | \n",
1028 | " women | \n",
1029 | " 0 | \n",
1030 | " 0 | \n",
1031 | " 1 | \n",
1032 | " 0.912495 | \n",
1033 | " -0.801392 | \n",
1034 | " 7.790809e-01 | \n",
1035 | " 0.133513 | \n",
1036 | " -0.085423 | \n",
1037 | "
\n",
1038 | " \n",
1039 | "
\n",
1040 | "
"
1041 | ],
1042 | "text/plain": [
1043 | " blood_pressure_#CAT# unhealty_lifestyle_#CAT# cholesterol_#CAT# \\\n",
1044 | "id \n",
1045 | "98125 high_pressure_stage_2 0 well_above_normal \n",
1046 | "28510 high_pressure_stage_1 1 well_above_normal \n",
1047 | "15795 high_pressure_stage_2 0 normal \n",
1048 | "39560 high_pressure_stage_2 0 well_above_normal \n",
1049 | "32677 high_pressure_stage_2 0 well_above_normal \n",
1050 | "\n",
1051 | " gluc_#CAT# gender_#CAT# smoke_#CAT# alco_#CAT# active_#CAT# \\\n",
1052 | "id \n",
1053 | "98125 normal women 0 0 1 \n",
1054 | "28510 well_above_normal men 0 0 0 \n",
1055 | "15795 normal women 0 0 1 \n",
1056 | "39560 normal women 0 0 1 \n",
1057 | "32677 normal women 0 0 1 \n",
1058 | "\n",
1059 | " age height weight ap_hi ap_lo \n",
1060 | "id \n",
1061 | "98125 0.388175 -0.549199 1.024035e-15 -0.056261 -0.035475 \n",
1062 | "28510 1.308468 0.333478 -6.621198e-01 -0.056261 -0.085423 \n",
1063 | "15795 1.346527 -0.549199 -5.900598e-01 -0.056261 -0.035475 \n",
1064 | "39560 1.291463 -0.170909 -1.382720e+00 -0.024632 -0.035475 \n",
1065 | "32677 0.912495 -0.801392 7.790809e-01 0.133513 -0.085423 "
1066 | ]
1067 | },
1068 | "execution_count": 14,
1069 | "metadata": {},
1070 | "output_type": "execute_result"
1071 | }
1072 | ],
1073 | "source": [
1074 | "preprocess_mapper.fit_transform(X_train, y_train).head()"
1075 | ]
1076 | },
1077 | {
1078 | "cell_type": "markdown",
1079 | "metadata": {},
1080 | "source": [
1081 | "### Custom CatBoost Classifier"
1082 | ]
1083 | },
1084 | {
1085 | "cell_type": "markdown",
1086 | "metadata": {},
1087 | "source": [
1088 | "We need to implement our own catboost classifier so we can track our categorical features at runtime"
1089 | ]
1090 | },
1091 | {
1092 | "cell_type": "code",
1093 | "execution_count": null,
1094 | "metadata": {},
1095 | "outputs": [],
1096 | "source": [
1097 | "class CustomCatBoostClassifier(CatBoostClassifier):\n",
1098 | "\n",
1099 | " def fit(self, X, y=None, **fit_params):\n",
1100 | " print(X.filter(regex=f\"{categorical_suffix}$\").columns.to_list())\n",
1101 | "\n",
1102 | " return super().fit(\n",
1103 | " X,\n",
1104 | " y=y,\n",
1105 | " cat_features=X.filter(regex=f\"{categorical_suffix}$\").columns,\n",
1106 | " **fit_params\n",
1107 | " )"
1108 | ]
1109 | },
1110 | {
1111 | "cell_type": "markdown",
1112 | "metadata": {},
1113 | "source": [
1114 | "### Feature Selection"
1115 | ]
1116 | },
1117 | {
1118 | "cell_type": "markdown",
1119 | "metadata": {},
1120 | "source": [
1121 | "In the feature selection step, we create out own custom feature selection so the input dataframe will stay a dataframe and not numpy array"
1122 | ]
1123 | },
1124 | {
1125 | "cell_type": "code",
1126 | "execution_count": null,
1127 | "metadata": {},
1128 | "outputs": [],
1129 | "source": [
1130 | "class CustomFeatureSelection(SelectFromModel):\n",
1131 | "\n",
1132 | " def transform(self, X):\n",
1133 | " \n",
1134 | " # Get indices of important features\n",
1135 | " important_features_indices = list(self.get_support(indices=True))\n",
1136 | "\n",
1137 | " # Select important features\n",
1138 | " _X = X.iloc[:, important_features_indices].copy()\n",
1139 | "\n",
1140 | " return _X"
1141 | ]
1142 | },
1143 | {
1144 | "cell_type": "markdown",
1145 | "metadata": {},
1146 | "source": [
1147 | "### Sklearn Pipeline"
1148 | ]
1149 | },
1150 | {
1151 | "cell_type": "markdown",
1152 | "metadata": {},
1153 | "source": [
1154 | "Now we piece together all prevoius definitions to define the full pipeline:\n",
1155 | "* preprocessing\n",
1156 | "* feature selection\n",
1157 | "* estimator"
1158 | ]
1159 | },
1160 | {
1161 | "cell_type": "code",
1162 | "execution_count": null,
1163 | "metadata": {},
1164 | "outputs": [],
1165 | "source": [
1166 | "pipeline = Pipeline(steps=[\n",
1167 | " (\"preprocess\", preprocess_mapper),\n",
1168 | " (\"feature_selection\", CustomFeatureSelection(CustomCatBoostClassifier(logging_level=\"Silent\"))),\n",
1169 | " (\"estimator\", CustomCatBoostClassifier(logging_level=\"Silent\"))\n",
1170 | "])"
1171 | ]
1172 | },
1173 | {
1174 | "cell_type": "code",
1175 | "execution_count": null,
1176 | "metadata": {},
1177 | "outputs": [
1178 | {
1179 | "name": "stdout",
1180 | "output_type": "stream",
1181 | "text": [
1182 | "['blood_pressure_#CAT#', 'unhealty_lifestyle_#CAT#', 'cholesterol_#CAT#', 'gluc_#CAT#', 'gender_#CAT#', 'smoke_#CAT#', 'alco_#CAT#', 'active_#CAT#']\n"
1183 | ]
1184 | }
1185 | ],
1186 | "source": [
1187 | "pipeline.fit(X_train, y_train)\n",
1188 | "preds = pipeline.predict(X_test)"
1189 | ]
1190 | },
1191 | {
1192 | "cell_type": "markdown",
1193 | "metadata": {},
1194 | "source": [
1195 | "## Grid Search"
1196 | ]
1197 | },
1198 | {
1199 | "cell_type": "markdown",
1200 | "metadata": {},
1201 | "source": [
1202 | "Next I created a grid search object which includes the original pipeline. \n",
1203 | "When I then call fit, the transformations are applied to the data, before a cross-validated grid-search is performed over the parameter grid."
1204 | ]
1205 | },
1206 | {
1207 | "cell_type": "code",
1208 | "execution_count": null,
1209 | "metadata": {},
1210 | "outputs": [],
1211 | "source": [
1212 | "pipeline = Pipeline(steps=[\n",
1213 | " (\"preprocess\", preprocess_mapper),\n",
1214 | " (\"feature_selection\", CustomFeatureSelection(CustomCatBoostClassifier(logging_level=\"Silent\"))),\n",
1215 | " (\"estimator\", CustomCatBoostClassifier(logging_level=\"Silent\"))\n",
1216 | "])"
1217 | ]
1218 | },
1219 | {
1220 | "cell_type": "code",
1221 | "execution_count": null,
1222 | "metadata": {},
1223 | "outputs": [],
1224 | "source": [
1225 | "param_grid = {\n",
1226 | " \"estimator__depth\": [4, 5, 6],\n",
1227 | " \"estimator__iterations\": [500, 1000],\n",
1228 | " \"estimator__learning_rate\": [0.001, 0.01, 0.1], \n",
1229 | " \"estimator__l2_leaf_reg\": [3, 5, 100],\n",
1230 | " \"estimator__border_count\": [10, 50, 200],\n",
1231 | " \"estimator__ctr_border_count\": [10, 50, 200],\n",
1232 | " \"estimator__thread_count\": [4]\n",
1233 | "}"
1234 | ]
1235 | },
1236 | {
1237 | "cell_type": "code",
1238 | "execution_count": null,
1239 | "metadata": {},
1240 | "outputs": [],
1241 | "source": [
1242 | "gscv_estimator = GridSearchCV(pipeline, param_grid, cv=5, n_jobs=-1)\n",
1243 | "gscv_estimator.fit(X_train, y_train)"
1244 | ]
1245 | },
1246 | {
1247 | "cell_type": "code",
1248 | "execution_count": null,
1249 | "metadata": {},
1250 | "outputs": [],
1251 | "source": [
1252 | "display(gscv_estimator.best_params_)"
1253 | ]
1254 | },
1255 | {
1256 | "cell_type": "code",
1257 | "execution_count": null,
1258 | "metadata": {},
1259 | "outputs": [],
1260 | "source": [
1261 | "preds = gscv_estimator.predict(X_test)"
1262 | ]
1263 | }
1264 | ],
1265 | "metadata": {
1266 | "kernelspec": {
1267 | "display_name": "Python 3",
1268 | "language": "python",
1269 | "name": "python3"
1270 | },
1271 | "language_info": {
1272 | "codemirror_mode": {
1273 | "name": "ipython",
1274 | "version": 3
1275 | },
1276 | "file_extension": ".py",
1277 | "mimetype": "text/x-python",
1278 | "name": "python",
1279 | "nbconvert_exporter": "python",
1280 | "pygments_lexer": "ipython3",
1281 | "version": "3.7.1"
1282 | }
1283 | },
1284 | "nbformat": 4,
1285 | "nbformat_minor": 2
1286 | }
1287 |
--------------------------------------------------------------------------------
/sklearn-pandas.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import numpy as np\n",
11 | "from sklearn_pandas import DataFrameMapper, gen_features\n",
12 | "from sklearn.model_selection import train_test_split\n",
13 | "from sklearn.impute import SimpleImputer\n",
14 | "from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler\n",
15 | "from sklearn.base import BaseEstimator, TransformerMixin\n",
16 | "from sklearn.pipeline import Pipeline\n",
17 | "from sklearn.feature_selection import SelectFromModel\n",
18 | "from sklearn.ensemble import RandomForestClassifier\n",
19 | "from sklearn.model_selection import GridSearchCV"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "## Read data\n",
27 | "We are not reading the original data from Kaggel. \n",
28 | "The data we are reading is based on the data from Kaggle with small changes:\n",
29 | "* Random replacement of values with NaNs\n",
30 | "* Transformations of categorical (int) features to text categories"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 2,
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "data": {
40 | "text/html": [
41 | "\n",
42 | "\n",
55 | "
\n",
56 | " \n",
57 | " \n",
58 | " | \n",
59 | " age | \n",
60 | " gender | \n",
61 | " height | \n",
62 | " weight | \n",
63 | " ap_hi | \n",
64 | " ap_lo | \n",
65 | " cholesterol | \n",
66 | " gluc | \n",
67 | " smoke | \n",
68 | " alco | \n",
69 | " active | \n",
70 | " cardio | \n",
71 | "
\n",
72 | " \n",
73 | " \n",
74 | " \n",
75 | " | count | \n",
76 | " 70000.000000 | \n",
77 | " 70000 | \n",
78 | " 66591.000000 | \n",
79 | " 66578.000000 | \n",
80 | " 70000.000000 | \n",
81 | " 70000.000000 | \n",
82 | " 66589 | \n",
83 | " 70000 | \n",
84 | " 70000.000000 | \n",
85 | " 70000.000000 | \n",
86 | " 70000.000000 | \n",
87 | " 70000.000000 | \n",
88 | "
\n",
89 | " \n",
90 | " | unique | \n",
91 | " NaN | \n",
92 | " 2 | \n",
93 | " NaN | \n",
94 | " NaN | \n",
95 | " NaN | \n",
96 | " NaN | \n",
97 | " 3 | \n",
98 | " 3 | \n",
99 | " NaN | \n",
100 | " NaN | \n",
101 | " NaN | \n",
102 | " NaN | \n",
103 | "
\n",
104 | " \n",
105 | " | top | \n",
106 | " NaN | \n",
107 | " women | \n",
108 | " NaN | \n",
109 | " NaN | \n",
110 | " NaN | \n",
111 | " NaN | \n",
112 | " normal | \n",
113 | " normal | \n",
114 | " NaN | \n",
115 | " NaN | \n",
116 | " NaN | \n",
117 | " NaN | \n",
118 | "
\n",
119 | " \n",
120 | " | freq | \n",
121 | " NaN | \n",
122 | " 45530 | \n",
123 | " NaN | \n",
124 | " NaN | \n",
125 | " NaN | \n",
126 | " NaN | \n",
127 | " 49789 | \n",
128 | " 59479 | \n",
129 | " NaN | \n",
130 | " NaN | \n",
131 | " NaN | \n",
132 | " NaN | \n",
133 | "
\n",
134 | " \n",
135 | " | mean | \n",
136 | " 19468.865814 | \n",
137 | " NaN | \n",
138 | " 164.361205 | \n",
139 | " 74.210467 | \n",
140 | " 128.817286 | \n",
141 | " 96.630414 | \n",
142 | " NaN | \n",
143 | " NaN | \n",
144 | " 0.088129 | \n",
145 | " 0.053771 | \n",
146 | " 0.803729 | \n",
147 | " 0.499700 | \n",
148 | "
\n",
149 | " \n",
150 | " | std | \n",
151 | " 2467.251667 | \n",
152 | " NaN | \n",
153 | " 8.226411 | \n",
154 | " 14.397678 | \n",
155 | " 154.011419 | \n",
156 | " 188.472530 | \n",
157 | " NaN | \n",
158 | " NaN | \n",
159 | " 0.283484 | \n",
160 | " 0.225568 | \n",
161 | " 0.397179 | \n",
162 | " 0.500003 | \n",
163 | "
\n",
164 | " \n",
165 | " | min | \n",
166 | " 10798.000000 | \n",
167 | " NaN | \n",
168 | " 55.000000 | \n",
169 | " 10.000000 | \n",
170 | " -150.000000 | \n",
171 | " -70.000000 | \n",
172 | " NaN | \n",
173 | " NaN | \n",
174 | " 0.000000 | \n",
175 | " 0.000000 | \n",
176 | " 0.000000 | \n",
177 | " 0.000000 | \n",
178 | "
\n",
179 | " \n",
180 | " | 25% | \n",
181 | " 17664.000000 | \n",
182 | " NaN | \n",
183 | " 159.000000 | \n",
184 | " 65.000000 | \n",
185 | " 120.000000 | \n",
186 | " 80.000000 | \n",
187 | " NaN | \n",
188 | " NaN | \n",
189 | " 0.000000 | \n",
190 | " 0.000000 | \n",
191 | " 1.000000 | \n",
192 | " 0.000000 | \n",
193 | "
\n",
194 | " \n",
195 | " | 50% | \n",
196 | " 19703.000000 | \n",
197 | " NaN | \n",
198 | " 165.000000 | \n",
199 | " 72.000000 | \n",
200 | " 120.000000 | \n",
201 | " 80.000000 | \n",
202 | " NaN | \n",
203 | " NaN | \n",
204 | " 0.000000 | \n",
205 | " 0.000000 | \n",
206 | " 1.000000 | \n",
207 | " 0.000000 | \n",
208 | "
\n",
209 | " \n",
210 | " | 75% | \n",
211 | " 21327.000000 | \n",
212 | " NaN | \n",
213 | " 170.000000 | \n",
214 | " 82.000000 | \n",
215 | " 140.000000 | \n",
216 | " 90.000000 | \n",
217 | " NaN | \n",
218 | " NaN | \n",
219 | " 0.000000 | \n",
220 | " 0.000000 | \n",
221 | " 1.000000 | \n",
222 | " 1.000000 | \n",
223 | "
\n",
224 | " \n",
225 | " | max | \n",
226 | " 23713.000000 | \n",
227 | " NaN | \n",
228 | " 250.000000 | \n",
229 | " 200.000000 | \n",
230 | " 16020.000000 | \n",
231 | " 11000.000000 | \n",
232 | " NaN | \n",
233 | " NaN | \n",
234 | " 1.000000 | \n",
235 | " 1.000000 | \n",
236 | " 1.000000 | \n",
237 | " 1.000000 | \n",
238 | "
\n",
239 | " \n",
240 | "
\n",
241 | "
"
242 | ],
243 | "text/plain": [
244 | " age gender height weight ap_hi \\\n",
245 | "count 70000.000000 70000 66591.000000 66578.000000 70000.000000 \n",
246 | "unique NaN 2 NaN NaN NaN \n",
247 | "top NaN women NaN NaN NaN \n",
248 | "freq NaN 45530 NaN NaN NaN \n",
249 | "mean 19468.865814 NaN 164.361205 74.210467 128.817286 \n",
250 | "std 2467.251667 NaN 8.226411 14.397678 154.011419 \n",
251 | "min 10798.000000 NaN 55.000000 10.000000 -150.000000 \n",
252 | "25% 17664.000000 NaN 159.000000 65.000000 120.000000 \n",
253 | "50% 19703.000000 NaN 165.000000 72.000000 120.000000 \n",
254 | "75% 21327.000000 NaN 170.000000 82.000000 140.000000 \n",
255 | "max 23713.000000 NaN 250.000000 200.000000 16020.000000 \n",
256 | "\n",
257 | " ap_lo cholesterol gluc smoke alco \\\n",
258 | "count 70000.000000 66589 70000 70000.000000 70000.000000 \n",
259 | "unique NaN 3 3 NaN NaN \n",
260 | "top NaN normal normal NaN NaN \n",
261 | "freq NaN 49789 59479 NaN NaN \n",
262 | "mean 96.630414 NaN NaN 0.088129 0.053771 \n",
263 | "std 188.472530 NaN NaN 0.283484 0.225568 \n",
264 | "min -70.000000 NaN NaN 0.000000 0.000000 \n",
265 | "25% 80.000000 NaN NaN 0.000000 0.000000 \n",
266 | "50% 80.000000 NaN NaN 0.000000 0.000000 \n",
267 | "75% 90.000000 NaN NaN 0.000000 0.000000 \n",
268 | "max 11000.000000 NaN NaN 1.000000 1.000000 \n",
269 | "\n",
270 | " active cardio \n",
271 | "count 70000.000000 70000.000000 \n",
272 | "unique NaN NaN \n",
273 | "top NaN NaN \n",
274 | "freq NaN NaN \n",
275 | "mean 0.803729 0.499700 \n",
276 | "std 0.397179 0.500003 \n",
277 | "min 0.000000 0.000000 \n",
278 | "25% 1.000000 0.000000 \n",
279 | "50% 1.000000 0.000000 \n",
280 | "75% 1.000000 1.000000 \n",
281 | "max 1.000000 1.000000 "
282 | ]
283 | },
284 | "metadata": {},
285 | "output_type": "display_data"
286 | },
287 | {
288 | "data": {
289 | "text/html": [
290 | "\n",
291 | "\n",
304 | "
\n",
305 | " \n",
306 | " \n",
307 | " | \n",
308 | " age | \n",
309 | " gender | \n",
310 | " height | \n",
311 | " weight | \n",
312 | " ap_hi | \n",
313 | " ap_lo | \n",
314 | " cholesterol | \n",
315 | " gluc | \n",
316 | " smoke | \n",
317 | " alco | \n",
318 | " active | \n",
319 | " cardio | \n",
320 | "
\n",
321 | " \n",
322 | " | id | \n",
323 | " | \n",
324 | " | \n",
325 | " | \n",
326 | " | \n",
327 | " | \n",
328 | " | \n",
329 | " | \n",
330 | " | \n",
331 | " | \n",
332 | " | \n",
333 | " | \n",
334 | " | \n",
335 | "
\n",
336 | " \n",
337 | " \n",
338 | " \n",
339 | " | 0 | \n",
340 | " 18393 | \n",
341 | " men | \n",
342 | " 168.0 | \n",
343 | " 62.0 | \n",
344 | " 110 | \n",
345 | " 80 | \n",
346 | " normal | \n",
347 | " normal | \n",
348 | " 0 | \n",
349 | " 0 | \n",
350 | " 1 | \n",
351 | " 0 | \n",
352 | "
\n",
353 | " \n",
354 | " | 1 | \n",
355 | " 20228 | \n",
356 | " women | \n",
357 | " 156.0 | \n",
358 | " 85.0 | \n",
359 | " 140 | \n",
360 | " 90 | \n",
361 | " well_above_normal | \n",
362 | " normal | \n",
363 | " 0 | \n",
364 | " 0 | \n",
365 | " 1 | \n",
366 | " 1 | \n",
367 | "
\n",
368 | " \n",
369 | " | 2 | \n",
370 | " 18857 | \n",
371 | " women | \n",
372 | " 165.0 | \n",
373 | " 64.0 | \n",
374 | " 130 | \n",
375 | " 70 | \n",
376 | " NaN | \n",
377 | " normal | \n",
378 | " 0 | \n",
379 | " 0 | \n",
380 | " 0 | \n",
381 | " 1 | \n",
382 | "
\n",
383 | " \n",
384 | " | 3 | \n",
385 | " 17623 | \n",
386 | " men | \n",
387 | " 169.0 | \n",
388 | " 82.0 | \n",
389 | " 150 | \n",
390 | " 100 | \n",
391 | " normal | \n",
392 | " normal | \n",
393 | " 0 | \n",
394 | " 0 | \n",
395 | " 1 | \n",
396 | " 1 | \n",
397 | "
\n",
398 | " \n",
399 | " | 4 | \n",
400 | " 17474 | \n",
401 | " women | \n",
402 | " 156.0 | \n",
403 | " 56.0 | \n",
404 | " 100 | \n",
405 | " 60 | \n",
406 | " normal | \n",
407 | " normal | \n",
408 | " 0 | \n",
409 | " 0 | \n",
410 | " 0 | \n",
411 | " 0 | \n",
412 | "
\n",
413 | " \n",
414 | "
\n",
415 | "
"
416 | ],
417 | "text/plain": [
418 | " age gender height weight ap_hi ap_lo cholesterol gluc \\\n",
419 | "id \n",
420 | "0 18393 men 168.0 62.0 110 80 normal normal \n",
421 | "1 20228 women 156.0 85.0 140 90 well_above_normal normal \n",
422 | "2 18857 women 165.0 64.0 130 70 NaN normal \n",
423 | "3 17623 men 169.0 82.0 150 100 normal normal \n",
424 | "4 17474 women 156.0 56.0 100 60 normal normal \n",
425 | "\n",
426 | " smoke alco active cardio \n",
427 | "id \n",
428 | "0 0 0 1 0 \n",
429 | "1 0 0 1 1 \n",
430 | "2 0 0 0 1 \n",
431 | "3 0 0 1 1 \n",
432 | "4 0 0 0 0 "
433 | ]
434 | },
435 | "metadata": {},
436 | "output_type": "display_data"
437 | }
438 | ],
439 | "source": [
440 | "np.random.seed(seed=42)\n",
441 | "df_data = pd.read_csv(\"./cardiovascular-disease-dataset/messy/cardio_train.csv\", sep=';', index_col=\"id\")\n",
442 | "display(df_data.describe(include=\"all\"))\n",
443 | "display(df_data.head())"
444 | ]
445 | },
446 | {
447 | "cell_type": "markdown",
448 | "metadata": {},
449 | "source": [
450 | "## Initializations "
451 | ]
452 | },
453 | {
454 | "cell_type": "markdown",
455 | "metadata": {},
456 | "source": [
457 | "Declerations of all column types and target column"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": 3,
463 | "metadata": {},
464 | "outputs": [],
465 | "source": [
466 | "category_features = [[\"cholesterol\"], [\"gluc\"]]\n",
467 | "binary_features = [[\"gender\"], [\"smoke\"], [\"alco\"], [\"active\"]]\n",
468 | "numeric_features = [[\"age\"], [\"height\"], [\"weight\"], [\"ap_hi\"], [\"ap_lo\"]]\n",
469 | "target = \"cardio\""
470 | ]
471 | },
472 | {
473 | "cell_type": "markdown",
474 | "metadata": {},
475 | "source": [
476 | "Split the data into features and labels"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": 4,
482 | "metadata": {},
483 | "outputs": [],
484 | "source": [
485 | "X = df_data.copy()\n",
486 | "y = X.pop(target)"
487 | ]
488 | },
489 | {
490 | "cell_type": "markdown",
491 | "metadata": {},
492 | "source": [
493 | "## Split data to train test datasets"
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": 5,
499 | "metadata": {},
500 | "outputs": [],
501 | "source": [
502 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)"
503 | ]
504 | },
505 | {
506 | "cell_type": "markdown",
507 | "metadata": {},
508 | "source": [
509 | "## Custom Transformers"
510 | ]
511 | },
512 | {
513 | "cell_type": "markdown",
514 | "metadata": {},
515 | "source": [
516 | "We create custom transformers for feature engineering"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "metadata": {},
522 | "source": [
523 | "### Blood Pressure Transformer"
524 | ]
525 | },
526 | {
527 | "cell_type": "markdown",
528 | "metadata": {},
529 | "source": [
530 | "Custom transformer responsible for the creation of a new blood pressure categorical feature based on systolic (the number at the top) and diastolic (the number at the bottom) values. \n",
531 | "The transformer will create a new categorical feature with values according to the American Heart Association ranges of blood pressure:\n",
532 | "* normal \n",
533 | "* elevated \n",
534 | "* high_pressure_stage_1 \n",
535 | "* high_pressure_stage_2 \n",
536 | "* hypertensive_crisis \n",
537 | " \n",
538 | " \n",
539 | " \n",
540 | "Photo from [American Heart Association](https://www.heart.org/-/media/data-import/downloadables/pe-abh-what-is-high-blood-pressure-ucm_300310.pdf?la=en&hash=CAC0F1D377BDB7BC3870993918226869524AAC3D)"
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": 6,
546 | "metadata": {},
547 | "outputs": [],
548 | "source": [
549 | "class BloodPressureTransformer(BaseEstimator, TransformerMixin):\n",
550 | " \n",
551 | " def __init__(self):\n",
552 | " \n",
553 | " # Systolic and diastolic blood pressure ranges based on the American Heart Association\n",
554 | " self.systolic_ranges = [-np.inf, 119, 129, 139, 180, np.inf]\n",
555 | " self.diastolic_ranges = [-np.inf, 79, 89, 120, np.inf]\n",
556 | " \n",
557 | " # Blood pressure categories\n",
558 | " self.blood_pressure_category = [\"normal\", \"elevated\", \"high_pressure_stage_1\", \"high_pressure_stage_2\", \"hypertensive_crisis\"]\n",
559 | " \n",
560 | " def fit(self, X, y=None):\n",
561 | " return self\n",
562 | " \n",
563 | " def transform(self, X):\n",
564 | " # Copy the data so we will not change the original instance\n",
565 | " df_blood_pressure = X.copy()\n",
566 | " \n",
567 | " # Break down ranges of systolic values to categories\n",
568 | " df_blood_pressure[\"systolic\"] = pd.cut(df_blood_pressure[\"ap_hi\"], self.systolic_ranges, labels=[\"<120\", \"120-129\", \"130-139\", \"140-180\", \">180\"])\n",
569 | " \n",
570 | " # Break down ranges of diastolic values to categories\n",
571 | " df_blood_pressure[\"diastolic\"] = pd.cut(df_blood_pressure[\"ap_lo\"], self.diastolic_ranges, labels=[\"<79\", \"80-89\", \"90-120\", \">120\"])\n",
572 | " \n",
573 | " # Combine ranges from systolic and diastolic features to determine the category of the blood pressure feature\n",
574 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \"<120\") &\n",
575 | " (df_blood_pressure[\"diastolic\"] == \"<79\"), \"blood_pressure\"] = self.blood_pressure_category[0]\n",
576 | " \n",
577 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \"120-129\") &\n",
578 | " (df_blood_pressure[\"diastolic\"] == \"<79\"), \"blood_pressure\"] = self.blood_pressure_category[1]\n",
579 | " \n",
580 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \"130-139\") |\n",
581 | " (df_blood_pressure[\"diastolic\"] == \"80-89\"), \"blood_pressure\"] = self.blood_pressure_category[2]\n",
582 | " \n",
583 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \"140-180\") |\n",
584 | " (df_blood_pressure[\"diastolic\"] == \"90-120\"), \"blood_pressure\"] = self.blood_pressure_category[3]\n",
585 | " \n",
586 | " df_blood_pressure.loc[(df_blood_pressure[\"systolic\"] == \">180\") |\n",
587 | " (df_blood_pressure[\"diastolic\"] == \">120\"), \"blood_pressure\"] = self.blood_pressure_category[4]\n",
588 | " \n",
589 | " # Return blood pressure feature as a dataframe with one column\n",
590 | " return df_blood_pressure[[\"blood_pressure\"]]"
591 | ]
592 | },
593 | {
594 | "cell_type": "markdown",
595 | "metadata": {},
596 | "source": [
597 | "### Unhealty Lifestyle Transformer"
598 | ]
599 | },
600 | {
601 | "cell_type": "markdown",
602 | "metadata": {},
603 | "source": [
604 | "Custom transformer responsible for the creation of a new \"unhealty lifestyle\" feature. \n",
605 | "This is a boolean feature representing the use of cigarettes, alcohol, and physical inactivity. "
606 | ]
607 | },
608 | {
609 | "cell_type": "code",
610 | "execution_count": 7,
611 | "metadata": {},
612 | "outputs": [],
613 | "source": [
614 | "class UnhealtyLifestyleTransformer(BaseEstimator, TransformerMixin):\n",
615 | " \n",
616 | " def fit(self, X, y=None):\n",
617 | " return self\n",
618 | " \n",
619 | " def transform(self, X):\n",
620 | " # Copy the data so we will not change the original instance\n",
621 | " df_unhealty_lifestyle = X.copy()\n",
622 | " \n",
623 | " # If you smoke or use alcohol or don't do physical activity, you maintain an unhealty lifestyle!\n",
624 | " df_unhealty_lifestyle[\"unhealty_lifestyle\"] = df_unhealty_lifestyle[\"smoke\"] | df_unhealty_lifestyle[\"alco\"] | (1 - df_unhealty_lifestyle[\"active\"])\n",
625 | " \n",
626 | " # Return unhealty lifestyle feature as a dataframe with one column\n",
627 | " return df_unhealty_lifestyle[[\"unhealty_lifestyle\"]]"
628 | ]
629 | },
630 | {
631 | "cell_type": "markdown",
632 | "metadata": {},
633 | "source": [
634 | "## Definition of DataFrameMapper transformers"
635 | ]
636 | },
637 | {
638 | "cell_type": "markdown",
639 | "metadata": {},
640 | "source": [
641 | "Now we will define the pipeline of transformations and the raw features we need to complete the creation and processing of the new features and the original features. \n",
642 | "We will pass this to the DataFrameMapper class of the sklearn-pandas package."
643 | ]
644 | },
645 | {
646 | "cell_type": "code",
647 | "execution_count": 8,
648 | "metadata": {},
649 | "outputs": [
650 | {
651 | "data": {
652 | "text/plain": [
653 | "(['ap_hi', 'ap_lo'],\n",
654 | " [BloodPressureTransformer(),\n",
655 | " SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
656 | " strategy='most_frequent', verbose=0),\n",
657 | " OneHotEncoder(categorical_features=None, categories=None,\n",
658 | " dtype=, handle_unknown='error',\n",
659 | " n_values=None, sparse=True)],\n",
660 | " {'alias': 'blood_pressure'})"
661 | ]
662 | },
663 | "execution_count": 8,
664 | "metadata": {},
665 | "output_type": "execute_result"
666 | }
667 | ],
668 | "source": [
669 | "# Input features \"ap_hi\", \"ap_lo\".\n",
670 | "# Steps:\n",
671 | "# BloodPressureTransformer - create blood pressure feature based on \"ap_hi\", \"ap_lo\".\n",
672 | "# SimpleImputer - fill nans with the most frequent value.\n",
673 | "# OneHotEncoder - encode categorical values as a one-hot numeric array.\n",
674 | "gen_blood_pressure = (\n",
675 | " [\"ap_hi\", \"ap_lo\"],\n",
676 | " [\n",
677 | " BloodPressureTransformer(),\n",
678 | " SimpleImputer(strategy=\"most_frequent\"),\n",
679 | " OneHotEncoder()\n",
680 | " ],\n",
681 | " {\"alias\": \"blood_pressure\"}\n",
682 | ")\n",
683 | "\n",
684 | "gen_blood_pressure"
685 | ]
686 | },
687 | {
688 | "cell_type": "code",
689 | "execution_count": 9,
690 | "metadata": {},
691 | "outputs": [
692 | {
693 | "data": {
694 | "text/plain": [
695 | "(['smoke', 'alco', 'active'],\n",
696 | " [UnhealtyLifestyleTransformer(),\n",
697 | " SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
698 | " strategy='most_frequent', verbose=0)],\n",
699 | " {'alias': 'unhealty_lifestyle'})"
700 | ]
701 | },
702 | "execution_count": 9,
703 | "metadata": {},
704 | "output_type": "execute_result"
705 | }
706 | ],
707 | "source": [
708 | "# Input features [\"smoke\", \"alco\", \"active\"].\n",
709 | "# Steps:\n",
710 | "# UnhealtyLifestyleTransformer - create unhealty lifestyle feature based on \"smoke\", \"alco\", \"active\".\n",
711 | "# SimpleImputer - fill nans with the most frequent value.\n",
712 | "gen_unhealty_lifestyle = (\n",
713 | " [\"smoke\", \"alco\", \"active\"],\n",
714 | " [\n",
715 | " UnhealtyLifestyleTransformer(),\n",
716 | " SimpleImputer(strategy=\"most_frequent\")\n",
717 | " ],\n",
718 | " {\"alias\": \"unhealty_lifestyle\"}\n",
719 | ")\n",
720 | "\n",
721 | "gen_unhealty_lifestyle"
722 | ]
723 | },
724 | {
725 | "cell_type": "markdown",
726 | "metadata": {},
727 | "source": [
728 | "### Apply the same transformers for multiple columns with gen_features"
729 | ]
730 | },
731 | {
732 | "cell_type": "code",
733 | "execution_count": 10,
734 | "metadata": {},
735 | "outputs": [
736 | {
737 | "data": {
738 | "text/plain": [
739 | "[(['cholesterol'],\n",
740 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
741 | " strategy='most_frequent', verbose=0),\n",
742 | " OneHotEncoder(categorical_features=None, categories=None,\n",
743 | " dtype=, handle_unknown='error',\n",
744 | " n_values=None, sparse=True)]),\n",
745 | " (['gluc'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
746 | " strategy='most_frequent', verbose=0),\n",
747 | " OneHotEncoder(categorical_features=None, categories=None,\n",
748 | " dtype=, handle_unknown='error',\n",
749 | " n_values=None, sparse=True)])]"
750 | ]
751 | },
752 | "execution_count": 10,
753 | "metadata": {},
754 | "output_type": "execute_result"
755 | }
756 | ],
757 | "source": [
758 | "# Input features [[\"cholesterol\"], [\"gluc\"]] (The columns are now list of lists because we want to send 2-dimentional DataFrame to each of the transformers).\n",
759 | "# Steps:\n",
760 | "# SimpleImputer - fill nans with the most frequent value.\n",
761 | "# OneHotEncoder - encode categorical values as a one-hot numeric array.\n",
762 | "gen_category = gen_features(\n",
763 | " columns=category_features,\n",
764 | " classes=[\n",
765 | " {\n",
766 | " \"class\": SimpleImputer,\n",
767 | " \"strategy\": \"most_frequent\"\n",
768 | " },\n",
769 | " {\n",
770 | " \"class\": OneHotEncoder\n",
771 | " }\n",
772 | " ]\n",
773 | ")\n",
774 | "\n",
775 | "gen_category"
776 | ]
777 | },
778 | {
779 | "cell_type": "code",
780 | "execution_count": 11,
781 | "metadata": {},
782 | "outputs": [
783 | {
784 | "data": {
785 | "text/plain": [
786 | "[(['gender'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
787 | " strategy='most_frequent', verbose=0),\n",
788 | " OrdinalEncoder(categories='auto', dtype=)]),\n",
789 | " (['smoke'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
790 | " strategy='most_frequent', verbose=0),\n",
791 | " OrdinalEncoder(categories='auto', dtype=)]),\n",
792 | " (['alco'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
793 | " strategy='most_frequent', verbose=0),\n",
794 | " OrdinalEncoder(categories='auto', dtype=)]),\n",
795 | " (['active'], [SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
796 | " strategy='most_frequent', verbose=0),\n",
797 | " OrdinalEncoder(categories='auto', dtype=)])]"
798 | ]
799 | },
800 | "execution_count": 11,
801 | "metadata": {},
802 | "output_type": "execute_result"
803 | }
804 | ],
805 | "source": [
806 | "# Input features [[\"gender\"], [\"smoke\"], [\"alco\"], [\"active\"]] (The columns are now list of lists because we want to send 2-dimentional DataFrame to each of the transformers).\n",
807 | "# Steps:\n",
808 | "# SimpleImputer - fill nans with the most frequent value.\n",
809 | "# OrdinalEncoder - encode categorical features as an integer array.\n",
810 | "gen_binary = gen_features(\n",
811 | " columns=binary_features,\n",
812 | " classes=[\n",
813 | " {\n",
814 | " \"class\": SimpleImputer,\n",
815 | " \"strategy\": \"most_frequent\"\n",
816 | " },\n",
817 | " {\n",
818 | " \"class\": OrdinalEncoder\n",
819 | " }\n",
820 | " ]\n",
821 | ")\n",
822 | "\n",
823 | "gen_binary"
824 | ]
825 | },
826 | {
827 | "cell_type": "code",
828 | "execution_count": 12,
829 | "metadata": {},
830 | "outputs": [
831 | {
832 | "data": {
833 | "text/plain": [
834 | "[(['age'],\n",
835 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
836 | " verbose=0),\n",
837 | " StandardScaler(copy=True, with_mean=True, with_std=True)]),\n",
838 | " (['height'],\n",
839 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
840 | " verbose=0),\n",
841 | " StandardScaler(copy=True, with_mean=True, with_std=True)]),\n",
842 | " (['weight'],\n",
843 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
844 | " verbose=0),\n",
845 | " StandardScaler(copy=True, with_mean=True, with_std=True)]),\n",
846 | " (['ap_hi'],\n",
847 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
848 | " verbose=0),\n",
849 | " StandardScaler(copy=True, with_mean=True, with_std=True)]),\n",
850 | " (['ap_lo'],\n",
851 | " [SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
852 | " verbose=0),\n",
853 | " StandardScaler(copy=True, with_mean=True, with_std=True)])]"
854 | ]
855 | },
856 | "execution_count": 12,
857 | "metadata": {},
858 | "output_type": "execute_result"
859 | }
860 | ],
861 | "source": [
862 | "# Input features [[\"age\"], [\"height\"], [\"weight\"], [\"ap_hi\"], [\"ap_lo\"]] (The columns are now list of lists because we want to send 2-dimentional DataFrame to each of the transformers).\n",
863 | "# Steps:\n",
864 | "# SimpleImputer - fill nans with the mean value.\n",
865 | "# StandardScaler - standardize features by removing the mean and scaling to unit variance.\n",
866 | "gen_numeric = gen_features(\n",
867 | " columns=numeric_features,\n",
868 | " classes=[\n",
869 | " {\n",
870 | " \"class\": SimpleImputer,\n",
871 | " \"strategy\": \"mean\"\n",
872 | " },\n",
873 | " {\n",
874 | " \"class\": StandardScaler\n",
875 | " }\n",
876 | " ]\n",
877 | ")\n",
878 | "\n",
879 | "gen_numeric"
880 | ]
881 | },
882 | {
883 | "cell_type": "markdown",
884 | "metadata": {},
885 | "source": [
886 | "### DataFrameMapper Construction"
887 | ]
888 | },
889 | {
890 | "cell_type": "markdown",
891 | "metadata": {},
892 | "source": [
893 | "Now we will define the course of action of the DataFrameMapper and indicate that the input and output will be Pandas Dataframe."
894 | ]
895 | },
896 | {
897 | "cell_type": "code",
898 | "execution_count": 13,
899 | "metadata": {},
900 | "outputs": [],
901 | "source": [
902 | "preprocess_mapper = DataFrameMapper(\n",
903 | " [\n",
904 | " gen_blood_pressure,\n",
905 | " gen_unhealty_lifestyle,\n",
906 | " *gen_category,\n",
907 | " *gen_binary,\n",
908 | " *gen_numeric,\n",
909 | " ],\n",
910 | " input_df=True,\n",
911 | " df_out=True\n",
912 | ")"
913 | ]
914 | },
915 | {
916 | "cell_type": "code",
917 | "execution_count": 14,
918 | "metadata": {},
919 | "outputs": [
920 | {
921 | "data": {
922 | "text/html": [
923 | "\n",
924 | "\n",
937 | "
\n",
938 | " \n",
939 | " \n",
940 | " | \n",
941 | " blood_pressure_x0_elevated | \n",
942 | " blood_pressure_x0_high_pressure_stage_1 | \n",
943 | " blood_pressure_x0_high_pressure_stage_2 | \n",
944 | " blood_pressure_x0_hypertensive_crisis | \n",
945 | " blood_pressure_x0_normal | \n",
946 | " unhealty_lifestyle | \n",
947 | " cholesterol_x0_above_normal | \n",
948 | " cholesterol_x0_normal | \n",
949 | " cholesterol_x0_well_above_normal | \n",
950 | " gluc_x0_above_normal | \n",
951 | " ... | \n",
952 | " gluc_x0_well_above_normal | \n",
953 | " gender | \n",
954 | " smoke | \n",
955 | " alco | \n",
956 | " active | \n",
957 | " age | \n",
958 | " height | \n",
959 | " weight | \n",
960 | " ap_hi | \n",
961 | " ap_lo | \n",
962 | "
\n",
963 | " \n",
964 | " | id | \n",
965 | " | \n",
966 | " | \n",
967 | " | \n",
968 | " | \n",
969 | " | \n",
970 | " | \n",
971 | " | \n",
972 | " | \n",
973 | " | \n",
974 | " | \n",
975 | " | \n",
976 | " | \n",
977 | " | \n",
978 | " | \n",
979 | " | \n",
980 | " | \n",
981 | " | \n",
982 | " | \n",
983 | " | \n",
984 | " | \n",
985 | " | \n",
986 | "
\n",
987 | " \n",
988 | " \n",
989 | " \n",
990 | " | 98125 | \n",
991 | " 0.0 | \n",
992 | " 0.0 | \n",
993 | " 1.0 | \n",
994 | " 0.0 | \n",
995 | " 0.0 | \n",
996 | " 0 | \n",
997 | " 0.0 | \n",
998 | " 0.0 | \n",
999 | " 1.0 | \n",
1000 | " 0.0 | \n",
1001 | " ... | \n",
1002 | " 0.0 | \n",
1003 | " 1.0 | \n",
1004 | " 0.0 | \n",
1005 | " 0.0 | \n",
1006 | " 1.0 | \n",
1007 | " 0.388175 | \n",
1008 | " -0.549199 | \n",
1009 | " 1.024035e-15 | \n",
1010 | " -0.056261 | \n",
1011 | " -0.035475 | \n",
1012 | "
\n",
1013 | " \n",
1014 | " | 28510 | \n",
1015 | " 0.0 | \n",
1016 | " 1.0 | \n",
1017 | " 0.0 | \n",
1018 | " 0.0 | \n",
1019 | " 0.0 | \n",
1020 | " 1 | \n",
1021 | " 0.0 | \n",
1022 | " 0.0 | \n",
1023 | " 1.0 | \n",
1024 | " 0.0 | \n",
1025 | " ... | \n",
1026 | " 1.0 | \n",
1027 | " 0.0 | \n",
1028 | " 0.0 | \n",
1029 | " 0.0 | \n",
1030 | " 0.0 | \n",
1031 | " 1.308468 | \n",
1032 | " 0.333478 | \n",
1033 | " -6.621198e-01 | \n",
1034 | " -0.056261 | \n",
1035 | " -0.085423 | \n",
1036 | "
\n",
1037 | " \n",
1038 | " | 15795 | \n",
1039 | " 0.0 | \n",
1040 | " 0.0 | \n",
1041 | " 1.0 | \n",
1042 | " 0.0 | \n",
1043 | " 0.0 | \n",
1044 | " 0 | \n",
1045 | " 0.0 | \n",
1046 | " 1.0 | \n",
1047 | " 0.0 | \n",
1048 | " 0.0 | \n",
1049 | " ... | \n",
1050 | " 0.0 | \n",
1051 | " 1.0 | \n",
1052 | " 0.0 | \n",
1053 | " 0.0 | \n",
1054 | " 1.0 | \n",
1055 | " 1.346527 | \n",
1056 | " -0.549199 | \n",
1057 | " -5.900598e-01 | \n",
1058 | " -0.056261 | \n",
1059 | " -0.035475 | \n",
1060 | "
\n",
1061 | " \n",
1062 | " | 39560 | \n",
1063 | " 0.0 | \n",
1064 | " 0.0 | \n",
1065 | " 1.0 | \n",
1066 | " 0.0 | \n",
1067 | " 0.0 | \n",
1068 | " 0 | \n",
1069 | " 0.0 | \n",
1070 | " 0.0 | \n",
1071 | " 1.0 | \n",
1072 | " 0.0 | \n",
1073 | " ... | \n",
1074 | " 0.0 | \n",
1075 | " 1.0 | \n",
1076 | " 0.0 | \n",
1077 | " 0.0 | \n",
1078 | " 1.0 | \n",
1079 | " 1.291463 | \n",
1080 | " -0.170909 | \n",
1081 | " -1.382720e+00 | \n",
1082 | " -0.024632 | \n",
1083 | " -0.035475 | \n",
1084 | "
\n",
1085 | " \n",
1086 | " | 32677 | \n",
1087 | " 0.0 | \n",
1088 | " 0.0 | \n",
1089 | " 1.0 | \n",
1090 | " 0.0 | \n",
1091 | " 0.0 | \n",
1092 | " 0 | \n",
1093 | " 0.0 | \n",
1094 | " 0.0 | \n",
1095 | " 1.0 | \n",
1096 | " 0.0 | \n",
1097 | " ... | \n",
1098 | " 0.0 | \n",
1099 | " 1.0 | \n",
1100 | " 0.0 | \n",
1101 | " 0.0 | \n",
1102 | " 1.0 | \n",
1103 | " 0.912495 | \n",
1104 | " -0.801392 | \n",
1105 | " 7.790809e-01 | \n",
1106 | " 0.133513 | \n",
1107 | " -0.085423 | \n",
1108 | "
\n",
1109 | " \n",
1110 | "
\n",
1111 | "
5 rows × 21 columns
\n",
1112 | "
"
1113 | ],
1114 | "text/plain": [
1115 | " blood_pressure_x0_elevated blood_pressure_x0_high_pressure_stage_1 \\\n",
1116 | "id \n",
1117 | "98125 0.0 0.0 \n",
1118 | "28510 0.0 1.0 \n",
1119 | "15795 0.0 0.0 \n",
1120 | "39560 0.0 0.0 \n",
1121 | "32677 0.0 0.0 \n",
1122 | "\n",
1123 | " blood_pressure_x0_high_pressure_stage_2 \\\n",
1124 | "id \n",
1125 | "98125 1.0 \n",
1126 | "28510 0.0 \n",
1127 | "15795 1.0 \n",
1128 | "39560 1.0 \n",
1129 | "32677 1.0 \n",
1130 | "\n",
1131 | " blood_pressure_x0_hypertensive_crisis blood_pressure_x0_normal \\\n",
1132 | "id \n",
1133 | "98125 0.0 0.0 \n",
1134 | "28510 0.0 0.0 \n",
1135 | "15795 0.0 0.0 \n",
1136 | "39560 0.0 0.0 \n",
1137 | "32677 0.0 0.0 \n",
1138 | "\n",
1139 | " unhealty_lifestyle cholesterol_x0_above_normal cholesterol_x0_normal \\\n",
1140 | "id \n",
1141 | "98125 0 0.0 0.0 \n",
1142 | "28510 1 0.0 0.0 \n",
1143 | "15795 0 0.0 1.0 \n",
1144 | "39560 0 0.0 0.0 \n",
1145 | "32677 0 0.0 0.0 \n",
1146 | "\n",
1147 | " cholesterol_x0_well_above_normal gluc_x0_above_normal ... \\\n",
1148 | "id ... \n",
1149 | "98125 1.0 0.0 ... \n",
1150 | "28510 1.0 0.0 ... \n",
1151 | "15795 0.0 0.0 ... \n",
1152 | "39560 1.0 0.0 ... \n",
1153 | "32677 1.0 0.0 ... \n",
1154 | "\n",
1155 | " gluc_x0_well_above_normal gender smoke alco active age \\\n",
1156 | "id \n",
1157 | "98125 0.0 1.0 0.0 0.0 1.0 0.388175 \n",
1158 | "28510 1.0 0.0 0.0 0.0 0.0 1.308468 \n",
1159 | "15795 0.0 1.0 0.0 0.0 1.0 1.346527 \n",
1160 | "39560 0.0 1.0 0.0 0.0 1.0 1.291463 \n",
1161 | "32677 0.0 1.0 0.0 0.0 1.0 0.912495 \n",
1162 | "\n",
1163 | " height weight ap_hi ap_lo \n",
1164 | "id \n",
1165 | "98125 -0.549199 1.024035e-15 -0.056261 -0.035475 \n",
1166 | "28510 0.333478 -6.621198e-01 -0.056261 -0.085423 \n",
1167 | "15795 -0.549199 -5.900598e-01 -0.056261 -0.035475 \n",
1168 | "39560 -0.170909 -1.382720e+00 -0.024632 -0.035475 \n",
1169 | "32677 -0.801392 7.790809e-01 0.133513 -0.085423 \n",
1170 | "\n",
1171 | "[5 rows x 21 columns]"
1172 | ]
1173 | },
1174 | "execution_count": 14,
1175 | "metadata": {},
1176 | "output_type": "execute_result"
1177 | }
1178 | ],
1179 | "source": [
1180 | "preprocess_mapper.fit_transform(X_train, y_train).head()"
1181 | ]
1182 | },
1183 | {
1184 | "cell_type": "markdown",
1185 | "metadata": {},
1186 | "source": [
1187 | "### Feature Selection"
1188 | ]
1189 | },
1190 | {
1191 | "cell_type": "markdown",
1192 | "metadata": {},
1193 | "source": [
1194 | "In the feature selection step, we specify that the input columns are the transform columns of the previous preprocessing step and then specify the feature selection transformer we want to use. \n",
1195 | "In our case I choose SelectFromModel with RandomForestClassifier to select the most important features based on feature_importances_ attribute of RandomForestClassifier. "
1196 | ]
1197 | },
1198 | {
1199 | "cell_type": "code",
1200 | "execution_count": 15,
1201 | "metadata": {},
1202 | "outputs": [
1203 | {
1204 | "data": {
1205 | "text/plain": [
1206 | "['blood_pressure_x0_elevated',\n",
1207 | " 'blood_pressure_x0_high_pressure_stage_1',\n",
1208 | " 'blood_pressure_x0_high_pressure_stage_2',\n",
1209 | " 'blood_pressure_x0_hypertensive_crisis',\n",
1210 | " 'blood_pressure_x0_normal',\n",
1211 | " 'unhealty_lifestyle',\n",
1212 | " 'cholesterol_x0_above_normal',\n",
1213 | " 'cholesterol_x0_normal',\n",
1214 | " 'cholesterol_x0_well_above_normal',\n",
1215 | " 'gluc_x0_above_normal',\n",
1216 | " 'gluc_x0_normal',\n",
1217 | " 'gluc_x0_well_above_normal',\n",
1218 | " 'gender',\n",
1219 | " 'smoke',\n",
1220 | " 'alco',\n",
1221 | " 'active',\n",
1222 | " 'age',\n",
1223 | " 'height',\n",
1224 | " 'weight',\n",
1225 | " 'ap_hi',\n",
1226 | " 'ap_lo']"
1227 | ]
1228 | },
1229 | "execution_count": 15,
1230 | "metadata": {},
1231 | "output_type": "execute_result"
1232 | }
1233 | ],
1234 | "source": [
1235 | "preprocess_mapper.transformed_names_"
1236 | ]
1237 | },
1238 | {
1239 | "cell_type": "code",
1240 | "execution_count": 16,
1241 | "metadata": {},
1242 | "outputs": [],
1243 | "source": [
1244 | "feature_selection = DataFrameMapper(\n",
1245 | " [(\n",
1246 | " preprocess_mapper.transformed_names_,\n",
1247 | " SelectFromModel(RandomForestClassifier(n_estimators=100))\n",
1248 | " )]\n",
1249 | ")"
1250 | ]
1251 | },
1252 | {
1253 | "cell_type": "markdown",
1254 | "metadata": {},
1255 | "source": [
1256 | "### Sklearn Pipeline"
1257 | ]
1258 | },
1259 | {
1260 | "cell_type": "markdown",
1261 | "metadata": {},
1262 | "source": [
1263 | "Now we piece together all prevoius definitions to define the full pipeline:\n",
1264 | "* preprocessing\n",
1265 | "* feature selection\n",
1266 | "* estimator"
1267 | ]
1268 | },
1269 | {
1270 | "cell_type": "code",
1271 | "execution_count": 17,
1272 | "metadata": {},
1273 | "outputs": [],
1274 | "source": [
1275 | "pipeline = Pipeline(steps=[\n",
1276 | " (\"preprocess\", preprocess_mapper),\n",
1277 | " (\"feature_selection\", feature_selection),\n",
1278 | " (\"estimator\", RandomForestClassifier(n_estimators=100, max_depth=6))\n",
1279 | "])"
1280 | ]
1281 | },
1282 | {
1283 | "cell_type": "code",
1284 | "execution_count": 18,
1285 | "metadata": {},
1286 | "outputs": [],
1287 | "source": [
1288 | "pipeline.fit(X_train, y_train)\n",
1289 | "preds = pipeline.predict(X_test)"
1290 | ]
1291 | },
1292 | {
1293 | "cell_type": "markdown",
1294 | "metadata": {},
1295 | "source": [
1296 | "## Grid Search"
1297 | ]
1298 | },
1299 | {
1300 | "cell_type": "markdown",
1301 | "metadata": {},
1302 | "source": [
1303 | "Next I created a grid search object which includes the original pipeline. \n",
1304 | "When I then call fit, the transformations are applied to the data, before a cross-validated grid-search is performed over the parameter grid."
1305 | ]
1306 | },
1307 | {
1308 | "cell_type": "code",
1309 | "execution_count": 19,
1310 | "metadata": {},
1311 | "outputs": [],
1312 | "source": [
1313 | "pipeline = Pipeline(steps=[\n",
1314 | " (\"preprocess\", preprocess_mapper),\n",
1315 | " (\"feature_selection\", feature_selection),\n",
1316 | " (\"estimator\", RandomForestClassifier())\n",
1317 | "])"
1318 | ]
1319 | },
1320 | {
1321 | "cell_type": "code",
1322 | "execution_count": 20,
1323 | "metadata": {},
1324 | "outputs": [],
1325 | "source": [
1326 | "param_grid = { \n",
1327 | " \"estimator__n_estimators\": [200, 500],\n",
1328 | " \"estimator__max_features\": ['auto', 'sqrt', 'log2'],\n",
1329 | " \"estimator__max_depth\": [4, 5, 6, 7, 8],\n",
1330 | " \"estimator__criterion\":['gini', 'entropy']\n",
1331 | "}"
1332 | ]
1333 | },
1334 | {
1335 | "cell_type": "code",
1336 | "execution_count": 21,
1337 | "metadata": {},
1338 | "outputs": [
1339 | {
1340 | "data": {
1341 | "text/plain": [
1342 | "GridSearchCV(cv=5, error_score='raise-deprecating',\n",
1343 | " estimator=Pipeline(memory=None,\n",
1344 | " steps=[('preprocess', DataFrameMapper(default=False, df_out=True,\n",
1345 | " features=[(['ap_hi', 'ap_lo'], [BloodPressureTransformer(), SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
1346 | " strategy='most_frequent', verbose=0), OneHotEncoder(categorical_features=None, categories=None,\n",
1347 | " ...obs=None,\n",
1348 | " oob_score=False, random_state=None, verbose=0,\n",
1349 | " warm_start=False))]),\n",
1350 | " fit_params=None, iid='warn', n_jobs=-1,\n",
1351 | " param_grid={'estimator__n_estimators': [200, 500], 'estimator__max_features': ['auto', 'sqrt', 'log2'], 'estimator__max_depth': [4, 5, 6, 7, 8], 'estimator__criterion': ['gini', 'entropy']},\n",
1352 | " pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
1353 | " scoring=None, verbose=0)"
1354 | ]
1355 | },
1356 | "execution_count": 21,
1357 | "metadata": {},
1358 | "output_type": "execute_result"
1359 | }
1360 | ],
1361 | "source": [
1362 | "gscv_estimator = GridSearchCV(pipeline, param_grid, cv=5, n_jobs=-1)\n",
1363 | "gscv_estimator.fit(X_train, y_train)"
1364 | ]
1365 | },
1366 | {
1367 | "cell_type": "code",
1368 | "execution_count": 22,
1369 | "metadata": {},
1370 | "outputs": [
1371 | {
1372 | "data": {
1373 | "text/plain": [
1374 | "{'estimator__criterion': 'gini',\n",
1375 | " 'estimator__max_depth': 7,\n",
1376 | " 'estimator__max_features': 'sqrt',\n",
1377 | " 'estimator__n_estimators': 500}"
1378 | ]
1379 | },
1380 | "metadata": {},
1381 | "output_type": "display_data"
1382 | }
1383 | ],
1384 | "source": [
1385 | "display(gscv_estimator.best_params_)"
1386 | ]
1387 | },
1388 | {
1389 | "cell_type": "code",
1390 | "execution_count": 23,
1391 | "metadata": {},
1392 | "outputs": [],
1393 | "source": [
1394 | "preds = gscv_estimator.predict(X_test)"
1395 | ]
1396 | }
1397 | ],
1398 | "metadata": {
1399 | "kernelspec": {
1400 | "display_name": "Python 3",
1401 | "language": "python",
1402 | "name": "python3"
1403 | },
1404 | "language_info": {
1405 | "codemirror_mode": {
1406 | "name": "ipython",
1407 | "version": 3
1408 | },
1409 | "file_extension": ".py",
1410 | "mimetype": "text/x-python",
1411 | "name": "python",
1412 | "nbconvert_exporter": "python",
1413 | "pygments_lexer": "ipython3",
1414 | "version": "3.7.1"
1415 | }
1416 | },
1417 | "nbformat": 4,
1418 | "nbformat_minor": 2
1419 | }
1420 |
--------------------------------------------------------------------------------