├── CNN ├── run_resnet.py ├── run_vggnet.py └── source │ ├── datamanager.py │ ├── neuralnet_resnet34.py │ ├── neuralnet_vggnet16.py │ └── tf_process.py ├── LICENSE ├── README.md ├── figures ├── normal.png ├── qrs_lead_avr.png ├── qrs_lead_i.png ├── qrs_voted.png ├── resnet34.png ├── stemi.png └── vggnet16.png └── preprocessing.py /CNN/run_resnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import tensorflow as tf 4 | 5 | import source.neuralnet_resnet34 as nn 6 | import source.datamanager as dman 7 | import source.tf_process as tfp 8 | 9 | def main(): 10 | 11 | dataset = dman.DataSet(setname=FLAGS.setname, tr_ratio=FLAGS.tr_ratio) 12 | 13 | neuralnet = nn.ConvNet(data_dim=dataset.data_dim, channel=dataset.channel, num_class=dataset.num_class, learning_rate=FLAGS.lr) 14 | 15 | sess = tf.InteractiveSession() 16 | sess.run(tf.global_variables_initializer()) 17 | saver = tf.train.Saver() 18 | 19 | tfp.training(sess=sess, neuralnet=neuralnet, saver=saver, dataset=dataset, epochs=FLAGS.epoch, batch_size=FLAGS.batch, dropout=FLAGS.dropout) 20 | tfp.validation(sess=sess, neuralnet=neuralnet, saver=saver, dataset=dataset) 21 | 22 | if __name__ == '__main__': 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--epoch', type=int, default=300, help='Number of epoch for training') 26 | parser.add_argument('--batch', type=int, default=200, help='Mini-batch size for training') 27 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for training') 28 | parser.add_argument('--dropout', type=float, default=1, help='Dropout ratio for training.') 29 | parser.add_argument('--setname', type=str, default="dataset_BP", help='Name of dataset for use.') 30 | parser.add_argument('--tr_ratio', type=float, default=0.9, help='Ratio of patient for training to total patient.') 31 | 32 | FLAGS, unparsed = parser.parse_known_args() 33 | 34 | main() 35 | -------------------------------------------------------------------------------- /CNN/run_vggnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import tensorflow as tf 4 | 5 | import source.neuralnet_vggnet16 as nn 6 | import source.datamanager as dman 7 | import source.tf_process as tfp 8 | 9 | def main(): 10 | 11 | dataset = dman.DataSet(setname=FLAGS.setname, tr_ratio=FLAGS.tr_ratio) 12 | 13 | neuralnet = nn.ConvNet(data_dim=dataset.data_dim, channel=dataset.channel, num_class=dataset.num_class, learning_rate=FLAGS.lr) 14 | 15 | sess = tf.InteractiveSession() 16 | sess.run(tf.global_variables_initializer()) 17 | saver = tf.train.Saver() 18 | 19 | tfp.training(sess=sess, neuralnet=neuralnet, saver=saver, dataset=dataset, epochs=FLAGS.epoch, batch_size=FLAGS.batch, dropout=FLAGS.dropout) 20 | tfp.validation(sess=sess, neuralnet=neuralnet, saver=saver, dataset=dataset) 21 | 22 | if __name__ == '__main__': 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--epoch', type=int, default=300, help='Number of epoch for training') 26 | parser.add_argument('--batch', type=int, default=200, help='Mini-batch size for training') 27 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for training') 28 | parser.add_argument('--dropout', type=float, default=1, help='Dropout ratio for training.') 29 | parser.add_argument('--setname', type=str, default="dataset_BP", help='Name of dataset for use.') 30 | parser.add_argument('--tr_ratio', type=float, default=0.9, help='Ratio of patient for training to total patient.') 31 | 32 | FLAGS, unparsed = parser.parse_known_args() 33 | 34 | main() 35 | -------------------------------------------------------------------------------- /CNN/source/datamanager.py: -------------------------------------------------------------------------------- 1 | import os, inspect, glob, random, shutil 2 | import numpy as np 3 | 4 | PACK_PATH = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))+"/.." 5 | 6 | class DataSet(object): 7 | 8 | def __init__(self, setname, tr_ratio=0.9): 9 | 10 | self.data_path = os.path.join(PACK_PATH, "..", setname) 11 | self.tr_ratio = tr_ratio 12 | self.split_train_test() 13 | 14 | self.data4nn = os.path.join(PACK_PATH, "dataset") 15 | self.path_classes = glob.glob(os.path.join(self.data4nn, "train", "*")) 16 | self.path_classes.sort() 17 | self.class_names = [] 18 | for path_class in self.path_classes: self.class_names.append(path_class.split("/")[-1]) 19 | 20 | self.num_class = len(self.class_names) 21 | 22 | self.npy_train = glob.glob(os.path.join(self.data4nn, "train", "*", "*.npy")) 23 | self.npy_test = glob.glob(os.path.join(self.data4nn, "test", "*", "*.npy")) 24 | self.npy_train.sort() 25 | self.boundtxt = self.npy_train[0].split("/")[-2] 26 | for cidx, clsname in enumerate(self.class_names): 27 | if(clsname in self.boundtxt): 28 | bndidx = cidx 29 | break 30 | self.bounds = [0] 31 | for bi in range(bndidx): 32 | self.bounds.append(0) 33 | for idx, _ in enumerate(self.npy_train): 34 | if(self.boundtxt in self.npy_train[idx]): pass 35 | else: 36 | for cidx, clsname in enumerate(self.class_names): 37 | if(clsname in self.npy_train[idx-1]): 38 | idx_pri = cidx 39 | if(clsname in self.npy_train[idx]): 40 | idx_pos = cidx 41 | 42 | if(bndidx == idx_pri): 43 | self.bounds.append(idx) 44 | for _ in range(abs(idx_pri-idx_pos)-1): 45 | self.bounds.append(idx) 46 | else: 47 | for _ in range(abs(idx_pri-idx_pos)-1): 48 | self.bounds.append(idx) 49 | self.bounds.append(idx) 50 | self.boundtxt = self.npy_train[idx].split("/")[-2] 51 | bndidx = idx_pos 52 | for _ in range(abs(self.num_class-bndidx)-1): 53 | self.bounds.append(idx) 54 | 55 | random.shuffle(self.npy_train) 56 | self.npy_test.sort() 57 | 58 | self.idx_tr = 0 59 | self.idx_te = 0 60 | 61 | self.amount_tr = len(self.npy_train) 62 | self.amount_te = len(self.npy_test) 63 | print("Training: %d" %(self.amount_tr)) 64 | print("Test: %d" %(self.amount_te)) 65 | 66 | sample = np.load(self.npy_train[0]) 67 | self.data_dim = sample.shape[0] 68 | self.channel = sample.shape[1] 69 | 70 | def makedir(self, path): 71 | try: os.mkdir(path) 72 | except: pass 73 | 74 | def split_train_test(self): 75 | 76 | subdirs = glob.glob(os.path.join(self.data_path, "*")) 77 | subdirs.sort() 78 | 79 | list_cls = [] 80 | for sdix, sdir in enumerate(subdirs): 81 | npys = glob.glob(os.path.join(sdir, "*")) 82 | npys.sort() 83 | 84 | patnum = "_" 85 | patlist = [] 86 | for nidx, npy in enumerate(npys): 87 | tmpnum = npy.split(",")[0].split("/")[-1] 88 | if(patnum == tmpnum): pass 89 | else: 90 | patnum = tmpnum 91 | patlist.append(patnum) 92 | 93 | numtr = int(len(patlist) * self.tr_ratio) 94 | 95 | random.shuffle(patlist) 96 | 97 | trtelist = [] 98 | trtelist.append(patlist[:numtr]) 99 | trtelist.append(patlist[numtr:]) 100 | 101 | list_cls.append(trtelist) 102 | 103 | 104 | try: shutil.rmtree("dataset") 105 | except: pass 106 | self.makedir("dataset") 107 | self.makedir(os.path.join("dataset", "train")) 108 | self.makedir(os.path.join("dataset", "test")) 109 | 110 | for cidx, cdir in enumerate(subdirs): 111 | 112 | clsname = cdir.split("/")[-1] 113 | self.makedir(os.path.join("dataset", "train", clsname)) 114 | self.makedir(os.path.join("dataset", "test", clsname)) 115 | 116 | npylist = glob.glob(os.path.join(cdir, "*.npy")) 117 | npylist.sort() 118 | 119 | for fidx, foldlist in enumerate(list_cls[cidx]): 120 | for fold_content in foldlist: 121 | for npyname in npylist: 122 | if(fold_content in npyname): 123 | tmp = np.load(npyname) 124 | if(fidx == 1): np.save(os.path.join("dataset", "test", clsname, npyname.split("/")[-1]), tmp) 125 | else: np.save(os.path.join("dataset", "train", clsname, npyname.split("/")[-1]), tmp) 126 | 127 | def split_cls(self, pathlist, bound_start, bound_end): 128 | 129 | try: return pathlist[bound_start:bound_end] 130 | except: return [] 131 | 132 | def next_batch(self, batch_size=1, train=False): 133 | 134 | data = np.zeros((0, 1, 1)) 135 | label = np.zeros((0, self.num_class)) 136 | if(train): 137 | while(True): 138 | np_data = np.load(self.npy_train[self.idx_tr]) 139 | for cidx, clsname in enumerate(self.class_names): 140 | if(clsname in self.npy_train[self.idx_tr]): tmp_label = cidx 141 | label_vector = np.eye(self.num_class)[tmp_label] 142 | 143 | if(data.shape[0] == 0): data = np.zeros((0, np_data.shape[0], np_data.shape[1])) 144 | 145 | np_data = np.expand_dims(np_data, axis=0) 146 | label_vector = np.expand_dims(label_vector, axis=0) 147 | data = np.append(data, np_data, axis=0) 148 | label = np.append(label, label_vector, axis=0) 149 | 150 | if(data.shape[0] >= batch_size): break 151 | else: self.idx_tr = (self.idx_tr + 1) % self.amount_tr 152 | 153 | return data, label 154 | 155 | else: 156 | if(self.idx_te >= self.amount_te): 157 | self.idx_te = 0 158 | return None, None, None 159 | 160 | tmppath = self.npy_test[self.idx_te] 161 | np_data = np.load(self.npy_test[self.idx_te]) 162 | for cidx, clsname in enumerate(self.class_names): 163 | if(clsname in self.npy_test[self.idx_te]): tmp_label = cidx 164 | try: label_vector = np.eye(self.num_class)[tmp_label] 165 | except: label_vector = np.zeros(self.num_class) 166 | 167 | if(data.shape[0] == 0): data = np.zeros((0, np_data.shape[0], np_data.shape[1])) 168 | 169 | np_data = np.expand_dims(np_data, axis=0) 170 | label_vector = np.expand_dims(label_vector, axis=0) 171 | data = np.append(data, np_data, axis=0) 172 | label = np.append(label, label_vector, axis=0) 173 | 174 | self.idx_te += 1 175 | 176 | return data, label, tmppath 177 | -------------------------------------------------------------------------------- /CNN/source/neuralnet_resnet34.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 3 | import tensorflow as tf 4 | 5 | class ConvNet(object): 6 | 7 | def __init__(self, data_dim, channel, num_class, learning_rate): 8 | 9 | print("\n** Initialize CNN Layers") 10 | self.num_class = num_class 11 | self.inputs = tf.placeholder(tf.float32, [None, data_dim, channel]) 12 | self.labels = tf.placeholder(tf.float32, [None, self.num_class]) 13 | self.dropout_prob = tf.placeholder(tf.float32, shape=[]) 14 | print("Input: "+str(self.inputs.shape)) 15 | 16 | fc = self.convnet_module() 17 | self.score = tf.nn.softmax(fc) 18 | 19 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=fc, labels=self.labels)) 20 | self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss) 21 | 22 | self.pred = tf.argmax(self.score, 1) 23 | self.correct = tf.equal(self.pred, tf.argmax(self.labels, 1)) 24 | self.accuracy = tf.reduce_mean(tf.cast(self.correct, tf.float32)) 25 | 26 | tf.summary.scalar('loss', self.loss) 27 | tf.summary.scalar('accuracy', self.accuracy) 28 | self.summaries = tf.summary.merge_all() 29 | 30 | def convnet_module(self): 31 | 32 | conv1_1 = self.convolution(inputs=self.inputs, filters=64, k_size=7, stride=1, padding="SAME") 33 | pool1 = self.maxpool(inputs=conv1_1, pool_size=2) 34 | 35 | conv2_1 = self.convolution(inputs=pool1, filters=64, k_size=3, stride=1, padding="SAME") 36 | conv2_2 = self.convolution(inputs=conv2_1, filters=64, k_size=3, stride=1, padding="SAME") 37 | conv2_concat = pool1 + conv2_2 38 | 39 | conv3_1 = self.convolution(inputs=conv2_concat, filters=64, k_size=3, stride=1, padding="SAME") 40 | conv3_2 = self.convolution(inputs=conv3_1, filters=64, k_size=3, stride=1, padding="SAME") 41 | conv3_concat = conv2_concat + conv3_2 42 | 43 | conv4_1 = self.convolution(inputs=conv3_concat, filters=64, k_size=3, stride=1, padding="SAME") 44 | conv4_2 = self.convolution(inputs=conv4_1, filters=64, k_size=3, stride=1, padding="SAME") 45 | conv4_concat = conv3_concat + conv4_2 46 | 47 | """=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=""" 48 | pool2 = self.maxpool(inputs=conv4_concat, pool_size=2) 49 | 50 | conv5_1 = self.convolution(inputs=pool2, filters=128, k_size=3, stride=1, padding="SAME") 51 | conv5_2 = self.convolution(inputs=conv5_1, filters=128, k_size=3, stride=1, padding="SAME") 52 | conv5_res = self.convolution(inputs=pool2, filters=128, k_size=1, stride=1, padding="SAME") 53 | conv5_concat = conv5_res + conv5_2 54 | 55 | conv6_1 = self.convolution(inputs=conv5_concat, filters=128, k_size=3, stride=1, padding="SAME") 56 | conv6_2 = self.convolution(inputs=conv6_1, filters=128, k_size=3, stride=1, padding="SAME") 57 | conv6_concat = conv5_concat + conv6_2 58 | 59 | conv7_1 = self.convolution(inputs=conv6_concat, filters=128, k_size=3, stride=1, padding="SAME") 60 | conv7_2 = self.convolution(inputs=conv7_1, filters=128, k_size=3, stride=1, padding="SAME") 61 | conv7_concat = conv6_concat + conv7_2 62 | 63 | conv8_1 = self.convolution(inputs=conv7_concat, filters=128, k_size=3, stride=1, padding="SAME") 64 | conv8_2 = self.convolution(inputs=conv8_1, filters=128, k_size=3, stride=1, padding="SAME") 65 | conv8_concat = conv7_concat + conv8_2 66 | 67 | """=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=""" 68 | pool3 = self.maxpool(inputs=conv8_concat, pool_size=2) 69 | 70 | conv9_1 = self.convolution(inputs=pool3, filters=256, k_size=3, stride=1, padding="SAME") 71 | conv9_2 = self.convolution(inputs=conv9_1, filters=256, k_size=3, stride=1, padding="SAME") 72 | conv9_res = self.convolution(inputs=pool3, filters=256, k_size=1, stride=1, padding="SAME") 73 | conv9_concat = conv9_res + conv9_2 74 | 75 | conv10_1 = self.convolution(inputs=conv9_concat, filters=256, k_size=3, stride=1, padding="SAME") 76 | conv10_2 = self.convolution(inputs=conv10_1, filters=256, k_size=3, stride=1, padding="SAME") 77 | conv10_concat = conv9_concat + conv10_2 78 | 79 | conv11_1 = self.convolution(inputs=conv10_concat, filters=256, k_size=3, stride=1, padding="SAME") 80 | conv11_2 = self.convolution(inputs=conv11_1, filters=256, k_size=3, stride=1, padding="SAME") 81 | conv11_concat = conv10_concat + conv11_2 82 | 83 | conv12_1 = self.convolution(inputs=conv11_concat, filters=256, k_size=3, stride=1, padding="SAME") 84 | conv12_2 = self.convolution(inputs=conv12_1, filters=256, k_size=3, stride=1, padding="SAME") 85 | conv12_concat = conv11_concat + conv12_2 86 | 87 | conv13_1 = self.convolution(inputs=conv12_concat, filters=256, k_size=3, stride=1, padding="SAME") 88 | conv13_2 = self.convolution(inputs=conv13_1, filters=256, k_size=3, stride=1, padding="SAME") 89 | conv13_concat = conv12_concat + conv13_2 90 | 91 | conv14_1 = self.convolution(inputs=conv13_concat, filters=256, k_size=3, stride=1, padding="SAME") 92 | conv14_2 = self.convolution(inputs=conv14_1, filters=256, k_size=3, stride=1, padding="SAME") 93 | conv14_concat = conv13_concat + conv14_2 94 | 95 | """=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=""" 96 | pool4 = self.maxpool(inputs=conv14_concat, pool_size=2) 97 | 98 | conv15_1 = self.convolution(inputs=pool4, filters=512, k_size=3, stride=1, padding="SAME") 99 | conv15_2 = self.convolution(inputs=conv15_1, filters=512, k_size=3, stride=1, padding="SAME") 100 | conv15_res = self.convolution(inputs=pool4, filters=512, k_size=1, stride=1, padding="SAME") 101 | conv15_concat = conv15_res + conv15_2 102 | 103 | conv16_1 = self.convolution(inputs=conv15_concat, filters=512, k_size=3, stride=1, padding="SAME") 104 | conv16_2 = self.convolution(inputs=conv16_1, filters=512, k_size=3, stride=1, padding="SAME") 105 | conv16_concat = conv15_concat + conv16_2 106 | 107 | conv17_1 = self.convolution(inputs=conv16_concat, filters=512, k_size=3, stride=1, padding="SAME") 108 | conv17_2 = self.convolution(inputs=conv17_1, filters=512, k_size=3, stride=1, padding="SAME") 109 | conv17_concat = conv16_concat + conv17_2 110 | 111 | pool5 = self.avgpool(inputs=conv17_concat, pool_size=2) 112 | 113 | flat = self.flatten(inputs=pool5) 114 | 115 | fc1 = self.fully_connected(inputs=flat, num_outputs=self.num_class, activate_fn=None) 116 | 117 | return fc1 118 | 119 | def convolution(self, inputs=None, filters=32, k_size=3, stride=1, padding="SAME"): 120 | 121 | xavier = tf.contrib.layers.xavier_initializer() 122 | 123 | conv = tf.layers.conv1d(inputs=inputs, filters=filters, kernel_size=k_size, strides=1, 124 | padding=padding, data_format='channels_last', dilation_rate=1, 125 | activation=tf.nn.relu, use_bias=True, 126 | kernel_initializer=tf.contrib.keras.initializers.he_normal(), bias_initializer=tf.contrib.keras.initializers.he_normal(), 127 | kernel_regularizer=None, bias_regularizer=None, 128 | activity_regularizer=None, trainable=True, name=None, reuse=None) 129 | 130 | print("Convolution: "+str(conv.shape)) 131 | return conv 132 | 133 | def relu(self, inputs=None): 134 | 135 | re = tf.nn.relu(features=inputs, name=None) 136 | 137 | print("ReLU: "+str(re.shape)) 138 | return re 139 | 140 | def maxpool(self, inputs=None, pool_size=2): 141 | 142 | maxp = tf.layers.max_pooling1d(inputs=inputs, pool_size=pool_size, strides=pool_size, padding='SAME', data_format='channels_last', name=None) 143 | 144 | print("Max Pool: "+str(maxp.shape)) 145 | return maxp 146 | 147 | def avgpool(self, inputs=None, pool_size=2): 148 | avgp = tf.layers.average_pooling1d(inputs=inputs, pool_size=pool_size, strides=pool_size, padding='SAME', data_format='channels_last', name=None) 149 | 150 | print("Avg Pool: "+str(avgp.shape)) 151 | return avgp 152 | 153 | def flatten(self, inputs=None): 154 | 155 | flat = tf.contrib.layers.flatten(inputs=inputs) 156 | 157 | print("Flatten: "+str(flat.shape)) 158 | return flat 159 | 160 | def fully_connected(self, inputs=None, num_outputs=None, activate_fn=None): 161 | 162 | full_con = tf.contrib.layers.fully_connected(inputs=inputs, num_outputs=num_outputs, 163 | activation_fn=activate_fn, normalizer_fn=None, normalizer_params=None, 164 | weights_initializer=tf.contrib.keras.initializers.he_normal(), weights_regularizer=None, 165 | biases_initializer=tf.contrib.keras.initializers.he_normal(), biases_regularizer=None, reuse=None, 166 | variables_collections=None, outputs_collections=None, trainable=True, scope=None) 167 | 168 | print("Fully Connected: "+str(full_con.shape)) 169 | return full_con 170 | -------------------------------------------------------------------------------- /CNN/source/neuralnet_vggnet16.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 3 | import tensorflow as tf 4 | 5 | class ConvNet(object): 6 | 7 | def __init__(self, data_dim, channel, num_class, learning_rate): 8 | 9 | print("\n** Initialize CNN Layers") 10 | self.num_class = num_class 11 | self.inputs = tf.placeholder(tf.float32, [None, data_dim, channel]) 12 | self.labels = tf.placeholder(tf.float32, [None, self.num_class]) 13 | self.dropout_prob = tf.placeholder(tf.float32, shape=[]) 14 | print("Input: "+str(self.inputs.shape)) 15 | 16 | fc = self.convnet_module() 17 | self.score = tf.nn.softmax(fc) 18 | 19 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=fc, labels=self.labels)) 20 | self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss) 21 | 22 | self.pred = tf.argmax(self.score, 1) 23 | self.correct = tf.equal(self.pred, tf.argmax(self.labels, 1)) 24 | self.accuracy = tf.reduce_mean(tf.cast(self.correct, tf.float32)) 25 | 26 | tf.summary.scalar('loss', self.loss) 27 | tf.summary.scalar('accuracy', self.accuracy) 28 | self.summaries = tf.summary.merge_all() 29 | 30 | def convnet_module(self): 31 | 32 | conv1_1 = self.convolution(inputs=self.inputs, filters=64, k_size=3, stride=1, padding="SAME") 33 | conv1_2 = self.convolution(inputs=conv1_1, filters=64, k_size=3, stride=1, padding="SAME") 34 | pool1 = self.maxpool(inputs=conv1_2, pool_size=2) 35 | 36 | conv2_1 = self.convolution(inputs=pool1, filters=128, k_size=3, stride=1, padding="SAME") 37 | conv2_2 = self.convolution(inputs=conv2_1, filters=128, k_size=3, stride=1, padding="SAME") 38 | pool2 = self.maxpool(inputs=conv2_2, pool_size=2) 39 | 40 | conv3_1 = self.convolution(inputs=pool2, filters=256, k_size=3, stride=1, padding="SAME") 41 | conv3_2 = self.convolution(inputs=conv3_1, filters=256, k_size=3, stride=1, padding="SAME") 42 | conv3_3 = self.convolution(inputs=conv3_2, filters=256, k_size=3, stride=1, padding="SAME") 43 | pool3 = self.maxpool(inputs=conv3_3, pool_size=2) 44 | 45 | conv4_1 = self.convolution(inputs=pool3, filters=512, k_size=3, stride=1, padding="SAME") 46 | conv4_2 = self.convolution(inputs=conv4_1, filters=512, k_size=3, stride=1, padding="SAME") 47 | conv4_3 = self.convolution(inputs=conv4_2, filters=512, k_size=3, stride=1, padding="SAME") 48 | pool4 = self.maxpool(inputs=conv4_3, pool_size=2) 49 | 50 | conv5_1 = self.convolution(inputs=pool4, filters=512, k_size=3, stride=1, padding="SAME") 51 | conv5_2 = self.convolution(inputs=conv5_1, filters=512, k_size=3, stride=1, padding="SAME") 52 | conv5_3 = self.convolution(inputs=conv5_2, filters=512, k_size=3, stride=1, padding="SAME") 53 | pool5 = self.maxpool(inputs=conv5_3, pool_size=2) 54 | 55 | flat = self.flatten(inputs=pool5) 56 | 57 | fc1 = self.fully_connected(inputs=flat, num_outputs=4096, activate_fn=tf.nn.relu) 58 | drop1 = tf.nn.dropout(fc1, keep_prob=self.dropout_prob) 59 | fc2 = self.fully_connected(inputs=drop1, num_outputs=4096, activate_fn=tf.nn.relu) 60 | drop2 = tf.nn.dropout(fc2, keep_prob=self.dropout_prob) 61 | fc3 = self.fully_connected(inputs=drop2, num_outputs=self.num_class, activate_fn=None) 62 | 63 | return fc3 64 | 65 | def convolution(self, inputs=None, filters=32, k_size=3, stride=1, padding="SAME"): 66 | 67 | xavier = tf.contrib.layers.xavier_initializer() 68 | 69 | conv = tf.layers.conv1d(inputs=inputs, filters=filters, kernel_size=k_size, strides=1, 70 | padding=padding, data_format='channels_last', dilation_rate=1, 71 | activation=tf.nn.relu, use_bias=True, 72 | kernel_initializer=tf.contrib.keras.initializers.he_normal(), bias_initializer=tf.contrib.keras.initializers.he_normal(), 73 | kernel_regularizer=None, bias_regularizer=None, 74 | activity_regularizer=None, trainable=True, name=None, reuse=None) 75 | 76 | print("Convolution: "+str(conv.shape)) 77 | return conv 78 | 79 | def relu(self, inputs=None): 80 | 81 | re = tf.nn.relu(features=inputs, name=None) 82 | 83 | print("ReLU: "+str(re.shape)) 84 | return re 85 | 86 | def maxpool(self, inputs=None, pool_size=2): 87 | 88 | maxp = tf.layers.max_pooling1d(inputs=inputs, pool_size=pool_size, strides=pool_size, padding='SAME', data_format='channels_last', name=None) 89 | 90 | print("Max Pool: "+str(maxp.shape)) 91 | return maxp 92 | 93 | def flatten(self, inputs=None): 94 | 95 | flat = tf.contrib.layers.flatten(inputs=inputs) 96 | 97 | print("Flatten: "+str(flat.shape)) 98 | return flat 99 | 100 | def fully_connected(self, inputs=None, num_outputs=None, activate_fn=None): 101 | 102 | full_con = tf.contrib.layers.fully_connected(inputs=inputs, num_outputs=num_outputs, 103 | activation_fn=activate_fn, normalizer_fn=None, normalizer_params=None, 104 | weights_initializer=tf.contrib.keras.initializers.he_normal(), weights_regularizer=None, 105 | biases_initializer=tf.contrib.keras.initializers.he_normal(), biases_regularizer=None, reuse=None, 106 | variables_collections=None, outputs_collections=None, trainable=True, scope=None) 107 | 108 | print("Fully Connected: "+str(full_con.shape)) 109 | return full_con 110 | -------------------------------------------------------------------------------- /CNN/source/tf_process.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import os, inspect, time 4 | import tensorflow as tf 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | PACK_PATH = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))+"/.." 9 | 10 | def loss_record(data): 11 | 12 | np.save("loss", np.asarray(data)) 13 | plt.clf() 14 | plt.rcParams['font.size'] = 15 15 | plt.plot(data) 16 | plt.ylabel("Cross-Entropy loss") 17 | plt.xlabel("Iteration") 18 | plt.tight_layout(pad=1, w_pad=1, h_pad=1) 19 | plt.savefig("loss.png") 20 | plt.close() 21 | 22 | def acc_record(data): 23 | 24 | np.save("accuracy", np.asarray(data)) 25 | plt.clf() 26 | plt.rcParams['font.size'] = 15 27 | plt.plot(data) 28 | plt.ylabel("Accuracy") 29 | plt.xlabel("Iteration") 30 | plt.tight_layout(pad=1, w_pad=1, h_pad=1) 31 | plt.savefig("accuracy.png") 32 | plt.close() 33 | 34 | def training(sess, neuralnet, saver, dataset, epochs, batch_size, dropout, print_step=10): 35 | 36 | print("\n* Training to %d epochs (%d of minibatch size)" %(epochs, batch_size)) 37 | 38 | train_writer = tf.summary.FileWriter(PACK_PATH+'/Checkpoint') 39 | 40 | epoch = 0 41 | start_time = time.time() 42 | list_loss, list_acc = [], [] 43 | for it in range(int((dataset.amount_tr / batch_size) * epochs)): 44 | X_tr, Y_tr = dataset.next_batch(batch_size=batch_size, train=True) 45 | _ = sess.run([neuralnet.optimizer], feed_dict={neuralnet.inputs:X_tr, neuralnet.labels:Y_tr, neuralnet.dropout_prob:dropout}) 46 | tmp_loss, tmp_acc = sess.run([neuralnet.loss, neuralnet.accuracy], feed_dict={neuralnet.inputs:X_tr, neuralnet.labels:Y_tr, neuralnet.dropout_prob:1}) 47 | 48 | summaries = sess.run(neuralnet.summaries, feed_dict={neuralnet.inputs:X_tr, neuralnet.labels:Y_tr, neuralnet.dropout_prob:1}) 49 | train_writer.add_summary(summaries, it) 50 | 51 | list_loss.append(tmp_loss) 52 | list_acc.append(tmp_acc) 53 | 54 | if(it > (dataset.amount_tr/batch_size)*epoch): 55 | if(epoch % print_step == 0): print("Epoch [%d / %d] \nLoss: %.5f \tAccuracy: %.5f" %(epoch, epochs, tmp_loss, tmp_acc)) 56 | saver.save(sess, PACK_PATH+"/Checkpoint/model_checker") 57 | epoch += 1 58 | 59 | print("Final Epoch \nLoss: %.5f \tAccuracy: %.5f" %(tmp_loss, tmp_acc)) 60 | elapsed_time = time.time() - start_time 61 | print("Elapsed: "+str(elapsed_time)) 62 | 63 | loss_record(data=list_loss) 64 | acc_record(data=list_acc) 65 | 66 | def validation(sess, neuralnet, saver, dataset): 67 | 68 | if(os.path.exists(PACK_PATH+"/Checkpoint/model_checker.index")): 69 | saver.restore(sess, PACK_PATH+"/Checkpoint/model_checker") 70 | 71 | confmat = np.zeros((dataset.num_class, dataset.num_class)) 72 | while(True): 73 | X_te, Y_te, path_te = dataset.next_batch(batch_size=1) 74 | if(X_te is None): break 75 | preds, scores = sess.run([neuralnet.pred, neuralnet.score], feed_dict={neuralnet.inputs:X_te, neuralnet.labels:Y_te, neuralnet.dropout_prob:1}) 76 | idx_y, idx_p = np.argmax(Y_te, axis=1)[0], preds[0] 77 | 78 | for cidx, clsname in enumerate(dataset.class_names): 79 | if(clsname in path_te): confmat[cidx][idx_p] += 1 80 | 81 | print("Confusion Matrix") 82 | print(confmat) 83 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 YeongHyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Preprocessing Method for Performance Enhancement in CNN-based STEMI Detection from 12-lead ECG 2 | ===== 3 | 4 | This repository provides the source code of the paper "Preprocessing Method for Performance Enhancement in CNN-based STEMI Detection from 12-lead ECG" [pdf]. 5 | 6 | ## STEMI: ST-elevation myocardial infarction 7 | The complete occlusion of the coronary artery is called STEMI and it is characterized by a sudden shut-down of blood flow caused by thrombus or embolism. In the ECG of STEMI, the ST-segment is more elevated than the normal. 8 | 9 |
10 | 11 | 12 |

