├── README.md ├── RnnAttention └── attention.py ├── cnn_class.py ├── cram.py └── data_proc.py /README.md: -------------------------------------------------------------------------------- 1 | # CRAM 2 | 3 | [A Convolutional Recurrent Attention Model for Subject-Independent EEG Signal Analysis](https://ieeexplore.ieee.org/document/8675451) 4 | IEEE Signal Processing Letters (Volume: 26 , Issue: 5 , May 2019, pp715-719) 5 | -------------------------------------------------------------------------------- /RnnAttention/attention.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def attention(inputs, attention_size, time_major=False, return_alphas=False, train_phase=True): 5 | """ 6 | Attention mechanism layer which reduces RNN/Bi-RNN outputs with Attention vector. 7 | 8 | Args: 9 | inputs: The Attention inputs. 10 | Matches outputs of RNN/Bi-RNN layer (not final state): 11 | In case of RNN, this must be RNN outputs `Tensor`: 12 | If time_major == False (default), this must be a tensor of shape: 13 | `[batch_size, max_time, cell.output_size]`. 14 | If time_major == True, this must be a tensor of shape: 15 | `[max_time, batch_size, cell.output_size]`. 16 | In case of Bidirectional RNN, this must be a tuple (outputs_fw, outputs_bw) containing the forward and 17 | the backward RNN outputs `Tensor`. 18 | If time_major == False (default), 19 | outputs_fw is a `Tensor` shaped: 20 | `[batch_size, max_time, cell_fw.output_size]` 21 | and outputs_bw is a `Tensor` shaped: 22 | `[batch_size, max_time, cell_bw.output_size]`. 23 | If time_major == True, 24 | outputs_fw is a `Tensor` shaped: 25 | `[max_time, batch_size, cell_fw.output_size]` 26 | and outputs_bw is a `Tensor` shaped: 27 | `[max_time, batch_size, cell_bw.output_size]`. 28 | attention_size: Linear size of the Attention weights. 29 | time_major: The shape format of the `inputs` Tensors. 30 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 31 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 32 | Using `time_major = True` is a bit more efficient because it avoids 33 | transposes at the beginning and end of the RNN calculation. However, 34 | most TensorFlow data is batch-major, so by default this function 35 | accepts input and emits output in batch-major form. 36 | return_alphas: Whether to return attention coefficients variable along with layer's output. 37 | Used for visualization purpose. 38 | Returns: 39 | The Attention output `Tensor`. 40 | In case of RNN, this will be a `Tensor` shaped: 41 | `[batch_size, cell.output_size]`. 42 | In case of Bidirectional RNN, this will be a `Tensor` shaped: 43 | `[batch_size, cell_fw.output_size + cell_bw.output_size]`. 44 | """ 45 | 46 | if isinstance(inputs, tuple): 47 | # In case of Bi-RNN, concatenate the forward and the backward RNN outputs. 48 | inputs = tf.concat(inputs, 2) 49 | 50 | if time_major: 51 | # (T,B,D) => (B,T,D) 52 | inputs = tf.array_ops.transpose(inputs, [1, 0, 2]) 53 | 54 | hidden_size = inputs.shape[2].value # D value - hidden size of the RNN layer 55 | 56 | # Trainable parameters 57 | w_omega = tf.Variable(tf.random_normal([hidden_size, attention_size], stddev=0.1)) 58 | b_omega = tf.Variable(tf.random_normal([attention_size], stddev=0.1)) 59 | u_omega = tf.Variable(tf.random_normal([attention_size], stddev=0.1)) 60 | 61 | with tf.name_scope('v'): 62 | # Applying fully connected layer with non-linear activation to each of the B*T timestamps; 63 | # the shape of `v` is (B,T,D)*(D,A)=(B,T,A), where A=attention_size 64 | v = tf.tanh(tf.tensordot(inputs, w_omega, axes=1) + b_omega) 65 | # For each of the timestamps its vector of size A from `v` is reduced with `u` vector 66 | vu = tf.tensordot(v, u_omega, axes=1, name='vu') # (B,T) shape 67 | alphas = tf.nn.softmax(vu, name='alphas') # (B,T) shape 68 | 69 | # Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape 70 | output = tf.reduce_sum(inputs * tf.expand_dims(alphas, -1), 1) 71 | 72 | if not return_alphas: 73 | return output 74 | else: 75 | return output, alphas 76 | -------------------------------------------------------------------------------- /cnn_class.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | 3 | import tensorflow as tf 4 | 5 | class cnn: 6 | def __init__( 7 | self, 8 | weight_stddev = 0.1, 9 | bias_constant = 0.1, 10 | padding = "SAME", 11 | ): 12 | self.weight_stddev = weight_stddev 13 | self.bias_constant = bias_constant 14 | self.padding = padding 15 | 16 | def weight_variable(self, shape): 17 | initial = tf.truncated_normal(shape, stddev = self.weight_stddev) 18 | return tf.Variable(initial) 19 | 20 | 21 | def bias_variable(self, shape): 22 | initial = tf.constant(self.bias_constant, shape = shape) 23 | return tf.Variable(initial) 24 | 25 | 26 | def conv1d(self, x, W, kernel_stride): 27 | # API: must strides[0]=strides[4]=1 28 | return tf.nn.conv1d(x, W, stride=kernel_stride, padding=self.padding) 29 | 30 | 31 | def conv2d(self, x, W, kernel_stride): 32 | # API: must strides[0]=strides[4]=1 33 | return tf.nn.conv2d(x, W, strides=[1, kernel_stride, kernel_stride, 1], padding=self.padding) 34 | 35 | 36 | def conv3d(self, x, W, kernel_stride): 37 | # API: must strides[0]=strides[4]=1 38 | return tf.nn.conv3d(x, W, strides=[1, kernel_stride, kernel_stride, kernel_stride, 1], padding=self.padding) 39 | 40 | 41 | def apply_conv1d(self, x, filter_width, in_channels, out_channels, kernel_stride, train_phase): 42 | weight = self.weight_variable([filter_width, in_channels, out_channels]) 43 | bias = self.bias_variable([out_channels]) # each feature map shares the same weight and bias 44 | conv_1d = tf.add(self.conv1d(x, weight, kernel_stride), bias) 45 | conv_1d_bn = self.batch_norm_cnv_1d(conv_1d, train_phase) 46 | return tf.nn.relu(conv_1d_bn) 47 | 48 | 49 | def apply_conv2d(self, x, filter_height, filter_width, in_channels, out_channels, kernel_stride, train_phase): 50 | weight = self.weight_variable([filter_height, filter_width, in_channels, out_channels]) 51 | bias = self.bias_variable([out_channels]) # each feature map shares the same weight and bias 52 | conv_2d = tf.add(self.conv2d(x, weight, kernel_stride), bias) 53 | conv_2d_bn = self.batch_norm_cnv_2d(conv_2d, train_phase) 54 | return tf.nn.relu(conv_2d_bn) 55 | 56 | 57 | 58 | def apply_conv3d(self, x, filter_depth, filter_height, filter_width, in_channels, out_channels, kernel_stride, train_phase): 59 | weight = self.weight_variable([filter_depth, filter_height, filter_width, in_channels, out_channels]) 60 | bias = self.bias_variable([out_channels]) # each feature map shares the same weight and bias 61 | conv_3d = tf.add(self.conv3d(x, weight, kernel_stride), bias) 62 | conv_3d_bn = self.batch_norm_cnv_3d(conv_3d, train_phase) 63 | return tf.nn.relu(conv_3d_bn) 64 | 65 | 66 | def batch_norm_cnv_3d(self, inputs, train_phase): 67 | return tf.layers.batch_normalization(inputs, axis=4, momentum=0.993, epsilon=1e-5, scale=False, training=train_phase) 68 | 69 | 70 | def batch_norm_cnv_2d(self, inputs, train_phase): 71 | return tf.layers.batch_normalization(inputs, axis=3, momentum=0.993, epsilon=1e-5, scale=False, training=train_phase) 72 | 73 | 74 | def batch_norm_cnv_1d(self, inputs, train_phase): 75 | return tf.layers.batch_normalization(inputs, axis=2, momentum=0.993, epsilon=1e-5, scale=False, training=train_phase) 76 | 77 | 78 | def batch_norm(self, inputs, train_phase): 79 | return tf.layers.batch_normalization(inputs, axis=1, momentum=0.993, epsilon=1e-5, scale=False, training=train_phase) 80 | 81 | 82 | def apply_max_pooling(self, x, pooling_height, pooling_width, pooling_stride): 83 | # API: must ksize[0]=ksize[4]=1, strides[0]=strides[4]=1 84 | return tf.nn.max_pool(x, ksize=[1, pooling_height, pooling_width, 1], strides=[1, pooling_stride, pooling_stride, 1], padding=self.padding) 85 | 86 | 87 | def apply_max_pooling3d(self, x, pooling_depth, pooling_height, pooling_width, pooling_stride): 88 | # API: must ksize[0]=ksize[4]=1, strides[0]=strides[4]=1 89 | return tf.nn.max_pool3d(x, ksize=[1, pooling_depth, pooling_height, pooling_width, 1], strides=[1, pooling_stride, pooling_stride, pooling_stride, 1], padding=self.padding) 90 | 91 | 92 | def apply_fully_connect(self, x, x_size, fc_size, train_phase): 93 | fc_weight = self.weight_variable([x_size, fc_size]) 94 | fc_bias = self.bias_variable([fc_size]) 95 | fc = tf.add(tf.matmul(x, fc_weight), fc_bias) 96 | fc_bn = self.batch_norm(fc, train_phase) 97 | return tf.nn.relu(fc_bn) 98 | 99 | 100 | def apply_readout(self, x, x_size, readout_size): 101 | readout_weight = self.weight_variable([x_size, readout_size]) 102 | readout_bias = self.bias_variable([readout_size]) 103 | return tf.add(tf.matmul(x, readout_weight), readout_bias) 104 | -------------------------------------------------------------------------------- /cram.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | import numpy as np 3 | import pandas as pd 4 | import tensorflow as tf 5 | from cnn_class import cnn 6 | import time 7 | import scipy.io as sio 8 | from sklearn.metrics import classification_report, roc_auc_score, auc, roc_curve, f1_score 9 | from RnnAttention.attention import attention 10 | from scipy import interp 11 | 12 | 13 | def multiclass_roc_auc_score(y_true, y_score): 14 | assert y_true.shape == y_score.shape 15 | fpr = dict() 16 | tpr = dict() 17 | roc_auc = dict() 18 | n_classes = y_true.shape[1] 19 | # compute ROC curve and ROC area for each class 20 | for i in range(n_classes): 21 | fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_score[:, i]) 22 | roc_auc[i] = auc(fpr[i], tpr[i]) 23 | # compute micro-average ROC curve and ROC area 24 | fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_score.ravel()) 25 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 26 | 27 | # compute macro-average ROC curve and ROC area 28 | # First aggregate all false probtive rates 29 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) 30 | # Then interpolate all ROC curves at this points 31 | mean_tpr = np.zeros_like(all_fpr) 32 | for i in range(n_classes): 33 | mean_tpr += interp(all_fpr, fpr[i], tpr[i]) 34 | # Finally average it and compute AUC 35 | mean_tpr /= n_classes 36 | fpr["macro"] = all_fpr 37 | tpr["macro"] = mean_tpr 38 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) 39 | return roc_auc 40 | 41 | ########################################################################### 42 | # prepare raw data 43 | ########################################################################### 44 | subject_id = 1 45 | data_folder = '/home/dalinzhang/scratch/datasets/BCICIV_2a_gdf' 46 | data = sio.loadmat(data_folder+"/cross_sub/cross_subject_data_"+str(subject_id)+".mat") 47 | print("subject id ", subject_id) 48 | 49 | test_X = data["test_x"] # [trials, channels, time length] 50 | train_X = data["train_x"] 51 | 52 | test_y = data["test_y"].ravel() 53 | train_y = data["train_y"].ravel() 54 | 55 | 56 | train_y = np.asarray(pd.get_dummies(train_y), dtype = np.int8) 57 | test_y = np.asarray(pd.get_dummies(test_y), dtype = np.int8) 58 | 59 | ########################################################################### 60 | # crop data 61 | ########################################################################### 62 | 63 | window_size = 400 64 | step = 50 65 | n_channel = 22 66 | 67 | 68 | def windows(data, size, step): 69 | start = 0 70 | while ((start+size) < data.shape[0]): 71 | yield int(start), int(start + size) 72 | start += step 73 | 74 | 75 | def segment_signal_without_transition(data, window_size, step): 76 | segments = [] 77 | for (start, end) in windows(data, window_size, step): 78 | if(len(data[start:end]) == window_size): 79 | segments = segments + [data[start:end]] 80 | return np.array(segments) 81 | 82 | 83 | def segment_dataset(X, window_size, step): 84 | win_x = [] 85 | for i in range(X.shape[0]): 86 | win_x = win_x + [segment_signal_without_transition(X[i], window_size, step)] 87 | win_x = np.array(win_x) 88 | return win_x 89 | 90 | 91 | train_raw_x = np.transpose(train_X, [0, 2, 1]) 92 | test_raw_x = np.transpose(test_X, [0, 2, 1]) 93 | 94 | 95 | train_win_x = segment_dataset(train_raw_x, window_size, step) 96 | print("train_win_x shape: ", train_win_x.shape) 97 | test_win_x = segment_dataset(test_raw_x, window_size, step) 98 | print("test_win_x shape: ", test_win_x.shape) 99 | 100 | # [trial, window, channel, time_length] 101 | train_win_x = np.transpose(train_win_x, [0, 1, 3, 2]) 102 | print("train_win_x shape: ", train_win_x.shape) 103 | 104 | test_win_x = np.transpose(test_win_x, [0, 1, 3, 2]) 105 | print("test_win_x shape: ", test_win_x.shape) 106 | 107 | 108 | # [trial, window, channel, time_length, 1] 109 | train_x = np.expand_dims(train_win_x, axis = 4) 110 | test_x = np.expand_dims(test_win_x, axis = 4) 111 | 112 | num_timestep = train_x.shape[1] 113 | ########################################################################### 114 | # set model parameters 115 | ########################################################################### 116 | # kernel parameter 117 | kernel_height_1st = 22 118 | kernel_width_1st = 45 119 | 120 | kernel_stride = 1 121 | 122 | conv_channel_num = 40 123 | 124 | # pooling parameter 125 | pooling_height_1st = 1 126 | pooling_width_1st = 75 127 | 128 | pooling_stride_1st = 10 129 | 130 | # full connected parameter 131 | attention_size = 512 132 | n_hidden_state = 64 133 | 134 | ########################################################################### 135 | # set dataset parameters 136 | ########################################################################### 137 | # input channel 138 | input_channel_num = 1 139 | 140 | # input height 141 | input_height = train_x.shape[2] 142 | 143 | # input width 144 | input_width = train_x.shape[3] 145 | 146 | # prediction class 147 | num_labels = 4 148 | ########################################################################### 149 | # set training parameters 150 | ########################################################################### 151 | # set learning rate 152 | learning_rate = 1e-4 153 | 154 | # set maximum traing epochs 155 | training_epochs = 200 156 | 157 | # set batch size 158 | batch_size = 10 159 | 160 | # set dropout probability 161 | dropout_prob = 0.5 162 | 163 | # set train batch number per epoch 164 | batch_num_per_epoch = train_x.shape[0]//batch_size 165 | 166 | # instance cnn class 167 | padding = 'VALID' 168 | 169 | cnn_2d = cnn(padding=padding) 170 | 171 | # input placeholder 172 | X = tf.placeholder(tf.float32, shape=[None, input_height, input_width, input_channel_num], name = 'X') 173 | Y = tf.placeholder(tf.float32, shape=[None, num_labels], name = 'Y') 174 | train_phase = tf.placeholder(tf.bool, name = 'train_phase') 175 | keep_prob = tf.placeholder(tf.float32, name='keep_prob') 176 | 177 | # first CNN layer 178 | conv_1 = cnn_2d.apply_conv2d(X, kernel_height_1st, kernel_width_1st, input_channel_num, conv_channel_num, kernel_stride, train_phase) 179 | print("conv 1 shape: ", conv_1.get_shape().as_list()) 180 | pool_1 = cnn_2d.apply_max_pooling(conv_1, pooling_height_1st, pooling_width_1st, pooling_stride_1st) 181 | print("pool 1 shape: ", pool_1.get_shape().as_list()) 182 | 183 | pool1_shape = pool_1.get_shape().as_list() 184 | pool1_flat = tf.reshape(pool_1, [-1, pool1_shape[1]*pool1_shape[2]*pool1_shape[3]]) 185 | 186 | fc_drop = tf.nn.dropout(pool1_flat, keep_prob) 187 | 188 | lstm_in = tf.reshape(fc_drop, [-1, num_timestep, pool1_shape[1]*pool1_shape[2]*pool1_shape[3]]) 189 | 190 | ########################## RNN ######################## 191 | cells = [] 192 | for _ in range(2): 193 | cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_state, forget_bias=1.0, state_is_tuple=True) 194 | cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob) 195 | cells.append(cell) 196 | lstm_cell = tf.contrib.rnn.MultiRNNCell(cells) 197 | 198 | init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) 199 | 200 | # output ==> [batch, step, n_hidden_state] 201 | rnn_op, states = tf.nn.dynamic_rnn(lstm_cell, lstm_in, initial_state=init_state, time_major=False) 202 | 203 | ########################## attention ######################## 204 | with tf.name_scope('Attention_layer'): 205 | attention_op, alphas = attention(rnn_op, attention_size, time_major = False, return_alphas=True) 206 | 207 | attention_drop = tf.nn.dropout(attention_op, keep_prob) 208 | 209 | ########################## readout ######################## 210 | y_ = cnn_2d.apply_readout(attention_drop, rnn_op.shape[2].value, num_labels) 211 | 212 | # probability prediction 213 | y_prob = tf.nn.softmax(y_, name = "y_prob") 214 | 215 | # class prediction 216 | y_pred = tf.argmax(y_prob, 1, name = "y_pred") 217 | 218 | ########################## loss and optimizer ######################## 219 | # cross entropy cost function 220 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_, labels=Y), name = 'loss') 221 | 222 | 223 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 224 | with tf.control_dependencies(update_ops): 225 | # set training SGD optimizer 226 | optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost) 227 | 228 | # get correctly predicted object 229 | correct_prediction = tf.equal(tf.argmax(tf.nn.softmax(y_), 1), tf.argmax(Y, 1)) 230 | 231 | ########################## define accuracy ######################## 232 | # calculate prediction accuracy 233 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name = 'accuracy') 234 | 235 | 236 | ########################################################################### 237 | # train test and save result 238 | ########################################################################### 239 | 240 | # run with gpu memory growth 241 | config = tf.ConfigProto() 242 | config.gpu_options.allow_growth = True 243 | 244 | train_acc = [] 245 | test_acc = [] 246 | best_test_acc = [] 247 | train_loss = [] 248 | with tf.Session(config=config) as session: 249 | session.run(tf.global_variables_initializer()) 250 | best_acc = 0 251 | for epoch in range(training_epochs): 252 | pred_test = np.array([]) 253 | true_test = [] 254 | prob_test = [] 255 | ########################## training process ######################## 256 | for b in range(batch_num_per_epoch): 257 | offset = (b * batch_size) % (train_y.shape[0] - batch_size) 258 | batch_x = train_x[offset:(offset + batch_size), :, :, :, :] 259 | batch_x = batch_x.reshape([len(batch_x)*num_timestep, n_channel, window_size, 1]) 260 | batch_y = train_y[offset:(offset + batch_size), :] 261 | _, c = session.run([optimizer, cost], feed_dict={X: batch_x, Y: batch_y, keep_prob: 1-dropout_prob, train_phase: True}) 262 | # calculate train and test accuracy after each training epoch 263 | if(epoch%1 == 0): 264 | train_accuracy = np.zeros(shape=[0], dtype=float) 265 | test_accuracy = np.zeros(shape=[0], dtype=float) 266 | train_l = np.zeros(shape=[0], dtype=float) 267 | test_l = np.zeros(shape=[0], dtype=float) 268 | # calculate train accuracy after each training epoch 269 | for i in range(batch_num_per_epoch): 270 | ########################## prepare training data ######################## 271 | offset = (i * batch_size) % (train_y.shape[0] - batch_size) 272 | train_batch_x = train_x[offset:(offset + batch_size), :, :, :] 273 | train_batch_x = train_batch_x.reshape([len(train_batch_x)*num_timestep, n_channel, window_size, 1]) 274 | train_batch_y = train_y[offset:(offset + batch_size), :] 275 | 276 | ########################## calculate training results ######################## 277 | train_a, train_c = session.run([accuracy, cost], feed_dict={X: train_batch_x, Y: train_batch_y, keep_prob: 1.0, train_phase: False}) 278 | 279 | train_l = np.append(train_l, train_c) 280 | train_accuracy = np.append(train_accuracy, train_a) 281 | print("("+time.asctime(time.localtime(time.time()))+") Epoch: ", epoch+1, " Training Cost: ", np.mean(train_l), "Training Accuracy: ", np.mean(train_accuracy)) 282 | train_acc = train_acc + [np.mean(train_accuracy)] 283 | train_loss = train_loss + [np.mean(train_l)] 284 | # calculate test accuracy after each training epoch 285 | for j in range(batch_num_per_epoch): 286 | ########################## prepare test data ######################## 287 | offset = (j * batch_size) % (test_y.shape[0] - batch_size) 288 | test_batch_x = test_x[offset:(offset + batch_size), :, :, :] 289 | test_batch_x = test_batch_x.reshape([len(test_batch_x)*num_timestep, n_channel, window_size, 1]) 290 | test_batch_y = test_y[offset:(offset + batch_size), :] 291 | 292 | ########################## calculate test results ######################## 293 | test_a, test_c, prob_v, pred_v = session.run([accuracy, cost, y_prob, y_pred], feed_dict={X: test_batch_x, Y: test_batch_y, keep_prob: 1.0, train_phase: False}) 294 | 295 | test_accuracy = np.append(test_accuracy, test_a) 296 | test_l = np.append(test_l, test_c) 297 | pred_test = np.append(pred_test, pred_v) 298 | true_test.append(test_batch_y) 299 | prob_test.append(prob_v) 300 | if np.mean(test_accuracy) > best_acc : 301 | best_acc = np.mean(test_accuracy) 302 | true_test = np.array(true_test).reshape([-1, num_labels]) 303 | prob_test = np.array(prob_test).reshape([-1, num_labels]) 304 | auc_roc_test = multiclass_roc_auc_score(y_true=true_test, y_score=prob_test) 305 | f1 = f1_score(y_true=np.argmax(true_test, axis = 1), y_pred=pred_test, average = 'macro') 306 | print("("+time.asctime(time.localtime(time.time()))+") Epoch: ", epoch+1, "Test Cost: ", np.mean(test_l), 307 | "Test Accuracy: ", np.mean(test_accuracy), 308 | "Test f1: ", f1, 309 | "Test AUC: ", auc_roc_test['macro'], "\n") 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | -------------------------------------------------------------------------------- /data_proc.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import scipy.io as sio 3 | import pickle 4 | import numpy as np 5 | 6 | from collections import OrderedDict 7 | from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A 8 | from braindecode.mne_ext.signalproc import mne_apply 9 | from braindecode.datautil.signalproc import (bandpass_cnt, 10 | exponential_running_standardize) 11 | from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne 12 | from braindecode.datautil.splitters import concatenate_sets 13 | 14 | ####################### 15 | # reference: github: TNTLFreiburg: braindecode/examples/bcic_iv_2a.py 16 | ####################### 17 | 18 | 19 | def data_gen(subject, high_cut_hz=38, low_cut_hz=0): 20 | data_sub = {} 21 | for i in range(len(subject)): 22 | subject_id = subject[i] 23 | data_folder = '/home/dadafly/program/bci_data/data_folder' 24 | ival = [-500, 4000] 25 | factor_new = 1e-3 26 | init_block_size = 1000 27 | 28 | train_filename = 'A{:02d}T.gdf'.format(subject_id) 29 | test_filename = 'A{:02d}E.gdf'.format(subject_id) 30 | train_filepath = os.path.join(data_folder, train_filename) 31 | test_filepath = os.path.join(data_folder, test_filename) 32 | train_label_filepath = train_filepath.replace('.gdf', '.mat') 33 | test_label_filepath = test_filepath.replace('.gdf', '.mat') 34 | 35 | train_loader = BCICompetition4Set2A( 36 | train_filepath, labels_filename=train_label_filepath) 37 | test_loader = BCICompetition4Set2A( 38 | test_filepath, labels_filename=test_label_filepath) 39 | 40 | train_cnt = train_loader.load() 41 | test_cnt = test_loader.load() 42 | 43 | 44 | train_loader = BCICompetition4Set2A( 45 | train_filepath, labels_filename=train_label_filepath) 46 | test_loader = BCICompetition4Set2A( 47 | test_filepath, labels_filename=test_label_filepath) 48 | 49 | train_cnt = train_loader.load() 50 | test_cnt = test_loader.load() 51 | 52 | # train set process 53 | train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left', 54 | 'EOG-central', 'EOG-right']) 55 | assert len(train_cnt.ch_names) == 22 56 | 57 | train_cnt = mne_apply(lambda a: a * 1e6, train_cnt) 58 | train_cnt = mne_apply( 59 | lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'], 60 | filt_order=3, axis=1), train_cnt) 61 | 62 | train_cnt = mne_apply( 63 | lambda a: exponential_running_standardize(a.T, factor_new=factor_new, 64 | init_block_size=init_block_size, 65 | eps=1e-4).T, train_cnt) 66 | 67 | # test set process 68 | test_cnt = test_cnt.drop_channels(['STI 014', 'EOG-left', 69 | 'EOG-central', 'EOG-right']) 70 | assert len(test_cnt.ch_names) == 22 71 | test_cnt = mne_apply(lambda a: a * 1e6, test_cnt) 72 | test_cnt = mne_apply( 73 | lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, test_cnt.info['sfreq'], 74 | filt_order=3, axis=1), test_cnt) 75 | test_cnt = mne_apply( 76 | lambda a: exponential_running_standardize(a.T, factor_new=factor_new, 77 | init_block_size=init_block_size, 78 | eps=1e-4).T, test_cnt) 79 | 80 | marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],), 81 | ('Foot', [3]), ('Tongue', [4])]) 82 | 83 | train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival) 84 | test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival) 85 | 86 | data_sub[str(subject_id)] = concatenate_sets([train_set, test_set]) 87 | if i == 0: 88 | dataset = data_sub[str(subject_id)] 89 | else: 90 | dataset = concatenate_sets([dataset, data_sub[str(subject_id)]]) 91 | assert len(data_sub) == len(subject) 92 | 93 | return dataset 94 | 95 | 96 | if __name__ == '__main__': 97 | for j in range(1,10): 98 | train_subject = [k for k in range(1,10) if k != j] 99 | test_subject = [j] 100 | train_dataset = data_gen(train_subject, high_cut_hz=125, low_cut_hz=0) 101 | test_dataset = data_gen(test_subject, high_cut_hz=125, low_cut_hz=0) 102 | 103 | train_X = train_dataset.X 104 | train_y = train_dataset.y 105 | test_X = test_dataset.X 106 | test_y = test_dataset.y 107 | 108 | idx = list(range(len(train_y))) 109 | np.random.shuffle(idx) 110 | train_X = train_X[idx] 111 | train_y = train_y[idx] 112 | sio.savemat('/home/dadafly/program/bci_data/data_folder/cross_sub/cross_subject_data_'+str(j)+'.mat', {"train_x": train_X, "train_y": train_y, "test_x": test_X, "test_y": test_y}) 113 | 114 | 115 | --------------------------------------------------------------------------------