├── .gitignore
├── 2017db_feature_extraction.ipynb
├── LICENSE
├── NOTICE
├── README.md
├── afdb_segmentation.ipynb
├── assets
└── deployment_diagram.png
├── baseline.ipynb
├── bonsaiTrainer2.py
├── bonsai_example2.py
├── feature_importance.ipynb
├── pipeline.ipynb
├── process_afdb_2017_vfdb_cudb.ipynb
└── record_inference_time.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ./data
--------------------------------------------------------------------------------
/2017db_feature_extraction.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import os\n",
11 | "import numpy as np\n",
12 | "from ecgdetectors import Detectors\n",
13 | "from time import time as time\n",
14 | "import seaborn as sns\n",
15 | "from hrvanalysis import remove_outliers, remove_ectopic_beats, interpolate_nan_values, get_time_domain_features, get_geometrical_features, get_frequency_domain_features, get_csi_cvi_features, get_poincare_plot_features, get_sampen\n",
16 | "from time import time as time\n",
17 | "import warnings\n",
18 | "warnings.filterwarnings('ignore')"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "fs = 250"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 8,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "df = pd.read_csv('finaldfs/2017final_normal_rpeak_L5_S1.csv')\n",
37 | "df.reset_index(drop = True, inplace = True)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 9,
43 | "metadata": {},
44 | "outputs": [
45 | {
46 | "data": {
47 | "text/html": [
48 | "
\n",
49 | "\n",
62 | "
\n",
63 | " \n",
64 | " \n",
65 | " | \n",
66 | " af | \n",
67 | " 0 | \n",
68 | " 1 | \n",
69 | " 2 | \n",
70 | " 3 | \n",
71 | " 4 | \n",
72 | " 5 | \n",
73 | " 6 | \n",
74 | " 7 | \n",
75 | " 8 | \n",
76 | " ... | \n",
77 | " 1240 | \n",
78 | " 1241 | \n",
79 | " 1242 | \n",
80 | " 1243 | \n",
81 | " 1244 | \n",
82 | " 1245 | \n",
83 | " 1246 | \n",
84 | " 1247 | \n",
85 | " 1248 | \n",
86 | " 1249 | \n",
87 | "
\n",
88 | " \n",
89 | " \n",
90 | " \n",
91 | " 0 | \n",
92 | " 0 | \n",
93 | " 0 | \n",
94 | " 0 | \n",
95 | " 0 | \n",
96 | " 0 | \n",
97 | " 0 | \n",
98 | " 0 | \n",
99 | " 0 | \n",
100 | " 0 | \n",
101 | " 0 | \n",
102 | " ... | \n",
103 | " 0 | \n",
104 | " 0 | \n",
105 | " 0 | \n",
106 | " 0 | \n",
107 | " 0 | \n",
108 | " 0 | \n",
109 | " 0 | \n",
110 | " 0 | \n",
111 | " 0 | \n",
112 | " 0 | \n",
113 | "
\n",
114 | " \n",
115 | " 1 | \n",
116 | " 0 | \n",
117 | " 0 | \n",
118 | " 0 | \n",
119 | " 0 | \n",
120 | " 0 | \n",
121 | " 0 | \n",
122 | " 0 | \n",
123 | " 0 | \n",
124 | " 0 | \n",
125 | " 0 | \n",
126 | " ... | \n",
127 | " 0 | \n",
128 | " 0 | \n",
129 | " 0 | \n",
130 | " 0 | \n",
131 | " 0 | \n",
132 | " 0 | \n",
133 | " 0 | \n",
134 | " 0 | \n",
135 | " 0 | \n",
136 | " 0 | \n",
137 | "
\n",
138 | " \n",
139 | " 2 | \n",
140 | " 0 | \n",
141 | " 0 | \n",
142 | " 0 | \n",
143 | " 0 | \n",
144 | " 0 | \n",
145 | " 0 | \n",
146 | " 0 | \n",
147 | " 0 | \n",
148 | " 0 | \n",
149 | " 0 | \n",
150 | " ... | \n",
151 | " 0 | \n",
152 | " 0 | \n",
153 | " 0 | \n",
154 | " 0 | \n",
155 | " 0 | \n",
156 | " 0 | \n",
157 | " 0 | \n",
158 | " 0 | \n",
159 | " 0 | \n",
160 | " 0 | \n",
161 | "
\n",
162 | " \n",
163 | " 3 | \n",
164 | " 0 | \n",
165 | " 0 | \n",
166 | " 0 | \n",
167 | " 0 | \n",
168 | " 0 | \n",
169 | " 0 | \n",
170 | " 0 | \n",
171 | " 0 | \n",
172 | " 0 | \n",
173 | " 0 | \n",
174 | " ... | \n",
175 | " 0 | \n",
176 | " 0 | \n",
177 | " 0 | \n",
178 | " 0 | \n",
179 | " 0 | \n",
180 | " 0 | \n",
181 | " 0 | \n",
182 | " 0 | \n",
183 | " 0 | \n",
184 | " 0 | \n",
185 | "
\n",
186 | " \n",
187 | " 4 | \n",
188 | " 0 | \n",
189 | " 0 | \n",
190 | " 0 | \n",
191 | " 0 | \n",
192 | " 0 | \n",
193 | " 0 | \n",
194 | " 0 | \n",
195 | " 0 | \n",
196 | " 0 | \n",
197 | " 0 | \n",
198 | " ... | \n",
199 | " 0 | \n",
200 | " 0 | \n",
201 | " 0 | \n",
202 | " 0 | \n",
203 | " 0 | \n",
204 | " 0 | \n",
205 | " 0 | \n",
206 | " 0 | \n",
207 | " 0 | \n",
208 | " 0 | \n",
209 | "
\n",
210 | " \n",
211 | " ... | \n",
212 | " ... | \n",
213 | " ... | \n",
214 | " ... | \n",
215 | " ... | \n",
216 | " ... | \n",
217 | " ... | \n",
218 | " ... | \n",
219 | " ... | \n",
220 | " ... | \n",
221 | " ... | \n",
222 | " ... | \n",
223 | " ... | \n",
224 | " ... | \n",
225 | " ... | \n",
226 | " ... | \n",
227 | " ... | \n",
228 | " ... | \n",
229 | " ... | \n",
230 | " ... | \n",
231 | " ... | \n",
232 | " ... | \n",
233 | "
\n",
234 | " \n",
235 | " 163588 | \n",
236 | " 0 | \n",
237 | " 0 | \n",
238 | " 0 | \n",
239 | " 0 | \n",
240 | " 0 | \n",
241 | " 0 | \n",
242 | " 0 | \n",
243 | " 0 | \n",
244 | " 0 | \n",
245 | " 0 | \n",
246 | " ... | \n",
247 | " 0 | \n",
248 | " 0 | \n",
249 | " 0 | \n",
250 | " 0 | \n",
251 | " 0 | \n",
252 | " 0 | \n",
253 | " 0 | \n",
254 | " 0 | \n",
255 | " 0 | \n",
256 | " 0 | \n",
257 | "
\n",
258 | " \n",
259 | " 163589 | \n",
260 | " 0 | \n",
261 | " 0 | \n",
262 | " 0 | \n",
263 | " 0 | \n",
264 | " 0 | \n",
265 | " 0 | \n",
266 | " 0 | \n",
267 | " 0 | \n",
268 | " 0 | \n",
269 | " 0 | \n",
270 | " ... | \n",
271 | " 0 | \n",
272 | " 0 | \n",
273 | " 0 | \n",
274 | " 0 | \n",
275 | " 0 | \n",
276 | " 0 | \n",
277 | " 0 | \n",
278 | " 0 | \n",
279 | " 0 | \n",
280 | " 0 | \n",
281 | "
\n",
282 | " \n",
283 | " 163590 | \n",
284 | " 0 | \n",
285 | " 0 | \n",
286 | " 0 | \n",
287 | " 0 | \n",
288 | " 0 | \n",
289 | " 0 | \n",
290 | " 0 | \n",
291 | " 0 | \n",
292 | " 0 | \n",
293 | " 0 | \n",
294 | " ... | \n",
295 | " 0 | \n",
296 | " 0 | \n",
297 | " 0 | \n",
298 | " 0 | \n",
299 | " 0 | \n",
300 | " 0 | \n",
301 | " 0 | \n",
302 | " 0 | \n",
303 | " 0 | \n",
304 | " 0 | \n",
305 | "
\n",
306 | " \n",
307 | " 163591 | \n",
308 | " 0 | \n",
309 | " 0 | \n",
310 | " 0 | \n",
311 | " 0 | \n",
312 | " 0 | \n",
313 | " 0 | \n",
314 | " 0 | \n",
315 | " 0 | \n",
316 | " 0 | \n",
317 | " 0 | \n",
318 | " ... | \n",
319 | " 0 | \n",
320 | " 0 | \n",
321 | " 0 | \n",
322 | " 0 | \n",
323 | " 0 | \n",
324 | " 0 | \n",
325 | " 0 | \n",
326 | " 0 | \n",
327 | " 0 | \n",
328 | " 0 | \n",
329 | "
\n",
330 | " \n",
331 | " 163592 | \n",
332 | " 0 | \n",
333 | " 0 | \n",
334 | " 0 | \n",
335 | " 0 | \n",
336 | " 0 | \n",
337 | " 0 | \n",
338 | " 0 | \n",
339 | " 0 | \n",
340 | " 0 | \n",
341 | " 0 | \n",
342 | " ... | \n",
343 | " 0 | \n",
344 | " 0 | \n",
345 | " 0 | \n",
346 | " 0 | \n",
347 | " 0 | \n",
348 | " 0 | \n",
349 | " 0 | \n",
350 | " 0 | \n",
351 | " 0 | \n",
352 | " 0 | \n",
353 | "
\n",
354 | " \n",
355 | "
\n",
356 | "
163593 rows × 1251 columns
\n",
357 | "
"
358 | ],
359 | "text/plain": [
360 | " af 0 1 2 3 4 5 6 7 8 ... 1240 1241 1242 1243 1244 \\\n",
361 | "0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
362 | "1 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
363 | "2 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
364 | "3 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
365 | "4 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
366 | "... .. .. .. .. .. .. .. .. .. .. ... ... ... ... ... ... \n",
367 | "163588 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
368 | "163589 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
369 | "163590 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
370 | "163591 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
371 | "163592 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 \n",
372 | "\n",
373 | " 1245 1246 1247 1248 1249 \n",
374 | "0 0 0 0 0 0 \n",
375 | "1 0 0 0 0 0 \n",
376 | "2 0 0 0 0 0 \n",
377 | "3 0 0 0 0 0 \n",
378 | "4 0 0 0 0 0 \n",
379 | "... ... ... ... ... ... \n",
380 | "163588 0 0 0 0 0 \n",
381 | "163589 0 0 0 0 0 \n",
382 | "163590 0 0 0 0 0 \n",
383 | "163591 0 0 0 0 0 \n",
384 | "163592 0 0 0 0 0 \n",
385 | "\n",
386 | "[163593 rows x 1251 columns]"
387 | ]
388 | },
389 | "execution_count": 9,
390 | "metadata": {},
391 | "output_type": "execute_result"
392 | }
393 | ],
394 | "source": [
395 | "df"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 10,
401 | "metadata": {},
402 | "outputs": [
403 | {
404 | "name": "stdout",
405 | "output_type": "stream",
406 | "text": [
407 | "CPU times: user 8min 44s, sys: 1.72 s, total: 8min 46s\n",
408 | "Wall time: 8min 46s\n"
409 | ]
410 | }
411 | ],
412 | "source": [
413 | "%%time\n",
414 | "def featureExtraction(df, fs):\n",
415 | " df1 = df.af.copy()\n",
416 | " df.drop(columns = ['af'], inplace = True)\n",
417 | " final = list()\n",
418 | " for i in range(len(df)):\n",
419 | " peaks = []\n",
420 | " peaks = df.loc[i][df.loc[i] == 1].index.values\n",
421 | " peaks = [int(x) for x in peaks]\n",
422 | " if (len(peaks) == 0 or len(peaks) == 1):\n",
423 | " continue\n",
424 | " af = df1.loc[i]\n",
425 | " rr = np.diff(peaks) / fs \n",
426 | " hr = 60 / rr\n",
427 | " rmssd = np.sqrt(np.mean(np.square(rr)))\n",
428 | " sdnn = np.std(rr)\n",
429 | " mean_rr = np.mean(rr)\n",
430 | " mean_hr = np.mean(hr)\n",
431 | " std_hr = np.std(hr)\n",
432 | " min_hr = np.min(hr)\n",
433 | " max_hr = np.max(hr)\n",
434 | " single_window_df = pd.DataFrame([[rmssd, sdnn, mean_rr, mean_hr, std_hr, min_hr, max_hr]], columns = ['RMSSD', 'STDNN', 'MEAN_RR', 'MEAN_HR', 'STD_HR', 'MIN_HR', 'MAX_HR'])\n",
435 | " single_window_df.insert(0, 'af', af)\n",
436 | " final.append(single_window_df)\n",
437 | " del single_window_df\n",
438 | " return pd.concat(final)\n",
439 | "\n",
440 | "finaldf = featureExtraction(df, fs)\n",
441 | "finaldf.reset_index(drop = True, inplace = True)"
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": 11,
447 | "metadata": {},
448 | "outputs": [],
449 | "source": [
450 | "finaldf1 = finaldf.copy()"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": 14,
456 | "metadata": {},
457 | "outputs": [
458 | {
459 | "data": {
460 | "text/html": [
461 | "\n",
462 | "\n",
475 | "
\n",
476 | " \n",
477 | " \n",
478 | " | \n",
479 | " af | \n",
480 | " RMSSD | \n",
481 | " STDNN | \n",
482 | " MEAN_RR | \n",
483 | " MEAN_HR | \n",
484 | " STD_HR | \n",
485 | " MIN_HR | \n",
486 | " MAX_HR | \n",
487 | "
\n",
488 | " \n",
489 | " \n",
490 | " \n",
491 | " 0 | \n",
492 | " 0 | \n",
493 | " 0.763275 | \n",
494 | " 0.036538 | \n",
495 | " 0.7624 | \n",
496 | " 78.881404 | \n",
497 | " 3.810557 | \n",
498 | " 74.257426 | \n",
499 | " 83.798883 | \n",
500 | "
\n",
501 | " \n",
502 | " 1 | \n",
503 | " 0 | \n",
504 | " 0.784564 | \n",
505 | " 0.029755 | \n",
506 | " 0.7840 | \n",
507 | " 76.646108 | \n",
508 | " 3.044678 | \n",
509 | " 73.891626 | \n",
510 | " 82.872928 | \n",
511 | "
\n",
512 | " \n",
513 | " 2 | \n",
514 | " 0 | \n",
515 | " 0.799334 | \n",
516 | " 0.014621 | \n",
517 | " 0.7992 | \n",
518 | " 75.100714 | \n",
519 | " 1.401583 | \n",
520 | " 73.891626 | \n",
521 | " 77.720207 | \n",
522 | "
\n",
523 | " \n",
524 | " 3 | \n",
525 | " 0 | \n",
526 | " 0.793752 | \n",
527 | " 0.015513 | \n",
528 | " 0.7936 | \n",
529 | " 75.633844 | \n",
530 | " 1.483975 | \n",
531 | " 73.891626 | \n",
532 | " 77.720207 | \n",
533 | "
\n",
534 | " \n",
535 | " 4 | \n",
536 | " 0 | \n",
537 | " 0.800080 | \n",
538 | " 0.011314 | \n",
539 | " 0.8000 | \n",
540 | " 75.015176 | \n",
541 | " 1.073150 | \n",
542 | " 73.891626 | \n",
543 | " 76.923077 | \n",
544 | "
\n",
545 | " \n",
546 | " ... | \n",
547 | " ... | \n",
548 | " ... | \n",
549 | " ... | \n",
550 | " ... | \n",
551 | " ... | \n",
552 | " ... | \n",
553 | " ... | \n",
554 | " ... | \n",
555 | "
\n",
556 | " \n",
557 | " 162623 | \n",
558 | " 0 | \n",
559 | " 0.922219 | \n",
560 | " 0.020100 | \n",
561 | " 0.9220 | \n",
562 | " 65.107552 | \n",
563 | " 1.451010 | \n",
564 | " 63.829787 | \n",
565 | " 67.567568 | \n",
566 | "
\n",
567 | " \n",
568 | " 162624 | \n",
569 | " 0 | \n",
570 | " 0.923231 | \n",
571 | " 0.020664 | \n",
572 | " 0.9230 | \n",
573 | " 65.038773 | \n",
574 | " 1.489974 | \n",
575 | " 63.829787 | \n",
576 | " 67.567568 | \n",
577 | "
\n",
578 | " \n",
579 | " 162625 | \n",
580 | " 0 | \n",
581 | " 0.924251 | \n",
582 | " 0.021541 | \n",
583 | " 0.9240 | \n",
584 | " 64.971157 | \n",
585 | " 1.548300 | \n",
586 | " 63.559322 | \n",
587 | " 67.567568 | \n",
588 | "
\n",
589 | " \n",
590 | " 162626 | \n",
591 | " 0 | \n",
592 | " 0.933032 | \n",
593 | " 0.007681 | \n",
594 | " 0.9330 | \n",
595 | " 64.313031 | \n",
596 | " 0.528290 | \n",
597 | " 63.559322 | \n",
598 | " 64.935065 | \n",
599 | "
\n",
600 | " \n",
601 | " 162627 | \n",
602 | " 0 | \n",
603 | " 0.918481 | \n",
604 | " 0.029732 | \n",
605 | " 0.9180 | \n",
606 | " 65.430344 | \n",
607 | " 2.188336 | \n",
608 | " 63.559322 | \n",
609 | " 69.124424 | \n",
610 | "
\n",
611 | " \n",
612 | "
\n",
613 | "
162628 rows × 8 columns
\n",
614 | "
"
615 | ],
616 | "text/plain": [
617 | " af RMSSD STDNN MEAN_RR MEAN_HR STD_HR MIN_HR \\\n",
618 | "0 0 0.763275 0.036538 0.7624 78.881404 3.810557 74.257426 \n",
619 | "1 0 0.784564 0.029755 0.7840 76.646108 3.044678 73.891626 \n",
620 | "2 0 0.799334 0.014621 0.7992 75.100714 1.401583 73.891626 \n",
621 | "3 0 0.793752 0.015513 0.7936 75.633844 1.483975 73.891626 \n",
622 | "4 0 0.800080 0.011314 0.8000 75.015176 1.073150 73.891626 \n",
623 | "... .. ... ... ... ... ... ... \n",
624 | "162623 0 0.922219 0.020100 0.9220 65.107552 1.451010 63.829787 \n",
625 | "162624 0 0.923231 0.020664 0.9230 65.038773 1.489974 63.829787 \n",
626 | "162625 0 0.924251 0.021541 0.9240 64.971157 1.548300 63.559322 \n",
627 | "162626 0 0.933032 0.007681 0.9330 64.313031 0.528290 63.559322 \n",
628 | "162627 0 0.918481 0.029732 0.9180 65.430344 2.188336 63.559322 \n",
629 | "\n",
630 | " MAX_HR \n",
631 | "0 83.798883 \n",
632 | "1 82.872928 \n",
633 | "2 77.720207 \n",
634 | "3 77.720207 \n",
635 | "4 76.923077 \n",
636 | "... ... \n",
637 | "162623 67.567568 \n",
638 | "162624 67.567568 \n",
639 | "162625 67.567568 \n",
640 | "162626 64.935065 \n",
641 | "162627 69.124424 \n",
642 | "\n",
643 | "[162628 rows x 8 columns]"
644 | ]
645 | },
646 | "execution_count": 14,
647 | "metadata": {},
648 | "output_type": "execute_result"
649 | }
650 | ],
651 | "source": [
652 | "finaldf1"
653 | ]
654 | },
655 | {
656 | "cell_type": "code",
657 | "execution_count": 13,
658 | "metadata": {},
659 | "outputs": [
660 | {
661 | "data": {
662 | "text/plain": [
663 | "0 141350\n",
664 | "1 21278\n",
665 | "Name: af, dtype: int64"
666 | ]
667 | },
668 | "execution_count": 13,
669 | "metadata": {},
670 | "output_type": "execute_result"
671 | }
672 | ],
673 | "source": [
674 | "finaldf1.af.value_counts()"
675 | ]
676 | },
677 | {
678 | "cell_type": "code",
679 | "execution_count": 64,
680 | "metadata": {},
681 | "outputs": [],
682 | "source": [
683 | "finaldf1 = finaldf1.drop(finaldf1[finaldf1['af'].eq(0)].sample(10000).index)"
684 | ]
685 | },
686 | {
687 | "cell_type": "code",
688 | "execution_count": 65,
689 | "metadata": {},
690 | "outputs": [],
691 | "source": [
692 | "finaldf1.reset_index(drop = True, inplace = True)"
693 | ]
694 | },
695 | {
696 | "cell_type": "code",
697 | "execution_count": 16,
698 | "metadata": {},
699 | "outputs": [],
700 | "source": [
701 | "finaldf1.to_csv('/hdd/physio/af2/finaldfs/2017final_normal_feats_L5_S1.csv')"
702 | ]
703 | }
704 | ],
705 | "metadata": {
706 | "kernelspec": {
707 | "display_name": "Python [conda env:physio]",
708 | "language": "python",
709 | "name": "conda-env-physio-py"
710 | },
711 | "language_info": {
712 | "codemirror_mode": {
713 | "name": "ipython",
714 | "version": 3
715 | },
716 | "file_extension": ".py",
717 | "mimetype": "text/x-python",
718 | "name": "python",
719 | "nbconvert_exporter": "python",
720 | "pygments_lexer": "ipython3",
721 | "version": "3.7.5"
722 | }
723 | },
724 | "nbformat": 4,
725 | "nbformat_minor": 4
726 | }
727 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Vishal Nagarajan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | The code in this repository incorporates materials from https://github.com/microsoft/EdgeML which has been licensed under the MIT License (c) Microsoft Corporation.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # End-to-End Optimized Arrhythmia Detection Pipeline using Machine Learning for Ultra-Edge Devices
2 |
3 | ## Description
4 |
5 | This repository contains code for the implementation of the paper titled "End-to-End Optimized Arrhythmia Detection Pipeline using Machine Learning for Ultra-Edge Devices", that has been **published** at the [20th IEEE International Conference on Machine Learning and Applications (ICMLA21), Pasadena, CA, USA](https://www.icmla-conference.org/icmla21/). You can download all three datasets from [PhysioNet](https://physionet.org/about/database/).
6 |
7 | 1. [MIT-BIH Atrial Fibrillation DataBase (AFDB)](https://physionet.org/content/afdb/1.0.0/)
8 | 2. [AF Classification from a Short Single Lead ECG Recording - The PhysioNet Computing in Cardiology Challenge 2017 (2017/CHDB)](https://physionet.org/content/challenge-2017/1.0.0/)
9 | 3. [MIT-BIH Malignant Ventricular Ectopy Database (VFDB)](https://physionet.org/content/vfdb/1.0.0/)
10 |
11 | ## Architecture
12 |
13 |
14 |
15 |
16 |
17 | ## Publication Link
18 |
19 | Link: https://ieeexplore.ieee.org/document/9680091
20 |
21 | Citation:
22 | ```
23 | @INPROCEEDINGS{9680091,
24 | author={Sideshwar, J B and Sachin Krishan, T and Nagarajan, Vishal and S, Shanthakumar and Vijayaraghavan, Vineeth},
25 | booktitle={2021 20th IEEE International Conference on Machine Learning and Applications (ICMLA)},
26 | title={End-to-End Optimized Arrhythmia Detection Pipeline using Machine Learning for Ultra-Edge Devices},
27 | year={2021},
28 | volume={},
29 | number={},
30 | pages={1501-1506},
31 | doi={10.1109/ICMLA52953.2021.00242}}
32 | ```
33 |
--------------------------------------------------------------------------------
/afdb_segmentation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import numpy as np\n",
11 | "import os\n",
12 | "import functools\n",
13 | "import operator"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "lengthInSeconds = 5\n",
23 | "strideInSeconds = 1\n",
24 | "fs = 250\n",
25 | "length = lengthInSeconds * fs\n",
26 | "stride = strideInSeconds * fs"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 3,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "onemindf = pd.read_csv('hugedfrpeak.csv', index_col=[0])"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 4,
41 | "metadata": {},
42 | "outputs": [
43 | {
44 | "data": {
45 | "text/html": [
46 | "\n",
47 | "\n",
60 | "
\n",
61 | " \n",
62 | " \n",
63 | " | \n",
64 | " af | \n",
65 | " 0 | \n",
66 | " 1 | \n",
67 | " 2 | \n",
68 | " 3 | \n",
69 | " 4 | \n",
70 | " 5 | \n",
71 | " 6 | \n",
72 | " 7 | \n",
73 | " 8 | \n",
74 | " ... | \n",
75 | " 14990 | \n",
76 | " 14991 | \n",
77 | " 14992 | \n",
78 | " 14993 | \n",
79 | " 14994 | \n",
80 | " 14995 | \n",
81 | " 14996 | \n",
82 | " 14997 | \n",
83 | " 14998 | \n",
84 | " 14999 | \n",
85 | "
\n",
86 | " \n",
87 | " \n",
88 | " \n",
89 | " 0 | \n",
90 | " 0 | \n",
91 | " 0 | \n",
92 | " 0 | \n",
93 | " 0 | \n",
94 | " 0 | \n",
95 | " 0 | \n",
96 | " 0 | \n",
97 | " 0 | \n",
98 | " 0 | \n",
99 | " 0 | \n",
100 | " ... | \n",
101 | " 0.0 | \n",
102 | " 0.0 | \n",
103 | " 0.0 | \n",
104 | " 0.0 | \n",
105 | " 0.0 | \n",
106 | " 0.0 | \n",
107 | " 0.0 | \n",
108 | " 0.0 | \n",
109 | " 0.0 | \n",
110 | " 0.0 | \n",
111 | "
\n",
112 | " \n",
113 | " 1 | \n",
114 | " 0 | \n",
115 | " 0 | \n",
116 | " 0 | \n",
117 | " 0 | \n",
118 | " 0 | \n",
119 | " 0 | \n",
120 | " 0 | \n",
121 | " 0 | \n",
122 | " 0 | \n",
123 | " 0 | \n",
124 | " ... | \n",
125 | " 0.0 | \n",
126 | " 0.0 | \n",
127 | " 0.0 | \n",
128 | " 0.0 | \n",
129 | " 0.0 | \n",
130 | " 0.0 | \n",
131 | " 0.0 | \n",
132 | " 0.0 | \n",
133 | " 0.0 | \n",
134 | " 0.0 | \n",
135 | "
\n",
136 | " \n",
137 | " 2 | \n",
138 | " 0 | \n",
139 | " 0 | \n",
140 | " 0 | \n",
141 | " 0 | \n",
142 | " 0 | \n",
143 | " 0 | \n",
144 | " 0 | \n",
145 | " 0 | \n",
146 | " 0 | \n",
147 | " 0 | \n",
148 | " ... | \n",
149 | " 0.0 | \n",
150 | " 0.0 | \n",
151 | " 0.0 | \n",
152 | " 0.0 | \n",
153 | " 0.0 | \n",
154 | " 0.0 | \n",
155 | " 0.0 | \n",
156 | " 0.0 | \n",
157 | " 0.0 | \n",
158 | " 0.0 | \n",
159 | "
\n",
160 | " \n",
161 | " 3 | \n",
162 | " 0 | \n",
163 | " 0 | \n",
164 | " 0 | \n",
165 | " 0 | \n",
166 | " 0 | \n",
167 | " 1 | \n",
168 | " 0 | \n",
169 | " 0 | \n",
170 | " 0 | \n",
171 | " 0 | \n",
172 | " ... | \n",
173 | " 0.0 | \n",
174 | " 0.0 | \n",
175 | " 0.0 | \n",
176 | " 0.0 | \n",
177 | " 0.0 | \n",
178 | " 0.0 | \n",
179 | " 0.0 | \n",
180 | " 0.0 | \n",
181 | " 0.0 | \n",
182 | " 0.0 | \n",
183 | "
\n",
184 | " \n",
185 | " 4 | \n",
186 | " 0 | \n",
187 | " 0 | \n",
188 | " 0 | \n",
189 | " 0 | \n",
190 | " 0 | \n",
191 | " 0 | \n",
192 | " 0 | \n",
193 | " 0 | \n",
194 | " 0 | \n",
195 | " 0 | \n",
196 | " ... | \n",
197 | " 0.0 | \n",
198 | " 0.0 | \n",
199 | " 0.0 | \n",
200 | " 0.0 | \n",
201 | " 0.0 | \n",
202 | " 0.0 | \n",
203 | " 0.0 | \n",
204 | " 0.0 | \n",
205 | " 1.0 | \n",
206 | " 0.0 | \n",
207 | "
\n",
208 | " \n",
209 | " ... | \n",
210 | " ... | \n",
211 | " ... | \n",
212 | " ... | \n",
213 | " ... | \n",
214 | " ... | \n",
215 | " ... | \n",
216 | " ... | \n",
217 | " ... | \n",
218 | " ... | \n",
219 | " ... | \n",
220 | " ... | \n",
221 | " ... | \n",
222 | " ... | \n",
223 | " ... | \n",
224 | " ... | \n",
225 | " ... | \n",
226 | " ... | \n",
227 | " ... | \n",
228 | " ... | \n",
229 | " ... | \n",
230 | " ... | \n",
231 | "
\n",
232 | " \n",
233 | " 609 | \n",
234 | " 1 | \n",
235 | " 0 | \n",
236 | " 0 | \n",
237 | " 0 | \n",
238 | " 0 | \n",
239 | " 0 | \n",
240 | " 0 | \n",
241 | " 0 | \n",
242 | " 0 | \n",
243 | " 0 | \n",
244 | " ... | \n",
245 | " 0.0 | \n",
246 | " 0.0 | \n",
247 | " 0.0 | \n",
248 | " 0.0 | \n",
249 | " 0.0 | \n",
250 | " 0.0 | \n",
251 | " 0.0 | \n",
252 | " 0.0 | \n",
253 | " 0.0 | \n",
254 | " 0.0 | \n",
255 | "
\n",
256 | " \n",
257 | " 610 | \n",
258 | " 1 | \n",
259 | " 0 | \n",
260 | " 0 | \n",
261 | " 0 | \n",
262 | " 0 | \n",
263 | " 0 | \n",
264 | " 0 | \n",
265 | " 0 | \n",
266 | " 0 | \n",
267 | " 0 | \n",
268 | " ... | \n",
269 | " 0.0 | \n",
270 | " 0.0 | \n",
271 | " 0.0 | \n",
272 | " 0.0 | \n",
273 | " 0.0 | \n",
274 | " 0.0 | \n",
275 | " 0.0 | \n",
276 | " 0.0 | \n",
277 | " 0.0 | \n",
278 | " 0.0 | \n",
279 | "
\n",
280 | " \n",
281 | " 611 | \n",
282 | " 1 | \n",
283 | " 0 | \n",
284 | " 0 | \n",
285 | " 0 | \n",
286 | " 0 | \n",
287 | " 0 | \n",
288 | " 0 | \n",
289 | " 0 | \n",
290 | " 0 | \n",
291 | " 0 | \n",
292 | " ... | \n",
293 | " 0.0 | \n",
294 | " 0.0 | \n",
295 | " 0.0 | \n",
296 | " 0.0 | \n",
297 | " 0.0 | \n",
298 | " 0.0 | \n",
299 | " 0.0 | \n",
300 | " 0.0 | \n",
301 | " 0.0 | \n",
302 | " 0.0 | \n",
303 | "
\n",
304 | " \n",
305 | " 612 | \n",
306 | " 1 | \n",
307 | " 0 | \n",
308 | " 0 | \n",
309 | " 0 | \n",
310 | " 0 | \n",
311 | " 0 | \n",
312 | " 0 | \n",
313 | " 0 | \n",
314 | " 0 | \n",
315 | " 0 | \n",
316 | " ... | \n",
317 | " 0.0 | \n",
318 | " 0.0 | \n",
319 | " 0.0 | \n",
320 | " 0.0 | \n",
321 | " 0.0 | \n",
322 | " 0.0 | \n",
323 | " 0.0 | \n",
324 | " 0.0 | \n",
325 | " 0.0 | \n",
326 | " 0.0 | \n",
327 | "
\n",
328 | " \n",
329 | " 613 | \n",
330 | " 1 | \n",
331 | " 0 | \n",
332 | " 0 | \n",
333 | " 0 | \n",
334 | " 0 | \n",
335 | " 0 | \n",
336 | " 0 | \n",
337 | " 0 | \n",
338 | " 0 | \n",
339 | " 0 | \n",
340 | " ... | \n",
341 | " NaN | \n",
342 | " NaN | \n",
343 | " NaN | \n",
344 | " NaN | \n",
345 | " NaN | \n",
346 | " NaN | \n",
347 | " NaN | \n",
348 | " NaN | \n",
349 | " NaN | \n",
350 | " NaN | \n",
351 | "
\n",
352 | " \n",
353 | "
\n",
354 | "
14063 rows × 15001 columns
\n",
355 | "
"
356 | ],
357 | "text/plain": [
358 | " af 0 1 2 3 4 5 6 7 8 ... 14990 14991 14992 14993 14994 \\\n",
359 | "0 0 0 0 0 0 0 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
360 | "1 0 0 0 0 0 0 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
361 | "2 0 0 0 0 0 0 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
362 | "3 0 0 0 0 0 1 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
363 | "4 0 0 0 0 0 0 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
364 | ".. .. .. .. .. .. .. .. .. .. .. ... ... ... ... ... ... \n",
365 | "609 1 0 0 0 0 0 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
366 | "610 1 0 0 0 0 0 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
367 | "611 1 0 0 0 0 0 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
368 | "612 1 0 0 0 0 0 0 0 0 0 ... 0.0 0.0 0.0 0.0 0.0 \n",
369 | "613 1 0 0 0 0 0 0 0 0 0 ... NaN NaN NaN NaN NaN \n",
370 | "\n",
371 | " 14995 14996 14997 14998 14999 \n",
372 | "0 0.0 0.0 0.0 0.0 0.0 \n",
373 | "1 0.0 0.0 0.0 0.0 0.0 \n",
374 | "2 0.0 0.0 0.0 0.0 0.0 \n",
375 | "3 0.0 0.0 0.0 0.0 0.0 \n",
376 | "4 0.0 0.0 0.0 1.0 0.0 \n",
377 | ".. ... ... ... ... ... \n",
378 | "609 0.0 0.0 0.0 0.0 0.0 \n",
379 | "610 0.0 0.0 0.0 0.0 0.0 \n",
380 | "611 0.0 0.0 0.0 0.0 0.0 \n",
381 | "612 0.0 0.0 0.0 0.0 0.0 \n",
382 | "613 NaN NaN NaN NaN NaN \n",
383 | "\n",
384 | "[14063 rows x 15001 columns]"
385 | ]
386 | },
387 | "execution_count": 4,
388 | "metadata": {},
389 | "output_type": "execute_result"
390 | }
391 | ],
392 | "source": [
393 | "onemindf"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": 5,
399 | "metadata": {},
400 | "outputs": [],
401 | "source": [
402 | "def strided_app(a, L, S ): \n",
403 | " nrows = ((a.size-L)//S)+1\n",
404 | "# print(nrows)\n",
405 | " n = a.strides[0]\n",
406 | " return pd.DataFrame(np.lib.stride_tricks.as_strided(a, shape=(nrows,L), strides=(S*n,n)))"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 6,
412 | "metadata": {},
413 | "outputs": [],
414 | "source": [
415 | "L_arr = [length]*len(onemindf)\n",
416 | "S_arr = [stride]*len(onemindf)\n",
417 | "labels = onemindf['af'].copy()\n",
418 | "del onemindf['af']"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": 7,
424 | "metadata": {},
425 | "outputs": [],
426 | "source": [
427 | "final_ecg_list = list(map(strided_app, onemindf.to_numpy(), L_arr, S_arr))\n",
428 | "final_label_list = list(map(lambda x: [x]*56, labels.values))"
429 | ]
430 | },
431 | {
432 | "cell_type": "code",
433 | "execution_count": 10,
434 | "metadata": {},
435 | "outputs": [],
436 | "source": [
437 | "final_df = pd.concat(final_ecg_list)"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": 11,
443 | "metadata": {},
444 | "outputs": [
445 | {
446 | "data": {
447 | "text/html": [
448 | "\n",
449 | "\n",
462 | "
\n",
463 | " \n",
464 | " \n",
465 | " | \n",
466 | " 0 | \n",
467 | " 1 | \n",
468 | " 2 | \n",
469 | " 3 | \n",
470 | " 4 | \n",
471 | " 5 | \n",
472 | " 6 | \n",
473 | " 7 | \n",
474 | " 8 | \n",
475 | " 9 | \n",
476 | " ... | \n",
477 | " 1240 | \n",
478 | " 1241 | \n",
479 | " 1242 | \n",
480 | " 1243 | \n",
481 | " 1244 | \n",
482 | " 1245 | \n",
483 | " 1246 | \n",
484 | " 1247 | \n",
485 | " 1248 | \n",
486 | " 1249 | \n",
487 | "
\n",
488 | " \n",
489 | " \n",
490 | " \n",
491 | " 0 | \n",
492 | " 0.0 | \n",
493 | " 0.0 | \n",
494 | " 0.0 | \n",
495 | " 0.0 | \n",
496 | " 0.0 | \n",
497 | " 0.0 | \n",
498 | " 0.0 | \n",
499 | " 0.0 | \n",
500 | " 0.0 | \n",
501 | " 0.0 | \n",
502 | " ... | \n",
503 | " 0.0 | \n",
504 | " 0.0 | \n",
505 | " 0.0 | \n",
506 | " 0.0 | \n",
507 | " 0.0 | \n",
508 | " 0.0 | \n",
509 | " 0.0 | \n",
510 | " 0.0 | \n",
511 | " 0.0 | \n",
512 | " 0.0 | \n",
513 | "
\n",
514 | " \n",
515 | " 1 | \n",
516 | " 0.0 | \n",
517 | " 0.0 | \n",
518 | " 0.0 | \n",
519 | " 0.0 | \n",
520 | " 0.0 | \n",
521 | " 0.0 | \n",
522 | " 0.0 | \n",
523 | " 0.0 | \n",
524 | " 0.0 | \n",
525 | " 0.0 | \n",
526 | " ... | \n",
527 | " 0.0 | \n",
528 | " 0.0 | \n",
529 | " 0.0 | \n",
530 | " 0.0 | \n",
531 | " 0.0 | \n",
532 | " 0.0 | \n",
533 | " 0.0 | \n",
534 | " 0.0 | \n",
535 | " 0.0 | \n",
536 | " 0.0 | \n",
537 | "
\n",
538 | " \n",
539 | " 2 | \n",
540 | " 0.0 | \n",
541 | " 0.0 | \n",
542 | " 0.0 | \n",
543 | " 0.0 | \n",
544 | " 0.0 | \n",
545 | " 0.0 | \n",
546 | " 0.0 | \n",
547 | " 0.0 | \n",
548 | " 0.0 | \n",
549 | " 0.0 | \n",
550 | " ... | \n",
551 | " 0.0 | \n",
552 | " 0.0 | \n",
553 | " 0.0 | \n",
554 | " 0.0 | \n",
555 | " 0.0 | \n",
556 | " 0.0 | \n",
557 | " 0.0 | \n",
558 | " 0.0 | \n",
559 | " 0.0 | \n",
560 | " 0.0 | \n",
561 | "
\n",
562 | " \n",
563 | " 3 | \n",
564 | " 0.0 | \n",
565 | " 0.0 | \n",
566 | " 0.0 | \n",
567 | " 0.0 | \n",
568 | " 0.0 | \n",
569 | " 0.0 | \n",
570 | " 0.0 | \n",
571 | " 0.0 | \n",
572 | " 0.0 | \n",
573 | " 0.0 | \n",
574 | " ... | \n",
575 | " 0.0 | \n",
576 | " 0.0 | \n",
577 | " 0.0 | \n",
578 | " 0.0 | \n",
579 | " 0.0 | \n",
580 | " 0.0 | \n",
581 | " 0.0 | \n",
582 | " 0.0 | \n",
583 | " 0.0 | \n",
584 | " 0.0 | \n",
585 | "
\n",
586 | " \n",
587 | " 4 | \n",
588 | " 0.0 | \n",
589 | " 0.0 | \n",
590 | " 0.0 | \n",
591 | " 0.0 | \n",
592 | " 0.0 | \n",
593 | " 0.0 | \n",
594 | " 0.0 | \n",
595 | " 0.0 | \n",
596 | " 0.0 | \n",
597 | " 0.0 | \n",
598 | " ... | \n",
599 | " 0.0 | \n",
600 | " 0.0 | \n",
601 | " 0.0 | \n",
602 | " 0.0 | \n",
603 | " 0.0 | \n",
604 | " 0.0 | \n",
605 | " 0.0 | \n",
606 | " 0.0 | \n",
607 | " 0.0 | \n",
608 | " 0.0 | \n",
609 | "
\n",
610 | " \n",
611 | " ... | \n",
612 | " ... | \n",
613 | " ... | \n",
614 | " ... | \n",
615 | " ... | \n",
616 | " ... | \n",
617 | " ... | \n",
618 | " ... | \n",
619 | " ... | \n",
620 | " ... | \n",
621 | " ... | \n",
622 | " ... | \n",
623 | " ... | \n",
624 | " ... | \n",
625 | " ... | \n",
626 | " ... | \n",
627 | " ... | \n",
628 | " ... | \n",
629 | " ... | \n",
630 | " ... | \n",
631 | " ... | \n",
632 | " ... | \n",
633 | "
\n",
634 | " \n",
635 | " 51 | \n",
636 | " NaN | \n",
637 | " NaN | \n",
638 | " NaN | \n",
639 | " NaN | \n",
640 | " NaN | \n",
641 | " NaN | \n",
642 | " NaN | \n",
643 | " NaN | \n",
644 | " NaN | \n",
645 | " NaN | \n",
646 | " ... | \n",
647 | " NaN | \n",
648 | " NaN | \n",
649 | " NaN | \n",
650 | " NaN | \n",
651 | " NaN | \n",
652 | " NaN | \n",
653 | " NaN | \n",
654 | " NaN | \n",
655 | " NaN | \n",
656 | " NaN | \n",
657 | "
\n",
658 | " \n",
659 | " 52 | \n",
660 | " NaN | \n",
661 | " NaN | \n",
662 | " NaN | \n",
663 | " NaN | \n",
664 | " NaN | \n",
665 | " NaN | \n",
666 | " NaN | \n",
667 | " NaN | \n",
668 | " NaN | \n",
669 | " NaN | \n",
670 | " ... | \n",
671 | " NaN | \n",
672 | " NaN | \n",
673 | " NaN | \n",
674 | " NaN | \n",
675 | " NaN | \n",
676 | " NaN | \n",
677 | " NaN | \n",
678 | " NaN | \n",
679 | " NaN | \n",
680 | " NaN | \n",
681 | "
\n",
682 | " \n",
683 | " 53 | \n",
684 | " NaN | \n",
685 | " NaN | \n",
686 | " NaN | \n",
687 | " NaN | \n",
688 | " NaN | \n",
689 | " NaN | \n",
690 | " NaN | \n",
691 | " NaN | \n",
692 | " NaN | \n",
693 | " NaN | \n",
694 | " ... | \n",
695 | " NaN | \n",
696 | " NaN | \n",
697 | " NaN | \n",
698 | " NaN | \n",
699 | " NaN | \n",
700 | " NaN | \n",
701 | " NaN | \n",
702 | " NaN | \n",
703 | " NaN | \n",
704 | " NaN | \n",
705 | "
\n",
706 | " \n",
707 | " 54 | \n",
708 | " NaN | \n",
709 | " NaN | \n",
710 | " NaN | \n",
711 | " NaN | \n",
712 | " NaN | \n",
713 | " NaN | \n",
714 | " NaN | \n",
715 | " NaN | \n",
716 | " NaN | \n",
717 | " NaN | \n",
718 | " ... | \n",
719 | " NaN | \n",
720 | " NaN | \n",
721 | " NaN | \n",
722 | " NaN | \n",
723 | " NaN | \n",
724 | " NaN | \n",
725 | " NaN | \n",
726 | " NaN | \n",
727 | " NaN | \n",
728 | " NaN | \n",
729 | "
\n",
730 | " \n",
731 | " 55 | \n",
732 | " NaN | \n",
733 | " NaN | \n",
734 | " NaN | \n",
735 | " NaN | \n",
736 | " NaN | \n",
737 | " NaN | \n",
738 | " NaN | \n",
739 | " NaN | \n",
740 | " NaN | \n",
741 | " NaN | \n",
742 | " ... | \n",
743 | " NaN | \n",
744 | " NaN | \n",
745 | " NaN | \n",
746 | " NaN | \n",
747 | " NaN | \n",
748 | " NaN | \n",
749 | " NaN | \n",
750 | " NaN | \n",
751 | " NaN | \n",
752 | " NaN | \n",
753 | "
\n",
754 | " \n",
755 | "
\n",
756 | "
787528 rows × 1250 columns
\n",
757 | "
"
758 | ],
759 | "text/plain": [
760 | " 0 1 2 3 4 5 6 7 8 9 ... 1240 \\\n",
761 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
762 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
763 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
764 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
765 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
766 | ".. ... ... ... ... ... ... ... ... ... ... ... ... \n",
767 | "51 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
768 | "52 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
769 | "53 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
770 | "54 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
771 | "55 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
772 | "\n",
773 | " 1241 1242 1243 1244 1245 1246 1247 1248 1249 \n",
774 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
775 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
776 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
777 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
778 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
779 | ".. ... ... ... ... ... ... ... ... ... \n",
780 | "51 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
781 | "52 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
782 | "53 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
783 | "54 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
784 | "55 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
785 | "\n",
786 | "[787528 rows x 1250 columns]"
787 | ]
788 | },
789 | "execution_count": 11,
790 | "metadata": {},
791 | "output_type": "execute_result"
792 | }
793 | ],
794 | "source": [
795 | "final_df"
796 | ]
797 | },
798 | {
799 | "cell_type": "code",
800 | "execution_count": 12,
801 | "metadata": {},
802 | "outputs": [],
803 | "source": [
804 | "final_labels_list_flat = functools.reduce(operator.iconcat, final_label_list, [])"
805 | ]
806 | },
807 | {
808 | "cell_type": "code",
809 | "execution_count": 13,
810 | "metadata": {},
811 | "outputs": [],
812 | "source": [
813 | "final_df.insert(0,'af', final_labels_list_flat)"
814 | ]
815 | },
816 | {
817 | "cell_type": "code",
818 | "execution_count": 14,
819 | "metadata": {},
820 | "outputs": [],
821 | "source": [
822 | "final_df.reset_index(inplace=True, drop = True)"
823 | ]
824 | },
825 | {
826 | "cell_type": "code",
827 | "execution_count": 19,
828 | "metadata": {},
829 | "outputs": [
830 | {
831 | "data": {
832 | "text/html": [
833 | "\n",
834 | "\n",
847 | "
\n",
848 | " \n",
849 | " \n",
850 | " | \n",
851 | " af | \n",
852 | " 0 | \n",
853 | " 1 | \n",
854 | " 2 | \n",
855 | " 3 | \n",
856 | " 4 | \n",
857 | " 5 | \n",
858 | " 6 | \n",
859 | " 7 | \n",
860 | " 8 | \n",
861 | " ... | \n",
862 | " 1240 | \n",
863 | " 1241 | \n",
864 | " 1242 | \n",
865 | " 1243 | \n",
866 | " 1244 | \n",
867 | " 1245 | \n",
868 | " 1246 | \n",
869 | " 1247 | \n",
870 | " 1248 | \n",
871 | " 1249 | \n",
872 | "
\n",
873 | " \n",
874 | " \n",
875 | " \n",
876 | " 0 | \n",
877 | " 0 | \n",
878 | " 0.0 | \n",
879 | " 0.0 | \n",
880 | " 0.0 | \n",
881 | " 0.0 | \n",
882 | " 0.0 | \n",
883 | " 0.0 | \n",
884 | " 0.0 | \n",
885 | " 0.0 | \n",
886 | " 0.0 | \n",
887 | " ... | \n",
888 | " 0.0 | \n",
889 | " 0.0 | \n",
890 | " 0.0 | \n",
891 | " 0.0 | \n",
892 | " 0.0 | \n",
893 | " 0.0 | \n",
894 | " 0.0 | \n",
895 | " 0.0 | \n",
896 | " 0.0 | \n",
897 | " 0.0 | \n",
898 | "
\n",
899 | " \n",
900 | " 1 | \n",
901 | " 0 | \n",
902 | " 0.0 | \n",
903 | " 0.0 | \n",
904 | " 0.0 | \n",
905 | " 0.0 | \n",
906 | " 0.0 | \n",
907 | " 0.0 | \n",
908 | " 0.0 | \n",
909 | " 0.0 | \n",
910 | " 0.0 | \n",
911 | " ... | \n",
912 | " 0.0 | \n",
913 | " 0.0 | \n",
914 | " 0.0 | \n",
915 | " 0.0 | \n",
916 | " 0.0 | \n",
917 | " 0.0 | \n",
918 | " 0.0 | \n",
919 | " 0.0 | \n",
920 | " 0.0 | \n",
921 | " 0.0 | \n",
922 | "
\n",
923 | " \n",
924 | " 2 | \n",
925 | " 0 | \n",
926 | " 0.0 | \n",
927 | " 0.0 | \n",
928 | " 0.0 | \n",
929 | " 0.0 | \n",
930 | " 0.0 | \n",
931 | " 0.0 | \n",
932 | " 0.0 | \n",
933 | " 0.0 | \n",
934 | " 0.0 | \n",
935 | " ... | \n",
936 | " 0.0 | \n",
937 | " 0.0 | \n",
938 | " 0.0 | \n",
939 | " 0.0 | \n",
940 | " 0.0 | \n",
941 | " 0.0 | \n",
942 | " 0.0 | \n",
943 | " 0.0 | \n",
944 | " 0.0 | \n",
945 | " 0.0 | \n",
946 | "
\n",
947 | " \n",
948 | " 3 | \n",
949 | " 0 | \n",
950 | " 0.0 | \n",
951 | " 0.0 | \n",
952 | " 0.0 | \n",
953 | " 0.0 | \n",
954 | " 0.0 | \n",
955 | " 0.0 | \n",
956 | " 0.0 | \n",
957 | " 0.0 | \n",
958 | " 0.0 | \n",
959 | " ... | \n",
960 | " 0.0 | \n",
961 | " 0.0 | \n",
962 | " 0.0 | \n",
963 | " 0.0 | \n",
964 | " 0.0 | \n",
965 | " 0.0 | \n",
966 | " 0.0 | \n",
967 | " 0.0 | \n",
968 | " 0.0 | \n",
969 | " 0.0 | \n",
970 | "
\n",
971 | " \n",
972 | " 4 | \n",
973 | " 0 | \n",
974 | " 0.0 | \n",
975 | " 0.0 | \n",
976 | " 0.0 | \n",
977 | " 0.0 | \n",
978 | " 0.0 | \n",
979 | " 0.0 | \n",
980 | " 0.0 | \n",
981 | " 0.0 | \n",
982 | " 0.0 | \n",
983 | " ... | \n",
984 | " 0.0 | \n",
985 | " 0.0 | \n",
986 | " 0.0 | \n",
987 | " 0.0 | \n",
988 | " 0.0 | \n",
989 | " 0.0 | \n",
990 | " 0.0 | \n",
991 | " 0.0 | \n",
992 | " 0.0 | \n",
993 | " 0.0 | \n",
994 | "
\n",
995 | " \n",
996 | " ... | \n",
997 | " ... | \n",
998 | " ... | \n",
999 | " ... | \n",
1000 | " ... | \n",
1001 | " ... | \n",
1002 | " ... | \n",
1003 | " ... | \n",
1004 | " ... | \n",
1005 | " ... | \n",
1006 | " ... | \n",
1007 | " ... | \n",
1008 | " ... | \n",
1009 | " ... | \n",
1010 | " ... | \n",
1011 | " ... | \n",
1012 | " ... | \n",
1013 | " ... | \n",
1014 | " ... | \n",
1015 | " ... | \n",
1016 | " ... | \n",
1017 | " ... | \n",
1018 | "
\n",
1019 | " \n",
1020 | " 787523 | \n",
1021 | " 1 | \n",
1022 | " NaN | \n",
1023 | " NaN | \n",
1024 | " NaN | \n",
1025 | " NaN | \n",
1026 | " NaN | \n",
1027 | " NaN | \n",
1028 | " NaN | \n",
1029 | " NaN | \n",
1030 | " NaN | \n",
1031 | " ... | \n",
1032 | " NaN | \n",
1033 | " NaN | \n",
1034 | " NaN | \n",
1035 | " NaN | \n",
1036 | " NaN | \n",
1037 | " NaN | \n",
1038 | " NaN | \n",
1039 | " NaN | \n",
1040 | " NaN | \n",
1041 | " NaN | \n",
1042 | "
\n",
1043 | " \n",
1044 | " 787524 | \n",
1045 | " 1 | \n",
1046 | " NaN | \n",
1047 | " NaN | \n",
1048 | " NaN | \n",
1049 | " NaN | \n",
1050 | " NaN | \n",
1051 | " NaN | \n",
1052 | " NaN | \n",
1053 | " NaN | \n",
1054 | " NaN | \n",
1055 | " ... | \n",
1056 | " NaN | \n",
1057 | " NaN | \n",
1058 | " NaN | \n",
1059 | " NaN | \n",
1060 | " NaN | \n",
1061 | " NaN | \n",
1062 | " NaN | \n",
1063 | " NaN | \n",
1064 | " NaN | \n",
1065 | " NaN | \n",
1066 | "
\n",
1067 | " \n",
1068 | " 787525 | \n",
1069 | " 1 | \n",
1070 | " NaN | \n",
1071 | " NaN | \n",
1072 | " NaN | \n",
1073 | " NaN | \n",
1074 | " NaN | \n",
1075 | " NaN | \n",
1076 | " NaN | \n",
1077 | " NaN | \n",
1078 | " NaN | \n",
1079 | " ... | \n",
1080 | " NaN | \n",
1081 | " NaN | \n",
1082 | " NaN | \n",
1083 | " NaN | \n",
1084 | " NaN | \n",
1085 | " NaN | \n",
1086 | " NaN | \n",
1087 | " NaN | \n",
1088 | " NaN | \n",
1089 | " NaN | \n",
1090 | "
\n",
1091 | " \n",
1092 | " 787526 | \n",
1093 | " 1 | \n",
1094 | " NaN | \n",
1095 | " NaN | \n",
1096 | " NaN | \n",
1097 | " NaN | \n",
1098 | " NaN | \n",
1099 | " NaN | \n",
1100 | " NaN | \n",
1101 | " NaN | \n",
1102 | " NaN | \n",
1103 | " ... | \n",
1104 | " NaN | \n",
1105 | " NaN | \n",
1106 | " NaN | \n",
1107 | " NaN | \n",
1108 | " NaN | \n",
1109 | " NaN | \n",
1110 | " NaN | \n",
1111 | " NaN | \n",
1112 | " NaN | \n",
1113 | " NaN | \n",
1114 | "
\n",
1115 | " \n",
1116 | " 787527 | \n",
1117 | " 1 | \n",
1118 | " NaN | \n",
1119 | " NaN | \n",
1120 | " NaN | \n",
1121 | " NaN | \n",
1122 | " NaN | \n",
1123 | " NaN | \n",
1124 | " NaN | \n",
1125 | " NaN | \n",
1126 | " NaN | \n",
1127 | " ... | \n",
1128 | " NaN | \n",
1129 | " NaN | \n",
1130 | " NaN | \n",
1131 | " NaN | \n",
1132 | " NaN | \n",
1133 | " NaN | \n",
1134 | " NaN | \n",
1135 | " NaN | \n",
1136 | " NaN | \n",
1137 | " NaN | \n",
1138 | "
\n",
1139 | " \n",
1140 | "
\n",
1141 | "
787528 rows × 1251 columns
\n",
1142 | "
"
1143 | ],
1144 | "text/plain": [
1145 | " af 0 1 2 3 4 5 6 7 8 ... 1240 1241 \\\n",
1146 | "0 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1147 | "1 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1148 | "2 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1149 | "3 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1150 | "4 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1151 | "... .. ... ... ... ... ... ... ... ... ... ... ... ... \n",
1152 | "787523 1 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN \n",
1153 | "787524 1 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN \n",
1154 | "787525 1 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN \n",
1155 | "787526 1 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN \n",
1156 | "787527 1 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN \n",
1157 | "\n",
1158 | " 1242 1243 1244 1245 1246 1247 1248 1249 \n",
1159 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1160 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1161 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1162 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1163 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1164 | "... ... ... ... ... ... ... ... ... \n",
1165 | "787523 NaN NaN NaN NaN NaN NaN NaN NaN \n",
1166 | "787524 NaN NaN NaN NaN NaN NaN NaN NaN \n",
1167 | "787525 NaN NaN NaN NaN NaN NaN NaN NaN \n",
1168 | "787526 NaN NaN NaN NaN NaN NaN NaN NaN \n",
1169 | "787527 NaN NaN NaN NaN NaN NaN NaN NaN \n",
1170 | "\n",
1171 | "[787528 rows x 1251 columns]"
1172 | ]
1173 | },
1174 | "execution_count": 19,
1175 | "metadata": {},
1176 | "output_type": "execute_result"
1177 | }
1178 | ],
1179 | "source": [
1180 | "final_df"
1181 | ]
1182 | },
1183 | {
1184 | "cell_type": "code",
1185 | "execution_count": 21,
1186 | "metadata": {},
1187 | "outputs": [],
1188 | "source": [
1189 | "nan_df = final_df[final_df.isnull().any(axis=1)]"
1190 | ]
1191 | },
1192 | {
1193 | "cell_type": "code",
1194 | "execution_count": 22,
1195 | "metadata": {},
1196 | "outputs": [],
1197 | "source": [
1198 | "final_df.drop(nan_df.index, inplace=True)"
1199 | ]
1200 | },
1201 | {
1202 | "cell_type": "code",
1203 | "execution_count": 23,
1204 | "metadata": {},
1205 | "outputs": [
1206 | {
1207 | "data": {
1208 | "text/html": [
1209 | "\n",
1210 | "\n",
1223 | "
\n",
1224 | " \n",
1225 | " \n",
1226 | " | \n",
1227 | " af | \n",
1228 | " 0 | \n",
1229 | " 1 | \n",
1230 | " 2 | \n",
1231 | " 3 | \n",
1232 | " 4 | \n",
1233 | " 5 | \n",
1234 | " 6 | \n",
1235 | " 7 | \n",
1236 | " 8 | \n",
1237 | " ... | \n",
1238 | " 1240 | \n",
1239 | " 1241 | \n",
1240 | " 1242 | \n",
1241 | " 1243 | \n",
1242 | " 1244 | \n",
1243 | " 1245 | \n",
1244 | " 1246 | \n",
1245 | " 1247 | \n",
1246 | " 1248 | \n",
1247 | " 1249 | \n",
1248 | "
\n",
1249 | " \n",
1250 | " \n",
1251 | " \n",
1252 | " 0 | \n",
1253 | " 0 | \n",
1254 | " 0.0 | \n",
1255 | " 0.0 | \n",
1256 | " 0.0 | \n",
1257 | " 0.0 | \n",
1258 | " 0.0 | \n",
1259 | " 0.0 | \n",
1260 | " 0.0 | \n",
1261 | " 0.0 | \n",
1262 | " 0.0 | \n",
1263 | " ... | \n",
1264 | " 0.0 | \n",
1265 | " 0.0 | \n",
1266 | " 0.0 | \n",
1267 | " 0.0 | \n",
1268 | " 0.0 | \n",
1269 | " 0.0 | \n",
1270 | " 0.0 | \n",
1271 | " 0.0 | \n",
1272 | " 0.0 | \n",
1273 | " 0.0 | \n",
1274 | "
\n",
1275 | " \n",
1276 | " 1 | \n",
1277 | " 0 | \n",
1278 | " 0.0 | \n",
1279 | " 0.0 | \n",
1280 | " 0.0 | \n",
1281 | " 0.0 | \n",
1282 | " 0.0 | \n",
1283 | " 0.0 | \n",
1284 | " 0.0 | \n",
1285 | " 0.0 | \n",
1286 | " 0.0 | \n",
1287 | " ... | \n",
1288 | " 0.0 | \n",
1289 | " 0.0 | \n",
1290 | " 0.0 | \n",
1291 | " 0.0 | \n",
1292 | " 0.0 | \n",
1293 | " 0.0 | \n",
1294 | " 0.0 | \n",
1295 | " 0.0 | \n",
1296 | " 0.0 | \n",
1297 | " 0.0 | \n",
1298 | "
\n",
1299 | " \n",
1300 | " 2 | \n",
1301 | " 0 | \n",
1302 | " 0.0 | \n",
1303 | " 0.0 | \n",
1304 | " 0.0 | \n",
1305 | " 0.0 | \n",
1306 | " 0.0 | \n",
1307 | " 0.0 | \n",
1308 | " 0.0 | \n",
1309 | " 0.0 | \n",
1310 | " 0.0 | \n",
1311 | " ... | \n",
1312 | " 0.0 | \n",
1313 | " 0.0 | \n",
1314 | " 0.0 | \n",
1315 | " 0.0 | \n",
1316 | " 0.0 | \n",
1317 | " 0.0 | \n",
1318 | " 0.0 | \n",
1319 | " 0.0 | \n",
1320 | " 0.0 | \n",
1321 | " 0.0 | \n",
1322 | "
\n",
1323 | " \n",
1324 | " 3 | \n",
1325 | " 0 | \n",
1326 | " 0.0 | \n",
1327 | " 0.0 | \n",
1328 | " 0.0 | \n",
1329 | " 0.0 | \n",
1330 | " 0.0 | \n",
1331 | " 0.0 | \n",
1332 | " 0.0 | \n",
1333 | " 0.0 | \n",
1334 | " 0.0 | \n",
1335 | " ... | \n",
1336 | " 0.0 | \n",
1337 | " 0.0 | \n",
1338 | " 0.0 | \n",
1339 | " 0.0 | \n",
1340 | " 0.0 | \n",
1341 | " 0.0 | \n",
1342 | " 0.0 | \n",
1343 | " 0.0 | \n",
1344 | " 0.0 | \n",
1345 | " 0.0 | \n",
1346 | "
\n",
1347 | " \n",
1348 | " 4 | \n",
1349 | " 0 | \n",
1350 | " 0.0 | \n",
1351 | " 0.0 | \n",
1352 | " 0.0 | \n",
1353 | " 0.0 | \n",
1354 | " 0.0 | \n",
1355 | " 0.0 | \n",
1356 | " 0.0 | \n",
1357 | " 0.0 | \n",
1358 | " 0.0 | \n",
1359 | " ... | \n",
1360 | " 0.0 | \n",
1361 | " 0.0 | \n",
1362 | " 0.0 | \n",
1363 | " 0.0 | \n",
1364 | " 0.0 | \n",
1365 | " 0.0 | \n",
1366 | " 0.0 | \n",
1367 | " 0.0 | \n",
1368 | " 0.0 | \n",
1369 | " 0.0 | \n",
1370 | "
\n",
1371 | " \n",
1372 | " ... | \n",
1373 | " ... | \n",
1374 | " ... | \n",
1375 | " ... | \n",
1376 | " ... | \n",
1377 | " ... | \n",
1378 | " ... | \n",
1379 | " ... | \n",
1380 | " ... | \n",
1381 | " ... | \n",
1382 | " ... | \n",
1383 | " ... | \n",
1384 | " ... | \n",
1385 | " ... | \n",
1386 | " ... | \n",
1387 | " ... | \n",
1388 | " ... | \n",
1389 | " ... | \n",
1390 | " ... | \n",
1391 | " ... | \n",
1392 | " ... | \n",
1393 | " ... | \n",
1394 | "
\n",
1395 | " \n",
1396 | " 787506 | \n",
1397 | " 1 | \n",
1398 | " 0.0 | \n",
1399 | " 0.0 | \n",
1400 | " 0.0 | \n",
1401 | " 0.0 | \n",
1402 | " 0.0 | \n",
1403 | " 0.0 | \n",
1404 | " 0.0 | \n",
1405 | " 0.0 | \n",
1406 | " 0.0 | \n",
1407 | " ... | \n",
1408 | " 0.0 | \n",
1409 | " 0.0 | \n",
1410 | " 0.0 | \n",
1411 | " 0.0 | \n",
1412 | " 0.0 | \n",
1413 | " 0.0 | \n",
1414 | " 0.0 | \n",
1415 | " 0.0 | \n",
1416 | " 0.0 | \n",
1417 | " 0.0 | \n",
1418 | "
\n",
1419 | " \n",
1420 | " 787507 | \n",
1421 | " 1 | \n",
1422 | " 0.0 | \n",
1423 | " 0.0 | \n",
1424 | " 0.0 | \n",
1425 | " 0.0 | \n",
1426 | " 0.0 | \n",
1427 | " 0.0 | \n",
1428 | " 0.0 | \n",
1429 | " 0.0 | \n",
1430 | " 0.0 | \n",
1431 | " ... | \n",
1432 | " 0.0 | \n",
1433 | " 0.0 | \n",
1434 | " 0.0 | \n",
1435 | " 0.0 | \n",
1436 | " 0.0 | \n",
1437 | " 0.0 | \n",
1438 | " 0.0 | \n",
1439 | " 0.0 | \n",
1440 | " 0.0 | \n",
1441 | " 0.0 | \n",
1442 | "
\n",
1443 | " \n",
1444 | " 787508 | \n",
1445 | " 1 | \n",
1446 | " 0.0 | \n",
1447 | " 0.0 | \n",
1448 | " 0.0 | \n",
1449 | " 0.0 | \n",
1450 | " 0.0 | \n",
1451 | " 0.0 | \n",
1452 | " 0.0 | \n",
1453 | " 0.0 | \n",
1454 | " 0.0 | \n",
1455 | " ... | \n",
1456 | " 0.0 | \n",
1457 | " 0.0 | \n",
1458 | " 0.0 | \n",
1459 | " 0.0 | \n",
1460 | " 0.0 | \n",
1461 | " 0.0 | \n",
1462 | " 0.0 | \n",
1463 | " 0.0 | \n",
1464 | " 0.0 | \n",
1465 | " 0.0 | \n",
1466 | "
\n",
1467 | " \n",
1468 | " 787509 | \n",
1469 | " 1 | \n",
1470 | " 0.0 | \n",
1471 | " 0.0 | \n",
1472 | " 0.0 | \n",
1473 | " 0.0 | \n",
1474 | " 0.0 | \n",
1475 | " 0.0 | \n",
1476 | " 0.0 | \n",
1477 | " 0.0 | \n",
1478 | " 0.0 | \n",
1479 | " ... | \n",
1480 | " 0.0 | \n",
1481 | " 0.0 | \n",
1482 | " 0.0 | \n",
1483 | " 0.0 | \n",
1484 | " 0.0 | \n",
1485 | " 0.0 | \n",
1486 | " 0.0 | \n",
1487 | " 0.0 | \n",
1488 | " 0.0 | \n",
1489 | " 0.0 | \n",
1490 | "
\n",
1491 | " \n",
1492 | " 787510 | \n",
1493 | " 1 | \n",
1494 | " 0.0 | \n",
1495 | " 0.0 | \n",
1496 | " 0.0 | \n",
1497 | " 0.0 | \n",
1498 | " 0.0 | \n",
1499 | " 0.0 | \n",
1500 | " 0.0 | \n",
1501 | " 0.0 | \n",
1502 | " 0.0 | \n",
1503 | " ... | \n",
1504 | " 0.0 | \n",
1505 | " 0.0 | \n",
1506 | " 0.0 | \n",
1507 | " 0.0 | \n",
1508 | " 0.0 | \n",
1509 | " 0.0 | \n",
1510 | " 0.0 | \n",
1511 | " 0.0 | \n",
1512 | " 0.0 | \n",
1513 | " 0.0 | \n",
1514 | "
\n",
1515 | " \n",
1516 | "
\n",
1517 | "
787154 rows × 1251 columns
\n",
1518 | "
"
1519 | ],
1520 | "text/plain": [
1521 | " af 0 1 2 3 4 5 6 7 8 ... 1240 1241 \\\n",
1522 | "0 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1523 | "1 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1524 | "2 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1525 | "3 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1526 | "4 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1527 | "... .. ... ... ... ... ... ... ... ... ... ... ... ... \n",
1528 | "787506 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1529 | "787507 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1530 | "787508 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1531 | "787509 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1532 | "787510 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
1533 | "\n",
1534 | " 1242 1243 1244 1245 1246 1247 1248 1249 \n",
1535 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1536 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1537 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1538 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1539 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1540 | "... ... ... ... ... ... ... ... ... \n",
1541 | "787506 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1542 | "787507 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1543 | "787508 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1544 | "787509 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1545 | "787510 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
1546 | "\n",
1547 | "[787154 rows x 1251 columns]"
1548 | ]
1549 | },
1550 | "execution_count": 23,
1551 | "metadata": {},
1552 | "output_type": "execute_result"
1553 | }
1554 | ],
1555 | "source": [
1556 | "final_df"
1557 | ]
1558 | },
1559 | {
1560 | "cell_type": "code",
1561 | "execution_count": 24,
1562 | "metadata": {},
1563 | "outputs": [
1564 | {
1565 | "data": {
1566 | "text/html": [
1567 | "\n",
1568 | "\n",
1581 | "
\n",
1582 | " \n",
1583 | " \n",
1584 | " | \n",
1585 | " af | \n",
1586 | " 0 | \n",
1587 | " 1 | \n",
1588 | " 2 | \n",
1589 | " 3 | \n",
1590 | " 4 | \n",
1591 | " 5 | \n",
1592 | " 6 | \n",
1593 | " 7 | \n",
1594 | " 8 | \n",
1595 | " ... | \n",
1596 | " 1240 | \n",
1597 | " 1241 | \n",
1598 | " 1242 | \n",
1599 | " 1243 | \n",
1600 | " 1244 | \n",
1601 | " 1245 | \n",
1602 | " 1246 | \n",
1603 | " 1247 | \n",
1604 | " 1248 | \n",
1605 | " 1249 | \n",
1606 | "
\n",
1607 | " \n",
1608 | " \n",
1609 | " \n",
1610 | "
\n",
1611 | "
0 rows × 1251 columns
\n",
1612 | "
"
1613 | ],
1614 | "text/plain": [
1615 | "Empty DataFrame\n",
1616 | "Columns: [af, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, ...]\n",
1617 | "Index: []\n",
1618 | "\n",
1619 | "[0 rows x 1251 columns]"
1620 | ]
1621 | },
1622 | "execution_count": 24,
1623 | "metadata": {},
1624 | "output_type": "execute_result"
1625 | }
1626 | ],
1627 | "source": [
1628 | "final_df[final_df.isnull().any(axis=1)]"
1629 | ]
1630 | },
1631 | {
1632 | "cell_type": "code",
1633 | "execution_count": 25,
1634 | "metadata": {},
1635 | "outputs": [],
1636 | "source": [
1637 | "final_df.to_csv('hugedfrpeak5sec.csv')"
1638 | ]
1639 | }
1640 | ],
1641 | "metadata": {
1642 | "kernelspec": {
1643 | "display_name": "Python [conda env:physio]",
1644 | "language": "python",
1645 | "name": "conda-env-physio-py"
1646 | },
1647 | "language_info": {
1648 | "codemirror_mode": {
1649 | "name": "ipython",
1650 | "version": 3
1651 | },
1652 | "file_extension": ".py",
1653 | "mimetype": "text/x-python",
1654 | "name": "python",
1655 | "nbconvert_exporter": "python",
1656 | "pygments_lexer": "ipython3",
1657 | "version": "3.7.5"
1658 | }
1659 | },
1660 | "nbformat": 4,
1661 | "nbformat_minor": 4
1662 | }
1663 |
--------------------------------------------------------------------------------
/assets/deployment_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vishaln15/OptimizedArrhythmiaDetection/9717c2aeef6214fe3a68bd50544e764cc1a98c5c/assets/deployment_diagram.png
--------------------------------------------------------------------------------
/baseline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Decision Tree:"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 4,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pandas as pd\n",
17 | "import numpy as np\n",
18 | "import os\n",
19 | "from sklearn.tree import DecisionTreeClassifier\n",
20 | "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier\n",
21 | "from sklearn.svm import SVC\n",
22 | "from sklearn.model_selection import train_test_split\n",
23 | "from sklearn.metrics import classification_report"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "name": "stdout",
33 | "output_type": "stream",
34 | "text": [
35 | "CPU times: user 806 ms, sys: 101 ms, total: 908 ms\n",
36 | "Wall time: 1.38 s\n"
37 | ]
38 | }
39 | ],
40 | "source": [
41 | "%%time\n",
42 | "afd = pd.read_csv('/hdd/physio/af/hugedf5secfeatures.csv')\n",
43 | "chd = pd.read_csv('/hdd/physio/af2/finaldfs/2017features.csv')"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 13,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "afd = chd"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 14,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "# x = afd[afd.columns[1:]].sample(15000)\n",
62 | "# y = afd[afd.columns[0]]\n",
63 | "\n",
64 | "# x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state = 0)\n",
65 | "\n",
66 | "_0 = afd[afd.af == 0].sample(n = 15000).reset_index(drop = True)\n",
67 | "_1 = afd[afd.af == 1].sample(n = 15000).reset_index(drop = True)\n",
68 | "train = pd.concat([_0, _1], axis = 0).sample(frac = 1).reset_index(drop = True)\n",
69 | "x_train = train[train.columns[1:]]\n",
70 | "y_train = train[train.columns[0]]\n",
71 | "train = train.to_numpy()\n",
72 | "_0 = afd[afd.af == 0].sample(n = 1000).reset_index(drop = True)\n",
73 | "_1 = afd[afd.af == 1].sample(n = 1000).reset_index(drop = True)\n",
74 | "test = pd.concat([_0, _1], axis = 0).sample(frac = 1).reset_index(drop = True)\n",
75 | "x_test = test[test.columns[1:]]\n",
76 | "y_test = test[test.columns[0]]\n",
77 | "test = test.to_numpy()"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 15,
83 | "metadata": {},
84 | "outputs": [
85 | {
86 | "data": {
87 | "text/plain": [
88 | "ExtraTreesClassifier(random_state=0)"
89 | ]
90 | },
91 | "execution_count": 15,
92 | "metadata": {},
93 | "output_type": "execute_result"
94 | }
95 | ],
96 | "source": [
97 | "# Support Vector Classifier\n",
98 | "\n",
99 | "svc = SVC(random_state = 0)\n",
100 | "svc.fit(x_train, y_train)\n",
101 | "\n",
102 | "# Decision Tree Classifier\n",
103 | "\n",
104 | "dt = DecisionTreeClassifier(random_state = 0)\n",
105 | "dt.fit(x_train, y_train)\n",
106 | "\n",
107 | "# Random Forest Classifier\n",
108 | "\n",
109 | "rf = RandomForestClassifier(random_state = 0)\n",
110 | "rf.fit(x_train, y_train)\n",
111 | "\n",
112 | "# Extra Trees Classifier\n",
113 | "\n",
114 | "et = ExtraTreesClassifier(random_state = 0)\n",
115 | "et.fit(x_train, y_train)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 11,
121 | "metadata": {},
122 | "outputs": [
123 | {
124 | "name": "stdout",
125 | "output_type": "stream",
126 | "text": [
127 | "SVC\n",
128 | " precision recall f1-score support\n",
129 | "\n",
130 | " 0 0.95 0.83 0.88 1000\n",
131 | " 1 0.85 0.95 0.90 1000\n",
132 | "\n",
133 | " accuracy 0.89 2000\n",
134 | " macro avg 0.90 0.89 0.89 2000\n",
135 | "weighted avg 0.90 0.89 0.89 2000\n",
136 | "\n",
137 | "DT\n",
138 | " precision recall f1-score support\n",
139 | "\n",
140 | " 0 0.89 0.87 0.88 1000\n",
141 | " 1 0.87 0.89 0.88 1000\n",
142 | "\n",
143 | " accuracy 0.88 2000\n",
144 | " macro avg 0.88 0.88 0.88 2000\n",
145 | "weighted avg 0.88 0.88 0.88 2000\n",
146 | "\n",
147 | "RF\n",
148 | " precision recall f1-score support\n",
149 | "\n",
150 | " 0 0.95 0.88 0.91 1000\n",
151 | " 1 0.89 0.95 0.92 1000\n",
152 | "\n",
153 | " accuracy 0.92 2000\n",
154 | " macro avg 0.92 0.92 0.92 2000\n",
155 | "weighted avg 0.92 0.92 0.92 2000\n",
156 | "\n",
157 | "ET\n",
158 | " precision recall f1-score support\n",
159 | "\n",
160 | " 0 0.95 0.88 0.91 1000\n",
161 | " 1 0.89 0.95 0.92 1000\n",
162 | "\n",
163 | " accuracy 0.92 2000\n",
164 | " macro avg 0.92 0.92 0.92 2000\n",
165 | "weighted avg 0.92 0.92 0.92 2000\n",
166 | "\n"
167 | ]
168 | }
169 | ],
170 | "source": [
171 | "# AFDB\n",
172 | "\n",
173 | "for i, j in zip([svc, dt, rf, et], [\"SVC\", \"DT\", \"RF\", \"ET\"]):\n",
174 | " print(j + '\\n' + classification_report(y_test, i.predict(x_test)))"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 16,
180 | "metadata": {},
181 | "outputs": [
182 | {
183 | "name": "stdout",
184 | "output_type": "stream",
185 | "text": [
186 | "SVC\n",
187 | " precision recall f1-score support\n",
188 | "\n",
189 | " 0 0.89 0.93 0.91 1000\n",
190 | " 1 0.93 0.89 0.91 1000\n",
191 | "\n",
192 | " accuracy 0.91 2000\n",
193 | " macro avg 0.91 0.91 0.91 2000\n",
194 | "weighted avg 0.91 0.91 0.91 2000\n",
195 | "\n",
196 | "DT\n",
197 | " precision recall f1-score support\n",
198 | "\n",
199 | " 0 0.96 0.89 0.92 1000\n",
200 | " 1 0.90 0.96 0.93 1000\n",
201 | "\n",
202 | " accuracy 0.93 2000\n",
203 | " macro avg 0.93 0.93 0.93 2000\n",
204 | "weighted avg 0.93 0.93 0.93 2000\n",
205 | "\n",
206 | "RF\n",
207 | " precision recall f1-score support\n",
208 | "\n",
209 | " 0 0.97 0.93 0.95 1000\n",
210 | " 1 0.93 0.97 0.95 1000\n",
211 | "\n",
212 | " accuracy 0.95 2000\n",
213 | " macro avg 0.95 0.95 0.95 2000\n",
214 | "weighted avg 0.95 0.95 0.95 2000\n",
215 | "\n",
216 | "ET\n",
217 | " precision recall f1-score support\n",
218 | "\n",
219 | " 0 0.97 0.93 0.95 1000\n",
220 | " 1 0.93 0.97 0.95 1000\n",
221 | "\n",
222 | " accuracy 0.95 2000\n",
223 | " macro avg 0.95 0.95 0.95 2000\n",
224 | "weighted avg 0.95 0.95 0.95 2000\n",
225 | "\n"
226 | ]
227 | }
228 | ],
229 | "source": [
230 | "# CHDB\n",
231 | "\n",
232 | "for i, j in zip([svc, dt, rf, et], [\"SVC\", \"DT\", \"RF\", \"ET\"]):\n",
233 | " print(j + '\\n' + classification_report(y_test, i.predict(x_test)))"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": 17,
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "import joblib\n",
243 | "\n",
244 | "for i, j in zip([svc, dt, rf, et], [\"SV\", \"DT\", \"RF\", \"ET\"]):\n",
245 | " joblib.dump(i, \"chdb_\" + j + \".joblib\")"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 20,
251 | "metadata": {},
252 | "outputs": [
253 | {
254 | "name": "stdout",
255 | "output_type": "stream",
256 | "text": [
257 | "SV:\n",
258 | "0.55MB\n",
259 | "DT:\n",
260 | "0.37MB\n",
261 | "RF:\n",
262 | "29.98MB\n",
263 | "ET:\n",
264 | "89.66MB\n"
265 | ]
266 | }
267 | ],
268 | "source": [
269 | "for i, j in zip([svc, dt, rf, et], [\"SV\", \"DT\", \"RF\", \"ET\"]):\n",
270 | " print(j + ':\\n' + str(np.round(os.path.getsize('chdb_' + j + '.joblib') / 1024 / 1024, 2))+\"MB\")"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 50,
276 | "metadata": {},
277 | "outputs": [
278 | {
279 | "name": "stdout",
280 | "output_type": "stream",
281 | "text": [
282 | "Random Forest size: 0.58 MB\n"
283 | ]
284 | }
285 | ],
286 | "source": [
287 | "joblib.dump(svc, \"RandomForest_100_trees.joblib\") \n",
288 | "print(f\"Random Forest size: {np.round(os.path.getsize('RandomForest_100_trees.joblib') / 1024 / 1024, 2) } MB\")"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 53,
294 | "metadata": {},
295 | "outputs": [
296 | {
297 | "data": {
298 | "text/plain": [
299 | "DecisionTreeClassifier()"
300 | ]
301 | },
302 | "execution_count": 53,
303 | "metadata": {},
304 | "output_type": "execute_result"
305 | }
306 | ],
307 | "source": [
308 | "dt = DecisionTreeClassifier()\n",
309 | "dt.fit(x_train, y_train)"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 54,
315 | "metadata": {},
316 | "outputs": [
317 | {
318 | "name": "stdout",
319 | "output_type": "stream",
320 | "text": [
321 | "Random Forest size: 0.35 MB\n"
322 | ]
323 | }
324 | ],
325 | "source": [
326 | "joblib.dump(dt, \"RandomForest_100_trees.joblib\") \n",
327 | "print(f\"Random Forest size: {np.round(os.path.getsize('RandomForest_100_trees.joblib') / 1024 / 1024, 2) } MB\")"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": 3,
333 | "metadata": {},
334 | "outputs": [],
335 | "source": [
336 | "d = pd.read_csv('/hdd/physio/af/finalscore.csv')"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": 4,
342 | "metadata": {},
343 | "outputs": [
344 | {
345 | "data": {
346 | "text/html": [
347 | "\n",
348 | "\n",
361 | "
\n",
362 | " \n",
363 | " \n",
364 | " | \n",
365 | " af | \n",
366 | " RMSSD | \n",
367 | " STDNN | \n",
368 | " MEAN_RR | \n",
369 | " MEAN_HR | \n",
370 | " STD_HR | \n",
371 | " MIN_HR | \n",
372 | " MAX_HR | \n",
373 | "
\n",
374 | " \n",
375 | " \n",
376 | " \n",
377 | " 0 | \n",
378 | " 1 | \n",
379 | " 0.799498 | \n",
380 | " 0.074705 | \n",
381 | " 0.796000 | \n",
382 | " 76.041099 | \n",
383 | " 7.091900 | \n",
384 | " 66.371681 | \n",
385 | " 84.745763 | \n",
386 | "
\n",
387 | " \n",
388 | " 1 | \n",
389 | " 1 | \n",
390 | " 0.399271 | \n",
391 | " 0.035757 | \n",
392 | " 0.397667 | \n",
393 | " 152.061961 | \n",
394 | " 13.203645 | \n",
395 | " 127.118644 | \n",
396 | " 172.413793 | \n",
397 | "
\n",
398 | " \n",
399 | " 2 | \n",
400 | " 0 | \n",
401 | " 0.610334 | \n",
402 | " 0.007667 | \n",
403 | " 0.610286 | \n",
404 | " 98.330151 | \n",
405 | " 1.237577 | \n",
406 | " 96.153846 | \n",
407 | " 100.671141 | \n",
408 | "
\n",
409 | " \n",
410 | " 3 | \n",
411 | " 0 | \n",
412 | " 0.884018 | \n",
413 | " 0.005657 | \n",
414 | " 0.884000 | \n",
415 | " 67.876083 | \n",
416 | " 0.434376 | \n",
417 | " 67.264574 | \n",
418 | " 68.493151 | \n",
419 | "
\n",
420 | " \n",
421 | " 4 | \n",
422 | " 0 | \n",
423 | " 0.597896 | \n",
424 | " 0.128156 | \n",
425 | " 0.584000 | \n",
426 | " 109.100500 | \n",
427 | " 29.368470 | \n",
428 | " 81.967213 | \n",
429 | " 159.574468 | \n",
430 | "
\n",
431 | " \n",
432 | " ... | \n",
433 | " ... | \n",
434 | " ... | \n",
435 | " ... | \n",
436 | " ... | \n",
437 | " ... | \n",
438 | " ... | \n",
439 | " ... | \n",
440 | " ... | \n",
441 | "
\n",
442 | " \n",
443 | " 35995 | \n",
444 | " 0 | \n",
445 | " 1.002106 | \n",
446 | " 0.014560 | \n",
447 | " 1.002000 | \n",
448 | " 59.892887 | \n",
449 | " 0.870390 | \n",
450 | " 58.823529 | \n",
451 | " 60.975610 | \n",
452 | "
\n",
453 | " \n",
454 | " 35996 | \n",
455 | " 0 | \n",
456 | " 0.574999 | \n",
457 | " 0.012778 | \n",
458 | " 0.574857 | \n",
459 | " 104.425519 | \n",
460 | " 2.329051 | \n",
461 | " 101.351351 | \n",
462 | " 107.913669 | \n",
463 | "
\n",
464 | " \n",
465 | " 35997 | \n",
466 | " 0 | \n",
467 | " 0.956046 | \n",
468 | " 0.009381 | \n",
469 | " 0.956000 | \n",
470 | " 62.767489 | \n",
471 | " 0.609665 | \n",
472 | " 61.728395 | \n",
473 | " 63.291139 | \n",
474 | "
\n",
475 | " \n",
476 | " 35998 | \n",
477 | " 0 | \n",
478 | " 1.087822 | \n",
479 | " 0.042273 | \n",
480 | " 1.087000 | \n",
481 | " 55.278977 | \n",
482 | " 2.088315 | \n",
483 | " 51.903114 | \n",
484 | " 57.471264 | \n",
485 | "
\n",
486 | " \n",
487 | " 35999 | \n",
488 | " 1 | \n",
489 | " 0.654437 | \n",
490 | " 0.136704 | \n",
491 | " 0.640000 | \n",
492 | " 97.858593 | \n",
493 | " 19.429063 | \n",
494 | " 69.124424 | \n",
495 | " 125.000000 | \n",
496 | "
\n",
497 | " \n",
498 | "
\n",
499 | "
36000 rows × 8 columns
\n",
500 | "
"
501 | ],
502 | "text/plain": [
503 | " af RMSSD STDNN MEAN_RR MEAN_HR STD_HR MIN_HR \\\n",
504 | "0 1 0.799498 0.074705 0.796000 76.041099 7.091900 66.371681 \n",
505 | "1 1 0.399271 0.035757 0.397667 152.061961 13.203645 127.118644 \n",
506 | "2 0 0.610334 0.007667 0.610286 98.330151 1.237577 96.153846 \n",
507 | "3 0 0.884018 0.005657 0.884000 67.876083 0.434376 67.264574 \n",
508 | "4 0 0.597896 0.128156 0.584000 109.100500 29.368470 81.967213 \n",
509 | "... .. ... ... ... ... ... ... \n",
510 | "35995 0 1.002106 0.014560 1.002000 59.892887 0.870390 58.823529 \n",
511 | "35996 0 0.574999 0.012778 0.574857 104.425519 2.329051 101.351351 \n",
512 | "35997 0 0.956046 0.009381 0.956000 62.767489 0.609665 61.728395 \n",
513 | "35998 0 1.087822 0.042273 1.087000 55.278977 2.088315 51.903114 \n",
514 | "35999 1 0.654437 0.136704 0.640000 97.858593 19.429063 69.124424 \n",
515 | "\n",
516 | " MAX_HR \n",
517 | "0 84.745763 \n",
518 | "1 172.413793 \n",
519 | "2 100.671141 \n",
520 | "3 68.493151 \n",
521 | "4 159.574468 \n",
522 | "... ... \n",
523 | "35995 60.975610 \n",
524 | "35996 107.913669 \n",
525 | "35997 63.291139 \n",
526 | "35998 57.471264 \n",
527 | "35999 125.000000 \n",
528 | "\n",
529 | "[36000 rows x 8 columns]"
530 | ]
531 | },
532 | "execution_count": 4,
533 | "metadata": {},
534 | "output_type": "execute_result"
535 | }
536 | ],
537 | "source": [
538 | "d"
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": 9,
544 | "metadata": {},
545 | "outputs": [],
546 | "source": [
547 | "x_train, x_test, y_train, y_test = train_test_split(d[d.columns[1:]], d[d.columns[0]], test_size = 0.2, random_state = 0)"
548 | ]
549 | },
550 | {
551 | "cell_type": "code",
552 | "execution_count": 10,
553 | "metadata": {},
554 | "outputs": [
555 | {
556 | "data": {
557 | "text/plain": [
558 | "DecisionTreeClassifier()"
559 | ]
560 | },
561 | "execution_count": 10,
562 | "metadata": {},
563 | "output_type": "execute_result"
564 | }
565 | ],
566 | "source": [
567 | "dt = DecisionTreeClassifier()\n",
568 | "dt.fit(x_train, y_train)"
569 | ]
570 | },
571 | {
572 | "cell_type": "code",
573 | "execution_count": 11,
574 | "metadata": {},
575 | "outputs": [
576 | {
577 | "name": "stdout",
578 | "output_type": "stream",
579 | "text": [
580 | " precision recall f1-score support\n",
581 | "\n",
582 | " 0 0.88 0.89 0.89 3573\n",
583 | " 1 0.89 0.88 0.89 3627\n",
584 | "\n",
585 | " accuracy 0.89 7200\n",
586 | " macro avg 0.89 0.89 0.89 7200\n",
587 | "weighted avg 0.89 0.89 0.89 7200\n",
588 | "\n"
589 | ]
590 | }
591 | ],
592 | "source": [
593 | "pred = dt.predict(x_test)\n",
594 | "print(classification_report(y_test, pred))"
595 | ]
596 | },
597 | {
598 | "cell_type": "code",
599 | "execution_count": null,
600 | "metadata": {},
601 | "outputs": [],
602 | "source": []
603 | }
604 | ],
605 | "metadata": {
606 | "kernelspec": {
607 | "display_name": "Python [conda env:physio]",
608 | "language": "python",
609 | "name": "conda-env-physio-py"
610 | },
611 | "language_info": {
612 | "codemirror_mode": {
613 | "name": "ipython",
614 | "version": 3
615 | },
616 | "file_extension": ".py",
617 | "mimetype": "text/x-python",
618 | "name": "python",
619 | "nbconvert_exporter": "python",
620 | "pygments_lexer": "ipython3",
621 | "version": "3.7.5"
622 | }
623 | },
624 | "nbformat": 4,
625 | "nbformat_minor": 4
626 | }
627 |
--------------------------------------------------------------------------------
/bonsaiTrainer2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT license.
3 |
4 | import torch
5 | import numpy as np
6 | import os
7 | import sys
8 | import edgeml_pytorch.utils as utils
9 | from sklearn.metrics import f1_score
10 | from sklearn.metrics import classification_report
11 | from sklearn.metrics import confusion_matrix
12 |
13 |
14 | class BonsaiTrainer:
15 |
16 | def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ,
17 | learningRate, useMCHLoss=False, outFile=None, device=None):
18 | '''
19 | bonsaiObj - Initialised Bonsai Object and Graph
20 | lW, lT, lV and lZ are regularisers to Bonsai Params
21 | sW, sT, sV and sZ are sparsity factors to Bonsai Params
22 | learningRate - learningRate for optimizer
23 | useMCHLoss - For choice between HingeLoss vs CrossEntropy
24 | useMCHLoss - True - MultiClass - multiClassHingeLoss
25 | useMCHLoss - False - MultiClass - crossEntropyLoss
26 | '''
27 |
28 | self.bonsaiObj = bonsaiObj
29 |
30 | self.lW = lW
31 | self.lV = lV
32 | self.lT = lT
33 | self.lZ = lZ
34 |
35 | self.sW = sW
36 | self.sV = sV
37 | self.sT = sT
38 | self.sZ = sZ
39 |
40 | if device is None:
41 | self.device = "cpu"
42 | else:
43 | self.device = device
44 |
45 | self.useMCHLoss = useMCHLoss
46 |
47 | if outFile is not None:
48 | print("Outfile : ", outFile)
49 | self.outFile = open(outFile, 'w')
50 | else:
51 | self.outFile = sys.stdout
52 |
53 | self.learningRate = learningRate
54 |
55 | self.assertInit()
56 |
57 | self.optimizer = self.optimizer()
58 |
59 | if self.sW > 0.99 and self.sV > 0.99 and self.sZ > 0.99 and self.sT > 0.99:
60 | self.isDenseTraining = True
61 | else:
62 | self.isDenseTraining = False
63 |
64 | def loss(self, logits, labels):
65 | '''
66 | Loss function for given Bonsai Obj
67 | '''
68 | regLoss = 0.5 * (self.lZ * (torch.norm(self.bonsaiObj.Z)**2) +
69 | self.lW * (torch.norm(self.bonsaiObj.W)**2) +
70 | self.lV * (torch.norm(self.bonsaiObj.V)**2) +
71 | self.lT * (torch.norm(self.bonsaiObj.T))**2)
72 |
73 | if (self.bonsaiObj.numClasses > 2):
74 | if self.useMCHLoss is True:
75 | marginLoss = utils.multiClassHingeLoss(logits, labels)
76 | else:
77 | marginLoss = utils.crossEntropyLoss(logits, labels)
78 | loss = marginLoss + regLoss
79 | else:
80 | marginLoss = utils.binaryHingeLoss(logits, labels)
81 | loss = marginLoss + regLoss
82 |
83 | return loss, marginLoss, regLoss
84 |
85 | def optimizer(self):
86 | '''
87 | Optimizer for Bonsai Params
88 | '''
89 | optimizer = torch.optim.Adam(
90 | self.bonsaiObj.parameters(), lr=self.learningRate)
91 |
92 | return optimizer
93 |
94 | def accuracy(self, logits, labels):
95 | '''
96 | Accuracy fucntion to evaluate accuracy when needed
97 | '''
98 | if (self.bonsaiObj.numClasses > 2):
99 | correctPredictions = (logits.argmax(dim=1) == labels.argmax(dim=1))
100 | accuracy = torch.mean(correctPredictions.float())
101 | else:
102 | pred = (torch.cat((torch.zeros(logits.shape),
103 | logits), 1)).argmax(dim=1)
104 | accuracy = torch.mean((labels.view(-1).long() == pred).float())
105 |
106 | return accuracy
107 |
108 | def classificationReport(self, logits, labels):
109 | pred = (torch.cat((torch.zeros(logits.shape),
110 | logits), 1)).argmax(dim=1)
111 | return classification_report(labels, pred, output_dict=True)
112 |
113 | def confusion_matrix_FAR(self, logits, labels):
114 | pred = (torch.cat((torch.zeros(logits.shape),
115 | logits), 1)).argmax(dim=1)
116 | CM = confusion_matrix(labels, pred)
117 | TN = CM[0][0]
118 | FN = CM[1][0]
119 | TP = CM[1][1]
120 | FP = CM[0][1]
121 | FAR = FP/(FP+TN)
122 |
123 | return CM, FAR
124 |
125 | def f1(self, logits, labels):
126 | '''
127 | f1 score function to evaluate f1 when needed
128 | '''
129 | # print("logits:", logits, logits.shape)
130 | # print("labels:", labels, labels.shape)
131 | if (self.bonsaiObj.numClasses > 2): # doesnt work for multi-class
132 | correct = (logits.argmax(dim=1) == labels.argmax(dim=1))
133 | pred = torch.zeros(logits.shape)
134 | pred[logits.argmax(dim=1),:] = 1
135 | else:
136 | pred = (torch.cat((torch.zeros(logits.shape),
137 | logits), 1)).argmax(dim=1)
138 | # print("pred:", pred, pred.shape)
139 | f1score = f1_score(labels, pred)
140 |
141 | return f1score
142 |
143 | def runHardThrsd(self):
144 | '''
145 | Function to run the IHT routine on Bonsai Obj
146 | '''
147 | currW = self.bonsaiObj.W.data
148 | currV = self.bonsaiObj.V.data
149 | currZ = self.bonsaiObj.Z.data
150 | currT = self.bonsaiObj.T.data
151 |
152 | __thrsdW = utils.hardThreshold(currW.cpu(), self.sW)
153 | __thrsdV = utils.hardThreshold(currV.cpu(), self.sV)
154 | __thrsdZ = utils.hardThreshold(currZ.cpu(), self.sZ)
155 | __thrsdT = utils.hardThreshold(currT.cpu(), self.sT)
156 |
157 | self.bonsaiObj.W.data = torch.FloatTensor(
158 | __thrsdW).to(self.device)
159 | self.bonsaiObj.V.data = torch.FloatTensor(
160 | __thrsdV).to(self.device)
161 | self.bonsaiObj.Z.data = torch.FloatTensor(
162 | __thrsdZ).to(self.device)
163 | self.bonsaiObj.T.data = torch.FloatTensor(
164 | __thrsdT).to(self.device)
165 |
166 | self.__thrsdW = torch.FloatTensor(
167 | __thrsdW.detach().clone()).to(self.device)
168 | self.__thrsdV = torch.FloatTensor(
169 | __thrsdV.detach().clone()).to(self.device)
170 | self.__thrsdZ = torch.FloatTensor(
171 | __thrsdZ.detach().clone()).to(self.device)
172 | self.__thrsdT = torch.FloatTensor(
173 | __thrsdT.detach().clone()).to(self.device)
174 |
175 | def runSparseTraining(self):
176 | '''
177 | Function to run the Sparse Retraining routine on Bonsai Obj
178 | '''
179 | currW = self.bonsaiObj.W.data
180 | currV = self.bonsaiObj.V.data
181 | currZ = self.bonsaiObj.Z.data
182 | currT = self.bonsaiObj.T.data
183 |
184 | newW = utils.copySupport(self.__thrsdW, currW)
185 | newV = utils.copySupport(self.__thrsdV, currV)
186 | newZ = utils.copySupport(self.__thrsdZ, currZ)
187 | newT = utils.copySupport(self.__thrsdT, currT)
188 |
189 | self.bonsaiObj.W.data = newW
190 | self.bonsaiObj.V.data = newV
191 | self.bonsaiObj.Z.data = newZ
192 | self.bonsaiObj.T.data = newT
193 |
194 | def assertInit(self):
195 | err = "sparsity must be between 0 and 1"
196 | assert self.sW >= 0 and self.sW <= 1, "W " + err
197 | assert self.sV >= 0 and self.sV <= 1, "V " + err
198 | assert self.sZ >= 0 and self.sZ <= 1, "Z " + err
199 | assert self.sT >= 0 and self.sT <= 1, "T " + err
200 |
201 | def saveParams(self, currDir):
202 | '''
203 | Function to save Parameter matrices into a given folder
204 | '''
205 | paramDir = currDir + '/'
206 | np.save(paramDir + "W.npy", self.bonsaiObj.W.data.cpu())
207 | np.save(paramDir + "V.npy", self.bonsaiObj.V.data.cpu())
208 | np.save(paramDir + "T.npy", self.bonsaiObj.T.data.cpu())
209 | np.save(paramDir + "Z.npy", self.bonsaiObj.Z.data.cpu())
210 | hyperParamDict = {'dataDim': self.bonsaiObj.dataDimension,
211 | 'projDim': self.bonsaiObj.projectionDimension,
212 | 'numClasses': self.bonsaiObj.numClasses,
213 | 'depth': self.bonsaiObj.treeDepth,
214 | 'sigma': self.bonsaiObj.sigma}
215 | hyperParamFile = paramDir + 'hyperParam.npy'
216 | np.save(hyperParamFile, hyperParamDict)
217 |
218 | def saveParamsForSeeDot(self, currDir):
219 | '''
220 | Function to save Parameter matrices into a given folder for SeeDot compiler
221 | '''
222 | seeDotDir = currDir + '/SeeDot/'
223 |
224 | if os.path.isdir(seeDotDir) is False:
225 | try:
226 | os.mkdir(seeDotDir)
227 | except OSError:
228 | print("Creation of the directory %s failed" %
229 | seeDotDir)
230 |
231 | np.savetxt(seeDotDir + "W",
232 | utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.W.data.cpu(),
233 | self.bonsaiObj.numClasses,
234 | self.bonsaiObj.totalNodes),
235 | delimiter="\t")
236 | np.savetxt(seeDotDir + "V",
237 | utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.V.data.cpu(),
238 | self.bonsaiObj.numClasses,
239 | self.bonsaiObj.totalNodes),
240 | delimiter="\t")
241 | np.savetxt(seeDotDir + "T", self.bonsaiObj.T.data.cpu(), delimiter="\t")
242 | np.savetxt(seeDotDir + "Z", self.bonsaiObj.Z.data.cpu(), delimiter="\t")
243 | np.savetxt(seeDotDir + "Sigma",
244 | np.array([self.bonsaiObj.sigma]), delimiter="\t")
245 |
246 | def loadModel(self, currDir):
247 | '''
248 | Load the Saved model and load it to the model using constructor
249 | Returns two dict one for params and other for hyperParams
250 | '''
251 | paramDir = currDir + '/'
252 | paramDict = {}
253 | paramDict['W'] = np.load(paramDir + "W.npy")
254 | paramDict['V'] = np.load(paramDir + "V.npy")
255 | paramDict['T'] = np.load(paramDir + "T.npy")
256 | paramDict['Z'] = np.load(paramDir + "Z.npy")
257 | hyperParamDict = np.load(paramDir + "hyperParam.npy").item()
258 | return paramDict, hyperParamDict
259 |
260 | # Function to get aimed model size
261 | def getModelSize(self):
262 | '''
263 | Function to get aimed model size
264 | '''
265 | nnzZ, sizeZ, sparseZ = utils.estimateNNZ(self.bonsaiObj.Z, self.sZ)
266 | nnzW, sizeW, sparseW = utils.estimateNNZ(self.bonsaiObj.W, self.sW)
267 | nnzV, sizeV, sparseV = utils.estimateNNZ(self.bonsaiObj.V, self.sV)
268 | nnzT, sizeT, sparseT = utils.estimateNNZ(self.bonsaiObj.T, self.sT)
269 |
270 | totalnnZ = (nnzZ + nnzT + nnzV + nnzW)
271 | totalSize = (sizeZ + sizeW + sizeV + sizeT)
272 | hasSparse = (sparseW or sparseV or sparseT or sparseZ)
273 | return totalnnZ, totalSize, hasSparse
274 |
275 | def train(self, batchSize, totalEpochs,
276 | Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir):
277 | '''
278 | The Dense - IHT - Sparse Retrain Routine for Bonsai Training
279 | '''
280 | resultFile = open(dataDir + '/PyTorchBonsaiResults.txt', 'a+')
281 | numIters = Xtrain.shape[0] / batchSize
282 |
283 | totalBatches = numIters * totalEpochs
284 |
285 | self.sigmaI = 1
286 |
287 | counter = 0
288 | if self.bonsaiObj.numClasses > 2:
289 | trimlevel = 15
290 | else:
291 | trimlevel = 5
292 | ihtDone = 0
293 |
294 | maxTestAcc = -10000
295 | finalF1 = -10000
296 | finalTrainLoss = -10000
297 | finalTrainAcc = -10000
298 | finalClassificationReport = None
299 | finalFAR = -10000
300 | finalCM = None
301 | if self.isDenseTraining is True:
302 | ihtDone = 1
303 | self.sigmaI = 1
304 | itersInPhase = 0
305 |
306 | header = '*' * 20
307 | for i in range(totalEpochs):
308 | print("\nEpoch Number: " + str(i), file=self.outFile)
309 |
310 | '''
311 | trainAcc -> For Classification, it is 'Accuracy'.
312 | '''
313 | trainAcc = 0.0
314 | trainLoss = 0.0
315 |
316 | numIters = int(numIters)
317 | for j in range(numIters):
318 |
319 | if counter == 0:
320 | msg = " Dense Training Phase Started "
321 | print("\n%s%s%s\n" %
322 | (header, msg, header), file=self.outFile)
323 |
324 | # Updating the indicator sigma
325 | if ((counter == 0) or (counter == int(totalBatches / 3.0)) or
326 | (counter == int(2 * totalBatches / 3.0))) and (self.isDenseTraining is False):
327 | self.sigmaI = 1
328 | itersInPhase = 0
329 |
330 | elif (itersInPhase % 100 == 0):
331 | indices = np.random.choice(Xtrain.shape[0], 100)
332 | batchX = Xtrain[indices, :]
333 | batchY = Ytrain[indices, :]
334 | batchY = np.reshape(
335 | batchY, [-1, self.bonsaiObj.numClasses])
336 |
337 | Teval = self.bonsaiObj.T.data
338 | Xcapeval = (torch.matmul(self.bonsaiObj.Z, torch.t(
339 | batchX.to(self.device))) / self.bonsaiObj.projectionDimension).data
340 |
341 | sum_tr = 0.0
342 | for k in range(0, self.bonsaiObj.internalNodes):
343 | sum_tr += (
344 | np.sum(np.abs(np.dot(Teval[k].cpu(), Xcapeval.cpu()))))
345 |
346 | if(self.bonsaiObj.internalNodes > 0):
347 | sum_tr /= (100 * self.bonsaiObj.internalNodes)
348 | sum_tr = 0.1 / sum_tr
349 | else:
350 | sum_tr = 0.1
351 | sum_tr = min(
352 | 1000, sum_tr * (2**(float(itersInPhase) /
353 | (float(totalBatches) / 30.0))))
354 |
355 | self.sigmaI = sum_tr
356 |
357 | itersInPhase += 1
358 | batchX = Xtrain[j * batchSize:(j + 1) * batchSize]
359 | batchY = Ytrain[j * batchSize:(j + 1) * batchSize]
360 | batchY = np.reshape(
361 | batchY, [-1, self.bonsaiObj.numClasses])
362 |
363 | self.optimizer.zero_grad()
364 |
365 | logits, _ = self.bonsaiObj(batchX.to(self.device), self.sigmaI)
366 | batchLoss, _, _ = self.loss(logits, batchY.to(self.device))
367 | batchAcc = self.accuracy(logits, batchY.to(self.device))
368 |
369 | batchLoss.backward()
370 | self.optimizer.step()
371 |
372 | # Classification.
373 |
374 | trainAcc += batchAcc.item()
375 | trainLoss += batchLoss.item()
376 |
377 | # Training routine involving IHT and sparse retraining
378 | if (counter >= int(totalBatches / 3.0) and
379 | (counter < int(2 * totalBatches / 3.0)) and
380 | counter % trimlevel == 0 and
381 | self.isDenseTraining is False):
382 | self.runHardThrsd()
383 | if ihtDone == 0:
384 | msg = " IHT Phase Started "
385 | print("\n%s%s%s\n" %
386 | (header, msg, header), file=self.outFile)
387 | ihtDone = 1
388 | elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and
389 | (counter < int(2 * totalBatches / 3.0)) and
390 | counter % trimlevel != 0 and
391 | self.isDenseTraining is False) or
392 | (counter >= int(2 * totalBatches / 3.0) and
393 | self.isDenseTraining is False)):
394 | self.runSparseTraining()
395 | if counter == int(2 * totalBatches / 3.0):
396 | msg = " Sparse Retraining Phase Started "
397 | print("\n%s%s%s\n" %
398 | (header, msg, header), file=self.outFile)
399 | counter += 1
400 |
401 | print("\nClassification Train Loss: " + str(trainLoss / numIters) +
402 | "\nTraining accuracy (Classification): " +
403 | str(trainAcc / numIters),
404 | file=self.outFile)
405 |
406 | #####################################
407 | finalTrainAcc = trainAcc / numIters
408 | finalTrainLoss = trainLoss / numIters
409 |
410 |
411 | oldSigmaI = self.sigmaI
412 | self.sigmaI = 1e9
413 |
414 | ###################HERE####################################
415 |
416 | logits, _ = self.bonsaiObj(Xtest.to(self.device), self.sigmaI)
417 | testLoss, marginLoss, regLoss = self.loss(
418 | logits, Ytest.to(self.device))
419 | testAcc = self.accuracy(logits, Ytest.to(self.device)).item()
420 | testf1 = self.f1(logits, Ytest.to(self.device))
421 | testclass = self.classificationReport(logits, Ytest.to(self.device))
422 | CM, FAR = self.confusion_matrix_FAR(logits, Ytest.to(self.device))
423 |
424 | if ihtDone == 0:
425 | maxTestAcc = -10000
426 | maxTestAccEpoch = i
427 | else:
428 | if maxTestAcc <= testAcc:
429 | maxTestAccEpoch = i
430 | maxTestAcc = testAcc
431 | self.saveParams(currDir)
432 | self.saveParamsForSeeDot(currDir)
433 |
434 | print("Test accuracy %g" % testAcc, file=self.outFile)
435 | print("Test F1 ", testf1, file=self.outFile)
436 | print("Test False Alarm Rate ", FAR, file=self.outFile)
437 | print("Confusion Matrix \n", CM, file=self.outFile)
438 | print("Classification Report \n", testclass, file=self.outFile)
439 |
440 | #####################################
441 | testAcc = testAcc
442 | maxTestAcc = maxTestAcc
443 |
444 | finalF1 = testf1
445 | finalClassificationReport = testclass
446 | finalFAR = FAR
447 | finalCM = CM
448 |
449 | print("MarginLoss + RegLoss: " + str(marginLoss.item()) + " + " +
450 | str(regLoss.item()) + " = " + str(testLoss.item()) + "\n",
451 | file=self.outFile)
452 | self.outFile.flush()
453 |
454 | self.sigmaI = oldSigmaI
455 |
456 | # sigmaI has to be set to infinity to ensure
457 | # only a single path is used in inference
458 | self.sigmaI = 1e9
459 | print("\nNon-Zero : " + str(self.getModelSize()[0]) + " Model Size: " +
460 | str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " +
461 | str(self.getModelSize()[2]) + "\n", file=self.outFile)
462 |
463 | print("For Classification, Maximum Test accuracy at compressed" +
464 | " model size(including early stopping): " +
465 | str(maxTestAcc) + " at Epoch: " +
466 | str(maxTestAccEpoch + 1) + "\nFinal Test" +
467 | " Accuracy: " + str(testAcc), file=self.outFile)
468 |
469 | resultFile.write("MaxTestAcc: " + str(maxTestAcc) +
470 | " at Epoch(totalEpochs): " +
471 | str(maxTestAccEpoch + 1) +
472 | "(" + str(totalEpochs) + ")" + " ModelSize: " +
473 | str(float(self.getModelSize()[1]) / 1024.0) +
474 | " KB hasSparse: " + str(self.getModelSize()[2]) +
475 | " Param Directory: " +
476 | str(os.path.abspath(currDir)) + "\n")
477 |
478 | ##############################################################
479 | finalModelSize = float(self.getModelSize()[1]) / 1024.0
480 |
481 | print("The Model Directory: " + currDir + "\n")
482 |
483 | resultFile.close()
484 | self.outFile.flush()
485 |
486 | if self.outFile is not sys.stdout:
487 | self.outFile.close()
488 |
489 | finalClassificationReport['train loss'] = finalTrainLoss
490 | finalClassificationReport['train acc'] = finalTrainAcc
491 | finalClassificationReport['test f1'] = finalF1
492 | finalClassificationReport['model size'] = finalModelSize
493 | finalClassificationReport['test far'] = finalFAR
494 | return(finalClassificationReport, finalCM)
495 |
496 |
497 |
--------------------------------------------------------------------------------
/bonsai_example2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT license.
3 |
4 | import helpermethods
5 | import numpy as np
6 | import sys
7 | import pandas as pd
8 | from edgeml_pytorch.trainer.bonsaiTrainer2 import BonsaiTrainer
9 | from edgeml_pytorch.graph.bonsai import Bonsai
10 | import torch
11 | from sklearn.model_selection import KFold
12 | import json
13 |
14 | def dict_mean(dict_list):
15 | mean_dict = {}
16 | for key in dict_list[0].keys():
17 | mean_dict[key] = sum(d[key] for d in dict_list) / len(dict_list)
18 | return mean_dict
19 |
20 |
21 | def main():
22 | # change cuda:0 to cuda:gpuid for specific allocation
23 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24 | device = torch.device("cpu")
25 | # Fixing seeds for reproducibility
26 | torch.manual_seed(42)
27 | np.random.seed(42)
28 |
29 | # Hyper Param pre-processing
30 | args = helpermethods.getArgs()
31 |
32 | sigma = args.sigma
33 | depth = args.depth
34 |
35 | projectionDimension = args.proj_dim
36 | regZ = args.rZ
37 | regT = args.rT
38 | regW = args.rW
39 | regV = args.rV
40 |
41 | totalEpochs = args.epochs
42 |
43 | learningRate = args.learning_rate
44 |
45 | dataDir = args.data_dir
46 |
47 | DFPATH = args.df_path
48 |
49 | outFile = args.output_file
50 |
51 | k_folds = args.kF
52 |
53 | ######################HERE#################################
54 |
55 | final_scores = []
56 | final_scores_with_cm = []
57 | df = pd.read_csv(DFPATH)
58 |
59 | kfold = KFold(n_splits=k_folds, shuffle=True)
60 |
61 | for train_index, test_index in kfold.split(df):
62 | train = df.iloc[train_index]
63 | test = df.iloc[test_index]
64 | # print(train)
65 | # print(test)
66 |
67 | train = train.to_numpy()
68 | test = test.to_numpy()
69 |
70 | np.save(dataDir + '/train.npy', train)
71 | np.save(dataDir + '/test.npy', test)
72 |
73 | (dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest,
74 | mean, std) = helpermethods.preProcessData(dataDir)
75 |
76 | sparZ = args.sZ
77 |
78 | if numClasses > 2:
79 | sparW = 0.2
80 | sparV = 0.2
81 | sparT = 0.2
82 | else:
83 | sparW = 1
84 | sparV = 1
85 | sparT = 1
86 |
87 | if args.sW is not None:
88 | sparW = args.sW
89 | if args.sV is not None:
90 | sparV = args.sV
91 | if args.sT is not None:
92 | sparT = args.sT
93 |
94 | if args.batch_size is None:
95 | batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0]))))
96 | else:
97 | batchSize = args.batch_size
98 |
99 | useMCHLoss = True
100 |
101 | if numClasses == 2:
102 | numClasses = 1
103 |
104 | currDir = helpermethods.createTimeStampDir(dataDir)
105 |
106 | helpermethods.dumpCommand(sys.argv, currDir)
107 | helpermethods.saveMeanStd(mean, std, currDir)
108 |
109 | # numClasses = 1 for binary case
110 | bonsaiObj = Bonsai(numClasses, dataDimension,
111 | projectionDimension, depth, sigma).to(device)
112 |
113 | bonsaiTrainer = BonsaiTrainer(bonsaiObj,
114 | regW, regT, regV, regZ,
115 | sparW, sparT, sparV, sparZ,
116 | learningRate, useMCHLoss, outFile, device)
117 |
118 | fold_scores, CM = bonsaiTrainer.train(batchSize, totalEpochs,
119 | torch.from_numpy(Xtrain.astype(np.float32)),
120 | torch.from_numpy(Xtest.astype(np.float32)),
121 | torch.from_numpy(Ytrain.astype(np.float32)),
122 | torch.from_numpy(Ytest.astype(np.float32)),
123 | dataDir, currDir)
124 |
125 | fold_scores = pd.json_normalize(fold_scores, sep='_')
126 | fold_scores = fold_scores.to_dict(orient='records')[0]
127 | fold_scores_with_cm = fold_scores.copy()
128 | fold_scores_with_cm['CM'] = CM.tolist()
129 | final_scores_with_cm.append(fold_scores_with_cm)
130 |
131 |
132 | print('########################################## FOLD SCORES ############################################')
133 | print(fold_scores)
134 | final_scores.append(fold_scores)
135 |
136 | print('########################################## FINAL SCORES ############################################')
137 | avg_score = dict_mean(final_scores)
138 | print(avg_score)
139 |
140 | with open(dataDir + 'Fold Results tf e-300 kF-5 depth-3.txt', 'w') as file:
141 | file.write(json.dumps(final_scores_with_cm, indent=4))
142 |
143 | with open(dataDir + 'Final Results tf e-300 kF-5 depth-3.txt', 'w') as file:
144 | file.write(json.dumps(avg_score, indent=4))
145 |
146 |
147 | sys.stdout.close()
148 |
149 |
150 | if __name__ == '__main__':
151 | main()
152 |
--------------------------------------------------------------------------------
/pipeline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import pandas as pd\n",
11 | "from sklearn.model_selection import train_test_split\n",
12 | "import os\n",
13 | "from ecgdetectors import Detectors\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import seaborn as sns\n",
16 | "from time import time as time"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | "CPU times: user 1min 43s, sys: 6.69 s, total: 1min 50s\n",
29 | "Wall time: 2min\n"
30 | ]
31 | },
32 | {
33 | "data": {
34 | "text/html": [
35 | "\n",
36 | "\n",
49 | "
\n",
50 | " \n",
51 | " \n",
52 | " | \n",
53 | " af | \n",
54 | " 0 | \n",
55 | " 1 | \n",
56 | " 2 | \n",
57 | " 3 | \n",
58 | " 4 | \n",
59 | " 5 | \n",
60 | " 6 | \n",
61 | " 7 | \n",
62 | " 8 | \n",
63 | " ... | \n",
64 | " 1240 | \n",
65 | " 1241 | \n",
66 | " 1242 | \n",
67 | " 1243 | \n",
68 | " 1244 | \n",
69 | " 1245 | \n",
70 | " 1246 | \n",
71 | " 1247 | \n",
72 | " 1248 | \n",
73 | " 1249 | \n",
74 | "
\n",
75 | " \n",
76 | " \n",
77 | " \n",
78 | " 0 | \n",
79 | " 0 | \n",
80 | " 0.0 | \n",
81 | " 0.0 | \n",
82 | " 0.0 | \n",
83 | " 0.0 | \n",
84 | " 0.0 | \n",
85 | " 0.0 | \n",
86 | " 0.0 | \n",
87 | " 0.0 | \n",
88 | " 0.0 | \n",
89 | " ... | \n",
90 | " 0.0 | \n",
91 | " 0.0 | \n",
92 | " 0.0 | \n",
93 | " 0.0 | \n",
94 | " 0.0 | \n",
95 | " 0.0 | \n",
96 | " 0.0 | \n",
97 | " 0.0 | \n",
98 | " 0.0 | \n",
99 | " 0.0 | \n",
100 | "
\n",
101 | " \n",
102 | " 1 | \n",
103 | " 0 | \n",
104 | " 0.0 | \n",
105 | " 0.0 | \n",
106 | " 0.0 | \n",
107 | " 0.0 | \n",
108 | " 0.0 | \n",
109 | " 0.0 | \n",
110 | " 0.0 | \n",
111 | " 0.0 | \n",
112 | " 0.0 | \n",
113 | " ... | \n",
114 | " 0.0 | \n",
115 | " 0.0 | \n",
116 | " 0.0 | \n",
117 | " 0.0 | \n",
118 | " 0.0 | \n",
119 | " 0.0 | \n",
120 | " 0.0 | \n",
121 | " 0.0 | \n",
122 | " 0.0 | \n",
123 | " 0.0 | \n",
124 | "
\n",
125 | " \n",
126 | " 2 | \n",
127 | " 0 | \n",
128 | " 0.0 | \n",
129 | " 0.0 | \n",
130 | " 0.0 | \n",
131 | " 0.0 | \n",
132 | " 0.0 | \n",
133 | " 0.0 | \n",
134 | " 0.0 | \n",
135 | " 0.0 | \n",
136 | " 0.0 | \n",
137 | " ... | \n",
138 | " 0.0 | \n",
139 | " 0.0 | \n",
140 | " 0.0 | \n",
141 | " 0.0 | \n",
142 | " 0.0 | \n",
143 | " 0.0 | \n",
144 | " 0.0 | \n",
145 | " 0.0 | \n",
146 | " 0.0 | \n",
147 | " 0.0 | \n",
148 | "
\n",
149 | " \n",
150 | " 3 | \n",
151 | " 0 | \n",
152 | " 0.0 | \n",
153 | " 0.0 | \n",
154 | " 0.0 | \n",
155 | " 0.0 | \n",
156 | " 0.0 | \n",
157 | " 0.0 | \n",
158 | " 0.0 | \n",
159 | " 0.0 | \n",
160 | " 0.0 | \n",
161 | " ... | \n",
162 | " 0.0 | \n",
163 | " 0.0 | \n",
164 | " 0.0 | \n",
165 | " 0.0 | \n",
166 | " 0.0 | \n",
167 | " 0.0 | \n",
168 | " 0.0 | \n",
169 | " 0.0 | \n",
170 | " 0.0 | \n",
171 | " 0.0 | \n",
172 | "
\n",
173 | " \n",
174 | " 4 | \n",
175 | " 0 | \n",
176 | " 0.0 | \n",
177 | " 0.0 | \n",
178 | " 0.0 | \n",
179 | " 0.0 | \n",
180 | " 0.0 | \n",
181 | " 0.0 | \n",
182 | " 0.0 | \n",
183 | " 0.0 | \n",
184 | " 0.0 | \n",
185 | " ... | \n",
186 | " 0.0 | \n",
187 | " 0.0 | \n",
188 | " 0.0 | \n",
189 | " 0.0 | \n",
190 | " 0.0 | \n",
191 | " 0.0 | \n",
192 | " 0.0 | \n",
193 | " 0.0 | \n",
194 | " 0.0 | \n",
195 | " 0.0 | \n",
196 | "
\n",
197 | " \n",
198 | " ... | \n",
199 | " ... | \n",
200 | " ... | \n",
201 | " ... | \n",
202 | " ... | \n",
203 | " ... | \n",
204 | " ... | \n",
205 | " ... | \n",
206 | " ... | \n",
207 | " ... | \n",
208 | " ... | \n",
209 | " ... | \n",
210 | " ... | \n",
211 | " ... | \n",
212 | " ... | \n",
213 | " ... | \n",
214 | " ... | \n",
215 | " ... | \n",
216 | " ... | \n",
217 | " ... | \n",
218 | " ... | \n",
219 | " ... | \n",
220 | "
\n",
221 | " \n",
222 | " 787506 | \n",
223 | " 1 | \n",
224 | " 0.0 | \n",
225 | " 0.0 | \n",
226 | " 0.0 | \n",
227 | " 0.0 | \n",
228 | " 0.0 | \n",
229 | " 0.0 | \n",
230 | " 0.0 | \n",
231 | " 0.0 | \n",
232 | " 0.0 | \n",
233 | " ... | \n",
234 | " 0.0 | \n",
235 | " 0.0 | \n",
236 | " 0.0 | \n",
237 | " 0.0 | \n",
238 | " 0.0 | \n",
239 | " 0.0 | \n",
240 | " 0.0 | \n",
241 | " 0.0 | \n",
242 | " 0.0 | \n",
243 | " 0.0 | \n",
244 | "
\n",
245 | " \n",
246 | " 787507 | \n",
247 | " 1 | \n",
248 | " 0.0 | \n",
249 | " 0.0 | \n",
250 | " 0.0 | \n",
251 | " 0.0 | \n",
252 | " 0.0 | \n",
253 | " 0.0 | \n",
254 | " 0.0 | \n",
255 | " 0.0 | \n",
256 | " 0.0 | \n",
257 | " ... | \n",
258 | " 0.0 | \n",
259 | " 0.0 | \n",
260 | " 0.0 | \n",
261 | " 0.0 | \n",
262 | " 0.0 | \n",
263 | " 0.0 | \n",
264 | " 0.0 | \n",
265 | " 0.0 | \n",
266 | " 0.0 | \n",
267 | " 0.0 | \n",
268 | "
\n",
269 | " \n",
270 | " 787508 | \n",
271 | " 1 | \n",
272 | " 0.0 | \n",
273 | " 0.0 | \n",
274 | " 0.0 | \n",
275 | " 0.0 | \n",
276 | " 0.0 | \n",
277 | " 0.0 | \n",
278 | " 0.0 | \n",
279 | " 0.0 | \n",
280 | " 0.0 | \n",
281 | " ... | \n",
282 | " 0.0 | \n",
283 | " 0.0 | \n",
284 | " 0.0 | \n",
285 | " 0.0 | \n",
286 | " 0.0 | \n",
287 | " 0.0 | \n",
288 | " 0.0 | \n",
289 | " 0.0 | \n",
290 | " 0.0 | \n",
291 | " 0.0 | \n",
292 | "
\n",
293 | " \n",
294 | " 787509 | \n",
295 | " 1 | \n",
296 | " 0.0 | \n",
297 | " 0.0 | \n",
298 | " 0.0 | \n",
299 | " 0.0 | \n",
300 | " 0.0 | \n",
301 | " 0.0 | \n",
302 | " 0.0 | \n",
303 | " 0.0 | \n",
304 | " 0.0 | \n",
305 | " ... | \n",
306 | " 0.0 | \n",
307 | " 0.0 | \n",
308 | " 0.0 | \n",
309 | " 0.0 | \n",
310 | " 0.0 | \n",
311 | " 0.0 | \n",
312 | " 0.0 | \n",
313 | " 0.0 | \n",
314 | " 0.0 | \n",
315 | " 0.0 | \n",
316 | "
\n",
317 | " \n",
318 | " 787510 | \n",
319 | " 1 | \n",
320 | " 0.0 | \n",
321 | " 0.0 | \n",
322 | " 0.0 | \n",
323 | " 0.0 | \n",
324 | " 0.0 | \n",
325 | " 0.0 | \n",
326 | " 0.0 | \n",
327 | " 0.0 | \n",
328 | " 0.0 | \n",
329 | " ... | \n",
330 | " 0.0 | \n",
331 | " 0.0 | \n",
332 | " 0.0 | \n",
333 | " 0.0 | \n",
334 | " 0.0 | \n",
335 | " 0.0 | \n",
336 | " 0.0 | \n",
337 | " 0.0 | \n",
338 | " 0.0 | \n",
339 | " 0.0 | \n",
340 | "
\n",
341 | " \n",
342 | "
\n",
343 | "
787154 rows × 1251 columns
\n",
344 | "
"
345 | ],
346 | "text/plain": [
347 | " af 0 1 2 3 4 5 6 7 8 ... 1240 1241 \\\n",
348 | "0 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
349 | "1 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
350 | "2 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
351 | "3 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
352 | "4 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
353 | "... .. ... ... ... ... ... ... ... ... ... ... ... ... \n",
354 | "787506 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
355 | "787507 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
356 | "787508 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
357 | "787509 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
358 | "787510 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
359 | "\n",
360 | " 1242 1243 1244 1245 1246 1247 1248 1249 \n",
361 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
362 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
363 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
364 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
365 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
366 | "... ... ... ... ... ... ... ... ... \n",
367 | "787506 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
368 | "787507 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
369 | "787508 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
370 | "787509 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
371 | "787510 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
372 | "\n",
373 | "[787154 rows x 1251 columns]"
374 | ]
375 | },
376 | "execution_count": 2,
377 | "metadata": {},
378 | "output_type": "execute_result"
379 | }
380 | ],
381 | "source": [
382 | "%%time\n",
383 | "# df = pd.read_csv(\"/hdd/physio/af/.csv\", index_col = 0)\n",
384 | "final = pd.read_csv('/hdd/physio/af/hugedfrpeak5sec.csv', index_col = 0)\n",
385 | "final"
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": 93,
391 | "metadata": {},
392 | "outputs": [
393 | {
394 | "name": "stdout",
395 | "output_type": "stream",
396 | "text": [
397 | "(1250,)\n",
398 | "(5,) [188, 422, 634, 851, 1053]\n",
399 | "(7,)\n",
400 | "0.019435644149780273\n"
401 | ]
402 | }
403 | ],
404 | "source": [
405 | "fs = 250\n",
406 | "\n",
407 | "def pipelinedRpeakExtraction(x, fs):\n",
408 | " detectors = Detectors(fs)\n",
409 | "# x = detectors.swt_detector(x)\n",
410 | "# x = detectors.hamilton_detector(x)\n",
411 | " x = detectors.pan_tompkins_detector(x)\n",
412 | " return x\n",
413 | "\n",
414 | "def optimizedFeatureExtraction(x, fs):\n",
415 | " rr = np.diff(x) / fs \n",
416 | " hr = 60 / rr\n",
417 | " rmssd = np.sqrt(np.mean(np.square(rr)))\n",
418 | " sdnn = np.std(rr)\n",
419 | " mean_rr = np.mean(rr)\n",
420 | " mean_hr = np.mean(hr)\n",
421 | " std_hr = np.std(hr)\n",
422 | " min_hr = np.min(hr)\n",
423 | " max_hr = np.max(hr)\n",
424 | " features = np.array([rmssd, sdnn, mean_rr, mean_hr, std_hr, min_hr, max_hr])\n",
425 | " return features\n",
426 | "\n",
427 | "def helperMethod():\n",
428 | " return np.array(final.loc[5])[1:]\n",
429 | "\n",
430 | "\n",
431 | "x = helperMethod()\n",
432 | "print(x.shape)\n",
433 | "start = time()\n",
434 | "x = pipelinedRpeakExtraction(x, fs)\n",
435 | "print(np.array(x).shape, x)\n",
436 | "features = optimizedFeatureExtraction(x, fs)\n",
437 | "print(features.shape)\n",
438 | "print(time() - start)"
439 | ]
440 | },
441 | {
442 | "cell_type": "raw",
443 | "metadata": {},
444 | "source": [
445 | "Time for RPeakExtraction:\n",
446 | "\n",
447 | "0.011760711669921875 - SWT\n",
448 | "0.010072946548461914 - HS\n",
449 | "0.032416343688964844 - PT\n",
450 | "\n",
451 | "Time for RPeakExtraction & optimizedFeatureExtraction:\n",
452 | "\n",
453 | "0.0029833316802978516 - SWT\n",
454 | "0.004145383834838867 - HS\n",
455 | "0.018214702606201172 - PT"
456 | ]
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": 87,
461 | "metadata": {},
462 | "outputs": [
463 | {
464 | "data": {
465 | "text/plain": [
466 | "(5,)"
467 | ]
468 | },
469 | "execution_count": 87,
470 | "metadata": {},
471 | "output_type": "execute_result"
472 | }
473 | ],
474 | "source": [
475 | "x = np.array(x)\n",
476 | "x.shape"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": 94,
482 | "metadata": {},
483 | "outputs": [],
484 | "source": [
485 | "chdb = pd.read_csv('/hdd/physio/af2/finaldfs/2017final_normal_ecg_L5_S1.csv')"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 95,
491 | "metadata": {},
492 | "outputs": [
493 | {
494 | "data": {
495 | "text/html": [
496 | "\n",
497 | "\n",
510 | "
\n",
511 | " \n",
512 | " \n",
513 | " | \n",
514 | " af | \n",
515 | " 0 | \n",
516 | " 1 | \n",
517 | " 2 | \n",
518 | " 3 | \n",
519 | " 4 | \n",
520 | " 5 | \n",
521 | " 6 | \n",
522 | " 7 | \n",
523 | " 8 | \n",
524 | " ... | \n",
525 | " 1240 | \n",
526 | " 1241 | \n",
527 | " 1242 | \n",
528 | " 1243 | \n",
529 | " 1244 | \n",
530 | " 1245 | \n",
531 | " 1246 | \n",
532 | " 1247 | \n",
533 | " 1248 | \n",
534 | " 1249 | \n",
535 | "
\n",
536 | " \n",
537 | " \n",
538 | " \n",
539 | " 0 | \n",
540 | " 0 | \n",
541 | " -0.119671 | \n",
542 | " -0.171568 | \n",
543 | " -0.209467 | \n",
544 | " -0.241545 | \n",
545 | " -0.251298 | \n",
546 | " -0.261781 | \n",
547 | " -0.265077 | \n",
548 | " -0.268898 | \n",
549 | " -0.266928 | \n",
550 | " ... | \n",
551 | " -0.004762 | \n",
552 | " -0.011451 | \n",
553 | " -0.016941 | \n",
554 | " -0.012288 | \n",
555 | " -0.003720 | \n",
556 | " -0.000231 | \n",
557 | " 0.002896 | \n",
558 | " 0.007219 | \n",
559 | " 0.010819 | \n",
560 | " 0.015448 | \n",
561 | "
\n",
562 | " \n",
563 | " 1 | \n",
564 | " 0 | \n",
565 | " 0.179901 | \n",
566 | " 0.182646 | \n",
567 | " 0.186076 | \n",
568 | " 0.190405 | \n",
569 | " 0.194489 | \n",
570 | " 0.197974 | \n",
571 | " 0.201573 | \n",
572 | " 0.205480 | \n",
573 | " 0.210523 | \n",
574 | " ... | \n",
575 | " -0.099974 | \n",
576 | " -0.101241 | \n",
577 | " -0.102327 | \n",
578 | " -0.103782 | \n",
579 | " -0.102690 | \n",
580 | " -0.098856 | \n",
581 | " -0.094154 | \n",
582 | " -0.085563 | \n",
583 | " -0.073865 | \n",
584 | " -0.057123 | \n",
585 | "
\n",
586 | " \n",
587 | " 2 | \n",
588 | " 0 | \n",
589 | " -0.080083 | \n",
590 | " -0.073838 | \n",
591 | " -0.066394 | \n",
592 | " -0.056862 | \n",
593 | " -0.043762 | \n",
594 | " -0.025739 | \n",
595 | " -0.003682 | \n",
596 | " 0.017315 | \n",
597 | " 0.036692 | \n",
598 | " ... | \n",
599 | " -0.062881 | \n",
600 | " -0.066561 | \n",
601 | " -0.068860 | \n",
602 | " -0.073529 | \n",
603 | " -0.078563 | \n",
604 | " -0.089408 | \n",
605 | " -0.105989 | \n",
606 | " -0.113971 | \n",
607 | " -0.117699 | \n",
608 | " -0.121547 | \n",
609 | "
\n",
610 | " \n",
611 | " 3 | \n",
612 | " 0 | \n",
613 | " 0.004971 | \n",
614 | " 0.009785 | \n",
615 | " 0.012551 | \n",
616 | " 0.013169 | \n",
617 | " 0.011561 | \n",
618 | " 0.008022 | \n",
619 | " 0.005519 | \n",
620 | " 0.001601 | \n",
621 | " -0.000719 | \n",
622 | " ... | \n",
623 | " -0.027093 | \n",
624 | " -0.030426 | \n",
625 | " -0.032781 | \n",
626 | " -0.035219 | \n",
627 | " -0.037578 | \n",
628 | " -0.039177 | \n",
629 | " -0.041016 | \n",
630 | " -0.041205 | \n",
631 | " -0.043212 | \n",
632 | " -0.044915 | \n",
633 | "
\n",
634 | " \n",
635 | " 4 | \n",
636 | " 0 | \n",
637 | " 0.008039 | \n",
638 | " 0.006846 | \n",
639 | " 0.005272 | \n",
640 | " 0.002796 | \n",
641 | " 0.000349 | \n",
642 | " -0.000125 | \n",
643 | " -0.001080 | \n",
644 | " -0.000778 | \n",
645 | " -0.002517 | \n",
646 | " ... | \n",
647 | " -0.018165 | \n",
648 | " -0.026139 | \n",
649 | " -0.033500 | \n",
650 | " -0.040133 | \n",
651 | " -0.044640 | \n",
652 | " -0.047706 | \n",
653 | " -0.049499 | \n",
654 | " -0.049575 | \n",
655 | " -0.048598 | \n",
656 | " -0.045659 | \n",
657 | "
\n",
658 | " \n",
659 | " ... | \n",
660 | " ... | \n",
661 | " ... | \n",
662 | " ... | \n",
663 | " ... | \n",
664 | " ... | \n",
665 | " ... | \n",
666 | " ... | \n",
667 | " ... | \n",
668 | " ... | \n",
669 | " ... | \n",
670 | " ... | \n",
671 | " ... | \n",
672 | " ... | \n",
673 | " ... | \n",
674 | " ... | \n",
675 | " ... | \n",
676 | " ... | \n",
677 | " ... | \n",
678 | " ... | \n",
679 | " ... | \n",
680 | " ... | \n",
681 | "
\n",
682 | " \n",
683 | " 163588 | \n",
684 | " 0 | \n",
685 | " -0.026915 | \n",
686 | " -0.021031 | \n",
687 | " -0.014944 | \n",
688 | " -0.009475 | \n",
689 | " -0.004890 | \n",
690 | " 0.000104 | \n",
691 | " 0.006036 | \n",
692 | " 0.013603 | \n",
693 | " 0.019909 | \n",
694 | " ... | \n",
695 | " -0.008041 | \n",
696 | " -0.009067 | \n",
697 | " -0.008895 | \n",
698 | " -0.009538 | \n",
699 | " -0.010804 | \n",
700 | " -0.012067 | \n",
701 | " -0.014235 | \n",
702 | " -0.015384 | \n",
703 | " -0.016604 | \n",
704 | " -0.017808 | \n",
705 | "
\n",
706 | " \n",
707 | " 163589 | \n",
708 | " 0 | \n",
709 | " 0.273103 | \n",
710 | " 0.281178 | \n",
711 | " 0.288176 | \n",
712 | " 0.293863 | \n",
713 | " 0.300131 | \n",
714 | " 0.305001 | \n",
715 | " 0.310949 | \n",
716 | " 0.315261 | \n",
717 | " 0.318010 | \n",
718 | " ... | \n",
719 | " -0.041078 | \n",
720 | " -0.040873 | \n",
721 | " -0.041422 | \n",
722 | " -0.042109 | \n",
723 | " -0.041949 | \n",
724 | " -0.042029 | \n",
725 | " -0.041985 | \n",
726 | " -0.042000 | \n",
727 | " -0.042037 | \n",
728 | " -0.041325 | \n",
729 | "
\n",
730 | " \n",
731 | " 163590 | \n",
732 | " 0 | \n",
733 | " 0.083085 | \n",
734 | " 0.065010 | \n",
735 | " 0.046886 | \n",
736 | " 0.029653 | \n",
737 | " 0.010828 | \n",
738 | " -0.013283 | \n",
739 | " -0.039487 | \n",
740 | " -0.053796 | \n",
741 | " -0.064019 | \n",
742 | " ... | \n",
743 | " -0.059016 | \n",
744 | " -0.057726 | \n",
745 | " -0.056972 | \n",
746 | " -0.056561 | \n",
747 | " -0.055091 | \n",
748 | " -0.054100 | \n",
749 | " -0.052678 | \n",
750 | " -0.051961 | \n",
751 | " -0.052033 | \n",
752 | " -0.051343 | \n",
753 | "
\n",
754 | " \n",
755 | " 163591 | \n",
756 | " 0 | \n",
757 | " 0.004186 | \n",
758 | " 0.001340 | \n",
759 | " 0.000015 | \n",
760 | " -0.000567 | \n",
761 | " -0.001093 | \n",
762 | " -0.001900 | \n",
763 | " -0.003228 | \n",
764 | " -0.004483 | \n",
765 | " -0.004620 | \n",
766 | " ... | \n",
767 | " -0.105929 | \n",
768 | " -0.100147 | \n",
769 | " -0.095189 | \n",
770 | " -0.088889 | \n",
771 | " -0.082454 | \n",
772 | " -0.074770 | \n",
773 | " -0.068059 | \n",
774 | " -0.060077 | \n",
775 | " -0.052014 | \n",
776 | " -0.042487 | \n",
777 | "
\n",
778 | " \n",
779 | " 163592 | \n",
780 | " 0 | \n",
781 | " -0.006969 | \n",
782 | " -0.009485 | \n",
783 | " -0.013761 | \n",
784 | " -0.020594 | \n",
785 | " -0.027776 | \n",
786 | " -0.035079 | \n",
787 | " -0.043239 | \n",
788 | " -0.050359 | \n",
789 | " -0.057678 | \n",
790 | " ... | \n",
791 | " -0.002951 | \n",
792 | " -0.002407 | \n",
793 | " -0.006838 | \n",
794 | " -0.007050 | \n",
795 | " -0.011927 | \n",
796 | " -0.010521 | \n",
797 | " -0.015027 | \n",
798 | " -0.012242 | \n",
799 | " -0.020470 | \n",
800 | " -0.011068 | \n",
801 | "
\n",
802 | " \n",
803 | "
\n",
804 | "
163593 rows × 1251 columns
\n",
805 | "
"
806 | ],
807 | "text/plain": [
808 | " af 0 1 2 3 4 5 \\\n",
809 | "0 0 -0.119671 -0.171568 -0.209467 -0.241545 -0.251298 -0.261781 \n",
810 | "1 0 0.179901 0.182646 0.186076 0.190405 0.194489 0.197974 \n",
811 | "2 0 -0.080083 -0.073838 -0.066394 -0.056862 -0.043762 -0.025739 \n",
812 | "3 0 0.004971 0.009785 0.012551 0.013169 0.011561 0.008022 \n",
813 | "4 0 0.008039 0.006846 0.005272 0.002796 0.000349 -0.000125 \n",
814 | "... .. ... ... ... ... ... ... \n",
815 | "163588 0 -0.026915 -0.021031 -0.014944 -0.009475 -0.004890 0.000104 \n",
816 | "163589 0 0.273103 0.281178 0.288176 0.293863 0.300131 0.305001 \n",
817 | "163590 0 0.083085 0.065010 0.046886 0.029653 0.010828 -0.013283 \n",
818 | "163591 0 0.004186 0.001340 0.000015 -0.000567 -0.001093 -0.001900 \n",
819 | "163592 0 -0.006969 -0.009485 -0.013761 -0.020594 -0.027776 -0.035079 \n",
820 | "\n",
821 | " 6 7 8 ... 1240 1241 1242 \\\n",
822 | "0 -0.265077 -0.268898 -0.266928 ... -0.004762 -0.011451 -0.016941 \n",
823 | "1 0.201573 0.205480 0.210523 ... -0.099974 -0.101241 -0.102327 \n",
824 | "2 -0.003682 0.017315 0.036692 ... -0.062881 -0.066561 -0.068860 \n",
825 | "3 0.005519 0.001601 -0.000719 ... -0.027093 -0.030426 -0.032781 \n",
826 | "4 -0.001080 -0.000778 -0.002517 ... -0.018165 -0.026139 -0.033500 \n",
827 | "... ... ... ... ... ... ... ... \n",
828 | "163588 0.006036 0.013603 0.019909 ... -0.008041 -0.009067 -0.008895 \n",
829 | "163589 0.310949 0.315261 0.318010 ... -0.041078 -0.040873 -0.041422 \n",
830 | "163590 -0.039487 -0.053796 -0.064019 ... -0.059016 -0.057726 -0.056972 \n",
831 | "163591 -0.003228 -0.004483 -0.004620 ... -0.105929 -0.100147 -0.095189 \n",
832 | "163592 -0.043239 -0.050359 -0.057678 ... -0.002951 -0.002407 -0.006838 \n",
833 | "\n",
834 | " 1243 1244 1245 1246 1247 1248 1249 \n",
835 | "0 -0.012288 -0.003720 -0.000231 0.002896 0.007219 0.010819 0.015448 \n",
836 | "1 -0.103782 -0.102690 -0.098856 -0.094154 -0.085563 -0.073865 -0.057123 \n",
837 | "2 -0.073529 -0.078563 -0.089408 -0.105989 -0.113971 -0.117699 -0.121547 \n",
838 | "3 -0.035219 -0.037578 -0.039177 -0.041016 -0.041205 -0.043212 -0.044915 \n",
839 | "4 -0.040133 -0.044640 -0.047706 -0.049499 -0.049575 -0.048598 -0.045659 \n",
840 | "... ... ... ... ... ... ... ... \n",
841 | "163588 -0.009538 -0.010804 -0.012067 -0.014235 -0.015384 -0.016604 -0.017808 \n",
842 | "163589 -0.042109 -0.041949 -0.042029 -0.041985 -0.042000 -0.042037 -0.041325 \n",
843 | "163590 -0.056561 -0.055091 -0.054100 -0.052678 -0.051961 -0.052033 -0.051343 \n",
844 | "163591 -0.088889 -0.082454 -0.074770 -0.068059 -0.060077 -0.052014 -0.042487 \n",
845 | "163592 -0.007050 -0.011927 -0.010521 -0.015027 -0.012242 -0.020470 -0.011068 \n",
846 | "\n",
847 | "[163593 rows x 1251 columns]"
848 | ]
849 | },
850 | "execution_count": 95,
851 | "metadata": {},
852 | "output_type": "execute_result"
853 | }
854 | ],
855 | "source": [
856 | "chdb"
857 | ]
858 | },
859 | {
860 | "cell_type": "code",
861 | "execution_count": 96,
862 | "metadata": {},
863 | "outputs": [
864 | {
865 | "data": {
866 | "text/html": [
867 | "\n",
868 | "\n",
881 | "
\n",
882 | " \n",
883 | " \n",
884 | " | \n",
885 | " Unnamed: 0 | \n",
886 | " vf | \n",
887 | " 0 | \n",
888 | " 1 | \n",
889 | " 2 | \n",
890 | " 3 | \n",
891 | " 4 | \n",
892 | " 5 | \n",
893 | " 6 | \n",
894 | " 7 | \n",
895 | " ... | \n",
896 | " 1240 | \n",
897 | " 1241 | \n",
898 | " 1242 | \n",
899 | " 1243 | \n",
900 | " 1244 | \n",
901 | " 1245 | \n",
902 | " 1246 | \n",
903 | " 1247 | \n",
904 | " 1248 | \n",
905 | " 1249 | \n",
906 | "
\n",
907 | " \n",
908 | " \n",
909 | " \n",
910 | " 0 | \n",
911 | " 0 | \n",
912 | " 0 | \n",
913 | " 0 | \n",
914 | " 0 | \n",
915 | " 0 | \n",
916 | " 0 | \n",
917 | " 0 | \n",
918 | " 0 | \n",
919 | " 0 | \n",
920 | " 0 | \n",
921 | " ... | \n",
922 | " 0 | \n",
923 | " 0 | \n",
924 | " 0 | \n",
925 | " 0 | \n",
926 | " 0 | \n",
927 | " 0 | \n",
928 | " 0 | \n",
929 | " 0 | \n",
930 | " 0 | \n",
931 | " 0 | \n",
932 | "
\n",
933 | " \n",
934 | " 1 | \n",
935 | " 0 | \n",
936 | " 0 | \n",
937 | " 0 | \n",
938 | " 0 | \n",
939 | " 0 | \n",
940 | " 0 | \n",
941 | " 0 | \n",
942 | " 0 | \n",
943 | " 0 | \n",
944 | " 0 | \n",
945 | " ... | \n",
946 | " 0 | \n",
947 | " 0 | \n",
948 | " 0 | \n",
949 | " 0 | \n",
950 | " 0 | \n",
951 | " 0 | \n",
952 | " 0 | \n",
953 | " 0 | \n",
954 | " 0 | \n",
955 | " 0 | \n",
956 | "
\n",
957 | " \n",
958 | " 2 | \n",
959 | " 0 | \n",
960 | " 0 | \n",
961 | " 0 | \n",
962 | " 0 | \n",
963 | " 0 | \n",
964 | " 0 | \n",
965 | " 0 | \n",
966 | " 0 | \n",
967 | " 0 | \n",
968 | " 0 | \n",
969 | " ... | \n",
970 | " 0 | \n",
971 | " 0 | \n",
972 | " 0 | \n",
973 | " 0 | \n",
974 | " 0 | \n",
975 | " 0 | \n",
976 | " 0 | \n",
977 | " 0 | \n",
978 | " 0 | \n",
979 | " 0 | \n",
980 | "
\n",
981 | " \n",
982 | " 3 | \n",
983 | " 0 | \n",
984 | " 0 | \n",
985 | " 0 | \n",
986 | " 0 | \n",
987 | " 0 | \n",
988 | " 0 | \n",
989 | " 0 | \n",
990 | " 0 | \n",
991 | " 0 | \n",
992 | " 0 | \n",
993 | " ... | \n",
994 | " 0 | \n",
995 | " 0 | \n",
996 | " 0 | \n",
997 | " 0 | \n",
998 | " 0 | \n",
999 | " 0 | \n",
1000 | " 0 | \n",
1001 | " 0 | \n",
1002 | " 0 | \n",
1003 | " 0 | \n",
1004 | "
\n",
1005 | " \n",
1006 | " 4 | \n",
1007 | " 0 | \n",
1008 | " 0 | \n",
1009 | " 0 | \n",
1010 | " 0 | \n",
1011 | " 0 | \n",
1012 | " 0 | \n",
1013 | " 0 | \n",
1014 | " 0 | \n",
1015 | " 0 | \n",
1016 | " 0 | \n",
1017 | " ... | \n",
1018 | " 0 | \n",
1019 | " 0 | \n",
1020 | " 0 | \n",
1021 | " 0 | \n",
1022 | " 0 | \n",
1023 | " 0 | \n",
1024 | " 0 | \n",
1025 | " 0 | \n",
1026 | " 0 | \n",
1027 | " 0 | \n",
1028 | "
\n",
1029 | " \n",
1030 | " ... | \n",
1031 | " ... | \n",
1032 | " ... | \n",
1033 | " ... | \n",
1034 | " ... | \n",
1035 | " ... | \n",
1036 | " ... | \n",
1037 | " ... | \n",
1038 | " ... | \n",
1039 | " ... | \n",
1040 | " ... | \n",
1041 | " ... | \n",
1042 | " ... | \n",
1043 | " ... | \n",
1044 | " ... | \n",
1045 | " ... | \n",
1046 | " ... | \n",
1047 | " ... | \n",
1048 | " ... | \n",
1049 | " ... | \n",
1050 | " ... | \n",
1051 | " ... | \n",
1052 | "
\n",
1053 | " \n",
1054 | " 18746 | \n",
1055 | " 0 | \n",
1056 | " 0 | \n",
1057 | " 0 | \n",
1058 | " 0 | \n",
1059 | " 0 | \n",
1060 | " 0 | \n",
1061 | " 0 | \n",
1062 | " 0 | \n",
1063 | " 0 | \n",
1064 | " 0 | \n",
1065 | " ... | \n",
1066 | " 0 | \n",
1067 | " 0 | \n",
1068 | " 0 | \n",
1069 | " 0 | \n",
1070 | " 0 | \n",
1071 | " 0 | \n",
1072 | " 0 | \n",
1073 | " 0 | \n",
1074 | " 0 | \n",
1075 | " 0 | \n",
1076 | "
\n",
1077 | " \n",
1078 | " 18747 | \n",
1079 | " 0 | \n",
1080 | " 0 | \n",
1081 | " 0 | \n",
1082 | " 1 | \n",
1083 | " 0 | \n",
1084 | " 0 | \n",
1085 | " 0 | \n",
1086 | " 0 | \n",
1087 | " 0 | \n",
1088 | " 0 | \n",
1089 | " ... | \n",
1090 | " 0 | \n",
1091 | " 0 | \n",
1092 | " 0 | \n",
1093 | " 0 | \n",
1094 | " 0 | \n",
1095 | " 0 | \n",
1096 | " 0 | \n",
1097 | " 0 | \n",
1098 | " 0 | \n",
1099 | " 0 | \n",
1100 | "
\n",
1101 | " \n",
1102 | " 18748 | \n",
1103 | " 0 | \n",
1104 | " 0 | \n",
1105 | " 0 | \n",
1106 | " 0 | \n",
1107 | " 0 | \n",
1108 | " 0 | \n",
1109 | " 0 | \n",
1110 | " 0 | \n",
1111 | " 0 | \n",
1112 | " 0 | \n",
1113 | " ... | \n",
1114 | " 0 | \n",
1115 | " 0 | \n",
1116 | " 0 | \n",
1117 | " 0 | \n",
1118 | " 1 | \n",
1119 | " 0 | \n",
1120 | " 0 | \n",
1121 | " 0 | \n",
1122 | " 0 | \n",
1123 | " 0 | \n",
1124 | "
\n",
1125 | " \n",
1126 | " 18749 | \n",
1127 | " 0 | \n",
1128 | " 0 | \n",
1129 | " 0 | \n",
1130 | " 0 | \n",
1131 | " 0 | \n",
1132 | " 0 | \n",
1133 | " 0 | \n",
1134 | " 0 | \n",
1135 | " 0 | \n",
1136 | " 0 | \n",
1137 | " ... | \n",
1138 | " 0 | \n",
1139 | " 0 | \n",
1140 | " 0 | \n",
1141 | " 0 | \n",
1142 | " 0 | \n",
1143 | " 0 | \n",
1144 | " 0 | \n",
1145 | " 0 | \n",
1146 | " 0 | \n",
1147 | " 0 | \n",
1148 | "
\n",
1149 | " \n",
1150 | " 18750 | \n",
1151 | " 0 | \n",
1152 | " 0 | \n",
1153 | " 0 | \n",
1154 | " 0 | \n",
1155 | " 0 | \n",
1156 | " 0 | \n",
1157 | " 0 | \n",
1158 | " 0 | \n",
1159 | " 0 | \n",
1160 | " 0 | \n",
1161 | " ... | \n",
1162 | " 0 | \n",
1163 | " 0 | \n",
1164 | " 0 | \n",
1165 | " 0 | \n",
1166 | " 0 | \n",
1167 | " 0 | \n",
1168 | " 0 | \n",
1169 | " 0 | \n",
1170 | " 0 | \n",
1171 | " 0 | \n",
1172 | "
\n",
1173 | " \n",
1174 | "
\n",
1175 | "
18751 rows × 1252 columns
\n",
1176 | "
"
1177 | ],
1178 | "text/plain": [
1179 | " Unnamed: 0 vf 0 1 2 3 4 5 6 7 ... 1240 1241 1242 1243 \\\n",
1180 | "0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1181 | "1 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1182 | "2 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1183 | "3 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1184 | "4 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1185 | "... ... .. .. .. .. .. .. .. .. .. ... ... ... ... ... \n",
1186 | "18746 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1187 | "18747 0 0 0 1 0 0 0 0 0 0 ... 0 0 0 0 \n",
1188 | "18748 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1189 | "18749 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1190 | "18750 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n",
1191 | "\n",
1192 | " 1244 1245 1246 1247 1248 1249 \n",
1193 | "0 0 0 0 0 0 0 \n",
1194 | "1 0 0 0 0 0 0 \n",
1195 | "2 0 0 0 0 0 0 \n",
1196 | "3 0 0 0 0 0 0 \n",
1197 | "4 0 0 0 0 0 0 \n",
1198 | "... ... ... ... ... ... ... \n",
1199 | "18746 0 0 0 0 0 0 \n",
1200 | "18747 0 0 0 0 0 0 \n",
1201 | "18748 1 0 0 0 0 0 \n",
1202 | "18749 0 0 0 0 0 0 \n",
1203 | "18750 0 0 0 0 0 0 \n",
1204 | "\n",
1205 | "[18751 rows x 1252 columns]"
1206 | ]
1207 | },
1208 | "execution_count": 96,
1209 | "metadata": {},
1210 | "output_type": "execute_result"
1211 | }
1212 | ],
1213 | "source": [
1214 | "vf = pd.read_csv('/hdd/physio/vf2/finaldfs/final5sec.csv')\n",
1215 | "vf"
1216 | ]
1217 | },
1218 | {
1219 | "cell_type": "markdown",
1220 | "metadata": {},
1221 | "source": [
1222 | "Pi deploy baseline:
"
1223 | ]
1224 | },
1225 | {
1226 | "cell_type": "code",
1227 | "execution_count": 97,
1228 | "metadata": {},
1229 | "outputs": [],
1230 | "source": [
1231 | "import numpy as np\n",
1232 | "import pandas as pd\n",
1233 | "from sklearn.model_selection import train_test_split\n",
1234 | "import os\n",
1235 | "from ecgdetectors import Detectors\n",
1236 | "import matplotlib.pyplot as plt\n",
1237 | "import seaborn as sns\n",
1238 | "from time import time as time"
1239 | ]
1240 | },
1241 | {
1242 | "cell_type": "code",
1243 | "execution_count": 98,
1244 | "metadata": {},
1245 | "outputs": [
1246 | {
1247 | "name": "stdout",
1248 | "output_type": "stream",
1249 | "text": [
1250 | "CPU times: user 2min 26s, sys: 2min 2s, total: 4min 29s\n",
1251 | "Wall time: 6min 27s\n"
1252 | ]
1253 | }
1254 | ],
1255 | "source": [
1256 | "%%time\n",
1257 | "\n",
1258 | "afd = pd.read_csv('/hdd/physio/af/hugedfrpeak5sec.csv', index_col = 0)\n",
1259 | "chd = pd.read_csv('/hdd/physio/af2/finaldfs/2017final_normal_ecg_L5_S1.csv')"
1260 | ]
1261 | },
1262 | {
1263 | "cell_type": "code",
1264 | "execution_count": null,
1265 | "metadata": {},
1266 | "outputs": [],
1267 | "source": [
1268 | "def unpackJoblib(model):\n",
1269 | " times = []\n",
1270 | " for i in range(int(sys.argv[1])):\n",
1271 | " start = time.time()\n",
1272 | " x = pipelinedRpeakExtraction(window, fs)\n",
1273 | " features = optimizedFeatureExtraction(x, fs)\n",
1274 | " pred = model.predict(features)\n",
1275 | " end = time.time()\n",
1276 | " times.append(end - start)\n",
1277 | "\n",
1278 | " print(np.mean(times))\n",
1279 | " \n",
1280 | "def helperMethod():\n",
1281 | " for file in os.listdir('AFDB/'):\n",
1282 | " model = joblib.load(file)\n",
1283 | " print('\\n\\n' + str(file)[5:7] + \":\\n\")\n",
1284 | " unpackJoblib(model)\n",
1285 | " print('---------------')\n",
1286 | " \n",
1287 | "# for file in os.listdir('CHDB/'):\n",
1288 | "# model = joblib.load(file)\n",
1289 | "# print('\\n\\n' + str(file)[5:7] + \":\\n\")\n",
1290 | "# unpackJoblib(model)\n",
1291 | "# print('---------------')\n",
1292 | " "
1293 | ]
1294 | },
1295 | {
1296 | "cell_type": "code",
1297 | "execution_count": 100,
1298 | "metadata": {},
1299 | "outputs": [],
1300 | "source": [
1301 | "a = np.load('/hdd/physio/edgeml/examples/pytorch/Bonsai/traditional models/1window.npy')"
1302 | ]
1303 | },
1304 | {
1305 | "cell_type": "code",
1306 | "execution_count": 105,
1307 | "metadata": {},
1308 | "outputs": [
1309 | {
1310 | "data": {
1311 | "text/plain": [
1312 | "0.0"
1313 | ]
1314 | },
1315 | "execution_count": 105,
1316 | "metadata": {},
1317 | "output_type": "execute_result"
1318 | }
1319 | ],
1320 | "source": []
1321 | },
1322 | {
1323 | "cell_type": "markdown",
1324 | "metadata": {},
1325 | "source": [
1326 | "# Top features to be deployed on Pi"
1327 | ]
1328 | },
1329 | {
1330 | "cell_type": "code",
1331 | "execution_count": 2,
1332 | "metadata": {},
1333 | "outputs": [],
1334 | "source": [
1335 | "import numpy as np\n",
1336 | "import pandas as pd\n",
1337 | "from sklearn.model_selection import train_test_split\n",
1338 | "import os\n",
1339 | "from ecgdetectors import Detectors\n",
1340 | "import matplotlib.pyplot as plt\n",
1341 | "import seaborn as sns\n",
1342 | "from time import time as time"
1343 | ]
1344 | },
1345 | {
1346 | "cell_type": "code",
1347 | "execution_count": null,
1348 | "metadata": {},
1349 | "outputs": [],
1350 | "source": [
1351 | "# Individual Features\n",
1352 | "\n",
1353 | "def get_nni_20(nn_intervals, fs):\n",
1354 | " \n",
1355 | "\n",
1356 | "def get_mean_nni(nn_intervals, fs):\n",
1357 | " diff_nni = np.diff(nn_intervals)\n",
1358 | " length_int = len(nn_intervals)\n",
1359 | " return np.mean(nn_intervals)"
1360 | ]
1361 | },
1362 | {
1363 | "cell_type": "raw",
1364 | "metadata": {},
1365 | "source": [
1366 | "AFDB:\n",
1367 | "4 - nni_20, nni_50, pnni_20, pnni_50\n",
1368 | "6 - cvnni, cvsd, nni_20, nni_50, pnni_20, pnni_50\n",
1369 | "8 - cvnni, cvsd, nni_20, nni_50, pnni_20, pnni_50, std_nn, std_hr\n",
1370 | "10 - max_hr, cvnni, cvsd, nni_20, nni_50, pnni_20, pnni_50, std_nn, std_hr, sdsd\n",
1371 | "12 - max_hr, cvnni, cvsd, nni_20, nni_50, pnni_20, pnni_50, std_nn, std_hr, sdsd, mean_hr, rmssd\n",
1372 | "14 - max_hr, cvnni, cvsd, nni_20, nni_50, pnni_20, pnni_50, std_nn, std_hr, sdsd, mean_hr, rmssd, min_hr, mean_nni\n",
1373 | "\n",
1374 | "2017:\n",
1375 | "4 - cvsd, nni_20, nni_50, pnni_50\n",
1376 | "6 - cvsd, nni_20, nni_50, pnni_50, cvnni, max_hr\n",
1377 | "8 - cvsd, nni_20, nni_50, pnni_50, cvnni, max_hr, mean_hr, sdsd\n",
1378 | "10 - cvsd, nni_20, nni_50, pnni_50, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20\n",
1379 | "12 - cvsd, nni_20, nni_50, pnni_50, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20, rmssd, sdnn\n",
1380 | "14 - cvsd, nni_20, nni_50, pnni_50, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20, rmssd, sdnn, min_hr, mean_nni"
1381 | ]
1382 | },
1383 | {
1384 | "cell_type": "code",
1385 | "execution_count": 6,
1386 | "metadata": {},
1387 | "outputs": [],
1388 | "source": [
1389 | "def _2017_top_4_features(nn_intervals, fs):\n",
1390 | " diff_nni = np.diff(nn_intervals)\n",
1391 | " length_int = len(nn_intervals)\n",
1392 | " \n",
1393 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1394 | " pnni_50 = 100 * nni_50 / length_int\n",
1395 | " \n",
1396 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1397 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1398 | " \n",
1399 | " return nni_50, pnni_50, nni_20, cvsd\n",
1400 | "\n",
1401 | "def _2017_top_6_features(nn_intervals, fs):\n",
1402 | " \n",
1403 | " diff_nni = np.diff(nn_intervals)\n",
1404 | " length_int = len(nn_intervals)\n",
1405 | " \n",
1406 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1407 | " pnni_50 = 100 * nni_50 / length_int\n",
1408 | " \n",
1409 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1410 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1411 | " \n",
1412 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1413 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1414 | " \n",
1415 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1416 | " max_hr = max(heart_rate_list)\n",
1417 | " \n",
1418 | " return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr\n",
1419 | "\n",
1420 | "def _2017_top_8_features(nn_intervals, fs):\n",
1421 | " \n",
1422 | " diff_nni = np.diff(nn_intervals)\n",
1423 | " length_int = len(nn_intervals)\n",
1424 | " \n",
1425 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1426 | " pnni_50 = 100 * nni_50 / length_int\n",
1427 | " \n",
1428 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1429 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1430 | " \n",
1431 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1432 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1433 | " \n",
1434 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1435 | " max_hr = max(heart_rate_list)\n",
1436 | " mean_hr = np.mean(heart_rate_list)\n",
1437 | " sdsd = np.std(diff_nni)\n",
1438 | " \n",
1439 | " return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr, mean_hr, sdsd\n",
1440 | "\n",
1441 | "def _2017_top_10_features(nn_intervals, fs):\n",
1442 | " \n",
1443 | " diff_nni = np.diff(nn_intervals)\n",
1444 | " length_int = len(nn_intervals)\n",
1445 | " \n",
1446 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1447 | " pnni_50 = 100 * nni_50 / length_int\n",
1448 | " \n",
1449 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1450 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1451 | " \n",
1452 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1453 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1454 | " \n",
1455 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1456 | " max_hr = max(heart_rate_list)\n",
1457 | " mean_hr = np.mean(heart_rate_list)\n",
1458 | " sdsd = np.std(diff_nni)\n",
1459 | " \n",
1460 | " std_hr = np.std(heart_rate_list)\n",
1461 | " pnni_20 = 100 * nni_20 / length_int\n",
1462 | " \n",
1463 | " \n",
1464 | " return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20\n",
1465 | "\n",
1466 | "def _2017_top_12_features(nn_intervals, fs):\n",
1467 | " \n",
1468 | " diff_nni = np.diff(nn_intervals)\n",
1469 | " length_int = len(nn_intervals)\n",
1470 | " \n",
1471 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1472 | " pnni_50 = 100 * nni_50 / length_int\n",
1473 | " \n",
1474 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1475 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1476 | " \n",
1477 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1478 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1479 | " \n",
1480 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1481 | " max_hr = max(heart_rate_list)\n",
1482 | " mean_hr = np.mean(heart_rate_list)\n",
1483 | " sdsd = np.std(diff_nni)\n",
1484 | " \n",
1485 | " std_hr = np.std(heart_rate_list)\n",
1486 | " pnni_20 = 100 * nni_20 / length_int\n",
1487 | " \n",
1488 | " rmssd = np.sqrt(np.mean(diff_nni ** 2))\n",
1489 | " \n",
1490 | " return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20, rmssd, sdnn\n",
1491 | "\n",
1492 | "def _2017_top_14_features(nn_intervals, fs):\n",
1493 | " \n",
1494 | " diff_nni = np.diff(nn_intervals)\n",
1495 | " length_int = len(nn_intervals)\n",
1496 | " \n",
1497 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1498 | " pnni_50 = 100 * nni_50 / length_int\n",
1499 | " \n",
1500 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1501 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1502 | " \n",
1503 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1504 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1505 | " \n",
1506 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1507 | " max_hr = max(heart_rate_list)\n",
1508 | " mean_hr = np.mean(heart_rate_list)\n",
1509 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1510 | " \n",
1511 | " std_hr = np.std(heart_rate_list)\n",
1512 | " pnni_20 = 100 * nni_20 / length_int\n",
1513 | " \n",
1514 | " rmssd = np.sqrt(np.mean(diff_nni ** 2))\n",
1515 | " sdsd = np.std(diff_nni) \n",
1516 | " min_hr = min(heart_rate_list)\n",
1517 | " mean_nni = np.mean(nn_intervals)\n",
1518 | " \n",
1519 | " return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20, rmssd, sdnn, min_hr, mean_nni\n",
1520 | " \n",
1521 | "\n",
1522 | "def afdb_top_4_features(nn_intervals, fs): \n",
1523 | " diff_nni = np.diff(nn_intervals)\n",
1524 | " length_int = len(nn_intervals)\n",
1525 | " \n",
1526 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1527 | " pnni_50 = 100 * nni_50 / length_int\n",
1528 | " \n",
1529 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1530 | " pnni_20 = 100 * nni_20 / length_int\n",
1531 | " \n",
1532 | " return nni_20, nni_50, pnni_20, pnni_50\n",
1533 | " \n",
1534 | " \n",
1535 | "def afdb_top_6_features(nn_intervals, fs):\n",
1536 | " \n",
1537 | " diff_nni = np.diff(nn_intervals)\n",
1538 | " length_int = len(nn_intervals)\n",
1539 | " \n",
1540 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1541 | " pnni_50 = 100 * nni_50 / length_int\n",
1542 | " \n",
1543 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1544 | " pnni_20 = 100 * nni_20 / length_int\n",
1545 | " \n",
1546 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1547 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1548 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1549 | " \n",
1550 | " return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd\n",
1551 | "\n",
1552 | "def afdb_top_8_features(nn_intervals, fs):\n",
1553 | " \n",
1554 | " diff_nni = np.diff(nn_intervals)\n",
1555 | " length_int = len(nn_intervals)\n",
1556 | " \n",
1557 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1558 | " pnni_50 = 100 * nni_50 / length_int\n",
1559 | " \n",
1560 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1561 | " pnni_20 = 100 * nni_20 / length_int\n",
1562 | " \n",
1563 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1564 | " \n",
1565 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1566 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1567 | " \n",
1568 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1569 | " std_hr = np.std(heart_rate_list)\n",
1570 | " return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd, sdnn, std_hr \n",
1571 | "\n",
1572 | "def afdb_top_10_features(nn_intervals, fs):\n",
1573 | " \n",
1574 | " diff_nni = np.diff(nn_intervals)\n",
1575 | " length_int = len(nn_intervals)\n",
1576 | " \n",
1577 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1578 | " pnni_50 = 100 * nni_50 / length_int\n",
1579 | " \n",
1580 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1581 | " pnni_20 = 100 * nni_20 / length_int\n",
1582 | " \n",
1583 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1584 | "\n",
1585 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1586 | " cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)\n",
1587 | " \n",
1588 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1589 | " std_hr = np.std(heart_rate_list)\n",
1590 | " \n",
1591 | " max_hr = max(heart_rate_list)\n",
1592 | " sdsd = np.std(diff_nni)\n",
1593 | " \n",
1594 | " return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd, sdnn, std_hr, max_hr, sdsd\n",
1595 | "\n",
1596 | "def afdb_top_12_features(nn_intervals, fs):\n",
1597 | " \n",
1598 | " diff_nni = np.diff(nn_intervals)\n",
1599 | " length_int = len(nn_intervals)\n",
1600 | " \n",
1601 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1602 | " pnni_50 = 100 * nni_50 / length_int\n",
1603 | " \n",
1604 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1605 | " pnni_20 = 100 * nni_20 / length_int\n",
1606 | " \n",
1607 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1608 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1609 | " rmssd = np.sqrt(np.mean(diff_nni ** 2))\n",
1610 | " cvsd = rmssd / np.mean(nn_intervals)\n",
1611 | " \n",
1612 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1613 | " std_hr = np.std(heart_rate_list)\n",
1614 | " \n",
1615 | " max_hr = max(heart_rate_list)\n",
1616 | " sdsd = np.std(diff_nni)\n",
1617 | " \n",
1618 | " mean_hr = np.mean(heart_rate_list)\n",
1619 | " \n",
1620 | " \n",
1621 | " return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd, sdnn, std_hr, max_hr, sdsd, mean_hr, rmssd\n",
1622 | "\n",
1623 | "def afdb_top_14_features(nn_intervals, fs):\n",
1624 | " \n",
1625 | " diff_nni = np.diff(nn_intervals)\n",
1626 | " length_int = len(nn_intervals)\n",
1627 | " \n",
1628 | " nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))\n",
1629 | " pnni_50 = 100 * nni_50 / length_int\n",
1630 | " \n",
1631 | " nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))\n",
1632 | " pnni_20 = 100 * nni_20 / length_int\n",
1633 | " \n",
1634 | " sdnn = np.std(nn_intervals, ddof = 1)\n",
1635 | " cvnni = sdnn / np.mean(nn_intervals)\n",
1636 | " rmssd = np.sqrt(np.mean(diff_nni ** 2))\n",
1637 | " cvsd = rmssd / np.mean(nn_intervals)\n",
1638 | " \n",
1639 | " heart_rate_list = np.divide(60000, nn_intervals)\n",
1640 | " std_hr = np.std(heart_rate_list)\n",
1641 | " \n",
1642 | " max_hr = max(heart_rate_list)\n",
1643 | " sdsd = np.std(diff_nni)\n",
1644 | " \n",
1645 | " mean_hr = np.mean(heart_rate_list)\n",
1646 | " \n",
1647 | " min_hr = min(heart_rate_list)\n",
1648 | " mean_nni = np.mean(nn_intervals)\n",
1649 | " \n",
1650 | " return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd, sdnn, std_hr, max_hr, sdsd, mean_hr, rmssd, min_hr, mean_nni"
1651 | ]
1652 | },
1653 | {
1654 | "cell_type": "code",
1655 | "execution_count": 18,
1656 | "metadata": {},
1657 | "outputs": [
1658 | {
1659 | "data": {
1660 | "text/plain": [
1661 | "array([0., 1.])"
1662 | ]
1663 | },
1664 | "execution_count": 18,
1665 | "metadata": {},
1666 | "output_type": "execute_result"
1667 | }
1668 | ],
1669 | "source": [
1670 | "# sathyama therla\n",
1671 | "\n",
1672 | "def normalize(data):\n",
1673 | " return (data - np.min(data)) / (np.max(data) - np.min(data))\n",
1674 | "\n",
1675 | "normalize([1, 2])"
1676 | ]
1677 | },
1678 | {
1679 | "cell_type": "code",
1680 | "execution_count": 12,
1681 | "metadata": {},
1682 | "outputs": [],
1683 | "source": [
1684 | "from scipy.signal import butter, lfilter,filtfilt, iirnotch\n",
1685 | "from scipy.signal import freqs\n",
1686 | "import matplotlib.pyplot as plt\n",
1687 | "from scipy.signal import medfilt"
1688 | ]
1689 | },
1690 | {
1691 | "cell_type": "code",
1692 | "execution_count": 28,
1693 | "metadata": {},
1694 | "outputs": [
1695 | {
1696 | "name": "stdout",
1697 | "output_type": "stream",
1698 | "text": [
1699 | "CPU times: user 3.04 ms, sys: 0 ns, total: 3.04 ms\n",
1700 | "Wall time: 2.19 ms\n"
1701 | ]
1702 | }
1703 | ],
1704 | "source": [
1705 | "%%time\n",
1706 | "fs = 250\n",
1707 | "n = 0.5*fs\n",
1708 | "f_high = 0.5\n",
1709 | "cut_off = f_high/n\n",
1710 | "\n",
1711 | "order = 4\n",
1712 | "\n",
1713 | "data = np.arange(250)\n",
1714 | "data = normalize(data)\n",
1715 | "\n",
1716 | "b,a = butter(order, cut_off,btype='high')\n",
1717 | "high_filtered_data = filtfilt(b, a, data)"
1718 | ]
1719 | },
1720 | {
1721 | "cell_type": "code",
1722 | "execution_count": null,
1723 | "metadata": {},
1724 | "outputs": [],
1725 | "source": []
1726 | }
1727 | ],
1728 | "metadata": {
1729 | "kernelspec": {
1730 | "display_name": "Python [conda env:physio]",
1731 | "language": "python",
1732 | "name": "conda-env-physio-py"
1733 | },
1734 | "language_info": {
1735 | "codemirror_mode": {
1736 | "name": "ipython",
1737 | "version": 3
1738 | },
1739 | "file_extension": ".py",
1740 | "mimetype": "text/x-python",
1741 | "name": "python",
1742 | "nbconvert_exporter": "python",
1743 | "pygments_lexer": "ipython3",
1744 | "version": "3.7.5"
1745 | }
1746 | },
1747 | "nbformat": 4,
1748 | "nbformat_minor": 4
1749 | }
1750 |
--------------------------------------------------------------------------------
/record_inference_time.py:
--------------------------------------------------------------------------------
1 | import helpermethods
2 | import numpy as np
3 | import sys
4 | import edgeml_pytorch.utils as utils
5 | from edgeml_pytorch.graph.bonsai import Bonsai
6 | import torch
7 | import time
8 | import pandas as pd
9 | from ecgdetectors import Detectors
10 | from scipy.signal import butter, lfilter,filtfilt, iirnotch
11 | from scipy.signal import freqs
12 | import matplotlib.pyplot as plt
13 | from scipy.signal import medfilt
14 |
15 | fs = 250
16 | n = 0.5*fs
17 | f_high = 0.5
18 | cut_off = f_high/n
19 |
20 | order = 4
21 |
22 | def loadModel(currDir):
23 | '''
24 | Load the Saved model and load it to the model using constructor
25 | Returns two dict one for params and other for hyperParams
26 | '''
27 | paramDir = currDir + '/'
28 | paramDict = {}
29 | paramDict['W'] = np.load(paramDir + "W.npy")
30 | paramDict['V'] = np.load(paramDir + "V.npy")
31 | paramDict['T'] = np.load(paramDir + "T.npy")
32 | paramDict['Z'] = np.load(paramDir + "Z.npy")
33 | hyperParamDict = np.load(paramDir + "hyperParam.npy", allow_pickle=True).item()
34 | return paramDict, hyperParamDict
35 |
36 |
37 | def pipelinedRpeakExtraction(x, fs):
38 |
39 | x = detectors.swt_detector(x)
40 | # x = detectors.hamilton_detector(x)
41 | # x = detectors.pan_tompkins_detector(x)
42 | return x
43 |
44 | # def get_mean_nni(nn_intervals, fs):
45 | # diff_nni = np.diff(nn_intervals)
46 | # length_int = len(nn_intervals)
47 | # return np.mean(nn_intervals)
48 |
49 | def _2017_top_4_features(nn_intervals, fs):
50 | diff_nni = np.diff(nn_intervals)
51 | length_int = len(nn_intervals)
52 |
53 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
54 | pnni_50 = 100 * nni_50 / length_int
55 |
56 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
57 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
58 |
59 | return nni_50, pnni_50, nni_20, cvsd
60 |
61 | def _2017_top_6_features(nn_intervals, fs):
62 |
63 | diff_nni = np.diff(nn_intervals)
64 | length_int = len(nn_intervals)
65 |
66 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
67 | pnni_50 = 100 * nni_50 / length_int
68 |
69 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
70 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
71 |
72 | sdnn = np.std(nn_intervals, ddof = 1)
73 | cvnni = sdnn / np.mean(nn_intervals)
74 |
75 | heart_rate_list = np.divide(60000, nn_intervals)
76 | max_hr = max(heart_rate_list)
77 |
78 | return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr
79 |
80 | def _2017_top_8_features(nn_intervals, fs):
81 |
82 | diff_nni = np.diff(nn_intervals)
83 | length_int = len(nn_intervals)
84 |
85 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
86 | pnni_50 = 100 * nni_50 / length_int
87 |
88 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
89 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
90 |
91 | sdnn = np.std(nn_intervals, ddof = 1)
92 | cvnni = sdnn / np.mean(nn_intervals)
93 |
94 | heart_rate_list = np.divide(60000, nn_intervals)
95 | max_hr = max(heart_rate_list)
96 | mean_hr = np.mean(heart_rate_list)
97 | sdnn = np.std(nn_intervals, ddof = 1)
98 |
99 | return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr, mean_hr, sdsd
100 |
101 | def _2017_top_10_features(nn_intervals, fs):
102 |
103 | diff_nni = np.diff(nn_intervals)
104 | length_int = len(nn_intervals)
105 |
106 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
107 | pnni_50 = 100 * nni_50 / length_int
108 |
109 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
110 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
111 |
112 | sdnn = np.std(nn_intervals, ddof = 1)
113 | cvnni = sdnn / np.mean(nn_intervals)
114 |
115 | heart_rate_list = np.divide(60000, nn_intervals)
116 | max_hr = max(heart_rate_list)
117 | mean_hr = np.mean(heart_rate_list)
118 | sdnn = np.std(nn_intervals, ddof = 1)
119 |
120 | std_hr = np.std(heart_rate_list)
121 | pnni_20 = 100 * nni_20 / length_int
122 |
123 |
124 | return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20
125 |
126 | def _2017_top_12_features(nn_intervals, fs):
127 |
128 | diff_nni = np.diff(nn_intervals)
129 | length_int = len(nn_intervals)
130 |
131 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
132 | pnni_50 = 100 * nni_50 / length_int
133 |
134 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
135 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
136 |
137 | sdnn = np.std(nn_intervals, ddof = 1)
138 | cvnni = sdnn / np.mean(nn_intervals)
139 |
140 | heart_rate_list = np.divide(60000, nn_intervals)
141 | max_hr = max(heart_rate_list)
142 | mean_hr = np.mean(heart_rate_list)
143 | sdnn = np.std(nn_intervals, ddof = 1)
144 |
145 | std_hr = np.std(heart_rate_list)
146 | pnni_20 = 100 * nni_20 / length_int
147 |
148 | rmssd = np.sqrt(np.mean(diff_nni ** 2))
149 | sdnn = np.std(nn_intervals, ddof = 1)
150 |
151 | return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20, rmssd, sdnn
152 |
153 |
154 |
155 |
156 | def afdb_top_4_features(nn_intervals, fs):
157 | diff_nni = np.diff(nn_intervals)
158 | length_int = len(nn_intervals)
159 |
160 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
161 | pnni_50 = 100 * nni_50 / length_int
162 |
163 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
164 | pnni_20 = 100 * nni_20 / length_int
165 |
166 | return np.array([nni_20, nni_50, pnni_20, pnni_50, 1])
167 |
168 |
169 | def afdb_top_6_features(nn_intervals, fs):
170 |
171 | diff_nni = np.diff(nn_intervals)
172 | length_int = len(nn_intervals)
173 |
174 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
175 | pnni_50 = 100 * nni_50 / length_int
176 |
177 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
178 | pnni_20 = 100 * nni_20 / length_int
179 |
180 | sdnn = np.std(nn_intervals, ddof = 1)
181 | cvnni = sdnn / np.mean(nn_intervals)
182 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
183 |
184 | return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd
185 |
186 | def afdb_top_8_features(nn_intervals, fs):
187 |
188 | diff_nni = np.diff(nn_intervals)
189 | length_int = len(nn_intervals)
190 |
191 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
192 | pnni_50 = 100 * nni_50 / length_int
193 |
194 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
195 | pnni_20 = 100 * nni_20 / length_int
196 |
197 | sdnn = np.std(nn_intervals, ddof = 1)
198 | cvnni = sdnn / np.mean(nn_intervals)
199 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
200 |
201 | heart_rate_list = np.divide(60000, nn_intervals)
202 | std_hr = np.std(heart_rate_list)
203 | return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd, sdnn, std_hr
204 |
205 | def afdb_top_10_features(nn_intervals, fs):
206 |
207 | diff_nni = np.diff(nn_intervals)
208 | length_int = len(nn_intervals)
209 |
210 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
211 | pnni_50 = 100 * nni_50 / length_int
212 |
213 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
214 | pnni_20 = 100 * nni_20 / length_int
215 |
216 | sdnn = np.std(nn_intervals, ddof = 1)
217 | cvnni = sdnn / np.mean(nn_intervals)
218 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
219 |
220 | heart_rate_list = np.divide(60000, nn_intervals)
221 | std_hr = np.std(heart_rate_list)
222 |
223 | max_hr = max(heart_rate_list)
224 | sdsd = np.std(diff_nni)
225 |
226 | return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd, sdnn, std_hr, max_hr, sdsd
227 |
228 | def afdb_top_12_features(nn_intervals, fs):
229 |
230 | diff_nni = np.diff(nn_intervals)
231 | length_int = len(nn_intervals)
232 |
233 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
234 | pnni_50 = 100 * nni_50 / length_int
235 |
236 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
237 | pnni_20 = 100 * nni_20 / length_int
238 |
239 | sdnn = np.std(nn_intervals, ddof = 1)
240 | cvnni = sdnn / np.mean(nn_intervals)
241 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
242 |
243 | heart_rate_list = np.divide(60000, nn_intervals)
244 | std_hr = np.std(heart_rate_list)
245 |
246 | max_hr = max(heart_rate_list)
247 | sdsd = np.std(diff_nni)
248 |
249 | mean_hr = np.mean(heart_rate_list)
250 | rmssd = np.sqrt(np.mean(diff_nni ** 2))
251 |
252 | return nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd, sdnn, std_hr, max_hr, sdsd, mean_hr, rmssd
253 |
254 | def afdb_top_14_features(nn_intervals, fs):
255 |
256 | diff_nni = np.diff(nn_intervals)
257 | abs_diff = np.abs(diff_nni)
258 | length_int = len(nn_intervals)
259 | mean_nni = np.mean(nn_intervals)
260 | rmssd = np.sqrt(np.mean(diff_nni ** 2))
261 |
262 |
263 | nni_50 = sum(abs_diff > 12.5)
264 | pnni_50 = 100 * nni_50 / length_int
265 |
266 | nni_20 = sum(abs_diff > 5)
267 | pnni_20 = 100 * nni_20 / length_int
268 |
269 | sdnn = np.std(nn_intervals, ddof = 1)
270 | cvnni = sdnn / mean_nni
271 | cvsd = rmssd / mean_nni
272 |
273 | heart_rate_list = np.divide(60, nn_intervals)
274 | std_hr = np.std(heart_rate_list)
275 |
276 | max_hr = max(heart_rate_list)
277 | sdsd = np.std(diff_nni)
278 |
279 | mean_hr = np.mean(heart_rate_list)
280 |
281 | min_hr = min(heart_rate_list)
282 |
283 |
284 | return np.array([nni_20, nni_50, pnni_20, pnni_50, cvnni, cvsd, sdnn, std_hr, max_hr, sdsd, mean_hr, rmssd, min_hr, mean_nni, 1])
285 |
286 | def _2017_top_14_features(nn_intervals, fs):
287 |
288 | diff_nni = np.diff(nn_intervals)
289 | length_int = len(nn_intervals)
290 |
291 | nni_50 = sum(np.abs(diff_nni) > (50*fs/1000))
292 | pnni_50 = 100 * nni_50 / length_int
293 |
294 | nni_20 = sum(np.abs(diff_nni) > (20*fs/1000))
295 | cvsd = np.sqrt(np.mean(diff_nni ** 2)) / np.mean(nn_intervals)
296 |
297 | sdnn = np.std(nn_intervals, ddof = 1)
298 | cvnni = sdnn / np.mean(nn_intervals)
299 |
300 | heart_rate_list = np.divide(60000, nn_intervals)
301 | max_hr = max(heart_rate_list)
302 | mean_hr = np.mean(heart_rate_list)
303 | sdnn = np.std(nn_intervals, ddof = 1)
304 |
305 | std_hr = np.std(heart_rate_list)
306 | pnni_20 = 100 * nni_20 / length_int
307 |
308 | rmssd = np.sqrt(np.mean(diff_nni ** 2))
309 | sdnn = np.std(nn_intervals, ddof = 1)
310 |
311 | min_hr = min(heart_rate_list)
312 | mean_nni = np.mean(nn_intervals)
313 |
314 | return nni_50, pnni_50, nni_20, cvsd, cvnni, max_hr, mean_hr, sdsd, std_hr, pnni_20, rmssd, sdnn, min_hr, mean_nni
315 |
316 |
317 | device = torch.device("cpu")
318 |
319 | MODEL_DIR = "/hdd/physio/edgeml/examples/pytorch/Bonsai/AFDB_top14/PyTorchBonsaiResults/16_50_10_09_08_21"
320 |
321 | paramDict, hyperParamDict = loadModel(MODEL_DIR)
322 |
323 | bonsai = Bonsai(hyperParamDict['numClasses'], hyperParamDict['dataDim'], hyperParamDict['projDim'],
324 | hyperParamDict['depth'], hyperParamDict['sigma'], W=paramDict['W'], T=paramDict['T'], V=paramDict['V'],
325 | Z=paramDict['Z']).to(device)
326 |
327 | sigmaI = 1e9
328 |
329 | def normalize(data):
330 | return (data - np.min(data)) / (np.max(data) - np.min(data))
331 |
332 |
333 | window = np.load("1window.npy")
334 |
335 | fs = 250
336 | detectors = Detectors(fs)
337 | b, a = butter(order, cut_off,btype='high')
338 |
339 | times = []
340 | for i in range(int(sys.argv[1])):
341 | start = time.time()
342 | window = filtfilt(b, a, window)
343 | window = normalize(window)
344 | x = pipelinedRpeakExtraction(window, fs)
345 | x = np.diff(x)
346 | features = afdb_top_14_features(x, fs)
347 | _, _ = bonsai(torch.from_numpy(features.astype(np.float32)), sigmaI)
348 | end = time.time()
349 | times.append(end - start)
350 |
351 |
352 | print("features + model + baseline wander removal: ", np.mean(times)*1000, "ms")
353 |
354 | times = []
355 | for i in range(int(sys.argv[1])):
356 | start = time.time()
357 | x = pipelinedRpeakExtraction(window, fs)
358 | x = np.diff(x)
359 | features = afdb_top_14_features(x, fs)
360 | _, _ = bonsai(torch.from_numpy(features.astype(np.float32)), sigmaI)
361 | end = time.time()
362 | times.append(end - start)
363 |
364 |
365 | print("features + model : ", np.mean(times)*1000, "ms")
366 | print("features + model : max", np.max(times)*1000, "ms")
367 | print("features + model : min", np.min(times)*1000, "ms")
368 |
369 | times = []
370 | for i in range(int(sys.argv[1])):
371 | start = time.time()
372 | _, _ = bonsai(torch.from_numpy(features.astype(np.float32)), sigmaI)
373 | end = time.time()
374 | times.append(end - start)
375 |
376 |
377 | print("model : ", np.mean(times)*1000, "ms")
--------------------------------------------------------------------------------