├── .gitignore
├── README.md
├── dataloader.py
├── figure
├── architecture.png
├── incresblock.png
└── inference.gif
├── infer.py
├── model.py
├── preprocess.py
├── requirements.txt
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | #data
133 | saved_data/
134 | runs/
135 | best_model/
136 | results/
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SeismoNet
2 |
3 | This repository contains code used for the paper "End-to-End Deep Learning for Reliable Cardiac Activity Monitoring using Seismocardiograms" which has been accepted for presentation at the [19th International Conference on Machine Learning and Applications](https://www.icmla-conference.org/icmla20/index.html), Boca Raton, FL, USA.
4 |
5 | SeismoNet is a Deep Convolutional Neural Network which aims to provide an end-to-end solution to robustly observe heart activity from Seismocardiogram (SCG) signals. These SCG signals are motion-based and can be acquired in an easy, user-friendly fashion. SeismoNet transforms the SCG signal into an interpretable waveform consisting of relevant information which allows for extraction of heart rate indices.
6 |
7 | Preprint available at [arxiv](https://arxiv.org/abs/2010.05662) :newspaper:
8 |
9 | ## Getting Started :rocket:
10 |
11 | * [preprocess.py](preprocess.py) Preprocesses the CEBS dataset available at [physionet](https://physionet.org/content/cebsdb/1.0.0/)
12 | * [trainer.py](trainer.py) Helps train the model.
13 | * [infer.py](infer.py) Helps take inference on any input SCG signal.
14 | * [utils.py](utils.py) This file consists of all the helper functions.
15 | * [model.py](model.py) SeismoNet architecture in torch
16 |
17 | ## Model Architecture
18 |
19 |
20 |
21 | ## Usage
22 |
23 | Install all dependencies with:
24 | ```bash
25 | $ pip install -r requirements.txt
26 | ```
27 | Download datasets with:
28 | ```bash
29 | $ wget -r -N -c -np https://physionet.org/files/cebsdb/1.0.0/
30 | ```
31 | Preprocess raw data:
32 | ```bash
33 | $ python preprocess.py --data_path /path/to/data
34 | ```
35 | Train SeismoNet using preprocessed data:
36 | ```bash
37 | $ python trainer.py --data_path /path/to/preprocessed/data
38 | ```
39 |
40 | Take inference and evaluate model:
41 | ```bash
42 | $ python infer.py --best_model /path/to/model --data_path /path/to/preprocessed/data --evaluate
43 | ```
44 | ## Inference
45 |
46 |
47 | ## Authors :mortar_board:
48 |
49 | [Prithvi Suresh](https://github.com/prithusuresh/), [Naveen Narayanan](https://github.com/naveenggmu/), Pranav CV, Vineeth Vijayaraghavan
50 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.data import Dataset
7 | from glob import glob
8 |
9 | class CEBSDataset(Dataset):
10 | def __init__(self, data_path, ecg_channel = 1):
11 |
12 |
13 |
14 | self.data_path = data_path #---"saved_data/b"
15 | self.input = []
16 | self.ground = []
17 |
18 | gt_file_suffix = "groundTruth{}_".format(ecg_channel)
19 | p_files = sorted(glob(os.path.join(data_path, "preprocessed_data","inputSig_*.pt")))
20 |
21 | for inp_file in p_files:
22 | p_no = inp_file.split(".")[-2].split("_")[-1]
23 | self.input.append(torch.load(inp_file))
24 | gt_file_name = '/'.join(inp_file.split("/")[:-1]) +"/"+ gt_file_suffix + str(p_no) + ".pt"
25 | self.ground.append(torch.load(gt_file_name))
26 |
27 | self.input = torch.cat(self.input).type(torch.float)
28 | self.ground = torch.cat(self.ground).type(torch.float)
29 |
30 | def __len__(self):
31 | return len(self.mer_input)
32 |
33 | def __getitem__(self,idx):
34 |
35 | label = self.ground[idx]
36 | input_tensor = self.input[idx]
37 | return input_tensor,label
--------------------------------------------------------------------------------
/figure/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prithusuresh/SeismoNet/8850f35a4d1d9db520d1e38b58347544c0daa012/figure/architecture.png
--------------------------------------------------------------------------------
/figure/incresblock.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prithusuresh/SeismoNet/8850f35a4d1d9db520d1e38b58347544c0daa012/figure/incresblock.png
--------------------------------------------------------------------------------
/figure/inference.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prithusuresh/SeismoNet/8850f35a4d1d9db520d1e38b58347544c0daa012/figure/inference.gif
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | '''infer'''
2 | import os
3 | import sys
4 | import signal
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import Dataset, DataLoader, TensorDataset
10 | from tqdm import tqdm
11 | from torch import optim
12 | import torch.nn.functional as F
13 | from glob import glob
14 | import warnings
15 | import matplotlib.pyplot as plt
16 |
17 | from sklearn.model_selection import train_test_split
18 | import argparse
19 |
20 | from model import SeismoNet
21 | from utils import *
22 |
23 | def main(args):
24 |
25 |
26 | if args.create_test_file:
27 | print ("Creating Test File... ")
28 | data_path = os.path.join(args.data_path, args.file_type)
29 | __, _, test_loader = create_loaders(data_path, "data.pt","labels.pt")
30 |
31 | else:
32 | print ("Loading Test File... ")
33 | test_tensor = torch.load(args.test_tensor_file)
34 | test_dataset = TensorDataset(test_tensor)
35 | test_loader = DataLoader(test_dataset, batch_size = 1, pin_memory = True)
36 |
37 | print ("Loading Model... ")
38 | model = SeismoNet(get_shape(test_loader))
39 | model.load_state_dict(torch.load(args.best_model)["model"])
40 | window_info = []
41 | metrics = []
42 | if not(os.path.exists("results/")):
43 | os.mkdir("results/")
44 |
45 | for i,x in enumerate(test_loader):
46 |
47 | if len(x) > 1:
48 | pred_distance_transform, pred_peak_locations = infer(model, x[0] ,downsampling_factor = 1)
49 | print (pred_distance_transform, pred_peak_locations)
50 |
51 | else:
52 | pred_distance_transform, peak_locations = infer(model, x, downsampling_factor = 1)
53 |
54 | if args.evaluate:
55 | assert len(x)>1
56 | actual_peak_locations = np.where(x[1] == 0.0)[0] #provide actual rpeak locations as array
57 | metrics.append(evaluate_window(actual_peak_locations, pred_peak_locations))
58 |
59 | if args.save_figures:
60 | if not(os.path.exits("results/figures")):
61 | os.mkdir("results/figures")
62 | plt.figure(figsize = [10,5])
63 | plt.subplot(1,2,1)
64 | plt.plot(x[0].cpu().numpy().flatten())
65 | plt.subplot(1,2,2)
66 | plt.plot(pred_distance_transform.flatten())
67 | plt.plot(x[1].cpu().numpy().flatten())
68 | plt.scatter(pred_peak_locations, pred_distance_transform.flatten()[pred_peak_locations])
69 | plt.savefig("results/figures/{}.png".format(i+1))
70 |
71 | window_info.append(pred_peak_locations)
72 | metrics = pd.DataFrame(metrics)
73 | metrics.to_csv("results/results.csv")
74 |
75 |
76 |
77 | if __name__ == "__main__":
78 | parser = argparse.ArgumentParser()
79 | parser.add_argument("--create_test_file", action ="store_true", help = "Create test file if not already present")
80 | parser.add_argument('--test_tensor_file',nargs="?" , help = 'Path to saved files directory')
81 | parser.add_argument('--data_path',nargs="?", const = "saved_data/", default = "saved_data/", help = 'Path to saved files directory')
82 | parser.add_argument('--file_type',nargs="?", const = "b", default = "b", help = "file type")
83 | parser.add_argument('--best_model',nargs="?", const = "best_model/best_model_pretrained.pt", default = "best_model/best_model_pretrained.pt", help = "Best Model File")
84 | parser.add_argument('--evaluate', action = "store_true", help = "Compare against label or not")
85 | parser.add_argument('--save_figures', action = "store_true", help = "save figure along with results")
86 | args = parser.parse_args()
87 |
88 | main(args)
89 |
90 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | class IncBlock(nn.Module):
7 | def __init__(self, in_channels, out_channels, size = 15, stride = 1, padding = 7):
8 | super(IncBlock,self).__init__()
9 |
10 | self.conv1x1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias = False)
11 |
12 | self.conv1 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = size, stride = stride, padding = padding ),
13 | nn.BatchNorm1d(out_channels//4))
14 |
15 | self.conv2 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False),
16 | nn.BatchNorm1d(out_channels//4),
17 | nn.LeakyReLU(0.2),
18 | nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size +2 , stride = stride, padding = padding + 1),
19 | nn.BatchNorm1d(out_channels//4))
20 |
21 | self.conv3 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False),
22 | nn.BatchNorm1d(out_channels//4),
23 | nn.LeakyReLU(0.2),
24 | nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size + 4 , stride = stride, padding = padding + 2),
25 | nn.BatchNorm1d(out_channels//4))
26 |
27 | self.conv4 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False),
28 | nn.BatchNorm1d(out_channels//4),
29 | nn.LeakyReLU(0.2),
30 | nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size + 6 , stride = stride, padding = padding + 3),
31 | nn.BatchNorm1d(out_channels//4))
32 |
33 | self.relu = nn.ReLU()
34 |
35 | def forward(self,x):
36 |
37 | res = self.conv1x1(x)
38 |
39 | c1 = self.conv1(x)
40 |
41 | c2 = self.conv2(x)
42 |
43 | c3 = self.conv3(x)
44 |
45 | c4 = self.conv4(x)
46 |
47 | concat = torch.cat((c1,c2,c3,c4),dim = 1)
48 |
49 | concat+=res
50 |
51 | return self.relu(concat)
52 |
53 | class AveragingBlock(nn.Module):
54 |
55 | def __init__(self,in_channels = 1, out_channels = 1):
56 |
57 | super(AveragingBlock, self).__init__()
58 |
59 | self.conv1 = nn.Conv1d(in_channels,8,3)
60 | self.bn1 = nn.BatchNorm1d(8)
61 |
62 | self.conv2 = nn.Conv1d(8,16,3)
63 | self.bn2 =nn.BatchNorm1d(16)
64 |
65 | self.conv3 = nn.Conv2d(1,1,(3,3), 2)
66 | self.bn3 = nn.BatchNorm2d(1)
67 |
68 | self.conv4 = nn.Conv2d(1, 1, (3,15), padding = (0,7))
69 | self.bn4 = nn.BatchNorm2d(1)
70 |
71 | self.conv5 = nn.Conv1d(1,out_channels,3, padding = 1)
72 | self.bn5 = nn.BatchNorm1d(out_channels)
73 |
74 | self.relu1 = nn.LeakyReLU(0.2)
75 |
76 | self.mp1 = nn.MaxPool1d(2)
77 | self.mp2 = nn.MaxPool2d((2,2))
78 |
79 | def forward(self, x):
80 |
81 | x = self.relu1(self.bn1(self.conv1(x)))
82 |
83 | x = self.relu1(self.bn2(self.conv2(x)))
84 |
85 | x = x.view(x.shape[0],1,x.shape[1],x.shape[2])
86 |
87 | x = self.relu1(self.bn3(self.conv3(x)))
88 |
89 | x = self.mp2(x)
90 |
91 | x = self.relu1(self.bn4(self.conv4(x)))
92 |
93 | x = torch.squeeze(x, dim = 1)
94 |
95 | x = self.relu1(self.bn5(self.conv5(x)))
96 |
97 | return x
98 |
99 | class SeismoNet(nn.Module):
100 | def __init__(self, shape):
101 | super(SeismoNet, self).__init__()
102 | in_channels = 1
103 | self.cea = nn.Sequential(AveragingBlock())
104 |
105 | self.en1 = nn.Sequential(nn.Conv1d(in_channels, 32, 3, padding = 1),
106 | nn.BatchNorm1d(32),
107 | nn.LeakyReLU(0.2),
108 | nn.Conv1d(32, 32, 5, stride = 2, padding = 2),
109 | IncBlock(32,32))
110 |
111 | self.en2 = nn.Sequential(nn.Conv1d(32, 64, 3, padding = 1),
112 | nn.BatchNorm1d(64),
113 | nn.LeakyReLU(0.2),
114 | nn.Conv1d(64, 64, 5, stride = 2, padding = 2),
115 | IncBlock(64,64))
116 |
117 | self.en3 = nn.Sequential(nn.Conv1d(64,128, 3, padding = 1),
118 | nn.BatchNorm1d(128),
119 | nn.LeakyReLU(0.2),
120 | nn.Conv1d(128, 128, 3, stride = 2, padding = 1),
121 | IncBlock(128,128))
122 |
123 | self.en4 = nn.Sequential(nn.Conv1d(128,256, 3,padding = 1),
124 | nn.BatchNorm1d(256),
125 | nn.LeakyReLU(0.2),
126 | nn.Conv1d(256, 256, 5, stride = 2, padding = 1),
127 | IncBlock(256,256))
128 |
129 | self.en5 = nn.Sequential(nn.Conv1d(256,512, 3, padding = 1),
130 | nn.BatchNorm1d(512),
131 | nn.LeakyReLU(0.2),
132 | IncBlock(512,512))
133 |
134 |
135 | self.de1 = nn.Sequential(nn.ConvTranspose1d(512,256,1),
136 | nn.BatchNorm1d(256),
137 | nn.LeakyReLU(0.2),
138 | IncBlock(256,256))
139 |
140 | self.de2 = nn.Sequential(nn.Conv1d(512,256,3, padding = 1),
141 | nn.BatchNorm1d(256),
142 | nn.LeakyReLU(0.2),
143 | nn.ConvTranspose1d(256,128,3, stride = 2),
144 | IncBlock(128,128))
145 |
146 | self.de3 = nn.Sequential(nn.Conv1d(256,128,3, stride = 1, padding = 1),
147 | nn.BatchNorm1d(128),
148 | nn.LeakyReLU(0.2),
149 | nn.ConvTranspose1d(128,64,3, stride = 2),
150 | IncBlock(64,64))
151 |
152 | self.de4 = nn.Sequential(nn.Conv1d(128,64,3, stride = 1, padding = 1),
153 | nn.BatchNorm1d(64),
154 | nn.LeakyReLU(0.2),
155 | nn.ConvTranspose1d(64,32,3, stride = 2),
156 | IncBlock(32,32))
157 |
158 | self.de5 = nn.Sequential(nn.Conv1d(64,32,3, stride = 1, padding = 1),
159 | nn.BatchNorm1d(32),
160 | nn.LeakyReLU(0.2),
161 | nn.ConvTranspose1d(32,16,3, stride = 2),
162 | IncBlock(16,16))
163 |
164 | self.de6 = nn.Sequential(nn.ConvTranspose1d(16,8,2,stride =2),
165 | nn.BatchNorm1d(8),
166 | nn.LeakyReLU(0.2))
167 |
168 | self.de7 = nn.Sequential(nn.ConvTranspose1d(8,4,2,stride =2),
169 | nn.BatchNorm1d(4),
170 | nn.LeakyReLU(0.2))
171 |
172 | self.de8 = nn.Sequential(nn.ConvTranspose1d(4,2,1,stride =1),
173 | nn.BatchNorm1d(2),
174 | nn.LeakyReLU(0.2))
175 |
176 | self.de9 = nn.Sequential(nn.ConvTranspose1d(2,1,1,stride =1),
177 | nn.BatchNorm1d(1),
178 | nn.LeakyReLU(0.2))
179 |
180 |
181 | def forward(self,x):
182 |
183 | x = self.cea(x) #-Convolutional Ensemble Averaging--
184 |
185 | x = nn.ConstantPad1d((1,1),0)(x)
186 |
187 | e1 = self.en1(x) #-----------------------------------
188 | e2 = self.en2(e1) #-----------------------------------
189 | e3 = self.en3(e2) #---------Contracting Path----------
190 | e4 = self.en4(e3) #-----------------------------------
191 | e5 = self.en5(e4) #-----------------------------------
192 |
193 | d1 = self.de1(e5) #-----------------------------------
194 | cat = torch.cat([d1,e4],1) #-----------------------------------
195 | d2 = self.de2(cat) #-----------------------------------
196 | cat = torch.cat([d2,e3],1) #-----------------------------------
197 | d3 = self.de3(cat) #----------Expanding Path-----------
198 | cat = torch.cat([d3[:,:,:-2],e2],1) #-----------------------------------
199 | d4 = self.de4(cat) #-----------------------------------
200 | cat = torch.cat([d4[:,:,:-1],e1],1) #-----------------------------------
201 | d5 = self.de5(cat)[:,:,:-1] #-----------------------------------
202 | d6 = self.de6(d5) #-----------------------------------
203 |
204 | d7 = self.de7(d6) #-----------------------------------
205 | d8 = self.de8(d7) #---------Denoising Block-----------
206 | d9 = self.de9(d8) #-----------------------------------
207 |
208 | return d9
209 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import wfdb
5 | import torch
6 | from torch.utils.data import Dataset
7 | from tqdm import tqdm
8 | from sklearn.preprocessing import StandardScaler, MinMaxScaler
9 | import pickle
10 | import argparse
11 | from glob import glob
12 |
13 | from utils import *
14 | from dataloader import CEBSDataset
15 |
16 |
17 | def main(args):
18 |
19 | file_type = args.file_type
20 | data_path = args.data_path
21 |
22 | if not(os.path.exists("saved_data")):
23 | print ("Creating Saved Data Path")
24 | os.mkdir("saved_data")
25 | if not(os.path.exists("saved_data/{}".format(file_type))):
26 | os.mkdir("saved_data/{}".format(file_type))
27 | if not(os.path.exists("saved_data/{}/pickle_files".format(file_type))):
28 | os.mkdir("saved_data/{}/pickle_files".format(file_type))
29 |
30 | files = sorted(glob(os.path.join(data_path,file_type+"*.dat")))
31 |
32 | for i in tqdm(files, total = len(files)):
33 |
34 | i = i.rstrip(".dat")
35 | [x,info] = wfdb.rdsamp(i)
36 | ann = wfdb.io.rdann(i,'atr')
37 | all_peaks = ann.sample
38 |
39 | subjectWise_dict ={"rpeak1": all_peaks[::2],
40 | "rpeak2": all_peaks[1::2],
41 | "resp": x[:,2].flatten(),
42 | "scg": x[:,3].flatten(),
43 | "ecg1":x[:,0].flatten(),
44 | "ecg2":x[:,1].flatten(),
45 | }
46 | with open("saved_data/{}/pickle_files/{}.pkl".format(file_type,i.split("/")[-1]), "wb") as f:
47 | pickle.dump(subjectWise_dict,f)
48 |
49 | wlen = args.wlen
50 | overlap = args.overlap
51 | fs = args.fs
52 |
53 |
54 | generator = generateSignals(subjectWise_dict, fs, wlen, overlap)
55 |
56 | scgSig = []
57 | ecg1Sig = []
58 | ecg2Sig = []
59 |
60 | groundTruth1 = []
61 | groundTruth2 = []
62 | for scg,ecg1,rpeak1,ecg2,rpeak2 in generator:
63 | if ecg1.shape[0] != wlen*fs or ecg2.shape[0] != wlen*fs or scg.shape[0] != wlen*fs or rpeak1 is None or rpeak2 is None:
64 | continue
65 | transform1 = distanceTransform(ecg1, rpeak1)
66 |
67 | transform2 = distanceTransform(ecg2, rpeak2)
68 |
69 |
70 | scgSig.append(scg.reshape((1,-1)))
71 |
72 | ecg1Sig.append(ecg1.reshape((1,-1)))
73 | ecg2Sig.append(ecg2.reshape((1,-1)))
74 |
75 |
76 | groundTruth1.append(transform1.reshape((1,-1)))
77 | groundTruth2.append(transform2.reshape((1,-1)))
78 |
79 | inputSig_t = torch.tensor(scgSig).type(torch.float)
80 | ecg1Sig_t = torch.tensor(ecg1Sig).type(torch.float)
81 | ecg2Sig_t = torch.tensor(ecg2Sig).type(torch.float)
82 |
83 |
84 | ecg12Sig_t = torch.cat((ecg1Sig_t, ecg2Sig_t),1)
85 |
86 | groundTruth1_t = torch.tensor(groundTruth1).type(torch.float)
87 | groundTruth2_t = torch.tensor(groundTruth2).type(torch.float)
88 | saving_path = 'saved_data/{}/preprocessed_data/'.format(file_type)
89 |
90 | if not(os.path.exists(saving_path)):
91 | os.mkdir(saving_path)
92 |
93 |
94 | p_no = int(i.split("/")[2].split(".")[0].lstrip(file_type))
95 | torch.save(inputSig_t, saving_path+"inputSig_{}.pt".format(p_no))
96 | torch.save(groundTruth1_t, saving_path+"groundTruth1_{}.pt".format(p_no))
97 | torch.save(groundTruth2_t, saving_path+"groundTruth2_{}.pt".format(p_no))
98 | torch.save(ecg12Sig_t,saving_path+"ecg12_{}.pt".format(p_no))
99 |
100 | print("--Saving Data--")
101 | data = CEBSDataset(os.path.join("saved_data/", file_type))
102 | torch.save(data.input, os.path.join("saved_data/", file_type, "data.pt"))
103 | torch.save(data.ground, os.path.join("saved_data/", file_type, "labels.pt"))
104 |
105 |
106 |
107 |
108 |
109 | if __name__ =="__main__":
110 |
111 | parser = argparse.ArgumentParser()
112 | parser.add_argument('--file_type', nargs='?',type = str, default= "b", help = 'm, p or b')
113 | parser.add_argument('--data_path', nargs = '?', type = str, default = "../files/", help= "path to data files")
114 | parser.add_argument('--wlen', nargs = '?', type = int, default = 10, help= "window length in seconds")
115 | parser.add_argument('--overlap', nargs = '?', type = int, default = 5, help= "overlap length in seconds")
116 | parser.add_argument('--fs', nargs = '?', type = int, default = 5000, help= "sampling frequency")
117 |
118 |
119 | args = parser.parse_args()
120 | main(args)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch===1.6.0 -f https://download.pytorch.org/whl/torch_stable.html
2 | torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
3 | numpy==1.19.2
4 | pandas==1.1.3
5 | scipy==1.5.2
6 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | '''trainer'''
2 |
3 | import os
4 | import sys
5 | import signal
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import Dataset, DataLoader, TensorDataset
11 | from tqdm import tqdm
12 | from torch import optim
13 | import torch.nn.functional as F
14 | from torch.utils.tensorboard import SummaryWriter
15 | from glob import glob
16 | import warnings
17 |
18 | from sklearn.model_selection import train_test_split
19 | import argparse
20 |
21 | from model import SeismoNet
22 | from dataloader import CEBSDataset
23 | from utils import *
24 | warnings.filterwarnings("ignore")
25 |
26 |
27 | def dump_and_exit(signalnumber, frame):
28 | if not(os.path.exists("best_model")):
29 | os.mkdir("best_model")
30 | torch.save(model_state, "best_model/best_model_on_SIGINT.pt")
31 | sys.exit(0)
32 |
33 | def main(args):
34 |
35 |
36 | global model_state
37 |
38 | test_size = float(args.test_size)
39 | val_size = float(args.val_size)
40 | data_path = os.path.join(args.data_path, args.file_type)
41 | lr = float(args.lr)
42 | train_batch_size = int(args.train_batch_size)
43 | val_batch_size = int(args.val_batch_size)
44 | epochs = args.epochs
45 | typ = args.file_type
46 |
47 | print ("Training SeismoNet on CEBS")
48 |
49 |
50 | train_loader, val_loader, test_loader = create_loaders(data_path, "data.pt","labels.pt", test_size, val_size, train_batch_size, val_batch_size)
51 | writer = SummaryWriter()
52 |
53 | model= SeismoNet(get_shape(train_loader)).cuda()
54 |
55 |
56 | optimizer = torch.optim.SGD(model.parameters(), lr=lr)
57 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[100,200], gamma=0.1)
58 |
59 | criterion = nn.SmoothL1Loss()
60 |
61 | best_loss = 1000
62 | best_accuracy = 0
63 | if not(os.path.exists("best_model/")):
64 | os.mkdir("best_model/")
65 |
66 | for epoch in range(int(epochs)):
67 |
68 | model.train()
69 | print('epochs {}/{} '.format(epoch+1,epochs))
70 | running_loss = 0.0
71 | running_loss_v = 0.0
72 | correct = 0
73 | correct_v = 0
74 | for idx, (inputs,labels) in tqdm(enumerate(train_loader), total = len(train_loader)):
75 |
76 | inputs = inputs.cuda()
77 | labels = labels.cuda()
78 |
79 | optimizer.zero_grad()
80 |
81 | y_pred= model(inputs)
82 |
83 |
84 | loss = criterion(y_pred,labels)
85 | running_loss += loss
86 | loss.backward()
87 | optimizer.step()
88 |
89 |
90 | scheduler.step()
91 | model.eval()
92 | with torch.no_grad():
93 | for idx,(inputs_v,labels_v) in tqdm(enumerate(val_loader),total=len(val_loader)):
94 |
95 | inputs_v = inputs_v.cuda()
96 | labels_v = labels_v.cuda()
97 | y_pred_v = model(inputs_v)
98 | loss_v = criterion(y_pred_v,labels_v)
99 |
100 | running_loss_v += loss_v
101 |
102 |
103 | val_loss = running_loss_v/len(val_loader)
104 | model_state = {
105 | 'epoch': epoch,
106 | 'model': model.state_dict(),
107 | 'optimizer': optimizer.state_dict(),
108 | 'val_loss': val_loss
109 | }
110 |
111 | if (val_loss <= best_loss):
112 | best_loss = running_loss_v/len(val_loader)
113 | out = torch.save(model_state, f='best_model/best_model.pt')
114 |
115 | print('train loss: {:.4f} val loss : {:.4f}'.format(running_loss/len(train_loader), running_loss_v/len(val_loader)))
116 | writer.add_scalar("Loss/train_loss",running_loss/len(train_loader), epoch )
117 | writer.add_scalar("Loss/val_loss",running_loss_v/len(val_loader), epoch )
118 |
119 |
120 | writer.close()
121 |
122 | print ("Completed")
123 | torch.save(model_state, f='best_model/best_model_training_completed.pt')
124 |
125 | if __name__ == "__main__":
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument('--data_path',nargs="?", const = "saved_data/", default = "saved_data/", help = 'Path to saved files directory')
128 | parser.add_argument('--file_type',nargs="?", const = "b", default = "b", help = "file type")
129 | parser.add_argument('--test_size', nargs='?',const = 0.2,default = 0.2, help = 'Size of Test Set (float)')
130 | parser.add_argument('--val_size',nargs='?', const = 0.2,default = 0.2, help = 'Size of Validation Set (float)')
131 | parser.add_argument('--train_batch_size',nargs='?', const = 32,default = 32, help = 'Batch Size of Train Loader')
132 | parser.add_argument('--val_batch_size',nargs='?', const = 32,default = 32, help = 'Batch Size of Validation Loader')
133 | parser.add_argument('--epochs',nargs='?', const = 300,default = 300, help = 'Number of Epochs')
134 | parser.add_argument("--lr",nargs = "?",const = 0.001,default = 0.001, help = 'Learning Rate')
135 | signal.signal(signal.SIGINT, dump_and_exit)
136 | args = parser.parse_args()
137 | main(args)
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | from sklearn.preprocessing import StandardScaler, MinMaxScaler
5 | from sklearn.model_selection import train_test_split
6 | from torch.utils.data import Dataset, DataLoader, TensorDataset
7 | from scipy.signal import find_peaks
8 | import torch
9 |
10 |
11 | def generateSignals(data,fs = 5000, wlen = 10, overlap = 5):
12 | wlen = wlen*fs
13 | overlap = (overlap*fs)/wlen
14 | totalLength = len(data["scg"])
15 |
16 | for start in range(0, totalLength, int((1-overlap)*wlen)):
17 | yield data["scg"][start:start+wlen], data["ecg1"][start:start+wlen], data["rpeak1"][(data["rpeak1"] >=start) & (data["rpeak1"] <=start + wlen )] - start, data["ecg2"][start:start+wlen], data["rpeak2"][(data["rpeak2"] >=start) & (data["rpeak2"] <=start + wlen )] - start
18 |
19 | def distanceTransform(signal, rpeaks):
20 | length = len(signal)
21 | transform = []
22 | lower = rpeaks[0]
23 | for j in range(0, lower):
24 | transform.append(abs(lower - j))
25 | for i in range(1,len(rpeaks)):
26 | upper = rpeaks[i]
27 | lower = rpeaks[i-1]
28 | middle = (upper + lower)/2
29 | for k in range(lower, upper):
30 | transform.append(abs(k-lower)) if k < middle else transform.append(abs(k-upper))
31 | for i in range(upper,length):
32 | transform.append(abs(i-upper))
33 | transform = np.array(transform)
34 | from sklearn.preprocessing import MinMaxScaler
35 | scaler = MinMaxScaler()
36 | scaledTransform = scaler.fit_transform(transform.reshape((-1,1)))
37 |
38 | return scaledTransform
39 |
40 | def create_loaders(data_path, inp_file = "data.pt", label_file = "labels.pt", test_size = 0.2, val_size = 0.2, train_batch_size = 64, val_batch_size = 64):
41 | data = torch.load(os.path.join(data_path,inp_file))
42 | target = torch.load(os.path.join(data_path,label_file))
43 | x_train, x_val, y_train, y_val = train_test_split(data, target, random_state = 42, test_size = val_size + test_size)
44 | x_val,x_test, y_val,y_test = train_test_split(x_val,y_val, random_state = 32, test_size = (test_size/(test_size + val_size)))
45 | train, val, test = TensorDataset(x_train, y_train), TensorDataset(x_val, y_val), TensorDataset(x_test, y_test)
46 |
47 | train_loader = DataLoader(train, batch_size=train_batch_size, shuffle =False, num_workers = 4, pin_memory = True)
48 | val_loader = DataLoader(val, batch_size = val_batch_size, shuffle = False,num_workers = 4, pin_memory = True)
49 | test_loader = DataLoader(test, batch_size = 1 , shuffle = False)
50 |
51 | return train_loader, val_loader, test_loader
52 |
53 | def get_shape(loader):
54 | for x,y in loader:
55 | return x.shape
56 |
57 |
58 | def infer(model, inp, prominence = 0.3, distance = 625,smoothen = True, downsampling_factor = 10):
59 | model.cuda()
60 | model.eval()
61 | inp = inp[:,0,:].view(1, 1, inp.shape[-1]).cuda()
62 | with torch.no_grad():
63 | pred = model(inp)
64 | if smoothen:
65 | out=smooth(pred.cpu().detach().view(pred.shape[-1]).numpy())
66 | else:
67 | out = pred.cpu().detach().view(pred.shape[-1]).numpy()
68 | if (downsampling_factor!=1):
69 | downsampled = out.flatten()[0::downsampling_factor]
70 | else:
71 | downsampled = out.flatten()
72 | valley_loc_downsampled,_ = getValleys(downsampled, prominence = prominence,distance = max(1,distance//downsampling_factor))
73 | return out,valley_loc_downsampled*downsampling_factor
74 |
75 | def getValleys(signal, prominence, distance ):
76 | signal = signal*-1
77 | valley_loc, _ = find_peaks(signal, prominence = prominence,distance = distance)
78 | return valley_loc,_
79 |
80 | def smooth(signal,window_len=50):
81 | y = pd.DataFrame(signal).rolling(window_len,center = True, min_periods = 1).mean().values.reshape((-1,))
82 | return y
83 |
84 | def evaluate_window(actual, detected, fs = 5000, tolerance = 75):
85 |
86 | tolerance = (tolerance/1000)*fs
87 | grouped_missed = []
88 | FP= 0
89 | matched_beats = []
90 | correct = 0
91 | for correctPeak in actual:
92 | matched = detected[np.where(abs(correctPeak - detected) < tolerance)[0]]
93 | try:
94 | assert len(matched) == 1
95 | correct+=1
96 | matched_beats.append(matched[0])
97 | except AssertionError:
98 | if len(matched) > 1:
99 | FP+= len(matched) - 1
100 | else:
101 | matched = [np.NaN]
102 | matched_beats.append(np.NaN)
103 | temp = np.asarray([correctPeak, matched[0]])
104 | grouped_missed.append(temp)
105 |
106 | grouped_missed = np.asarray(grouped_missed)
107 | matched_beats = np.asarray(matched_beats)
108 | matched_interbeat_intervals = np.diff(matched_beats)
109 | matched_interbeat_intervals = matched_interbeat_intervals[~np.isnan(matched_interbeat_intervals)]
110 | matched_IBI_SD = np.diff(matched_interbeat_intervals*1000/fs)
111 | matched_RMSSD = rms = np.sqrt(np.mean(matched_IBI_SD**2))
112 | matched_NN50 = len(np.where(matched_IBI_SD>50)[0])
113 | matched_pNN50 = matched_NN50/ len(matched_interbeat_intervals)
114 | matched_mIBI = matched_interbeat_intervals.mean()*1000/fs
115 | matched_SDNN = matched_interbeat_intervals.std()*1000/fs
116 | actual_interbeat_intervals = np.diff(actual)
117 | actual_IBI_SD = np.diff(actual_interbeat_intervals*1000/fs)
118 | actual_RMSSD = rms = np.sqrt(np.mean(actual_IBI_SD**2))
119 | actual_NN50 = len(np.where(actual_IBI_SD>50)[0])
120 | actual_pNN50 = actual_NN50/ len(actual_interbeat_intervals)
121 | actual_mIBI = actual_interbeat_intervals.mean()*1000/fs
122 | actual_SDNN = actual_interbeat_intervals.std()*1000/fs
123 |
124 | metrics = {
125 | "Total Positives": len(actual),
126 | "Total Detected": len(detected),
127 | "True Positives": correct,
128 | "False Positivies": len(detected) - correct,
129 | "Missed": len(actual) - correct,
130 | "Actual Mean Inter Beat Interval" : actual_mIBI,
131 | "Detected Mean Inter Beat Interval": matched_mIBI,
132 | "Actual Standard Deviation of Intervals": actual_SDNN,
133 | "Detected Standard Deviation of Intervals": matched_SDNN,
134 | "Actual pNN50" : actual_pNN50,
135 | 'Detected pNN50': matched_pNN50,
136 | 'Actual RMSSD' : actual_RMSSD,
137 | 'Detected RMSSD': matched_RMSSD,
138 | 'Sensitivity' : correct/(correct + (len(actual) - correct)),
139 | 'PPV' : correct/(correct + (len(detected) - correct))
140 | }
141 |
142 | return metrics
--------------------------------------------------------------------------------