├── README.md ├── app ├── README.md ├── app.py └── img.jpg ├── classifier ├── README.md ├── classifier.sh ├── classifier_LR.py ├── classifier_MLP.py └── classifier_RF.py ├── data ├── README.md └── data.py ├── docs ├── app_bokeh.png ├── app_main.png ├── app_start.png └── map.png ├── extractor ├── README.md ├── gen_feature.py ├── gen_feature.sh ├── map_gen_feature.py └── reduce_gen_feature.py ├── length_effect ├── README.md ├── draw.py ├── length_effect.py ├── length_effect.sh ├── map_length_effect.py └── reduce_length_effect.py ├── lib ├── __init__.py ├── util.py └── visual.py ├── model └── weights_tail_best.hdf5 └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Anomaly Detection in Time Series with Triadic Motif Fields and Application in Atrial Fibrillation ECG Classification 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2012.04936-b31b1b.svg)](https://arxiv.org/abs/2012.04936) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/ydup/bokeh/master?urlpath=/proxy/5006/bokeh-app) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/anomaly-detection-in-time-series-with-triadic/atrial-fibrillation-detection-on-physionet)](https://paperswithcode.com/sota/atrial-fibrillation-detection-on-physionet?p=anomaly-detection-in-time-series-with-triadic) 4 | 5 | Author: [Yadong Zhang](https://github.com/ydup) and Xin Chen 6 | 7 | Paper: [![arXiv](https://img.shields.io/badge/arXiv-2012.04936-b31b1b.svg)](https://arxiv.org/abs/2012.04936) 8 | 9 | Online demo: [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/ydup/bokeh/master?urlpath=/proxy/5006/bokeh-app) 10 | 11 | ### Modules 12 | 13 | Module | Path | Note | Default Settings 14 | --- | --- | --- | --- 15 | Basic | 1. [_lib_](lib/)
2. [_data_](data/)
3. [_model_](model/) | 1. Basic functions of the project.
2. Dataset processing.
3. Saved tail model weights. | 1. -
2. no filter, z-normalization
3. MLP model 16 | Classification | 1. [_extractor_](extractor/)
2. [_classifier_](classifier/) | 1. Features extraction of TMF images based on transfer learning.
2. Feature vectors classification to AF and non-AF probabilities. | 1. VGG16, map-reduce use ```10``` nodes and ```5``` mpisize.
2. - 17 | Evaluation | 1. [_length\_effect_](length_effect/) | 1. Evaluate the trained model on varying-length ECG signals. | 1. VGG16-MLP, map-reduce use ```10``` nodes and ```5``` mpisize. 18 | App | 1. [_pyQT app_](app/)
2. [_bokeh app_](https://github.com/ydup/bokeh) | 1. Local app for classification and interpretation.
2. Web server for interpretation. | VGG16-MLP 19 | 20 | ### Structures of Parallel Codes (mpi) 21 | 22 | [_extractor_](extractor/) and [_length\_effect_](length_effect/) are parallelized on the linux clustering. (map-reduce) 23 | + ```.py```: main code. 24 | + ```.sh```: script for single submission to the pbs queue. 25 | + ```map*.py```: map the tasks to multi-nodes and mpi. 26 | + ```reduce*.py```: collect the results from the finished tasks. 27 | 28 | ### Guidelines of APP 29 | 30 | Features | Classification | Visualization | Interactive | Remote | Local 31 | --- | --- | --- | --- |--- |--- 32 | [pyQT app](app/) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: 33 | [bokeh app](https://github.com/ydup/bokeh) | :x: (available in future) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: 34 | 35 | #### [pyQT app](app/) 36 | 37 | 1. Start page (click ```start```) 38 | + Start button 39 | + Process bar & status 40 | ![hello](docs/app_start.png) 41 | 2. Main page (from top to bottom) 42 | + Time series with label 43 | + Symmetrized Grad-CAM of AF and its predicted probability 44 | + Symmetrized Grad-CAM of non-AF and its predicted probability 45 | + Sliders of ```time index``` and ```delay``` to adjust the triadic time series motifs 46 | - Triad (red) in time series is corresponding to the cross (white) in two Symmetrized Grad-CAM images 47 | - The text with red background indicates the predicted type. 48 | ![main](docs/app_main.png) 49 | #### [bokeh app](https://github.com/ydup/bokeh) 50 | ![bokeh](docs/app_bokeh.png) 51 | 52 | ### [Requirements](./requirements.txt) 53 | Python 3.6: 54 | ``` 55 | matplotlib 56 | mpi4py==3.0.3 57 | numba==0.50.1 58 | scikit-learn==0.23.0 59 | scipy==1.5.2 60 | tensorflow==1.14.0 61 | opencv-python 62 | tqdm 63 | PyQT5 64 | ``` 65 | 66 | ### Citation 67 | 68 | Cite our work with: 69 | ```latex 70 | @misc{zhang2020anomaly, 71 | title={Anomaly Detection in Time Series with Triadic Motif Fields and Application in Atrial Fibrillation ECG Classification}, 72 | author={Yadong Zhang and Xin Chen}, 73 | year={2020}, 74 | eprint={2012.04936}, 75 | archivePrefix={arXiv}, 76 | primaryClass={cs.LG} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | # Application 2 | 3 | Demo: 4 | 1. Default index of ECG in the test dataset. (AF signal) 5 | ```shell 6 | $ python3 app.py 7 | ``` 8 | 2. User defined index for example, 1. 9 | ```shell 10 | $ python3 app.py 1 11 | ``` 12 | -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ECG application (classification and interpretation). 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ python app.py 8 | ''' 9 | import matplotlib 10 | matplotlib.use('Qt5Agg') 11 | from matplotlib import pyplot as plt 12 | plt.rc('text', usetex=True) 13 | plt.rc('font', family='Times New Roman') 14 | import matplotlib.patches as patches 15 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT as NavigationToolbar 16 | from matplotlib.figure import Figure 17 | 18 | import time 19 | import pickle 20 | import numpy as np 21 | from PyQt5 import QtCore, QtGui, QtWidgets 22 | from PyQt5.QtWidgets import * 23 | from PyQt5.QtCore import * 24 | from PyQt5.QtGui import * 25 | 26 | import sys 27 | sys.path.append('../') 28 | from lib.util import TMF_image, get_GCAM, get_sym_GCAM 29 | from lib.util import build_fullnet as build 30 | 31 | StyleSheet = ''' 32 | #RedProgressBar { 33 | text-align: center; 34 | min-height: 12px; 35 | max-height: 12px; 36 | } 37 | #RedProgressBar::chunk { 38 | background-color: #F44336; 39 | } 40 | ''' 41 | 42 | DETAIL = {0: 'Load data', 10: 'Load model', 20: 'Get TMF image', 30: 'Get Grad-CAM of AF', 55: 'Get symmetrized Grad-CAM of AF', 60: 'Get Grad-CAM of non-AF', 85: 'Get symmetrized Grad-CAM of non-AF', 90: 'Predict', 100: 'Finished'} # detail of the process status 43 | 44 | class TSCanvas(FigureCanvasQTAgg): 45 | def __init__(self, parent=None, width=5, height=4, dpi=100): 46 | ''' 47 | ECG time series 48 | ''' 49 | fig = Figure(figsize=(width, height), dpi=dpi) 50 | self.axes = fig.add_subplot(111) 51 | super(TSCanvas, self).__init__(fig) 52 | 53 | class HMCanvas(FigureCanvasQTAgg): 54 | def __init__(self, parent=None, width=5, height=4, dpi=100): 55 | ''' 56 | Symmetrized Grad-CAM of TMF images 57 | ''' 58 | fig = Figure(figsize=(width, height), dpi=dpi) 59 | self.axes = fig.add_subplot(111) 60 | super(HMCanvas, self).__init__(fig) 61 | 62 | class MainWindow(QtWidgets.QMainWindow): 63 | 64 | def __init__(self, ts, gcam_AF, gcam_nAF, proba_AF, proba_nAF, label): 65 | 66 | super().__init__() 67 | self.gcam_AF = gcam_AF 68 | self.gcam_nAF = gcam_nAF 69 | self.ts = ts 70 | self.setFixedSize(640, 800) 71 | win = len(self.ts) # time series length 72 | D = 3 # triad 73 | self.overlap = win-(D-1)*self.gcam_AF.shape[0] # overlap of the heatmap 74 | 75 | # ECG time series, initial plot 76 | self.sc1 = HMCanvas(self, width=4.0, height=5, dpi=100) 77 | self.sc1.axes.plot(np.arange(len(self.ts)), self.ts, 'k', linewidth=1.0) 78 | self.sc1.axes.set_xlim([0, len(self.ts)]) 79 | self.sc1.axes.set_yticks([]) 80 | # Heatmap, initial plot 81 | self.sc2 = HMCanvas(self, width=4.0, height=5, dpi=100) 82 | self.sc2.axes.imshow(self.gcam_AF, cmap=plt.cm.jet) 83 | self.sc2.axes.set_xticks([]) 84 | self.sc2.axes.set_yticks([]) 85 | # Heatmap, initial plot 86 | self.sc3 = HMCanvas(self, width=4.0, height=5, dpi=100) 87 | self.sc3.axes.imshow(self.gcam_nAF, cmap=plt.cm.jet) 88 | self.sc3.axes.set_xticks([]) 89 | self.sc3.axes.set_yticks([]) 90 | 91 | # Slider for gap 92 | self.s=QSlider(Qt.Horizontal) 93 | self.s.setMinimum(1) 94 | self.s.setMaximum(self.gcam_AF.shape[0]) 95 | self.s.setSingleStep(1) 96 | self.s.setValue(20) 97 | self.s.setTickPosition(QSlider.TicksBelow) 98 | self.s.setTickInterval(100) 99 | self.s.valueChanged.connect(self.triadchange) # connect with the triadchange function to redraw the plots 100 | 101 | # Slider for initial time index of triad 102 | self.idx=QSlider(Qt.Horizontal) 103 | self.idx.setMinimum(0) 104 | self.idx.setMaximum(self.gcam_AF.shape[1]-1) 105 | self.idx.setSingleStep(1) 106 | self.idx.setValue(20) 107 | self.idx.setTickPosition(QSlider.TicksBelow) 108 | self.idx.setTickInterval(100) 109 | self.idx.valueChanged.connect(self.triadchange) 110 | 111 | # Labels 112 | self.l0 = QLabel('Label: %s'%('AF' if label == 1 else 'non-AF')) 113 | self.l0.setAlignment(Qt.AlignCenter) 114 | self.imp1 = QLabel('Probability of AF:%.3f'%proba_AF) 115 | self.imp1.setAlignment(Qt.AlignCenter) 116 | self.imp2 = QLabel('Probability of non-AF:%.3f'%proba_nAF) 117 | self.imp2.setAlignment(Qt.AlignCenter) 118 | if proba_AF > proba_nAF: 119 | self.imp1.setStyleSheet("background-color: red") 120 | elif proba_AF < proba_nAF: 121 | self.imp2.setStyleSheet("background-color: red") 122 | self.l1 = QLabel('Time index:') 123 | self.l1.setAlignment(Qt.AlignCenter) 124 | self.l2 = QLabel('Delay:') 125 | self.l2.setAlignment(Qt.AlignCenter) 126 | 127 | layout = QtWidgets.QVBoxLayout() 128 | layout.addWidget(self.l0) # label 129 | layout.addWidget(self.sc1) # Time series 130 | layout.addWidget(self.imp1) # Label: probability of AF 131 | layout.addWidget(self.sc2) # Heatmap 132 | layout.addWidget(self.imp2) # Label: probability of non-AF 133 | layout.addWidget(self.sc3) # Heatmap 134 | layout.addWidget(self.l1) # Label: initial time index 135 | layout.addWidget(self.idx) # Time index slider 136 | layout.addWidget(self.l2) # LabeL: gap 137 | layout.addWidget(self.s) # Gap slider 138 | # Create a placeholder widget to hold our toolbar and canvas. 139 | widget = QtWidgets.QWidget() 140 | widget.setLayout(layout) 141 | self.setCentralWidget(widget) 142 | self.show() 143 | self.setLayout(layout) 144 | 145 | def triadchange(self): 146 | ''' 147 | Re-draw the triad, time series and heatmap 148 | ''' 149 | start = self.idx.value() 150 | gap = self.s.value() 151 | 152 | # Re-draw the heatmap 153 | self.sc2.axes.cla() # Clear the canvas. 154 | self.sc2.axes.imshow(self.gcam_AF, cmap=plt.cm.jet) 155 | # Change color of + in heatmap 156 | self.sc2.axes.scatter(x=[start], y=[gap-1], s=100, color='w', marker='+') 157 | self.sc2.axes.set_xlim([0, self.gcam_AF.shape[1]]) 158 | self.sc2.axes.set_ylim([self.gcam_AF.shape[0], 0]) 159 | self.sc2.axes.set_xticks([]) 160 | self.sc2.axes.set_yticks([]) 161 | # Trigger the canvas to update and redraw. 162 | self.sc2.draw() 163 | 164 | # Re-draw the heatmap 165 | self.sc3.axes.cla() # Clear the canvas. 166 | self.sc3.axes.imshow(self.gcam_nAF, cmap=plt.cm.jet) 167 | # Change color of + in heatmap 168 | self.sc3.axes.scatter(x=[start], y=[gap-1], s=100, color='w', marker='+') 169 | self.sc3.axes.set_xlim([0, self.gcam_nAF.shape[1]]) 170 | self.sc3.axes.set_ylim([self.gcam_nAF.shape[0], 0]) 171 | self.sc3.axes.set_xticks([]) 172 | self.sc3.axes.set_yticks([]) 173 | # Trigger the canvas to update and redraw. 174 | self.sc3.draw() 175 | 176 | # Fix the symmetrized triad 177 | right_bound = len(self.ts)-(3-1)*gap 178 | if start + 1> right_bound: 179 | start = self.gcam_AF.shape[1] - start - 1 180 | gap = self.gcam_AF.shape[0] - gap + 1 181 | 182 | # Re-draw time series 183 | self.sc1.axes.cla() # Clear the canvas. 184 | self.sc1.axes.plot(np.arange(len(self.ts)), self.ts, 'k', linewidth=1.0) 185 | self.sc1.axes.plot(np.arange(start, start+gap*3, gap), self.ts[np.arange(start, start+gap*3, gap)], 'ro-', markersize=5) 186 | self.sc1.axes.set_xlim([0, len(self.ts)]) 187 | self.sc1.axes.set_yticks([]) 188 | self.sc1.draw() 189 | 190 | # Set new text for labels 191 | self.l1.setText('Time index:'+str(start)) 192 | self.l2.setText('Delay:'+str(gap)) 193 | 194 | def closeEvent(self, event): 195 | 196 | reply = QMessageBox.question(self, 'Warning', 197 | "Sure to exit?", QMessageBox.Yes | 198 | QMessageBox.No, QMessageBox.No) 199 | if reply == QMessageBox.Yes: 200 | self.hide() 201 | self.dialog = StartWindow() 202 | else: 203 | event.ignore() 204 | 205 | class Thread(QThread): 206 | _signal = pyqtSignal(int) 207 | def __init__(self): 208 | super(Thread, self).__init__() 209 | 210 | def __del__(self): 211 | self.wait() 212 | 213 | def run(self): 214 | pnums = list(DETAIL.keys()) 215 | pnums.sort() 216 | 217 | self._signal.emit(pnums[0]) 218 | net = build() 219 | with open(ecg_path, 'rb') as f: 220 | data = np.load(f) 221 | with open(label_path, 'rb') as f: 222 | label = pickle.load(f) 223 | self._signal.emit(pnums[1]) 224 | net = build() 225 | self._signal.emit(pnums[2]) 226 | # AF signal 227 | ts = data[IDX] 228 | val_y = label['Y_test'][IDX] 229 | 230 | D = 3 231 | shape = np.array([len(range(1, (len(ts)-1)//(D-1) + 1)), len(range(0, len(ts)-(D-1)*1)), D]) 232 | overlap = len(ts)-(D-1)*shape[0] 233 | # TMF image: [1, W, H, 3] 234 | img = np.zeros(shape) 235 | img = TMF_image(ts, overlap, img, D) 236 | img = np.expand_dims(img, axis=0) 237 | self._signal.emit(pnums[3]) 238 | # SG-CAM image of non-AF 239 | gcam = get_GCAM(net, img, [1,0], layers=-3) 240 | self._signal.emit(pnums[4]) 241 | gcam_norm = get_sym_GCAM(overlap, gcam, D) 242 | nAF_gcam = gcam_norm.copy() 243 | self._signal.emit(pnums[5]) 244 | # SG-CAM image of AF 245 | gcam = get_GCAM(net, img, [0,1], layers=-3) 246 | self._signal.emit(pnums[6]) 247 | gcam_norm = get_sym_GCAM(overlap, gcam, D) 248 | AF_gcam = gcam_norm.copy() 249 | self._signal.emit(pnums[7]) 250 | proba = net.predict(img) 251 | print('Predicted probabilities:', proba) 252 | nAF_proba, AF_proba = proba[0] 253 | global RES 254 | RES = [ts, AF_gcam, nAF_gcam, AF_proba, nAF_proba, val_y] 255 | self._signal.emit(pnums[8]) 256 | 257 | class StartWindow(QWidget): 258 | 259 | def __init__(self): 260 | super().__init__() 261 | self.initUI() 262 | 263 | def initUI(self): 264 | 265 | pixmap = QPixmap("img.jpg") 266 | 267 | # Start menu image 268 | lbl = QLabel(self) 269 | lbl.setPixmap(pixmap) 270 | self.setFixedSize(pixmap.width(),pixmap.height()) 271 | self.parentSize = [pixmap.width(),pixmap.height()] 272 | self.center() 273 | 274 | # Process bar 275 | self.progress = QProgressBar(self, objectName="RedProgressBar") 276 | self.progress.move(pixmap.width()//2-50, pixmap.height()-150) 277 | self.progress.resize(100, 20) 278 | # Start button 279 | self.startBtn = QPushButton("Start", self) 280 | self.startBtn.move(pixmap.width()//2-50, pixmap.height()-100) 281 | self.startBtn.clicked.connect(self.on_pushButton_clicked) 282 | self.startBtn.resize(100, 50) 283 | # Title 284 | self.label = QLabel(self) 285 | self.label.setFixedWidth(800) 286 | self.label.setFixedHeight(100) 287 | self.label.move((pixmap.width()-self.label.width())/2, (pixmap.height()-self.label.height())/2) 288 | self.label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 289 | self.label.setAlignment(Qt.AlignCenter) 290 | self.label.setText(u"Triadic Motif Field") 291 | self.label.setAutoFillBackground(False) 292 | self.label.setFont(QFont("Roman times", 30, QFont.Bold)) 293 | 294 | self.setWindowTitle('TMF: Interpretable Classification Model') 295 | self.show() 296 | 297 | def on_pushButton_clicked(self): 298 | self.thread = Thread() 299 | self.thread._signal.connect(self.signal_accept) 300 | self.thread.start() 301 | self.startBtn.setEnabled(False) 302 | 303 | def signal_accept(self, msg): 304 | self.progress.setValue(int(msg)) 305 | self.label.setText(DETAIL[int(msg)]) 306 | print(DETAIL[int(msg)]) 307 | if self.progress.value() == 100: 308 | self.hide() 309 | self.progress.setValue(0) 310 | global RES 311 | self.dialog = MainWindow(*RES) 312 | 313 | def center(self): 314 | qr = self.frameGeometry() 315 | cp = QDesktopWidget().availableGeometry().center() 316 | qr.moveCenter(cp) 317 | self.move(qr.topLeft()) 318 | 319 | def closeEvent(self, event): 320 | 321 | reply = QMessageBox.question(self, 'Warning', 322 | "Sure to exit?", QMessageBox.Yes | 323 | QMessageBox.No, QMessageBox.No) 324 | 325 | if reply == QMessageBox.Yes: 326 | event.accept() 327 | else: 328 | event.ignore() 329 | 330 | if __name__ == '__main__': 331 | 332 | if len(sys.argv) > 1: 333 | IDX = int(sys.argv[1]) 334 | else: # default value 335 | IDX = 3736 # AF signal 336 | 337 | ecg_path = '../data/ECG_X_test.bin' # 2-d array, [N, dim] 338 | label_path = '../data/ECG_info.pkl' # dict, key='Y_test', 0 and 1 indicate non-AF and AF 339 | app = QApplication(sys.argv) 340 | app.setStyleSheet(StyleSheet) 341 | ex = StartWindow() 342 | sys.exit(app.exec_()) -------------------------------------------------------------------------------- /app/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/app/img.jpg -------------------------------------------------------------------------------- /classifier/README.md: -------------------------------------------------------------------------------- 1 | # Classifiers 2 | 3 | 1. MLP: early stopping 4 | 2. RF and LR: random search 5 | 6 | ## Default 7 | 8 | Three classifiers use the feature extracted from VGG16 which is based on the original time series. ```mode = 'no'``` means no filter was used upon the ECG signal. 9 | ```feature = 'vgg16'``` means it use the VGG16 features. 10 | 11 | ## Demo 12 | 13 | 1. Run directly in command line 14 | ```shell 15 | $ python3 classifier_MLP.py vgg16 16 | ``` 17 | 2. Submit a pbs job 18 | ```shell 19 | $ qsub -v classifier=MLP,feature=vgg16 classifier.sh 20 | ``` 21 | 22 | -------------------------------------------------------------------------------- /classifier/classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -N Train_{classifier}_{feature} 3 | #PBS -l nodes=1:ppn=20 4 | #PBS -l walltime=88888:00:00 5 | #PBS -q adf 6 | #PBS -j oe 7 | #PBS -m ae 8 | #PBS -M your@email.com 9 | 10 | which python 11 | 12 | cd $PBS_O_WORKDIR 13 | NPROCS=`wc -l < $PBS_NODEFILE` 14 | 15 | python classifier_{classifier}.py ${feature} 16 | 17 | -------------------------------------------------------------------------------- /classifier/classifier_LR.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logistic regression with grid search the best parameters on the validation dataset 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ python classifier_LR.py vgg16 8 | """ 9 | import os 10 | import sys 11 | sys.path.append('../') 12 | from sklearn.ensemble import RandomForestRegressor 13 | from sklearn.linear_model import LogisticRegression 14 | from sklearn.model_selection import RandomizedSearchCV 15 | import numpy as np 16 | import pickle 17 | from tensorflow import keras 18 | from sklearn import metrics 19 | from lib.util import eval 20 | 21 | # Param 22 | data_path = '../' 23 | mode = 'no' 24 | nb_classes = 2 25 | feature = str(sys.argv[1]) if len(sys.argv) > 1 else 'vgg16' 26 | 27 | # Load dataset 28 | with open(data_path+'data/ECG_info.pkl', 'rb') as f: 29 | label = pickle.load(f) 30 | 31 | with open(data_path+'feature-{1}/train/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 32 | train_x = np.load(f) 33 | train_y = label['Y_train'] 34 | 35 | with open(data_path+'feature-{1}/val/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 36 | val_x = np.load(f) 37 | val_y = label['Y_val'] 38 | 39 | with open(data_path+'feature-{1}/test/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 40 | test_x = np.load(f) 41 | test_y = label['Y_test'] 42 | 43 | X = np.concatenate([train_x, val_x], axis=0) 44 | y = np.concatenate([train_y, val_y], axis=0) 45 | train_indices = np.arange(train_x.shape[0]) 46 | val_indices = np.arange(train_x.shape[0], X.shape[0]) 47 | 48 | # Number of trees in random forest 49 | random_grid = {'penalty' : ['l1', 'l2'], 50 | 'C' : [0.4, 0.6, 0.8, 1], 51 | 'solver' : ['liblinear', 'saga', 'lbfgs']} 52 | 53 | rf = LogisticRegression(n_jobs=20, verbose=2)#random_state = 42) 54 | 55 | cv = [(train_indices, val_indices)] 56 | 57 | search = RandomizedSearchCV( 58 | rf, 59 | param_distributions=random_grid, 60 | cv=cv, 61 | n_iter=10, 62 | verbose=10, 63 | n_jobs=20 64 | ) 65 | 66 | search.fit(X, y) 67 | print(search.best_params_) 68 | pred = rf.predict_proba(test_x) 69 | print(feature, eval(test_y, pred[:, 1])) 70 | -------------------------------------------------------------------------------- /classifier/classifier_MLP.py: -------------------------------------------------------------------------------- 1 | """ 2 | MLP classifier of the extracted features with early stopping strategy on the validation dataset. 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ python3 classifier_MLP.py vgg16 8 | """ 9 | import sys 10 | sys.path.append('../') 11 | import numpy as np 12 | import pickle 13 | from tensorflow import keras 14 | from lib.util import eval 15 | 16 | # Param 17 | data_path = '../' 18 | nb_classes = 2 19 | batch_size = 16 20 | nb_epochs = 10 21 | mode = 'no' 22 | feature = str(sys.argv[1]) if len(sys.argv) > 1 else 'vgg16' 23 | 24 | # Load data 25 | with open(data_path+'data/ECG_info.pkl', 'rb') as f: 26 | label = pickle.load(f) 27 | 28 | with open(data_path+'feature-{1}/train/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 29 | train_x = np.load(f) 30 | train_y = keras.utils.to_categorical(label['Y_train'], num_classes=nb_classes) 31 | 32 | with open(data_path+'feature-{1}/val/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 33 | val_x = np.load(f) 34 | val_y = keras.utils.to_categorical(label['Y_val'], num_classes=nb_classes) 35 | 36 | with open(data_path+'feature-{1}/test/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 37 | test_x = np.load(f) 38 | test_y = keras.utils.to_categorical(label['Y_test'], num_classes=nb_classes) 39 | 40 | dim = train_x.shape[1] 41 | 42 | def classifier(nb_classes): 43 | # classifier 44 | x = keras.layers.Input(shape=(dim)) 45 | dnn = keras.layers.Dense(128, activation='relu')(x) 46 | predictions = keras.layers.Dense(nb_classes, activation='softmax')(dnn) 47 | model = keras.models.Model(inputs=x, outputs=predictions) 48 | optimizer = keras.optimizers.Adam() 49 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 50 | return model 51 | 52 | def run(idx): 53 | net = classifier(nb_classes) 54 | 55 | reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=0.0001) 56 | checkpointer = keras.callbacks.ModelCheckpoint(filepath='../model/weights_tail_{0}.hdf5'.format(idx), verbose=1, save_best_only=True) 57 | 58 | # Train model on dataset 59 | hist = net.fit(train_x, train_y, batch_size=batch_size, nb_epoch=nb_epochs, verbose=1, validation_data=(val_x, val_y), callbacks = [reduce_lr, checkpointer], shuffle=True) 60 | 61 | net.load_weights('../model/weights_tail_{0}.hdf5'.format(idx)) 62 | 63 | test_pred = net.predict(test_x) 64 | res = eval(test_y[:, 1], test_pred[:, 1]) 65 | print('ROC_AUC:{0}, PR_AUC:{1}, F1:{2}'.format(*res)) 66 | 67 | return [idx] + res 68 | 69 | 70 | run(feature) 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /classifier/classifier_RF.py: -------------------------------------------------------------------------------- 1 | """ 2 | Random forest classifier with grid search the best parameters according to the performance on the validation dataset 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ python3 classifier_RF.py vgg16 8 | """ 9 | import sys 10 | sys.path.append('../') 11 | from sklearn.ensemble import RandomForestRegressor,RandomForestClassifier 12 | from sklearn.model_selection import RandomizedSearchCV 13 | import numpy as np 14 | import pickle 15 | from tensorflow import keras 16 | from sklearn import metrics 17 | from lib.util import eval 18 | 19 | # Param 20 | data_path = '../' 21 | mode = 'no' 22 | nb_classes = 2 23 | feature = str(sys.argv[1]) if len(sys.argv) > 1 else 'vgg16' 24 | 25 | # Load data 26 | with open(data_path+'data/ECG_info.pkl', 'rb') as f: 27 | label = pickle.load(f) 28 | 29 | with open(data_path+'feature-{1}/train/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 30 | train_x = np.load(f) 31 | train_y = label['Y_train'] 32 | 33 | with open(data_path+'feature-{1}/val/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 34 | val_x = np.load(f) 35 | val_y = label['Y_val'] 36 | 37 | with open(data_path+'feature-{1}/test/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f: 38 | test_x = np.load(f) 39 | test_y = label['Y_test'] 40 | 41 | X = np.concatenate([train_x, val_x], axis=0) 42 | y = np.concatenate([train_y, val_y], axis=0) 43 | train_indices = np.arange(train_x.shape[0]) 44 | val_indices = np.arange(train_x.shape[0], X.shape[0]) 45 | 46 | n_estimators = list([500, 1000, 1500]) 47 | max_features = list([32, 64, 128]) 48 | min_samples_split = [512, 1024] 49 | bootstrap = [True, False] 50 | 51 | # Create the random grid 52 | random_grid = {'n_estimators': n_estimators, 53 | 'max_features': max_features, 54 | 'min_samples_split': min_samples_split, 55 | 'bootstrap': bootstrap} 56 | 57 | rf = RandomForestClassifier(random_state=43, n_jobs=20) 58 | 59 | cv = [(train_indices, val_indices)] 60 | 61 | search = RandomizedSearchCV( 62 | rf, 63 | param_distributions=random_grid, 64 | cv=cv, 65 | n_iter=10, 66 | verbose=10, 67 | n_jobs=20) 68 | 69 | search.fit(X, y) 70 | print(search.best_params_) 71 | pred=search.predict_proba(test_x) 72 | print(feature, eval(test_y, pred[:, 1])) 73 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | 1. Source data: ```challenge2017.pkl``` (from [MINA](https://github.com/hsd1503/MINA)) 4 | 2. Processed data (run python3 data.py) 5 | + training, validation and test ECG signals: ```ECG_X_*.bin``` (numpy: array type) 6 | + label: ECG_info.pkl (pickle: dict type with keys, ```Y_train```, ```Y_val```, ```Y_test```, ```pid_test```) 7 | -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create dataset 3 | steps: 4 | 1. use the challenge2017.pkl from https://github.com/hsd1503/MINA 5 | 2. run: 6 | $ python3 data.py 7 | """ 8 | import sys 9 | sys.path.append('../') 10 | from lib.util import make_data_physionet 11 | 12 | make_data_physionet('./') 13 | -------------------------------------------------------------------------------- /docs/app_bokeh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/docs/app_bokeh.png -------------------------------------------------------------------------------- /docs/app_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/docs/app_main.png -------------------------------------------------------------------------------- /docs/app_start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/docs/app_start.png -------------------------------------------------------------------------------- /docs/map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/docs/map.png -------------------------------------------------------------------------------- /extractor/README.md: -------------------------------------------------------------------------------- 1 | # Feature extractors 2 | 3 | ## Default 4 | 5 | Features are generated using VGG16 (excluding its top three layers) from the TMF images of original ECG signals. ```model_type = 'vgg16', mode='no'```. MPI size is set as 5 in ```.sh``` and ```reduce*.py``` files. 6 | 7 | ## Demo 8 | 9 | 1. Submit the parallel jobs: training set, 10 nodes, and no filter 10 | ``` 11 | $ python map_gen_feature.py train 10 no 12 | ``` 13 | 2. Collect all the results from the finished jobs 14 | ``` 15 | $ python reduce_gen_feature.py train 10 no 16 | ``` -------------------------------------------------------------------------------- /extractor/gen_feature.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract the features of TMF image 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: map the dataset to 10 nodes and generate the first slice using 5 process 7 | $ mpirun -n 5 python3 gen_feature.py --mode train --freq no --slice 0 --nodes 10 8 | """ 9 | import os 10 | import sys 11 | sys.path.append('../') 12 | from tqdm import tqdm 13 | import numpy as np 14 | import pandas as pd 15 | from mpi4py import MPI 16 | from collections import OrderedDict, Counter 17 | import pickle as dill 18 | import argparse 19 | from tensorflow.keras.applications.vgg16 import VGG16 20 | from tensorflow.keras.applications import VGG19, ResNet50 21 | from tensorflow.keras.models import Model 22 | from tensorflow.keras.layers import Dense, GlobalAveragePooling2D 23 | from scipy.signal import butter, lfilter 24 | from lib.util import TMF_image as gen_TMF 25 | from lib.util import mkdir 26 | from lib.visual import filt_ECG 27 | 28 | def build(model_type): 29 | # create the base pre-trained model 30 | if model_type == 'vgg16': 31 | base_model = VGG16(weights='imagenet', include_top=False) 32 | elif model_type == 'vgg19': 33 | base_model = VGG19(weights='imagenet', include_top=False) 34 | elif model_type == 'resnet50': 35 | base_model = ResNet50(weights='imagenet', include_top=False) 36 | else: 37 | raise(Exception('model_type must be one of vgg16, vgg19 and resnet50.')) 38 | # add a global spatial average pooling layer 39 | x = base_model.output 40 | x = GlobalAveragePooling2D()(x) 41 | # this is the model we will train 42 | model = Model(inputs=base_model.input, outputs=x) 43 | return model 44 | 45 | parser = argparse.ArgumentParser() 46 | 47 | # Get the settings from command line 48 | parser.add_argument('--mode', type=str, help='data type: train, val or test') 49 | parser.add_argument('--freq', type=str, default='mid', help='filter type: low, mid, high, no') 50 | parser.add_argument('--slice', type=int, default=0, help='slice index of the dataset') 51 | parser.add_argument('--nodes', type=int, default=8, help='total slice of the dataset') 52 | args = parser.parse_args() 53 | 54 | mode = str(args.mode) # train, val or test 55 | freq = str(args.freq) # filter type 56 | slidx = int(args.slice) # index of node 57 | slice_num = int(args.nodes) # total nodes 58 | model_type = 'vgg16' # model for feature extraction 59 | 60 | extractor = build(model_type) 61 | 62 | comm = MPI.COMM_WORLD 63 | mpisize = int(comm.Get_size()) # total num of the cpu cores, the n_splits of the k-Fold 64 | mpirank = int(comm.Get_rank()) # rank of this core 65 | 66 | print(mpisize, mpirank) 67 | 68 | if mpirank == 0: 69 | # Slice and broadcast the data 70 | with open('../data/ECG_X_{0}.bin'.format(mode), 'rb') as fin: 71 | X = np.load(fin) 72 | total_num = len(X)//slice_num # map to nodes 73 | if slidx != slice_num-1: 74 | X = X[slidx*total_num: (slidx+1)*total_num] 75 | else: 76 | X = X[slidx*total_num: ] 77 | total_num = len(X) 78 | total_idx = np.arange(total_num) 79 | chunk_len = total_num//(mpisize-1) if total_num >= mpisize - 1 else total_num # map to process 80 | data = [X[total_idx[i*chunk_len: (i+1)*chunk_len]] for i in tqdm(range(mpisize), desc='Scatter data')] # scatter the data to other process 81 | else: 82 | data = None 83 | data = comm.scatter(data, root=0) 84 | 85 | # TMF 86 | data_path = '../feature-{3}/{0}/{1}/{2}/'.format(*[mode, freq, slidx, model_type]) 87 | mkdir(data_path) 88 | 89 | all_feature = [] 90 | 91 | win = 3000 # length of the time series 92 | D = 3 # order of the motif, 3 means triad 93 | shape = np.array([len(range(1, (win-1)//(D-1) + 1)), len(range(0, win-(D-1)*1)), D]) # TMF image shape 94 | overlap = win-(D-1)*shape[0] # overlap cause by the rotation in the TMF image 95 | TMF = np.zeros(shape) # placeholder of TMF image 96 | 97 | for ts in tqdm(data): 98 | filt_ts = filt_ECG(ts, freq) 99 | TMF = gen_TMF(filt_ts, overlap, TMF, D) # the gen_TMF function is optimized with numba package 100 | img = np.expand_dims(TMF, axis=0) # [1, W, H, 3] 101 | feature = extractor.predict(img) # [1, 512] 102 | all_feature.append(feature) 103 | 104 | all_feature = np.concatenate(all_feature, axis=0) if len(all_feature) != 0 else np.array([]) 105 | np.save(data_path+str(mpirank), all_feature) 106 | 107 | -------------------------------------------------------------------------------- /extractor/gen_feature.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -N ECG-${slice} 3 | #PBS -l nodes=1:ppn=20 4 | #PBS -l walltime=88888:00:00 5 | #PBS -q adf 6 | #PBS -j oe 7 | #PBS -m ae 8 | #PBS -M your@email.com 9 | 10 | which python 11 | 12 | cd $PBS_O_WORKDIR 13 | NPROCS=`wc -l < $PBS_NODEFILE` 14 | 15 | mpirun -n 5 python gen_feature.py --mode ${mode} --freq ${freq} --slice ${slice} --nodes ${nodes} 16 | 17 | -------------------------------------------------------------------------------- /extractor/map_gen_feature.py: -------------------------------------------------------------------------------- 1 | """ 2 | Map the gen_feature.py to different nodes and cores to run them parallelly. 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ python map_gen_feature.py train 10 no 8 | """ 9 | import sys 10 | sys.path.append('../') 11 | import subprocess 12 | from lib.util import mkdir 13 | mode= str(sys.argv[1]) 14 | nodes = int(sys.argv[2]) 15 | freq = str(sys.argv[3]) # no, mid 16 | 17 | for idx in range(nodes): 18 | qsub_command = """qsub -v slice={0},nodes={1},mode={2},freq={3} -q adf gen_feature.sh""".format(*[idx, nodes, mode, freq]) # submit the job script 19 | exit_status = subprocess.call(qsub_command, shell=True) # upload 20 | if exit_status is 1: # Check to make sure the job submitted 21 | print("Job {0} failed to submit".format(qsub_command)) 22 | 23 | 24 | -------------------------------------------------------------------------------- /extractor/reduce_gen_feature.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collect the features 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: collect the features of training dataset 7 | $ python reduce_gen_feature.py train 10 no 8 | """ 9 | import numpy as np 10 | import argparse 11 | import sys 12 | 13 | mode = str(sys.argv[1]) 14 | slide_num = int(sys.argv[2]) # slice num must be same as the nodes paramters of gen_feature.py 15 | mpi_size = 5 # must be same as the mpirun -n in .sh script 16 | freq = str(sys.argv[3]) # no filter 17 | path = '../feature-vgg16/{0}/{1}/'.format(*[mode, freq]) 18 | feature = [] 19 | for i in range(slide_num): 20 | for j in range(mpi_size): 21 | with open(path+'{0}/{1}.npy'.format(*[i, j]), "rb") as f: 22 | tmp = np.load(f) 23 | # print(tmp.shape) 24 | if len(tmp): 25 | feature.append(tmp) 26 | 27 | feature = np.concatenate(feature, axis=0) 28 | np.save(path+freq, feature) 29 | print(feature.shape) 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /length_effect/README.md: -------------------------------------------------------------------------------- 1 | # Length effect 2 | 3 | ## Default 4 | 5 | Length effect based on trained VGG-MLP model is evaluated on the test dataset. Length are selected from 100 to 3000 with 100 as the gap. 6 | 7 | ## Demo 8 | 9 | 1. Submit the parallel jobs: 10 nodes 10 | ``` 11 | $ python map_length_effect.py 10 12 | ``` 13 | 2. Collect all the results from the finished jobs 14 | ``` 15 | $ python reduce_length_effect.py prob 10 # collect probabilities 16 | $ python reduce_length_effect.py time 10 # collect time consumptions 17 | ``` -------------------------------------------------------------------------------- /length_effect/draw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Draw the lines of length effect 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ python3 draw.py 8 | """ 9 | from matplotlib.collections import PatchCollection, LineCollection 10 | from matplotlib import gridspec 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pickle 14 | import sys 15 | sys.path.append('../') 16 | plt.rc('text', usetex=True) 17 | plt.rc('font', family='Times New Roman') 18 | plt.rcParams['xtick.direction'] = 'in' 19 | plt.rcParams['ytick.direction'] = 'in' 20 | from matplotlib.patches import Rectangle 21 | from lib.visual import detectR 22 | from lib.util import eval 23 | 24 | # load dataset 25 | with open('./slice/prob.npy', 'rb') as f: 26 | prob_res = np.load(f) 27 | with open('./slice/time.npy', 'rb') as f: 28 | time_res = np.load(f) 29 | with open('../data/ECG_X_test.bin', 'rb') as f: 30 | data = np.load(f) 31 | with open('../data/ECG_info.pkl', 'rb') as f: 32 | label = pickle.load(f) 33 | label = label['Y_test'] 34 | 35 | '''Draw the length effect of an AF signal''' 36 | idx = 18107 # an AF signal 37 | ts = data[idx] 38 | out = prob_res[idx] 39 | t = time_res[idx] 40 | 41 | fig = plt.figure(figsize=(7, 4)) 42 | gs = gridspec.GridSpec(2, 1, height_ratios=[5, 1], hspace=0.) 43 | # Upper panel 44 | ax = plt.subplot(gs[0]) 45 | vertical = 1600 46 | color = 'tab:red' 47 | line = ax.plot(range(100, 3100, 100), 100 * np.array([p[1] for p in out]), '.r-', linewidth=1) 48 | ax.plot([vertical, vertical], [0, 100], 'r--', linewidth=1) 49 | ax.set_ylim([0, 100]) 50 | ax.set_ylabel('Probability of AF (\%)', color=color, fontsize=15) 51 | ax.tick_params(axis='y', labelcolor=color) 52 | 53 | ax2 = ax.twinx() # instantiate a second axes that shares the same x-axis 54 | ax2.spines['right'].set_color('tab:blue') 55 | ax2.spines['left'].set_color('tab:red') 56 | color = 'tab:blue' 57 | # we already handled the x-label with ax1 58 | ax2.set_ylabel('Time consumption (s)', color=color, fontsize=15) 59 | bar = ax2.bar(x=range(100, 3100, 100), height=np.array(t), width=20, color=color, alpha=0.7) 60 | ax2.tick_params(axis='y', labelcolor=color) 61 | ax2.set_xlim([0, 3000]) 62 | ax2.set_xticks([]) 63 | 64 | # Lower panel 65 | ax = plt.subplot(gs[1]) 66 | ax.plot(ts, 'k', linewidth=1) 67 | ax.set_ylabel('ECG', fontsize=15) 68 | ax.set_xlim([0, 3000]) 69 | ax.set_xlabel('Length of ECG signals', fontsize=15) 70 | ax.set_ylim([-2.5, 7.5]) 71 | ax.plot([vertical, vertical], [-2.5, 7.5], 'r--', linewidth=1) 72 | R_peak = detectR(ts) 73 | R_peak.sort() 74 | ax.plot(R_peak[6:8], ts[R_peak][6:8], 'r.', markersize=5) 75 | ax.text(R_peak[6]-160, ts[R_peak][6]-2, '7th', fontsize=15) 76 | ax.text(R_peak[7]-160, ts[R_peak][7]-2, '8th', fontsize=15) 77 | xt = np.array(ax.get_xticks(), dtype=int) 78 | xt = np.append(xt, vertical) 79 | xtl = xt.tolist() 80 | xtl.remove(1500) 81 | xtl[-1] = str(vertical) 82 | ax.set_xticks(np.delete(xt, 3)) 83 | ax.set_xticklabels(xtl) 84 | ax.set_ylim([-2, 7]) 85 | boxes = [Rectangle((0, -2), 1600, 9)] 86 | pc = PatchCollection(boxes, facecolor='red', alpha=0.2, edgecolor=None) 87 | ax.add_collection(pc) 88 | plt.tight_layout() 89 | fig.savefig('flexibility.pdf', dpi=300) 90 | 91 | '''Draw the statistical result''' 92 | res = [] 93 | for i in range(30): 94 | res.append(eval(label, prob_res[:, i, 1])) 95 | res = np.array(res) # [30, 3] 96 | 97 | fig = plt.figure(figsize=(7, 3)) 98 | ax = plt.subplot() 99 | ax.plot(np.arange(100, 3100, 100), 100 * res[:, 0], 'vy-', linewidth=1.0, markersize=6) 100 | ax.plot(np.arange(100, 3100, 100), 100 * res[:, 1], '.r--', linewidth=1.0, markersize=4) 101 | ax.plot(np.arange(100, 3100, 100), 100 * res[:, 2], 'og-.', linewidth=1.0, markersize=4) 102 | ax.set_ylim([0, 100]) 103 | ax.set_xlim([0, 3000]) 104 | ax.set_xlabel('Length of ECG signals', fontsize=15) 105 | ax.set_ylabel('Performance (\%)', fontsize=15, color='k') 106 | ax.legend(['ROC AUC', 'PR AUC', 'F1'], fontsize=15, 107 | framealpha=0, loc='center') # , ncol=3) 108 | ax.tick_params(axis='y', labelcolor='k') 109 | ax2 = ax.twinx() # instantiate a second axes that shares the same x-axis 110 | ax2.spines['right'].set_color('tab:blue') 111 | ax2.spines['left'].set_color('k') 112 | color = 'tab:blue' 113 | # we already handled the x-label with ax1 114 | ax2.set_ylabel('Time consumption (s)', color=color, fontsize=15) 115 | bar = ax2.bar(x=range(100, 3100, 100), height=np.mean(time_res, axis=0), width=20, color=color, alpha=0.7) 116 | ax2.tick_params(axis='y', labelcolor=color) 117 | ax2.set_xlim([0, 3000]) 118 | plt.tight_layout() 119 | fig.savefig('length_effect.pdf', dpi=300) 120 | -------------------------------------------------------------------------------- /length_effect/length_effect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Get the performance (time and probability) of model with different length of ECG signals 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ mpirun -n 5 python3 parallel_flex.py --slice 0 --nodes 10 8 | """ 9 | import os 10 | import sys 11 | sys.path.append('../') 12 | from tqdm import tqdm 13 | from mpi4py import MPI 14 | import numpy as np 15 | import pandas as pd 16 | from collections import OrderedDict, Counter 17 | import pickle as dill 18 | import argparse 19 | import time 20 | from lib.util import TMF_image as gen_TMF 21 | from lib.util import mkdir 22 | from lib.util import build_fullnet as build 23 | 24 | net = build() 25 | 26 | parser = argparse.ArgumentParser() 27 | 28 | # Get the settings from command line 29 | parser.add_argument('--slice', type=int, default=0, help='slice index of the dataset') 30 | parser.add_argument('--nodes', type=int, default=8, help='total slice of the dataset') 31 | args = parser.parse_args() 32 | 33 | mode = 'test' 34 | slidx = int(args.slice) 35 | slice_num = int(args.nodes) # node num 36 | 37 | comm = MPI.COMM_WORLD 38 | mpisize = int(comm.Get_size()) # total num of the cpu cores, the n_splits of the k-Fold 39 | mpirank = int(comm.Get_rank()) # rank of this core 40 | 41 | if mpirank == 0: 42 | with open('../data/ECG_X_{0}.bin'.format(mode), 'rb') as fin: 43 | X = np.load(fin) 44 | total_num = len(X)//slice_num 45 | if slidx != slice_num-1: 46 | X = X[slidx*total_num: (slidx+1)*total_num] 47 | else: 48 | X = X[slidx*total_num: ] 49 | total_num = len(X) 50 | total_idx = np.arange(total_num) 51 | chunk_len = total_num//(mpisize-1) if total_num >= mpisize - 1 else total_num 52 | data = [X[total_idx[i*chunk_len: (i+1)*chunk_len]] for i in tqdm(range(mpisize), desc='Scatter data')] 53 | else: 54 | data = None 55 | 56 | data = comm.scatter(data, root=0) 57 | 58 | # TMF 59 | data_path = 'slice/{0}/'.format(slidx) 60 | mkdir(data_path) 61 | 62 | D = 3 63 | output = [] 64 | time_consumption = [] 65 | 66 | for ts in tqdm(data): 67 | onesample = [] 68 | onetime = [] 69 | for win in range(100, 3100, 100): 70 | time_start=time.time() # start time 71 | shape = np.array([len(range(1, (win-1)//(D-1) + 1)), len(range(0, win-(D-1)*1)), D]) 72 | overlap = win-(D-1)*shape[0] 73 | TMF = np.zeros(shape) 74 | TMF = gen_TMF(ts[0: win], overlap, TMF, D) 75 | img = np.expand_dims(TMF, axis=0) # [1, W, H, 3] 76 | prob = net.predict(img) # [1, 2] 77 | time_end=time.time() # end time 78 | onesample.append(prob) 79 | onetime.append(time_end - time_start) 80 | if len(onetime) != 0 and len(onesample) != 0: 81 | onesample = np.concatenate(onesample, axis=0) # [S, 2] 82 | onetime = np.array(onetime) # [S] 83 | output.append(onesample) 84 | time_consumption.append(onetime) 85 | 86 | output = np.stack(output, axis=0) if len(output) != 0 else np.array([]) # [N, S, 2] 87 | time_consumption = np.stack(time_consumption, axis=0) if len(time_consumption) != 0 else np.array([]) # [N, S] 88 | 89 | np.save(data_path+'prob_'+str(mpirank), output) 90 | np.save(data_path+'time_'+str(mpirank), time_consumption) 91 | 92 | 93 | -------------------------------------------------------------------------------- /length_effect/length_effect.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -N ECG-${slice} 3 | #PBS -l nodes=1:ppn=20 4 | #PBS -l walltime=88888:00:00 5 | #PBS -q adf 6 | #PBS -j oe 7 | #PBS -m ae 8 | #PBS -M your@email.com 9 | 10 | which python 11 | 12 | cd $PBS_O_WORKDIR 13 | NPROCS=`wc -l < $PBS_NODEFILE` 14 | 15 | mpirun -n 5 python length_effect.py --slice ${slice} --nodes ${nodes} 16 | 17 | -------------------------------------------------------------------------------- /length_effect/map_length_effect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Map the parallel_flex.py to different nodes and cores to run them parallelly. 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ python3 map_length_effect.py 10 8 | """ 9 | import sys 10 | sys.path.append('../') 11 | import subprocess 12 | from lib.util import mkdir 13 | 14 | nodes = int(sys.argv[1]) 15 | for idx in range(nodes): 16 | mkdir('slice/{0}/'.format(idx)) 17 | qsub_command = """qsub -v slice={0},nodes={1} -q adf length_effect.sh""".format(*[idx, nodes]) 18 | exit_status = subprocess.call(qsub_command, shell=True) # upload 19 | if exit_status is 1: # Check to make sure the job submitted 20 | print("Job {0} failed to submit".format(qsub_command)) 21 | 22 | 23 | -------------------------------------------------------------------------------- /length_effect/reduce_length_effect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collect the result of the length effect 3 | Author: Yadong Zhang 4 | E-mail: zhangyadong@stu.xjtu.edu.cn 5 | 6 | Demo: 7 | $ python reduce_length_effect.py prob 10 8 | """ 9 | import numpy as np 10 | import argparse 11 | import sys 12 | 13 | mode = str(sys.argv[1]) # 'prob' or 'time' 14 | slide_num = int(sys.argv[2]) 15 | mpi_size = 5 16 | 17 | path = 'slice/' 18 | feature = [] 19 | for i in range(slide_num): 20 | for j in range(mpi_size): 21 | with open(path+'{0}/{2}_{1}.npy'.format(*[i, j, mode]), "rb") as f: 22 | tmp = np.load(f) 23 | if len(tmp): 24 | feature.append(tmp) 25 | 26 | feature = np.concatenate(feature, axis=0) 27 | np.save(path+mode, feature) 28 | print(feature.shape) 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/lib/__init__.py -------------------------------------------------------------------------------- /lib/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import pickle as dill 5 | 6 | import cv2 7 | import numba as nb 8 | import numpy as np 9 | import pandas as pd 10 | from tqdm import tqdm 11 | from collections import OrderedDict, Counter 12 | 13 | from sklearn import metrics 14 | from sklearn.metrics import roc_auc_score, average_precision_score, f1_score 15 | from sklearn.metrics import log_loss 16 | from sklearn.metrics import confusion_matrix 17 | 18 | import tensorflow as tf 19 | from tensorflow.keras import backend as K 20 | from tensorflow.keras.applications.vgg16 import VGG16 21 | from tensorflow.keras.models import Model 22 | from tensorflow import keras 23 | 24 | @nb.jit(nopython=True) 25 | def TMF_image(ts, overlap, TMF, D=3): 26 | ''' 27 | generate triadic motif field image of time series 28 | :param ts: 1-d array, time series 29 | :param shape: int, shape of TMF, np.array([len(range(1, (win-1)//(D-1) + 1)), len(range(0, win-(D-1)*1)), D]) # TMF image shape 30 | :param overlap: int, overlap cause by the rotation in the TMF image, overlap = win-(D-1)*shape[0] 31 | :param TMF: 2-d array, np.zeros(shape) # placeholder of TMF image 32 | :param D: int, number of ordinal points, default triad 33 | return image, [W, H, D] 34 | ''' 35 | shape = TMF.shape 36 | for i in range(shape[0]): 37 | right_bound = len(ts)-(D-1)*(i+1) 38 | for j in range(right_bound): 39 | motif_idx = np.arange(j, j+D*(i+1), (i+1)) 40 | if j < right_bound - overlap: 41 | TMF[i, j, :] = ts[motif_idx] 42 | TMF[shape[0]-i-1, shape[1]-j-1, :] = ts[motif_idx] 43 | else: 44 | TMF[i, j, :] = ts[motif_idx] 45 | return TMF 46 | 47 | def get_GCAM(model, inputs, targets, layers=-11): 48 | ''' 49 | Grad-CAM 50 | :param model: keras model 51 | :param input: 4-d array, shape is [1, W, H, C] 52 | :param target: 1-d array, shape is [nb_class] (one-hot) 53 | :param layers: int value, visualize which layer 54 | return: 2-d array, shape is [W, H] 55 | ''' 56 | class_idx = np.argmax(targets, axis=-1) 57 | class_output = model.output[:, class_idx] 58 | last_conv_layer = model.layers[layers] 59 | class_output = model.output[:, class_idx] 60 | 61 | x = inputs.copy() 62 | grads = K.gradients(class_output, last_conv_layer.output)[0] 63 | pooled_grads = K.mean(grads, axis=(0, 1, 2)) 64 | iterate = K.function(model.input, [pooled_grads, last_conv_layer.output[0]]) 65 | pooled_grads_value, conv_layer_output_value = iterate(x) 66 | for i, grads_value in enumerate(pooled_grads_value): 67 | conv_layer_output_value[:, :, i] *= grads_value 68 | heatmap = np.mean(conv_layer_output_value, axis=-1) 69 | heatmap[heatmap<0] = 0 # ReLU 70 | heatmap = cv2.resize(heatmap, (x.shape[2], x.shape[1])) 71 | gcam = (heatmap-heatmap.min())/(heatmap.max() - heatmap.min()) 72 | return gcam 73 | 74 | @nb.jit(nopython=True) 75 | def get_sym_GCAM(overlap, gcam, D=3): 76 | ''' 77 | normalize Grad-CAM into symmetrized Grad-CAM 78 | :param shape: int, shape of TMF, np.array([len(range(1, (win-1)//(D-1) + 1)), len(range(0, win-(D-1)*1)), D]) # TMF image shape 79 | :param overlap: int, overlap cause by the rotation in the TMF image, overlap = win-(D-1)*shape[0] 80 | :param gcam: 2-d array, np.zeros(shape) # placeholder of gcam 81 | :param D: int, number of ordinal points, default triad 82 | return image, [W, H] 83 | ''' 84 | shape = gcam.shape 85 | for i in range(shape[0]): 86 | right_bound = shape[1]+(D-1)-(D-1)*(i+1) 87 | for j in range(right_bound): 88 | if j < right_bound - overlap: 89 | gcam[i, j] = (gcam[i, j] + gcam[shape[0]-i-1, shape[1]-j-1])/2.0 90 | gcam[shape[0]-i-1, shape[1]-j-1] = gcam[i, j] 91 | 92 | else: 93 | gcam[i, j] = gcam[i, j] 94 | gcam = (gcam - gcam.min())/(gcam.max()-gcam.min()) 95 | return gcam 96 | 97 | def get_cam_image(net, ts, return_proba=False): 98 | ''' 99 | generate the SG-CAM of AF and non-AF 100 | :param net: keras model 101 | :param ts: 1-d array, time series 102 | :param return_proba: boolean 103 | return SG-CAM of non-AF, SG-CAM of AF, predicted probability (if return_proba=True) 104 | ''' 105 | D = 3 106 | shape = np.array([len(range(1, (len(ts)-1)//(D-1) + 1)), len(range(0, len(ts)-(D-1)*1)), D]) 107 | overlap = len(ts)-(D-1)*shape[0] 108 | # TMF image: [1, W, H, 3] 109 | img = np.zeros(shape) 110 | img = TMF_image(ts, overlap, img, D) 111 | img = np.expand_dims(img, axis=0) 112 | # SG-CAM image of non-AF 113 | gcam = get_GCAM(net, img, [1,0], layers=-3) 114 | gcam_norm = get_sym_GCAM(overlap, gcam, D) 115 | nAF_cam = gcam_norm.copy() 116 | # SG-CAM image of AF 117 | gcam = get_GCAM(net, img, [0,1], layers=-3) 118 | gcam_norm = get_sym_GCAM(overlap, gcam, D) 119 | AF_cam = gcam_norm.copy() 120 | if not return_proba: 121 | return nAF_cam, AF_cam 122 | else: 123 | proba = net.predict(img) 124 | return nAF_cam, AF_cam, proba 125 | 126 | def build_fullnet(path='../model/weights_tail_best.hdf5'): 127 | ''' 128 | build the full network (VGG16-MLP) 129 | :param path: str, path of the trained MLP 130 | return: keras network 131 | ''' 132 | base_model = VGG16(weights='imagenet', include_top=False) 133 | x = base_model.output 134 | vec = keras.layers.GlobalAveragePooling2D()(x) 135 | 136 | row_input = keras.layers.Input((vec.shape[1:])) 137 | dnn = keras.layers.Dense(128, activation='relu')(row_input) 138 | 139 | predictions = keras.layers.Dense(2, activation='softmax')(dnn) 140 | 141 | tail_model = keras.models.Model(inputs=row_input, outputs=predictions) 142 | if path is not None: 143 | tail_model.load_weights(path) 144 | 145 | out = tail_model(vec) 146 | network = keras.models.Model(inputs=base_model.input, outputs=out) 147 | return network 148 | 149 | def eval(target, predict): 150 | ''' 151 | evaluate the results 152 | :param target: 1-d array, 0 or 1 153 | :param predict: 1-d array, probability of AF 154 | return roc_auc, pr_auc and F1 155 | ''' 156 | ROC_AUC = metrics.roc_auc_score(target, predict) 157 | # PR_AUC 158 | precision, recall, _thresholds = metrics.precision_recall_curve( 159 | target, predict) 160 | PR_AUC = metrics.auc(recall, precision) 161 | # F1 162 | predict = np.array([i > 0.5 for i in predict], dtype=int) 163 | F1 = metrics.f1_score(target, predict) 164 | return [ROC_AUC, PR_AUC, F1] 165 | 166 | def eval_patient(pid, label, proba): 167 | ''' 168 | evaluate the patient-wise accuracy 169 | :param pid: 1-d array, shape is [N,], id of patients 170 | :param label: 1-d array, shape is [N,], 0 or 1, 1 indicates the AF segment 171 | :param proba: 1-d array, shape is [N,], probability of AF 172 | return: 2-d array, [N, 3], columns are pid, count of segment, patient-wise accuracy 173 | ''' 174 | wrong = pid[(label == 1) & (proba < 0.5)] 175 | wrong_pid = np.unique(wrong) 176 | AF = pid[(label == 1)] 177 | AF_pid = np.unique(AF) 178 | detail = [] 179 | for i in AF_pid: 180 | # [id, length, accuracy] 181 | total = float(np.sum(AF == i)) 182 | detail.append([i, total, (total - np.sum(wrong == i)) / total]) 183 | detail.sort(key=lambda x: x[2]) # sort according to the accuracy 184 | return np.array(detail) 185 | 186 | def mkdir(path): 187 | """ 188 | mkdir of the path 189 | :param input: string of the path 190 | return: boolean 191 | """ 192 | path = path.strip() 193 | path = path.rstrip("\\") 194 | isExists = os.path.exists(path) 195 | 196 | if not isExists: 197 | os.makedirs(path) 198 | print(path+' is created!') 199 | return True 200 | else: 201 | print(path+' already exists!') 202 | return False 203 | 204 | def slide_and_cut(X, Y, window_size, stride, output_pid=False): 205 | ''' 206 | From https://github.com/hsd1503/MINA 207 | MINA: Multilevel Knowledge-Guided Attention for Modeling Electrocardiography Signals, IJCAI 2019 208 | ''' 209 | out_X = [] 210 | out_Y = [] 211 | out_pid = [] 212 | n_sample = X.shape[0] 213 | mode = 0 214 | for i in range(n_sample): 215 | tmp_ts = X[i] 216 | tmp_Y = Y[i] 217 | if tmp_Y == 0: 218 | i_stride = stride 219 | elif tmp_Y == 1: 220 | i_stride = stride//10 221 | for j in range(0, len(tmp_ts)-window_size, i_stride): 222 | out_X.append(tmp_ts[j:j+window_size]) 223 | out_Y.append(tmp_Y) 224 | out_pid.append(i) 225 | if output_pid: 226 | return np.array(out_X), np.array(out_Y), np.array(out_pid) 227 | else: 228 | return np.array(out_X), np.array(out_Y) 229 | 230 | def make_data_physionet(data_path, n_split=50, window_size=3000, stride=500): 231 | ''' 232 | From https://github.com/hsd1503/MINA 233 | MINA: Multilevel Knowledge-Guided Attention for Modeling Electrocardiography Signals, IJCAI 2019 234 | ''' 235 | # read pkl 236 | with open(os.path.join(data_path, 'challenge2017.pkl'), 'rb') as fin: 237 | res = dill.load(fin) 238 | ## scale data 239 | all_data = res['data'] 240 | for i in range(len(all_data)): 241 | tmp_data = all_data[i] 242 | tmp_std = np.std(tmp_data) 243 | tmp_mean = np.mean(tmp_data) 244 | all_data[i] = (tmp_data - tmp_mean) / tmp_std # normalize 245 | all_data = res['data'] 246 | all_data = np.array(all_data) 247 | ## encode label 248 | all_label = [] 249 | for i in res['label']: 250 | if i == 'A': 251 | all_label.append(1) 252 | else: 253 | all_label.append(0) 254 | all_label = np.array(all_label) 255 | 256 | # split train test 257 | n_sample = len(all_label) 258 | split_idx_1 = int(0.75 * n_sample) 259 | split_idx_2 = int(0.85 * n_sample) 260 | 261 | shuffle_idx = np.random.RandomState(seed=40).permutation(n_sample) 262 | all_data = all_data[shuffle_idx] 263 | all_label = all_label[shuffle_idx] 264 | 265 | X_train = all_data[:split_idx_1] 266 | X_val = all_data[split_idx_1:split_idx_2] 267 | X_test = all_data[split_idx_2:] 268 | Y_train = all_label[:split_idx_1] 269 | Y_val = all_label[split_idx_1:split_idx_2] 270 | Y_test = all_label[split_idx_2:] 271 | 272 | # slide and cut 273 | print(Counter(Y_train), Counter(Y_val), Counter(Y_test)) 274 | X_train, Y_train = slide_and_cut( 275 | X_train, Y_train, window_size=window_size, stride=stride) 276 | X_val, Y_val = slide_and_cut( 277 | X_val, Y_val, window_size=window_size, stride=stride) 278 | X_test, Y_test, pid_test = slide_and_cut( 279 | X_test, Y_test, window_size=window_size, stride=stride, output_pid=True) 280 | print('after: ') 281 | print(Counter(Y_train), Counter(Y_val), Counter(Y_test)) 282 | 283 | # shuffle train 284 | shuffle_pid = np.random.RandomState(seed=42).permutation( 285 | Y_train.shape[0]) # np.random.permutation(Y_train.shape[0]) 286 | X_train = X_train[shuffle_pid] 287 | Y_train = Y_train[shuffle_pid] 288 | 289 | # save 290 | res = {'Y_train': Y_train, 'Y_val': Y_val, 291 | 'Y_test': Y_test, 'pid_test': pid_test} 292 | with open(os.path.join(data_path, 'ECG_info.pkl'), 'wb') as fout: 293 | dill.dump(res, fout) 294 | 295 | fout = open(os.path.join(data_path, 'ECG_X_train.bin'), 'wb') 296 | np.save(fout, X_train) 297 | fout.close() 298 | 299 | fout = open(os.path.join(data_path, 'ECG_X_val.bin'), 'wb') 300 | np.save(fout, X_val) 301 | fout.close() 302 | 303 | fout = open(os.path.join(data_path, 'ECG_X_test.bin'), 'wb') 304 | np.save(fout, X_test) 305 | fout.close() 306 | -------------------------------------------------------------------------------- /lib/visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | import matplotlib.pyplot as plt 5 | from matplotlib.collections import LineCollection 6 | from matplotlib import gridspec 7 | plt.rc('text', usetex=True) 8 | plt.rc('font', family='Times New Roman') 9 | plt.rcParams['xtick.direction'] = 'in' 10 | plt.rcParams['ytick.direction'] = 'in' 11 | import matplotlib.patches as patches 12 | from tqdm import tqdm 13 | 14 | from scipy.signal import butter, lfilter 15 | ECG_param = {'PQ_max': 0.2, 'QRS_max': 0.150, 16 | 'P_min': 0.08, 'RR_min': 0.200, 17 | 'sample_freq': 300 18 | } # unit: s, Hz 19 | 20 | def detectR(data, threshold=0.96): 21 | ''' 22 | detect R peaks from ECG signals 23 | :param data: 1-d array, ECG signals 24 | :param threshold: float, threshold for the R peaks 25 | return: R peak indices 26 | ''' 27 | RR_interval = int(ECG_param['RR_min']*ECG_param['sample_freq']) 28 | # Filter the QRS regions for R detection 29 | nyquist_freq = 0.5 * ECG_param['sample_freq'] 30 | bandpass = [10, 50] # R band 31 | low = bandpass[0] / nyquist_freq 32 | high = bandpass[1] / nyquist_freq 33 | filter_order = 1 34 | b, a = butter(filter_order, [low, high], btype="band") 35 | ts = lfilter(b, a, data) 36 | 37 | # Second order difference of the filtered time series 38 | diff2ts = np.power(np.diff(np.diff(ts)), 2) 39 | # Find the top 96% amplitude of the ECG as the threshold of R peaks 40 | thres_amp= np.percentile(diff2ts, int(threshold*100)) 41 | 42 | idxsort = np.argsort(diff2ts) # obtain the sorted idx of time series 43 | idx = np.arange(len(idxsort)) 44 | filtered = [] 45 | for i in idxsort[::-1]: 46 | if diff2ts[i] > thres_amp: 47 | if i in idx: 48 | filtered.append(i) # if i is in the idx, it means that the peak is not removed previously 49 | idx = np.setdiff1d(idx, range(i-RR_interval//2, i+RR_interval//2)) # remove the idx 50 | R_peak = [] 51 | for i in filtered: 52 | try: 53 | p = i-RR_interval//2 + np.argmax(data[i-RR_interval//2: i+RR_interval//2]) 54 | if np.max(data[i-RR_interval//2: i+RR_interval//2]) > np.mean(data): 55 | R_peak.append(p) 56 | except: 57 | pass 58 | 59 | return R_peak 60 | 61 | def flip_ECG(data): 62 | ''' 63 | Some ECG is reversed which needed to be inversed 64 | :param data: 1-d array, ECG signal 65 | return: ECG signal 66 | ''' 67 | less, more = np.percentile(data, (2, 98)) 68 | if np.abs(less) > np.abs(more): 69 | return -data 70 | else: 71 | return data 72 | 73 | def get_beats(data, R_peak:list): 74 | ''' 75 | Get beats 76 | :param ts: 1-d array, ECG signals 77 | :param R_peak: list, R_peak indices 78 | return: list of beats 79 | ''' 80 | R_peak = np.sort(R_peak) 81 | RR_interval = int(np.mean(np.diff(R_peak))) 82 | beats = [] 83 | for i in R_peak: 84 | beats.append(data[i-RR_interval//2: i+RR_interval//2]) 85 | return beats 86 | 87 | def filt_ECG(data, freq='mid', sample_freq=300): 88 | ''' 89 | filter the ECG signal 90 | :param freq: str, 'mid', 'high', 'low', 'no' 91 | :param sample_freq: int 92 | return: filtered ECG signal 93 | ''' 94 | if freq == 'no': return data 95 | nyquist_freq = 0.5 * sample_freq 96 | select = {'low': (0.001 / nyquist_freq, 0.5 / nyquist_freq), 'mid': (0.5 / nyquist_freq, 50 / nyquist_freq), 'high': 50 / nyquist_freq} 97 | filter_order = 1 98 | if freq == 'high': 99 | b, a = butter(filter_order, select[freq], btype="high") 100 | else: 101 | b, a = butter(filter_order, select[freq], btype="band") 102 | out = lfilter(b, a, data) 103 | return out 104 | 105 | def argmax2D(matrix): 106 | """ 107 | find the index of maximum value in 2d matrix 108 | :param matrix: 2-d array 109 | :return: x, y 110 | """ 111 | cols = matrix.shape[1] 112 | loc = np.argmax(matrix) 113 | y, x = loc //cols,loc %cols 114 | return y, x 115 | 116 | def plot_cam(ts, img, path=None, point=None, peak_idx=None): 117 | ''' 118 | plot time series and SG-CAM images 119 | :param ts: 1-d array, ECG signal 120 | :param img: 2-d array, image of SG-CAM 121 | :param path: str, path to save the figure 122 | :param point: plot which point in the SG-CAM 123 | :param peak_idx: plot which peak in ECG 124 | ''' 125 | if point: 126 | win = 100 127 | centerx, centery = argmax2D(img[point[1]-win:point[1]+win, point[0]-win: point[0]+win]) 128 | centerx += point[1]-win 129 | centery += point[0]-win 130 | 131 | fig = plt.figure(figsize=(6, 4)) 132 | gs = gridspec.GridSpec(2, 1, height_ratios=[1,4], hspace=0.) 133 | ax = plt.subplot(gs[0]) 134 | ax.set_xticks([]) 135 | ax.set_yticks([]) 136 | ax.set_xlim([0, 3000-3]) 137 | ax.plot(ts, 'k', linewidth=1) 138 | if point: 139 | motif = [ 140 | [centery, centerx+1, 'ro--', 4] 141 | ] 142 | for start, gap, ty, s in motif: 143 | ix = np.arange(start, start + gap*3, gap) 144 | ax.plot(ix, ts[ix], ty, markersize=s) 145 | 146 | ax.set_ylim([-4, 8.5]) 147 | ts_flip = flip_ECG(ts) 148 | R_peak = detectR(ts_flip) 149 | R_peak.sort() 150 | for i in R_peak: 151 | ax.plot([i, i], [-4, 8.5], 'r--', linewidth=0.5) 152 | 153 | ax = plt.subplot(gs[1]) 154 | ax.imshow(img, cmap=plt.cm.jet, vmax=1, vmin=0) 155 | if point: 156 | ax.plot([centery], [centerx], 'r+', markersize=10, alpha=0.5) 157 | 158 | horizon_move = 100 159 | if peak_idx: 160 | rect = patches.Rectangle([R_peak[peak_idx]-50, 0],50*2,30*2,linestyle='--',linewidth=1,edgecolor='w',facecolor='none') 161 | bbox_props = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.1) 162 | ax.text(R_peak[peak_idx]+horizon_move, 160, "QRS", color='w', ha="center", va="center", size=14, 163 | bbox=bbox_props) 164 | ax.add_patch(rect) 165 | 166 | rect = patches.Rectangle([R_peak[peak_idx]-170, 0],50*2,30*2,linestyle='--',linewidth=1,edgecolor='w',facecolor='none') 167 | bbox_props = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.1) 168 | ax.text(R_peak[peak_idx]-250+horizon_move, 160, "P", color='w', ha="center", va="center", size=14, 169 | bbox=bbox_props) 170 | ax.add_patch(rect) 171 | ax.set_xticks([]) 172 | ax.set_yticks([]) 173 | plt.tight_layout() 174 | if path: 175 | fig.savefig(path, dpi=300) 176 | 177 | def plot_basic_tsne(X_embedded, val_y, alpha=1, s=1, path=None): 178 | ''' 179 | draw tsne figure of AF and non-AF 180 | :param X_embedded: 2-d array, shape is [N, 2], the low-dimensional points 181 | :param val_y: 1-d array, shape is [N], the label of the data, 1 and 0 indicated the AF and non-AF 182 | :param alpha: float, alpha of the scatters 183 | :param s: float, size of the scatters 184 | :param path: str, path of saving the figure 185 | ''' 186 | fig = plt.figure(figsize=(6, 6)) 187 | plt.scatter(x=X_embedded[val_y==1][:,0], y=X_embedded[val_y==1][:,1],color='r',s=s,alpha=alpha,edgecolors="none") 188 | plt.scatter(x=X_embedded[val_y==0][:,0], y=X_embedded[val_y==0][:,1],color='b',s=s,alpha=alpha,edgecolors="none") 189 | lgnd = plt.legend(['AF', 'non-AF'],fancybox=False, framealpha=0,fontsize=15, loc='lower left') 190 | 191 | for handle in lgnd.legendHandles: 192 | handle.set_sizes([12.0]) 193 | handle.set_alpha(1.0) 194 | ax = plt.gca() 195 | ax.spines['top'].set_visible(False) 196 | ax.spines['right'].set_visible(False) 197 | ax.spines['bottom'].set_visible(False) 198 | ax.spines['left'].set_visible(False) 199 | 200 | ax.set_xticks([]) 201 | ax.set_yticks([]) 202 | 203 | fig.tight_layout() 204 | if path: 205 | fig.savefig(path, dpi=300) 206 | 207 | def plot_AF_pid(X_embedded, val_y, pid_test, path=None): 208 | ''' 209 | draw tsne figure of AF according to the id of patients (pid) 210 | :param X_embedded: 2-d array, shape is [N, 2], the low-dimensional points 211 | :param val_y: 1-d array, shape is [N], the label of the data, 1 and 0 indicated the AF and non-AF 212 | :param pid_test: 1-d array, shape is [N], the id of the patients 213 | :param path: str, path of saving the figure 214 | ''' 215 | fig = plt.figure(figsize=(6, 6)) 216 | plt.scatter(x=X_embedded[val_y==1][:,0], y=X_embedded[val_y==1][:,1], 217 | s=0.5,alpha=1,c=pid_test[val_y==1], cmap='jet') 218 | 219 | x, y, s = X_embedded[val_y==1][:,0], X_embedded[val_y==1][:,1], pid_test[val_y==1] 220 | for i in tqdm(np.unique(s)): 221 | for xi, yi in zip(x[s==i], y[s==i]): 222 | plt.plot([np.mean(x[s==i]), xi], [np.mean(y[s==i]), yi], 'k', linewidth=0.1, alpha=0.2) 223 | ax = plt.gca() 224 | ax.spines['top'].set_visible(False) 225 | ax.spines['right'].set_visible(False) 226 | ax.spines['bottom'].set_visible(False) 227 | ax.spines['left'].set_visible(False) 228 | 229 | ax.set_xticks([]) 230 | ax.set_yticks([]) 231 | fig.tight_layout() 232 | if path: 233 | fig.savefig(path, dpi=300) 234 | -------------------------------------------------------------------------------- /model/weights_tail_best.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/model/weights_tail_best.hdf5 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | mpi4py==3.0.3 3 | numba==0.50.1 4 | scikit-learn==0.23.0 5 | scipy==1.5.2 6 | tensorflow==1.14.0 7 | opencv-python 8 | tqdm 9 | PyQT5 --------------------------------------------------------------------------------