├── data └── .gitignore ├── output └── .gitignore ├── annotations └── .gitignore ├── pyfar ├── __init__.py ├── vtach_beats.csv ├── parameters.py ├── utils.py ├── pipeline.py ├── sigtypes ├── ventricular_beat_bank.py ├── classifier.py ├── dtw.py ├── ventricular_beat_stdev.py └── baseline_algorithm.py ├── requirements.txt ├── README.rst ├── download ├── download_data.sh └── download_annotations.sh ├── matlab ├── README.md └── run_on_challenge_data.m ├── docs ├── index.rst ├── pyfar.rst ├── Makefile ├── contributing.rst ├── install.rst ├── quickstart.rst └── conf.py ├── LICENSE ├── .gitignore ├── paper.bib ├── codemeta.json ├── paper.md └── setup.py /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /output/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /annotations/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /pyfar/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["baseline_algorithm", "classifier", "dtw", "parameters", "pipeline", 2 | "utils","ventricular_beat_bank","ventricular_beat_stdev"] 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastdtw>=0.3.2 2 | matplotlib>=2.0.2 3 | numpy>=1.13.1 4 | pandas>=0.20.3 5 | PeakUtils>=1.1.0 6 | scikit-learn>=0.19.0 7 | scipy>=0.19.0 8 | sklearn>=0.0 9 | spectrum>=0.7.1 10 | virtualenv>=15.0.1 11 | wfdb>=1.2.2 12 | 13 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | false-alarm-reduction 2 | ===================== 3 | 4 | |DOI| 5 | 6 | Code for building a model to reduce false alarms in the intensive care 7 | unit. 8 | 9 | .. |DOI| image:: https://zenodo.org/badge/59120353.svg 10 | :target: https://zenodo.org/badge/latestdoi/59120353 11 | -------------------------------------------------------------------------------- /download/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # downloads the data to the data subfolder 3 | # run this script from the root folder of the cloned repository 4 | 5 | mkdir -p data 6 | 7 | cd data 8 | wget -O training.zip https://physionet.org/challenge/2015/training.zip 9 | unzip training.zip 10 | cd .. 11 | -------------------------------------------------------------------------------- /matlab/README.md: -------------------------------------------------------------------------------- 1 | This folder contains code which was used to generate the jqrs annotations. In order to run the code, you need: 2 | 3 | * the peak-detector repository (www.github.com/alistairewj/peak-detector) 4 | * the WFDB Matlab toolbox (https://github.com/ikarosilva/wfdb-app-toolbox) 5 | 6 | You may also need to modify the paths so it locates the data properly. 7 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. False Alarm Reduction documentation master file, created by 2 | sphinx-quickstart on Tue Feb 13 09:57:24 2018. 3 | 4 | False Alarm Reduction 5 | ===================== 6 | 7 | False Alarm Reduction is a library for reducing the number of false alarms 8 | when detecting events in physiologic waveforms. 9 | 10 | See :doc:`the quickstart ` to get started. 11 | 12 | Contents 13 | -------- 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :caption: Getting Started 18 | 19 | quickstart 20 | install 21 | contributing 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Package Reference 26 | 27 | pyfar 28 | -------------------------------------------------------------------------------- /docs/pyfar.rst: -------------------------------------------------------------------------------- 1 | pyfar 2 | =================== 3 | 4 | Classifier 5 | ------------------- 6 | .. automodule:: baseline_algorithm 7 | :members: 8 | 9 | .. automodule:: classifier 10 | :members: 11 | 12 | Utilities 13 | ------------------- 14 | .. automodule:: utils 15 | :members: 16 | 17 | DTW 18 | ------------------- 19 | .. automodule:: dtw 20 | :members: 21 | 22 | Pipeline 23 | ------------------- 24 | .. automodule:: pipeline 25 | :members: 26 | 27 | Ventricular beat modelling 28 | -------------------------- 29 | 30 | .. automodule:: ventricular_beat_stdev 31 | :members: 32 | 33 | .. automodule:: ventricular_beat_bank 34 | :members: 35 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = FalseAlarmReduction 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /download/download_annotations.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # downloads the annotations (R peak detections) to the annotations subfolder 3 | # run this script from the root folder of the cloned repository 4 | 5 | mkdir -p annotations 6 | 7 | wget -O annotations/ann_gqrs0.zip https://www.dropbox.com/sh/hv4uat0ihwlygq8/AABvdXbSGZi3COPG-O_-nBGxa/ann_gqrs0.zip?dl=1 8 | wget -O annotations/ann_gqrs1.zip https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAAAb14a_NN8iKojXEoInXCGa/ann_gqrs1.zip?dl=1 9 | wget -O annotations/ann_wabp.zip https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAALSmteHaL0gQovwXj8CXV4a/ann_wabp.zip?dl=1 10 | wget -O annotations/ann_wpleth.zip https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAAko1RNvgmdhWF7lNux-Ob3a/ann_wpleth.zip?dl=1 11 | 12 | cd annotations 13 | unzip ann_gqrs0.zip 14 | unzip ann_gqrs1.zip 15 | unzip ann_wabp.zip 16 | unzip ann_wpleth.zip 17 | cd .. 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Alistair Johnson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore virtual env files 2 | include/ 3 | bin/ 4 | local/ 5 | 6 | 7 | # Mac OS-X 8 | .DS_Store 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | env/ 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sampledata/ 31 | sample_data/challenge_training_ann/ 32 | sample_data/challenge_training_data/ 33 | sample_data/challenge_training_multiann/ 34 | sample_data/fplesinger_data/ 35 | sdist/ 36 | var/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *,cover 60 | .hypothesis/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | #Ipython Notebook 76 | .ipynb_checkpoints 77 | -------------------------------------------------------------------------------- /pyfar/vtach_beats.csv: -------------------------------------------------------------------------------- 1 | sample_name,lead,is_vtach,peak_time,start_time,end_time,gqrs,jqrs,fp 2 | v131l,II,1,297,296.8,297.2,-2,0,1 3 | v132s,II,1,297.1,296.8,297.3,-2,0,1 4 | v133l,II,1,297.4,297,297.8,-2,0,0 5 | v158s,II,1,296.5,296.3,296.9,0,, 6 | v159l,II,1,294,293.6,294.2,0,, 7 | v194s,II,1,297.1,296.7,297.1,0,, 8 | v197l,II,1,296.5,296.4,297,0,, 9 | v199l,II,1,296.3,296,296.7,0,, 10 | v206s,II,1,287,286.8,287.3,-1,, 11 | v221l,II,1,293.5,293,293.65,-1,, 12 | v253l,II,1,291.5,291.2,291.8,-1,, 13 | v254s,II,1,291.3,291,291.5,0,, 14 | v255l,II,1,291.8,291.5,292,-1,, 15 | v275l,II,1,294.3,294.1,294.5,1,, 16 | v309l,II,1,295.3,295.2,295.7,-2,, 17 | v318s,II,1,292,291.7,292.4,-2,, 18 | v329l,II,1,296.7,296.6,296.95,-2,, 19 | v369l,II,1,298.1,298,298.4,0,, 20 | v404s,II,1,299.1,299,299.7,0,, 21 | v100s,II,0,294.2,294,294.6,-2,, 22 | v101l,II,0,292.2,292,292.6,-2,, 23 | v102s,II,0,292,291.7,292.2,-2,, 24 | v111l,II,0,294.2,294,294.9,-1,, 25 | v113l,II,0,294.2,294,294.5,-2,, 26 | v119l,II,0,292.2,292,292.6,-2,, 27 | v122s,II,0,293.4,293,293.8,-2,, 28 | v127l,II,0,293.6,293.2,294.2,-2,, 29 | v128s,II,0,293.1,292.5,293.5,-2,, 30 | v135l,II,0,294,293.7,294.4,-2,, 31 | v136s,II,0,293.6,293.3,294,-2,, 32 | v141l,II,0,294.3,294.15,294.7,-2,, 33 | v148s,II,0,292.5,292.1,293.2,-2,, 34 | v153l,II,0,294,293.8,294.4,-2,, 35 | v154s,II,0,292,291.9,292.4,-2,, 36 | v155l,II,0,296,295.8,296.3,-2,, 37 | v160s,II,0,296.7,296.6,296.9,0,, 38 | v162s,II,0,294.9,294.7,295,-2,0,0 39 | v164s,II,0,292.2,292,292.7,-2,, 40 | -------------------------------------------------------------------------------- /paper.bib: -------------------------------------------------------------------------------- 1 | 2 | @article{siebig, 3 | author = {Siebig, S and Kuhls, S and Imhoff, M and Langgartner, J and Reng, M and Sch\"olmerich, J and et al.}, 4 | title = {Collection of annotated data in a clinical validation study for alarm algorithms in intensive care -- A methodologic framework}, 5 | journal = {J. Crit. Care.}, 6 | volume = {25}, 7 | pages = {128-135}, 8 | year = {2010}, 9 | } 10 | 11 | @inproceedings{challenge, 12 | title={The PhysioNet/Computing in Cardiology Challenge 2015: reducing false arrhythmia alarms in the ICU}, 13 | author={Clifford, Gari D and Silva, Ikaro and Moody, Benjamin and Li, Qiao and Kella, Danesh and Shahin, Abdullah and Kooistra, Tristan and Perry, Diane and Mark, Roger G}, 14 | booktitle={Computing in Cardiology Conference (CinC), 2015}, 15 | pages={273--276}, 16 | year={2015}, 17 | organization={IEEE} 18 | } 19 | 20 | @article{plesinger, 21 | author = {Plesinger, F and Klimes, P and Halamet, J and Jurak, P.}, 22 | title = {Taming of the monitors: Reducing false alarms in intensive care units}, 23 | journal = {Physiol. Meas.}, 24 | volume = {37}, 25 | pages = {1313-1325}, 26 | year = {2016}, 27 | } 28 | 29 | @article{fareduction, 30 | title={MIT-LCP/false-alarm-reduction: False Alarm Reduction v1.0.0}, 31 | DOI={10.5281/zenodo.889036}, 32 | abstractNote={

