├── utils
├── __pycache__
│ ├── tools.cpython-39.pyc
│ ├── eval_tools.cpython-39.pyc
│ ├── optimizers.cpython-39.pyc
│ ├── DINO_dataloader.cpython-39.pyc
│ ├── base_dataloader.cpython-39.pyc
│ ├── contrastive_dataloader.cpython-39.pyc
│ └── timeseries_transformations.cpython-39.pyc
├── optimizers.py
├── tools.py
├── scheduler.py
├── contrastive_dataloader.py
└── eval_tools.py
├── models
├── __pycache__
│ ├── seresnet2d.cpython-39.pyc
│ ├── xresnet1d.cpython-39.pyc
│ ├── basic_conv1d.cpython-39.pyc
│ ├── signal_model.cpython-39.pyc
│ ├── ensemble_model.cpython-39.pyc
│ └── spectrogram_model.cpython-39.pyc
├── spectrogram_model.py
├── resnet.py
├── resnet_simclr.py
├── signal_model.py
├── seresnet2d.py
├── seresnet.py
├── ensemble_model.py
├── xresnet1d.py
├── basic_conv1d.py
├── inception_resnet_v2.py
└── se_inception_resnet_v2.py
├── experiments
├── __pycache__
│ └── signal.cpython-39.pyc
├── BYOL_signal.py
├── SIMCLR_signal.py
├── SIMCLR_signal_finetune.py
├── BYOL_signal_finetune.py
├── run_signal.py
├── run_spectrogram.py
└── run_ensembled.py
├── data_folder
└── evaluation-2020-master
│ ├── LICENSE
│ ├── evaluate_12ECG_score.m
│ ├── dx_mapping_scored.csv
│ ├── README.md
│ ├── .gitignore
│ ├── Results
│ ├── physionet_2020_unofficial_scores.csv
│ ├── README.md
│ ├── physionet_2020_official_scores.csv
│ └── physionet_2020_metrics_perDatabase_official_entries.csv
│ ├── weights.csv
│ └── dx_mapping_unscored.csv
├── README.md
└── data_preparation
├── reformat_memmap.py
├── data_extraction_without_preprocessing.py
├── stratify.py
└── data_extraction_with_preprocessing.py
/utils/__pycache__/tools.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/tools.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/seresnet2d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/seresnet2d.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/xresnet1d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/xresnet1d.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/eval_tools.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/eval_tools.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/optimizers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/optimizers.cpython-39.pyc
--------------------------------------------------------------------------------
/experiments/__pycache__/signal.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/experiments/__pycache__/signal.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/basic_conv1d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/basic_conv1d.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/signal_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/signal_model.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ensemble_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/ensemble_model.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/DINO_dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/DINO_dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/base_dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/base_dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/spectrogram_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/models/__pycache__/spectrogram_model.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/contrastive_dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/contrastive_dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/timeseries_transformations.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UARK-AICV/ECG_SSL_12Lead/HEAD/utils/__pycache__/timeseries_transformations.cpython-39.pyc
--------------------------------------------------------------------------------
/models/spectrogram_model.py:
--------------------------------------------------------------------------------
1 | from models.seresnet2d import se_resnet34
2 | import torch.nn as nn
3 |
4 | class spectrogram_model(nn.Module):
5 | def __init__(self,no_classes):
6 | super(spectrogram_model,self).__init__()
7 | self.backbone = se_resnet34()
8 | self.backbone.conv1 = nn.Conv2d(12, 64, kernel_size=7, stride=2, padding=3)
9 | list_of_modules = list(self.backbone.children())
10 | self.features = nn.Sequential(*list_of_modules[:-1])
11 | num_ftrs = self.backbone.fc.in_features
12 |
13 | self.fc = nn.Sequential(
14 | nn.Linear(in_features=num_ftrs,out_features=num_ftrs//2),
15 | nn.Linear(in_features=num_ftrs//2,out_features=no_classes)
16 | )
17 |
18 | def forward(self, x):
19 | h = self.features(x)
20 | h = h.squeeze()
21 | x = self.fc(h)
22 | return x
23 |
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2020, PhysioNet/Computing in Cardiology Challenges
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/utils/optimizers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class LARS(torch.optim.Optimizer):
7 | """
8 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
9 | """
10 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
11 | weight_decay_filter=None, lars_adaptation_filter=None):
12 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
13 | eta=eta, weight_decay_filter=weight_decay_filter,
14 | lars_adaptation_filter=lars_adaptation_filter)
15 | super().__init__(params, defaults)
16 |
17 | @torch.no_grad()
18 | def step(self):
19 | for g in self.param_groups:
20 | for p in g['params']:
21 | dp = p.grad
22 |
23 | if dp is None:
24 | continue
25 |
26 | if p.ndim != 1:
27 | dp = dp.add(p, alpha=g['weight_decay'])
28 |
29 | if p.ndim != 1:
30 | param_norm = torch.norm(p)
31 | update_norm = torch.norm(dp)
32 | one = torch.ones_like(param_norm)
33 | q = torch.where(param_norm > 0.,
34 | torch.where(update_norm > 0,
35 | (g['eta'] * param_norm / update_norm), one), one)
36 | dp = dp.mul(q)
37 |
38 | param_state = self.state[p]
39 | if 'mu' not in param_state:
40 | param_state['mu'] = torch.zeros_like(p)
41 | mu = param_state['mu']
42 | mu.mul_(g['momentum']).add_(dp)
43 |
44 | p.add_(mu, alpha=-g['lr'])
45 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torchvision.models as models
4 | from .xresnet1d import xresnet1d50, xresnet1d101
5 |
6 |
7 | class ResNet(nn.Module):
8 |
9 | def __init__(self, base_model, out_dim, widen=1.0, hidden=False):
10 | super(ResNet, self).__init__()
11 | self.resnet_dict = {"resnet18": models.resnet18(pretrained=True),
12 | "resnet50": models.resnet50(pretrained=True)}
13 |
14 | resnet = self._get_basemodel(base_model)
15 | self.base_model = base_model
16 |
17 | list_of_modules = list(resnet.children())
18 | if "xresnet" in base_model:
19 | self.features = nn.Sequential(*list_of_modules[:-1], list_of_modules[-1][0])
20 | num_ftrs = resnet[-1][-1].in_features
21 | resnet[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2)
22 | else:
23 | resnet[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2)
24 | self.features = nn.Sequential(*list_of_modules[:-1])
25 | num_ftrs = resnet.fc.in_features
26 |
27 | # projection MLP
28 | if hidden:
29 | self.l1 = nn.Linear(num_ftrs, num_ftrs)
30 | self.l2 = nn.Linear(num_ftrs, out_dim)
31 | else:
32 | self.l1 = nn.Linear(num_ftrs, out_dim)
33 |
34 | def _get_basemodel(self, model_name):
35 | try:
36 | model = self.resnet_dict[model_name]
37 | return model
38 | except:
39 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
40 |
41 | def forward(self, x):
42 | h = self.features(x)
43 | h = h.squeeze()
44 |
45 | x = self.l1(h)
46 | x = F.relu(x)
47 | x = self.l2(x)
48 | return h, x
49 |
--------------------------------------------------------------------------------
/models/resnet_simclr.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torchvision.models as models
4 | from .xresnet1d import xresnet1d50, xresnet1d101
5 |
6 |
7 | class ResNetSimCLR(nn.Module):
8 |
9 | def __init__(self, base_model, out_dim, widen=1.0, hidden=False):
10 | super(ResNetSimCLR, self).__init__()
11 | self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
12 | "resnet50": models.resnet50(pretrained=False),
13 | "xresnet1d50": xresnet1d50(widen=widen),
14 | "xresnet1d101": xresnet1d101(widen=widen)}
15 |
16 | resnet = self._get_basemodel(base_model)
17 | self.base_model = base_model
18 |
19 | list_of_modules = list(resnet.children())
20 | if "xresnet" in base_model:
21 | self.features = nn.Sequential(*list_of_modules[:-1], list_of_modules[-1][0])
22 | num_ftrs = resnet[-1][-1].in_features
23 | resnet[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2)
24 | else:
25 | self.features = nn.Sequential(*list_of_modules[:-1])
26 | num_ftrs = resnet.fc.in_features
27 |
28 | # projection MLP
29 | if hidden:
30 | self.l1 = nn.Linear(num_ftrs, num_ftrs)
31 | self.l2 = nn.Linear(num_ftrs, out_dim)
32 | else:
33 | self.l1 = nn.Linear(num_ftrs, out_dim)
34 |
35 | def _get_basemodel(self, model_name):
36 | try:
37 | model = self.resnet_dict[model_name]
38 | return model
39 | except:
40 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
41 |
42 | def forward(self, x):
43 | h = self.features(x)
44 | h = h.squeeze()
45 |
46 | x = self.l1(h)
47 | x = F.relu(x)
48 | x = self.l2(x)
49 | return h, x
50 |
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/evaluate_12ECG_score.m:
--------------------------------------------------------------------------------
1 | % This file contains functions for evaluating algorithms for the 2020 PhysioNet/
2 | % Computing in Cardiology Challenge. You can run it as follows:
3 | %
4 | % evaluate_12ECG_score(labels, outputs, scores.csv)
5 | %
6 | % where 'labels' is a directory containing files with the labels, 'outputs' is a
7 | % directory containing files with the outputs from your model, and 'scores.csv'
8 | % (optional) is a collection of scores for the algorithm outputs.
9 | %
10 | % Each file of labels or outputs must have the format described on the Challenge
11 | % webpage. The scores for the algorithm outputs include the area under the
12 | % receiver-operating characteristic curve (AUROC), the area under the recall-
13 | % precision curve (AUPRC), accuracy (fraction of correct recordings), macro F-
14 | % measure, and the Challenge metric, which assigns different weights to
15 | % different misclassification errors.
16 |
17 | function evaluate_12ECG_score(labels, outputs, output_file, class_output_file)
18 | % Check for Python and NumPy.
19 | command = 'python -V';
20 | [status, ~] = system(command);
21 | if status~=0
22 | error('Python not found: please install Python or make it available by running "python ...".');
23 | end
24 |
25 | command = 'python -c "import numpy"';
26 | [status, ~] = system(command);
27 | if status~=0
28 | error('NumPy not found: please install NumPy or make it available to Python.');
29 | end
30 |
31 | % Define command for evaluating model outputs.
32 | switch nargin
33 | case 2
34 | command = ['python evaluate_12ECG_score.py' ' ' labels ' ' outputs];
35 | case 3
36 | command = ['python evaluate_12ECG_score.py' ' ' labels ' ' outputs ' ' output_file];
37 | case 4
38 | command = ['python evaluate_12ECG_score.py' ' ' labels ' ' outputs ' ' output_file ' ' class_output_file];
39 | otherwise
40 | command = '';
41 | end
42 |
43 | % Evaluate model outputs.
44 | [~, output] = system(command);
45 | fprintf(output);
46 | end
47 |
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/dx_mapping_scored.csv:
--------------------------------------------------------------------------------
1 | Dx,SNOMED CT Code,Abbreviation,CPSC,CPSC-Extra,StPetersburg,PTB,PTB-XL,Georgia,Total,Notes
2 | 1st degree av block,270492004,IAVB,722,106,0,0,797,769,2394,
3 | atrial fibrillation,164889003,AF,1221,153,2,15,1514,570,3475,
4 | atrial flutter,164890007,AFL,0,54,0,1,73,186,314,
5 | bradycardia,426627000,Brady,0,271,11,0,0,6,288,
6 | complete right bundle branch block,713427006,CRBBB,0,113,0,0,542,28,683,We score 713427006 and 59118001 as the same diagnosis.
7 | incomplete right bundle branch block,713426002,IRBBB,0,86,0,0,1118,407,1611,
8 | left anterior fascicular block,445118002,LAnFB,0,0,0,0,1626,180,1806,
9 | left axis deviation,39732003,LAD,0,0,0,0,5146,940,6086,
10 | left bundle branch block,164909002,LBBB,236,38,0,0,536,231,1041,
11 | low qrs voltages,251146004,LQRSV,0,0,0,0,182,374,556,
12 | nonspecific intraventricular conduction disorder,698252002,NSIVCB,0,4,1,0,789,203,997,
13 | pacing rhythm,10370003,PR,0,3,0,0,296,0,299,
14 | premature atrial contraction,284470004,PAC,616,73,3,0,398,639,1729,We score 284470004 and 63593006 as the same diagnosis.
15 | premature ventricular contractions,427172004,PVC,0,188,0,0,0,0,188,We score 427172004 and 17338001 as the same diagnosis.
16 | prolonged pr interval,164947007,LPR,0,0,0,0,340,0,340,
17 | prolonged qt interval,111975006,LQT,0,4,0,0,118,1391,1513,
18 | qwave abnormal,164917005,QAb,0,1,0,0,548,464,1013,
19 | right axis deviation,47665007,RAD,0,1,0,0,343,83,427,
20 | right bundle branch block,59118001,RBBB,1857,1,2,0,0,542,2402,We score 713427006 and 59118001 as the same diagnosis.
21 | sinus arrhythmia,427393009,SA,0,11,2,0,772,455,1240,
22 | sinus bradycardia,426177001,SB,0,45,0,0,637,1677,2359,
23 | sinus rhythm,426783006,NSR,918,4,0,80,18092,1752,20846,
24 | sinus tachycardia,427084000,STach,0,303,11,1,826,1261,2402,
25 | supraventricular premature beats,63593006,SVPB,0,53,4,0,157,1,215,We score 284470004 and 63593006 as the same diagnosis.
26 | t wave abnormal,164934002,TAb,0,22,0,0,2345,2306,4673,
27 | t wave inversion,59931005,TInv,0,5,1,0,294,812,1112,
28 | ventricular premature beats,17338001,VPB,0,8,0,0,0,357,365,We score 427172004 and 17338001 as the same diagnosis.
29 |
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/README.md:
--------------------------------------------------------------------------------
1 | # PhysioNet/CinC Challenge 2020 Evaluation Metrics
2 |
3 | This repository contains the Python and MATLAB evaluation code for the PhysioNet/Computing in Cardiology Challenge 2020. The `evaluate_12ECG_score` script evaluates the output of your algorithm using the evaluation metric that is described on the [webpage](https://physionetchallenges.github.io/2020/) for the PhysioNet/CinC Challenge 2020. While this script reports multiple evaluation metric, we use the last score (`Challenge Metric`) to evaluate your algorithm.
4 |
5 | ## Python
6 |
7 | You can run the Python evaluation code by installing the NumPy Python package and running
8 |
9 | python evaluate_12ECG_score.py labels outputs scores.csv class_scores.csv
10 |
11 | where `labels` is a directory containing files with one or more labels for each 12-lead ECG recording, such as the training database on the PhysioNet webpage; `outputs` is a directory containing files with outputs produced by your algorithm for those recordings; `scores.csv` (optional) is a collection of scores for your algorithm; and `class_scores.csv` (optional) is a collection of per-class scores for your algorithm.
12 |
13 | ## MATLAB
14 |
15 | You can run the MATLAB evaluation code by installing Python and the NumPy Python package and running
16 |
17 | evaluate_12ECG_score(labels, outputs, scores.csv, class_scores.csv)
18 |
19 | where `labels` is a directory containing files with one or more labels for each 12-lead ECG recording, such as the training database on the PhysioNet webpage; `outputs` is a directory containing files with outputs produced by your algorithm for those recordings; `scores.csv` (optional) is a collection of scores for your algorithm; and `class_scores.csv` (optional) is a collection of per-class scores for your algorithm.
20 |
21 | ## Troubleshooting
22 |
23 | Unable to run this code with your code? Try one of the [baseline classifiers](https://physionetchallenges.github.io/2020/#submissions) on the [training data](https://physionetchallenges.github.io/2020/#data). Unable to install or run Python? Try [Python](https://www.python.org/downloads/), [Anaconda](https://www.anaconda.com/products/individual), or your package manager.
24 |
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/Results/physionet_2020_unofficial_scores.csv:
--------------------------------------------------------------------------------
1 | Team name,CinC Abstract #,Validation Set Score,Hidden CPSC Set Score,Hidden G12EC Set Score,Hidden Undisclosed Set Score,Test Set Score,Training code produces output?,Model uses output from training code?,Open-source license?,Registered at CinC?,Preprint at CinC?,Presented at CinC?
2 | AAIST,No abstract,0.507,0.674,0.485,0.275,0.377,Y,Y,BSD 2 License,N,N,N
3 | AImsterdam,327,0.609,0.636,0.252,-0.093,0.198,N,N,Unknown,Y,N,N
4 | BERCLAB UND,79,0.197,0.564,0.127,0.106,0.141,Y,Y,Unknown,Y,N,N
5 | BME_Feng,69,0.001,0.003,0.001,-0.030,-0.016,Y,Y,BSD 2 License,N,N,N
6 | BraveHeart400,83,0.449,0.657,0.413,-0.265,0.034,Y,Y,BSD 2 License,N,N,N
7 | BRIC,331,0.539,0.652,0.189,0.030,0.127,Y,Y,BSD 2 License,N,N,N
8 | Chapman,No abstract,-0.204,-0.282,-0.190,0.004,-0.084,Y,Y,BSD 2 License,N,N,N
9 | Connected_Health,176,0.566,0.703,0.541,0.417,0.479,Y,Y,BSD 2 License,N (rejected abstract),N,N
10 | Health team Szeged,48,0.493,0.423,0.505,0.472,0.480,Y,Y,BSD 2 License,N,N,N
11 | IBMTpeakyFinders,173,0.282,-0.072,-0.052,-0.426,-0.269,Y,Y,BSD 2 License,N,N,N
12 | Kimball_IRL,31,0.178,0.444,0.138,-0.211,-0.042,Y,Y,BSD 2 License,Y,N,Y - poster
13 | LaussenLabs,353,-0.406,-0.455,-0.390,-0.848,-0.658,Y,N,BSD 2 License,Y,Y,Y - poster
14 | LIST_AIHealthCare,120,0.216,0.152,0.229,0.173,0.192,Y,Y,BSD 2 License,N,N,N
15 | Marquette,74,0.511,0.458,0.521,0.478,0.492,Y,Y,BSD 2 License,Y,Y (but after deadline),Y - poster
16 | Medics,187,0.189,0.480,0.146,NaN,NaN,Y,Y,BSD 2 License,N,N,N
17 | MetaHeart,196,0.616,0.758,0.590,0.194,0.370,Y,Y,BSD 2 License,Y,Y,Y - poster (but no response to questions)
18 | Metformin-121,136,0.623,0.865,0.586,0.413,0.505,Y,Y,BSD 2 License,N,N,N
19 | ML Warriors,412,0.389,0.395,0.390,0.181,0.269,N,N,BSD 2 License,N,N,N
20 | NACAS_12X,180,0.645,0.846,0.202,0.000,0.127,Y,Y,BSD 2 License,Y,Y (but did not update preprint),Y - poster
21 | nebula,39,0.526,0.736,0.086,0.052,0.109,N,N,BSD 2 License,Y,Y,Y - poster
22 | NN-MIH,63,0.585,0.665,0.567,0.367,0.456,N,N,BSD 2 License,Y,Y,Y - poster
23 | NTU-Accesslab,72,0.544,0.725,0.510,NaN,NaN,Y,Y,BSD 2 License,Y,Y,Y - poster
24 | Orange Peel,145,0.650,0.813,0.621,0.161,0.364,Y,Y ,BSD 2 License,N (no abstract submission),N,N
25 | SBU_AI,307,0.416,0.513,0.016,-0.028,0.024,N,N,BSD 2 License,Y,Y,Y - 3rd talk in 2nd session
26 | SpaceOn Flattop,7,0.681,0.871,0.219,0.126,0.208,N,N,BSD 2 License,Y,Y,Y - poster
27 | try again,No abstract,0.072,0.261,-0.266,-0.753,-0.553,Y,Y,BSD 2 License,N,N,N
28 | UniA4Life,314,-0.105,0.250,-0.156,-0.523,-0.339,N,N,GNU GPL V3 License,N,N,N
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Multimodality Multi-Lead ECG Arrhythmia Classification using Self-Supervised Learning
2 |
3 | Paper link: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9926925
4 |
5 | 1. Download datasets from the PhysioNet 2020 Competition. Put in the folder ./data_folder/datasets and extract all of them .
6 | https://physionetchallenges.github.io/2020/
7 |
8 | 2. Preparing the data
9 | python data_preparation/data_extraction_without_preprocessing.py
10 | python data_preparation/reformat_memmap.py
11 |
12 | 3. Training base models
13 | python experiments/run_signal.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --save_folder ./checkpoints/base_signal
14 | python experiments/run_spectrogram.py --batch_size 256 --lr_rate 5e-3 --num_epoches 200 --gpu 0 --save_folder ./checkpoints/base_spectrogram
15 | (without gating fusion)
16 | python experiments/run_ensembled.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --save_folder ./checkpoints/base_ensemble_wogating
17 | (with gating fusion)
18 | python experiments/run_ensembled.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --gating --save_folder ./checkpoints/base_ensemble_wgating
19 |
20 | 4. Self-supervised learning for pretrained models
21 | (SimCLR)
22 | python experiments/SIMCLR_signal.py
23 | (BYOL)
24 | python experiments/BYOL_signal.py
25 | (DINO)
26 | python experiments/DINO_signal.py
27 | python experiments/DINO_spectrogram.py
28 |
29 | 5. Finetuning the main model based on the self-supervised pretrained models
30 | (SimCLR)
31 | python experiments/SIMCLR_signal_finetune.py
32 | (BYOL)
33 | python experiments/BYOL_signal_finetune.py
34 | (DINO)
35 | python experiments/run_signal.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --finetune ./checkpoints/DINO_signal_student.pth --save_folder ./checkpoints/finetune_signal
36 | python experiments/run_spectrogram.py --batch_size 256 --lr_rate 5e-3 --num_epoches 200 --gpu 0 --finetune ./checkpoints/DINO_spectrogram_student.pth --save_folder ./checkpoints/finetune_spectrogram
37 | (without gating fusion)
38 | python experiments/run_ensembled.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --finetune ./checkpoints --save_folder ./checkpoints/finetune_ensemble_wogating
39 | (with gating fusion)
40 | python experiments/run_ensembled.py --batch_size 128 --lr_rate 5e-3 --num_epoches 100 --gpu 0 --finetune ./checkpoints --gating --save_folder ./checkpoints/finetune_ensemble_wgating
41 |
42 | 6. Searching the thresholds of classes for best Challenge score
43 | python experiments/threshold_search.py --model_type signal --best-type PRC --gpu 0 --weight_folder ./checkpoints/base_signal
44 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | def weights_init_xavier(m):
6 | if isinstance(m, nn.Conv2d):
7 | torch.nn.init.xavier_normal_(m.weight.data)
8 | if m.bias is not None:
9 | torch.nn.init.normal_(m.bias.data)
10 | elif isinstance(m, nn.Conv1d):
11 | torch.nn.init.xavier_normal_(m.weight.data)
12 | if m.bias is not None:
13 | torch.nn.init.normal_(m.bias.data)
14 | elif isinstance(m, nn.BatchNorm1d):
15 | torch.nn.init.normal_(m.weight.data, mean=1, std=0.02)
16 | torch.nn.init.constant_(m.bias.data, 0)
17 | elif isinstance(m, nn.BatchNorm2d):
18 | torch.nn.init.normal_(m.weight.data, mean=1, std=0.02)
19 | torch.nn.init.constant_(m.bias.data, 0)
20 | elif isinstance(m, nn.Linear):
21 | torch.nn.init.xavier_normal_(m.weight.data)
22 | # torch.nn.init.normal_(m.bias.data)
23 |
24 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
25 | warmup_schedule = np.array([])
26 | warmup_iters = warmup_epochs * niter_per_ep
27 | if warmup_epochs > 0:
28 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
29 |
30 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
31 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
32 |
33 | schedule = np.concatenate((warmup_schedule, schedule))
34 | assert len(schedule) == epochs * niter_per_ep
35 | return schedule
36 |
37 | def set_requires_grad(model, val):
38 | for p in model.parameters():
39 | p.requires_grad = val
40 |
41 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
42 | if epoch >= freeze_last_layer:
43 | return
44 | for n, p in model.named_parameters():
45 | if "last_layer" in n:
46 | p.grad = None
47 |
48 |
49 | def open_all_layers(model):
50 | for p in model.parameters():
51 | p.requires_grad = True
52 |
53 |
54 | def open_specified_layers(model, open_layers):
55 |
56 | if isinstance(open_layers, str):
57 | open_layers = [open_layers]
58 |
59 | for layer in open_layers:
60 | assert hasattr(
61 | model, layer
62 | ), '"{}" is not an attribute of the model, please provide the correct name'.format(
63 | layer
64 | )
65 |
66 | for name, module in model.named_children():
67 | if name in open_layers:
68 | for p in module.parameters():
69 | p.requires_grad = True
70 | else:
71 | for p in module.parameters():
72 | p.requires_grad = False
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/Results/README.md:
--------------------------------------------------------------------------------
1 | # PhysioNet/CinC Challenge 2020 Results
2 |
3 | This folder contains several files with the results of the 2020 Challenge.
4 |
5 | We introduced [new scoring metric](https://physionetchallenges.github.io/2020/#scoring) for this Challenge. We used this scoring metric to evaluate and rank the Challenge entries. We included several other metrics for reference. The area under the receiver operating characteristic (AUROC), area under the precision recall curve (AUPRC), and _F_-measure scores are the macro-average of the scores across all classes. The accuracy metric is the fraction of correctly diagnosed recordings, i.e., all classes for the recording are correct. These metrics were computed by the [evaluate_12ECG_score.py](https://github.com/physionetchallenges/evaluation-2020/blob/master/evaluate_12ECG_score.py) script in this repository. Please see this script for more details of these scores.
6 |
7 | We included the scores on the following datasets:
8 |
9 | 1. __Validation Set:__ Includes recordings from the hidden CPSC and G12EC sets.
10 | 2. __Hidden CPSC Set:__ Split between the validation and test sets.
11 | 3. __Hidden G12EC Set:__ Split between the validation and test sets.
12 | 4. __Hidden Undisclosed Set:__ All recordings were part of the test sets.
13 | 5. __Test Set:__ Includes recordings from the hidden CPSC, G12EC, and undisclosed test sets.
14 |
15 | To refer to these tables in a publication, please cite [Perez Alday EA, Gu A, Shah AJ, Robichaux C, Wong AI, Liu C, Liu F, Rad AB, Elola A, Seyedi S, Li Q, Sharma A, Clifford GD*, Reyna MA*. Classification of 12-lead ECGs: the PhysioNet/Computing in Cardiology Challenge 2020. Physiol Meas. 41 (2020). doi: 10.1088/1361-6579/abc960](https://iopscience.iop.org/article/10.1088/1361-6579/abc960).
16 |
17 | 1. Official entries that were scored on the validation and test data and ranked in the Challenge:
18 | [physionet_2020_official_scores.csv](https://github.com/physionetchallenges/evaluation-2020/blob/master/Results/physionet_2020_official_scores.csv)
19 | 2. Unofficial entries that were scored on the validation and test data but unranked because they did not satisfy all of the [rules](https://physionetchallenges.github.io/2020/#rules-and-deadlines) or were unsuccessful on one or more of the test sets:
20 | [physionet_2020_unofficial_scores.csv](https://github.com/physionetchallenges/evaluation-2020/blob/master/Results/physionet_2020_unofficial_scores.csv)
21 | 3. Challenge and other scoring metrics on all official entries broken with scores for each database in the validation and test data:
22 | [physionet_2020_full_metrics_official_entries.csv](https://github.com/physionetchallenges/evaluation-2020/blob/master/Results/physionet_2020_full_metrics_official_entries.csv )
23 | 4. Per-class scoring metrics on the validation data:
24 | [physionet_2020_validation_metrics_by_class_official_entries.csv](https://github.com/physionetchallenges/evaluation-2020/blob/master/Results/physionet_2020_validation_metrics_by_class_official_entries.csv)
25 |
--------------------------------------------------------------------------------
/data_preparation/reformat_memmap.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pandas as pd
4 | import pickle
5 | from tqdm import tqdm
6 |
7 | def npys_to_memmap(npys, target_filename, max_len=0, delete_npys=True):
8 | memmap = None
9 | start = []#start_idx in current memmap file
10 | length = []#length of segment
11 | filenames= []#memmap files
12 | file_idx=[]#corresponding memmap file for sample
13 | shape=[]
14 |
15 | for idx,npy in tqdm(list(enumerate(npys))):
16 | data = np.load(npy, allow_pickle=True)
17 | if(memmap is None or (max_len>0 and start[-1]+length[-1]>max_len)):
18 | filenames.append(target_filename)
19 |
20 | if(memmap is not None):#an existing memmap exceeded max_len
21 | shape.append([start[-1]+length[-1]]+[l for l in data.shape[1:]])
22 | del memmap
23 | #create new memmap
24 | start.append(0)
25 | length.append(data.shape[0])
26 | memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='w+', shape=data.shape)
27 | else:
28 | #append to existing memmap
29 | start.append(start[-1]+length[-1])
30 | length.append(data.shape[0])
31 | memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='r+', shape=tuple([start[-1]+length[-1]]+[l for l in data.shape[1:]]))
32 |
33 | #store mapping memmap_id to memmap_file_id
34 | file_idx.append(len(filenames)-1)
35 | #insert the actual data
36 | memmap[start[-1]:start[-1]+length[-1]]=data[:]
37 | memmap.flush()
38 | if(delete_npys is True):
39 | npy.unlink()
40 | del memmap
41 |
42 | #append final shape if necessary
43 | if(len(shape)= self.cur_cycle_steps:
68 | self.cycle += 1
69 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
70 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
71 | else:
72 | if epoch >= self.first_cycle_steps:
73 | if self.cycle_mult == 1.:
74 | self.step_in_cycle = epoch % self.first_cycle_steps
75 | self.cycle = epoch // self.first_cycle_steps
76 | else:
77 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
78 | self.cycle = n
79 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
80 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
81 | else:
82 | self.cur_cycle_steps = self.first_cycle_steps
83 | self.step_in_cycle = epoch
84 |
85 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
86 | self.last_epoch = math.floor(epoch)
87 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
88 | param_group['lr'] = lr
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/dx_mapping_unscored.csv:
--------------------------------------------------------------------------------
1 | Dx,SNOMED CT Code,Abbreviation,CPSC,CPSC-Extra,StPetersburg,PTB,PTB-XL,Georgia,Total
2 | 2nd degree av block,195042002,IIAVB,0,21,0,0,14,23,58
3 | abnormal QRS,164951009,abQRS,0,0,0,0,3389,0,3389
4 | accelerated junctional rhythm,426664006,AJR,0,0,0,0,0,19,19
5 | acute myocardial infarction,57054005,AMI,0,0,6,0,0,0,6
6 | acute myocardial ischemia,413444003,AMIs,0,1,0,0,0,1,2
7 | anterior ischemia,426434006,AnMIs,0,0,0,0,44,281,325
8 | anterior myocardial infarction,54329005,AnMI,0,62,0,0,354,0,416
9 | atrial bigeminy,251173003,AB,0,0,3,0,0,0,3
10 | atrial fibrillation and flutter,195080001,AFAFL,0,39,0,0,0,2,41
11 | atrial hypertrophy,195126007,AH,0,2,0,0,0,60,62
12 | atrial pacing pattern,251268003,AP,0,0,0,0,0,52,52
13 | atrial tachycardia,713422000,ATach,0,15,0,0,0,28,43
14 | atrioventricular junctional rhythm,29320008,AVJR,0,6,0,0,0,0,6
15 | av block,233917008,AVB,0,5,0,0,0,74,79
16 | blocked premature atrial contraction,251170000,BPAC,0,2,3,0,0,0,5
17 | brady tachy syndrome,74615001,BTS,0,1,1,0,0,0,2
18 | bundle branch block,6374002,BBB,0,0,1,20,0,116,137
19 | cardiac dysrhythmia,698247007,CD,0,0,0,16,0,0,16
20 | chronic atrial fibrillation,426749004,CAF,0,1,0,0,0,0,1
21 | chronic myocardial ischemia,413844008,CMI,0,161,0,0,0,0,161
22 | complete heart block,27885002,CHB,0,27,0,0,16,8,51
23 | congenital incomplete atrioventricular heart block,204384007,CIAHB,0,0,0,2,0,0,2
24 | coronary heart disease,53741008,CHD,0,0,16,21,0,0,37
25 | decreased qt interval,77867006,SQT,0,1,0,0,0,0,1
26 | diffuse intraventricular block,82226007,DIB,0,1,0,0,0,0,1
27 | early repolarization,428417006,ERe,0,0,0,0,0,140,140
28 | fusion beats,13640000,FB,0,0,7,0,0,0,7
29 | heart failure,84114007,HF,0,0,0,7,0,0,7
30 | heart valve disorder,368009,HVD,0,0,0,6,0,0,6
31 | high t-voltage,251259000,HTV,0,1,0,0,0,0,1
32 | idioventricular rhythm,49260003,IR,0,0,2,0,0,0,2
33 | incomplete left bundle branch block,251120003,ILBBB,0,42,0,0,77,86,205
34 | indeterminate cardiac axis,251200008,ICA,0,0,0,0,156,0,156
35 | inferior ischaemia,425419005,IIs,0,0,0,0,219,451,670
36 | inferior ST segment depression,704997005,ISTD,0,1,0,0,0,0,1
37 | junctional escape,426995002,JE,0,4,0,0,0,5,9
38 | junctional premature complex,251164006,JPC,0,2,0,0,0,0,2
39 | junctional tachycardia,426648003,JTach,0,2,0,0,0,4,6
40 | lateral ischaemia,425623009,LIs,0,0,0,0,142,903,1045
41 | left atrial abnormality,253352002,LAA,0,0,0,0,0,72,72
42 | left atrial enlargement,67741000119109,LAE,0,1,0,0,427,870,1298
43 | left atrial hypertrophy,446813000,LAH,0,40,0,0,0,0,40
44 | left posterior fascicular block,445211001,LPFB,0,0,0,0,177,25,202
45 | left ventricular hypertrophy,164873001,LVH,0,158,10,0,2359,1232,3759
46 | left ventricular strain,370365005,LVS,0,1,0,0,0,0,1
47 | mobitz type i wenckebach atrioventricular block,54016002,MoI,0,0,3,0,0,0,3
48 | myocardial infarction,164865005,MI,0,376,9,368,5261,7,6021
49 | myocardial ischemia,164861001,MIs,0,384,0,0,2175,0,2559
50 | nonspecific st t abnormality,428750005,NSSTTA,0,1290,0,0,381,1883,3554
51 | old myocardial infarction,164867002,OldMI,0,1168,0,0,0,0,1168
52 | paired ventricular premature complexes,251182009,VPVC,0,0,23,0,0,0,23
53 | paroxysmal atrial fibrillation,282825002,PAF,0,0,1,1,0,0,2
54 | paroxysmal supraventricular tachycardia,67198005,PSVT,0,0,3,0,24,0,27
55 | paroxysmal ventricular tachycardia,425856008,PVT,0,0,15,0,0,0,15
56 | r wave abnormal,164921003,RAb,0,1,0,0,0,10,11
57 | rapid atrial fibrillation,314208002,RAF,0,0,0,2,0,0,2
58 | right atrial abnormality,253339007,RAAb,0,0,0,0,0,14,14
59 | right atrial hypertrophy,446358003,RAH,0,18,0,0,99,0,117
60 | right ventricular hypertrophy,89792004,RVH,0,20,0,0,126,86,232
61 | s t changes,55930002,STC,0,1,0,0,770,6,777
62 | shortened pr interval,49578007,SPRI,0,3,0,0,0,2,5
63 | sinoatrial block,65778007,SAB,0,9,0,0,0,0,9
64 | sinus node dysfunction,60423000,SND,0,0,2,0,0,0,2
65 | st depression,429622005,STD,869,57,4,0,1009,38,1977
66 | st elevation,164931005,STE,220,66,4,0,28,134,452
67 | st interval abnormal,164930006,STIAb,0,481,2,0,0,992,1475
68 | supraventricular bigeminy,251168009,SVB,0,0,1,0,0,0,1
69 | supraventricular tachycardia,426761007,SVT,0,3,1,0,27,32,63
70 | suspect arm ecg leads reversed,251139008,ALR,0,0,0,0,0,12,12
71 | transient ischemic attack,266257000,TIA,0,0,7,0,0,0,7
72 | u wave abnormal,164937009,UAb,0,1,0,0,0,0,1
73 | ventricular bigeminy,11157007,VBig,0,5,9,0,82,2,98
74 | ventricular ectopics,164884008,VEB,700,0,49,0,1154,41,1944
75 | ventricular escape beat,75532003,VEsB,0,3,1,0,0,0,4
76 | ventricular escape rhythm,81898007,VEsR,0,1,0,0,0,1,2
77 | ventricular fibrillation,164896001,VF,0,10,0,25,0,3,38
78 | ventricular flutter,111288001,VFL,0,1,0,0,0,0,1
79 | ventricular hypertrophy,266249003,VH,0,5,0,13,30,71,119
80 | ventricular pacing pattern,251266004,VPP,0,0,0,0,0,46,46
81 | ventricular pre excitation,195060002,VPEx,0,6,0,0,0,2,8
82 | ventricular tachycardia,164895002,VTach,0,1,1,10,0,0,12
83 | ventricular trigeminy,251180001,VTrig,0,4,4,0,20,1,29
84 | wandering atrial pacemaker,195101003,WAP,0,0,0,0,0,7,7
85 | wolff parkinson white pattern,74390002,WPW,0,0,4,2,80,2,88
86 |
--------------------------------------------------------------------------------
/experiments/BYOL_signal.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import os
3 | import numpy as np
4 | from tqdm import tqdm
5 | import torch
6 | from torch.optim.lr_scheduler import CosineAnnealingLR
7 | from torch.utils.data import DataLoader
8 | import torch.nn.functional as F
9 | import torch.nn as nn
10 |
11 | import sys
12 | current_path = os.getcwd()
13 | sys.path.append(current_path)
14 |
15 | from models.signal_model import signal_model_byol
16 | from utils.contrastive_dataloader import ECG_contrastive_dataset
17 | from utils.eval_tools import load_weights
18 | from utils.optimizers import LARS
19 | from utils.eval_tools import load_weights
20 | from utils.tools import weights_init_xavier, set_requires_grad
21 |
22 | ctx = "cuda:0" if torch.cuda.is_available() else 'cpu'
23 |
24 | class MLPHead(nn.Module):
25 | def __init__(self, in_channels, mlp_hidden_size, projection_size):
26 | super(MLPHead, self).__init__()
27 |
28 | self.net = nn.Sequential(
29 | nn.Linear(in_channels, mlp_hidden_size),
30 | nn.BatchNorm1d(mlp_hidden_size),
31 | nn.ReLU(inplace=True),
32 | nn.Linear(mlp_hidden_size, projection_size)
33 | )
34 |
35 | def forward(self, x):
36 | return self.net(x)
37 |
38 | def regression_loss(x, y):
39 | x = F.normalize(x, dim=1)
40 | y = F.normalize(y, dim=1)
41 | return 2 - 2 * (x * y).sum(dim=-1)
42 |
43 | def run():
44 | root_folder = './data_folder'
45 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing')
46 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']]
47 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
48 |
49 | no_channels = 12
50 | signal_size = 250
51 | train_stride = signal_size
52 | train_chunk_length = 0
53 |
54 |
55 | transforms = ["TimeOut_difflead","GaussianNoise"]
56 |
57 | batch_size = 1024
58 | learning_rate = 1e-3
59 | no_epoches = 400
60 |
61 | get_mean = np.load(os.path.join(data_folder,"mean.npy"))
62 | get_std = np.load(os.path.join(data_folder,"std.npy"))
63 |
64 | t_params = {"gaussian_scale":[0.005,0.025], "global_crop_scale": [0.5, 1.0], "local_crop_scale": [0.1, 0.5],
65 | "output_size": 250, "warps": 3, "radius": 10, "shift_range":[0.2,0.5],
66 | "epsilon": 10, "magnitude_range": [0.5, 2], "downsample_ratio": 0.2, "to_crop_ratio_range": [0.2, 0.4],
67 | "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1, "stats_mean":get_mean,"stats_std":get_std}
68 |
69 |
70 | train_dataset = ECG_contrastive_dataset(summary_folder=data_folder, signal_size=signal_size, stride=train_stride,
71 | chunk_length=train_chunk_length, transforms=transforms,t_params=t_params,
72 | equivalent_classes=equivalent_classes, sample_items_per_record=1,random_crop=True)
73 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size,drop_last=True)
74 |
75 | no_classes = 24
76 | online_network = signal_model_byol(no_classes)
77 | target_network = deepcopy(online_network)
78 | online_network.to(ctx)
79 | target_network.to(ctx)
80 |
81 | set_requires_grad(target_network,False)
82 |
83 | # optimizer = torch.optim.Adam(list(online_network.parameters()) + list(target_network.parameters()),lr=learning_rate)
84 | optimizer = torch.optim.Adam(online_network.parameters(),lr=learning_rate)
85 | # optimizer = LARS(online_network.parameters(),lr=0.1,weight_decay=0.0048)
86 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-4, last_epoch=-1)
87 |
88 |
89 | optimizer.zero_grad()
90 | optimizer.step()
91 |
92 | lowest_train_loss = 2
93 | for epoch in range(1,no_epoches+1):
94 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches))
95 | print('Current learning rate: ',optimizer.param_groups[0]['lr'])
96 | scheduler_steplr.step()
97 | online_network.train()
98 | train_loss = 0
99 | train_acc = 0
100 |
101 | for batch_idx, sample in enumerate(tqdm(train_dataloader)):
102 | data_i = sample['sig_i'].to(ctx).float()
103 | data_j = sample['sig_j'].to(ctx).float()
104 |
105 | # features, projector, predictor, output
106 | h1a,z1a,t1a,_ = online_network(data_i)
107 | h1b,z1b,t1b,_ = online_network(data_j)
108 |
109 | with torch.no_grad():
110 | h2a,z2a,t2a,_ = target_network(data_j)
111 | h2b,z2b,t2b,_ = target_network(data_j)
112 |
113 | # image 1 to image 2 loss
114 | loss = regression_loss(t1a, z2b)
115 | loss += regression_loss(t1b,z2a)
116 | total_loss = loss.mean()
117 | # image 2 to image 1 loss
118 |
119 | train_loss += total_loss.item()
120 |
121 | optimizer.zero_grad()
122 | total_loss.backward()
123 | optimizer.step()
124 |
125 | t_d = 0.9
126 | # t_d = 0.996
127 | for param_q, param_k in zip(online_network.parameters(), target_network.parameters()):
128 | param_k.data = param_k.data * t_d + param_q.data * (1. - t_d)
129 |
130 | whole_train_loss = train_loss / (batch_idx + 1)
131 | print(f'Train Loss: {whole_train_loss}')
132 | if whole_train_loss < lowest_train_loss:
133 | lowest_train_loss = whole_train_loss
134 | torch.save(online_network.state_dict(), f'./checkpoints/BYOL_signal.pth')
135 |
136 |
137 | if __name__ == "__main__":
138 | run()
--------------------------------------------------------------------------------
/models/seresnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch
4 |
5 | def conv3x3_1d(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride,
8 | padding=1, bias=False)
9 |
10 | def conv5x5_1d(in_planes, out_planes, stride=1):
11 | return nn.Conv1d(in_planes, out_planes, kernel_size=5, stride=stride,
12 | padding=1, bias=False)
13 |
14 | def conv7x7_1d(in_planes, out_planes, stride=1):
15 | return nn.Conv1d(in_planes, out_planes, kernel_size=7, stride=stride,
16 | padding=1, bias=False)
17 |
18 | class SELayer_1d(nn.Module):
19 | def __init__(self, channel, reduction=16):
20 | super(SELayer_1d, self).__init__()
21 | # self.avg_pool = nn.AdaptiveAvgPool2d(1)
22 | self.avg_pool = nn.AdaptiveAvgPool1d(1)
23 | self.fc = nn.Sequential(
24 | nn.Linear(channel, channel // reduction, bias=False),
25 | nn.ReLU(inplace=True),
26 | nn.Linear(channel // reduction, channel, bias=False),
27 | nn.Sigmoid()
28 | )
29 |
30 | def forward(self, x):
31 | b, c, _ = x.size()
32 | y = self.avg_pool(x).view(b, c)
33 | y = self.fc(y).view(b, c, 1)
34 | return x * y.expand_as(x)
35 |
36 | class SE_BasicBlock3x3_1d(nn.Module):
37 | expansion = 1
38 | def __init__(self, inplanes3, planes, stride=1, downsample=None,reduction=16):
39 | super(SE_BasicBlock3x3_1d, self).__init__()
40 | self.conv1 = conv3x3_1d(inplanes3, planes, stride)
41 | self.bn1 = nn.BatchNorm1d(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = conv3x3_1d(planes, planes)
44 | self.bn2 = nn.BatchNorm1d(planes)
45 | self.se = SELayer_1d(planes, reduction)
46 | self.downsample = downsample
47 | self.stride = stride
48 |
49 | def forward(self, x):
50 | residual = x
51 |
52 | out = self.conv1(x)
53 | out = self.bn1(out)
54 | out = self.relu(out)
55 |
56 | out = self.conv2(out)
57 | out = self.bn2(out)
58 | out = self.se(out)
59 |
60 | if self.downsample is not None:
61 | residual = self.downsample(x)
62 |
63 | out += residual
64 | out = self.relu(out)
65 |
66 | return out
67 |
68 | class SEResNet_1d(nn.Module):
69 |
70 | def __init__(self, input_channel, block, layers, num_classes=1000):
71 | self.inplanes = 64
72 | super(SEResNet_1d, self).__init__()
73 |
74 | self.conv1 = nn.Conv1d(input_channel, 64, kernel_size=7, stride=2, padding=3,
75 | bias=False)
76 | self.bn1 = nn.BatchNorm1d(64)
77 | self.relu = nn.ReLU(inplace=True)
78 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
79 |
80 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
81 | # bias=False)
82 | # self.bn1 = nn.BatchNorm2d(64)
83 | # self.relu = nn.ReLU(inplace=True)
84 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
85 | self.layer1 = self._make_layer(block, 64, layers[0])
86 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
87 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
88 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
89 | # self.avgpool = nn.AvgPool2d(7)
90 | self.avgpool = nn.AvgPool1d(8)
91 | self.fc = nn.Linear(512 * block.expansion, num_classes)
92 |
93 | def _make_layer(self, block, planes, blocks, stride=1):
94 | downsample = None
95 | if stride != 1 or self.inplanes != planes * block.expansion:
96 | # downsample = nn.Sequential(
97 | # nn.Conv2d(self.inplanes, planes * block.expansion,
98 | # kernel_size=1, stride=stride, bias=False),
99 | # nn.BatchNorm2d(planes * block.expansion),
100 | # )
101 | downsample = nn.Sequential(
102 | nn.Conv1d(self.inplanes, planes * block.expansion,
103 | kernel_size=1, stride=stride, bias=False),
104 | nn.BatchNorm1d(planes * block.expansion),
105 | )
106 |
107 | layers = []
108 | layers.append(block(self.inplanes, planes, stride, downsample))
109 | self.inplanes = planes * block.expansion
110 | for i in range(1, blocks):
111 | layers.append(block(self.inplanes, planes))
112 |
113 | return nn.Sequential(*layers)
114 |
115 | def forward(self, x):
116 | x = self.conv1(x)
117 | x = self.bn1(x)
118 | x = self.relu(x)
119 | x = self.maxpool(x)
120 |
121 | x = self.layer1(x)
122 | x = self.layer2(x)
123 | x = self.layer3(x)
124 | x = self.layer4(x)
125 |
126 | x = self.avgpool(x)
127 | x = x.view(x.size(0), -1)
128 | x = self.fc(x)
129 |
130 | return x
131 |
132 |
133 | def se_resnet18_1d(input_channel, num_classes=1000):
134 | """Constructs a ResNet-18 model.
135 | Args:
136 | pretrained (bool): If True, returns a model pre-trained on ImageNet
137 | """
138 | model = SEResNet_1d(input_channel, SE_BasicBlock3x3_1d, [2, 2, 2, 2], num_classes)
139 | return model
140 |
141 |
142 | def se_resnet34_1d(input_channel, num_classes=1000):
143 | """Constructs a ResNet-34 model.
144 | Args:
145 | pretrained (bool): If True, returns a model pre-trained on ImageNet
146 | """
147 | model = SEResNet_1d(input_channel, SE_BasicBlock3x3_1d, [3, 4, 6, 3], num_classes)
148 | return model
--------------------------------------------------------------------------------
/data_preparation/data_extraction_without_preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | import wfdb
7 | from scipy.signal import resample
8 | from stratify import stratify
9 |
10 | save_folder = "./data_folder/extracted_data_without_preprocessing"
11 | save_summary = "./data_folder/data_summary_without_preprocessing"
12 | raw_data_cinc = "./data_folder/datasets"
13 | dataset_names = ["ICBEB2018","ICBEB2018_2","INCART","PTB","PTB-XL","Georgia"]
14 | mapping_scored_path = "./data_folder/evaluation-2020-master/dx_mapping_scored.csv" # 27 main labels
15 | target_fs = 100
16 | strat_folds = 10
17 | channels = 12
18 |
19 | mapping_scored_df = pd.read_csv(mapping_scored_path)
20 | dx_mapping_snomed_abbrev = {a:b for [a,b] in list(mapping_scored_df.apply(lambda row: [row["SNOMED CT Code"],row["Abbreviation"]],axis=1))}
21 | list_label_available = np.array(mapping_scored_df["SNOMED CT Code"])
22 |
23 | CPSC_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[0],'**/*.hea'))
24 | print('No files in CPSC:', len(CPSC_files))
25 | CPSC_extra_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[1],'**/*.hea'))
26 | print('No files in CPSC-Extra:', len(CPSC_extra_files))
27 | SPeter_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[2],'**/*.hea'))
28 | print('No files in StPetersburg:', len(SPeter_files))
29 | PTB_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[3],'**/*.hea'))
30 | print('No files in PTB:', len(PTB_files))
31 | PTBXL_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[4],'**/*.hea'))
32 | print('No files in PTB-XL:', len(PTBXL_files))
33 | Georgia_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[5],'**/*.hea'))
34 | print('No files in Georgia:', len(Georgia_files))
35 |
36 | all_files = CPSC_files + CPSC_extra_files + SPeter_files + PTB_files + PTBXL_files + Georgia_files
37 | print('Total no files:',len(all_files))
38 | # (7500, 12)
39 | # {'fs': 500, 'sig_len': 7500, 'n_sig': 12, 'base_date': None, 'base_time': datetime.time(0, 0, 12),
40 | # 'units': ['mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV'],
41 | # 'sig_name': ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'],
42 | # 'comments': ['Age: 74', 'Sex: Male', 'Dx: 59118001', 'Rx: Unknown', 'Hx: Unknown', 'Sx: Unknown']}
43 |
44 | skip_files = 0
45 | metadata = []
46 | for idx, hea_file in enumerate(tqdm(all_files)):
47 | file_name = hea_file.split("/")[-1].split(".hea")[0]
48 | data_folder = hea_file.split("/")[-3]
49 | sigbufs, header = wfdb.rdsamp(str(hea_file)[:-4])
50 |
51 | if(np.any(np.isnan(sigbufs))):
52 | print("Warning:",str(hea_file),"is corrupt. Skipping.")
53 | continue
54 |
55 | labels=[]
56 | age=np.nan
57 | sex="nan"
58 | for l in header["comments"]:
59 | arrs = l.strip().split(' ')
60 | if l.startswith('Dx:'):
61 | for x in arrs[1].split(','):
62 | if int(x) in list_label_available:
63 | labels.append(x)
64 | elif l.startswith('Age:'):
65 | try:
66 | age = int(arrs[1])
67 | except:
68 | age= np.nan
69 | elif l.startswith('Sex:'):
70 | sex = arrs[1].strip().lower()
71 | if(sex=="m"):
72 | sex="male"
73 | elif(sex=="f"):
74 | sex="female"
75 |
76 | if len(labels) == 0:
77 | skip_files += 1
78 | continue
79 |
80 | ori_fs = header['fs']
81 | factor = target_fs/ori_fs
82 | timesteps_new = int(len(sigbufs)*factor)
83 | data = np.zeros((timesteps_new, channels), dtype=np.float32)
84 | for i in range(channels):
85 | data[:,i] = resample(sigbufs[:,0],timesteps_new)
86 |
87 | np.save(os.path.join(save_folder,file_name+".npy"),data)
88 |
89 | metadata.append({"data":file_name+".npy","label":labels,"sex":sex,"age":age,"dataset":data_folder})
90 |
91 | df = pd.DataFrame(metadata)
92 | lbl_itos = np.unique([item for sublist in list(df.label) for item in sublist])
93 | lbl_stoi = {s:i for i,s in enumerate(lbl_itos)}
94 | df["label"] = df["label"].apply(lambda x: [lbl_stoi[y] for y in x])
95 |
96 | df["strat_fold"]=-1
97 | for ds in np.unique(df["dataset"]):
98 | print("Creating CV folds:",ds)
99 | dfx = df[df.dataset==ds]
100 | idxs = np.array(dfx.index.values)
101 | lbl_itosx = np.unique([item for sublist in list(dfx.label) for item in sublist])
102 | stratified_ids = stratify(list(dfx["label"]), lbl_itosx, [1./strat_folds]*strat_folds)
103 |
104 | for i,split in enumerate(stratified_ids):
105 | df.loc[idxs[split],"strat_fold"]=i
106 |
107 | print("Add Mean Column")
108 | df["data_mean"]=df["data"].apply(lambda x: np.mean(np.load(x if save_folder is None else os.path.join(save_folder,x), allow_pickle=True),axis=0))
109 | print("Add Std Column")
110 | df["data_std"]=df["data"].apply(lambda x: np.std(np.load(x if data_folder is None else os.path.join(save_folder,x), allow_pickle=True),axis=0))
111 | print("Add Length Column")
112 | df["data_length"]=df["data"].apply(lambda x: len(np.load(x if data_folder is None else os.path.join(save_folder,x), allow_pickle=True)))
113 |
114 | #save means and stds
115 | df_mean = df["data_mean"].mean()
116 | df_std = df["data_std"].mean()
117 |
118 | # save dataset
119 | df.to_pickle(os.path.join(save_summary,'df.pkl'),protocol=4)
120 | np.save(os.path.join(save_summary,"lbl_itos.npy"),lbl_itos)
121 | np.save(os.path.join(save_summary,"mean.npy"),df_mean)
122 | np.save(os.path.join(save_summary,"std.npy"),df_std)
123 |
124 |
125 | # file1 = 'df.pkl'
126 | # file2 = 'lbl_itos.npy'
127 | # file3 = 'memmap.npy'
128 | # file4 = 'memmap_meta.npz'
129 | # file5 = 'df_memmap.pkl'
130 | # file6 = 'mean.npy'
131 | # file7 = 'std.npy'
--------------------------------------------------------------------------------
/data_preparation/stratify.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def stratify(data, classes, ratios, samples_per_group=None):
4 | """Stratifying procedure. Modified from https://vict0rs.ch/2018/05/24/sample-multilabel-dataset/ (based on Sechidis 2011)
5 |
6 | data is a list of lists: a list of labels, for each sample (possibly containing duplicates not multi-hot encoded).
7 |
8 | classes is the list of classes each label can take
9 |
10 | ratios is a list, summing to 1, of how the dataset should be split
11 |
12 | samples_per_group: list with number of samples per patient/group
13 |
14 | """
15 | np.random.seed(0) # fix the random seed
16 | # data is now always a list of lists; len(data) is the number of patients; data[i] is the list of all labels for patient i (possibly multiple identical entries)
17 |
18 | if(samples_per_group is None):
19 | samples_per_group = np.ones(len(data))
20 |
21 | #size is the number of ecgs
22 | size = np.sum(samples_per_group)
23 |
24 | # Organize data per label: for each label l, per_label_data[l] contains the list of patients
25 | # in data which have this label (potentially multiple identical entries)
26 | per_label_data = {c: [] for c in classes}
27 | for i, d in enumerate(data):
28 | for l in d:
29 | per_label_data[l].append(i)
30 |
31 | # In order not to compute lengths each time, they are tracked here.
32 | subset_sizes = [r * size for r in ratios] #list of subset_sizes in terms of ecgs
33 | per_label_subset_sizes = { c: [r * len(per_label_data[c]) for r in ratios] for c in classes } #dictionary with label: list of subset sizes in terms of patients
34 |
35 | # For each subset we want, the set of sample-ids which should end up in it
36 | stratified_data_ids = [set() for _ in range(len(ratios))] #initialize empty
37 |
38 | # For each sample in the data set
39 | print("Starting fold distribution...")
40 | size_prev=size+1 #just for output
41 | while size > 0:
42 | if(int(size_prev/1000) > int(size/1000)):
43 | print("Remaining entries to distribute:",size,"non-empty labels:", np.sum([1 for l, label_data in per_label_data.items() if len(label_data)>0]))
44 | size_prev=size
45 | # Compute |Di|
46 | lengths = {
47 | l: len(label_data)
48 | for l, label_data in per_label_data.items()
49 | } #dictionary label: number of ecgs with this label that have not been assigned to a fold yet
50 | try:
51 | # Find label of smallest |Di|
52 | label = min({k: v for k, v in lengths.items() if v > 0}, key=lengths.get)
53 | except ValueError:
54 | # If the dictionary in `min` is empty we get a Value Error.
55 | # This can happen if there are unlabeled samples.
56 | # In this case, `size` would be > 0 but only samples without label would remain.
57 | # "No label" could be a class in itself: it's up to you to format your data accordingly.
58 | break
59 | # For each patient with label `label` get patient and corresponding counts
60 | unique_samples, unique_counts = np.unique(per_label_data[label],return_counts=True)
61 | idxs_sorted = np.argsort(unique_counts, kind='stable')[::-1]
62 | unique_samples = unique_samples[idxs_sorted] # this is a list of all patient ids with this label sort by size descending
63 | unique_counts = unique_counts[idxs_sorted] # these are the corresponding counts
64 |
65 | # loop through all patient ids with this label
66 | for current_id, current_count in zip(unique_samples,unique_counts):
67 |
68 | subset_sizes_for_label = per_label_subset_sizes[label] #current subset sizes for the chosen label
69 |
70 | # Find argmax clj i.e. subset in greatest need of the current label
71 | largest_subsets = np.argwhere(subset_sizes_for_label == np.amax(subset_sizes_for_label)).flatten()
72 |
73 | # if there is a single best choice: assign it
74 | if len(largest_subsets) == 1:
75 | subset = largest_subsets[0]
76 | # If there is more than one such subset, find the one in greatest need of any label
77 | else:
78 | largest_subsets2 = np.argwhere(np.array(subset_sizes)[largest_subsets] == np.amax(np.array(subset_sizes)[largest_subsets])).flatten()
79 | subset = largest_subsets[np.random.choice(largest_subsets2)]
80 |
81 | # Store the sample's id in the selected subset
82 | stratified_data_ids[subset].add(current_id)
83 |
84 | # There is current_count fewer samples to distribute
85 | size -= samples_per_group[current_id]
86 | # The selected subset needs current_count fewer samples
87 | subset_sizes[subset] -= samples_per_group[current_id]
88 |
89 | # In the selected subset, there is one more example for each label
90 | # the current sample has
91 | for l in data[current_id]:
92 | per_label_subset_sizes[l][subset] -= 1
93 |
94 | # Remove the sample from the dataset, meaning from all per_label dataset created
95 | for x in per_label_data.keys():
96 | per_label_data[x] = [y for y in per_label_data[x] if y!=current_id]
97 |
98 | # Create the stratified dataset as a list of subsets, each containing the orginal labels
99 | stratified_data_ids = [sorted(strat) for strat in stratified_data_ids]
100 | #stratified_data = [
101 | # [data[i] for i in strat] for strat in stratified_data_ids
102 | #]
103 |
104 | # Return both the stratified indexes, to be used to sample the `features` associated with your labels
105 | # And the stratified labels dataset
106 |
107 | #return stratified_data_ids, stratified_data
108 | return stratified_data_ids
109 |
--------------------------------------------------------------------------------
/models/ensemble_model.py:
--------------------------------------------------------------------------------
1 | from models.xresnet1d import xresnet1d50
2 | from models.seresnet2d import se_resnet34
3 | import torch.nn.functional as F
4 | import torch
5 | import torch.nn as nn
6 |
7 | class ensemble_model(nn.Module):
8 | def __init__(self, no_classes=24, gate=False, w_time=None, w_spec=None,device=None):
9 | super(ensemble_model,self).__init__()
10 | # gating encoding
11 | self.gate = gate
12 |
13 | # Time series module
14 | self.time_backbone = xresnet1d50(widen=1.0)
15 | time_list_of_modules = list(self.time_backbone.children())
16 | self.time_features = nn.Sequential(*time_list_of_modules[:-1], time_list_of_modules[-1][0])
17 | time_num_ftrs = self.time_backbone[-1][-1].in_features
18 | self.time_backbone[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2)
19 |
20 | if w_time is not None:
21 | time_state_dict = torch.load(w_time,map_location=device)
22 | self.time_features.load_state_dict(time_state_dict,strict=False)
23 |
24 | self.spec_backbone = se_resnet34()
25 | self.spec_backbone.conv1 = nn.Conv2d(12, 64, kernel_size=7, stride=2, padding=3)
26 | spec_list_of_modules = list(self.spec_backbone.children())
27 | self.spec_features = nn.Sequential(*spec_list_of_modules[:-1])
28 | spec_num_ftrs = self.spec_backbone.fc.in_features
29 |
30 | if w_spec is not None:
31 | spec_state_dict = torch.load(w_spec,map_location=device)
32 | self.spec_features.load_state_dict(spec_state_dict,strict=False)
33 |
34 | if self.gate:
35 | num_ftrs = time_num_ftrs + spec_num_ftrs
36 | self.gate_fc = nn.Linear(num_ftrs,2)
37 | self.fc = nn.Sequential(
38 | nn.Linear(in_features=num_ftrs,out_features=num_ftrs//2),
39 | nn.Linear(in_features=num_ftrs//2,out_features=no_classes)
40 | )
41 | else:
42 | num_ftrs = time_num_ftrs + spec_num_ftrs
43 | self.fc = nn.Sequential(
44 | nn.Linear(in_features=num_ftrs,out_features=num_ftrs//2),
45 | nn.Linear(in_features=num_ftrs//2,out_features=no_classes)
46 | )
47 |
48 |
49 | def forward(self, x_sig, x_spec):
50 | h_time = self.time_features(x_sig)
51 | h_time = h_time.squeeze()
52 |
53 | h_spec = self.spec_features(x_spec)
54 | h_spec = h_spec.squeeze()
55 |
56 | if self.gate:
57 | h_gate = F.softmax(self.gate_fc(torch.cat((h_time,h_spec),dim=1)),dim=1)
58 | h_encode = torch.cat([h_time*h_gate[:,0:1],h_spec*h_gate[:,1:2]],dim=1)
59 | x = self.fc(h_encode)
60 | return x
61 | else:
62 | h_comb = torch.cat((h_time,h_spec),1)
63 | x = self.fc(h_comb)
64 | return x
65 |
66 |
67 |
68 | class ensemble_model_3head(nn.Module):
69 | def __init__(self, no_classes=24,w_time=None, w_spec=None,device=None):
70 | super(ensemble_model_3head,self).__init__()
71 |
72 | # Time series module
73 | self.time_backbone = xresnet1d50(widen=1.0)
74 | time_list_of_modules = list(self.time_backbone.children())
75 | self.time_features = nn.Sequential(*time_list_of_modules[:-1], time_list_of_modules[-1][0])
76 | time_num_ftrs = self.time_backbone[-1][-1].in_features
77 | self.time_backbone[0][0] = nn.Conv1d(12, 32, kernel_size=5, stride=2, padding=2)
78 |
79 | if w_time is not None:
80 | time_state_dict = torch.load(w_time,map_location=device)
81 | self.time_features.load_state_dict(time_state_dict,strict=False)
82 |
83 | self.spec_backbone = se_resnet34()
84 | self.spec_backbone.conv1 = nn.Conv2d(12, 64, kernel_size=7, stride=2, padding=3)
85 | spec_list_of_modules = list(self.spec_backbone.children())
86 | self.spec_features = nn.Sequential(*spec_list_of_modules[:-1])
87 | spec_num_ftrs = self.spec_backbone.fc.in_features
88 |
89 | if w_spec is not None:
90 | spec_state_dict = torch.load(w_spec,map_location=device)
91 | self.spec_features.load_state_dict(spec_state_dict,strict=False)
92 |
93 |
94 | num_ftrs = time_num_ftrs + spec_num_ftrs
95 | self.gate_fc = nn.Linear(num_ftrs,2)
96 |
97 | self.fc = nn.Sequential(
98 | nn.Linear(in_features=num_ftrs,out_features=num_ftrs//2),
99 | nn.Linear(in_features=num_ftrs//2,out_features=no_classes)
100 | )
101 | self.fc_time = nn.Sequential(
102 | nn.Linear(in_features=time_num_ftrs,out_features=time_num_ftrs//2),
103 | nn.Linear(in_features=time_num_ftrs//2,out_features=no_classes)
104 | )
105 | self.fc_spec = nn.Sequential(
106 | nn.Linear(in_features=spec_num_ftrs,out_features=spec_num_ftrs//2),
107 | nn.Linear(in_features=spec_num_ftrs//2,out_features=no_classes)
108 | )
109 |
110 | # for p in self.fc.parameters():
111 | # p.requires_grad = False
112 | # for p in self.gate_fc.parameters():
113 | # p.requires_grad = False
114 |
115 |
116 | def forward(self, x_sig, x_spec):
117 | h_time = self.time_features(x_sig)
118 | h_time = h_time.squeeze()
119 |
120 | h_spec = self.spec_features(x_spec)
121 | h_spec = h_spec.squeeze()
122 |
123 | h_gate = F.softmax(self.gate_fc(torch.cat((h_time,h_spec),dim=1)),dim=1)
124 | h_encode = torch.cat([h_time*h_gate[:,0:1],h_spec*h_gate[:,1:2]],dim=1)
125 | y = self.fc(h_encode)
126 | y_time = self.fc_time(h_time)
127 | y_spec = self.fc_spec(h_spec)
128 |
129 | return y, y_time, y_spec, h_gate
130 |
131 | def freeze_backbone(self):
132 | for p in self.spec_features.parameters():
133 | p.requires_grad = False
134 | for p in self.time_features.parameters():
135 | p.requires_grad = False
136 |
137 | def freeze_gate(self):
138 | for p in self.gate_fc.parameters():
139 | p.requires_grad = False
140 |
--------------------------------------------------------------------------------
/experiments/SIMCLR_signal.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | import torch
5 | from torch.optim.lr_scheduler import CosineAnnealingLR
6 | from torch.utils.data import DataLoader
7 | import torch.nn.functional as F
8 | import torch.nn as nn
9 | import torch.distributed as dist
10 | import math
11 |
12 | import sys
13 | current_path = os.getcwd()
14 | sys.path.append(current_path)
15 |
16 | from models.signal_model import signal_model_simclr
17 | from utils.contrastive_dataloader import ECG_contrastive_dataset
18 | from utils.tools import weights_init_xavier
19 |
20 | ctx = "cuda:0" if torch.cuda.is_available() else 'cpu'
21 | eps = 1e-7
22 |
23 | class Flatten(nn.Module):
24 |
25 | def __init__(self):
26 | super(Flatten, self).__init__()
27 |
28 | def forward(self, input_tensor):
29 | return input_tensor.view(input_tensor.size(0), -1)
30 |
31 | class Projection(nn.Module):
32 | def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
33 | super().__init__()
34 | self.output_dim = output_dim
35 | self.input_dim = input_dim
36 | self.hidden_dim = hidden_dim
37 | self.model = nn.Sequential(
38 | # nn.AdaptiveAvgPool2d((1, 1)),
39 | Flatten(),
40 | nn.Linear(self.input_dim, self.hidden_dim, bias=True),
41 | # nn.BatchNorm1d(self.hidden_dim),
42 | nn.ReLU(),
43 | nn.Linear(self.hidden_dim, self.output_dim, bias=True))
44 |
45 | def forward(self, x):
46 | x = self.model(x)
47 | return F.normalize(x, dim=1)
48 |
49 | def nt_xent_loss(out_1, out_2, temperature, eps=1e-6):
50 | """
51 | assume out_1 and out_2 are normalized
52 | out_1: [batch_size, dim]
53 | out_2: [batch_size, dim]
54 | """
55 | # gather representations in case of distributed training
56 | # out_1_dist: [batch_size * world_size, dim]
57 | # out_2_dist: [batch_size * world_size, dim]
58 |
59 | out_1_dist = out_1
60 | out_2_dist = out_2
61 |
62 | # out: [2 * batch_size, dim]
63 | # out_dist: [2 * batch_size * world_size, dim]
64 | out = torch.cat([out_1, out_2], dim=0)
65 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0)
66 |
67 | # cov and sim: [2 * batch_size, 2 * batch_size * world_size]
68 | # neg: [2 * batch_size]
69 | cov = torch.mm(out, out_dist.t().contiguous())
70 | sim = torch.exp(cov / temperature)
71 | neg = sim.sum(dim=-1)
72 |
73 | # from each row, subtract e^1 to remove similarity measure for x1.x1
74 | row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
75 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability
76 |
77 | # Positive similarity, pos becomes [2 * batch_size]
78 | pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
79 | pos = torch.cat([pos, pos], dim=0)
80 |
81 | loss = -torch.log(pos / (neg + eps)).mean()
82 |
83 | return loss
84 |
85 |
86 | def run():
87 | root_folder = './data_folder'
88 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing')
89 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']]
90 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
91 |
92 | no_channels = 12
93 | signal_size = 250
94 | train_stride = signal_size
95 | train_chunk_length = 0
96 |
97 | transforms = ["TimeOut_difflead","GaussianNoise"]
98 |
99 | batch_size = 1024
100 | learning_rate = 1e-3
101 | no_epoches = 1000
102 |
103 | get_mean = np.load(os.path.join(data_folder,"mean.npy"))
104 | get_std = np.load(os.path.join(data_folder,"std.npy"))
105 |
106 | t_params = {"gaussian_scale":[0.005,0.025], "global_crop_scale": [0.5, 1.0], "local_crop_scale": [0.1, 0.5],
107 | "output_size": 250, "warps": 3, "radius": 10, "shift_range":[0.2,0.5],
108 | "epsilon": 10, "magnitude_range": [0.5, 2], "downsample_ratio": 0.2, "to_crop_ratio_range": [0.2, 0.4],
109 | "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1, "stats_mean":get_mean,"stats_std":get_std}
110 |
111 | train_dataset = ECG_contrastive_dataset(summary_folder=data_folder, signal_size=signal_size, stride=train_stride,
112 | chunk_length=train_chunk_length, transforms=transforms,t_params=t_params,
113 | equivalent_classes=equivalent_classes, sample_items_per_record=1,random_crop=True)
114 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size,drop_last=True)
115 |
116 | no_classes = 24
117 | model = signal_model_simclr(no_classes)
118 | projection_head = Projection(model.num_ftrs, hidden_dim=512, output_dim=128)
119 |
120 | model.apply(weights_init_xavier)
121 | projection_head.apply(weights_init_xavier)
122 | model.to(ctx)
123 | projection_head.to(ctx)
124 |
125 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
126 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-4, last_epoch=-1)
127 |
128 | optimizer.zero_grad()
129 | optimizer.step()
130 |
131 | lowest_train_loss = 10
132 | for epoch in range(1,no_epoches+1):
133 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches))
134 | print('Current learning rate: ',optimizer.param_groups[0]['lr'])
135 | scheduler_steplr.step()
136 | model.train()
137 | train_loss = 0
138 |
139 | for batch_idx, sample in enumerate(tqdm(train_dataloader)):
140 | data_i = sample['sig_i'].to(ctx).float()
141 | data_j = sample['sig_j'].to(ctx).float()
142 |
143 | h1 = model(data_i)[0]
144 | h2 = model(data_j)[0]
145 |
146 | # PROJECT
147 | # img -> E -> h -> || -> z
148 | # (b, 2048, 2, 2) -> (b, 128)
149 | z1 = projection_head(h1.squeeze())
150 | z2 = projection_head(h2.squeeze())
151 |
152 | loss = nt_xent_loss(z1,z2,temperature=0.1)
153 |
154 | train_loss += loss.item()
155 |
156 | optimizer.zero_grad()
157 | loss.backward()
158 | optimizer.step()
159 |
160 | whole_train_loss = train_loss / (batch_idx + 1)
161 | print(f'Train Loss: {whole_train_loss}')
162 | if whole_train_loss < lowest_train_loss:
163 | lowest_train_loss = whole_train_loss
164 | torch.save(model.state_dict(), f'./checkpoints/SIMCLR_signal.pth')
165 |
166 |
167 | if __name__ == "__main__":
168 | run()
--------------------------------------------------------------------------------
/data_folder/evaluation-2020-master/Results/physionet_2020_metrics_perDatabase_official_entries.csv:
--------------------------------------------------------------------------------
1 | ,,Database ->,Validation Dataset,Validation Dataset,Validation Dataset,Validation Dataset,Validation Dataset,Hidden CSPC Set,Hidden CSPC Set,Hidden CSPC Set,Hidden CSPC Set,Hidden CSPC Set,Hidden G12EC Set,Hidden G12EC Set,Hidden G12EC Set,Hidden G12EC Set,Hidden G12EC Set,Hidden Undisclosed Set,Hidden Undisclosed Set,Hidden Undisclosed Set,Hidden Undisclosed Set,Hidden Undisclosed Set,Test set,Test set,Test set,Test set,Test set
2 | Final ranking,Team name,CinC Abstract #,AUROC,AUPRC,Accuracy,F-measure,Challenge Score,AUROC,AUPRC,Accuracy,F-measure,Challenge Score,AUROC,AUPRC,Accuracy,F-measure,Challenge Score,AUROC,AUPRC,Accuracy,F-measure,Challenge Score,AUROC,AUPRC,Accuracy,F-measure,Challenge Score
3 | 1,prna,107,0.893,0.428,0.279,0.411,0.587,0.964,0.832,0.532,0.202,0.761,0.871,0.414,0.206,0.401,0.558,0.889,0.523,0.365,0.379,0.492,0.880,0.429,0.330,0.409,0.533
4 | 2,Between_a_ROC_and_a_heart_place,112,0.932,0.548,0.367,0.495,0.672,0.971,0.856,0.555,0.214,0.845,0.915,0.538,0.305,0.495,0.639,0.906,0.578,0.305,0.378,0.412,0.889,0.502,0.324,0.464,0.520
5 | 3,HeartBeats,281,0.945,0.556,0.400,0.525,0.682,0.968,0.816,0.587,0.229,0.852,0.930,0.558,0.338,0.523,0.649,0.909,0.594,0.309,0.416,0.396,0.900,0.510,0.340,0.487,0.514
6 | 4,Triage,133,0.909,0.491,0.424,0.471,0.640,0.962,0.829,0.610,0.216,0.833,0.894,0.491,0.367,0.472,0.609,0.907,0.584,0.358,0.396,0.370,0.897,0.492,0.382,0.462,0.485
7 | 5,Sharif AI Team,445,0.930,0.488,0.359,0.452,0.609,0.966,0.805,0.499,0.222,0.793,0.913,0.489,0.311,0.456,0.577,0.908,0.562,0.297,0.371,0.314,0.911,0.469,0.316,0.426,0.437
8 | 6,DSAIL_SNU,328,0.947,0.570,0.389,0.541,0.688,0.981,0.899,0.598,0.250,0.872,0.937,0.561,0.320,0.538,0.654,0.929,0.592,0.319,0.286,0.228,0.900,0.514,0.341,0.433,0.420
9 | 7,UMCUVA,253,0.932,0.516,0.243,0.468,0.586,0.956,0.823,0.410,0.197,0.643,0.919,0.510,0.189,0.469,0.574,0.917,0.592,0.225,0.377,0.298,0.915,0.481,0.228,0.438,0.417
10 | 8,CQUPT_ECG,85,0.932,0.507,0.320,0.468,0.640,0.966,0.815,0.501,0.190,0.800,0.915,0.504,0.259,0.459,0.609,0.781,0.359,0.064,0.219,0.248,0.821,0.364,0.160,0.321,0.411
11 | 9,ECU,161,0.916,0.490,0.362,0.450,0.623,0.959,0.808,0.538,0.201,0.797,0.905,0.496,0.309,0.476,0.596,0.802,0.325,0.199,0.209,0.205,0.832,0.365,0.262,0.352,0.382
12 | 10,PALab,35,0.942,0.549,0.381,0.530,0.653,0.971,0.873,0.574,0.218,0.836,0.928,0.541,0.319,0.525,0.623,0.751,0.319,0.247,0.199,0.144,0.852,0.393,0.296,0.380,0.359
13 | 11,HITTING,171,0.701,0.273,0.338,0.366,0.435,0.841,0.606,0.337,0.220,0.556,0.699,0.291,0.334,0.381,0.418,0.730,0.378,0.344,0.386,0.290,0.695,0.289,0.339,0.366,0.354
14 | 12,Gio_Ivo,116,0.830,0.314,0.045,0.296,0.426,0.882,0.619,0.117,0.116,0.452,0.799,0.312,0.026,0.304,0.421,0.810,0.376,0.047,0.244,0.205,0.777,0.302,0.047,0.266,0.298
15 | 13,AUTh Team,417,0.879,0.388,0.057,0.349,0.470,0.918,0.698,0.093,0.169,0.447,0.869,0.397,0.047,0.358,0.476,0.834,0.412,0.008,0.234,0.143,0.815,0.329,0.028,0.272,0.281
16 | 14,BioS,124,,,,,,,,,,,,,,,,,,,,,,,,,
17 | 15,UC_Lab_Kn,229,0.938,0.528,0.322,0.480,0.656,0.973,0.871,0.606,0.237,0.840,0.851,0.392,0.221,0.326,0.300,0.828,0.464,0.289,0.277,0.190,0.845,0.391,0.294,0.315,0.270
18 | 16,Cardio-Challengers,225,0.498,0.060,0.000,0.105,0.337,0.500,0.135,0.000,0.058,0.176,0.498,0.067,0.000,0.115,0.369,0.495,0.103,0.001,0.119,0.198,0.498,0.072,0.001,0.116,0.258
19 | 17,JuJuRock,134,0.457,0.060,0.021,0.271,0.406,0.577,0.169,0.013,0.127,0.253,0.439,0.065,0.022,0.292,0.437,0.441,0.093,0.009,0.177,0.125,0.436,0.064,0.012,0.223,0.244
20 | 18,Minibus,282,0.864,0.489,0.476,0.430,0.446,0.963,0.861,0.755,0.235,0.722,0.862,0.491,0.391,0.447,0.394,0.828,0.493,0.284,0.324,0.088,0.828,0.448,0.357,0.409,0.236
21 | 19,Desafinado,363,0.906,0.478,0.224,0.413,0.576,0.967,0.833,0.325,0.182,0.681,0.887,0.481,0.186,0.412,0.556,0.806,0.358,0.009,0.202,-0.013,0.822,0.362,0.089,0.298,0.233
22 | 20,TeamUIO,227,0.846,0.334,0.079,0.309,0.377,0.856,0.605,0.09,0.165,0.379,0.812,0.315,0.07,0.307,0.382,0.677,0.215,0.133,0.15,0.076,0.728,0.218,0.107,0.233,0.206
23 | 21,Eagles,138,0.677,0.189,0.160,0.195,0.214,0.714,0.351,0.146,0.104,0.235,0.653,0.188,0.155,0.185,0.205,0.648,0.270,0.302,0.186,0.205,0.647,0.202,0.240,0.200,0.205
24 | 22,BUTTeam,189,0.940,0.531,0.395,0.522,0.696,0.974,0.844,0.661,0.245,0.892,0.864,0.381,0.238,0.259,0.235,0.877,0.467,0.292,0.251,0.104,0.850,0.392,0.307,0.277,0.202
25 | 23,DSC,71,0.769,0.374,0.429,0.536,0.616,0.900,0.647,0.597,0.231,0.824,0.670,0.248,0.288,0.350,0.301,0.668,0.304,0.286,0.283,0.062,0.658,0.245,0.311,0.316,0.194
26 | 24,Pink Irish Hat,198,0.878,0.447,0.417,0.381,0.511,0.944,0.796,0.653,0.274,0.762,0.715,0.267,0.170,0.193,0.127,0.776,0.397,0.267,0.287,0.123,0.748,0.311,0.271,0.256,0.167
27 | 25,Madhardmax,185,0.921,0.471,0.365,0.461,0.533,0.958,0.810,0.508,0.221,0.544,0.914,0.470,0.315,0.454,0.525,0.916,0.542,0.240,0.281,-0.109,0.895,0.426,0.284,0.373,0.155
28 | 26,Care4MyHeart,127,0.869,0.361,0.250,0.350,0.379,0.929,0.721,0.408,0.168,0.611,0.862,0.362,0.208,0.352,0.342,0.820,0.376,0.108,0.239,-0.027,0.828,0.315,0.166,0.290,0.146
29 | 27,MCIRCC,374,0.907,0.442,0.333,0.433,0.616,0.956,0.810,0.665,0.234,0.813,0.810,0.315,0.232,0.199,0.162,0.792,0.398,0.274,0.231,0.050,0.807,0.328,0.296,0.243,0.141
30 | 28,heartly-ai,356,0.870,0.370,0.310,0.230,0.159,0.927,0.730,0.514,0.202,0.351,0.847,0.378,0.249,0.210,0.128,0.870,0.476,0.346,0.262,0.116,0.847,0.387,0.330,0.237,0.136
31 | 29,Code Team,130,0.940,0.531,0.369,0.513,0.657,0.968,0.850,0.658,0.302,0.830,0.835,0.370,0.248,0.237,0.181,0.835,0.467,0.276,0.237,0.023,0.831,0.376,0.300,0.256,0.132
32 | 30,ISIBrno,32,0.922,0.533,0.417,0.510,0.659,0.977,0.893,0.717,0.262,0.847,0.815,0.356,0.246,0.270,0.195,0.492,0.102,0.002,0.046,-0.006,0.675,0.222,0.141,0.191,0.122
33 | 31,Alba_W.O.,61,0.602,0.182,0.368,0.220,0.308,0.815,0.491,0.568,0.434,0.709,0.576,0.120,0.192,0.125,0.094,0.540,0.151,0.302,0.115,0.035,0.554,0.125,0.291,0.130,0.102
34 | 32,AI Strollers,277,0.625,0.124,0.000,0.140,0.342,0.783,0.437,0.001,0.117,0.212,0.622,0.135,0.000,0.153,0.359,0.629,0.197,0.000,0.144,0.096,0.599,0.136,0.000,0.142,0.077
35 | 33,ECGLearner,95,0.903,0.473,0.421,0.451,0.486,0.956,0.807,0.674,0.195,0.669,0.886,0.461,0.339,0.441,0.452,0.786,0.315,0.054,0.159,-0.347,0.829,0.324,0.194,0.300,0.001
36 | 34,Leicester-Fox,135,0.647,0.242,0.340,0.316,0.395,0.861,0.650,0.543,0.218,0.717,0.635,0.226,0.279,0.301,0.340,0.556,0.163,0.021,0.118,-0.309,0.587,0.163,0.146,0.207,-0.012
37 | 35,deepzx987,424,0.825,0.326,0.331,0.286,0.305,0.919,0.711,0.527,0.228,0.648,0.812,0.318,0.277,0.259,0.25,0.694,0.239,0.08,0.118,-0.287,0.742,0.229,0.181,0.182,-0.035
38 | 36,CVC,128,0.858,0.386,0.208,0.369,0.476,0.951,0.768,0.431,0.199,0.491,0.782,0.305,0.140,0.237,0.150,0.774,0.351,0.088,0.202,-0.287,0.762,0.292,0.135,0.233,-0.080
39 | 37,Cordi-Ak,297,0.815,0.299,0.083,0.238,0.304,0.773,0.507,0.209,0.160,0.254,0.733,0.219,0.025,0.183,0.267,0.639,0.190,0.013,0.090,-0.387,0.688,0.167,0.035,0.136,-0.113
40 | 38,MIndS,339,0.845,0.343,0.317,0.296,0.368,0.905,0.672,0.463,0.190,0.587,0.836,0.336,0.275,0.287,0.333,0.657,0.175,0.017,0.080,-0.489,0.828,0.448,0.357,0.409,-0.128
41 | 39,easyG,148,0.865,0.369,0.392,0.310,0.403,0.919,0.730,0.651,0.312,0.692,0.789,0.295,0.229,0.176,0.066,0.690,0.190,0.020,0.081,-0.622,0.726,0.209,0.141,0.126,-0.290
42 | 40,BiSP Lab,406,0.690,0.164,0.205,0.081,-0.179,0.683,0.283,0.278,0.062,-0.228,0.638,0.133,0.181,0.029,-0.087,0.602,0.147,0.031,0.027,-0.740,0.585,0.109,0.097,0.031,-0.476
43 | 41,Technion_AIMLAB,202,0.662,0.175,0.117,0.000,-0.406,0.774,0.448,0.233,0.000,-0.455,0.646,0.160,0.084,0.000,-0.390,0.650,0.249,0.004,0.000,-0.848,0.645,0.180,0.050,0.000,-0.658
44 |
--------------------------------------------------------------------------------
/experiments/SIMCLR_signal_finetune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from tqdm import tqdm
5 | import torch
6 | from torch.optim.lr_scheduler import CosineAnnealingLR
7 | from torch.utils.data import DataLoader
8 | import torch.nn as nn
9 |
10 | import sys
11 | current_path = os.getcwd()
12 | sys.path.append(current_path)
13 |
14 | from models.signal_model import signal_model_simclr
15 | from utils.base_dataloader import ECG_dataset_base
16 | from utils.eval_tools import load_weights
17 | from utils.eval_tools import compute_accuracy, compute_f_measure_mod
18 | from utils.eval_tools import compute_auc, load_weights, compute_challenge_metric
19 | from utils.tools import open_all_layers, open_specified_layers
20 |
21 | ctx = "cuda:0" if torch.cuda.is_available() else 'cpu'
22 |
23 |
24 | def run():
25 | root_folder = './data_folder'
26 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing')
27 |
28 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']]
29 | normal_class = '426783006'
30 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
31 | weights_file = './data_folder/evaluation-2020-master/weights.csv'
32 | classes, weights = load_weights(weights_file, equivalent_classes)
33 |
34 | no_fold = 8
35 | no_channels = 12
36 | signal_size = 250
37 | train_stride = signal_size
38 | train_chunk_length = 0
39 | # train_stride = signal_size//2
40 | # train_chunk_length = signal_size
41 | val_stride = signal_size//2 # overlap sample signal
42 | val_chunk_length = signal_size
43 |
44 | transforms = True
45 | batch_size = 256
46 | learning_rate = 5e-3
47 | no_epoches = 80
48 | warmup_epoches = 10
49 |
50 | train_dataset = ECG_dataset_base(summary_folder=data_folder,classes=classes, signal_size=signal_size, stride=train_stride,
51 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='train',
52 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold)
53 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size)
54 |
55 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride,
56 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='val',
57 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold)
58 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size)
59 |
60 | no_classes = train_dataset.get_num_classes()
61 | model = signal_model_simclr(no_classes)
62 | state_dict = torch.load('./checkpoints/SIMCLR_signal.pth',map_location=ctx)
63 | model.load_state_dict(state_dict,strict=True)
64 | model.to(ctx)
65 |
66 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
67 | criterion = nn.BCEWithLogitsLoss()
68 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-4, last_epoch=-1)
69 | optimizer.zero_grad()
70 | optimizer.step()
71 |
72 | for epoch in range(1,no_epoches+1):
73 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches))
74 | print('Current learning rate: ',optimizer.param_groups[0]['lr'])
75 | scheduler_steplr.step()
76 |
77 | if epoch <= warmup_epoches:
78 | open_specified_layers(model,['backbone','features'])
79 | print('Freeze the backbone')
80 | else:
81 | open_all_layers(model)
82 |
83 | model.train()
84 | train_loss = 0
85 | train_pred = []
86 | train_gt = []
87 |
88 | for batch_idx, sample in enumerate(tqdm(train_dataloader)):
89 | signal = sample['sig'].to(ctx).float()
90 | signal = signal.view(-1,no_channels,signal_size)
91 | label = sample['lbl'].to(ctx).float()
92 | label = label.view(-1,no_classes)
93 |
94 | _, pred = model(signal)
95 | result = torch.sigmoid(pred)
96 |
97 | loss = criterion(pred,label)
98 | train_loss += loss.item()
99 |
100 | optimizer.zero_grad()
101 | loss.backward()
102 | optimizer.step()
103 |
104 | train_pred.append(result.detach().cpu().numpy())
105 | train_gt.append(label.detach().cpu().numpy())
106 |
107 | train_pred = np.concatenate(train_pred,axis=0)
108 | train_gt = np.concatenate(train_gt,axis=0)
109 |
110 |
111 | print(f'Train Loss: {train_loss / (batch_idx + 1)}')
112 | # auroc, auprc = compute_auc(train_gt,train_pred.astype(np.float64))
113 | # AUROC and AUPRC measures the model performance without the dependency on a decision threshold
114 | train_pred = (train_pred>0.1)
115 | print(f'Accuracy: {compute_accuracy(train_gt.astype(np.bool),train_pred.astype(np.bool))}')
116 | print(f'F1 macro score: {compute_f_measure_mod(train_gt.astype(np.bool),train_pred.astype(np.bool))}')
117 | # print(f'AU_ROC: {auroc}, AUPRC: {auprc}')
118 | # print(f'Challenge metric: {compute_challenge_metric(weights,train_gt.astype(np.bool),train_pred.astype(np.bool),classes,normal_class)}')
119 |
120 | # # Accuracy, F1 macro score, AUROC, AUPRC, Challenge metric
121 | model.eval()
122 | with torch.no_grad():
123 | val_loss = 0
124 | val_pred = []
125 | val_gt = []
126 | val_name = []
127 |
128 | for batch_idx, sample in enumerate(val_dataloader):
129 | signal = sample['sig'].to(ctx).float()
130 | label = sample['lbl'].to(ctx).float()
131 | name = sample['idx']
132 |
133 | _, pred = model(signal)
134 | result = torch.sigmoid(pred)
135 |
136 | loss = criterion(pred,label)
137 | val_loss += loss.item()
138 |
139 | val_pred.append(result.detach().cpu().numpy())
140 | val_gt.append(label.detach().cpu().numpy())
141 | val_name.append(name)
142 |
143 | val_pred = np.concatenate(val_pred,axis=0)
144 | val_gt = np.concatenate(val_gt,axis=0)
145 | val_name = np.concatenate(val_name,axis=0)
146 |
147 | df_pred = pd.DataFrame(data=val_pred)
148 | df_gt = pd.DataFrame(data=val_gt)
149 | df_name = pd.DataFrame(data=val_name)
150 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True)
151 | df_concat_group = df_concat.groupby([0]).mean()
152 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy()
153 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy()
154 |
155 |
156 | print('######## VALIDATION ########')
157 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}')
158 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64))
159 | val_pred_after = (val_pred_after>0.1)
160 | print(f'-----> Accuracy: {compute_accuracy(val_gt_after.astype(np.bool),val_pred_after.astype(np.bool))}')
161 | print(f'-----> F1 macro score: {compute_f_measure_mod(val_gt_after.astype(np.bool),val_pred_after.astype(np.bool))}')
162 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}')
163 | print(f'-----> Challenge metric: {compute_challenge_metric(weights,val_gt_after.astype(np.bool),val_pred_after.astype(np.bool),classes,normal_class)}')
164 |
165 |
166 | if __name__ == "__main__":
167 | run()
--------------------------------------------------------------------------------
/experiments/BYOL_signal_finetune.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import os
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | import torch
7 | from torch.optim.lr_scheduler import CosineAnnealingLR
8 | from torch.utils.data import DataLoader
9 | import torch.nn.functional as F
10 | import torch.nn as nn
11 |
12 | import sys
13 | current_path = os.getcwd()
14 | sys.path.append(current_path)
15 |
16 | from models.signal_model import signal_model_byol
17 | from utils.base_dataloader import ECG_dataset_base
18 | from utils.eval_tools import load_weights
19 | from utils.eval_tools import compute_accuracy, compute_f_measure_mod
20 | from utils.eval_tools import compute_auc, load_weights, compute_challenge_metric
21 | from utils.tools import open_all_layers, open_specified_layers
22 |
23 | ctx = "cuda:0" if torch.cuda.is_available() else 'cpu'
24 |
25 | def run():
26 | root_folder = './data_folder'
27 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing')
28 |
29 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']]
30 | normal_class = '426783006'
31 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
32 | weights_file = './data_folder/evaluation-2020-master/weights.csv'
33 | classes, weights = load_weights(weights_file, equivalent_classes)
34 |
35 | no_fold = 8
36 | no_channels = 12
37 | signal_size = 250
38 | train_stride = signal_size
39 | train_chunk_length = 0
40 | # train_stride = signal_size//2
41 | # train_chunk_length = signal_size
42 | val_stride = signal_size//2 # overlap sample signal
43 | val_chunk_length = signal_size
44 |
45 | transforms = True
46 | batch_size = 256
47 | learning_rate = 5e-3
48 | no_epoches = 80
49 | warmup_epoches = 5
50 |
51 | train_dataset = ECG_dataset_base(summary_folder=data_folder,classes=classes, signal_size=signal_size, stride=train_stride,
52 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='train',
53 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold)
54 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size)
55 |
56 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride,
57 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='val',
58 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold)
59 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size)
60 |
61 |
62 | no_classes = 24
63 | model = signal_model_byol(no_classes)
64 | state_dict = torch.load('./checkpoints/BYOL_signal.pth',map_location=ctx)
65 | model.load_state_dict(state_dict,strict=True)
66 | model.to(ctx)
67 |
68 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
69 | criterion = nn.BCEWithLogitsLoss()
70 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-4, last_epoch=-1)
71 | optimizer.zero_grad()
72 | optimizer.step()
73 |
74 | for epoch in range(1,no_epoches+1):
75 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches))
76 | print('Current learning rate: ',optimizer.param_groups[0]['lr'])
77 | scheduler_steplr.step()
78 |
79 | if epoch <= warmup_epoches:
80 | open_specified_layers(model,['backbone','features'])
81 | print('Freeze the backbone')
82 | else:
83 | open_all_layers(model)
84 |
85 | model.train()
86 | train_loss = 0
87 | train_pred = []
88 | train_gt = []
89 |
90 | for batch_idx, sample in enumerate(tqdm(train_dataloader)):
91 | signal = sample['sig'].to(ctx).float()
92 | signal = signal.view(-1,no_channels,signal_size)
93 | label = sample['lbl'].to(ctx).float()
94 | label = label.view(-1,no_classes)
95 |
96 | _,_,_, pred = model(signal)
97 | result = torch.sigmoid(pred)
98 |
99 | loss = criterion(pred,label)
100 | train_loss += loss.item()
101 |
102 | optimizer.zero_grad()
103 | loss.backward()
104 | optimizer.step()
105 |
106 | train_pred.append(result.detach().cpu().numpy())
107 | train_gt.append(label.detach().cpu().numpy())
108 |
109 | train_pred = np.concatenate(train_pred,axis=0)
110 | train_gt = np.concatenate(train_gt,axis=0)
111 |
112 |
113 | print(f'Train Loss: {train_loss / (batch_idx + 1)}')
114 | # auroc, auprc = compute_auc(train_gt,train_pred.astype(np.float64))
115 | # AUROC and AUPRC measures the model performance without the dependency on a decision threshold
116 | train_pred = (train_pred>0.1)
117 | print(f'Accuracy: {compute_accuracy(train_gt.astype(np.bool),train_pred.astype(np.bool))}')
118 | print(f'F1 macro score: {compute_f_measure_mod(train_gt.astype(np.bool),train_pred.astype(np.bool))}')
119 | # print(f'AU_ROC: {auroc}, AUPRC: {auprc}')
120 | # print(f'Challenge metric: {compute_challenge_metric(weights,train_gt.astype(np.bool),train_pred.astype(np.bool),classes,normal_class)}')
121 |
122 | # # Accuracy, F1 macro score, AUROC, AUPRC, Challenge metric
123 | model.eval()
124 | with torch.no_grad():
125 | val_loss = 0
126 | val_pred = []
127 | val_gt = []
128 | val_name = []
129 |
130 | for batch_idx, sample in enumerate(val_dataloader):
131 | signal = sample['sig'].to(ctx).float()
132 | label = sample['lbl'].to(ctx).float()
133 | name = sample['idx']
134 |
135 | _,_,_, pred = model(signal)
136 | result = torch.sigmoid(pred)
137 |
138 | loss = criterion(pred,label)
139 | val_loss += loss.item()
140 |
141 | val_pred.append(result.detach().cpu().numpy())
142 | val_gt.append(label.detach().cpu().numpy())
143 | val_name.append(name)
144 |
145 | val_pred = np.concatenate(val_pred,axis=0)
146 | val_gt = np.concatenate(val_gt,axis=0)
147 | val_name = np.concatenate(val_name,axis=0)
148 |
149 | df_pred = pd.DataFrame(data=val_pred)
150 | df_gt = pd.DataFrame(data=val_gt)
151 | df_name = pd.DataFrame(data=val_name)
152 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True)
153 | df_concat_group = df_concat.groupby([0]).mean()
154 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy()
155 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy()
156 |
157 |
158 | print('######## VALIDATION ########')
159 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}')
160 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64))
161 | val_pred_after = (val_pred_after>0.1)
162 | print(f'-----> Accuracy: {compute_accuracy(val_gt_after.astype(np.bool),val_pred_after.astype(np.bool))}')
163 | print(f'-----> F1 macro score: {compute_f_measure_mod(val_gt_after.astype(np.bool),val_pred_after.astype(np.bool))}')
164 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}')
165 | print(f'-----> Challenge metric: {compute_challenge_metric(weights,val_gt_after.astype(np.bool),val_pred_after.astype(np.bool),classes,normal_class)}')
166 |
167 |
168 |
169 | if __name__ == "__main__":
170 | run()
--------------------------------------------------------------------------------
/data_preparation/data_extraction_with_preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | import wfdb
7 | from scipy.signal import resample
8 | import pywt
9 | from stratify import stratify
10 |
11 | save_folder = "./data_folder/extracted_data_with_preprocessing"
12 | save_summary = "./data_folder/data_summary_with_preprocessing"
13 | raw_data_cinc = "./data_folder/datasets"
14 | dataset_names = ["ICBEB2018","ICBEB2018_2","INCART","PTB","PTB-XL","Georgia"]
15 | mapping_scored_path = "./data_folder/evaluation-2020-master/dx_mapping_scored.csv" # 27 main labels
16 | target_fs = 100
17 | strat_folds = 10
18 | channels = 12
19 |
20 | mapping_scored_df = pd.read_csv(mapping_scored_path)
21 | dx_mapping_snomed_abbrev = {a:b for [a,b] in list(mapping_scored_df.apply(lambda row: [row["SNOMED CT Code"],row["Abbreviation"]],axis=1))}
22 | list_label_available = np.array(mapping_scored_df["SNOMED CT Code"])
23 |
24 | CPSC_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[0],'**/*.hea'))
25 | print('No files in CPSC:', len(CPSC_files))
26 | CPSC_extra_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[1],'**/*.hea'))
27 | print('No files in CPSC-Extra:', len(CPSC_extra_files))
28 | SPeter_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[2],'**/*.hea'))
29 | print('No files in StPetersburg:', len(SPeter_files))
30 | PTB_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[3],'**/*.hea'))
31 | print('No files in PTB:', len(PTB_files))
32 | PTBXL_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[4],'**/*.hea'))
33 | print('No files in PTB-XL:', len(PTBXL_files))
34 | Georgia_files = glob.glob(os.path.join(raw_data_cinc,dataset_names[5],'**/*.hea'))
35 | print('No files in Georgia:', len(Georgia_files))
36 |
37 | all_files = CPSC_files + CPSC_extra_files + SPeter_files + PTB_files + PTBXL_files + Georgia_files
38 | print('Total no files:',len(all_files))
39 | # (7500, 12)
40 | # {'fs': 500, 'sig_len': 7500, 'n_sig': 12, 'base_date': None, 'base_time': datetime.time(0, 0, 12),
41 | # 'units': ['mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV'],
42 | # 'sig_name': ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'],
43 | # 'comments': ['Age: 74', 'Sex: Male', 'Dx: 59118001', 'Rx: Unknown', 'Hx: Unknown', 'Sx: Unknown']}
44 |
45 | # CSPC fs 500 level 3 stride 50
46 | # CSPC extra fs 500 level 3 stride 50
47 | # StPeter fs 257 level 2 stride 30
48 | # PTB fs 1000 level 2 stride 100
49 | # PTB XL fs 500 level 3 stride 50
50 | # Grorgia fs 500 level 3 stride 50
51 |
52 | def madev(d, axis=None):
53 | """ Mean absolute deviation of a signal """
54 | return np.mean(np.absolute(d - np.mean(d, axis)), axis)
55 |
56 | def moving_average(x, w):
57 | return np.convolve(x, np.ones(w), 'valid') / w
58 |
59 |
60 | skip_files = 0
61 | metadata = []
62 | for idx, hea_file in enumerate(tqdm(all_files)):
63 | file_name = hea_file.split("/")[-1].split(".hea")[0]
64 | data_folder = hea_file.split("/")[-3]
65 | sigbufs, header = wfdb.rdsamp(str(hea_file)[:-4])
66 |
67 | if(np.any(np.isnan(sigbufs))):
68 | print("Warning:",str(hea_file),"is corrupt. Skipping.")
69 | continue
70 |
71 | labels=[]
72 | age=np.nan
73 | sex="nan"
74 | for l in header["comments"]:
75 | arrs = l.strip().split(' ')
76 | if l.startswith('Dx:'):
77 | for x in arrs[1].split(','):
78 | if int(x) in list_label_available:
79 | labels.append(x)
80 | elif l.startswith('Age:'):
81 | try:
82 | age = int(arrs[1])
83 | except:
84 | age= np.nan
85 | elif l.startswith('Sex:'):
86 | sex = arrs[1].strip().lower()
87 | if(sex=="m"):
88 | sex="male"
89 | elif(sex=="f"):
90 | sex="female"
91 |
92 | if len(labels) == 0:
93 | skip_files += 1
94 | continue
95 |
96 | if data_folder == "ICBEB2018":
97 | level = 3
98 | stride = 50
99 | elif data_folder == "ICBEB2018_2":
100 | level = 3
101 | stride = 50
102 | elif data_folder == "INCART":
103 | level = 3
104 | stride = 30
105 | elif data_folder == "PTB":
106 | level = 3
107 | stride = 100
108 | elif data_folder == "PTB-XL":
109 | level = 3
110 | stride = 50
111 | elif data_folder == "Georgia":
112 | level = 3
113 | stride = 50
114 |
115 |
116 | # DENOISE
117 | # Create wavelet object and define parameters
118 | w = pywt.Wavelet('db4')
119 | maxlev = pywt.dwt_max_level(len(sigbufs[:,0]), w.dec_len)
120 | denoised_data = np.zeros((len(sigbufs), channels), dtype=np.float32)
121 | # Decompose into wavelet components, to the level selected:
122 | for cha in range(channels):
123 | coeffs = pywt.wavedec(sigbufs[:,cha], 'db4', mode='periodic',level=maxlev)
124 |
125 | sigma = (1/0.6745) * madev(coeffs[-level])
126 | uthresh = sigma * np.sqrt(2 * np.log(len(sigbufs[:,cha])))
127 |
128 | coeffs[1:] = (pywt.threshold(i, value=uthresh, mode='hard') for i in coeffs[1:])
129 |
130 | datarec = pywt.waverec(coeffs, 'db4')
131 | if len(datarec) < len(sigbufs):
132 | datarec = np.pad(datarec,len(sigbufs)-len(datarec),'edge')
133 | denoised_data[:,cha] = datarec
134 | elif len(datarec) > len(sigbufs):
135 | denoised_data[:,cha] = datarec[0:len(sigbufs)]
136 | else:
137 | denoised_data[:,cha] = datarec
138 |
139 | # BASELINE WANDER REMOVAL
140 | baseline_removal_data = np.zeros((len(sigbufs), channels), dtype=np.float32)
141 | for cha in range(channels):
142 | avg_output = moving_average(denoised_data[:,cha],stride)
143 | avg_pad = np.pad(avg_output,(0,len(sigbufs[:,cha])-len(avg_output)),'edge')
144 | baseline_removal_data[:,cha] = denoised_data[:,cha]- avg_pad
145 |
146 |
147 | ori_fs = header['fs']
148 | factor = target_fs/ori_fs
149 | timesteps_new = int(len(sigbufs)*factor)
150 | data = np.zeros((timesteps_new, channels), dtype=np.float32)
151 | for i in range(channels):
152 | data[:,i] = resample(baseline_removal_data[:,0],timesteps_new)
153 |
154 | np.save(os.path.join(save_folder,file_name+".npy"),data)
155 |
156 | metadata.append({"data":file_name+".npy","label":labels,"sex":sex,"age":age,"dataset":data_folder})
157 |
158 | df =pd.DataFrame(metadata)
159 | lbl_itos = np.unique([item for sublist in list(df.label) for item in sublist])
160 | lbl_stoi = {s:i for i,s in enumerate(lbl_itos)}
161 | df["label"] = df["label"].apply(lambda x: [lbl_stoi[y] for y in x])
162 |
163 | df["strat_fold"]=-1
164 | for ds in np.unique(df["dataset"]):
165 | print("Creating CV folds:",ds)
166 | dfx = df[df.dataset==ds]
167 | idxs = np.array(dfx.index.values)
168 | lbl_itosx = np.unique([item for sublist in list(dfx.label) for item in sublist])
169 | stratified_ids = stratify(list(dfx["label"]), lbl_itosx, [1./strat_folds]*strat_folds)
170 |
171 | for i,split in enumerate(stratified_ids):
172 | df.loc[idxs[split],"strat_fold"]=i
173 |
174 | print("Add Mean Column")
175 | df["data_mean"]=df["data"].apply(lambda x: np.mean(np.load(x if save_folder is None else os.path.join(save_folder,x), allow_pickle=True),axis=0))
176 | print("Add Std Column")
177 | df["data_std"]=df["data"].apply(lambda x: np.std(np.load(x if data_folder is None else os.path.join(save_folder,x), allow_pickle=True),axis=0))
178 | print("Add Length Column")
179 | df["data_length"]=df["data"].apply(lambda x: len(np.load(x if data_folder is None else os.path.join(save_folder,x), allow_pickle=True)))
180 |
181 | #save means and stds
182 | df_mean = df["data_mean"].mean()
183 | df_std = df["data_std"].mean()
184 |
185 | # save dataset
186 | df.to_pickle(os.path.join(save_summary,'df.pkl'),protocol=4)
187 | np.save(os.path.join(save_summary,"lbl_itos.npy"),lbl_itos)
188 | np.save(os.path.join(save_summary,"mean.npy"),df_mean)
189 | np.save(os.path.join(save_summary,"std.npy"),df_std)
190 |
191 | # file1 = 'df.pkl'
192 | # file2 = 'lbl_itos.npy'
193 | # file3 = 'memmap.npy'
194 | # file4 = 'memmap_meta.npz'
195 | # file5 = 'df_memmap.pkl'
196 | # file6 = 'mean.npy'
197 | # file7 = 'std.npy'
198 |
199 |
--------------------------------------------------------------------------------
/experiments/run_signal.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | import torch
7 | from torch.optim.lr_scheduler import CosineAnnealingLR
8 | from torch.utils.data import DataLoader
9 | import torch.nn.functional as F
10 | import torch.nn as nn
11 | import argparse
12 |
13 | import sys
14 | current_path = os.getcwd()
15 | sys.path.append(current_path)
16 |
17 | from models.signal_model import signal_model
18 | from utils.base_dataloader import ECG_dataset_base
19 | from utils.eval_tools import compute_auc, load_weights
20 | from utils.tools import weights_init_xavier
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--batch_size",type=int, default=128)
25 | parser.add_argument("--lr_rate",type=float, default=5e-3)
26 | parser.add_argument("--num_epoches",type=int, default=100)
27 | parser.add_argument('--fold', nargs='+', type=int, default=[11])
28 | parser.add_argument("--gpu", type=str, default="0")
29 | parser.add_argument("--finetune", type=str, default=None)
30 | parser.add_argument("--save_folder", type=str, default=None)
31 | return parser.parse_args()
32 |
33 | def run():
34 | args = parse_args()
35 |
36 | ctx = "cuda:"+args.gpu if torch.cuda.is_available() else 'cpu'
37 |
38 | root_folder = './data_folder'
39 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing')
40 |
41 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']]
42 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
43 | weights_file = './data_folder/evaluation-2020-master/weights.csv'
44 | classes, weights = load_weights(weights_file, equivalent_classes)
45 |
46 | no_channels = 12
47 | signal_size = 250
48 | train_stride = signal_size
49 | train_chunk_length = 0
50 | # train_stride = signal_size//2
51 | # train_chunk_length = signal_size
52 | val_stride = signal_size//2 # overlap sample signal
53 | val_chunk_length = signal_size
54 |
55 | transforms = True
56 | batch_size = args.batch_size
57 | learning_rate = args.lr_rate
58 | no_epoches = args.num_epoches
59 |
60 | list_folds = args.fold
61 | if 11 in list_folds:
62 | fold_range = np.arange(10)
63 | else:
64 | fold_range = list_folds
65 | # run 10 fold cross validation
66 | for no_fold in fold_range:
67 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###')
68 | print(f'Starting fold {no_fold} ...')
69 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###')
70 |
71 | train_dataset = ECG_dataset_base(summary_folder=data_folder,classes=classes, signal_size=signal_size, stride=train_stride,
72 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='train',
73 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold)
74 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size)
75 |
76 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride,
77 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=False, meta_inc=False, t_or_v='val',
78 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold)
79 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size)
80 |
81 | no_classes = train_dataset.get_num_classes()
82 | model = signal_model(no_classes)
83 |
84 | # use the pretrain models from self-supervised learning
85 | if args.finetune is not None:
86 | state_dict = torch.load(args.finetune,map_location=ctx)
87 | model.load_state_dict(state_dict,strict=False)
88 | else:
89 | model.apply(weights_init_xavier)
90 | model.to(ctx)
91 |
92 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
93 | criterion = nn.BCEWithLogitsLoss()
94 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-5, last_epoch=-1)
95 | optimizer.zero_grad()
96 | optimizer.step()
97 |
98 | best_auroc = 0
99 | best_auprc = 0
100 | for epoch in range(1,no_epoches+1):
101 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches))
102 | print('Current learning rate: ',optimizer.param_groups[0]['lr'])
103 | scheduler_steplr.step()
104 | model.train()
105 | train_loss = 0
106 | train_pred = []
107 | train_gt = []
108 |
109 | for batch_idx, sample in enumerate(tqdm(train_dataloader)):
110 | signal = sample['sig'].to(ctx).float()
111 | signal = signal.view(-1,no_channels,signal_size)
112 | label = sample['lbl'].to(ctx).float()
113 | label = label.view(-1,no_classes)
114 |
115 | pred = model(signal)
116 | result = torch.sigmoid(pred)
117 |
118 | loss = criterion(pred,label)
119 | train_loss += loss.item()
120 |
121 | optimizer.zero_grad()
122 | loss.backward()
123 | optimizer.step()
124 |
125 | train_pred.append(result.detach().cpu().numpy())
126 | train_gt.append(label.detach().cpu().numpy())
127 |
128 | train_pred = np.concatenate(train_pred,axis=0)
129 | train_gt = np.concatenate(train_gt,axis=0)
130 |
131 | print(f'Train Loss: {train_loss / (batch_idx + 1)}')
132 |
133 | model.eval()
134 | with torch.no_grad():
135 | val_loss = 0
136 | val_pred = []
137 | val_gt = []
138 | val_name = []
139 |
140 | for batch_idx, sample in enumerate(val_dataloader):
141 | signal = sample['sig'].to(ctx).float()
142 | label = sample['lbl'].to(ctx).float()
143 | name = sample['idx']
144 |
145 | pred = model(signal)
146 | result = torch.sigmoid(pred)
147 |
148 | loss = criterion(pred,label)
149 | val_loss += loss.item()
150 |
151 | val_pred.append(result.detach().cpu().numpy())
152 | val_gt.append(label.detach().cpu().numpy())
153 | val_name.append(name)
154 |
155 | val_pred = np.concatenate(val_pred,axis=0)
156 | val_gt = np.concatenate(val_gt,axis=0)
157 | val_name = np.concatenate(val_name,axis=0)
158 |
159 | df_pred = pd.DataFrame(data=val_pred)
160 | df_gt = pd.DataFrame(data=val_gt)
161 | df_name = pd.DataFrame(data=val_name)
162 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True)
163 | df_concat_group = df_concat.groupby([0]).mean()
164 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy()
165 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy()
166 |
167 | print('######## VALIDATION ########')
168 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}')
169 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64))
170 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}')
171 |
172 | if auroc > best_auroc:
173 | best_auroc = auroc
174 | if args.finetune is not None:
175 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestROC_finetune.pth')
176 | else:
177 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestROC.pth')
178 | if auprc > best_auprc:
179 | best_auprc = auprc
180 | if args.finetune is not None:
181 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestPRC_finetune.pth')
182 | else:
183 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestPRC.pth')
184 |
185 | if __name__ == "__main__":
186 | run()
--------------------------------------------------------------------------------
/experiments/run_spectrogram.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | import torch
7 | from torch.optim.lr_scheduler import CosineAnnealingLR
8 | from torch.utils.data import DataLoader
9 | import torch.nn.functional as F
10 | import torch.nn as nn
11 | import argparse
12 |
13 | import sys
14 | current_path = os.getcwd()
15 | sys.path.append(current_path)
16 |
17 | from models.spectrogram_model import spectrogram_model
18 | from utils.base_dataloader import ECG_dataset_base
19 | from utils.eval_tools import compute_auc, load_weights
20 | from utils.tools import weights_init_xavier
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--batch_size",type=int, default=128)
25 | parser.add_argument("--lr_rate",type=float, default=5e-3)
26 | parser.add_argument("--num_epoches",type=int, default=100)
27 | parser.add_argument('--fold', nargs='+', type=int, default=[11])
28 | parser.add_argument("--gpu", type=str, default="0")
29 | parser.add_argument("--finetune", type=str, default=None)
30 | parser.add_argument("--save_folder", type=str, default=None)
31 | return parser.parse_args()
32 |
33 | def run():
34 | args = parse_args()
35 |
36 | ctx = "cuda:"+args.gpu if torch.cuda.is_available() else 'cpu'
37 |
38 | root_folder = './data_folder'
39 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing')
40 |
41 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']]
42 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
43 | weights_file = './data_folder/evaluation-2020-master/weights.csv'
44 | classes, weights = load_weights(weights_file, equivalent_classes)
45 |
46 | no_channels = 12
47 | signal_size = 250
48 | train_stride = signal_size
49 | train_chunk_length = 0
50 | # train_stride = signal_size//2
51 | # train_chunk_length = signal_size
52 | val_stride = signal_size//2 # overlap sample signal
53 | val_chunk_length = signal_size
54 |
55 | transforms = True
56 | batch_size = args.batch_size
57 | learning_rate = args.lr_rate
58 | no_epoches = args.num_epoches
59 |
60 | list_folds = args.fold
61 | if 11 in list_folds:
62 | fold_range = np.arange(10)
63 | else:
64 | fold_range = list_folds
65 | # run 10 fold cross validation
66 | for no_fold in fold_range:
67 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###')
68 | print(f'Starting fold {no_fold} ...')
69 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###')
70 |
71 | train_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes, signal_size=signal_size, stride=train_stride,
72 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=True, meta_inc=False, t_or_v='train',
73 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold)
74 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size)
75 |
76 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride,
77 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=True, meta_inc=False, t_or_v='val',
78 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold)
79 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size)
80 |
81 | no_classes = train_dataset.get_num_classes()
82 | model = spectrogram_model(no_classes)
83 |
84 | # use the pretrain models from self-supervised learning
85 | if args.finetune is not None:
86 | state_dict = torch.load(args.finetune,map_location=ctx)
87 | model.load_state_dict(state_dict,strict=False)
88 | else:
89 | model.apply(weights_init_xavier)
90 | model.to(ctx)
91 |
92 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
93 | criterion = nn.BCEWithLogitsLoss()
94 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-6, last_epoch=-1)
95 | optimizer.zero_grad()
96 | optimizer.step()
97 |
98 | best_auroc = 0
99 | best_auprc = 0
100 | for epoch in range(1,no_epoches+1):
101 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches))
102 | print('Current learning rate: ',optimizer.param_groups[0]['lr'])
103 | scheduler_steplr.step()
104 | model.train()
105 | train_loss = 0
106 | train_pred = []
107 | train_gt = []
108 |
109 | for batch_idx, sample in enumerate(tqdm(train_dataloader)):
110 | stft = sample['stft'].to(ctx).float()
111 | stft = stft.view(-1,no_channels,13,21)
112 | label = sample['lbl'].to(ctx).float()
113 | label = label.view(-1,no_classes)
114 |
115 | pred = model(stft)
116 | result = torch.sigmoid(pred)
117 |
118 | loss = criterion(pred,label)
119 | train_loss += loss.item()
120 |
121 | optimizer.zero_grad()
122 | loss.backward()
123 | optimizer.step()
124 |
125 | train_pred.append(result.detach().cpu().numpy())
126 | train_gt.append(label.detach().cpu().numpy())
127 |
128 | train_pred = np.concatenate(train_pred,axis=0)
129 | train_gt = np.concatenate(train_gt,axis=0)
130 |
131 | print(f'Train Loss: {train_loss / (batch_idx + 1)}')
132 |
133 | model.eval()
134 | with torch.no_grad():
135 | val_loss = 0
136 | val_pred = []
137 | val_gt = []
138 | val_name = []
139 |
140 | for batch_idx, sample in enumerate(val_dataloader):
141 | signal = sample['stft'].to(ctx).float()
142 | label = sample['lbl'].to(ctx).float()
143 | name = sample['idx']
144 |
145 | pred = model(signal)
146 | result = torch.sigmoid(pred)
147 |
148 | loss = criterion(pred,label)
149 | val_loss += loss.item()
150 |
151 | val_pred.append(result.detach().cpu().numpy())
152 | val_gt.append(label.detach().cpu().numpy())
153 | val_name.append(name)
154 |
155 | val_pred = np.concatenate(val_pred,axis=0)
156 | val_gt = np.concatenate(val_gt,axis=0)
157 | val_name = np.concatenate(val_name,axis=0)
158 |
159 | df_pred = pd.DataFrame(data=val_pred)
160 | df_gt = pd.DataFrame(data=val_gt)
161 | df_name = pd.DataFrame(data=val_name)
162 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True)
163 | df_concat_group = df_concat.groupby([0]).mean()
164 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy()
165 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy()
166 |
167 |
168 | print('######## VALIDATION ########')
169 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}')
170 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64))
171 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}')
172 |
173 | if auroc > best_auroc:
174 | best_auroc = auroc
175 | if args.finetune is not None:
176 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestROC_finetune.pth')
177 | else:
178 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestROC.pth')
179 | if auprc > best_auprc:
180 | best_auprc = auprc
181 | if args.finetune is not None:
182 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestPRC_finetune.pth')
183 | else:
184 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_fold{no_fold}_bestPRC.pth')
185 |
186 | if __name__ == "__main__":
187 | run()
--------------------------------------------------------------------------------
/experiments/run_ensembled.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | import torch
7 | from torch.optim.lr_scheduler import CosineAnnealingLR
8 | from torch.utils.data import DataLoader
9 | import torch.nn.functional as F
10 | import torch.nn as nn
11 | import argparse
12 |
13 | import sys
14 | current_path = os.getcwd()
15 | sys.path.append(current_path)
16 |
17 | from models.ensemble_model import ensemble_model
18 | from utils.base_dataloader import ECG_dataset_base
19 | from utils.eval_tools import compute_auc, load_weights
20 | from utils.tools import weights_init_xavier
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--batch_size",type=int, default=128)
25 | parser.add_argument("--lr_rate",type=float, default=5e-3)
26 | parser.add_argument("--num_epoches",type=int, default=100)
27 | parser.add_argument('--fold', nargs='+', type=int, default=[11])
28 | parser.add_argument("--gpu", type=str, default="0")
29 | parser.add_argument("--gating",action="store_true")
30 | parser.add_argument("--finetune", type=str, default=None)
31 | parser.add_argument("--save_folder", type=str, default=None)
32 | return parser.parse_args()
33 |
34 | def run():
35 | args = parse_args()
36 |
37 | ctx = "cuda:"+args.gpu if torch.cuda.is_available() else 'cpu'
38 |
39 | root_folder = './data_folder'
40 | data_folder = os.path.join(root_folder,'data_summary_without_preprocessing')
41 |
42 | # equivalent_classes = [['CRBBB', 'RBBB'], ['PAC', 'SVPB'], ['PVC', 'VPB']]
43 | equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
44 | weights_file = './data_folder/evaluation-2020-master/weights.csv'
45 | classes, weights = load_weights(weights_file, equivalent_classes)
46 |
47 | no_channels = 12
48 | signal_size = 250
49 | train_stride = signal_size
50 | train_chunk_length = 0
51 | # train_stride = signal_size//2
52 | # train_chunk_length = signal_size
53 | val_stride = signal_size//2 # overlap sample signal
54 | val_chunk_length = signal_size
55 |
56 | transforms = True
57 | batch_size = args.batch_size
58 | learning_rate = args.lr_rate
59 | no_epoches = args.num_epoches
60 |
61 | list_folds = args.fold
62 | if 11 in list_folds:
63 | fold_range = np.arange(10)
64 | else:
65 | fold_range = list_folds
66 | # run 10 fold cross validation
67 | for no_fold in fold_range:
68 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###')
69 | print(f'Starting fold {no_fold} ...')
70 | print('### FOLD-FOLD-FOLD-FOLD-FOLD ###')
71 |
72 | train_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes, signal_size=signal_size, stride=train_stride,
73 | chunk_length=train_chunk_length, transforms=transforms, stft_inc=True, meta_inc=False, t_or_v='train',
74 | equivalent_classes=equivalent_classes, sample_items_per_record=5, preload=False,random_crop=True,val_fold=no_fold)
75 | train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=4,batch_size=batch_size)
76 | val_dataset = ECG_dataset_base(summary_folder=data_folder, classes=classes,signal_size=signal_size, stride=val_stride,
77 | chunk_length=val_chunk_length, transforms=transforms, stft_inc=True, meta_inc=False, t_or_v='val',
78 | equivalent_classes=equivalent_classes, sample_items_per_record=1, preload=True,random_crop=False,val_fold=no_fold)
79 | val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=4,batch_size=batch_size)
80 |
81 | no_classes = train_dataset.get_num_classes()
82 |
83 |
84 | if args.finetune is not None:
85 | checkpoint_folder = "./checkpoints"
86 | w_time = os.path.join(checkpoint_folder,"DINO_signal_student.pth")
87 | w_spec = os.path.join(checkpoint_folder,"DINO_spectrogram_student.pth")
88 | model = ensemble_model(no_classes, args.gating,w_time,w_spec,ctx)
89 | else:
90 | model = ensemble_model(no_classes, args.gating)
91 | model.apply(weights_init_xavier)
92 | model.to(ctx)
93 |
94 | optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
95 | criterion = nn.BCEWithLogitsLoss()
96 | scheduler_steplr = CosineAnnealingLR(optimizer, no_epoches, eta_min=1e-5, last_epoch=-1)
97 | optimizer.zero_grad()
98 | optimizer.step()
99 |
100 | best_auroc = 0
101 | best_auprc = 0
102 | for epoch in range(1,no_epoches+1):
103 | print('===================Epoch [{}/{}]'.format(epoch,no_epoches))
104 | print('Current learning rate: ',optimizer.param_groups[0]['lr'])
105 | scheduler_steplr.step()
106 | model.train()
107 | train_loss = 0
108 | train_pred = []
109 | train_gt = []
110 |
111 | for batch_idx, sample in enumerate(tqdm(train_dataloader)):
112 | signal = sample['sig'].to(ctx).float()
113 | signal = signal.view(-1,no_channels,signal_size)
114 | stft = sample['stft'].to(ctx).float()
115 | stft = stft.view(-1,no_channels,13,21)
116 | label = sample['lbl'].to(ctx).float()
117 | label = label.view(-1,no_classes)
118 |
119 | pred = model(signal,stft)
120 | result = torch.sigmoid(pred)
121 |
122 | loss = criterion(pred,label)
123 | train_loss += loss.item()
124 |
125 | optimizer.zero_grad()
126 | loss.backward()
127 | optimizer.step()
128 |
129 | train_pred.append(result.detach().cpu().numpy())
130 | train_gt.append(label.detach().cpu().numpy())
131 |
132 | train_pred = np.concatenate(train_pred,axis=0)
133 | train_gt = np.concatenate(train_gt,axis=0)
134 |
135 | print(f'Train Loss: {train_loss / (batch_idx + 1)}')
136 |
137 | model.eval()
138 | with torch.no_grad():
139 | val_loss = 0
140 | val_pred = []
141 | val_gt = []
142 | val_name = []
143 |
144 | for batch_idx, sample in enumerate(val_dataloader):
145 | signal = sample['sig'].to(ctx).float()
146 | stft = sample['stft'].to(ctx).float()
147 | label = sample['lbl'].to(ctx).float()
148 | name = sample['idx']
149 |
150 | pred = model(signal,stft)
151 | result = torch.sigmoid(pred)
152 |
153 | loss = criterion(pred,label)
154 | val_loss += loss.item()
155 |
156 | val_pred.append(result.detach().cpu().numpy())
157 | val_gt.append(label.detach().cpu().numpy())
158 | val_name.append(name)
159 |
160 | val_pred = np.concatenate(val_pred,axis=0)
161 | val_gt = np.concatenate(val_gt,axis=0)
162 | val_name = np.concatenate(val_name,axis=0)
163 |
164 | df_pred = pd.DataFrame(data=val_pred)
165 | df_gt = pd.DataFrame(data=val_gt)
166 | df_name = pd.DataFrame(data=val_name)
167 | df_concat = pd.concat([df_name,df_gt,df_pred],axis=1,ignore_index=True)
168 | df_concat_group = df_concat.groupby([0]).mean()
169 | val_gt_after = df_concat_group[df_concat_group.columns[np.arange(0,24)]].to_numpy()
170 | val_pred_after = df_concat_group[df_concat_group.columns[np.arange(24,48)]].to_numpy()
171 |
172 | print('######## VALIDATION ########')
173 | print(f'-----> Val Loss: {val_loss / (batch_idx + 1)}')
174 | auroc, auprc = compute_auc(val_gt_after,val_pred_after.astype(np.float64))
175 | print(f'-----> AU_ROC: {auroc}, AUPRC: {auprc}')
176 |
177 | if auroc > best_auroc:
178 | best_auroc = auroc
179 | if args.gating:
180 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_withgating_fold{no_fold}_bestROC.pth')
181 | else:
182 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_wthoutgating_fold{no_fold}_bestROC.pth')
183 |
184 | if auprc > best_auprc:
185 | best_auprc = auprc
186 | if args.gating:
187 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_withgating_fold{no_fold}_bestPRC.pth')
188 | else:
189 | torch.save(model.state_dict(), f'./checkpoints/{args.save_folder}/{args.model_type}_wthoutgating_fold{no_fold}_bestPRC.pth')
190 |
191 | if __name__ == "__main__":
192 | run()
--------------------------------------------------------------------------------
/models/xresnet1d.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/13_xresnet1d.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['delegates', 'store_attr', 'init_default', 'BatchNorm', 'NormType', 'ConvLayer', 'AdaptiveAvgPool',
4 | 'MaxPool', 'AvgPool', 'ResBlock', 'init_cnn', 'XResNet1d', 'xresnet1d18', 'xresnet1d34', 'xresnet1d50',
5 | 'xresnet1d101', 'xresnet1d152', 'xresnet1d18_deep', 'xresnet1d34_deep', 'xresnet1d50_deep',
6 | 'xresnet1d18_deeper', 'xresnet1d34_deeper', 'xresnet1d50_deeper']
7 |
8 | # Cell
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .basic_conv1d import create_head1d, Flatten
14 |
15 | from enum import Enum
16 | import re
17 |
18 | # Cell
19 | import inspect
20 |
21 | def delegates(to=None, keep=False):
22 | "Decorator: replace `**kwargs` in signature with params from `to`"
23 | def _f(f):
24 | if to is None: to_f,from_f = f.__base__.__init__,f.__init__
25 | else: to_f,from_f = to,f
26 | sig = inspect.signature(from_f)
27 | sigd = dict(sig.parameters)
28 | k = sigd.pop('kwargs')
29 | s2 = {k:v for k,v in inspect.signature(to_f).parameters.items()
30 | if v.default != inspect.Parameter.empty and k not in sigd}
31 | sigd.update(s2)
32 | if keep: sigd['kwargs'] = k
33 | from_f.__signature__ = sig.replace(parameters=sigd.values())
34 | return f
35 | return _f
36 |
37 | def store_attr(self, nms):
38 | "Store params named in comma-separated `nms` from calling context into attrs in `self`"
39 | mod = inspect.currentframe().f_back.f_locals
40 | for n in re.split(', *', nms): setattr(self,n,mod[n])
41 |
42 | # Cell
43 | NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero')
44 |
45 | def _conv_func(ndim=2, transpose=False):
46 | "Return the proper conv `ndim` function, potentially `transposed`."
47 | assert 1 <= ndim <=3
48 | return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d')
49 |
50 | def init_default(m, func=nn.init.kaiming_normal_):
51 | "Initialize `m` weights with `func` and set `bias` to 0."
52 | if func and hasattr(m, 'weight'): func(m.weight)
53 | with torch.no_grad():
54 | if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)
55 | return m
56 |
57 | def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs):
58 | "Norm layer with `nf` features and `ndim` initialized depending on `norm_type`."
59 | assert 1 <= ndim <= 3
60 | bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs)
61 | if bn.affine:
62 | bn.bias.data.fill_(1e-3)
63 | bn.weight.data.fill_(0. if zero else 1.)
64 | return bn
65 |
66 | def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs):
67 | "BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."
68 | return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs)
69 |
70 | # Cell
71 | class ConvLayer(nn.Sequential):
72 | "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers."
73 | def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,
74 | act_cls=nn.ReLU, transpose=False, init=nn.init.kaiming_normal_, xtra=None, **kwargs):
75 | if padding is None: padding = ((ks-1)//2 if not transpose else 0)
76 | bn = norm_type in (NormType.Batch, NormType.BatchZero)
77 | inn = norm_type in (NormType.Instance, NormType.InstanceZero)
78 | if bias is None: bias = not (bn or inn)
79 | conv_func = _conv_func(ndim, transpose=transpose)
80 | conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs), init)
81 | if norm_type==NormType.Weight: conv = weight_norm(conv)
82 | elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
83 | layers = [conv]
84 | act_bn = []
85 | if act_cls is not None: act_bn.append(act_cls())
86 | if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))
87 | if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim))
88 | if bn_1st: act_bn.reverse()
89 | layers += act_bn
90 | if xtra: layers.append(xtra)
91 | super().__init__(*layers)
92 |
93 | # Cell
94 | def AdaptiveAvgPool(sz=1, ndim=2):
95 | "nn.AdaptiveAvgPool layer for `ndim`"
96 | assert 1 <= ndim <= 3
97 | return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz)
98 |
99 | def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
100 | "nn.MaxPool layer for `ndim`"
101 | assert 1 <= ndim <= 3
102 | return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding)
103 |
104 | def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
105 | "nn.AvgPool layer for `ndim`"
106 | assert 1 <= ndim <= 3
107 | return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)
108 |
109 | # Cell
110 | class ResBlock(nn.Module):
111 | "Resnet block from `ni` to `nh` with `stride`"
112 | @delegates(ConvLayer.__init__)
113 | def __init__(self, expansion, ni, nf, stride=1, kernel_size=3, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1,
114 | sa=False, sym=False, norm_type=NormType.Batch, act_cls=nn.ReLU, ndim=2,
115 | pool=AvgPool, pool_first=True, **kwargs):
116 | super().__init__()
117 | norm2 = (NormType.BatchZero if norm_type==NormType.Batch else
118 | NormType.InstanceZero if norm_type==NormType.Instance else norm_type)
119 | if nh2 is None: nh2 = nf
120 | if nh1 is None: nh1 = nh2
121 | nf,ni = nf*expansion,ni*expansion
122 | k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs)
123 | k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
124 | layers = [ConvLayer(ni, nh2, kernel_size, stride=stride, groups=ni if dw else groups, **k0),
125 | ConvLayer(nh2, nf, kernel_size, groups=g2, **k1)
126 | ] if expansion == 1 else [
127 | ConvLayer(ni, nh1, 1, **k0),
128 | ConvLayer(nh1, nh2, kernel_size, stride=stride, groups=nh1 if dw else groups, **k0),
129 | ConvLayer(nh2, nf, 1, groups=g2, **k1)]
130 | self.convs = nn.Sequential(*layers)
131 | convpath = [self.convs]
132 | if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls))
133 | if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym))
134 | self.convpath = nn.Sequential(*convpath)
135 | idpath = []
136 | if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs))
137 | if stride!=1: idpath.insert((1,0)[pool_first], pool(2, ndim=ndim, ceil_mode=True))
138 | self.idpath = nn.Sequential(*idpath)
139 | self.act = nn.ReLU(inplace=True) if act_cls is nn.ReLU else act_cls()
140 |
141 | def forward(self, x): return self.act(self.convpath(x) + self.idpath(x))
142 |
143 |
144 |
145 | # Cell
146 | def init_cnn(m):
147 | if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
148 | if isinstance(m, (nn.Conv1d, nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
149 | for l in m.children(): init_cnn(l)
150 |
151 | # Cell
152 | class XResNet1d(nn.Sequential):
153 | @delegates(ResBlock)
154 | def __init__(self, block, expansion, layers, p=0.0, input_channels=3, num_classes=1000, stem_szs=(32,32,64),kernel_size=5,kernel_size_stem=5,
155 | widen=1.0, sa=False, act_cls=nn.ReLU, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
156 | store_attr(self, 'block,expansion,act_cls')
157 | stem_szs = [input_channels, *stem_szs]
158 | stem = [ConvLayer(stem_szs[i], stem_szs[i+1], ks=kernel_size_stem, stride=2 if i==0 else 1, act_cls=act_cls, ndim=1)
159 | for i in range(3)]
160 |
161 | #block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
162 | block_szs = [int(o*widen) for o in [64,64,64,64] +[32]*(len(layers)-4)]
163 | block_szs = [64//expansion] + block_szs
164 | blocks = [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l,
165 | stride=1 if i==0 else 2, kernel_size=kernel_size, sa=sa and i==len(layers)-4, ndim=1, **kwargs)
166 | for i,l in enumerate(layers)]
167 |
168 | head = create_head1d(block_szs[-1]*expansion, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
169 |
170 | super().__init__(
171 | *stem, nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
172 | *blocks,
173 | head,
174 | )
175 | init_cnn(self)
176 |
177 | def _make_layer(self, ni, nf, blocks, stride, kernel_size, sa, **kwargs):
178 | return nn.Sequential(
179 | *[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1,
180 | kernel_size=kernel_size, sa=sa and i==(blocks-1), act_cls=self.act_cls, **kwargs)
181 | for i in range(blocks)])
182 |
183 | def get_layer_groups(self):
184 | return (self[3],self[-1])
185 |
186 | def get_output_layer(self):
187 | return self[-1][-1]
188 |
189 | def set_output_layer(self,x):
190 | self[-1][-1]=x
191 |
192 | # Cell
193 | def _xresnet1d(expansion, layers, **kwargs):
194 | return XResNet1d(ResBlock, expansion, layers, **kwargs)
195 |
196 | def xresnet1d18 (**kwargs): return _xresnet1d(1, [2, 2, 2, 2], **kwargs)
197 | def xresnet1d34 (**kwargs): return _xresnet1d(1, [3, 4, 6, 3], **kwargs)
198 | def xresnet1d50 (**kwargs): return _xresnet1d(4, [3, 4, 6, 3], **kwargs)
199 | def xresnet1d101(**kwargs): return _xresnet1d(4, [3, 4, 23, 3], **kwargs)
200 | def xresnet1d152(**kwargs): return _xresnet1d(4, [3, 8, 36, 3], **kwargs)
201 | def xresnet1d18_deep (**kwargs): return _xresnet1d(1, [2,2,2,2,1,1], **kwargs)
202 | def xresnet1d34_deep (**kwargs): return _xresnet1d(1, [3,4,6,3,1,1], **kwargs)
203 | def xresnet1d50_deep (**kwargs): return _xresnet1d(4, [3,4,6,3,1,1], **kwargs)
204 | def xresnet1d18_deeper(**kwargs): return _xresnet1d(1, [2,2,1,1,1,1,1,1], **kwargs)
205 | def xresnet1d34_deeper(**kwargs): return _xresnet1d(1, [3,4,6,3,1,1,1,1], **kwargs)
206 | def xresnet1d50_deeper(**kwargs): return _xresnet1d(4, [3,4,6,3,1,1,1,1], **kwargs)
--------------------------------------------------------------------------------
/models/basic_conv1d.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/11_basic_conv1d.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['cd_adaptiveconcatpool', 'attrib_adaptiveconcatpool', 'AdaptiveConcatPool1d', 'SqueezeExcite1d',
4 | 'weight_init', 'create_head1d', 'basic_conv1d', 'fcn', 'fcn_wang', 'schirrmeister', 'sen', 'basic1d']
5 |
6 | # Cell
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import math
11 | from typing import Iterable
12 |
13 | class Flatten(nn.Module):
14 | "Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
15 | def __init__(self, full:bool=False):
16 | super().__init__()
17 | self.full = full
18 | def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1)
19 |
20 |
21 | def listify(p=None, q=None):
22 | "Make `p` listy and the same length as `q`."
23 | if p is None: p=[]
24 | elif isinstance(p, str): p = [p]
25 | elif not isinstance(p, Iterable): p = [p]
26 | #Rank 0 tensors in PyTorch are Iterable but don't have a length.
27 | else:
28 | try: a = len(p)
29 | except: p = [p]
30 | n = q if type(q)==int else len(p) if q is None else len(q)
31 | if len(p)==1: p = p * n
32 | assert len(p)==n, f'List len mismatch ({len(p)} vs {n})'
33 | return list(p)
34 |
35 |
36 | def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None):
37 | "Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
38 | layers = [nn.BatchNorm1d(n_in)] if bn else []
39 | if p != 0: layers.append(nn.Dropout(p))
40 | layers.append(nn.Linear(n_in, n_out))
41 | if actn is not None: layers.append(actn)
42 | return layers
43 |
44 | # Cell
45 | def _conv1d(in_planes,out_planes,kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0):
46 | lst=[]
47 | if(drop_p>0):
48 | lst.append(nn.Dropout(drop_p))
49 | lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, dilation=dilation, bias=not(bn)))
50 | if(bn):
51 | lst.append(nn.BatchNorm1d(out_planes))
52 | if(act=="relu"):
53 | lst.append(nn.ReLU(True))
54 | if(act=="elu"):
55 | lst.append(nn.ELU(True))
56 | if(act=="prelu"):
57 | lst.append(nn.PReLU(True))
58 | return nn.Sequential(*lst)
59 |
60 | def _fc(in_planes,out_planes, act="relu", bn=True):
61 | lst = [nn.Linear(in_planes, out_planes, bias=not(bn))]
62 | if(bn):
63 | lst.append(nn.BatchNorm1d(out_planes))
64 | if(act=="relu"):
65 | lst.append(nn.ReLU(True))
66 | if(act=="elu"):
67 | lst.append(nn.ELU(True))
68 | if(act=="prelu"):
69 | lst.append(nn.PReLU(True))
70 | return nn.Sequential(*lst)
71 |
72 | class AdaptiveConcatPool1d(nn.Module):
73 | "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."
74 | def __init__(self, sz=None):
75 | "Output will be 2*sz or 2 if sz is None"
76 | super().__init__()
77 | sz = sz or 1
78 | self.ap,self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
79 | def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
80 | def attrib(self,relevant,irrelevant):
81 | return attrib_adaptiveconcatpool(self,relevant,irrelevant)
82 |
83 |
84 | # Cell
85 | class SqueezeExcite1d(nn.Module):
86 | '''squeeze excite block as used for example in LSTM FCN'''
87 | def __init__(self,channels,reduction=16):
88 | super().__init__()
89 | channels_reduced = channels//reduction
90 | self.w1 = torch.nn.Parameter(torch.randn(channels_reduced,channels).unsqueeze(0))
91 | self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0))
92 |
93 | def forward(self, x):
94 | #input is bs,ch,seq
95 | z=torch.mean(x,dim=2,keepdim=True)#bs,ch
96 | intermed = F.relu(torch.matmul(self.w1,z))#(1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1)
97 | s=F.sigmoid(torch.matmul(self.w2,intermed))#(1,ch,ch_red * bs, ch_red, 1=bs, ch, 1
98 | return s*x #bs,ch,seq * bs, ch,1 = bs,ch,seq
99 |
100 | # Cell
101 | def weight_init(m):
102 | '''call weight initialization for model n via n.appy(weight_init)'''
103 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
104 | nn.init.kaiming_normal_(m.weight)
105 | if m.bias is not None:
106 | nn.init.zeros_(m.bias)
107 | if isinstance(m, nn.BatchNorm1d):
108 | nn.init.constant_(m.weight,1)
109 | nn.init.constant_(m.bias,0)
110 | if isinstance(m,SqueezeExcite1d):
111 | stdv1=math.sqrt(2./m.w1.size[0])
112 | nn.init.normal_(m.w1,0.,stdv1)
113 | stdv2=math.sqrt(1./m.w2.size[1])
114 | nn.init.normal_(m.w2,0.,stdv2)
115 |
116 | # Cell
117 | def create_head1d(nf, nc, lin_ftrs=None, ps=0.5, bn_final:bool=False, bn:bool=True, act="relu", concat_pooling=True):
118 | "Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here"
119 | lin_ftrs = [2*nf if concat_pooling else nf, nc] if lin_ftrs is None else [2*nf if concat_pooling else nf] + lin_ftrs + [nc] #was [nf, 512,nc]
120 | ps = listify(ps)
121 | if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
122 | actns = [nn.ReLU(inplace=True) if act=="relu" else nn.ELU(inplace=True)] * (len(lin_ftrs)-2) + [None]
123 | layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()]
124 | for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
125 | layers += bn_drop_lin(ni,no,bn,p,actn)
126 | if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
127 | return nn.Sequential(*layers)
128 |
129 | # Cell
130 | class basic_conv1d(nn.Sequential):
131 | '''basic conv1d'''
132 | def __init__(self, filters=[128,128,128,128],kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,split_first_layer=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
133 | layers = []
134 | if(isinstance(kernel_size,int)):
135 | kernel_size = [kernel_size]*len(filters)
136 | for i in range(len(filters)):
137 | layers_tmp = []
138 |
139 | layers_tmp.append(_conv1d(input_channels if i==0 else filters[i-1],filters[i],kernel_size=kernel_size[i],stride=(1 if (split_first_layer is True and i==0) else stride),dilation=dilation,act="none" if ((headless is True and i==len(filters)-1) or (split_first_layer is True and i==0)) else act, bn=False if (headless is True and i==len(filters)-1) else bn,drop_p=(0. if i==0 else drop_p)))
140 | if((split_first_layer is True and i==0)):
141 | layers_tmp.append(_conv1d(filters[0],filters[0],kernel_size=1,stride=1,act=act, bn=bn,drop_p=0.))
142 | #layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn)))
143 | #layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn))
144 | if(pool>0 and i0):
147 | layers_tmp.append(SqueezeExcite1d(filters[i],squeeze_excite_reduction))
148 | layers.append(nn.Sequential(*layers_tmp))
149 |
150 | #head
151 | #layers.append(nn.AdaptiveAvgPool1d(1))
152 | #layers.append(nn.Linear(filters[-1],num_classes))
153 | #head #inplace=True leads to a runtime error see ReLU+ dropout https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5
154 | self.headless = headless
155 | if(headless is True):
156 | head = nn.Sequential(nn.AdaptiveAvgPool1d(1),Flatten())
157 | else:
158 | head=create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
159 | layers.append(head)
160 |
161 | super().__init__(*layers)
162 |
163 | def get_layer_groups(self):
164 | return (self[2],self[-1])
165 |
166 | def get_output_layer(self):
167 | if self.headless is False:
168 | return self[-1][-1]
169 | else:
170 | return None
171 |
172 | def set_output_layer(self,x):
173 | if self.headless is False:
174 | self[-1][-1] = x
175 |
176 |
177 | # Cell
178 | def fcn(filters=[128]*5,num_classes=2,input_channels=8,**kwargs):
179 | filters_in = filters + [num_classes]
180 | return basic_conv1d(filters=filters_in,kernel_size=3,stride=1,pool=2,pool_stride=2,input_channels=input_channels,act="relu",bn=True,headless=True)
181 |
182 | def fcn_wang(num_classes=2,input_channels=8,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
183 | return basic_conv1d(filters=[128,256,128],kernel_size=[8,5,3],stride=1,pool=0,pool_stride=2, num_classes=num_classes,input_channels=input_channels,act="relu",bn=True,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
184 |
185 | def schirrmeister(num_classes=2,input_channels=8,kernel_size=10,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
186 | return basic_conv1d(filters=[25,50,100,200],kernel_size=kernel_size, stride=3, pool=3, pool_stride=1, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,split_first_layer=True,drop_p=0.5,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
187 |
188 | def sen(filters=[128]*5,num_classes=2,input_channels=8,kernel_size=3,squeeze_excite_reduction=16,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
189 | return basic_conv1d(filters=filters,kernel_size=kernel_size,stride=2,pool=0,pool_stride=0,input_channels=input_channels,act="relu",bn=True,num_classes=num_classes,squeeze_excite_reduction=squeeze_excite_reduction,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
190 |
191 | def basic1d(filters=[128]*5,kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
192 | return basic_conv1d(filters=filters,kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
193 |
--------------------------------------------------------------------------------
/utils/contrastive_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data
4 | import os
5 | import pickle
6 | import os
7 | import random
8 |
9 | from .timeseries_transformations import TTimeOut,ToTensor,TGaussianNoise, TRandomResizedCrop,TTimeOut_difflead,Transpose
10 |
11 | class Normalize(object):
12 | """Normalize using given stats.
13 | """
14 | def __init__(self, stats_mean, stats_std, input=True, channels=[]):
15 | self.stats_mean=stats_mean.astype(np.float32) if stats_mean is not None else None
16 | self.stats_std=stats_std.astype(np.float32)+1e-8 if stats_std is not None else None
17 | self.input = input
18 | if(len(channels)>0):
19 | for i in range(len(stats_mean)):
20 | if(not(i in channels)):
21 | self.stats_mean[:,i]=0
22 | self.stats_std[:,i]=1
23 |
24 | def __call__(self, sample):
25 | datax, labelx = sample
26 | data = datax if self.input else labelx
27 | #assuming channel last
28 | if(self.stats_mean is not None):
29 | data = data - self.stats_mean
30 | if(self.stats_std is not None):
31 | data = data/self.stats_std
32 |
33 | if(self.input):
34 | return (data, labelx)
35 | else:
36 | return (datax, data)
37 |
38 | def replace_labels(x, stay_idx, remove_idx):
39 | res = []
40 | for y in x:
41 | if y == remove_idx:
42 | res.append(stay_idx)
43 | else:
44 | res.append(y)
45 | return res
46 |
47 | def keep_one_random_class(x):
48 | res = np.random.choice(x,1)[0]
49 | return res
50 |
51 | def transformations_from_strings(transformations, t_params):
52 | if transformations is None:
53 | return [ToTensor()]
54 |
55 | def str_to_trafo(trafo):
56 | if trafo == "RandomResizedCrop":
57 | return TRandomResizedCrop(crop_ratio_range=t_params["rr_crop_ratio_range"], output_size=t_params["output_size"])
58 | elif trafo == "TimeOut":
59 | return TTimeOut(crop_ratio_range=t_params["to_crop_ratio_range"])
60 | elif trafo == "GaussianNoise":
61 | return TGaussianNoise(scale=t_params["gaussian_scale"])
62 | elif trafo == "TimeOut_difflead":
63 | return TTimeOut_difflead(crop_ratio_range=t_params["to_crop_ratio_range"])
64 | else:
65 | raise Exception(str(trafo) + " is not a valid transformation")
66 |
67 | trafo_list = [ToTensor(transpose_data=False)] + [str_to_trafo(trafo)
68 | for trafo in transformations] + [Normalize(stats_mean=t_params["stats_mean"],stats_std=t_params["stats_std"])] + [Transpose()]
69 |
70 | return trafo_list
71 |
72 |
73 | class ECG_contrastive_dataset(torch.utils.data.Dataset):
74 | def __init__(self, summary_folder, signal_size, stride, chunk_length, transforms, t_params,
75 | equivalent_classes, sample_items_per_record=1, random_crop=True):
76 |
77 | self.folder = summary_folder
78 | self.signal_size = signal_size
79 | self.transforms = transformations_from_strings(transforms, t_params)
80 | # number of small samples we want to take out of the big signal data
81 | self.sample_items_per_record = sample_items_per_record
82 | # from the large signal data, we randomly choose where we acquire the sample data
83 | self.random_crop = random_crop
84 |
85 | # Loading data info
86 | self.df = pickle.load(open(os.path.join(self.folder,"df_memmap.pkl"), "rb"))
87 | self.lbl_itos = np.load(os.path.join(self.folder,"lbl_itos.npy"))
88 | self.mean = np.load(os.path.join(self.folder,"mean.npy"))
89 | self.std = np.load(os.path.join(self.folder,"std.npy"))
90 |
91 | stack_remove_idx = []
92 | # Grouping the equivalent classes, remove the correspond classes
93 | if len(equivalent_classes)!=0:
94 | for i in range(len(equivalent_classes)):
95 | stay_class, remove_class = equivalent_classes[i]
96 | if stay_class not in self.lbl_itos or remove_class not in self.lbl_itos:
97 | print(f'{stay_class},{remove_class}: one of those is not in the dictionary')
98 | else:
99 | stay_idx = np.where(self.lbl_itos==stay_class)[0][0]
100 | remove_idx = np.where(self.lbl_itos==remove_class)[0][0]
101 | self.df['label'] = self.df['label'].apply(lambda x: replace_labels(x,stay_idx,remove_idx))
102 | stack_remove_idx.append(remove_idx)
103 |
104 |
105 | self.df['label'] = self.df['label'].apply(lambda x: keep_one_random_class(x))
106 | self.lbl_itos = np.delete(self.lbl_itos,stack_remove_idx)
107 |
108 |
109 | self.timeseries_df_data = np.array(self.df['data'])
110 | if(self.timeseries_df_data.dtype not in [np.int16, np.int32, np.int64]):
111 | self.timeseries_df_data = np.array(self.df["data"].astype(str)).astype(np.string_)
112 |
113 | #stack arrays/lists for proper batching
114 | if(isinstance(self.df['data'].iloc[0],list) or isinstance(self.df['label'].iloc[0],np.ndarray)):
115 | self.timeseries_df_label = np.stack(self.df['label'])
116 | else: # single integers/floats
117 | self.timeseries_df_label = np.array(self.df['label'])
118 | #everything else cannot be batched anyway mp.Manager().list(self.timeseries_df_label)
119 | if(self.timeseries_df_label.dtype not in [np.int16, np.int32, np.int64, np.float32, np.float64]):
120 | # assert(annotation and memmap_filename is None and npy_data is None)#only for filenames in mode files
121 | self.timeseries_df_label = np.array(self.df['label'].apply(lambda x:str(x))).astype(np.string_)
122 |
123 | # load meta data for memmap npy
124 | self.mode = "memmap"
125 | memmap_meta = np.load(os.path.join(self.folder,'memmap_meta.npz'), allow_pickle=True)
126 | self.memmap_start = memmap_meta["start"]
127 | self.memmap_shape = memmap_meta["shape"]
128 | self.memmap_length = memmap_meta["length"]
129 | self.memmap_file_idx = memmap_meta["file_idx"]
130 | self.memmap_dtype = np.dtype(str(memmap_meta["dtype"]))
131 | self.memmap_filenames = np.array(memmap_meta["filenames"]).astype(np.string_)#save as byte to avoid issue with mp
132 |
133 | # load data from memamp file
134 | self.memmap_signaldata = np.memmap(os.path.join(self.folder,'memmap.npy'),self.memmap_dtype, mode='r', shape=tuple(self.memmap_shape[0]))
135 |
136 | # get the position of the signal inside the stack memmap signal data
137 | self.df_idx_mapping = []
138 | self.start_idx_mapping = []
139 | self.end_idx_mapping = []
140 | start_idx = 0
141 | min_chunk_length = signal_size
142 |
143 | for df_idx,(id,row) in enumerate(self.df.iterrows()):
144 | data_length = self.memmap_length[row["data"]]
145 |
146 | if(chunk_length == 0): # do not split into chunks
147 | idx_start = [start_idx]
148 | idx_end = [data_length]
149 | else:
150 | idx_start = list(range(start_idx,data_length,chunk_length if stride is None else stride))
151 | idx_end = [min(l+chunk_length, data_length) for l in idx_start]
152 |
153 | #remove final chunk(s) if too short
154 | for i in range(len(idx_start)):
155 | if(idx_end[i]-idx_start[i]< min_chunk_length):
156 | del idx_start[i:]
157 | del idx_end[i:]
158 | break
159 | #append to lists
160 | copies = 0
161 | for _ in range(copies+1):
162 | for i_s,i_e in zip(idx_start,idx_end):
163 | self.df_idx_mapping.append(df_idx)
164 | self.start_idx_mapping.append(i_s)
165 | self.end_idx_mapping.append(i_e)
166 |
167 | #convert to np.array to avoid mp issues with python lists
168 | self.df_idx_mapping = np.array(self.df_idx_mapping)
169 | self.start_idx_mapping = np.array(self.start_idx_mapping)
170 | self.end_idx_mapping = np.array(self.end_idx_mapping)
171 |
172 | def __len__(self):
173 | return len(self.df_idx_mapping)
174 |
175 | @property
176 | def is_empty(self):
177 | return len(self.df_idx_mapping)==0
178 |
179 | def __getitem__(self, idx):
180 | lst_data_i = []
181 | lst_data_j = []
182 | lst_lbl = []
183 | lst_patient = []
184 | for _ in range(self.sample_items_per_record):
185 | #determine crop idxs
186 | timesteps= self.get_sample_length(idx)
187 |
188 | if(self.random_crop): #random crop
189 | if(timesteps==self.signal_size):
190 | start_idx_rel = 0
191 | else:
192 | # get random start of the crop inside the big signal
193 | start_idx_rel = random.randint(0, timesteps - self.signal_size -1)#np.random.randint(0, timesteps - self.output_size)
194 | else:
195 | # if not random, this may be for valid and the timesteps is probably equal to the signal_size
196 | start_idx_rel = (timesteps - self.signal_size)//2
197 | if(self.sample_items_per_record==1):
198 | data_i, data_j, label, patient = self.get_signal_sample(idx,start_idx_rel)
199 | return {'sig_i':data_i,'sig_j':data_j,'lbl':label,'idx':patient}
200 | else:
201 | data_i, data_j, label, patient = self.get_signal_sample(idx,start_idx_rel)
202 | lst_data_i.append(data_i)
203 | lst_data_j.append(data_j)
204 | lst_patient.append(patient)
205 | lst_lbl.append(label)
206 | lst_data_i = torch.stack(lst_data_i)
207 | lst_data_j = torch.stack(lst_data_j)
208 | lst_lbl = torch.from_numpy(np.stack(lst_lbl))
209 |
210 | return {'sig_i':lst_data_i,'sig_j':lst_data_j,'lbl':lst_lbl,'idx':lst_patient}
211 |
212 | def get_signal_sample(self, idx,start_idx_rel):
213 | df_idx = self.df_idx_mapping[idx]
214 | start_idx = self.start_idx_mapping[idx]
215 | end_idx = self.end_idx_mapping[idx]
216 | # determine crop idxs
217 | timesteps= end_idx - start_idx
218 | assert(timesteps>=self.signal_size)
219 | start_idx_crop = start_idx + start_idx_rel
220 | end_idx_crop = start_idx_crop+self.signal_size
221 |
222 | memmap_idx = self.timeseries_df_data[df_idx]
223 | idx_offset = self.memmap_start[memmap_idx]
224 |
225 | signal_data = np.copy(self.memmap_signaldata[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
226 |
227 | #print(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
228 | label = self.timeseries_df_label[df_idx]
229 | sample1 = (signal_data,label)
230 | sample2 = (signal_data,label)
231 |
232 | for trans in self.transforms:
233 | sample1 = trans(sample1)
234 | sample2 = trans(sample2)
235 |
236 | aug_i, lbl_i = sample1
237 | aug_j, lbl_j = sample2
238 |
239 | return aug_i, aug_j, label, df_idx
240 |
241 | def get_id_mapping(self):
242 | return self.df_idx_mapping
243 |
244 | def get_sample_id(self,idx):
245 | return self.df_idx_mapping[idx]
246 |
247 | def get_sample_length(self,idx):
248 | return self.end_idx_mapping[idx]-self.start_idx_mapping[idx]
249 |
250 | def get_sample_start(self,idx):
251 | return self.start_idx_mapping[idx]
--------------------------------------------------------------------------------
/models/inception_resnet_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class BasicConv1d(nn.Module):
6 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
7 | super(BasicConv1d, self).__init__()
8 | self.conv = nn.Conv1d(
9 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
10 | self.bn = nn.BatchNorm1d(out_planes, eps=.001)
11 | self.relu = nn.ReLU(inplace=False)
12 |
13 | def forward(self, x):
14 | x = self.conv(x)
15 | x = self.bn(x)
16 | x = self.relu(x)
17 | return x
18 |
19 |
20 | class Mixed_5b(nn.Module):
21 | def __init__(self):
22 | super(Mixed_5b, self).__init__()
23 |
24 | self.branch0 = BasicConv1d(192, 96, kernel_size=1, stride=1)
25 |
26 | self.branch1 = nn.Sequential(
27 | BasicConv1d(192, 48, kernel_size=1, stride=1),
28 | BasicConv1d(48, 64, kernel_size=5, stride=1, padding=2)
29 | )
30 |
31 | self.branch2 = nn.Sequential(
32 | BasicConv1d(192, 64, kernel_size=1, stride=1),
33 | BasicConv1d(64, 96, kernel_size=3, stride=1, padding=1),
34 | BasicConv1d(96, 96, kernel_size=3, stride=1, padding=1)
35 | )
36 |
37 | self.branch3 = nn.Sequential(
38 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
39 | BasicConv1d(192, 64, kernel_size=1, stride=1)
40 | )
41 |
42 | def forward(self, x):
43 | x0 = self.branch0(x)
44 | x1 = self.branch1(x)
45 | x2 = self.branch2(x)
46 | x3 = self.branch3(x)
47 | out = torch.cat((x0, x1, x2, x3), 1)
48 | return out
49 |
50 |
51 | class Block35(nn.Module):
52 | def __init__(self, scale=1.0):
53 | super(Block35, self).__init__()
54 |
55 | self.scale = scale
56 |
57 | self.branch0 = BasicConv1d(320, 32, kernel_size=1, stride=1)
58 |
59 | self.branch1 = nn.Sequential(
60 | BasicConv1d(320, 32, kernel_size=1, stride=1),
61 | BasicConv1d(32, 32, kernel_size=3, stride=1, padding=1)
62 | )
63 |
64 | self.branch2 = nn.Sequential(
65 | BasicConv1d(320, 32, kernel_size=1, stride=1),
66 | BasicConv1d(32, 48, kernel_size=3, stride=1, padding=1),
67 | BasicConv1d(48, 64, kernel_size=3, stride=1, padding=1)
68 | )
69 |
70 | self.conv2d = nn.Conv1d(128, 320, kernel_size=1, stride=1)
71 | self.relu = nn.ReLU(inplace=False)
72 |
73 | def forward(self, x):
74 | x0 = self.branch0(x)
75 | x1 = self.branch1(x)
76 | x2 = self.branch2(x)
77 | out = torch.cat((x0, x1, x2), 1)
78 | out = self.conv2d(out)
79 | out = out * self.scale + x
80 | out = self.relu(out)
81 | return out
82 |
83 |
84 | class Mixed_6a(nn.Module):
85 | def __init__(self):
86 | super(Mixed_6a, self).__init__()
87 |
88 | self.branch0 = BasicConv1d(320, 384, kernel_size=3, stride=2)
89 |
90 | self.branch1 = nn.Sequential(
91 | BasicConv1d(320, 256, kernel_size=1, stride=1),
92 | BasicConv1d(256, 256, kernel_size=3, stride=1, padding=1),
93 | BasicConv1d(256, 384, kernel_size=3, stride=2)
94 | )
95 |
96 | self.branch2 = nn.MaxPool1d(3, stride=2)
97 |
98 | def forward(self, x):
99 | x0 = self.branch0(x)
100 | x1 = self.branch1(x)
101 | x2 = self.branch2(x)
102 | out = torch.cat((x0, x1, x2), 1)
103 | return out
104 |
105 |
106 | class Block17(nn.Module):
107 | def __init__(self, scale=1.0):
108 | super(Block17, self).__init__()
109 |
110 | self.scale = scale
111 |
112 | self.branch0 = BasicConv1d(1088, 192, kernel_size=1, stride=1)
113 |
114 | self.branch1 = nn.Sequential(
115 | BasicConv1d(1088, 128, kernel_size=1, stride=1),
116 | BasicConv1d(128, 160, kernel_size=7, stride=1, padding=3),
117 | BasicConv1d(160, 192, kernel_size=7, stride=1, padding=3)
118 | )
119 |
120 | self.conv2d = nn.Conv1d(384, 1088, kernel_size=1, stride=1)
121 | self.relu = nn.ReLU(inplace=False)
122 |
123 | def forward(self, x):
124 | x0 = self.branch0(x)
125 | x1 = self.branch1(x)
126 | out = torch.cat((x0, x1), 1)
127 | out = self.conv2d(out)
128 | out = out * self.scale + x
129 | out = self.relu(out)
130 | return out
131 |
132 |
133 | class Mixed_7a(nn.Module):
134 | def __init__(self):
135 | super(Mixed_7a, self).__init__()
136 |
137 | self.branch0 = nn.Sequential(
138 | BasicConv1d(1088, 256, kernel_size=1, stride=1),
139 | BasicConv1d(256, 384, kernel_size=3, stride=2)
140 | )
141 |
142 | self.branch1 = nn.Sequential(
143 | BasicConv1d(1088, 256, kernel_size=1, stride=1),
144 | BasicConv1d(256, 288, kernel_size=3, stride=2)
145 | )
146 |
147 | self.branch2 = nn.Sequential(
148 | BasicConv1d(1088, 256, kernel_size=1, stride=1),
149 | BasicConv1d(256, 288, kernel_size=3, stride=1, padding=1),
150 | BasicConv1d(288, 320, kernel_size=3, stride=2)
151 | )
152 |
153 | self.branch3 = nn.MaxPool1d(3, stride=2)
154 |
155 | def forward(self, x):
156 | x0 = self.branch0(x)
157 | x1 = self.branch1(x)
158 | x2 = self.branch2(x)
159 | x3 = self.branch3(x)
160 | out = torch.cat((x0, x1, x2, x3), 1)
161 | return out
162 |
163 |
164 | class Block8(nn.Module):
165 |
166 | def __init__(self, scale=1.0, no_relu=False):
167 | super(Block8, self).__init__()
168 |
169 | self.scale = scale
170 |
171 | self.branch0 = BasicConv1d(2080, 192, kernel_size=1, stride=1)
172 |
173 | self.branch1 = nn.Sequential(
174 | BasicConv1d(2080, 192, kernel_size=1, stride=1),
175 | BasicConv1d(192, 224, kernel_size=3, stride=1, padding=1),
176 | BasicConv1d(224, 256, kernel_size=3, stride=1, padding=1)
177 | )
178 |
179 | self.conv2d = nn.Conv1d(448, 2080, kernel_size=1, stride=1)
180 | self.relu = None if no_relu else nn.ReLU(inplace=False)
181 |
182 | def forward(self, x):
183 | x0 = self.branch0(x)
184 | x1 = self.branch1(x)
185 | out = torch.cat((x0, x1), 1)
186 | out = self.conv2d(out)
187 | out = out * self.scale + x
188 | if self.relu is not None:
189 | out = self.relu(out)
190 | return out
191 |
192 | def adaptive_pool_feat_mult(pool_type='avg'):
193 | if pool_type == 'catavgmax':
194 | return 2
195 | else:
196 | return 1
197 |
198 | class SelectAdaptivePool2d(nn.Module):
199 | """Selectable global pooling layer with dynamic input kernel size
200 | """
201 | def __init__(self, output_size=1, pool_type='avg', flatten=False):
202 | super(SelectAdaptivePool2d, self).__init__()
203 | self.output_size = output_size
204 | self.pool_type = pool_type
205 | self.flatten = flatten
206 | self.pool = nn.AdaptiveAvgPool1d(output_size)
207 |
208 | def forward(self, x):
209 | x = self.pool(x)
210 | if self.flatten:
211 | x = x.flatten(1)
212 | return x
213 |
214 | def feat_mult(self):
215 | return adaptive_pool_feat_mult(self.pool_type)
216 |
217 | def __repr__(self):
218 | return self.__class__.__name__ + ' (' \
219 | + 'output_size=' + str(self.output_size) \
220 | + ', pool_type=' + self.pool_type + ')'
221 |
222 | class InceptionResnetV2(nn.Module):
223 | def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
224 | super(InceptionResnetV2, self).__init__()
225 | self.drop_rate = drop_rate
226 | self.num_classes = num_classes
227 | self.num_features = 1536
228 |
229 | self.conv2d_1a = BasicConv1d(in_chans, 32, kernel_size=3, stride=2)
230 | self.conv2d_2a = BasicConv1d(32, 32, kernel_size=3, stride=1)
231 | self.conv2d_2b = BasicConv1d(32, 64, kernel_size=3, stride=1, padding=1)
232 | self.maxpool_3a = nn.MaxPool1d(3, stride=2)
233 | self.conv2d_3b = BasicConv1d(64, 80, kernel_size=1, stride=1)
234 | self.conv2d_4a = BasicConv1d(80, 192, kernel_size=3, stride=1)
235 | self.maxpool_5a = nn.MaxPool1d(3, stride=2)
236 | self.mixed_5b = Mixed_5b()
237 | self.repeat = nn.Sequential(
238 | Block35(scale=0.17),
239 | Block35(scale=0.17),
240 | Block35(scale=0.17),
241 | Block35(scale=0.17),
242 | Block35(scale=0.17),
243 | Block35(scale=0.17),
244 | Block35(scale=0.17),
245 | Block35(scale=0.17),
246 | Block35(scale=0.17),
247 | Block35(scale=0.17)
248 | )
249 | self.mixed_6a = Mixed_6a()
250 | self.repeat_1 = nn.Sequential(
251 | Block17(scale=0.10),
252 | Block17(scale=0.10),
253 | Block17(scale=0.10),
254 | Block17(scale=0.10),
255 | Block17(scale=0.10),
256 | Block17(scale=0.10),
257 | Block17(scale=0.10),
258 | Block17(scale=0.10),
259 | Block17(scale=0.10),
260 | Block17(scale=0.10),
261 | Block17(scale=0.10),
262 | Block17(scale=0.10),
263 | Block17(scale=0.10),
264 | Block17(scale=0.10),
265 | Block17(scale=0.10),
266 | Block17(scale=0.10),
267 | Block17(scale=0.10),
268 | Block17(scale=0.10),
269 | Block17(scale=0.10),
270 | Block17(scale=0.10)
271 | )
272 | self.mixed_7a = Mixed_7a()
273 | self.repeat_2 = nn.Sequential(
274 | Block8(scale=0.20),
275 | Block8(scale=0.20),
276 | Block8(scale=0.20),
277 | Block8(scale=0.20),
278 | Block8(scale=0.20),
279 | Block8(scale=0.20),
280 | Block8(scale=0.20),
281 | Block8(scale=0.20),
282 | Block8(scale=0.20)
283 | )
284 | self.block8 = Block8(no_relu=True)
285 | self.conv2d_7b = BasicConv1d(2080, self.num_features, kernel_size=1, stride=1)
286 | self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
287 | # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
288 | self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
289 |
290 | def get_classifier(self):
291 | return self.classif
292 |
293 | def reset_classifier(self, num_classes, global_pool='avg'):
294 | self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
295 | self.num_classes = num_classes
296 | if num_classes:
297 | num_features = self.num_features * self.global_pool.feat_mult()
298 | self.classif = nn.Linear(num_features, num_classes)
299 | else:
300 | self.classif = nn.Identity()
301 |
302 | def forward_features(self, x):
303 | x = self.conv2d_1a(x)
304 | x = self.conv2d_2a(x)
305 | x = self.conv2d_2b(x)
306 | x = self.maxpool_3a(x)
307 | x = self.conv2d_3b(x)
308 | x = self.conv2d_4a(x)
309 | x = self.maxpool_5a(x)
310 | x = self.mixed_5b(x)
311 | x = self.repeat(x)
312 | x = self.mixed_6a(x)
313 | x = self.repeat_1(x)
314 | x = self.mixed_7a(x)
315 | x = self.repeat_2(x)
316 | x = self.block8(x)
317 | x = self.conv2d_7b(x)
318 | return x
319 |
320 | def forward(self, x):
321 | x = self.forward_features(x)
322 | x = self.global_pool(x).flatten(1)
323 | if self.drop_rate > 0:
324 | x = F.dropout(x, p=self.drop_rate, training=self.training)
325 | x = self.classif(x)
326 | return x
327 |
--------------------------------------------------------------------------------
/models/se_inception_resnet_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class SELayer_1d(nn.Module):
6 | def __init__(self, channel, reduction=16):
7 | super(SELayer_1d, self).__init__()
8 | # self.avg_pool = nn.AdaptiveAvgPool2d(1)
9 | self.avg_pool = nn.AdaptiveAvgPool1d(1)
10 | self.fc = nn.Sequential(
11 | nn.Linear(channel, channel // reduction, bias=False),
12 | nn.ReLU(inplace=True),
13 | nn.Linear(channel // reduction, channel, bias=False),
14 | nn.Sigmoid()
15 | )
16 |
17 | def forward(self, x):
18 | b, c, _ = x.size()
19 | y = self.avg_pool(x).view(b, c)
20 | y = self.fc(y).view(b, c, 1)
21 | return x * y.expand_as(x)
22 |
23 | class BasicConv1d(nn.Module):
24 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
25 | super(BasicConv1d, self).__init__()
26 | self.conv = nn.Conv1d(
27 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
28 | self.bn = nn.BatchNorm1d(out_planes, eps=.001)
29 | self.relu = nn.ReLU(inplace=False)
30 |
31 | def forward(self, x):
32 | x = self.conv(x)
33 | x = self.bn(x)
34 | x = self.relu(x)
35 | return x
36 |
37 |
38 | class Mixed_5b(nn.Module):
39 | def __init__(self):
40 | super(Mixed_5b, self).__init__()
41 |
42 | self.branch0 = BasicConv1d(192, 96, kernel_size=1, stride=1)
43 |
44 | self.branch1 = nn.Sequential(
45 | BasicConv1d(192, 48, kernel_size=1, stride=1),
46 | BasicConv1d(48, 64, kernel_size=5, stride=1, padding=2)
47 | )
48 |
49 | self.branch2 = nn.Sequential(
50 | BasicConv1d(192, 64, kernel_size=1, stride=1),
51 | BasicConv1d(64, 96, kernel_size=3, stride=1, padding=1),
52 | BasicConv1d(96, 96, kernel_size=3, stride=1, padding=1)
53 | )
54 |
55 | self.branch3 = nn.Sequential(
56 | nn.AvgPool1d(3, stride=1, padding=1, count_include_pad=False),
57 | BasicConv1d(192, 64, kernel_size=1, stride=1)
58 | )
59 |
60 | def forward(self, x):
61 | x0 = self.branch0(x)
62 | x1 = self.branch1(x)
63 | x2 = self.branch2(x)
64 | x3 = self.branch3(x)
65 | out = torch.cat((x0, x1, x2, x3), 1)
66 | return out
67 |
68 |
69 | class Block35(nn.Module):
70 | def __init__(self, scale=1.0):
71 | super(Block35, self).__init__()
72 |
73 | self.scale = scale
74 |
75 | self.branch0 = BasicConv1d(320, 32, kernel_size=1, stride=1)
76 |
77 | self.branch1 = nn.Sequential(
78 | BasicConv1d(320, 32, kernel_size=1, stride=1),
79 | BasicConv1d(32, 32, kernel_size=3, stride=1, padding=1)
80 | )
81 |
82 | self.branch2 = nn.Sequential(
83 | BasicConv1d(320, 32, kernel_size=1, stride=1),
84 | BasicConv1d(32, 48, kernel_size=3, stride=1, padding=1),
85 | BasicConv1d(48, 64, kernel_size=3, stride=1, padding=1)
86 | )
87 |
88 | self.conv2d = nn.Conv1d(128, 320, kernel_size=1, stride=1)
89 | self.relu = nn.ReLU(inplace=False)
90 |
91 | def forward(self, x):
92 | x0 = self.branch0(x)
93 | x1 = self.branch1(x)
94 | x2 = self.branch2(x)
95 | out = torch.cat((x0, x1, x2), 1)
96 | out = self.conv2d(out)
97 | out = out * self.scale + x
98 | out = self.relu(out)
99 | return out
100 |
101 |
102 | class Mixed_6a(nn.Module):
103 | def __init__(self):
104 | super(Mixed_6a, self).__init__()
105 |
106 | self.branch0 = BasicConv1d(320, 384, kernel_size=3, stride=2)
107 |
108 | self.branch1 = nn.Sequential(
109 | BasicConv1d(320, 256, kernel_size=1, stride=1),
110 | BasicConv1d(256, 256, kernel_size=3, stride=1, padding=1),
111 | BasicConv1d(256, 384, kernel_size=3, stride=2)
112 | )
113 |
114 | self.branch2 = nn.MaxPool1d(3, stride=2)
115 |
116 | def forward(self, x):
117 | x0 = self.branch0(x)
118 | x1 = self.branch1(x)
119 | x2 = self.branch2(x)
120 | out = torch.cat((x0, x1, x2), 1)
121 | return out
122 |
123 |
124 | class Block17(nn.Module):
125 | def __init__(self, scale=1.0):
126 | super(Block17, self).__init__()
127 |
128 | self.scale = scale
129 |
130 | self.branch0 = BasicConv1d(1088, 192, kernel_size=1, stride=1)
131 |
132 | self.branch1 = nn.Sequential(
133 | BasicConv1d(1088, 128, kernel_size=1, stride=1),
134 | BasicConv1d(128, 160, kernel_size=7, stride=1, padding=3),
135 | BasicConv1d(160, 192, kernel_size=7, stride=1, padding=3)
136 | )
137 |
138 | self.conv2d = nn.Conv1d(384, 1088, kernel_size=1, stride=1)
139 | self.relu = nn.ReLU(inplace=False)
140 |
141 | def forward(self, x):
142 | x0 = self.branch0(x)
143 | x1 = self.branch1(x)
144 | out = torch.cat((x0, x1), 1)
145 | out = self.conv2d(out)
146 | out = out * self.scale + x
147 | out = self.relu(out)
148 | return out
149 |
150 |
151 | class Mixed_7a(nn.Module):
152 | def __init__(self):
153 | super(Mixed_7a, self).__init__()
154 |
155 | self.branch0 = nn.Sequential(
156 | BasicConv1d(1088, 256, kernel_size=1, stride=1),
157 | BasicConv1d(256, 384, kernel_size=3, stride=2)
158 | )
159 |
160 | self.branch1 = nn.Sequential(
161 | BasicConv1d(1088, 256, kernel_size=1, stride=1),
162 | BasicConv1d(256, 288, kernel_size=3, stride=2)
163 | )
164 |
165 | self.branch2 = nn.Sequential(
166 | BasicConv1d(1088, 256, kernel_size=1, stride=1),
167 | BasicConv1d(256, 288, kernel_size=3, stride=1, padding=1),
168 | BasicConv1d(288, 320, kernel_size=3, stride=2)
169 | )
170 |
171 | self.branch3 = nn.MaxPool1d(3, stride=2)
172 |
173 | def forward(self, x):
174 | x0 = self.branch0(x)
175 | x1 = self.branch1(x)
176 | x2 = self.branch2(x)
177 | x3 = self.branch3(x)
178 | out = torch.cat((x0, x1, x2, x3), 1)
179 | return out
180 |
181 |
182 | class Block8(nn.Module):
183 |
184 | def __init__(self, scale=1.0, no_relu=False):
185 | super(Block8, self).__init__()
186 |
187 | self.scale = scale
188 |
189 | self.branch0 = BasicConv1d(2080, 192, kernel_size=1, stride=1)
190 |
191 | self.branch1 = nn.Sequential(
192 | BasicConv1d(2080, 192, kernel_size=1, stride=1),
193 | BasicConv1d(192, 224, kernel_size=3, stride=1, padding=1),
194 | BasicConv1d(224, 256, kernel_size=3, stride=1, padding=1)
195 | )
196 |
197 | self.conv2d = nn.Conv1d(448, 2080, kernel_size=1, stride=1)
198 | self.relu = None if no_relu else nn.ReLU(inplace=False)
199 |
200 | def forward(self, x):
201 | x0 = self.branch0(x)
202 | x1 = self.branch1(x)
203 | out = torch.cat((x0, x1), 1)
204 | out = self.conv2d(out)
205 | out = out * self.scale + x
206 | if self.relu is not None:
207 | out = self.relu(out)
208 | return out
209 |
210 | def adaptive_pool_feat_mult(pool_type='avg'):
211 | if pool_type == 'catavgmax':
212 | return 2
213 | else:
214 | return 1
215 |
216 | class SelectAdaptivePool1d(nn.Module):
217 | """Selectable global pooling layer with dynamic input kernel size
218 | """
219 | def __init__(self, output_size=1, pool_type='avg', flatten=False):
220 | super(SelectAdaptivePool1d, self).__init__()
221 | self.output_size = output_size
222 | self.pool_type = pool_type
223 | self.flatten = flatten
224 | self.pool = nn.AdaptiveAvgPool1d(output_size)
225 |
226 | def forward(self, x):
227 | x = self.pool(x)
228 | if self.flatten:
229 | x = x.flatten(1)
230 | return x
231 |
232 | def feat_mult(self):
233 | return adaptive_pool_feat_mult(self.pool_type)
234 |
235 | def __repr__(self):
236 | return self.__class__.__name__ + ' (' \
237 | + 'output_size=' + str(self.output_size) \
238 | + ', pool_type=' + self.pool_type + ')'
239 |
240 | class SE_InceptionResnetV2(nn.Module):
241 | def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
242 | super(SE_InceptionResnetV2, self).__init__()
243 | self.drop_rate = drop_rate
244 | self.num_classes = num_classes
245 | self.num_features = 1536
246 |
247 | self.conv2d_1a = BasicConv1d(in_chans, 32, kernel_size=3, stride=2)
248 | self.conv2d_2a = BasicConv1d(32, 32, kernel_size=3, stride=1)
249 | self.conv2d_2b = BasicConv1d(32, 64, kernel_size=3, stride=1, padding=1)
250 | self.maxpool_3a = nn.MaxPool1d(3, stride=2)
251 | self.conv2d_3b = BasicConv1d(64, 80, kernel_size=1, stride=1)
252 | self.conv2d_4a = BasicConv1d(80, 192, kernel_size=3, stride=1)
253 | self.maxpool_5a = nn.MaxPool1d(3, stride=2)
254 | self.mixed_5b = Mixed_5b()
255 | self.repeat = nn.Sequential(
256 | Block35(scale=0.17),
257 | SELayer_1d(channel=320),
258 | Block35(scale=0.17),
259 | SELayer_1d(channel=320),
260 | Block35(scale=0.17),
261 | SELayer_1d(channel=320),
262 | Block35(scale=0.17),
263 | SELayer_1d(channel=320),
264 | Block35(scale=0.17),
265 | SELayer_1d(channel=320),
266 | Block35(scale=0.17),
267 | SELayer_1d(channel=320),
268 | Block35(scale=0.17),
269 | SELayer_1d(channel=320),
270 | Block35(scale=0.17),
271 | SELayer_1d(channel=320),
272 | Block35(scale=0.17),
273 | SELayer_1d(channel=320),
274 | Block35(scale=0.17),
275 | SELayer_1d(channel=320)
276 | )
277 | self.mixed_6a = Mixed_6a()
278 | self.repeat_1 = nn.Sequential(
279 | Block17(scale=0.10),
280 | SELayer_1d(channel=1088),
281 | Block17(scale=0.10),
282 | SELayer_1d(channel=1088),
283 | Block17(scale=0.10),
284 | SELayer_1d(channel=1088),
285 | Block17(scale=0.10),
286 | SELayer_1d(channel=1088),
287 | Block17(scale=0.10),
288 | SELayer_1d(channel=1088),
289 | Block17(scale=0.10),
290 | SELayer_1d(channel=1088),
291 | Block17(scale=0.10),
292 | SELayer_1d(channel=1088),
293 | Block17(scale=0.10),
294 | SELayer_1d(channel=1088),
295 | Block17(scale=0.10),
296 | SELayer_1d(channel=1088),
297 | Block17(scale=0.10),
298 | SELayer_1d(channel=1088),
299 | Block17(scale=0.10),
300 | SELayer_1d(channel=1088),
301 | Block17(scale=0.10),
302 | SELayer_1d(channel=1088),
303 | Block17(scale=0.10),
304 | SELayer_1d(channel=1088),
305 | Block17(scale=0.10),
306 | SELayer_1d(channel=1088),
307 | Block17(scale=0.10),
308 | SELayer_1d(channel=1088),
309 | Block17(scale=0.10),
310 | SELayer_1d(channel=1088),
311 | Block17(scale=0.10),
312 | SELayer_1d(channel=1088),
313 | Block17(scale=0.10),
314 | SELayer_1d(channel=1088),
315 | Block17(scale=0.10),
316 | SELayer_1d(channel=1088),
317 | Block17(scale=0.10),
318 | SELayer_1d(channel=1088)
319 | )
320 | self.mixed_7a = Mixed_7a()
321 | self.repeat_2 = nn.Sequential(
322 | Block8(scale=0.20),
323 | SELayer_1d(channel=2080),
324 | Block8(scale=0.20),
325 | SELayer_1d(channel=2080),
326 | Block8(scale=0.20),
327 | SELayer_1d(channel=2080),
328 | Block8(scale=0.20),
329 | SELayer_1d(channel=2080),
330 | Block8(scale=0.20),
331 | SELayer_1d(channel=2080),
332 | Block8(scale=0.20),
333 | SELayer_1d(channel=2080),
334 | Block8(scale=0.20),
335 | SELayer_1d(channel=2080),
336 | Block8(scale=0.20),
337 | SELayer_1d(channel=2080),
338 | Block8(scale=0.20),
339 | SELayer_1d(channel=2080)
340 | )
341 | self.block8 = Block8(no_relu=True)
342 | self.conv2d_7b = BasicConv1d(2080, self.num_features, kernel_size=1, stride=1)
343 | self.global_pool = SelectAdaptivePool1d(pool_type=global_pool)
344 | # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
345 | self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
346 |
347 | def get_classifier(self):
348 | return self.classif
349 |
350 | def reset_classifier(self, num_classes, global_pool='avg'):
351 | self.global_pool = SelectAdaptivePool1d(pool_type=global_pool)
352 | self.num_classes = num_classes
353 | if num_classes:
354 | num_features = self.num_features * self.global_pool.feat_mult()
355 | self.classif = nn.Linear(num_features, num_classes)
356 | else:
357 | self.classif = nn.Identity()
358 |
359 | def forward_features(self, x):
360 | x = self.conv2d_1a(x)
361 | x = self.conv2d_2a(x)
362 | x = self.conv2d_2b(x)
363 | x = self.maxpool_3a(x)
364 | x = self.conv2d_3b(x)
365 | x = self.conv2d_4a(x)
366 | x = self.maxpool_5a(x)
367 | x = self.mixed_5b(x)
368 | x = self.repeat(x)
369 | x = self.mixed_6a(x)
370 | x = self.repeat_1(x)
371 | x = self.mixed_7a(x)
372 | x = self.repeat_2(x)
373 | x = self.block8(x)
374 | x = self.conv2d_7b(x)
375 | return x
376 |
377 | def forward(self, x):
378 | x = self.forward_features(x)
379 | x = self.global_pool(x).flatten(1)
380 | if self.drop_rate > 0:
381 | x = F.dropout(x, p=self.drop_rate, training=self.training)
382 | x = self.classif(x)
383 | return x
--------------------------------------------------------------------------------
/utils/eval_tools.py:
--------------------------------------------------------------------------------
1 | from tracemalloc import start
2 | import numpy as np
3 | import time
4 | from sklearn.metrics import multilabel_confusion_matrix
5 | import numpy as np
6 | import os
7 |
8 | # Check if the input is a number.
9 | def is_number(x):
10 | try:
11 | float(x)
12 | return True
13 | except ValueError:
14 | return False
15 |
16 | def load_table(table_file):
17 | # The table should have the following form:
18 | #
19 | # , a, b, c
20 | # a, 1.2, 2.3, 3.4
21 | # b, 4.5, 5.6, 6.7
22 | # c, 7.8, 8.9, 9.0
23 | #
24 | table = list()
25 | with open(table_file, 'r') as f:
26 | for i, l in enumerate(f):
27 | arrs = [arr.strip() for arr in l.split(',')]
28 | table.append(arrs)
29 |
30 | # Define the numbers of rows and columns and check for errors.
31 | num_rows = len(table)-1
32 | if num_rows<1:
33 | raise Exception('The table {} is empty.'.format(table_file))
34 |
35 | num_cols = set(len(table[i])-1 for i in range(num_rows))
36 | if len(num_cols)!=1:
37 | raise Exception('The table {} has rows with different lengths.'.format(table_file))
38 | num_cols = min(num_cols)
39 | if num_cols<1:
40 | raise Exception('The table {} is empty.'.format(table_file))
41 |
42 | # Find the row and column labels.
43 | rows = [table[0][j+1] for j in range(num_rows)]
44 | cols = [table[i+1][0] for i in range(num_cols)]
45 |
46 | # Find the entries of the table.
47 | values = np.zeros((num_rows, num_cols), dtype=np.float64)
48 | for i in range(num_rows):
49 | for j in range(num_cols):
50 | value = table[i+1][j+1]
51 | if is_number(value):
52 | values[i, j] = float(value)
53 | else:
54 | values[i, j] = float('nan')
55 |
56 | return rows, cols, values
57 |
58 | # For each set of equivalent classes, replace each class with the representative class for the set.
59 | def replace_equivalent_classes(classes, equivalent_classes):
60 | for j, x in enumerate(classes):
61 | for multiple_classes in equivalent_classes:
62 | if x in multiple_classes:
63 | classes[j] = multiple_classes[0] # Use the first class as the representative class.
64 | return classes
65 |
66 | # Load weights.
67 | def load_weights(weight_file, equivalent_classes):
68 | # Load the weight matrix.
69 | rows, cols, values = load_table(weight_file)
70 | assert(rows == cols)
71 |
72 | # For each collection of equivalent classes, replace each class with the representative class for the set.
73 | rows = replace_equivalent_classes(rows, equivalent_classes)
74 |
75 | # Check that equivalent classes have identical weights.
76 | for j, x in enumerate(rows):
77 | for k, y in enumerate(rows[j+1:]):
78 | if x==y:
79 | assert(np.all(values[j, :]==values[j+1+k, :]))
80 | assert(np.all(values[:, j]==values[:, j+1+k]))
81 |
82 | # Use representative classes.
83 | classes = [x for j, x in enumerate(rows) if x not in rows[:j]]
84 | indices = [rows.index(x) for x in classes]
85 | weights = values[np.ix_(indices, indices)]
86 |
87 | return classes, weights
88 |
89 |
90 | # Compute recording-wise accuracy.
91 | # input is np.bool
92 | def compute_accuracy(labels, outputs):
93 | num_recordings, num_classes = np.shape(labels)
94 | comparison = [np.all(labels[idx,:]==outputs[idx,:]) for idx in range(num_recordings)]
95 | num_correct_recordings = np.count_nonzero(comparison)
96 |
97 | return float(num_correct_recordings) / float(num_recordings)
98 |
99 |
100 | def compute_confusion_matrices(labels, outputs, normalize=False):
101 | # Compute a binary confusion matrix for each class k:
102 | #
103 | # [TN_k FN_k]
104 | # [FP_k TP_k]
105 | #
106 | # If the normalize variable is set to true, then normalize the contributions
107 | # to the confusion matrix by the number of labels per recording.
108 | num_recordings, num_classes = np.shape(labels)
109 |
110 | if not normalize:
111 | A = np.zeros((num_classes, 2, 2))
112 | for i in range(num_recordings):
113 | for j in range(num_classes):
114 | if labels[i, j]==1 and outputs[i, j]==1: # TP
115 | A[j, 1, 1] += 1
116 | elif labels[i, j]==0 and outputs[i, j]==1: # FP
117 | A[j, 1, 0] += 1
118 | elif labels[i, j]==1 and outputs[i, j]==0: # FN
119 | A[j, 0, 1] += 1
120 | elif labels[i, j]==0 and outputs[i, j]==0: # TN
121 | A[j, 0, 0] += 1
122 | else: # This condition should not happen.
123 | raise ValueError('Error in computing the confusion matrix.')
124 | else:
125 | A = np.zeros((num_classes, 2, 2))
126 | for i in range(num_recordings):
127 | normalization = float(max(np.sum(labels[i, :]), 1))
128 | for j in range(num_classes):
129 | if labels[i, j]==1 and outputs[i, j]==1: # TP
130 | A[j, 1, 1] += 1.0/normalization
131 | elif labels[i, j]==0 and outputs[i, j]==1: # FP
132 | A[j, 1, 0] += 1.0/normalization
133 | elif labels[i, j]==1 and outputs[i, j]==0: # FN
134 | A[j, 0, 1] += 1.0/normalization
135 | elif labels[i, j]==0 and outputs[i, j]==0: # TN
136 | A[j, 0, 0] += 1.0/normalization
137 | else: # This condition should not happen.
138 | raise ValueError('Error in computing the confusion matrix.')
139 |
140 | return A
141 |
142 | # Compute macro F-measure.
143 | # input is np.bool
144 | def compute_f_measure(labels, outputs):
145 | num_recordings, num_classes = np.shape(labels)
146 | A = compute_confusion_matrices(labels, outputs)
147 | # [[tn,fn],[fp,tp]]
148 |
149 | f_measure = np.zeros(num_classes) # f_measure_classes
150 | for k in range(num_classes):
151 | tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
152 | if 2 * tp + fp + fn:
153 | f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn)
154 | else:
155 | f_measure[k] = float('nan')
156 |
157 | macro_f_measure = np.nanmean(f_measure)
158 |
159 | return macro_f_measure
160 |
161 | def compute_f_measure_mod(labels, outputs):
162 | num_recordings, num_classes = np.shape(labels)
163 | A = multilabel_confusion_matrix(labels, outputs)
164 | # [[tn,fp],[fn,tp]]
165 | f_measure = np.zeros(num_classes) # f_measure_classes
166 | for k in range(num_classes):
167 | tp, fn, fp, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
168 | if 2 * tp + fp + fn:
169 | f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn)
170 | else:
171 | f_measure[k] = float('nan')
172 |
173 | macro_f_measure = np.nanmean(f_measure)
174 | return macro_f_measure
175 |
176 | # Compute macro AUROC and macro AUPRC.
177 | # input scalar np.float64 for outputs
178 | def compute_auc(labels, outputs):
179 | num_recordings, num_classes = np.shape(labels)
180 |
181 | # Compute and summarize the confusion matrices for each class across at distinct output values.
182 | auroc = np.zeros(num_classes)
183 | auprc = np.zeros(num_classes)
184 |
185 | for k in range(num_classes):
186 | # We only need to compute TPs, FPs, FNs, and TNs at distinct output values.
187 | thresholds = np.unique(outputs[:, k])
188 | thresholds = np.append(thresholds, thresholds[-1]+1)
189 | thresholds = thresholds[::-1]
190 | num_thresholds = len(thresholds)
191 |
192 | # Initialize the TPs, FPs, FNs, and TNs.
193 | tp = np.zeros(num_thresholds)
194 | fp = np.zeros(num_thresholds)
195 | fn = np.zeros(num_thresholds)
196 | tn = np.zeros(num_thresholds)
197 | fn[0] = np.sum(labels[:, k]==1)
198 | tn[0] = np.sum(labels[:, k]==0)
199 |
200 | # Find the indices that result in sorted output values.
201 | idx = np.argsort(outputs[:, k])[::-1]
202 |
203 | # Compute the TPs, FPs, FNs, and TNs for class k across thresholds.
204 | i = 0
205 | for j in range(1, num_thresholds):
206 | # Initialize TPs, FPs, FNs, and TNs using values at previous threshold.
207 | tp[j] = tp[j-1]
208 | fp[j] = fp[j-1]
209 | fn[j] = fn[j-1]
210 | tn[j] = tn[j-1]
211 |
212 | # Update the TPs, FPs, FNs, and TNs at i-th output value.
213 | while i < num_recordings and outputs[idx[i], k] >= thresholds[j]:
214 | if labels[idx[i], k]:
215 | tp[j] += 1
216 | fn[j] -= 1
217 | else:
218 | fp[j] += 1
219 | tn[j] -= 1
220 | i += 1
221 |
222 | # Summarize the TPs, FPs, FNs, and TNs for class k.
223 | tpr = np.zeros(num_thresholds)
224 | tnr = np.zeros(num_thresholds)
225 | ppv = np.zeros(num_thresholds)
226 | for j in range(num_thresholds):
227 | if tp[j] + fn[j]:
228 | tpr[j] = float(tp[j]) / float(tp[j] + fn[j])
229 | else:
230 | tpr[j] = float('nan')
231 | if fp[j] + tn[j]:
232 | tnr[j] = float(tn[j]) / float(fp[j] + tn[j])
233 | else:
234 | tnr[j] = float('nan')
235 | if tp[j] + fp[j]:
236 | ppv[j] = float(tp[j]) / float(tp[j] + fp[j])
237 | else:
238 | ppv[j] = float('nan')
239 |
240 | # Compute AUROC as the area under a piecewise linear function with TPR/
241 | # sensitivity (x-axis) and TNR/specificity (y-axis) and AUPRC as the area
242 | # under a piecewise constant with TPR/recall (x-axis) and PPV/precision
243 | # (y-axis) for class k.
244 | for j in range(num_thresholds-1):
245 | auroc[k] += 0.5 * (tpr[j+1] - tpr[j]) * (tnr[j+1] + tnr[j])
246 | auprc[k] += (tpr[j+1] - tpr[j]) * ppv[j+1]
247 |
248 | # Compute macro AUROC and macro AUPRC across classes.
249 | macro_auroc = np.nanmean(auroc)
250 | macro_auprc = np.nanmean(auprc)
251 |
252 | return macro_auroc, macro_auprc
253 |
254 | # computer f beta and g beta
255 | # input is np.bool
256 | def compute_beta_measures(labels, outputs, beta):
257 | num_recordings, num_classes = np.shape(labels)
258 |
259 | A = compute_confusion_matrices(labels, outputs, normalize=True)
260 |
261 | f_beta_measure = np.zeros(num_classes)
262 | g_beta_measure = np.zeros(num_classes)
263 | for k in range(num_classes):
264 | tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
265 | if (1+beta**2)*tp + fp + beta**2*fn:
266 | f_beta_measure[k] = float((1+beta**2)*tp) / float((1+beta**2)*tp + fp + beta**2*fn)
267 | else:
268 | f_beta_measure[k] = float('nan')
269 | if tp + fp + beta*fn:
270 | g_beta_measure[k] = float(tp) / float(tp + fp + beta*fn)
271 | else:
272 | g_beta_measure[k] = float('nan')
273 |
274 | macro_f_beta_measure = np.nanmean(f_beta_measure)
275 | macro_g_beta_measure = np.nanmean(g_beta_measure)
276 |
277 | return macro_f_beta_measure, macro_g_beta_measure
278 |
279 | # Compute Challenge Metric
280 | # input is np.bool
281 | def compute_challenge_metric(weights,labels,outputs,classes,normal_class):
282 | num_recordings, num_classes = np.shape(labels)
283 | normal_index = classes.index(normal_class)
284 |
285 | # Compute the observed score.
286 | A = compute_modified_confusion_matrix(labels, outputs)
287 | observed_score = np.nansum(weights * A)
288 |
289 | # Compute the score for the model that always chooses the correct label(s).
290 | correct_outputs = labels
291 | A = compute_modified_confusion_matrix(labels, correct_outputs)
292 | correct_score = np.nansum(weights * A)
293 |
294 | # Compute the score for the model that always chooses the normal class.
295 | inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool)
296 | inactive_outputs[:, normal_index] = 1
297 | A = compute_modified_confusion_matrix(labels, inactive_outputs)
298 | inactive_score = np.nansum(weights * A)
299 |
300 | if correct_score != inactive_score:
301 | normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
302 | else:
303 | normalized_score = 0.0
304 |
305 | return normalized_score
306 |
307 | def compute_modified_confusion_matrix(labels, outputs):
308 | # Compute a binary multi-class, multi-label confusion matrix, where the rows
309 | # are the labels and the columns are the outputs.
310 | num_recordings, num_classes = np.shape(labels)
311 | A = np.zeros((num_classes, num_classes))
312 |
313 | # Iterate over all of the recordings.
314 | for i in range(num_recordings):
315 | # Calculate the number of positive labels and/or outputs.
316 | normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1))
317 | # Iterate over all of the classes.
318 | for j in range(num_classes):
319 | # Assign full and/or partial credit for each positive class.
320 | if labels[i, j]:
321 | for k in range(num_classes):
322 | if outputs[i, k]:
323 | A[j, k] += 1.0/normalization
324 |
325 | return A
326 |
327 | # Load weights.
328 | def load_weights(weight_file, equivalent_classes):
329 | # Load the weight matrix.
330 | rows, cols, values = load_table(weight_file)
331 | assert(rows == cols)
332 |
333 | # For each collection of equivalent classes, replace each class with the representative class for the set.
334 | rows = replace_equivalent_classes(rows, equivalent_classes)
335 |
336 | # Check that equivalent classes have identical weights.
337 | for j, x in enumerate(rows):
338 | for k, y in enumerate(rows[j+1:]):
339 | if x==y:
340 | assert(np.all(values[j, :]==values[j+1+k, :]))
341 | assert(np.all(values[:, j]==values[:, j+1+k]))
342 |
343 | # Use representative classes.
344 | classes = [x for j, x in enumerate(rows) if x not in rows[:j]]
345 | indices = [rows.index(x) for x in classes]
346 | weights = values[np.ix_(indices, indices)]
347 |
348 | return classes, weights
349 |
350 |
--------------------------------------------------------------------------------