The ECG of normal (upper) and STEMI (lower)

13 |
14 | 15 | ## Requirements 16 | * Python 3.5.2 17 | * Tensorflow 1.4.0 18 | * Numpy 1.13.3 19 | * Scipy 1.2.0 20 | * WFDB 2.2.1 21 | * Matplotlib 3.0.2 22 | 23 | 24 | ## Usage 25 | ### Preparing the dataset 26 | First, Organize the dataset and keep as below. Keep the channel (or lead) information in axis 0, and time information in axis 1. For example, 12 lead data with 5500-time length (500Hz x 11 seconds) should be saved as `(12, 5500)` form. 27 | ``` 28 | Raw_ECG 29 | ├── Normal 30 | │ ├── data_1 31 | │ ├── data_2 32 | │ ├── data_3 33 | │ │ ... 34 | │ └── data_n 35 | └── STEMI 36 | ``` 37 | Then, run the python script as following. 38 | ``` 39 | $ python preprocessing.py 40 | $ python preprocessing.py --help // for confirming the option 41 | ``` 42 | 43 | In the process of executing the above source code, pulse segmentation is performed by voting and choosing the location of the QRS complex as shown below. 44 | 45 |
46 | 47 | 48 | 49 |

The top of the figure shows location voting result for finding the QRS complex. Lower two figure shows most voted time location as QRS complex at lead-I and lead-aVR respectively.