This is the initial release of the False Alarm Reduction code used to detect false arrhythmia alarms in the intensive care unit. The algorithm was based upon that of Plesinger et al. 2015 with various modifications.

}, 33 | publisher={Zenodo}, 34 | author={Alistair Johnson and Andrea Li}, 35 | year={2017}, 36 | month={Sep} 37 | } 38 | -------------------------------------------------------------------------------- /codemeta.json: -------------------------------------------------------------------------------- 1 | { 2 | "@context": "https://raw.githubusercontent.com/codemeta/codemeta/master/codemeta.jsonld", 3 | "@type": "Code", 4 | "author": [ 5 | { 6 | "@id": "http://orcid.org/0000-0001-8419-5527", 7 | "@type": "Person", 8 | "email": "liandrea@mit.edu", 9 | "name": "Andrea S. Li", 10 | "affiliation": "Massachusetts Institute of Technology" 11 | }, 12 | { 13 | "@id": "http://orcid.org/0000-0002-8735-3014", 14 | "@type": "Person", 15 | "email": "aewj@mit.edu", 16 | "name": "Alistair E. W. Johnson", 17 | "affiliation": "Massachusetts Institute of Technology" 18 | }, 19 | { 20 | "@id": "http://orcid.org/0000-0002-6318-2978", 21 | "@type": "Person", 22 | "email": "rgmark@mit.edu", 23 | "name": "Roger G. Mark", 24 | "affiliation": "Massachusetts Institute of Technology" 25 | } 26 | ], 27 | "identifier": "http://dx.doi.org/10.5281/zenodo.889036", 28 | "codeRepository": "https://github.com/MIT-LCP/false-alarm-reduction", 29 | "datePublished": "2017-09-11", 30 | "dateModified": "2017-09-11", 31 | "dateCreated": "2017-09-11", 32 | "description": "This code provides tools to decrease the false alarm rate for cardiac arrhythmias in the intensive care unit. It explores a baseline algorithm based on an algorithm published by Plesinger, et al. (2016), as well as dynamic time warping as a technique to identify false alarms.", 33 | "keywords": "false alarm reduction, signal processing, dynamic time warping, intensive care unit, arrhythmia", 34 | "license": "MIT", 35 | "title": "False alarm reduction in the intensive care unit", 36 | "version": "v1.0.1" 37 | } 38 | -------------------------------------------------------------------------------- /paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'False alarm reduction in the intensive care unit' 3 | tags: 4 | - false alarm reduction 5 | - signal processing 6 | - intensive care unit 7 | - dynamic time warping 8 | - arrhythmia 9 | authors: 10 | - name: Andrea S. Li 11 | orcid: 0000-0001-8419-5527 12 | affiliation: 1 13 | - name: Alistair E. W. Johnson 14 | orcid: 0000-0002-8735-3014 15 | affiliation: 1 16 | - name: Roger G. Mark 17 | orcid: 0000-0002-6318-2978 18 | affiliation: 1 19 | affiliations: 20 | - name: Massachusetts Institute of Technology 21 | index: 1 22 | date: 11 September 2017 23 | bibliography: paper.bib 24 | --- 25 | 26 | # Summary 27 | 28 | This is an algorithm for reducing the number of false arrhythmia alarms reported by intensive care unit monitors. 29 | Research has shown that only 17\% of alarms in the intensive care unit (ICU) are clinically relevant [@siebig]. 30 | The high false arrhythmia alarm rate has severe implications such as disruption of patient care, caregiver alarm fatigue, and desensitization from clinical staff to real life-threatening alarms [@imhoff]. 31 | A method to reduce the false alarm rate would therefore greatly benefit patients as well as nurses in their ability to provide care. We here develop and describe a robust false arrhythmia alarm reduction system for use in the ICU. 32 | We utilize the PhysioNet/Computing in Cardiology (CinC) Challenge 2015 dataset for development and validation of our approach [@challenge]. 33 | Building off of work previously described in the literature [@plesinger], we make use of signal processing and machine learning techniques to identify true and false alarms for five arrhythmia types. 34 | This baseline algorithm alone is able to perform remarkably well, with a sensitivity of 0.908, a specificity of 0.838, and a PhysioNet/CinC challenge score of 0.756 [@challenge]. 35 | We additionally explore dynamic time warping techniques on both the entire alarm signal as well as on a beat-by-beat basis in an effort to improve performance of ventricular tachycardia, which has in the literature been one of the hardest arrhythmias to classify. Such an algorithm with strong performance and efficiency could potentially be translated for use in the ICU to promote overall patient care and recovery. 36 | The software is published to zenodo with DOI 'http://dx.doi.org/10.5281/zenodo.889036' [@fareduction]. 37 | 38 | # References 39 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | ********************* 2 | Contributing to pyfar 3 | ********************* 4 | 5 | We welcome all contributions to the package! 6 | 7 | .. contents:: Table of contents: 8 | :local: 9 | 10 | 11 | Where to start? 12 | =============== 13 | 14 | Bug reports, bug fixes, documentation improvements, and other contributions 15 | are welcome. 16 | 17 | For reporting bugs or suggesting improvements, please use the `GitHub issues 18 | tab `_. 19 | 20 | Bug reports 21 | =========== 22 | 23 | Bug reports are core to ensuring the package remains useful for all users. 24 | A complete bug report greatly improves the ability of others to understand and 25 | fix it. For information on how to make a complete bug report, we recommend 26 | you review `this helpful StackOverflow article `_. 27 | 28 | Contributing improvements 29 | ========================= 30 | 31 | Bug fixes or other enhancements are welcome via pull requests. You can `read more 32 | about pull requests on GitHub's website `_. 33 | 34 | Contributing to the documentation 35 | ================================= 36 | 37 | Rewriting small pieces of the documentation as you read through it is a 38 | surefire way of improving them for the next user. 39 | 40 | About the documentation 41 | ----------------------- 42 | 43 | The documentation is written in *reStructuredText*, and subsequently built 44 | using the Python package `Sphinx `__. The Sphinx 45 | documentation provides `a gentle introduction to 46 | reStructuredText `__. 47 | 48 | The documentation follows the 49 | `NumPy Docstring Standard `__, 50 | which are parsed using the 51 | `napolean extension for sphinx `. 52 | 53 | How to build the documentation 54 | ------------------------------ 55 | 56 | Requirements 57 | ^^^^^^^^^^^^ 58 | 59 | To build the documentation you will need to additionally install ``sphinx``. 60 | Furthermore, you'll also need to install the readthedocs theme. 61 | This is easily done using pip:: 62 | 63 | pip install sphinx sphinx_rtd_theme 64 | 65 | Building the documentation 66 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 67 | 68 | Navigate to the ``docs`` subfolder and run:: 69 | 70 | sphinx-build -b html . _build 71 | 72 | Which will build the documentation in the subfolder ``_build``. 73 | Alternatively, you can run the Makefile provided:: 74 | 75 | make html 76 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Always prefer setuptools over distutils 2 | from setuptools import setup, find_packages 3 | 4 | # To use a consistent encoding 5 | from codecs import open 6 | from os import path 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | # Get the long description from the README file 11 | with open(path.join(here, 'README.rst'), encoding='utf-8') as f: 12 | long_description = f.read() 13 | 14 | setup( 15 | name='pyfar', 16 | 17 | # Versions should comply with PEP440. For a discussion on single-sourcing 18 | # the version across setup.py and the project code, see 19 | # https://packaging.python.org/en/latest/single_source_version.html 20 | version='0.1.0', 21 | 22 | description='False Alarm Reduction for physiologic waveform alarms', 23 | long_description=long_description, 24 | 25 | # The project's main homepage. 26 | url='https://github.com/MIT-LCP/false-alarm-reduction', 27 | 28 | # Author details 29 | author='Andrea Li', 30 | author_email='liandrea@mit.edu', 31 | 32 | # Choose your license 33 | license='MIT', 34 | 35 | # What does your project relate to? 36 | keywords='false alarm reduction ICU', 37 | 38 | # You can just specify the packages manually here if your project is 39 | # simple. Or you can use find_packages(). 40 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), 41 | 42 | # List run-time dependencies here. These will be installed by pip when 43 | # your project is installed. For an analysis of "install_requires" vs pip's 44 | # requirements files see: 45 | # https://packaging.python.org/en/latest/requirements.html 46 | install_requires=[ 47 | 'fastdtw>=0.3.2' 48 | ,'matplotlib>=2.0.2' 49 | ,'numpy>=1.13.1' 50 | ,'pandas>=0.20.3' 51 | ,'PeakUtils>=1.1.0' 52 | ,'scikit-learn>=0.19.0' 53 | ,'scipy>=0.19.0' 54 | ,'sklearn>=0.0' 55 | ,'spectrum>=0.7.1' 56 | ,'wfdb==1.2.2' 57 | ], 58 | 59 | 60 | 61 | # List additional groups of dependencies here (e.g. development 62 | # dependencies). You can install these using the following syntax, 63 | # for example: 64 | # $ pip install -e .[dev,test] 65 | # extras_require={ 66 | # 'dev': ['check-manifest'], 67 | # 'test': ['coverage'], 68 | # }, 69 | 70 | # If there are data files included in your packages that need to be 71 | # installed, specify them here. If using Python 2.6 or less, then these 72 | # have to be included in MANIFEST.in as well. 73 | package_data={'pyfar': ['sigtypes', 'vtach_beats.csv']}, 74 | 75 | # Although 'package_data' is the preferred approach, in some case you may 76 | # need to place data files outside of your packages. See: 77 | # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa 78 | # In this case, 'data_file' will be installed into '/my_data' 79 | # data_files=[('my_data', ['data/data_file'])], 80 | 81 | # To provide executable scripts, use entry points in preference to the 82 | # "scripts" keyword. Entry points provide cross-platform support and allow 83 | # pip to create the appropriate form of executable for the target platform. 84 | # entry_points={ 85 | # 'console_scripts': [ 86 | # 'sample=sample:main', 87 | # ], 88 | # }, 89 | 90 | ) 91 | -------------------------------------------------------------------------------- /matlab/run_on_challenge_data.m: -------------------------------------------------------------------------------- 1 | data_path = '/data/challenge-2015/data/'; 2 | addpath(genpath('/home/alistairewj/git/peak-detector')); 3 | addpath('/data/challenge-2015/wfdb-app-toolbox-0-9-9/mcode/'); 4 | fp = fopen([data_path 'ALARMS'],'r'); 5 | alarms=textscan(fp,'%s%s%d','delimiter',','); 6 | fclose(fp); 7 | records=alarms{1}; 8 | targets=alarms{3}; 9 | alarms=alarms{2}; 10 | 11 | % define input options for the peak detector 12 | % all of the options listed here are the default values, and are optionally omitted 13 | opt = struct(... 14 | 'SIZE_WIND',10,... % define the window for the bSQI check on the ECG 15 | 'LG_MED',3,... % take the median SQI using X nearby values, so if LG_MED = 3, we take the median of the 3 prior and 3 posterior windows 16 | 'REG_WIN',1,... % how frequently to check the SQI for switching - i.e., if REG_WIN = 1, then we check the signals every second to switch 17 | 'THR',0.150,... % the width, in seconds, used when comparing peaks in the F1 based ECG SQI 18 | 'SQI_THR',0.8,... % the SQI threshold - we switch signals if SQI < this value 19 | 'USE_PACING',1,... % flag turning on/off the pacing detection/correction 20 | 'ABPMethod','wabp',... % ABP peak detection method (wabp, delineator) 21 | 'SIMPLEMODE', 0,... % simple mode only uses the first ABP and ECG signal, and ignores all others 22 | 'DELAYALG', 'map',... % algorithm used to determine the delay between the ABP and the ECG signal 23 | 'SAVE_STUFF', 0,... % leave temporary files in working directory 24 | ... % jqrs parameters - the custom peak detector implemented herein 25 | 'JQRS_THRESH', 0.3,... % energy threshold for defining peaks 26 | 'JQRS_REFRAC', 0.25,... % refractory period in seconds 27 | 'JQRS_INTWIN_SZ', 7,... 28 | 'JQRS_WINDOW', 15); 29 | 30 | % copy file to this folder 31 | for i = 1:numel(records) 32 | recordName = records{i}; 33 | 34 | % create symlinks for data 35 | system(['ln -frs ' data_path recordName '.mat ' recordName '.mat']); 36 | system(['ln -frs ' data_path recordName '.hea ' recordName '.hea']); 37 | 38 | % load data 39 | [t,data] = rdsamp(recordName); 40 | [siginfo,fs] = wfdbdesc(recordName); 41 | 42 | % extract info from structure output by wfdbdesc 43 | header = arrayfun(@(x) x.Description, siginfo, 'UniformOutput', false); 44 | 45 | % run SQI based switching 46 | [ qrs, sqi, qrs_comp, qrs_header ] = detect_sqi_nowriteout(recordName, data, header, fs, opt); 47 | if isempty(qrs) 48 | wrann(recordName,'aqrs',0); 49 | else 50 | wrann(recordName,'aqrs',floor(qrs*fs(1))); 51 | end 52 | system(['mv ' recordName '.aqrs /data/challenge-2015/ann/' recordName '.aqrs']); 53 | 54 | [ idxECG, idxABP, idxPPG, idxSV ] = getSignalIndices(header); 55 | idxECG = idxECG(:)'; 56 | if ~isempty(idxECG) 57 | for m = idxECG 58 | opt.LG_REC = size(data,1) ./ fs(m); % length of the record in seconds 59 | opt.N_WIN = ceil(opt.LG_REC/opt.REG_WIN); % number of windows in the signal 60 | ann_jqrs = run_qrsdet_by_seg_ali(data(:,m),fs(m),opt); 61 | if isempty(ann_jqrs) 62 | fprintf('%s - empty signal.\n',recordName); 63 | %system('rm tmp; touch tmp'); 64 | % make an empty annotation file for jqrs 65 | 66 | %system(['wrann -r ' recordName ' -a jqrs `_, run 14 | `conda `_:: 15 | 16 | conda install pyfar -c conda-forge 17 | 18 | Pip 19 | ~~~ 20 | 21 | Or install pyfar with ``pip``:: 22 | 23 | pip install pyfar --upgrade 24 | 25 | Source 26 | ~~~~~~ 27 | 28 | To install pyfar from source, clone the repository from `github 29 | `_:: 30 | 31 | git clone https://github.com/MIT-LCP/false-alarm-reduction.git 32 | cd false-alarm-reduction 33 | python setup.py install 34 | 35 | 36 | Detailed instructions 37 | --------------------- 38 | 39 | These instructions were tested on Ubuntu 16.04. 40 | 41 | Install Virtual Environment (Python 2.7) 42 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 43 | 44 | We recommend installing the package in a virtual environment. If not 45 | using a virtual environment, skip ahead to the 46 | ``pip install pip --upgrade`` step. 47 | 48 | First, install pip and virtualenv as follows: 49 | 50 | :: 51 | 52 | sudo apt-get install python-pip python-dev python-virtualenv 53 | 54 | Create a virtual environment 55 | 56 | :: 57 | 58 | virtualenv --system-site-packages TARGET_DIRECTORY 59 | 60 | Where TARGET\_DIRECTORY is the desired location of the virtual 61 | environment. Here, we assume it is \`\ ``~/false-alarm-reduction``. 62 | 63 | Activate the virtual environment 64 | 65 | :: 66 | 67 | source ~/false-alarm-reduction/bin/activate 68 | 69 | Now you should be working in the virtual environment. Verify pip is 70 | installed: 71 | 72 | :: 73 | 74 | (false-alarm-reduction)$ easy_install -U pip 75 | 76 | Upgrade pip. 77 | 78 | :: 79 | 80 | (false-alarm-reduction)$ pip install pip --upgrade 81 | 82 | Now install all the necessary packages using the requirements file (for 83 | reference, these are: ``wfdb``, ``numpy``, ``scipy``, ``matplotlib``, 84 | ``sklearn``, ``fastdtw``, ``spectrum``, ``peakutils``). 85 | 86 | :: 87 | 88 | pip install -r requirements.txt 89 | 90 | Download data 91 | ~~~~~~~~~~~~~ 92 | 93 | Two convenience scripts are provided to download the data. Run these 94 | from the main folder as follows: 95 | 96 | :: 97 | 98 | sh download/download_annotations.sh 99 | sh download/download_data.sh 100 | 101 | These scripts download the following data: 102 | 103 | - training.zip https://physionet.org/challenge/2015/training.zip 104 | - ann\_gqrs0.zip 105 | https://www.dropbox.com/sh/hv4uat0ihwlygq8/AABvdXbSGZi3COPG-O\_-nBGxa/ann\_gqrs0.zip?dl=1 106 | - ann\_gqrs1.zip 107 | https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAAAb14a\_NN8iKojXEoInXCGa/ann\_gqrs1.zip?dl=1 108 | - ann\_wabp.zip 109 | https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAALSmteHaL0gQovwXj8CXV4a/ann\_wabp.zip?dl=1 110 | - ann\_wpleth.zip 111 | https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAAko1RNvgmdhWF7lNux-Ob3a/ann\_wpleth.zip?dl=1 112 | 113 | Afterward, you should have the following directory structure: 114 | 115 | - ``annotations`` subfolder with all annotations (*.gqrs0, *.gqrs1, 116 | *.wabp, *.wpleth) 117 | - ``data/training`` subfolder with all data (*.mat and *.hea) and a 118 | RECORDS file 119 | 120 | Brief instructions for other OS 121 | ------------------------------- 122 | 123 | 1. Download and install the following packages: ``wfdb``, ``numpy``, 124 | ``scipy``, ``matplotlib``, ``sklearn``, ``fastdtw``, ``spectrum``, 125 | ``peakutils``. 126 | 2. Download data and annotations 127 | 128 | - training.zip https://physionet.org/challenge/2015/training.zip 129 | - ann\_gqrs0.zip 130 | https://www.dropbox.com/sh/hv4uat0ihwlygq8/AABvdXbSGZi3COPG-O\_-nBGxa/ann\_gqrs0.zip?dl=1 131 | - ann\_gqrs1.zip 132 | https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAAAb14a\_NN8iKojXEoInXCGa/ann\_gqrs1.zip?dl=1 133 | - ann\_wabp.zip 134 | https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAALSmteHaL0gQovwXj8CXV4a/ann\_wabp.zip?dl=1 135 | - ann\_wpleth.zip 136 | https://www.dropbox.com/sh/hv4uat0ihwlygq8/AAAko1RNvgmdhWF7lNux-Ob3a/ann\_wpleth.zip?dl=1 137 | 138 | 3. Data should be unzipped into ``data/`` (ultimately the files will be 139 | in ``data/training/``) 140 | 4. Annotations should be unzipped into ``annotations/`` 141 | -------------------------------------------------------------------------------- /docs/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ========== 3 | 4 | Install 5 | ------- 6 | 7 | :: 8 | 9 | $ pip install false-alarm-reduction 10 | 11 | See :doc:`installation ` document for more information. 12 | 13 | 14 | Acquire data 15 | ------------ 16 | 17 | The easiest way to understand what this package does is to evaluate it on 18 | physiologic waveforms. The 2015 PhysioNet/Computing in Cardiology Challenge 19 | focused on false alarm reduction and provides a useful dataset to work with. 20 | 21 | To download this dataset, run the download shell script: 22 | 23 | .. code-block:: bash 24 | 25 | bash download/download_data.sh 26 | 27 | This will download data into the ``data`` subfolder using ``wget`` and decompress 28 | the individual files. 29 | 30 | See the :doc:`Challenge 2016 data ` document for more 31 | information on the dataset. 32 | 33 | More detail on the dataset can be found on the `Challenge 2016 website`__. 34 | 35 | .. _challenge2016: https://physionet.org/challenge/2016/ 36 | 37 | __ challenge2016_ 38 | 39 | Acquire R-peak annotations 40 | -------------------------- 41 | 42 | R-peak annotations indicate where in the electrocardiogram (ECG) heart beat 43 | cycle the "R" peak is estimated to be. You can read about the different waves 44 | in the ECG from `ecgpedia basics`__. 45 | 46 | .. _ecgpediabasics: http://en.ecgpedia.org/index.php?title=Basics#The_different_ECG_waves 47 | 48 | __ ecgpediabasics_ 49 | 50 | R-peak annotations can be acquired in two ways: (1) downloading them directly, 51 | or (2) generating them from the data. 52 | 53 | Downloading R-peak annotations 54 | ------------------------------ 55 | 56 | Run the bash script to download annotations: 57 | 58 | .. code-block:: bash 59 | 60 | bash download/download_annotations.sh 61 | 62 | Generating R-peak annotations 63 | ----------------------------- 64 | 65 | Annotations can be regenerated from the data itself. Two annotation software 66 | tools are required: GQRS and WABP. GQRS uses the ECG to identify R-peaks, while 67 | WABP uses pulsatile waveforms such as the arterial blood pressure (ABP) waveform 68 | or the photoplethysmogram (PPG) to identify these peaks. 69 | 70 | First, install the WFDB toolbox from `PhysioNet`__. 71 | 72 | .. _wfdb: https://physionet.org/tools 73 | 74 | __ wfdb_ 75 | 76 | Then, run the annotation script. This iterates through all data records and generates annotations. 77 | 78 | .. code-block:: bash 79 | 80 | cd annotations 81 | bash annotate_data.sh 82 | 83 | 84 | Running the code 85 | ---------------- 86 | 87 | Baseline algorithm 88 | ~~~~~~~~~~~~~~~~~~ 89 | 90 | You can run the main "baseline" algorithm using: 91 | 92 | ``python pipeline.py`` 93 | 94 | This will run through all the data files and associated annotation files 95 | to detect/flag false alarms. The output files are written at the end of 96 | the algorithm to the current directory. For the baseline algorithm, this 97 | is by default ``results.json``. 98 | 99 | Note: To run using a different QRS detector (e.g. JQRS instead of GQRS), 100 | change ``ecg_ann_type``, e.g. ``ecg_ann_type = "jqrs"``. See the 101 | ``matlab/`` subfolder for code to generate JQRS (note this code is 102 | untested!). 103 | 104 | DTW time warping 105 | ~~~~~~~~~~~~~~~~ 106 | 107 | To run the DTW algorithm on the alarm signal, update the 108 | ``matrix_filename`` and ``distances_filename`` variables in 109 | ``parameters.py`` to be the filenames to output the final confusion 110 | matrix and corresponding distance results, respectively. Then, call the 111 | algorithm as ``python dtw.py``. 112 | 113 | (Experimental) Using ventricular/normal beat banks 114 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 115 | 116 | - DTW algorithm beat-by-beat (bank): update ``output_path_bank`` in 117 | ``parameters.py`` to be the path to the folder desired for 118 | ventricular beat annotations via standard beat comparisons. Then run 119 | ``ventricular_beat_bank.py``. In ``baseline_algorithm.py``, comment 120 | out lines 916-917 and uncomment line 918. Make sure that the 121 | ``output_path`` on line 869 in the 122 | ``read_ventricular_beat_annotations`` function is set to 123 | ``parameters.output_path_bank`` in ``baseline_algorithm.py``. 124 | - DTW algorithm beat-by-beat (standard deviation), update 125 | ``output_path_std`` in ``parameters.py`` to be the path to the folder 126 | desired for ventricular beat annotations via standard deviation 127 | calculations. Then run ``ventricular_beat_std.py``. In 128 | ``baseline_algorithm.py``, comment out lines 916-917 and uncomment 129 | line 918. Make sure that the ``output_path`` on line 869 in the 130 | ``read_ventricular_beat_annotations`` function is set to 131 | ``parameters.output_path_std`` in ``baseline_algorithm.py``. 132 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # False Alarm Reduction documentation build configuration file, created by 4 | # sphinx-quickstart on Tue Feb 13 09:57:24 2018. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another directory, 16 | # add these directories to sys.path here. If the directory is relative to the 17 | # documentation root, use os.path.abspath to make it absolute, like shown here. 18 | import os 19 | import sys 20 | sys.path.insert(0, os.path.abspath('../pyfar/')) 21 | 22 | 23 | # -- General configuration ------------------------------------------------ 24 | 25 | # If your documentation needs a minimal Sphinx version, state it here. 26 | # 27 | # needs_sphinx = '1.0' 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = ['sphinx.ext.autodoc', 33 | 'sphinx.ext.coverage', 34 | 'sphinx.ext.githubpages', 35 | 'sphinx.ext.napoleon'] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ['_templates'] 39 | 40 | # The suffix(es) of source filenames. 41 | # You can specify multiple suffix as a list of string: 42 | # 43 | # source_suffix = ['.rst', '.md'] 44 | source_suffix = '.rst' 45 | 46 | # The master toctree document. 47 | master_doc = 'index' 48 | 49 | # General information about the project. 50 | project = u'False Alarm Reduction' 51 | copyright = u'2018, Andrea Li, Alistair E. W. Johnson, Roger G. Mark' 52 | author = u'Andrea Li, Alistair E. W. Johnson, Roger G. Mark' 53 | 54 | # The version info for the project you're documenting, acts as replacement for 55 | # |version| and |release|, also used in various other places throughout the 56 | # built documents. 57 | # 58 | # The short X.Y version. 59 | version = u'0.1' 60 | # The full version, including alpha/beta/rc tags. 61 | release = u'0' 62 | 63 | # The language for content autogenerated by Sphinx. Refer to documentation 64 | # for a list of supported languages. 65 | # 66 | # This is also used if you do content translation via gettext catalogs. 67 | # Usually you set "language" from the command line for these cases. 68 | language = None 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This patterns also effect to html_static_path and html_extra_path 73 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 74 | 75 | # The name of the Pygments (syntax highlighting) style to use. 76 | pygments_style = 'sphinx' 77 | 78 | # If true, `todo` and `todoList` produce output, else they produce nothing. 79 | todo_include_todos = False 80 | 81 | # -- Options for HTML output ---------------------------------------------- 82 | 83 | # The theme to use for HTML and HTML Help pages. See the documentation for 84 | # a list of builtin themes. 85 | html_theme = 'sphinx_rtd_theme' 86 | html_theme_path = ["_themes", ] 87 | 88 | # Theme options are theme-specific and customize the look and feel of a theme 89 | # further. For a list of options available for each theme, see the 90 | # documentation. 91 | # 92 | html_theme_options = { 93 | 'canonical_url': '', 94 | 'analytics_id': '', 95 | 'logo_only': False, 96 | 'display_version': True, 97 | 'prev_next_buttons_location': 'bottom', 98 | # Toc options 99 | 'collapse_navigation': False, 100 | 'sticky_navigation': True, 101 | 'navigation_depth': 4 102 | } 103 | 104 | # Add any paths that contain custom static files (such as style sheets) here, 105 | # relative to this directory. They are copied after the builtin static files, 106 | # so a file named "default.css" will overwrite the builtin "default.css". 107 | html_static_path = ['_static'] 108 | 109 | 110 | # -- Options for HTMLHelp output ------------------------------------------ 111 | 112 | # Output file base name for HTML help builder. 113 | htmlhelp_basename = 'FalseAlarmReductiondoc' 114 | 115 | 116 | # -- Options for LaTeX output --------------------------------------------- 117 | 118 | latex_elements = { 119 | # The paper size ('letterpaper' or 'a4paper'). 120 | # 121 | # 'papersize': 'letterpaper', 122 | 123 | # The font size ('10pt', '11pt' or '12pt'). 124 | # 125 | # 'pointsize': '10pt', 126 | 127 | # Additional stuff for the LaTeX preamble. 128 | # 129 | # 'preamble': '', 130 | 131 | # Latex figure (float) alignment 132 | # 133 | # 'figure_align': 'htbp', 134 | } 135 | 136 | # Grouping the document tree into LaTeX files. List of tuples 137 | # (source start file, target name, title, 138 | # author, documentclass [howto, manual, or own class]). 139 | latex_documents = [ 140 | (master_doc, 'FalseAlarmReduction.tex', u'False Alarm Reduction Documentation', 141 | u'Andrea Li, Alistair E. W. Johnson, Roger G. Mark', 'manual'), 142 | ] 143 | 144 | 145 | # -- Options for manual page output --------------------------------------- 146 | 147 | # One entry per manual page. List of tuples 148 | # (source start file, name, description, authors, manual section). 149 | man_pages = [ 150 | (master_doc, 'falsealarmreduction', u'False Alarm Reduction Documentation', 151 | [author], 1) 152 | ] 153 | 154 | 155 | # -- Options for Texinfo output ------------------------------------------- 156 | 157 | # Grouping the document tree into Texinfo files. List of tuples 158 | # (source start file, target name, title, author, 159 | # dir menu entry, description, category) 160 | texinfo_documents = [ 161 | (master_doc, 'FalseAlarmReduction', u'False Alarm Reduction Documentation', 162 | author, 'FalseAlarmReduction', 'One line description of project.', 163 | 'Miscellaneous'), 164 | ] 165 | -------------------------------------------------------------------------------- /pyfar/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import wfdb 4 | import json 5 | 6 | def abs_value(x, y): 7 | return abs(x-y) 8 | 9 | def is_true_alarm_fields(fields): 10 | return fields['comments'][1] == 'True alarm' 11 | 12 | 13 | def is_true_alarm(data_path, sample_name): 14 | sig, fields = wfdb.srdsamp(data_path + sample_name) 15 | return is_true_alarm_fields(fields) 16 | 17 | # start and end in seconds 18 | def get_annotation(sample, ann_type, ann_fs, start, end): 19 | try: 20 | annotation = wfdb.rdann(sample, ann_type, sampfrom=start*ann_fs, sampto=end*ann_fs) 21 | except Exception as e: 22 | annotation = [] 23 | print(e) 24 | 25 | return annotation 26 | 27 | ## Returns type of arrhythmia alarm 28 | # output types include: 'a', 'b', 't', 'v', 'f' 29 | def get_arrhythmia_type(fields): 30 | """Returns type of arrhythmia based on fields of the sample 31 | 32 | Arguments 33 | --------- 34 | fields: fields of sample read from wfdb.rdsamp 35 | 36 | Returns 37 | ------- 38 | Type of arrhythmia 39 | 'a': asystole 40 | 'b': bradycardia 41 | 't': tachycardia 42 | 'f': ventricular fibrillation 43 | 'v': ventricular tachycardia 44 | """ 45 | 46 | arrhythmias = { 47 | 'Asystole': 'a', 48 | 'Bradycardia': 'b', 49 | 'Tachycardia': 't', 50 | 'Ventricular_Tachycardia': 'v', 51 | 'Ventricular_Flutter_Fib': 'f' 52 | } 53 | 54 | arrhythmia_type = fields['comments'][0] 55 | return arrhythmias[arrhythmia_type] 56 | 57 | 58 | def get_channel_type(channel_name, sigtypes_filename): 59 | """Returns type of channel 60 | 61 | Arguments 62 | --------- 63 | channel_name: name of channel (e.g. "II", "V", etc.) 64 | 65 | sigtypes_filename: file mapping channel names to channel 66 | types 67 | 68 | Returns 69 | ------- 70 | Type of channel (e.g. "ECG", "BP", "PLETH", "Resp") 71 | """ 72 | 73 | channel_types_dict = {} 74 | with open(sigtypes_filename, "r") as f: 75 | for line in f: 76 | splitted_line = line.split("\t") 77 | channel = splitted_line[-1].rstrip() 78 | channel_type = splitted_line[0] 79 | channel_types_dict[channel] = channel_type 80 | 81 | if channel_name in channel_types_dict.keys(): 82 | return channel_types_dict[channel_name] 83 | 84 | raise Exception("Unknown channel name") 85 | 86 | 87 | def get_samples_of_type(samples_dict, arrhythmia_type): 88 | """Returns a sub-dictionary of only the given arrhythmia type 89 | 90 | Arguments 91 | --------- 92 | samples_dict: dictionary mapping sample names to data associated 93 | with the given sample 94 | 95 | arrhythmia_type: 96 | 'a': asystole 97 | 'b': bradycardia 98 | 't': tachycardia 99 | 'f': ventricular fibrillation 100 | 'v': ventricular tachycardia 101 | 102 | Returns 103 | ------- 104 | a sub-dictionary with keys of only the given arrhythmia 105 | """ 106 | 107 | subdict = {} 108 | 109 | for sample_name in samples_dict.keys(): 110 | if sample_name[0] == arrhythmia_type: 111 | subdict[sample_name] = samples_dict[sample_name] 112 | 113 | return subdict 114 | 115 | 116 | def write_json(dictionary, filename): 117 | with open(filename, "w") as f: 118 | json.dump(dictionary, f) 119 | 120 | 121 | def read_json(filename): 122 | with open(filename, "r") as f: 123 | dictionary = json.load(f) 124 | return dictionary 125 | 126 | 127 | def get_classification_accuracy(matrix): 128 | num_correct = len(matrix["TP"]) + len(matrix["TN"]) 129 | num_total = len(matrix["FP"]) + len(matrix["FN"]) + num_correct 130 | 131 | return float(num_correct) / num_total 132 | 133 | 134 | def calc_sensitivity(counts): 135 | tp = counts["TP"] 136 | fn = counts["FN"] 137 | return tp / float(tp + fn) 138 | 139 | 140 | def calc_specificity(counts): 141 | tn = counts["TN"] 142 | fp = counts["FP"] 143 | 144 | return tn / float(tn + fp) 145 | 146 | 147 | def calc_ppv(counts): 148 | tp = counts["TP"] 149 | fp = counts["FP"] 150 | return tp / float(tp + fp) 151 | 152 | 153 | def calc_npv(counts): 154 | tn = counts["TN"] 155 | fn = counts["FN"] 156 | return tn / float(tn + fn) 157 | 158 | 159 | def calc_f1(counts): 160 | sensitivity = calc_sensitivity(counts) 161 | ppv = calc_ppv(counts) 162 | 163 | return 2 * sensitivity * ppv / float(sensitivity + ppv) 164 | 165 | 166 | def print_stats(counts): 167 | try: 168 | sensitivity = calc_sensitivity(counts) 169 | specificity = calc_specificity(counts) 170 | ppv = calc_ppv(counts) 171 | npv = calc_npv(counts) 172 | f1 = calc_f1(counts) 173 | except Exception as e: 174 | print(e) 175 | 176 | print("counts: ", counts) 177 | print("sensitivity: ", sensitivity) 178 | print("specificity: ", specificity) 179 | print("ppv: ", ppv) 180 | print("npv: ", npv) 181 | print("f1: ", f1) 182 | 183 | 184 | def get_matrix_classification(actual, predicted): 185 | if actual and predicted: 186 | return "TP" 187 | elif actual and not predicted: 188 | return "FN" 189 | elif not actual and predicted: 190 | return "FP" 191 | return "TN" 192 | 193 | 194 | def get_score(matrix): 195 | numerator = len(matrix["TP"]) + len(matrix["TN"]) 196 | denominator = len(matrix["FP"]) + 5*len(matrix["FN"]) + numerator 197 | 198 | return float(numerator) / denominator 199 | 200 | 201 | def get_by_arrhythmia(confusion_matrix, arrhythmia_prefix): 202 | counts_by_arrhythmia = {} 203 | matrix_by_arrhythmia = {} 204 | for classification_type in confusion_matrix.keys(): 205 | sample_list = [ sample for sample in confusion_matrix[classification_type] if sample[0] == arrhythmia_prefix] 206 | counts_by_arrhythmia[classification_type] = len(sample_list) 207 | matrix_by_arrhythmia[classification_type] = sample_list 208 | 209 | return counts_by_arrhythmia, matrix_by_arrhythmia 210 | -------------------------------------------------------------------------------- /pyfar/pipeline.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from datetime import datetime 4 | import numpy as np 5 | from baseline_algorithm import * 6 | from parameters import * 7 | import os 8 | import csv 9 | import json 10 | import wfdb 11 | 12 | ## Classifying arrhythmia alarms 13 | 14 | # Returns true if alarm is a true alarm 15 | # Only for samples with known classification 16 | def is_true_alarm(data_path, sample_name): 17 | sig, fields = wfdb.srdsamp(data_path + sample_name) 18 | true_alarm = fields['comments'][1] == 'True alarm' 19 | return true_alarm 20 | 21 | # Generate confusion matrix for all samples given sample name/directory 22 | def generate_confusion_matrix_dir(data_path, ann_path, ecg_ann_type): 23 | confusion_matrix = { 24 | "TP": [], 25 | "FP": [], 26 | "FN": [], 27 | "TN": [] 28 | } 29 | 30 | for filename in os.listdir(data_path): 31 | if filename.endswith(HEADER_EXTENSION): 32 | sample_name = filename.rstrip(HEADER_EXTENSION) 33 | 34 | if sample_name[0] != 'v': 35 | continue 36 | 37 | print("sample name: {}".format(sample_name)) 38 | 39 | # sig, fields = wfdb.srdsamp(data_path + sample_name) 40 | # if "II" not in fields['signame']: 41 | # continue 42 | 43 | true_alarm = is_true_alarm(data_path, sample_name) 44 | classified_true_alarm = classify_alarm(data_path, ann_path, sample_name, ecg_ann_type) 45 | 46 | matrix_classification = get_confusion_matrix_classification(true_alarm, classified_true_alarm) 47 | confusion_matrix[matrix_classification].append(sample_name) 48 | if matrix_classification == "FN": 49 | print("FALSE NEGATIVE: {}".format(filename)) 50 | 51 | return confusion_matrix 52 | 53 | 54 | def get_confusion_matrix_classification(true_alarm, classified_true_alarm): 55 | if true_alarm and classified_true_alarm: 56 | matrix_classification = "TP" 57 | 58 | elif true_alarm and not classified_true_alarm: 59 | matrix_classification = "FN" 60 | 61 | elif not true_alarm and classified_true_alarm: 62 | matrix_classification = "FP" 63 | 64 | else: 65 | matrix_classification = "TN" 66 | 67 | return matrix_classification 68 | 69 | 70 | ## Printing and calculating counts 71 | 72 | def print_by_type(false_negatives): 73 | counts_by_type = {} 74 | for false_negative in false_negatives: 75 | first = false_negative[0] 76 | if first not in counts_by_type.keys(): 77 | counts_by_type[first] = 0 78 | counts_by_type[first] += 1 79 | 80 | print(counts_by_type) 81 | 82 | 83 | def print_by_arrhythmia(confusion_matrix, arrhythmia_prefix): 84 | counts_by_arrhythmia = {} 85 | for classification_type in confusion_matrix.keys(): 86 | sample_list = [ sample for sample in confusion_matrix[classification_type] if sample[0] == arrhythmia_prefix] 87 | counts_by_arrhythmia[classification_type] = (len(sample_list), sample_list) 88 | 89 | print(counts_by_arrhythmia) 90 | 91 | def get_counts(confusion_matrix): 92 | return { key : len(confusion_matrix[key]) for key in confusion_matrix.keys() } 93 | 94 | 95 | def calc_sensitivity(counts): 96 | tp = counts["TP"] 97 | fn = counts["FN"] 98 | return tp / float(tp + fn) 99 | 100 | def calc_specificity(counts): 101 | tn = counts["TN"] 102 | fp = counts["FP"] 103 | 104 | return tn / float(tn + fp) 105 | 106 | def calc_ppv(counts): 107 | tp = counts["TP"] 108 | fp = counts["FP"] 109 | return tp / float(tp + fp) 110 | 111 | def calc_f1(counts): 112 | sensitivity = calc_sensitivity(counts) 113 | ppv = calc_ppv(counts) 114 | 115 | return 2 * sensitivity * ppv / float(sensitivity + ppv) 116 | 117 | def print_stats(counts): 118 | sensitivity = calc_sensitivity(counts) 119 | specificity = calc_specificity(counts) 120 | ppv = calc_ppv(counts) 121 | f1 = calc_f1(counts) 122 | score = float(counts["TP"] + counts["TN"])/(counts["TP"] + counts["FP"] + counts["TN"] + counts["FN"] * 5) 123 | 124 | print("counts: {}".format(counts)) 125 | print("sensitivity: {}".format(sensitivity)) 126 | print("specificity: {}".format(specificity)) 127 | print("ppv: {}".format(ppv)) 128 | print("f1: {}".format(f1)) 129 | print("score: {}".format(score)) 130 | 131 | 132 | ## Run pipeline 133 | 134 | def run(data_path, ann_path, filename, ecg_ann_type): 135 | print("ecg_ann_type: {}".format(ecg_ann_type)) 136 | print(" ann_path: {}".format(ann_path)) 137 | 138 | start = datetime.now() 139 | matrix = generate_confusion_matrix_dir(data_path, ann_path, ecg_ann_type) 140 | print("confusion matrix: {}".format(matrix)) 141 | print("total time: {}".format(datetime.now() - start)) 142 | 143 | with open(filename, "w") as f: 144 | json.dump(matrix, f) 145 | 146 | def read_json(filename): 147 | with open(filename, "r") as f: 148 | dictionary = json.load(f) 149 | 150 | return dictionary 151 | 152 | # print(datetime.now()) 153 | # write_filename = "sample_data/pipeline_fpinvalids_vtachfpann_nancheck.json" 154 | # ecg_ann_type = "gqrs" 155 | # run(data_path, ann_path, write_filename, ecg_ann_type) 156 | 157 | 158 | if __name__ == '__main__': 159 | run(data_path, ann_path, write_filename, ecg_ann_type) 160 | 161 | matrix = read_json(write_filename) 162 | counts = get_counts(matrix) 163 | print_stats(counts) 164 | 165 | # matrix = read_json("../sample_data/baseline_performance/vtach_gqrs.json") 166 | # counts = get_counts(matrix) 167 | # print_stats(counts) 168 | 169 | # fplesinger_confusion_matrix = others_confusion_matrices['fplesinger-210'] 170 | # print("missed true positives: {}".format(get_missed(gqrs_matrix, fplesinger_confusion_matrix, "TP"))) 171 | # print("missed true negatives: {}".format(get_missed(gqrs_matrix, fplesinger_confusion_matrix, "TN"))) 172 | 173 | 174 | # print("\nFP") 175 | # fp_matrix = read_json("sample_data/pipeline_fp.json") 176 | # counts_fp = get_counts(fp_matrix) 177 | # evaluate.print_stats(counts_fp) 178 | # print_by_type(fp_matrix['FN']) 179 | 180 | 181 | # print("\nFP invalids with GQRS") 182 | # fpinvalids_matrix = read_json("sample_data/pipeline_fpinvalids.json") 183 | # counts_fpinvalids = get_counts(fpinvalids_matrix) 184 | # evaluate.print_stats(counts_fpinvalids) 185 | # print_by_type(fpinvalids_matrix['FN']) 186 | 187 | # missed_true_negatives = get_missed(fpinvalids_matrix, fplesinger_confusion_matrix, "TN") 188 | # print("missed true positives: {}".format(get_missed(fpinvalids_matrix, fplesinger_confusion_matrix, "TP"))) 189 | # print("missed true negatives: {}".format(missed_true_negatives)) 190 | # print_by_type(missed_true_negatives) 191 | # print(len(missed_true_negatives)) 192 | 193 | 194 | # print("\nFP invalids with GQRS without abp test in vtach") 195 | # fpinvalids_without_vtach_abp = read_json("sample_data/pipeline_fpinvalids_novtachabp.json") 196 | # counts_fpinvalids_without_vtach_abp = get_counts(fpinvalids_without_vtach_abp) 197 | # evaluate.print_stats(counts_fpinvalids_without_vtach_abp) 198 | # print_by_type(fpinvalids_without_vtach_abp['FN']) 199 | 200 | # print_by_type(gqrs_matrix['FN']) 201 | # print_by_arrhythmia(confusion_matrix_gqrs, 'v') 202 | 203 | # fplesinger_confusion_matrix = others_confusion_matrices['fplesinger-210'] 204 | # print("missed true positives: {}".format(get_missed(gqrs_matrix, fplesinger_confusion_matrix, "TP"))) 205 | # print("missed true negatives: {}".format(get_missed(gqrs_matrix, fplesinger_confusion_matrix, "TN"))) 206 | 207 | 208 | ## Comparing classification with other algorithms 209 | 210 | # In[21]: 211 | 212 | # def generate_others_confusion_matrices(filename, data_path): 213 | # others_confusion_matrices = {} 214 | 215 | # with open(filename, "r") as f: 216 | # reader = csv.DictReader(f) 217 | # authors = reader.fieldnames[1:] 218 | # for author in authors: 219 | # others_confusion_matrices[author] = { "TP": [], "FP": [], "FN": [], "TN": [] } 220 | 221 | # for line in reader: 222 | # sample_name = line['record name'] 223 | # true_alarm = is_true_alarm(data_path, sample_name) 224 | 225 | # for author in authors: 226 | # classified_true_alarm = line[author] == '1' 227 | # matrix_classification = get_confusion_matrix_classification(true_alarm, classified_true_alarm) 228 | 229 | # others_confusion_matrices[author][matrix_classification].append(sample_name) 230 | 231 | # return others_confusion_matrices 232 | 233 | 234 | # filename = "sample_data/answers.csv" 235 | # others_confusion_matrices = generate_others_confusion_matrices(filename, data_path) 236 | 237 | 238 | # In[7]: 239 | 240 | # for author in others_confusion_matrices.keys(): 241 | # other_confusion_matrix = others_confusion_matrices[author] 242 | # print(author) 243 | # counts = get_counts(other_confusion_matrix) 244 | # evaluate.print_stats(counts) 245 | # print_by_type(other_confusion_matrix['FN']) 246 | 247 | 248 | # In[23]: 249 | 250 | # def get_missed(confusion_matrix, other_confusion_matrix, classification): 251 | # missed = [] 252 | 253 | # for sample in other_confusion_matrix[classification]: 254 | # if sample not in confusion_matrix[classification]: 255 | # missed.append(sample) 256 | 257 | # return missed 258 | 259 | # fplesinger_confusion_matrix = others_confusion_matrices['fplesinger-210'] 260 | # print("missed true positives: {}".format(get_missed(confusion_matrix_gqrs, fplesinger_confusion_matrix, "TP"))) 261 | # print("missed true negatives: {}".format(get_missed(confusion_matrix_gqrs, fplesinger_confusion_matrix, "TN"))) 262 | 263 | 264 | # In[ ]: 265 | -------------------------------------------------------------------------------- /pyfar/sigtypes: -------------------------------------------------------------------------------- 1 | BP ABP 2 | BP ABP 3 | BP ABP_1/2 4 | BP ABP_2/2 5 | BP ABPdias 6 | BP ABPDias 7 | BP ABPDiasTrend 8 | BP ABPmean 9 | BP ABPMean 10 | BP ABPMeanTrend 11 | BP ABPSQI 12 | BP ABPsys 13 | BP ABPSys 14 | BP ABPSysTrend 15 | BP AOBP_1/2 16 | BP ART 17 | BP ART 18 | BP ART 1 19 | BP ART1 20 | BP ART_1/2 21 | BP ART1dias 22 | BP ART1mean 23 | BP ART1sys 24 | BP ART_2/2 25 | BP ARTdias 26 | BP ART^M 27 | BP ARTmean 28 | BP ARTsys 29 | BP BP 30 | BP BP 31 | BP CPP 32 | BP CVP 33 | BP CVP_1/2 34 | BP CVP_1/3 35 | BP CVP_2/2 36 | BP CVP_2/3 37 | BP CVP 3 38 | BP CVP_3/3 39 | BP FAP 40 | BP FAP_1/3 41 | BP FAP_2/3 42 | BP FAP_3/3 43 | BP IC1 44 | BP IC1_1/3 45 | BP IC1_2/3 46 | BP IC1_3/3 47 | BP IC2 48 | BP IC2_1/3 49 | BP IC2_2/3 50 | BP IC2_3/3 51 | BP ICP_1/3 52 | BP ICP_2/3 53 | BP ICP_3/3 54 | BP LAP 55 | BP LAP_1/3 56 | BP LAP_2/3 57 | BP LAP_3/3 58 | BP LOC-ROC 59 | BP NBP_1/2 60 | BP NBP_2/2 61 | BP NBPdias 62 | BP NBPDias 63 | BP NBPmean 64 | BP NBPMean 65 | BP NBPsys 66 | BP NBPSys 67 | BP P4dias 68 | BP P4mean 69 | BP P4sys 70 | BP PA 71 | BP PAP 72 | BP PAP 73 | BP PAP_1/2 74 | BP PAP 2 75 | BP PAP_2/2 76 | BP PAPdias 77 | BP PAPDias 78 | BP PAPmean 79 | BP PAPMean 80 | BP PAPsys 81 | BP PAPSys 82 | BP PAWP 83 | BP PAWP_1/2 84 | BP PAWP_1/3 85 | BP PAWP_2/3 86 | BP PAWP_3/3 87 | BP Pressure 88 | BP Pressure 89 | BP Pressure1 90 | BP Pressure 1 91 | BP Pressure 2 92 | BP Pressure 3 93 | BP Pressure 4 94 | BP RAP_1/2 95 | BP RAP_1/3 96 | BP RAP_2/3 97 | BP RAP_3/3 98 | BP UAP 99 | BP UAP_1/2 100 | BP UAP_2/2 101 | BP UAPdias 102 | BP UAPmean 103 | BP UAPsys 104 | CO2 C02 105 | CO2 Co2 106 | CO2 CO2 107 | CO2 CO2^M 108 | CO2 CPCO2 109 | CO2 CPCO2_1/3 110 | CO2 CPCO2_2/3 111 | CO2 CPCO2_3/3 112 | CO CI 113 | CO CO 114 | CO C.O. 115 | CO CO_1/2 116 | CO CO_1/3 117 | CO CO_2/2 118 | CO CO_2/3 119 | CO CO_3/3 120 | ECG Abdomen_1 121 | ECG Abdomen_2 122 | ECG Abdomen_3 123 | ECG Abdomen_4 124 | ECG A-I 125 | ECG A-S 126 | ECG avf 127 | ECG aVF 128 | ECG AVF 129 | ECG AVF+ 130 | ECG avl 131 | ECG aVL 132 | ECG AVL 133 | ECG avr 134 | ECG aVR 135 | ECG AVR 136 | ECG CC5 137 | ECG chan 1 138 | ECG chan 2 139 | ECG chan 3 140 | ECG CM2 141 | ECG CM4 142 | ECG CM5 143 | ECG CS12 144 | ECG CS34 145 | ECG CS56 146 | ECG CS78 147 | ECG CS90 148 | ECG D3 149 | ECG D4 150 | ECG ECG 151 | ECG ECG 152 | ECG ECG0 153 | ECG ECG1 154 | ECG ECG 1 155 | ECG ECG 2 156 | ECG ECG 3 157 | ECG ECG 4 158 | ECG ECG AVF 159 | ECG ECG [ECG1] 160 | ECG ECG F 161 | ECG ECG I 162 | ECG ECG II 163 | ECG ECG III 164 | ECG ECG lead 1 165 | ECG ECG lead 2 166 | ECG ECG lead 3 167 | ECG ECG Lead AVF 168 | ECG ECG lead AVL 169 | ECG ECG lead I 170 | ECG ECG Lead I 171 | ECG ECG lead II 172 | ECG ECG lead II 173 | ECG ECG LeadII 174 | ECG ECG Lead II 175 | ECG ECG lead III 176 | ECG ECG Lead III 177 | ECG ECG lead V 178 | ECG ECG Lead V 179 | ECG ECG Lead V 180 | ECG ECG Lead V3 181 | ECG ECG Lead V4 182 | ECG ECG lead V5 183 | ECG ECG Lead V5 184 | ECG ECG lead V6 185 | ECG ECG signal 0 186 | ECG ECG signal 1 187 | ECG ECG V 188 | ECG ECG V3 189 | ECG ECG V Lead 190 | ECG EKG1-CHIN 191 | ECG EKG1-EKG2 192 | ECG E-S 193 | ECG i 194 | ECG I 195 | ECG I 196 | ECG I+ 197 | ECG ii 198 | ECG II 199 | ECG II 200 | ECG II+ 201 | ECG iii 202 | ECG III 203 | ECG III 204 | ECG III+ 205 | ECG lead I 206 | ECG lead II 207 | ECG lead V 208 | ECG MCL 209 | ECG MCL1 210 | ECG MCL1 211 | ECG MCL1+ 212 | ECG ML2 213 | ECG ML5 214 | ECG MLI 215 | ECG MLII 216 | ECG MLIII 217 | ECG mod.V1 218 | ECG MV2 219 | ECG MV2 220 | ECG V 221 | ECG V 222 | ECG V+ 223 | ECG v1 224 | ECG V1 225 | ECG V1-V2 226 | ECG v2 227 | ECG V2 228 | ECG V2-V3 229 | ECG v3 230 | ECG V3 231 | ECG v4 232 | ECG V4 233 | ECG V4-V5 234 | ECG v5 235 | ECG V5 236 | ECG v6 237 | ECG V6 238 | ECG VCGMAG 239 | ECG vx 240 | ECG vy 241 | ECG vz 242 | EEG Af3. 243 | EEG Af4. 244 | EEG Af7. 245 | EEG Af8. 246 | EEG Afz. 247 | EEG C1.. 248 | EEG C2 249 | EEG C2.. 250 | EEG C2-CS2 251 | EEG C3 252 | EEG C3.. 253 | EEG C3A2 254 | EEG C3-CS2 255 | EEG C3-P3 256 | EEG C4 257 | EEG C4.. 258 | EEG C4A1 259 | EEG C4-CS2 260 | EEG C4-P4 261 | EEG C5.. 262 | EEG C6 263 | EEG C6.. 264 | EEG C6-CS2 265 | EEG Cp1. 266 | EEG CP1-Ref 267 | EEG Cp2. 268 | EEG CP2 269 | EEG CP2-CS2 270 | EEG CP2-Ref 271 | EEG Cp3. 272 | EEG Cp4. 273 | EEG CP4 274 | EEG CP4-CS2 275 | EEG Cp5. 276 | EEG CP5-Ref 277 | EEG Cp6. 278 | EEG CP6 279 | EEG CP6-CS2 280 | EEG CP6-Ref 281 | EEG Cpz. 282 | EEG Cz.. 283 | EEG CZ 284 | EEG CZ-CS2 285 | EEG CZ-PZ 286 | EEG EEG 287 | EEG EEG C3-A2 [C3A2] 288 | EEG EEG (C3-O1) 289 | EEG EEG (C4-A1) 290 | EEG EEG C4-A1 [C4A1] 291 | EEG EEG Fpz-Cz 292 | EEG EEG (O2-A1) 293 | EEG EEG Pz-Oz 294 | EEG EEG(sec) 295 | EEG F1.. 296 | EEG F2.. 297 | EEG F3 298 | EEG F3.. 299 | EEG F3-C3 300 | EEG F3-CS2 301 | EEG F4 302 | EEG F4.. 303 | EEG F4-C4 304 | EEG F4-CS2 305 | EEG F5.. 306 | EEG F6.. 307 | EEG F7 308 | EEG F7.. 309 | EEG F7-CS2 310 | EEG F7-T7 311 | EEG F8 312 | EEG F8.. 313 | EEG F8-CS2 314 | EEG F8-T8 315 | EEG Fc1. 316 | EEG FC1-Ref 317 | EEG Fc2. 318 | EEG FC2-Ref 319 | EEG Fc3. 320 | EEG Fc4. 321 | EEG Fc5. 322 | EEG FC5-Ref 323 | EEG Fc6. 324 | EEG FC6-Ref 325 | EEG Fcz. 326 | EEG Fp1. 327 | EEG FP1 328 | EEG FP1-CS2 329 | EEG FP1-F3 330 | EEG FP1-F7 331 | EEG Fp2. 332 | EEG FP2 333 | EEG FP2-CS2 334 | EEG FP2-F4 335 | EEG FP2-F8 336 | EEG Fpz. 337 | EEG FT10-T8 338 | EEG Ft7. 339 | EEG Ft8. 340 | EEG FT9-FT10 341 | EEG Fz.. 342 | EEG FZ 343 | EEG FZ-CS2 344 | EEG FZ-CZ 345 | EEG Iz.. 346 | EEG LUE-RAE 347 | EEG O1.. 348 | EEG O1-CS2 349 | EEG O2 350 | EEG O2.. 351 | EEG O2-CS2 352 | EEG Oz.. 353 | EEG P1.. 354 | EEG P2.. 355 | EEG P3 356 | EEG P3.. 357 | EEG P3-CS2 358 | EEG P3-O1 359 | EEG P4 360 | EEG P4.. 361 | EEG P4-CS2 362 | EEG P4-O2 363 | EEG P5.. 364 | EEG P6.. 365 | EEG P7 366 | EEG P7.. 367 | EEG P7-CS2 368 | EEG P7-O1 369 | EEG P7-T7 370 | EEG P8 371 | EEG P8.. 372 | EEG P8-CS2 373 | EEG P8-O2 374 | EEG Po3. 375 | EEG Po4. 376 | EEG Po7. 377 | EEG Po8. 378 | EEG Poz. 379 | EEG Pz.. 380 | EEG PZ 381 | EEG PZ-CS2 382 | EEG PZ-OZ 383 | EEG T10. 384 | EEG T7 385 | EEG T7.. 386 | EEG T7-CS2 387 | EEG T7-FT9 388 | EEG T7-P7 389 | EEG T8 390 | EEG T8.. 391 | EEG T8-CS2 392 | EEG T8-P8 393 | EEG T9.. 394 | EEG Tp7. 395 | EEG Tp8. 396 | EMG EHG1 397 | EMG EHG2 398 | EMG EHG3 399 | EMG EHG4 400 | EMG EHG5 401 | EMG EHG6 402 | EMG EHG7 403 | EMG EHG8 404 | EMG EHG9 405 | EMG EHG10 406 | EMG EHG11 407 | EMG EHG12 408 | EMG EHG13 409 | EMG EHG14 410 | EMG EHG15 411 | EMG EHG16 412 | EMG EMG 413 | EMG EMG-Chin [EMYG] 414 | EMG EMG submental 415 | EMG EMG Submental 416 | EMG Left leg 417 | EMG Right leg 418 | EMG UC 419 | EOG EOG 420 | EOG EOG E1-A1 [EOGL] 421 | EOG EOG E2-A1 [EOGR] 422 | EOG EOG horizontal 423 | EOG EOG(L) 424 | EOG EOG(R) 425 | EOG EOG (right) 426 | EOG Lefteye 427 | EOG RightEye 428 | EP ABR 429 | HR FHR 430 | Flow Flow 431 | HR Ectopic Count 432 | HR HR 433 | HR HR_1/2 434 | HR HR_1/3 435 | HR HR_2/2 436 | HR HR_2/3 437 | HR HR_3/3 438 | HR HRABP 439 | HR HRABPo 440 | HR HRABPsh 441 | HR HRABPSQI 442 | HR HRECG 443 | HR HRECGSQI 444 | HR HRepltd 445 | HR HRSH1 446 | HR HRSH1SQI 447 | HR HRSHm 448 | HR HRSHmSQI 449 | HR HRSQI 450 | HR HRTrend 451 | HR HRwqrs 452 | HR PR 453 | HR Pulse 454 | HR PULSE 455 | HR PULSE_1/2 456 | HR PULSE_1/3 457 | HR PULSE_2/2 458 | HR PULSE_2/3 459 | HR PULSE_3/3 460 | HR PVC Rate per Minute 461 | HR PVC Rate per Minute_1/2 462 | HR PVC Rate per Minute_1/3 463 | HR PVC Rate per Minute_2/2 464 | HR PVC Rate per Minute_2/3 465 | HR PVC Rate per Minute_3/3 466 | Noise BW noise, signal 0 467 | Noise BW noise, signal 1 468 | Noise EM noise, signal 0 469 | Noise EM noise, signal 1 470 | Noise MA noise, signal 0 471 | Noise MA noise, signal 1 472 | O2 CPO2 473 | O2 dSpO2 474 | O2 SaO2 475 | O2 SaO2 [OSAT] 476 | O2 SO2 477 | O2 SpO2 478 | O2 SpO2_1/2 479 | O2 SpO2_1/3 480 | O2 SpO2_2/2 481 | O2 SpO2_2/3 482 | O2 SpO2_3/3 483 | O2 SpO2_AP 484 | O2 SpO2 Aperiodic 485 | O2 SpO2 L 486 | O2 SpO2 R 487 | PLETH PLETH 488 | PLETH PLETH 489 | PLETH PLETH L 490 | PLETH PLETH R 491 | Pos BodyPos 492 | Pos Position 493 | Resp abdo 494 | Resp Abdomen [ABMV] 495 | Resp ABDO RES 496 | Resp AIRFLOW 497 | Resp AWRR 498 | Resp ETCO2 499 | Resp Flow [AFLO] 500 | Resp IMCO2 501 | Resp RESP 502 | Resp RESP 503 | Resp RESP 504 | Resp RESP_1/2 505 | Resp RESP_1/3 506 | Resp RESP_2/2 507 | Resp RESP_2/3 508 | Resp RESP_3/3 509 | Resp Resp A 510 | Resp Resp (abdomen) 511 | Resp Resp (abdominal) 512 | Resp Resp C 513 | Resp Resp (chest) 514 | Resp Resp Imp 515 | Resp Resp. Imp. 516 | Resp Resp.Imp. 517 | Resp RESP IMP 518 | Resp Resp Inp. 519 | Resp Resp N 520 | Resp Resp (nasal) 521 | Resp Resp oro-nasal 522 | Resp Resp (sum) 523 | Resp ribcage 524 | Resp Sum 525 | Resp Thorax_1 526 | Resp Thorax_2 527 | Resp Thorax [CHMV] 528 | Resp THOR RES 529 | SCG SCG (seismocardiogram) 530 | Sound OAE 531 | Sound Soud 532 | Sound Sound 533 | Status CO2 off 534 | Status CVP off 535 | Status ECG No signal 536 | Status Ectopic Status 537 | Status Ectopic Status_1/2 538 | Status Ectopic Status_1/3 539 | Status Ectopic Status_2/2 540 | Status Ectopic Status_2/3 541 | Status Ectopic Status_3/3 542 | Status EDF Annotations 543 | Status Event marker 544 | Status Hypnogram 545 | Status ID+Sync+Error 546 | Status No CO2 547 | Status No ECG 548 | Status No ECG 549 | Status No PAP 550 | Status No Pressure 551 | Status no signal 552 | Status No signal 553 | Status No signal 554 | Status No Signal 555 | Status Nothing 556 | Status OFF 557 | Status OFF 558 | Status PAP off 559 | Status Rhythm Status 560 | Status Rhythm Status_1/2 561 | Status Rhythm Status_1/3 562 | Status Rhythm Status_2/2 563 | Status Rhythm Status_2/3 564 | Status Rhythm Status_3/3 565 | Stim VNS 566 | ST ST1 567 | ST ST2 568 | ST ST3 569 | ST ST_I 570 | ST ST_II 571 | ST ST_III 572 | ST ST_AVF 573 | ST ST_AVL 574 | ST ST_AVR 575 | ST ST_V1 576 | ST ST_V2 577 | ST ST_V3 578 | ST ST_V4 579 | ST ST_V5 580 | ST ST_V6 581 | ST ST AVF 582 | ST ST AVF_1/2 583 | ST ST AVF_1/3 584 | ST ST AVF_2/2 585 | ST ST AVF_2/3 586 | ST ST AVF_3/3 587 | ST ST AVL 588 | ST ST AVL_1/2 589 | ST ST AVL_1/3 590 | ST ST AVL_2/2 591 | ST ST AVL_2/3 592 | ST ST AVL_3/3 593 | ST ST AVR 594 | ST ST AVR_1/2 595 | ST ST AVR_1/3 596 | ST ST AVR_2/2 597 | ST ST AVR_2/3 598 | ST ST AVR_3/3 599 | ST ST I 600 | ST ST I_1/2 601 | ST ST I_1/3 602 | ST ST I_2/2 603 | ST ST I_2/3 604 | ST ST I_3/3 605 | ST ST II 606 | ST ST II_1/2 607 | ST ST II_1/3 608 | ST ST II_2/3 609 | ST ST II_3/3 610 | ST ST III 611 | ST ST III_1/2 612 | ST ST III_1/3 613 | ST ST III_2/2 614 | ST ST III_2/3 615 | ST ST III_3/3 616 | ST ST MCL 617 | ST ST MCL_1/2 618 | ST ST MCL_1/3 619 | ST ST MCL_2/2 620 | ST ST MCL_2/3 621 | ST ST MCL_3/3 622 | ST ST V 623 | ST ST V_1/2 624 | ST ST V_1/3 625 | ST ST V_2/2 626 | ST ST V_2/3 627 | ST ST V_3/3 628 | ST ST V2 629 | ST ST V5 630 | SV SV 631 | Temp BLOODT 632 | Temp BLOODT_1/2 633 | Temp BLOODT_1/3 634 | Temp BLOODT_2/2 635 | Temp BLOODT_2/3 636 | Temp BLOODT_3/3 637 | Temp Tblood 638 | Temp TEMP 639 | Temp Temp Art 640 | Temp Temp body 641 | Temp Temp Core 642 | Temp TEMPDIFF 643 | Temp Temp Esoph 644 | Temp Temp Nasal 645 | Temp Temp Skin 646 | Temp Temp Rect 647 | Temp Temp Vent 648 | -------------------------------------------------------------------------------- /pyfar/ventricular_beat_bank.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from fastdtw import fastdtw 4 | from scipy.spatial.distance import euclidean 5 | from datetime import datetime 6 | from utils import * 7 | from parameters import * 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import wfdb 11 | import csv 12 | import os 13 | 14 | 15 | AVERAGE_START_DIFF = 0.25 16 | AVERAGE_END_DIFF = 0.35 17 | 18 | headers = [ 19 | 'num', 20 | 'sample_name', 21 | 'arrhythmia', 22 | 'is_true_beat', 23 | 'start_time', 24 | 'end_time', 25 | 'comments' 26 | ] 27 | 28 | def generate_training(filename): 29 | training = [] 30 | 31 | with open(filename, 'r') as f: 32 | reader = csv.DictReader(f) 33 | 34 | for row in reader: 35 | sample_name = row['sample_name'] 36 | is_true_beat = int(row['is_vtach']) == 1 37 | 38 | start_time = float(row['start_time']) 39 | end_time = float(row['end_time']) 40 | 41 | # peak_time = float(row['peak_time']) 42 | # start_time = peak_time - AVERAGE_START_DIFF 43 | # end_time = peak_time + AVERAGE_END_DIFF 44 | 45 | sig, fields = wfdb.srdsamp(data_path + sample_name) 46 | start_index = int(start_time*250.) 47 | end_index = int(end_time*250.) 48 | channel_index = fields['signame'].index(row['lead']) 49 | 50 | beat_sig = sig[start_index:end_index,channel_index] 51 | 52 | training.append((beat_sig, is_true_beat, sample_name)) 53 | 54 | return training 55 | 56 | def get_self_beats(channel_sig, annotation, sample_name): 57 | ### TODO: add quality check before adding training beats 58 | 59 | training_beats = [] 60 | 61 | for ann_index in range(1, len(annotation)-1): 62 | start_index = (annotation[ann_index-1] + annotation[ann_index]) / 2 63 | end_index = (annotation[ann_index] + annotation[ann_index+1]) / 2 64 | beat_sig = channel_sig[start_index:end_index] 65 | 66 | training_beats.append((beat_sig, False, sample_name)) 67 | 68 | if len(training_beats) >= 5: 69 | break 70 | 71 | return training_beats 72 | 73 | 74 | def normalize_sig(sig): 75 | return (sig - np.mean(sig)) / np.std(sig) 76 | 77 | def is_ventricular_beat(beat_sig, training_beats): 78 | # Euclidean distance between beat sig and a flat line to represent noise --> not a ventricular beat 79 | min_distance = sum([val**2 for val in normalize_sig(beat_sig)]) 80 | classification = False 81 | min_sample_name = "" 82 | min_training_beat = [] 83 | 84 | # figure_num = 1 85 | # plt.figure(figsize=[12, 12]) 86 | 87 | for beat_tuple in training_beats: 88 | training_beat = beat_tuple[0] 89 | is_true_beat = beat_tuple[1] 90 | sample_name = beat_tuple[2] 91 | 92 | training_beat_normalized = normalize_sig(training_beat) 93 | beat_sig_normalized = normalize_sig(beat_sig) 94 | 95 | # if len(training_beat) > len(beat_sig): 96 | # training_beat_normalized = normalize_sig(training_beat[:len(beat_sig)]) 97 | 98 | # elif len(beat_sig) > len(training_beat): 99 | # beat_sig_normalized = normalize_sig(beat_sig[:len(training_beat)]) 100 | # distance = sum([val**2 for val in (training_beat_normalized - beat_sig_normalized)]) 101 | 102 | try: 103 | distance, path = fastdtw(beat_sig_normalized, training_beat_normalized, radius=250, dist=euclidean) 104 | except Exception as e: 105 | distance = float('inf') 106 | print("Error with training sample: {}".format(sample_name)) 107 | print(e) 108 | 109 | # print(sample_name, distance, is_true_beat) 110 | # plt.subplot(9, 5, figure_num) 111 | # plt.title(str(int(distance)) + " " + str(is_true_beat)) 112 | # plt.plot(training_beat_normalized, 'b-') 113 | # plt.axis('off') 114 | # plt.plot(beat_sig_normalized, 'r-') 115 | # figure_num += 1 116 | 117 | if distance < min_distance: 118 | min_distance = distance 119 | classification = is_true_beat 120 | min_sample_name = sample_name 121 | min_training_beat = training_beat_normalized 122 | 123 | # print("min: ", min_sample_name, min_distance, classification) 124 | # plt.show() 125 | 126 | # if classification: 127 | # plt.figure() 128 | # plt.plot(min_training_beat, 'b-') 129 | # plt.plot(beat_sig_normalized, 'r-') 130 | # plt.title(min_sample_name + " " + str(int(min_distance)) + " " + str(classification)) 131 | # plt.show() 132 | 133 | return classification 134 | 135 | 136 | def get_ventricular_beats(beats, training_beats): 137 | ventricular_beats = [] 138 | nonventricular_beats = [] 139 | 140 | for beat in beats: 141 | beat_sig = beat[2] 142 | 143 | if is_ventricular_beat(beat_sig, training_beats): 144 | ventricular_beats.append(beat) 145 | else: 146 | nonventricular_beats.append(beat) 147 | 148 | return ventricular_beats, nonventricular_beats 149 | 150 | 151 | ## 152 | # Returns beats (list of tuples): 153 | # annotation of beat QRS 154 | # start and end indices 155 | # sig of beat 156 | ## 157 | def get_beats(channel_sig, annotation): 158 | beats = [] 159 | for ann_index in range(1, len(annotation)-1): 160 | # Assumes a beat starts halfway between annotations and ends halfway between annotations 161 | start_index = (annotation[ann_index-1] + annotation[ann_index]) / 2 162 | end_index = (annotation[ann_index] + annotation[ann_index+1]) / 2 163 | 164 | # Assumes a beat starts 250ms before the annotation and ends 250 ms after the annotation 165 | # start_index = annotation[ann_index] - int(AVERAGE_START_DIFF * 250.) 166 | # end_index = annotation[ann_index] + int(AVERAGE_END_DIFF * 250.) + 1 167 | 168 | indices = (start_index, end_index) 169 | beat_sig = channel_sig[indices[0]:indices[1]] 170 | beat = (annotation[ann_index], indices, beat_sig) 171 | 172 | beats.append(beat) 173 | 174 | return beats 175 | 176 | def ventricular_beat_annotations_dtw( 177 | channel_sig, 178 | ann_path, 179 | sample_name, 180 | start_time, 181 | end_time, 182 | ann_type, 183 | ann_fs=250., 184 | training_filename="vtach_beats.csv"): 185 | 186 | training_beats = generate_training(training_filename) 187 | 188 | annotation = get_annotation(ann_path + sample_name, ann_type, ann_fs, start_time, end_time).annsamp 189 | full_annotation = get_annotation(ann_path + sample_name, ann_type, ann_fs, 0, start_time).annsamp 190 | sample_training_beats = get_self_beats(channel_sig, full_annotation, sample_name) 191 | 192 | beats = get_beats(channel_sig, annotation) 193 | ventricular_beats, nonventricular_beats = get_ventricular_beats(beats, sample_training_beats + training_beats) 194 | # ventricular_beats, nonventricular_beats = get_ventricular_beats(beats, training_beats) 195 | 196 | ventricular_beat_annotations = [ beat[0] for beat in ventricular_beats ] 197 | nonventricular_beat_annotations = [ beat[0] for beat in nonventricular_beats ] 198 | 199 | return ventricular_beat_annotations, nonventricular_beat_annotations 200 | 201 | 202 | def write_vtach_beats_files(data_path, ann_path, output_path, ecg_ann_type, start_time, end_time): 203 | for filename in os.listdir(data_path): 204 | if filename.endswith(HEADER_EXTENSION): 205 | sample_name = filename.rstrip(HEADER_EXTENSION) 206 | 207 | if sample_name[0] != 'v': 208 | continue 209 | 210 | sig, fields = wfdb.srdsamp(data_path + sample_name) 211 | if "II" not in fields['signame']: 212 | print("Lead II not found for sample: {}".format(sample_name)) 213 | continue 214 | 215 | output_filename = output_path + sample_name + ".csv" 216 | 217 | if os.path.isfile(output_filename): 218 | continue 219 | 220 | channel_index = fields['signame'].index("II") 221 | ann_type = ecg_ann_type + str(channel_index) 222 | 223 | start = datetime.now() 224 | 225 | with open(output_filename, "w") as f: 226 | channel_sig = sig[:,channel_index] 227 | 228 | vtach_beats, nonvtach_beats = ventricular_beat_annotations_dtw(channel_sig, ann_path, sample_name, start_time, end_time, ann_type) 229 | 230 | writer = csv.writer(f) 231 | writer.writerow(['ann_index', 'is_true_beat']) 232 | 233 | for beat in vtach_beats: 234 | writer.writerow([beat, 1]) 235 | for beat in nonvtach_beats: 236 | writer.writerow([beat, 0]) 237 | 238 | print("sample_name: {}".format(sample_name)) 239 | print(" elapsed: {}".format(datetime.now() - start)) 240 | 241 | 242 | if __name__ == '__main__': 243 | # sample_name = "v127l" 244 | # sample_name = "v141l" 245 | # ecg_ann_type = 'gqrs' 246 | # start_time = 296 247 | # end_time = 300 248 | # channel_index = 0 249 | # ann_fs = 250. 250 | # ann_type = 'gqrs' + str(channel_index) 251 | 252 | # sig, fields = wfdb.srdsamp(data_path + sample_name) 253 | # channel_sig = sig[:,channel_index] 254 | 255 | # vtach_beats, nonvtach_beats = ventricular_beat_annotations_dtw(channel_sig, ann_path, sample_name, start_time, end_time, ann_type) 256 | # vtach_indices = [ ann - start_time * 250. for ann in vtach_beats ] 257 | # nonvtach_indices = [ ann - start_time * 250. for ann in nonvtach_beats ] 258 | 259 | # plt.figure(figsize=[8,5]) 260 | # plt.plot(channel_sig[int(start_time*250.):int(end_time*250.)],'b-') 261 | # plt.plot(nonvtach_indices, [channel_sig[int(index)] for index in nonvtach_indices], 'bo', markersize=8) 262 | # plt.plot(vtach_indices, [ channel_sig[int(index)] for index in vtach_indices ], 'ro', markersize=8) 263 | # plt.show() 264 | 265 | start_time = 290 266 | end_time = 300 267 | write_vtach_beats_files(data_path, ann_path, output_path_bank, ecg_ann_type, start_time, end_time) 268 | 269 | 270 | # sig, fields = wfdb.srdsamp(data_path + sample_name) 271 | # channel_sig = sig[:,channel_index] 272 | 273 | # annotation = wfdb.rdann(ann_path + sample_name, ann_type, sampfrom=start*ann_fs, sampto=end*ann_fs)[0] 274 | # print(annotation) 275 | 276 | # beats = get_beats(channel_sig, annotation) 277 | 278 | 279 | # for beat in beats: 280 | # indices = beat[0] 281 | # beat_sig = beat[1] 282 | # time_vector = np.linspace(indices[0], indices[1], len(beat_sig)) 283 | 284 | # whole_sig = channel_sig[250*start:250*end] 285 | # sig_time_vector = np.linspace(250*start, 250*end, len(whole_sig)) 286 | 287 | # annotation_y = [ channel_sig[ann_t] for ann_t in annotation ] 288 | 289 | # plt.figure() 290 | # plt.plot(sig_time_vector, whole_sig, 'b') 291 | # plt.plot(time_vector, beat_sig, 'r') 292 | # plt.plot(annotation, annotation_y, 'go') 293 | # plt.show() 294 | 295 | 296 | 297 | # print("") 298 | # print(annotation[0] / float(250.)) 299 | -------------------------------------------------------------------------------- /pyfar/classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from spectrum import * 4 | from scipy.stats import kurtosis 5 | from sklearn.linear_model import LogisticRegression, LassoCV 6 | from sklearn.metrics import auc, roc_curve 7 | from datetime import datetime 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import csv 11 | import wfdb 12 | 13 | data_path = "sample_data/challenge_training_data/" 14 | answers_filename = "sample_data/answers.csv" 15 | features_filename = "sample_data/features.csv" 16 | 17 | start_time = 290 18 | end_time = 300 19 | fs = 250. 20 | 21 | TRAINING_THRESHOLD = 600 22 | 23 | 24 | def get_psd(channel_subsig, order, nfft): 25 | channel_subsig = channel_subsig-np.mean(channel_subsig) 26 | ar, rho, ref = arburg(channel_subsig, order) 27 | psd = arma2psd(ar, rho=rho, NFFT=nfft) 28 | psd = psd[len(psd):len(psd)/2:-1] 29 | 30 | # plt.figure() 31 | # plt.plot(linspace(0, 1, len(psd)), abs(psd)*2./(2.*np.pi)) 32 | # plt.title('PSD') 33 | # plt.ylabel('Log of PSD') 34 | # plt.xlabel('Normalized Frequency') 35 | # plt.show() 36 | 37 | # print(len(psd)) 38 | 39 | return psd 40 | 41 | 42 | def get_baseline(channel_subsig, order=30, nfft=1024): 43 | psd = get_psd(channel_subsig, order, nfft) 44 | 45 | numerator_min_freq = int(0 * nfft / 125.) 46 | numerator_max_freq = int(1 * nfft / 125.) 47 | denominator_min_freq =int( 0 * nfft / 125.) 48 | denominator_max_freq = int(40 * nfft / 125.) 49 | 50 | numerator = sum(psd[numerator_min_freq:numerator_max_freq+1]) 51 | denominator = sum(psd[denominator_min_freq:denominator_max_freq+1]) 52 | 53 | baseline = float(numerator) / denominator 54 | return 1 - baseline 55 | 56 | 57 | def get_power(channel_subsig, order=30, nfft=1024): 58 | psd = get_psd(channel_subsig, order, nfft) 59 | 60 | numerator_min_freq =int( 5 * nfft / 125.) 61 | numerator_max_freq = int(15 * nfft / 125.) 62 | denominator_min_freq = int(5 * nfft / 125.) 63 | denominator_max_freq = int(40 * nfft / 125.) 64 | 65 | numerator = sum(psd[numerator_min_freq:numerator_max_freq+1]) 66 | denominator = sum(psd[denominator_min_freq:denominator_max_freq+1]) 67 | 68 | power = float(numerator) / denominator 69 | return power 70 | 71 | 72 | def get_ksqi(channel_subsig): 73 | ksqi = kurtosis(channel_subsig) 74 | 75 | if abs(ksqi) >= 25: 76 | return 25 77 | return ksqi 78 | 79 | 80 | def get_pursqi(channel_subsig): 81 | s = channel_subsig 82 | sd = np.diff(channel_subsig); 83 | sdd = np.zeros(len(channel_subsig)) 84 | for i in range(len(channel_subsig)): 85 | if i == 0: 86 | sdd[i] = channel_subsig[2] - 2*channel_subsig[1] + channel_subsig[0] 87 | 88 | elif i == len(channel_subsig) - 1: 89 | sdd[i] = channel_subsig[-1] - 2*channel_subsig[-2] + channel_subsig[-3] 90 | 91 | else: 92 | sdd[i] = channel_subsig[i+1] - 2*channel_subsig[i] + channel_subsig[i-1] 93 | 94 | w0 = (2*np.pi/len(s))*sum(np.square(s)) # 2pi E[s^2]=2pi Rs(0) 95 | w2 = (2*np.pi/len(s))*sum(np.square(sd)) # 2pi Ts^2 E[sd^2], 96 | w4 = (2*np.pi/len(s))*sum(np.square(sdd)) #2pi Ts^4 E[sdd^2] 97 | 98 | pursqi = (w2**2)/(w0*w4) 99 | return pursqi 100 | 101 | 102 | def get_channel_type(channel_name): 103 | channel_types_dict = {} 104 | with open("sample_data/sigtypes", "r") as f: 105 | for line in f: 106 | splitted_line = line.split("\t") 107 | channel = splitted_line[-1].rstrip() 108 | channel_type = splitted_line[0] 109 | 110 | if channel_name == channel: 111 | return channel_type 112 | 113 | raise Exception("Unknown channel name") 114 | 115 | 116 | # Return list of channel indices for channels of type channel_type 117 | def get_channels_of_type(channels, channel_type): 118 | channel_indices = np.array([]) 119 | 120 | for channel_index in range(len(channels)): 121 | channel_name = channels[channel_index] 122 | if channel_type == get_channel_type(channel_name): 123 | channel_indices = np.append(channel_indices, channel_index) 124 | 125 | return channel_indices 126 | 127 | 128 | # x includes sample names --> exclude for classification 129 | # training = sample num < 600 130 | # testing = sample num > 600 131 | def generate_features(features_filename): 132 | training_x, training_y = [], [] 133 | testing_x, testing_y = [], [] 134 | 135 | with open(features_filename, 'w') as fo: 136 | writer = csv.writer(fo) 137 | writer.writerow(['sample_name', 'is_training', 'is_true', 'baseline', 'dtw', 'psd', 'power', 'ksqi', 'pursqi']) 138 | 139 | with open(answers_filename, 'r') as f: 140 | reader = csv.reader(f) 141 | headers = reader.next() 142 | 143 | reader = csv.DictReader(f, fieldnames=headers) 144 | 145 | for row in reader: 146 | sample_name = row['sample_name'] 147 | sample_number = sample_name[1:-1] 148 | 149 | sig, fields = wfdb.rdsamp(data_path + sample_name) 150 | subsig = sig[int(start_time*fs):int(end_time*fs),:] 151 | ecg_channels = get_channels_of_type(fields['signame'], "ECG") 152 | 153 | if len(ecg_channels) == 0: 154 | print("NO ECG CHANNELS FOR SAMPLE: {}".format(sample_name)) 155 | continue 156 | 157 | channel_subsig = subsig[:,int(ecg_channels[0])] 158 | 159 | try: 160 | baseline = get_baseline(channel_subsig) 161 | power = get_power(channel_subsig) 162 | ksqi = get_ksqi(channel_subsig) 163 | pursqi = get_pursqi(channel_subsig) 164 | 165 | except Exception as e: 166 | print("sample_name: {}\n{}".format(sample_name, e)) 167 | continue 168 | 169 | if np.isnan([baseline, power, ksqi, pursqi]).any(): 170 | print("sample containing nan: {}\n{}".format(sample_name, [baseline, power, ksqi, pursqi])) 171 | continue 172 | 173 | if int(sample_number) < TRAINING_THRESHOLD: 174 | is_training = 1 175 | else: 176 | is_training = 0 177 | 178 | x_val = [ 179 | row['sample_name'], 180 | is_training, 181 | int(row['is_true']), 182 | int(row['baseline_is_classified_true']), 183 | int(row['dtw_is_classified_true']), 184 | baseline, 185 | power, 186 | ksqi, 187 | pursqi 188 | ] 189 | 190 | writer.writerow(x_val) 191 | 192 | def generate_datasets(features_filename): 193 | training_x, training_y, testing_x, testing_y = [], [], [], [] 194 | 195 | with open(features_filename, 'r') as f: 196 | reader = csv.reader(f) 197 | headers = reader.next() 198 | 199 | reader = csv.DictReader(f, fieldnames=headers) 200 | 201 | for row in reader: 202 | x_val = [ 203 | int(row['baseline']), 204 | int(row['dtw']), 205 | float(row['psd']), 206 | float(row['power']), 207 | float(row['ksqi']), 208 | float(row['pursqi']) 209 | ] 210 | y_val = int(row['is_true']) 211 | 212 | if int(row['is_training']) == 1 and row['sample_name'][0] == 'v': 213 | training_x.append(x_val) 214 | training_y.append(y_val) 215 | elif row['sample_name'][0] == 'v': 216 | testing_x.append(x_val) 217 | testing_y.append(y_val) 218 | return training_x, training_y, testing_x, testing_y 219 | 220 | 221 | def get_score(prediction, true): 222 | 223 | TP = np.sum([(prediction[i] == 1) and (true[i] == 1) for i in range(len(prediction))]) 224 | TN = np.sum([(prediction[i] == 0) and (true[i] == 0) for i in range(len(prediction))]) 225 | FP = np.sum([(prediction[i] == 1) and (true[i] == 0) for i in range(len(prediction))]) 226 | FN = np.sum([(prediction[i] == 0) and (true[i] == 1) for i in range(len(prediction))]) 227 | 228 | # print('{} {} {} {}'.format(TP, TN, FP, FN)) 229 | 230 | numerator = TP + TN 231 | denominator = FP + 5*FN + numerator 232 | 233 | return float(numerator) / denominator 234 | 235 | 236 | 237 | if __name__ == '__main__': 238 | print("Nothing to do!") 239 | # print("Generating datasets...") 240 | # generate_features(features_filename) 241 | # training_x, training_y, testing_x, testing_y = generate_datasets(features_filename) 242 | 243 | # print("{} {}".format(len(training_y), len(testing_y))) 244 | 245 | 246 | # # start = datetime.now() 247 | # # print("Starting at".format(start)) 248 | # # print("Generating datasets...") 249 | # # training_x, training_y, testing_x, testing_y = generate_training_testing() 250 | 251 | 252 | 253 | # print("Running classifier...") 254 | # classifier = LogisticRegression(penalty='l1') 255 | # lasso = LassoCV() 256 | # classifier.fit(training_x, training_y) 257 | 258 | # # probability of class 1 (versus 0) 259 | # # predictions_y = classifier.predict_proba(testing_x)[:,1] 260 | # # score = classifier.score(testing_x, testing_y) 261 | 262 | # # fpr, tpr, thresholds = roc_curve(testing_y, predictions_y) 263 | # # auc = auc(fpr, tpr) 264 | 265 | # # print("auc: {}".format(auc)) 266 | # # print("score: {}".format(score) 267 | # # print("fpr: {}".format(fpr), end=" ")) 268 | # # print("tpr: {}".format(tpr) 269 | 270 | # # plt.figure() 271 | # # plt.title("ROC curve for DTW-only classiifer") 272 | # # plt.xlabel("False positive rate") 273 | # # plt.ylabel("True positive rate") 274 | # # plt.plot(fpr, tpr) 275 | # # plt.show() 276 | # lasso.fit(training_x, training_y) 277 | # predictions_y = lasso.predict(testing_x) 278 | 279 | # fpr, tpr, thresholds = roc_curve(testing_y, predictions_y) 280 | 281 | # chall_score = list() 282 | # for th in thresholds: 283 | # chall_score.append(get_score([x >= th for x in predictions_y], testing_y)) 284 | 285 | 286 | # auc = auc(fpr, tpr) 287 | 288 | # print(classifier.coef_) 289 | # print("auc: {}".format(auc)) 290 | # print("score: {}".format(score)) 291 | # print("fpr: {}".format(fpr)) 292 | # print("tpr: {}".format(tpr)) 293 | 294 | 295 | # plt.figure() 296 | # plt.title("ROC curve for top-level classifier with challenge scores") 297 | # plt.xlabel("False positive rate") 298 | # plt.ylabel("True positive rate") 299 | # plt.plot(fpr, tpr, label='ROC Curve') 300 | # plt.plot(fpr, chall_score, label='Challenge score') 301 | # plt.show() 302 | 303 | # # DTW only 304 | # # auc: 0.461675144589 305 | # # score: 0.529166666667 306 | 307 | # # Baseline only 308 | # # auc: 0.877012054909 309 | # # score: 0.875 310 | 311 | # # Combined 312 | # # auc: 0.910041112118 313 | # # score: 0.841666666667 314 | -------------------------------------------------------------------------------- /pyfar/dtw.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from utils import * 4 | from parameters import * 5 | from datetime import datetime 6 | from scipy.signal import resample 7 | from scipy.spatial.distance import euclidean 8 | from scipy.stats.mstats import zscore 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import sklearn 12 | import fastdtw 13 | import wfdb 14 | import json 15 | import os 16 | import glob 17 | import csv 18 | 19 | def read_signals(data_path): 20 | signals_dict = {} 21 | fields_dict = {} 22 | 23 | for filename in os.listdir(data_path): 24 | if filename.endswith(HEADER_EXTENSION): 25 | sample_name = filename.rstrip(HEADER_EXTENSION) 26 | 27 | sig, fields = wfdb.srdsamp(data_path + sample_name) 28 | 29 | signals_dict[sample_name] = sig 30 | fields_dict[sample_name] = fields 31 | 32 | return signals_dict, fields_dict 33 | 34 | 35 | def get_data(sig_dict, fields_dict, num_training): 36 | training_keys = list(sig_dict.keys())[:num_training] 37 | testing_keys = list(sig_dict.keys())[num_training:] 38 | 39 | sig_training = { key : sig_dict[key] for key in training_keys } 40 | fields_training = { key : fields_dict[key] for key in training_keys } 41 | sig_testing = { key : sig_dict[key] for key in testing_keys } 42 | fields_testing = { key : fields_dict[key] for key in testing_keys } 43 | 44 | return sig_training, fields_training, sig_testing, fields_testing 45 | 46 | 47 | def downsample_signal(sig, fields, Fnew=125): 48 | Fs = fields['fs'] 49 | 50 | # downsample if needed 51 | if FnewFs: 55 | sig_new = sig 56 | print('{} is higher than sampling frequency ({}) - not resampling.'.format(Fnew, Fs)) 57 | return sig_new 58 | 59 | 60 | def alt_dtw(): 61 | T_START=290 62 | T_END=300 63 | Fnew=125 64 | 65 | for s in fields_dict: 66 | print("\n\n" + s) 67 | for j, lead in enumerate(fields_dict[s]['signame']): 68 | # get signal lead name 69 | if lead in ['I','II','III','V']: 70 | # print('\t' + lead, end=' ') 71 | sig1 = (np.copy(signals_dict[s][:,j])*fields_dict[s]['gain'][j]).astype(int) 72 | 73 | # downsample the signal 74 | sig1 = downsample_signal(sig1, fields_dict[s], Fnew=Fnew) 75 | 76 | # extract the 10 seconds of interest 77 | sig1 = sig1[T_START*Fnew:T_END*Fnew] 78 | 79 | # normalize 80 | mu1 = np.mean(sig1) 81 | sd1 = np.std(sig1) 82 | 83 | sig1 = (sig1 - mu1) / sd1 84 | 85 | # compare to all other signals with that lead 86 | for s2 in fields_dict: 87 | if s==s2: 88 | continue 89 | 90 | # print(s2, end=' ') 91 | if lead in fields_dict[s2]['signame']: 92 | # get index of lead in 2nd signal 93 | m = [i for i, val in enumerate(fields_dict[s2]['signame']) if val==lead][0] 94 | 95 | sig2 = (np.copy(signals_dict[s2][:,m])*fields_dict[s2]['gain'][m]).astype(int) 96 | 97 | # downsample the signal 98 | sig2 = downsample_signal(sig2, fields_dict[s2], Fnew=Fnew) 99 | 100 | # extract the 10 seconds of interest 101 | sig2 = sig2[T_START*Fnew:T_END*Fnew] 102 | 103 | # normalize 104 | sig2 = (sig2 - np.mean(sig2)) / np.std(sig2) 105 | 106 | # run DTW 107 | dist, cost, path = mlpy.dtw_std(sig1, sig2, dist_only=False, squared=False) 108 | 109 | #path[0], sig2[path[1]] 110 | #path[0], sig2[path[1]] 111 | sig_out = np.array( [path[1], (sig1[path[0]]*sd1)+mu1] ).T 112 | 113 | np.savetxt('dtw/' + lead + '_' + s + '_to_' + s2 + '.csv', 114 | sig_out, fmt=['%4d','%8.2f'], delimiter=',') 115 | else: 116 | # the comparison signal does not have the same lead 117 | continue 118 | print() # newline to go to a new signal 119 | 120 | else: 121 | continue 122 | 123 | def normalize_sig(sig): 124 | return (sig - np.mean(sig)) / np.std(sig) 125 | 126 | 127 | def sig_distance(sig1, fields1, sig2, fields2, radius, new_fs, max_channels=1, num_secs=10): 128 | channels_dists = {} 129 | channels1 = fields1['signame'] 130 | channels2 = fields2['signame'] 131 | 132 | common_channels = list(set(channels1).intersection(set(channels2))) 133 | if len(common_channels) > max_channels: 134 | common_channels = common_channels[:max_channels] 135 | 136 | start_index = int(DEFAULT_ECG_FS * (ALARM_TIME-num_secs)) 137 | end_index = int(DEFAULT_ECG_FS * ALARM_TIME) 138 | 139 | for channel in common_channels: 140 | if channel == "RESP": 141 | continue 142 | 143 | try: 144 | channel_index1 = channels1.index(channel) 145 | channel_index2 = channels2.index(channel) 146 | 147 | except Exception as e: 148 | print(" channels1: {}".format(channels1), end=" ") 149 | print(" channels2: {}".format(channels2), end=" ") 150 | print(" common_channels:{}".format(common_channels), end=" ") 151 | print(e) 152 | continue 153 | 154 | channel1 = sig1[start_index:end_index,channel_index1] 155 | channel2 = sig2[start_index:end_index,channel_index2] 156 | 157 | # Downsample 158 | channel1_sampled = resample(channel1, num_secs*new_fs) 159 | channel2_sampled = resample(channel2, num_secs*new_fs) 160 | 161 | # Normalize 162 | channel1_normalized = normalize_sig(channel1_sampled) 163 | channel2_normalized = normalize_sig(channel2_sampled) 164 | 165 | try: 166 | if radius > 0: 167 | distance, path = fastdtw.fastdtw(channel1_normalized, channel2_normalized, radius=radius, dist=euclidean) 168 | else: 169 | distance = sum([val**2 for val in (channel1_normalized - channel2_normalized)]) 170 | 171 | except Exception as e: 172 | continue 173 | 174 | channels_dists[channel] = distance 175 | 176 | return channels_dists 177 | 178 | 179 | def sig_distance_from_file(sig1, fields1, sig2, fields2, new_fs, num_secs=10): 180 | channels_dists = {} 181 | 182 | sample_name1 = fields1['filename'][0].strip('.mat') 183 | sample_name2 = fields2['filename'][0].strip('.mat') 184 | 185 | start_index = int(DEFAULT_ECG_FS * (ALARM_TIME-num_secs)) 186 | end_index = int(DEFAULT_ECG_FS * ALARM_TIME) 187 | 188 | pathname = 'dtw_data/*_{}_to_{}.csv'.format(sample_name1, sample_name2) 189 | matched_filenames = glob.glob(pathname) 190 | 191 | for filename in matched_filenames: 192 | channel = filename.lstrip('dtw_data/')[:filename.index("_")-1] 193 | channel_index1 = fields1['signame'].index(channel) 194 | channel_index2 = fields2['signame'].index(channel) 195 | 196 | channel1 = sig1[start_index:end_index,channel_index1] 197 | channel2 = sig2[start_index:end_index,channel_index2] 198 | 199 | # Downsample 200 | channel1_sampled = resample(channel1, num_secs*new_fs) 201 | channel2_sampled = resample(channel2, num_secs*new_fs) 202 | 203 | with open(filename, 'r') as f: 204 | reader = csv.DictReader(f, fieldnames=['path0', 'path1']) 205 | 206 | indices = [ [int(row['path0']), int(row['path1'])] for row in reader ] 207 | 208 | channel1_indices = [ index[0] for index in indices ] 209 | channel2_indices = [ index[1] for index in indices ] 210 | 211 | channel1_warped = [ channel1_sampled[index] for index in channel1_indices ] 212 | channel2_warped = [ channel2_sampled[index] for index in channel2_indices ] 213 | 214 | # Normalize 215 | channel1_normalized = normalize_sig(channel1_warped) 216 | channel2_normalized = normalize_sig(channel2_warped) 217 | 218 | distance = sum([val**2 for val in (channel1_normalized - channel2_normalized)]) 219 | 220 | channels_dists[channel] = distance 221 | 222 | return channels_dists 223 | 224 | 225 | def normalize_distances(channels_dists, normalization='ecg_average', sigtypes=sigtypes_filename): 226 | if len(channels_dists.keys()) == 0: 227 | return float('inf') 228 | 229 | if len(channels_dists.keys()) == 1: 230 | return channels_dists.values().pop() 231 | 232 | ecg_channels = [ channel for channel in channels_dists if get_channel_type(channel, sigtypes) == "ECG" ] 233 | ecg_dists = [ channels_dists[channel] for channel in ecg_channels ] 234 | 235 | if normalization == 'ecg_average': 236 | return np.mean(ecg_dists) 237 | 238 | if normalization == 'ecg_min': 239 | return min(ecg_dists) 240 | 241 | if normalization == 'ecg_max': 242 | return max(ecg_dists) 243 | 244 | if normalization == 'average': 245 | return np.mean(channels_dists.values()) 246 | 247 | if normalization == 'min': 248 | return min(channels_dists.values()) 249 | 250 | elif normalization == 'max': 251 | return max(channels_dists.values()) 252 | 253 | raise Exception("Unrecognized normalization") 254 | 255 | 256 | def predict(test_sig, test_fields, sig_training_by_arrhythmia, fields_training_by_arrhythmia, radius, new_fs, weighting): 257 | min_distance = float("inf") 258 | min_label = "" 259 | min_sample = "" 260 | 261 | arrhythmia = get_arrhythmia_type(test_fields) 262 | sig_training = sig_training_by_arrhythmia[arrhythmia] 263 | fields_training = fields_training_by_arrhythmia[arrhythmia] 264 | 265 | for sample_name, train_sig in sig_training.items(): 266 | train_fields = fields_training[sample_name] 267 | 268 | # channels_dists = sig_distance_from_file(test_sig, test_fields, train_sig, train_fields, new_fs) 269 | # print("sample_name: {}".format(sample_name)) 270 | # print("channels_dists: {}".format(channels_dists)) 271 | # if len(channels_dists.keys()) == 0: 272 | # print("Processing sample {} from scratch".format(sample_name)) 273 | channels_dists = sig_distance(test_sig, test_fields, train_sig, train_fields, radius, new_fs) 274 | 275 | distance = normalize_distances(channels_dists) 276 | 277 | if distance < min_distance: 278 | min_distance = distance 279 | min_label = is_true_alarm_fields(train_fields) 280 | min_sample = sample_name 281 | 282 | return min_label, min_distance, min_sample 283 | 284 | 285 | ## Get classification accuracy of testing based on training set 286 | def run_classification(sig_training_by_arrhythmia, fields_training_by_arrhythmia, sig_testing, fields_testing, radius, new_fs, weighting): 287 | num_correct = 0 288 | matrix = { 289 | "TP": [], 290 | "FP": [], 291 | "TN": [], 292 | "FN": [] 293 | } 294 | min_distances = {} 295 | 296 | for sample_name, test_sig in sig_testing.items(): 297 | start = datetime.now() 298 | test_fields = fields_testing[sample_name] 299 | 300 | predicted, distance, sample = predict(test_sig, test_fields, sig_training_by_arrhythmia, fields_training_by_arrhythmia, radius, new_fs, weighting) 301 | actual = is_true_alarm_fields(test_fields) 302 | # print("sample: {}".format(sample_name)) 303 | # print(" predicted: {}".format(predicted)) 304 | # print(" actual: {}".format(actual)) 305 | # print("elapsed: {}".format(datetime.now() - start)) 306 | 307 | min_distances[sample_name] = (distance, sample, predicted == actual) 308 | 309 | classification = get_matrix_classification(actual, predicted) 310 | matrix[classification].append(sample_name) 311 | 312 | return matrix, min_distances 313 | 314 | 315 | def run(data_path, num_training, arrhythmias, matrix_filename, distances_filename, radius=0, new_fs=DEFAULT_ECG_FS, weighting=1): 316 | print("Generating sig and fields dicts...") 317 | sig_dict, fields_dict = read_signals(data_path) 318 | sig_training, fields_training, sig_testing, fields_testing = \ 319 | get_data(sig_dict, fields_dict, num_training) 320 | sig_training_by_arrhythmia = { arrhythmia : get_samples_of_type(sig_training, arrhythmia) \ 321 | for arrhythmia in arrhythmias } 322 | fields_training_by_arrhythmia = { arrhythmia : get_samples_of_type(fields_training, arrhythmia) \ 323 | for arrhythmia in arrhythmias } 324 | 325 | sig_testing_temp = { sample_name : sig_testing[sample_name] for sample_name in sig_testing.keys() if sample_name[0] == 'v' } 326 | fields_testing_temp = { sample_name : fields_testing[sample_name] for sample_name in fields_testing.keys() if sample_name[0] == 'v' } 327 | 328 | print("Calculating classification accuracy...") 329 | matrix, min_distances = run_classification( \ 330 | sig_training_by_arrhythmia, fields_training_by_arrhythmia, sig_testing_temp, fields_testing_temp, radius, new_fs, weighting) 331 | 332 | write_json(matrix, matrix_filename) 333 | write_json(min_distances, distances_filename) 334 | 335 | 336 | 337 | if __name__ == '__main__': 338 | start = datetime.now() 339 | 340 | new_fs = 125 341 | num_training = 500 342 | arrhythmias = ['a', 'b', 't', 'v', 'f'] 343 | 344 | run(data_path, num_training, arrhythmias, matrix_filename, distances_filename, radius=0, new_fs=new_fs) 345 | 346 | matrix = read_json(matrix_filename) 347 | min_distances = read_json(distances_filename) 348 | 349 | counts = { key : len(matrix[key]) for key in matrix.keys() } 350 | # vtach_counts, vtach_matrix = get_by_arrhythmia(matrix, 'v') 351 | 352 | print("accuracy: {}".format(get_classification_accuracy(matrix))) 353 | print("score: {}".format(get_score(matrix))) 354 | print_stats(counts) 355 | 356 | # print("accuracy: {}".format(get_classification_accuracy(vtach_matrix))) 357 | # print("score: {}".format(get_score(vtach_matrix))) 358 | # print_stats(vtach_counts) 359 | 360 | 361 | # print("\nVTACH STATS") 362 | # arrhythmia_dict = get_counts_by_arrhythmia(matrix, "v") 363 | # arrhythmia_counts = { key : arrhythmia_dict[key][0] for key in arrhythmia_dict.keys() } 364 | # arrhythmia_matrix = { key : arrhythmia_dict[key][1] for key in arrhythmia_dict.keys() } 365 | # print("accuracy: {}".format(get_classification_accuracy(arrhythmia_matrix))) 366 | # print("score: {}".format(get_score(arrhythmia_matrix))) 367 | # print_stats(arrhythmia_counts) 368 | -------------------------------------------------------------------------------- /pyfar/ventricular_beat_stdev.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from classifier import get_baseline, get_power, get_ksqi, get_pursqi 4 | from fastdtw import fastdtw 5 | from scipy.spatial.distance import euclidean 6 | from scipy.stats import entropy 7 | from datetime import datetime 8 | from copy import deepcopy 9 | from utils import * 10 | from parameters import * 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | import wfdb 14 | import peakutils 15 | import csv 16 | import os 17 | import json 18 | 19 | 20 | STD_MULTIPLIER = 1 21 | MIN_DISTANCE_DIFF = 5 22 | # Assuming a max physiologically possible HR of 300 23 | MIN_PEAK_DIST = 60. / 300 * 250 24 | # Assuming a min physiologically possible HR of 30 25 | MAX_PEAK_DIST = 60. / 30 * 250 26 | 27 | DEBUG = False 28 | 29 | def dprint(*args): 30 | if DEBUG: 31 | for arg in args: 32 | print(arg) 33 | print("") 34 | 35 | 36 | def is_noisy( 37 | channel_subsig, 38 | checks_to_use, 39 | baseline_threshold=0.75, 40 | power_threshold=0.9, 41 | ksqi_threshold=4, 42 | pursqi_threshold=5 43 | ): 44 | 45 | checks = [] 46 | dprint(get_baseline(channel_subsig), get_power(channel_subsig), get_ksqi(channel_subsig)) 47 | 48 | # True if passes check 49 | baseline_check = get_baseline(channel_subsig) > baseline_threshold 50 | power_check = get_power(channel_subsig) > power_threshold 51 | ksqi_check = get_ksqi(channel_subsig) > ksqi_threshold 52 | # pursqi_check = get_pursqi(channel_subsig) > pursqi_threshold 53 | # checks = [baseline_check, power_check, ksqi_check, pursqi_check] 54 | 55 | # TODO: maybe high pass filter instead of using baseline check as a check 56 | if 'baseline' in checks_to_use: 57 | checks.append(baseline_check) 58 | 59 | if 'power' in checks_to_use: 60 | checks.append(power_check) 61 | 62 | if 'ksqi' in checks_to_use: 63 | checks.append(ksqi_check) 64 | 65 | return not all(checks) 66 | 67 | 68 | def get_adjusted_ann_indices(annotation, ann_index, start_ratio=1/3.): 69 | a = annotation[ann_index-1] 70 | b = annotation[ann_index] 71 | c = annotation[ann_index+1] 72 | 73 | end_ratio = 1-start_ratio 74 | 75 | ann_start_index = b - start_ratio*(b-a) 76 | ann_end_index = b + end_ratio*(c-b) 77 | 78 | return ann_start_index, ann_end_index 79 | 80 | 81 | ## 82 | # Returns self_beats, a list of: 83 | # annotation index 84 | # beat_sig 85 | # for regular beats detected in own patient's signal 86 | ## 87 | def get_self_beats( 88 | channel_sig, 89 | annotation, 90 | sample_name, 91 | checks_to_use=['baseline', 'power', 'ksqi'], 92 | num_self_beats=20, 93 | window_increment=10, 94 | fs=250.): 95 | 96 | self_beats = [] 97 | 98 | # Get self beats in first 2 min 99 | for start_time in range(0, 120-window_increment+1, window_increment): 100 | end_time = start_time + window_increment 101 | start_index = int(start_time * fs) 102 | end_index = int(end_time * fs) 103 | 104 | channel_subsig = channel_sig[start_index:end_index] 105 | # print(start_index, end_index,) 106 | 107 | if not is_noisy(channel_subsig, checks_to_use): 108 | for ann_index in range(1, len(annotation)-1): 109 | # TODO: update to have the start and end index be smoothed over past values 110 | ann_start_index, ann_end_index = get_adjusted_ann_indices(annotation, ann_index) 111 | 112 | # If beat annotation in clean (not noisy) data range 113 | if ann_start_index > start_index and ann_end_index < end_index: 114 | beat_sig = channel_sig[int(ann_start_index):int(ann_end_index)] 115 | 116 | peaks = peakutils.indexes(beat_sig, thres=0.75*max(beat_sig), min_dist=MIN_PEAK_DIST) 117 | 118 | # if DEBUG: 119 | # plt.figure() 120 | # plt.plot(peaks, [beat_sig[index] for index in peaks], 'ro') 121 | # plt.plot(beat_sig) 122 | # plt.show() 123 | 124 | if len(peaks) < 2: 125 | self_beats.append((annotation[ann_index], beat_sig)) 126 | 127 | if len(self_beats) >= num_self_beats: 128 | break 129 | 130 | dprint("Found", len(self_beats), "self beats.") 131 | 132 | if DEBUG: 133 | plt.figure() 134 | for i, beat in enumerate(self_beats): 135 | plt.subplot(5, 4, i+1) 136 | plt.plot(beat[1]) 137 | plt.show() 138 | 139 | return self_beats 140 | 141 | 142 | def get_best_self_beats(channel_sig, full_annotation, sample_name): 143 | self_beats = get_self_beats(channel_sig, full_annotation, sample_name) 144 | 145 | if len(self_beats) == 0: 146 | self_beats = get_self_beats(channel_sig, full_annotation, sample_name, ['power', 'ksqi']) 147 | 148 | if len(self_beats) == 0: 149 | self_beats = get_self_beats(channel_sig, full_annotation, sample_name, ['power']) 150 | 151 | if len(self_beats) == 0: 152 | dprint("No self beats found for", sample_name) 153 | 154 | return self_beats 155 | 156 | 157 | def normalize_sig(sig): 158 | return (sig - np.mean(sig)) / np.std(sig) 159 | 160 | 161 | ## 162 | # Returns mean and stdev comparing against every other self beat in bank 163 | ## 164 | def get_baseline_distances(self_beats, radius=250): 165 | 166 | # if DEBUG: 167 | # plt.figure() 168 | # for i, beat in enumerate(self_beats): 169 | # plt.subplot(5, 4, i+1) 170 | # plt.plot(beat[1]) 171 | # plt.show() 172 | 173 | # Pairwise compare with every other self beat 174 | all_distances = [] 175 | 176 | for i in range(len(self_beats)): 177 | distances = [] 178 | 179 | for j in range(len(self_beats)): 180 | if i != j: 181 | i_beat = self_beats[i][1] 182 | j_beat = self_beats[j][1] 183 | 184 | distance, path = fastdtw(normalize_sig(i_beat), normalize_sig(j_beat), radius=radius, dist=euclidean) 185 | distances.append(distance) 186 | 187 | all_distances.append(distances) 188 | 189 | return all_distances 190 | 191 | def get_kl_dist(distances): 192 | return [ val if val > 0 else 0.000001 for val in np.histogram(distances, bins=2000)[0] ] 193 | 194 | 195 | def get_baseline_metrics(metric, baseline_distances): 196 | top_level_distances = [] 197 | 198 | if metric == 'kl': 199 | flat_distances = [ item for sublist in baseline_distances for item in sublist ] 200 | flat_hist = get_kl_dist(flat_distances) 201 | 202 | for sublist in baseline_distances: 203 | sublist_hist = get_kl_dist(sublist) 204 | kl_distance = entropy(sublist_hist, flat_hist) 205 | top_level_distances.append(kl_distance) 206 | 207 | elif metric == 'min': 208 | top_level_distances = [ min(sublist) for sublist in baseline_distances ] 209 | 210 | elif metric == 'mean': 211 | top_level_distances = [ np.mean(sublist) for sublist in baseline_distances ] 212 | 213 | else: 214 | raise Exception("Unrecognized metric: ", metric) 215 | 216 | metric_info = [ np.mean(top_level_distances), np.std(top_level_distances) ] 217 | if metric == 'kl': 218 | metric_info.append(deepcopy(baseline_distances)) 219 | 220 | return metric_info 221 | 222 | 223 | def get_dtw_distances(beat_sig, self_beats, radius=250): 224 | distances = [] 225 | beat_sig_normalized = normalize_sig(beat_sig) 226 | 227 | # figure_num = 1 228 | 229 | for self_beat in self_beats: 230 | self_beat_normalized = normalize_sig(self_beat[1]) 231 | 232 | try: 233 | distance, path = fastdtw(beat_sig_normalized, self_beat_normalized, radius=radius, dist=euclidean) 234 | distances.append(distance) 235 | 236 | # plt.subplot(5, 4, figure_num) 237 | # plt.title(str(int(distance))) 238 | # plt.plot(self_beat_normalized, 'b-') 239 | # plt.plot(beat_sig_normalized, 'r-') 240 | # plt.axis('off') 241 | # figure_num += 1 242 | 243 | except Exception as e: 244 | print(e) 245 | 246 | # plt.show() 247 | return distances 248 | 249 | 250 | ## 251 | # Determine if ventricular beat is stdev or not 252 | # metric: string indicating metric ('kl', 'min', 'mean') 253 | # metric info: list of relevant metric info 254 | # if 'kl': [ mean, std, baseline_distances ] 255 | # else: [ mean, std ] 256 | ## 257 | def is_ventricular_beat_stdev(beat_sig, self_beats, metric, metric_info, threshold): 258 | plt.figure(figsize=[12, 8]) 259 | plt.title(str(metric_info[0]) + " " + str(metric_info[1])) 260 | beat_distances = get_dtw_distances(beat_sig, self_beats) 261 | 262 | if len(beat_distances) == 0: 263 | # TODO: maybe return false because probably contains inf/nan which is invalid data 264 | return True 265 | 266 | if metric == 'kl': 267 | baseline_distances = metric_info[2] 268 | flat_distances = [ item for sublist in baseline_distances for item in sublist ] 269 | flat_hist = get_kl_dist(flat_distances) 270 | beat_hist = get_kl_dist(beat_distances) 271 | 272 | metric_distance = entropy(beat_hist, flat_hist) 273 | 274 | elif metric == "min": 275 | metric_distance = min(beat_distances) 276 | 277 | elif metric == 'mean': 278 | metric_distance = np.mean(beat_distances) 279 | 280 | else: 281 | raise Exception("Unrecognized metric type: ", metric) 282 | 283 | dprint("distance: ", metric_distance, metric_distance > threshold) 284 | 285 | if metric_distance > threshold: 286 | return True 287 | 288 | return False 289 | 290 | 291 | ## 292 | # beats is a list of tuples containing: 293 | # annotation of beat QRS 294 | # start and end indices 295 | # sig of beat 296 | ## 297 | def get_ventricular_beats(beats, self_beats, metric, metric_info): 298 | ventricular_beats = [] 299 | nonventricular_beats = [] 300 | 301 | mean = metric_info[0] 302 | std = metric_info[1] 303 | 304 | # TODO: optimize hyperparameter STD_MULTIPLIER and MIN_DISTANCE_DIFF 305 | threshold = max(mean + std * STD_MULTIPLIER, mean + MIN_DISTANCE_DIFF) 306 | dprint("mean: ", metric_info[0], "std: ", metric_info[1], "threshold: ", threshold) 307 | 308 | for beat in beats: 309 | beat_sig = beat[1] 310 | 311 | if is_ventricular_beat_stdev(beat_sig, self_beats, metric, metric_info, threshold): 312 | ventricular_beats.append(beat) 313 | else: 314 | nonventricular_beats.append(beat) 315 | 316 | return ventricular_beats, nonventricular_beats 317 | 318 | 319 | ## 320 | # Returns beats (list of tuples): 321 | # annotation of beat QRS 322 | # start and end indices 323 | # sig of beat 324 | ## 325 | def get_alarm_beats(channel_sig, annotation): 326 | beats = [] 327 | for ann_index in range(1, len(annotation)-1): 328 | # Assumes a beat starts start_ratio (default 1/3) before the annotation 329 | # and ends end_ratio (default 2/3) after annotation 330 | # TODO: update this to update dynamically based on past values 331 | start_index, end_index = get_adjusted_ann_indices(annotation, ann_index) 332 | 333 | indices = (start_index, end_index) 334 | beat_sig = channel_sig[int(indices[0]):int(indices[1])] 335 | beat = (annotation[ann_index], beat_sig) 336 | 337 | if len(beat_sig) > MIN_PEAK_DIST and len(beat_sig) < MAX_PEAK_DIST: 338 | beats.append(beat) 339 | 340 | if DEBUG: 341 | plt.figure() 342 | for i, beat in enumerate(beats): 343 | plt.subplot(5, 4, i+1) 344 | plt.plot(beat[1]) 345 | plt.show() 346 | 347 | return beats 348 | 349 | 350 | ## 351 | # Plot histogram of all pairwise distances between self beatst 352 | ## 353 | def plot_metrics(baseline_distances, metric, metric_info): 354 | flat_distances = [ item for sublist in baseline_distances for item in sublist ] 355 | mean = metric_info[0] 356 | std = metric_info[1] 357 | multipliers = [0.5, 1, 2, 3, 4, 5] 358 | 359 | # Plot all flat distances with mean + std + various multipliers 360 | plt.figure() 361 | plt.hist(flat_distances, edgecolor='black') 362 | plt.axvline(mean, color='r') 363 | for multiplier in multipliers: 364 | plt.axvline(x=mean + std*multiplier, color='g') 365 | 366 | plt.show() 367 | 368 | 369 | # Plot individual distance distributions against flat distances 370 | plt.figure(figsize=[12, 8]) 371 | for index, distances in enumerate(baseline_distances): 372 | plt.subplot(5, 4, index+1) 373 | plt.hist(flat_distances, color='blue', edgecolor='black') 374 | plt.hist(distances, color='red', edgecolor='black') 375 | if metric == 'min': 376 | plt.axvline(x=min(distances), color='r') 377 | elif metric == 'mean': 378 | plt.axvline(x=np.mean(distances), color='r') 379 | 380 | plt.show() 381 | 382 | 383 | def plot_self_beat_comparison(self_beats): 384 | for i in range(len(self_beats)): 385 | plt.figure(figsize=[12, 8]) 386 | figure_num = 1 387 | 388 | for j in range(len(self_beats)): 389 | if i != j: 390 | i_beat = self_beats[i][1] 391 | j_beat = self_beats[j][1] 392 | 393 | plt.subplot(5, 4, figure_num) 394 | plt.plot(normalize_sig(i_beat), 'b-') 395 | plt.plot(normalize_sig(j_beat), 'r-') 396 | plt.axis('off') 397 | figure_num += 1 398 | 399 | plt.show() 400 | 401 | 402 | def filter_out_nan(beats): 403 | filtered = [] 404 | 405 | for beat in beats: 406 | beat_sig = beat[1] 407 | if not np.isnan(np.sum(beat_sig)): 408 | filtered.append(beat) 409 | 410 | return filtered 411 | 412 | 413 | def ventricular_beat_annotations_dtw( 414 | channel_sig, 415 | ann_path, 416 | sample_name, 417 | metric, 418 | start_time, 419 | end_time, 420 | ann_type, 421 | force=False, 422 | file_prefix=output_path_std_distances, 423 | ann_fs=250.): 424 | 425 | baseline_dist_filename = file_prefix + sample_name + ".json" 426 | 427 | dprint("Finding alarm beats...") 428 | annotation = get_annotation(ann_path + sample_name, ann_type, ann_fs, start_time, end_time).annsamp 429 | alarm_beats = get_alarm_beats(channel_sig, annotation) 430 | 431 | dprint("Finding self beats...") 432 | # Full annotation except for when the alarm signal starts (usually last 10 seconds) 433 | full_annotation = get_annotation(ann_path + sample_name, ann_type, ann_fs, 0, start_time).annsamp 434 | self_beats = get_best_self_beats(channel_sig, full_annotation, sample_name) 435 | 436 | if os.path.isfile(baseline_dist_filename) and not force: 437 | dprint("Loading baseline distances from file...") 438 | with open(baseline_dist_filename, 'r') as f: 439 | baseline_distances = json.load(f) 440 | else: 441 | dprint("Calculating baseline distances...") 442 | baseline_distances = get_baseline_distances(self_beats) 443 | 444 | dprint("Writing baseline distances to file...") 445 | with open(baseline_dist_filename, 'w') as f: 446 | json.dump(baseline_distances, f) 447 | 448 | try: 449 | dprint("Calculating baseline metrics...") 450 | metric_info = get_baseline_metrics(metric, baseline_distances) 451 | except Exception as e: 452 | print("sample_name: {}".format(sample_name)) 453 | print(e) 454 | return [], [] 455 | 456 | # plot_metrics(baseline_distances, metric, metric_info) 457 | # plot_self_beat_comparison(self_beats) 458 | 459 | dprint("Classifying alarm beats...") 460 | ventricular_beats, nonventricular_beats = get_ventricular_beats(alarm_beats, self_beats, metric, metric_info) 461 | vtach_beats = filter_out_nan(ventricular_beats) 462 | 463 | # Only find distances if ventricular beats were found 464 | if len(vtach_beats) > 1: 465 | ventricular_distances = get_baseline_distances(vtach_beats) 466 | ventricular_mean, ventricular_std = get_baseline_metrics('min', ventricular_distances) 467 | 468 | # If ventricular beats don't look very similar, mark as noise instead 469 | if ventricular_mean > 20 and ventricular_std > 15 and ventricular_mean > ventricular_std: 470 | vtach_beats = [] 471 | 472 | ventricular_beat_anns = [ beat[0] for beat in vtach_beats ] 473 | nonventricular_beat_anns = [ beat[0] for beat in nonventricular_beats ] 474 | 475 | return ventricular_beat_anns, nonventricular_beat_anns 476 | 477 | 478 | 479 | def write_vtach_beats_files( 480 | data_path, 481 | ann_path, 482 | output_path, 483 | ecg_ann_type, 484 | start_time, 485 | end_time, 486 | metric): 487 | 488 | for filename in os.listdir(data_path): 489 | if filename.endswith(HEADER_EXTENSION): 490 | sample_name = filename.rstrip(HEADER_EXTENSION) 491 | 492 | if sample_name[0] != 'v': 493 | continue 494 | 495 | sig, fields = wfdb.srdsamp(data_path + sample_name) 496 | if "II" not in fields['signame']: 497 | print("Lead II not found for sample: {}".format(sample_name)) 498 | continue 499 | 500 | output_filename = output_path + sample_name + "_1peak_" + metric + ".csv" 501 | 502 | if os.path.isfile(output_filename): 503 | continue 504 | 505 | channel_index = fields['signame'].index("II") 506 | ann_type = ecg_ann_type + str(channel_index) 507 | 508 | start = datetime.now() 509 | 510 | with open(output_filename, "w") as f: 511 | channel_sig = sig[:,channel_index] 512 | 513 | vtach, nonvtach = ventricular_beat_annotations_dtw(channel_sig, ann_path, sample_name, metric, start_time, end_time, ann_type) 514 | 515 | writer = csv.writer(f) 516 | writer.writerow(['ann_index', 'is_true_beat']) 517 | 518 | for beat in vtach: 519 | writer.writerow([beat, 1]) 520 | for beat in nonvtach: 521 | writer.writerow([beat, 0]) 522 | 523 | print("sample_name: {}".format(sample_name), end=" ") 524 | print(" elapsed: {}".format(datetime.now() - start)) 525 | 526 | def run_one_sample(): 527 | # sample_name = "v100s" # false alarm 528 | # sample_name = "v141l" # noisy at beginning 529 | # sample_name = "v159l" # quite clean 530 | # sample_name = "v206s" # high baseline 531 | # sample_name = "v143l" 532 | # sample_name = "v696s" 533 | sample_name = "v837l" 534 | metric = "min" 535 | channel_index = 0 536 | ann_fs = 250. 537 | ann_type = 'gqrs' + str(channel_index) 538 | 539 | sig, fields = wfdb.srdsamp(data_path + sample_name) 540 | channel_sig = sig[:,channel_index] 541 | 542 | vtach_beats, nonvtach_beats = ventricular_beat_annotations_dtw(channel_sig, ann_path, sample_name, metric, start_time, end_time, ann_type) 543 | 544 | plt.figure(figsize=[8,5]) 545 | plt.plot(channel_sig[int(start_time*250.):int(end_time*250.)],'b-') 546 | plt.plot([ int(index-250.*start_time) for index in nonvtach_beats ], [channel_sig[int(index)] for index in nonvtach_beats], 'bo', markersize=8) 547 | plt.plot([ int(index-250.*start_time) for index in vtach_beats ], [ channel_sig[int(index)] for index in vtach_beats ], 'ro', markersize=8) 548 | plt.show() 549 | 550 | 551 | 552 | if __name__ == '__main__': 553 | start_time = 290 554 | end_time = 300 555 | metric = 'min' 556 | write_vtach_beats_files(data_path, ann_path, output_path_std, ecg_ann_type, start_time, end_time, metric) 557 | 558 | 559 | # sig, fields = wfdb.rdsamp(data_path + sample_name) 560 | # channel_sig = sig[:,channel_index] 561 | 562 | # annotation = wfdb.rdann(ann_path + sample_name, ann_type, sampfrom=start*ann_fs, sampto=end*ann_fs).annsamp 563 | # print(annotation) 564 | 565 | # beats = get_beats(channel_sig, annotation) 566 | 567 | 568 | # for beat in beats: 569 | # indices = beat[0] 570 | # beat_sig = beat[1] 571 | # time_vector = np.linspace(indices[0], indices[1], len(beat_sig)) 572 | 573 | # whole_sig = channel_sig[250*start:250*end] 574 | # sig_time_vector = np.linspace(250*start, 250*end, len(whole_sig)) 575 | 576 | # annotation_y = [ channel_sig[ann_t] for ann_t in annotation ] 577 | 578 | # plt.figure() 579 | # plt.plot(sig_time_vector, whole_sig, 'b') 580 | # plt.plot(time_vector, beat_sig, 'r') 581 | # plt.plot(annotation, annotation_y, 'go') 582 | # plt.show() 583 | 584 | 585 | 586 | # print("") 587 | # print(annotation[0] / float(250.)) 588 | -------------------------------------------------------------------------------- /pyfar/baseline_algorithm.py: -------------------------------------------------------------------------------- 1 | # from ventricular_beat_detection import ventricular_beat_annotations_dtw 2 | from __future__ import print_function 3 | 4 | import scipy.signal as scipy_signal 5 | import scipy.fftpack as scipy_fftpack 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import parameters 9 | import csv 10 | import wfdb 11 | 12 | 13 | ############################## 14 | ##### Invalids detection ##### 15 | ############################## 16 | 17 | def band_pass_filter(signal, f_low, f_high, order, fs): 18 | window = scipy_signal.firwin(order+1, [f_low, f_high], nyq=np.floor(fs/2), pass_zero=False, 19 | window='hamming', scale=False) 20 | A = scipy_fftpack.fft(window, 2048) / (len(window)/2.0) 21 | freq = np.linspace(-0.5, 0.5, len(A)) 22 | response = 20 * np.log10(np.abs(scipy_fftpack.fftshift(A / abs(A).max()))) 23 | 24 | if np.size(signal) < 153: 25 | return 26 | return scipy_signal.filtfilt(window, 1, signal) 27 | 28 | 29 | def get_signal_fft(signal, signal_duration, fs): 30 | # Number of samplepoints 31 | N = signal_duration * fs 32 | # sample spacing 33 | T = 1.0 / fs 34 | 35 | xf = np.linspace(0.0, 1.0/(2.0*T), N/2) 36 | signal_fft = scipy_fftpack.fft(signal) 37 | 38 | return (xf, 2.0/N * np.abs(signal_fft[:int(N/2)])) 39 | 40 | 41 | # Check if amplitude within invalid range is above acceptable amplitudes 42 | def is_amplitude_within_cutoff(signal, f_low, f_high, cutoff, order, fs): 43 | filtered_signal = band_pass_filter(signal, f_low, f_high, order, fs) 44 | if filtered_signal is not None: 45 | # Return False if any value in the filtered_signal is greater than cutoff 46 | return not (filtered_signal > cutoff).any() 47 | return True 48 | 49 | 50 | # Check signal statistics to be within minimum and maximum values 51 | def check_stats_within_cutoff(signal, channel_type, stats_cutoffs): 52 | signal_min = np.amin(signal) 53 | signal_max = np.amax(signal) 54 | var_range = signal_max - signal_min 55 | channel_stats_cutoffs = stats_cutoffs[channel_type] 56 | 57 | # Check minimum and maximum signal values 58 | if signal_min < channel_stats_cutoffs["val_min"] or signal_max > channel_stats_cutoffs["val_max"]: 59 | return False 60 | 61 | # Check signal range in value 62 | if var_range > channel_stats_cutoffs["var_range_max"] or var_range < channel_stats_cutoffs["var_range_min"]: 63 | return False 64 | 65 | return True 66 | 67 | 68 | # Check if signal contains NaN values 69 | def contains_nan(signal): 70 | return np.isnan(signal).any() 71 | 72 | 73 | # Check borders between histogram buckets so the difference is within a cutoff value 74 | def histogram_test(signal, histogram_cutoff): 75 | top_percentile = np.percentile(signal, parameters.TOP_PERCENTILE) 76 | bottom_percentile = np.percentile(signal, parameters.BOTTOM_PERCENTILE) 77 | 78 | # Filter out top and bottom 1% for data on which to generate histogram 79 | adjusted_signal = signal[(signal >= bottom_percentile) & (signal <= top_percentile)] 80 | 81 | # Generate histogram with 10 buckets by default 82 | histogram = np.histogram(adjusted_signal)[0] 83 | 84 | # Calculate frequency diffs between neighboring buckets and return True if all diffs within cutoff 85 | diffs = np.diff(histogram) 86 | return not (diffs > histogram_cutoff).any() 87 | 88 | 89 | def get_channel_type(channel_name): 90 | channel_types_dict = {} 91 | with open(parameters.sigtypes_filename, "r") as f: 92 | for line in f: 93 | splitted_line = line.split("\t") 94 | channel = splitted_line[-1].rstrip() 95 | channel_type = splitted_line[0] 96 | channel_types_dict[channel] = channel_type 97 | 98 | if channel_name in channel_types_dict.keys(): 99 | return channel_types_dict[channel_name] 100 | 101 | raise Exception("Unknown channel name") 102 | 103 | 104 | # Return list of channel indices for channels of type channel_type 105 | def get_channels_of_type(channels, channel_type): 106 | channel_indices = np.array([]) 107 | 108 | for channel_index in range(len(channels)): 109 | channel_name = channels[channel_index] 110 | if channel_type == get_channel_type(channel_name): 111 | channel_indices = np.append(channel_indices, channel_index) 112 | 113 | return channel_indices 114 | 115 | 116 | # Get start and end points (in seconds) to check depending on type of alarm signaled 117 | def get_start_and_end(fields): 118 | alarm_type = fields['comments'][0] 119 | if alarm_type not in parameters.TESTED_BLOCK_LENGTHS: 120 | raise Exception("Unrecognized alarm type") 121 | 122 | tested_block_length = parameters.TESTED_BLOCK_LENGTHS[alarm_type] 123 | 124 | end = parameters.ALARM_TIME # in seconds, alarm always sounded at 300th second 125 | start = end - tested_block_length # in seconds 126 | 127 | return (start, end, tested_block_length) 128 | 129 | # Returns whether signal is valid or not 130 | def is_valid(signal, channel_type, f_low, f_high, histogram_cutoff, freq_amplitude_cutoff, stats_cutoffs, order, fs): 131 | if channel_type == "Resp": 132 | return True 133 | 134 | # Checks which return True if passing the test, False if not 135 | histogram_check = histogram_test(signal, histogram_cutoff) 136 | stats_check = check_stats_within_cutoff(signal, channel_type, stats_cutoffs) 137 | nan_check = not contains_nan(signal) 138 | checks = np.array([histogram_check, stats_check, nan_check]) 139 | 140 | # If ECG signal, also check signal amplitude in frequency range within limits 141 | if channel_type == "ECG": 142 | signal_amplitude_check = is_amplitude_within_cutoff(signal, f_low, f_high, freq_amplitude_cutoff, order, fs) 143 | checks = np.append(checks, signal_amplitude_check) 144 | 145 | return all(checks) 146 | 147 | 148 | # Return invalids list given sig for a single channel 149 | def calculate_channel_invalids(channel_sig, 150 | channel_type, 151 | fs=parameters.DEFAULT_ECG_FS, 152 | block_length=parameters.BLOCK_LENGTH, 153 | order=parameters.ORDER, 154 | f_low=parameters.F_LOW, 155 | f_high=parameters.F_HIGH, 156 | hist_cutoff=parameters.HIST_CUTOFF, 157 | ampl_cutoff=parameters.AMPL_CUTOFF, 158 | stats_cutoffs=parameters.STATS_CUTOFFS): 159 | invalids = np.array([]) 160 | start = 0 # in sample number 161 | 162 | # Check validity of signal for each block_length-long block 163 | while start < len(channel_sig): 164 | signal = channel_sig[int(start):int(start + block_length*fs)] 165 | start += (block_length * fs) 166 | 167 | is_data_valid = is_valid(signal, channel_type, f_low, f_high, hist_cutoff, ampl_cutoff, stats_cutoffs, order, fs) 168 | 169 | if is_data_valid: 170 | invalids = np.append(invalids, 0) 171 | else: 172 | invalids = np.append(invalids, 1) 173 | 174 | return invalids 175 | 176 | 177 | # Returns invalids dictionary mapping each channel to an invalids array representing validity of 0.8 second blocks 178 | # Takes in sig and fields after already reading the sample file 179 | def calculate_invalids_sig(sig, fields, 180 | start=None, 181 | end=None, 182 | block_length=parameters.BLOCK_LENGTH, 183 | order=parameters.ORDER, 184 | f_low=parameters.F_LOW, 185 | f_high=parameters.F_HIGH, 186 | hist_cutoff=parameters.HIST_CUTOFF, 187 | ampl_cutoff=parameters.AMPL_CUTOFF, 188 | stats_cutoffs=parameters.STATS_CUTOFFS): 189 | 190 | channels = fields['signame'] 191 | fs = fields['fs'] 192 | if start is None or end is None: 193 | start, end, alarm_duration = get_start_and_end(fields) 194 | window_start, window_end = start * fs, end * fs # in sample number 195 | 196 | invalids = {} 197 | 198 | # Generate invalids array for each channel 199 | for channel_num in range(len(channels)): 200 | start = window_start 201 | channel_name = channels[channel_num] 202 | channel_type = get_channel_type(channel_name) 203 | channel_sig = sig[:,channel_num] 204 | 205 | invalids_array = calculate_channel_invalids(channel_sig, channel_type) 206 | invalids[channel_name] = invalids_array 207 | 208 | return invalids 209 | 210 | 211 | # Calculate overall c_val of invalids list for a single channel (0 = invalid, 1 = valid) 212 | def calculate_cval_channel(channel_invalids): 213 | if len(channel_invalids) > 0: 214 | return 1 - float(sum(channel_invalids)) / len(channel_invalids) 215 | return None 216 | 217 | 218 | ####################### 219 | ##### Annotations ##### 220 | ####################### 221 | 222 | # Get annotation file type based on channel type and index 223 | def get_ann_type(channel, channel_index, ecg_ann_type): 224 | channel_type = get_channel_type(channel) 225 | if channel_type == "Resp": 226 | return "" 227 | 228 | if ecg_ann_type == "fp": 229 | return ann_type_fplesinger(channel_index) 230 | 231 | else: 232 | return ann_type_qrs(channel_type, channel_index, ecg_ann_type) 233 | 234 | 235 | # Get annotation file type for fplesinger ann files 236 | def ann_type_fplesinger(channel_index): 237 | return "fp" + str(channel_index) 238 | 239 | 240 | # Get annotation file type for non-fplesinger ann files 241 | def ann_type_qrs(channel_type, channel_index, ecg_ann_type): 242 | if channel_type == "ECG": 243 | ann_type = ecg_ann_type + str(channel_index) 244 | elif channel_type == "BP": 245 | ann_type = 'wabp' 246 | elif channel_type == "PLETH": 247 | ann_type = "wpleth" 248 | elif channel_type == "Resp": 249 | ann_type = "" 250 | else: 251 | raise Exception("Unrecognized ann type") 252 | 253 | return ann_type 254 | 255 | 256 | def get_ann_fs(channel_type, ecg_ann_type): 257 | if channel_type == "ECG" or ecg_ann_type.startswith("fp"): 258 | return parameters.DEFAULT_ECG_FS 259 | return parameters.DEFAULT_OTHER_FS 260 | 261 | 262 | # start and end in seconds 263 | def get_annotation(sample, ann_type, ann_fs, start, end): 264 | try: 265 | annotation = wfdb.rdann(sample, ann_type, sampfrom=start*ann_fs, sampto=end*ann_fs) 266 | except Exception as e: 267 | annotation = wfdb.Annotation(sample, ann_type, [], []) 268 | print("Error getting annotation for sample ", sample, ann_type, e) 269 | 270 | return annotation 271 | 272 | 273 | # start and end in seconds 274 | def get_channel_rr_intervals(ann_path, sample_name, channel_index, fields, ecg_ann_type, start=None, end=None): 275 | if start is None or end is None: 276 | # Start and end given in seconds 277 | start, end, alarm_duration = get_start_and_end(fields) 278 | 279 | channels = fields['signame'] 280 | channel = channels[channel_index] 281 | channel_type = get_channel_type(channel) 282 | channel_rr_intervals = np.array([]) 283 | 284 | ann_type = get_ann_type(channel, channel_index, ecg_ann_type) 285 | 286 | try: 287 | ann_fs = get_ann_fs(channel_type, ecg_ann_type) 288 | annotation = get_annotation(ann_path + sample_name, ann_type, ann_fs, start, end) 289 | 290 | # Convert annotations sample numbers into seconds if >0 annotations in signal 291 | if len(annotation.annsamp) > 0: 292 | ann_seconds = np.array(annotation.annsamp) / float(ann_fs) 293 | else: 294 | return np.array([0.0]) 295 | 296 | for index in range(1, np.size(ann_seconds)): 297 | channel_rr_intervals = np.append(channel_rr_intervals, round(ann_seconds[index] - ann_seconds[index - 1], 4)) 298 | 299 | except Exception as e: 300 | print("Error getting channel RR intervals for sample", sample_name, e) 301 | 302 | return channel_rr_intervals 303 | 304 | 305 | # Start and end given in seconds 306 | def get_rr_dict(ann_path, sample_name, fields, ecg_ann_type, start=None, end=None): 307 | rr_dict = {} 308 | if start is None or end is None: 309 | # Start and end given in seconds 310 | start, end, alarm_duration = get_start_and_end(fields) 311 | 312 | channels = fields['signame'] 313 | for channel_index in range(len(channels)): 314 | channel_name = channels[channel_index] 315 | channel_type = get_channel_type(channel_name) 316 | if channel_type == "Resp": 317 | continue 318 | 319 | rr_intervals = get_channel_rr_intervals(ann_path, sample_name, channel_index, fields, ecg_ann_type, start, end) 320 | 321 | rr_dict[channel_name] = rr_intervals 322 | 323 | return rr_dict 324 | 325 | 326 | ############################ 327 | ##### Regular activity ##### 328 | ############################ 329 | 330 | # Check if standard deviation of RR intervals of signal are within limits 331 | def check_rr_stdev(rr_intervals): 332 | stdev = np.std(rr_intervals) 333 | 334 | if stdev > parameters.RR_STDEV: 335 | return False 336 | return True 337 | 338 | # Check if heart rate, calculated by number of RR intervals in signal, are within limits 339 | def check_heart_rate(rr_intervals, alarm_duration): 340 | hr = (len(rr_intervals) + 1.) / alarm_duration * parameters.NUM_SECS_IN_MIN 341 | 342 | if hr > parameters.HR_MAX or hr < parameters.HR_MIN: 343 | return False 344 | return True 345 | 346 | # Check if sum of RR intervals is within limit of total duration, to ensure beats are evenly spaced throughout 347 | def check_sum_rr_intervals(rr_intervals, alarm_duration): 348 | min_sum = alarm_duration - parameters.RR_MIN_SUM_DIFF 349 | 350 | rr_sum = sum(rr_intervals) 351 | 352 | if rr_sum < min_sum: 353 | return False 354 | return True 355 | 356 | # Check if total number of RR intervals is greater than a minimum 357 | def check_num_rr_intervals(rr_intervals): 358 | if len(rr_intervals) < parameters.MIN_NUM_RR_INTERVALS: 359 | return False 360 | return True 361 | 362 | 363 | # Returns False if any block within signal is identified as invalid (invalid sample detection) 364 | def check_invalids(invalids, channel): 365 | if channel not in invalids.keys(): 366 | raise Exception("Unknown channel") 367 | 368 | block_invalids_sum = sum(invalids[channel]) 369 | if block_invalids_sum > parameters.INVALIDS_SUM: 370 | return False 371 | return True 372 | 373 | # Returns True for a given channel if all regular activity tests checked pass 374 | def check_interval_regular_activity(rr_intervals, invalids, alarm_duration, channel): 375 | all_checks = np.array([]) 376 | 377 | # If the RR intervals should be checked but all annotations missing, auto fail 378 | if len(rr_intervals) == 0: 379 | return False 380 | 381 | stdev_check = check_rr_stdev(rr_intervals) 382 | hr_check = check_heart_rate(rr_intervals, alarm_duration) 383 | sum_check = check_sum_rr_intervals(rr_intervals, alarm_duration) 384 | num_check = check_num_rr_intervals(rr_intervals) 385 | invalids_check = check_invalids(invalids, channel) 386 | 387 | all_checks = np.append(all_checks, [stdev_check, hr_check, sum_check, num_check, invalids_check]) 388 | 389 | return np.all(all_checks) 390 | 391 | 392 | # Determines regular activity of sample based on RR intervals and invalids array: 393 | # param: rr_dict as a dictionary of the form: 394 | # { channel0: [rr_intervals], channel1: [rr_intervals], ...} 395 | # param: alarm_duration duration of alarm in seconds 396 | def is_rr_invalids_regular(rr_dict, invalids, alarm_duration, arrhythmia_type, 397 | should_check_invalids=True, should_check_rr=True, should_num_check=True): 398 | 399 | for channel in rr_dict.keys(): 400 | channel_type = get_channel_type(channel) 401 | 402 | if arrhythmia_type == "Ventricular_Flutter_Fib" and channel_type != "ECG": 403 | continue 404 | 405 | rr_intervals = rr_dict[channel] 406 | is_regular = check_interval_regular_activity(rr_intervals, invalids, alarm_duration, channel) 407 | 408 | # If any channel is regular, reject alarm as false alarm 409 | if is_regular: 410 | return True 411 | return False 412 | 413 | 414 | # Check overall sample for regular activity by iterating through each channel. 415 | # If any channel exhibits regular activity, alarm indicated as false alarm. 416 | def is_sample_regular(data_path, 417 | ann_path, 418 | sample_name, 419 | ecg_ann_type, 420 | start=None, 421 | end=None, 422 | verbose=False): 423 | sig, fields = wfdb.srdsamp(data_path + sample_name) 424 | channels = fields['signame'] 425 | nonresp_channels = [ channels.index(channel) for channel in channels if channel != "RESP" ] 426 | 427 | if start is None or end is None: 428 | start, end, alarm_duration = get_start_and_end(fields) 429 | else: 430 | alarm_duration = end - start 431 | 432 | # try: 433 | # invalids = {} 434 | # for channel_index in nonresp_channels: 435 | # channel = channels[channel_index] 436 | 437 | # with open(ann_path + sample_name + "-invalids.csv", "r") as f: 438 | # reader = csv.reader(f) 439 | # channel_invalids = [ int(float(row[channel_index])) for row in reader] 440 | # invalids[channel] = channel_invalids[start*250:end*250] 441 | 442 | # except Exception as e: 443 | # print("Error finding invalids for sample " + sample_name, e) 444 | # invalids = calculate_invalids_sig(sig, fields, start, end) 445 | 446 | invalids = calculate_invalids_sig(sig, fields, start, end) 447 | 448 | for channel_index in range(len(channels)): 449 | channel = channels[channel_index] 450 | channel_type = get_channel_type(channel) 451 | 452 | # Ignore respiratory channel 453 | if channel_type == "Resp": 454 | continue 455 | 456 | alarm_prefix = sample_name[0] 457 | # Only use ECG channels for ventricular fib 458 | if alarm_prefix == "f": 459 | if channel_type != "ECG": 460 | continue 461 | 462 | rr = get_channel_rr_intervals(ann_path, sample_name, channel_index, fields, ecg_ann_type) 463 | 464 | is_regular = check_interval_regular_activity(rr, invalids, alarm_duration, channel) 465 | 466 | # If any channel exhibits regular activity, deem signal as regular activity 467 | if is_regular: 468 | return True 469 | return False 470 | 471 | 472 | 473 | ################################ 474 | ##### Specific arrhythmias ##### 475 | ################################ 476 | 477 | def test_asystole(data_path, ann_path, sample_name, ecg_ann_type, verbose=False): 478 | sig, fields = wfdb.srdsamp(data_path + sample_name) 479 | channels = fields['signame'] 480 | fs = fields['fs'] 481 | 482 | # Start and end given in seconds 483 | start, end, alarm_duration = get_start_and_end(fields) 484 | current_start = start 485 | current_end = current_start + parameters.ASYSTOLE_WINDOW_SIZE 486 | 487 | max_score = 0 488 | 489 | while current_end < end: 490 | start_index, end_index = int(current_start*fs), int(current_end*fs) 491 | subsig = sig[start_index:end_index,:] 492 | summed_asystole_score = calc_summed_asystole_score(ann_path, sample_name, subsig, channels, ecg_ann_type, 493 | current_start, current_end, verbose) 494 | max_score = max(max_score, summed_asystole_score) 495 | 496 | current_start += parameters.ASYSTOLE_ROLLING_INCREMENT 497 | current_end = current_start + parameters.ASYSTOLE_WINDOW_SIZE 498 | 499 | if verbose: 500 | print(sample_name + " has max asystole score: " + str(max_score)) 501 | 502 | return max_score > 0 503 | 504 | 505 | def test_bradycardia(data_path, ann_path, sample_name, ecg_ann_type, verbose=False): 506 | sig, fields = wfdb.srdsamp(data_path + sample_name) 507 | channels = fields['signame'] 508 | 509 | # Start and end given in seconds 510 | start, end, alarm_duration = get_start_and_end(fields) 511 | 512 | rr_intervals_list = get_rr_intervals_list(ann_path, sample_name, ecg_ann_type, fields, start, end) 513 | best_channel_rr = find_best_channel(rr_intervals_list, alarm_duration) 514 | min_hr = get_min_hr(best_channel_rr, parameters.BRADYCARDIA_NUM_BEATS) 515 | 516 | if verbose: 517 | print(sample_name + " with min HR: " + str(min_hr)) 518 | 519 | return min_hr < parameters.BRADYCARDIA_HR_MIN 520 | 521 | 522 | def test_tachycardia(data_path, ann_path, sample_name, ecg_ann_type, verbose=False): 523 | sig, fields = wfdb.srdsamp(data_path + sample_name) 524 | channels = fields['signame'] 525 | 526 | # Start and end given in s#econds 527 | start, end, alarm_duration = get_start_and_end(fields) 528 | 529 | rr_intervals_list = get_rr_intervals_list(ann_path, sample_name, ecg_ann_type, fields, start, end) 530 | if check_tachycardia_channel(rr_intervals_list, alarm_duration): 531 | return True 532 | 533 | best_channel_rr = find_best_channel(rr_intervals_list, alarm_duration) 534 | max_hr = get_max_hr(best_channel_rr, parameters.TACHYCARDIA_NUM_BEATS) 535 | 536 | if verbose: 537 | print(sample_name + " with max HR: " + str(max_hr)) 538 | 539 | return max_hr > parameters.TACHYCARDIA_HR_MAX 540 | 541 | 542 | def test_ventricular_tachycardia(data_path, 543 | ann_path, 544 | sample_name, 545 | ecg_ann_type, 546 | verbose=False, 547 | fs=parameters.DEFAULT_ECG_FS, 548 | order=parameters.ORDER, 549 | num_beats=parameters.VTACH_NUM_BEATS, 550 | std_threshold=parameters.VTACH_ABP_THRESHOLD, 551 | window_size=parameters.VTACH_WINDOW_SIZE, 552 | rolling_increment=parameters.VTACH_ROLLING_INCREMENT): 553 | 554 | sig, fields = wfdb.srdsamp(data_path + sample_name) 555 | channels = fields['signame'] 556 | 557 | # Start and end given in seconds 558 | start_time, end_time, alarm_duration = get_start_and_end(fields) 559 | alarm_sig = sig[int(start_time*fs):int(end_time*fs),:] 560 | 561 | ecg_channels = get_channels_of_type(channels, "ECG") 562 | abp_channels = get_channels_of_type(channels, "BP") 563 | 564 | # Initialize R vector 565 | size = int((alarm_duration - window_size) / rolling_increment) + 1 566 | r_vector = [0.] * size 567 | 568 | # index = int(channels.index("II")) 569 | # ann_type = get_ann_type("II", index, ecg_ann_type) 570 | # r_delta = get_ventricular_beats_scores(alarm_sig[:,int(index)], ann_path, sample_name, ann_type, start_time, end_time, "II") 571 | # r_vector = r_vector + r_delta 572 | 573 | # Adjust R vector based on ventricular beats in signal 574 | for channel_index in ecg_channels: 575 | index = int(channel_index) 576 | channel_name = channels[index] 577 | ann_type = get_ann_type(channel_name, index, ecg_ann_type) 578 | 579 | r_delta = get_ventricular_beats_scores(alarm_sig[:,int(index)], ann_path, sample_name, ann_type, start_time, end_time, channel_name) 580 | r_vector = r_vector + r_delta 581 | 582 | # if verbose: 583 | # channel_sig = alarm_sig[:,index] 584 | # lf, sub = get_lf_sub(channel_sig, order) 585 | # ventricular_beats = ventricular_beat_annotations(lf, sub, ann_path + sample_name, ann_type, start_time, end_time, verbose) 586 | # max_hr = max_ventricular_hr(ventricular_beats, num_beats, fs) 587 | # print(str(sample_name) + " on channel " + str(channels[int(channel_index)]) + " with max ventricular HR: ", str(max_hr)) 588 | 589 | return any([ r_value > 0 for r_value in r_vector ]) 590 | 591 | 592 | def test_ventricular_flutter_fibrillation(data_path, 593 | ann_path, 594 | sample_name, 595 | ecg_ann_type, 596 | verbose=False, 597 | fs=parameters.DEFAULT_ECG_FS, 598 | ann_fs=parameters.DEFAULT_ECG_FS, 599 | std_threshold=parameters.VFIB_ABP_THRESHOLD, 600 | window_size=parameters.VFIB_WINDOW_SIZE, 601 | rolling_increment=parameters.VFIB_ROLLING_INCREMENT): 602 | sig, fields = wfdb.srdsamp(data_path + sample_name) 603 | channels = fields['signame'] 604 | 605 | # Start and end given in seconds 606 | start, end, alarm_duration = get_start_and_end(fields) 607 | alarm_sig = sig[int(start*fs):int(end*fs),:] 608 | 609 | ecg_channels = get_channels_of_type(channels, "ECG") 610 | abp_channels = get_channels_of_type(channels, "BP") 611 | 612 | # Find max duration of low frequency signal from all channels 613 | dlfmax = 0 614 | for channel_index in ecg_channels: 615 | channel_index = int(channel_index) 616 | channel_dlfmax = calculate_dlfmax(alarm_sig[:,channel_index]) 617 | dlfmax = max(dlfmax, channel_dlfmax) 618 | 619 | # Initialize R vector to a value based on the D_lfmax (duration of low frequency) 620 | if dlfmax > parameters.VFIB_DLFMAX_LIMIT: 621 | r_vector_value = 1. 622 | else: 623 | r_vector_value = 0. 624 | size = int((alarm_duration - window_size) / rolling_increment) + 1 625 | r_vector = [r_vector_value] * size 626 | 627 | # Adjust R vector based on whether standard deviation of ABP channel is > or < the threshold 628 | for channel_index in abp_channels: 629 | r_delta = get_abp_std_scores(alarm_sig[:,int(channel_index)], std_threshold, window_size, rolling_increment) 630 | r_vector = r_vector + r_delta 631 | 632 | # Adjust R vector based on dominant frequency in signal 633 | for channel_index in ecg_channels: 634 | channel_index = int(channel_index) 635 | 636 | dominant_freqs = get_dominant_freq_array(alarm_sig[:,channel_index]) 637 | regular_activity = get_regular_activity_array(alarm_sig, fields, ann_path, sample_name, ecg_ann_type) 638 | adjusted_dominant_freqs = adjust_dominant_freqs(dominant_freqs, regular_activity) 639 | 640 | new_r_vector = np.array([]) 641 | for dominant_freq, r_value in zip(adjusted_dominant_freqs, r_vector): 642 | if dominant_freq < parameters.VFIB_DOMINANT_FREQ_THRESHOLD: 643 | new_r_vector = np.append(new_r_vector, 0.) 644 | else: 645 | new_r_vector = np.append(new_r_vector, r_value) 646 | 647 | r_vector = new_r_vector 648 | 649 | return any([ r_value > 0 for r_value in r_vector ]) 650 | 651 | 652 | ############################################### 653 | ##### Specific arrhythmias - helper funcs ##### 654 | ############################################### 655 | 656 | def calc_summed_asystole_score(ann_path, 657 | sample_name, 658 | subsig, 659 | channels, 660 | ecg_ann_type, 661 | current_start, 662 | current_end, 663 | verbose=False, 664 | data_fs=parameters.DEFAULT_ECG_FS): 665 | summed_score = 0 666 | 667 | for channel_index, channel in zip(range(len(channels)), channels): 668 | channel_type = get_channel_type(channel) 669 | if channel_type == "Resp": 670 | continue 671 | 672 | channel_subsig = subsig[:,channel_index] 673 | invalids = calculate_channel_invalids(channel_subsig, channel_type) 674 | cval = calculate_cval_channel(invalids) 675 | 676 | ann_type = get_ann_type(channel, channel_index, ecg_ann_type) 677 | ann_fs = get_ann_fs(channel_type, ecg_ann_type) 678 | 679 | annotation = get_annotation(ann_path + sample_name, ann_type, ann_fs, current_start, current_end) 680 | 681 | if len(annotation.annsamp) > 0: 682 | current_score = -cval 683 | else: 684 | current_score = cval 685 | 686 | if verbose: 687 | plt.figure(figsize=[7,5]) 688 | plt.plot(channel_subsig, 'g-') 689 | annotation_seconds = annotation.annsamp / float(ann_fs) 690 | ann_x = [ (seconds - current_start) * data_fs for seconds in annotation_seconds ] 691 | ann_y = [ channel_subsig[index] for index in ann_x ] 692 | plt.plot(ann_x, ann_y, 'bo', markersize=8) 693 | plt.show() 694 | 695 | print(sample_name + ": " + channel + " [" + str(current_start) + ", " + str(current_end) + "] " + str(current_score)) 696 | 697 | summed_score += current_score 698 | 699 | return summed_score 700 | 701 | 702 | def get_rr_intervals_list(ann_path, sample_name, ecg_ann_type, fields, start, end): 703 | channels = fields['signame'] 704 | 705 | rr_intervals_list = [] 706 | 707 | for channel_index in range(len(channels)): 708 | channel_name = channels[channel_index] 709 | channel_type = get_channel_type(channel_name) 710 | if channel_type == "Resp": 711 | continue 712 | 713 | rr_intervals = get_channel_rr_intervals(ann_path, sample_name, channel_index, fields, ecg_ann_type, start, end) 714 | rr_intervals_list.append(rr_intervals) 715 | 716 | return rr_intervals_list 717 | 718 | 719 | # Precondition: len(rr_intervals_list) > 0 720 | # Return RR intervals with the min stdev of all the RR intervals in the list 721 | def min_stdev_rr_intervals(rr_intervals_list): 722 | opt_rr_intervals = [] 723 | min_stdev = float('inf') 724 | 725 | for rr_intervals in rr_intervals_list: 726 | stdev = np.std(rr_intervals) 727 | if stdev < min_stdev: 728 | opt_rr_intervals = rr_intervals 729 | min_stdev = stdev 730 | 731 | return opt_rr_intervals 732 | 733 | 734 | # Best channel: minimum stdev with acceptable RR intervals sum and count 735 | # If none with acceptable RR interval sum and count --> select minimum stdev out of all RR intervals 736 | def find_best_channel(rr_intervals_list, alarm_duration): 737 | count_and_sum = [] 738 | only_one_test = [] 739 | for rr_intervals in rr_intervals_list: 740 | sum_check = check_sum_rr_intervals(rr_intervals, alarm_duration) 741 | num_check = check_num_rr_intervals(rr_intervals) 742 | 743 | if sum_check and num_check: 744 | count_and_sum.append(rr_intervals) 745 | 746 | elif sum_check or num_check: 747 | only_one_test.append(rr_intervals) 748 | 749 | if len(count_and_sum) > 0: 750 | return min_stdev_rr_intervals(count_and_sum) 751 | 752 | if len(only_one_test) > 0: 753 | return min_stdev_rr_intervals(only_one_test) 754 | 755 | return min_stdev_rr_intervals(rr_intervals_list) 756 | 757 | 758 | def get_min_hr(rr_intervals, num_beats_per_block): 759 | min_hr = float('inf') 760 | 761 | for index in range(num_beats_per_block, len(rr_intervals)): 762 | subsection = rr_intervals[index - num_beats_per_block:index] 763 | hr = float(num_beats_per_block) / sum(subsection) * parameters.NUM_SECS_IN_MIN 764 | min_hr = min(min_hr, hr) 765 | 766 | return min_hr 767 | 768 | 769 | def check_tachycardia_channel(rr_intervals_list, alarm_duration): 770 | for rr_intervals in rr_intervals_list: 771 | stdev_check = check_rr_stdev(rr_intervals) 772 | sum_check = check_sum_rr_intervals(rr_intervals, alarm_duration) 773 | hr_check = check_heart_rate(rr_intervals, alarm_duration) 774 | if stdev_check and sum_check and not hr_check: 775 | return True 776 | 777 | return False 778 | 779 | 780 | def get_max_hr(rr_intervals, num_beats_per_block): 781 | max_hr = -float('inf') 782 | 783 | for index in range(num_beats_per_block, len(rr_intervals)): 784 | subsection = rr_intervals[index - num_beats_per_block:index] 785 | hr = float(num_beats_per_block) / sum(subsection) * parameters.NUM_SECS_IN_MIN 786 | max_hr = max(max_hr, hr) 787 | 788 | return max_hr 789 | 790 | 791 | def hilbert_transform(x, fs, f_low, f_high, demod=False): 792 | N = len(x) 793 | f = scipy_fftpack.fft(x, n=N) 794 | i_high = int(np.floor(float(f_high)/fs*N)) 795 | i_low = int(np.floor(float(f_low)/fs*N)) 796 | win = scipy_signal.hamming( i_high - i_low ) 797 | 798 | f[0:i_low] = 0 799 | f[i_low:i_high] = f[i_low:i_high]*win 800 | f[i_high+1:] = 0 801 | 802 | if demod==True: 803 | # demodulate the signal, i.e. shift the freq spectrum to 0 804 | i_mid = int( np.floor((i_high+i_low)/2.0) ) 805 | f = np.concatenate( [f[i_mid:i_high], np.zeros(len(f)-(i_high-i_low)), f[i_low:i_mid] ] ) 806 | 807 | return 2*np.abs(scipy_fftpack.ifft(f, n=N)) 808 | 809 | 810 | def get_lf_sub(channel_sig, order): 811 | lf = abs(hilbert_transform(channel_sig, parameters.DEFAULT_ECG_FS, parameters.LF_LOW, parameters.LF_HIGH)) 812 | mf = abs(hilbert_transform(channel_sig, parameters.DEFAULT_ECG_FS, parameters.MF_LOW, parameters.MF_HIGH)) 813 | hf = abs(hilbert_transform(channel_sig, parameters.DEFAULT_ECG_FS, parameters.HF_LOW, parameters.HF_HIGH)) 814 | sub = mf - hf 815 | 816 | return lf, sub 817 | 818 | 819 | # Return list of ventricular beats for ECG channels 820 | def ventricular_beat_annotations(lf_subsig, sub_subsig, sample, ann_type, start_time, end_time, 821 | verbose=False, 822 | fs=parameters.DEFAULT_ECG_FS, 823 | threshold_ratio=parameters.VENTRICULAR_BEAT_THRESHOLD_RATIO, 824 | ann_fs=parameters.DEFAULT_ECG_FS): 825 | annotation = get_annotation(sample, ann_type, ann_fs, start_time, end_time) 826 | 827 | single_peak_indices = [ index - ann_fs * start_time for index in annotation.annsamp ] 828 | 829 | ventricular_beat_indices = np.array([]) 830 | nonventricular_beat_indices = np.array([]) 831 | 832 | for index in single_peak_indices: 833 | if index >= len(lf_subsig) or index >= len(sub_subsig): 834 | continue 835 | 836 | index = int(index) 837 | if lf_subsig[index] > sub_subsig[index]: 838 | ventricular_beat_indices = np.append(ventricular_beat_indices, index) 839 | else: 840 | nonventricular_beat_indices = np.append(nonventricular_beat_indices, index) 841 | 842 | if verbose: 843 | plt.figure(figsize=[8,5]) 844 | plt.plot(sub_subsig,'b-') 845 | plt.plot(lf_subsig,'r-') 846 | plt.plot(nonventricular_beat_indices, [sub_subsig[int(index)] for index in nonventricular_beat_indices], 'bo', markersize=8) 847 | plt.plot(ventricular_beat_indices, [ lf_subsig[int(index)] for index in ventricular_beat_indices ], 'ro', markersize=8) 848 | plt.show() 849 | 850 | return ventricular_beat_indices, nonventricular_beat_indices 851 | 852 | 853 | def max_ventricular_hr(ventricular_beats, num_beats, fs): 854 | max_hr = 0 855 | 856 | if len(ventricular_beats) < num_beats: 857 | return max_hr 858 | 859 | for index in range(num_beats-1, len(ventricular_beats)): 860 | sublist = ventricular_beats[index-num_beats+1:index] 861 | start_time = ventricular_beats[index-num_beats+1] / fs 862 | end_time = ventricular_beats[index] / fs 863 | 864 | hr = (num_beats-1) / (end_time - start_time) * parameters.NUM_SECS_IN_MIN 865 | max_hr = max(hr, max_hr) 866 | 867 | return max_hr 868 | 869 | 870 | ##### Modify this method to run regular baseline algorithm, using std vs. bank ventricular beat annotations 871 | def read_ventricular_beat_annotations(sample_name, metric, output_path=parameters.output_path_bank): 872 | ventricular_beats = [] 873 | nonventricular_beats = [] 874 | 875 | try: 876 | with open(output_path + sample_name + "_" + metric + ".csv", 'r') as f: 877 | reader = csv.DictReader(f) 878 | 879 | for row in reader: 880 | if row['is_true_beat'] == '1': 881 | ventricular_beats.append(int(row['ann_index'])) 882 | else: 883 | nonventricular_beats.append(int(row['ann_index'])) 884 | 885 | except Exception: 886 | with open(output_path + sample_name + ".csv", 'r') as f: 887 | reader = csv.DictReader(f) 888 | 889 | for row in reader: 890 | if row['is_true_beat'] == '1': 891 | ventricular_beats.append(int(row['ann_index'])) 892 | else: 893 | nonventricular_beats.append(int(row['ann_index'])) 894 | 895 | return ventricular_beats, nonventricular_beats 896 | 897 | 898 | def get_ventricular_beats_scores(channel_sig, 899 | ann_path, 900 | sample_name, 901 | ann_type, 902 | initial_start_time, 903 | initial_end_time, 904 | channel_name, 905 | fs=parameters.DEFAULT_ECG_FS, 906 | order=parameters.ORDER, 907 | max_hr_threshold=parameters.VTACH_MAX_HR, 908 | num_beats=parameters.VTACH_NUM_BEATS, 909 | window_size=parameters.VTACH_WINDOW_SIZE, 910 | rolling_increment=parameters.VTACH_ROLLING_INCREMENT): 911 | r_delta = np.array([]) 912 | end = window_size * fs 913 | 914 | lf, sub = get_lf_sub(channel_sig, order) 915 | 916 | while end <= channel_sig.size: 917 | start = end - window_size * fs 918 | start_index, end_index = int(start), int(end) 919 | 920 | channel_subsig = channel_sig[start_index:end_index] 921 | lf_subsig = lf[start_index:end_index] 922 | sub_subsig = sub[start_index:end_index] 923 | start_time = initial_start_time + start/fs 924 | end_time = start_time + window_size 925 | 926 | try: 927 | 928 | ##### Comment out the next line and uncomment the following line to use ventricular beat annotations generated beat-by-beat using DTW 929 | ventricular_beats, nonventricular_beats = ventricular_beat_annotations( 930 | lf_subsig, sub_subsig, ann_path + sample_name, ann_type, start_time, end_time) 931 | # ventricular_beats, nonventricular_beats = read_ventricular_beat_annotations(sample_name, "min") 932 | 933 | max_hr = max_ventricular_hr(ventricular_beats, num_beats, fs) 934 | 935 | invalids = calculate_channel_invalids(channel_subsig, "ECG") 936 | cval = calculate_cval_channel(invalids) 937 | 938 | if max_hr > max_hr_threshold: 939 | r_delta = np.append(r_delta, cval) 940 | else: 941 | r_delta = np.append(r_delta, 0) #-cval) 942 | 943 | except Exception as e: 944 | print(sample_name) 945 | print(e) 946 | r_delta = np.append(r_delta, 1) 947 | 948 | end += (rolling_increment * fs) 949 | 950 | return r_delta 951 | 952 | 953 | def get_abp_std_scores(channel_sig, 954 | std_threshold, 955 | window_size, 956 | rolling_increment, 957 | fs=parameters.DEFAULT_ECG_FS): 958 | r_delta = np.array([]) 959 | end = window_size * fs 960 | 961 | while end <= channel_sig.size: 962 | start = end - window_size * fs 963 | start_index, end_index = int(start), int(end) 964 | 965 | channel_subsig = channel_sig[start_index:end_index] 966 | end += (rolling_increment * fs) 967 | 968 | invalids = calculate_channel_invalids(channel_subsig, "BP") 969 | cval = calculate_cval_channel(invalids) 970 | 971 | std = np.std(channel_subsig) 972 | if std > std_threshold: 973 | r_delta = np.append(r_delta, 0) #-cval) 974 | else: 975 | r_delta = np.append(r_delta, cval) 976 | 977 | return r_delta 978 | 979 | 980 | def calculate_dlfmax(channel_sig, 981 | order=parameters.ORDER): 982 | lf, sub = get_lf_sub(channel_sig, order) 983 | 984 | current_dlfmax_start = None 985 | dlfmax_duration = 0 986 | prev_low_dominance = 0 987 | 988 | for index in range(len(lf)): 989 | lf_sample = lf[index] 990 | sub_sample = sub[index] 991 | 992 | if lf_sample > sub_sample: 993 | # If not yet started a low dominance area, set the start index 994 | if current_dlfmax_start is None: 995 | current_dlfmax_start = index 996 | 997 | # If a separate low dominance area, reset 998 | elif index - prev_low_dominance > parameters.VFIB_LOW_DOMINANCE_INDEX_THRESHOLD: 999 | # Calculate duration of previous low dominance area and update max dlfmax 1000 | duration = prev_low_dominance - current_dlfmax_start 1001 | dlfmax_duration = max(dlfmax_duration, duration) 1002 | 1003 | # Start new area of low dominance 1004 | current_dlfmax_start = index 1005 | 1006 | # Update previous index seen with low frequency dominance 1007 | prev_low_dominance = index 1008 | 1009 | # Handle duration at the end of the segment 1010 | if current_dlfmax_start is not None: 1011 | duration = prev_low_dominance - current_dlfmax_start 1012 | dlfmax_duration = max(dlfmax_duration, duration) 1013 | 1014 | return dlfmax_duration 1015 | 1016 | 1017 | # Get dominant freq in signal in rolling window 1018 | def get_dominant_freq_array(channel_sig, 1019 | fs=parameters.DEFAULT_ECG_FS, 1020 | window_size=parameters.VFIB_WINDOW_SIZE, 1021 | rolling_increment=parameters.VFIB_ROLLING_INCREMENT): 1022 | 1023 | end = window_size * fs 1024 | dominant_freqs = np.array([]) 1025 | 1026 | while end < channel_sig.size: 1027 | start = end - window_size * fs 1028 | 1029 | start_index, end_index = int(start), int(end) 1030 | channel_subsig = channel_sig[start_index:end_index] 1031 | end += (rolling_increment * fs) 1032 | 1033 | xf, fft = get_signal_fft(channel_subsig, window_size, fs) 1034 | 1035 | # Index of the fft is 2 * the actual frequency 1036 | dominant_freq = np.argmax(fft) / 2 1037 | 1038 | dominant_freqs = np.append(dominant_freqs, dominant_freq) 1039 | return dominant_freqs 1040 | 1041 | 1042 | def get_regular_activity_array(sig, 1043 | fields, 1044 | ann_path, 1045 | sample_name, 1046 | ecg_ann_type, 1047 | arrhythmia_type="Ventricular_Flutter_Fib", 1048 | fs=parameters.DEFAULT_ECG_FS, 1049 | window_size=parameters.VFIB_WINDOW_SIZE, 1050 | rolling_increment=parameters.VFIB_ROLLING_INCREMENT): 1051 | regular_activity_array = np.array([]) 1052 | end = window_size * fs 1053 | 1054 | while end < sig[:,0].size: 1055 | start = end - window_size * fs 1056 | start_index, end_index = int(start), int(end) 1057 | subsig = sig[start_index:end_index] 1058 | 1059 | invalids_dict = calculate_invalids_sig(subsig, fields) 1060 | rr_dict = get_rr_dict(ann_path, sample_name, fields, ecg_ann_type, start/fs, end/fs) 1061 | 1062 | is_regular = is_rr_invalids_regular(rr_dict, invalids_dict, window_size, arrhythmia_type) 1063 | if is_regular: 1064 | regular_activity_array = np.append(regular_activity_array, 1) 1065 | else: 1066 | regular_activity_array = np.append(regular_activity_array, 0) 1067 | 1068 | end += (rolling_increment * fs) 1069 | 1070 | return regular_activity_array 1071 | 1072 | 1073 | def adjust_dominant_freqs(dominant_freqs, regular_activity): 1074 | adjusted_dominant_freqs = np.array([]) 1075 | 1076 | for freq, is_regular in zip(dominant_freqs, regular_activity): 1077 | if is_regular: 1078 | adjusted_dominant_freqs = np.append(adjusted_dominant_freqs, 0) 1079 | else: 1080 | adjusted_dominant_freqs = np.append(adjusted_dominant_freqs, freq) 1081 | 1082 | return adjusted_dominant_freqs 1083 | 1084 | 1085 | 1086 | #################### 1087 | ##### Pipeline ##### 1088 | #################### 1089 | 1090 | # Returns true if alarm is classified as a true alarm 1091 | # Uses fplesinger annotations for ventricular tachycardia test 1092 | def classify_alarm(data_path, ann_path, sample_name, ecg_ann_type, verbose=False): 1093 | sig, fields = wfdb.srdsamp(data_path + sample_name) 1094 | 1095 | is_regular = is_sample_regular(data_path, ann_path, sample_name, ecg_ann_type) 1096 | if is_regular: 1097 | return False 1098 | 1099 | alarm_type = sample_name[0] 1100 | if alarm_type == "a": 1101 | arrhythmia_test = test_asystole 1102 | 1103 | elif alarm_type == "b": 1104 | arrhythmia_test = test_bradycardia 1105 | 1106 | elif alarm_type == "t": 1107 | arrhythmia_test = test_tachycardia 1108 | 1109 | elif alarm_type == "v": 1110 | arrhythmia_test = test_ventricular_tachycardia 1111 | 1112 | elif alarm_type == "f": 1113 | arrhythmia_test = test_ventricular_flutter_fibrillation 1114 | 1115 | else: 1116 | raise Exception("Unknown arrhythmia alarm type") 1117 | 1118 | # try: 1119 | return arrhythmia_test(data_path, ann_path, sample_name, ecg_ann_type, verbose) 1120 | # except Exception as e: 1121 | # print("sample_name: ", sample_name, e) 1122 | # return True 1123 | 1124 | # if __name__ == '__main__': 1125 | # data_path = '../sample_data/challenge_training_data/' 1126 | # ann_path = '../sample_data/challenge_training_multiann/' 1127 | # ecg_ann_type = 'gqrs' 1128 | 1129 | # print(classify_alarm(data_path, ann_path, "v199l", ecg_ann_type)) 1130 | --------------------------------------------------------------------------------