├── README.md
├── app
├── README.md
├── app.py
└── img.jpg
├── classifier
├── README.md
├── classifier.sh
├── classifier_LR.py
├── classifier_MLP.py
└── classifier_RF.py
├── data
├── README.md
└── data.py
├── docs
├── app_bokeh.png
├── app_main.png
├── app_start.png
└── map.png
├── extractor
├── README.md
├── gen_feature.py
├── gen_feature.sh
├── map_gen_feature.py
└── reduce_gen_feature.py
├── length_effect
├── README.md
├── draw.py
├── length_effect.py
├── length_effect.sh
├── map_length_effect.py
└── reduce_length_effect.py
├── lib
├── __init__.py
├── util.py
└── visual.py
├── model
└── weights_tail_best.hdf5
└── requirements.txt
/README.md:
--------------------------------------------------------------------------------
1 | # Anomaly Detection in Time Series with Triadic Motif Fields and Application in Atrial Fibrillation ECG Classification
2 |
3 | [](https://arxiv.org/abs/2012.04936) [](https://mybinder.org/v2/gh/ydup/bokeh/master?urlpath=/proxy/5006/bokeh-app) [](https://paperswithcode.com/sota/atrial-fibrillation-detection-on-physionet?p=anomaly-detection-in-time-series-with-triadic)
4 |
5 | Author: [Yadong Zhang](https://github.com/ydup) and Xin Chen
6 |
7 | Paper: [](https://arxiv.org/abs/2012.04936)
8 |
9 | Online demo: [](https://mybinder.org/v2/gh/ydup/bokeh/master?urlpath=/proxy/5006/bokeh-app)
10 |
11 | ### Modules
12 |
13 | Module | Path | Note | Default Settings
14 | --- | --- | --- | ---
15 | Basic | 1. [_lib_](lib/)
2. [_data_](data/)
3. [_model_](model/) | 1. Basic functions of the project.
2. Dataset processing.
3. Saved tail model weights. | 1. -
2. no filter, z-normalization
3. MLP model
16 | Classification | 1. [_extractor_](extractor/)
2. [_classifier_](classifier/) | 1. Features extraction of TMF images based on transfer learning.
2. Feature vectors classification to AF and non-AF probabilities. | 1. VGG16, map-reduce use ```10``` nodes and ```5``` mpisize.
2. -
17 | Evaluation | 1. [_length\_effect_](length_effect/) | 1. Evaluate the trained model on varying-length ECG signals. | 1. VGG16-MLP, map-reduce use ```10``` nodes and ```5``` mpisize.
18 | App | 1. [_pyQT app_](app/)
2. [_bokeh app_](https://github.com/ydup/bokeh) | 1. Local app for classification and interpretation.
2. Web server for interpretation. | VGG16-MLP
19 |
20 | ### Structures of Parallel Codes (mpi)
21 |
22 | [_extractor_](extractor/) and [_length\_effect_](length_effect/) are parallelized on the linux clustering. (map-reduce)
23 | + ```.py```: main code.
24 | + ```.sh```: script for single submission to the pbs queue.
25 | + ```map*.py```: map the tasks to multi-nodes and mpi.
26 | + ```reduce*.py```: collect the results from the finished tasks.
27 |
28 | ### Guidelines of APP
29 |
30 | Features | Classification | Visualization | Interactive | Remote | Local
31 | --- | --- | --- | --- |--- |---
32 | [pyQT app](app/) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark:
33 | [bokeh app](https://github.com/ydup/bokeh) | :x: (available in future) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark:
34 |
35 | #### [pyQT app](app/)
36 |
37 | 1. Start page (click ```start```)
38 | + Start button
39 | + Process bar & status
40 | 
41 | 2. Main page (from top to bottom)
42 | + Time series with label
43 | + Symmetrized Grad-CAM of AF and its predicted probability
44 | + Symmetrized Grad-CAM of non-AF and its predicted probability
45 | + Sliders of ```time index``` and ```delay``` to adjust the triadic time series motifs
46 | - Triad (red) in time series is corresponding to the cross (white) in two Symmetrized Grad-CAM images
47 | - The text with red background indicates the predicted type.
48 | 
49 | #### [bokeh app](https://github.com/ydup/bokeh)
50 | 
51 |
52 | ### [Requirements](./requirements.txt)
53 | Python 3.6:
54 | ```
55 | matplotlib
56 | mpi4py==3.0.3
57 | numba==0.50.1
58 | scikit-learn==0.23.0
59 | scipy==1.5.2
60 | tensorflow==1.14.0
61 | opencv-python
62 | tqdm
63 | PyQT5
64 | ```
65 |
66 | ### Citation
67 |
68 | Cite our work with:
69 | ```latex
70 | @misc{zhang2020anomaly,
71 | title={Anomaly Detection in Time Series with Triadic Motif Fields and Application in Atrial Fibrillation ECG Classification},
72 | author={Yadong Zhang and Xin Chen},
73 | year={2020},
74 | eprint={2012.04936},
75 | archivePrefix={arXiv},
76 | primaryClass={cs.LG}
77 | }
78 | ```
79 |
--------------------------------------------------------------------------------
/app/README.md:
--------------------------------------------------------------------------------
1 | # Application
2 |
3 | Demo:
4 | 1. Default index of ECG in the test dataset. (AF signal)
5 | ```shell
6 | $ python3 app.py
7 | ```
8 | 2. User defined index for example, 1.
9 | ```shell
10 | $ python3 app.py 1
11 | ```
12 |
--------------------------------------------------------------------------------
/app/app.py:
--------------------------------------------------------------------------------
1 | '''
2 | ECG application (classification and interpretation).
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ python app.py
8 | '''
9 | import matplotlib
10 | matplotlib.use('Qt5Agg')
11 | from matplotlib import pyplot as plt
12 | plt.rc('text', usetex=True)
13 | plt.rc('font', family='Times New Roman')
14 | import matplotlib.patches as patches
15 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT as NavigationToolbar
16 | from matplotlib.figure import Figure
17 |
18 | import time
19 | import pickle
20 | import numpy as np
21 | from PyQt5 import QtCore, QtGui, QtWidgets
22 | from PyQt5.QtWidgets import *
23 | from PyQt5.QtCore import *
24 | from PyQt5.QtGui import *
25 |
26 | import sys
27 | sys.path.append('../')
28 | from lib.util import TMF_image, get_GCAM, get_sym_GCAM
29 | from lib.util import build_fullnet as build
30 |
31 | StyleSheet = '''
32 | #RedProgressBar {
33 | text-align: center;
34 | min-height: 12px;
35 | max-height: 12px;
36 | }
37 | #RedProgressBar::chunk {
38 | background-color: #F44336;
39 | }
40 | '''
41 |
42 | DETAIL = {0: 'Load data', 10: 'Load model', 20: 'Get TMF image', 30: 'Get Grad-CAM of AF', 55: 'Get symmetrized Grad-CAM of AF', 60: 'Get Grad-CAM of non-AF', 85: 'Get symmetrized Grad-CAM of non-AF', 90: 'Predict', 100: 'Finished'} # detail of the process status
43 |
44 | class TSCanvas(FigureCanvasQTAgg):
45 | def __init__(self, parent=None, width=5, height=4, dpi=100):
46 | '''
47 | ECG time series
48 | '''
49 | fig = Figure(figsize=(width, height), dpi=dpi)
50 | self.axes = fig.add_subplot(111)
51 | super(TSCanvas, self).__init__(fig)
52 |
53 | class HMCanvas(FigureCanvasQTAgg):
54 | def __init__(self, parent=None, width=5, height=4, dpi=100):
55 | '''
56 | Symmetrized Grad-CAM of TMF images
57 | '''
58 | fig = Figure(figsize=(width, height), dpi=dpi)
59 | self.axes = fig.add_subplot(111)
60 | super(HMCanvas, self).__init__(fig)
61 |
62 | class MainWindow(QtWidgets.QMainWindow):
63 |
64 | def __init__(self, ts, gcam_AF, gcam_nAF, proba_AF, proba_nAF, label):
65 |
66 | super().__init__()
67 | self.gcam_AF = gcam_AF
68 | self.gcam_nAF = gcam_nAF
69 | self.ts = ts
70 | self.setFixedSize(640, 800)
71 | win = len(self.ts) # time series length
72 | D = 3 # triad
73 | self.overlap = win-(D-1)*self.gcam_AF.shape[0] # overlap of the heatmap
74 |
75 | # ECG time series, initial plot
76 | self.sc1 = HMCanvas(self, width=4.0, height=5, dpi=100)
77 | self.sc1.axes.plot(np.arange(len(self.ts)), self.ts, 'k', linewidth=1.0)
78 | self.sc1.axes.set_xlim([0, len(self.ts)])
79 | self.sc1.axes.set_yticks([])
80 | # Heatmap, initial plot
81 | self.sc2 = HMCanvas(self, width=4.0, height=5, dpi=100)
82 | self.sc2.axes.imshow(self.gcam_AF, cmap=plt.cm.jet)
83 | self.sc2.axes.set_xticks([])
84 | self.sc2.axes.set_yticks([])
85 | # Heatmap, initial plot
86 | self.sc3 = HMCanvas(self, width=4.0, height=5, dpi=100)
87 | self.sc3.axes.imshow(self.gcam_nAF, cmap=plt.cm.jet)
88 | self.sc3.axes.set_xticks([])
89 | self.sc3.axes.set_yticks([])
90 |
91 | # Slider for gap
92 | self.s=QSlider(Qt.Horizontal)
93 | self.s.setMinimum(1)
94 | self.s.setMaximum(self.gcam_AF.shape[0])
95 | self.s.setSingleStep(1)
96 | self.s.setValue(20)
97 | self.s.setTickPosition(QSlider.TicksBelow)
98 | self.s.setTickInterval(100)
99 | self.s.valueChanged.connect(self.triadchange) # connect with the triadchange function to redraw the plots
100 |
101 | # Slider for initial time index of triad
102 | self.idx=QSlider(Qt.Horizontal)
103 | self.idx.setMinimum(0)
104 | self.idx.setMaximum(self.gcam_AF.shape[1]-1)
105 | self.idx.setSingleStep(1)
106 | self.idx.setValue(20)
107 | self.idx.setTickPosition(QSlider.TicksBelow)
108 | self.idx.setTickInterval(100)
109 | self.idx.valueChanged.connect(self.triadchange)
110 |
111 | # Labels
112 | self.l0 = QLabel('Label: %s'%('AF' if label == 1 else 'non-AF'))
113 | self.l0.setAlignment(Qt.AlignCenter)
114 | self.imp1 = QLabel('Probability of AF:%.3f'%proba_AF)
115 | self.imp1.setAlignment(Qt.AlignCenter)
116 | self.imp2 = QLabel('Probability of non-AF:%.3f'%proba_nAF)
117 | self.imp2.setAlignment(Qt.AlignCenter)
118 | if proba_AF > proba_nAF:
119 | self.imp1.setStyleSheet("background-color: red")
120 | elif proba_AF < proba_nAF:
121 | self.imp2.setStyleSheet("background-color: red")
122 | self.l1 = QLabel('Time index:')
123 | self.l1.setAlignment(Qt.AlignCenter)
124 | self.l2 = QLabel('Delay:')
125 | self.l2.setAlignment(Qt.AlignCenter)
126 |
127 | layout = QtWidgets.QVBoxLayout()
128 | layout.addWidget(self.l0) # label
129 | layout.addWidget(self.sc1) # Time series
130 | layout.addWidget(self.imp1) # Label: probability of AF
131 | layout.addWidget(self.sc2) # Heatmap
132 | layout.addWidget(self.imp2) # Label: probability of non-AF
133 | layout.addWidget(self.sc3) # Heatmap
134 | layout.addWidget(self.l1) # Label: initial time index
135 | layout.addWidget(self.idx) # Time index slider
136 | layout.addWidget(self.l2) # LabeL: gap
137 | layout.addWidget(self.s) # Gap slider
138 | # Create a placeholder widget to hold our toolbar and canvas.
139 | widget = QtWidgets.QWidget()
140 | widget.setLayout(layout)
141 | self.setCentralWidget(widget)
142 | self.show()
143 | self.setLayout(layout)
144 |
145 | def triadchange(self):
146 | '''
147 | Re-draw the triad, time series and heatmap
148 | '''
149 | start = self.idx.value()
150 | gap = self.s.value()
151 |
152 | # Re-draw the heatmap
153 | self.sc2.axes.cla() # Clear the canvas.
154 | self.sc2.axes.imshow(self.gcam_AF, cmap=plt.cm.jet)
155 | # Change color of + in heatmap
156 | self.sc2.axes.scatter(x=[start], y=[gap-1], s=100, color='w', marker='+')
157 | self.sc2.axes.set_xlim([0, self.gcam_AF.shape[1]])
158 | self.sc2.axes.set_ylim([self.gcam_AF.shape[0], 0])
159 | self.sc2.axes.set_xticks([])
160 | self.sc2.axes.set_yticks([])
161 | # Trigger the canvas to update and redraw.
162 | self.sc2.draw()
163 |
164 | # Re-draw the heatmap
165 | self.sc3.axes.cla() # Clear the canvas.
166 | self.sc3.axes.imshow(self.gcam_nAF, cmap=plt.cm.jet)
167 | # Change color of + in heatmap
168 | self.sc3.axes.scatter(x=[start], y=[gap-1], s=100, color='w', marker='+')
169 | self.sc3.axes.set_xlim([0, self.gcam_nAF.shape[1]])
170 | self.sc3.axes.set_ylim([self.gcam_nAF.shape[0], 0])
171 | self.sc3.axes.set_xticks([])
172 | self.sc3.axes.set_yticks([])
173 | # Trigger the canvas to update and redraw.
174 | self.sc3.draw()
175 |
176 | # Fix the symmetrized triad
177 | right_bound = len(self.ts)-(3-1)*gap
178 | if start + 1> right_bound:
179 | start = self.gcam_AF.shape[1] - start - 1
180 | gap = self.gcam_AF.shape[0] - gap + 1
181 |
182 | # Re-draw time series
183 | self.sc1.axes.cla() # Clear the canvas.
184 | self.sc1.axes.plot(np.arange(len(self.ts)), self.ts, 'k', linewidth=1.0)
185 | self.sc1.axes.plot(np.arange(start, start+gap*3, gap), self.ts[np.arange(start, start+gap*3, gap)], 'ro-', markersize=5)
186 | self.sc1.axes.set_xlim([0, len(self.ts)])
187 | self.sc1.axes.set_yticks([])
188 | self.sc1.draw()
189 |
190 | # Set new text for labels
191 | self.l1.setText('Time index:'+str(start))
192 | self.l2.setText('Delay:'+str(gap))
193 |
194 | def closeEvent(self, event):
195 |
196 | reply = QMessageBox.question(self, 'Warning',
197 | "Sure to exit?", QMessageBox.Yes |
198 | QMessageBox.No, QMessageBox.No)
199 | if reply == QMessageBox.Yes:
200 | self.hide()
201 | self.dialog = StartWindow()
202 | else:
203 | event.ignore()
204 |
205 | class Thread(QThread):
206 | _signal = pyqtSignal(int)
207 | def __init__(self):
208 | super(Thread, self).__init__()
209 |
210 | def __del__(self):
211 | self.wait()
212 |
213 | def run(self):
214 | pnums = list(DETAIL.keys())
215 | pnums.sort()
216 |
217 | self._signal.emit(pnums[0])
218 | net = build()
219 | with open(ecg_path, 'rb') as f:
220 | data = np.load(f)
221 | with open(label_path, 'rb') as f:
222 | label = pickle.load(f)
223 | self._signal.emit(pnums[1])
224 | net = build()
225 | self._signal.emit(pnums[2])
226 | # AF signal
227 | ts = data[IDX]
228 | val_y = label['Y_test'][IDX]
229 |
230 | D = 3
231 | shape = np.array([len(range(1, (len(ts)-1)//(D-1) + 1)), len(range(0, len(ts)-(D-1)*1)), D])
232 | overlap = len(ts)-(D-1)*shape[0]
233 | # TMF image: [1, W, H, 3]
234 | img = np.zeros(shape)
235 | img = TMF_image(ts, overlap, img, D)
236 | img = np.expand_dims(img, axis=0)
237 | self._signal.emit(pnums[3])
238 | # SG-CAM image of non-AF
239 | gcam = get_GCAM(net, img, [1,0], layers=-3)
240 | self._signal.emit(pnums[4])
241 | gcam_norm = get_sym_GCAM(overlap, gcam, D)
242 | nAF_gcam = gcam_norm.copy()
243 | self._signal.emit(pnums[5])
244 | # SG-CAM image of AF
245 | gcam = get_GCAM(net, img, [0,1], layers=-3)
246 | self._signal.emit(pnums[6])
247 | gcam_norm = get_sym_GCAM(overlap, gcam, D)
248 | AF_gcam = gcam_norm.copy()
249 | self._signal.emit(pnums[7])
250 | proba = net.predict(img)
251 | print('Predicted probabilities:', proba)
252 | nAF_proba, AF_proba = proba[0]
253 | global RES
254 | RES = [ts, AF_gcam, nAF_gcam, AF_proba, nAF_proba, val_y]
255 | self._signal.emit(pnums[8])
256 |
257 | class StartWindow(QWidget):
258 |
259 | def __init__(self):
260 | super().__init__()
261 | self.initUI()
262 |
263 | def initUI(self):
264 |
265 | pixmap = QPixmap("img.jpg")
266 |
267 | # Start menu image
268 | lbl = QLabel(self)
269 | lbl.setPixmap(pixmap)
270 | self.setFixedSize(pixmap.width(),pixmap.height())
271 | self.parentSize = [pixmap.width(),pixmap.height()]
272 | self.center()
273 |
274 | # Process bar
275 | self.progress = QProgressBar(self, objectName="RedProgressBar")
276 | self.progress.move(pixmap.width()//2-50, pixmap.height()-150)
277 | self.progress.resize(100, 20)
278 | # Start button
279 | self.startBtn = QPushButton("Start", self)
280 | self.startBtn.move(pixmap.width()//2-50, pixmap.height()-100)
281 | self.startBtn.clicked.connect(self.on_pushButton_clicked)
282 | self.startBtn.resize(100, 50)
283 | # Title
284 | self.label = QLabel(self)
285 | self.label.setFixedWidth(800)
286 | self.label.setFixedHeight(100)
287 | self.label.move((pixmap.width()-self.label.width())/2, (pixmap.height()-self.label.height())/2)
288 | self.label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
289 | self.label.setAlignment(Qt.AlignCenter)
290 | self.label.setText(u"Triadic Motif Field")
291 | self.label.setAutoFillBackground(False)
292 | self.label.setFont(QFont("Roman times", 30, QFont.Bold))
293 |
294 | self.setWindowTitle('TMF: Interpretable Classification Model')
295 | self.show()
296 |
297 | def on_pushButton_clicked(self):
298 | self.thread = Thread()
299 | self.thread._signal.connect(self.signal_accept)
300 | self.thread.start()
301 | self.startBtn.setEnabled(False)
302 |
303 | def signal_accept(self, msg):
304 | self.progress.setValue(int(msg))
305 | self.label.setText(DETAIL[int(msg)])
306 | print(DETAIL[int(msg)])
307 | if self.progress.value() == 100:
308 | self.hide()
309 | self.progress.setValue(0)
310 | global RES
311 | self.dialog = MainWindow(*RES)
312 |
313 | def center(self):
314 | qr = self.frameGeometry()
315 | cp = QDesktopWidget().availableGeometry().center()
316 | qr.moveCenter(cp)
317 | self.move(qr.topLeft())
318 |
319 | def closeEvent(self, event):
320 |
321 | reply = QMessageBox.question(self, 'Warning',
322 | "Sure to exit?", QMessageBox.Yes |
323 | QMessageBox.No, QMessageBox.No)
324 |
325 | if reply == QMessageBox.Yes:
326 | event.accept()
327 | else:
328 | event.ignore()
329 |
330 | if __name__ == '__main__':
331 |
332 | if len(sys.argv) > 1:
333 | IDX = int(sys.argv[1])
334 | else: # default value
335 | IDX = 3736 # AF signal
336 |
337 | ecg_path = '../data/ECG_X_test.bin' # 2-d array, [N, dim]
338 | label_path = '../data/ECG_info.pkl' # dict, key='Y_test', 0 and 1 indicate non-AF and AF
339 | app = QApplication(sys.argv)
340 | app.setStyleSheet(StyleSheet)
341 | ex = StartWindow()
342 | sys.exit(app.exec_())
--------------------------------------------------------------------------------
/app/img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/app/img.jpg
--------------------------------------------------------------------------------
/classifier/README.md:
--------------------------------------------------------------------------------
1 | # Classifiers
2 |
3 | 1. MLP: early stopping
4 | 2. RF and LR: random search
5 |
6 | ## Default
7 |
8 | Three classifiers use the feature extracted from VGG16 which is based on the original time series. ```mode = 'no'``` means no filter was used upon the ECG signal.
9 | ```feature = 'vgg16'``` means it use the VGG16 features.
10 |
11 | ## Demo
12 |
13 | 1. Run directly in command line
14 | ```shell
15 | $ python3 classifier_MLP.py vgg16
16 | ```
17 | 2. Submit a pbs job
18 | ```shell
19 | $ qsub -v classifier=MLP,feature=vgg16 classifier.sh
20 | ```
21 |
22 |
--------------------------------------------------------------------------------
/classifier/classifier.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -N Train_{classifier}_{feature}
3 | #PBS -l nodes=1:ppn=20
4 | #PBS -l walltime=88888:00:00
5 | #PBS -q adf
6 | #PBS -j oe
7 | #PBS -m ae
8 | #PBS -M your@email.com
9 |
10 | which python
11 |
12 | cd $PBS_O_WORKDIR
13 | NPROCS=`wc -l < $PBS_NODEFILE`
14 |
15 | python classifier_{classifier}.py ${feature}
16 |
17 |
--------------------------------------------------------------------------------
/classifier/classifier_LR.py:
--------------------------------------------------------------------------------
1 | """
2 | Logistic regression with grid search the best parameters on the validation dataset
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ python classifier_LR.py vgg16
8 | """
9 | import os
10 | import sys
11 | sys.path.append('../')
12 | from sklearn.ensemble import RandomForestRegressor
13 | from sklearn.linear_model import LogisticRegression
14 | from sklearn.model_selection import RandomizedSearchCV
15 | import numpy as np
16 | import pickle
17 | from tensorflow import keras
18 | from sklearn import metrics
19 | from lib.util import eval
20 |
21 | # Param
22 | data_path = '../'
23 | mode = 'no'
24 | nb_classes = 2
25 | feature = str(sys.argv[1]) if len(sys.argv) > 1 else 'vgg16'
26 |
27 | # Load dataset
28 | with open(data_path+'data/ECG_info.pkl', 'rb') as f:
29 | label = pickle.load(f)
30 |
31 | with open(data_path+'feature-{1}/train/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
32 | train_x = np.load(f)
33 | train_y = label['Y_train']
34 |
35 | with open(data_path+'feature-{1}/val/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
36 | val_x = np.load(f)
37 | val_y = label['Y_val']
38 |
39 | with open(data_path+'feature-{1}/test/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
40 | test_x = np.load(f)
41 | test_y = label['Y_test']
42 |
43 | X = np.concatenate([train_x, val_x], axis=0)
44 | y = np.concatenate([train_y, val_y], axis=0)
45 | train_indices = np.arange(train_x.shape[0])
46 | val_indices = np.arange(train_x.shape[0], X.shape[0])
47 |
48 | # Number of trees in random forest
49 | random_grid = {'penalty' : ['l1', 'l2'],
50 | 'C' : [0.4, 0.6, 0.8, 1],
51 | 'solver' : ['liblinear', 'saga', 'lbfgs']}
52 |
53 | rf = LogisticRegression(n_jobs=20, verbose=2)#random_state = 42)
54 |
55 | cv = [(train_indices, val_indices)]
56 |
57 | search = RandomizedSearchCV(
58 | rf,
59 | param_distributions=random_grid,
60 | cv=cv,
61 | n_iter=10,
62 | verbose=10,
63 | n_jobs=20
64 | )
65 |
66 | search.fit(X, y)
67 | print(search.best_params_)
68 | pred = rf.predict_proba(test_x)
69 | print(feature, eval(test_y, pred[:, 1]))
70 |
--------------------------------------------------------------------------------
/classifier/classifier_MLP.py:
--------------------------------------------------------------------------------
1 | """
2 | MLP classifier of the extracted features with early stopping strategy on the validation dataset.
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ python3 classifier_MLP.py vgg16
8 | """
9 | import sys
10 | sys.path.append('../')
11 | import numpy as np
12 | import pickle
13 | from tensorflow import keras
14 | from lib.util import eval
15 |
16 | # Param
17 | data_path = '../'
18 | nb_classes = 2
19 | batch_size = 16
20 | nb_epochs = 10
21 | mode = 'no'
22 | feature = str(sys.argv[1]) if len(sys.argv) > 1 else 'vgg16'
23 |
24 | # Load data
25 | with open(data_path+'data/ECG_info.pkl', 'rb') as f:
26 | label = pickle.load(f)
27 |
28 | with open(data_path+'feature-{1}/train/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
29 | train_x = np.load(f)
30 | train_y = keras.utils.to_categorical(label['Y_train'], num_classes=nb_classes)
31 |
32 | with open(data_path+'feature-{1}/val/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
33 | val_x = np.load(f)
34 | val_y = keras.utils.to_categorical(label['Y_val'], num_classes=nb_classes)
35 |
36 | with open(data_path+'feature-{1}/test/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
37 | test_x = np.load(f)
38 | test_y = keras.utils.to_categorical(label['Y_test'], num_classes=nb_classes)
39 |
40 | dim = train_x.shape[1]
41 |
42 | def classifier(nb_classes):
43 | # classifier
44 | x = keras.layers.Input(shape=(dim))
45 | dnn = keras.layers.Dense(128, activation='relu')(x)
46 | predictions = keras.layers.Dense(nb_classes, activation='softmax')(dnn)
47 | model = keras.models.Model(inputs=x, outputs=predictions)
48 | optimizer = keras.optimizers.Adam()
49 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
50 | return model
51 |
52 | def run(idx):
53 | net = classifier(nb_classes)
54 |
55 | reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=0.0001)
56 | checkpointer = keras.callbacks.ModelCheckpoint(filepath='../model/weights_tail_{0}.hdf5'.format(idx), verbose=1, save_best_only=True)
57 |
58 | # Train model on dataset
59 | hist = net.fit(train_x, train_y, batch_size=batch_size, nb_epoch=nb_epochs, verbose=1, validation_data=(val_x, val_y), callbacks = [reduce_lr, checkpointer], shuffle=True)
60 |
61 | net.load_weights('../model/weights_tail_{0}.hdf5'.format(idx))
62 |
63 | test_pred = net.predict(test_x)
64 | res = eval(test_y[:, 1], test_pred[:, 1])
65 | print('ROC_AUC:{0}, PR_AUC:{1}, F1:{2}'.format(*res))
66 |
67 | return [idx] + res
68 |
69 |
70 | run(feature)
71 |
72 |
73 |
74 |
75 |
76 |
--------------------------------------------------------------------------------
/classifier/classifier_RF.py:
--------------------------------------------------------------------------------
1 | """
2 | Random forest classifier with grid search the best parameters according to the performance on the validation dataset
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ python3 classifier_RF.py vgg16
8 | """
9 | import sys
10 | sys.path.append('../')
11 | from sklearn.ensemble import RandomForestRegressor,RandomForestClassifier
12 | from sklearn.model_selection import RandomizedSearchCV
13 | import numpy as np
14 | import pickle
15 | from tensorflow import keras
16 | from sklearn import metrics
17 | from lib.util import eval
18 |
19 | # Param
20 | data_path = '../'
21 | mode = 'no'
22 | nb_classes = 2
23 | feature = str(sys.argv[1]) if len(sys.argv) > 1 else 'vgg16'
24 |
25 | # Load data
26 | with open(data_path+'data/ECG_info.pkl', 'rb') as f:
27 | label = pickle.load(f)
28 |
29 | with open(data_path+'feature-{1}/train/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
30 | train_x = np.load(f)
31 | train_y = label['Y_train']
32 |
33 | with open(data_path+'feature-{1}/val/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
34 | val_x = np.load(f)
35 | val_y = label['Y_val']
36 |
37 | with open(data_path+'feature-{1}/test/{0}/{0}.npy'.format(*[mode, feature]), 'rb') as f:
38 | test_x = np.load(f)
39 | test_y = label['Y_test']
40 |
41 | X = np.concatenate([train_x, val_x], axis=0)
42 | y = np.concatenate([train_y, val_y], axis=0)
43 | train_indices = np.arange(train_x.shape[0])
44 | val_indices = np.arange(train_x.shape[0], X.shape[0])
45 |
46 | n_estimators = list([500, 1000, 1500])
47 | max_features = list([32, 64, 128])
48 | min_samples_split = [512, 1024]
49 | bootstrap = [True, False]
50 |
51 | # Create the random grid
52 | random_grid = {'n_estimators': n_estimators,
53 | 'max_features': max_features,
54 | 'min_samples_split': min_samples_split,
55 | 'bootstrap': bootstrap}
56 |
57 | rf = RandomForestClassifier(random_state=43, n_jobs=20)
58 |
59 | cv = [(train_indices, val_indices)]
60 |
61 | search = RandomizedSearchCV(
62 | rf,
63 | param_distributions=random_grid,
64 | cv=cv,
65 | n_iter=10,
66 | verbose=10,
67 | n_jobs=20)
68 |
69 | search.fit(X, y)
70 | print(search.best_params_)
71 | pred=search.predict_proba(test_x)
72 | print(feature, eval(test_y, pred[:, 1]))
73 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | 1. Source data: ```challenge2017.pkl``` (from [MINA](https://github.com/hsd1503/MINA))
4 | 2. Processed data (run python3 data.py)
5 | + training, validation and test ECG signals: ```ECG_X_*.bin``` (numpy: array type)
6 | + label: ECG_info.pkl (pickle: dict type with keys, ```Y_train```, ```Y_val```, ```Y_test```, ```pid_test```)
7 |
--------------------------------------------------------------------------------
/data/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Create dataset
3 | steps:
4 | 1. use the challenge2017.pkl from https://github.com/hsd1503/MINA
5 | 2. run:
6 | $ python3 data.py
7 | """
8 | import sys
9 | sys.path.append('../')
10 | from lib.util import make_data_physionet
11 |
12 | make_data_physionet('./')
13 |
--------------------------------------------------------------------------------
/docs/app_bokeh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/docs/app_bokeh.png
--------------------------------------------------------------------------------
/docs/app_main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/docs/app_main.png
--------------------------------------------------------------------------------
/docs/app_start.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/docs/app_start.png
--------------------------------------------------------------------------------
/docs/map.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/docs/map.png
--------------------------------------------------------------------------------
/extractor/README.md:
--------------------------------------------------------------------------------
1 | # Feature extractors
2 |
3 | ## Default
4 |
5 | Features are generated using VGG16 (excluding its top three layers) from the TMF images of original ECG signals. ```model_type = 'vgg16', mode='no'```. MPI size is set as 5 in ```.sh``` and ```reduce*.py``` files.
6 |
7 | ## Demo
8 |
9 | 1. Submit the parallel jobs: training set, 10 nodes, and no filter
10 | ```
11 | $ python map_gen_feature.py train 10 no
12 | ```
13 | 2. Collect all the results from the finished jobs
14 | ```
15 | $ python reduce_gen_feature.py train 10 no
16 | ```
--------------------------------------------------------------------------------
/extractor/gen_feature.py:
--------------------------------------------------------------------------------
1 | """
2 | Extract the features of TMF image
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo: map the dataset to 10 nodes and generate the first slice using 5 process
7 | $ mpirun -n 5 python3 gen_feature.py --mode train --freq no --slice 0 --nodes 10
8 | """
9 | import os
10 | import sys
11 | sys.path.append('../')
12 | from tqdm import tqdm
13 | import numpy as np
14 | import pandas as pd
15 | from mpi4py import MPI
16 | from collections import OrderedDict, Counter
17 | import pickle as dill
18 | import argparse
19 | from tensorflow.keras.applications.vgg16 import VGG16
20 | from tensorflow.keras.applications import VGG19, ResNet50
21 | from tensorflow.keras.models import Model
22 | from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
23 | from scipy.signal import butter, lfilter
24 | from lib.util import TMF_image as gen_TMF
25 | from lib.util import mkdir
26 | from lib.visual import filt_ECG
27 |
28 | def build(model_type):
29 | # create the base pre-trained model
30 | if model_type == 'vgg16':
31 | base_model = VGG16(weights='imagenet', include_top=False)
32 | elif model_type == 'vgg19':
33 | base_model = VGG19(weights='imagenet', include_top=False)
34 | elif model_type == 'resnet50':
35 | base_model = ResNet50(weights='imagenet', include_top=False)
36 | else:
37 | raise(Exception('model_type must be one of vgg16, vgg19 and resnet50.'))
38 | # add a global spatial average pooling layer
39 | x = base_model.output
40 | x = GlobalAveragePooling2D()(x)
41 | # this is the model we will train
42 | model = Model(inputs=base_model.input, outputs=x)
43 | return model
44 |
45 | parser = argparse.ArgumentParser()
46 |
47 | # Get the settings from command line
48 | parser.add_argument('--mode', type=str, help='data type: train, val or test')
49 | parser.add_argument('--freq', type=str, default='mid', help='filter type: low, mid, high, no')
50 | parser.add_argument('--slice', type=int, default=0, help='slice index of the dataset')
51 | parser.add_argument('--nodes', type=int, default=8, help='total slice of the dataset')
52 | args = parser.parse_args()
53 |
54 | mode = str(args.mode) # train, val or test
55 | freq = str(args.freq) # filter type
56 | slidx = int(args.slice) # index of node
57 | slice_num = int(args.nodes) # total nodes
58 | model_type = 'vgg16' # model for feature extraction
59 |
60 | extractor = build(model_type)
61 |
62 | comm = MPI.COMM_WORLD
63 | mpisize = int(comm.Get_size()) # total num of the cpu cores, the n_splits of the k-Fold
64 | mpirank = int(comm.Get_rank()) # rank of this core
65 |
66 | print(mpisize, mpirank)
67 |
68 | if mpirank == 0:
69 | # Slice and broadcast the data
70 | with open('../data/ECG_X_{0}.bin'.format(mode), 'rb') as fin:
71 | X = np.load(fin)
72 | total_num = len(X)//slice_num # map to nodes
73 | if slidx != slice_num-1:
74 | X = X[slidx*total_num: (slidx+1)*total_num]
75 | else:
76 | X = X[slidx*total_num: ]
77 | total_num = len(X)
78 | total_idx = np.arange(total_num)
79 | chunk_len = total_num//(mpisize-1) if total_num >= mpisize - 1 else total_num # map to process
80 | data = [X[total_idx[i*chunk_len: (i+1)*chunk_len]] for i in tqdm(range(mpisize), desc='Scatter data')] # scatter the data to other process
81 | else:
82 | data = None
83 | data = comm.scatter(data, root=0)
84 |
85 | # TMF
86 | data_path = '../feature-{3}/{0}/{1}/{2}/'.format(*[mode, freq, slidx, model_type])
87 | mkdir(data_path)
88 |
89 | all_feature = []
90 |
91 | win = 3000 # length of the time series
92 | D = 3 # order of the motif, 3 means triad
93 | shape = np.array([len(range(1, (win-1)//(D-1) + 1)), len(range(0, win-(D-1)*1)), D]) # TMF image shape
94 | overlap = win-(D-1)*shape[0] # overlap cause by the rotation in the TMF image
95 | TMF = np.zeros(shape) # placeholder of TMF image
96 |
97 | for ts in tqdm(data):
98 | filt_ts = filt_ECG(ts, freq)
99 | TMF = gen_TMF(filt_ts, overlap, TMF, D) # the gen_TMF function is optimized with numba package
100 | img = np.expand_dims(TMF, axis=0) # [1, W, H, 3]
101 | feature = extractor.predict(img) # [1, 512]
102 | all_feature.append(feature)
103 |
104 | all_feature = np.concatenate(all_feature, axis=0) if len(all_feature) != 0 else np.array([])
105 | np.save(data_path+str(mpirank), all_feature)
106 |
107 |
--------------------------------------------------------------------------------
/extractor/gen_feature.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -N ECG-${slice}
3 | #PBS -l nodes=1:ppn=20
4 | #PBS -l walltime=88888:00:00
5 | #PBS -q adf
6 | #PBS -j oe
7 | #PBS -m ae
8 | #PBS -M your@email.com
9 |
10 | which python
11 |
12 | cd $PBS_O_WORKDIR
13 | NPROCS=`wc -l < $PBS_NODEFILE`
14 |
15 | mpirun -n 5 python gen_feature.py --mode ${mode} --freq ${freq} --slice ${slice} --nodes ${nodes}
16 |
17 |
--------------------------------------------------------------------------------
/extractor/map_gen_feature.py:
--------------------------------------------------------------------------------
1 | """
2 | Map the gen_feature.py to different nodes and cores to run them parallelly.
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ python map_gen_feature.py train 10 no
8 | """
9 | import sys
10 | sys.path.append('../')
11 | import subprocess
12 | from lib.util import mkdir
13 | mode= str(sys.argv[1])
14 | nodes = int(sys.argv[2])
15 | freq = str(sys.argv[3]) # no, mid
16 |
17 | for idx in range(nodes):
18 | qsub_command = """qsub -v slice={0},nodes={1},mode={2},freq={3} -q adf gen_feature.sh""".format(*[idx, nodes, mode, freq]) # submit the job script
19 | exit_status = subprocess.call(qsub_command, shell=True) # upload
20 | if exit_status is 1: # Check to make sure the job submitted
21 | print("Job {0} failed to submit".format(qsub_command))
22 |
23 |
24 |
--------------------------------------------------------------------------------
/extractor/reduce_gen_feature.py:
--------------------------------------------------------------------------------
1 | """
2 | Collect the features
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo: collect the features of training dataset
7 | $ python reduce_gen_feature.py train 10 no
8 | """
9 | import numpy as np
10 | import argparse
11 | import sys
12 |
13 | mode = str(sys.argv[1])
14 | slide_num = int(sys.argv[2]) # slice num must be same as the nodes paramters of gen_feature.py
15 | mpi_size = 5 # must be same as the mpirun -n in .sh script
16 | freq = str(sys.argv[3]) # no filter
17 | path = '../feature-vgg16/{0}/{1}/'.format(*[mode, freq])
18 | feature = []
19 | for i in range(slide_num):
20 | for j in range(mpi_size):
21 | with open(path+'{0}/{1}.npy'.format(*[i, j]), "rb") as f:
22 | tmp = np.load(f)
23 | # print(tmp.shape)
24 | if len(tmp):
25 | feature.append(tmp)
26 |
27 | feature = np.concatenate(feature, axis=0)
28 | np.save(path+freq, feature)
29 | print(feature.shape)
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/length_effect/README.md:
--------------------------------------------------------------------------------
1 | # Length effect
2 |
3 | ## Default
4 |
5 | Length effect based on trained VGG-MLP model is evaluated on the test dataset. Length are selected from 100 to 3000 with 100 as the gap.
6 |
7 | ## Demo
8 |
9 | 1. Submit the parallel jobs: 10 nodes
10 | ```
11 | $ python map_length_effect.py 10
12 | ```
13 | 2. Collect all the results from the finished jobs
14 | ```
15 | $ python reduce_length_effect.py prob 10 # collect probabilities
16 | $ python reduce_length_effect.py time 10 # collect time consumptions
17 | ```
--------------------------------------------------------------------------------
/length_effect/draw.py:
--------------------------------------------------------------------------------
1 | """
2 | Draw the lines of length effect
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ python3 draw.py
8 | """
9 | from matplotlib.collections import PatchCollection, LineCollection
10 | from matplotlib import gridspec
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import pickle
14 | import sys
15 | sys.path.append('../')
16 | plt.rc('text', usetex=True)
17 | plt.rc('font', family='Times New Roman')
18 | plt.rcParams['xtick.direction'] = 'in'
19 | plt.rcParams['ytick.direction'] = 'in'
20 | from matplotlib.patches import Rectangle
21 | from lib.visual import detectR
22 | from lib.util import eval
23 |
24 | # load dataset
25 | with open('./slice/prob.npy', 'rb') as f:
26 | prob_res = np.load(f)
27 | with open('./slice/time.npy', 'rb') as f:
28 | time_res = np.load(f)
29 | with open('../data/ECG_X_test.bin', 'rb') as f:
30 | data = np.load(f)
31 | with open('../data/ECG_info.pkl', 'rb') as f:
32 | label = pickle.load(f)
33 | label = label['Y_test']
34 |
35 | '''Draw the length effect of an AF signal'''
36 | idx = 18107 # an AF signal
37 | ts = data[idx]
38 | out = prob_res[idx]
39 | t = time_res[idx]
40 |
41 | fig = plt.figure(figsize=(7, 4))
42 | gs = gridspec.GridSpec(2, 1, height_ratios=[5, 1], hspace=0.)
43 | # Upper panel
44 | ax = plt.subplot(gs[0])
45 | vertical = 1600
46 | color = 'tab:red'
47 | line = ax.plot(range(100, 3100, 100), 100 * np.array([p[1] for p in out]), '.r-', linewidth=1)
48 | ax.plot([vertical, vertical], [0, 100], 'r--', linewidth=1)
49 | ax.set_ylim([0, 100])
50 | ax.set_ylabel('Probability of AF (\%)', color=color, fontsize=15)
51 | ax.tick_params(axis='y', labelcolor=color)
52 |
53 | ax2 = ax.twinx() # instantiate a second axes that shares the same x-axis
54 | ax2.spines['right'].set_color('tab:blue')
55 | ax2.spines['left'].set_color('tab:red')
56 | color = 'tab:blue'
57 | # we already handled the x-label with ax1
58 | ax2.set_ylabel('Time consumption (s)', color=color, fontsize=15)
59 | bar = ax2.bar(x=range(100, 3100, 100), height=np.array(t), width=20, color=color, alpha=0.7)
60 | ax2.tick_params(axis='y', labelcolor=color)
61 | ax2.set_xlim([0, 3000])
62 | ax2.set_xticks([])
63 |
64 | # Lower panel
65 | ax = plt.subplot(gs[1])
66 | ax.plot(ts, 'k', linewidth=1)
67 | ax.set_ylabel('ECG', fontsize=15)
68 | ax.set_xlim([0, 3000])
69 | ax.set_xlabel('Length of ECG signals', fontsize=15)
70 | ax.set_ylim([-2.5, 7.5])
71 | ax.plot([vertical, vertical], [-2.5, 7.5], 'r--', linewidth=1)
72 | R_peak = detectR(ts)
73 | R_peak.sort()
74 | ax.plot(R_peak[6:8], ts[R_peak][6:8], 'r.', markersize=5)
75 | ax.text(R_peak[6]-160, ts[R_peak][6]-2, '7th', fontsize=15)
76 | ax.text(R_peak[7]-160, ts[R_peak][7]-2, '8th', fontsize=15)
77 | xt = np.array(ax.get_xticks(), dtype=int)
78 | xt = np.append(xt, vertical)
79 | xtl = xt.tolist()
80 | xtl.remove(1500)
81 | xtl[-1] = str(vertical)
82 | ax.set_xticks(np.delete(xt, 3))
83 | ax.set_xticklabels(xtl)
84 | ax.set_ylim([-2, 7])
85 | boxes = [Rectangle((0, -2), 1600, 9)]
86 | pc = PatchCollection(boxes, facecolor='red', alpha=0.2, edgecolor=None)
87 | ax.add_collection(pc)
88 | plt.tight_layout()
89 | fig.savefig('flexibility.pdf', dpi=300)
90 |
91 | '''Draw the statistical result'''
92 | res = []
93 | for i in range(30):
94 | res.append(eval(label, prob_res[:, i, 1]))
95 | res = np.array(res) # [30, 3]
96 |
97 | fig = plt.figure(figsize=(7, 3))
98 | ax = plt.subplot()
99 | ax.plot(np.arange(100, 3100, 100), 100 * res[:, 0], 'vy-', linewidth=1.0, markersize=6)
100 | ax.plot(np.arange(100, 3100, 100), 100 * res[:, 1], '.r--', linewidth=1.0, markersize=4)
101 | ax.plot(np.arange(100, 3100, 100), 100 * res[:, 2], 'og-.', linewidth=1.0, markersize=4)
102 | ax.set_ylim([0, 100])
103 | ax.set_xlim([0, 3000])
104 | ax.set_xlabel('Length of ECG signals', fontsize=15)
105 | ax.set_ylabel('Performance (\%)', fontsize=15, color='k')
106 | ax.legend(['ROC AUC', 'PR AUC', 'F1'], fontsize=15,
107 | framealpha=0, loc='center') # , ncol=3)
108 | ax.tick_params(axis='y', labelcolor='k')
109 | ax2 = ax.twinx() # instantiate a second axes that shares the same x-axis
110 | ax2.spines['right'].set_color('tab:blue')
111 | ax2.spines['left'].set_color('k')
112 | color = 'tab:blue'
113 | # we already handled the x-label with ax1
114 | ax2.set_ylabel('Time consumption (s)', color=color, fontsize=15)
115 | bar = ax2.bar(x=range(100, 3100, 100), height=np.mean(time_res, axis=0), width=20, color=color, alpha=0.7)
116 | ax2.tick_params(axis='y', labelcolor=color)
117 | ax2.set_xlim([0, 3000])
118 | plt.tight_layout()
119 | fig.savefig('length_effect.pdf', dpi=300)
120 |
--------------------------------------------------------------------------------
/length_effect/length_effect.py:
--------------------------------------------------------------------------------
1 | """
2 | Get the performance (time and probability) of model with different length of ECG signals
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ mpirun -n 5 python3 parallel_flex.py --slice 0 --nodes 10
8 | """
9 | import os
10 | import sys
11 | sys.path.append('../')
12 | from tqdm import tqdm
13 | from mpi4py import MPI
14 | import numpy as np
15 | import pandas as pd
16 | from collections import OrderedDict, Counter
17 | import pickle as dill
18 | import argparse
19 | import time
20 | from lib.util import TMF_image as gen_TMF
21 | from lib.util import mkdir
22 | from lib.util import build_fullnet as build
23 |
24 | net = build()
25 |
26 | parser = argparse.ArgumentParser()
27 |
28 | # Get the settings from command line
29 | parser.add_argument('--slice', type=int, default=0, help='slice index of the dataset')
30 | parser.add_argument('--nodes', type=int, default=8, help='total slice of the dataset')
31 | args = parser.parse_args()
32 |
33 | mode = 'test'
34 | slidx = int(args.slice)
35 | slice_num = int(args.nodes) # node num
36 |
37 | comm = MPI.COMM_WORLD
38 | mpisize = int(comm.Get_size()) # total num of the cpu cores, the n_splits of the k-Fold
39 | mpirank = int(comm.Get_rank()) # rank of this core
40 |
41 | if mpirank == 0:
42 | with open('../data/ECG_X_{0}.bin'.format(mode), 'rb') as fin:
43 | X = np.load(fin)
44 | total_num = len(X)//slice_num
45 | if slidx != slice_num-1:
46 | X = X[slidx*total_num: (slidx+1)*total_num]
47 | else:
48 | X = X[slidx*total_num: ]
49 | total_num = len(X)
50 | total_idx = np.arange(total_num)
51 | chunk_len = total_num//(mpisize-1) if total_num >= mpisize - 1 else total_num
52 | data = [X[total_idx[i*chunk_len: (i+1)*chunk_len]] for i in tqdm(range(mpisize), desc='Scatter data')]
53 | else:
54 | data = None
55 |
56 | data = comm.scatter(data, root=0)
57 |
58 | # TMF
59 | data_path = 'slice/{0}/'.format(slidx)
60 | mkdir(data_path)
61 |
62 | D = 3
63 | output = []
64 | time_consumption = []
65 |
66 | for ts in tqdm(data):
67 | onesample = []
68 | onetime = []
69 | for win in range(100, 3100, 100):
70 | time_start=time.time() # start time
71 | shape = np.array([len(range(1, (win-1)//(D-1) + 1)), len(range(0, win-(D-1)*1)), D])
72 | overlap = win-(D-1)*shape[0]
73 | TMF = np.zeros(shape)
74 | TMF = gen_TMF(ts[0: win], overlap, TMF, D)
75 | img = np.expand_dims(TMF, axis=0) # [1, W, H, 3]
76 | prob = net.predict(img) # [1, 2]
77 | time_end=time.time() # end time
78 | onesample.append(prob)
79 | onetime.append(time_end - time_start)
80 | if len(onetime) != 0 and len(onesample) != 0:
81 | onesample = np.concatenate(onesample, axis=0) # [S, 2]
82 | onetime = np.array(onetime) # [S]
83 | output.append(onesample)
84 | time_consumption.append(onetime)
85 |
86 | output = np.stack(output, axis=0) if len(output) != 0 else np.array([]) # [N, S, 2]
87 | time_consumption = np.stack(time_consumption, axis=0) if len(time_consumption) != 0 else np.array([]) # [N, S]
88 |
89 | np.save(data_path+'prob_'+str(mpirank), output)
90 | np.save(data_path+'time_'+str(mpirank), time_consumption)
91 |
92 |
93 |
--------------------------------------------------------------------------------
/length_effect/length_effect.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -N ECG-${slice}
3 | #PBS -l nodes=1:ppn=20
4 | #PBS -l walltime=88888:00:00
5 | #PBS -q adf
6 | #PBS -j oe
7 | #PBS -m ae
8 | #PBS -M your@email.com
9 |
10 | which python
11 |
12 | cd $PBS_O_WORKDIR
13 | NPROCS=`wc -l < $PBS_NODEFILE`
14 |
15 | mpirun -n 5 python length_effect.py --slice ${slice} --nodes ${nodes}
16 |
17 |
--------------------------------------------------------------------------------
/length_effect/map_length_effect.py:
--------------------------------------------------------------------------------
1 | """
2 | Map the parallel_flex.py to different nodes and cores to run them parallelly.
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ python3 map_length_effect.py 10
8 | """
9 | import sys
10 | sys.path.append('../')
11 | import subprocess
12 | from lib.util import mkdir
13 |
14 | nodes = int(sys.argv[1])
15 | for idx in range(nodes):
16 | mkdir('slice/{0}/'.format(idx))
17 | qsub_command = """qsub -v slice={0},nodes={1} -q adf length_effect.sh""".format(*[idx, nodes])
18 | exit_status = subprocess.call(qsub_command, shell=True) # upload
19 | if exit_status is 1: # Check to make sure the job submitted
20 | print("Job {0} failed to submit".format(qsub_command))
21 |
22 |
23 |
--------------------------------------------------------------------------------
/length_effect/reduce_length_effect.py:
--------------------------------------------------------------------------------
1 | """
2 | Collect the result of the length effect
3 | Author: Yadong Zhang
4 | E-mail: zhangyadong@stu.xjtu.edu.cn
5 |
6 | Demo:
7 | $ python reduce_length_effect.py prob 10
8 | """
9 | import numpy as np
10 | import argparse
11 | import sys
12 |
13 | mode = str(sys.argv[1]) # 'prob' or 'time'
14 | slide_num = int(sys.argv[2])
15 | mpi_size = 5
16 |
17 | path = 'slice/'
18 | feature = []
19 | for i in range(slide_num):
20 | for j in range(mpi_size):
21 | with open(path+'{0}/{2}_{1}.npy'.format(*[i, j, mode]), "rb") as f:
22 | tmp = np.load(f)
23 | if len(tmp):
24 | feature.append(tmp)
25 |
26 | feature = np.concatenate(feature, axis=0)
27 | np.save(path+mode, feature)
28 | print(feature.shape)
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/lib/__init__.py
--------------------------------------------------------------------------------
/lib/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import pickle as dill
5 |
6 | import cv2
7 | import numba as nb
8 | import numpy as np
9 | import pandas as pd
10 | from tqdm import tqdm
11 | from collections import OrderedDict, Counter
12 |
13 | from sklearn import metrics
14 | from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
15 | from sklearn.metrics import log_loss
16 | from sklearn.metrics import confusion_matrix
17 |
18 | import tensorflow as tf
19 | from tensorflow.keras import backend as K
20 | from tensorflow.keras.applications.vgg16 import VGG16
21 | from tensorflow.keras.models import Model
22 | from tensorflow import keras
23 |
24 | @nb.jit(nopython=True)
25 | def TMF_image(ts, overlap, TMF, D=3):
26 | '''
27 | generate triadic motif field image of time series
28 | :param ts: 1-d array, time series
29 | :param shape: int, shape of TMF, np.array([len(range(1, (win-1)//(D-1) + 1)), len(range(0, win-(D-1)*1)), D]) # TMF image shape
30 | :param overlap: int, overlap cause by the rotation in the TMF image, overlap = win-(D-1)*shape[0]
31 | :param TMF: 2-d array, np.zeros(shape) # placeholder of TMF image
32 | :param D: int, number of ordinal points, default triad
33 | return image, [W, H, D]
34 | '''
35 | shape = TMF.shape
36 | for i in range(shape[0]):
37 | right_bound = len(ts)-(D-1)*(i+1)
38 | for j in range(right_bound):
39 | motif_idx = np.arange(j, j+D*(i+1), (i+1))
40 | if j < right_bound - overlap:
41 | TMF[i, j, :] = ts[motif_idx]
42 | TMF[shape[0]-i-1, shape[1]-j-1, :] = ts[motif_idx]
43 | else:
44 | TMF[i, j, :] = ts[motif_idx]
45 | return TMF
46 |
47 | def get_GCAM(model, inputs, targets, layers=-11):
48 | '''
49 | Grad-CAM
50 | :param model: keras model
51 | :param input: 4-d array, shape is [1, W, H, C]
52 | :param target: 1-d array, shape is [nb_class] (one-hot)
53 | :param layers: int value, visualize which layer
54 | return: 2-d array, shape is [W, H]
55 | '''
56 | class_idx = np.argmax(targets, axis=-1)
57 | class_output = model.output[:, class_idx]
58 | last_conv_layer = model.layers[layers]
59 | class_output = model.output[:, class_idx]
60 |
61 | x = inputs.copy()
62 | grads = K.gradients(class_output, last_conv_layer.output)[0]
63 | pooled_grads = K.mean(grads, axis=(0, 1, 2))
64 | iterate = K.function(model.input, [pooled_grads, last_conv_layer.output[0]])
65 | pooled_grads_value, conv_layer_output_value = iterate(x)
66 | for i, grads_value in enumerate(pooled_grads_value):
67 | conv_layer_output_value[:, :, i] *= grads_value
68 | heatmap = np.mean(conv_layer_output_value, axis=-1)
69 | heatmap[heatmap<0] = 0 # ReLU
70 | heatmap = cv2.resize(heatmap, (x.shape[2], x.shape[1]))
71 | gcam = (heatmap-heatmap.min())/(heatmap.max() - heatmap.min())
72 | return gcam
73 |
74 | @nb.jit(nopython=True)
75 | def get_sym_GCAM(overlap, gcam, D=3):
76 | '''
77 | normalize Grad-CAM into symmetrized Grad-CAM
78 | :param shape: int, shape of TMF, np.array([len(range(1, (win-1)//(D-1) + 1)), len(range(0, win-(D-1)*1)), D]) # TMF image shape
79 | :param overlap: int, overlap cause by the rotation in the TMF image, overlap = win-(D-1)*shape[0]
80 | :param gcam: 2-d array, np.zeros(shape) # placeholder of gcam
81 | :param D: int, number of ordinal points, default triad
82 | return image, [W, H]
83 | '''
84 | shape = gcam.shape
85 | for i in range(shape[0]):
86 | right_bound = shape[1]+(D-1)-(D-1)*(i+1)
87 | for j in range(right_bound):
88 | if j < right_bound - overlap:
89 | gcam[i, j] = (gcam[i, j] + gcam[shape[0]-i-1, shape[1]-j-1])/2.0
90 | gcam[shape[0]-i-1, shape[1]-j-1] = gcam[i, j]
91 |
92 | else:
93 | gcam[i, j] = gcam[i, j]
94 | gcam = (gcam - gcam.min())/(gcam.max()-gcam.min())
95 | return gcam
96 |
97 | def get_cam_image(net, ts, return_proba=False):
98 | '''
99 | generate the SG-CAM of AF and non-AF
100 | :param net: keras model
101 | :param ts: 1-d array, time series
102 | :param return_proba: boolean
103 | return SG-CAM of non-AF, SG-CAM of AF, predicted probability (if return_proba=True)
104 | '''
105 | D = 3
106 | shape = np.array([len(range(1, (len(ts)-1)//(D-1) + 1)), len(range(0, len(ts)-(D-1)*1)), D])
107 | overlap = len(ts)-(D-1)*shape[0]
108 | # TMF image: [1, W, H, 3]
109 | img = np.zeros(shape)
110 | img = TMF_image(ts, overlap, img, D)
111 | img = np.expand_dims(img, axis=0)
112 | # SG-CAM image of non-AF
113 | gcam = get_GCAM(net, img, [1,0], layers=-3)
114 | gcam_norm = get_sym_GCAM(overlap, gcam, D)
115 | nAF_cam = gcam_norm.copy()
116 | # SG-CAM image of AF
117 | gcam = get_GCAM(net, img, [0,1], layers=-3)
118 | gcam_norm = get_sym_GCAM(overlap, gcam, D)
119 | AF_cam = gcam_norm.copy()
120 | if not return_proba:
121 | return nAF_cam, AF_cam
122 | else:
123 | proba = net.predict(img)
124 | return nAF_cam, AF_cam, proba
125 |
126 | def build_fullnet(path='../model/weights_tail_best.hdf5'):
127 | '''
128 | build the full network (VGG16-MLP)
129 | :param path: str, path of the trained MLP
130 | return: keras network
131 | '''
132 | base_model = VGG16(weights='imagenet', include_top=False)
133 | x = base_model.output
134 | vec = keras.layers.GlobalAveragePooling2D()(x)
135 |
136 | row_input = keras.layers.Input((vec.shape[1:]))
137 | dnn = keras.layers.Dense(128, activation='relu')(row_input)
138 |
139 | predictions = keras.layers.Dense(2, activation='softmax')(dnn)
140 |
141 | tail_model = keras.models.Model(inputs=row_input, outputs=predictions)
142 | if path is not None:
143 | tail_model.load_weights(path)
144 |
145 | out = tail_model(vec)
146 | network = keras.models.Model(inputs=base_model.input, outputs=out)
147 | return network
148 |
149 | def eval(target, predict):
150 | '''
151 | evaluate the results
152 | :param target: 1-d array, 0 or 1
153 | :param predict: 1-d array, probability of AF
154 | return roc_auc, pr_auc and F1
155 | '''
156 | ROC_AUC = metrics.roc_auc_score(target, predict)
157 | # PR_AUC
158 | precision, recall, _thresholds = metrics.precision_recall_curve(
159 | target, predict)
160 | PR_AUC = metrics.auc(recall, precision)
161 | # F1
162 | predict = np.array([i > 0.5 for i in predict], dtype=int)
163 | F1 = metrics.f1_score(target, predict)
164 | return [ROC_AUC, PR_AUC, F1]
165 |
166 | def eval_patient(pid, label, proba):
167 | '''
168 | evaluate the patient-wise accuracy
169 | :param pid: 1-d array, shape is [N,], id of patients
170 | :param label: 1-d array, shape is [N,], 0 or 1, 1 indicates the AF segment
171 | :param proba: 1-d array, shape is [N,], probability of AF
172 | return: 2-d array, [N, 3], columns are pid, count of segment, patient-wise accuracy
173 | '''
174 | wrong = pid[(label == 1) & (proba < 0.5)]
175 | wrong_pid = np.unique(wrong)
176 | AF = pid[(label == 1)]
177 | AF_pid = np.unique(AF)
178 | detail = []
179 | for i in AF_pid:
180 | # [id, length, accuracy]
181 | total = float(np.sum(AF == i))
182 | detail.append([i, total, (total - np.sum(wrong == i)) / total])
183 | detail.sort(key=lambda x: x[2]) # sort according to the accuracy
184 | return np.array(detail)
185 |
186 | def mkdir(path):
187 | """
188 | mkdir of the path
189 | :param input: string of the path
190 | return: boolean
191 | """
192 | path = path.strip()
193 | path = path.rstrip("\\")
194 | isExists = os.path.exists(path)
195 |
196 | if not isExists:
197 | os.makedirs(path)
198 | print(path+' is created!')
199 | return True
200 | else:
201 | print(path+' already exists!')
202 | return False
203 |
204 | def slide_and_cut(X, Y, window_size, stride, output_pid=False):
205 | '''
206 | From https://github.com/hsd1503/MINA
207 | MINA: Multilevel Knowledge-Guided Attention for Modeling Electrocardiography Signals, IJCAI 2019
208 | '''
209 | out_X = []
210 | out_Y = []
211 | out_pid = []
212 | n_sample = X.shape[0]
213 | mode = 0
214 | for i in range(n_sample):
215 | tmp_ts = X[i]
216 | tmp_Y = Y[i]
217 | if tmp_Y == 0:
218 | i_stride = stride
219 | elif tmp_Y == 1:
220 | i_stride = stride//10
221 | for j in range(0, len(tmp_ts)-window_size, i_stride):
222 | out_X.append(tmp_ts[j:j+window_size])
223 | out_Y.append(tmp_Y)
224 | out_pid.append(i)
225 | if output_pid:
226 | return np.array(out_X), np.array(out_Y), np.array(out_pid)
227 | else:
228 | return np.array(out_X), np.array(out_Y)
229 |
230 | def make_data_physionet(data_path, n_split=50, window_size=3000, stride=500):
231 | '''
232 | From https://github.com/hsd1503/MINA
233 | MINA: Multilevel Knowledge-Guided Attention for Modeling Electrocardiography Signals, IJCAI 2019
234 | '''
235 | # read pkl
236 | with open(os.path.join(data_path, 'challenge2017.pkl'), 'rb') as fin:
237 | res = dill.load(fin)
238 | ## scale data
239 | all_data = res['data']
240 | for i in range(len(all_data)):
241 | tmp_data = all_data[i]
242 | tmp_std = np.std(tmp_data)
243 | tmp_mean = np.mean(tmp_data)
244 | all_data[i] = (tmp_data - tmp_mean) / tmp_std # normalize
245 | all_data = res['data']
246 | all_data = np.array(all_data)
247 | ## encode label
248 | all_label = []
249 | for i in res['label']:
250 | if i == 'A':
251 | all_label.append(1)
252 | else:
253 | all_label.append(0)
254 | all_label = np.array(all_label)
255 |
256 | # split train test
257 | n_sample = len(all_label)
258 | split_idx_1 = int(0.75 * n_sample)
259 | split_idx_2 = int(0.85 * n_sample)
260 |
261 | shuffle_idx = np.random.RandomState(seed=40).permutation(n_sample)
262 | all_data = all_data[shuffle_idx]
263 | all_label = all_label[shuffle_idx]
264 |
265 | X_train = all_data[:split_idx_1]
266 | X_val = all_data[split_idx_1:split_idx_2]
267 | X_test = all_data[split_idx_2:]
268 | Y_train = all_label[:split_idx_1]
269 | Y_val = all_label[split_idx_1:split_idx_2]
270 | Y_test = all_label[split_idx_2:]
271 |
272 | # slide and cut
273 | print(Counter(Y_train), Counter(Y_val), Counter(Y_test))
274 | X_train, Y_train = slide_and_cut(
275 | X_train, Y_train, window_size=window_size, stride=stride)
276 | X_val, Y_val = slide_and_cut(
277 | X_val, Y_val, window_size=window_size, stride=stride)
278 | X_test, Y_test, pid_test = slide_and_cut(
279 | X_test, Y_test, window_size=window_size, stride=stride, output_pid=True)
280 | print('after: ')
281 | print(Counter(Y_train), Counter(Y_val), Counter(Y_test))
282 |
283 | # shuffle train
284 | shuffle_pid = np.random.RandomState(seed=42).permutation(
285 | Y_train.shape[0]) # np.random.permutation(Y_train.shape[0])
286 | X_train = X_train[shuffle_pid]
287 | Y_train = Y_train[shuffle_pid]
288 |
289 | # save
290 | res = {'Y_train': Y_train, 'Y_val': Y_val,
291 | 'Y_test': Y_test, 'pid_test': pid_test}
292 | with open(os.path.join(data_path, 'ECG_info.pkl'), 'wb') as fout:
293 | dill.dump(res, fout)
294 |
295 | fout = open(os.path.join(data_path, 'ECG_X_train.bin'), 'wb')
296 | np.save(fout, X_train)
297 | fout.close()
298 |
299 | fout = open(os.path.join(data_path, 'ECG_X_val.bin'), 'wb')
300 | np.save(fout, X_val)
301 | fout.close()
302 |
303 | fout = open(os.path.join(data_path, 'ECG_X_test.bin'), 'wb')
304 | np.save(fout, X_test)
305 | fout.close()
306 |
--------------------------------------------------------------------------------
/lib/visual.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 |
4 | import matplotlib.pyplot as plt
5 | from matplotlib.collections import LineCollection
6 | from matplotlib import gridspec
7 | plt.rc('text', usetex=True)
8 | plt.rc('font', family='Times New Roman')
9 | plt.rcParams['xtick.direction'] = 'in'
10 | plt.rcParams['ytick.direction'] = 'in'
11 | import matplotlib.patches as patches
12 | from tqdm import tqdm
13 |
14 | from scipy.signal import butter, lfilter
15 | ECG_param = {'PQ_max': 0.2, 'QRS_max': 0.150,
16 | 'P_min': 0.08, 'RR_min': 0.200,
17 | 'sample_freq': 300
18 | } # unit: s, Hz
19 |
20 | def detectR(data, threshold=0.96):
21 | '''
22 | detect R peaks from ECG signals
23 | :param data: 1-d array, ECG signals
24 | :param threshold: float, threshold for the R peaks
25 | return: R peak indices
26 | '''
27 | RR_interval = int(ECG_param['RR_min']*ECG_param['sample_freq'])
28 | # Filter the QRS regions for R detection
29 | nyquist_freq = 0.5 * ECG_param['sample_freq']
30 | bandpass = [10, 50] # R band
31 | low = bandpass[0] / nyquist_freq
32 | high = bandpass[1] / nyquist_freq
33 | filter_order = 1
34 | b, a = butter(filter_order, [low, high], btype="band")
35 | ts = lfilter(b, a, data)
36 |
37 | # Second order difference of the filtered time series
38 | diff2ts = np.power(np.diff(np.diff(ts)), 2)
39 | # Find the top 96% amplitude of the ECG as the threshold of R peaks
40 | thres_amp= np.percentile(diff2ts, int(threshold*100))
41 |
42 | idxsort = np.argsort(diff2ts) # obtain the sorted idx of time series
43 | idx = np.arange(len(idxsort))
44 | filtered = []
45 | for i in idxsort[::-1]:
46 | if diff2ts[i] > thres_amp:
47 | if i in idx:
48 | filtered.append(i) # if i is in the idx, it means that the peak is not removed previously
49 | idx = np.setdiff1d(idx, range(i-RR_interval//2, i+RR_interval//2)) # remove the idx
50 | R_peak = []
51 | for i in filtered:
52 | try:
53 | p = i-RR_interval//2 + np.argmax(data[i-RR_interval//2: i+RR_interval//2])
54 | if np.max(data[i-RR_interval//2: i+RR_interval//2]) > np.mean(data):
55 | R_peak.append(p)
56 | except:
57 | pass
58 |
59 | return R_peak
60 |
61 | def flip_ECG(data):
62 | '''
63 | Some ECG is reversed which needed to be inversed
64 | :param data: 1-d array, ECG signal
65 | return: ECG signal
66 | '''
67 | less, more = np.percentile(data, (2, 98))
68 | if np.abs(less) > np.abs(more):
69 | return -data
70 | else:
71 | return data
72 |
73 | def get_beats(data, R_peak:list):
74 | '''
75 | Get beats
76 | :param ts: 1-d array, ECG signals
77 | :param R_peak: list, R_peak indices
78 | return: list of beats
79 | '''
80 | R_peak = np.sort(R_peak)
81 | RR_interval = int(np.mean(np.diff(R_peak)))
82 | beats = []
83 | for i in R_peak:
84 | beats.append(data[i-RR_interval//2: i+RR_interval//2])
85 | return beats
86 |
87 | def filt_ECG(data, freq='mid', sample_freq=300):
88 | '''
89 | filter the ECG signal
90 | :param freq: str, 'mid', 'high', 'low', 'no'
91 | :param sample_freq: int
92 | return: filtered ECG signal
93 | '''
94 | if freq == 'no': return data
95 | nyquist_freq = 0.5 * sample_freq
96 | select = {'low': (0.001 / nyquist_freq, 0.5 / nyquist_freq), 'mid': (0.5 / nyquist_freq, 50 / nyquist_freq), 'high': 50 / nyquist_freq}
97 | filter_order = 1
98 | if freq == 'high':
99 | b, a = butter(filter_order, select[freq], btype="high")
100 | else:
101 | b, a = butter(filter_order, select[freq], btype="band")
102 | out = lfilter(b, a, data)
103 | return out
104 |
105 | def argmax2D(matrix):
106 | """
107 | find the index of maximum value in 2d matrix
108 | :param matrix: 2-d array
109 | :return: x, y
110 | """
111 | cols = matrix.shape[1]
112 | loc = np.argmax(matrix)
113 | y, x = loc //cols,loc %cols
114 | return y, x
115 |
116 | def plot_cam(ts, img, path=None, point=None, peak_idx=None):
117 | '''
118 | plot time series and SG-CAM images
119 | :param ts: 1-d array, ECG signal
120 | :param img: 2-d array, image of SG-CAM
121 | :param path: str, path to save the figure
122 | :param point: plot which point in the SG-CAM
123 | :param peak_idx: plot which peak in ECG
124 | '''
125 | if point:
126 | win = 100
127 | centerx, centery = argmax2D(img[point[1]-win:point[1]+win, point[0]-win: point[0]+win])
128 | centerx += point[1]-win
129 | centery += point[0]-win
130 |
131 | fig = plt.figure(figsize=(6, 4))
132 | gs = gridspec.GridSpec(2, 1, height_ratios=[1,4], hspace=0.)
133 | ax = plt.subplot(gs[0])
134 | ax.set_xticks([])
135 | ax.set_yticks([])
136 | ax.set_xlim([0, 3000-3])
137 | ax.plot(ts, 'k', linewidth=1)
138 | if point:
139 | motif = [
140 | [centery, centerx+1, 'ro--', 4]
141 | ]
142 | for start, gap, ty, s in motif:
143 | ix = np.arange(start, start + gap*3, gap)
144 | ax.plot(ix, ts[ix], ty, markersize=s)
145 |
146 | ax.set_ylim([-4, 8.5])
147 | ts_flip = flip_ECG(ts)
148 | R_peak = detectR(ts_flip)
149 | R_peak.sort()
150 | for i in R_peak:
151 | ax.plot([i, i], [-4, 8.5], 'r--', linewidth=0.5)
152 |
153 | ax = plt.subplot(gs[1])
154 | ax.imshow(img, cmap=plt.cm.jet, vmax=1, vmin=0)
155 | if point:
156 | ax.plot([centery], [centerx], 'r+', markersize=10, alpha=0.5)
157 |
158 | horizon_move = 100
159 | if peak_idx:
160 | rect = patches.Rectangle([R_peak[peak_idx]-50, 0],50*2,30*2,linestyle='--',linewidth=1,edgecolor='w',facecolor='none')
161 | bbox_props = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.1)
162 | ax.text(R_peak[peak_idx]+horizon_move, 160, "QRS", color='w', ha="center", va="center", size=14,
163 | bbox=bbox_props)
164 | ax.add_patch(rect)
165 |
166 | rect = patches.Rectangle([R_peak[peak_idx]-170, 0],50*2,30*2,linestyle='--',linewidth=1,edgecolor='w',facecolor='none')
167 | bbox_props = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.1)
168 | ax.text(R_peak[peak_idx]-250+horizon_move, 160, "P", color='w', ha="center", va="center", size=14,
169 | bbox=bbox_props)
170 | ax.add_patch(rect)
171 | ax.set_xticks([])
172 | ax.set_yticks([])
173 | plt.tight_layout()
174 | if path:
175 | fig.savefig(path, dpi=300)
176 |
177 | def plot_basic_tsne(X_embedded, val_y, alpha=1, s=1, path=None):
178 | '''
179 | draw tsne figure of AF and non-AF
180 | :param X_embedded: 2-d array, shape is [N, 2], the low-dimensional points
181 | :param val_y: 1-d array, shape is [N], the label of the data, 1 and 0 indicated the AF and non-AF
182 | :param alpha: float, alpha of the scatters
183 | :param s: float, size of the scatters
184 | :param path: str, path of saving the figure
185 | '''
186 | fig = plt.figure(figsize=(6, 6))
187 | plt.scatter(x=X_embedded[val_y==1][:,0], y=X_embedded[val_y==1][:,1],color='r',s=s,alpha=alpha,edgecolors="none")
188 | plt.scatter(x=X_embedded[val_y==0][:,0], y=X_embedded[val_y==0][:,1],color='b',s=s,alpha=alpha,edgecolors="none")
189 | lgnd = plt.legend(['AF', 'non-AF'],fancybox=False, framealpha=0,fontsize=15, loc='lower left')
190 |
191 | for handle in lgnd.legendHandles:
192 | handle.set_sizes([12.0])
193 | handle.set_alpha(1.0)
194 | ax = plt.gca()
195 | ax.spines['top'].set_visible(False)
196 | ax.spines['right'].set_visible(False)
197 | ax.spines['bottom'].set_visible(False)
198 | ax.spines['left'].set_visible(False)
199 |
200 | ax.set_xticks([])
201 | ax.set_yticks([])
202 |
203 | fig.tight_layout()
204 | if path:
205 | fig.savefig(path, dpi=300)
206 |
207 | def plot_AF_pid(X_embedded, val_y, pid_test, path=None):
208 | '''
209 | draw tsne figure of AF according to the id of patients (pid)
210 | :param X_embedded: 2-d array, shape is [N, 2], the low-dimensional points
211 | :param val_y: 1-d array, shape is [N], the label of the data, 1 and 0 indicated the AF and non-AF
212 | :param pid_test: 1-d array, shape is [N], the id of the patients
213 | :param path: str, path of saving the figure
214 | '''
215 | fig = plt.figure(figsize=(6, 6))
216 | plt.scatter(x=X_embedded[val_y==1][:,0], y=X_embedded[val_y==1][:,1],
217 | s=0.5,alpha=1,c=pid_test[val_y==1], cmap='jet')
218 |
219 | x, y, s = X_embedded[val_y==1][:,0], X_embedded[val_y==1][:,1], pid_test[val_y==1]
220 | for i in tqdm(np.unique(s)):
221 | for xi, yi in zip(x[s==i], y[s==i]):
222 | plt.plot([np.mean(x[s==i]), xi], [np.mean(y[s==i]), yi], 'k', linewidth=0.1, alpha=0.2)
223 | ax = plt.gca()
224 | ax.spines['top'].set_visible(False)
225 | ax.spines['right'].set_visible(False)
226 | ax.spines['bottom'].set_visible(False)
227 | ax.spines['left'].set_visible(False)
228 |
229 | ax.set_xticks([])
230 | ax.set_yticks([])
231 | fig.tight_layout()
232 | if path:
233 | fig.savefig(path, dpi=300)
234 |
--------------------------------------------------------------------------------
/model/weights_tail_best.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ydup/Anomaly-Detection-in-Time-Series-with-Triadic-Motif-Fields/e1f4ded660e81985223e49a1095dda2a4f654f40/model/weights_tail_best.hdf5
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | mpi4py==3.0.3
3 | numba==0.50.1
4 | scikit-learn==0.23.0
5 | scipy==1.5.2
6 | tensorflow==1.14.0
7 | opencv-python
8 | tqdm
9 | PyQT5
--------------------------------------------------------------------------------