├── utils ├── __pycache__ │ ├── tools.cpython-39.pyc │ ├── eval_tools.cpython-39.pyc │ ├── optimizers.cpython-39.pyc │ ├── DINO_dataloader.cpython-39.pyc │ ├── base_dataloader.cpython-39.pyc │ ├── contrastive_dataloader.cpython-39.pyc │ └── timeseries_transformations.cpython-39.pyc ├── optimizers.py ├── tools.py ├── scheduler.py ├── contrastive_dataloader.py └── eval_tools.py ├── models ├── __pycache__ │ ├── seresnet2d.cpython-39.pyc │ ├── xresnet1d.cpython-39.pyc │ ├── basic_conv1d.cpython-39.pyc │ ├── signal_model.cpython-39.pyc │ ├── ensemble_model.cpython-39.pyc │ └── spectrogram_model.cpython-39.pyc ├── spectrogram_model.py ├── resnet.py ├── resnet_simclr.py ├── signal_model.py ├── seresnet2d.py ├── seresnet.py ├── ensemble_model.py ├── xresnet1d.py ├── basic_conv1d.py ├── inception_resnet_v2.py └── se_inception_resnet_v2.py ├── experiments ├── __pycache__ │ └── signal.cpython-39.pyc ├── BYOL_signal.py ├── SIMCLR_signal.py ├── SIMCLR_signal_finetune.py ├── BYOL_signal_finetune.py ├── run_signal.py ├── run_spectrogram.py └── run_ensembled.py ├── data_folder └── evaluation-2020-master │ ├── LICENSE │ ├── evaluate_12ECG_score.m │ ├── dx_mapping_scored.csv │ ├── README.md │ ├── .gitignore │ ├── Results │ ├── physionet_2020_unofficial_scores.csv │ ├── README.md │ ├── physionet_2020_official_scores.csv │ └── physionet_2020_metrics_perDatabase_official_entries.csv │ ├── weights.csv │ └── dx_mapping_unscored.csv ├── README.md └── data_preparation ├── reformat_memmap.py ├── data_extraction_without_preprocessing.py ├── stratify.py └── data_extraction_with_preprocessing.py /utils/__pycache__/tools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/tools.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/seresnet2d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/seresnet2d.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/xresnet1d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/xresnet1d.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_tools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/eval_tools.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/optimizers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/optimizers.cpython-39.pyc -------------------------------------------------------------------------------- /experiments/__pycache__/signal.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/experiments/__pycache__/signal.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/basic_conv1d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/basic_conv1d.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/signal_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/signal_model.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/ensemble_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/ensemble_model.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/DINO_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/DINO_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/base_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/base_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/spectrogram_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/spectrogram_model.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/contrastive_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/contrastive_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/timeseries_transformations.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/timeseries_transformations.cpython-39.pyc -------------------------------------------------------------------------------- /models/spectrogram_model.py: -------------------------------------------------------------------------------- 1 | from models.seresnet2d import se_resnet34 2 | import torch.nn as nn 3 | 4 | class spectrogram_model(nn.Module): 5 | def __init__(self,no_classes): 6 | super(spectrogram_model,self).__init__() 7 | self.backbone = se_resnet34() 8 | self.backbone.conv1 = nn.Conv2d(12, 64, kernel_size=7, stride=2, padding=3) 9 | list_of_modules = list(self.backbone.children()) 10 | self.features = nn.Sequential(*list_of_modules[:-1]) 11 | num_ftrs = self.backbone.fc.in_features 12 | 13 | self.fc = nn.Sequential( 14 | nn.Linear(in_features=num_ftrs,out_features=num_ftrs//2), 15 | nn.Linear(in_features=num_ftrs//2,out_features=no_classes) 16 | ) 17 | 18 | def forward(self, x): 19 | h = self.features(x) 20 | h = h.squeeze() 21 | x = self.fc(h) 22 | return x 23 | -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, PhysioNet/Computing in Cardiology Challenges 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class LARS(torch.optim.Optimizer): 7 | """ 8 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py 9 | """ 10 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, 11 | weight_decay_filter=None, lars_adaptation_filter=None): 12 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 13 | eta=eta, weight_decay_filter=weight_decay_filter, 14 | lars_adaptation_filter=lars_adaptation_filter) 15 | super().__init__(params, defaults) 16 | 17 | @torch.no_grad() 18 | def step(self): 19 | for g in self.param_groups: 20 | for p in g['params']: 21 | dp = p.grad 22 | 23 | if dp is None: 24 | continue 25 | 26 | if p.ndim != 1: 27 | dp = dp.add(p, alpha=g['weight_decay']) 28 | 29 | if p.ndim != 1: 30 | param_norm = torch.norm(p) 31 | update_norm = torch.norm(dp) 32 | one = torch.ones_like(param_norm) 33 | q = torch.where(param_norm > 0., 34 | torch.where(update_norm > 0, 35 | (g['eta'] * param_norm / update_norm), one), one) 36 | dp = dp.mul(q) 37 | 38 | param_state = self.state[p] 39 | if 'mu' not in param_state: 40 | param_state['mu'] = torch.zeros_like(p) 41 | mu = param_state['mu'] 42 | mu.mul_(g['momentum']).add_(dp) 43 | 44 | p.add_(mu, alpha=-g['lr']) 45 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision.models as models 4 | from .xresnet1d import xresnet1d50, xresnet1d101 5 | 6 | 7 | class ResNet(nn.Module): 8 | 9 | def __init__(self, base_model, out_dim, widen=1.0, hidden=False): 10 | super(ResNet, self).__init__() 11 | self.resnet_dict = {"resnet18": models.resnet18(pretrained=True), 12 | "resnet50": models.resnet50(pretrained=True)} 13 | 14 | resnet = self._get_basemodel(base_model) 15 | self.base_model = base_model 16 | 17 | list_of_modules = list(resnet.children()) 18 | if "xresnet" in base_model: 19 | self.features = nn.Sequential(*list_of_modules[:-1], list_of_modules[-1][0]) 20 | num_ftrs = resnet[-1][-1].in_features 21 | resnet[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2) 22 | else: 23 | resnet[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2) 24 | self.features = nn.Sequential(*list_of_modules[:-1]) 25 | num_ftrs = resnet.fc.in_features 26 | 27 | # projection MLP 28 | if hidden: 29 | self.l1 = nn.Linear(num_ftrs, num_ftrs) 30 | self.l2 = nn.Linear(num_ftrs, out_dim) 31 | else: 32 | self.l1 = nn.Linear(num_ftrs, out_dim) 33 | 34 | def _get_basemodel(self, model_name): 35 | try: 36 | model = self.resnet_dict[model_name] 37 | return model 38 | except: 39 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") 40 | 41 | def forward(self, x): 42 | h = self.features(x) 43 | h = h.squeeze() 44 | 45 | x = self.l1(h) 46 | x = F.relu(x) 47 | x = self.l2(x) 48 | return h, x 49 | -------------------------------------------------------------------------------- /models/resnet_simclr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision.models as models 4 | from .xresnet1d import xresnet1d50, xresnet1d101 5 | 6 | 7 | class ResNetSimCLR(nn.Module): 8 | 9 | def __init__(self, base_model, out_dim, widen=1.0, hidden=False): 10 | super(ResNetSimCLR, self).__init__() 11 | self.resnet_dict = {"resnet18": models.resnet18(pretrained=False), 12 | "resnet50": models.resnet50(pretrained=False), 13 | "xresnet1d50": xresnet1d50(widen=widen), 14 | "xresnet1d101": xresnet1d101(widen=widen)} 15 | 16 | resnet = self._get_basemodel(base_model) 17 | self.base_model = base_model 18 | 19 | list_of_modules = list(resnet.children()) 20 | if "xresnet" in base_model: 21 | self.features = nn.Sequential(*list_of_modules[:-1], list_of_modules[-1][0]) 22 | num_ftrs = resnet[-1][-1].in_features 23 | resnet[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2) 24 | else: 25 | self.features = nn.Sequential(*list_of_modules[:-1]) 26 | num_ftrs = resnet.fc.in_features 27 | 28 | # projection MLP 29 | if hidden: 30 | self.l1 = nn.Linear(num_ftrs, num_ftrs) 31 | self.l2 = nn.Linear(num_ftrs, out_dim) 32 | else: 33 | self.l1 = nn.Linear(num_ftrs, out_dim) 34 | 35 | def _get_basemodel(self, model_name): 36 | try: 37 | model = self.resnet_dict[model_name] 38 | return model 39 | except: 40 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") 41 | 42 | def forward(self, x): 43 | h = self.features(x) 44 | h = h.squeeze() 45 | 46 | x = self.l1(h) 47 | x = F.relu(x) 48 | x = self.l2(x) 49 | return h, x 50 | -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/evaluate_12ECG_score.m: -------------------------------------------------------------------------------- 1 | % This file contains functions for evaluating algorithms for the 2020 PhysioNet/ 2 | % Computing in Cardiology Challenge. You can run it as follows: 3 | % 4 | % evaluate_12ECG_score(labels, outputs, scores.csv) 5 | % 6 | % where 'labels' is a directory containing files with the labels, 'outputs' is a 7 | % directory containing files with the outputs from your model, and 'scores.csv' 8 | % (optional) is a collection of scores for the algorithm outputs. 9 | % 10 | % Each file of labels or outputs must have the format described on the Challenge 11 | % webpage. The scores for the algorithm outputs include the area under the 12 | % receiver-operating characteristic curve (AUROC), the area under the recall- 13 | % precision curve (AUPRC), accuracy (fraction of correct recordings), macro F- 14 | % measure, and the Challenge metric, which assigns different weights to 15 | % different misclassification errors. 16 | 17 | function evaluate_12ECG_score(labels, outputs, output_file, class_output_file) 18 | % Check for Python and NumPy. 19 | command = 'python -V'; 20 | [status, ~] = system(command); 21 | if status~=0 22 | error('Python not found: please install Python or make it available by running "python ...".'); 23 | end 24 | 25 | command = 'python -c "import numpy"'; 26 | [status, ~] = system(command); 27 | if status~=0 28 | error('NumPy not found: please install NumPy or make it available to Python.'); 29 | end 30 | 31 | % Define command for evaluating model outputs. 32 | switch nargin 33 | case 2 34 | command = ['python evaluate_12ECG_score.py' ' ' labels ' ' outputs]; 35 | case 3 36 | command = ['python evaluate_12ECG_score.py' ' ' labels ' ' outputs ' ' output_file]; 37 | case 4 38 | command = ['python evaluate_12ECG_score.py' ' ' labels ' ' outputs ' ' output_file ' ' class_output_file]; 39 | otherwise 40 | command = ''; 41 | end 42 | 43 | % Evaluate model outputs. 44 | [~, output] = system(command); 45 | fprintf(output); 46 | end 47 | -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/dx_mapping_scored.csv: -------------------------------------------------------------------------------- 1 | Dx,SNOMED CT Code,Abbreviation,CPSC,CPSC-Extra,StPetersburg,PTB,PTB-XL,Georgia,Total,Notes 2 | 1st degree av block,270492004,IAVB,722,106,0,0,797,769,2394, 3 | atrial fibrillation,164889003,AF,1221,153,2,15,1514,570,3475, 4 | atrial flutter,164890007,AFL,0,54,0,1,73,186,314, 5 | bradycardia,426627000,Brady,0,271,11,0,0,6,288, 6 | complete right bundle branch block,713427006,CRBBB,0,113,0,0,542,28,683,We score 713427006 and 59118001 as the same diagnosis. 7 | incomplete right bundle branch block,713426002,IRBBB,0,86,0,0,1118,407,1611, 8 | left anterior fascicular block,445118002,LAnFB,0,0,0,0,1626,180,1806, 9 | left axis deviation,39732003,LAD,0,0,0,0,5146,940,6086, 10 | left bundle branch block,164909002,LBBB,236,38,0,0,536,231,1041, 11 | low qrs voltages,251146004,LQRSV,0,0,0,0,182,374,556, 12 | nonspecific intraventricular conduction disorder,698252002,NSIVCB,0,4,1,0,789,203,997, 13 | pacing rhythm,10370003,PR,0,3,0,0,296,0,299, 14 | premature atrial contraction,284470004,PAC,616,73,3,0,398,639,1729,We score 284470004 and 63593006 as the same diagnosis. 15 | premature ventricular contractions,427172004,PVC,0,188,0,0,0,0,188,We score 427172004 and 17338001 as the same diagnosis. 16 | prolonged pr interval,164947007,LPR,0,0,0,0,340,0,340, 17 | prolonged qt interval,111975006,LQT,0,4,0,0,118,1391,1513, 18 | qwave abnormal,164917005,QAb,0,1,0,0,548,464,1013, 19 | right axis deviation,47665007,RAD,0,1,0,0,343,83,427, 20 | right bundle branch block,59118001,RBBB,1857,1,2,0,0,542,2402,We score 713427006 and 59118001 as the same diagnosis. 21 | sinus arrhythmia,427393009,SA,0,11,2,0,772,455,1240, 22 | sinus bradycardia,426177001,SB,0,45,0,0,637,1677,2359, 23 | sinus rhythm,426783006,NSR,918,4,0,80,18092,1752,20846, 24 | sinus tachycardia,427084000,STach,0,303,11,1,826,1261,2402, 25 | supraventricular premature beats,63593006,SVPB,0,53,4,0,157,1,215,We score 284470004 and 63593006 as the same diagnosis. 26 | t wave abnormal,164934002,TAb,0,22,0,0,2345,2306,4673, 27 | t wave inversion,59931005,TInv,0,5,1,0,294,812,1112, 28 | ventricular premature beats,17338001,VPB,0,8,0,0,0,357,365,We score 427172004 and 17338001 as the same diagnosis. 29 | -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/README.md: -------------------------------------------------------------------------------- 1 | # PhysioNet/CinC Challenge 2020 Evaluation Metrics 2 | 3 | This repository contains the Python and MATLAB evaluation code for the PhysioNet/Computing in Cardiology Challenge 2020. The `evaluate_12ECG_score` script evaluates the output of your algorithm using the evaluation metric that is described on the [webpage](https://physionetchallenges.github.io/2020/) for the PhysioNet/CinC Challenge 2020. While this script reports multiple evaluation metric, we use the last score (`Challenge Metric`) to evaluate your algorithm. 4 | 5 | ## Python 6 | 7 | You can run the Python evaluation code by installing the NumPy Python package and running 8 | 9 | python evaluate_12ECG_score.py labels outputs scores.csv class_scores.csv 10 | 11 | where `labels` is a directory containing files with one or more labels for each 12-lead ECG recording, such as the training database on the PhysioNet webpage; `outputs` is a directory containing files with outputs produced by your algorithm for those recordings; `scores.csv` (optional) is a collection of scores for your algorithm; and `class_scores.csv` (optional) is a collection of per-class scores for your algorithm. 12 | 13 | ## MATLAB 14 | 15 | You can run the MATLAB evaluation code by installing Python and the NumPy Python package and running 16 | 17 | evaluate_12ECG_score(labels, outputs, scores.csv, class_scores.csv) 18 | 19 | where `labels` is a directory containing files with one or more labels for each 12-lead ECG recording, such as the training database on the PhysioNet webpage; `outputs` is a directory containing files with outputs produced by your algorithm for those recordings; `scores.csv` (optional) is a collection of scores for your algorithm; and `class_scores.csv` (optional) is a collection of per-class scores for your algorithm. 20 | 21 | ## Troubleshooting 22 | 23 | Unable to run this code with your code? Try one of the [baseline classifiers](https://physionetchallenges.github.io/2020/#submissions) on the [training data](https://physionetchallenges.github.io/2020/#data). Unable to install or run Python? Try [Python](https://www.python.org/downloads/), [Anaconda](https://www.anaconda.com/products/individual), or your package manager. 24 | -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/Results/physionet_2020_unofficial_scores.csv: -------------------------------------------------------------------------------- 1 | Team name,CinC Abstract #,Validation Set Score,Hidden CPSC Set Score,Hidden G12EC Set Score,Hidden Undisclosed Set Score,Test Set Score,Training code produces output?,Model uses output from training code?,Open-source license?,Registered at CinC?,Preprint at CinC?,Presented at CinC? 2 | AAIST,No abstract,0.507,0.674,0.485,0.275,0.377,Y,Y,BSD 2 License,N,N,N 3 | AImsterdam,327,0.609,0.636,0.252,-0.093,0.198,N,N,Unknown,Y,N,N 4 | BERCLAB UND,79,0.197,0.564,0.127,0.106,0.141,Y,Y,Unknown,Y,N,N 5 | BME_Feng,69,0.001,0.003,0.001,-0.030,-0.016,Y,Y,BSD 2 License,N,N,N 6 | BraveHeart400,83,0.449,0.657,0.413,-0.265,0.034,Y,Y,BSD 2 License,N,N,N 7 | BRIC,331,0.539,0.652,0.189,0.030,0.127,Y,Y,BSD 2 License,N,N,N 8 | Chapman,No abstract,-0.204,-0.282,-0.190,0.004,-0.084,Y,Y,BSD 2 License,N,N,N 9 | Connected_Health,176,0.566,0.703,0.541,0.417,0.479,Y,Y,BSD 2 License,N (rejected abstract),N,N 10 | Health team Szeged,48,0.493,0.423,0.505,0.472,0.480,Y,Y,BSD 2 License,N,N,N 11 | IBMTpeakyFinders,173,0.282,-0.072,-0.052,-0.426,-0.269,Y,Y,BSD 2 License,N,N,N 12 | Kimball_IRL,31,0.178,0.444,0.138,-0.211,-0.042,Y,Y,BSD 2 License,Y,N,Y - poster 13 | LaussenLabs,353,-0.406,-0.455,-0.390,-0.848,-0.658,Y,N,BSD 2 License,Y,Y,Y - poster 14 | LIST_AIHealthCare,120,0.216,0.152,0.229,0.173,0.192,Y,Y,BSD 2 License,N,N,N 15 | Marquette,74,0.511,0.458,0.521,0.478,0.492,Y,Y,BSD 2 License,Y,Y (but after deadline),Y - poster 16 | Medics,187,0.189,0.480,0.146,NaN,NaN,Y,Y,BSD 2 License,N,N,N 17 | MetaHeart,196,0.616,0.758,0.590,0.194,0.370,Y,Y,BSD 2 License,Y,Y,Y - poster (but no response to questions) 18 | Metformin-121,136,0.623,0.865,0.586,0.413,0.505,Y,Y,BSD 2 License,N,N,N 19 | ML Warriors,412,0.389,0.395,0.390,0.181,0.269,N,N,BSD 2 License,N,N,N 20 | NACAS_12X,180,0.645,0.846,0.202,0.000,0.127,Y,Y,BSD 2 License,Y,Y (but did not update preprint),Y - poster 21 | nebula,39,0.526,0.736,0.086,0.052,0.109,N,N,BSD 2 License,Y,Y,Y - poster 22 | NN-MIH,63,0.585,0.665,0.567,0.367,0.456,N,N,BSD 2 License,Y,Y,Y - poster 23 | NTU-Accesslab,72,0.544,0.725,0.510,NaN,NaN,Y,Y,BSD 2 License,Y,Y,Y - poster 24 | Orange Peel,145,0.650,0.813,0.621,0.161,0.364,Y,Y ,BSD 2 License,N (no abstract submission),N,N 25 | SBU_AI,307,0.416,0.513,0.016,-0.028,0.024,N,N,BSD 2 License,Y,Y,Y - 3rd talk in 2nd session 26 | SpaceOn Flattop,7,0.681,0.871,0.219,0.126,0.208,N,N,BSD 2 License,Y,Y,Y - poster 27 | try again,No abstract,0.072,0.261,-0.266,-0.753,-0.553,Y,Y,BSD 2 License,N,N,N 28 | UniA4Life,314,-0.105,0.250,-0.156,-0.523,-0.339,N,N,GNU GPL V3 License,N,N,N 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Multimodality Multi-Lead ECG Arrhythmia Classification using Self-Supervised Learning 2 | 3 | Paper link: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9926925 4 | 5 | 1. Download datasets from the PhysioNet 2020 Competition. Put in the folder ./data_folder/datasets and extract all of them . 6 | https://physionetchallenges.github.io/2020/ 7 | 8 | 2. Preparing the data 9 | python data_preparation/data_extraction_without_preprocessing.py 10 | python data_preparation/reformat_memmap.py 11 | 12 | 3. Training base models 13 | python experiments/run_signal.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --save_folder ./checkpoints/base_signal 14 | python experiments/run_spectrogram.py --batch_size 256 --lr_rate 5e-3 --num_epoches 200 --gpu 0 --save_folder ./checkpoints/base_spectrogram 15 | (without gating fusion) 16 | python experiments/run_ensembled.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --save_folder ./checkpoints/base_ensemble_wogating 17 | (with gating fusion) 18 | python experiments/run_ensembled.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --gating --save_folder ./checkpoints/base_ensemble_wgating 19 | 20 | 4. Self-supervised learning for pretrained models 21 | (SimCLR) 22 | python experiments/SIMCLR_signal.py 23 | (BYOL) 24 | python experiments/BYOL_signal.py 25 | (DINO) 26 | python experiments/DINO_signal.py 27 | python experiments/DINO_spectrogram.py 28 | 29 | 5. Finetuning the main model based on the self-supervised pretrained models 30 | (SimCLR) 31 | python experiments/SIMCLR_signal_finetune.py 32 | (BYOL) 33 | python experiments/BYOL_signal_finetune.py 34 | (DINO) 35 | python experiments/run_signal.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --finetune ./checkpoints/DINO_signal_student.pth --save_folder ./checkpoints/finetune_signal 36 | python experiments/run_spectrogram.py --batch_size 256 --lr_rate 5e-3 --num_epoches 200 --gpu 0 --finetune ./checkpoints/DINO_spectrogram_student.pth --save_folder ./checkpoints/finetune_spectrogram 37 | (without gating fusion) 38 | python experiments/run_ensembled.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --finetune ./checkpoints --save_folder ./checkpoints/finetune_ensemble_wogating 39 | (with gating fusion) 40 | python experiments/run_ensembled.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --finetune ./checkpoints --gating --save_folder ./checkpoints/finetune_ensemble_wgating 41 | 42 | 6. Searching the thresholds of classes for best Challenge score 43 | python experiments/threshold_search.py --model_type signal --best-type PRC --gpu 0 --weight_folder ./checkpoints/base_signal 44 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def weights_init_xavier(m): 6 | if isinstance(m, nn.Conv2d): 7 | torch.nn.init.xavier_normal_(m.weight.data) 8 | if m.bias is not None: 9 | torch.nn.init.normal_(m.bias.data) 10 | elif isinstance(m, nn.Conv1d): 11 | torch.nn.init.xavier_normal_(m.weight.data) 12 | if m.bias is not None: 13 | torch.nn.init.normal_(m.bias.data) 14 | elif isinstance(m, nn.BatchNorm1d): 15 | torch.nn.init.normal_(m.weight.data, mean=1, std=0.02) 16 | torch.nn.init.constant_(m.bias.data, 0) 17 | elif isinstance(m, nn.BatchNorm2d): 18 | torch.nn.init.normal_(m.weight.data, mean=1, std=0.02) 19 | torch.nn.init.constant_(m.bias.data, 0) 20 | elif isinstance(m, nn.Linear): 21 | torch.nn.init.xavier_normal_(m.weight.data) 22 | # torch.nn.init.normal_(m.bias.data) 23 | 24 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 25 | warmup_schedule = np.array([]) 26 | warmup_iters = warmup_epochs * niter_per_ep 27 | if warmup_epochs > 0: 28 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 29 | 30 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 31 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 32 | 33 | schedule = np.concatenate((warmup_schedule, schedule)) 34 | assert len(schedule) == epochs * niter_per_ep 35 | return schedule 36 | 37 | def set_requires_grad(model, val): 38 | for p in model.parameters(): 39 | p.requires_grad = val 40 | 41 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 42 | if epoch >= freeze_last_layer: 43 | return 44 | for n, p in model.named_parameters(): 45 | if "last_layer" in n: 46 | p.grad = None 47 | 48 | 49 | def open_all_layers(model): 50 | for p in model.parameters(): 51 | p.requires_grad = True 52 | 53 | 54 | def open_specified_layers(model, open_layers): 55 | 56 | if isinstance(open_layers, str): 57 | open_layers = [open_layers] 58 | 59 | for layer in open_layers: 60 | assert hasattr( 61 | model, layer 62 | ), '"{}" is not an attribute of the model, please provide the correct name'.format( 63 | layer 64 | ) 65 | 66 | for name, module in model.named_children(): 67 | if name in open_layers: 68 | for p in module.parameters(): 69 | p.requires_grad = True 70 | else: 71 | for p in module.parameters(): 72 | p.requires_grad = False -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/Results/README.md: -------------------------------------------------------------------------------- 1 | # PhysioNet/CinC Challenge 2020 Results 2 | 3 | This folder contains several files with the results of the 2020 Challenge. 4 | 5 | We introduced [new scoring metric](https://physionetchallenges.github.io/2020/#scoring) for this Challenge. We used this scoring metric to evaluate and rank the Challenge entries. We included several other metrics for reference. The area under the receiver operating characteristic (AUROC), area under the precision recall curve (AUPRC), and _F_-measure scores are the macro-average of the scores across all classes. The accuracy metric is the fraction of correctly diagnosed recordings, i.e., all classes for the recording are correct. These metrics were computed by the [evaluate_12ECG_score.py](https://github.com/physionetchallenges/evaluation-2020/blob/master/evaluate_12ECG_score.py) script in this repository. Please see this script for more details of these scores. 6 | 7 | We included the scores on the following datasets: 8 | 9 | 1. __Validation Set:__ Includes recordings from the hidden CPSC and G12EC sets. 10 | 2. __Hidden CPSC Set:__ Split between the validation and test sets. 11 | 3. __Hidden G12EC Set:__ Split between the validation and test sets. 12 | 4. __Hidden Undisclosed Set:__ All recordings were part of the test sets. 13 | 5. __Test Set:__ Includes recordings from the hidden CPSC, G12EC, and undisclosed test sets. 14 | 15 | To refer to these tables in a publication, please cite [Perez Alday EA, Gu A, Shah AJ, Robichaux C, Wong AI, Liu C, Liu F, Rad AB, Elola A, Seyedi S, Li Q, Sharma A, Clifford GD*, Reyna MA*. Classification of 12-lead ECGs: the PhysioNet/Computing in Cardiology Challenge 2020. Physiol Meas. 41 (2020). doi: 10.1088/1361-6579/abc960](https://iopscience.iop.org/article/10.1088/1361-6579/abc960). 16 | 17 | 1. Official entries that were scored on the validation and test data and ranked in the Challenge: 18 | [physionet_2020_official_scores.csv](https://github.com/physionetchallenges/evaluation-2020/blob/master/Results/physionet_2020_official_scores.csv) 19 | 2. Unofficial entries that were scored on the validation and test data but unranked because they did not satisfy all of the [rules](https://physionetchallenges.github.io/2020/#rules-and-deadlines) or were unsuccessful on one or more of the test sets: 20 | [physionet_2020_unofficial_scores.csv](https://github.com/physionetchallenges/evaluation-2020/blob/master/Results/physionet_2020_unofficial_scores.csv) 21 | 3. Challenge and other scoring metrics on all official entries broken with scores for each database in the validation and test data: 22 | [physionet_2020_full_metrics_official_entries.csv](https://github.com/physionetchallenges/evaluation-2020/blob/master/Results/physionet_2020_full_metrics_official_entries.csv ) 23 | 4. Per-class scoring metrics on the validation data: 24 | [physionet_2020_validation_metrics_by_class_official_entries.csv](https://github.com/physionetchallenges/evaluation-2020/blob/master/Results/physionet_2020_validation_metrics_by_class_official_entries.csv) 25 | -------------------------------------------------------------------------------- /data_preparation/reformat_memmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | def npys_to_memmap(npys, target_filename, max_len=0, delete_npys=True): 8 | memmap = None 9 | start = []#start_idx in current memmap file 10 | length = []#length of segment 11 | filenames= []#memmap files 12 | file_idx=[]#corresponding memmap file for sample 13 | shape=[] 14 | 15 | for idx,npy in tqdm(list(enumerate(npys))): 16 | data = np.load(npy, allow_pickle=True) 17 | if(memmap is None or (max_len>0 and start[-1]+length[-1]>max_len)): 18 | filenames.append(target_filename) 19 | 20 | if(memmap is not None):#an existing memmap exceeded max_len 21 | shape.append([start[-1]+length[-1]]+[l for l in data.shape[1:]]) 22 | del memmap 23 | #create new memmap 24 | start.append(0) 25 | length.append(data.shape[0]) 26 | memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='w+', shape=data.shape) 27 | else: 28 | #append to existing memmap 29 | start.append(start[-1]+length[-1]) 30 | length.append(data.shape[0]) 31 | memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='r+', shape=tuple([start[-1]+length[-1]]+[l for l in data.shape[1:]])) 32 | 33 | #store mapping memmap_id to memmap_file_id 34 | file_idx.append(len(filenames)-1) 35 | #insert the actual data 36 | memmap[start[-1]:start[-1]+length[-1]]=data[:] 37 | memmap.flush() 38 | if(delete_npys is True): 39 | npy.unlink() 40 | del memmap 41 | 42 | #append final shape if necessary 43 | if(len(shape)= self.cur_cycle_steps: 68 | self.cycle += 1 69 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 70 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 71 | else: 72 | if epoch >= self.first_cycle_steps: 73 | if self.cycle_mult == 1.: 74 | self.step_in_cycle = epoch % self.first_cycle_steps 75 | self.cycle = epoch // self.first_cycle_steps 76 | else: 77 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 78 | self.cycle = n 79 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 80 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 81 | else: 82 | self.cur_cycle_steps = self.first_cycle_steps 83 | self.step_in_cycle = epoch 84 | 85 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 86 | self.last_epoch = math.floor(epoch) 87 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 88 | param_group['lr'] = lr -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/dx_mapping_unscored.csv: -------------------------------------------------------------------------------- 1 | Dx,SNOMED CT Code,Abbreviation,CPSC,CPSC-Extra,StPetersburg,PTB,PTB-XL,Georgia,Total 2 | 2nd degree av block,195042002,IIAVB,0,21,0,0,14,23,58 3 | abnormal QRS,164951009,abQRS,0,0,0,0,3389,0,3389 4 | accelerated junctional rhythm,426664006,AJR,0,0,0,0,0,19,19 5 | acute myocardial infarction,57054005,AMI,0,0,6,0,0,0,6 6 | acute myocardial ischemia,413444003,AMIs,0,1,0,0,0,1,2 7 | anterior ischemia,426434006,AnMIs,0,0,0,0,44,281,325 8 | anterior myocardial infarction,54329005,AnMI,0,62,0,0,354,0,416 9 | atrial bigeminy,251173003,AB,0,0,3,0,0,0,3 10 | atrial fibrillation and flutter,195080001,AFAFL,0,39,0,0,0,2,41 11 | atrial hypertrophy,195126007,AH,0,2,0,0,0,60,62 12 | atrial pacing pattern,251268003,AP,0,0,0,0,0,52,52 13 | atrial tachycardia,713422000,ATach,0,15,0,0,0,28,43 14 | atrioventricular junctional rhythm,29320008,AVJR,0,6,0,0,0,0,6 15 | av block,233917008,AVB,0,5,0,0,0,74,79 16 | blocked premature atrial contraction,251170000,BPAC,0,2,3,0,0,0,5 17 | brady tachy syndrome,74615001,BTS,0,1,1,0,0,0,2 18 | bundle branch block,6374002,BBB,0,0,1,20,0,116,137 19 | cardiac dysrhythmia,698247007,CD,0,0,0,16,0,0,16 20 | chronic atrial fibrillation,426749004,CAF,0,1,0,0,0,0,1 21 | chronic myocardial ischemia,413844008,CMI,0,161,0,0,0,0,161 22 | complete heart block,27885002,CHB,0,27,0,0,16,8,51 23 | congenital incomplete atrioventricular heart block,204384007,CIAHB,0,0,0,2,0,0,2 24 | coronary heart disease,53741008,CHD,0,0,16,21,0,0,37 25 | decreased qt interval,77867006,SQT,0,1,0,0,0,0,1 26 | diffuse intraventricular block,82226007,DIB,0,1,0,0,0,0,1 27 | early repolarization,428417006,ERe,0,0,0,0,0,140,140 28 | fusion beats,13640000,FB,0,0,7,0,0,0,7 29 | heart failure,84114007,HF,0,0,0,7,0,0,7 30 | heart valve disorder,368009,HVD,0,0,0,6,0,0,6 31 | high t-voltage,251259000,HTV,0,1,0,0,0,0,1 32 | idioventricular rhythm,49260003,IR,0,0,2,0,0,0,2 33 | incomplete left bundle branch block,251120003,ILBBB,0,42,0,0,77,86,205 34 | indeterminate cardiac axis,251200008,ICA,0,0,0,0,156,0,156 35 | inferior ischaemia,425419005,IIs,0,0,0,0,219,451,670 36 | inferior ST segment depression,704997005,ISTD,0,1,0,0,0,0,1 37 | junctional escape,426995002,JE,0,4,0,0,0,5,9 38 | junctional premature complex,251164006,JPC,0,2,0,0,0,0,2 39 | junctional tachycardia,426648003,JTach,0,2,0,0,0,4,6 40 | lateral ischaemia,425623009,LIs,0,0,0,0,142,903,1045 41 | left atrial abnormality,253352002,LAA,0,0,0,0,0,72,72 42 | left atrial enlargement,67741000119109,LAE,0,1,0,0,427,870,1298 43 | left atrial hypertrophy,446813000,LAH,0,40,0,0,0,0,40 44 | left posterior fascicular block,445211001,LPFB,0,0,0,0,177,25,202 45 | left ventricular hypertrophy,164873001,LVH,0,158,10,0,2359,1232,3759 46 | left ventricular strain,370365005,LVS,0,1,0,0,0,0,1 47 | mobitz type i wenckebach atrioventricular block,54016002,MoI,0,0,3,0,0,0,3 48 | myocardial infarction,164865005,MI,0,376,9,368,5261,7,6021 49 | myocardial ischemia,164861001,MIs,0,384,0,0,2175,0,2559 50 | nonspecific st t abnormality,428750005,NSSTTA,0,1290,0,0,381,1883,3554 51 | old myocardial infarction,164867002,OldMI,0,1168,0,0,0,0,1168 52 | paired ventricular premature complexes,251182009,VPVC,0,0,23,0,0,0,23 53 | paroxysmal atrial fibrillation,282825002,PAF,0,0,1,1,0,0,2 54 | paroxysmal supraventricular tachycardia,67198005,PSVT,0,0,3,0,24,0,27 55 | paroxysmal ventricular tachycardia,425856008,PVT,0,0,15,0,0,0,15 56 | r wave abnormal,164921003,RAb,0,1,0,0,0,10,11 57 | rapid atrial fibrillation,314208002,RAF,0,0,0,2,0,0,2 58 | right atrial abnormality,253339007,RAAb,0,0,0,0,0,14,14 59 | right atrial hypertrophy,446358003,RAH,0,18,0,0,99,0,117 60 | right ventricular hypertrophy,89792004,RVH,0,20,0,0,126,86,232 61 | s t changes,55930002,STC,0,1,0,0,770,6,777 62 | shortened pr interval,49578007,SPRI,0,3,0,0,0,2,5 63 | sinoatrial block,65778007,SAB,0,9,0,0,0,0,9 64 | sinus node dysfunction,60423000,SND,0,0,2,0,0,0,2 65 | st depression,429622005,STD,869,57,4,0,1009,38,1977 66 | st elevation,164931005,STE,220,66,4,0,28,134,452 67 | st interval abnormal,164930006,STIAb,0,481,2,0,0,992,1475 68 | supraventricular bigeminy,251168009,SVB,0,0,1,0,0,0,1 69 | supraventricular tachycardia,426761007,SVT,0,3,1,0,27,32,63 70 | suspect arm ecg leads reversed,251139008,ALR,0,0,0,0,0,12,12 71 | transient ischemic attack,266257000,TIA,0,0,7,0,0,0,7 72 | u wave abnormal,164937009,UAb,0,1,0,0,0,0,1 73 | ventricular bigeminy,11157007,VBig,0,5,9,0,82,2,98 74 | ventricular ectopics,164884008,VEB,700,0,49,0,1154,41,1944 75 | ventricular escape beat,75532003,VEsB,0,3,1,0,0,0,4 76 | ventricular escape rhythm,81898007,VEsR,0,1,0,0,0,1,2 77 | ventricular fibrillation,164896001,VF,0,10,0,25,0,3,38 78 | ventricular flutter,111288001,VFL,0,1,0,0,0,0,1 79 | ventricular hypertrophy,266249003,VH,0,5,0,13,30,71,119 80 | ventricular pacing pattern,251266004,VPP,0,0,0,0,0,46,46 81 | ventricular pre excitation,195060002,VPEx,0,6,0,0,0,2,8 82 | ventricular tachycardia,164895002,VTach,0,1,1,10,0,0,12 83 | ventricular trigeminy,251180001,VTrig,0,4,4,0,20,1,29 84 | wandering atrial pacemaker,195101003,WAP,0,0,0,0,0,7,7 85 | wolff parkinson white pattern,74390002,WPW,0,0,4,2,80,2,88 86 | -------------------------------------------------------------------------------- /experiments/BYOL_signal.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | from torch.utils.data import DataLoader 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | import sys 12 | current_path = os.getcwd() 13 | sys.path.append(current_path) 14 | 15 | from models.signal_model import signal_model_byol 16 | from utils.contrastive_dataloader import ECG_contrastive_dataset 17 | from utils.eval_tools import load_weights 18 | from utils.optimizers import LARS 19 | from utils.eval_tools import load_weights 20 | from utils.tools import weights_init_xavier, set_requires_grad 21 | 22 | ctx = "cuda:0" if torch.cuda.is_available() else 'cpu' 23 | 24 | class MLPHead(nn.Module): 25 | def __init__(self, in_channels, mlp_hidden_size, projection_size): 26 | super(MLPHead, self).__init__() 27 | 28 | self.net = nn.Sequential( 29 | nn.Linear(in_channels, mlp_hidden_size), 30 | nn.BatchNorm1d(mlp_hidden_size), 31 | nn.ReLU(inplace=True), 32 | nn.Linear(mlp_hidden_size, projection_size) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | def regression_loss(x, y): 39 | x = F.normalize(x, dim=1) 40 | y = F.normalize(y, dim=1) 41 | return 2 - 2 * (x * y).sum(dim=-1) 42 | 43 | def run(): 44 | root_folder = './data_folder' 45 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing') 46 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']] 47 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']] 48 | 49 | no_channels = 12 50 | signal_size = 250 51 | train_stride = signal_size 52 | train_chunk_length = 0 53 | 54 | 55 | transforms = ["TimeOut_difflead","GaussianNoise"] 56 | 57 | batch_size = 1024 58 | learning_rate = 1e-3 59 | no_epoches = 400 60 | 61 | get_mean = np.load(os.path.join(data_folder,"mean.npy")) 62 | get_std = np.load(os.path.join(data_folder,"std.npy")) 63 | 64 | t_params = {"gaussian_scale":[0.005,0.025], "global_crop_scale": [0.5, 1.0], "local_crop_scale": [0.1, 0.5], 65 | "output_size": 250, "warps": 3, "radius": 10, "shift_range":[0.2,0.5], 66 | "epsilon": 10, "magnitude_range": [0.5, 2], "downsample_ratio": 0.2, "to_crop_ratio_range": [0.2, 0.4], 67 | "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1, "stats_mean":get_mean,"stats_std":get_std} 68 | 69 | 70 | train_dataset = ECG_contrastive_dataset(summary_folder=data_folder, signal_size=signal_size, stride=train_stride, 71 | chunk_length=train_chunk_length, transforms=transforms,t_params=t_params, 72 | equivalent_classes=equivalent_classes, sample_items_per_record=1,random_crop=True) 73 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size,drop_last=True) 74 | 75 | no_classes = 24 76 | online_network = signal_model_byol(no_classes) 77 | target_network = deepcopy(online_network) 78 | online_network.to(ctx) 79 | target_network.to(ctx) 80 | 81 | set_requires_grad(target_network,False) 82 | 83 | # optimizer = torch.optim.Adam(list(online_network.parameters()) + list(target_network.parameters()),lr=learning_rate) 84 | optimizer = torch.optim.Adam(online_network.parameters(),lr=learning_rate) 85 | # optimizer = LARS(online_network.parameters(),lr=0.1,weight_decay=0.0048) 86 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-4, last_epoch=-1) 87 | 88 | 89 | optimizer.zero_grad() 90 | optimizer.step() 91 | 92 | lowest_train_loss = 2 93 | for epoch in range(1,no_epoches+1): 94 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches)) 95 | print('Current learning rate: ',optimizer.param_groups[0]['lr']) 96 | scheduler_steplr.step() 97 | online_network.train() 98 | train_loss = 0 99 | train_acc = 0 100 | 101 | for batch_idx, sample in enumerate(tqdm(train_dataloader)): 102 | data_i = sample['sig_i'].to(ctx).float() 103 | data_j = sample['sig_j'].to(ctx).float() 104 | 105 | # features, projector, predictor, output 106 | h1a,z1a,t1a,_ = online_network(data_i) 107 | h1b,z1b,t1b,_ = online_network(data_j) 108 | 109 | with torch.no_grad(): 110 | h2a,z2a,t2a,_ = target_network(data_j) 111 | h2b,z2b,t2b,_ = target_network(data_j) 112 | 113 | # image 1 to image 2 loss 114 | loss = regression_loss(t1a, z2b) 115 | loss += regression_loss(t1b,z2a) 116 | total_loss = loss.mean() 117 | # image 2 to image 1 loss 118 | 119 | train_loss += total_loss.item() 120 | 121 | optimizer.zero_grad() 122 | total_loss.backward() 123 | optimizer.step() 124 | 125 | t_d = 0.9 126 | # t_d = 0.996 127 | for param_q, param_k in zip(online_network.parameters(), target_network.parameters()): 128 | param_k.data = param_k.data * t_d + param_q.data * (1. - t_d) 129 | 130 | whole_train_loss = train_loss / (batch_idx + 1) 131 | print(f'Train Loss: {whole_train_loss}') 132 | if whole_train_loss < lowest_train_loss: 133 | lowest_train_loss = whole_train_loss 134 | torch.save(online_network.state_dict(), f'./checkpoints/BYOL_signal.pth') 135 | 136 | 137 | if __name__ == "__main__": 138 | run() -------------------------------------------------------------------------------- /models/seresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | 5 | def conv3x3_1d(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | def conv5x5_1d(in_planes, out_planes, stride=1): 11 | return nn.Conv1d(in_planes, out_planes, kernel_size=5, stride=stride, 12 | padding=1, bias=False) 13 | 14 | def conv7x7_1d(in_planes, out_planes, stride=1): 15 | return nn.Conv1d(in_planes, out_planes, kernel_size=7, stride=stride, 16 | padding=1, bias=False) 17 | 18 | class SELayer_1d(nn.Module): 19 | def __init__(self, channel, reduction=16): 20 | super(SELayer_1d, self).__init__() 21 | # self.avg_pool = nn.AdaptiveAvgPool2d(1) 22 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 23 | self.fc = nn.Sequential( 24 | nn.Linear(channel, channel // reduction, bias=False), 25 | nn.ReLU(inplace=True), 26 | nn.Linear(channel // reduction, channel, bias=False), 27 | nn.Sigmoid() 28 | ) 29 | 30 | def forward(self, x): 31 | b, c, _ = x.size() 32 | y = self.avg_pool(x).view(b, c) 33 | y = self.fc(y).view(b, c, 1) 34 | return x * y.expand_as(x) 35 | 36 | class SE_BasicBlock3x3_1d(nn.Module): 37 | expansion = 1 38 | def __init__(self, inplanes3, planes, stride=1, downsample=None,reduction=16): 39 | super(SE_BasicBlock3x3_1d, self).__init__() 40 | self.conv1 = conv3x3_1d(inplanes3, planes, stride) 41 | self.bn1 = nn.BatchNorm1d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3_1d(planes, planes) 44 | self.bn2 = nn.BatchNorm1d(planes) 45 | self.se = SELayer_1d(planes, reduction) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | out = self.se(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | class SEResNet_1d(nn.Module): 69 | 70 | def __init__(self, input_channel, block, layers, num_classes=1000): 71 | self.inplanes = 64 72 | super(SEResNet_1d, self).__init__() 73 | 74 | self.conv1 = nn.Conv1d(input_channel, 64, kernel_size=7, stride=2, padding=3, 75 | bias=False) 76 | self.bn1 = nn.BatchNorm1d(64) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 79 | 80 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 81 | # bias=False) 82 | # self.bn1 = nn.BatchNorm2d(64) 83 | # self.relu = nn.ReLU(inplace=True) 84 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 85 | self.layer1 = self._make_layer(block, 64, layers[0]) 86 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 87 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 88 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 89 | # self.avgpool = nn.AvgPool2d(7) 90 | self.avgpool = nn.AvgPool1d(8) 91 | self.fc = nn.Linear(512 * block.expansion, num_classes) 92 | 93 | def _make_layer(self, block, planes, blocks, stride=1): 94 | downsample = None 95 | if stride != 1 or self.inplanes != planes * block.expansion: 96 | # downsample = nn.Sequential( 97 | # nn.Conv2d(self.inplanes, planes * block.expansion, 98 | # kernel_size=1, stride=stride, bias=False), 99 | # nn.BatchNorm2d(planes * block.expansion), 100 | # ) 101 | downsample = nn.Sequential( 102 | nn.Conv1d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm1d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | x = self.relu(x) 119 | x = self.maxpool(x) 120 | 121 | x = self.layer1(x) 122 | x = self.layer2(x) 123 | x = self.layer3(x) 124 | x = self.layer4(x) 125 | 126 | x = self.avgpool(x) 127 | x = x.view(x.size(0), -1) 128 | x = self.fc(x) 129 | 130 | return x 131 | 132 | 133 | def se_resnet18_1d(input_channel, num_classes=1000): 134 | """Constructs a ResNet-18 model. 135 | Args: 136 | pretrained (bool): If True, returns a model pre-trained on ImageNet 137 | """ 138 | model = SEResNet_1d(input_channel, SE_BasicBlock3x3_1d, [2, 2, 2, 2], num_classes) 139 | return model 140 | 141 | 142 | def se_resnet34_1d(input_channel, num_classes=1000): 143 | """Constructs a ResNet-34 model. 144 | Args: 145 | pretrained (bool): If True, returns a model pre-trained on ImageNet 146 | """ 147 | model = SEResNet_1d(input_channel, SE_BasicBlock3x3_1d, [3, 4, 6, 3], num_classes) 148 | return model -------------------------------------------------------------------------------- /data_preparation/data_extraction_without_preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import wfdb 7 | from scipy.signal import resample 8 | from stratify import stratify 9 | 10 | save_folder = "./data_folder/extracted_data_without_preprocessing" 11 | save_summary = "./data_folder/data_summary_without_preprocessing" 12 | raw_data_cinc = "./data_folder/datasets" 13 | dataset_names = ["ICBEB2018","ICBEB2018_2","INCART","PTB","PTB-XL","Georgia"] 14 | mapping_scored_path = "./data_folder/evaluation-2020-master/dx_mapping_scored.csv" # 27 main labels 15 | target_fs = 100 16 | strat_folds = 10 17 | channels = 12 18 | 19 | mapping_scored_df = pd.read_csv(mapping_scored_path) 20 | dx_mapping_snomed_abbrev = {a:b for [a,b] in list(mapping_scored_df.apply(lambda row: [row["SNOMED CT Code"],row["Abbreviation"]],axis=1))} 21 | list_label_available = np.array(mapping_scored_df["SNOMED CT Code"]) 22 | 23 | CPSC_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[0],'**/*.hea')) 24 | print('No files in CPSC:', len(CPSC_files)) 25 | CPSC_extra_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[1],'**/*.hea')) 26 | print('No files in CPSC-Extra:', len(CPSC_extra_files)) 27 | SPeter_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[2],'**/*.hea')) 28 | print('No files in StPetersburg:', len(SPeter_files)) 29 | PTB_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[3],'**/*.hea')) 30 | print('No files in PTB:', len(PTB_files)) 31 | PTBXL_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[4],'**/*.hea')) 32 | print('No files in PTB-XL:', len(PTBXL_files)) 33 | Georgia_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[5],'**/*.hea')) 34 | print('No files in Georgia:', len(Georgia_files)) 35 | 36 | all_files = CPSC_files + CPSC_extra_files + SPeter_files + PTB_files + PTBXL_files + Georgia_files 37 | print('Total no files:',len(all_files)) 38 | # (7500, 12) 39 | # {'fs': 500, 'sig_len': 7500, 'n_sig': 12, 'base_date': None, 'base_time': datetime.time(0, 0, 12), 40 | # 'units': ['mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV'], 41 | # 'sig_name': ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'], 42 | # 'comments': ['Age: 74', 'Sex: Male', 'Dx: 59118001', 'Rx: Unknown', 'Hx: Unknown', 'Sx: Unknown']} 43 | 44 | skip_files = 0 45 | metadata = [] 46 | for idx, hea_file in enumerate(tqdm(all_files)): 47 | file_name = hea_file.split("/")[-1].split(".hea")[0] 48 | data_folder = hea_file.split("/")[-3] 49 | sigbufs, header = wfdb.rdsamp(str(hea_file)[:-4]) 50 | 51 | if(np.any(np.isnan(sigbufs))): 52 | print("Warning:",str(hea_file),"is corrupt. Skipping.") 53 | continue 54 | 55 | labels=[] 56 | age=np.nan 57 | sex="nan" 58 | for l in header["comments"]: 59 | arrs = l.strip().split(' ') 60 | if l.startswith('Dx:'): 61 | for x in arrs[1].split(','): 62 | if int(x) in list_label_available: 63 | labels.append(x) 64 | elif l.startswith('Age:'): 65 | try: 66 | age = int(arrs[1]) 67 | except: 68 | age= np.nan 69 | elif l.startswith('Sex:'): 70 | sex = arrs[1].strip().lower() 71 | if(sex=="m"): 72 | sex="male" 73 | elif(sex=="f"): 74 | sex="female" 75 | 76 | if len(labels) == 0: 77 | skip_files += 1 78 | continue 79 | 80 | ori_fs = header['fs'] 81 | factor = target_fs/ori_fs 82 | timesteps_new = int(len(sigbufs)*factor) 83 | data = np.zeros((timesteps_new, channels), dtype=np.float32) 84 | for i in range(channels): 85 | data[:,i] = resample(sigbufs[:,0],timesteps_new) 86 | 87 | np.save(os.path.join(save_folder,file_name+".npy"),data) 88 | 89 | metadata.append({"data":file_name+".npy","label":labels,"sex":sex,"age":age,"dataset":data_folder}) 90 | 91 | df = pd.DataFrame(metadata) 92 | lbl_itos = np.unique([item for sublist in list(df.label) for item in sublist]) 93 | lbl_stoi = {s:i for i,s in enumerate(lbl_itos)} 94 | df["label"] = df["label"].apply(lambda x: [lbl_stoi[y] for y in x]) 95 | 96 | df["strat_fold"]=-1 97 | for ds in np.unique(df["dataset"]): 98 | print("Creating CV folds:",ds) 99 | dfx = df[df.dataset==ds] 100 | idxs = np.array(dfx.index.values) 101 | lbl_itosx = np.unique([item for sublist in list(dfx.label) for item in sublist]) 102 | stratified_ids = stratify(list(dfx["label"]), lbl_itosx, [1./strat_folds]*strat_folds) 103 | 104 | for i,split in enumerate(stratified_ids): 105 | df.loc[idxs[split],"strat_fold"]=i 106 | 107 | print("Add Mean Column") 108 | df["data_mean"]=df["data"].apply(lambda x: np.mean(np.load(x if save_folder is None else os.path.join(save_folder,x), allow_pickle=True),axis=0)) 109 | print("Add Std Column") 110 | df["data_std"]=df["data"].apply(lambda x: np.std(np.load(x if data_folder is None else os.path.join(save_folder,x), allow_pickle=True),axis=0)) 111 | print("Add Length Column") 112 | df["data_length"]=df["data"].apply(lambda x: len(np.load(x if data_folder is None else os.path.join(save_folder,x), allow_pickle=True))) 113 | 114 | #save means and stds 115 | df_mean = df["data_mean"].mean() 116 | df_std = df["data_std"].mean() 117 | 118 | # save dataset 119 | df.to_pickle(os.path.join(save_summary,'df.pkl'),protocol=4) 120 | np.save(os.path.join(save_summary,"lbl_itos.npy"),lbl_itos) 121 | np.save(os.path.join(save_summary,"mean.npy"),df_mean) 122 | np.save(os.path.join(save_summary,"std.npy"),df_std) 123 | 124 | 125 | # file1 = 'df.pkl' 126 | # file2 = 'lbl_itos.npy' 127 | # file3 = 'memmap.npy' 128 | # file4 = 'memmap_meta.npz' 129 | # file5 = 'df_memmap.pkl' 130 | # file6 = 'mean.npy' 131 | # file7 = 'std.npy' -------------------------------------------------------------------------------- /data_preparation/stratify.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def stratify(data, classes, ratios, samples_per_group=None): 4 | """Stratifying procedure. Modified from https://vict0rs.ch/2018/05/24/sample-multilabel-dataset/ (based on Sechidis 2011) 5 | 6 | data is a list of lists: a list of labels, for each sample (possibly containing duplicates not multi-hot encoded). 7 | 8 | classes is the list of classes each label can take 9 | 10 | ratios is a list, summing to 1, of how the dataset should be split 11 | 12 | samples_per_group: list with number of samples per patient/group 13 | 14 | """ 15 | np.random.seed(0) # fix the random seed 16 | # data is now always a list of lists; len(data) is the number of patients; data[i] is the list of all labels for patient i (possibly multiple identical entries) 17 | 18 | if(samples_per_group is None): 19 | samples_per_group = np.ones(len(data)) 20 | 21 | #size is the number of ecgs 22 | size = np.sum(samples_per_group) 23 | 24 | # Organize data per label: for each label l, per_label_data[l] contains the list of patients 25 | # in data which have this label (potentially multiple identical entries) 26 | per_label_data = {c: [] for c in classes} 27 | for i, d in enumerate(data): 28 | for l in d: 29 | per_label_data[l].append(i) 30 | 31 | # In order not to compute lengths each time, they are tracked here. 32 | subset_sizes = [r * size for r in ratios] #list of subset_sizes in terms of ecgs 33 | per_label_subset_sizes = { c: [r * len(per_label_data[c]) for r in ratios] for c in classes } #dictionary with label: list of subset sizes in terms of patients 34 | 35 | # For each subset we want, the set of sample-ids which should end up in it 36 | stratified_data_ids = [set() for _ in range(len(ratios))] #initialize empty 37 | 38 | # For each sample in the data set 39 | print("Starting fold distribution...") 40 | size_prev=size+1 #just for output 41 | while size > 0: 42 | if(int(size_prev/1000) > int(size/1000)): 43 | print("Remaining entries to distribute:",size,"non-empty labels:", np.sum([1 for l, label_data in per_label_data.items() if len(label_data)>0])) 44 | size_prev=size 45 | # Compute |Di| 46 | lengths = { 47 | l: len(label_data) 48 | for l, label_data in per_label_data.items() 49 | } #dictionary label: number of ecgs with this label that have not been assigned to a fold yet 50 | try: 51 | # Find label of smallest |Di| 52 | label = min({k: v for k, v in lengths.items() if v > 0}, key=lengths.get) 53 | except ValueError: 54 | # If the dictionary in `min` is empty we get a Value Error. 55 | # This can happen if there are unlabeled samples. 56 | # In this case, `size` would be > 0 but only samples without label would remain. 57 | # "No label" could be a class in itself: it's up to you to format your data accordingly. 58 | break 59 | # For each patient with label `label` get patient and corresponding counts 60 | unique_samples, unique_counts = np.unique(per_label_data[label],return_counts=True) 61 | idxs_sorted = np.argsort(unique_counts, kind='stable')[::-1] 62 | unique_samples = unique_samples[idxs_sorted] # this is a list of all patient ids with this label sort by size descending 63 | unique_counts = unique_counts[idxs_sorted] # these are the corresponding counts 64 | 65 | # loop through all patient ids with this label 66 | for current_id, current_count in zip(unique_samples,unique_counts): 67 | 68 | subset_sizes_for_label = per_label_subset_sizes[label] #current subset sizes for the chosen label 69 | 70 | # Find argmax clj i.e. subset in greatest need of the current label 71 | largest_subsets = np.argwhere(subset_sizes_for_label == np.amax(subset_sizes_for_label)).flatten() 72 | 73 | # if there is a single best choice: assign it 74 | if len(largest_subsets) == 1: 75 | subset = largest_subsets[0] 76 | # If there is more than one such subset, find the one in greatest need of any label 77 | else: 78 | largest_subsets2 = np.argwhere(np.array(subset_sizes)[largest_subsets] == np.amax(np.array(subset_sizes)[largest_subsets])).flatten() 79 | subset = largest_subsets[np.random.choice(largest_subsets2)] 80 | 81 | # Store the sample's id in the selected subset 82 | stratified_data_ids[subset].add(current_id) 83 | 84 | # There is current_count fewer samples to distribute 85 | size -= samples_per_group[current_id] 86 | # The selected subset needs current_count fewer samples 87 | subset_sizes[subset] -= samples_per_group[current_id] 88 | 89 | # In the selected subset, there is one more example for each label 90 | # the current sample has 91 | for l in data[current_id]: 92 | per_label_subset_sizes[l][subset] -= 1 93 | 94 | # Remove the sample from the dataset, meaning from all per_label dataset created 95 | for x in per_label_data.keys(): 96 | per_label_data[x] = [y for y in per_label_data[x] if y!=current_id] 97 | 98 | # Create the stratified dataset as a list of subsets, each containing the orginal labels 99 | stratified_data_ids = [sorted(strat) for strat in stratified_data_ids] 100 | #stratified_data = [ 101 | # [data[i] for i in strat] for strat in stratified_data_ids 102 | #] 103 | 104 | # Return both the stratified indexes, to be used to sample the `features` associated with your labels 105 | # And the stratified labels dataset 106 | 107 | #return stratified_data_ids, stratified_data 108 | return stratified_data_ids 109 | -------------------------------------------------------------------------------- /models/ensemble_model.py: -------------------------------------------------------------------------------- 1 | from models.xresnet1d import xresnet1d50 2 | from models.seresnet2d import se_resnet34 3 | import torch.nn.functional as F 4 | import torch 5 | import torch.nn as nn 6 | 7 | class ensemble_model(nn.Module): 8 | def __init__(self, no_classes=24, gate=False, w_time=None, w_spec=None,device=None): 9 | super(ensemble_model,self).__init__() 10 | # gating encoding 11 | self.gate = gate 12 | 13 | # Time series module 14 | self.time_backbone = xresnet1d50(widen=1.0) 15 | time_list_of_modules = list(self.time_backbone.children()) 16 | self.time_features = nn.Sequential(*time_list_of_modules[:-1], time_list_of_modules[-1][0]) 17 | time_num_ftrs = self.time_backbone[-1][-1].in_features 18 | self.time_backbone[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2) 19 | 20 | if w_time is not None: 21 | time_state_dict = torch.load(w_time,map_location=device) 22 | self.time_features.load_state_dict(time_state_dict,strict=False) 23 | 24 | self.spec_backbone = se_resnet34() 25 | self.spec_backbone.conv1 = nn.Conv2d(12, 64, kernel_size=7, stride=2, padding=3) 26 | spec_list_of_modules = list(self.spec_backbone.children()) 27 | self.spec_features = nn.Sequential(*spec_list_of_modules[:-1]) 28 | spec_num_ftrs = self.spec_backbone.fc.in_features 29 | 30 | if w_spec is not None: 31 | spec_state_dict = torch.load(w_spec,map_location=device) 32 | self.spec_features.load_state_dict(spec_state_dict,strict=False) 33 | 34 | if self.gate: 35 | num_ftrs = time_num_ftrs + spec_num_ftrs 36 | self.gate_fc = nn.Linear(num_ftrs,2) 37 | self.fc = nn.Sequential( 38 | nn.Linear(in_features=num_ftrs,out_features=num_ftrs//2), 39 | nn.Linear(in_features=num_ftrs//2,out_features=no_classes) 40 | ) 41 | else: 42 | num_ftrs = time_num_ftrs + spec_num_ftrs 43 | self.fc = nn.Sequential( 44 | nn.Linear(in_features=num_ftrs,out_features=num_ftrs//2), 45 | nn.Linear(in_features=num_ftrs//2,out_features=no_classes) 46 | ) 47 | 48 | 49 | def forward(self, x_sig, x_spec): 50 | h_time = self.time_features(x_sig) 51 | h_time = h_time.squeeze() 52 | 53 | h_spec = self.spec_features(x_spec) 54 | h_spec = h_spec.squeeze() 55 | 56 | if self.gate: 57 | h_gate = F.softmax(self.gate_fc(torch.cat((h_time,h_spec),dim=1)),dim=1) 58 | h_encode = torch.cat([h_time*h_gate[:,0:1],h_spec*h_gate[:,1:2]],dim=1) 59 | x = self.fc(h_encode) 60 | return x 61 | else: 62 | h_comb = torch.cat((h_time,h_spec),1) 63 | x = self.fc(h_comb) 64 | return x 65 | 66 | 67 | 68 | class ensemble_model_3head(nn.Module): 69 | def __init__(self, no_classes=24,w_time=None, w_spec=None,device=None): 70 | super(ensemble_model_3head,self).__init__() 71 | 72 | # Time series module 73 | self.time_backbone = xresnet1d50(widen=1.0) 74 | time_list_of_modules = list(self.time_backbone.children()) 75 | self.time_features = nn.Sequential(*time_list_of_modules[:-1], time_list_of_modules[-1][0]) 76 | time_num_ftrs = self.time_backbone[-1][-1].in_features 77 | self.time_backbone[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2) 78 | 79 | if w_time is not None: 80 | time_state_dict = torch.load(w_time,map_location=device) 81 | self.time_features.load_state_dict(time_state_dict,strict=False) 82 | 83 | self.spec_backbone = se_resnet34() 84 | self.spec_backbone.conv1 = nn.Conv2d(12, 64, kernel_size=7, stride=2, padding=3) 85 | spec_list_of_modules = list(self.spec_backbone.children()) 86 | self.spec_features = nn.Sequential(*spec_list_of_modules[:-1]) 87 | spec_num_ftrs = self.spec_backbone.fc.in_features 88 | 89 | if w_spec is not None: 90 | spec_state_dict = torch.load(w_spec,map_location=device) 91 | self.spec_features.load_state_dict(spec_state_dict,strict=False) 92 | 93 | 94 | num_ftrs = time_num_ftrs + spec_num_ftrs 95 | self.gate_fc = nn.Linear(num_ftrs,2) 96 | 97 | self.fc = nn.Sequential( 98 | nn.Linear(in_features=num_ftrs,out_features=num_ftrs//2), 99 | nn.Linear(in_features=num_ftrs//2,out_features=no_classes) 100 | ) 101 | self.fc_time = nn.Sequential( 102 | nn.Linear(in_features=time_num_ftrs,out_features=time_num_ftrs//2), 103 | nn.Linear(in_features=time_num_ftrs//2,out_features=no_classes) 104 | ) 105 | self.fc_spec = nn.Sequential( 106 | nn.Linear(in_features=spec_num_ftrs,out_features=spec_num_ftrs//2), 107 | nn.Linear(in_features=spec_num_ftrs//2,out_features=no_classes) 108 | ) 109 | 110 | # for p in self.fc.parameters(): 111 | # p.requires_grad = False 112 | # for p in self.gate_fc.parameters(): 113 | # p.requires_grad = False 114 | 115 | 116 | def forward(self, x_sig, x_spec): 117 | h_time = self.time_features(x_sig) 118 | h_time = h_time.squeeze() 119 | 120 | h_spec = self.spec_features(x_spec) 121 | h_spec = h_spec.squeeze() 122 | 123 | h_gate = F.softmax(self.gate_fc(torch.cat((h_time,h_spec),dim=1)),dim=1) 124 | h_encode = torch.cat([h_time*h_gate[:,0:1],h_spec*h_gate[:,1:2]],dim=1) 125 | y = self.fc(h_encode) 126 | y_time = self.fc_time(h_time) 127 | y_spec = self.fc_spec(h_spec) 128 | 129 | return y, y_time, y_spec, h_gate 130 | 131 | def freeze_backbone(self): 132 | for p in self.spec_features.parameters(): 133 | p.requires_grad = False 134 | for p in self.time_features.parameters(): 135 | p.requires_grad = False 136 | 137 | def freeze_gate(self): 138 | for p in self.gate_fc.parameters(): 139 | p.requires_grad = False 140 | -------------------------------------------------------------------------------- /experiments/SIMCLR_signal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | from torch.utils.data import DataLoader 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | import math 11 | 12 | import sys 13 | current_path = os.getcwd() 14 | sys.path.append(current_path) 15 | 16 | from models.signal_model import signal_model_simclr 17 | from utils.contrastive_dataloader import ECG_contrastive_dataset 18 | from utils.tools import weights_init_xavier 19 | 20 | ctx = "cuda:0" if torch.cuda.is_available() else 'cpu' 21 | eps = 1e-7 22 | 23 | class Flatten(nn.Module): 24 | 25 | def __init__(self): 26 | super(Flatten, self).__init__() 27 | 28 | def forward(self, input_tensor): 29 | return input_tensor.view(input_tensor.size(0), -1) 30 | 31 | class Projection(nn.Module): 32 | def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): 33 | super().__init__() 34 | self.output_dim = output_dim 35 | self.input_dim = input_dim 36 | self.hidden_dim = hidden_dim 37 | self.model = nn.Sequential( 38 | # nn.AdaptiveAvgPool2d((1, 1)), 39 | Flatten(), 40 | nn.Linear(self.input_dim, self.hidden_dim, bias=True), 41 | # nn.BatchNorm1d(self.hidden_dim), 42 | nn.ReLU(), 43 | nn.Linear(self.hidden_dim, self.output_dim, bias=True)) 44 | 45 | def forward(self, x): 46 | x = self.model(x) 47 | return F.normalize(x, dim=1) 48 | 49 | def nt_xent_loss(out_1, out_2, temperature, eps=1e-6): 50 | """ 51 | assume out_1 and out_2 are normalized 52 | out_1: [batch_size, dim] 53 | out_2: [batch_size, dim] 54 | """ 55 | # gather representations in case of distributed training 56 | # out_1_dist: [batch_size * world_size, dim] 57 | # out_2_dist: [batch_size * world_size, dim] 58 | 59 | out_1_dist = out_1 60 | out_2_dist = out_2 61 | 62 | # out: [2 * batch_size, dim] 63 | # out_dist: [2 * batch_size * world_size, dim] 64 | out = torch.cat([out_1, out_2], dim=0) 65 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) 66 | 67 | # cov and sim: [2 * batch_size, 2 * batch_size * world_size] 68 | # neg: [2 * batch_size] 69 | cov = torch.mm(out, out_dist.t().contiguous()) 70 | sim = torch.exp(cov / temperature) 71 | neg = sim.sum(dim=-1) 72 | 73 | # from each row, subtract e^1 to remove similarity measure for x1.x1 74 | row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device) 75 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 76 | 77 | # Positive similarity, pos becomes [2 * batch_size] 78 | pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) 79 | pos = torch.cat([pos, pos], dim=0) 80 | 81 | loss = -torch.log(pos / (neg + eps)).mean() 82 | 83 | return loss 84 | 85 | 86 | def run(): 87 | root_folder = './data_folder' 88 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing') 89 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']] 90 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']] 91 | 92 | no_channels = 12 93 | signal_size = 250 94 | train_stride = signal_size 95 | train_chunk_length = 0 96 | 97 | transforms = ["TimeOut_difflead","GaussianNoise"] 98 | 99 | batch_size = 1024 100 | learning_rate = 1e-3 101 | no_epoches = 1000 102 | 103 | get_mean = np.load(os.path.join(data_folder,"mean.npy")) 104 | get_std = np.load(os.path.join(data_folder,"std.npy")) 105 | 106 | t_params = {"gaussian_scale":[0.005,0.025], "global_crop_scale": [0.5, 1.0], "local_crop_scale": [0.1, 0.5], 107 | "output_size": 250, "warps": 3, "radius": 10, "shift_range":[0.2,0.5], 108 | "epsilon": 10, "magnitude_range": [0.5, 2], "downsample_ratio": 0.2, "to_crop_ratio_range": [0.2, 0.4], 109 | "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1, "stats_mean":get_mean,"stats_std":get_std} 110 | 111 | train_dataset = ECG_contrastive_dataset(summary_folder=data_folder, signal_size=signal_size, stride=train_stride, 112 | chunk_length=train_chunk_length, transforms=transforms,t_params=t_params, 113 | equivalent_classes=equivalent_classes, sample_items_per_record=1,random_crop=True) 114 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size,drop_last=True) 115 | 116 | no_classes = 24 117 | model = signal_model_simclr(no_classes) 118 | projection_head = Projection(model.num_ftrs, hidden_dim=512, output_dim=128) 119 | 120 | model.apply(weights_init_xavier) 121 | projection_head.apply(weights_init_xavier) 122 | model.to(ctx) 123 | projection_head.to(ctx) 124 | 125 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate) 126 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-4, last_epoch=-1) 127 | 128 | optimizer.zero_grad() 129 | optimizer.step() 130 | 131 | lowest_train_loss = 10 132 | for epoch in range(1,no_epoches+1): 133 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches)) 134 | print('Current learning rate: ',optimizer.param_groups[0]['lr']) 135 | scheduler_steplr.step() 136 | model.train() 137 | train_loss = 0 138 | 139 | for batch_idx, sample in enumerate(tqdm(train_dataloader)): 140 | data_i = sample['sig_i'].to(ctx).float() 141 | data_j = sample['sig_j'].to(ctx).float() 142 | 143 | h1 = model(data_i)[0] 144 | h2 = model(data_j)[0] 145 | 146 | # PROJECT 147 | # img -> E -> h -> || -> z 148 | # (b, 2048, 2, 2) -> (b, 128) 149 | z1 = projection_head(h1.squeeze()) 150 | z2 = projection_head(h2.squeeze()) 151 | 152 | loss = nt_xent_loss(z1,z2,temperature=0.1) 153 | 154 | train_loss += loss.item() 155 | 156 | optimizer.zero_grad() 157 | loss.backward() 158 | optimizer.step() 159 | 160 | whole_train_loss = train_loss / (batch_idx + 1) 161 | print(f'Train Loss: {whole_train_loss}') 162 | if whole_train_loss < lowest_train_loss: 163 | lowest_train_loss = whole_train_loss 164 | torch.save(model.state_dict(), f'./checkpoints/SIMCLR_signal.pth') 165 | 166 | 167 | if __name__ == "__main__": 168 | run() -------------------------------------------------------------------------------- /data_folder/evaluation-2020-master/Results/physionet_2020_metrics_perDatabase_official_entries.csv: -------------------------------------------------------------------------------- 1 | ,,Database ->,Validation Dataset,Validation Dataset,Validation Dataset,Validation Dataset,Validation Dataset,Hidden CSPC Set,Hidden CSPC Set,Hidden CSPC Set,Hidden CSPC Set,Hidden CSPC Set,Hidden G12EC Set,Hidden G12EC Set,Hidden G12EC Set,Hidden G12EC Set,Hidden G12EC Set,Hidden Undisclosed Set,Hidden Undisclosed Set,Hidden Undisclosed Set,Hidden Undisclosed Set,Hidden Undisclosed Set,Test set,Test set,Test set,Test set,Test set 2 | Final ranking,Team name,CinC Abstract #,AUROC,AUPRC,Accuracy,F-measure,Challenge Score,AUROC,AUPRC,Accuracy,F-measure,Challenge Score,AUROC,AUPRC,Accuracy,F-measure,Challenge Score,AUROC,AUPRC,Accuracy,F-measure,Challenge Score,AUROC,AUPRC,Accuracy,F-measure,Challenge Score 3 | 1,prna,107,0.893,0.428,0.279,0.411,0.587,0.964,0.832,0.532,0.202,0.761,0.871,0.414,0.206,0.401,0.558,0.889,0.523,0.365,0.379,0.492,0.880,0.429,0.330,0.409,0.533 4 | 2,Between_a_ROC_and_a_heart_place,112,0.932,0.548,0.367,0.495,0.672,0.971,0.856,0.555,0.214,0.845,0.915,0.538,0.305,0.495,0.639,0.906,0.578,0.305,0.378,0.412,0.889,0.502,0.324,0.464,0.520 5 | 3,HeartBeats,281,0.945,0.556,0.400,0.525,0.682,0.968,0.816,0.587,0.229,0.852,0.930,0.558,0.338,0.523,0.649,0.909,0.594,0.309,0.416,0.396,0.900,0.510,0.340,0.487,0.514 6 | 4,Triage,133,0.909,0.491,0.424,0.471,0.640,0.962,0.829,0.610,0.216,0.833,0.894,0.491,0.367,0.472,0.609,0.907,0.584,0.358,0.396,0.370,0.897,0.492,0.382,0.462,0.485 7 | 5,Sharif AI Team,445,0.930,0.488,0.359,0.452,0.609,0.966,0.805,0.499,0.222,0.793,0.913,0.489,0.311,0.456,0.577,0.908,0.562,0.297,0.371,0.314,0.911,0.469,0.316,0.426,0.437 8 | 6,DSAIL_SNU,328,0.947,0.570,0.389,0.541,0.688,0.981,0.899,0.598,0.250,0.872,0.937,0.561,0.320,0.538,0.654,0.929,0.592,0.319,0.286,0.228,0.900,0.514,0.341,0.433,0.420 9 | 7,UMCUVA,253,0.932,0.516,0.243,0.468,0.586,0.956,0.823,0.410,0.197,0.643,0.919,0.510,0.189,0.469,0.574,0.917,0.592,0.225,0.377,0.298,0.915,0.481,0.228,0.438,0.417 10 | 8,CQUPT_ECG,85,0.932,0.507,0.320,0.468,0.640,0.966,0.815,0.501,0.190,0.800,0.915,0.504,0.259,0.459,0.609,0.781,0.359,0.064,0.219,0.248,0.821,0.364,0.160,0.321,0.411 11 | 9,ECU,161,0.916,0.490,0.362,0.450,0.623,0.959,0.808,0.538,0.201,0.797,0.905,0.496,0.309,0.476,0.596,0.802,0.325,0.199,0.209,0.205,0.832,0.365,0.262,0.352,0.382 12 | 10,PALab,35,0.942,0.549,0.381,0.530,0.653,0.971,0.873,0.574,0.218,0.836,0.928,0.541,0.319,0.525,0.623,0.751,0.319,0.247,0.199,0.144,0.852,0.393,0.296,0.380,0.359 13 | 11,HITTING,171,0.701,0.273,0.338,0.366,0.435,0.841,0.606,0.337,0.220,0.556,0.699,0.291,0.334,0.381,0.418,0.730,0.378,0.344,0.386,0.290,0.695,0.289,0.339,0.366,0.354 14 | 12,Gio_Ivo,116,0.830,0.314,0.045,0.296,0.426,0.882,0.619,0.117,0.116,0.452,0.799,0.312,0.026,0.304,0.421,0.810,0.376,0.047,0.244,0.205,0.777,0.302,0.047,0.266,0.298 15 | 13,AUTh Team,417,0.879,0.388,0.057,0.349,0.470,0.918,0.698,0.093,0.169,0.447,0.869,0.397,0.047,0.358,0.476,0.834,0.412,0.008,0.234,0.143,0.815,0.329,0.028,0.272,0.281 16 | 14,BioS,124,,,,,,,,,,,,,,,,,,,,,,,,, 17 | 15,UC_Lab_Kn,229,0.938,0.528,0.322,0.480,0.656,0.973,0.871,0.606,0.237,0.840,0.851,0.392,0.221,0.326,0.300,0.828,0.464,0.289,0.277,0.190,0.845,0.391,0.294,0.315,0.270 18 | 16,Cardio-Challengers,225,0.498,0.060,0.000,0.105,0.337,0.500,0.135,0.000,0.058,0.176,0.498,0.067,0.000,0.115,0.369,0.495,0.103,0.001,0.119,0.198,0.498,0.072,0.001,0.116,0.258 19 | 17,JuJuRock,134,0.457,0.060,0.021,0.271,0.406,0.577,0.169,0.013,0.127,0.253,0.439,0.065,0.022,0.292,0.437,0.441,0.093,0.009,0.177,0.125,0.436,0.064,0.012,0.223,0.244 20 | 18,Minibus,282,0.864,0.489,0.476,0.430,0.446,0.963,0.861,0.755,0.235,0.722,0.862,0.491,0.391,0.447,0.394,0.828,0.493,0.284,0.324,0.088,0.828,0.448,0.357,0.409,0.236 21 | 19,Desafinado,363,0.906,0.478,0.224,0.413,0.576,0.967,0.833,0.325,0.182,0.681,0.887,0.481,0.186,0.412,0.556,0.806,0.358,0.009,0.202,-0.013,0.822,0.362,0.089,0.298,0.233 22 | 20,TeamUIO,227,0.846,0.334,0.079,0.309,0.377,0.856,0.605,0.09,0.165,0.379,0.812,0.315,0.07,0.307,0.382,0.677,0.215,0.133,0.15,0.076,0.728,0.218,0.107,0.233,0.206 23 | 21,Eagles,138,0.677,0.189,0.160,0.195,0.214,0.714,0.351,0.146,0.104,0.235,0.653,0.188,0.155,0.185,0.205,0.648,0.270,0.302,0.186,0.205,0.647,0.202,0.240,0.200,0.205 24 | 22,BUTTeam,189,0.940,0.531,0.395,0.522,0.696,0.974,0.844,0.661,0.245,0.892,0.864,0.381,0.238,0.259,0.235,0.877,0.467,0.292,0.251,0.104,0.850,0.392,0.307,0.277,0.202 25 | 23,DSC,71,0.769,0.374,0.429,0.536,0.616,0.900,0.647,0.597,0.231,0.824,0.670,0.248,0.288,0.350,0.301,0.668,0.304,0.286,0.283,0.062,0.658,0.245,0.311,0.316,0.194 26 | 24,Pink Irish Hat,198,0.878,0.447,0.417,0.381,0.511,0.944,0.796,0.653,0.274,0.762,0.715,0.267,0.170,0.193,0.127,0.776,0.397,0.267,0.287,0.123,0.748,0.311,0.271,0.256,0.167 27 | 25,Madhardmax,185,0.921,0.471,0.365,0.461,0.533,0.958,0.810,0.508,0.221,0.544,0.914,0.470,0.315,0.454,0.525,0.916,0.542,0.240,0.281,-0.109,0.895,0.426,0.284,0.373,0.155 28 | 26,Care4MyHeart,127,0.869,0.361,0.250,0.350,0.379,0.929,0.721,0.408,0.168,0.611,0.862,0.362,0.208,0.352,0.342,0.820,0.376,0.108,0.239,-0.027,0.828,0.315,0.166,0.290,0.146 29 | 27,MCIRCC,374,0.907,0.442,0.333,0.433,0.616,0.956,0.810,0.665,0.234,0.813,0.810,0.315,0.232,0.199,0.162,0.792,0.398,0.274,0.231,0.050,0.807,0.328,0.296,0.243,0.141 30 | 28,heartly-ai,356,0.870,0.370,0.310,0.230,0.159,0.927,0.730,0.514,0.202,0.351,0.847,0.378,0.249,0.210,0.128,0.870,0.476,0.346,0.262,0.116,0.847,0.387,0.330,0.237,0.136 31 | 29,Code Team,130,0.940,0.531,0.369,0.513,0.657,0.968,0.850,0.658,0.302,0.830,0.835,0.370,0.248,0.237,0.181,0.835,0.467,0.276,0.237,0.023,0.831,0.376,0.300,0.256,0.132 32 | 30,ISIBrno,32,0.922,0.533,0.417,0.510,0.659,0.977,0.893,0.717,0.262,0.847,0.815,0.356,0.246,0.270,0.195,0.492,0.102,0.002,0.046,-0.006,0.675,0.222,0.141,0.191,0.122 33 | 31,Alba_W.O.,61,0.602,0.182,0.368,0.220,0.308,0.815,0.491,0.568,0.434,0.709,0.576,0.120,0.192,0.125,0.094,0.540,0.151,0.302,0.115,0.035,0.554,0.125,0.291,0.130,0.102 34 | 32,AI Strollers,277,0.625,0.124,0.000,0.140,0.342,0.783,0.437,0.001,0.117,0.212,0.622,0.135,0.000,0.153,0.359,0.629,0.197,0.000,0.144,0.096,0.599,0.136,0.000,0.142,0.077 35 | 33,ECGLearner,95,0.903,0.473,0.421,0.451,0.486,0.956,0.807,0.674,0.195,0.669,0.886,0.461,0.339,0.441,0.452,0.786,0.315,0.054,0.159,-0.347,0.829,0.324,0.194,0.300,0.001 36 | 34,Leicester-Fox,135,0.647,0.242,0.340,0.316,0.395,0.861,0.650,0.543,0.218,0.717,0.635,0.226,0.279,0.301,0.340,0.556,0.163,0.021,0.118,-0.309,0.587,0.163,0.146,0.207,-0.012 37 | 35,deepzx987,424,0.825,0.326,0.331,0.286,0.305,0.919,0.711,0.527,0.228,0.648,0.812,0.318,0.277,0.259,0.25,0.694,0.239,0.08,0.118,-0.287,0.742,0.229,0.181,0.182,-0.035 38 | 36,CVC,128,0.858,0.386,0.208,0.369,0.476,0.951,0.768,0.431,0.199,0.491,0.782,0.305,0.140,0.237,0.150,0.774,0.351,0.088,0.202,-0.287,0.762,0.292,0.135,0.233,-0.080 39 | 37,Cordi-Ak,297,0.815,0.299,0.083,0.238,0.304,0.773,0.507,0.209,0.160,0.254,0.733,0.219,0.025,0.183,0.267,0.639,0.190,0.013,0.090,-0.387,0.688,0.167,0.035,0.136,-0.113 40 | 38,MIndS,339,0.845,0.343,0.317,0.296,0.368,0.905,0.672,0.463,0.190,0.587,0.836,0.336,0.275,0.287,0.333,0.657,0.175,0.017,0.080,-0.489,0.828,0.448,0.357,0.409,-0.128 41 | 39,easyG,148,0.865,0.369,0.392,0.310,0.403,0.919,0.730,0.651,0.312,0.692,0.789,0.295,0.229,0.176,0.066,0.690,0.190,0.020,0.081,-0.622,0.726,0.209,0.141,0.126,-0.290 42 | 40,BiSP Lab,406,0.690,0.164,0.205,0.081,-0.179,0.683,0.283,0.278,0.062,-0.228,0.638,0.133,0.181,0.029,-0.087,0.602,0.147,0.031,0.027,-0.740,0.585,0.109,0.097,0.031,-0.476 43 | 41,Technion_AIMLAB,202,0.662,0.175,0.117,0.000,-0.406,0.774,0.448,0.233,0.000,-0.455,0.646,0.160,0.084,0.000,-0.390,0.650,0.249,0.004,0.000,-0.848,0.645,0.180,0.050,0.000,-0.658 44 | -------------------------------------------------------------------------------- /experiments/SIMCLR_signal_finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | from torch.utils.data import DataLoader 8 | import torch.nn as nn 9 | 10 | import sys 11 | current_path = os.getcwd() 12 | sys.path.append(current_path) 13 | 14 | from models.signal_model import signal_model_simclr 15 | from utils.base_dataloader import ECG_dataset_base 16 | from utils.eval_tools import load_weights 17 | from utils.eval_tools import compute_accuracy, compute_f_measure_mod 18 | from utils.eval_tools import compute_auc, load_weights, compute_challenge_metric 19 | from utils.tools import open_all_layers, open_specified_layers 20 | 21 | ctx = "cuda:0" if torch.cuda.is_available() else 'cpu' 22 | 23 | 24 | def run(): 25 | root_folder = './data_folder' 26 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing') 27 | 28 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']] 29 | normal_class = '426783006' 30 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']] 31 | weights_file = './data_folder/evaluation-2020-master/weights.csv' 32 | classes, weights = load_weights(weights_file, equivalent_classes) 33 | 34 | no_fold = 8 35 | no_channels = 12 36 | signal_size = 250 37 | train_stride = signal_size 38 | train_chunk_length = 0 39 | # train_stride = signal_size//2 40 | # train_chunk_length = signal_size 41 | val_stride = signal_size//2 # overlap sample signal 42 | val_chunk_length = signal_size 43 | 44 | transforms = True 45 | batch_size = 256 46 | learning_rate = 5e-3 47 | no_epoches = 80 48 | warmup_epoches = 10 49 | 50 | train_dataset = ECG_dataset_base(summary_folder=data_folder,classes=classes, signal_size=signal_size, stride=train_stride, 51 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='train', 52 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold) 53 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size) 54 | 55 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride, 56 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='val', 57 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold) 58 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size) 59 | 60 | no_classes = train_dataset.get_num_classes() 61 | model = signal_model_simclr(no_classes) 62 | state_dict = torch.load('./checkpoints/SIMCLR_signal.pth',map_location=ctx) 63 | model.load_state_dict(state_dict,strict=True) 64 | model.to(ctx) 65 | 66 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate) 67 | criterion = nn.BCEWithLogitsLoss() 68 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-4, last_epoch=-1) 69 | optimizer.zero_grad() 70 | optimizer.step() 71 | 72 | for epoch in range(1,no_epoches+1): 73 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches)) 74 | print('Current learning rate: ',optimizer.param_groups[0]['lr']) 75 | scheduler_steplr.step() 76 | 77 | if epoch <= warmup_epoches: 78 | open_specified_layers(model,['backbone','features']) 79 | print('Freeze the backbone') 80 | else: 81 | open_all_layers(model) 82 | 83 | model.train() 84 | train_loss = 0 85 | train_pred = [] 86 | train_gt = [] 87 | 88 | for batch_idx, sample in enumerate(tqdm(train_dataloader)): 89 | signal = sample['sig'].to(ctx).float() 90 | signal = signal.view(-1,no_channels,signal_size) 91 | label = sample['lbl'].to(ctx).float() 92 | label = label.view(-1,no_classes) 93 | 94 | _, pred = model(signal) 95 | result = torch.sigmoid(pred) 96 | 97 | loss = criterion(pred,label) 98 | train_loss += loss.item() 99 | 100 | optimizer.zero_grad() 101 | loss.backward() 102 | optimizer.step() 103 | 104 | train_pred.append(result.detach().cpu().numpy()) 105 | train_gt.append(label.detach().cpu().numpy()) 106 | 107 | train_pred = np.concatenate(train_pred,axis=0) 108 | train_gt = np.concatenate(train_gt,axis=0) 109 | 110 | 111 | print(f'Train Loss: {train_loss / (batch_idx + 1)}') 112 | # auroc, auprc = compute_auc(train_gt,train_pred.astype(np.float64)) 113 | # AUROC and AUPRC measures the model performance without the dependency on a decision threshold 114 | train_pred = (train_pred>0.1) 115 | print(f'Accuracy: {compute_accuracy(train_gt.astype(np.bool),train_pred.astype(np.bool))}') 116 | print(f'F1 macro score: {compute_f_measure_mod(train_gt.astype(np.bool),train_pred.astype(np.bool))}') 117 | # print(f'AU_ROC: {auroc}, AUPRC: {auprc}') 118 | # print(f'Challenge metric: {compute_challenge_metric(weights,train_gt.astype(np.bool),train_pred.astype(np.bool),classes,normal_class)}') 119 | 120 | # # Accuracy, F1 macro score, AUROC, AUPRC, Challenge metric 121 | model.eval() 122 | with torch.no_grad(): 123 | val_loss = 0 124 | val_pred = [] 125 | val_gt = [] 126 | val_name = [] 127 | 128 | for batch_idx, sample in enumerate(val_dataloader): 129 | signal = sample['sig'].to(ctx).float() 130 | label = sample['lbl'].to(ctx).float() 131 | name = sample['idx'] 132 | 133 | _, pred = model(signal) 134 | result = torch.sigmoid(pred) 135 | 136 | loss = criterion(pred,label) 137 | val_loss += loss.item() 138 | 139 | val_pred.append(result.detach().cpu().numpy()) 140 | val_gt.append(label.detach().cpu().numpy()) 141 | val_name.append(name) 142 | 143 | val_pred = np.concatenate(val_pred,axis=0) 144 | val_gt = np.concatenate(val_gt,axis=0) 145 | val_name = np.concatenate(val_name,axis=0) 146 | 147 | df_pred = pd.DataFrame(data=val_pred) 148 | df_gt = pd.DataFrame(data=val_gt) 149 | df_name = pd.DataFrame(data=val_name) 150 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True) 151 | df_concat_group = df_concat.groupby([0]).mean() 152 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy() 153 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy() 154 | 155 | 156 | print('######## VALIDATION ########') 157 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}') 158 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64)) 159 | val_pred_after = (val_pred_after>0.1) 160 | print(f'-----> Accuracy: {compute_accuracy(val_gt_after.astype(np.bool),val_pred_after.astype(np.bool))}') 161 | print(f'-----> F1 macro score: {compute_f_measure_mod(val_gt_after.astype(np.bool),val_pred_after.astype(np.bool))}') 162 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}') 163 | print(f'-----> Challenge metric: {compute_challenge_metric(weights,val_gt_after.astype(np.bool),val_pred_after.astype(np.bool),classes,normal_class)}') 164 | 165 | 166 | if __name__ == "__main__": 167 | run() -------------------------------------------------------------------------------- /experiments/BYOL_signal_finetune.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import torch 7 | from torch.optim.lr_scheduler import CosineAnnealingLR 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | 12 | import sys 13 | current_path = os.getcwd() 14 | sys.path.append(current_path) 15 | 16 | from models.signal_model import signal_model_byol 17 | from utils.base_dataloader import ECG_dataset_base 18 | from utils.eval_tools import load_weights 19 | from utils.eval_tools import compute_accuracy, compute_f_measure_mod 20 | from utils.eval_tools import compute_auc, load_weights, compute_challenge_metric 21 | from utils.tools import open_all_layers, open_specified_layers 22 | 23 | ctx = "cuda:0" if torch.cuda.is_available() else 'cpu' 24 | 25 | def run(): 26 | root_folder = './data_folder' 27 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing') 28 | 29 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']] 30 | normal_class = '426783006' 31 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']] 32 | weights_file = './data_folder/evaluation-2020-master/weights.csv' 33 | classes, weights = load_weights(weights_file, equivalent_classes) 34 | 35 | no_fold = 8 36 | no_channels = 12 37 | signal_size = 250 38 | train_stride = signal_size 39 | train_chunk_length = 0 40 | # train_stride = signal_size//2 41 | # train_chunk_length = signal_size 42 | val_stride = signal_size//2 # overlap sample signal 43 | val_chunk_length = signal_size 44 | 45 | transforms = True 46 | batch_size = 256 47 | learning_rate = 5e-3 48 | no_epoches = 80 49 | warmup_epoches = 5 50 | 51 | train_dataset = ECG_dataset_base(summary_folder=data_folder,classes=classes, signal_size=signal_size, stride=train_stride, 52 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='train', 53 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold) 54 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size) 55 | 56 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride, 57 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='val', 58 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold) 59 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size) 60 | 61 | 62 | no_classes = 24 63 | model = signal_model_byol(no_classes) 64 | state_dict = torch.load('./checkpoints/BYOL_signal.pth',map_location=ctx) 65 | model.load_state_dict(state_dict,strict=True) 66 | model.to(ctx) 67 | 68 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate) 69 | criterion = nn.BCEWithLogitsLoss() 70 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-4, last_epoch=-1) 71 | optimizer.zero_grad() 72 | optimizer.step() 73 | 74 | for epoch in range(1,no_epoches+1): 75 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches)) 76 | print('Current learning rate: ',optimizer.param_groups[0]['lr']) 77 | scheduler_steplr.step() 78 | 79 | if epoch <= warmup_epoches: 80 | open_specified_layers(model,['backbone','features']) 81 | print('Freeze the backbone') 82 | else: 83 | open_all_layers(model) 84 | 85 | model.train() 86 | train_loss = 0 87 | train_pred = [] 88 | train_gt = [] 89 | 90 | for batch_idx, sample in enumerate(tqdm(train_dataloader)): 91 | signal = sample['sig'].to(ctx).float() 92 | signal = signal.view(-1,no_channels,signal_size) 93 | label = sample['lbl'].to(ctx).float() 94 | label = label.view(-1,no_classes) 95 | 96 | _,_,_, pred = model(signal) 97 | result = torch.sigmoid(pred) 98 | 99 | loss = criterion(pred,label) 100 | train_loss += loss.item() 101 | 102 | optimizer.zero_grad() 103 | loss.backward() 104 | optimizer.step() 105 | 106 | train_pred.append(result.detach().cpu().numpy()) 107 | train_gt.append(label.detach().cpu().numpy()) 108 | 109 | train_pred = np.concatenate(train_pred,axis=0) 110 | train_gt = np.concatenate(train_gt,axis=0) 111 | 112 | 113 | print(f'Train Loss: {train_loss / (batch_idx + 1)}') 114 | # auroc, auprc = compute_auc(train_gt,train_pred.astype(np.float64)) 115 | # AUROC and AUPRC measures the model performance without the dependency on a decision threshold 116 | train_pred = (train_pred>0.1) 117 | print(f'Accuracy: {compute_accuracy(train_gt.astype(np.bool),train_pred.astype(np.bool))}') 118 | print(f'F1 macro score: {compute_f_measure_mod(train_gt.astype(np.bool),train_pred.astype(np.bool))}') 119 | # print(f'AU_ROC: {auroc}, AUPRC: {auprc}') 120 | # print(f'Challenge metric: {compute_challenge_metric(weights,train_gt.astype(np.bool),train_pred.astype(np.bool),classes,normal_class)}') 121 | 122 | # # Accuracy, F1 macro score, AUROC, AUPRC, Challenge metric 123 | model.eval() 124 | with torch.no_grad(): 125 | val_loss = 0 126 | val_pred = [] 127 | val_gt = [] 128 | val_name = [] 129 | 130 | for batch_idx, sample in enumerate(val_dataloader): 131 | signal = sample['sig'].to(ctx).float() 132 | label = sample['lbl'].to(ctx).float() 133 | name = sample['idx'] 134 | 135 | _,_,_, pred = model(signal) 136 | result = torch.sigmoid(pred) 137 | 138 | loss = criterion(pred,label) 139 | val_loss += loss.item() 140 | 141 | val_pred.append(result.detach().cpu().numpy()) 142 | val_gt.append(label.detach().cpu().numpy()) 143 | val_name.append(name) 144 | 145 | val_pred = np.concatenate(val_pred,axis=0) 146 | val_gt = np.concatenate(val_gt,axis=0) 147 | val_name = np.concatenate(val_name,axis=0) 148 | 149 | df_pred = pd.DataFrame(data=val_pred) 150 | df_gt = pd.DataFrame(data=val_gt) 151 | df_name = pd.DataFrame(data=val_name) 152 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True) 153 | df_concat_group = df_concat.groupby([0]).mean() 154 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy() 155 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy() 156 | 157 | 158 | print('######## VALIDATION ########') 159 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}') 160 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64)) 161 | val_pred_after = (val_pred_after>0.1) 162 | print(f'-----> Accuracy: {compute_accuracy(val_gt_after.astype(np.bool),val_pred_after.astype(np.bool))}') 163 | print(f'-----> F1 macro score: {compute_f_measure_mod(val_gt_after.astype(np.bool),val_pred_after.astype(np.bool))}') 164 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}') 165 | print(f'-----> Challenge metric: {compute_challenge_metric(weights,val_gt_after.astype(np.bool),val_pred_after.astype(np.bool),classes,normal_class)}') 166 | 167 | 168 | 169 | if __name__ == "__main__": 170 | run() -------------------------------------------------------------------------------- /data_preparation/data_extraction_with_preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import wfdb 7 | from scipy.signal import resample 8 | import pywt 9 | from stratify import stratify 10 | 11 | save_folder = "./data_folder/extracted_data_with_preprocessing" 12 | save_summary = "./data_folder/data_summary_with_preprocessing" 13 | raw_data_cinc = "./data_folder/datasets" 14 | dataset_names = ["ICBEB2018","ICBEB2018_2","INCART","PTB","PTB-XL","Georgia"] 15 | mapping_scored_path = "./data_folder/evaluation-2020-master/dx_mapping_scored.csv" # 27 main labels 16 | target_fs = 100 17 | strat_folds = 10 18 | channels = 12 19 | 20 | mapping_scored_df = pd.read_csv(mapping_scored_path) 21 | dx_mapping_snomed_abbrev = {a:b for [a,b] in list(mapping_scored_df.apply(lambda row: [row["SNOMED CT Code"],row["Abbreviation"]],axis=1))} 22 | list_label_available = np.array(mapping_scored_df["SNOMED CT Code"]) 23 | 24 | CPSC_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[0],'**/*.hea')) 25 | print('No files in CPSC:', len(CPSC_files)) 26 | CPSC_extra_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[1],'**/*.hea')) 27 | print('No files in CPSC-Extra:', len(CPSC_extra_files)) 28 | SPeter_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[2],'**/*.hea')) 29 | print('No files in StPetersburg:', len(SPeter_files)) 30 | PTB_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[3],'**/*.hea')) 31 | print('No files in PTB:', len(PTB_files)) 32 | PTBXL_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[4],'**/*.hea')) 33 | print('No files in PTB-XL:', len(PTBXL_files)) 34 | Georgia_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[5],'**/*.hea')) 35 | print('No files in Georgia:', len(Georgia_files)) 36 | 37 | all_files = CPSC_files + CPSC_extra_files + SPeter_files + PTB_files + PTBXL_files + Georgia_files 38 | print('Total no files:',len(all_files)) 39 | # (7500, 12) 40 | # {'fs': 500, 'sig_len': 7500, 'n_sig': 12, 'base_date': None, 'base_time': datetime.time(0, 0, 12), 41 | # 'units': ['mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV'], 42 | # 'sig_name': ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'], 43 | # 'comments': ['Age: 74', 'Sex: Male', 'Dx: 59118001', 'Rx: Unknown', 'Hx: Unknown', 'Sx: Unknown']} 44 | 45 | # CSPC fs 500 level 3 stride 50 46 | # CSPC extra fs 500 level 3 stride 50 47 | # StPeter fs 257 level 2 stride 30 48 | # PTB fs 1000 level 2 stride 100 49 | # PTB XL fs 500 level 3 stride 50 50 | # Grorgia fs 500 level 3 stride 50 51 | 52 | def madev(d, axis=None): 53 | """ Mean absolute deviation of a signal """ 54 | return np.mean(np.absolute(d - np.mean(d, axis)), axis) 55 | 56 | def moving_average(x, w): 57 | return np.convolve(x, np.ones(w), 'valid') / w 58 | 59 | 60 | skip_files = 0 61 | metadata = [] 62 | for idx, hea_file in enumerate(tqdm(all_files)): 63 | file_name = hea_file.split("/")[-1].split(".hea")[0] 64 | data_folder = hea_file.split("/")[-3] 65 | sigbufs, header = wfdb.rdsamp(str(hea_file)[:-4]) 66 | 67 | if(np.any(np.isnan(sigbufs))): 68 | print("Warning:",str(hea_file),"is corrupt. Skipping.") 69 | continue 70 | 71 | labels=[] 72 | age=np.nan 73 | sex="nan" 74 | for l in header["comments"]: 75 | arrs = l.strip().split(' ') 76 | if l.startswith('Dx:'): 77 | for x in arrs[1].split(','): 78 | if int(x) in list_label_available: 79 | labels.append(x) 80 | elif l.startswith('Age:'): 81 | try: 82 | age = int(arrs[1]) 83 | except: 84 | age= np.nan 85 | elif l.startswith('Sex:'): 86 | sex = arrs[1].strip().lower() 87 | if(sex=="m"): 88 | sex="male" 89 | elif(sex=="f"): 90 | sex="female" 91 | 92 | if len(labels) == 0: 93 | skip_files += 1 94 | continue 95 | 96 | if data_folder == "ICBEB2018": 97 | level = 3 98 | stride = 50 99 | elif data_folder == "ICBEB2018_2": 100 | level = 3 101 | stride = 50 102 | elif data_folder == "INCART": 103 | level = 3 104 | stride = 30 105 | elif data_folder == "PTB": 106 | level = 3 107 | stride = 100 108 | elif data_folder == "PTB-XL": 109 | level = 3 110 | stride = 50 111 | elif data_folder == "Georgia": 112 | level = 3 113 | stride = 50 114 | 115 | 116 | # DENOISE 117 | # Create wavelet object and define parameters 118 | w = pywt.Wavelet('db4') 119 | maxlev = pywt.dwt_max_level(len(sigbufs[:,0]), w.dec_len) 120 | denoised_data = np.zeros((len(sigbufs), channels), dtype=np.float32) 121 | # Decompose into wavelet components, to the level selected: 122 | for cha in range(channels): 123 | coeffs = pywt.wavedec(sigbufs[:,cha], 'db4', mode='periodic',level=maxlev) 124 | 125 | sigma = (1/0.6745) * madev(coeffs[-level]) 126 | uthresh = sigma * np.sqrt(2 * np.log(len(sigbufs[:,cha]))) 127 | 128 | coeffs[1:] = (pywt.threshold(i, value=uthresh, mode='hard') for i in coeffs[1:]) 129 | 130 | datarec = pywt.waverec(coeffs, 'db4') 131 | if len(datarec) < len(sigbufs): 132 | datarec = np.pad(datarec,len(sigbufs)-len(datarec),'edge') 133 | denoised_data[:,cha] = datarec 134 | elif len(datarec) > len(sigbufs): 135 | denoised_data[:,cha] = datarec[0:len(sigbufs)] 136 | else: 137 | denoised_data[:,cha] = datarec 138 | 139 | # BASELINE WANDER REMOVAL 140 | baseline_removal_data = np.zeros((len(sigbufs), channels), dtype=np.float32) 141 | for cha in range(channels): 142 | avg_output = moving_average(denoised_data[:,cha],stride) 143 | avg_pad = np.pad(avg_output,(0,len(sigbufs[:,cha])-len(avg_output)),'edge') 144 | baseline_removal_data[:,cha] = denoised_data[:,cha]- avg_pad 145 | 146 | 147 | ori_fs = header['fs'] 148 | factor = target_fs/ori_fs 149 | timesteps_new = int(len(sigbufs)*factor) 150 | data = np.zeros((timesteps_new, channels), dtype=np.float32) 151 | for i in range(channels): 152 | data[:,i] = resample(baseline_removal_data[:,0],timesteps_new) 153 | 154 | np.save(os.path.join(save_folder,file_name+".npy"),data) 155 | 156 | metadata.append({"data":file_name+".npy","label":labels,"sex":sex,"age":age,"dataset":data_folder}) 157 | 158 | df =pd.DataFrame(metadata) 159 | lbl_itos = np.unique([item for sublist in list(df.label) for item in sublist]) 160 | lbl_stoi = {s:i for i,s in enumerate(lbl_itos)} 161 | df["label"] = df["label"].apply(lambda x: [lbl_stoi[y] for y in x]) 162 | 163 | df["strat_fold"]=-1 164 | for ds in np.unique(df["dataset"]): 165 | print("Creating CV folds:",ds) 166 | dfx = df[df.dataset==ds] 167 | idxs = np.array(dfx.index.values) 168 | lbl_itosx = np.unique([item for sublist in list(dfx.label) for item in sublist]) 169 | stratified_ids = stratify(list(dfx["label"]), lbl_itosx, [1./strat_folds]*strat_folds) 170 | 171 | for i,split in enumerate(stratified_ids): 172 | df.loc[idxs[split],"strat_fold"]=i 173 | 174 | print("Add Mean Column") 175 | df["data_mean"]=df["data"].apply(lambda x: np.mean(np.load(x if save_folder is None else os.path.join(save_folder,x), allow_pickle=True),axis=0)) 176 | print("Add Std Column") 177 | df["data_std"]=df["data"].apply(lambda x: np.std(np.load(x if data_folder is None else os.path.join(save_folder,x), allow_pickle=True),axis=0)) 178 | print("Add Length Column") 179 | df["data_length"]=df["data"].apply(lambda x: len(np.load(x if data_folder is None else os.path.join(save_folder,x), allow_pickle=True))) 180 | 181 | #save means and stds 182 | df_mean = df["data_mean"].mean() 183 | df_std = df["data_std"].mean() 184 | 185 | # save dataset 186 | df.to_pickle(os.path.join(save_summary,'df.pkl'),protocol=4) 187 | np.save(os.path.join(save_summary,"lbl_itos.npy"),lbl_itos) 188 | np.save(os.path.join(save_summary,"mean.npy"),df_mean) 189 | np.save(os.path.join(save_summary,"std.npy"),df_std) 190 | 191 | # file1 = 'df.pkl' 192 | # file2 = 'lbl_itos.npy' 193 | # file3 = 'memmap.npy' 194 | # file4 = 'memmap_meta.npz' 195 | # file5 = 'df_memmap.pkl' 196 | # file6 = 'mean.npy' 197 | # file7 = 'std.npy' 198 | 199 | -------------------------------------------------------------------------------- /experiments/run_signal.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import torch 7 | from torch.optim.lr_scheduler import CosineAnnealingLR 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | import argparse 12 | 13 | import sys 14 | current_path = os.getcwd() 15 | sys.path.append(current_path) 16 | 17 | from models.signal_model import signal_model 18 | from utils.base_dataloader import ECG_dataset_base 19 | from utils.eval_tools import compute_auc, load_weights 20 | from utils.tools import weights_init_xavier 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--batch_size",type=int, default=128) 25 | parser.add_argument("--lr_rate",type=float, default=5e-3) 26 | parser.add_argument("--num_epoches",type=int, default=100) 27 | parser.add_argument('--fold', nargs='+', type=int, default=[11]) 28 | parser.add_argument("--gpu", type=str, default="0") 29 | parser.add_argument("--finetune", type=str, default=None) 30 | parser.add_argument("--save_folder", type=str, default=None) 31 | return parser.parse_args() 32 | 33 | def run(): 34 | args = parse_args() 35 | 36 | ctx = "cuda:"+args.gpu if torch.cuda.is_available() else 'cpu' 37 | 38 | root_folder = './data_folder' 39 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing') 40 | 41 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']] 42 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']] 43 | weights_file = './data_folder/evaluation-2020-master/weights.csv' 44 | classes, weights = load_weights(weights_file, equivalent_classes) 45 | 46 | no_channels = 12 47 | signal_size = 250 48 | train_stride = signal_size 49 | train_chunk_length = 0 50 | # train_stride = signal_size//2 51 | # train_chunk_length = signal_size 52 | val_stride = signal_size//2 # overlap sample signal 53 | val_chunk_length = signal_size 54 | 55 | transforms = True 56 | batch_size = args.batch_size 57 | learning_rate = args.lr_rate 58 | no_epoches = args.num_epoches 59 | 60 | list_folds = args.fold 61 | if 11 in list_folds: 62 | fold_range = np.arange(10) 63 | else: 64 | fold_range = list_folds 65 | # run 10 fold cross validation 66 | for no_fold in fold_range: 67 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###') 68 | print(f'Starting fold {no_fold} ...') 69 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###') 70 | 71 | train_dataset = ECG_dataset_base(summary_folder=data_folder,classes=classes, signal_size=signal_size, stride=train_stride, 72 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='train', 73 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold) 74 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size) 75 | 76 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride, 77 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='val', 78 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold) 79 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size) 80 | 81 | no_classes = train_dataset.get_num_classes() 82 | model = signal_model(no_classes) 83 | 84 | # use the pretrain models from self-supervised learning 85 | if args.finetune is not None: 86 | state_dict = torch.load(args.finetune,map_location=ctx) 87 | model.load_state_dict(state_dict,strict=False) 88 | else: 89 | model.apply(weights_init_xavier) 90 | model.to(ctx) 91 | 92 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate) 93 | criterion = nn.BCEWithLogitsLoss() 94 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-5, last_epoch=-1) 95 | optimizer.zero_grad() 96 | optimizer.step() 97 | 98 | best_auroc = 0 99 | best_auprc = 0 100 | for epoch in range(1,no_epoches+1): 101 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches)) 102 | print('Current learning rate: ',optimizer.param_groups[0]['lr']) 103 | scheduler_steplr.step() 104 | model.train() 105 | train_loss = 0 106 | train_pred = [] 107 | train_gt = [] 108 | 109 | for batch_idx, sample in enumerate(tqdm(train_dataloader)): 110 | signal = sample['sig'].to(ctx).float() 111 | signal = signal.view(-1,no_channels,signal_size) 112 | label = sample['lbl'].to(ctx).float() 113 | label = label.view(-1,no_classes) 114 | 115 | pred = model(signal) 116 | result = torch.sigmoid(pred) 117 | 118 | loss = criterion(pred,label) 119 | train_loss += loss.item() 120 | 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | 125 | train_pred.append(result.detach().cpu().numpy()) 126 | train_gt.append(label.detach().cpu().numpy()) 127 | 128 | train_pred = np.concatenate(train_pred,axis=0) 129 | train_gt = np.concatenate(train_gt,axis=0) 130 | 131 | print(f'Train Loss: {train_loss / (batch_idx + 1)}') 132 | 133 | model.eval() 134 | with torch.no_grad(): 135 | val_loss = 0 136 | val_pred = [] 137 | val_gt = [] 138 | val_name = [] 139 | 140 | for batch_idx, sample in enumerate(val_dataloader): 141 | signal = sample['sig'].to(ctx).float() 142 | label = sample['lbl'].to(ctx).float() 143 | name = sample['idx'] 144 | 145 | pred = model(signal) 146 | result = torch.sigmoid(pred) 147 | 148 | loss = criterion(pred,label) 149 | val_loss += loss.item() 150 | 151 | val_pred.append(result.detach().cpu().numpy()) 152 | val_gt.append(label.detach().cpu().numpy()) 153 | val_name.append(name) 154 | 155 | val_pred = np.concatenate(val_pred,axis=0) 156 | val_gt = np.concatenate(val_gt,axis=0) 157 | val_name = np.concatenate(val_name,axis=0) 158 | 159 | df_pred = pd.DataFrame(data=val_pred) 160 | df_gt = pd.DataFrame(data=val_gt) 161 | df_name = pd.DataFrame(data=val_name) 162 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True) 163 | df_concat_group = df_concat.groupby([0]).mean() 164 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy() 165 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy() 166 | 167 | print('######## VALIDATION ########') 168 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}') 169 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64)) 170 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}') 171 | 172 | if auroc > best_auroc: 173 | best_auroc = auroc 174 | if args.finetune is not None: 175 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestROC_finetune.pth') 176 | else: 177 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestROC.pth') 178 | if auprc > best_auprc: 179 | best_auprc = auprc 180 | if args.finetune is not None: 181 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestPRC_finetune.pth') 182 | else: 183 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestPRC.pth') 184 | 185 | if __name__ == "__main__": 186 | run() -------------------------------------------------------------------------------- /experiments/run_spectrogram.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import torch 7 | from torch.optim.lr_scheduler import CosineAnnealingLR 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | import argparse 12 | 13 | import sys 14 | current_path = os.getcwd() 15 | sys.path.append(current_path) 16 | 17 | from models.spectrogram_model import spectrogram_model 18 | from utils.base_dataloader import ECG_dataset_base 19 | from utils.eval_tools import compute_auc, load_weights 20 | from utils.tools import weights_init_xavier 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--batch_size",type=int, default=128) 25 | parser.add_argument("--lr_rate",type=float, default=5e-3) 26 | parser.add_argument("--num_epoches",type=int, default=100) 27 | parser.add_argument('--fold', nargs='+', type=int, default=[11]) 28 | parser.add_argument("--gpu", type=str, default="0") 29 | parser.add_argument("--finetune", type=str, default=None) 30 | parser.add_argument("--save_folder", type=str, default=None) 31 | return parser.parse_args() 32 | 33 | def run(): 34 | args = parse_args() 35 | 36 | ctx = "cuda:"+args.gpu if torch.cuda.is_available() else 'cpu' 37 | 38 | root_folder = './data_folder' 39 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing') 40 | 41 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']] 42 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']] 43 | weights_file = './data_folder/evaluation-2020-master/weights.csv' 44 | classes, weights = load_weights(weights_file, equivalent_classes) 45 | 46 | no_channels = 12 47 | signal_size = 250 48 | train_stride = signal_size 49 | train_chunk_length = 0 50 | # train_stride = signal_size//2 51 | # train_chunk_length = signal_size 52 | val_stride = signal_size//2 # overlap sample signal 53 | val_chunk_length = signal_size 54 | 55 | transforms = True 56 | batch_size = args.batch_size 57 | learning_rate = args.lr_rate 58 | no_epoches = args.num_epoches 59 | 60 | list_folds = args.fold 61 | if 11 in list_folds: 62 | fold_range = np.arange(10) 63 | else: 64 | fold_range = list_folds 65 | # run 10 fold cross validation 66 | for no_fold in fold_range: 67 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###') 68 | print(f'Starting fold {no_fold} ...') 69 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###') 70 | 71 | train_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes, signal_size=signal_size, stride=train_stride, 72 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=True, meta_inc=False, t_or_v='train', 73 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold) 74 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size) 75 | 76 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride, 77 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=True, meta_inc=False, t_or_v='val', 78 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold) 79 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size) 80 | 81 | no_classes = train_dataset.get_num_classes() 82 | model = spectrogram_model(no_classes) 83 | 84 | # use the pretrain models from self-supervised learning 85 | if args.finetune is not None: 86 | state_dict = torch.load(args.finetune,map_location=ctx) 87 | model.load_state_dict(state_dict,strict=False) 88 | else: 89 | model.apply(weights_init_xavier) 90 | model.to(ctx) 91 | 92 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate) 93 | criterion = nn.BCEWithLogitsLoss() 94 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-6, last_epoch=-1) 95 | optimizer.zero_grad() 96 | optimizer.step() 97 | 98 | best_auroc = 0 99 | best_auprc = 0 100 | for epoch in range(1,no_epoches+1): 101 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches)) 102 | print('Current learning rate: ',optimizer.param_groups[0]['lr']) 103 | scheduler_steplr.step() 104 | model.train() 105 | train_loss = 0 106 | train_pred = [] 107 | train_gt = [] 108 | 109 | for batch_idx, sample in enumerate(tqdm(train_dataloader)): 110 | stft = sample['stft'].to(ctx).float() 111 | stft = stft.view(-1,no_channels,13,21) 112 | label = sample['lbl'].to(ctx).float() 113 | label = label.view(-1,no_classes) 114 | 115 | pred = model(stft) 116 | result = torch.sigmoid(pred) 117 | 118 | loss = criterion(pred,label) 119 | train_loss += loss.item() 120 | 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | 125 | train_pred.append(result.detach().cpu().numpy()) 126 | train_gt.append(label.detach().cpu().numpy()) 127 | 128 | train_pred = np.concatenate(train_pred,axis=0) 129 | train_gt = np.concatenate(train_gt,axis=0) 130 | 131 | print(f'Train Loss: {train_loss / (batch_idx + 1)}') 132 | 133 | model.eval() 134 | with torch.no_grad(): 135 | val_loss = 0 136 | val_pred = [] 137 | val_gt = [] 138 | val_name = [] 139 | 140 | for batch_idx, sample in enumerate(val_dataloader): 141 | signal = sample['stft'].to(ctx).float() 142 | label = sample['lbl'].to(ctx).float() 143 | name = sample['idx'] 144 | 145 | pred = model(signal) 146 | result = torch.sigmoid(pred) 147 | 148 | loss = criterion(pred,label) 149 | val_loss += loss.item() 150 | 151 | val_pred.append(result.detach().cpu().numpy()) 152 | val_gt.append(label.detach().cpu().numpy()) 153 | val_name.append(name) 154 | 155 | val_pred = np.concatenate(val_pred,axis=0) 156 | val_gt = np.concatenate(val_gt,axis=0) 157 | val_name = np.concatenate(val_name,axis=0) 158 | 159 | df_pred = pd.DataFrame(data=val_pred) 160 | df_gt = pd.DataFrame(data=val_gt) 161 | df_name = pd.DataFrame(data=val_name) 162 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True) 163 | df_concat_group = df_concat.groupby([0]).mean() 164 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy() 165 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy() 166 | 167 | 168 | print('######## VALIDATION ########') 169 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}') 170 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64)) 171 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}') 172 | 173 | if auroc > best_auroc: 174 | best_auroc = auroc 175 | if args.finetune is not None: 176 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestROC_finetune.pth') 177 | else: 178 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestROC.pth') 179 | if auprc > best_auprc: 180 | best_auprc = auprc 181 | if args.finetune is not None: 182 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestPRC_finetune.pth') 183 | else: 184 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestPRC.pth') 185 | 186 | if __name__ == "__main__": 187 | run() -------------------------------------------------------------------------------- /experiments/run_ensembled.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import torch 7 | from torch.optim.lr_scheduler import CosineAnnealingLR 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | import argparse 12 | 13 | import sys 14 | current_path = os.getcwd() 15 | sys.path.append(current_path) 16 | 17 | from models.ensemble_model import ensemble_model 18 | from utils.base_dataloader import ECG_dataset_base 19 | from utils.eval_tools import compute_auc, load_weights 20 | from utils.tools import weights_init_xavier 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--batch_size",type=int, default=128) 25 | parser.add_argument("--lr_rate",type=float, default=5e-3) 26 | parser.add_argument("--num_epoches",type=int, default=100) 27 | parser.add_argument('--fold', nargs='+', type=int, default=[11]) 28 | parser.add_argument("--gpu", type=str, default="0") 29 | parser.add_argument("--gating",action="store_true") 30 | parser.add_argument("--finetune", type=str, default=None) 31 | parser.add_argument("--save_folder", type=str, default=None) 32 | return parser.parse_args() 33 | 34 | def run(): 35 | args = parse_args() 36 | 37 | ctx = "cuda:"+args.gpu if torch.cuda.is_available() else 'cpu' 38 | 39 | root_folder = './data_folder' 40 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing') 41 | 42 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']] 43 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']] 44 | weights_file = './data_folder/evaluation-2020-master/weights.csv' 45 | classes, weights = load_weights(weights_file, equivalent_classes) 46 | 47 | no_channels = 12 48 | signal_size = 250 49 | train_stride = signal_size 50 | train_chunk_length = 0 51 | # train_stride = signal_size//2 52 | # train_chunk_length = signal_size 53 | val_stride = signal_size//2 # overlap sample signal 54 | val_chunk_length = signal_size 55 | 56 | transforms = True 57 | batch_size = args.batch_size 58 | learning_rate = args.lr_rate 59 | no_epoches = args.num_epoches 60 | 61 | list_folds = args.fold 62 | if 11 in list_folds: 63 | fold_range = np.arange(10) 64 | else: 65 | fold_range = list_folds 66 | # run 10 fold cross validation 67 | for no_fold in fold_range: 68 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###') 69 | print(f'Starting fold {no_fold} ...') 70 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###') 71 | 72 | train_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes, signal_size=signal_size, stride=train_stride, 73 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=True, meta_inc=False, t_or_v='train', 74 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold) 75 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size) 76 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride, 77 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=True, meta_inc=False, t_or_v='val', 78 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold) 79 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size) 80 | 81 | no_classes = train_dataset.get_num_classes() 82 | 83 | 84 | if args.finetune is not None: 85 | checkpoint_folder = "./checkpoints" 86 | w_time = os.path.join(checkpoint_folder,"DINO_signal_student.pth") 87 | w_spec = os.path.join(checkpoint_folder,"DINO_spectrogram_student.pth") 88 | model = ensemble_model(no_classes, args.gating,w_time,w_spec,ctx) 89 | else: 90 | model = ensemble_model(no_classes, args.gating) 91 | model.apply(weights_init_xavier) 92 | model.to(ctx) 93 | 94 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate) 95 | criterion = nn.BCEWithLogitsLoss() 96 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-5, last_epoch=-1) 97 | optimizer.zero_grad() 98 | optimizer.step() 99 | 100 | best_auroc = 0 101 | best_auprc = 0 102 | for epoch in range(1,no_epoches+1): 103 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches)) 104 | print('Current learning rate: ',optimizer.param_groups[0]['lr']) 105 | scheduler_steplr.step() 106 | model.train() 107 | train_loss = 0 108 | train_pred = [] 109 | train_gt = [] 110 | 111 | for batch_idx, sample in enumerate(tqdm(train_dataloader)): 112 | signal = sample['sig'].to(ctx).float() 113 | signal = signal.view(-1,no_channels,signal_size) 114 | stft = sample['stft'].to(ctx).float() 115 | stft = stft.view(-1,no_channels,13,21) 116 | label = sample['lbl'].to(ctx).float() 117 | label = label.view(-1,no_classes) 118 | 119 | pred = model(signal,stft) 120 | result = torch.sigmoid(pred) 121 | 122 | loss = criterion(pred,label) 123 | train_loss += loss.item() 124 | 125 | optimizer.zero_grad() 126 | loss.backward() 127 | optimizer.step() 128 | 129 | train_pred.append(result.detach().cpu().numpy()) 130 | train_gt.append(label.detach().cpu().numpy()) 131 | 132 | train_pred = np.concatenate(train_pred,axis=0) 133 | train_gt = np.concatenate(train_gt,axis=0) 134 | 135 | print(f'Train Loss: {train_loss / (batch_idx + 1)}') 136 | 137 | model.eval() 138 | with torch.no_grad(): 139 | val_loss = 0 140 | val_pred = [] 141 | val_gt = [] 142 | val_name = [] 143 | 144 | for batch_idx, sample in enumerate(val_dataloader): 145 | signal = sample['sig'].to(ctx).float() 146 | stft = sample['stft'].to(ctx).float() 147 | label = sample['lbl'].to(ctx).float() 148 | name = sample['idx'] 149 | 150 | pred = model(signal,stft) 151 | result = torch.sigmoid(pred) 152 | 153 | loss = criterion(pred,label) 154 | val_loss += loss.item() 155 | 156 | val_pred.append(result.detach().cpu().numpy()) 157 | val_gt.append(label.detach().cpu().numpy()) 158 | val_name.append(name) 159 | 160 | val_pred = np.concatenate(val_pred,axis=0) 161 | val_gt = np.concatenate(val_gt,axis=0) 162 | val_name = np.concatenate(val_name,axis=0) 163 | 164 | df_pred = pd.DataFrame(data=val_pred) 165 | df_gt = pd.DataFrame(data=val_gt) 166 | df_name = pd.DataFrame(data=val_name) 167 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True) 168 | df_concat_group = df_concat.groupby([0]).mean() 169 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy() 170 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy() 171 | 172 | print('######## VALIDATION ########') 173 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}') 174 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64)) 175 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}') 176 | 177 | if auroc > best_auroc: 178 | best_auroc = auroc 179 | if args.gating: 180 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_withgating_fold{no_fold}_bestROC.pth') 181 | else: 182 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_wthoutgating_fold{no_fold}_bestROC.pth') 183 | 184 | if auprc > best_auprc: 185 | best_auprc = auprc 186 | if args.gating: 187 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_withgating_fold{no_fold}_bestPRC.pth') 188 | else: 189 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_wthoutgating_fold{no_fold}_bestPRC.pth') 190 | 191 | if __name__ == "__main__": 192 | run() -------------------------------------------------------------------------------- /models/xresnet1d.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/13_xresnet1d.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['delegates', 'store_attr', 'init_default', 'BatchNorm', 'NormType', 'ConvLayer', 'AdaptiveAvgPool', 4 | 'MaxPool', 'AvgPool', 'ResBlock', 'init_cnn', 'XResNet1d', 'xresnet1d18', 'xresnet1d34', 'xresnet1d50', 5 | 'xresnet1d101', 'xresnet1d152', 'xresnet1d18_deep', 'xresnet1d34_deep', 'xresnet1d50_deep', 6 | 'xresnet1d18_deeper', 'xresnet1d34_deeper', 'xresnet1d50_deeper'] 7 | 8 | # Cell 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .basic_conv1d import create_head1d, Flatten 14 | 15 | from enum import Enum 16 | import re 17 | 18 | # Cell 19 | import inspect 20 | 21 | def delegates(to=None, keep=False): 22 | "Decorator: replace `**kwargs` in signature with params from `to`" 23 | def _f(f): 24 | if to is None: to_f,from_f = f.__base__.__init__,f.__init__ 25 | else: to_f,from_f = to,f 26 | sig = inspect.signature(from_f) 27 | sigd = dict(sig.parameters) 28 | k = sigd.pop('kwargs') 29 | s2 = {k:v for k,v in inspect.signature(to_f).parameters.items() 30 | if v.default != inspect.Parameter.empty and k not in sigd} 31 | sigd.update(s2) 32 | if keep: sigd['kwargs'] = k 33 | from_f.__signature__ = sig.replace(parameters=sigd.values()) 34 | return f 35 | return _f 36 | 37 | def store_attr(self, nms): 38 | "Store params named in comma-separated `nms` from calling context into attrs in `self`" 39 | mod = inspect.currentframe().f_back.f_locals 40 | for n in re.split(', *', nms): setattr(self,n,mod[n]) 41 | 42 | # Cell 43 | NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero') 44 | 45 | def _conv_func(ndim=2, transpose=False): 46 | "Return the proper conv `ndim` function, potentially `transposed`." 47 | assert 1 <= ndim <=3 48 | return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d') 49 | 50 | def init_default(m, func=nn.init.kaiming_normal_): 51 | "Initialize `m` weights with `func` and set `bias` to 0." 52 | if func and hasattr(m, 'weight'): func(m.weight) 53 | with torch.no_grad(): 54 | if getattr(m, 'bias', None) is not None: m.bias.fill_(0.) 55 | return m 56 | 57 | def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs): 58 | "Norm layer with `nf` features and `ndim` initialized depending on `norm_type`." 59 | assert 1 <= ndim <= 3 60 | bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs) 61 | if bn.affine: 62 | bn.bias.data.fill_(1e-3) 63 | bn.weight.data.fill_(0. if zero else 1.) 64 | return bn 65 | 66 | def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs): 67 | "BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`." 68 | return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs) 69 | 70 | # Cell 71 | class ConvLayer(nn.Sequential): 72 | "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers." 73 | def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True, 74 | act_cls=nn.ReLU, transpose=False, init=nn.init.kaiming_normal_, xtra=None, **kwargs): 75 | if padding is None: padding = ((ks-1)//2 if not transpose else 0) 76 | bn = norm_type in (NormType.Batch, NormType.BatchZero) 77 | inn = norm_type in (NormType.Instance, NormType.InstanceZero) 78 | if bias is None: bias = not (bn or inn) 79 | conv_func = _conv_func(ndim, transpose=transpose) 80 | conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs), init) 81 | if norm_type==NormType.Weight: conv = weight_norm(conv) 82 | elif norm_type==NormType.Spectral: conv = spectral_norm(conv) 83 | layers = [conv] 84 | act_bn = [] 85 | if act_cls is not None: act_bn.append(act_cls()) 86 | if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim)) 87 | if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim)) 88 | if bn_1st: act_bn.reverse() 89 | layers += act_bn 90 | if xtra: layers.append(xtra) 91 | super().__init__(*layers) 92 | 93 | # Cell 94 | def AdaptiveAvgPool(sz=1, ndim=2): 95 | "nn.AdaptiveAvgPool layer for `ndim`" 96 | assert 1 <= ndim <= 3 97 | return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz) 98 | 99 | def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): 100 | "nn.MaxPool layer for `ndim`" 101 | assert 1 <= ndim <= 3 102 | return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding) 103 | 104 | def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): 105 | "nn.AvgPool layer for `ndim`" 106 | assert 1 <= ndim <= 3 107 | return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode) 108 | 109 | # Cell 110 | class ResBlock(nn.Module): 111 | "Resnet block from `ni` to `nh` with `stride`" 112 | @delegates(ConvLayer.__init__) 113 | def __init__(self, expansion, ni, nf, stride=1, kernel_size=3, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1, 114 | sa=False, sym=False, norm_type=NormType.Batch, act_cls=nn.ReLU, ndim=2, 115 | pool=AvgPool, pool_first=True, **kwargs): 116 | super().__init__() 117 | norm2 = (NormType.BatchZero if norm_type==NormType.Batch else 118 | NormType.InstanceZero if norm_type==NormType.Instance else norm_type) 119 | if nh2 is None: nh2 = nf 120 | if nh1 is None: nh1 = nh2 121 | nf,ni = nf*expansion,ni*expansion 122 | k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs) 123 | k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs) 124 | layers = [ConvLayer(ni, nh2, kernel_size, stride=stride, groups=ni if dw else groups, **k0), 125 | ConvLayer(nh2, nf, kernel_size, groups=g2, **k1) 126 | ] if expansion == 1 else [ 127 | ConvLayer(ni, nh1, 1, **k0), 128 | ConvLayer(nh1, nh2, kernel_size, stride=stride, groups=nh1 if dw else groups, **k0), 129 | ConvLayer(nh2, nf, 1, groups=g2, **k1)] 130 | self.convs = nn.Sequential(*layers) 131 | convpath = [self.convs] 132 | if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls)) 133 | if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym)) 134 | self.convpath = nn.Sequential(*convpath) 135 | idpath = [] 136 | if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs)) 137 | if stride!=1: idpath.insert((1,0)[pool_first], pool(2, ndim=ndim, ceil_mode=True)) 138 | self.idpath = nn.Sequential(*idpath) 139 | self.act = nn.ReLU(inplace=True) if act_cls is nn.ReLU else act_cls() 140 | 141 | def forward(self, x): return self.act(self.convpath(x) + self.idpath(x)) 142 | 143 | 144 | 145 | # Cell 146 | def init_cnn(m): 147 | if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0) 148 | if isinstance(m, (nn.Conv1d, nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight) 149 | for l in m.children(): init_cnn(l) 150 | 151 | # Cell 152 | class XResNet1d(nn.Sequential): 153 | @delegates(ResBlock) 154 | def __init__(self, block, expansion, layers, p=0.0, input_channels=3, num_classes=1000, stem_szs=(32,32,64),kernel_size=5,kernel_size_stem=5, 155 | widen=1.0, sa=False, act_cls=nn.ReLU, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 156 | store_attr(self, 'block,expansion,act_cls') 157 | stem_szs = [input_channels, *stem_szs] 158 | stem = [ConvLayer(stem_szs[i], stem_szs[i+1], ks=kernel_size_stem, stride=2 if i==0 else 1, act_cls=act_cls, ndim=1) 159 | for i in range(3)] 160 | 161 | #block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)] 162 | block_szs = [int(o*widen) for o in [64,64,64,64] +[32]*(len(layers)-4)] 163 | block_szs = [64//expansion] + block_szs 164 | blocks = [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l, 165 | stride=1 if i==0 else 2, kernel_size=kernel_size, sa=sa and i==len(layers)-4, ndim=1, **kwargs) 166 | for i,l in enumerate(layers)] 167 | 168 | head = create_head1d(block_szs[-1]*expansion, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) 169 | 170 | super().__init__( 171 | *stem, nn.MaxPool1d(kernel_size=3, stride=2, padding=1), 172 | *blocks, 173 | head, 174 | ) 175 | init_cnn(self) 176 | 177 | def _make_layer(self, ni, nf, blocks, stride, kernel_size, sa, **kwargs): 178 | return nn.Sequential( 179 | *[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1, 180 | kernel_size=kernel_size, sa=sa and i==(blocks-1), act_cls=self.act_cls, **kwargs) 181 | for i in range(blocks)]) 182 | 183 | def get_layer_groups(self): 184 | return (self[3],self[-1]) 185 | 186 | def get_output_layer(self): 187 | return self[-1][-1] 188 | 189 | def set_output_layer(self,x): 190 | self[-1][-1]=x 191 | 192 | # Cell 193 | def _xresnet1d(expansion, layers, **kwargs): 194 | return XResNet1d(ResBlock, expansion, layers, **kwargs) 195 | 196 | def xresnet1d18 (**kwargs): return _xresnet1d(1, [2, 2, 2, 2], **kwargs) 197 | def xresnet1d34 (**kwargs): return _xresnet1d(1, [3, 4, 6, 3], **kwargs) 198 | def xresnet1d50 (**kwargs): return _xresnet1d(4, [3, 4, 6, 3], **kwargs) 199 | def xresnet1d101(**kwargs): return _xresnet1d(4, [3, 4, 23, 3], **kwargs) 200 | def xresnet1d152(**kwargs): return _xresnet1d(4, [3, 8, 36, 3], **kwargs) 201 | def xresnet1d18_deep (**kwargs): return _xresnet1d(1, [2,2,2,2,1,1], **kwargs) 202 | def xresnet1d34_deep (**kwargs): return _xresnet1d(1, [3,4,6,3,1,1], **kwargs) 203 | def xresnet1d50_deep (**kwargs): return _xresnet1d(4, [3,4,6,3,1,1], **kwargs) 204 | def xresnet1d18_deeper(**kwargs): return _xresnet1d(1, [2,2,1,1,1,1,1,1], **kwargs) 205 | def xresnet1d34_deeper(**kwargs): return _xresnet1d(1, [3,4,6,3,1,1,1,1], **kwargs) 206 | def xresnet1d50_deeper(**kwargs): return _xresnet1d(4, [3,4,6,3,1,1,1,1], **kwargs) -------------------------------------------------------------------------------- /models/basic_conv1d.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/11_basic_conv1d.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['cd_adaptiveconcatpool', 'attrib_adaptiveconcatpool', 'AdaptiveConcatPool1d', 'SqueezeExcite1d', 4 | 'weight_init', 'create_head1d', 'basic_conv1d', 'fcn', 'fcn_wang', 'schirrmeister', 'sen', 'basic1d'] 5 | 6 | # Cell 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import math 11 | from typing import Iterable 12 | 13 | class Flatten(nn.Module): 14 | "Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor" 15 | def __init__(self, full:bool=False): 16 | super().__init__() 17 | self.full = full 18 | def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1) 19 | 20 | 21 | def listify(p=None, q=None): 22 | "Make `p` listy and the same length as `q`." 23 | if p is None: p=[] 24 | elif isinstance(p, str): p = [p] 25 | elif not isinstance(p, Iterable): p = [p] 26 | #Rank 0 tensors in PyTorch are Iterable but don't have a length. 27 | else: 28 | try: a = len(p) 29 | except: p = [p] 30 | n = q if type(q)==int else len(p) if q is None else len(q) 31 | if len(p)==1: p = p * n 32 | assert len(p)==n, f'List len mismatch ({len(p)} vs {n})' 33 | return list(p) 34 | 35 | 36 | def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None): 37 | "Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`." 38 | layers = [nn.BatchNorm1d(n_in)] if bn else [] 39 | if p != 0: layers.append(nn.Dropout(p)) 40 | layers.append(nn.Linear(n_in, n_out)) 41 | if actn is not None: layers.append(actn) 42 | return layers 43 | 44 | # Cell 45 | def _conv1d(in_planes,out_planes,kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0): 46 | lst=[] 47 | if(drop_p>0): 48 | lst.append(nn.Dropout(drop_p)) 49 | lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, dilation=dilation, bias=not(bn))) 50 | if(bn): 51 | lst.append(nn.BatchNorm1d(out_planes)) 52 | if(act=="relu"): 53 | lst.append(nn.ReLU(True)) 54 | if(act=="elu"): 55 | lst.append(nn.ELU(True)) 56 | if(act=="prelu"): 57 | lst.append(nn.PReLU(True)) 58 | return nn.Sequential(*lst) 59 | 60 | def _fc(in_planes,out_planes, act="relu", bn=True): 61 | lst = [nn.Linear(in_planes, out_planes, bias=not(bn))] 62 | if(bn): 63 | lst.append(nn.BatchNorm1d(out_planes)) 64 | if(act=="relu"): 65 | lst.append(nn.ReLU(True)) 66 | if(act=="elu"): 67 | lst.append(nn.ELU(True)) 68 | if(act=="prelu"): 69 | lst.append(nn.PReLU(True)) 70 | return nn.Sequential(*lst) 71 | 72 | class AdaptiveConcatPool1d(nn.Module): 73 | "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`." 74 | def __init__(self, sz=None): 75 | "Output will be 2*sz or 2 if sz is None" 76 | super().__init__() 77 | sz = sz or 1 78 | self.ap,self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz) 79 | def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) 80 | def attrib(self,relevant,irrelevant): 81 | return attrib_adaptiveconcatpool(self,relevant,irrelevant) 82 | 83 | 84 | # Cell 85 | class SqueezeExcite1d(nn.Module): 86 | '''squeeze excite block as used for example in LSTM FCN''' 87 | def __init__(self,channels,reduction=16): 88 | super().__init__() 89 | channels_reduced = channels//reduction 90 | self.w1 = torch.nn.Parameter(torch.randn(channels_reduced,channels).unsqueeze(0)) 91 | self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0)) 92 | 93 | def forward(self, x): 94 | #input is bs,ch,seq 95 | z=torch.mean(x,dim=2,keepdim=True)#bs,ch 96 | intermed = F.relu(torch.matmul(self.w1,z))#(1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1) 97 | s=F.sigmoid(torch.matmul(self.w2,intermed))#(1,ch,ch_red * bs, ch_red, 1=bs, ch, 1 98 | return s*x #bs,ch,seq * bs, ch,1 = bs,ch,seq 99 | 100 | # Cell 101 | def weight_init(m): 102 | '''call weight initialization for model n via n.appy(weight_init)''' 103 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): 104 | nn.init.kaiming_normal_(m.weight) 105 | if m.bias is not None: 106 | nn.init.zeros_(m.bias) 107 | if isinstance(m, nn.BatchNorm1d): 108 | nn.init.constant_(m.weight,1) 109 | nn.init.constant_(m.bias,0) 110 | if isinstance(m,SqueezeExcite1d): 111 | stdv1=math.sqrt(2./m.w1.size[0]) 112 | nn.init.normal_(m.w1,0.,stdv1) 113 | stdv2=math.sqrt(1./m.w2.size[1]) 114 | nn.init.normal_(m.w2,0.,stdv2) 115 | 116 | # Cell 117 | def create_head1d(nf, nc, lin_ftrs=None, ps=0.5, bn_final:bool=False, bn:bool=True, act="relu", concat_pooling=True): 118 | "Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here" 119 | lin_ftrs = [2*nf if concat_pooling else nf, nc] if lin_ftrs is None else [2*nf if concat_pooling else nf] + lin_ftrs + [nc] #was [nf, 512,nc] 120 | ps = listify(ps) 121 | if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps 122 | actns = [nn.ReLU(inplace=True) if act=="relu" else nn.ELU(inplace=True)] * (len(lin_ftrs)-2) + [None] 123 | layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()] 124 | for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns): 125 | layers += bn_drop_lin(ni,no,bn,p,actn) 126 | if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01)) 127 | return nn.Sequential(*layers) 128 | 129 | # Cell 130 | class basic_conv1d(nn.Sequential): 131 | '''basic conv1d''' 132 | def __init__(self, filters=[128,128,128,128],kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,split_first_layer=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): 133 | layers = [] 134 | if(isinstance(kernel_size,int)): 135 | kernel_size = [kernel_size]*len(filters) 136 | for i in range(len(filters)): 137 | layers_tmp = [] 138 | 139 | layers_tmp.append(_conv1d(input_channels if i==0 else filters[i-1],filters[i],kernel_size=kernel_size[i],stride=(1 if (split_first_layer is True and i==0) else stride),dilation=dilation,act="none" if ((headless is True and i==len(filters)-1) or (split_first_layer is True and i==0)) else act, bn=False if (headless is True and i==len(filters)-1) else bn,drop_p=(0. if i==0 else drop_p))) 140 | if((split_first_layer is True and i==0)): 141 | layers_tmp.append(_conv1d(filters[0],filters[0],kernel_size=1,stride=1,act=act, bn=bn,drop_p=0.)) 142 | #layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn))) 143 | #layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn)) 144 | if(pool>0 and i0): 147 | layers_tmp.append(SqueezeExcite1d(filters[i],squeeze_excite_reduction)) 148 | layers.append(nn.Sequential(*layers_tmp)) 149 | 150 | #head 151 | #layers.append(nn.AdaptiveAvgPool1d(1)) 152 | #layers.append(nn.Linear(filters[-1],num_classes)) 153 | #head #inplace=True leads to a runtime error see ReLU+ dropout https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5 154 | self.headless = headless 155 | if(headless is True): 156 | head = nn.Sequential(nn.AdaptiveAvgPool1d(1),Flatten()) 157 | else: 158 | head=create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) 159 | layers.append(head) 160 | 161 | super().__init__(*layers) 162 | 163 | def get_layer_groups(self): 164 | return (self[2],self[-1]) 165 | 166 | def get_output_layer(self): 167 | if self.headless is False: 168 | return self[-1][-1] 169 | else: 170 | return None 171 | 172 | def set_output_layer(self,x): 173 | if self.headless is False: 174 | self[-1][-1] = x 175 | 176 | 177 | # Cell 178 | def fcn(filters=[128]*5,num_classes=2,input_channels=8,**kwargs): 179 | filters_in = filters + [num_classes] 180 | return basic_conv1d(filters=filters_in,kernel_size=3,stride=1,pool=2,pool_stride=2,input_channels=input_channels,act="relu",bn=True,headless=True) 181 | 182 | def fcn_wang(num_classes=2,input_channels=8,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 183 | return basic_conv1d(filters=[128,256,128],kernel_size=[8,5,3],stride=1,pool=0,pool_stride=2, num_classes=num_classes,input_channels=input_channels,act="relu",bn=True,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) 184 | 185 | def schirrmeister(num_classes=2,input_channels=8,kernel_size=10,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 186 | return basic_conv1d(filters=[25,50,100,200],kernel_size=kernel_size, stride=3, pool=3, pool_stride=1, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,split_first_layer=True,drop_p=0.5,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) 187 | 188 | def sen(filters=[128]*5,num_classes=2,input_channels=8,kernel_size=3,squeeze_excite_reduction=16,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 189 | return basic_conv1d(filters=filters,kernel_size=kernel_size,stride=2,pool=0,pool_stride=0,input_channels=input_channels,act="relu",bn=True,num_classes=num_classes,squeeze_excite_reduction=squeeze_excite_reduction,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) 190 | 191 | def basic1d(filters=[128]*5,kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 192 | return basic_conv1d(filters=filters,kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) 193 | -------------------------------------------------------------------------------- /utils/contrastive_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | import os 5 | import pickle 6 | import os 7 | import random 8 | 9 | from .timeseries_transformations import TTimeOut,ToTensor,TGaussianNoise, TRandomResizedCrop,TTimeOut_difflead,Transpose 10 | 11 | class Normalize(object): 12 | """Normalize using given stats. 13 | """ 14 | def __init__(self, stats_mean, stats_std, input=True, channels=[]): 15 | self.stats_mean=stats_mean.astype(np.float32) if stats_mean is not None else None 16 | self.stats_std=stats_std.astype(np.float32)+1e-8 if stats_std is not None else None 17 | self.input = input 18 | if(len(channels)>0): 19 | for i in range(len(stats_mean)): 20 | if(not(i in channels)): 21 | self.stats_mean[:,i]=0 22 | self.stats_std[:,i]=1 23 | 24 | def __call__(self, sample): 25 | datax, labelx = sample 26 | data = datax if self.input else labelx 27 | #assuming channel last 28 | if(self.stats_mean is not None): 29 | data = data - self.stats_mean 30 | if(self.stats_std is not None): 31 | data = data/self.stats_std 32 | 33 | if(self.input): 34 | return (data, labelx) 35 | else: 36 | return (datax, data) 37 | 38 | def replace_labels(x, stay_idx, remove_idx): 39 | res = [] 40 | for y in x: 41 | if y == remove_idx: 42 | res.append(stay_idx) 43 | else: 44 | res.append(y) 45 | return res 46 | 47 | def keep_one_random_class(x): 48 | res = np.random.choice(x,1)[0] 49 | return res 50 | 51 | def transformations_from_strings(transformations, t_params): 52 | if transformations is None: 53 | return [ToTensor()] 54 | 55 | def str_to_trafo(trafo): 56 | if trafo == "RandomResizedCrop": 57 | return TRandomResizedCrop(crop_ratio_range=t_params["rr_crop_ratio_range"], output_size=t_params["output_size"]) 58 | elif trafo == "TimeOut": 59 | return TTimeOut(crop_ratio_range=t_params["to_crop_ratio_range"]) 60 | elif trafo == "GaussianNoise": 61 | return TGaussianNoise(scale=t_params["gaussian_scale"]) 62 | elif trafo == "TimeOut_difflead": 63 | return TTimeOut_difflead(crop_ratio_range=t_params["to_crop_ratio_range"]) 64 | else: 65 | raise Exception(str(trafo) + " is not a valid transformation") 66 | 67 | trafo_list = [ToTensor(transpose_data=False)] + [str_to_trafo(trafo) 68 | for trafo in transformations] + [Normalize(stats_mean=t_params["stats_mean"],stats_std=t_params["stats_std"])] + [Transpose()] 69 | 70 | return trafo_list 71 | 72 | 73 | class ECG_contrastive_dataset(torch.utils.data.Dataset): 74 | def __init__(self, summary_folder, signal_size, stride, chunk_length, transforms, t_params, 75 | equivalent_classes, sample_items_per_record=1, random_crop=True): 76 | 77 | self.folder = summary_folder 78 | self.signal_size = signal_size 79 | self.transforms = transformations_from_strings(transforms, t_params) 80 | # number of small samples we want to take out of the big signal data 81 | self.sample_items_per_record = sample_items_per_record 82 | # from the large signal data, we randomly choose where we acquire the sample data 83 | self.random_crop = random_crop 84 | 85 | # Loading data info 86 | self.df = pickle.load(open(os.path.join(self.folder,"df_memmap.pkl"), "rb")) 87 | self.lbl_itos = np.load(os.path.join(self.folder,"lbl_itos.npy")) 88 | self.mean = np.load(os.path.join(self.folder,"mean.npy")) 89 | self.std = np.load(os.path.join(self.folder,"std.npy")) 90 | 91 | stack_remove_idx = [] 92 | # Grouping the equivalent classes, remove the correspond classes 93 | if len(equivalent_classes)!=0: 94 | for i in range(len(equivalent_classes)): 95 | stay_class, remove_class = equivalent_classes[i] 96 | if stay_class not in self.lbl_itos or remove_class not in self.lbl_itos: 97 | print(f'{stay_class},{remove_class}: one of those is not in the dictionary') 98 | else: 99 | stay_idx = np.where(self.lbl_itos==stay_class)[0][0] 100 | remove_idx = np.where(self.lbl_itos==remove_class)[0][0] 101 | self.df['label'] = self.df['label'].apply(lambda x: replace_labels(x,stay_idx,remove_idx)) 102 | stack_remove_idx.append(remove_idx) 103 | 104 | 105 | self.df['label'] = self.df['label'].apply(lambda x: keep_one_random_class(x)) 106 | self.lbl_itos = np.delete(self.lbl_itos,stack_remove_idx) 107 | 108 | 109 | self.timeseries_df_data = np.array(self.df['data']) 110 | if(self.timeseries_df_data.dtype not in [np.int16, np.int32, np.int64]): 111 | self.timeseries_df_data = np.array(self.df["data"].astype(str)).astype(np.string_) 112 | 113 | #stack arrays/lists for proper batching 114 | if(isinstance(self.df['data'].iloc[0],list) or isinstance(self.df['label'].iloc[0],np.ndarray)): 115 | self.timeseries_df_label = np.stack(self.df['label']) 116 | else: # single integers/floats 117 | self.timeseries_df_label = np.array(self.df['label']) 118 | #everything else cannot be batched anyway mp.Manager().list(self.timeseries_df_label) 119 | if(self.timeseries_df_label.dtype not in [np.int16, np.int32, np.int64, np.float32, np.float64]): 120 | # assert(annotation and memmap_filename is None and npy_data is None)#only for filenames in mode files 121 | self.timeseries_df_label = np.array(self.df['label'].apply(lambda x:str(x))).astype(np.string_) 122 | 123 | # load meta data for memmap npy 124 | self.mode = "memmap" 125 | memmap_meta = np.load(os.path.join(self.folder,'memmap_meta.npz'), allow_pickle=True) 126 | self.memmap_start = memmap_meta["start"] 127 | self.memmap_shape = memmap_meta["shape"] 128 | self.memmap_length = memmap_meta["length"] 129 | self.memmap_file_idx = memmap_meta["file_idx"] 130 | self.memmap_dtype = np.dtype(str(memmap_meta["dtype"])) 131 | self.memmap_filenames = np.array(memmap_meta["filenames"]).astype(np.string_)#save as byte to avoid issue with mp 132 | 133 | # load data from memamp file 134 | self.memmap_signaldata = np.memmap(os.path.join(self.folder,'memmap.npy'),self.memmap_dtype, mode='r', shape=tuple(self.memmap_shape[0])) 135 | 136 | # get the position of the signal inside the stack memmap signal data 137 | self.df_idx_mapping = [] 138 | self.start_idx_mapping = [] 139 | self.end_idx_mapping = [] 140 | start_idx = 0 141 | min_chunk_length = signal_size 142 | 143 | for df_idx,(id,row) in enumerate(self.df.iterrows()): 144 | data_length = self.memmap_length[row["data"]] 145 | 146 | if(chunk_length == 0): # do not split into chunks 147 | idx_start = [start_idx] 148 | idx_end = [data_length] 149 | else: 150 | idx_start = list(range(start_idx,data_length,chunk_length if stride is None else stride)) 151 | idx_end = [min(l+chunk_length, data_length) for l in idx_start] 152 | 153 | #remove final chunk(s) if too short 154 | for i in range(len(idx_start)): 155 | if(idx_end[i]-idx_start[i]< min_chunk_length): 156 | del idx_start[i:] 157 | del idx_end[i:] 158 | break 159 | #append to lists 160 | copies = 0 161 | for _ in range(copies+1): 162 | for i_s,i_e in zip(idx_start,idx_end): 163 | self.df_idx_mapping.append(df_idx) 164 | self.start_idx_mapping.append(i_s) 165 | self.end_idx_mapping.append(i_e) 166 | 167 | #convert to np.array to avoid mp issues with python lists 168 | self.df_idx_mapping = np.array(self.df_idx_mapping) 169 | self.start_idx_mapping = np.array(self.start_idx_mapping) 170 | self.end_idx_mapping = np.array(self.end_idx_mapping) 171 | 172 | def __len__(self): 173 | return len(self.df_idx_mapping) 174 | 175 | @property 176 | def is_empty(self): 177 | return len(self.df_idx_mapping)==0 178 | 179 | def __getitem__(self, idx): 180 | lst_data_i = [] 181 | lst_data_j = [] 182 | lst_lbl = [] 183 | lst_patient = [] 184 | for _ in range(self.sample_items_per_record): 185 | #determine crop idxs 186 | timesteps= self.get_sample_length(idx) 187 | 188 | if(self.random_crop): #random crop 189 | if(timesteps==self.signal_size): 190 | start_idx_rel = 0 191 | else: 192 | # get random start of the crop inside the big signal 193 | start_idx_rel = random.randint(0, timesteps - self.signal_size -1)#np.random.randint(0, timesteps - self.output_size) 194 | else: 195 | # if not random, this may be for valid and the timesteps is probably equal to the signal_size 196 | start_idx_rel = (timesteps - self.signal_size)//2 197 | if(self.sample_items_per_record==1): 198 | data_i, data_j, label, patient = self.get_signal_sample(idx,start_idx_rel) 199 | return {'sig_i':data_i,'sig_j':data_j,'lbl':label,'idx':patient} 200 | else: 201 | data_i, data_j, label, patient = self.get_signal_sample(idx,start_idx_rel) 202 | lst_data_i.append(data_i) 203 | lst_data_j.append(data_j) 204 | lst_patient.append(patient) 205 | lst_lbl.append(label) 206 | lst_data_i = torch.stack(lst_data_i) 207 | lst_data_j = torch.stack(lst_data_j) 208 | lst_lbl = torch.from_numpy(np.stack(lst_lbl)) 209 | 210 | return {'sig_i':lst_data_i,'sig_j':lst_data_j,'lbl':lst_lbl,'idx':lst_patient} 211 | 212 | def get_signal_sample(self, idx,start_idx_rel): 213 | df_idx = self.df_idx_mapping[idx] 214 | start_idx = self.start_idx_mapping[idx] 215 | end_idx = self.end_idx_mapping[idx] 216 | # determine crop idxs 217 | timesteps= end_idx - start_idx 218 | assert(timesteps>=self.signal_size) 219 | start_idx_crop = start_idx + start_idx_rel 220 | end_idx_crop = start_idx_crop+self.signal_size 221 | 222 | memmap_idx = self.timeseries_df_data[df_idx] 223 | idx_offset = self.memmap_start[memmap_idx] 224 | 225 | signal_data = np.copy(self.memmap_signaldata[idx_offset + start_idx_crop: idx_offset + end_idx_crop]) 226 | 227 | #print(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop]) 228 | label = self.timeseries_df_label[df_idx] 229 | sample1 = (signal_data,label) 230 | sample2 = (signal_data,label) 231 | 232 | for trans in self.transforms: 233 | sample1 = trans(sample1) 234 | sample2 = trans(sample2) 235 | 236 | aug_i, lbl_i = sample1 237 | aug_j, lbl_j = sample2 238 | 239 | return aug_i, aug_j, label, df_idx 240 | 241 | def get_id_mapping(self): 242 | return self.df_idx_mapping 243 | 244 | def get_sample_id(self,idx): 245 | return self.df_idx_mapping[idx] 246 | 247 | def get_sample_length(self,idx): 248 | return self.end_idx_mapping[idx]-self.start_idx_mapping[idx] 249 | 250 | def get_sample_start(self,idx): 251 | return self.start_idx_mapping[idx] -------------------------------------------------------------------------------- /models/inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv1d(nn.Module): 6 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 7 | super(BasicConv1d, self).__init__() 8 | self.conv = nn.Conv1d( 9 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 10 | self.bn = nn.BatchNorm1d(out_planes, eps=.001) 11 | self.relu = nn.ReLU(inplace=False) 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | x = self.bn(x) 16 | x = self.relu(x) 17 | return x 18 | 19 | 20 | class Mixed_5b(nn.Module): 21 | def __init__(self): 22 | super(Mixed_5b, self).__init__() 23 | 24 | self.branch0 = BasicConv1d(192, 96, kernel_size=1, stride=1) 25 | 26 | self.branch1 = nn.Sequential( 27 | BasicConv1d(192, 48, kernel_size=1, stride=1), 28 | BasicConv1d(48, 64, kernel_size=5, stride=1, padding=2) 29 | ) 30 | 31 | self.branch2 = nn.Sequential( 32 | BasicConv1d(192, 64, kernel_size=1, stride=1), 33 | BasicConv1d(64, 96, kernel_size=3, stride=1, padding=1), 34 | BasicConv1d(96, 96, kernel_size=3, stride=1, padding=1) 35 | ) 36 | 37 | self.branch3 = nn.Sequential( 38 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 39 | BasicConv1d(192, 64, kernel_size=1, stride=1) 40 | ) 41 | 42 | def forward(self, x): 43 | x0 = self.branch0(x) 44 | x1 = self.branch1(x) 45 | x2 = self.branch2(x) 46 | x3 = self.branch3(x) 47 | out = torch.cat((x0, x1, x2, x3), 1) 48 | return out 49 | 50 | 51 | class Block35(nn.Module): 52 | def __init__(self, scale=1.0): 53 | super(Block35, self).__init__() 54 | 55 | self.scale = scale 56 | 57 | self.branch0 = BasicConv1d(320, 32, kernel_size=1, stride=1) 58 | 59 | self.branch1 = nn.Sequential( 60 | BasicConv1d(320, 32, kernel_size=1, stride=1), 61 | BasicConv1d(32, 32, kernel_size=3, stride=1, padding=1) 62 | ) 63 | 64 | self.branch2 = nn.Sequential( 65 | BasicConv1d(320, 32, kernel_size=1, stride=1), 66 | BasicConv1d(32, 48, kernel_size=3, stride=1, padding=1), 67 | BasicConv1d(48, 64, kernel_size=3, stride=1, padding=1) 68 | ) 69 | 70 | self.conv2d = nn.Conv1d(128, 320, kernel_size=1, stride=1) 71 | self.relu = nn.ReLU(inplace=False) 72 | 73 | def forward(self, x): 74 | x0 = self.branch0(x) 75 | x1 = self.branch1(x) 76 | x2 = self.branch2(x) 77 | out = torch.cat((x0, x1, x2), 1) 78 | out = self.conv2d(out) 79 | out = out * self.scale + x 80 | out = self.relu(out) 81 | return out 82 | 83 | 84 | class Mixed_6a(nn.Module): 85 | def __init__(self): 86 | super(Mixed_6a, self).__init__() 87 | 88 | self.branch0 = BasicConv1d(320, 384, kernel_size=3, stride=2) 89 | 90 | self.branch1 = nn.Sequential( 91 | BasicConv1d(320, 256, kernel_size=1, stride=1), 92 | BasicConv1d(256, 256, kernel_size=3, stride=1, padding=1), 93 | BasicConv1d(256, 384, kernel_size=3, stride=2) 94 | ) 95 | 96 | self.branch2 = nn.MaxPool1d(3, stride=2) 97 | 98 | def forward(self, x): 99 | x0 = self.branch0(x) 100 | x1 = self.branch1(x) 101 | x2 = self.branch2(x) 102 | out = torch.cat((x0, x1, x2), 1) 103 | return out 104 | 105 | 106 | class Block17(nn.Module): 107 | def __init__(self, scale=1.0): 108 | super(Block17, self).__init__() 109 | 110 | self.scale = scale 111 | 112 | self.branch0 = BasicConv1d(1088, 192, kernel_size=1, stride=1) 113 | 114 | self.branch1 = nn.Sequential( 115 | BasicConv1d(1088, 128, kernel_size=1, stride=1), 116 | BasicConv1d(128, 160, kernel_size=7, stride=1, padding=3), 117 | BasicConv1d(160, 192, kernel_size=7, stride=1, padding=3) 118 | ) 119 | 120 | self.conv2d = nn.Conv1d(384, 1088, kernel_size=1, stride=1) 121 | self.relu = nn.ReLU(inplace=False) 122 | 123 | def forward(self, x): 124 | x0 = self.branch0(x) 125 | x1 = self.branch1(x) 126 | out = torch.cat((x0, x1), 1) 127 | out = self.conv2d(out) 128 | out = out * self.scale + x 129 | out = self.relu(out) 130 | return out 131 | 132 | 133 | class Mixed_7a(nn.Module): 134 | def __init__(self): 135 | super(Mixed_7a, self).__init__() 136 | 137 | self.branch0 = nn.Sequential( 138 | BasicConv1d(1088, 256, kernel_size=1, stride=1), 139 | BasicConv1d(256, 384, kernel_size=3, stride=2) 140 | ) 141 | 142 | self.branch1 = nn.Sequential( 143 | BasicConv1d(1088, 256, kernel_size=1, stride=1), 144 | BasicConv1d(256, 288, kernel_size=3, stride=2) 145 | ) 146 | 147 | self.branch2 = nn.Sequential( 148 | BasicConv1d(1088, 256, kernel_size=1, stride=1), 149 | BasicConv1d(256, 288, kernel_size=3, stride=1, padding=1), 150 | BasicConv1d(288, 320, kernel_size=3, stride=2) 151 | ) 152 | 153 | self.branch3 = nn.MaxPool1d(3, stride=2) 154 | 155 | def forward(self, x): 156 | x0 = self.branch0(x) 157 | x1 = self.branch1(x) 158 | x2 = self.branch2(x) 159 | x3 = self.branch3(x) 160 | out = torch.cat((x0, x1, x2, x3), 1) 161 | return out 162 | 163 | 164 | class Block8(nn.Module): 165 | 166 | def __init__(self, scale=1.0, no_relu=False): 167 | super(Block8, self).__init__() 168 | 169 | self.scale = scale 170 | 171 | self.branch0 = BasicConv1d(2080, 192, kernel_size=1, stride=1) 172 | 173 | self.branch1 = nn.Sequential( 174 | BasicConv1d(2080, 192, kernel_size=1, stride=1), 175 | BasicConv1d(192, 224, kernel_size=3, stride=1, padding=1), 176 | BasicConv1d(224, 256, kernel_size=3, stride=1, padding=1) 177 | ) 178 | 179 | self.conv2d = nn.Conv1d(448, 2080, kernel_size=1, stride=1) 180 | self.relu = None if no_relu else nn.ReLU(inplace=False) 181 | 182 | def forward(self, x): 183 | x0 = self.branch0(x) 184 | x1 = self.branch1(x) 185 | out = torch.cat((x0, x1), 1) 186 | out = self.conv2d(out) 187 | out = out * self.scale + x 188 | if self.relu is not None: 189 | out = self.relu(out) 190 | return out 191 | 192 | def adaptive_pool_feat_mult(pool_type='avg'): 193 | if pool_type == 'catavgmax': 194 | return 2 195 | else: 196 | return 1 197 | 198 | class SelectAdaptivePool2d(nn.Module): 199 | """Selectable global pooling layer with dynamic input kernel size 200 | """ 201 | def __init__(self, output_size=1, pool_type='avg', flatten=False): 202 | super(SelectAdaptivePool2d, self).__init__() 203 | self.output_size = output_size 204 | self.pool_type = pool_type 205 | self.flatten = flatten 206 | self.pool = nn.AdaptiveAvgPool1d(output_size) 207 | 208 | def forward(self, x): 209 | x = self.pool(x) 210 | if self.flatten: 211 | x = x.flatten(1) 212 | return x 213 | 214 | def feat_mult(self): 215 | return adaptive_pool_feat_mult(self.pool_type) 216 | 217 | def __repr__(self): 218 | return self.__class__.__name__ + ' (' \ 219 | + 'output_size=' + str(self.output_size) \ 220 | + ', pool_type=' + self.pool_type + ')' 221 | 222 | class InceptionResnetV2(nn.Module): 223 | def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): 224 | super(InceptionResnetV2, self).__init__() 225 | self.drop_rate = drop_rate 226 | self.num_classes = num_classes 227 | self.num_features = 1536 228 | 229 | self.conv2d_1a = BasicConv1d(in_chans, 32, kernel_size=3, stride=2) 230 | self.conv2d_2a = BasicConv1d(32, 32, kernel_size=3, stride=1) 231 | self.conv2d_2b = BasicConv1d(32, 64, kernel_size=3, stride=1, padding=1) 232 | self.maxpool_3a = nn.MaxPool1d(3, stride=2) 233 | self.conv2d_3b = BasicConv1d(64, 80, kernel_size=1, stride=1) 234 | self.conv2d_4a = BasicConv1d(80, 192, kernel_size=3, stride=1) 235 | self.maxpool_5a = nn.MaxPool1d(3, stride=2) 236 | self.mixed_5b = Mixed_5b() 237 | self.repeat = nn.Sequential( 238 | Block35(scale=0.17), 239 | Block35(scale=0.17), 240 | Block35(scale=0.17), 241 | Block35(scale=0.17), 242 | Block35(scale=0.17), 243 | Block35(scale=0.17), 244 | Block35(scale=0.17), 245 | Block35(scale=0.17), 246 | Block35(scale=0.17), 247 | Block35(scale=0.17) 248 | ) 249 | self.mixed_6a = Mixed_6a() 250 | self.repeat_1 = nn.Sequential( 251 | Block17(scale=0.10), 252 | Block17(scale=0.10), 253 | Block17(scale=0.10), 254 | Block17(scale=0.10), 255 | Block17(scale=0.10), 256 | Block17(scale=0.10), 257 | Block17(scale=0.10), 258 | Block17(scale=0.10), 259 | Block17(scale=0.10), 260 | Block17(scale=0.10), 261 | Block17(scale=0.10), 262 | Block17(scale=0.10), 263 | Block17(scale=0.10), 264 | Block17(scale=0.10), 265 | Block17(scale=0.10), 266 | Block17(scale=0.10), 267 | Block17(scale=0.10), 268 | Block17(scale=0.10), 269 | Block17(scale=0.10), 270 | Block17(scale=0.10) 271 | ) 272 | self.mixed_7a = Mixed_7a() 273 | self.repeat_2 = nn.Sequential( 274 | Block8(scale=0.20), 275 | Block8(scale=0.20), 276 | Block8(scale=0.20), 277 | Block8(scale=0.20), 278 | Block8(scale=0.20), 279 | Block8(scale=0.20), 280 | Block8(scale=0.20), 281 | Block8(scale=0.20), 282 | Block8(scale=0.20) 283 | ) 284 | self.block8 = Block8(no_relu=True) 285 | self.conv2d_7b = BasicConv1d(2080, self.num_features, kernel_size=1, stride=1) 286 | self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) 287 | # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC 288 | self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) 289 | 290 | def get_classifier(self): 291 | return self.classif 292 | 293 | def reset_classifier(self, num_classes, global_pool='avg'): 294 | self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) 295 | self.num_classes = num_classes 296 | if num_classes: 297 | num_features = self.num_features * self.global_pool.feat_mult() 298 | self.classif = nn.Linear(num_features, num_classes) 299 | else: 300 | self.classif = nn.Identity() 301 | 302 | def forward_features(self, x): 303 | x = self.conv2d_1a(x) 304 | x = self.conv2d_2a(x) 305 | x = self.conv2d_2b(x) 306 | x = self.maxpool_3a(x) 307 | x = self.conv2d_3b(x) 308 | x = self.conv2d_4a(x) 309 | x = self.maxpool_5a(x) 310 | x = self.mixed_5b(x) 311 | x = self.repeat(x) 312 | x = self.mixed_6a(x) 313 | x = self.repeat_1(x) 314 | x = self.mixed_7a(x) 315 | x = self.repeat_2(x) 316 | x = self.block8(x) 317 | x = self.conv2d_7b(x) 318 | return x 319 | 320 | def forward(self, x): 321 | x = self.forward_features(x) 322 | x = self.global_pool(x).flatten(1) 323 | if self.drop_rate > 0: 324 | x = F.dropout(x, p=self.drop_rate, training=self.training) 325 | x = self.classif(x) 326 | return x 327 | -------------------------------------------------------------------------------- /models/se_inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SELayer_1d(nn.Module): 6 | def __init__(self, channel, reduction=16): 7 | super(SELayer_1d, self).__init__() 8 | # self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 10 | self.fc = nn.Sequential( 11 | nn.Linear(channel, channel // reduction, bias=False), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(channel // reduction, channel, bias=False), 14 | nn.Sigmoid() 15 | ) 16 | 17 | def forward(self, x): 18 | b, c, _ = x.size() 19 | y = self.avg_pool(x).view(b, c) 20 | y = self.fc(y).view(b, c, 1) 21 | return x * y.expand_as(x) 22 | 23 | class BasicConv1d(nn.Module): 24 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 25 | super(BasicConv1d, self).__init__() 26 | self.conv = nn.Conv1d( 27 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 28 | self.bn = nn.BatchNorm1d(out_planes, eps=.001) 29 | self.relu = nn.ReLU(inplace=False) 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | x = self.bn(x) 34 | x = self.relu(x) 35 | return x 36 | 37 | 38 | class Mixed_5b(nn.Module): 39 | def __init__(self): 40 | super(Mixed_5b, self).__init__() 41 | 42 | self.branch0 = BasicConv1d(192, 96, kernel_size=1, stride=1) 43 | 44 | self.branch1 = nn.Sequential( 45 | BasicConv1d(192, 48, kernel_size=1, stride=1), 46 | BasicConv1d(48, 64, kernel_size=5, stride=1, padding=2) 47 | ) 48 | 49 | self.branch2 = nn.Sequential( 50 | BasicConv1d(192, 64, kernel_size=1, stride=1), 51 | BasicConv1d(64, 96, kernel_size=3, stride=1, padding=1), 52 | BasicConv1d(96, 96, kernel_size=3, stride=1, padding=1) 53 | ) 54 | 55 | self.branch3 = nn.Sequential( 56 | nn.AvgPool1d(3, stride=1, padding=1, count_include_pad=False), 57 | BasicConv1d(192, 64, kernel_size=1, stride=1) 58 | ) 59 | 60 | def forward(self, x): 61 | x0 = self.branch0(x) 62 | x1 = self.branch1(x) 63 | x2 = self.branch2(x) 64 | x3 = self.branch3(x) 65 | out = torch.cat((x0, x1, x2, x3), 1) 66 | return out 67 | 68 | 69 | class Block35(nn.Module): 70 | def __init__(self, scale=1.0): 71 | super(Block35, self).__init__() 72 | 73 | self.scale = scale 74 | 75 | self.branch0 = BasicConv1d(320, 32, kernel_size=1, stride=1) 76 | 77 | self.branch1 = nn.Sequential( 78 | BasicConv1d(320, 32, kernel_size=1, stride=1), 79 | BasicConv1d(32, 32, kernel_size=3, stride=1, padding=1) 80 | ) 81 | 82 | self.branch2 = nn.Sequential( 83 | BasicConv1d(320, 32, kernel_size=1, stride=1), 84 | BasicConv1d(32, 48, kernel_size=3, stride=1, padding=1), 85 | BasicConv1d(48, 64, kernel_size=3, stride=1, padding=1) 86 | ) 87 | 88 | self.conv2d = nn.Conv1d(128, 320, kernel_size=1, stride=1) 89 | self.relu = nn.ReLU(inplace=False) 90 | 91 | def forward(self, x): 92 | x0 = self.branch0(x) 93 | x1 = self.branch1(x) 94 | x2 = self.branch2(x) 95 | out = torch.cat((x0, x1, x2), 1) 96 | out = self.conv2d(out) 97 | out = out * self.scale + x 98 | out = self.relu(out) 99 | return out 100 | 101 | 102 | class Mixed_6a(nn.Module): 103 | def __init__(self): 104 | super(Mixed_6a, self).__init__() 105 | 106 | self.branch0 = BasicConv1d(320, 384, kernel_size=3, stride=2) 107 | 108 | self.branch1 = nn.Sequential( 109 | BasicConv1d(320, 256, kernel_size=1, stride=1), 110 | BasicConv1d(256, 256, kernel_size=3, stride=1, padding=1), 111 | BasicConv1d(256, 384, kernel_size=3, stride=2) 112 | ) 113 | 114 | self.branch2 = nn.MaxPool1d(3, stride=2) 115 | 116 | def forward(self, x): 117 | x0 = self.branch0(x) 118 | x1 = self.branch1(x) 119 | x2 = self.branch2(x) 120 | out = torch.cat((x0, x1, x2), 1) 121 | return out 122 | 123 | 124 | class Block17(nn.Module): 125 | def __init__(self, scale=1.0): 126 | super(Block17, self).__init__() 127 | 128 | self.scale = scale 129 | 130 | self.branch0 = BasicConv1d(1088, 192, kernel_size=1, stride=1) 131 | 132 | self.branch1 = nn.Sequential( 133 | BasicConv1d(1088, 128, kernel_size=1, stride=1), 134 | BasicConv1d(128, 160, kernel_size=7, stride=1, padding=3), 135 | BasicConv1d(160, 192, kernel_size=7, stride=1, padding=3) 136 | ) 137 | 138 | self.conv2d = nn.Conv1d(384, 1088, kernel_size=1, stride=1) 139 | self.relu = nn.ReLU(inplace=False) 140 | 141 | def forward(self, x): 142 | x0 = self.branch0(x) 143 | x1 = self.branch1(x) 144 | out = torch.cat((x0, x1), 1) 145 | out = self.conv2d(out) 146 | out = out * self.scale + x 147 | out = self.relu(out) 148 | return out 149 | 150 | 151 | class Mixed_7a(nn.Module): 152 | def __init__(self): 153 | super(Mixed_7a, self).__init__() 154 | 155 | self.branch0 = nn.Sequential( 156 | BasicConv1d(1088, 256, kernel_size=1, stride=1), 157 | BasicConv1d(256, 384, kernel_size=3, stride=2) 158 | ) 159 | 160 | self.branch1 = nn.Sequential( 161 | BasicConv1d(1088, 256, kernel_size=1, stride=1), 162 | BasicConv1d(256, 288, kernel_size=3, stride=2) 163 | ) 164 | 165 | self.branch2 = nn.Sequential( 166 | BasicConv1d(1088, 256, kernel_size=1, stride=1), 167 | BasicConv1d(256, 288, kernel_size=3, stride=1, padding=1), 168 | BasicConv1d(288, 320, kernel_size=3, stride=2) 169 | ) 170 | 171 | self.branch3 = nn.MaxPool1d(3, stride=2) 172 | 173 | def forward(self, x): 174 | x0 = self.branch0(x) 175 | x1 = self.branch1(x) 176 | x2 = self.branch2(x) 177 | x3 = self.branch3(x) 178 | out = torch.cat((x0, x1, x2, x3), 1) 179 | return out 180 | 181 | 182 | class Block8(nn.Module): 183 | 184 | def __init__(self, scale=1.0, no_relu=False): 185 | super(Block8, self).__init__() 186 | 187 | self.scale = scale 188 | 189 | self.branch0 = BasicConv1d(2080, 192, kernel_size=1, stride=1) 190 | 191 | self.branch1 = nn.Sequential( 192 | BasicConv1d(2080, 192, kernel_size=1, stride=1), 193 | BasicConv1d(192, 224, kernel_size=3, stride=1, padding=1), 194 | BasicConv1d(224, 256, kernel_size=3, stride=1, padding=1) 195 | ) 196 | 197 | self.conv2d = nn.Conv1d(448, 2080, kernel_size=1, stride=1) 198 | self.relu = None if no_relu else nn.ReLU(inplace=False) 199 | 200 | def forward(self, x): 201 | x0 = self.branch0(x) 202 | x1 = self.branch1(x) 203 | out = torch.cat((x0, x1), 1) 204 | out = self.conv2d(out) 205 | out = out * self.scale + x 206 | if self.relu is not None: 207 | out = self.relu(out) 208 | return out 209 | 210 | def adaptive_pool_feat_mult(pool_type='avg'): 211 | if pool_type == 'catavgmax': 212 | return 2 213 | else: 214 | return 1 215 | 216 | class SelectAdaptivePool1d(nn.Module): 217 | """Selectable global pooling layer with dynamic input kernel size 218 | """ 219 | def __init__(self, output_size=1, pool_type='avg', flatten=False): 220 | super(SelectAdaptivePool1d, self).__init__() 221 | self.output_size = output_size 222 | self.pool_type = pool_type 223 | self.flatten = flatten 224 | self.pool = nn.AdaptiveAvgPool1d(output_size) 225 | 226 | def forward(self, x): 227 | x = self.pool(x) 228 | if self.flatten: 229 | x = x.flatten(1) 230 | return x 231 | 232 | def feat_mult(self): 233 | return adaptive_pool_feat_mult(self.pool_type) 234 | 235 | def __repr__(self): 236 | return self.__class__.__name__ + ' (' \ 237 | + 'output_size=' + str(self.output_size) \ 238 | + ', pool_type=' + self.pool_type + ')' 239 | 240 | class SE_InceptionResnetV2(nn.Module): 241 | def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): 242 | super(SE_InceptionResnetV2, self).__init__() 243 | self.drop_rate = drop_rate 244 | self.num_classes = num_classes 245 | self.num_features = 1536 246 | 247 | self.conv2d_1a = BasicConv1d(in_chans, 32, kernel_size=3, stride=2) 248 | self.conv2d_2a = BasicConv1d(32, 32, kernel_size=3, stride=1) 249 | self.conv2d_2b = BasicConv1d(32, 64, kernel_size=3, stride=1, padding=1) 250 | self.maxpool_3a = nn.MaxPool1d(3, stride=2) 251 | self.conv2d_3b = BasicConv1d(64, 80, kernel_size=1, stride=1) 252 | self.conv2d_4a = BasicConv1d(80, 192, kernel_size=3, stride=1) 253 | self.maxpool_5a = nn.MaxPool1d(3, stride=2) 254 | self.mixed_5b = Mixed_5b() 255 | self.repeat = nn.Sequential( 256 | Block35(scale=0.17), 257 | SELayer_1d(channel=320), 258 | Block35(scale=0.17), 259 | SELayer_1d(channel=320), 260 | Block35(scale=0.17), 261 | SELayer_1d(channel=320), 262 | Block35(scale=0.17), 263 | SELayer_1d(channel=320), 264 | Block35(scale=0.17), 265 | SELayer_1d(channel=320), 266 | Block35(scale=0.17), 267 | SELayer_1d(channel=320), 268 | Block35(scale=0.17), 269 | SELayer_1d(channel=320), 270 | Block35(scale=0.17), 271 | SELayer_1d(channel=320), 272 | Block35(scale=0.17), 273 | SELayer_1d(channel=320), 274 | Block35(scale=0.17), 275 | SELayer_1d(channel=320) 276 | ) 277 | self.mixed_6a = Mixed_6a() 278 | self.repeat_1 = nn.Sequential( 279 | Block17(scale=0.10), 280 | SELayer_1d(channel=1088), 281 | Block17(scale=0.10), 282 | SELayer_1d(channel=1088), 283 | Block17(scale=0.10), 284 | SELayer_1d(channel=1088), 285 | Block17(scale=0.10), 286 | SELayer_1d(channel=1088), 287 | Block17(scale=0.10), 288 | SELayer_1d(channel=1088), 289 | Block17(scale=0.10), 290 | SELayer_1d(channel=1088), 291 | Block17(scale=0.10), 292 | SELayer_1d(channel=1088), 293 | Block17(scale=0.10), 294 | SELayer_1d(channel=1088), 295 | Block17(scale=0.10), 296 | SELayer_1d(channel=1088), 297 | Block17(scale=0.10), 298 | SELayer_1d(channel=1088), 299 | Block17(scale=0.10), 300 | SELayer_1d(channel=1088), 301 | Block17(scale=0.10), 302 | SELayer_1d(channel=1088), 303 | Block17(scale=0.10), 304 | SELayer_1d(channel=1088), 305 | Block17(scale=0.10), 306 | SELayer_1d(channel=1088), 307 | Block17(scale=0.10), 308 | SELayer_1d(channel=1088), 309 | Block17(scale=0.10), 310 | SELayer_1d(channel=1088), 311 | Block17(scale=0.10), 312 | SELayer_1d(channel=1088), 313 | Block17(scale=0.10), 314 | SELayer_1d(channel=1088), 315 | Block17(scale=0.10), 316 | SELayer_1d(channel=1088), 317 | Block17(scale=0.10), 318 | SELayer_1d(channel=1088) 319 | ) 320 | self.mixed_7a = Mixed_7a() 321 | self.repeat_2 = nn.Sequential( 322 | Block8(scale=0.20), 323 | SELayer_1d(channel=2080), 324 | Block8(scale=0.20), 325 | SELayer_1d(channel=2080), 326 | Block8(scale=0.20), 327 | SELayer_1d(channel=2080), 328 | Block8(scale=0.20), 329 | SELayer_1d(channel=2080), 330 | Block8(scale=0.20), 331 | SELayer_1d(channel=2080), 332 | Block8(scale=0.20), 333 | SELayer_1d(channel=2080), 334 | Block8(scale=0.20), 335 | SELayer_1d(channel=2080), 336 | Block8(scale=0.20), 337 | SELayer_1d(channel=2080), 338 | Block8(scale=0.20), 339 | SELayer_1d(channel=2080) 340 | ) 341 | self.block8 = Block8(no_relu=True) 342 | self.conv2d_7b = BasicConv1d(2080, self.num_features, kernel_size=1, stride=1) 343 | self.global_pool = SelectAdaptivePool1d(pool_type=global_pool) 344 | # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC 345 | self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) 346 | 347 | def get_classifier(self): 348 | return self.classif 349 | 350 | def reset_classifier(self, num_classes, global_pool='avg'): 351 | self.global_pool = SelectAdaptivePool1d(pool_type=global_pool) 352 | self.num_classes = num_classes 353 | if num_classes: 354 | num_features = self.num_features * self.global_pool.feat_mult() 355 | self.classif = nn.Linear(num_features, num_classes) 356 | else: 357 | self.classif = nn.Identity() 358 | 359 | def forward_features(self, x): 360 | x = self.conv2d_1a(x) 361 | x = self.conv2d_2a(x) 362 | x = self.conv2d_2b(x) 363 | x = self.maxpool_3a(x) 364 | x = self.conv2d_3b(x) 365 | x = self.conv2d_4a(x) 366 | x = self.maxpool_5a(x) 367 | x = self.mixed_5b(x) 368 | x = self.repeat(x) 369 | x = self.mixed_6a(x) 370 | x = self.repeat_1(x) 371 | x = self.mixed_7a(x) 372 | x = self.repeat_2(x) 373 | x = self.block8(x) 374 | x = self.conv2d_7b(x) 375 | return x 376 | 377 | def forward(self, x): 378 | x = self.forward_features(x) 379 | x = self.global_pool(x).flatten(1) 380 | if self.drop_rate > 0: 381 | x = F.dropout(x, p=self.drop_rate, training=self.training) 382 | x = self.classif(x) 383 | return x -------------------------------------------------------------------------------- /utils/eval_tools.py: -------------------------------------------------------------------------------- 1 | from tracemalloc import start 2 | import numpy as np 3 | import time 4 | from sklearn.metrics import multilabel_confusion_matrix 5 | import numpy as np 6 | import os 7 | 8 | # Check if the input is a number. 9 | def is_number(x): 10 | try: 11 | float(x) 12 | return True 13 | except ValueError: 14 | return False 15 | 16 | def load_table(table_file): 17 | # The table should have the following form: 18 | # 19 | # , a, b, c 20 | # a, 1.2, 2.3, 3.4 21 | # b, 4.5, 5.6, 6.7 22 | # c, 7.8, 8.9, 9.0 23 | # 24 | table = list() 25 | with open(table_file, 'r') as f: 26 | for i, l in enumerate(f): 27 | arrs = [arr.strip() for arr in l.split(',')] 28 | table.append(arrs) 29 | 30 | # Define the numbers of rows and columns and check for errors. 31 | num_rows = len(table)-1 32 | if num_rows<1: 33 | raise Exception('The table {} is empty.'.format(table_file)) 34 | 35 | num_cols = set(len(table[i])-1 for i in range(num_rows)) 36 | if len(num_cols)!=1: 37 | raise Exception('The table {} has rows with different lengths.'.format(table_file)) 38 | num_cols = min(num_cols) 39 | if num_cols<1: 40 | raise Exception('The table {} is empty.'.format(table_file)) 41 | 42 | # Find the row and column labels. 43 | rows = [table[0][j+1] for j in range(num_rows)] 44 | cols = [table[i+1][0] for i in range(num_cols)] 45 | 46 | # Find the entries of the table. 47 | values = np.zeros((num_rows, num_cols), dtype=np.float64) 48 | for i in range(num_rows): 49 | for j in range(num_cols): 50 | value = table[i+1][j+1] 51 | if is_number(value): 52 | values[i, j] = float(value) 53 | else: 54 | values[i, j] = float('nan') 55 | 56 | return rows, cols, values 57 | 58 | # For each set of equivalent classes, replace each class with the representative class for the set. 59 | def replace_equivalent_classes(classes, equivalent_classes): 60 | for j, x in enumerate(classes): 61 | for multiple_classes in equivalent_classes: 62 | if x in multiple_classes: 63 | classes[j] = multiple_classes[0] # Use the first class as the representative class. 64 | return classes 65 | 66 | # Load weights. 67 | def load_weights(weight_file, equivalent_classes): 68 | # Load the weight matrix. 69 | rows, cols, values = load_table(weight_file) 70 | assert(rows == cols) 71 | 72 | # For each collection of equivalent classes, replace each class with the representative class for the set. 73 | rows = replace_equivalent_classes(rows, equivalent_classes) 74 | 75 | # Check that equivalent classes have identical weights. 76 | for j, x in enumerate(rows): 77 | for k, y in enumerate(rows[j+1:]): 78 | if x==y: 79 | assert(np.all(values[j, :]==values[j+1+k, :])) 80 | assert(np.all(values[:, j]==values[:, j+1+k])) 81 | 82 | # Use representative classes. 83 | classes = [x for j, x in enumerate(rows) if x not in rows[:j]] 84 | indices = [rows.index(x) for x in classes] 85 | weights = values[np.ix_(indices, indices)] 86 | 87 | return classes, weights 88 | 89 | 90 | # Compute recording-wise accuracy. 91 | # input is np.bool 92 | def compute_accuracy(labels, outputs): 93 | num_recordings, num_classes = np.shape(labels) 94 | comparison = [np.all(labels[idx,:]==outputs[idx,:]) for idx in range(num_recordings)] 95 | num_correct_recordings = np.count_nonzero(comparison) 96 | 97 | return float(num_correct_recordings) / float(num_recordings) 98 | 99 | 100 | def compute_confusion_matrices(labels, outputs, normalize=False): 101 | # Compute a binary confusion matrix for each class k: 102 | # 103 | # [TN_k FN_k] 104 | # [FP_k TP_k] 105 | # 106 | # If the normalize variable is set to true, then normalize the contributions 107 | # to the confusion matrix by the number of labels per recording. 108 | num_recordings, num_classes = np.shape(labels) 109 | 110 | if not normalize: 111 | A = np.zeros((num_classes, 2, 2)) 112 | for i in range(num_recordings): 113 | for j in range(num_classes): 114 | if labels[i, j]==1 and outputs[i, j]==1: # TP 115 | A[j, 1, 1] += 1 116 | elif labels[i, j]==0 and outputs[i, j]==1: # FP 117 | A[j, 1, 0] += 1 118 | elif labels[i, j]==1 and outputs[i, j]==0: # FN 119 | A[j, 0, 1] += 1 120 | elif labels[i, j]==0 and outputs[i, j]==0: # TN 121 | A[j, 0, 0] += 1 122 | else: # This condition should not happen. 123 | raise ValueError('Error in computing the confusion matrix.') 124 | else: 125 | A = np.zeros((num_classes, 2, 2)) 126 | for i in range(num_recordings): 127 | normalization = float(max(np.sum(labels[i, :]), 1)) 128 | for j in range(num_classes): 129 | if labels[i, j]==1 and outputs[i, j]==1: # TP 130 | A[j, 1, 1] += 1.0/normalization 131 | elif labels[i, j]==0 and outputs[i, j]==1: # FP 132 | A[j, 1, 0] += 1.0/normalization 133 | elif labels[i, j]==1 and outputs[i, j]==0: # FN 134 | A[j, 0, 1] += 1.0/normalization 135 | elif labels[i, j]==0 and outputs[i, j]==0: # TN 136 | A[j, 0, 0] += 1.0/normalization 137 | else: # This condition should not happen. 138 | raise ValueError('Error in computing the confusion matrix.') 139 | 140 | return A 141 | 142 | # Compute macro F-measure. 143 | # input is np.bool 144 | def compute_f_measure(labels, outputs): 145 | num_recordings, num_classes = np.shape(labels) 146 | A = compute_confusion_matrices(labels, outputs) 147 | # [[tn,fn],[fp,tp]] 148 | 149 | f_measure = np.zeros(num_classes) # f_measure_classes 150 | for k in range(num_classes): 151 | tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0] 152 | if 2 * tp + fp + fn: 153 | f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn) 154 | else: 155 | f_measure[k] = float('nan') 156 | 157 | macro_f_measure = np.nanmean(f_measure) 158 | 159 | return macro_f_measure 160 | 161 | def compute_f_measure_mod(labels, outputs): 162 | num_recordings, num_classes = np.shape(labels) 163 | A = multilabel_confusion_matrix(labels, outputs) 164 | # [[tn,fp],[fn,tp]] 165 | f_measure = np.zeros(num_classes) # f_measure_classes 166 | for k in range(num_classes): 167 | tp, fn, fp, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0] 168 | if 2 * tp + fp + fn: 169 | f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn) 170 | else: 171 | f_measure[k] = float('nan') 172 | 173 | macro_f_measure = np.nanmean(f_measure) 174 | return macro_f_measure 175 | 176 | # Compute macro AUROC and macro AUPRC. 177 | # input scalar np.float64 for outputs 178 | def compute_auc(labels, outputs): 179 | num_recordings, num_classes = np.shape(labels) 180 | 181 | # Compute and summarize the confusion matrices for each class across at distinct output values. 182 | auroc = np.zeros(num_classes) 183 | auprc = np.zeros(num_classes) 184 | 185 | for k in range(num_classes): 186 | # We only need to compute TPs, FPs, FNs, and TNs at distinct output values. 187 | thresholds = np.unique(outputs[:, k]) 188 | thresholds = np.append(thresholds, thresholds[-1]+1) 189 | thresholds = thresholds[::-1] 190 | num_thresholds = len(thresholds) 191 | 192 | # Initialize the TPs, FPs, FNs, and TNs. 193 | tp = np.zeros(num_thresholds) 194 | fp = np.zeros(num_thresholds) 195 | fn = np.zeros(num_thresholds) 196 | tn = np.zeros(num_thresholds) 197 | fn[0] = np.sum(labels[:, k]==1) 198 | tn[0] = np.sum(labels[:, k]==0) 199 | 200 | # Find the indices that result in sorted output values. 201 | idx = np.argsort(outputs[:, k])[::-1] 202 | 203 | # Compute the TPs, FPs, FNs, and TNs for class k across thresholds. 204 | i = 0 205 | for j in range(1, num_thresholds): 206 | # Initialize TPs, FPs, FNs, and TNs using values at previous threshold. 207 | tp[j] = tp[j-1] 208 | fp[j] = fp[j-1] 209 | fn[j] = fn[j-1] 210 | tn[j] = tn[j-1] 211 | 212 | # Update the TPs, FPs, FNs, and TNs at i-th output value. 213 | while i < num_recordings and outputs[idx[i], k] >= thresholds[j]: 214 | if labels[idx[i], k]: 215 | tp[j] += 1 216 | fn[j] -= 1 217 | else: 218 | fp[j] += 1 219 | tn[j] -= 1 220 | i += 1 221 | 222 | # Summarize the TPs, FPs, FNs, and TNs for class k. 223 | tpr = np.zeros(num_thresholds) 224 | tnr = np.zeros(num_thresholds) 225 | ppv = np.zeros(num_thresholds) 226 | for j in range(num_thresholds): 227 | if tp[j] + fn[j]: 228 | tpr[j] = float(tp[j]) / float(tp[j] + fn[j]) 229 | else: 230 | tpr[j] = float('nan') 231 | if fp[j] + tn[j]: 232 | tnr[j] = float(tn[j]) / float(fp[j] + tn[j]) 233 | else: 234 | tnr[j] = float('nan') 235 | if tp[j] + fp[j]: 236 | ppv[j] = float(tp[j]) / float(tp[j] + fp[j]) 237 | else: 238 | ppv[j] = float('nan') 239 | 240 | # Compute AUROC as the area under a piecewise linear function with TPR/ 241 | # sensitivity (x-axis) and TNR/specificity (y-axis) and AUPRC as the area 242 | # under a piecewise constant with TPR/recall (x-axis) and PPV/precision 243 | # (y-axis) for class k. 244 | for j in range(num_thresholds-1): 245 | auroc[k] += 0.5 * (tpr[j+1] - tpr[j]) * (tnr[j+1] + tnr[j]) 246 | auprc[k] += (tpr[j+1] - tpr[j]) * ppv[j+1] 247 | 248 | # Compute macro AUROC and macro AUPRC across classes. 249 | macro_auroc = np.nanmean(auroc) 250 | macro_auprc = np.nanmean(auprc) 251 | 252 | return macro_auroc, macro_auprc 253 | 254 | # computer f beta and g beta 255 | # input is np.bool 256 | def compute_beta_measures(labels, outputs, beta): 257 | num_recordings, num_classes = np.shape(labels) 258 | 259 | A = compute_confusion_matrices(labels, outputs, normalize=True) 260 | 261 | f_beta_measure = np.zeros(num_classes) 262 | g_beta_measure = np.zeros(num_classes) 263 | for k in range(num_classes): 264 | tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0] 265 | if (1+beta**2)*tp + fp + beta**2*fn: 266 | f_beta_measure[k] = float((1+beta**2)*tp) / float((1+beta**2)*tp + fp + beta**2*fn) 267 | else: 268 | f_beta_measure[k] = float('nan') 269 | if tp + fp + beta*fn: 270 | g_beta_measure[k] = float(tp) / float(tp + fp + beta*fn) 271 | else: 272 | g_beta_measure[k] = float('nan') 273 | 274 | macro_f_beta_measure = np.nanmean(f_beta_measure) 275 | macro_g_beta_measure = np.nanmean(g_beta_measure) 276 | 277 | return macro_f_beta_measure, macro_g_beta_measure 278 | 279 | # Compute Challenge Metric 280 | # input is np.bool 281 | def compute_challenge_metric(weights,labels,outputs,classes,normal_class): 282 | num_recordings, num_classes = np.shape(labels) 283 | normal_index = classes.index(normal_class) 284 | 285 | # Compute the observed score. 286 | A = compute_modified_confusion_matrix(labels, outputs) 287 | observed_score = np.nansum(weights * A) 288 | 289 | # Compute the score for the model that always chooses the correct label(s). 290 | correct_outputs = labels 291 | A = compute_modified_confusion_matrix(labels, correct_outputs) 292 | correct_score = np.nansum(weights * A) 293 | 294 | # Compute the score for the model that always chooses the normal class. 295 | inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool) 296 | inactive_outputs[:, normal_index] = 1 297 | A = compute_modified_confusion_matrix(labels, inactive_outputs) 298 | inactive_score = np.nansum(weights * A) 299 | 300 | if correct_score != inactive_score: 301 | normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score) 302 | else: 303 | normalized_score = 0.0 304 | 305 | return normalized_score 306 | 307 | def compute_modified_confusion_matrix(labels, outputs): 308 | # Compute a binary multi-class, multi-label confusion matrix, where the rows 309 | # are the labels and the columns are the outputs. 310 | num_recordings, num_classes = np.shape(labels) 311 | A = np.zeros((num_classes, num_classes)) 312 | 313 | # Iterate over all of the recordings. 314 | for i in range(num_recordings): 315 | # Calculate the number of positive labels and/or outputs. 316 | normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1)) 317 | # Iterate over all of the classes. 318 | for j in range(num_classes): 319 | # Assign full and/or partial credit for each positive class. 320 | if labels[i, j]: 321 | for k in range(num_classes): 322 | if outputs[i, k]: 323 | A[j, k] += 1.0/normalization 324 | 325 | return A 326 | 327 | # Load weights. 328 | def load_weights(weight_file, equivalent_classes): 329 | # Load the weight matrix. 330 | rows, cols, values = load_table(weight_file) 331 | assert(rows == cols) 332 | 333 | # For each collection of equivalent classes, replace each class with the representative class for the set. 334 | rows = replace_equivalent_classes(rows, equivalent_classes) 335 | 336 | # Check that equivalent classes have identical weights. 337 | for j, x in enumerate(rows): 338 | for k, y in enumerate(rows[j+1:]): 339 | if x==y: 340 | assert(np.all(values[j, :]==values[j+1+k, :])) 341 | assert(np.all(values[:, j]==values[:, j+1+k])) 342 | 343 | # Use representative classes. 344 | classes = [x for j, x in enumerate(rows) if x not in rows[:j]] 345 | indices = [rows.index(x) for x in classes] 346 | weights = values[np.ix_(indices, indices)] 347 | 348 | return classes, weights 349 | 350 | --------------------------------------------------------------------------------