├── .gitignore ├── README.md ├── dataloader.py ├── figure ├── architecture.png ├── incresblock.png └── inference.gif ├── infer.py ├── model.py ├── preprocess.py ├── requirements.txt ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | #data 133 | saved_data/ 134 | runs/ 135 | best_model/ 136 | results/ 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeismoNet 2 | 3 | This repository contains code used for the paper "End-to-End Deep Learning for Reliable Cardiac Activity Monitoring using Seismocardiograms" which has been accepted for presentation at the [19th International Conference on Machine Learning and Applications](https://www.icmla-conference.org/icmla20/index.html), Boca Raton, FL, USA. 4 | 5 | SeismoNet is a Deep Convolutional Neural Network which aims to provide an end-to-end solution to robustly observe heart activity from Seismocardiogram (SCG) signals. These SCG signals are motion-based and can be acquired in an easy, user-friendly fashion. SeismoNet transforms the SCG signal into an interpretable waveform consisting of relevant information which allows for extraction of heart rate indices. 6 | 7 | Preprint available at [arxiv](https://arxiv.org/abs/2010.05662) :newspaper: 8 | 9 | ## Getting Started :rocket: 10 | 11 | * [preprocess.py](preprocess.py) Preprocesses the CEBS dataset available at [physionet](https://physionet.org/content/cebsdb/1.0.0/) 12 | * [trainer.py](trainer.py) Helps train the model. 13 | * [infer.py](infer.py) Helps take inference on any input SCG signal. 14 | * [utils.py](utils.py) This file consists of all the helper functions. 15 | * [model.py](model.py) SeismoNet architecture in torch 16 | 17 | ## Model Architecture 18 | 19 | 20 | 21 | ## Usage 22 | 23 | Install all dependencies with: 24 | ```bash 25 | $ pip install -r requirements.txt 26 | ``` 27 | Download datasets with: 28 | ```bash 29 | $ wget -r -N -c -np https://physionet.org/files/cebsdb/1.0.0/ 30 | ``` 31 | Preprocess raw data: 32 | ```bash 33 | $ python preprocess.py --data_path /path/to/data 34 | ``` 35 | Train SeismoNet using preprocessed data: 36 | ```bash 37 | $ python trainer.py --data_path /path/to/preprocessed/data 38 | ``` 39 | 40 | Take inference and evaluate model: 41 | ```bash 42 | $ python infer.py --best_model /path/to/model --data_path /path/to/preprocessed/data --evaluate 43 | ``` 44 | ## Inference 45 | 46 | 47 | ## Authors :mortar_board: 48 | 49 | [Prithvi Suresh](https://github.com/prithusuresh/), [Naveen Narayanan](https://github.com/naveenggmu/), Pranav CV, Vineeth Vijayaraghavan 50 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import Dataset 7 | from glob import glob 8 | 9 | class CEBSDataset(Dataset): 10 | def __init__(self, data_path, ecg_channel = 1): 11 | 12 | 13 | 14 | self.data_path = data_path #---"saved_data/b" 15 | self.input = [] 16 | self.ground = [] 17 | 18 | gt_file_suffix = "groundTruth{}_".format(ecg_channel) 19 | p_files = sorted(glob(os.path.join(data_path, "preprocessed_data","inputSig_*.pt"))) 20 | 21 | for inp_file in p_files: 22 | p_no = inp_file.split(".")[-2].split("_")[-1] 23 | self.input.append(torch.load(inp_file)) 24 | gt_file_name = '/'.join(inp_file.split("/")[:-1]) +"/"+ gt_file_suffix + str(p_no) + ".pt" 25 | self.ground.append(torch.load(gt_file_name)) 26 | 27 | self.input = torch.cat(self.input).type(torch.float) 28 | self.ground = torch.cat(self.ground).type(torch.float) 29 | 30 | def __len__(self): 31 | return len(self.mer_input) 32 | 33 | def __getitem__(self,idx): 34 | 35 | label = self.ground[idx] 36 | input_tensor = self.input[idx] 37 | return input_tensor,label -------------------------------------------------------------------------------- /figure/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prithusuresh/SeismoNet/8850f35a4d1d9db520d1e38b58347544c0daa012/figure/architecture.png -------------------------------------------------------------------------------- /figure/incresblock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prithusuresh/SeismoNet/8850f35a4d1d9db520d1e38b58347544c0daa012/figure/incresblock.png -------------------------------------------------------------------------------- /figure/inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prithusuresh/SeismoNet/8850f35a4d1d9db520d1e38b58347544c0daa012/figure/inference.gif -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | '''infer''' 2 | import os 3 | import sys 4 | import signal 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import Dataset, DataLoader, TensorDataset 10 | from tqdm import tqdm 11 | from torch import optim 12 | import torch.nn.functional as F 13 | from glob import glob 14 | import warnings 15 | import matplotlib.pyplot as plt 16 | 17 | from sklearn.model_selection import train_test_split 18 | import argparse 19 | 20 | from model import SeismoNet 21 | from utils import * 22 | 23 | def main(args): 24 | 25 | 26 | if args.create_test_file: 27 | print ("Creating Test File... ") 28 | data_path = os.path.join(args.data_path, args.file_type) 29 | __, _, test_loader = create_loaders(data_path, "data.pt","labels.pt") 30 | 31 | else: 32 | print ("Loading Test File... ") 33 | test_tensor = torch.load(args.test_tensor_file) 34 | test_dataset = TensorDataset(test_tensor) 35 | test_loader = DataLoader(test_dataset, batch_size = 1, pin_memory = True) 36 | 37 | print ("Loading Model... ") 38 | model = SeismoNet(get_shape(test_loader)) 39 | model.load_state_dict(torch.load(args.best_model)["model"]) 40 | window_info = [] 41 | metrics = [] 42 | if not(os.path.exists("results/")): 43 | os.mkdir("results/") 44 | 45 | for i,x in enumerate(test_loader): 46 | 47 | if len(x) > 1: 48 | pred_distance_transform, pred_peak_locations = infer(model, x[0] ,downsampling_factor = 1) 49 | print (pred_distance_transform, pred_peak_locations) 50 | 51 | else: 52 | pred_distance_transform, peak_locations = infer(model, x, downsampling_factor = 1) 53 | 54 | if args.evaluate: 55 | assert len(x)>1 56 | actual_peak_locations = np.where(x[1] == 0.0)[0] #provide actual rpeak locations as array 57 | metrics.append(evaluate_window(actual_peak_locations, pred_peak_locations)) 58 | 59 | if args.save_figures: 60 | if not(os.path.exits("results/figures")): 61 | os.mkdir("results/figures") 62 | plt.figure(figsize = [10,5]) 63 | plt.subplot(1,2,1) 64 | plt.plot(x[0].cpu().numpy().flatten()) 65 | plt.subplot(1,2,2) 66 | plt.plot(pred_distance_transform.flatten()) 67 | plt.plot(x[1].cpu().numpy().flatten()) 68 | plt.scatter(pred_peak_locations, pred_distance_transform.flatten()[pred_peak_locations]) 69 | plt.savefig("results/figures/{}.png".format(i+1)) 70 | 71 | window_info.append(pred_peak_locations) 72 | metrics = pd.DataFrame(metrics) 73 | metrics.to_csv("results/results.csv") 74 | 75 | 76 | 77 | if __name__ == "__main__": 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("--create_test_file", action ="store_true", help = "Create test file if not already present") 80 | parser.add_argument('--test_tensor_file',nargs="?" , help = 'Path to saved files directory') 81 | parser.add_argument('--data_path',nargs="?", const = "saved_data/", default = "saved_data/", help = 'Path to saved files directory') 82 | parser.add_argument('--file_type',nargs="?", const = "b", default = "b", help = "file type") 83 | parser.add_argument('--best_model',nargs="?", const = "best_model/best_model_pretrained.pt", default = "best_model/best_model_pretrained.pt", help = "Best Model File") 84 | parser.add_argument('--evaluate', action = "store_true", help = "Compare against label or not") 85 | parser.add_argument('--save_figures', action = "store_true", help = "save figure along with results") 86 | args = parser.parse_args() 87 | 88 | main(args) 89 | 90 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | class IncBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels, size = 15, stride = 1, padding = 7): 8 | super(IncBlock,self).__init__() 9 | 10 | self.conv1x1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias = False) 11 | 12 | self.conv1 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = size, stride = stride, padding = padding ), 13 | nn.BatchNorm1d(out_channels//4)) 14 | 15 | self.conv2 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False), 16 | nn.BatchNorm1d(out_channels//4), 17 | nn.LeakyReLU(0.2), 18 | nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size +2 , stride = stride, padding = padding + 1), 19 | nn.BatchNorm1d(out_channels//4)) 20 | 21 | self.conv3 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False), 22 | nn.BatchNorm1d(out_channels//4), 23 | nn.LeakyReLU(0.2), 24 | nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size + 4 , stride = stride, padding = padding + 2), 25 | nn.BatchNorm1d(out_channels//4)) 26 | 27 | self.conv4 = nn.Sequential(nn.Conv1d(in_channels, out_channels//4, kernel_size = 1, bias = False), 28 | nn.BatchNorm1d(out_channels//4), 29 | nn.LeakyReLU(0.2), 30 | nn.Conv1d(out_channels//4, out_channels//4, kernel_size = size + 6 , stride = stride, padding = padding + 3), 31 | nn.BatchNorm1d(out_channels//4)) 32 | 33 | self.relu = nn.ReLU() 34 | 35 | def forward(self,x): 36 | 37 | res = self.conv1x1(x) 38 | 39 | c1 = self.conv1(x) 40 | 41 | c2 = self.conv2(x) 42 | 43 | c3 = self.conv3(x) 44 | 45 | c4 = self.conv4(x) 46 | 47 | concat = torch.cat((c1,c2,c3,c4),dim = 1) 48 | 49 | concat+=res 50 | 51 | return self.relu(concat) 52 | 53 | class AveragingBlock(nn.Module): 54 | 55 | def __init__(self,in_channels = 1, out_channels = 1): 56 | 57 | super(AveragingBlock, self).__init__() 58 | 59 | self.conv1 = nn.Conv1d(in_channels,8,3) 60 | self.bn1 = nn.BatchNorm1d(8) 61 | 62 | self.conv2 = nn.Conv1d(8,16,3) 63 | self.bn2 =nn.BatchNorm1d(16) 64 | 65 | self.conv3 = nn.Conv2d(1,1,(3,3), 2) 66 | self.bn3 = nn.BatchNorm2d(1) 67 | 68 | self.conv4 = nn.Conv2d(1, 1, (3,15), padding = (0,7)) 69 | self.bn4 = nn.BatchNorm2d(1) 70 | 71 | self.conv5 = nn.Conv1d(1,out_channels,3, padding = 1) 72 | self.bn5 = nn.BatchNorm1d(out_channels) 73 | 74 | self.relu1 = nn.LeakyReLU(0.2) 75 | 76 | self.mp1 = nn.MaxPool1d(2) 77 | self.mp2 = nn.MaxPool2d((2,2)) 78 | 79 | def forward(self, x): 80 | 81 | x = self.relu1(self.bn1(self.conv1(x))) 82 | 83 | x = self.relu1(self.bn2(self.conv2(x))) 84 | 85 | x = x.view(x.shape[0],1,x.shape[1],x.shape[2]) 86 | 87 | x = self.relu1(self.bn3(self.conv3(x))) 88 | 89 | x = self.mp2(x) 90 | 91 | x = self.relu1(self.bn4(self.conv4(x))) 92 | 93 | x = torch.squeeze(x, dim = 1) 94 | 95 | x = self.relu1(self.bn5(self.conv5(x))) 96 | 97 | return x 98 | 99 | class SeismoNet(nn.Module): 100 | def __init__(self, shape): 101 | super(SeismoNet, self).__init__() 102 | in_channels = 1 103 | self.cea = nn.Sequential(AveragingBlock()) 104 | 105 | self.en1 = nn.Sequential(nn.Conv1d(in_channels, 32, 3, padding = 1), 106 | nn.BatchNorm1d(32), 107 | nn.LeakyReLU(0.2), 108 | nn.Conv1d(32, 32, 5, stride = 2, padding = 2), 109 | IncBlock(32,32)) 110 | 111 | self.en2 = nn.Sequential(nn.Conv1d(32, 64, 3, padding = 1), 112 | nn.BatchNorm1d(64), 113 | nn.LeakyReLU(0.2), 114 | nn.Conv1d(64, 64, 5, stride = 2, padding = 2), 115 | IncBlock(64,64)) 116 | 117 | self.en3 = nn.Sequential(nn.Conv1d(64,128, 3, padding = 1), 118 | nn.BatchNorm1d(128), 119 | nn.LeakyReLU(0.2), 120 | nn.Conv1d(128, 128, 3, stride = 2, padding = 1), 121 | IncBlock(128,128)) 122 | 123 | self.en4 = nn.Sequential(nn.Conv1d(128,256, 3,padding = 1), 124 | nn.BatchNorm1d(256), 125 | nn.LeakyReLU(0.2), 126 | nn.Conv1d(256, 256, 5, stride = 2, padding = 1), 127 | IncBlock(256,256)) 128 | 129 | self.en5 = nn.Sequential(nn.Conv1d(256,512, 3, padding = 1), 130 | nn.BatchNorm1d(512), 131 | nn.LeakyReLU(0.2), 132 | IncBlock(512,512)) 133 | 134 | 135 | self.de1 = nn.Sequential(nn.ConvTranspose1d(512,256,1), 136 | nn.BatchNorm1d(256), 137 | nn.LeakyReLU(0.2), 138 | IncBlock(256,256)) 139 | 140 | self.de2 = nn.Sequential(nn.Conv1d(512,256,3, padding = 1), 141 | nn.BatchNorm1d(256), 142 | nn.LeakyReLU(0.2), 143 | nn.ConvTranspose1d(256,128,3, stride = 2), 144 | IncBlock(128,128)) 145 | 146 | self.de3 = nn.Sequential(nn.Conv1d(256,128,3, stride = 1, padding = 1), 147 | nn.BatchNorm1d(128), 148 | nn.LeakyReLU(0.2), 149 | nn.ConvTranspose1d(128,64,3, stride = 2), 150 | IncBlock(64,64)) 151 | 152 | self.de4 = nn.Sequential(nn.Conv1d(128,64,3, stride = 1, padding = 1), 153 | nn.BatchNorm1d(64), 154 | nn.LeakyReLU(0.2), 155 | nn.ConvTranspose1d(64,32,3, stride = 2), 156 | IncBlock(32,32)) 157 | 158 | self.de5 = nn.Sequential(nn.Conv1d(64,32,3, stride = 1, padding = 1), 159 | nn.BatchNorm1d(32), 160 | nn.LeakyReLU(0.2), 161 | nn.ConvTranspose1d(32,16,3, stride = 2), 162 | IncBlock(16,16)) 163 | 164 | self.de6 = nn.Sequential(nn.ConvTranspose1d(16,8,2,stride =2), 165 | nn.BatchNorm1d(8), 166 | nn.LeakyReLU(0.2)) 167 | 168 | self.de7 = nn.Sequential(nn.ConvTranspose1d(8,4,2,stride =2), 169 | nn.BatchNorm1d(4), 170 | nn.LeakyReLU(0.2)) 171 | 172 | self.de8 = nn.Sequential(nn.ConvTranspose1d(4,2,1,stride =1), 173 | nn.BatchNorm1d(2), 174 | nn.LeakyReLU(0.2)) 175 | 176 | self.de9 = nn.Sequential(nn.ConvTranspose1d(2,1,1,stride =1), 177 | nn.BatchNorm1d(1), 178 | nn.LeakyReLU(0.2)) 179 | 180 | 181 | def forward(self,x): 182 | 183 | x = self.cea(x) #-Convolutional Ensemble Averaging-- 184 | 185 | x = nn.ConstantPad1d((1,1),0)(x) 186 | 187 | e1 = self.en1(x) #----------------------------------- 188 | e2 = self.en2(e1) #----------------------------------- 189 | e3 = self.en3(e2) #---------Contracting Path---------- 190 | e4 = self.en4(e3) #----------------------------------- 191 | e5 = self.en5(e4) #----------------------------------- 192 | 193 | d1 = self.de1(e5) #----------------------------------- 194 | cat = torch.cat([d1,e4],1) #----------------------------------- 195 | d2 = self.de2(cat) #----------------------------------- 196 | cat = torch.cat([d2,e3],1) #----------------------------------- 197 | d3 = self.de3(cat) #----------Expanding Path----------- 198 | cat = torch.cat([d3[:,:,:-2],e2],1) #----------------------------------- 199 | d4 = self.de4(cat) #----------------------------------- 200 | cat = torch.cat([d4[:,:,:-1],e1],1) #----------------------------------- 201 | d5 = self.de5(cat)[:,:,:-1] #----------------------------------- 202 | d6 = self.de6(d5) #----------------------------------- 203 | 204 | d7 = self.de7(d6) #----------------------------------- 205 | d8 = self.de8(d7) #---------Denoising Block----------- 206 | d9 = self.de9(d8) #----------------------------------- 207 | 208 | return d9 209 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import wfdb 5 | import torch 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 9 | import pickle 10 | import argparse 11 | from glob import glob 12 | 13 | from utils import * 14 | from dataloader import CEBSDataset 15 | 16 | 17 | def main(args): 18 | 19 | file_type = args.file_type 20 | data_path = args.data_path 21 | 22 | if not(os.path.exists("saved_data")): 23 | print ("Creating Saved Data Path") 24 | os.mkdir("saved_data") 25 | if not(os.path.exists("saved_data/{}".format(file_type))): 26 | os.mkdir("saved_data/{}".format(file_type)) 27 | if not(os.path.exists("saved_data/{}/pickle_files".format(file_type))): 28 | os.mkdir("saved_data/{}/pickle_files".format(file_type)) 29 | 30 | files = sorted(glob(os.path.join(data_path,file_type+"*.dat"))) 31 | 32 | for i in tqdm(files, total = len(files)): 33 | 34 | i = i.rstrip(".dat") 35 | [x,info] = wfdb.rdsamp(i) 36 | ann = wfdb.io.rdann(i,'atr') 37 | all_peaks = ann.sample 38 | 39 | subjectWise_dict ={"rpeak1": all_peaks[::2], 40 | "rpeak2": all_peaks[1::2], 41 | "resp": x[:,2].flatten(), 42 | "scg": x[:,3].flatten(), 43 | "ecg1":x[:,0].flatten(), 44 | "ecg2":x[:,1].flatten(), 45 | } 46 | with open("saved_data/{}/pickle_files/{}.pkl".format(file_type,i.split("/")[-1]), "wb") as f: 47 | pickle.dump(subjectWise_dict,f) 48 | 49 | wlen = args.wlen 50 | overlap = args.overlap 51 | fs = args.fs 52 | 53 | 54 | generator = generateSignals(subjectWise_dict, fs, wlen, overlap) 55 | 56 | scgSig = [] 57 | ecg1Sig = [] 58 | ecg2Sig = [] 59 | 60 | groundTruth1 = [] 61 | groundTruth2 = [] 62 | for scg,ecg1,rpeak1,ecg2,rpeak2 in generator: 63 | if ecg1.shape[0] != wlen*fs or ecg2.shape[0] != wlen*fs or scg.shape[0] != wlen*fs or rpeak1 is None or rpeak2 is None: 64 | continue 65 | transform1 = distanceTransform(ecg1, rpeak1) 66 | 67 | transform2 = distanceTransform(ecg2, rpeak2) 68 | 69 | 70 | scgSig.append(scg.reshape((1,-1))) 71 | 72 | ecg1Sig.append(ecg1.reshape((1,-1))) 73 | ecg2Sig.append(ecg2.reshape((1,-1))) 74 | 75 | 76 | groundTruth1.append(transform1.reshape((1,-1))) 77 | groundTruth2.append(transform2.reshape((1,-1))) 78 | 79 | inputSig_t = torch.tensor(scgSig).type(torch.float) 80 | ecg1Sig_t = torch.tensor(ecg1Sig).type(torch.float) 81 | ecg2Sig_t = torch.tensor(ecg2Sig).type(torch.float) 82 | 83 | 84 | ecg12Sig_t = torch.cat((ecg1Sig_t, ecg2Sig_t),1) 85 | 86 | groundTruth1_t = torch.tensor(groundTruth1).type(torch.float) 87 | groundTruth2_t = torch.tensor(groundTruth2).type(torch.float) 88 | saving_path = 'saved_data/{}/preprocessed_data/'.format(file_type) 89 | 90 | if not(os.path.exists(saving_path)): 91 | os.mkdir(saving_path) 92 | 93 | 94 | p_no = int(i.split("/")[2].split(".")[0].lstrip(file_type)) 95 | torch.save(inputSig_t, saving_path+"inputSig_{}.pt".format(p_no)) 96 | torch.save(groundTruth1_t, saving_path+"groundTruth1_{}.pt".format(p_no)) 97 | torch.save(groundTruth2_t, saving_path+"groundTruth2_{}.pt".format(p_no)) 98 | torch.save(ecg12Sig_t,saving_path+"ecg12_{}.pt".format(p_no)) 99 | 100 | print("--Saving Data--") 101 | data = CEBSDataset(os.path.join("saved_data/", file_type)) 102 | torch.save(data.input, os.path.join("saved_data/", file_type, "data.pt")) 103 | torch.save(data.ground, os.path.join("saved_data/", file_type, "labels.pt")) 104 | 105 | 106 | 107 | 108 | 109 | if __name__ =="__main__": 110 | 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('--file_type', nargs='?',type = str, default= "b", help = 'm, p or b') 113 | parser.add_argument('--data_path', nargs = '?', type = str, default = "../files/", help= "path to data files") 114 | parser.add_argument('--wlen', nargs = '?', type = int, default = 10, help= "window length in seconds") 115 | parser.add_argument('--overlap', nargs = '?', type = int, default = 5, help= "overlap length in seconds") 116 | parser.add_argument('--fs', nargs = '?', type = int, default = 5000, help= "sampling frequency") 117 | 118 | 119 | args = parser.parse_args() 120 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch===1.6.0 -f https://download.pytorch.org/whl/torch_stable.html 2 | torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html 3 | numpy==1.19.2 4 | pandas==1.1.3 5 | scipy==1.5.2 6 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | '''trainer''' 2 | 3 | import os 4 | import sys 5 | import signal 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset, DataLoader, TensorDataset 11 | from tqdm import tqdm 12 | from torch import optim 13 | import torch.nn.functional as F 14 | from torch.utils.tensorboard import SummaryWriter 15 | from glob import glob 16 | import warnings 17 | 18 | from sklearn.model_selection import train_test_split 19 | import argparse 20 | 21 | from model import SeismoNet 22 | from dataloader import CEBSDataset 23 | from utils import * 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | def dump_and_exit(signalnumber, frame): 28 | if not(os.path.exists("best_model")): 29 | os.mkdir("best_model") 30 | torch.save(model_state, "best_model/best_model_on_SIGINT.pt") 31 | sys.exit(0) 32 | 33 | def main(args): 34 | 35 | 36 | global model_state 37 | 38 | test_size = float(args.test_size) 39 | val_size = float(args.val_size) 40 | data_path = os.path.join(args.data_path, args.file_type) 41 | lr = float(args.lr) 42 | train_batch_size = int(args.train_batch_size) 43 | val_batch_size = int(args.val_batch_size) 44 | epochs = args.epochs 45 | typ = args.file_type 46 | 47 | print ("Training SeismoNet on CEBS") 48 | 49 | 50 | train_loader, val_loader, test_loader = create_loaders(data_path, "data.pt","labels.pt", test_size, val_size, train_batch_size, val_batch_size) 51 | writer = SummaryWriter() 52 | 53 | model= SeismoNet(get_shape(train_loader)).cuda() 54 | 55 | 56 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 57 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[100,200], gamma=0.1) 58 | 59 | criterion = nn.SmoothL1Loss() 60 | 61 | best_loss = 1000 62 | best_accuracy = 0 63 | if not(os.path.exists("best_model/")): 64 | os.mkdir("best_model/") 65 | 66 | for epoch in range(int(epochs)): 67 | 68 | model.train() 69 | print('epochs {}/{} '.format(epoch+1,epochs)) 70 | running_loss = 0.0 71 | running_loss_v = 0.0 72 | correct = 0 73 | correct_v = 0 74 | for idx, (inputs,labels) in tqdm(enumerate(train_loader), total = len(train_loader)): 75 | 76 | inputs = inputs.cuda() 77 | labels = labels.cuda() 78 | 79 | optimizer.zero_grad() 80 | 81 | y_pred= model(inputs) 82 | 83 | 84 | loss = criterion(y_pred,labels) 85 | running_loss += loss 86 | loss.backward() 87 | optimizer.step() 88 | 89 | 90 | scheduler.step() 91 | model.eval() 92 | with torch.no_grad(): 93 | for idx,(inputs_v,labels_v) in tqdm(enumerate(val_loader),total=len(val_loader)): 94 | 95 | inputs_v = inputs_v.cuda() 96 | labels_v = labels_v.cuda() 97 | y_pred_v = model(inputs_v) 98 | loss_v = criterion(y_pred_v,labels_v) 99 | 100 | running_loss_v += loss_v 101 | 102 | 103 | val_loss = running_loss_v/len(val_loader) 104 | model_state = { 105 | 'epoch': epoch, 106 | 'model': model.state_dict(), 107 | 'optimizer': optimizer.state_dict(), 108 | 'val_loss': val_loss 109 | } 110 | 111 | if (val_loss <= best_loss): 112 | best_loss = running_loss_v/len(val_loader) 113 | out = torch.save(model_state, f='best_model/best_model.pt') 114 | 115 | print('train loss: {:.4f} val loss : {:.4f}'.format(running_loss/len(train_loader), running_loss_v/len(val_loader))) 116 | writer.add_scalar("Loss/train_loss",running_loss/len(train_loader), epoch ) 117 | writer.add_scalar("Loss/val_loss",running_loss_v/len(val_loader), epoch ) 118 | 119 | 120 | writer.close() 121 | 122 | print ("Completed") 123 | torch.save(model_state, f='best_model/best_model_training_completed.pt') 124 | 125 | if __name__ == "__main__": 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument('--data_path',nargs="?", const = "saved_data/", default = "saved_data/", help = 'Path to saved files directory') 128 | parser.add_argument('--file_type',nargs="?", const = "b", default = "b", help = "file type") 129 | parser.add_argument('--test_size', nargs='?',const = 0.2,default = 0.2, help = 'Size of Test Set (float)') 130 | parser.add_argument('--val_size',nargs='?', const = 0.2,default = 0.2, help = 'Size of Validation Set (float)') 131 | parser.add_argument('--train_batch_size',nargs='?', const = 32,default = 32, help = 'Batch Size of Train Loader') 132 | parser.add_argument('--val_batch_size',nargs='?', const = 32,default = 32, help = 'Batch Size of Validation Loader') 133 | parser.add_argument('--epochs',nargs='?', const = 300,default = 300, help = 'Number of Epochs') 134 | parser.add_argument("--lr",nargs = "?",const = 0.001,default = 0.001, help = 'Learning Rate') 135 | signal.signal(signal.SIGINT, dump_and_exit) 136 | args = parser.parse_args() 137 | main(args) 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 5 | from sklearn.model_selection import train_test_split 6 | from torch.utils.data import Dataset, DataLoader, TensorDataset 7 | from scipy.signal import find_peaks 8 | import torch 9 | 10 | 11 | def generateSignals(data,fs = 5000, wlen = 10, overlap = 5): 12 | wlen = wlen*fs 13 | overlap = (overlap*fs)/wlen 14 | totalLength = len(data["scg"]) 15 | 16 | for start in range(0, totalLength, int((1-overlap)*wlen)): 17 | yield data["scg"][start:start+wlen], data["ecg1"][start:start+wlen], data["rpeak1"][(data["rpeak1"] >=start) & (data["rpeak1"] <=start + wlen )] - start, data["ecg2"][start:start+wlen], data["rpeak2"][(data["rpeak2"] >=start) & (data["rpeak2"] <=start + wlen )] - start 18 | 19 | def distanceTransform(signal, rpeaks): 20 | length = len(signal) 21 | transform = [] 22 | lower = rpeaks[0] 23 | for j in range(0, lower): 24 | transform.append(abs(lower - j)) 25 | for i in range(1,len(rpeaks)): 26 | upper = rpeaks[i] 27 | lower = rpeaks[i-1] 28 | middle = (upper + lower)/2 29 | for k in range(lower, upper): 30 | transform.append(abs(k-lower)) if k < middle else transform.append(abs(k-upper)) 31 | for i in range(upper,length): 32 | transform.append(abs(i-upper)) 33 | transform = np.array(transform) 34 | from sklearn.preprocessing import MinMaxScaler 35 | scaler = MinMaxScaler() 36 | scaledTransform = scaler.fit_transform(transform.reshape((-1,1))) 37 | 38 | return scaledTransform 39 | 40 | def create_loaders(data_path, inp_file = "data.pt", label_file = "labels.pt", test_size = 0.2, val_size = 0.2, train_batch_size = 64, val_batch_size = 64): 41 | data = torch.load(os.path.join(data_path,inp_file)) 42 | target = torch.load(os.path.join(data_path,label_file)) 43 | x_train, x_val, y_train, y_val = train_test_split(data, target, random_state = 42, test_size = val_size + test_size) 44 | x_val,x_test, y_val,y_test = train_test_split(x_val,y_val, random_state = 32, test_size = (test_size/(test_size + val_size))) 45 | train, val, test = TensorDataset(x_train, y_train), TensorDataset(x_val, y_val), TensorDataset(x_test, y_test) 46 | 47 | train_loader = DataLoader(train, batch_size=train_batch_size, shuffle =False, num_workers = 4, pin_memory = True) 48 | val_loader = DataLoader(val, batch_size = val_batch_size, shuffle = False,num_workers = 4, pin_memory = True) 49 | test_loader = DataLoader(test, batch_size = 1 , shuffle = False) 50 | 51 | return train_loader, val_loader, test_loader 52 | 53 | def get_shape(loader): 54 | for x,y in loader: 55 | return x.shape 56 | 57 | 58 | def infer(model, inp, prominence = 0.3, distance = 625,smoothen = True, downsampling_factor = 10): 59 | model.cuda() 60 | model.eval() 61 | inp = inp[:,0,:].view(1, 1, inp.shape[-1]).cuda() 62 | with torch.no_grad(): 63 | pred = model(inp) 64 | if smoothen: 65 | out=smooth(pred.cpu().detach().view(pred.shape[-1]).numpy()) 66 | else: 67 | out = pred.cpu().detach().view(pred.shape[-1]).numpy() 68 | if (downsampling_factor!=1): 69 | downsampled = out.flatten()[0::downsampling_factor] 70 | else: 71 | downsampled = out.flatten() 72 | valley_loc_downsampled,_ = getValleys(downsampled, prominence = prominence,distance = max(1,distance//downsampling_factor)) 73 | return out,valley_loc_downsampled*downsampling_factor 74 | 75 | def getValleys(signal, prominence, distance ): 76 | signal = signal*-1 77 | valley_loc, _ = find_peaks(signal, prominence = prominence,distance = distance) 78 | return valley_loc,_ 79 | 80 | def smooth(signal,window_len=50): 81 | y = pd.DataFrame(signal).rolling(window_len,center = True, min_periods = 1).mean().values.reshape((-1,)) 82 | return y 83 | 84 | def evaluate_window(actual, detected, fs = 5000, tolerance = 75): 85 | 86 | tolerance = (tolerance/1000)*fs 87 | grouped_missed = [] 88 | FP= 0 89 | matched_beats = [] 90 | correct = 0 91 | for correctPeak in actual: 92 | matched = detected[np.where(abs(correctPeak - detected) < tolerance)[0]] 93 | try: 94 | assert len(matched) == 1 95 | correct+=1 96 | matched_beats.append(matched[0]) 97 | except AssertionError: 98 | if len(matched) > 1: 99 | FP+= len(matched) - 1 100 | else: 101 | matched = [np.NaN] 102 | matched_beats.append(np.NaN) 103 | temp = np.asarray([correctPeak, matched[0]]) 104 | grouped_missed.append(temp) 105 | 106 | grouped_missed = np.asarray(grouped_missed) 107 | matched_beats = np.asarray(matched_beats) 108 | matched_interbeat_intervals = np.diff(matched_beats) 109 | matched_interbeat_intervals = matched_interbeat_intervals[~np.isnan(matched_interbeat_intervals)] 110 | matched_IBI_SD = np.diff(matched_interbeat_intervals*1000/fs) 111 | matched_RMSSD = rms = np.sqrt(np.mean(matched_IBI_SD**2)) 112 | matched_NN50 = len(np.where(matched_IBI_SD>50)[0]) 113 | matched_pNN50 = matched_NN50/ len(matched_interbeat_intervals) 114 | matched_mIBI = matched_interbeat_intervals.mean()*1000/fs 115 | matched_SDNN = matched_interbeat_intervals.std()*1000/fs 116 | actual_interbeat_intervals = np.diff(actual) 117 | actual_IBI_SD = np.diff(actual_interbeat_intervals*1000/fs) 118 | actual_RMSSD = rms = np.sqrt(np.mean(actual_IBI_SD**2)) 119 | actual_NN50 = len(np.where(actual_IBI_SD>50)[0]) 120 | actual_pNN50 = actual_NN50/ len(actual_interbeat_intervals) 121 | actual_mIBI = actual_interbeat_intervals.mean()*1000/fs 122 | actual_SDNN = actual_interbeat_intervals.std()*1000/fs 123 | 124 | metrics = { 125 | "Total Positives": len(actual), 126 | "Total Detected": len(detected), 127 | "True Positives": correct, 128 | "False Positivies": len(detected) - correct, 129 | "Missed": len(actual) - correct, 130 | "Actual Mean Inter Beat Interval" : actual_mIBI, 131 | "Detected Mean Inter Beat Interval": matched_mIBI, 132 | "Actual Standard Deviation of Intervals": actual_SDNN, 133 | "Detected Standard Deviation of Intervals": matched_SDNN, 134 | "Actual pNN50" : actual_pNN50, 135 | 'Detected pNN50': matched_pNN50, 136 | 'Actual RMSSD' : actual_RMSSD, 137 | 'Detected RMSSD': matched_RMSSD, 138 | 'Sensitivity' : correct/(correct + (len(actual) - correct)), 139 | 'PPV' : correct/(correct + (len(detected) - correct)) 140 | } 141 | 142 | return metrics --------------------------------------------------------------------------------