├── .ipynb_checkpoints ├── Inference-checkpoint.ipynb └── Model Performance Analysis-checkpoint.ipynb ├── ECG.py ├── Figure1.PNG ├── Inference.ipynb ├── Model Performance Analysis.ipynb ├── PreOpNet MACE.pt ├── PreOpNet Mort.pt ├── PreOpNet Plus Clinical Features ├── .ipynb_checkpoints │ └── Inference-checkpoint.ipynb ├── ECG.py ├── Inference.ipynb ├── Process Data.ipynb ├── __pycache__ │ ├── ECG.cpython-36.pyc │ └── models.cpython-36.pyc ├── best_roc_model_MACE_with_RCRI_Features.pt ├── best_roc_model_Mortality_with_RCRI_Features.pt └── models.py ├── Process Data.ipynb ├── README.md ├── license.txt └── models.py /.ipynb_checkpoints/Inference-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from sklearn.metrics import roc_curve, auc\n", 11 | "import models\n", 12 | "import torch\n", 13 | "\n", 14 | "from tqdm import tqdm\n", 15 | "import torch.nn.functional as F\n", 16 | "import os\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay\n", 19 | "from pthflops import count_ops\n", 20 | "\n", 21 | "import pandas as pd\n", 22 | "from ECG import ECG_loader\n", 23 | "def get_model(depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs=0):\n", 24 | " model = models.EffNet(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion,num_additional_features=additional_inputs)\n", 25 | " print('parameters: ' +str(sum(p.numel() for p in model.parameters() if p.requires_grad)))\n", 26 | " model.eval()\n", 27 | " return model\n", 28 | "\n", 29 | "def point(y_total,yhat_total,t):\n", 30 | " specificity = []\n", 31 | " sensitivity = []\n", 32 | " for i in tqdm(t):\n", 33 | " tn, fp, fn, tp = confusion_matrix(y_total, yhat_total>i).ravel()\n", 34 | " specificity.append( tn / (tn+fp) )\n", 35 | " sensitivity.append( tp / (tp+fn) )\n", 36 | " return t[(np.array(specificity) + np.array(sensitivity) - 1).argmax()]\n", 37 | "def thresholded_output_transform(yhat,y):\n", 38 | " y_pred, y = yhat,y\n", 39 | " y_pred = torch.sigmoid(y_pred)\n", 40 | " return y_pred, y\n", 41 | "def produce_df(test_loader, checkpoint,folder = 'Test Results', depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs = 0,ds = None):\n", 42 | " \n", 43 | " # load model and checkpoint\n", 44 | " model = get_model(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion, additional_inputs=additional_inputs)\n", 45 | " checkpoint = torch.load(checkpoint)\n", 46 | " model.load_state_dict(checkpoint)\n", 47 | " model.to('cpu')\n", 48 | " yhat = torch.Tensor()\n", 49 | " \n", 50 | " # Calculate FLOPs\n", 51 | " if not ds is None:\n", 52 | " dl=torch.utils.data.DataLoader(ds, batch_size=1,num_workers=10,drop_last=False)\n", 53 | " for x,y in dl:\n", 54 | " count_ops(model,x)\n", 55 | " break\n", 56 | "\n", 57 | " # produce tensors\n", 58 | " y_total = torch.Tensor()\n", 59 | " yhat_total = torch.Tensor()\n", 60 | " for x,y in tqdm(test_loader):\n", 61 | " yhat = model(x)\n", 62 | " yhat,y = thresholded_output_transform(yhat,y)\n", 63 | " y_total = torch.cat((y_total,y),0)\n", 64 | " yhat_total = torch.cat((yhat_total,yhat.detach().cpu()),0)\n", 65 | " \n", 66 | " # Save predictions\n", 67 | " df = pd.DataFrame({'labels':y_total.flatten().tolist(),'Prediction':yhat_total.flatten().tolist()})\n", 68 | " df.to_csv(\"Predictions.csv\",index=False)\n", 69 | " # Produce ROC and CM\n", 70 | " fpr,tpr, t = roc_curve(y_total,yhat_total)\n", 71 | " thresh = point(y_total,yhat_total,t)\n", 72 | " lw = 2\n", 73 | " cm(y_total.flatten(),yhat_total.flatten(),folder=folder,threshold = thresh)\n", 74 | " bootstrap(df)\n", 75 | " return yhat_total, y_total, fpr, tpr, t\n", 76 | "def bootstrap(df):\n", 77 | " y_total,yhat_total = df['labels'],df['Prediction']\n", 78 | " fpr_boot = []\n", 79 | " tpr_boot = []\n", 80 | " aucs = []\n", 81 | " \n", 82 | " # bootstrap for confidence interval\n", 83 | " for i in tqdm(range(0,10000)):\n", 84 | " choices = np.random.choice(range(0,len(yhat_total)),int(len(yhat_total)/2))\n", 85 | " fpr,tpr, _ = roc_curve(y_total[choices],yhat_total[choices])\n", 86 | " fpr_boot.append(fpr)\n", 87 | " tpr_boot.append(tpr)\n", 88 | " aucs.append(auc(fpr,tpr))\n", 89 | " low,high = np.nanmean(aucs)-np.nanstd(aucs)*1.96,np.nanmean(aucs)+np.nanstd(aucs)*1.96\n", 90 | " lower_point = round(np.percentile(aucs,2.5),2)\n", 91 | " higher_point = round(np.percentile(aucs,97.5),2)\n", 92 | " mean_point = round(np.nanmean(aucs),2)\n", 93 | " x = plt.hist(aucs,bins = 50,label = 'mean: '+str(mean_point))\n", 94 | "\n", 95 | " plt.plot([np.percentile(aucs,2.5),np.percentile(aucs,2.5)],[0,max(x[0])],label = 'lower interval: '+str(lower_point))\n", 96 | " plt.plot([np.percentile(aucs,97.5),np.percentile(aucs,97.5)],[0,max(x[0])],label = 'higher interval: '+str(higher_point))\n", 97 | " plt.title(\"AUC Histogram\")\n", 98 | " plt.xlabel(\"AUC\")\n", 99 | " plt.legend()\n", 100 | " plt.show()\n", 101 | " \n", 102 | " plt.figure()\n", 103 | " lw = 2\n", 104 | " for i in range(0,1000):\n", 105 | " plt.plot(fpr_boot[i],tpr_boot[i], color='lightblue',\n", 106 | " lw=lw)\n", 107 | " fpr,tpr, _ = roc_curve(y_total,yhat_total)\n", 108 | " plt.plot(fpr, tpr, color='darkorange',\n", 109 | " lw=lw, label='ROC curve (area = %0.2f)' % auc(fpr,tpr))\n", 110 | " plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n", 111 | " plt.xlim([0.0, 1.0])\n", 112 | " plt.ylim([0.0, 1.05])\n", 113 | " plt.xlabel('False Positive Rate')\n", 114 | " plt.ylabel('True Positive Rate')\n", 115 | " plt.title('ROC Curve')\n", 116 | " plt.legend(loc=\"lower right\")\n", 117 | " plt.show()\n", 118 | "def cm(y_total,yhat_total,Project_name = None,folder=None,threshold = 0.5):\n", 119 | " print(threshold)\n", 120 | " cm = confusion_matrix(y_total,yhat_total>threshold)\n", 121 | " tn, fp, fn, tp = confusion_matrix(y_total,yhat_total>threshold).ravel()\n", 122 | " specificity = ( tn / (tn+fp) )\n", 123 | " sensitivity= ( tp / (tp+fn) )\n", 124 | " print('Positive Predictive Value',round(tp/(tp+fp),2),'Negative Predictive Value', round(tn/(tn+fn),2), ' Specificty ', specificity, 'Sensitivity ', sensitivity)\n", 125 | " disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['No Event','Adverse Event'])\n", 126 | " disp.plot()\n", 127 | " plt.title(\"Confusion Matrix\")\n", 128 | " plt.show()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "Val_root = '/workspace/John/IntroECG-main/IntroECG-main/data/Definitely Not A Mistake/'\n", 138 | "Val_csv = '/workspace/John/IntroECG-main/IntroECG-main/data/Test_rcri_outcome.csv'\n" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "# RCRINet on RCRI Outcome" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "scrolled": false 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "bs = 2000\n", 157 | "checkpoint = 'RCRINet Mortality best_roc_model.pt'\n", 158 | "\n", 159 | "test_ds = ECG_loader(root = Val_root, csv = Val_csv,sliding = False,downsample=1)\n", 160 | "val_dataloader=torch.utils.data.DataLoader(test_ds, batch_size=bs,num_workers=10,drop_last=False)\n", 161 | "x = produce_df(val_dataloader,checkpoint,folder = 'Spare',stride = 8, dilation = 2)\n" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [] 170 | } 171 | ], 172 | "metadata": { 173 | "kernelspec": { 174 | "display_name": "Python 3", 175 | "language": "python", 176 | "name": "python3" 177 | }, 178 | "language_info": { 179 | "codemirror_mode": { 180 | "name": "ipython", 181 | "version": 3 182 | }, 183 | "file_extension": ".py", 184 | "mimetype": "text/x-python", 185 | "name": "python", 186 | "nbconvert_exporter": "python", 187 | "pygments_lexer": "ipython3", 188 | "version": "3.6.10" 189 | } 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 4 193 | } 194 | -------------------------------------------------------------------------------- /ECG.py: -------------------------------------------------------------------------------- 1 | 2 | # 10/24/2020 3 | 4 | import os, os.path 5 | import numpy as np 6 | import torch 7 | from torch import Tensor 8 | import matplotlib.pyplot as plt 9 | import torch.utils.data 10 | import pandas as pd 11 | import pathlib 12 | import math 13 | verbose = False 14 | 15 | 16 | def rolling_average (array,value_length): 17 | new_array = np.zeros((1,5000,12)) 18 | assert array.shape == (1,5000,12), "array is not shape (1,2500,12)" 19 | for i in range(0,12): 20 | new_array[0,:,i]=pd.Series(array[0][:,i]).rolling(window=value_length,min_periods=1).mean() #min_periods ensure no NaNs before value_length fulfilled 21 | return new_array 22 | 23 | 24 | # + 25 | # 2, 4 , 8, 16, 26 | # - 27 | 28 | 5000//64 29 | 30 | 31 | def plot(array,color = 'blue'): 32 | lead_order = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] 33 | plt.rcParams["figure.figsize"] = [16,9] 34 | 35 | fig, axs = plt.subplots(len(lead_order)) 36 | fig.suptitle("array") 37 | # rolling_arr = rolling_average(array,15) 38 | if array.shape == (5000, 12): 39 | for i in range(0,12): 40 | axs[i].plot(array[:2500,i],label = 'window') 41 | # axs[i].plot(array[::2,i],label = 'downsample') 42 | # axs[i].plot(rolling_arr[:2500,i],label = 'rolling') 43 | axs[i].set(ylabel=str(lead_order[i])) 44 | elif array.shape == (12, 5000): 45 | for i in range(0,12): 46 | axs[i].plot(array[i,:2500],label = 'window') 47 | # axs[i].plot(array[i,::2],label = 'downsample') 48 | # axs[i].plot(rolling_arr[i,:],label = 'rolling') 49 | axs[i].set(ylabel=str(lead_order[i])) 50 | elif array.shape == (1,5000,12): 51 | for i in range(0,12): 52 | axs[i].plot(array[0,:5000,i],label = 'window') 53 | # axs[i].plot(array[0,::2,i],label = 'downsample') 54 | # axs[i].plot(rolling_arr[0,:2500,i],label = 'rolling') 55 | axs[i].set(ylabel=str(lead_order[i])) 56 | elif array.shape == (1,1,5000,12): 57 | for i in range(0,12): 58 | axs[i].plot(array[0,0,:5000,i],label = 'window') 59 | # axs[i].plot(array[0,0,::2,i],label = 'downsample') 60 | # axs[i].plot(rolling_arr[0,0,:2500,i],label = 'rolling') 61 | axs[i].set(ylabel=str(lead_order[i])) 62 | else: 63 | print("ECG shape not valid: ",array.shape) 64 | 65 | plt.show() 66 | 67 | 68 | class ECG_loader(torch.utils.data.Dataset): 69 | 70 | def __init__(self, root=None,csv = None,bootstrap = False,sliding = False,downsample=0,rolling = 0,plot_data = False,additional_inputs = None,target = 'Mortality'): 71 | if root is None: 72 | root = "EchoNet_ECG_waveforms" 73 | if csv is None: 74 | csv = "EFFileList.csv" 75 | self.folder = pathlib.Path(root) 76 | self.file_list = pd.read_csv(csv) 77 | if bootstrap: 78 | self.file_list = self.file_list.sample(frac=0.5, replace=True).reset_index(drop=True) 79 | self.sliding = sliding 80 | self.downsample = downsample 81 | self.rolling = rolling 82 | self.plot_data = plot_data 83 | self.additional_inputs = additional_inputs 84 | self.target = target 85 | def __getitem__(self, index): 86 | if 'filename' in self.file_list.columns: 87 | fname = self.file_list.filename[index%len(self.file_list.index)] 88 | if 'Filename' in self.file_list.columns: 89 | fname = self.file_list.Filename[index%len(self.file_list.index)] 90 | waveform = np.load(os.path.join(self.folder, fname)) 91 | if waveform.shape == (5000,12): 92 | waveform = np.expand_dims(waveform,axis=0) 93 | x = [] 94 | if not self.additional_inputs is None: 95 | for i in self.additional_inputs: 96 | x.append(self.file_list[i][index%len(self.file_list.index)]) 97 | start = np.random.randint(2499) 98 | if self.rolling != 0: 99 | waveform = rolling_average(waveform,self.rolling) 100 | if self.plot_data == True: 101 | plot(waveform,color = 'orange') 102 | plt.show() 103 | if self.sliding: 104 | waveform = waveform[:,start:start+2500] 105 | if self.downsample>0: 106 | waveform = waveform[:,::self.downsample,:] 107 | 108 | target = self.file_list[self.target][index] 109 | target = torch.FloatTensor([target]) 110 | waveform = np.transpose(waveform[0,:,:],(1,0)) 111 | waveform = torch.FloatTensor(waveform) 112 | 113 | if not self.additional_inputs is None: 114 | return (waveform,torch.FloatTensor(x)), target 115 | else: 116 | return waveform, target 117 | 118 | 119 | def __len__(self): 120 | 121 | return math.ceil(len(self.file_list.index)) 122 | 123 | 124 | def _defaultdict_of_lists(): 125 | """Returns a defaultdict of lists. 126 | This is used to avoid issues with Windows (if this function is anonymous, 127 | the Echo dataset cannot be used in a dataloader). 128 | """ 129 | 130 | return collections.defaultdict(list) 131 | -------------------------------------------------------------------------------- /Figure1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecg-net/PreOpNet/0732b48ff67131498cee25e1fec79d75b41a0c6f/Figure1.PNG -------------------------------------------------------------------------------- /Inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from sklearn.metrics import roc_curve, auc\n", 11 | "import models\n", 12 | "import torch\n", 13 | "\n", 14 | "from tqdm import tqdm\n", 15 | "import torch.nn.functional as F\n", 16 | "import os\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay\n", 19 | "from pthflops import count_ops\n", 20 | "\n", 21 | "import pandas as pd\n", 22 | "from ECG import ECG_loader\n", 23 | "def get_model(depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs=0):\n", 24 | " model = models.EffNet(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion,num_additional_features=additional_inputs)\n", 25 | " print('parameters: ' +str(sum(p.numel() for p in model.parameters() if p.requires_grad)))\n", 26 | " model.eval()\n", 27 | " return model\n", 28 | "\n", 29 | "def point(y_total,yhat_total,t):\n", 30 | " specificity = []\n", 31 | " sensitivity = []\n", 32 | " for i in tqdm(t):\n", 33 | " tn, fp, fn, tp = confusion_matrix(y_total, yhat_total>i).ravel()\n", 34 | " specificity.append( tn / (tn+fp) )\n", 35 | " sensitivity.append( tp / (tp+fn) )\n", 36 | " return t[(np.array(specificity) + np.array(sensitivity) - 1).argmax()]\n", 37 | "def thresholded_output_transform(yhat,y):\n", 38 | " y_pred, y = yhat,y\n", 39 | " y_pred = torch.sigmoid(y_pred)\n", 40 | " return y_pred, y\n", 41 | "def produce_df(test_loader, checkpoint,folder = 'Test Results', depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs = 0,ds = None):\n", 42 | " \n", 43 | " # load model and checkpoint\n", 44 | " model = get_model(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion, additional_inputs=additional_inputs)\n", 45 | " checkpoint = torch.load(checkpoint)\n", 46 | " model.load_state_dict(checkpoint)\n", 47 | " model.to('cpu')\n", 48 | " yhat = torch.Tensor()\n", 49 | " \n", 50 | " # Calculate FLOPs\n", 51 | " if not ds is None:\n", 52 | " dl=torch.utils.data.DataLoader(ds, batch_size=1,num_workers=10,drop_last=False)\n", 53 | " for x,y in dl:\n", 54 | " count_ops(model,x)\n", 55 | " break\n", 56 | "\n", 57 | " # produce tensors\n", 58 | " y_total = torch.Tensor()\n", 59 | " yhat_total = torch.Tensor()\n", 60 | " for x,y in tqdm(test_loader):\n", 61 | " yhat = model(x)\n", 62 | " yhat,y = thresholded_output_transform(yhat,y)\n", 63 | " y_total = torch.cat((y_total,y),0)\n", 64 | " yhat_total = torch.cat((yhat_total,yhat.detach().cpu()),0)\n", 65 | " \n", 66 | " # Save predictions\n", 67 | " df = pd.DataFrame({'labels':y_total.flatten().tolist(),'Prediction':yhat_total.flatten().tolist()})\n", 68 | " df.to_csv(\"Predictions.csv\",index=False)\n", 69 | " # Produce ROC and CM\n", 70 | " fpr,tpr, t = roc_curve(y_total,yhat_total)\n", 71 | " thresh = point(y_total,yhat_total,t)\n", 72 | " lw = 2\n", 73 | " cm(y_total.flatten(),yhat_total.flatten(),folder=folder,threshold = thresh)\n", 74 | " bootstrap(df)\n", 75 | " return yhat_total, y_total, fpr, tpr, t\n", 76 | "def bootstrap(df):\n", 77 | " y_total,yhat_total = df['labels'],df['Prediction']\n", 78 | " fpr_boot = []\n", 79 | " tpr_boot = []\n", 80 | " aucs = []\n", 81 | " \n", 82 | " # bootstrap for confidence interval\n", 83 | " for i in tqdm(range(0,10000)):\n", 84 | " choices = np.random.choice(range(0,len(yhat_total)),int(len(yhat_total)/2))\n", 85 | " fpr,tpr, _ = roc_curve(y_total[choices],yhat_total[choices])\n", 86 | " fpr_boot.append(fpr)\n", 87 | " tpr_boot.append(tpr)\n", 88 | " aucs.append(auc(fpr,tpr))\n", 89 | " low,high = np.nanmean(aucs)-np.nanstd(aucs)*1.96,np.nanmean(aucs)+np.nanstd(aucs)*1.96\n", 90 | " lower_point = round(np.percentile(aucs,2.5),2)\n", 91 | " higher_point = round(np.percentile(aucs,97.5),2)\n", 92 | " mean_point = round(np.nanmean(aucs),2)\n", 93 | " x = plt.hist(aucs,bins = 50,label = 'mean: '+str(mean_point))\n", 94 | "\n", 95 | " plt.plot([np.percentile(aucs,2.5),np.percentile(aucs,2.5)],[0,max(x[0])],label = 'lower interval: '+str(lower_point))\n", 96 | " plt.plot([np.percentile(aucs,97.5),np.percentile(aucs,97.5)],[0,max(x[0])],label = 'higher interval: '+str(higher_point))\n", 97 | " plt.title(\"AUC Histogram\")\n", 98 | " plt.xlabel(\"AUC\")\n", 99 | " plt.legend()\n", 100 | " plt.show()\n", 101 | " \n", 102 | " plt.figure()\n", 103 | " lw = 2\n", 104 | " for i in range(0,1000):\n", 105 | " plt.plot(fpr_boot[i],tpr_boot[i], color='lightblue',\n", 106 | " lw=lw)\n", 107 | " fpr,tpr, _ = roc_curve(y_total,yhat_total)\n", 108 | " plt.plot(fpr, tpr, color='darkorange',\n", 109 | " lw=lw, label='ROC curve (area = %0.2f)' % auc(fpr,tpr))\n", 110 | " plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n", 111 | " plt.xlim([0.0, 1.0])\n", 112 | " plt.ylim([0.0, 1.05])\n", 113 | " plt.xlabel('False Positive Rate')\n", 114 | " plt.ylabel('True Positive Rate')\n", 115 | " plt.title('ROC Curve')\n", 116 | " plt.legend(loc=\"lower right\")\n", 117 | " plt.show()\n", 118 | "def cm(y_total,yhat_total,Project_name = None,folder=None,threshold = 0.5):\n", 119 | " print(threshold)\n", 120 | " cm = confusion_matrix(y_total,yhat_total>threshold)\n", 121 | " tn, fp, fn, tp = confusion_matrix(y_total,yhat_total>threshold).ravel()\n", 122 | " specificity = ( tn / (tn+fp) )\n", 123 | " sensitivity= ( tp / (tp+fn) )\n", 124 | " print('Positive Predictive Value',round(tp/(tp+fp),2),'Negative Predictive Value', round(tn/(tn+fn),2), ' Specificty ', specificity, 'Sensitivity ', sensitivity)\n", 125 | " disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['No Event','Adverse Event'])\n", 126 | " disp.plot()\n", 127 | " plt.title(\"Confusion Matrix\")\n", 128 | " plt.show()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "Val_root = '/workspace/John/IntroECG-main/IntroECG-main/data/Definitely Not A Mistake/'\n", 138 | "Val_csv = '/workspace/John/IntroECG-main/IntroECG-main/data/Test_rcri_outcome.csv'\n" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "# RCRINet on RCRI Outcome" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "scrolled": false 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "bs = 2000\n", 157 | "checkpoint = 'PreOpNet Mort.pt'\n", 158 | "\n", 159 | "test_ds = ECG_loader(root = Val_root, csv = Val_csv,sliding = False,downsample=1)\n", 160 | "val_dataloader=torch.utils.data.DataLoader(test_ds, batch_size=bs,num_workers=10,drop_last=False)\n", 161 | "x = produce_df(val_dataloader,checkpoint,folder = 'Spare',stride = 8, dilation = 2)\n" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [] 170 | } 171 | ], 172 | "metadata": { 173 | "kernelspec": { 174 | "display_name": "Python 3", 175 | "language": "python", 176 | "name": "python3" 177 | }, 178 | "language_info": { 179 | "codemirror_mode": { 180 | "name": "ipython", 181 | "version": 3 182 | }, 183 | "file_extension": ".py", 184 | "mimetype": "text/x-python", 185 | "name": "python", 186 | "nbconvert_exporter": "python", 187 | "pygments_lexer": "ipython3", 188 | "version": "3.8.5" 189 | } 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 4 193 | } 194 | -------------------------------------------------------------------------------- /PreOpNet MACE.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecg-net/PreOpNet/0732b48ff67131498cee25e1fec79d75b41a0c6f/PreOpNet MACE.pt -------------------------------------------------------------------------------- /PreOpNet Mort.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecg-net/PreOpNet/0732b48ff67131498cee25e1fec79d75b41a0c6f/PreOpNet Mort.pt -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/.ipynb_checkpoints/Inference-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from sklearn.metrics import roc_curve, auc\n", 11 | "import models\n", 12 | "import torch\n", 13 | "\n", 14 | "from tqdm import tqdm\n", 15 | "import torch.nn.functional as F\n", 16 | "import os\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay\n", 19 | "from pthflops import count_ops\n", 20 | "\n", 21 | "import pandas as pd\n", 22 | "from ECG import ECG_loader\n", 23 | "def get_model(depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs=0):\n", 24 | " model = models.EffNet(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion,num_additional_features=additional_inputs)\n", 25 | " print('parameters: ' +str(sum(p.numel() for p in model.parameters() if p.requires_grad)))\n", 26 | " model.eval()\n", 27 | " return model\n", 28 | "\n", 29 | "def point(y_total,yhat_total,t):\n", 30 | " specificity = []\n", 31 | " sensitivity = []\n", 32 | " for i in tqdm(t):\n", 33 | " tn, fp, fn, tp = confusion_matrix(y_total, yhat_total>i).ravel()\n", 34 | " specificity.append( tn / (tn+fp) )\n", 35 | " sensitivity.append( tp / (tp+fn) )\n", 36 | " return t[(np.array(specificity) + np.array(sensitivity) - 1).argmax()]\n", 37 | "def thresholded_output_transform(yhat,y):\n", 38 | " y_pred, y = yhat,y\n", 39 | " y_pred = torch.sigmoid(y_pred)\n", 40 | " return y_pred, y\n", 41 | "def produce_df(test_loader, checkpoint,folder = 'Test Results', depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs = 0,ds = None):\n", 42 | " \n", 43 | " # load model and checkpoint\n", 44 | " model = get_model(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion, additional_inputs=additional_inputs)\n", 45 | " checkpoint = torch.load(checkpoint)\n", 46 | " model.load_state_dict(checkpoint)\n", 47 | " model.to('cpu')\n", 48 | " yhat = torch.Tensor()\n", 49 | " \n", 50 | " # Calculate FLOPs\n", 51 | " if not ds is None:\n", 52 | " dl=torch.utils.data.DataLoader(ds, batch_size=1,num_workers=10,drop_last=False)\n", 53 | " for x,y in dl:\n", 54 | " count_ops(model,x)\n", 55 | " break\n", 56 | "\n", 57 | " # produce tensors\n", 58 | " y_total = torch.Tensor()\n", 59 | " yhat_total = torch.Tensor()\n", 60 | " for x,y in tqdm(test_loader):\n", 61 | " yhat = model(x)\n", 62 | " yhat,y = thresholded_output_transform(yhat,y)\n", 63 | " y_total = torch.cat((y_total,y),0)\n", 64 | " yhat_total = torch.cat((yhat_total,yhat.detach().cpu()),0)\n", 65 | " \n", 66 | " # Save predictions\n", 67 | " df = pd.DataFrame({'labels':y_total.flatten().tolist(),'Prediction':yhat_total.flatten().tolist()})\n", 68 | " df.to_csv(\"Predictions.csv\",index=False)\n", 69 | " \n", 70 | " # Produce ROC and CM\n", 71 | " fpr,tpr, t = roc_curve(y_total,yhat_total)\n", 72 | " thresh = point(y_total,yhat_total,t)\n", 73 | " lw = 2\n", 74 | " cm(y_total.flatten(),yhat_total.flatten(),folder=folder,threshold = thresh)\n", 75 | " bootstrap(df)\n", 76 | " return yhat_total, y_total, fpr, tpr, t\n", 77 | "def bootstrap(df):\n", 78 | " y_total,yhat_total = df['labels'],df['Prediction']\n", 79 | " fpr_boot = []\n", 80 | " tpr_boot = []\n", 81 | " aucs = []\n", 82 | " \n", 83 | " # bootstrap for confidence interval\n", 84 | " for i in tqdm(range(0,10000)):\n", 85 | " choices = np.random.choice(range(0,len(yhat_total)),int(len(yhat_total)/2))\n", 86 | " fpr,tpr, _ = roc_curve(y_total[choices],yhat_total[choices])\n", 87 | " fpr_boot.append(fpr)\n", 88 | " tpr_boot.append(tpr)\n", 89 | " aucs.append(auc(fpr,tpr))\n", 90 | " low,high = np.nanmean(aucs)-np.nanstd(aucs)*1.96,np.nanmean(aucs)+np.nanstd(aucs)*1.96\n", 91 | " lower_point = round(np.percentile(aucs,2.5),2)\n", 92 | " higher_point = round(np.percentile(aucs,97.5),2)\n", 93 | " mean_point = round(np.nanmean(aucs),2)\n", 94 | " x = plt.hist(aucs,bins = 50,label = 'mean: '+str(mean_point))\n", 95 | "\n", 96 | " plt.plot([np.percentile(aucs,2.5),np.percentile(aucs,2.5)],[0,max(x[0])],label = 'lower interval: '+str(lower_point))\n", 97 | " plt.plot([np.percentile(aucs,97.5),np.percentile(aucs,97.5)],[0,max(x[0])],label = 'higher interval: '+str(higher_point))\n", 98 | " plt.title(\"AUC Histogram\")\n", 99 | " plt.xlabel(\"AUC\")\n", 100 | " plt.legend()\n", 101 | " plt.show()\n", 102 | " \n", 103 | " plt.figure()\n", 104 | " lw = 2\n", 105 | " for i in range(0,1000):\n", 106 | " plt.plot(fpr_boot[i],tpr_boot[i], color='lightblue',\n", 107 | " lw=lw)\n", 108 | " fpr,tpr, _ = roc_curve(y_total,yhat_total)\n", 109 | " plt.plot(fpr, tpr, color='darkorange',\n", 110 | " lw=lw, label='ROC curve (area = %0.2f)' % auc(fpr,tpr))\n", 111 | " plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n", 112 | " plt.xlim([0.0, 1.0])\n", 113 | " plt.ylim([0.0, 1.05])\n", 114 | " plt.xlabel('False Positive Rate')\n", 115 | " plt.ylabel('True Positive Rate')\n", 116 | " plt.title('ROC Curve')\n", 117 | " plt.legend(loc=\"lower right\")\n", 118 | " plt.show()\n", 119 | "def cm(y_total,yhat_total,Project_name = None,folder=None,threshold = 0.5):\n", 120 | " print(threshold)\n", 121 | " cm = confusion_matrix(y_total,yhat_total>threshold)\n", 122 | " tn, fp, fn, tp = confusion_matrix(y_total,yhat_total>threshold).ravel()\n", 123 | " specificity = ( tn / (tn+fp) )\n", 124 | " sensitivity= ( tp / (tp+fn) )\n", 125 | " print('Positive Predictive Value',round(tp/(tp+fp),2),'Negative Predictive Value', round(tn/(tn+fn),2), ' Specificty ', specificity, 'Sensitivity ', sensitivity)\n", 126 | " disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['No Event','Adverse Event'])\n", 127 | " disp.plot()\n", 128 | " plt.title(\"Confusion Matrix\")\n", 129 | " plt.show()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "Val_root = '/workspace/John/IntroECG-main/IntroECG-main/data/Definitely Not A Mistake/'\n", 139 | "Val_csv = '/workspace/John/IntroECG-main/IntroECG-main/data/Test_rcri_outcome.csv'\n" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "# RCRINet on RCRI Outcome" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "scrolled": false 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "bs = 2000\n", 158 | "checkpoint = 'best_roc_model_Mortality_with_RCRI_Features.pt'\n", 159 | "additional_inputs = ['CrGreaterThan2','is_risk','insulin','cad','chf','stroke']\n", 160 | "\n", 161 | "test_ds = ECG_loader(root = Val_root, csv = Val_csv,sliding = False,downsample=1,additional_inputs=additional_inputs)\n", 162 | "val_dataloader=torch.utils.data.DataLoader(test_ds, batch_size=bs,num_workers=10,drop_last=False)\n", 163 | "x = produce_df(val_dataloader,checkpoint,folder = 'Spare',stride = 8, dilation = 2,additional_inputs=len(additional_inputs))\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "Python 3", 177 | "language": "python", 178 | "name": "python3" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "python", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.6.10" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 4 195 | } 196 | -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/ECG.py: -------------------------------------------------------------------------------- 1 | 2 | # 10/24/2020 3 | 4 | import os, os.path 5 | import numpy as np 6 | import torch 7 | from torch import Tensor 8 | import matplotlib.pyplot as plt 9 | import torch.utils.data 10 | import pandas as pd 11 | import pathlib 12 | import math 13 | verbose = False 14 | 15 | 16 | def rolling_average (array,value_length): 17 | new_array = np.zeros((1,5000,12)) 18 | assert array.shape == (1,5000,12), "array is not shape (1,2500,12)" 19 | for i in range(0,12): 20 | new_array[0,:,i]=pd.Series(array[0][:,i]).rolling(window=value_length,min_periods=1).mean() #min_periods ensure no NaNs before value_length fulfilled 21 | return new_array 22 | 23 | 24 | # + 25 | # 2, 4 , 8, 16, 26 | # - 27 | 28 | 5000//64 29 | 30 | 31 | def plot(array,color = 'blue'): 32 | lead_order = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] 33 | plt.rcParams["figure.figsize"] = [16,9] 34 | 35 | fig, axs = plt.subplots(len(lead_order)) 36 | fig.suptitle("array") 37 | # rolling_arr = rolling_average(array,15) 38 | if array.shape == (5000, 12): 39 | for i in range(0,12): 40 | axs[i].plot(array[:2500,i],label = 'window') 41 | # axs[i].plot(array[::2,i],label = 'downsample') 42 | # axs[i].plot(rolling_arr[:2500,i],label = 'rolling') 43 | axs[i].set(ylabel=str(lead_order[i])) 44 | elif array.shape == (12, 5000): 45 | for i in range(0,12): 46 | axs[i].plot(array[i,:2500],label = 'window') 47 | # axs[i].plot(array[i,::2],label = 'downsample') 48 | # axs[i].plot(rolling_arr[i,:],label = 'rolling') 49 | axs[i].set(ylabel=str(lead_order[i])) 50 | elif array.shape == (1,5000,12): 51 | for i in range(0,12): 52 | axs[i].plot(array[0,:5000,i],label = 'window') 53 | # axs[i].plot(array[0,::2,i],label = 'downsample') 54 | # axs[i].plot(rolling_arr[0,:2500,i],label = 'rolling') 55 | axs[i].set(ylabel=str(lead_order[i])) 56 | elif array.shape == (1,1,5000,12): 57 | for i in range(0,12): 58 | axs[i].plot(array[0,0,:5000,i],label = 'window') 59 | # axs[i].plot(array[0,0,::2,i],label = 'downsample') 60 | # axs[i].plot(rolling_arr[0,0,:2500,i],label = 'rolling') 61 | axs[i].set(ylabel=str(lead_order[i])) 62 | else: 63 | print("ECG shape not valid: ",array.shape) 64 | 65 | plt.show() 66 | 67 | 68 | class ECG_loader(torch.utils.data.Dataset): 69 | 70 | def __init__(self, root=None,csv = None,bootstrap = False,sliding = False,downsample=0,rolling = 0,plot_data = False,additional_inputs = None,target = 'Mortality'): 71 | if root is None: 72 | root = "EchoNet_ECG_waveforms" 73 | if csv is None: 74 | csv = "EFFileList.csv" 75 | self.folder = pathlib.Path(root) 76 | self.file_list = pd.read_csv(csv) 77 | if bootstrap: 78 | self.file_list = self.file_list.sample(frac=0.5, replace=True).reset_index(drop=True) 79 | self.sliding = sliding 80 | self.downsample = downsample 81 | self.rolling = rolling 82 | self.plot_data = plot_data 83 | self.additional_inputs = additional_inputs 84 | self.target = target 85 | def __getitem__(self, index): 86 | if 'filename' in self.file_list.columns: 87 | fname = self.file_list.filename[index%len(self.file_list.index)] 88 | if 'Filename' in self.file_list.columns: 89 | fname = self.file_list.Filename[index%len(self.file_list.index)] 90 | waveform = np.load(os.path.join(self.folder, fname)) 91 | if waveform.shape == (5000,12): 92 | waveform = np.expand_dims(waveform,axis=0) 93 | x = [] 94 | if not self.additional_inputs is None: 95 | for i in self.additional_inputs: 96 | x.append(self.file_list[i][index%len(self.file_list.index)]) 97 | start = np.random.randint(2499) 98 | if self.rolling != 0: 99 | waveform = rolling_average(waveform,self.rolling) 100 | if self.plot_data == True: 101 | plot(waveform,color = 'orange') 102 | plt.show() 103 | if self.sliding: 104 | waveform = waveform[:,start:start+2500] 105 | if self.downsample>0: 106 | waveform = waveform[:,::self.downsample,:] 107 | 108 | target = self.file_list[self.target][index] 109 | target = torch.FloatTensor([target]) 110 | waveform = np.transpose(waveform[0,:,:],(1,0)) 111 | waveform = torch.FloatTensor(waveform) 112 | 113 | if not self.additional_inputs is None: 114 | return (waveform,torch.FloatTensor(x)), target 115 | else: 116 | return waveform, target 117 | 118 | 119 | def __len__(self): 120 | 121 | return math.ceil(len(self.file_list.index)) 122 | 123 | 124 | def _defaultdict_of_lists(): 125 | """Returns a defaultdict of lists. 126 | This is used to avoid issues with Windows (if this function is anonymous, 127 | the Echo dataset cannot be used in a dataloader). 128 | """ 129 | 130 | return collections.defaultdict(list) 131 | -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/Inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from sklearn.metrics import roc_curve, auc\n", 11 | "import models\n", 12 | "import torch\n", 13 | "\n", 14 | "from tqdm import tqdm\n", 15 | "import torch.nn.functional as F\n", 16 | "import os\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay\n", 19 | "from pthflops import count_ops\n", 20 | "\n", 21 | "import pandas as pd\n", 22 | "from ECG import ECG_loader\n", 23 | "def get_model(depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs=0):\n", 24 | " model = models.EffNet(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion,num_additional_features=additional_inputs)\n", 25 | " print('parameters: ' +str(sum(p.numel() for p in model.parameters() if p.requires_grad)))\n", 26 | " model.eval()\n", 27 | " return model\n", 28 | "\n", 29 | "def point(y_total,yhat_total,t):\n", 30 | " specificity = []\n", 31 | " sensitivity = []\n", 32 | " for i in tqdm(t):\n", 33 | " tn, fp, fn, tp = confusion_matrix(y_total, yhat_total>i).ravel()\n", 34 | " specificity.append( tn / (tn+fp) )\n", 35 | " sensitivity.append( tp / (tp+fn) )\n", 36 | " return t[(np.array(specificity) + np.array(sensitivity) - 1).argmax()]\n", 37 | "def thresholded_output_transform(yhat,y):\n", 38 | " y_pred, y = yhat,y\n", 39 | " y_pred = torch.sigmoid(y_pred)\n", 40 | " return y_pred, y\n", 41 | "def produce_df(test_loader, checkpoint,folder = 'Test Results', depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs = 0,ds = None):\n", 42 | " \n", 43 | " # load model and checkpoint\n", 44 | " model = get_model(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion, additional_inputs=additional_inputs)\n", 45 | " checkpoint = torch.load(checkpoint)\n", 46 | " model.load_state_dict(checkpoint)\n", 47 | " model.to('cpu')\n", 48 | " yhat = torch.Tensor()\n", 49 | " \n", 50 | " # Calculate FLOPs\n", 51 | " if not ds is None:\n", 52 | " dl=torch.utils.data.DataLoader(ds, batch_size=1,num_workers=10,drop_last=False)\n", 53 | " for x,y in dl:\n", 54 | " count_ops(model,x)\n", 55 | " break\n", 56 | "\n", 57 | " # produce tensors\n", 58 | " y_total = torch.Tensor()\n", 59 | " yhat_total = torch.Tensor()\n", 60 | " for x,y in tqdm(test_loader):\n", 61 | " yhat = model(x)\n", 62 | " yhat,y = thresholded_output_transform(yhat,y)\n", 63 | " y_total = torch.cat((y_total,y),0)\n", 64 | " yhat_total = torch.cat((yhat_total,yhat.detach().cpu()),0)\n", 65 | " \n", 66 | " # Save predictions\n", 67 | " df = pd.DataFrame({'labels':y_total.flatten().tolist(),'Prediction':yhat_total.flatten().tolist()})\n", 68 | " df.to_csv(\"Predictions.csv\",index=False)\n", 69 | " \n", 70 | " # Produce ROC and CM\n", 71 | " fpr,tpr, t = roc_curve(y_total,yhat_total)\n", 72 | " thresh = point(y_total,yhat_total,t)\n", 73 | " lw = 2\n", 74 | " cm(y_total.flatten(),yhat_total.flatten(),folder=folder,threshold = thresh)\n", 75 | " bootstrap(df)\n", 76 | " return yhat_total, y_total, fpr, tpr, t\n", 77 | "def bootstrap(df):\n", 78 | " y_total,yhat_total = df['labels'],df['Prediction']\n", 79 | " fpr_boot = []\n", 80 | " tpr_boot = []\n", 81 | " aucs = []\n", 82 | " \n", 83 | " # bootstrap for confidence interval\n", 84 | " for i in tqdm(range(0,10000)):\n", 85 | " choices = np.random.choice(range(0,len(yhat_total)),int(len(yhat_total)/2))\n", 86 | " fpr,tpr, _ = roc_curve(y_total[choices],yhat_total[choices])\n", 87 | " fpr_boot.append(fpr)\n", 88 | " tpr_boot.append(tpr)\n", 89 | " aucs.append(auc(fpr,tpr))\n", 90 | " low,high = np.nanmean(aucs)-np.nanstd(aucs)*1.96,np.nanmean(aucs)+np.nanstd(aucs)*1.96\n", 91 | " lower_point = round(np.percentile(aucs,2.5),2)\n", 92 | " higher_point = round(np.percentile(aucs,97.5),2)\n", 93 | " mean_point = round(np.nanmean(aucs),2)\n", 94 | " x = plt.hist(aucs,bins = 50,label = 'mean: '+str(mean_point))\n", 95 | "\n", 96 | " plt.plot([np.percentile(aucs,2.5),np.percentile(aucs,2.5)],[0,max(x[0])],label = 'lower interval: '+str(lower_point))\n", 97 | " plt.plot([np.percentile(aucs,97.5),np.percentile(aucs,97.5)],[0,max(x[0])],label = 'higher interval: '+str(higher_point))\n", 98 | " plt.title(\"AUC Histogram\")\n", 99 | " plt.xlabel(\"AUC\")\n", 100 | " plt.legend()\n", 101 | " plt.show()\n", 102 | " \n", 103 | " plt.figure()\n", 104 | " lw = 2\n", 105 | " for i in range(0,1000):\n", 106 | " plt.plot(fpr_boot[i],tpr_boot[i], color='lightblue',\n", 107 | " lw=lw)\n", 108 | " fpr,tpr, _ = roc_curve(y_total,yhat_total)\n", 109 | " plt.plot(fpr, tpr, color='darkorange',\n", 110 | " lw=lw, label='ROC curve (area = %0.2f)' % auc(fpr,tpr))\n", 111 | " plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n", 112 | " plt.xlim([0.0, 1.0])\n", 113 | " plt.ylim([0.0, 1.05])\n", 114 | " plt.xlabel('False Positive Rate')\n", 115 | " plt.ylabel('True Positive Rate')\n", 116 | " plt.title('ROC Curve')\n", 117 | " plt.legend(loc=\"lower right\")\n", 118 | " plt.show()\n", 119 | "def cm(y_total,yhat_total,Project_name = None,folder=None,threshold = 0.5):\n", 120 | " print(threshold)\n", 121 | " cm = confusion_matrix(y_total,yhat_total>threshold)\n", 122 | " tn, fp, fn, tp = confusion_matrix(y_total,yhat_total>threshold).ravel()\n", 123 | " specificity = ( tn / (tn+fp) )\n", 124 | " sensitivity= ( tp / (tp+fn) )\n", 125 | " print('Positive Predictive Value',round(tp/(tp+fp),2),'Negative Predictive Value', round(tn/(tn+fn),2), ' Specificty ', specificity, 'Sensitivity ', sensitivity)\n", 126 | " disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['No Event','Adverse Event'])\n", 127 | " disp.plot()\n", 128 | " plt.title(\"Confusion Matrix\")\n", 129 | " plt.show()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "Val_root = '/workspace/John/IntroECG-main/IntroECG-main/data/Definitely Not A Mistake/'\n", 139 | "Val_csv = '/workspace/John/IntroECG-main/IntroECG-main/data/Test_rcri_outcome.csv'\n" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "# RCRINet on RCRI Outcome" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "scrolled": false 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "bs = 2000\n", 158 | "checkpoint = 'best_roc_model_Mortality_with_RCRI_Features.pt'\n", 159 | "additional_inputs = ['CrGreaterThan2','is_risk','insulin','cad','chf','stroke']\n", 160 | "\n", 161 | "test_ds = ECG_loader(root = Val_root, csv = Val_csv,sliding = False,downsample=1,additional_inputs=additional_inputs)\n", 162 | "val_dataloader=torch.utils.data.DataLoader(test_ds, batch_size=bs,num_workers=10,drop_last=False)\n", 163 | "x = produce_df(val_dataloader,checkpoint,folder = 'Spare',stride = 8, dilation = 2,additional_inputs=len(additional_inputs))\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "Python 3", 177 | "language": "python", 178 | "name": "python3" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "python", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.6.10" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 4 195 | } 196 | -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/Process Data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd \n", 10 | "import numpy as np \n", 11 | "import xmltodict\n", 12 | "import base64\n", 13 | "import struct\n", 14 | "import argparse\n", 15 | "import os\n", 16 | "import sys\n", 17 | "from tqdm import tqdm\n", 18 | "from scipy import stats\n", 19 | "\n", 20 | "args = {\n", 21 | " 'in_XML' : '/workspace/data/NAS/Muse2013-19ECGs_XML', # where the XML files are located, put none if these are already calculated\n", 22 | " 'in_csv': '/workspace/data/NAS/Cedars EKG Analysis/Muse2008-19 XML Metadata.csv',# Where the csv is located for your input dataset\n", 23 | " 'in_dataset': '/workspace/data/drives/Internal_SSD/sdc/EKG Test/',# Where the converted XMLs will be moved, or, where the numpy arrays are currently stored\n", 24 | " 'out_dataset': '/workspace/data/drives/Internal_SSD/sdc/Comparison/', # Where the normalized EKGs are stored\n", 25 | " 'sample':100000, # sample size to estimate mean and standard deviation\n", 26 | " 'mean':np.array([-0.753910531911661,-0.5609376530271284,0.19297287888453685,0.6574240924693946,-0.09648643944226842,0.09648643944226842,-0.9398182103547104,-0.8866948773251518,-0.9585726095365399,-0.9142084935751398,-0.9573448180888456,-0.9300810636208064]),\n", 27 | " 'std':np.array([32.082092358503644,34.97862852596865,38.153409045189754,27.612712586528637,19.076704522594877,19.076704522594877,42.77931881050877,63.872623440588406,61.15731462396783,54.12879139607189,48.435274440820855,43.34056377213695])\n", 28 | " \n", 29 | "}# replace mean and std with none to calculate your own" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "\n", 39 | "\n", 40 | "\n", 41 | "\n", 42 | "def file_path(path):\n", 43 | " return os.listdir(path)\n", 44 | " \n", 45 | "#need to update this function to check the output directory for the output file and then only on newly added EKGs\n", 46 | "#add timestamp to start file string \n", 47 | "#this is annoying because the XML file name is a random timestamp and the output file is the UniqueECGID\n", 48 | "\n", 49 | "\n", 50 | "#if not os.path.exists('/D:/Muse2019ECGs_npy/'):\n", 51 | "# os.mkdir('/D:/Muse2019ECGs_npy/')\n", 52 | "\n", 53 | "# parser = argparse.ArgumentParser(description='Input and outputs for XML EKG parsing')\n", 54 | "# parser.add_argument('input', type=str)\n", 55 | "# parser.set_defaults(output=os.getcwd() + '/EchoNet_ECG_waveforms/') #ensure this directory already exists\n", 56 | "\n", 57 | "# args = parser.parse_args()\n", 58 | "\n", 59 | "\n", 60 | "\n", 61 | "def decode_ekg_muse(raw_wave):\n", 62 | " \"\"\"\n", 63 | " Ingest the base64 encoded waveforms and transform to numeric\n", 64 | " \"\"\"\n", 65 | " # covert the waveform from base64 to byte array\n", 66 | " arr = base64.b64decode(bytes(raw_wave, 'utf-8'))\n", 67 | "\n", 68 | " # unpack every 2 bytes, little endian (16 bit encoding)\n", 69 | " unpack_symbols = ''.join([char*int(len(arr)/2) for char in 'h'])\n", 70 | " byte_array = struct.unpack(unpack_symbols, arr)\n", 71 | " return byte_array\n", 72 | "\n", 73 | "\n", 74 | "def decode_ekg_muse_to_array(raw_wave, downsample = 1):\n", 75 | " \"\"\"\n", 76 | " Ingest the base64 encoded waveforms and transform to numeric\n", 77 | " downsample: 0.5 takes every other value in the array. Muse samples at 500/s and the sample model requires 250/s. So take every other.\n", 78 | " \"\"\"\n", 79 | " try:\n", 80 | " dwnsmpl = int(1//downsample)\n", 81 | " except ZeroDivisionError:\n", 82 | " print(\"You must downsample by more than 0\")\n", 83 | " # covert the waveform from base64 to byte array\n", 84 | " arr = base64.b64decode(bytes(raw_wave, 'utf-8'))\n", 85 | "\n", 86 | " # unpack every 2 bytes, little endian (16 bit encoding)\n", 87 | " unpack_symbols = ''.join([char*int(len(arr)/2) for char in 'h'])\n", 88 | " byte_array = struct.unpack(unpack_symbols, arr)\n", 89 | " return np.array(byte_array)[::dwnsmpl]\n", 90 | "\n", 91 | "\n", 92 | "\n", 93 | "def xml_to_np_array_file(path_to_xml, path_to_output = os.getcwd(),df=None):\n", 94 | " with open(path_to_xml, 'rb') as fd:\n", 95 | " dic = xmltodict.parse(fd.read().decode('utf8'))\n", 96 | " \n", 97 | " \"\"\"\n", 98 | " Upload the ECG as numpy array with shape=[5000,12] ([time, leads, 1]).\n", 99 | " The voltage unit should be in 1 mv/unit and the sampling rate should be 250/second (total 10 second).\n", 100 | " The leads should be ordered as follow I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6.\n", 101 | " \"\"\"\n", 102 | " try:\n", 103 | " pt_id = dic['RestingECG']['PatientDemographics']['PatientID']\n", 104 | " except:\n", 105 | " pt_id = \"none\"\n", 106 | " try:\n", 107 | " PharmaUniqueECGID = dic['RestingECG']['PharmaData']['PharmaUniqueECGID']\n", 108 | " except:\n", 109 | " PharmaUniqueECGID = \"none\"\n", 110 | " try:\n", 111 | " AcquisitionDateTime = dic['RestingECG']['TestDemographics']['AcquisitionDate'] + \"_\" + dic['RestingECG']['TestDemographics']['AcquisitionTime'].replace(\":\",\"-\")\n", 112 | " except:\n", 113 | " AcquisitionDateTime = \"none\" \n", 114 | "\n", 115 | " # try:\n", 116 | " # requisition_number = dic['RestingECG']['Order']['RequisitionNumber']\n", 117 | " # except:\n", 118 | " # print(\"no requisition_number\")\n", 119 | " # requisition_number = \"none\"\n", 120 | "\n", 121 | " #need to instantiate leads in the proper order for the model\n", 122 | " lead_order = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']\n", 123 | "\n", 124 | " \"\"\"\n", 125 | " Each EKG will have this data structure:\n", 126 | " lead_data = {\n", 127 | " 'I': np.array\n", 128 | " }\n", 129 | " \"\"\"\n", 130 | "\n", 131 | " lead_data = dict.fromkeys(lead_order)\n", 132 | " #lead_data = {leadid: None for k in lead_order}\n", 133 | "\n", 134 | "# for all_lead_data in dic['RestingECG']['Waveform']:\n", 135 | "# for single_lead_data in lead['LeadData']:\n", 136 | "# leadname = single_lead_data['LeadID']\n", 137 | "# if leadname in (lead_order):\n", 138 | "\n", 139 | " for lead in dic['RestingECG']['Waveform']:\n", 140 | " for leadid in range(len(lead['LeadData'])):\n", 141 | " sample_length = len(decode_ekg_muse_to_array(lead['LeadData'][leadid]['WaveFormData']))\n", 142 | " #sample_length is equivalent to dic['RestingECG']['Waveform']['LeadData']['LeadSampleCountTotal']\n", 143 | " if sample_length == 5000 or sample_length == 5500:\n", 144 | " lead_data[lead['LeadData'][leadid]['LeadID']] = decode_ekg_muse_to_array(lead['LeadData'][leadid]['WaveFormData'], downsample = 1)\n", 145 | " elif sample_length == 2500:\n", 146 | " raise ValueError(\"Sample frequency too low, try with frequency = 500/s\")\n", 147 | " else:\n", 148 | " continue\n", 149 | " #ensures all leads have 2500 samples and also passes over the 3 second waveform\n", 150 | " lead_data['III'] = (np.array(lead_data[\"II\"]) - np.array(lead_data[\"I\"]))\n", 151 | " lead_data['aVR'] = -(np.array(lead_data[\"I\"]) + np.array(lead_data[\"II\"]))/2\n", 152 | " lead_data['aVF'] = (np.array(lead_data[\"II\"]) - np.array(lead_data[\"I\"]))/2\n", 153 | " lead_data['aVL'] = (np.array(lead_data[\"I\"]) - np.array(lead_data[\"II\"]))/2\n", 154 | " \n", 155 | " lead_data = {k: lead_data[k] for k in lead_order}\n", 156 | " # drops V3R, V4R, and V7 if it was a 15-lead ECG\n", 157 | "\n", 158 | " # now construct and reshape the array\n", 159 | " # converting the dictionary to an np.array\n", 160 | " temp = []\n", 161 | " for key,value in lead_data.items():\n", 162 | " temp.append(value)\n", 163 | "\n", 164 | " #transpose to be [time, leads, ]\n", 165 | " ekg_array = np.array(temp).T\n", 166 | "\n", 167 | " #expand dims to [time, leads, 1]\n", 168 | " ekg_array = np.expand_dims(ekg_array, axis=-1)\n", 169 | "\n", 170 | " # Here is a check to make sure all the model inputs are the right shape\n", 171 | "# assert ekg_array.shape == (2500, 12, 1), \"ekg_array is shape {} not (2500, 12, 1)\".format(ekg_array.shape )\n", 172 | "\n", 173 | " # filename = '/ekg_waveform_{}_{}.npy'.format(pt_id, requisition_number)\n", 174 | " filename = '{}_{}_{}.npy'.format(pt_id, AcquisitionDateTime,PharmaUniqueECGID)\n", 175 | " path_to_output = os.path.join(path_to_output,filename)\n", 176 | " # print(path_to_output)\n", 177 | " if len(ekg_array.shape)==3:\n", 178 | " ekg_array = ekg_array[:,:,0]\n", 179 | " # if len(ekg_array) == 5500:\n", 180 | " # ekg_array = ekg_array[:-500,:]\n", 181 | " with open(path_to_output, 'wb') as f:\n", 182 | " np.save(f, ekg_array)\n", 183 | " if not type(df) == type(pd.DataFrame()):\n", 184 | " col = list(dic['RestingECG']['PatientDemographics'].keys())\n", 185 | " col.append(\"Filename\")\n", 186 | " col.append('LocationName')\n", 187 | " col.append('SiteName')\n", 188 | " for key in dic['RestingECG']['RestingECGMeasurements'].keys():\n", 189 | " col.append(key)\n", 190 | " col.append('DiagnosisStatement')\n", 191 | " col.append('TestReason')\n", 192 | " col.append('AdmitDiagnosis')\n", 193 | " col.append('xmlFilename')\n", 194 | " df = pd.DataFrame(columns=col)\n", 195 | " pid = dic['RestingECG']['PatientDemographics']\n", 196 | " pid['Filename']=filename\n", 197 | " pid['LocationName']=dic['RestingECG']['TestDemographics']['LocationName']\n", 198 | " pid['SiteName']= dic['RestingECG']['TestDemographics']['SiteName']\n", 199 | " for key in dic['RestingECG']['RestingECGMeasurements'].keys():\n", 200 | " pid[key]=dic['RestingECG']['RestingECGMeasurements'][key]\n", 201 | " \n", 202 | " try:\n", 203 | " pid['AdmitDiagnosis'] = dic['RestingECG']['Order']['AdmitDiagnosis']\n", 204 | " except:\n", 205 | " pid['AdmitDiagnosis'] = ''\n", 206 | " try:\n", 207 | " pid['TestReason'] = dic['RestingECG']['TestDemographics']['TestReason']\n", 208 | " except:\n", 209 | " pid['TestReason'] = ''\n", 210 | " \n", 211 | " try:\n", 212 | " s = ''\n", 213 | " for sen in dic['RestingECG']['Diagnosis']['DiagnosisStatement']:\n", 214 | " s+=sen['StmtText']\n", 215 | " pid['DiagnosisStatement'] = s\n", 216 | " except:\n", 217 | " pid['DiagnosisStatement'] = ''\n", 218 | " pid['xmlFilename'] = path_to_xml.split('/')[-1]\n", 219 | " df=df.append(pid,ignore_index=True)\n", 220 | " return df\n", 221 | "\n", 222 | "def ekg_batch_run(ekg_list,args):\n", 223 | " i = 0\n", 224 | " x = 0\n", 225 | " df = 1\n", 226 | " for file in tqdm(ekg_list):\n", 227 | " try:\n", 228 | " df = xml_to_np_array_file(os.path.join(args['in_XML'],file), args['in_dataset'],df=df)\n", 229 | " i+=1\n", 230 | " except Exception as e:\n", 231 | "\n", 232 | " print(\"file failed: \", file)\n", 233 | " print(file, e)\n", 234 | " x+=1\n", 235 | " if i % 10000 == 9999:\n", 236 | " print(f\"Succesfully converted {i} EKGs, failed converting {x} EKGs\")\n", 237 | " df.to_csv(os.path.join(args['in_dataset'],'Personal_Data.csv'),index=False)\n", 238 | " df.to_csv(os.path.join(args['in_dataset'],'Personal_Data.csv'),index=False)\n", 239 | " print(f\"Succesfully converted {i} EKGs, failed converting {x} EKGs\")\n", 240 | "\n", 241 | "if not args['in_XML'] is None:\n", 242 | " ekg_file_list = []\n", 243 | " ekg_file_list = file_path(args['in_XML']) #if you want input to be a directory\n", 244 | " print(\"Number of EKGs found: \", len(ekg_file_list))\n", 245 | "\n", 246 | " ekg_batch_run(ekg_file_list,args)\n", 247 | "\n", 248 | "\n" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "def std(n1,s1,n2,s2,y1,y2,y):\n", 258 | " top = n1*(s1**2)+n1*(s1**2) + n1*((y1-y)**2)+n2*((y2-y)**2)\n", 259 | " bot = n1+n2\n", 260 | " return np.sqrt(top/bot)\n", 261 | "def calc_mean_std(arr,mean,stds,count):\n", 262 | " count_t = count+1\n", 263 | " arr_mean = np.mean(arr,axis=0)\n", 264 | " arr_std = np.std(arr,axis=0)\n", 265 | " new_mean = arr_mean*1/count_t + mean * (count_t-1)/count_t\n", 266 | " return new_mean,std(5000,arr_std,5000*count,stds,arr_mean,mean,new_mean)\n", 267 | "\n", 268 | "\n", 269 | "fnames = pd.read_csv(args['in_csv']).Filename\n", 270 | "out = args['out_dataset']\n", 271 | "input_dir = args['in_dataset']\n", 272 | "trainData = []\n", 273 | "\n", 274 | "cur_mean = None\n", 275 | "cur_std = None\n", 276 | "\n", 277 | "arr = np.zeros(12)\n", 278 | "count = 1\n", 279 | "if args['mean'] is None:\n", 280 | " for npfile in tqdm(np.random.choice(fnames,args['sample'])):\n", 281 | " path = os.path.join(input_dir,npfile)\n", 282 | " file = np.load(path)\n", 283 | " trainData.append(file)\n", 284 | "\n", 285 | " trainData = np.array([trainData])\n", 286 | " m = []\n", 287 | "\n", 288 | " s = []\n", 289 | " for i in tqdm(range(0,12)):\n", 290 | " m.append(np.mean(trainData[:,:,:,i]))\n", 291 | " s.append(np.std(trainData[:,:,:,i],axis=None))\n", 292 | "else:\n", 293 | " print(\"Existing mean and std\")\n", 294 | " m = args['mean']\n", 295 | " s = args['std']\n", 296 | "del trainData\n", 297 | "for i in tqdm(fnames):\n", 298 | " if not os.path.exists(os.path.join(out,i)):\n", 299 | " file = np.load(os.path.join(input_dir,i))\n", 300 | " for k in range(0,12):\n", 301 | " file[:,k] = (file[:,k] - m[k])/s[k]\n", 302 | "\n", 303 | " np.save(os.path.join(out,i),file)\n", 304 | "print(\"Finished\")\n" 305 | ] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "Python 3", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.6.10" 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 4 329 | } 330 | -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/__pycache__/ECG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecg-net/PreOpNet/0732b48ff67131498cee25e1fec79d75b41a0c6f/PreOpNet Plus Clinical Features/__pycache__/ECG.cpython-36.pyc -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecg-net/PreOpNet/0732b48ff67131498cee25e1fec79d75b41a0c6f/PreOpNet Plus Clinical Features/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/best_roc_model_MACE_with_RCRI_Features.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecg-net/PreOpNet/0732b48ff67131498cee25e1fec79d75b41a0c6f/PreOpNet Plus Clinical Features/best_roc_model_MACE_with_RCRI_Features.pt -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/best_roc_model_Mortality_with_RCRI_Features.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecg-net/PreOpNet/0732b48ff67131498cee25e1fec79d75b41a0c6f/PreOpNet Plus Clinical Features/best_roc_model_Mortality_with_RCRI_Features.pt -------------------------------------------------------------------------------- /PreOpNet Plus Clinical Features/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | # + 6 | from torch import nn 7 | from collections import OrderedDict 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | def __init__(self,in_channel,out_channel,expansion,activation,stride=1,padding = 1): 12 | super(Bottleneck, self).__init__() 13 | self.stride=stride 14 | self.conv1 = nn.Conv1d(in_channel,in_channel*expansion,kernel_size = 1) 15 | self.conv2 = nn.Conv1d(in_channel*expansion,in_channel*expansion,kernel_size = 3, groups = in_channel*expansion, 16 | padding=padding,stride = stride) 17 | self.conv3 = nn.Conv1d(in_channel*expansion,out_channel,kernel_size = 1, stride =1) 18 | self.b0 = nn.BatchNorm1d(in_channel*expansion) 19 | self.b1 = nn.BatchNorm1d(in_channel*expansion) 20 | self.d = nn.Dropout() 21 | self.act = activation() 22 | def forward(self,x): 23 | if self.stride == 1: 24 | y = self.act(self.b0(self.conv1(x))) 25 | y = self.act(self.b1(self.conv2(y))) 26 | y = self.conv3(y) 27 | y = self.d(y) 28 | y = x+y 29 | return y 30 | else: 31 | y = self.act(self.b0(self.conv1(x))) 32 | y = self.act(self.b1(self.conv2(y))) 33 | y = self.conv3(y) 34 | return y 35 | 36 | from torch import nn 37 | from collections import OrderedDict 38 | 39 | class MBConv(nn.Module): 40 | def __init__(self,in_channel,out_channels,expansion,layers,activation=nn.ReLU6,stride = 2): 41 | super(MBConv, self).__init__() 42 | self.stack = OrderedDict() 43 | for i in range(0,layers-1): 44 | self.stack['s'+str(i)] = Bottleneck(in_channel,in_channel,expansion,activation) 45 | #self.stack['a'+str(i)] = activation() 46 | self.stack['s'+str(layers+1)] = Bottleneck(in_channel,out_channels,expansion,activation,stride=stride) 47 | # self.stack['a'+str(layers+1)] = activation() 48 | self.stack = nn.Sequential(self.stack) 49 | 50 | self.bn = nn.BatchNorm1d(out_channels) 51 | def forward(self,x): 52 | x = self.stack(x) 53 | return self.bn(x) 54 | 55 | 56 | """def MBConv(in_channel,out_channels,expansion,layers,activation=nn.ReLU6,stride = 2): 57 | stack = OrderedDict() 58 | for i in range(0,layers-1): 59 | stack['b'+str(i)] = Bottleneck(in_channel,in_channel,expansion,activation) 60 | stack['b'+str(layers)] = Bottleneck(in_channel,out_channels,expansion,activation,stride=stride) 61 | return nn.Sequential(stack)""" 62 | 63 | 64 | class EffNet(nn.Module): 65 | 66 | def __init__(self,num_additional_features = 0,depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280], 67 | dilation = 1,stride = 2,expansion = 6): 68 | super(EffNet, self).__init__() 69 | print("depth ",depth) 70 | self.stage1 = nn.Conv1d(12, channels[0], kernel_size=3, stride=stride, padding=1,dilation = dilation) #1 conv 71 | self.b0 = nn.BatchNorm1d(channels[0]) 72 | self.stage2 = MBConv(channels[0], channels[1], expansion, depth[0], stride=2)# 16 #input, output, depth # 3 conv 73 | self.stage3 = MBConv(channels[1], channels[2], expansion, depth[1], stride=2)# 24 # 4 conv # d 2 74 | self.Pool = nn.MaxPool1d(3, stride=1, padding=1) # 75 | self.stage4 = MBConv(channels[2], channels[3], expansion, depth[2], stride=2)# 40 # 4 conv # d 2 76 | self.stage5 = MBConv(channels[3], channels[4], expansion, depth[3], stride=2)# 80 # 5 conv # d 77 | self.stage6 = MBConv(channels[4], channels[5], expansion, depth[4], stride=2)# 112 # 5 conv 78 | self.stage7 = MBConv(channels[5], channels[6], expansion, depth[5], stride=2)# 192 # 5 conv 79 | self.stage8 = MBConv(channels[6], channels[7], expansion, depth[6], stride=2)# 320 # 5 conv 80 | 81 | self.stage9 = nn.Conv1d(channels[7], channels[8], kernel_size=1) 82 | self.AAP = nn.AdaptiveAvgPool1d(1) 83 | self.act = nn.ReLU() 84 | self.drop = nn.Dropout() 85 | self.num_additional_features = num_additional_features 86 | self.fc = nn.Linear(channels[8] + num_additional_features, 1) 87 | 88 | 89 | def forward(self, x): 90 | if self.num_additional_features >0: 91 | x,additional = x 92 | # N x 12 x 2500 93 | x = self.b0(self.stage1(x)) 94 | # N x 32 x 1250 95 | x = self.stage2(x) 96 | # N x 16 x 625 97 | x = self.stage3(x) 98 | # N x 24 x 313 99 | x = self.Pool(x) 100 | # N x 24 x 313 101 | 102 | x = self.stage4(x) 103 | # N x 40 x 157 104 | x = self.stage5(x) 105 | # N x 80 x 79 106 | x = self.stage6(x) 107 | # N x 112 x 40 108 | x = self.Pool(x) 109 | # N x 192 x 20 110 | 111 | x = self.stage7(x) 112 | # N x 320 x 10 113 | x = self.stage8(x) 114 | x = self.stage9(x) 115 | # N x 1280 x 10 116 | x = self.act(self.AAP(x)[:,:,0]) 117 | # N x 1280 118 | x = self.drop(x) 119 | if self.num_additional_features >0: 120 | x = torch.cat((x,additional),1) 121 | x = self.fc(x) 122 | # N x 1 123 | return x 124 | -------------------------------------------------------------------------------- /Process Data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Processing code adapted from https://github.com/PierreElias/IntroECG/\n", 10 | "\n", 11 | "\n", 12 | "import pandas as pd \n", 13 | "import numpy as np \n", 14 | "import xmltodict\n", 15 | "import base64\n", 16 | "import struct\n", 17 | "import argparse\n", 18 | "import os\n", 19 | "import sys\n", 20 | "from tqdm import tqdm\n", 21 | "from scipy import stats\n", 22 | "\n", 23 | "args = {\n", 24 | " 'in_XML' : '/workspace/data/NAS/Muse2013-19ECGs_XML', # where the XML files are located, put none if these are already calculated\n", 25 | " 'in_csv': '/workspace/data/NAS/Cedars EKG Analysis/Muse2008-19 XML Metadata.csv',# Where the csv is located for your input dataset\n", 26 | " 'in_dataset': '/workspace/data/drives/Internal_SSD/sdc/EKG Test/',# Where the converted XMLs will be moved, or, where the numpy arrays are currently stored\n", 27 | " 'out_dataset': '/workspace/data/drives/Internal_SSD/sdc/Comparison/', # Where the normalized EKGs are stored\n", 28 | " 'sample':100000, # sample size to estimate mean and standard deviation\n", 29 | " 'mean':np.array([-0.753910531911661,-0.5609376530271284,0.19297287888453685,0.6574240924693946,-0.09648643944226842,0.09648643944226842,-0.9398182103547104,-0.8866948773251518,-0.9585726095365399,-0.9142084935751398,-0.9573448180888456,-0.9300810636208064]),\n", 30 | " 'std':np.array([32.082092358503644,34.97862852596865,38.153409045189754,27.612712586528637,19.076704522594877,19.076704522594877,42.77931881050877,63.872623440588406,61.15731462396783,54.12879139607189,48.435274440820855,43.34056377213695])\n", 31 | " \n", 32 | "}# replace mean and std with none to calculate your own" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "def file_path(path):\n", 46 | " return os.listdir(path)\n", 47 | " \n", 48 | "#need to update this function to check the output directory for the output file and then only on newly added EKGs\n", 49 | "#add timestamp to start file string \n", 50 | "#this is annoying because the XML file name is a random timestamp and the output file is the UniqueECGID\n", 51 | "\n", 52 | "\n", 53 | "#if not os.path.exists('/D:/Muse2019ECGs_npy/'):\n", 54 | "# os.mkdir('/D:/Muse2019ECGs_npy/')\n", 55 | "\n", 56 | "# parser = argparse.ArgumentParser(description='Input and outputs for XML EKG parsing')\n", 57 | "# parser.add_argument('input', type=str)\n", 58 | "# parser.set_defaults(output=os.getcwd() + '/EchoNet_ECG_waveforms/') #ensure this directory already exists\n", 59 | "\n", 60 | "# args = parser.parse_args()\n", 61 | "\n", 62 | "\n", 63 | "\n", 64 | "def decode_ekg_muse(raw_wave):\n", 65 | " \"\"\"\n", 66 | " Ingest the base64 encoded waveforms and transform to numeric\n", 67 | " \"\"\"\n", 68 | " # covert the waveform from base64 to byte array\n", 69 | " arr = base64.b64decode(bytes(raw_wave, 'utf-8'))\n", 70 | "\n", 71 | " # unpack every 2 bytes, little endian (16 bit encoding)\n", 72 | " unpack_symbols = ''.join([char*int(len(arr)/2) for char in 'h'])\n", 73 | " byte_array = struct.unpack(unpack_symbols, arr)\n", 74 | " return byte_array\n", 75 | "\n", 76 | "\n", 77 | "def decode_ekg_muse_to_array(raw_wave, downsample = 1):\n", 78 | " \"\"\"\n", 79 | " Ingest the base64 encoded waveforms and transform to numeric\n", 80 | " downsample: 0.5 takes every other value in the array. Muse samples at 500/s and the sample model requires 250/s. So take every other.\n", 81 | " \"\"\"\n", 82 | " try:\n", 83 | " dwnsmpl = int(1//downsample)\n", 84 | " except ZeroDivisionError:\n", 85 | " print(\"You must downsample by more than 0\")\n", 86 | " # covert the waveform from base64 to byte array\n", 87 | " arr = base64.b64decode(bytes(raw_wave, 'utf-8'))\n", 88 | "\n", 89 | " # unpack every 2 bytes, little endian (16 bit encoding)\n", 90 | " unpack_symbols = ''.join([char*int(len(arr)/2) for char in 'h'])\n", 91 | " byte_array = struct.unpack(unpack_symbols, arr)\n", 92 | " return np.array(byte_array)[::dwnsmpl]\n", 93 | "\n", 94 | "\n", 95 | "\n", 96 | "def xml_to_np_array_file(path_to_xml, path_to_output = os.getcwd(),df=None):\n", 97 | " with open(path_to_xml, 'rb') as fd:\n", 98 | " dic = xmltodict.parse(fd.read().decode('utf8'))\n", 99 | " \n", 100 | " \"\"\"\n", 101 | " Upload the ECG as numpy array with shape=[5000,12] ([time, leads, 1]).\n", 102 | " The voltage unit should be in 1 mv/unit and the sampling rate should be 250/second (total 10 second).\n", 103 | " The leads should be ordered as follow I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6.\n", 104 | " \"\"\"\n", 105 | " try:\n", 106 | " pt_id = dic['RestingECG']['PatientDemographics']['PatientID']\n", 107 | " except:\n", 108 | " pt_id = \"none\"\n", 109 | " try:\n", 110 | " PharmaUniqueECGID = dic['RestingECG']['PharmaData']['PharmaUniqueECGID']\n", 111 | " except:\n", 112 | " PharmaUniqueECGID = \"none\"\n", 113 | " try:\n", 114 | " AcquisitionDateTime = dic['RestingECG']['TestDemographics']['AcquisitionDate'] + \"_\" + dic['RestingECG']['TestDemographics']['AcquisitionTime'].replace(\":\",\"-\")\n", 115 | " except:\n", 116 | " AcquisitionDateTime = \"none\" \n", 117 | "\n", 118 | " # try:\n", 119 | " # requisition_number = dic['RestingECG']['Order']['RequisitionNumber']\n", 120 | " # except:\n", 121 | " # print(\"no requisition_number\")\n", 122 | " # requisition_number = \"none\"\n", 123 | "\n", 124 | " #need to instantiate leads in the proper order for the model\n", 125 | " lead_order = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']\n", 126 | "\n", 127 | " \"\"\"\n", 128 | " Each EKG will have this data structure:\n", 129 | " lead_data = {\n", 130 | " 'I': np.array\n", 131 | " }\n", 132 | " \"\"\"\n", 133 | "\n", 134 | " lead_data = dict.fromkeys(lead_order)\n", 135 | " #lead_data = {leadid: None for k in lead_order}\n", 136 | "\n", 137 | "# for all_lead_data in dic['RestingECG']['Waveform']:\n", 138 | "# for single_lead_data in lead['LeadData']:\n", 139 | "# leadname = single_lead_data['LeadID']\n", 140 | "# if leadname in (lead_order):\n", 141 | "\n", 142 | " for lead in dic['RestingECG']['Waveform']:\n", 143 | " for leadid in range(len(lead['LeadData'])):\n", 144 | " sample_length = len(decode_ekg_muse_to_array(lead['LeadData'][leadid]['WaveFormData']))\n", 145 | " #sample_length is equivalent to dic['RestingECG']['Waveform']['LeadData']['LeadSampleCountTotal']\n", 146 | " if sample_length == 5000 or sample_length == 5500:\n", 147 | " lead_data[lead['LeadData'][leadid]['LeadID']] = decode_ekg_muse_to_array(lead['LeadData'][leadid]['WaveFormData'], downsample = 1)\n", 148 | " elif sample_length == 2500:\n", 149 | " raise ValueError(\"Sample frequency too low, try with frequency = 500/s\")\n", 150 | " else:\n", 151 | " continue\n", 152 | " #ensures all leads have 2500 samples and also passes over the 3 second waveform\n", 153 | " lead_data['III'] = (np.array(lead_data[\"II\"]) - np.array(lead_data[\"I\"]))\n", 154 | " lead_data['aVR'] = -(np.array(lead_data[\"I\"]) + np.array(lead_data[\"II\"]))/2\n", 155 | " lead_data['aVF'] = (np.array(lead_data[\"II\"]) - np.array(lead_data[\"I\"]))/2\n", 156 | " lead_data['aVL'] = (np.array(lead_data[\"I\"]) - np.array(lead_data[\"II\"]))/2\n", 157 | " \n", 158 | " lead_data = {k: lead_data[k] for k in lead_order}\n", 159 | " # drops V3R, V4R, and V7 if it was a 15-lead ECG\n", 160 | "\n", 161 | " # now construct and reshape the array\n", 162 | " # converting the dictionary to an np.array\n", 163 | " temp = []\n", 164 | " for key,value in lead_data.items():\n", 165 | " temp.append(value)\n", 166 | "\n", 167 | " #transpose to be [time, leads, ]\n", 168 | " ekg_array = np.array(temp).T\n", 169 | "\n", 170 | " #expand dims to [time, leads, 1]\n", 171 | " ekg_array = np.expand_dims(ekg_array, axis=-1)\n", 172 | "\n", 173 | " # Here is a check to make sure all the model inputs are the right shape\n", 174 | "# assert ekg_array.shape == (2500, 12, 1), \"ekg_array is shape {} not (2500, 12, 1)\".format(ekg_array.shape )\n", 175 | "\n", 176 | " # filename = '/ekg_waveform_{}_{}.npy'.format(pt_id, requisition_number)\n", 177 | " filename = '{}_{}_{}.npy'.format(pt_id, AcquisitionDateTime,PharmaUniqueECGID)\n", 178 | " path_to_output = os.path.join(path_to_output,filename)\n", 179 | " # print(path_to_output)\n", 180 | " if len(ekg_array.shape)==3:\n", 181 | " ekg_array = ekg_array[:,:,0]\n", 182 | " # if len(ekg_array) == 5500:\n", 183 | " # ekg_array = ekg_array[:-500,:]\n", 184 | " with open(path_to_output, 'wb') as f:\n", 185 | " np.save(f, ekg_array)\n", 186 | " if not type(df) == type(pd.DataFrame()):\n", 187 | " col = list(dic['RestingECG']['PatientDemographics'].keys())\n", 188 | " col.append(\"Filename\")\n", 189 | " col.append('LocationName')\n", 190 | " col.append('SiteName')\n", 191 | " for key in dic['RestingECG']['RestingECGMeasurements'].keys():\n", 192 | " col.append(key)\n", 193 | " col.append('DiagnosisStatement')\n", 194 | " col.append('TestReason')\n", 195 | " col.append('AdmitDiagnosis')\n", 196 | " col.append('xmlFilename')\n", 197 | " df = pd.DataFrame(columns=col)\n", 198 | " pid = dic['RestingECG']['PatientDemographics']\n", 199 | " pid['Filename']=filename\n", 200 | " pid['LocationName']=dic['RestingECG']['TestDemographics']['LocationName']\n", 201 | " pid['SiteName']= dic['RestingECG']['TestDemographics']['SiteName']\n", 202 | " for key in dic['RestingECG']['RestingECGMeasurements'].keys():\n", 203 | " pid[key]=dic['RestingECG']['RestingECGMeasurements'][key]\n", 204 | " \n", 205 | " try:\n", 206 | " pid['AdmitDiagnosis'] = dic['RestingECG']['Order']['AdmitDiagnosis']\n", 207 | " except:\n", 208 | " pid['AdmitDiagnosis'] = ''\n", 209 | " try:\n", 210 | " pid['TestReason'] = dic['RestingECG']['TestDemographics']['TestReason']\n", 211 | " except:\n", 212 | " pid['TestReason'] = ''\n", 213 | " \n", 214 | " try:\n", 215 | " s = ''\n", 216 | " for sen in dic['RestingECG']['Diagnosis']['DiagnosisStatement']:\n", 217 | " s+=sen['StmtText']\n", 218 | " pid['DiagnosisStatement'] = s\n", 219 | " except:\n", 220 | " pid['DiagnosisStatement'] = ''\n", 221 | " pid['xmlFilename'] = path_to_xml.split('/')[-1]\n", 222 | " df=df.append(pid,ignore_index=True)\n", 223 | " return df\n", 224 | "\n", 225 | "def ekg_batch_run(ekg_list,args):\n", 226 | " i = 0\n", 227 | " x = 0\n", 228 | " df = 1\n", 229 | " for file in tqdm(ekg_list):\n", 230 | " try:\n", 231 | " df = xml_to_np_array_file(os.path.join(args['in_XML'],file), args['in_dataset'],df=df)\n", 232 | " i+=1\n", 233 | " except Exception as e:\n", 234 | "\n", 235 | " print(\"file failed: \", file)\n", 236 | " print(file, e)\n", 237 | " x+=1\n", 238 | " if i % 10000 == 9999:\n", 239 | " print(f\"Succesfully converted {i} EKGs, failed converting {x} EKGs\")\n", 240 | " df.to_csv(os.path.join(args['in_dataset'],'Personal_Data.csv'),index=False)\n", 241 | " df.to_csv(os.path.join(args['in_dataset'],'Personal_Data.csv'),index=False)\n", 242 | " print(f\"Succesfully converted {i} EKGs, failed converting {x} EKGs\")\n", 243 | "\n", 244 | "if not args['in_XML'] is None:\n", 245 | " ekg_file_list = []\n", 246 | " ekg_file_list = file_path(args['in_XML']) #if you want input to be a directory\n", 247 | " print(\"Number of EKGs found: \", len(ekg_file_list))\n", 248 | "\n", 249 | " ekg_batch_run(ekg_file_list,args)\n", 250 | "\n", 251 | "\n" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "def std(n1,s1,n2,s2,y1,y2,y):\n", 261 | " top = n1*(s1**2)+n1*(s1**2) + n1*((y1-y)**2)+n2*((y2-y)**2)\n", 262 | " bot = n1+n2\n", 263 | " return np.sqrt(top/bot)\n", 264 | "def calc_mean_std(arr,mean,stds,count):\n", 265 | " count_t = count+1\n", 266 | " arr_mean = np.mean(arr,axis=0)\n", 267 | " arr_std = np.std(arr,axis=0)\n", 268 | " new_mean = arr_mean*1/count_t + mean * (count_t-1)/count_t\n", 269 | " return new_mean,std(5000,arr_std,5000*count,stds,arr_mean,mean,new_mean)\n", 270 | "\n", 271 | "\n", 272 | "fnames = pd.read_csv(args['in_csv']).Filename\n", 273 | "out = args['out_dataset']\n", 274 | "input_dir = args['in_dataset']\n", 275 | "trainData = []\n", 276 | "\n", 277 | "cur_mean = None\n", 278 | "cur_std = None\n", 279 | "\n", 280 | "arr = np.zeros(12)\n", 281 | "count = 1\n", 282 | "if args['mean'] is None:\n", 283 | " for npfile in tqdm(np.random.choice(fnames,args['sample'])):\n", 284 | " path = os.path.join(input_dir,npfile)\n", 285 | " file = np.load(path)\n", 286 | " trainData.append(file)\n", 287 | "\n", 288 | " trainData = np.array([trainData])\n", 289 | " m = []\n", 290 | "\n", 291 | " s = []\n", 292 | " for i in tqdm(range(0,12)):\n", 293 | " m.append(np.mean(trainData[:,:,:,i]))\n", 294 | " s.append(np.std(trainData[:,:,:,i],axis=None))\n", 295 | "else:\n", 296 | " print(\"Existing mean and std\")\n", 297 | " m = args['mean']\n", 298 | " s = args['std']\n", 299 | "del trainData\n", 300 | "for i in tqdm(fnames):\n", 301 | " if not os.path.exists(os.path.join(out,i)):\n", 302 | " file = np.load(os.path.join(input_dir,i))\n", 303 | " for k in range(0,12):\n", 304 | " file[:,k] = (file[:,k] - m[k])/s[k]\n", 305 | "\n", 306 | " np.save(os.path.join(out,i),file)\n", 307 | "print(\"Finished\")\n" 308 | ] 309 | } 310 | ], 311 | "metadata": { 312 | "kernelspec": { 313 | "display_name": "Python 3", 314 | "language": "python", 315 | "name": "python3" 316 | }, 317 | "language_info": { 318 | "codemirror_mode": { 319 | "name": "ipython", 320 | "version": 3 321 | }, 322 | "file_extension": ".py", 323 | "mimetype": "text/x-python", 324 | "name": "python", 325 | "nbconvert_exporter": "python", 326 | "pygments_lexer": "ipython3", 327 | "version": "3.6.10" 328 | } 329 | }, 330 | "nbformat": 4, 331 | "nbformat_minor": 4 332 | } 333 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PreOpNet 2 | 3 | PreOpNet is a purpose built deep learning architecture for predicting perioperative outcomes (mortality and major adverse cardiovascular events), and in our cohort, its predictive performance was superior to the revised cardiac risk index score and previously published neural architectures for interpreting ECGs despite requiring less computational processing. 4 | 5 | ![](Figure1.PNG) 6 | 7 | ## Preprocessing 8 | 9 | The first notebook, Process Data.ipynb will convert your xml files to normalized numpy arrays. To run the notebook, edit the args dictionary with your relivent information 10 | 11 | 1. in_XML: Where your XML files are located, put None if there is an existing set of numpy arrays 12 | 2. in_csv: Where your csv is located. This csv should have a column labeled 'Filename', with the names of the numpy files, and a column for every topic of interest (mortality and MACE) 13 | 3. in_dataset: Where your numpy arrays are located, or will be located. 14 | 4. out_dataset: Where your normalized EKGs will be located 15 | 5. sample: The size of the random sample you will use to estimate mean and std 16 | 17 | ## Inference: 18 | 19 | This notebook will calculate the AUC, confusion matrix, and predictions for your dataset 20 | 21 | To run this notebook, fill out these variables: 22 | 1. Val_root: The location of your input dataset (this should match out_dataset from the first notebook) 23 | 2. Val_csv: The csv of your input dataset (this should match in_csv from the first notebook) 24 | 3. target: If you are testing on a column in your input dataset labeled something other than 'Mortality', you must provide a target to val_dataloader 25 | 4. checkpoint: Provide the path to the model checkpoint (PreOpNet Mortality best_roc_model.pt, PreOpNet MACE best_roc_model.pt) 26 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | Academic Software License: © 2022 Cedars-Sinai Medical Center ("Institution"). 2 | 3 | Academic or nonprofit researchers are permitted to use this Software (as defined below) subject to Paragraphs 1-4: 4 | 1. Institution hereby grants to you free of charge, so long as you are an academic or nonprofit researcher, a nonexclusive license under Institution's copyright ownership interest in this software and any derivative works made by you thereof (collectively, the "Software") to use, copy, and make derivative works of the Software solely for educational or academic research purposes, in all cases subject to the terms of this Academic Software License. Except as granted herein, all rights are reserved by Institution, including the right to pursue patent protection of the Software. 5 | 6 | 2. Please note you are prohibited from further transferring the Software -- including any derivatives you make thereof -- to any person or entity. Failure by you to adhere to the requirements in Paragraphs 1 and 2 will result in immediate termination of the license granted to you pursuant to this Academic Software License effective as of the date you first used the Software. 7 | 8 | 3. IN NO EVENT SHALL INSTITUTION BE LIABLE TO ANY ENTITY OR PERSON FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE, EVEN IF INSTITUTION HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. INSTITUTION SPECIFICALLY DISCLAIMS ANY AND ALL WARRANTIES, EXPRESS AND IMPLIED, INCLUDING, BUT NOT LIMITED TO: (A) ANY IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE AND (B) THAT THE SOFTWARE WILL BE FREE FROM ANY INFRINGEMENT ON PATENTS, COPYRIGHTS, OR OTHER INTELLECTUAL PROPERTY RIGHTS OF THIRD PARTIES. THE SOFTWARE IS PROVIDED "AS IS." INSTITUTION HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS OF THIS SOFTWARE. 9 | 10 | 4. Any academic or scholarly publication arising from the use of this Software or any derivative works thereof will include the following acknowledgment: The Software used in this research was created by John Theurer, Christine Albert, and David Ouyang of Cedars-Sinai Medical Center. © 2022 Cedars-Sinai Medical Center. 11 | 12 | Commercial entities or for commercial use of the Software: please contact CSTechTransfer@cshs.org for licensing opportunities. 13 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | # + 6 | from torch import nn 7 | from collections import OrderedDict 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | def __init__(self,in_channel,out_channel,expansion,activation,stride=1,padding = 1): 12 | super(Bottleneck, self).__init__() 13 | self.stride=stride 14 | self.conv1 = nn.Conv1d(in_channel,in_channel*expansion,kernel_size = 1) 15 | self.conv2 = nn.Conv1d(in_channel*expansion,in_channel*expansion,kernel_size = 3, groups = in_channel*expansion, 16 | padding=padding,stride = stride) 17 | self.conv3 = nn.Conv1d(in_channel*expansion,out_channel,kernel_size = 1, stride =1) 18 | self.b0 = nn.BatchNorm1d(in_channel*expansion) 19 | self.b1 = nn.BatchNorm1d(in_channel*expansion) 20 | self.d = nn.Dropout() 21 | self.act = activation() 22 | def forward(self,x): 23 | if self.stride == 1: 24 | y = self.act(self.b0(self.conv1(x))) 25 | y = self.act(self.b1(self.conv2(y))) 26 | y = self.conv3(y) 27 | y = self.d(y) 28 | y = x+y 29 | return y 30 | else: 31 | y = self.act(self.b0(self.conv1(x))) 32 | y = self.act(self.b1(self.conv2(y))) 33 | y = self.conv3(y) 34 | return y 35 | 36 | from torch import nn 37 | from collections import OrderedDict 38 | 39 | class MBConv(nn.Module): 40 | def __init__(self,in_channel,out_channels,expansion,layers,activation=nn.ReLU6,stride = 2): 41 | super(MBConv, self).__init__() 42 | self.stack = OrderedDict() 43 | for i in range(0,layers-1): 44 | self.stack['s'+str(i)] = Bottleneck(in_channel,in_channel,expansion,activation) 45 | #self.stack['a'+str(i)] = activation() 46 | self.stack['s'+str(layers+1)] = Bottleneck(in_channel,out_channels,expansion,activation,stride=stride) 47 | # self.stack['a'+str(layers+1)] = activation() 48 | self.stack = nn.Sequential(self.stack) 49 | 50 | self.bn = nn.BatchNorm1d(out_channels) 51 | def forward(self,x): 52 | x = self.stack(x) 53 | return self.bn(x) 54 | 55 | 56 | """def MBConv(in_channel,out_channels,expansion,layers,activation=nn.ReLU6,stride = 2): 57 | stack = OrderedDict() 58 | for i in range(0,layers-1): 59 | stack['b'+str(i)] = Bottleneck(in_channel,in_channel,expansion,activation) 60 | stack['b'+str(layers)] = Bottleneck(in_channel,out_channels,expansion,activation,stride=stride) 61 | return nn.Sequential(stack)""" 62 | 63 | 64 | class EffNet(nn.Module): 65 | 66 | def __init__(self,num_additional_features = 0,depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280], 67 | dilation = 1,stride = 2,expansion = 6): 68 | super(EffNet, self).__init__() 69 | print("depth ",depth) 70 | self.stage1 = nn.Conv1d(12, channels[0], kernel_size=3, stride=stride, padding=1,dilation = dilation) #1 conv 71 | self.b0 = nn.BatchNorm1d(channels[0]) 72 | self.stage2 = MBConv(channels[0], channels[1], expansion, depth[0], stride=2)# 16 #input, output, depth # 3 conv 73 | self.stage3 = MBConv(channels[1], channels[2], expansion, depth[1], stride=2)# 24 # 4 conv # d 2 74 | self.Pool = nn.MaxPool1d(3, stride=1, padding=1) # 75 | self.stage4 = MBConv(channels[2], channels[3], expansion, depth[2], stride=2)# 40 # 4 conv # d 2 76 | self.stage5 = MBConv(channels[3], channels[4], expansion, depth[3], stride=2)# 80 # 5 conv # d 77 | self.stage6 = MBConv(channels[4], channels[5], expansion, depth[4], stride=2)# 112 # 5 conv 78 | self.stage7 = MBConv(channels[5], channels[6], expansion, depth[5], stride=2)# 192 # 5 conv 79 | self.stage8 = MBConv(channels[6], channels[7], expansion, depth[6], stride=2)# 320 # 5 conv 80 | 81 | self.stage9 = nn.Conv1d(channels[7], channels[8], kernel_size=1) 82 | self.AAP = nn.AdaptiveAvgPool1d(1) 83 | self.act = nn.ReLU() 84 | self.drop = nn.Dropout() 85 | self.num_additional_features = num_additional_features 86 | self.fc = nn.Linear(channels[8] + num_additional_features, 1) 87 | 88 | 89 | def forward(self, x): 90 | if self.num_additional_features >0: 91 | x,additional = x 92 | # N x 12 x 2500 93 | x = self.b0(self.stage1(x)) 94 | # N x 32 x 1250 95 | x = self.stage2(x) 96 | # N x 16 x 625 97 | x = self.stage3(x) 98 | # N x 24 x 313 99 | x = self.Pool(x) 100 | # N x 24 x 313 101 | 102 | x = self.stage4(x) 103 | # N x 40 x 157 104 | x = self.stage5(x) 105 | # N x 80 x 79 106 | x = self.stage6(x) 107 | # N x 112 x 40 108 | x = self.Pool(x) 109 | # N x 192 x 20 110 | 111 | x = self.stage7(x) 112 | # N x 320 x 10 113 | x = self.stage8(x) 114 | x = self.stage9(x) 115 | # N x 1280 x 10 116 | x = self.act(self.AAP(x)[:,:,0]) 117 | # N x 1280 118 | x = self.drop(x) 119 | if self.num_additional_features >0: 120 | x = torch.cat((x,additional),1) 121 | x = self.fc(x) 122 | # N x 1 123 | return x 124 | --------------------------------------------------------------------------------