50 |
51 | 52 | ### Training and Test 53 | ``` 54 | $ cd CNN 55 | $ python run_resnet.py 56 | ``` 57 | Using `run_vggnet.py` instead of `run_resnet.py` is available for training VGGNet. 58 | 59 |
60 | 61 |

1D-VGGNet-16

62 | 63 |

1D-ResNet-34

64 |
65 | 66 | ### Pre-Trained CNN 67 | If you want to use pre-trained model it is available at Google Drive. Saved parameters are provided for 1D-VGGNet-16 and 1D-ResNet-34 respectively. 68 | 69 | Available since Mar.31.2019 70 | 71 | ### BibTeX 72 | ``` 73 | @Article{8771175, 74 | author={Park, YeongHyeon and Yun, Il Dong and Kang, Si-Hyuck}, 75 | journal={IEEE Access}, 76 | title={Preprocessing Method for Performance Enhancement in CNN-Based STEMI Detection From 12-Lead ECG}, 77 | year={2019}, 78 | volume={7}, 79 | pages={99964-99977}, 80 | ISSN={2169-3536}, 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /figures/normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/Preprocessing-Method-for-STEMI-Detection/dc19b72e9f45067d74d21da5db562ca642e10ae9/figures/normal.png -------------------------------------------------------------------------------- /figures/qrs_lead_avr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/Preprocessing-Method-for-STEMI-Detection/dc19b72e9f45067d74d21da5db562ca642e10ae9/figures/qrs_lead_avr.png -------------------------------------------------------------------------------- /figures/qrs_lead_i.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/Preprocessing-Method-for-STEMI-Detection/dc19b72e9f45067d74d21da5db562ca642e10ae9/figures/qrs_lead_i.png -------------------------------------------------------------------------------- /figures/qrs_voted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/Preprocessing-Method-for-STEMI-Detection/dc19b72e9f45067d74d21da5db562ca642e10ae9/figures/qrs_voted.png -------------------------------------------------------------------------------- /figures/resnet34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/Preprocessing-Method-for-STEMI-Detection/dc19b72e9f45067d74d21da5db562ca642e10ae9/figures/resnet34.png -------------------------------------------------------------------------------- /figures/stemi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/Preprocessing-Method-for-STEMI-Detection/dc19b72e9f45067d74d21da5db562ca642e10ae9/figures/stemi.png -------------------------------------------------------------------------------- /figures/vggnet16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/Preprocessing-Method-for-STEMI-Detection/dc19b72e9f45067d74d21da5db562ca642e10ae9/figures/vggnet16.png -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import os, inspect, glob, random, scipy, peakutils, argparse 2 | from scipy import signal 3 | from wfdb import processing 4 | 5 | import numpy as np 6 | 7 | PACK_PATH = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))+"/.." 8 | 9 | class Preprocess(object): 10 | 11 | def __init__(self, rawpath, setname="BP", fs=500, nfft=4096, rmfreqs=[60, 120, 180, 240], outdim=600): 12 | 13 | self.rawpath = rawpath 14 | self.setname = setname 15 | self.fs = fs 16 | self.nfft = nfft 17 | self.rmfreqs = rmfreqs 18 | self.outdim = outdim 19 | 20 | def make(self): 21 | 22 | self.makedir(path="dataset_%s" %(self.setname)) 23 | 24 | subclasses = glob.glob(os.path.join(self.rawpath, "*")) 25 | subclasses.sort() 26 | for sidx, subclass in enumerate(subclasses): 27 | subname = subclass.split("/")[-1] 28 | print("\n%s" %(subname)) 29 | self.makedir(path=os.path.join("dataset_%s" %(self.setname), subname)) 30 | 31 | npys = glob.glob(os.path.join(subclass, "*.npy")) 32 | npys.sort() 33 | for nidx, npy in enumerate(npys): 34 | npyname = npy.split("/")[-1].replace(".npy", "") 35 | print(npyname) 36 | data = np.load(npy) 37 | origin = data.copy() 38 | data = data[:, 250:5250] 39 | 40 | maxlen = data.shape[1] 41 | 42 | Y, x_freq, x_freq_val = self.fast_fourier_transform(sig=data[0], fs=self.fs, nfft=self.nfft) 43 | 44 | x_notch = np.zeros_like(data) 45 | x_high = np.zeros_like(data) 46 | x_total = np.zeros_like(data) 47 | for didx, dat in enumerate(data): 48 | if("N" in self.setname): 49 | x_total[didx] = self.notchfilter(sig=dat, fs=self.fs, freqs=self.rmfreqs, Q=1) 50 | elif("H" in self.setname): 51 | x_total[didx] = self.highpassfilter(data=dat, cutoff=1, fs=self.fs) 52 | elif("B" in self.setname): 53 | x_notch[didx] = self.notchfilter(sig=dat, fs=self.fs, freqs=self.rmfreqs, Q=1) 54 | x_total[didx] = self.highpassfilter(data=x_notch[didx], cutoff=1, fs=self.fs) 55 | else: 56 | x_total[didx] = dat 57 | 58 | if(("N" in self.setname) and not("NP" in self.setname) or 59 | ("H" in self.setname) and not("HP" in self.setname) or 60 | ("B" in self.setname) and not("BP" in self.setname) or 61 | ("R" in self.setname) and not("RP" in self.setname)): 62 | np.save(os.path.join("dataset_%s" %(self.setname), subname, "%s" %(npyname)), x_total) 63 | else: 64 | """Start point of peak voting process""" 65 | x_vote = np.zeros((maxlen)) 66 | 67 | x_total_flip = x_total * (-1) 68 | 69 | peak_indices = [] 70 | for cidx in range(12): 71 | indices = self.peak_selection(signal=x_total[cidx], threshold=0.8) 72 | indices_flip = self.peak_selection(signal=x_total_flip[cidx], threshold=0.8) 73 | peak_indices.append(indices) 74 | peak_indices.append(indices_flip) 75 | for idx in indices: 76 | x_vote[idx-10:idx+10] += 1 77 | for idx in indices_flip: 78 | x_vote[idx-10:idx+10] += 1 79 | x_vote[:250] /= 10 80 | x_vote[250:] /= 10 81 | indices = self.peak_selection(signal=x_vote, threshold=0.5) 82 | indices_filtered, interval = self.peak_filtering(indices=indices, maxlen=maxlen) 83 | """End point of peak voting process""" 84 | 85 | """Start point of slicing process""" 86 | for i, pidx in enumerate(indices_filtered): 87 | 88 | term = int(interval / 2) 89 | sp, ep = pidx - term, pidx + term 90 | if(sp < 0): sp = 0 91 | if(ep >= x_total.shape[1]): ep = (x_total.shape[1]-1) 92 | if(abs(sp-ep) < interval*0.9): continue 93 | 94 | rows = np.zeros((0, abs(sp-ep))) 95 | for idx in range(12): 96 | row = x_total[idx][sp:ep].reshape((1, abs(sp-ep))) 97 | rows = np.append(rows, row, axis=0) 98 | rows = rows.T 99 | 100 | rows = self.linear_interpolation(data=rows, outdim=self.outdim) 101 | rows = self.range_regularization(data=rows) 102 | np.save(os.path.join("dataset_%s" %(self.setname), subname, "%s_%d" %(npyname, i)), rows) 103 | """End point of slicing process""" 104 | 105 | def makedir(self, path): 106 | try: os.mkdir(path) 107 | except: pass 108 | 109 | def butter_highpass(self, cutoff, fs, order=5): 110 | nyq = 0.5 * fs 111 | normal_cutoff = cutoff / nyq 112 | b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) 113 | return b, a 114 | 115 | def highpassfilter(self, data, cutoff, fs, order=5): 116 | b, a = self.butter_highpass(cutoff, fs, order=order) 117 | y = signal.filtfilt(b, a, data) 118 | return y 119 | 120 | def notchfilter(self, sig, fs=500, freqs=[60], Q=1): 121 | 122 | for f0 in freqs: 123 | w0 = f0/(fs/2) # Normalized Frequency 124 | # Design notch filter 125 | b, a = signal.iirnotch(w0, Q) 126 | 127 | nfft = 4096 128 | cp = int(nfft/2) 129 | 130 | sig = scipy.signal.lfilter(b=b, a=a, x=sig) 131 | 132 | return sig 133 | 134 | def fast_fourier_transform(self, sig, fs=500, nfft=4096): 135 | 136 | sig_ft = np.fft.fft(sig, n=nfft) 137 | 138 | fftn = int(nfft/2) 139 | fstep = (fs/2)/fftn 140 | 141 | x_freq, x_freq_val = [], [] 142 | cnt = 0 143 | for i in range(fftn): 144 | if(int(i*fstep) == int(50*cnt)): 145 | cnt += 1 146 | x_freq.append(i) 147 | x_freq_val.append(int(i*fstep)) 148 | 149 | return abs(sig_ft), x_freq, x_freq_val 150 | 151 | def magnitude2dB(self, mag): 152 | 153 | mag[0] = 0 # Remove DC term 154 | db = (np.log(mag+1e-9) / np.log(np.ones_like(mag)*10)) * 10 155 | 156 | return db 157 | 158 | def peak_selection(self, signal, fs=500, threshold=0.2): 159 | 160 | while(True): 161 | indices = peakutils.indexes(signal, thres=threshold, min_dist=100) 162 | 163 | if(len(indices) >= int((signal.shape[0]/fs) - 1)): break 164 | else: threshold *= 0.95 165 | 166 | return indices 167 | 168 | def peak_filtering(self, indices, fs=500, maxlen=5500): 169 | 170 | interval = 0 171 | limit = 1.0 172 | while(True): 173 | for i, idx in enumerate(indices): 174 | if(i != 0): 175 | interval_tmp = abs(indices[i] - indices[i-1]) 176 | if((interval_tmp > interval) and (interval_tmp < (fs*limit))): interval = interval_tmp 177 | if(interval != 0): break 178 | else: limit += 0.1 179 | 180 | indices = list(indices) 181 | 182 | indices.reverse() 183 | i = 0 184 | while(True): 185 | try: 186 | # print(i, indices[i], indices[i+1], indices[i] - indices[i+1], abs(indices[i] - indices[i+1])) 187 | if(abs(indices[i] - indices[i+1]) < (interval * 0.7)): 188 | try: 189 | if(abs(indices[i] - indices[i+1]) < abs(indices[i+1] - indices[i+2])): indices.pop(i) 190 | else: indices.pop(i+1) 191 | except: indices.pop(i) 192 | else: i += 1 193 | except: break 194 | 195 | indices.reverse() 196 | i = 0 197 | while(True): 198 | try: 199 | if(indices[i] - interval < 0): indices.pop(i) 200 | elif(indices[i] + interval > maxlen): indices.pop(i) 201 | else: i += 1 202 | except: break 203 | 204 | return indices, interval 205 | 206 | def range_regularization(self, data): 207 | if(np.min(data) < 0): data += abs(np.min(data)) 208 | else: data -= abs(np.min(data)) 209 | data /= np.max(data) 210 | 211 | return data 212 | 213 | def linear_interpolation(self, data, outdim): 214 | 215 | inter_unit = outdim / data.shape[0] 216 | 217 | outdata = np.zeros((outdim, data.shape[1])) 218 | 219 | for sigidx in range(data.shape[0]): 220 | x1 = int((sigidx-1)*inter_unit) 221 | x2 = int((sigidx+1)*inter_unit) 222 | if(sigidx == data.shape[0]-1): x2 = outdim - 1 223 | for chdix in range(data.shape[1]): 224 | outdata[x2][chdix] = data[sigidx][chdix] 225 | if(sigidx != 0): 226 | diff = (data[sigidx][chdix] - data[sigidx-1][chdix]) / (x2 - x1) 227 | for inter in range(x2-x1): 228 | if(inter == 0): continue 229 | else: outdata[x1+inter][chdix] = data[sigidx-1][chdix] + (inter * diff) 230 | 231 | return outdata 232 | 233 | if __name__ == '__main__': 234 | 235 | parser = argparse.ArgumentParser() 236 | 237 | parser.add_argument('--rawpath', type=str, default="Raw_ECG", help='Path of rawdata.') 238 | parser.add_argument('--set', type=str, default="BP", help='Kind of dataset.') 239 | parser.add_argument('--fs', type=int, default=500, help='Sampling rate of raw data.') 240 | parser.add_argument('--nfft', type=int, default=4096, help='FFT point for Fourier transform.') 241 | parser.add_argument('--rmfreq', type=int, default=60, help='Frequency for Notch filter.') 242 | parser.add_argument('--outdim', type=int, default=600, help='Dimension of output.') 243 | 244 | FLAGS, unparsed = parser.parse_known_args() 245 | 246 | rmfreqs = [] 247 | for i in range(4): 248 | rmfreqs.append(FLAGS.rmfreq * (i+1)) 249 | 250 | process = Preprocess(rawpath=FLAGS.rawpath, setname=FLAGS.set.upper(), fs=FLAGS.fs, nfft=FLAGS.nfft, rmfreqs=rmfreqs, outdim=FLAGS.outdim) 251 | process.make() 252 | --------------------------------------------------------------------------------